From ea35217a619a774807d9d9e4e1aafb05bc0896cb Mon Sep 17 00:00:00 2001 From: Daniel Cazzulino Date: Fri, 6 Dec 2024 19:37:10 -0300 Subject: [PATCH] Improve usage of common namespace, more flexible TId processing Allows an indirection in templates to that the value type can be a file local type that can be used to declare the interface to expose. Partial implementation, still not working with scenarios like IComparable. --- src/Sample/Common/Product.cs | 2 +- src/StructId.Analyzer/AnalysisExtensions.cs | 54 ++++++- src/StructId.Analyzer/BaseGenerator.cs | 3 +- .../NewtonsoftJsonGenerator.cs | 3 +- src/StructId.Analyzer/TemplatedGenerator.cs | 151 ++++++++++++++---- src/StructId.FunctionalTests/Functional.cs | 3 +- src/StructId/Templates/SpanFormattable.cs | 5 +- 7 files changed, 177 insertions(+), 44 deletions(-) diff --git a/src/Sample/Common/Product.cs b/src/Sample/Common/Product.cs index cc7ffc4..18ba1e5 100644 --- a/src/Sample/Common/Product.cs +++ b/src/Sample/Common/Product.cs @@ -2,4 +2,4 @@ public record Product(ProductId Id, string Name); -public readonly partial record struct ProductId : IStructId; \ No newline at end of file +public readonly partial record struct ProductId : IStructId; diff --git a/src/StructId.Analyzer/AnalysisExtensions.cs b/src/StructId.Analyzer/AnalysisExtensions.cs index a9fbf21..5f502d7 100644 --- a/src/StructId.Analyzer/AnalysisExtensions.cs +++ b/src/StructId.Analyzer/AnalysisExtensions.cs @@ -10,12 +10,20 @@ namespace StructId; public static class AnalysisExtensions { - public static SymbolDisplayFormat FullName { get; } = new( + public static SymbolDisplayFormat TypeName { get; } = new( + typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameOnly, + genericsOptions: SymbolDisplayGenericsOptions.IncludeTypeParameters); + + public static SymbolDisplayFormat NamespacedTypeName { get; } = new( + typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameAndContainingTypesAndNamespaces, + genericsOptions: SymbolDisplayGenericsOptions.IncludeTypeParameters); + + public static SymbolDisplayFormat FullNameNullable { get; } = new( typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameAndContainingTypesAndNamespaces, genericsOptions: SymbolDisplayGenericsOptions.IncludeTypeParameters, miscellaneousOptions: SymbolDisplayMiscellaneousOptions.ExpandNullable); - public static string ToFullName(this ISymbol symbol) => symbol.ToDisplayString(FullName); + public static string ToFullName(this ISymbol symbol) => symbol.ToDisplayString(FullNameNullable); /// /// Checks whether the type inherits or implements the @@ -50,6 +58,44 @@ @this is INamedTypeSymbol namedActual && public static string GetStructIdNamespace(this AnalyzerConfigOptions options) => options.TryGetValue("build_property.StructIdNamespace", out var ns) && !string.IsNullOrEmpty(ns) ? ns : "StructId"; + public static IncrementalValueProvider GetStructIdNamespace(this IncrementalValueProvider options) + => options.Select((x, _) => x.GlobalOptions.TryGetValue("build_property.StructIdNamespace", out var ns) ? ns : "StructId"); + + public static bool ImplementsExplicitly(this INamedTypeSymbol namedTypeSymbol, INamedTypeSymbol interfaceTypeSymbol) + { + if (interfaceTypeSymbol.IsUnboundGenericType && interfaceTypeSymbol.TypeParameters.Length == 1) + { + try + { + interfaceTypeSymbol = interfaceTypeSymbol.ConstructedFrom.Construct(namedTypeSymbol); + } + catch { } + } + + foreach (var interfaceMember in interfaceTypeSymbol.GetMembers()) + { + foreach (var classMember in namedTypeSymbol.GetMembers()) + { + switch (classMember) + { + case IMethodSymbol methodSymbol: + if (methodSymbol.ExplicitInterfaceImplementations.Contains(interfaceMember, SymbolEqualityComparer.Default)) + return true; + break; + case IPropertySymbol propertySymbol: + if (propertySymbol.ExplicitInterfaceImplementations.Contains(interfaceMember, SymbolEqualityComparer.Default)) + return true; + break; + case IEventSymbol eventSymbol: + if (eventSymbol.ExplicitInterfaceImplementations.Contains(interfaceMember, SymbolEqualityComparer.Default)) + return true; + break; + } + } + } + return false; + } + public static IEnumerable GetAllTypes(this Compilation compilation, bool includeReferenced = true) => compilation.Assembly .GetAllTypes() @@ -82,7 +128,7 @@ static IEnumerable GetAllTypes(INamespaceSymbol namespaceSymbo public static string GetTypeName(this ITypeSymbol type, string? containingNamespace) { - var typeName = type.ToDisplayString(FullName); + var typeName = type.ToDisplayString(FullNameNullable); if (containingNamespace == null) return typeName; @@ -92,7 +138,7 @@ public static string GetTypeName(this ITypeSymbol type, string? containingNamesp return typeName; } - public static string ToFileName(this ITypeSymbol type) => type.ToDisplayString(FullName).Replace('+', '.'); + public static string ToFileName(this ITypeSymbol type) => type.ToDisplayString(FullNameNullable).Replace('+', '.'); public static bool IsStructId(this ITypeSymbol type) => type.AllInterfaces.Any(x => x.Name == "IStructId"); diff --git a/src/StructId.Analyzer/BaseGenerator.cs b/src/StructId.Analyzer/BaseGenerator.cs index bfa6580..a933af7 100644 --- a/src/StructId.Analyzer/BaseGenerator.cs +++ b/src/StructId.Analyzer/BaseGenerator.cs @@ -26,8 +26,7 @@ protected record struct TemplateArgs(string StructIdNamespace, INamedTypeSymbol public virtual void Initialize(IncrementalGeneratorInitializationContext context) { - var targetNamespace = context.AnalyzerConfigOptionsProvider - .Select((x, _) => x.GlobalOptions.TryGetValue("build_property.StructIdNamespace", out var ns) ? ns : "StructId"); + var targetNamespace = context.AnalyzerConfigOptionsProvider.GetStructIdNamespace(); // Locate the required types var types = context.CompilationProvider diff --git a/src/StructId.Analyzer/NewtonsoftJsonGenerator.cs b/src/StructId.Analyzer/NewtonsoftJsonGenerator.cs index c2075d8..a3273c8 100644 --- a/src/StructId.Analyzer/NewtonsoftJsonGenerator.cs +++ b/src/StructId.Analyzer/NewtonsoftJsonGenerator.cs @@ -18,8 +18,7 @@ public override void Initialize(IncrementalGeneratorInitializationContext contex context.RegisterSourceOutput( context.CompilationProvider .Select((x, _) => x.GetTypeByMetadataName("Newtonsoft.Json.JsonConverter`1")) - .Combine(context.AnalyzerConfigOptionsProvider - .Select((x, _) => x.GlobalOptions.TryGetValue("build_property.StructIdNamespace", out var ns) ? ns : "StructId")), + .Combine(context.AnalyzerConfigOptionsProvider.GetStructIdNamespace()), (context, source) => { if (source.Left == null) diff --git a/src/StructId.Analyzer/TemplatedGenerator.cs b/src/StructId.Analyzer/TemplatedGenerator.cs index 4572463..5b41599 100644 --- a/src/StructId.Analyzer/TemplatedGenerator.cs +++ b/src/StructId.Analyzer/TemplatedGenerator.cs @@ -1,4 +1,7 @@ +using System; +using System.Collections.Concurrent; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using System.Text; using System.Text.RegularExpressions; @@ -13,21 +16,70 @@ namespace StructId; [Generator(LanguageNames.CSharp)] public class TemplatedGenerator : IIncrementalGenerator { - record KnownTypes(string StructIdNamespace, INamedTypeSymbol String, INamedTypeSymbol? IStructId, INamedTypeSymbol? IStructIdT, INamedTypeSymbol? TStructId, INamedTypeSymbol? TStructIdT); - record IdTemplate(INamedTypeSymbol StructId, Template Template); - record Template(INamedTypeSymbol TSelf, ITypeSymbol TId, AttributeData Attribute, string StructIdNamespace, bool IsGenericTId) + /// + /// Provides access to some common types and properties used in the compilation. + /// + /// The compilation used to resolve the known types. + /// The namespace for StructId types. + record KnownTypes(Compilation Compilation, string StructIdNamespace) { + /// + /// System.String + /// + public INamedTypeSymbol String { get; } = Compilation.GetTypeByMetadataName("System.String")!; + /// + /// StructId.IStructId + /// + public INamedTypeSymbol? IStructId { get; } = Compilation.GetTypeByMetadataName($"{StructIdNamespace}.IStructId"); + /// + /// StructId.IStructId{T} + /// + public INamedTypeSymbol? IStructIdT { get; } = Compilation.GetTypeByMetadataName($"{StructIdNamespace}.IStructId`1"); + /// + /// StructId.TStructIdAttribute + /// + public INamedTypeSymbol? TStructId { get; } = Compilation.GetTypeByMetadataName($"{StructIdNamespace}.TStructIdAttribute"); + /// + /// StructId.TStructIdAttribute{T} + /// + public INamedTypeSymbol? TStructIdT { get; } = Compilation.GetTypeByMetadataName($"{StructIdNamespace}.TStructIdAttribute`1"); + } + + /// + /// Represents a template for struct ids. + /// + /// The struct id type, either IStructId or IStructId{T}. + /// The type of value the struct id holds, such as Guid or string. + /// The template to apply to it. + record IdTemplate(INamedTypeSymbol StructId, INamedTypeSymbol TId, Template Template); + + record Template(INamedTypeSymbol TSelf, INamedTypeSymbol TId, AttributeData Attribute, KnownTypes KnownTypes) + { + ConcurrentDictionary text = new(SymbolEqualityComparer.Default); + + public INamedTypeSymbol? CustomTId { get; init; } + + // A custom TId is a file-local type declaration. + public bool IsCustomTId => CustomTId?.DeclaringSyntaxReferences + .All(x => x.GetSyntax() is TypeDeclarationSyntax decl && decl.Modifiers.Any(m => m.IsKind(SyntaxKind.FileKeyword))) == true; + public Regex NameExpr { get; } = new Regex($@"\b{TSelf.Name}\b", RegexOptions.Compiled | RegexOptions.Multiline); - public string Text { get; } = GetTemplateCode(TSelf, TId, Attribute, StructIdNamespace); + public string GetText(INamedTypeSymbol tid) => text.GetOrAdd(tid, tid + => GetTemplateCode(TSelf, TId, CustomTId, tid, Attribute, KnownTypes)); - static string GetTemplateCode(INamedTypeSymbol self, ITypeSymbol tid, AttributeData attribute, string StructIdNamespace) + static string GetTemplateCode(INamedTypeSymbol self, + INamedTypeSymbol templateIdType, INamedTypeSymbol? customIdType, INamedTypeSymbol idTypeInstance, + AttributeData attribute, KnownTypes known) { if (self.DeclaringSyntaxReferences[0].GetSyntax() is not TypeDeclarationSyntax declaration) return ""; // Remove the TId/TValue if present in the same syntax tree. - var toremove = tid.DeclaringSyntaxReferences.Select(x => x.GetSyntax()).ToList(); + var toremove = templateIdType.DeclaringSyntaxReferences.Select(x => x.GetSyntax()).ToList(); + if (customIdType != null) + toremove.AddRange(customIdType.DeclaringSyntaxReferences.Select(x => x.GetSyntax())); + // Also the [TStructId] attribute applied to the template itself if (attribute.ApplicationSyntaxReference?.GetSyntax().FirstAncestorOrSelf() is { } attr) toremove.Add(attr); @@ -35,7 +87,8 @@ static string GetTemplateCode(INamedTypeSymbol self, ITypeSymbol tid, AttributeD if (declaration.ParameterList != null) toremove.Add(declaration.ParameterList); - var root = declaration.SyntaxTree.GetRoot() + var root = declaration.SyntaxTree + .GetRoot() .RemoveNodes(toremove, SyntaxRemoveOptions.KeepLeadingTrivia)!; var update = root.DescendantNodes().OfType().First(x => x.Identifier.Text == self.Name); @@ -57,9 +110,9 @@ static string GetTemplateCode(INamedTypeSymbol self, ITypeSymbol tid, AttributeD var nsname = ns?.Name.ToString(); if (nsname == "StructId") - root = root.ReplaceNode(ns!, ns!.WithName(ParseName(StructIdNamespace))); - else if (nsname != StructIdNamespace) - usings.Add(UsingDirective(ParseName(StructIdNamespace)).NormalizeWhitespace()); + root = root.ReplaceNode(ns!, ns!.WithName(ParseName(known.StructIdNamespace))); + else if (nsname != known.StructIdNamespace) + usings.Add(UsingDirective(ParseName(known.StructIdNamespace)).NormalizeWhitespace()); // deduplicate usings just in case var unique = new HashSet(); @@ -68,8 +121,8 @@ static string GetTemplateCode(INamedTypeSymbol self, ITypeSymbol tid, AttributeD // replace 'StructId' > StructIdNamespace if (old.Name?.ToString() == "StructId") { - unique.Add(StructIdNamespace); - return old.WithName(ParseName(StructIdNamespace)); + unique.Add(known.StructIdNamespace); + return old.WithName(ParseName(known.StructIdNamespace)); } if (unique.Add(old.Name?.ToString() ?? "")) @@ -78,8 +131,29 @@ static string GetTemplateCode(INamedTypeSymbol self, ITypeSymbol tid, AttributeD return null!; }); + if (idTypeInstance.Equals(templateIdType, SymbolEqualityComparer.Default)) + return root.SyntaxTree.GetRoot().ToFullString().Trim(); + + if (!idTypeInstance.ImplementsExplicitly(templateIdType)) + return root.SyntaxTree.GetRoot().ToFullString().Trim(); + // rewrite Value references to explicit casts just in case the // target type is implemented explicitly. + + var tid = templateIdType; + if (tid.IsUnboundGenericType && tid.TypeParameters.Length == 1) + { + try + { + // bind to namedTypeSymbol + tid = tid.ConstructedFrom.Construct(idTypeInstance); + } + catch (Exception ex) + { + Debug.WriteLine(ex.ToString()); + } + } + root = new ValueRewriter(tid).Visit(root); var code = root.SyntaxTree.GetRoot().ToFullString().Trim(); @@ -88,8 +162,6 @@ static string GetTemplateCode(INamedTypeSymbol self, ITypeSymbol tid, AttributeD } } - // create a syntax rewriter that replaces references to the Value property with an explicit - // cast of that property to a given INamedTypeSymbol class ValueRewriter(ITypeSymbol idType) : CSharpSyntaxRewriter { public override SyntaxNode? VisitMemberAccessExpression(MemberAccessExpressionSyntax node) @@ -108,19 +180,11 @@ class ValueRewriter(ITypeSymbol idType) : CSharpSyntaxRewriter public void Initialize(IncrementalGeneratorInitializationContext context) { - var structIdNamespace = context.AnalyzerConfigOptionsProvider - .Select((x, _) => x.GlobalOptions.TryGetValue("build_property.StructIdNamespace", out var ns) ? ns : "StructId"); + var structIdNamespace = context.AnalyzerConfigOptionsProvider.GetStructIdNamespace(); var known = context.CompilationProvider .Combine(structIdNamespace) - .Select((x, _) => new KnownTypes( - x.Right, - // get string known type - x.Left.GetTypeByMetadataName("System.String")!, - x.Left.GetTypeByMetadataName($"{x.Right}.IStructId"), - x.Left.GetTypeByMetadataName($"{x.Right}.IStructId`1"), - x.Left.GetTypeByMetadataName($"{x.Right}.TStructIdAttribute"), - x.Left.GetTypeByMetadataName($"{x.Right}.TStructIdAttribute`1"))); + .Select((x, _) => new KnownTypes(x.Left, x.Right)); var templates = context.CompilationProvider .SelectMany((x, _) => x.GetAllTypes(includeReferenced: true).OfType()) @@ -140,19 +204,40 @@ public void Initialize(IncrementalGeneratorInitializationContext context) x.Left.GetAttributes().Any(a => a.AttributeClass != null && // The attribute should either be the generic or regular TStructIdAttribute (a.AttributeClass.Is(x.Right.TStructId) || a.AttributeClass.Is(x.Right.TStructIdT)))) - .Select((x, _) => + .Select((x, cancellation) => { var (structId, known) = x; var attribute = structId.GetAttributes().FirstOrDefault(a => a.AttributeClass != null && a.AttributeClass.Is(known.TStructIdT)); - if (attribute != null) - return new Template(structId, attribute.AttributeClass!.TypeArguments[0], attribute, known.StructIdNamespace, true); + if (attribute != null && attribute.AttributeClass!.TypeArguments[0] is INamedTypeSymbol attrType) + return new Template(structId, attrType, attribute, known); // If we don't have the generic attribute, infer the idType from the required // primary constructor Value parameter type - var idType = structId.GetMembers().OfType().First(p => p.Name == "Value").Type; + var idType = (INamedTypeSymbol)structId.GetMembers().OfType().First(p => p.Name == "Value").Type; attribute = structId.GetAttributes().First(a => a.AttributeClass != null && a.AttributeClass.Is(known.TStructId)); - return new Template(structId, idType, attribute, known.StructIdNamespace, false); + // Otherwise, if idType is a local type, this should fail (but an analyzer will take care of that). + if (idType.DeclaringSyntaxReferences.Length == 0) + return new Template(structId, idType, attribute, known); + + // otherwise, the idType is a file-local type with a single interface + var type = idType.DeclaringSyntaxReferences[0].GetSyntax(cancellation) as TypeDeclarationSyntax; + var iface = type?.BaseList?.Types.FirstOrDefault()?.Type; + if (type == null || iface == null) + return new Template(structId, idType, attribute, known); + + if (x.Right.Compilation.GetSemanticModel(type.SyntaxTree).GetSymbolInfo(iface).Symbol is not INamedTypeSymbol ifaceType) + return new Template(structId, idType, attribute, known); + + // if the interface is a generic type with a single type argument that is the same as the idType + // make it an unbound generic type. We'll bind it to the actual idType later at template render time. + if (ifaceType.IsGenericType && ifaceType.TypeArguments.Length == 1 && ifaceType.TypeArguments[0].Equals(idType, SymbolEqualityComparer.Default)) + ifaceType = ifaceType.ConstructUnboundGenericType(); + + return new Template(structId, ifaceType, attribute, known) + { + CustomTId = idType + }; }) .Collect(); @@ -173,7 +258,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) var ((id, known), templates) = x; // Locate the IStructId interface implemented by the id var structId = id.AllInterfaces.First(i => i.Is(known.IStructId) || i.Is(known.IStructIdT)); - var tid = structId.IsGenericType ? structId.TypeArguments[0] : known.String; + var tid = structId.IsGenericType ? (INamedTypeSymbol)structId.TypeArguments[0] : known.String; // If the TId/Value implements or inherits from the template base type and/or its interfaces return templates // check struct id's value type against the template's TId for compatibility @@ -185,13 +270,13 @@ public void Initialize(IncrementalGeneratorInitializationContext context) // the struct id's value type, such as implementing multiple interfaces. In // this case, the tid would never equal or inherit from the template's TId, // but we want instead to check for base type compatibility plus all interfaces. - (template.IsGenericTId && + (template.IsCustomTId && // TId is a derived class of the template's TId base type (i.e. object or ValueType) tid.Is(template.TId.BaseType) && // All template provided TId interfaces must be implemented by the struct id's TId template.TId.AllInterfaces.All(iface => tid.AllInterfaces.Any(tface => tface.Is(iface))))) - .Select(template => new IdTemplate(id, template)); + .Select(template => new IdTemplate(id, tid, template)); }); context.RegisterSourceOutput(ids, GenerateCode); @@ -200,7 +285,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) void GenerateCode(SourceProductionContext context, IdTemplate source) { var hintName = $"{source.StructId.ToFileName()}-{source.Template.TSelf.Name}.cs"; - var output = source.Template.NameExpr.Replace(source.Template.Text, source.StructId.Name); + var output = source.Template.NameExpr.Replace(source.Template.GetText(source.TId), source.StructId.Name); if (source.StructId.ContainingNamespace.Equals(source.StructId.ContainingModule.GlobalNamespace, SymbolEqualityComparer.Default)) { diff --git a/src/StructId.FunctionalTests/Functional.cs b/src/StructId.FunctionalTests/Functional.cs index 943c874..bbfbc25 100644 --- a/src/StructId.FunctionalTests/Functional.cs +++ b/src/StructId.FunctionalTests/Functional.cs @@ -1,4 +1,5 @@ -using Dapper; +using System.Diagnostics.CodeAnalysis; +using Dapper; using Microsoft.Data.Sqlite; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyModel; diff --git a/src/StructId/Templates/SpanFormattable.cs b/src/StructId/Templates/SpanFormattable.cs index 4eb1876..31838be 100644 --- a/src/StructId/Templates/SpanFormattable.cs +++ b/src/StructId/Templates/SpanFormattable.cs @@ -1,4 +1,7 @@ -using System; +// +#nullable enable + +using System; using StructId; [TStructId]