diff --git a/AsyncTaskOrchestratorGenerator.UnitTests/IOrchestrator.cs b/AsyncTaskOrchestratorGenerator.UnitTests/IOrchestrator.cs new file mode 100644 index 0000000..e1167e2 --- /dev/null +++ b/AsyncTaskOrchestratorGenerator.UnitTests/IOrchestrator.cs @@ -0,0 +1,9 @@ +// +#nullable restore + +namespace TestLibrary; + +internal interface IOrchestrator +{ + public Task Execute(); +} diff --git a/AsyncTaskOrchestratorGenerator.UnitTests/Orchestrator.cs b/AsyncTaskOrchestratorGenerator.UnitTests/Orchestrator.cs index 23b0c9c..dc8f8f8 100644 --- a/AsyncTaskOrchestratorGenerator.UnitTests/Orchestrator.cs +++ b/AsyncTaskOrchestratorGenerator.UnitTests/Orchestrator.cs @@ -6,7 +6,7 @@ namespace TestLibrary; -internal class Orchestrator +internal class Orchestrator : IOrchestrator { private readonly TestLibrary.A a; private readonly TestLibrary.B b; diff --git a/AsyncTaskOrchestratorGenerator.UnitTests/Tests.cs b/AsyncTaskOrchestratorGenerator.UnitTests/Tests.cs index 440f3c7..2e34a15 100644 --- a/AsyncTaskOrchestratorGenerator.UnitTests/Tests.cs +++ b/AsyncTaskOrchestratorGenerator.UnitTests/Tests.cs @@ -19,7 +19,8 @@ public void Setup() { [Test] public async Task OneInterface() { var source = await ReadCSharpFile(true); - var generated = await ReadCSharpFile(false); + var generatedClass = await ReadCSharpFile(false); + var generatedInterface = await ReadCSharpFile(false); await new VerifyCS.Test { @@ -35,7 +36,8 @@ public async Task OneInterface() { Sources = { source }, GeneratedSources = { - (typeof(Main), "Orchestrator.generated.cs", SourceText.From(generated, Encoding.UTF8, SourceHashAlgorithm.Sha256)), + (typeof(Main), "Orchestrator.generated.cs", SourceText.From(generatedClass, Encoding.UTF8, SourceHashAlgorithm.Sha256)), + (typeof(Main), "IOrchestrator.generated.cs", SourceText.From(generatedInterface, Encoding.UTF8, SourceHashAlgorithm.Sha256)), }, }, }.RunAsync(); diff --git a/AsyncTaskOrchestratorGenerator/AsyncTaskOrchestratorGenerator.csproj b/AsyncTaskOrchestratorGenerator/AsyncTaskOrchestratorGenerator.csproj index 67c30d2..f4082f9 100644 --- a/AsyncTaskOrchestratorGenerator/AsyncTaskOrchestratorGenerator.csproj +++ b/AsyncTaskOrchestratorGenerator/AsyncTaskOrchestratorGenerator.csproj @@ -1,7 +1,7 @@ - + - netstandard2.1 + netstandard2.0 true True Async Task Orchestrator Generator diff --git a/AsyncTaskOrchestratorGenerator/Main.cs b/AsyncTaskOrchestratorGenerator/Main.cs index 640616a..7be2e64 100644 --- a/AsyncTaskOrchestratorGenerator/Main.cs +++ b/AsyncTaskOrchestratorGenerator/Main.cs @@ -1,9 +1,6 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Text; -using System; -using System.Collections.Generic; -using System.Linq; using System.Text; using System.Threading; @@ -26,15 +23,16 @@ private static bool IsSyntaxTargetForGeneration(SyntaxNode syntaxNode, Cancellat return syntaxNode is TypeDeclarationSyntax; } - private static (INamedTypeSymbol, SemanticModel) GetSemanticTargetForGeneration(GeneratorAttributeSyntaxContext context, CancellationToken cancellationToken) { - - return (context.TargetSymbol as INamedTypeSymbol, context.SemanticModel); + private static INamedTypeSymbol GetSemanticTargetForGeneration(GeneratorAttributeSyntaxContext context, CancellationToken cancellationToken) { + return context.TargetSymbol as INamedTypeSymbol; } - private static void Execute(SourceProductionContext context, (INamedTypeSymbol typeSymbol, SemanticModel semanticModel) typeInfo) { - var (source, className) = OutputGenerator.GenerateOutputs(typeInfo); + private static void Execute(SourceProductionContext context, INamedTypeSymbol typeSymbol) { + var (classSource, className) = OutputGenerator.GenerateClassOutputs(typeSymbol); + var (interfaceSource, interfaceName) = OutputGenerator.GenerateInterfaceOutputs(typeSymbol); - context.AddSource($"{className}.generated.cs", SourceText.From(source, Encoding.UTF8, SourceHashAlgorithm.Sha256)); + context.AddSource($"{className}.generated.cs", SourceText.From(classSource, Encoding.UTF8, SourceHashAlgorithm.Sha256)); + context.AddSource($"{interfaceName}.generated.cs", SourceText.From(interfaceSource, Encoding.UTF8, SourceHashAlgorithm.Sha256)); } } } \ No newline at end of file diff --git a/AsyncTaskOrchestratorGenerator/OutputGenerator.cs b/AsyncTaskOrchestratorGenerator/OutputGenerator.cs index 36bd00e..1e22c8d 100644 --- a/AsyncTaskOrchestratorGenerator/OutputGenerator.cs +++ b/AsyncTaskOrchestratorGenerator/OutputGenerator.cs @@ -8,13 +8,11 @@ namespace AsyncTaskOrchestratorGenerator { internal static class OutputGenerator { - public static (string source, string className) GenerateOutputs((INamedTypeSymbol typeSymbol, SemanticModel semanticModel) typeInfo) { - var type = typeInfo.typeSymbol; - var semanticModel = typeInfo.semanticModel; - - var constructorArguments = type.GetAttributes().First((a) => a.AttributeClass.Name == nameof(AsyncTaskOrchestratorAttribute)).ConstructorArguments; + public static (string source, string className) GenerateClassOutputs(INamedTypeSymbol type) { + var constructorArguments = GetAttributeConstructorArguments(type); var className = constructorArguments.First().Value.ToString(); var executeMethodName = constructorArguments.ElementAt(1).Value.ToString(); + var interfaceName = $"I{className}"; var accessModifier = type.DeclaredAccessibility.ToString().ToLower(); var typeMembers = type.GetMembers(); @@ -36,7 +34,7 @@ public static (string source, string className) GenerateOutputs((INamedTypeSymbo namespace {type.ContainingNamespace.ToDisplayString()}; -{accessModifier} class {className} +{accessModifier} class {className} : {interfaceName} {{ {string.Join(@" ", formattedFields)} @@ -50,37 +48,40 @@ namespace {type.ContainingNamespace.ToDisplayString()}; return (source, className); } - private static (ExecuteMethodSignatureData, Dictionary, TaskData) CreateExecuteMethodData(INamedTypeSymbol type, IEnumerable fields, string executeMethodName) { - var executeMethod = type - .GetMembers() - .Where(m => m is IMethodSymbol) - .First(m => (m as IMethodSymbol).MethodKind == MethodKind.Ordinary) as IMethodSymbol; - var statements = (executeMethod.DeclaringSyntaxReferences.First().GetSyntax() as MethodDeclarationSyntax).Body.Statements; - var variableStatements = statements.Remove(statements.Last()); + public static (string source, string interfaceName) GenerateInterfaceOutputs(INamedTypeSymbol type) { + var constructorArguments = GetAttributeConstructorArguments(type); + var className = constructorArguments.First().Value.ToString(); + var executeMethodName = constructorArguments.ElementAt(1).Value.ToString(); + var interfaceName = $"I{className}"; + + var accessModifier = type.DeclaredAccessibility.ToString().ToLower(); + var executeMethod = GetExecuteMethod(type); + var executeMethodAccessibility = executeMethod.DeclaredAccessibility.ToString().ToLower(); + var formattedExecuteMethod = $"{executeMethodAccessibility} {executeMethod.ReturnType} {executeMethodName}();"; - var variableData = variableStatements - .Select(s => s as LocalDeclarationStatementSyntax) - .SelectMany(v => v.Declaration.Variables) - .Select(declarationSyntax => { - var invocation = declarationSyntax.Initializer.Value as InvocationExpressionSyntax; - var methodAccessExpression = invocation.Expression as MemberAccessExpressionSyntax; - var methodCallTypeName = methodAccessExpression.ToString().Split('.').First(); - var methodCallType = fields.First(f => f.Name == methodCallTypeName).Type; - var methodCallName = methodAccessExpression.ToString().Split('.').Last(); - var methodSymbol = methodCallType.GetMembers(methodCallName).First() as IMethodSymbol; + var source = +$@"// +#nullable restore - var arguments = invocation.ArgumentList.Arguments; - var argumentTypeNames = arguments.Select(a => a.ToString().Split('.').First()); +namespace {type.ContainingNamespace.ToDisplayString()}; - return new TaskData - { - OutputName = declarationSyntax.Identifier.Text, - MethodCallName = methodAccessExpression.ToString(), - MethodCallReturnType = methodSymbol.ReturnType.ToString(), - DependenciesOutputNames = argumentTypeNames, - TaskName = $"{declarationSyntax.Identifier.Text}Task" - }; - }); +{accessModifier} interface {interfaceName} +{{ + {formattedExecuteMethod} +}} +"; + return (source, interfaceName); + } + + private static System.Collections.Immutable.ImmutableArray GetAttributeConstructorArguments(INamedTypeSymbol type) { + return type.GetAttributes().First((a) => a.AttributeClass.Name == nameof(AsyncTaskOrchestratorAttribute)).ConstructorArguments; + } + + private static (ExecuteMethodSignatureData, Dictionary, TaskData) CreateExecuteMethodData(INamedTypeSymbol type, IEnumerable fields, string executeMethodName) { + var executeMethod = GetExecuteMethod(type); + var statements = (executeMethod.DeclaringSyntaxReferences.First().GetSyntax() as MethodDeclarationSyntax).Body.Statements; + var variableStatements = statements.Remove(statements.Last()); + var variableData = GetVariableData(fields, variableStatements); var lastStatement = statements.Last() as ReturnStatementSyntax; var invocation = lastStatement.Expression as InvocationExpressionSyntax; @@ -110,36 +111,55 @@ private static (ExecuteMethodSignatureData, Dictionary, TaskDa }, variableData.ToDictionary(taskData => taskData.OutputName), finalTaskData); } + private static IEnumerable GetVariableData(IEnumerable fields, SyntaxList variableStatements) { + return variableStatements + .Select(s => s as LocalDeclarationStatementSyntax) + .SelectMany(v => v.Declaration.Variables) + .Select(declarationSyntax => { + var invocation = declarationSyntax.Initializer.Value as InvocationExpressionSyntax; + var methodAccessExpression = invocation.Expression as MemberAccessExpressionSyntax; + var methodCallTypeName = methodAccessExpression.ToString().Split('.').First(); + var methodCallType = fields.First(f => f.Name == methodCallTypeName).Type; + var methodCallName = methodAccessExpression.ToString().Split('.').Last(); + var methodSymbol = methodCallType.GetMembers(methodCallName).First() as IMethodSymbol; + + var arguments = invocation.ArgumentList.Arguments; + var argumentTypeNames = arguments.Select(a => a.ToString().Split('.').First()); + + return new TaskData + { + OutputName = declarationSyntax.Identifier.Text, + MethodCallName = methodAccessExpression.ToString(), + MethodCallReturnType = methodSymbol.ReturnType.ToString(), + DependenciesOutputNames = argumentTypeNames, + TaskName = $"{declarationSyntax.Identifier.Text}Task" + }; + }); + } + + private static IMethodSymbol GetExecuteMethod(INamedTypeSymbol type) { + return type + .GetMembers() + .Where(m => m is IMethodSymbol) + .First(m => (m as IMethodSymbol).MethodKind == MethodKind.Ordinary) as IMethodSymbol; + } + private static string FormatExecuteMethod(ExecuteMethodSignatureData signatureData, Dictionary data, TaskData finalTaskData) { var formattedTaskDeclarations = data.Select(keyValue => { var item = keyValue.Value; var hasDependencies = item.DependenciesOutputNames.Any(); - return hasDependencies ? - $@"var {item.TaskName} = new {item.MethodCallReturnType}(() => default);": + return hasDependencies ? + $@"var {item.TaskName} = new {item.MethodCallReturnType}(() => default);" : $@"var {item.TaskName} = {item.MethodCallName}();"; }); var taskNames = data.Where(keyValue => !keyValue.Value.DependenciesOutputNames.Any()).Select(keyValue => keyValue.Value.TaskName); var formattedTasksList = $@"var tasksToProcess = new List {{ {string.Join(@", ", taskNames)} }};"; - - var formattedHandleTaskCompletions = data.Where(keyValue => keyValue.Value.DependenciesOutputNames.Any()).Select(keyValue => { - var item = keyValue.Value; - var dependencyTaskNames = item.DependenciesOutputNames.Select(depName => data[depName].TaskName); - var formattedCompletedDependencyTaskNames = string.Join(" && ", dependencyTaskNames.Select(tn => $"{tn}.IsCompleted")); - var formattedResultDependencyTaskNames = string.Join(", ", dependencyTaskNames.Select(tn => $"{tn}.Result")); - var formattedCallDependencies = $@"{item.TaskName} = {item.MethodCallName}({formattedResultDependencyTaskNames});"; - var formattedAddTaskToList = $@"tasksToProcess.Add({item.TaskName});"; - - return $@"if (!{item.TaskName}.IsCompleted && {formattedCompletedDependencyTaskNames}) - {{ - {formattedCallDependencies} - {formattedAddTaskToList} - }}"; - }); + var formattedHandleTaskCompletions = CreateFormattedHandleTaskCompletions(data); var formattedWhenEach = $@"await foreach (var completed in Task.WhenEach(tasksToProcess)) {{ - { string.Join(@" + {string.Join(@" ", formattedHandleTaskCompletions)} }}"; @@ -149,11 +169,11 @@ private static string FormatExecuteMethod(ExecuteMethodSignatureData signatureDa var formattedFinalResult = $@"var finalResult = await {finalTaskData.MethodCallName}({formattedResultDependencyTaskNames}); return finalResult;"; - + return $@"{signatureData.AccessModifier} async {signatureData.ReturnType} {signatureData.Name}() {{ {string.Join(@" - ", formattedTaskDeclarations) } + ", formattedTaskDeclarations)} {formattedTasksList} @@ -163,6 +183,23 @@ private static string FormatExecuteMethod(ExecuteMethodSignatureData signatureDa }}"; } + private static IEnumerable CreateFormattedHandleTaskCompletions(Dictionary data) { + return data.Where(keyValue => keyValue.Value.DependenciesOutputNames.Any()).Select(keyValue => { + var item = keyValue.Value; + var dependencyTaskNames = item.DependenciesOutputNames.Select(depName => data[depName].TaskName); + var formattedCompletedDependencyTaskNames = string.Join(" && ", dependencyTaskNames.Select(tn => $"{tn}.IsCompleted")); + var formattedResultDependencyTaskNames = string.Join(", ", dependencyTaskNames.Select(tn => $"{tn}.Result")); + var formattedCallDependencies = $@"{item.TaskName} = {item.MethodCallName}({formattedResultDependencyTaskNames});"; + var formattedAddTaskToList = $@"tasksToProcess.Add({item.TaskName});"; + + return $@"if (!{item.TaskName}.IsCompleted && {formattedCompletedDependencyTaskNames}) + {{ + {formattedCallDependencies} + {formattedAddTaskToList} + }}"; + }); + } + private static string FormatConstructor(INamedTypeSymbol type, string className, IEnumerable typeMembers) { var constructor = typeMembers .Where(m => m.Kind == SymbolKind.Method)