From a402d5c66ff2f145082f3ff66f1c17afbecb8bb8 Mon Sep 17 00:00:00 2001 From: CodingFlow <3643313+CodingFlow@users.noreply.github.com> Date: Sat, 29 Nov 2025 23:09:45 -0500 Subject: [PATCH 1/5] - Implement generating interface file. --- .../IOrchestrator.cs | 9 ++++ .../Tests.cs | 6 ++- AsyncTaskOrchestratorGenerator/Main.cs | 9 ++-- .../OutputGenerator.cs | 46 ++++++++++++++++--- 4 files changed, 56 insertions(+), 14 deletions(-) create mode 100644 AsyncTaskOrchestratorGenerator.UnitTests/IOrchestrator.cs diff --git a/AsyncTaskOrchestratorGenerator.UnitTests/IOrchestrator.cs b/AsyncTaskOrchestratorGenerator.UnitTests/IOrchestrator.cs new file mode 100644 index 0000000..e1167e2 --- /dev/null +++ b/AsyncTaskOrchestratorGenerator.UnitTests/IOrchestrator.cs @@ -0,0 +1,9 @@ +// +#nullable restore + +namespace TestLibrary; + +internal interface IOrchestrator +{ + public Task Execute(); +} diff --git a/AsyncTaskOrchestratorGenerator.UnitTests/Tests.cs b/AsyncTaskOrchestratorGenerator.UnitTests/Tests.cs index 440f3c7..2e34a15 100644 --- a/AsyncTaskOrchestratorGenerator.UnitTests/Tests.cs +++ b/AsyncTaskOrchestratorGenerator.UnitTests/Tests.cs @@ -19,7 +19,8 @@ public void Setup() { [Test] public async Task OneInterface() { var source = await ReadCSharpFile(true); - var generated = await ReadCSharpFile(false); + var generatedClass = await ReadCSharpFile(false); + var generatedInterface = await ReadCSharpFile(false); await new VerifyCS.Test { @@ -35,7 +36,8 @@ public async Task OneInterface() { Sources = { source }, GeneratedSources = { - (typeof(Main), "Orchestrator.generated.cs", SourceText.From(generated, Encoding.UTF8, SourceHashAlgorithm.Sha256)), + (typeof(Main), "Orchestrator.generated.cs", SourceText.From(generatedClass, Encoding.UTF8, SourceHashAlgorithm.Sha256)), + (typeof(Main), "IOrchestrator.generated.cs", SourceText.From(generatedInterface, Encoding.UTF8, SourceHashAlgorithm.Sha256)), }, }, }.RunAsync(); diff --git a/AsyncTaskOrchestratorGenerator/Main.cs b/AsyncTaskOrchestratorGenerator/Main.cs index 640616a..32f4672 100644 --- a/AsyncTaskOrchestratorGenerator/Main.cs +++ b/AsyncTaskOrchestratorGenerator/Main.cs @@ -1,9 +1,6 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Text; -using System; -using System.Collections.Generic; -using System.Linq; using System.Text; using System.Threading; @@ -32,9 +29,11 @@ private static (INamedTypeSymbol, SemanticModel) GetSemanticTargetForGeneration( } private static void Execute(SourceProductionContext context, (INamedTypeSymbol typeSymbol, SemanticModel semanticModel) typeInfo) { - var (source, className) = OutputGenerator.GenerateOutputs(typeInfo); + var (classSource, className) = OutputGenerator.GenerateClassOutputs(typeInfo); + var (interfaceSource, interfaceName) = OutputGenerator.GenerateInterfaceOutputs(typeInfo.typeSymbol); - context.AddSource($"{className}.generated.cs", SourceText.From(source, Encoding.UTF8, SourceHashAlgorithm.Sha256)); + context.AddSource($"{className}.generated.cs", SourceText.From(classSource, Encoding.UTF8, SourceHashAlgorithm.Sha256)); + context.AddSource($"{interfaceName}.generated.cs", SourceText.From(interfaceSource, Encoding.UTF8, SourceHashAlgorithm.Sha256)); } } } \ No newline at end of file diff --git a/AsyncTaskOrchestratorGenerator/OutputGenerator.cs b/AsyncTaskOrchestratorGenerator/OutputGenerator.cs index 36bd00e..c9c5690 100644 --- a/AsyncTaskOrchestratorGenerator/OutputGenerator.cs +++ b/AsyncTaskOrchestratorGenerator/OutputGenerator.cs @@ -8,11 +8,10 @@ namespace AsyncTaskOrchestratorGenerator { internal static class OutputGenerator { - public static (string source, string className) GenerateOutputs((INamedTypeSymbol typeSymbol, SemanticModel semanticModel) typeInfo) { + public static (string source, string className) GenerateClassOutputs((INamedTypeSymbol typeSymbol, SemanticModel semanticModel) typeInfo) { var type = typeInfo.typeSymbol; var semanticModel = typeInfo.semanticModel; - - var constructorArguments = type.GetAttributes().First((a) => a.AttributeClass.Name == nameof(AsyncTaskOrchestratorAttribute)).ConstructorArguments; + var constructorArguments = GetAttributeConstructorArguments(type); var className = constructorArguments.First().Value.ToString(); var executeMethodName = constructorArguments.ElementAt(1).Value.ToString(); @@ -50,11 +49,37 @@ namespace {type.ContainingNamespace.ToDisplayString()}; return (source, className); } + private static System.Collections.Immutable.ImmutableArray GetAttributeConstructorArguments(INamedTypeSymbol type) { + return type.GetAttributes().First((a) => a.AttributeClass.Name == nameof(AsyncTaskOrchestratorAttribute)).ConstructorArguments; + } + + public static (string source, string interfaceName) GenerateInterfaceOutputs(INamedTypeSymbol type) { + var constructorArguments = GetAttributeConstructorArguments(type); + var className = constructorArguments.First().Value.ToString(); + var executeMethodName = constructorArguments.ElementAt(1).Value.ToString(); + var interfaceName = $"I{className}"; + + var accessModifier = type.DeclaredAccessibility.ToString().ToLower(); + var executeMethod = GetExecuteMethod(type); + var executeMethodAccessibility = executeMethod.DeclaredAccessibility.ToString().ToLower(); + var formattedExecuteMethod = $"{executeMethodAccessibility} {executeMethod.ReturnType} {executeMethodName}();"; + + var source = +$@"// +#nullable restore + +namespace {type.ContainingNamespace.ToDisplayString()}; + +{accessModifier} interface {interfaceName} +{{ + {formattedExecuteMethod} +}} +"; + return (source, interfaceName); + } + private static (ExecuteMethodSignatureData, Dictionary, TaskData) CreateExecuteMethodData(INamedTypeSymbol type, IEnumerable fields, string executeMethodName) { - var executeMethod = type - .GetMembers() - .Where(m => m is IMethodSymbol) - .First(m => (m as IMethodSymbol).MethodKind == MethodKind.Ordinary) as IMethodSymbol; + var executeMethod = GetExecuteMethod(type); var statements = (executeMethod.DeclaringSyntaxReferences.First().GetSyntax() as MethodDeclarationSyntax).Body.Statements; var variableStatements = statements.Remove(statements.Last()); @@ -110,6 +135,13 @@ private static (ExecuteMethodSignatureData, Dictionary, TaskDa }, variableData.ToDictionary(taskData => taskData.OutputName), finalTaskData); } + private static IMethodSymbol GetExecuteMethod(INamedTypeSymbol type) { + return type + .GetMembers() + .Where(m => m is IMethodSymbol) + .First(m => (m as IMethodSymbol).MethodKind == MethodKind.Ordinary) as IMethodSymbol; + } + private static string FormatExecuteMethod(ExecuteMethodSignatureData signatureData, Dictionary data, TaskData finalTaskData) { var formattedTaskDeclarations = data.Select(keyValue => { var item = keyValue.Value; From f288ec97650d8a64b1ce7d77d5035636363fc38b Mon Sep 17 00:00:00 2001 From: CodingFlow <3643313+CodingFlow@users.noreply.github.com> Date: Sat, 29 Nov 2025 23:11:58 -0500 Subject: [PATCH 2/5] Implement generating inhering from generated interface. --- AsyncTaskOrchestratorGenerator.UnitTests/Orchestrator.cs | 2 +- AsyncTaskOrchestratorGenerator/OutputGenerator.cs | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/AsyncTaskOrchestratorGenerator.UnitTests/Orchestrator.cs b/AsyncTaskOrchestratorGenerator.UnitTests/Orchestrator.cs index 23b0c9c..dc8f8f8 100644 --- a/AsyncTaskOrchestratorGenerator.UnitTests/Orchestrator.cs +++ b/AsyncTaskOrchestratorGenerator.UnitTests/Orchestrator.cs @@ -6,7 +6,7 @@ namespace TestLibrary; -internal class Orchestrator +internal class Orchestrator : IOrchestrator { private readonly TestLibrary.A a; private readonly TestLibrary.B b; diff --git a/AsyncTaskOrchestratorGenerator/OutputGenerator.cs b/AsyncTaskOrchestratorGenerator/OutputGenerator.cs index c9c5690..6814e56 100644 --- a/AsyncTaskOrchestratorGenerator/OutputGenerator.cs +++ b/AsyncTaskOrchestratorGenerator/OutputGenerator.cs @@ -14,6 +14,7 @@ public static (string source, string className) GenerateClassOutputs((INamedType var constructorArguments = GetAttributeConstructorArguments(type); var className = constructorArguments.First().Value.ToString(); var executeMethodName = constructorArguments.ElementAt(1).Value.ToString(); + var interfaceName = $"I{className}"; var accessModifier = type.DeclaredAccessibility.ToString().ToLower(); var typeMembers = type.GetMembers(); @@ -35,7 +36,7 @@ public static (string source, string className) GenerateClassOutputs((INamedType namespace {type.ContainingNamespace.ToDisplayString()}; -{accessModifier} class {className} +{accessModifier} class {className} : {interfaceName} {{ {string.Join(@" ", formattedFields)} From e8cfc8983d4f21ed97107568371b889a1443267d Mon Sep 17 00:00:00 2001 From: CodingFlow <3643313+CodingFlow@users.noreply.github.com> Date: Sat, 29 Nov 2025 23:15:22 -0500 Subject: [PATCH 3/5] Minor refactoring to remove passing SemanticModel where it is not used. --- AsyncTaskOrchestratorGenerator/Main.cs | 10 +++++----- AsyncTaskOrchestratorGenerator/OutputGenerator.cs | 12 +++++------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/AsyncTaskOrchestratorGenerator/Main.cs b/AsyncTaskOrchestratorGenerator/Main.cs index 32f4672..1d4c28d 100644 --- a/AsyncTaskOrchestratorGenerator/Main.cs +++ b/AsyncTaskOrchestratorGenerator/Main.cs @@ -23,14 +23,14 @@ private static bool IsSyntaxTargetForGeneration(SyntaxNode syntaxNode, Cancellat return syntaxNode is TypeDeclarationSyntax; } - private static (INamedTypeSymbol, SemanticModel) GetSemanticTargetForGeneration(GeneratorAttributeSyntaxContext context, CancellationToken cancellationToken) { + private static INamedTypeSymbol GetSemanticTargetForGeneration(GeneratorAttributeSyntaxContext context, CancellationToken cancellationToken) { - return (context.TargetSymbol as INamedTypeSymbol, context.SemanticModel); + return context.TargetSymbol as INamedTypeSymbol; } - private static void Execute(SourceProductionContext context, (INamedTypeSymbol typeSymbol, SemanticModel semanticModel) typeInfo) { - var (classSource, className) = OutputGenerator.GenerateClassOutputs(typeInfo); - var (interfaceSource, interfaceName) = OutputGenerator.GenerateInterfaceOutputs(typeInfo.typeSymbol); + private static void Execute(SourceProductionContext context, INamedTypeSymbol typeSymbol) { + var (classSource, className) = OutputGenerator.GenerateClassOutputs(typeSymbol); + var (interfaceSource, interfaceName) = OutputGenerator.GenerateInterfaceOutputs(typeSymbol); context.AddSource($"{className}.generated.cs", SourceText.From(classSource, Encoding.UTF8, SourceHashAlgorithm.Sha256)); context.AddSource($"{interfaceName}.generated.cs", SourceText.From(interfaceSource, Encoding.UTF8, SourceHashAlgorithm.Sha256)); diff --git a/AsyncTaskOrchestratorGenerator/OutputGenerator.cs b/AsyncTaskOrchestratorGenerator/OutputGenerator.cs index 6814e56..a8f5d7c 100644 --- a/AsyncTaskOrchestratorGenerator/OutputGenerator.cs +++ b/AsyncTaskOrchestratorGenerator/OutputGenerator.cs @@ -8,9 +8,7 @@ namespace AsyncTaskOrchestratorGenerator { internal static class OutputGenerator { - public static (string source, string className) GenerateClassOutputs((INamedTypeSymbol typeSymbol, SemanticModel semanticModel) typeInfo) { - var type = typeInfo.typeSymbol; - var semanticModel = typeInfo.semanticModel; + public static (string source, string className) GenerateClassOutputs(INamedTypeSymbol type) { var constructorArguments = GetAttributeConstructorArguments(type); var className = constructorArguments.First().Value.ToString(); var executeMethodName = constructorArguments.ElementAt(1).Value.ToString(); @@ -50,10 +48,6 @@ namespace {type.ContainingNamespace.ToDisplayString()}; return (source, className); } - private static System.Collections.Immutable.ImmutableArray GetAttributeConstructorArguments(INamedTypeSymbol type) { - return type.GetAttributes().First((a) => a.AttributeClass.Name == nameof(AsyncTaskOrchestratorAttribute)).ConstructorArguments; - } - public static (string source, string interfaceName) GenerateInterfaceOutputs(INamedTypeSymbol type) { var constructorArguments = GetAttributeConstructorArguments(type); var className = constructorArguments.First().Value.ToString(); @@ -79,6 +73,10 @@ namespace {type.ContainingNamespace.ToDisplayString()}; return (source, interfaceName); } + private static System.Collections.Immutable.ImmutableArray GetAttributeConstructorArguments(INamedTypeSymbol type) { + return type.GetAttributes().First((a) => a.AttributeClass.Name == nameof(AsyncTaskOrchestratorAttribute)).ConstructorArguments; + } + private static (ExecuteMethodSignatureData, Dictionary, TaskData) CreateExecuteMethodData(INamedTypeSymbol type, IEnumerable fields, string executeMethodName) { var executeMethod = GetExecuteMethod(type); var statements = (executeMethod.DeclaringSyntaxReferences.First().GetSyntax() as MethodDeclarationSyntax).Body.Statements; From 79f2414fe997f8004436b8f06724fd7af7e92765 Mon Sep 17 00:00:00 2001 From: CodingFlow <3643313+CodingFlow@users.noreply.github.com> Date: Sat, 29 Nov 2025 23:18:50 -0500 Subject: [PATCH 4/5] Downgrade main project from .NET Standard 2.1 to .NET Standard 2.0 to avoid warning that targeting other frameworks for source generators has unpredicable behavior. --- .../AsyncTaskOrchestratorGenerator.csproj | 4 +- .../OutputGenerator.cs | 94 ++++++++++--------- 2 files changed, 52 insertions(+), 46 deletions(-) diff --git a/AsyncTaskOrchestratorGenerator/AsyncTaskOrchestratorGenerator.csproj b/AsyncTaskOrchestratorGenerator/AsyncTaskOrchestratorGenerator.csproj index 67c30d2..f4082f9 100644 --- a/AsyncTaskOrchestratorGenerator/AsyncTaskOrchestratorGenerator.csproj +++ b/AsyncTaskOrchestratorGenerator/AsyncTaskOrchestratorGenerator.csproj @@ -1,7 +1,7 @@ - + - netstandard2.1 + netstandard2.0 true True Async Task Orchestrator Generator diff --git a/AsyncTaskOrchestratorGenerator/OutputGenerator.cs b/AsyncTaskOrchestratorGenerator/OutputGenerator.cs index a8f5d7c..1e22c8d 100644 --- a/AsyncTaskOrchestratorGenerator/OutputGenerator.cs +++ b/AsyncTaskOrchestratorGenerator/OutputGenerator.cs @@ -81,30 +81,7 @@ private static (ExecuteMethodSignatureData, Dictionary, TaskDa var executeMethod = GetExecuteMethod(type); var statements = (executeMethod.DeclaringSyntaxReferences.First().GetSyntax() as MethodDeclarationSyntax).Body.Statements; var variableStatements = statements.Remove(statements.Last()); - - var variableData = variableStatements - .Select(s => s as LocalDeclarationStatementSyntax) - .SelectMany(v => v.Declaration.Variables) - .Select(declarationSyntax => { - var invocation = declarationSyntax.Initializer.Value as InvocationExpressionSyntax; - var methodAccessExpression = invocation.Expression as MemberAccessExpressionSyntax; - var methodCallTypeName = methodAccessExpression.ToString().Split('.').First(); - var methodCallType = fields.First(f => f.Name == methodCallTypeName).Type; - var methodCallName = methodAccessExpression.ToString().Split('.').Last(); - var methodSymbol = methodCallType.GetMembers(methodCallName).First() as IMethodSymbol; - - var arguments = invocation.ArgumentList.Arguments; - var argumentTypeNames = arguments.Select(a => a.ToString().Split('.').First()); - - return new TaskData - { - OutputName = declarationSyntax.Identifier.Text, - MethodCallName = methodAccessExpression.ToString(), - MethodCallReturnType = methodSymbol.ReturnType.ToString(), - DependenciesOutputNames = argumentTypeNames, - TaskName = $"{declarationSyntax.Identifier.Text}Task" - }; - }); + var variableData = GetVariableData(fields, variableStatements); var lastStatement = statements.Last() as ReturnStatementSyntax; var invocation = lastStatement.Expression as InvocationExpressionSyntax; @@ -134,6 +111,32 @@ private static (ExecuteMethodSignatureData, Dictionary, TaskDa }, variableData.ToDictionary(taskData => taskData.OutputName), finalTaskData); } + private static IEnumerable GetVariableData(IEnumerable fields, SyntaxList variableStatements) { + return variableStatements + .Select(s => s as LocalDeclarationStatementSyntax) + .SelectMany(v => v.Declaration.Variables) + .Select(declarationSyntax => { + var invocation = declarationSyntax.Initializer.Value as InvocationExpressionSyntax; + var methodAccessExpression = invocation.Expression as MemberAccessExpressionSyntax; + var methodCallTypeName = methodAccessExpression.ToString().Split('.').First(); + var methodCallType = fields.First(f => f.Name == methodCallTypeName).Type; + var methodCallName = methodAccessExpression.ToString().Split('.').Last(); + var methodSymbol = methodCallType.GetMembers(methodCallName).First() as IMethodSymbol; + + var arguments = invocation.ArgumentList.Arguments; + var argumentTypeNames = arguments.Select(a => a.ToString().Split('.').First()); + + return new TaskData + { + OutputName = declarationSyntax.Identifier.Text, + MethodCallName = methodAccessExpression.ToString(), + MethodCallReturnType = methodSymbol.ReturnType.ToString(), + DependenciesOutputNames = argumentTypeNames, + TaskName = $"{declarationSyntax.Identifier.Text}Task" + }; + }); + } + private static IMethodSymbol GetExecuteMethod(INamedTypeSymbol type) { return type .GetMembers() @@ -145,32 +148,18 @@ private static string FormatExecuteMethod(ExecuteMethodSignatureData signatureDa var formattedTaskDeclarations = data.Select(keyValue => { var item = keyValue.Value; var hasDependencies = item.DependenciesOutputNames.Any(); - return hasDependencies ? - $@"var {item.TaskName} = new {item.MethodCallReturnType}(() => default);": + return hasDependencies ? + $@"var {item.TaskName} = new {item.MethodCallReturnType}(() => default);" : $@"var {item.TaskName} = {item.MethodCallName}();"; }); var taskNames = data.Where(keyValue => !keyValue.Value.DependenciesOutputNames.Any()).Select(keyValue => keyValue.Value.TaskName); var formattedTasksList = $@"var tasksToProcess = new List {{ {string.Join(@", ", taskNames)} }};"; - - var formattedHandleTaskCompletions = data.Where(keyValue => keyValue.Value.DependenciesOutputNames.Any()).Select(keyValue => { - var item = keyValue.Value; - var dependencyTaskNames = item.DependenciesOutputNames.Select(depName => data[depName].TaskName); - var formattedCompletedDependencyTaskNames = string.Join(" && ", dependencyTaskNames.Select(tn => $"{tn}.IsCompleted")); - var formattedResultDependencyTaskNames = string.Join(", ", dependencyTaskNames.Select(tn => $"{tn}.Result")); - var formattedCallDependencies = $@"{item.TaskName} = {item.MethodCallName}({formattedResultDependencyTaskNames});"; - var formattedAddTaskToList = $@"tasksToProcess.Add({item.TaskName});"; - - return $@"if (!{item.TaskName}.IsCompleted && {formattedCompletedDependencyTaskNames}) - {{ - {formattedCallDependencies} - {formattedAddTaskToList} - }}"; - }); + var formattedHandleTaskCompletions = CreateFormattedHandleTaskCompletions(data); var formattedWhenEach = $@"await foreach (var completed in Task.WhenEach(tasksToProcess)) {{ - { string.Join(@" + {string.Join(@" ", formattedHandleTaskCompletions)} }}"; @@ -180,11 +169,11 @@ private static string FormatExecuteMethod(ExecuteMethodSignatureData signatureDa var formattedFinalResult = $@"var finalResult = await {finalTaskData.MethodCallName}({formattedResultDependencyTaskNames}); return finalResult;"; - + return $@"{signatureData.AccessModifier} async {signatureData.ReturnType} {signatureData.Name}() {{ {string.Join(@" - ", formattedTaskDeclarations) } + ", formattedTaskDeclarations)} {formattedTasksList} @@ -194,6 +183,23 @@ private static string FormatExecuteMethod(ExecuteMethodSignatureData signatureDa }}"; } + private static IEnumerable CreateFormattedHandleTaskCompletions(Dictionary data) { + return data.Where(keyValue => keyValue.Value.DependenciesOutputNames.Any()).Select(keyValue => { + var item = keyValue.Value; + var dependencyTaskNames = item.DependenciesOutputNames.Select(depName => data[depName].TaskName); + var formattedCompletedDependencyTaskNames = string.Join(" && ", dependencyTaskNames.Select(tn => $"{tn}.IsCompleted")); + var formattedResultDependencyTaskNames = string.Join(", ", dependencyTaskNames.Select(tn => $"{tn}.Result")); + var formattedCallDependencies = $@"{item.TaskName} = {item.MethodCallName}({formattedResultDependencyTaskNames});"; + var formattedAddTaskToList = $@"tasksToProcess.Add({item.TaskName});"; + + return $@"if (!{item.TaskName}.IsCompleted && {formattedCompletedDependencyTaskNames}) + {{ + {formattedCallDependencies} + {formattedAddTaskToList} + }}"; + }); + } + private static string FormatConstructor(INamedTypeSymbol type, string className, IEnumerable typeMembers) { var constructor = typeMembers .Where(m => m.Kind == SymbolKind.Method) From 50e74bfd9753567689dc3af9931659f1971f336f Mon Sep 17 00:00:00 2001 From: CodingFlow <3643313+CodingFlow@users.noreply.github.com> Date: Sat, 29 Nov 2025 23:21:59 -0500 Subject: [PATCH 5/5] Remove unnecessary newline. --- AsyncTaskOrchestratorGenerator/Main.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/AsyncTaskOrchestratorGenerator/Main.cs b/AsyncTaskOrchestratorGenerator/Main.cs index 1d4c28d..7be2e64 100644 --- a/AsyncTaskOrchestratorGenerator/Main.cs +++ b/AsyncTaskOrchestratorGenerator/Main.cs @@ -24,7 +24,6 @@ private static bool IsSyntaxTargetForGeneration(SyntaxNode syntaxNode, Cancellat } private static INamedTypeSymbol GetSemanticTargetForGeneration(GeneratorAttributeSyntaxContext context, CancellationToken cancellationToken) { - return context.TargetSymbol as INamedTypeSymbol; }