From 4ae1bf08d05266c1d0f8060ca66eb066749ba462 Mon Sep 17 00:00:00 2001 From: Dima Enns Date: Fri, 26 Jan 2024 11:14:01 +0100 Subject: [PATCH] Implement generation of primary constructors for nested types --- PrimaryConstructor.Sample/Program.cs | 12 ++++ .../PrimaryConstructorGenerator.cs | 59 ++++++++++++++++--- 2 files changed, 63 insertions(+), 8 deletions(-) diff --git a/PrimaryConstructor.Sample/Program.cs b/PrimaryConstructor.Sample/Program.cs index 4e5911b..dd503c4 100644 --- a/PrimaryConstructor.Sample/Program.cs +++ b/PrimaryConstructor.Sample/Program.cs @@ -87,4 +87,16 @@ public string GetName() public class NotRegisteredDependency { } + + public partial class NestingGrandMother + { + public partial class NestingFather + { + [PrimaryConstructor] + public partial class Nested + { + private readonly int _dummy; + } + } + } } \ No newline at end of file diff --git a/PrimaryConstructor/PrimaryConstructorGenerator.cs b/PrimaryConstructor/PrimaryConstructorGenerator.cs index b6faa2f..6bec83d 100644 --- a/PrimaryConstructor/PrimaryConstructorGenerator.cs +++ b/PrimaryConstructor/PrimaryConstructorGenerator.cs @@ -37,8 +37,15 @@ public void Execute(GeneratorExecutionContext context) classNames.TryGetValue(classSymbol.Name, out var i); var name = i == 0 ? classSymbol.Name : $"{classSymbol.Name}{i + 1}"; classNames[classSymbol.Name] = i + 1; - context.AddSource($"{name}.PrimaryConstructor.g.cs", - SourceText.From(CreatePrimaryConstructor(classSymbol), Encoding.UTF8)); + + var source = CSharpSyntaxTree + .ParseText(SourceText.From(CreatePrimaryConstructor(classSymbol), Encoding.UTF8)) + .GetRoot() + .NormalizeWhitespace() + .SyntaxTree + .GetText(); + + context.AddSource($"{name}.PrimaryConstructor.g.cs", source); } } @@ -89,12 +96,21 @@ private static string CreatePrimaryConstructor(INamedTypeSymbol classSymbol) var memberList = GetMembers(classSymbol, false); var arguments = (baseClassConstructorArgs == null ? memberList : memberList.Concat(baseClassConstructorArgs)) .Select(it => $"{it.Type} {it.ParameterName}"); - var fullTypeName = classSymbol.ToDisplayString(TypeFormat); - var i = fullTypeName.IndexOf('<'); - var generic = i < 0 ? "" : fullTypeName.Substring(i); + var nestingStack = GetNestingAncestors(classSymbol); + var nestingCount = nestingStack.Count; var source = new StringBuilder($@"namespace {namespaceName} -{{ - partial class {classSymbol.Name}{generic} +{{"); + while (nestingStack.Any()) + { + var currentNestingAncestor = nestingStack.Pop(); + + source.Append($@" + partial class {currentNestingAncestor.Name}{GetGenericsNamePart(currentNestingAncestor)} + {{"); + } + + source.Append($@" + partial class {classSymbol.Name}{GetGenericsNamePart(classSymbol)} {{ public {classSymbol.Name}({string.Join(", ", arguments)}){baseConstructorInheritance} {{"); @@ -104,9 +120,18 @@ partial class {classSymbol.Name}{generic} source.Append($@" this.{item.Name} = {item.ParameterName};"); } + source.Append(@" } - } + }"); + + for (int i = 0; i < nestingCount; i++) + { + source.Append(@" + }"); + } + + source.Append(@" } "); @@ -177,5 +202,23 @@ select model.GetDeclaredSymbol(clazz)! into classSymbol where HasAttribute(classSymbol, nameof(PrimaryConstructorAttribute)) select classSymbol; } + + private static Stack GetNestingAncestors(INamedTypeSymbol classSymbol) + { + var stack = new Stack(); + var current = classSymbol.ContainingType; + while (current is not null) + { + stack.Push(current); + current = current.ContainingType; + } + + return stack; + } + + private static string GetGenericsNamePart(INamedTypeSymbol classSymbol) => + classSymbol.TypeArguments.Any() + ? $"<{string.Join(", ", classSymbol.TypeArguments.Select(s => s.ToDisplayString(TypeFormat)))}>" + : ""; } }