diff --git a/AsyncTaskOrchestratorGenerator.UnitTests/IOrchestrator.cs b/AsyncTaskOrchestratorGenerator.UnitTests/IOrchestrator.cs
new file mode 100644
index 0000000..e1167e2
--- /dev/null
+++ b/AsyncTaskOrchestratorGenerator.UnitTests/IOrchestrator.cs
@@ -0,0 +1,9 @@
+//
+#nullable restore
+
+namespace TestLibrary;
+
+internal interface IOrchestrator
+{
+ public Task Execute();
+}
diff --git a/AsyncTaskOrchestratorGenerator.UnitTests/Orchestrator.cs b/AsyncTaskOrchestratorGenerator.UnitTests/Orchestrator.cs
index 23b0c9c..dc8f8f8 100644
--- a/AsyncTaskOrchestratorGenerator.UnitTests/Orchestrator.cs
+++ b/AsyncTaskOrchestratorGenerator.UnitTests/Orchestrator.cs
@@ -6,7 +6,7 @@
namespace TestLibrary;
-internal class Orchestrator
+internal class Orchestrator : IOrchestrator
{
private readonly TestLibrary.A a;
private readonly TestLibrary.B b;
diff --git a/AsyncTaskOrchestratorGenerator.UnitTests/Tests.cs b/AsyncTaskOrchestratorGenerator.UnitTests/Tests.cs
index 440f3c7..2e34a15 100644
--- a/AsyncTaskOrchestratorGenerator.UnitTests/Tests.cs
+++ b/AsyncTaskOrchestratorGenerator.UnitTests/Tests.cs
@@ -19,7 +19,8 @@ public void Setup() {
[Test]
public async Task OneInterface() {
var source = await ReadCSharpFile(true);
- var generated = await ReadCSharpFile(false);
+ var generatedClass = await ReadCSharpFile(false);
+ var generatedInterface = await ReadCSharpFile(false);
await new VerifyCS.Test
{
@@ -35,7 +36,8 @@ public async Task OneInterface() {
Sources = { source },
GeneratedSources =
{
- (typeof(Main), "Orchestrator.generated.cs", SourceText.From(generated, Encoding.UTF8, SourceHashAlgorithm.Sha256)),
+ (typeof(Main), "Orchestrator.generated.cs", SourceText.From(generatedClass, Encoding.UTF8, SourceHashAlgorithm.Sha256)),
+ (typeof(Main), "IOrchestrator.generated.cs", SourceText.From(generatedInterface, Encoding.UTF8, SourceHashAlgorithm.Sha256)),
},
},
}.RunAsync();
diff --git a/AsyncTaskOrchestratorGenerator/AsyncTaskOrchestratorGenerator.csproj b/AsyncTaskOrchestratorGenerator/AsyncTaskOrchestratorGenerator.csproj
index 67c30d2..f4082f9 100644
--- a/AsyncTaskOrchestratorGenerator/AsyncTaskOrchestratorGenerator.csproj
+++ b/AsyncTaskOrchestratorGenerator/AsyncTaskOrchestratorGenerator.csproj
@@ -1,7 +1,7 @@
-
+
- netstandard2.1
+ netstandard2.0
true
True
Async Task Orchestrator Generator
diff --git a/AsyncTaskOrchestratorGenerator/Main.cs b/AsyncTaskOrchestratorGenerator/Main.cs
index 640616a..7be2e64 100644
--- a/AsyncTaskOrchestratorGenerator/Main.cs
+++ b/AsyncTaskOrchestratorGenerator/Main.cs
@@ -1,9 +1,6 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Text;
-using System;
-using System.Collections.Generic;
-using System.Linq;
using System.Text;
using System.Threading;
@@ -26,15 +23,16 @@ private static bool IsSyntaxTargetForGeneration(SyntaxNode syntaxNode, Cancellat
return syntaxNode is TypeDeclarationSyntax;
}
- private static (INamedTypeSymbol, SemanticModel) GetSemanticTargetForGeneration(GeneratorAttributeSyntaxContext context, CancellationToken cancellationToken) {
-
- return (context.TargetSymbol as INamedTypeSymbol, context.SemanticModel);
+ private static INamedTypeSymbol GetSemanticTargetForGeneration(GeneratorAttributeSyntaxContext context, CancellationToken cancellationToken) {
+ return context.TargetSymbol as INamedTypeSymbol;
}
- private static void Execute(SourceProductionContext context, (INamedTypeSymbol typeSymbol, SemanticModel semanticModel) typeInfo) {
- var (source, className) = OutputGenerator.GenerateOutputs(typeInfo);
+ private static void Execute(SourceProductionContext context, INamedTypeSymbol typeSymbol) {
+ var (classSource, className) = OutputGenerator.GenerateClassOutputs(typeSymbol);
+ var (interfaceSource, interfaceName) = OutputGenerator.GenerateInterfaceOutputs(typeSymbol);
- context.AddSource($"{className}.generated.cs", SourceText.From(source, Encoding.UTF8, SourceHashAlgorithm.Sha256));
+ context.AddSource($"{className}.generated.cs", SourceText.From(classSource, Encoding.UTF8, SourceHashAlgorithm.Sha256));
+ context.AddSource($"{interfaceName}.generated.cs", SourceText.From(interfaceSource, Encoding.UTF8, SourceHashAlgorithm.Sha256));
}
}
}
\ No newline at end of file
diff --git a/AsyncTaskOrchestratorGenerator/OutputGenerator.cs b/AsyncTaskOrchestratorGenerator/OutputGenerator.cs
index 36bd00e..1e22c8d 100644
--- a/AsyncTaskOrchestratorGenerator/OutputGenerator.cs
+++ b/AsyncTaskOrchestratorGenerator/OutputGenerator.cs
@@ -8,13 +8,11 @@ namespace AsyncTaskOrchestratorGenerator
{
internal static class OutputGenerator
{
- public static (string source, string className) GenerateOutputs((INamedTypeSymbol typeSymbol, SemanticModel semanticModel) typeInfo) {
- var type = typeInfo.typeSymbol;
- var semanticModel = typeInfo.semanticModel;
-
- var constructorArguments = type.GetAttributes().First((a) => a.AttributeClass.Name == nameof(AsyncTaskOrchestratorAttribute)).ConstructorArguments;
+ public static (string source, string className) GenerateClassOutputs(INamedTypeSymbol type) {
+ var constructorArguments = GetAttributeConstructorArguments(type);
var className = constructorArguments.First().Value.ToString();
var executeMethodName = constructorArguments.ElementAt(1).Value.ToString();
+ var interfaceName = $"I{className}";
var accessModifier = type.DeclaredAccessibility.ToString().ToLower();
var typeMembers = type.GetMembers();
@@ -36,7 +34,7 @@ public static (string source, string className) GenerateOutputs((INamedTypeSymbo
namespace {type.ContainingNamespace.ToDisplayString()};
-{accessModifier} class {className}
+{accessModifier} class {className} : {interfaceName}
{{
{string.Join(@"
", formattedFields)}
@@ -50,37 +48,40 @@ namespace {type.ContainingNamespace.ToDisplayString()};
return (source, className);
}
- private static (ExecuteMethodSignatureData, Dictionary, TaskData) CreateExecuteMethodData(INamedTypeSymbol type, IEnumerable fields, string executeMethodName) {
- var executeMethod = type
- .GetMembers()
- .Where(m => m is IMethodSymbol)
- .First(m => (m as IMethodSymbol).MethodKind == MethodKind.Ordinary) as IMethodSymbol;
- var statements = (executeMethod.DeclaringSyntaxReferences.First().GetSyntax() as MethodDeclarationSyntax).Body.Statements;
- var variableStatements = statements.Remove(statements.Last());
+ public static (string source, string interfaceName) GenerateInterfaceOutputs(INamedTypeSymbol type) {
+ var constructorArguments = GetAttributeConstructorArguments(type);
+ var className = constructorArguments.First().Value.ToString();
+ var executeMethodName = constructorArguments.ElementAt(1).Value.ToString();
+ var interfaceName = $"I{className}";
+
+ var accessModifier = type.DeclaredAccessibility.ToString().ToLower();
+ var executeMethod = GetExecuteMethod(type);
+ var executeMethodAccessibility = executeMethod.DeclaredAccessibility.ToString().ToLower();
+ var formattedExecuteMethod = $"{executeMethodAccessibility} {executeMethod.ReturnType} {executeMethodName}();";
- var variableData = variableStatements
- .Select(s => s as LocalDeclarationStatementSyntax)
- .SelectMany(v => v.Declaration.Variables)
- .Select(declarationSyntax => {
- var invocation = declarationSyntax.Initializer.Value as InvocationExpressionSyntax;
- var methodAccessExpression = invocation.Expression as MemberAccessExpressionSyntax;
- var methodCallTypeName = methodAccessExpression.ToString().Split('.').First();
- var methodCallType = fields.First(f => f.Name == methodCallTypeName).Type;
- var methodCallName = methodAccessExpression.ToString().Split('.').Last();
- var methodSymbol = methodCallType.GetMembers(methodCallName).First() as IMethodSymbol;
+ var source =
+$@"//
+#nullable restore
- var arguments = invocation.ArgumentList.Arguments;
- var argumentTypeNames = arguments.Select(a => a.ToString().Split('.').First());
+namespace {type.ContainingNamespace.ToDisplayString()};
- return new TaskData
- {
- OutputName = declarationSyntax.Identifier.Text,
- MethodCallName = methodAccessExpression.ToString(),
- MethodCallReturnType = methodSymbol.ReturnType.ToString(),
- DependenciesOutputNames = argumentTypeNames,
- TaskName = $"{declarationSyntax.Identifier.Text}Task"
- };
- });
+{accessModifier} interface {interfaceName}
+{{
+ {formattedExecuteMethod}
+}}
+";
+ return (source, interfaceName);
+ }
+
+ private static System.Collections.Immutable.ImmutableArray GetAttributeConstructorArguments(INamedTypeSymbol type) {
+ return type.GetAttributes().First((a) => a.AttributeClass.Name == nameof(AsyncTaskOrchestratorAttribute)).ConstructorArguments;
+ }
+
+ private static (ExecuteMethodSignatureData, Dictionary, TaskData) CreateExecuteMethodData(INamedTypeSymbol type, IEnumerable fields, string executeMethodName) {
+ var executeMethod = GetExecuteMethod(type);
+ var statements = (executeMethod.DeclaringSyntaxReferences.First().GetSyntax() as MethodDeclarationSyntax).Body.Statements;
+ var variableStatements = statements.Remove(statements.Last());
+ var variableData = GetVariableData(fields, variableStatements);
var lastStatement = statements.Last() as ReturnStatementSyntax;
var invocation = lastStatement.Expression as InvocationExpressionSyntax;
@@ -110,36 +111,55 @@ private static (ExecuteMethodSignatureData, Dictionary, TaskDa
}, variableData.ToDictionary(taskData => taskData.OutputName), finalTaskData);
}
+ private static IEnumerable GetVariableData(IEnumerable fields, SyntaxList variableStatements) {
+ return variableStatements
+ .Select(s => s as LocalDeclarationStatementSyntax)
+ .SelectMany(v => v.Declaration.Variables)
+ .Select(declarationSyntax => {
+ var invocation = declarationSyntax.Initializer.Value as InvocationExpressionSyntax;
+ var methodAccessExpression = invocation.Expression as MemberAccessExpressionSyntax;
+ var methodCallTypeName = methodAccessExpression.ToString().Split('.').First();
+ var methodCallType = fields.First(f => f.Name == methodCallTypeName).Type;
+ var methodCallName = methodAccessExpression.ToString().Split('.').Last();
+ var methodSymbol = methodCallType.GetMembers(methodCallName).First() as IMethodSymbol;
+
+ var arguments = invocation.ArgumentList.Arguments;
+ var argumentTypeNames = arguments.Select(a => a.ToString().Split('.').First());
+
+ return new TaskData
+ {
+ OutputName = declarationSyntax.Identifier.Text,
+ MethodCallName = methodAccessExpression.ToString(),
+ MethodCallReturnType = methodSymbol.ReturnType.ToString(),
+ DependenciesOutputNames = argumentTypeNames,
+ TaskName = $"{declarationSyntax.Identifier.Text}Task"
+ };
+ });
+ }
+
+ private static IMethodSymbol GetExecuteMethod(INamedTypeSymbol type) {
+ return type
+ .GetMembers()
+ .Where(m => m is IMethodSymbol)
+ .First(m => (m as IMethodSymbol).MethodKind == MethodKind.Ordinary) as IMethodSymbol;
+ }
+
private static string FormatExecuteMethod(ExecuteMethodSignatureData signatureData, Dictionary data, TaskData finalTaskData) {
var formattedTaskDeclarations = data.Select(keyValue => {
var item = keyValue.Value;
var hasDependencies = item.DependenciesOutputNames.Any();
- return hasDependencies ?
- $@"var {item.TaskName} = new {item.MethodCallReturnType}(() => default);":
+ return hasDependencies ?
+ $@"var {item.TaskName} = new {item.MethodCallReturnType}(() => default);" :
$@"var {item.TaskName} = {item.MethodCallName}();";
});
var taskNames = data.Where(keyValue => !keyValue.Value.DependenciesOutputNames.Any()).Select(keyValue => keyValue.Value.TaskName);
var formattedTasksList = $@"var tasksToProcess = new List {{ {string.Join(@", ", taskNames)} }};";
-
- var formattedHandleTaskCompletions = data.Where(keyValue => keyValue.Value.DependenciesOutputNames.Any()).Select(keyValue => {
- var item = keyValue.Value;
- var dependencyTaskNames = item.DependenciesOutputNames.Select(depName => data[depName].TaskName);
- var formattedCompletedDependencyTaskNames = string.Join(" && ", dependencyTaskNames.Select(tn => $"{tn}.IsCompleted"));
- var formattedResultDependencyTaskNames = string.Join(", ", dependencyTaskNames.Select(tn => $"{tn}.Result"));
- var formattedCallDependencies = $@"{item.TaskName} = {item.MethodCallName}({formattedResultDependencyTaskNames});";
- var formattedAddTaskToList = $@"tasksToProcess.Add({item.TaskName});";
-
- return $@"if (!{item.TaskName}.IsCompleted && {formattedCompletedDependencyTaskNames})
- {{
- {formattedCallDependencies}
- {formattedAddTaskToList}
- }}";
- });
+ var formattedHandleTaskCompletions = CreateFormattedHandleTaskCompletions(data);
var formattedWhenEach = $@"await foreach (var completed in Task.WhenEach(tasksToProcess))
{{
- { string.Join(@"
+ {string.Join(@"
", formattedHandleTaskCompletions)}
}}";
@@ -149,11 +169,11 @@ private static string FormatExecuteMethod(ExecuteMethodSignatureData signatureDa
var formattedFinalResult = $@"var finalResult = await {finalTaskData.MethodCallName}({formattedResultDependencyTaskNames});
return finalResult;";
-
+
return $@"{signatureData.AccessModifier} async {signatureData.ReturnType} {signatureData.Name}()
{{
{string.Join(@"
- ", formattedTaskDeclarations) }
+ ", formattedTaskDeclarations)}
{formattedTasksList}
@@ -163,6 +183,23 @@ private static string FormatExecuteMethod(ExecuteMethodSignatureData signatureDa
}}";
}
+ private static IEnumerable CreateFormattedHandleTaskCompletions(Dictionary data) {
+ return data.Where(keyValue => keyValue.Value.DependenciesOutputNames.Any()).Select(keyValue => {
+ var item = keyValue.Value;
+ var dependencyTaskNames = item.DependenciesOutputNames.Select(depName => data[depName].TaskName);
+ var formattedCompletedDependencyTaskNames = string.Join(" && ", dependencyTaskNames.Select(tn => $"{tn}.IsCompleted"));
+ var formattedResultDependencyTaskNames = string.Join(", ", dependencyTaskNames.Select(tn => $"{tn}.Result"));
+ var formattedCallDependencies = $@"{item.TaskName} = {item.MethodCallName}({formattedResultDependencyTaskNames});";
+ var formattedAddTaskToList = $@"tasksToProcess.Add({item.TaskName});";
+
+ return $@"if (!{item.TaskName}.IsCompleted && {formattedCompletedDependencyTaskNames})
+ {{
+ {formattedCallDependencies}
+ {formattedAddTaskToList}
+ }}";
+ });
+ }
+
private static string FormatConstructor(INamedTypeSymbol type, string className, IEnumerable typeMembers) {
var constructor = typeMembers
.Where(m => m.Kind == SymbolKind.Method)