Skip to content

Commit fa11039

Browse files
committed
[Proto] Implemented nested Serialization for ProtoPackable objects
1 parent 950a448 commit fa11039

13 files changed

+530
-41
lines changed

Lagrange.Proto.Generator/DiagnosticDescriptors.cs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,22 @@ public static class DiagnosticDescriptors
2121
defaultSeverity: DiagnosticSeverity.Error,
2222
isEnabledByDefault: true
2323
);
24+
25+
public static DiagnosticDescriptor InvalidNumberHandling { get; } = new(
26+
id: "PROTO003",
27+
title: "Invalid number handling for field {0} in class {1}",
28+
messageFormat: "Invalid number handling for field {0} in class {1}",
29+
category: "Usage",
30+
defaultSeverity: DiagnosticSeverity.Error,
31+
isEnabledByDefault: true
32+
);
33+
34+
public static DiagnosticDescriptor NestedTypeMustBeProtoPackable { get; } = new(
35+
id: "PROTO004",
36+
title: "Nested type {0} contained in {1} must be ProtoPackable",
37+
messageFormat: "Nested type {0} contained in {1} must be ProtoPackable",
38+
category: "Usage",
39+
defaultSeverity: DiagnosticSeverity.Error,
40+
isEnabledByDefault: true
41+
);
2442
}

Lagrange.Proto.Generator/Lagrange.Proto.Generator.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
<AnalyzerLanguage>cs</AnalyzerLanguage>
1414
<IncludeBuildOutput>false</IncludeBuildOutput>
1515
<DevelopmentDependency>true</DevelopmentDependency>
16+
<EnforceExtendedAnalyzerRules>true</EnforceExtendedAnalyzerRules>
1617
</PropertyGroup>
1718

1819
<ItemGroup>

Lagrange.Proto.Generator/ProtoSourceGenerator.Emitter.cs

Lines changed: 156 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ public void Emit(SourceProductionContext context)
2020
var classDeclaration = SF.ClassDeclaration(parser.Identifier)
2121
.AddModifiers(SF.Token(SK.PartialKeyword))
2222
.AddAttributeLists(SF.AttributeList().AddAttributes(EmitGeneratedCodeAttribute()))
23-
.AddMembers(EmitSerializeHandlerMethod());
23+
.AddMembers(EmitSerializeHandlerMethod(), EmitMeasureHandlerMethod())
24+
.AddBaseListTypes(SF.SimpleBaseType(SF.ParseName($"global::Lagrange.Proto.IProtoSerializer<{parser.Identifier}>")));
2425

2526
var namespaceDeclaration = SF.FileScopedNamespaceDeclaration(SF.ParseName(parser.Namespace ?? string.Empty))
2627
.AddMembers(classDeclaration);
@@ -41,15 +42,11 @@ public void Emit(SourceProductionContext context)
4142
string code = compilationUnit.NormalizeWhitespace().ToFullString();
4243
context.AddSource($"{parser.Identifier}.g.cs", code);
4344
}
44-
45+
46+
#region SerializeHandler
47+
4548
private MethodDeclarationSyntax EmitSerializeHandlerMethod()
4649
{
47-
string methodName = $"{parser.Identifier}SerializeHandler";
48-
string classFullName = $"global::{parser.Namespace}.{parser.Identifier}?";
49-
var parameters = SF.ParameterList()
50-
.AddParameters(SF.Parameter(SF.Identifier("obj")).WithType(SF.ParseTypeName(classFullName)))
51-
.AddParameters(SF.Parameter(SF.Identifier("writer")).WithType(SF.ParseTypeName(WriterFullName)));
52-
5350
var syntax = new List<StatementSyntax> { EmitNullableCheckStatement(true, "obj", SF.ReturnStatement()) };
5451

5552
foreach (var t in parser.Fields)
@@ -82,33 +79,40 @@ private MethodDeclarationSyntax EmitSerializeHandlerMethod()
8279
if (parser.Model.GetTypeSymbol(type).IsValueType && !type.IsNullableType())
8380
{
8481
syntax.Add(tag);
85-
syntax.Add(field.AddBlankLine());
82+
syntax.AddRange(field);
83+
syntax[syntax.Count - 1] = syntax[syntax.Count - 1].WithTrailingTrivia(SF.Comment("\n"));
8684
}
8785
else
8886
{
89-
var block = SF.Block(SF.List<StatementSyntax>([tag, field]));
87+
var block = SF.Block(SF.List<StatementSyntax>([tag, ..field]));
9088
syntax.Add(EmitNullableCheckStatement(false, $"obj.{name}", block, false));
9189
}
9290
}
9391

94-
return SF.MethodDeclaration(SF.PredefinedType(SF.Token(SK.VoidKeyword)), methodName)
95-
.AddModifiers(SF.Token(SK.PrivateKeyword), SF.Token(SK.StaticKeyword))
92+
string classFullName = $"global::{parser.Namespace}.{parser.Identifier}?";
93+
var parameters = SF.ParameterList()
94+
.AddParameters(SF.Parameter(SF.Identifier("obj")).WithType(SF.ParseTypeName(classFullName)))
95+
.AddParameters(SF.Parameter(SF.Identifier("writer")).WithType(SF.ParseTypeName(WriterFullName)));
96+
97+
return SF.MethodDeclaration(SF.PredefinedType(SF.Token(SK.VoidKeyword)), "SerializeHandler")
98+
.AddModifiers(SF.Token(SK.PublicKeyword), SF.Token(SK.StaticKeyword))
9699
.WithParameterList(parameters)
97100
.WithBody(SF.Block(SF.List(syntax)));
98101
}
99-
100-
private StatementSyntax EmitMemberStatement(WireType wireType, string identifier, TypeSyntax type, bool isSigned)
102+
103+
private StatementSyntax[] EmitMemberStatement(WireType wireType, string identifier, TypeSyntax type, bool isSigned)
101104
{
102-
bool isValueType = parser.Model.GetTypeSymbol(type).IsValueType;
103-
if (type.IsNullableType() && isValueType) identifier += ".Value";
105+
var symbol = parser.Model.GetTypeSymbol(type);
106+
if (type.IsNullableType() && symbol.IsValueType) identifier += ".Value";
104107

105108
return wireType switch
106109
{
107-
WireType.VarInt => EmitVarIntSerializeStatement(identifier, isSigned),
108-
WireType.Fixed32 => EmitFixed32SerializeStatement(identifier, isSigned),
109-
WireType.Fixed64 => EmitFixed64SerializeStatement(identifier, isSigned),
110-
WireType.LengthDelimited when type.IsStringType() => EmitStringSerializeStatement(identifier),
111-
WireType.LengthDelimited when type.IsByteArrayType() => EmitBytesSerializeStatement(identifier),
110+
WireType.VarInt => [EmitVarIntSerializeStatement(identifier, isSigned)],
111+
WireType.Fixed32 => [EmitFixed32SerializeStatement(identifier, isSigned)],
112+
WireType.Fixed64 => [EmitFixed64SerializeStatement(identifier, isSigned)],
113+
WireType.LengthDelimited when type.IsStringType() => [EmitStringSerializeStatement(identifier)],
114+
WireType.LengthDelimited when type.IsByteArrayType() => [EmitBytesSerializeStatement(identifier)],
115+
WireType.LengthDelimited when symbol.IsUserDefinedType() => EmitProtoPackableSerializeStatement(type.ToString(), identifier),
112116
_ => throw new Exception($"Unsupported wire type: {wireType} for {identifier}")
113117
};
114118
}
@@ -171,6 +175,19 @@ private static StatementSyntax EmitFixed64SerializeStatement(string name, bool i
171175
return SF.ExpressionStatement(SF.InvocationExpression(access).AddArgumentListArguments(SF.Argument(arg)));
172176
}
173177

178+
private static StatementSyntax[] EmitProtoPackableSerializeStatement(string typeName, string name)
179+
{
180+
var measure = SF.MemberAccessExpression(SK.SimpleMemberAccessExpression, SF.IdentifierName(typeName), SF.IdentifierName($"MeasureHandler"));
181+
var invocation = SF.InvocationExpression(measure).AddArgumentListArguments(SF.Argument(SF.MemberAccessExpression(SK.SimpleMemberAccessExpression, SF.IdentifierName("obj"), SF.IdentifierName(name))));
182+
var access = SF.MemberAccessExpression(SK.SimpleMemberAccessExpression, SF.IdentifierName("writer"), SF.IdentifierName("EncodeVarInt"));
183+
var serialize = SF.MemberAccessExpression(SK.SimpleMemberAccessExpression, SF.IdentifierName(typeName), SF.IdentifierName("SerializeHandler"));
184+
return
185+
[
186+
SF.ExpressionStatement(SF.InvocationExpression(access).AddArgumentListArguments(SF.Argument(invocation))),
187+
SF.ExpressionStatement(SF.InvocationExpression(serialize).AddArgumentListArguments(SF.Argument(SF.MemberAccessExpression(SK.SimpleMemberAccessExpression, SF.IdentifierName("obj"), SF.IdentifierName(name))), SF.Argument(SF.IdentifierName("writer"))))
188+
];
189+
}
190+
174191
private static StatementSyntax EmitBytesSerializeStatement(string name)
175192
{
176193
var arg = SF.MemberAccessExpression(SK.SimpleMemberAccessExpression, SF.IdentifierName("obj"), SF.IdentifierName(name));
@@ -195,5 +212,123 @@ private static AttributeSyntax EmitGeneratedCodeAttribute()
195212
SF.AttributeArgument(SF.LiteralExpression(SK.StringLiteralExpression, SF.Literal("Lagrange.Proto.Generator"))),
196213
SF.AttributeArgument(SF.LiteralExpression(SK.StringLiteralExpression, SF.Literal("1.0.0"))));
197214
}
215+
216+
#endregion
217+
218+
#region MeasureHandler
219+
220+
private MethodDeclarationSyntax EmitMeasureHandlerMethod()
221+
{
222+
ExpressionSyntax syntax;
223+
if (parser.Fields.Count == 0)
224+
{
225+
syntax = SF.LiteralExpression(SK.NumericLiteralExpression, SF.Literal(0));
226+
}
227+
else
228+
{
229+
int constant = 0;
230+
var expressions = new List<ExpressionSyntax>();
231+
232+
foreach (var kv in parser.Fields)
233+
{
234+
TypeSyntax type;
235+
string name;
236+
switch (kv.Value.Syntax)
237+
{
238+
case FieldDeclarationSyntax fieldDeclaration:
239+
{
240+
type = fieldDeclaration.Declaration.Type;
241+
name = fieldDeclaration.Declaration.Variables[0].Identifier.ToString();
242+
break;
243+
}
244+
case PropertyDeclarationSyntax propertyDeclaration:
245+
{
246+
type = propertyDeclaration.Type;
247+
name = propertyDeclaration.Identifier.ToString();
248+
break;
249+
}
250+
default:
251+
{
252+
throw new Exception($"Unsupported member type: {kv.Value.GetType()}");
253+
}
254+
}
255+
var symbol = parser.Model.GetTypeSymbol(type);
256+
string identifier = symbol.IsValueType && type.IsNullableType() ? name + ".Value" : name;
257+
258+
var expr = kv.Value.WireType switch
259+
{
260+
WireType.VarInt => EmitVarIntLengthExpression(identifier),
261+
WireType.Fixed32 => SF.LiteralExpression(SK.NumericLiteralExpression, SF.Literal(4)),
262+
WireType.Fixed64 => SF.LiteralExpression(SK.NumericLiteralExpression, SF.Literal(8)),
263+
WireType.LengthDelimited when type.IsStringType() => EmitStringLengthExpression(identifier),
264+
WireType.LengthDelimited when type.IsByteArrayType() => EmitBytesLengthExpression(identifier),
265+
WireType.LengthDelimited when symbol.IsUserDefinedType() => EmitProtoPackableLengthExpression(identifier),
266+
_ => throw new Exception($"Unsupported wire type: {kv.Value.WireType} for {type.ToString()}")
267+
};
268+
269+
if (symbol.IsValueType && !type.IsNullableType())
270+
{
271+
constant += ProtoHelper.EncodeVarInt((kv.Key << 3) | (byte)kv.Value.WireType).Length;
272+
}
273+
else // null check with obj.{identifier}
274+
{
275+
var tag = ProtoHelper.EncodeVarInt((kv.Key << 3) | (byte)kv.Value.WireType);
276+
var right = SF.BinaryExpression(SK.AddExpression, SF.LiteralExpression(SK.NumericLiteralExpression, SF.Literal(tag.Length)), expr);
277+
var left = SF.LiteralExpression(SK.NumericLiteralExpression, SF.Literal(0));
278+
expr = EmitNullableCheckExpression(name, left, right);
279+
}
280+
281+
expressions.Add(expr);
282+
}
283+
284+
syntax = SF.LiteralExpression(SK.NumericLiteralExpression, SF.Literal(constant));
285+
syntax = expressions.Aggregate(syntax, (current, expr) => SF.BinaryExpression(SK.AddExpression, current, expr));
286+
}
287+
288+
string classFullName = $"global::{parser.Namespace}.{parser.Identifier}";
289+
var parameters = SF.ParameterList()
290+
.AddParameters(SF.Parameter(SF.Identifier("obj")).WithType(SF.ParseTypeName(classFullName)));
291+
292+
return SF.MethodDeclaration(SF.PredefinedType(SF.Token(SK.IntKeyword)), "MeasureHandler")
293+
.AddModifiers(SF.Token(SK.PublicKeyword), SF.Token(SK.StaticKeyword))
294+
.WithParameterList(parameters)
295+
.WithBody(SF.Block(SF.ReturnStatement(syntax)));
296+
}
297+
298+
private static ExpressionSyntax EmitVarIntLengthExpression(string identifier)
299+
{
300+
var obj = SF.MemberAccessExpression(SK.SimpleMemberAccessExpression, SF.IdentifierName("obj"), SF.IdentifierName(identifier));
301+
var access = SF.MemberAccessExpression(SK.SimpleMemberAccessExpression, SF.IdentifierName("global::Lagrange.Proto.Utility.ProtoHelper"), SF.IdentifierName("GetVarIntLength"));
302+
return SF.InvocationExpression(access).AddArgumentListArguments(SF.Argument(obj));
303+
}
304+
305+
private static ExpressionSyntax EmitStringLengthExpression(string identifier)
306+
{
307+
var obj = SF.MemberAccessExpression(SK.SimpleMemberAccessExpression, SF.IdentifierName("obj"), SF.IdentifierName(identifier));
308+
var access = SF.MemberAccessExpression(SK.SimpleMemberAccessExpression, SF.IdentifierName("global::Lagrange.Proto.Utility.ProtoHelper"), SF.IdentifierName("CountString"));
309+
return SF.InvocationExpression(access).AddArgumentListArguments(SF.Argument(obj));
310+
}
311+
312+
private static ExpressionSyntax EmitBytesLengthExpression(string identifier)
313+
{
314+
var obj = SF.MemberAccessExpression(SK.SimpleMemberAccessExpression, SF.IdentifierName("obj"), SF.IdentifierName(identifier));
315+
var access = SF.MemberAccessExpression(SK.SimpleMemberAccessExpression, SF.IdentifierName("global::Lagrange.Proto.Utility.ProtoHelper"), SF.IdentifierName("CountBytes"));
316+
return SF.InvocationExpression(access).AddArgumentListArguments(SF.Argument(obj));
317+
}
318+
319+
private static ExpressionSyntax EmitProtoPackableLengthExpression(string name)
320+
{
321+
var obj = SF.MemberAccessExpression(SK.SimpleMemberAccessExpression, SF.IdentifierName("obj"), SF.IdentifierName(name));
322+
var access = SF.MemberAccessExpression(SK.SimpleMemberAccessExpression, SF.IdentifierName("global::Lagrange.Proto.Utility.ProtoHelper"), SF.IdentifierName("CountProtoPackable"));
323+
return SF.InvocationExpression(access).AddArgumentListArguments(SF.Argument(obj));
324+
}
325+
326+
/// <summary>
327+
/// (obj == null ? left : right)
328+
/// </summary>
329+
private static ExpressionSyntax EmitNullableCheckExpression(string identifier, ExpressionSyntax left, ExpressionSyntax right) => SF.ParenthesizedExpression(
330+
SF.ConditionalExpression(SF.BinaryExpression(SK.EqualsExpression, SF.MemberAccessExpression(SK.SimpleMemberAccessExpression, SF.IdentifierName("obj"), SF.IdentifierName(identifier)), SF.LiteralExpression(SK.NullLiteralExpression)), left, right)
331+
);
332+
#endregion
198333
}
199334
}

Lagrange.Proto.Generator/ProtoSourceGenerator.Parser.cs

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@
44
using Lagrange.Proto.Serialization;
55
using Microsoft.CodeAnalysis;
66
using Microsoft.CodeAnalysis.CSharp.Syntax;
7+
using static Lagrange.Proto.Generator.DiagnosticDescriptors;
78

89
namespace Lagrange.Proto.Generator;
910

1011
public partial class ProtoSourceGenerator
1112
{
13+
private const string ProtoPackableAttributeFullName = "Lagrange.Proto.ProtoPackableAttribute";
14+
1215
private class Parser(ClassDeclarationSyntax context, SemanticModel model)
1316
{
1417
public SemanticModel Model { get; } = model;
@@ -28,7 +31,7 @@ public void Parse(CancellationToken token = default)
2831

2932
if (!context.IsPartial())
3033
{
31-
ReportDiagnostics(DiagnosticDescriptors.MustBePartialClass, context.GetLocation(), context.Identifier.Text);
34+
ReportDiagnostics(MustBePartialClass, context.GetLocation(), context.Identifier.Text);
3235
return;
3336
}
3437

@@ -50,15 +53,37 @@ public void Parse(CancellationToken token = default)
5053
PropertyDeclarationSyntax propertyDeclaration => propertyDeclaration.Type,
5154
_ => throw new InvalidOperationException("Unsupported member type.")
5255
};
56+
var typeSymbol = symbol switch
57+
{
58+
IPropertySymbol propertySymbol => propertySymbol.Type,
59+
IFieldSymbol fieldSymbol => fieldSymbol.Type,
60+
_ => throw new InvalidOperationException("Unsupported member type.")
61+
};
5362
var wireType = type.GetWireType();
5463
bool signed = false;
64+
65+
if (wireType == WireType.LengthDelimited && typeSymbol.IsUserDefinedType())
66+
{
67+
var typeAttribute = typeSymbol.GetAttributes().FirstOrDefault(x => x.AttributeClass?.ToDisplayString() == ProtoPackableAttributeFullName);
68+
if (typeAttribute == null)
69+
{
70+
ReportDiagnostics(NestedTypeMustBeProtoPackable, member.GetLocation(), typeSymbol.Name, Identifier);
71+
continue;
72+
}
73+
}
5574

5675
foreach (var argument in attribute.NamedArguments)
5776
{
5877
switch (argument.Key)
5978
{
6079
case "NumberHandling":
6180
{
81+
if (wireType != WireType.VarInt)
82+
{
83+
ReportDiagnostics(InvalidNumberHandling, member.GetLocation(), field, Identifier);
84+
continue;
85+
}
86+
6287
var value = (ProtoNumberHandling)(argument.Value.Value ?? throw new InvalidOperationException("Unable to get number handling."));
6388
if (value.HasFlag(ProtoNumberHandling.Signed)) signed = true;
6489
if (value.HasFlag(ProtoNumberHandling.Fixed32)) wireType = WireType.Fixed32;
@@ -70,7 +95,7 @@ public void Parse(CancellationToken token = default)
7095

7196
if (Fields.ContainsKey(field))
7297
{
73-
ReportDiagnostics(DiagnosticDescriptors.DuplicateFieldNumber, member.GetLocation(), field, Identifier);
98+
ReportDiagnostics(DuplicateFieldNumber, member.GetLocation(), field, Identifier);
7499
continue;
75100
}
76101

Lagrange.Proto.Generator/Utility/Extension/SyntaxExtension.cs

Lines changed: 0 additions & 13 deletions
This file was deleted.

Lagrange.Proto.Generator/Utility/Extension/TypeExtension.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,18 @@ namespace Lagrange.Proto.Generator.Utility.Extension;
66

77
public static class TypeExtension
88
{
9+
private static readonly string[] SystemAssemblies = ["mscorlib", "System", "System.Core", "System.Private.CoreLib", "System.Runtime"];
10+
11+
public static bool IsUserDefinedType(this ITypeSymbol type)
12+
{
13+
return type.TypeKind is TypeKind.Class or TypeKind.Struct or TypeKind.Enum && !type.ContainingAssembly.IsSystemAssembly();
14+
}
15+
16+
private static bool IsSystemAssembly(this IAssemblySymbol assembly)
17+
{
18+
return SystemAssemblies.Contains(assembly.Name);
19+
}
20+
921
public static bool IsNumberType(this TypeSyntax type)
1022
{
1123
return type switch

0 commit comments

Comments
 (0)