diff --git a/src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs b/src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs index fb8731c..5eab43a 100644 --- a/src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs +++ b/src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs @@ -9,7 +9,8 @@ namespace Zomp.SyncMethodGenerator; /// Creates a new instance of . /// /// The semantic model. -internal sealed class AsyncToSyncRewriter(SemanticModel semanticModel) : CSharpSyntaxRewriter +/// The type of collection that should be used to replace IAsyncEnumerable. +internal sealed class AsyncToSyncRewriter(SemanticModel semanticModel, Dictionary replacementOverrides) : CSharpSyntaxRewriter { public const string SyncOnly = "SYNC_ONLY"; private const string SystemObject = "object"; @@ -58,6 +59,7 @@ internal sealed class AsyncToSyncRewriter(SemanticModel semanticModel) : CSharpS private readonly HashSet memoryToSpan = []; private readonly Dictionary renamedLocalFunctions = []; private readonly ImmutableArray.Builder diagnostics = ImmutableArray.CreateBuilder(); + private Dictionary parametersWithTypes = []; private enum SyncOnlyDirectiveType { @@ -213,7 +215,12 @@ static BinaryExpressionSyntax CheckNull(ExpressionSyntax expr) => BinaryExpressi var genericName = GetNameWithoutTypeParams(symbol); string? newType = null; - if (Replacements.TryGetValue(genericName, out var replacement)) + + //if (@base.Att) + // { + // } + if (replacementOverrides.TryGetValue(genericName, out var replacement) || + Replacements.TryGetValue(genericName, out replacement)) { if (replacement is not null) { @@ -867,10 +874,11 @@ bool ShouldRemoveAttribute(AttributeSyntax attributeSyntax) var attributeContainingTypeSymbol = attributeSymbol.ContainingType; // Is the attribute [CreateSyncVersion] attribute? - return IsCreateSyncVersionAttribute(attributeContainingTypeSymbol); + return IsCreateSyncVersionAttribute(attributeContainingTypeSymbol) || IsReplaceWithAttribute(attributeContainingTypeSymbol); } var @base = (AttributeListSyntax)base.VisitAttributeList(node)!; + node.Attributes.ToList().ForEach(a => AddParameterTypes(a)); var indices = node.Attributes.GetIndices((a, _) => ShouldRemoveAttribute(a)); var newList = RemoveAtRange(@base.Attributes, indices); return @base.WithAttributes(newList); @@ -878,6 +886,7 @@ bool ShouldRemoveAttribute(AttributeSyntax attributeSyntax) public override SyntaxNode? VisitAttribute(AttributeSyntax node) { + parametersWithTypes.ContainsKey("k"); var @base = (AttributeSyntax)base.VisitAttribute(node)!; if (GetSymbol(node.Name) is not IMethodSymbol ms) @@ -1013,6 +1022,9 @@ private static bool CanDropElse(ElseClauseSyntax @else) private static bool IsCreateSyncVersionAttribute(INamedTypeSymbol s) => s.ToDisplayString() == SyncMethodSourceGenerator.QualifiedCreateSyncVersionAttribute; + private static bool IsReplaceWithAttribute(INamedTypeSymbol s) + => s.ToDisplayString() == SyncMethodSourceGenerator.QualifiedReplaceWithAttribute; + private static SyntaxList RemoveAtRange(SyntaxList list, IEnumerable indices) where TNode : SyntaxNode { @@ -1188,6 +1200,22 @@ private static InvocationExpressionSyntax UnwrapExtension(InvocationExpressionSy return replacement; } + private void AddParameterTypes(AttributeSyntax attributeSyntax) + { + if (GetSymbol(attributeSyntax) is IMethodSymbol attributeSymbol) + { + var attributeContainingTypeSymbol = attributeSymbol.ContainingType; + if (IsReplaceWithAttribute(attributeContainingTypeSymbol) && attributeSyntax.Parent?.Parent is ParameterSyntax param + && GetSymbol(param) is IParameterSymbol parameterSymbol) + { + var variation = parameterSymbol.GetAttributes()[0].NamedArguments.FirstOrDefault(c => c.Key == "Variations").Value.Value; + + //.NamedArguments.FirstOrDefault(c => c.Key == "Variations").Value.Value; + parametersWithTypes.Add(param.Identifier.ValueText, param.Identifier.ValueText); + } + } + } + private bool PreProcess( SyntaxList statements, Dictionary extraNodeInfoList, diff --git a/src/Zomp.SyncMethodGenerator/CollectionTypes.cs b/src/Zomp.SyncMethodGenerator/CollectionTypes.cs new file mode 100644 index 0000000..e50dee6 --- /dev/null +++ b/src/Zomp.SyncMethodGenerator/CollectionTypes.cs @@ -0,0 +1,29 @@ +namespace Zomp.SyncMethodGenerator +{ + /// + /// All types that an IAsyncEnumerable can be converted into. + /// + [Flags] + public enum CollectionTypes + { + /// + /// Type for System.Collections.Generic.IEnumerable . + /// + IEnumerable = 1, + + /// + /// Type for System.Collections.Generic.IList . + /// + IList = 2, + + /// + /// Type for System.ReadOnlySpan . + /// + ReadOnlySpan = 4, + + /// + /// Type for System.Span . + /// + Span = 8, + } +} diff --git a/src/Zomp.SyncMethodGenerator/SourceGenerationHelper.cs b/src/Zomp.SyncMethodGenerator/SourceGenerationHelper.cs index a42cf50..5289ded 100644 --- a/src/Zomp.SyncMethodGenerator/SourceGenerationHelper.cs +++ b/src/Zomp.SyncMethodGenerator/SourceGenerationHelper.cs @@ -15,6 +15,22 @@ namespace Zomp.SyncMethodGenerator [System.AttributeUsage(System.AttributeTargets.Method)] internal class CreateSyncVersionAttribute : System.Attribute { + public virtual CollectionTypes Variations { get; set; } + } +} +"""; + + internal const string ReplaceWithAttributeSource = """ +// +namespace Zomp.SyncMethodGenerator +{ + /// + /// An attribute that can be used to specify the synchronous type a parameter should be converted into . + /// + [System.AttributeUsage(System.AttributeTargets.Parameter)] + internal class ReplaceWithAttribute : System.Attribute + { + public virtual CollectionTypes Variations { get; set; } } } """; diff --git a/src/Zomp.SyncMethodGenerator/SyncMethodSourceGenerator.cs b/src/Zomp.SyncMethodGenerator/SyncMethodSourceGenerator.cs index fae9618..7954b85 100644 --- a/src/Zomp.SyncMethodGenerator/SyncMethodSourceGenerator.cs +++ b/src/Zomp.SyncMethodGenerator/SyncMethodSourceGenerator.cs @@ -10,7 +10,13 @@ public class SyncMethodSourceGenerator : IIncrementalGenerator /// Create sync version attribute string. /// public const string CreateSyncVersionAttribute = "CreateSyncVersionAttribute"; + + /// + /// Replace with attribute string. + /// + public const string ReplaceWithAttribute = "ReplaceWithAttribute"; internal const string QualifiedCreateSyncVersionAttribute = $"{ThisAssembly.RootNamespace}.{CreateSyncVersionAttribute}"; + internal const string QualifiedReplaceWithAttribute = $"{ThisAssembly.RootNamespace}.{ReplaceWithAttribute}"; /// public void Initialize(IncrementalGeneratorInitializationContext context) @@ -22,9 +28,16 @@ public void Initialize(IncrementalGeneratorInitializationContext context) Debugger.Launch(); } #endif + + context.RegisterPostInitializationOutput(ctx => ctx.AddSource( + $"CollectionTypes.g.cs", SourceText.From(GetResourceText("Zomp.SyncMethodGenerator.tests.CollectionTypes.cs"), Encoding.UTF8))); + context.RegisterPostInitializationOutput(ctx => ctx.AddSource( $"{CreateSyncVersionAttribute}.g.cs", SourceText.From(SourceGenerationHelper.CreateSyncVersionAttributeSource, Encoding.UTF8))); + context.RegisterPostInitializationOutput(ctx => ctx.AddSource( + $"{ReplaceWithAttribute}.g.cs", SourceText.From(SourceGenerationHelper.ReplaceWithAttributeSource, Encoding.UTF8))); + IncrementalValuesProvider methodDeclarations = context.SyntaxProvider .ForAttributeWithMetadataName( QualifiedCreateSyncVersionAttribute, @@ -93,6 +106,7 @@ private static void Execute(Compilation compilation, ImmutableArray GetTypesToGenerate(SourceProductionContext context, Compilation compilation, IEnumerable methodDeclarations, CancellationToken ct) { var methodsToGenerate = new List(); + INamedTypeSymbol? attribute = compilation.GetTypeByMetadataName(QualifiedCreateSyncVersionAttribute); if (attribute == null) { @@ -120,6 +134,7 @@ private static List GetTypesToGenerate(SourceProductionContext var methodName = methodSymbol.ToString(); + var variations = CollectionTypes.IEnumerable; foreach (AttributeData attributeData in methodSymbol.GetAttributes()) { if (!attribute.Equals(attributeData.AttributeClass, SymbolEqualityComparer.Default)) @@ -127,6 +142,11 @@ private static List GetTypesToGenerate(SourceProductionContext continue; } + if (attributeData.NamedArguments.Length >= 1 && attributeData.NamedArguments.FirstOrDefault(c => c.Key == "Variations").Value.Value is int value) + { + variations = (CollectionTypes)value; + } + break; } @@ -161,47 +181,86 @@ private static List GetTypesToGenerate(SourceProductionContext continue; } - var rewriter = new AsyncToSyncRewriter(semanticModel); - var sn = rewriter.Visit(methodDeclarationSyntax); - var content = sn.ToFullString(); + var collections = new List(); - var diagnostics = rewriter.Diagnostics; + if ((variations & CollectionTypes.IList) == CollectionTypes.IList) + { + collections.Add("System.Collections.Generic.IList"); + } - var hasErrors = false; - foreach (var diagnostic in diagnostics) + if ((variations & CollectionTypes.Span) == CollectionTypes.Span) { - context.ReportDiagnostic(diagnostic); - hasErrors |= diagnostic.Severity == DiagnosticSeverity.Error; + collections.Add("System.Span"); } - if (hasErrors) + if ((variations & CollectionTypes.ReadOnlySpan) == CollectionTypes.ReadOnlySpan) { - continue; + collections.Add("System.ReadOnlySpan"); } - var isNamespaceFileScoped = false; - var namespaces = new List(); - while (node is not null && node is not CompilationUnitSyntax) + if ((variations & CollectionTypes.IEnumerable) == CollectionTypes.IEnumerable || collections.Count == 0) { - switch (node) + collections.Add("System.Collections.Generic.IEnumerable"); + } + + foreach (var collection in collections) + { + var replacementOverrides = new Dictionary { - case NamespaceDeclarationSyntax nds: - namespaces.Insert(0, nds.Name.ToString()); - break; - case FileScopedNamespaceDeclarationSyntax file: - namespaces.Add(file.Name.ToString()); - isNamespaceFileScoped = true; - break; - default: - throw new InvalidOperationException($"Cannot handle {node}"); + { "System.Collections.Generic.IAsyncEnumerable", collection }, + }; + var rewriter = new AsyncToSyncRewriter(semanticModel, replacementOverrides); + var sn = rewriter.Visit(methodDeclarationSyntax); + var content = sn.ToFullString(); + + var diagnostics = rewriter.Diagnostics; + + var hasErrors = false; + foreach (var diagnostic in diagnostics) + { + context.ReportDiagnostic(diagnostic); + hasErrors |= diagnostic.Severity == DiagnosticSeverity.Error; } - node = node.Parent; - } + if (hasErrors) + { + continue; + } - methodsToGenerate.Add(new(namespaces, isNamespaceFileScoped, classes, methodDeclarationSyntax.Identifier.ValueText, content)); + var isNamespaceFileScoped = false; + var namespaces = new List(); + while (node is not null && node is not CompilationUnitSyntax) + { + switch (node) + { + case NamespaceDeclarationSyntax nds: + namespaces.Insert(0, nds.Name.ToString()); + break; + case FileScopedNamespaceDeclarationSyntax file: + namespaces.Add(file.Name.ToString()); + isNamespaceFileScoped = true; + break; + default: + throw new InvalidOperationException($"Cannot handle {node}"); + } + + node = node.Parent; + } + + methodsToGenerate.Add(new(namespaces, isNamespaceFileScoped, classes, methodDeclarationSyntax.Identifier.ValueText, content)); + } } return methodsToGenerate; } + + private string GetResourceText(string name) + { + using var stream = GetType().Assembly.GetManifestResourceStream(name); + using var streamReader = new StreamReader(stream); + return $""" + // + {streamReader.ReadToEnd()} + """; + } } diff --git a/src/Zomp.SyncMethodGenerator/Zomp.SyncMethodGenerator.csproj b/src/Zomp.SyncMethodGenerator/Zomp.SyncMethodGenerator.csproj index a7c2bbc..19b4cac 100644 --- a/src/Zomp.SyncMethodGenerator/Zomp.SyncMethodGenerator.csproj +++ b/src/Zomp.SyncMethodGenerator/Zomp.SyncMethodGenerator.csproj @@ -24,6 +24,12 @@ + + + + Zomp.SyncMethodGenerator.tests.CollectionTypes.cs + + diff --git a/tests/GenerationSandbox.Tests/GenerationSandbox.Tests.csproj b/tests/GenerationSandbox.Tests/GenerationSandbox.Tests.csproj index d5d2288..28c0fd9 100644 --- a/tests/GenerationSandbox.Tests/GenerationSandbox.Tests.csproj +++ b/tests/GenerationSandbox.Tests/GenerationSandbox.Tests.csproj @@ -1,7 +1,7 @@  - $(NoWarn);CA1812;CA1852;SA1402;CA1822;SA1201;CA1303;SA1124;IDE0035;CS0162;RS1035;CS8619;CS8603;CA5394 + $(NoWarn);CA1812;CA1852;SA1402;CA1822;SA1201;CA1303;SA1124;IDE0035;CS0162;RS1035;SA1601;CS8619;CS8603;CA5394 false net7.0;net6.0 $(TargetFrameworks);net472 @@ -27,6 +27,10 @@ + + + + diff --git a/tests/Generator.Tests/Generator.Tests.csproj b/tests/Generator.Tests/Generator.Tests.csproj index 407aec1..693626d 100644 --- a/tests/Generator.Tests/Generator.Tests.csproj +++ b/tests/Generator.Tests/Generator.Tests.csproj @@ -32,6 +32,11 @@ + + + Generator.Tests.CollectionTypes.cs + + diff --git a/tests/Generator.Tests/IntegrationTesting.cs b/tests/Generator.Tests/IntegrationTesting.cs index 19cddeb..f652ea4 100644 --- a/tests/Generator.Tests/IntegrationTesting.cs +++ b/tests/Generator.Tests/IntegrationTesting.cs @@ -56,7 +56,24 @@ async Task EnumeratorTestAsync(IAsyncEnumerable range, CancellationToken ct [Fact] public Task CombineTwoLists() => """ -[CreateSyncVersion] +[CreateSyncVersion(Variations = (CollectionTypes.IList | CollectionTypes.Span))] +public static async IAsyncEnumerable<(TLeft Left, TRight Right)> CombineAsync(this IAsyncEnumerable list1, [ReplaceWith(Variations = CollectionTypes.Span)]IAsyncEnumerable list2, [EnumeratorCancellation] CancellationToken ct = default) +{ + await using var enumerator2 = list2.GetAsyncEnumerator(ct); + await foreach (var item in list1.WithCancellation(ct).ConfigureAwait(false)) + { + if (!(await enumerator2.MoveNextAsync().ConfigureAwait(false))) + { + throw new InvalidOperationException("Must have the same size"); + } + yield return (item, enumerator2.Current); + } +} +""".Verify(); + + [Fact] + public Task CombineTwoListsAll() => """ +[CreateSyncVersion(Variations = (CollectionTypes.IList | CollectionTypes.Span | CollectionTypes.ReadOnlySpan | CollectionTypes.IEnumerable))] public static async IAsyncEnumerable<(TLeft Left, TRight Right)> CombineAsync(this IAsyncEnumerable list1, IAsyncEnumerable list2, [EnumeratorCancellation] CancellationToken ct = default) { await using var enumerator2 = list2.GetAsyncEnumerator(ct); diff --git a/tests/Generator.Tests/Snapshots/IntegrationTesting.CombineTwoLists#.Class.CombineAsync.g.verified.cs b/tests/Generator.Tests/Snapshots/IntegrationTesting.CombineTwoLists#.Class.CombineAsync.g.verified.cs new file mode 100644 index 0000000..c7273b0 --- /dev/null +++ b/tests/Generator.Tests/Snapshots/IntegrationTesting.CombineTwoLists#.Class.CombineAsync.g.verified.cs @@ -0,0 +1,18 @@ +//HintName: .Class.CombineAsync.g.cs +// +#nullable enable +partial class Class +{ + public static global::System.Span<(TLeft Left, TRight Right)> Combine(this global::System.Span list1, global::System.Span list2) + { + using var enumerator2 = list2.GetEnumerator(); + foreach (var item in list1) + { + if (!enumerator2.MoveNext()) + { + throw new global::System.InvalidOperationException("Must have the same size"); + } + yield return (item, enumerator2.Current); + } + } +} diff --git a/tests/Generator.Tests/Snapshots/IntegrationTesting.CombineTwoLists#CombineAsync.g.verified.cs b/tests/Generator.Tests/Snapshots/IntegrationTesting.CombineTwoLists#CombineAsync.g.verified.cs index 73cb432..064a7ae 100644 --- a/tests/Generator.Tests/Snapshots/IntegrationTesting.CombineTwoLists#CombineAsync.g.verified.cs +++ b/tests/Generator.Tests/Snapshots/IntegrationTesting.CombineTwoLists#CombineAsync.g.verified.cs @@ -1,5 +1,5 @@ //HintName: Test.Class.CombineAsync.g.cs -public static global::System.Collections.Generic.IEnumerable<(TLeft Left, TRight Right)> Combine(this global::System.Collections.Generic.IEnumerable list1, global::System.Collections.Generic.IEnumerable list2) +public static global::System.Collections.Generic.IList<(TLeft Left, TRight Right)> Combine(this global::System.Collections.Generic.IList list1, global::System.Collections.Generic.IList list2) { using var enumerator2 = list2.GetEnumerator(); foreach (var item in list1) diff --git a/tests/Generator.Tests/Snapshots/IntegrationTesting.CombineTwoLists#ReplaceWithAttribute.g.verified.cs b/tests/Generator.Tests/Snapshots/IntegrationTesting.CombineTwoLists#ReplaceWithAttribute.g.verified.cs new file mode 100644 index 0000000..f8bb4bf --- /dev/null +++ b/tests/Generator.Tests/Snapshots/IntegrationTesting.CombineTwoLists#ReplaceWithAttribute.g.verified.cs @@ -0,0 +1,13 @@ +//HintName: ReplaceWithAttribute.g.cs +// +namespace Zomp.SyncMethodGenerator +{ + /// + /// An attribute that can be used to specify the synchronous type a parameter should be converted into . + /// + [System.AttributeUsage(System.AttributeTargets.Parameter)] + internal class ReplaceWithAttribute : System.Attribute + { + public virtual CollectionTypes Variations { get; set; } + } +} diff --git a/tests/Generator.Tests/Snapshots/IntegrationTesting.CombineTwoListsAll#.Class.CombineAsync.g.verified.cs b/tests/Generator.Tests/Snapshots/IntegrationTesting.CombineTwoListsAll#.Class.CombineAsync.g.verified.cs new file mode 100644 index 0000000..1b20b92 --- /dev/null +++ b/tests/Generator.Tests/Snapshots/IntegrationTesting.CombineTwoListsAll#.Class.CombineAsync.g.verified.cs @@ -0,0 +1,18 @@ +//HintName: .Class.CombineAsync.g.cs +// +#nullable enable +partial class Class +{ + public static global::System.Span<(TLeft Left, TRight Right)> Combine(this global::System.Span list1, global::System.Span list2) + { + using var enumerator2 = list2.GetEnumerator(); + foreach (var item in list1) + { + if (!enumerator2.MoveNext()) + { + throw new global::System.InvalidOperationException("Must have the same size"); + } + yield return (item, enumerator2.Current); + } + } +} diff --git a/tests/Generator.Tests/Snapshots/IntegrationTesting.CombineTwoListsAll#.Class.CombineAsync_2.g.verified.cs b/tests/Generator.Tests/Snapshots/IntegrationTesting.CombineTwoListsAll#.Class.CombineAsync_2.g.verified.cs new file mode 100644 index 0000000..e0315cf --- /dev/null +++ b/tests/Generator.Tests/Snapshots/IntegrationTesting.CombineTwoListsAll#.Class.CombineAsync_2.g.verified.cs @@ -0,0 +1,18 @@ +//HintName: .Class.CombineAsync_2.g.cs +// +#nullable enable +partial class Class +{ + public static global::System.ReadOnlySpan<(TLeft Left, TRight Right)> Combine(this global::System.ReadOnlySpan list1, global::System.ReadOnlySpan list2) + { + using var enumerator2 = list2.GetEnumerator(); + foreach (var item in list1) + { + if (!enumerator2.MoveNext()) + { + throw new global::System.InvalidOperationException("Must have the same size"); + } + yield return (item, enumerator2.Current); + } + } +} diff --git a/tests/Generator.Tests/Snapshots/IntegrationTesting.CombineTwoListsAll#.Class.CombineAsync_3.g.verified.cs b/tests/Generator.Tests/Snapshots/IntegrationTesting.CombineTwoListsAll#.Class.CombineAsync_3.g.verified.cs new file mode 100644 index 0000000..01e02b2 --- /dev/null +++ b/tests/Generator.Tests/Snapshots/IntegrationTesting.CombineTwoListsAll#.Class.CombineAsync_3.g.verified.cs @@ -0,0 +1,18 @@ +//HintName: .Class.CombineAsync_3.g.cs +// +#nullable enable +partial class Class +{ + public static global::System.Collections.Generic.IEnumerable<(TLeft Left, TRight Right)> Combine(this global::System.Collections.Generic.IEnumerable list1, global::System.Collections.Generic.IEnumerable list2) + { + using var enumerator2 = list2.GetEnumerator(); + foreach (var item in list1) + { + if (!enumerator2.MoveNext()) + { + throw new global::System.InvalidOperationException("Must have the same size"); + } + yield return (item, enumerator2.Current); + } + } +} diff --git a/tests/Generator.Tests/Snapshots/IntegrationTesting.CombineTwoListsAll#CombineAsync.g.verified.cs b/tests/Generator.Tests/Snapshots/IntegrationTesting.CombineTwoListsAll#CombineAsync.g.verified.cs new file mode 100644 index 0000000..064a7ae --- /dev/null +++ b/tests/Generator.Tests/Snapshots/IntegrationTesting.CombineTwoListsAll#CombineAsync.g.verified.cs @@ -0,0 +1,13 @@ +//HintName: Test.Class.CombineAsync.g.cs +public static global::System.Collections.Generic.IList<(TLeft Left, TRight Right)> Combine(this global::System.Collections.Generic.IList list1, global::System.Collections.Generic.IList list2) +{ + using var enumerator2 = list2.GetEnumerator(); + foreach (var item in list1) + { + if (!enumerator2.MoveNext()) + { + throw new global::System.InvalidOperationException("Must have the same size"); + } + yield return (item, enumerator2.Current); + } +} diff --git a/tests/Generator.Tests/TestHelper.cs b/tests/Generator.Tests/TestHelper.cs index 63cddf7..db274e9 100644 --- a/tests/Generator.Tests/TestHelper.cs +++ b/tests/Generator.Tests/TestHelper.cs @@ -132,7 +132,7 @@ partial class Class var target = new RunResultWithIgnoreList { Result = driver.GetRunResult(), - IgnoredFiles = { $"{SyncMethodSourceGenerator.CreateSyncVersionAttribute}.g.cs" }, + IgnoredFiles = { $"{SyncMethodSourceGenerator.CreateSyncVersionAttribute}.g.cs", $"{SyncMethodSourceGenerator.ReplaceWithAttribute}.g.cs", $"CollectionTypes.g.cs" }, }; var verifier = Verifier