diff --git a/src/CompositeKey.SourceGeneration.UnitTests/CompilationHelper.cs b/src/CompositeKey.SourceGeneration.UnitTests/CompilationHelper.cs index b17dddf..8f7455b 100644 --- a/src/CompositeKey.SourceGeneration.UnitTests/CompilationHelper.cs +++ b/src/CompositeKey.SourceGeneration.UnitTests/CompilationHelper.cs @@ -19,6 +19,7 @@ public sealed record SourceGeneratorResult public static class CompilationHelper { private static readonly Assembly SystemRuntimeAssembly = Assembly.Load(new AssemblyName("System.Runtime")); + private static readonly Assembly CompositeKeyAssembly = Assembly.Load(new AssemblyName("CompositeKey")); private static readonly CSharpParseOptions DefaultParseOptions = CreateParseOptions(); @@ -33,7 +34,7 @@ public static Compilation CreateCompilation(string source, string assemblyName = MetadataReference.CreateFromFile(typeof(Console).Assembly.Location), MetadataReference.CreateFromFile(SystemRuntimeAssembly.Location), - MetadataReference.CreateFromFile(typeof(CompositeKeyAttribute).Assembly.Location), + MetadataReference.CreateFromFile(CompositeKeyAssembly.Location), ]; return CSharpCompilation.Create( diff --git a/src/CompositeKey.SourceGeneration/CompositeKey.SourceGeneration.csproj b/src/CompositeKey.SourceGeneration/CompositeKey.SourceGeneration.csproj index a7d6177..8ff5586 100644 --- a/src/CompositeKey.SourceGeneration/CompositeKey.SourceGeneration.csproj +++ b/src/CompositeKey.SourceGeneration/CompositeKey.SourceGeneration.csproj @@ -7,6 +7,9 @@ true $(DefineConstants);BUILDING_SOURCE_GENERATOR + + + System.Runtime.CompilerServices.ModuleInitializerAttribute @@ -26,6 +29,10 @@ + + + + diff --git a/src/CompositeKey.SourceGeneration/Emitter.cs b/src/CompositeKey.SourceGeneration/Emitter.cs new file mode 100644 index 0000000..4cdc03b --- /dev/null +++ b/src/CompositeKey.SourceGeneration/Emitter.cs @@ -0,0 +1,1071 @@ +using System.Diagnostics; +using System.Reflection; +using System.Text; +using CompositeKey.SourceGeneration.Core; +using CompositeKey.SourceGeneration.Core.Extensions; +using CompositeKey.SourceGeneration.Model; +using CompositeKey.SourceGeneration.Model.Key; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Text; + +namespace CompositeKey.SourceGeneration; + +internal sealed class Emitter(SourceProductionContext context) +{ + private static readonly string AssemblyName = typeof(Emitter).Assembly.GetName().Name!; + private static readonly string AssemblyVersion = typeof(Emitter).Assembly.GetCustomAttribute()!.InformationalVersion; + + private const string NotNullWhen = "global::System.Diagnostics.CodeAnalysis.NotNullWhen"; + private const string MaybeNullWhen = "global::System.Diagnostics.CodeAnalysis.MaybeNullWhen"; + private const string InvariantCulture = "global::System.Globalization.CultureInfo.InvariantCulture"; + + private readonly SourceProductionContext _context = context; + + public void Emit(GenerationSpec generationSpec) + { + Debug.Assert(AssemblyName is not null); + Debug.Assert(AssemblyVersion is not null); + + var writer = CreateSourceWriterWithHeader(generationSpec); + + if (generationSpec.Key is PrimaryKeySpec primaryKeySpec) + EmitForPrimaryKey(writer, generationSpec.TargetType, primaryKeySpec); + else if (generationSpec.Key is CompositePrimaryKeySpec compositePrimaryKeySpec) + EmitForCompositePrimaryKey(writer, generationSpec.TargetType, compositePrimaryKeySpec); + + EmitCommonImplementations(writer, generationSpec.TargetType); + writer.EndBlock(); + + foreach (var enumSpec in generationSpec.TargetType.Properties.Select(p => p.EnumSpec).Where(es => es is not null)) + EnumGenerationHelper.EmitEnumHelperClass(writer, enumSpec!); + + string hintName = $"{generationSpec.TargetType.Type.FullyQualifiedName.Replace("global::", string.Empty)}.g.cs"; + AddSource(hintName, CompleteSourceFileAndReturnSourceText(writer)); + } + + private static void EmitForPrimaryKey(SourceWriter writer, TargetTypeSpec targetTypeSpec, PrimaryKeySpec keySpec) + { + var keyParts = keySpec.Parts.ToList(); + + WriteFormatMethodBodyForKeyParts(writer, "public override string ToString()", keyParts, keySpec.InvariantFormatting); + WriteFormatMethodBodyForKeyParts(writer, "public string ToPartitionKeyString()", keyParts, keySpec.InvariantFormatting); + WriteDynamicFormatMethodBodyForKeyParts(writer, "public string ToPartitionKeyString(int throughPartIndex, bool includeTrailingDelimiter = true)", keyParts, keySpec.InvariantFormatting); + + WriteParseMethodImplementation(); + WriteTryParseMethodImplementation(); + + return; + + void WriteParseMethodImplementation() + { + writer.WriteLines($$""" + public static {{targetTypeSpec.TypeName}} Parse(string primaryKey) + { + ArgumentNullException.ThrowIfNull(primaryKey); + + return Parse((ReadOnlySpan)primaryKey); + } + + public static {{targetTypeSpec.TypeName}} Parse(ReadOnlySpan primaryKey) + { + """); + writer.Indentation++; + + WriteLengthCheck(writer, keyParts, "primaryKey", true); + + Func getPrimaryKeyPartInputVariable = static _ => "primaryKey"; + string? primaryKeyPartCountVariable = null; + if (keyParts.Count > 1) + { + WriteSplitImplementation(writer, keyParts, "primaryKey", out string primaryKeyPartRangesVariable, true, out primaryKeyPartCountVariable); + getPrimaryKeyPartInputVariable = indexExpr => $"primaryKey[{primaryKeyPartRangesVariable}[{indexExpr}]]"; + } + + WriteParsePropertiesImplementation(writer, keyParts, getPrimaryKeyPartInputVariable, true, primaryKeyPartCountVariable); + + writer.WriteLine($"return {WriteConstructor(targetTypeSpec)};"); + + writer.EndBlock(); + writer.WriteLine(); + } + + void WriteTryParseMethodImplementation() + { + writer.WriteLines($$""" + public static bool TryParse([{{NotNullWhen}}(true)] string? primaryKey, [{{MaybeNullWhen}}(false)] out {{targetTypeSpec.TypeName}}? result) + { + if (primaryKey is null) + { + result = null; + return false; + } + + return TryParse((ReadOnlySpan)primaryKey, out result); + } + + public static bool TryParse(ReadOnlySpan primaryKey, [{{MaybeNullWhen}}(false)] out {{targetTypeSpec.TypeName}}? result) + { + result = null; + + """); + writer.Indentation++; + + WriteLengthCheck(writer, keyParts, "primaryKey", false); + + Func getPrimaryKeyPartInputVariable = static _ => "primaryKey"; + string? primaryKeyPartCountVariable = null; + if (keyParts.Count > 1) + { + WriteSplitImplementation(writer, keyParts, "primaryKey", out string primaryKeyPartRangesVariable, false, out primaryKeyPartCountVariable); + getPrimaryKeyPartInputVariable = indexExpr => $"primaryKey[{primaryKeyPartRangesVariable}[{indexExpr}]]"; + } + + WriteParsePropertiesImplementation(writer, keyParts, getPrimaryKeyPartInputVariable, false, primaryKeyPartCountVariable); + + writer.WriteLines($""" + result = {WriteConstructor(targetTypeSpec)}; + return true; + """); + + writer.EndBlock(); + writer.WriteLine(); + } + } + + private static void EmitForCompositePrimaryKey(SourceWriter writer, TargetTypeSpec targetTypeSpec, CompositePrimaryKeySpec keySpec) + { + var partitionKeyParts = keySpec.PartitionKeyParts.ToList(); + var sortKeyParts = keySpec.SortKeyParts.ToList(); + + WriteFormatMethodBodyForKeyParts(writer, "public override string ToString()", keySpec.AllParts, keySpec.InvariantFormatting); + WriteFormatMethodBodyForKeyParts(writer, "public string ToPartitionKeyString()", partitionKeyParts, keySpec.InvariantFormatting); + WriteDynamicFormatMethodBodyForKeyParts(writer, "public string ToPartitionKeyString(int throughPartIndex, bool includeTrailingDelimiter = true)", partitionKeyParts, keySpec.InvariantFormatting); + WriteFormatMethodBodyForKeyParts(writer, "public string ToSortKeyString()", sortKeyParts, keySpec.InvariantFormatting); + WriteDynamicFormatMethodBodyForKeyParts(writer, "public string ToSortKeyString(int throughPartIndex, bool includeTrailingDelimiter = true)", sortKeyParts, keySpec.InvariantFormatting); + + WriteParseMethodImplementation(); + WriteTryParseMethodImplementation(); + WriteCompositeParseMethodImplementation(); + WriteCompositeTryParseMethodImplementation(); + + return; + + void WritePrimaryKeySplit(bool shouldThrow) + { + writer.WriteLines($""" + const int expectedPrimaryKeyParts = 2; + Span primaryKeyPartRanges = stackalloc Range[expectedPrimaryKeyParts + 1]; + if (primaryKey.Split(primaryKeyPartRanges, '{keySpec.PrimaryDelimiterKeyPart.Value}', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries) != expectedPrimaryKeyParts) + {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; + + """); + } + + void WriteParseMethodImplementation() + { + writer.WriteLines($$""" + public static {{targetTypeSpec.TypeName}} Parse(string primaryKey) + { + ArgumentNullException.ThrowIfNull(primaryKey); + + return Parse((ReadOnlySpan)primaryKey); + } + + public static {{targetTypeSpec.TypeName}} Parse(ReadOnlySpan primaryKey) + { + """); + writer.Indentation++; + + WritePrimaryKeySplit(true); + + writer.WriteLine("return Parse(primaryKey[primaryKeyPartRanges[0]], primaryKey[primaryKeyPartRanges[1]]);"); + + writer.EndBlock(); + writer.WriteLine(); + } + + void WriteTryParseMethodImplementation() + { + writer.WriteLines($$""" + public static bool TryParse([{{NotNullWhen}}(true)] string? primaryKey, [{{MaybeNullWhen}}(false)] out {{targetTypeSpec.TypeName}}? result) + { + if (primaryKey is null) + { + result = null; + return false; + } + + return TryParse((ReadOnlySpan)primaryKey, out result); + } + + public static bool TryParse(ReadOnlySpan primaryKey, [{{MaybeNullWhen}}(false)] out {{targetTypeSpec.TypeName}}? result) + { + result = null; + + """); + writer.Indentation++; + + WritePrimaryKeySplit(false); + + writer.WriteLine("return TryParse(primaryKey[primaryKeyPartRanges[0]], primaryKey[primaryKeyPartRanges[1]], out result);"); + + writer.EndBlock(); + writer.WriteLine(); + } + + void WriteCompositeParseMethodImplementation() + { + writer.WriteLines($$""" + public static {{targetTypeSpec.TypeName}} Parse(string partitionKey, string sortKey) + { + ArgumentNullException.ThrowIfNull(partitionKey); + ArgumentNullException.ThrowIfNull(sortKey); + + return Parse((ReadOnlySpan)partitionKey, (ReadOnlySpan)sortKey); + } + + public static {{targetTypeSpec.TypeName}} Parse(ReadOnlySpan partitionKey, ReadOnlySpan sortKey) + { + """); + writer.Indentation++; + + WriteLengthCheck(writer, partitionKeyParts, "partitionKey", true); + WriteLengthCheck(writer, sortKeyParts, "sortKey", true); + + Func getPartitionKeyPartInputVariable = static _ => "partitionKey"; + string? partitionKeyPartCountVariable = null; + if (partitionKeyParts.Count > 1) + { + WriteSplitImplementation(writer, partitionKeyParts, "partitionKey", out string partitionKeyPartRangesVariable, true, out partitionKeyPartCountVariable); + getPartitionKeyPartInputVariable = indexExpr => $"partitionKey[{partitionKeyPartRangesVariable}[{indexExpr}]]"; + } + + Func getSortKeyPartInputVariable = static _ => "sortKey"; + string? sortKeyPartCountVariable = null; + if (sortKeyParts.Count > 1) + { + WriteSplitImplementation(writer, sortKeyParts, "sortKey", out string sortKeyPartRangesVariable, true, out sortKeyPartCountVariable); + getSortKeyPartInputVariable = indexExpr => $"sortKey[{sortKeyPartRangesVariable}[{indexExpr}]]"; + } + + var propertyNameCounts = partitionKeyParts.Concat(sortKeyParts).OfType().GroupBy(p => p.Property.CamelCaseName).ToDictionary(g => g.Key, _ => 0); + WriteParsePropertiesImplementation(writer, partitionKeyParts, getPartitionKeyPartInputVariable, true, propertyNameCounts, partitionKeyPartCountVariable); + WriteParsePropertiesImplementation(writer, sortKeyParts, getSortKeyPartInputVariable, true, propertyNameCounts, sortKeyPartCountVariable); + + writer.WriteLine($"return {WriteConstructor(targetTypeSpec)};"); + + writer.EndBlock(); + writer.WriteLine(); + } + + void WriteCompositeTryParseMethodImplementation() + { + writer.WriteLines($$""" + public static bool TryParse([{{NotNullWhen}}(true)] string partitionKey, string sortKey, [{{MaybeNullWhen}}(false)] out {{targetTypeSpec.TypeName}}? result) + { + if (partitionKey is null || sortKey is null) + { + result = null; + return false; + } + + return TryParse((ReadOnlySpan)partitionKey, (ReadOnlySpan)sortKey, out result); + } + + public static bool TryParse(ReadOnlySpan partitionKey, ReadOnlySpan sortKey, [{{MaybeNullWhen}}(false)] out {{targetTypeSpec.TypeName}}? result) + { + result = null; + + """); + writer.Indentation++; + + WriteLengthCheck(writer, partitionKeyParts, "partitionKey", false); + WriteLengthCheck(writer, sortKeyParts, "sortKey", false); + + Func getPartitionKeyPartInputVariable = static _ => "partitionKey"; + string? partitionKeyPartCountVariable = null; + if (partitionKeyParts.Count > 1) + { + WriteSplitImplementation(writer, partitionKeyParts, "partitionKey", out string partitionKeyPartRangesVariable, false, out partitionKeyPartCountVariable); + getPartitionKeyPartInputVariable = indexExpr => $"partitionKey[{partitionKeyPartRangesVariable}[{indexExpr}]]"; + } + + Func getSortKeyPartInputVariable = static _ => "sortKey"; + string? sortKeyPartCountVariable = null; + if (sortKeyParts.Count > 1) + { + WriteSplitImplementation(writer, sortKeyParts, "sortKey", out string sortKeyPartRangesVariable, false, out sortKeyPartCountVariable); + getSortKeyPartInputVariable = indexExpr => $"sortKey[{sortKeyPartRangesVariable}[{indexExpr}]]"; + } + + var propertyNameCounts = partitionKeyParts.Concat(sortKeyParts).OfType().GroupBy(p => p.Property.CamelCaseName).ToDictionary(g => g.Key, _ => 0); + WriteParsePropertiesImplementation(writer, partitionKeyParts, getPartitionKeyPartInputVariable, false, propertyNameCounts, partitionKeyPartCountVariable); + WriteParsePropertiesImplementation(writer, sortKeyParts, getSortKeyPartInputVariable, false, propertyNameCounts, sortKeyPartCountVariable); + + writer.WriteLines($""" + result = {WriteConstructor(targetTypeSpec)}; + return true; + """); + + writer.EndBlock(); + writer.WriteLine(); + } + } + + private static void EmitCommonImplementations(SourceWriter writer, TargetTypeSpec targetTypeSpec) + { + writer.WriteLines($$""" + /// + string IFormattable.ToString(string? format, IFormatProvider? formatProvider) => ToString(); + + /// + static {{targetTypeSpec.TypeName}} IParsable<{{targetTypeSpec.TypeName}}>.Parse(string s, IFormatProvider? provider) => Parse(s); + + /// + static bool IParsable<{{targetTypeSpec.TypeName}}>.TryParse([{{NotNullWhen}}(true)] string? s, IFormatProvider? provider, [{{MaybeNullWhen}}(false)] out {{targetTypeSpec.TypeName}} result) => TryParse(s, out result); + + /// + static {{targetTypeSpec.TypeName}} ISpanParsable<{{targetTypeSpec.TypeName}}>.Parse(ReadOnlySpan s, IFormatProvider? provider) => Parse(s); + + /// + static bool ISpanParsable<{{targetTypeSpec.TypeName}}>.TryParse(ReadOnlySpan s, IFormatProvider? provider, [{{MaybeNullWhen}}(false)] out {{targetTypeSpec.TypeName}} result) => TryParse(s, out result); + """); + } + + private static void WriteLengthCheck(SourceWriter writer, List parts, string inputName, bool shouldThrow) + { + int lengthRequired = parts.Select(p => p.LengthRequired).Sum(); + bool exactLengthRequirement = parts.All(p => p.ExactLengthRequirement); + + writer.WriteLines($""" + if ({inputName}.Length {(exactLengthRequirement ? "!=" : "<")} {lengthRequired}) + {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; + + """); + } + + private static void WriteSplitImplementation(SourceWriter writer, List parts, string inputName, out string partRangesVariableName, bool shouldThrow, out string? partCountVariableName) + { + var repeatingPart = parts.OfType().FirstOrDefault(); + var uniqueDelimiters = parts.OfType().Select(d => d.Value).Distinct().ToList(); + + partRangesVariableName = $"{inputName}PartRanges"; + partCountVariableName = null; + + if (repeatingPart is not null) + { + bool sameSeparator = uniqueDelimiters.Contains(repeatingPart.Separator); + + if (sameSeparator) + { + // Same separator as key delimiters: split produces variable number of parts + int fixedValueParts = parts.OfType().Count(p => p is not RepeatingPropertyKeyPart); + + (string method, string delimiters) = uniqueDelimiters switch + { + { Count: 1 } => ("Split", $"'{uniqueDelimiters[0]}'"), + { Count: > 1 } => ("SplitAny", $"\"{string.Join(string.Empty, uniqueDelimiters)}\""), + _ => throw new InvalidOperationException() + }; + + string minPartsVariable = $"minExpected{inputName.FirstToUpperInvariant()}Parts"; + partCountVariableName = $"{inputName}PartCount"; + + writer.WriteLines($""" + const int {minPartsVariable} = {fixedValueParts + 1}; + Span {partRangesVariableName} = stackalloc Range[128]; + int {partCountVariableName} = {inputName}.{method}({partRangesVariableName}, {delimiters}, StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); + if ({partCountVariableName} < {minPartsVariable}) + {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; + + """); + } + else + { + // Different separator: split by fixed delimiters, last part contains the repeating section + int expectedParts = parts.OfType().Count(); + + (string method, string delimiters) = uniqueDelimiters switch + { + { Count: 1 } => ("Split", $"'{uniqueDelimiters[0]}'"), + { Count: > 1 } => ("SplitAny", $"\"{string.Join(string.Empty, uniqueDelimiters)}\""), + _ => throw new InvalidOperationException() + }; + + string expectedPartsVariableName = $"expected{inputName.FirstToUpperInvariant()}Parts"; + + writer.WriteLines($""" + const int {expectedPartsVariableName} = {expectedParts}; + Span {partRangesVariableName} = stackalloc Range[{expectedPartsVariableName} + 1]; + if ({inputName}.{method}({partRangesVariableName}, {delimiters}, StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries) != {expectedPartsVariableName}) + {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; + + """); + } + } + else + { + int expectedParts = parts.OfType().Count(); + + (string method, string delimiters) = uniqueDelimiters switch + { + { Count: 1 } => ("Split", $"'{uniqueDelimiters[0]}'"), + { Count: > 1 } => ("SplitAny", $"\"{string.Join(string.Empty, uniqueDelimiters)}\""), + _ => throw new InvalidOperationException() + }; + + string expectedPartsVariableName = $"expected{inputName.FirstToUpperInvariant()}Parts"; + + writer.WriteLines($""" + const int {expectedPartsVariableName} = {expectedParts}; + Span {partRangesVariableName} = stackalloc Range[{expectedPartsVariableName} + 1]; + if ({inputName}.{method}({partRangesVariableName}, {delimiters}, StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries) != {expectedPartsVariableName}) + {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; + + """); + } + } + + private static void WriteParsePropertiesImplementation( + SourceWriter writer, List parts, Func getPartInputVariable, bool shouldThrow, string? inputPartCountVariable = null) + { + var propertyNameCounts = parts.OfType().GroupBy(p => p.Property.CamelCaseName).ToDictionary(g => g.Key, _ => 0); + WriteParsePropertiesImplementation(writer, parts, getPartInputVariable, shouldThrow, propertyNameCounts, inputPartCountVariable); + } + + private static void WriteParsePropertiesImplementation( + SourceWriter writer, List parts, Func getPartInputVariable, bool shouldThrow, Dictionary propertyNameCounts, string? inputPartCountVariable = null) + { + var valueParts = parts.OfType().ToArray(); + for (int i = 0; i < valueParts.Length; i++) + { + var valueKeyPart = valueParts[i]; + string partInputVariable = getPartInputVariable($"{i}"); + + if (valueKeyPart is ConstantKeyPart c) + { + writer.WriteLines($""" + if (!{partInputVariable}.Equals("{c.Value}", StringComparison.Ordinal)) + {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; + + """); + continue; + } + + if (valueKeyPart is RepeatingPropertyKeyPart repeatingPart) + { + WriteRepeatingPropertyParse(repeatingPart, i); + continue; + } + + (string camelCaseName, string? originalCamelCaseName) = valueKeyPart is PropertyKeyPart propertyPart + ? GetCamelCaseName(propertyPart.Property, propertyNameCounts) + : throw new InvalidOperationException($"Expected a {nameof(PropertyKeyPart)} but got a {valueKeyPart.GetType().Name}"); + + switch (valueKeyPart) + { + case PropertyKeyPart { ParseType: ParseType.Guid } part: + writer.WriteLines($""" + if ({ToStrictLengthCheck(part, partInputVariable)}!Guid.TryParseExact({partInputVariable}, "{part.Format}", out var {camelCaseName})) + {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; + + """); + break; + + case PropertyKeyPart { ParseType: ParseType.String }: + writer.WriteLines($""" + if ({partInputVariable}.Length == 0) + {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; + + string {camelCaseName} = {partInputVariable}.ToString(); + + """); + break; + + case PropertyKeyPart { ParseType: ParseType.Enum } part: + if (part.Property.EnumSpec is null) + throw new InvalidOperationException($"{nameof(part.Property.EnumSpec)} is null"); + + writer.WriteLines($""" + if (!{part.Property.EnumSpec.Name}Helper.TryParse({partInputVariable}, out var {camelCaseName})) + {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; + + """); + break; + + case PropertyKeyPart { ParseType: ParseType.SpanParsable } part: + writer.WriteLines($""" + if (!{part.Property.Type.FullyQualifiedName}.TryParse({partInputVariable}, out var {camelCaseName})) + {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; + + """); + break; + } + + if (originalCamelCaseName is not null) + { + writer.WriteLines($""" + if (!{originalCamelCaseName}.Equals({camelCaseName})) + {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; + + """); + } + } + + return; + + static (string camelCaseName, string? originalCamelCaseName) GetCamelCaseName(PropertySpec property, Dictionary propertyNameCounts) + { + int propertyCount = propertyNameCounts[property.CamelCaseName]++; + return propertyCount == 0 + ? (property.CamelCaseName, null) + : ($"{property.CamelCaseName}{propertyCount}", property.CamelCaseName); + } + + static string ToStrictLengthCheck(KeyPart part, string input) => + part.ExactLengthRequirement ? $"{input}.Length != {part.LengthRequired} || " : string.Empty; + + void WriteRepeatingPropertyParse(RepeatingPropertyKeyPart repeatingPart, int valuePartIndex) + { + string camelCaseName = repeatingPart.Property.CamelCaseName; + string innerTypeName = repeatingPart.InnerType.FullyQualifiedName; + var uniqueDelimiters = parts.OfType().Select(d => d.Value).Distinct().ToList(); + bool sameSeparator = uniqueDelimiters.Contains(repeatingPart.Separator); + + string itemVar = $"{camelCaseName}Item"; + string listVar = camelCaseName; + + if (sameSeparator && inputPartCountVariable is not null) + { + // Same separator: repeating items are at indices valuePartIndex..partCount-1 + writer.WriteLines($""" + var {listVar} = new global::System.Collections.Generic.List<{innerTypeName}>(); + """); + + writer.StartBlock($"for (int ri = {valuePartIndex}; ri < {inputPartCountVariable}; ri++)"); + + string riAccess = getPartInputVariable("ri"); + + WriteRepeatingItemParse(repeatingPart, riAccess, itemVar, listVar); + + writer.EndBlock(); + writer.WriteLine(); + } + else + { + // Different separator: sub-split the part by the repeating separator + string partInputVariable = getPartInputVariable($"{valuePartIndex}"); + string repeatingRangesVar = $"{camelCaseName}Ranges"; + string repeatingCountVar = $"{camelCaseName}Count"; + + writer.WriteLines($""" + Span {repeatingRangesVar} = stackalloc Range[128]; + int {repeatingCountVar} = {partInputVariable}.Split({repeatingRangesVar}, '{repeatingPart.Separator}', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); + if ({repeatingCountVar} < 1) + {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; + + var {listVar} = new global::System.Collections.Generic.List<{innerTypeName}>(); + """); + + writer.StartBlock($"for (int ri = 0; ri < {repeatingCountVar}; ri++)"); + + string riAccess = $"{partInputVariable}[{repeatingRangesVar}[ri]]"; + WriteRepeatingItemParse(repeatingPart, riAccess, itemVar, listVar); + + writer.EndBlock(); + writer.WriteLine(); + } + + // Validate at least 1 item + writer.WriteLines($""" + if ({listVar}.Count == 0) + {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; + + """); + } + + void WriteRepeatingItemParse(RepeatingPropertyKeyPart repeatingPart, string itemInput, string itemVar, string listVar) + { + string innerTypeName = repeatingPart.InnerType.FullyQualifiedName; + + switch (repeatingPart.InnerParseType) + { + case ParseType.Guid: + writer.WriteLines($""" + if (!Guid.TryParseExact({itemInput}, "{repeatingPart.Format}", out var {itemVar})) + {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; + {listVar}.Add({itemVar}); + """); + break; + + case ParseType.String: + writer.WriteLines($""" + if ({itemInput}.Length == 0) + {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; + {listVar}.Add({itemInput}.ToString()); + """); + break; + + case ParseType.Enum: + if (repeatingPart.Property.EnumSpec is null) + throw new InvalidOperationException($"{nameof(repeatingPart.Property.EnumSpec)} is null"); + + writer.WriteLines($""" + if (!{repeatingPart.Property.EnumSpec.Name}Helper.TryParse({itemInput}, out var {itemVar})) + {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; + {listVar}.Add({itemVar}); + """); + break; + + case ParseType.SpanParsable: + writer.WriteLines($""" + if (!{innerTypeName}.TryParse({itemInput}, out var {itemVar})) + {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; + {listVar}.Add({itemVar}); + """); + break; + } + } + } + + private static string WriteConstructor(TargetTypeSpec targetTypeSpec) + { + var builder = new StringBuilder(); + builder.Append($"new {targetTypeSpec.TypeName}("); + + if (targetTypeSpec.ConstructorParameters.Count > 0) + { + foreach (var parameter in targetTypeSpec.ConstructorParameters) + { + var property = targetTypeSpec.Properties.FirstOrDefault(p => p.CamelCaseName == parameter.CamelCaseName); + builder.Append(property?.CollectionType == CollectionType.ImmutableArray + ? $"global::System.Collections.Immutable.ImmutableArray.CreateRange({parameter.CamelCaseName}), " + : $"{parameter.CamelCaseName}, "); + } + + builder.Length -= 2; // Remove the last ", " + } + + builder.Append(')'); + + if (targetTypeSpec.PropertyInitializers.Count > 0) + { + builder.Append(" { "); + + foreach (var initializer in targetTypeSpec.PropertyInitializers) + { + var property = targetTypeSpec.Properties.FirstOrDefault(p => p.CamelCaseName == initializer.CamelCaseName); + builder.Append(property?.CollectionType == CollectionType.ImmutableArray + ? $"{initializer.Name} = global::System.Collections.Immutable.ImmutableArray.CreateRange({initializer.CamelCaseName}), " + : $"{initializer.Name} = {initializer.CamelCaseName}, "); + } + + builder.Length -= 2; // Remove the last ", " + builder.Append(" }"); + } + + return builder.ToString(); + } + + private static void WriteFormatMethodBodyForKeyParts( + SourceWriter writer, string methodDeclaration, IReadOnlyList keyParts, bool invariantFormatting) + { + writer.StartBlock(methodDeclaration); + + bool hasRepeatingPart = keyParts.Any(kp => kp is RepeatingPropertyKeyPart); + + if (hasRepeatingPart) + { + WriteRepeatingFormatBody(); + } + else if (keyParts.All(kp => kp is + DelimiterKeyPart + or ConstantKeyPart + or PropertyKeyPart { FormatType: FormatType.Guid, ExactLengthRequirement: true } + or PropertyKeyPart { FormatType: FormatType.Enum, Format: "g" } + or PropertyKeyPart { FormatType: FormatType.String })) + { + string lengthRequired = keyParts + .Where(kp => kp.ExactLengthRequirement) + .Sum(kp => kp switch + { + DelimiterKeyPart => 1, + ConstantKeyPart c => c.Value.Length, + PropertyKeyPart p => p.LengthRequired, + _ => throw new InvalidOperationException() + }) + .ToString(); + + foreach (var keyPart in keyParts.Where(kp => !kp.ExactLengthRequirement)) + { + if (lengthRequired.Length != 0) + lengthRequired += " + "; + + lengthRequired += keyPart switch + { + PropertyKeyPart { FormatType: FormatType.Enum, Property.EnumSpec: not null } p => $"{p.Property.EnumSpec.Name}Helper.GetFormattedLength({p.Property.Name})", + PropertyKeyPart { FormatType: FormatType.String } p => $"{p.Property.Name}.Length", + _ => throw new InvalidOperationException() + }; + } + + writer.StartBlock($"return string.Create({lengthRequired}, this, static (destination, state) =>"); + + writer.WriteLine("int position = 0;"); + writer.WriteLine(); + + for (int i = 0; i < keyParts.Count; i++) + { + var keyPart = keyParts[i]; + switch (keyPart) + { + case DelimiterKeyPart d: + writer.WriteLine($"destination[position] = '{d.Value}';"); + writer.WriteLine("position += 1;"); + break; + case ConstantKeyPart c: + writer.WriteLine($"\"{c.Value}\".CopyTo(destination[position..]);"); + writer.WriteLine($"position += {c.Value.Length};"); + break; + case PropertyKeyPart { FormatType: FormatType.Guid } p: + string formatProvider = invariantFormatting ? InvariantCulture : "null"; + writer.StartBlock(); + writer.WriteLine($"if (!((ISpanFormattable)state.{p.Property.Name}).TryFormat(destination[position..], out int {GetCharsWritten(p.Property)}, \"{p.Format ?? "d"}\", {formatProvider}))"); + writer.WriteLine("\tthrow new FormatException();\n"); + writer.WriteLine($"position += {GetCharsWritten(p.Property)};"); + writer.EndBlock(); + break; + case PropertyKeyPart { FormatType: FormatType.Enum, Property.EnumSpec: not null } p: + writer.StartBlock(); + writer.WriteLine($"if (!{p.Property.EnumSpec.Name}Helper.TryFormat(state.{p.Property.Name}, destination[position..], out int {GetCharsWritten(p.Property)}))"); + writer.WriteLine("\tthrow new FormatException();\n"); + writer.WriteLine($"position += {GetCharsWritten(p.Property)};"); + writer.EndBlock(); + break; + case PropertyKeyPart { FormatType: FormatType.String } p: + writer.WriteLine($"state.{p.Property.Name}.CopyTo(destination[position..]);"); + writer.WriteLine($"position += state.{p.Property.Name}.Length;"); + break; + default: + throw new InvalidOperationException(); + } + + if (i != keyParts.Count - 1) + writer.WriteLine(); + } + + writer.Indentation--; + writer.WriteLine("});"); + } + else + { + string formatString = BuildFormatStringForKeyParts(keyParts); + writer.WriteLine(invariantFormatting + ? $"return string.Create({InvariantCulture}, $\"{formatString}\");" + : $"return $\"{formatString}\";"); + } + + writer.EndBlock(); + writer.WriteLine(); + + return; + + static string GetCharsWritten(PropertySpec p) => $"{p.CamelCaseName}CharsWritten"; + + void WriteRepeatingFormatBody() + { + // Emit empty collection checks for all repeating parts + foreach (var keyPart in keyParts.OfType()) + { + string countExpression = GetRepeatingCountExpression(keyPart.Property); + + writer.WriteLines($""" + if ({countExpression} == 0) + throw new FormatException("Collection must contain at least one item."); + + """); + } + + // Count fixed literal lengths and variable parts for DefaultInterpolatedStringHandler + int fixedLiteralLength = 0; + int formattedCount = 0; + foreach (var keyPart in keyParts) + { + switch (keyPart) + { + case DelimiterKeyPart: + fixedLiteralLength += 1; + break; + case ConstantKeyPart c: + fixedLiteralLength += c.Value.Length; + break; + case PropertyKeyPart: + formattedCount++; + break; + case RepeatingPropertyKeyPart: + // Will be handled dynamically in the loop + break; + } + } + + string formatProvider = invariantFormatting ? InvariantCulture : "null"; + + writer.WriteLines($""" + var handler = new System.Runtime.CompilerServices.DefaultInterpolatedStringHandler({fixedLiteralLength}, {formattedCount}, {formatProvider}); + """); + + foreach (var keyPart in keyParts) + { + switch (keyPart) + { + case DelimiterKeyPart d: + writer.WriteLine($"handler.AppendLiteral(\"{d.Value}\");"); + break; + case ConstantKeyPart c: + writer.WriteLine($"handler.AppendLiteral(\"{c.Value}\");"); + break; + case PropertyKeyPart p: + if (p.Format is not null) + writer.WriteLine($"handler.AppendFormatted({p.Property.Name}, \"{p.Format}\");"); + else + writer.WriteLine($"handler.AppendFormatted({p.Property.Name});"); + break; + case RepeatingPropertyKeyPart rp: + WriteRepeatingPartFormatLoop(rp); + break; + } + } + + writer.WriteLine(); + writer.WriteLine("return handler.ToStringAndClear();"); + } + + void WriteRepeatingPartFormatLoop(RepeatingPropertyKeyPart rp) + { + string countExpression = GetRepeatingCountExpression(rp.Property); + + writer.StartBlock($"for (int i = 0; i < {countExpression}; i++)"); + + writer.WriteLines($""" + if (i > 0) + handler.AppendLiteral("{rp.Separator}"); + + """); + + if (rp.Format is not null) + writer.WriteLine($"handler.AppendFormatted({rp.Property.Name}[i], \"{rp.Format}\");"); + else + writer.WriteLine($"handler.AppendFormatted({rp.Property.Name}[i]);"); + + writer.EndBlock(); + } + } + + private static void WriteDynamicFormatMethodBodyForKeyParts( + SourceWriter writer, string methodDeclaration, IReadOnlyList keyParts, bool invariantFormatting) + { + var repeatingPart = keyParts.OfType().FirstOrDefault(); + + if (repeatingPart is null) + { + WriteDynamicFormatMethodBodyForFixedKeyParts(writer, methodDeclaration, keyParts, invariantFormatting); + return; + } + + // Find the index of the repeating part and count fixed value parts before it + int repeatingKeyPartIndex = keyParts.ToList().IndexOf(repeatingPart); + int fixedPartCount = keyParts.Take(repeatingKeyPartIndex).OfType().Count(); + var fixedKeyParts = keyParts.Take(repeatingKeyPartIndex).ToList(); + + writer.StartBlock(methodDeclaration); + + WriteFixedPartCases(); + WriteRepeatingPartHandler(); + + writer.EndBlock(); // end method + writer.WriteLine(); + + return; + + void WriteFixedPartCases() + { + if (fixedKeyParts.Count == 0) + return; + + writer.StartBlock("switch (throughPartIndex, includeTrailingDelimiter)"); + + for (int i = 0, keyPartIndex = -1; i < fixedKeyParts.Count; i++) + { + var keyPart = fixedKeyParts[i]; + + bool isDelimiter = keyPart is DelimiterKeyPart; + if (!isDelimiter) + keyPartIndex++; + + string switchCase = $"case ({keyPartIndex}, {(isDelimiter ? "true" : "false")}):"; + string formatString = BuildFormatStringForKeyParts(fixedKeyParts.Take(i + 1)); + + writer.WriteLine(invariantFormatting + ? $"{switchCase} return string.Create({InvariantCulture}, $\"{formatString}\");" + : $"{switchCase} return $\"{formatString}\";"); + } + + writer.EndBlock(); + writer.WriteLine(); + } + + void WriteRepeatingPartHandler() + { + string propName = repeatingPart.Property.Name; + char separator = repeatingPart.Separator; + string? format = repeatingPart.Format; + string countExpression = GetRepeatingCountExpression(repeatingPart.Property); + + writer.WriteLines($""" + int fixedPartCount = {fixedPartCount}; + int repeatIndex = throughPartIndex - fixedPartCount; + int repeatCount = Math.Min(repeatIndex + 1, {countExpression}); + if (repeatCount <= 0) + throw new InvalidOperationException("Invalid throughPartIndex for repeating section."); + + """); + + string fixedPrefix = BuildFormatStringForKeyParts(fixedKeyParts); + + writer.WriteLines($$""" + var handler = new System.Runtime.CompilerServices.DefaultInterpolatedStringHandler(0, 0{{(invariantFormatting ? $", {InvariantCulture}" : "")}}); + """); + + if (fixedPrefix.Length > 0) + { + writer.WriteLine(invariantFormatting + ? $"handler.AppendFormatted(string.Create({InvariantCulture}, $\"{fixedPrefix}\"));" + : $"handler.AppendFormatted($\"{fixedPrefix}\");"); + } + + writer.WriteLine(); + + writer.StartBlock("for (int i = 0; i < repeatCount; i++)"); + + writer.StartBlock("if (i > 0)"); + writer.WriteLine($"handler.AppendLiteral(\"{separator}\");"); + writer.EndBlock(); + + writer.WriteLine(); + + if (format is not null) + writer.WriteLine($"handler.AppendFormatted({propName}[i], \"{format}\");"); + else + writer.WriteLine($"handler.AppendFormatted({propName}[i]);"); + + writer.EndBlock(); // end for loop + writer.WriteLine(); + + writer.StartBlock("if (includeTrailingDelimiter)"); + writer.WriteLine($"handler.AppendLiteral(\"{separator}\");"); + writer.EndBlock(); + writer.WriteLine(); + + writer.WriteLine("return handler.ToStringAndClear();"); + } + } + + private static void WriteDynamicFormatMethodBodyForFixedKeyParts( + SourceWriter writer, string methodDeclaration, IReadOnlyList keyParts, bool invariantFormatting) + { + writer.StartBlock(methodDeclaration); + + writer.StartBlock("return (throughPartIndex, includeTrailingDelimiter) switch"); + + for (int i = 0, keyPartIndex = -1; i < keyParts.Count; i++) + { + var keyPart = keyParts[i]; + + bool isDelimiter = keyPart is DelimiterKeyPart; + if (!isDelimiter) + keyPartIndex++; + + string switchCase = $"({keyPartIndex}, {(isDelimiter ? "true" : "false")}) =>"; + string formatString = BuildFormatStringForKeyParts(keyParts.Take(i + 1)); + + writer.WriteLine(invariantFormatting + ? $"{switchCase} string.Create({InvariantCulture}, $\"{formatString}\")," + : $"{switchCase} $\"{formatString}\","); + } + + writer.WriteLine("_ => throw new InvalidOperationException(\"Invalid combination of throughPartIndex and includeTrailingDelimiter provided\")"); + + writer.EndBlock(withSemicolon: true); + + writer.EndBlock(); + writer.WriteLine(); + } + + private static string GetRepeatingCountExpression(PropertySpec property) => + property.CollectionType == CollectionType.ImmutableArray + ? $"{property.Name}.Length" + : $"{property.Name}.Count"; + + private static string BuildFormatStringForKeyParts(IEnumerable keyParts) + { + var builder = new StringBuilder(); + foreach (var keyPart in keyParts) + { + builder.Append(keyPart switch + { + DelimiterKeyPart d => d.Value, + ConstantKeyPart c => c.Value, + PropertyKeyPart p => $"{{{p.Property.Name}{(p.Format is not null ? $":{p.Format}" : string.Empty)}}}", + _ => throw new InvalidOperationException() + }); + } + + return builder.ToString(); + } + + private void AddSource(string hintName, SourceText sourceText) => _context.AddSource(hintName, sourceText); + + private static SourceWriter CreateSourceWriterWithHeader(GenerationSpec generationSpec) + { + var writer = new SourceWriter(); + + writer.WriteLines(""" + // + + #nullable enable annotations + #nullable disable warnings + + // Suppress warnings about [Obsolete] member usage in generated code. + #pragma warning disable CS0612, CS0618 + + using System; + using CompositeKey; + + """); + + if (generationSpec.TargetType.Namespace is not null) + writer.StartBlock($"namespace {generationSpec.TargetType.Namespace}"); + + var nestedTypeDeclarations = generationSpec.TargetType.TypeDeclarations; + Debug.Assert(nestedTypeDeclarations.Count > 0); + + for (int i = nestedTypeDeclarations.Count - 1; i > 0; i--) + writer.StartBlock(nestedTypeDeclarations[i]); + + // Annotate the context class with the GeneratedCodeAttribute + writer.WriteLine($"""[global::System.CodeDom.Compiler.GeneratedCodeAttribute("{AssemblyName}", "{AssemblyVersion}")]"""); + + // Emit the main class declaration + writer.StartBlock($"{nestedTypeDeclarations[0]} : {(generationSpec.Key is PrimaryKeySpec ? "IPrimaryKey" : "ICompositePrimaryKey")}<{generationSpec.TargetType.TypeName}>"); + + return writer; + } + + private static SourceText CompleteSourceFileAndReturnSourceText(SourceWriter writer) + { + while (writer.Indentation > 0) + writer.EndBlock(); + + return writer.ToSourceText(); + } +} diff --git a/src/CompositeKey.SourceGeneration/Parser.cs b/src/CompositeKey.SourceGeneration/Parser.cs new file mode 100644 index 0000000..99bd883 --- /dev/null +++ b/src/CompositeKey.SourceGeneration/Parser.cs @@ -0,0 +1,625 @@ +using System.Diagnostics; +using CompositeKey.Analyzers.Common.Diagnostics; +using CompositeKey.Analyzers.Common.Tokenization; +using CompositeKey.Analyzers.Common.Validation; +using CompositeKey.SourceGeneration.Core; +using CompositeKey.SourceGeneration.Core.Extensions; +using CompositeKey.SourceGeneration.Model; +using CompositeKey.SourceGeneration.Model.Key; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace CompositeKey.SourceGeneration; + +public sealed record CompositeKeyAttributeValues(string TemplateString, char? PrimaryKeySeparator, bool InvariantCulture); + +internal sealed class Parser(KnownTypeSymbols knownTypeSymbols) +{ + private const LanguageVersion MinimumSupportedLanguageVersion = LanguageVersion.CSharp11; + + private readonly KnownTypeSymbols _knownTypeSymbols = knownTypeSymbols; + private readonly List _diagnostics = []; + + private Location? _location; + + public ImmutableEquatableArray Diagnostics => _diagnostics.ToImmutableEquatableArray(); + + public void ReportDiagnostic(DiagnosticDescriptor descriptor, Location? location, params object?[]? messageArgs) + { + Debug.Assert(_location != null); + + if (location is null || (location.SourceTree is not null && !_knownTypeSymbols.Compilation.ContainsSyntaxTree(location.SourceTree))) + location = _location; + + _diagnostics.Add(DiagnosticInfo.Create(descriptor, location, messageArgs)); + } + + public GenerationSpec? Parse( + TypeDeclarationSyntax typeDeclarationSyntax, SemanticModel semanticModel, CancellationToken cancellationToken) + { + var targetTypeSymbol = semanticModel.GetDeclaredSymbol(typeDeclarationSyntax, cancellationToken); + Debug.Assert(targetTypeSymbol != null); + + _location = targetTypeSymbol!.Locations.Length > 0 ? targetTypeSymbol.Locations[0] : null; + Debug.Assert(_location is not null); + + var languageVersion = _knownTypeSymbols.Compilation is CSharpCompilation csc ? csc.LanguageVersion : (LanguageVersion?)null; + if (languageVersion is null or < MinimumSupportedLanguageVersion) + { + ReportDiagnostic(DiagnosticDescriptors.UnsupportedLanguageVersion, _location, languageVersion?.ToDisplayString(), MinimumSupportedLanguageVersion.ToDisplayString()); + return null; + } + + // Validate type structure using comprehensive shared validation + var validationResult = TypeValidation.ValidateTypeForCompositeKey( + targetTypeSymbol, + typeDeclarationSyntax, + semanticModel, + _knownTypeSymbols.CompositeKeyConstructorAttributeType, + cancellationToken); + + if (!validationResult.IsSuccess) + { + ReportDiagnostic(validationResult.Descriptor, _location, validationResult.MessageArgs); + return null; + } + + // Use validated data from the validation result (guaranteed non-null due to MemberNotNullWhen on IsSuccess) + var targetTypeDeclarations = validationResult.TargetTypeDeclarations; + var constructor = validationResult.Constructor; + + var compositeKeyAttributeValues = ParseCompositeKeyAttributeValues(targetTypeSymbol); + Debug.Assert(compositeKeyAttributeValues is not null); + + var constructorParameters = ParseConstructorParameters(constructor, out var constructionStrategy, out bool constructorSetsRequiredMembers); + var properties = ParseProperties(targetTypeSymbol); + var propertyInitializers = ParsePropertyInitializers(constructorParameters, properties.Select(p => p.Spec).ToList(), ref constructionStrategy, constructorSetsRequiredMembers); + + var propertiesUsedInKey = new List<(PropertySpec Spec, ITypeSymbol TypeSymbol)>(); + var keyParts = ParseTemplateStringIntoKeyParts(compositeKeyAttributeValues!, properties, propertiesUsedInKey); + if (keyParts is null) + return null; // Should have already reported diagnostics by this point so just return null... + + var primaryDelimiterKeyPart = keyParts.OfType().FirstOrDefault(); + + KeySpec key; + if (primaryDelimiterKeyPart is null) + { + // If we reach this branch then it's just a Primary Key + key = new PrimaryKeySpec(compositeKeyAttributeValues!.InvariantCulture, keyParts.ToImmutableEquatableArray()); + } + else + { + // If we reach this branch then it's a "Composite" Primary Key + var (partitionKeyParts, sortKeyParts) = SplitKeyPartsIntoPartitionAndSortKey(keyParts); + key = new CompositePrimaryKeySpec( + compositeKeyAttributeValues!.InvariantCulture, + keyParts.ToImmutableEquatableArray(), + partitionKeyParts.ToImmutableEquatableArray(), + primaryDelimiterKeyPart, + sortKeyParts.ToImmutableEquatableArray()); + } + + return new GenerationSpec( + new TargetTypeSpec( + new TypeRef(targetTypeSymbol), + targetTypeSymbol.ContainingNamespace is { IsGlobalNamespace: false } ns ? ns.ToDisplayString() : null, + targetTypeDeclarations.ToImmutableEquatableArray(), + propertiesUsedInKey.Select(p => p.Spec).ToImmutableEquatableArray(), + constructorParameters.ToImmutableEquatableArray(), + (propertyInitializers?.Where(pi => propertiesUsedInKey.Any(p => p.Spec.Name == pi.Name))).ToImmutableEquatableArray(), + constructionStrategy), + key); + } + + private static (List PartitionKeyParts, List SortKeyParts) SplitKeyPartsIntoPartitionAndSortKey(List keyParts) + { + int indexOfPrimaryKeyDelimiter = keyParts.FindIndex(kp => kp is PrimaryDelimiterKeyPart); + Debug.Assert(indexOfPrimaryKeyDelimiter != -1); + + return ( + keyParts.Take(indexOfPrimaryKeyDelimiter).ToList(), + keyParts.Skip(indexOfPrimaryKeyDelimiter + 1).ToList()); + } + + private List? ParseTemplateStringIntoKeyParts( + CompositeKeyAttributeValues compositeKeyAttributeValues, + List<(PropertySpec Spec, ITypeSymbol TypeSymbol)> properties, + List<(PropertySpec Spec, ITypeSymbol TypeSymbol)> propertiesUsedInKey) + { + (string templateString, char? primaryKeySeparator, _) = compositeKeyAttributeValues; + + var (tokenizationSuccessful, templateTokens) = TemplateValidation.TokenizeTemplateString(templateString, primaryKeySeparator); + if (!tokenizationSuccessful) + { + ReportDiagnostic(DiagnosticDescriptors.EmptyOrInvalidTemplateString, _location, templateString); + return null; + } + + var separatorValidation = TemplateValidation.ValidatePrimaryKeySeparator(templateString, primaryKeySeparator, templateTokens); + if (!separatorValidation.IsSuccess) + { + ReportDiagnostic(separatorValidation.Descriptor, _location, separatorValidation.MessageArgs); + return null; + } + + if (!TemplateValidation.HasValidTemplateStructure(templateTokens)) + { + ReportDiagnostic(DiagnosticDescriptors.EmptyOrInvalidTemplateString, _location, templateString); + return null; + } + + if (primaryKeySeparator.HasValue && !TemplateValidation.ValidatePartitionAndSortKeyStructure(templateTokens, out _)) + { + ReportDiagnostic(DiagnosticDescriptors.EmptyOrInvalidTemplateString, _location, templateString); + return null; + } + + List keyParts = []; + foreach (var templateToken in templateTokens) + { + KeyPart? keyPart = templateToken switch + { + PrimaryDelimiterTemplateToken pd => new PrimaryDelimiterKeyPart(pd.Value) { LengthRequired = 1 }, + DelimiterTemplateToken d => new DelimiterKeyPart(d.Value) { LengthRequired = 1 }, + PropertyTemplateToken p => ToPropertyKeyPart(p), + RepeatingPropertyTemplateToken rp => ToRepeatingPropertyKeyPart(rp), + ConstantTemplateToken c => new ConstantKeyPart(c.Value) { LengthRequired = c.Value.Length }, + _ => null + }; + + if (keyPart is null) + { + ReportDiagnostic(DiagnosticDescriptors.EmptyOrInvalidTemplateString, _location, templateString); + return null; + } + + keyParts.Add(keyPart); + } + + // Validate: repeating type used without repeating syntax -> COMPOSITE0010 + foreach (var keyPart in keyParts) + { + if (keyPart is PropertyKeyPart pkp && pkp.Property.CollectionType != CollectionType.None) + { + ReportDiagnostic(DiagnosticDescriptors.RepeatingTypeMustUseRepeatingSyntax, _location, pkp.Property.Name); + return null; + } + } + + // Validate: repeating property must be the last value part in its key section -> COMPOSITE0011 + var valueParts = keyParts.Where(kp => kp is ValueKeyPart).ToList(); + if (valueParts.Count > 0 && valueParts[^1] is not RepeatingPropertyKeyPart) + { + // Only report if there's a repeating part that isn't last + if (valueParts.Any(kp => kp is RepeatingPropertyKeyPart)) + { + var repeatingPart = valueParts.First(kp => kp is RepeatingPropertyKeyPart) as RepeatingPropertyKeyPart; + ReportDiagnostic(DiagnosticDescriptors.RepeatingPropertyMustBeLastPart, _location, repeatingPart!.Property.Name); + return null; + } + } + + // For composite keys, also validate repeating position within each section + if (keyParts.Any(kp => kp is PrimaryDelimiterKeyPart)) + { + int delimiterIndex = keyParts.FindIndex(kp => kp is PrimaryDelimiterKeyPart); + + var partitionValueParts = keyParts.Take(delimiterIndex).Where(kp => kp is ValueKeyPart).ToList(); + if (partitionValueParts.Count > 0 && partitionValueParts.Any(kp => kp is RepeatingPropertyKeyPart) && partitionValueParts[^1] is not RepeatingPropertyKeyPart) + { + var repeatingPart = partitionValueParts.First(kp => kp is RepeatingPropertyKeyPart) as RepeatingPropertyKeyPart; + ReportDiagnostic(DiagnosticDescriptors.RepeatingPropertyMustBeLastPart, _location, repeatingPart!.Property.Name); + return null; + } + + var sortValueParts = keyParts.Skip(delimiterIndex + 1).Where(kp => kp is ValueKeyPart).ToList(); + if (sortValueParts.Count > 0 && sortValueParts.Any(kp => kp is RepeatingPropertyKeyPart) && sortValueParts[^1] is not RepeatingPropertyKeyPart) + { + var repeatingPart = sortValueParts.First(kp => kp is RepeatingPropertyKeyPart) as RepeatingPropertyKeyPart; + ReportDiagnostic(DiagnosticDescriptors.RepeatingPropertyMustBeLastPart, _location, repeatingPart!.Property.Name); + return null; + } + } + + return keyParts; + + PropertyKeyPart? ToPropertyKeyPart(PropertyTemplateToken templateToken) + { + var availableProperties = properties + .Select(p => new TemplateValidation.PropertyInfo(p.Spec.Name, p.Spec.HasGetter, p.Spec.HasSetter)) + .ToList(); + + var propertyValidation = TemplateValidation.ValidatePropertyReferences([templateToken], availableProperties); + if (!propertyValidation.IsSuccess) + { + ReportDiagnostic(propertyValidation.Descriptor, _location, propertyValidation.MessageArgs); + return null; + } + + var property = properties.First(p => p.Spec.Name == templateToken.Name); + + propertiesUsedInKey.Add(property); + var (propertySpec, typeSymbol) = property; + + // Repeating type properties must use repeating syntax + if (propertySpec.CollectionType != CollectionType.None) + { + ReportDiagnostic(DiagnosticDescriptors.RepeatingTypeMustUseRepeatingSyntax, _location, propertySpec.Name); + return null; + } + + var interfaces = typeSymbol.AllInterfaces; + bool isSpanParsable = interfaces.Any(i => i.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat).StartsWith("global::System.ISpanParsable")); + bool isSpanFormattable = interfaces.Any(i => i.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat).Equals("global::System.ISpanFormattable")); + + var typeInfo = new PropertyValidation.PropertyTypeInfo( + TypeName: propertySpec.Type.FullyQualifiedName, + IsGuid: SymbolEqualityComparer.Default.Equals(typeSymbol, _knownTypeSymbols.GuidType), + IsString: SymbolEqualityComparer.Default.Equals(typeSymbol, _knownTypeSymbols.StringType), + IsEnum: typeSymbol.TypeKind == TypeKind.Enum, + IsSpanParsable: isSpanParsable, + IsSpanFormattable: isSpanFormattable); + + var formatValidation = PropertyValidation.ValidatePropertyFormat( + propertySpec.Name, + typeInfo, + templateToken.Format); + + if (!formatValidation.IsSuccess) + { + ReportDiagnostic(formatValidation.Descriptor, _location, formatValidation.MessageArgs); + return null; + } + + var typeCompatibility = PropertyValidation.ValidatePropertyTypeCompatibility( + propertySpec.Name, + typeInfo); + + if (!typeCompatibility.IsSuccess) + { + throw new NotSupportedException($"Unsupported property of type '{propertySpec.Type.FullyQualifiedName}'"); + } + + var lengthInfo = PropertyValidation.GetFormattedLength(typeInfo, templateToken.Format); + int lengthRequired = lengthInfo?.length ?? 1; + bool exactLengthRequirement = lengthInfo?.isExact ?? false; + + ParseType parseType; + FormatType formatType; + string? format = templateToken.Format; + + if (typeInfo.IsGuid) + { + parseType = ParseType.Guid; + formatType = FormatType.Guid; + format = templateToken.Format?.ToLowerInvariant() ?? "d"; + } + else if (typeInfo.IsString) + { + parseType = ParseType.String; + formatType = FormatType.String; + format = null; + } + else if (typeInfo.IsEnum) + { + parseType = ParseType.Enum; + formatType = FormatType.Enum; + format = templateToken.Format?.ToLowerInvariant() ?? "g"; + } + else + { + parseType = ParseType.SpanParsable; + formatType = FormatType.SpanFormattable; + } + + return new PropertyKeyPart(propertySpec, format, parseType, formatType) + { + LengthRequired = lengthRequired, + ExactLengthRequirement = exactLengthRequirement + }; + } + + RepeatingPropertyKeyPart? ToRepeatingPropertyKeyPart(RepeatingPropertyTemplateToken templateToken) + { + var availableProperties = properties + .Select(p => new TemplateValidation.PropertyInfo(p.Spec.Name, p.Spec.HasGetter, p.Spec.HasSetter)) + .ToList(); + + var propertyValidation = TemplateValidation.ValidatePropertyReferences([templateToken], availableProperties); + if (!propertyValidation.IsSuccess) + { + ReportDiagnostic(propertyValidation.Descriptor, _location, propertyValidation.MessageArgs); + return null; + } + + var property = properties.First(p => p.Spec.Name == templateToken.Name); + var (propertySpec, typeSymbol) = property; + + // Validate that the property is a collection type + if (propertySpec.CollectionType == CollectionType.None) + { + ReportDiagnostic(DiagnosticDescriptors.RepeatingPropertyMustUseCollectionType, _location, propertySpec.Name); + return null; + } + + // Extract inner type from the collection + var namedTypeSymbol = (INamedTypeSymbol)typeSymbol; + var innerTypeSymbol = namedTypeSymbol.TypeArguments[0]; + + var innerInterfaces = innerTypeSymbol.AllInterfaces; + bool isSpanParsable = innerInterfaces.Any(i => i.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat).StartsWith("global::System.ISpanParsable")); + bool isSpanFormattable = innerInterfaces.Any(i => i.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat).Equals("global::System.ISpanFormattable")); + + var innerTypeInfo = new PropertyValidation.PropertyTypeInfo( + TypeName: innerTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + IsGuid: SymbolEqualityComparer.Default.Equals(innerTypeSymbol, _knownTypeSymbols.GuidType), + IsString: SymbolEqualityComparer.Default.Equals(innerTypeSymbol, _knownTypeSymbols.StringType), + IsEnum: innerTypeSymbol.TypeKind == TypeKind.Enum, + IsSpanParsable: isSpanParsable, + IsSpanFormattable: isSpanFormattable); + + var formatValidation = PropertyValidation.ValidatePropertyFormat( + propertySpec.Name, + innerTypeInfo, + templateToken.Format); + + if (!formatValidation.IsSuccess) + { + ReportDiagnostic(formatValidation.Descriptor, _location, formatValidation.MessageArgs); + return null; + } + + var typeCompatibility = PropertyValidation.ValidatePropertyTypeCompatibility( + propertySpec.Name, + innerTypeInfo); + + if (!typeCompatibility.IsSuccess) + { + throw new NotSupportedException($"Unsupported inner type '{innerTypeInfo.TypeName}' for repeating property '{propertySpec.Name}'"); + } + + propertiesUsedInKey.Add(property); + + ParseType innerParseType; + FormatType innerFormatType; + string? format = templateToken.Format; + + if (innerTypeInfo.IsGuid) + { + innerParseType = ParseType.Guid; + innerFormatType = FormatType.Guid; + format = templateToken.Format?.ToLowerInvariant() ?? "d"; + } + else if (innerTypeInfo.IsString) + { + innerParseType = ParseType.String; + innerFormatType = FormatType.String; + format = null; + } + else if (innerTypeInfo.IsEnum) + { + innerParseType = ParseType.Enum; + innerFormatType = FormatType.Enum; + format = templateToken.Format?.ToLowerInvariant() ?? "g"; + } + else + { + innerParseType = ParseType.SpanParsable; + innerFormatType = FormatType.SpanFormattable; + } + + return new RepeatingPropertyKeyPart(propertySpec, templateToken.Separator, format, innerParseType, innerFormatType, new TypeRef(innerTypeSymbol)) + { + LengthRequired = 1, + ExactLengthRequirement = false + }; + } + } + + private static List? ParsePropertyInitializers( + ConstructorParameterSpec[] constructorParameters, + List properties, + ref ConstructionStrategy constructionStrategy, + bool constructorSetsRequiredParameters) + { + if (properties is []) + return []; + + HashSet? propertyInitializerNames = null; + List? propertyInitializers = null; + int parameterCount = constructorParameters.Length; + + foreach (var property in properties) + { + if (!property.HasSetter) + continue; + + if ((property.IsRequired || constructorSetsRequiredParameters) && !property.IsInitOnlySetter) + continue; + + if (!(propertyInitializerNames ??= []).Add(property.Name)) + continue; + + var matchingConstructorParameter = constructorParameters.FirstOrDefault(cp => cp.Name.Equals(property.Name, StringComparison.OrdinalIgnoreCase)); + if (!property.IsRequired && matchingConstructorParameter is not null) + continue; + + constructionStrategy = ConstructionStrategy.ParameterizedConstructor; + (propertyInitializers ??= []).Add(new PropertyInitializerSpec( + property.Type, + property.Name, + property.CamelCaseName, + matchingConstructorParameter?.ParameterIndex ?? parameterCount++, + matchingConstructorParameter is not null)); + } + + return propertyInitializers; + } + + private List<(PropertySpec Spec, ITypeSymbol TypeSymbol)> ParseProperties(INamedTypeSymbol typeSymbol) + { + List<(PropertySpec Spec, ITypeSymbol TypeSymbol)> properties = []; + foreach (var propertySymbol in typeSymbol.GetMembers().OfType()) + { + if (propertySymbol is { IsImplicitlyDeclared: true, Name: "EqualityContract" }) + continue; + + if (propertySymbol.IsStatic || propertySymbol.Parameters.Length > 0) + continue; + + // Detect collection types + var collectionType = CollectionType.None; + ITypeSymbol effectiveTypeSymbol = propertySymbol.Type; + + if (propertySymbol.Type is INamedTypeSymbol namedType) + { + var originalDefinition = namedType.OriginalDefinition; + if (SymbolEqualityComparer.Default.Equals(originalDefinition, _knownTypeSymbols.ListType)) + collectionType = CollectionType.List; + else if (SymbolEqualityComparer.Default.Equals(originalDefinition, _knownTypeSymbols.ReadOnlyListType)) + collectionType = CollectionType.IReadOnlyList; + else if (SymbolEqualityComparer.Default.Equals(originalDefinition, _knownTypeSymbols.ImmutableArrayType)) + collectionType = CollectionType.ImmutableArray; + } + + // For collection types, extract inner type for EnumSpec + EnumSpec? enumSpec = null; + if (collectionType != CollectionType.None) + { + var innerType = ((INamedTypeSymbol)propertySymbol.Type).TypeArguments[0]; + if (innerType is INamedTypeSymbol { TypeKind: TypeKind.Enum } innerEnumType) + enumSpec = ExtractEnumDefinition(innerEnumType); + } + else + { + if (propertySymbol.Type is INamedTypeSymbol { TypeKind: TypeKind.Enum } enumType) + enumSpec = ExtractEnumDefinition(enumType); + } + + var propertySpec = new PropertySpec( + new TypeRef(propertySymbol.Type), + propertySymbol.Name, + propertySymbol.Name.FirstToLowerInvariant(), + propertySymbol.IsRequired, + propertySymbol.GetMethod is not null, + propertySymbol.SetMethod is not null, + propertySymbol.SetMethod is { IsInitOnly: true }, + enumSpec, + collectionType); + + properties.Add((propertySpec, propertySymbol.Type)); + } + + return properties; + + EnumSpec ExtractEnumDefinition(INamedTypeSymbol enumSymbol) + { + string name = enumSymbol.Name; + string fullyQualifiedName = enumSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + string underlyingType = enumSymbol.EnumUnderlyingType?.ToString() ?? "int"; + + List members = []; + foreach (var enumMember in enumSymbol.GetMembers()) + { + if (enumMember is not IFieldSymbol { ConstantValue: not null } fieldSymbol) + continue; + + members.Add(new EnumSpec.Member(fieldSymbol.Name, fieldSymbol.ConstantValue)); + } + + bool isSequentialFromZero = true; + for (int i = 0; i < members.Count; i++) + { + if (Convert.ToUInt64(members[i].Value) == (uint)i) + continue; + + isSequentialFromZero = false; + break; + } + + return new EnumSpec( + name, + fullyQualifiedName, + underlyingType, + members.ToImmutableEquatableArray(), + isSequentialFromZero); + } + } + + private ConstructorParameterSpec[] ParseConstructorParameters( + IMethodSymbol constructor, out ConstructionStrategy constructionStrategy, out bool constructorSetsRequiredMembers) + { + constructorSetsRequiredMembers = constructor.GetAttributes().Any(a => + SymbolEqualityComparer.Default.Equals(a.AttributeClass, _knownTypeSymbols.SetsRequiredMembersAttributeType)); + + int parameterCount = constructor.Parameters.Length; + if (parameterCount == 0) + { + constructionStrategy = ConstructionStrategy.ParameterlessConstructor; + return []; + } + + constructionStrategy = ConstructionStrategy.ParameterizedConstructor; + + var constructorParameters = new ConstructorParameterSpec[parameterCount]; + for (int i = 0; i < parameterCount; i++) + { + var parameterSymbol = constructor.Parameters[i]; + constructorParameters[i] = new ConstructorParameterSpec(new TypeRef(parameterSymbol.Type), parameterSymbol.Name, parameterSymbol.Name.FirstToLowerInvariant(), i); + } + + return constructorParameters; + } + + + private CompositeKeyAttributeValues? ParseCompositeKeyAttributeValues(INamedTypeSymbol targetTypeSymbol) + { + Debug.Assert(_knownTypeSymbols.CompositeKeyAttributeType is not null); + + CompositeKeyAttributeValues? attributeValues = null; + + foreach (var attributeData in targetTypeSymbol.GetAttributes()) + { + Debug.Assert(attributeValues is null, $"There should only ever be one {nameof(CompositeKeyAttribute)} per type"); + + if (SymbolEqualityComparer.Default.Equals(attributeData.AttributeClass, _knownTypeSymbols.CompositeKeyAttributeType)) + attributeValues = TryExtractAttributeValues(attributeData); + } + + return attributeValues; + + static CompositeKeyAttributeValues? TryExtractAttributeValues(AttributeData attributeData) + { + Debug.Assert(attributeData.ConstructorArguments.Length is 1); + + string? templateString = (string?)attributeData.ConstructorArguments[0].Value; + if (templateString is null) + return null; + + char? primaryKeySeparator = null; + bool? invariantCulture = null; + foreach (var namedArgument in attributeData.NamedArguments) + { + (string? key, var value) = (namedArgument.Key, namedArgument.Value); + + switch (key) + { + case nameof(CompositeKeyAttribute.PrimaryKeySeparator): + primaryKeySeparator = (char?)value.Value; + break; + + case nameof(CompositeKeyAttribute.InvariantCulture): + invariantCulture = (bool?)value.Value; + break; + + default: + throw new InvalidOperationException(); + } + } + + return new CompositeKeyAttributeValues(templateString, primaryKeySeparator, invariantCulture ?? true); + } + } +} diff --git a/src/CompositeKey.SourceGeneration/SourceGenerator.Emitter.cs b/src/CompositeKey.SourceGeneration/SourceGenerator.Emitter.cs deleted file mode 100644 index 6ffb982..0000000 --- a/src/CompositeKey.SourceGeneration/SourceGenerator.Emitter.cs +++ /dev/null @@ -1,1074 +0,0 @@ -using System.Diagnostics; -using System.Reflection; -using System.Text; -using CompositeKey.SourceGeneration.Core; -using CompositeKey.SourceGeneration.Core.Extensions; -using CompositeKey.SourceGeneration.Model; -using CompositeKey.SourceGeneration.Model.Key; -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.Text; - -namespace CompositeKey.SourceGeneration; - -public sealed partial class SourceGenerator -{ - private sealed class Emitter(SourceProductionContext context) - { - private static readonly string AssemblyName = typeof(Emitter).Assembly.GetName().Name!; - private static readonly string AssemblyVersion = typeof(Emitter).Assembly.GetCustomAttribute()!.InformationalVersion; - - private const string NotNullWhen = "global::System.Diagnostics.CodeAnalysis.NotNullWhen"; - private const string MaybeNullWhen = "global::System.Diagnostics.CodeAnalysis.MaybeNullWhen"; - private const string InvariantCulture = "global::System.Globalization.CultureInfo.InvariantCulture"; - - private readonly SourceProductionContext _context = context; - - public void Emit(GenerationSpec generationSpec) - { - Debug.Assert(AssemblyName is not null); - Debug.Assert(AssemblyVersion is not null); - - var writer = CreateSourceWriterWithHeader(generationSpec); - - if (generationSpec.Key is PrimaryKeySpec primaryKeySpec) - EmitForPrimaryKey(writer, generationSpec.TargetType, primaryKeySpec); - else if (generationSpec.Key is CompositePrimaryKeySpec compositePrimaryKeySpec) - EmitForCompositePrimaryKey(writer, generationSpec.TargetType, compositePrimaryKeySpec); - - EmitCommonImplementations(writer, generationSpec.TargetType); - writer.EndBlock(); - - foreach (var enumSpec in generationSpec.TargetType.Properties.Select(p => p.EnumSpec).Where(es => es is not null)) - EnumGenerationHelper.EmitEnumHelperClass(writer, enumSpec!); - - string hintName = $"{generationSpec.TargetType.Type.FullyQualifiedName.Replace("global::", string.Empty)}.g.cs"; - AddSource(hintName, CompleteSourceFileAndReturnSourceText(writer)); - } - - private static void EmitForPrimaryKey(SourceWriter writer, TargetTypeSpec targetTypeSpec, PrimaryKeySpec keySpec) - { - var keyParts = keySpec.Parts.ToList(); - - WriteFormatMethodBodyForKeyParts(writer, "public override string ToString()", keyParts, keySpec.InvariantFormatting); - WriteFormatMethodBodyForKeyParts(writer, "public string ToPartitionKeyString()", keyParts, keySpec.InvariantFormatting); - WriteDynamicFormatMethodBodyForKeyParts(writer, "public string ToPartitionKeyString(int throughPartIndex, bool includeTrailingDelimiter = true)", keyParts, keySpec.InvariantFormatting); - - WriteParseMethodImplementation(); - WriteTryParseMethodImplementation(); - - return; - - void WriteParseMethodImplementation() - { - writer.WriteLines($$""" - public static {{targetTypeSpec.TypeName}} Parse(string primaryKey) - { - ArgumentNullException.ThrowIfNull(primaryKey); - - return Parse((ReadOnlySpan)primaryKey); - } - - public static {{targetTypeSpec.TypeName}} Parse(ReadOnlySpan primaryKey) - { - """); - writer.Indentation++; - - WriteLengthCheck(writer, keyParts, "primaryKey", true); - - Func getPrimaryKeyPartInputVariable = static _ => "primaryKey"; - string? primaryKeyPartCountVariable = null; - if (keyParts.Count > 1) - { - WriteSplitImplementation(writer, keyParts, "primaryKey", out string primaryKeyPartRangesVariable, true, out primaryKeyPartCountVariable); - getPrimaryKeyPartInputVariable = indexExpr => $"primaryKey[{primaryKeyPartRangesVariable}[{indexExpr}]]"; - } - - WriteParsePropertiesImplementation(writer, keyParts, getPrimaryKeyPartInputVariable, true, primaryKeyPartCountVariable); - - writer.WriteLine($"return {WriteConstructor(targetTypeSpec)};"); - - writer.EndBlock(); - writer.WriteLine(); - } - - void WriteTryParseMethodImplementation() - { - writer.WriteLines($$""" - public static bool TryParse([{{NotNullWhen}}(true)] string? primaryKey, [{{MaybeNullWhen}}(false)] out {{targetTypeSpec.TypeName}}? result) - { - if (primaryKey is null) - { - result = null; - return false; - } - - return TryParse((ReadOnlySpan)primaryKey, out result); - } - - public static bool TryParse(ReadOnlySpan primaryKey, [{{MaybeNullWhen}}(false)] out {{targetTypeSpec.TypeName}}? result) - { - result = null; - - """); - writer.Indentation++; - - WriteLengthCheck(writer, keyParts, "primaryKey", false); - - Func getPrimaryKeyPartInputVariable = static _ => "primaryKey"; - string? primaryKeyPartCountVariable = null; - if (keyParts.Count > 1) - { - WriteSplitImplementation(writer, keyParts, "primaryKey", out string primaryKeyPartRangesVariable, false, out primaryKeyPartCountVariable); - getPrimaryKeyPartInputVariable = indexExpr => $"primaryKey[{primaryKeyPartRangesVariable}[{indexExpr}]]"; - } - - WriteParsePropertiesImplementation(writer, keyParts, getPrimaryKeyPartInputVariable, false, primaryKeyPartCountVariable); - - writer.WriteLines($""" - result = {WriteConstructor(targetTypeSpec)}; - return true; - """); - - writer.EndBlock(); - writer.WriteLine(); - } - } - - private static void EmitForCompositePrimaryKey(SourceWriter writer, TargetTypeSpec targetTypeSpec, CompositePrimaryKeySpec keySpec) - { - var partitionKeyParts = keySpec.PartitionKeyParts.ToList(); - var sortKeyParts = keySpec.SortKeyParts.ToList(); - - WriteFormatMethodBodyForKeyParts(writer, "public override string ToString()", keySpec.AllParts, keySpec.InvariantFormatting); - WriteFormatMethodBodyForKeyParts(writer, "public string ToPartitionKeyString()", partitionKeyParts, keySpec.InvariantFormatting); - WriteDynamicFormatMethodBodyForKeyParts(writer, "public string ToPartitionKeyString(int throughPartIndex, bool includeTrailingDelimiter = true)", partitionKeyParts, keySpec.InvariantFormatting); - WriteFormatMethodBodyForKeyParts(writer, "public string ToSortKeyString()", sortKeyParts, keySpec.InvariantFormatting); - WriteDynamicFormatMethodBodyForKeyParts(writer, "public string ToSortKeyString(int throughPartIndex, bool includeTrailingDelimiter = true)", sortKeyParts, keySpec.InvariantFormatting); - - WriteParseMethodImplementation(); - WriteTryParseMethodImplementation(); - WriteCompositeParseMethodImplementation(); - WriteCompositeTryParseMethodImplementation(); - - return; - - void WritePrimaryKeySplit(bool shouldThrow) - { - writer.WriteLines($""" - const int expectedPrimaryKeyParts = 2; - Span primaryKeyPartRanges = stackalloc Range[expectedPrimaryKeyParts + 1]; - if (primaryKey.Split(primaryKeyPartRanges, '{keySpec.PrimaryDelimiterKeyPart.Value}', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries) != expectedPrimaryKeyParts) - {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; - - """); - } - - void WriteParseMethodImplementation() - { - writer.WriteLines($$""" - public static {{targetTypeSpec.TypeName}} Parse(string primaryKey) - { - ArgumentNullException.ThrowIfNull(primaryKey); - - return Parse((ReadOnlySpan)primaryKey); - } - - public static {{targetTypeSpec.TypeName}} Parse(ReadOnlySpan primaryKey) - { - """); - writer.Indentation++; - - WritePrimaryKeySplit(true); - - writer.WriteLine("return Parse(primaryKey[primaryKeyPartRanges[0]], primaryKey[primaryKeyPartRanges[1]]);"); - - writer.EndBlock(); - writer.WriteLine(); - } - - void WriteTryParseMethodImplementation() - { - writer.WriteLines($$""" - public static bool TryParse([{{NotNullWhen}}(true)] string? primaryKey, [{{MaybeNullWhen}}(false)] out {{targetTypeSpec.TypeName}}? result) - { - if (primaryKey is null) - { - result = null; - return false; - } - - return TryParse((ReadOnlySpan)primaryKey, out result); - } - - public static bool TryParse(ReadOnlySpan primaryKey, [{{MaybeNullWhen}}(false)] out {{targetTypeSpec.TypeName}}? result) - { - result = null; - - """); - writer.Indentation++; - - WritePrimaryKeySplit(false); - - writer.WriteLine("return TryParse(primaryKey[primaryKeyPartRanges[0]], primaryKey[primaryKeyPartRanges[1]], out result);"); - - writer.EndBlock(); - writer.WriteLine(); - } - - void WriteCompositeParseMethodImplementation() - { - writer.WriteLines($$""" - public static {{targetTypeSpec.TypeName}} Parse(string partitionKey, string sortKey) - { - ArgumentNullException.ThrowIfNull(partitionKey); - ArgumentNullException.ThrowIfNull(sortKey); - - return Parse((ReadOnlySpan)partitionKey, (ReadOnlySpan)sortKey); - } - - public static {{targetTypeSpec.TypeName}} Parse(ReadOnlySpan partitionKey, ReadOnlySpan sortKey) - { - """); - writer.Indentation++; - - WriteLengthCheck(writer, partitionKeyParts, "partitionKey", true); - WriteLengthCheck(writer, sortKeyParts, "sortKey", true); - - Func getPartitionKeyPartInputVariable = static _ => "partitionKey"; - string? partitionKeyPartCountVariable = null; - if (partitionKeyParts.Count > 1) - { - WriteSplitImplementation(writer, partitionKeyParts, "partitionKey", out string partitionKeyPartRangesVariable, true, out partitionKeyPartCountVariable); - getPartitionKeyPartInputVariable = indexExpr => $"partitionKey[{partitionKeyPartRangesVariable}[{indexExpr}]]"; - } - - Func getSortKeyPartInputVariable = static _ => "sortKey"; - string? sortKeyPartCountVariable = null; - if (sortKeyParts.Count > 1) - { - WriteSplitImplementation(writer, sortKeyParts, "sortKey", out string sortKeyPartRangesVariable, true, out sortKeyPartCountVariable); - getSortKeyPartInputVariable = indexExpr => $"sortKey[{sortKeyPartRangesVariable}[{indexExpr}]]"; - } - - var propertyNameCounts = partitionKeyParts.Concat(sortKeyParts).OfType().GroupBy(p => p.Property.CamelCaseName).ToDictionary(g => g.Key, _ => 0); - WriteParsePropertiesImplementation(writer, partitionKeyParts, getPartitionKeyPartInputVariable, true, propertyNameCounts, partitionKeyPartCountVariable); - WriteParsePropertiesImplementation(writer, sortKeyParts, getSortKeyPartInputVariable, true, propertyNameCounts, sortKeyPartCountVariable); - - writer.WriteLine($"return {WriteConstructor(targetTypeSpec)};"); - - writer.EndBlock(); - writer.WriteLine(); - } - - void WriteCompositeTryParseMethodImplementation() - { - writer.WriteLines($$""" - public static bool TryParse([{{NotNullWhen}}(true)] string partitionKey, string sortKey, [{{MaybeNullWhen}}(false)] out {{targetTypeSpec.TypeName}}? result) - { - if (partitionKey is null || sortKey is null) - { - result = null; - return false; - } - - return TryParse((ReadOnlySpan)partitionKey, (ReadOnlySpan)sortKey, out result); - } - - public static bool TryParse(ReadOnlySpan partitionKey, ReadOnlySpan sortKey, [{{MaybeNullWhen}}(false)] out {{targetTypeSpec.TypeName}}? result) - { - result = null; - - """); - writer.Indentation++; - - WriteLengthCheck(writer, partitionKeyParts, "partitionKey", false); - WriteLengthCheck(writer, sortKeyParts, "sortKey", false); - - Func getPartitionKeyPartInputVariable = static _ => "partitionKey"; - string? partitionKeyPartCountVariable = null; - if (partitionKeyParts.Count > 1) - { - WriteSplitImplementation(writer, partitionKeyParts, "partitionKey", out string partitionKeyPartRangesVariable, false, out partitionKeyPartCountVariable); - getPartitionKeyPartInputVariable = indexExpr => $"partitionKey[{partitionKeyPartRangesVariable}[{indexExpr}]]"; - } - - Func getSortKeyPartInputVariable = static _ => "sortKey"; - string? sortKeyPartCountVariable = null; - if (sortKeyParts.Count > 1) - { - WriteSplitImplementation(writer, sortKeyParts, "sortKey", out string sortKeyPartRangesVariable, false, out sortKeyPartCountVariable); - getSortKeyPartInputVariable = indexExpr => $"sortKey[{sortKeyPartRangesVariable}[{indexExpr}]]"; - } - - var propertyNameCounts = partitionKeyParts.Concat(sortKeyParts).OfType().GroupBy(p => p.Property.CamelCaseName).ToDictionary(g => g.Key, _ => 0); - WriteParsePropertiesImplementation(writer, partitionKeyParts, getPartitionKeyPartInputVariable, false, propertyNameCounts, partitionKeyPartCountVariable); - WriteParsePropertiesImplementation(writer, sortKeyParts, getSortKeyPartInputVariable, false, propertyNameCounts, sortKeyPartCountVariable); - - writer.WriteLines($""" - result = {WriteConstructor(targetTypeSpec)}; - return true; - """); - - writer.EndBlock(); - writer.WriteLine(); - } - } - - private static void EmitCommonImplementations(SourceWriter writer, TargetTypeSpec targetTypeSpec) - { - writer.WriteLines($$""" - /// - string IFormattable.ToString(string? format, IFormatProvider? formatProvider) => ToString(); - - /// - static {{targetTypeSpec.TypeName}} IParsable<{{targetTypeSpec.TypeName}}>.Parse(string s, IFormatProvider? provider) => Parse(s); - - /// - static bool IParsable<{{targetTypeSpec.TypeName}}>.TryParse([{{NotNullWhen}}(true)] string? s, IFormatProvider? provider, [{{MaybeNullWhen}}(false)] out {{targetTypeSpec.TypeName}} result) => TryParse(s, out result); - - /// - static {{targetTypeSpec.TypeName}} ISpanParsable<{{targetTypeSpec.TypeName}}>.Parse(ReadOnlySpan s, IFormatProvider? provider) => Parse(s); - - /// - static bool ISpanParsable<{{targetTypeSpec.TypeName}}>.TryParse(ReadOnlySpan s, IFormatProvider? provider, [{{MaybeNullWhen}}(false)] out {{targetTypeSpec.TypeName}} result) => TryParse(s, out result); - """); - } - - private static void WriteLengthCheck(SourceWriter writer, List parts, string inputName, bool shouldThrow) - { - int lengthRequired = parts.Select(p => p.LengthRequired).Sum(); - bool exactLengthRequirement = parts.All(p => p.ExactLengthRequirement); - - writer.WriteLines($""" - if ({inputName}.Length {(exactLengthRequirement ? "!=" : "<")} {lengthRequired}) - {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; - - """); - } - - private static void WriteSplitImplementation(SourceWriter writer, List parts, string inputName, out string partRangesVariableName, bool shouldThrow, out string? partCountVariableName) - { - var repeatingPart = parts.OfType().FirstOrDefault(); - var uniqueDelimiters = parts.OfType().Select(d => d.Value).Distinct().ToList(); - - partRangesVariableName = $"{inputName}PartRanges"; - partCountVariableName = null; - - if (repeatingPart is not null) - { - bool sameSeparator = uniqueDelimiters.Contains(repeatingPart.Separator); - - if (sameSeparator) - { - // Same separator as key delimiters: split produces variable number of parts - int fixedValueParts = parts.OfType().Count(p => p is not RepeatingPropertyKeyPart); - - (string method, string delimiters) = uniqueDelimiters switch - { - { Count: 1 } => ("Split", $"'{uniqueDelimiters[0]}'"), - { Count: > 1 } => ("SplitAny", $"\"{string.Join(string.Empty, uniqueDelimiters)}\""), - _ => throw new InvalidOperationException() - }; - - string minPartsVariable = $"minExpected{inputName.FirstToUpperInvariant()}Parts"; - partCountVariableName = $"{inputName}PartCount"; - - writer.WriteLines($""" - const int {minPartsVariable} = {fixedValueParts + 1}; - Span {partRangesVariableName} = stackalloc Range[128]; - int {partCountVariableName} = {inputName}.{method}({partRangesVariableName}, {delimiters}, StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); - if ({partCountVariableName} < {minPartsVariable}) - {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; - - """); - } - else - { - // Different separator: split by fixed delimiters, last part contains the repeating section - int expectedParts = parts.OfType().Count(); - - (string method, string delimiters) = uniqueDelimiters switch - { - { Count: 1 } => ("Split", $"'{uniqueDelimiters[0]}'"), - { Count: > 1 } => ("SplitAny", $"\"{string.Join(string.Empty, uniqueDelimiters)}\""), - _ => throw new InvalidOperationException() - }; - - string expectedPartsVariableName = $"expected{inputName.FirstToUpperInvariant()}Parts"; - - writer.WriteLines($""" - const int {expectedPartsVariableName} = {expectedParts}; - Span {partRangesVariableName} = stackalloc Range[{expectedPartsVariableName} + 1]; - if ({inputName}.{method}({partRangesVariableName}, {delimiters}, StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries) != {expectedPartsVariableName}) - {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; - - """); - } - } - else - { - int expectedParts = parts.OfType().Count(); - - (string method, string delimiters) = uniqueDelimiters switch - { - { Count: 1 } => ("Split", $"'{uniqueDelimiters[0]}'"), - { Count: > 1 } => ("SplitAny", $"\"{string.Join(string.Empty, uniqueDelimiters)}\""), - _ => throw new InvalidOperationException() - }; - - string expectedPartsVariableName = $"expected{inputName.FirstToUpperInvariant()}Parts"; - - writer.WriteLines($""" - const int {expectedPartsVariableName} = {expectedParts}; - Span {partRangesVariableName} = stackalloc Range[{expectedPartsVariableName} + 1]; - if ({inputName}.{method}({partRangesVariableName}, {delimiters}, StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries) != {expectedPartsVariableName}) - {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; - - """); - } - } - - private static void WriteParsePropertiesImplementation( - SourceWriter writer, List parts, Func getPartInputVariable, bool shouldThrow, string? inputPartCountVariable = null) - { - var propertyNameCounts = parts.OfType().GroupBy(p => p.Property.CamelCaseName).ToDictionary(g => g.Key, _ => 0); - WriteParsePropertiesImplementation(writer, parts, getPartInputVariable, shouldThrow, propertyNameCounts, inputPartCountVariable); - } - - private static void WriteParsePropertiesImplementation( - SourceWriter writer, List parts, Func getPartInputVariable, bool shouldThrow, Dictionary propertyNameCounts, string? inputPartCountVariable = null) - { - var valueParts = parts.OfType().ToArray(); - for (int i = 0; i < valueParts.Length; i++) - { - var valueKeyPart = valueParts[i]; - string partInputVariable = getPartInputVariable($"{i}"); - - if (valueKeyPart is ConstantKeyPart c) - { - writer.WriteLines($""" - if (!{partInputVariable}.Equals("{c.Value}", StringComparison.Ordinal)) - {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; - - """); - continue; - } - - if (valueKeyPart is RepeatingPropertyKeyPart repeatingPart) - { - WriteRepeatingPropertyParse(repeatingPart, i); - continue; - } - - (string camelCaseName, string? originalCamelCaseName) = valueKeyPart is PropertyKeyPart propertyPart - ? GetCamelCaseName(propertyPart.Property, propertyNameCounts) - : throw new InvalidOperationException($"Expected a {nameof(PropertyKeyPart)} but got a {valueKeyPart.GetType().Name}"); - - switch (valueKeyPart) - { - case PropertyKeyPart { ParseType: ParseType.Guid } part: - writer.WriteLines($""" - if ({ToStrictLengthCheck(part, partInputVariable)}!Guid.TryParseExact({partInputVariable}, "{part.Format}", out var {camelCaseName})) - {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; - - """); - break; - - case PropertyKeyPart { ParseType: ParseType.String }: - writer.WriteLines($""" - if ({partInputVariable}.Length == 0) - {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; - - string {camelCaseName} = {partInputVariable}.ToString(); - - """); - break; - - case PropertyKeyPart { ParseType: ParseType.Enum } part: - if (part.Property.EnumSpec is null) - throw new InvalidOperationException($"{nameof(part.Property.EnumSpec)} is null"); - - writer.WriteLines($""" - if (!{part.Property.EnumSpec.Name}Helper.TryParse({partInputVariable}, out var {camelCaseName})) - {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; - - """); - break; - - case PropertyKeyPart { ParseType: ParseType.SpanParsable } part: - writer.WriteLines($""" - if (!{part.Property.Type.FullyQualifiedName}.TryParse({partInputVariable}, out var {camelCaseName})) - {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; - - """); - break; - } - - if (originalCamelCaseName is not null) - { - writer.WriteLines($""" - if (!{originalCamelCaseName}.Equals({camelCaseName})) - {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; - - """); - } - } - - return; - - static (string camelCaseName, string? originalCamelCaseName) GetCamelCaseName(PropertySpec property, Dictionary propertyNameCounts) - { - int propertyCount = propertyNameCounts[property.CamelCaseName]++; - return propertyCount == 0 - ? (property.CamelCaseName, null) - : ($"{property.CamelCaseName}{propertyCount}", property.CamelCaseName); - } - - static string ToStrictLengthCheck(KeyPart part, string input) => - part.ExactLengthRequirement ? $"{input}.Length != {part.LengthRequired} || " : string.Empty; - - void WriteRepeatingPropertyParse(RepeatingPropertyKeyPart repeatingPart, int valuePartIndex) - { - string camelCaseName = repeatingPart.Property.CamelCaseName; - string innerTypeName = repeatingPart.InnerType.FullyQualifiedName; - var uniqueDelimiters = parts.OfType().Select(d => d.Value).Distinct().ToList(); - bool sameSeparator = uniqueDelimiters.Contains(repeatingPart.Separator); - - string itemVar = $"{camelCaseName}Item"; - string listVar = camelCaseName; - - if (sameSeparator && inputPartCountVariable is not null) - { - // Same separator: repeating items are at indices valuePartIndex..partCount-1 - writer.WriteLines($""" - var {listVar} = new global::System.Collections.Generic.List<{innerTypeName}>(); - """); - - writer.StartBlock($"for (int ri = {valuePartIndex}; ri < {inputPartCountVariable}; ri++)"); - - string riAccess = getPartInputVariable("ri"); - - WriteRepeatingItemParse(repeatingPart, riAccess, itemVar, listVar); - - writer.EndBlock(); - writer.WriteLine(); - } - else - { - // Different separator: sub-split the part by the repeating separator - string partInputVariable = getPartInputVariable($"{valuePartIndex}"); - string repeatingRangesVar = $"{camelCaseName}Ranges"; - string repeatingCountVar = $"{camelCaseName}Count"; - - writer.WriteLines($""" - Span {repeatingRangesVar} = stackalloc Range[128]; - int {repeatingCountVar} = {partInputVariable}.Split({repeatingRangesVar}, '{repeatingPart.Separator}', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); - if ({repeatingCountVar} < 1) - {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; - - var {listVar} = new global::System.Collections.Generic.List<{innerTypeName}>(); - """); - - writer.StartBlock($"for (int ri = 0; ri < {repeatingCountVar}; ri++)"); - - string riAccess = $"{partInputVariable}[{repeatingRangesVar}[ri]]"; - WriteRepeatingItemParse(repeatingPart, riAccess, itemVar, listVar); - - writer.EndBlock(); - writer.WriteLine(); - } - - // Validate at least 1 item - writer.WriteLines($""" - if ({listVar}.Count == 0) - {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; - - """); - } - - void WriteRepeatingItemParse(RepeatingPropertyKeyPart repeatingPart, string itemInput, string itemVar, string listVar) - { - string innerTypeName = repeatingPart.InnerType.FullyQualifiedName; - - switch (repeatingPart.InnerParseType) - { - case ParseType.Guid: - writer.WriteLines($""" - if (!Guid.TryParseExact({itemInput}, "{repeatingPart.Format}", out var {itemVar})) - {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; - {listVar}.Add({itemVar}); - """); - break; - - case ParseType.String: - writer.WriteLines($""" - if ({itemInput}.Length == 0) - {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; - {listVar}.Add({itemInput}.ToString()); - """); - break; - - case ParseType.Enum: - if (repeatingPart.Property.EnumSpec is null) - throw new InvalidOperationException($"{nameof(repeatingPart.Property.EnumSpec)} is null"); - - writer.WriteLines($""" - if (!{repeatingPart.Property.EnumSpec.Name}Helper.TryParse({itemInput}, out var {itemVar})) - {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; - {listVar}.Add({itemVar}); - """); - break; - - case ParseType.SpanParsable: - writer.WriteLines($""" - if (!{innerTypeName}.TryParse({itemInput}, out var {itemVar})) - {(shouldThrow ? "throw new FormatException(\"Unrecognized format.\")" : "return false")}; - {listVar}.Add({itemVar}); - """); - break; - } - } - } - - private static string WriteConstructor(TargetTypeSpec targetTypeSpec) - { - var builder = new StringBuilder(); - builder.Append($"new {targetTypeSpec.TypeName}("); - - if (targetTypeSpec.ConstructorParameters.Count > 0) - { - foreach (var parameter in targetTypeSpec.ConstructorParameters) - { - var property = targetTypeSpec.Properties.FirstOrDefault(p => p.CamelCaseName == parameter.CamelCaseName); - builder.Append(property?.CollectionType == CollectionType.ImmutableArray - ? $"global::System.Collections.Immutable.ImmutableArray.CreateRange({parameter.CamelCaseName}), " - : $"{parameter.CamelCaseName}, "); - } - - builder.Length -= 2; // Remove the last ", " - } - - builder.Append(')'); - - if (targetTypeSpec.PropertyInitializers.Count > 0) - { - builder.Append(" { "); - - foreach (var initializer in targetTypeSpec.PropertyInitializers) - { - var property = targetTypeSpec.Properties.FirstOrDefault(p => p.CamelCaseName == initializer.CamelCaseName); - builder.Append(property?.CollectionType == CollectionType.ImmutableArray - ? $"{initializer.Name} = global::System.Collections.Immutable.ImmutableArray.CreateRange({initializer.CamelCaseName}), " - : $"{initializer.Name} = {initializer.CamelCaseName}, "); - } - - builder.Length -= 2; // Remove the last ", " - builder.Append(" }"); - } - - return builder.ToString(); - } - - private static void WriteFormatMethodBodyForKeyParts( - SourceWriter writer, string methodDeclaration, IReadOnlyList keyParts, bool invariantFormatting) - { - writer.StartBlock(methodDeclaration); - - bool hasRepeatingPart = keyParts.Any(kp => kp is RepeatingPropertyKeyPart); - - if (hasRepeatingPart) - { - WriteRepeatingFormatBody(); - } - else if (keyParts.All(kp => kp is - DelimiterKeyPart - or ConstantKeyPart - or PropertyKeyPart { FormatType: FormatType.Guid, ExactLengthRequirement: true } - or PropertyKeyPart { FormatType: FormatType.Enum, Format: "g" } - or PropertyKeyPart { FormatType: FormatType.String })) - { - string lengthRequired = keyParts - .Where(kp => kp.ExactLengthRequirement) - .Sum(kp => kp switch - { - DelimiterKeyPart => 1, - ConstantKeyPart c => c.Value.Length, - PropertyKeyPart p => p.LengthRequired, - _ => throw new InvalidOperationException() - }) - .ToString(); - - foreach (var keyPart in keyParts.Where(kp => !kp.ExactLengthRequirement)) - { - if (lengthRequired.Length != 0) - lengthRequired += " + "; - - lengthRequired += keyPart switch - { - PropertyKeyPart { FormatType: FormatType.Enum, Property.EnumSpec: not null } p => $"{p.Property.EnumSpec.Name}Helper.GetFormattedLength({p.Property.Name})", - PropertyKeyPart { FormatType: FormatType.String } p => $"{p.Property.Name}.Length", - _ => throw new InvalidOperationException() - }; - } - - writer.StartBlock($"return string.Create({lengthRequired}, this, static (destination, state) =>"); - - writer.WriteLine("int position = 0;"); - writer.WriteLine(); - - for (int i = 0; i < keyParts.Count; i++) - { - var keyPart = keyParts[i]; - switch (keyPart) - { - case DelimiterKeyPart d: - writer.WriteLine($"destination[position] = '{d.Value}';"); - writer.WriteLine("position += 1;"); - break; - case ConstantKeyPart c: - writer.WriteLine($"\"{c.Value}\".CopyTo(destination[position..]);"); - writer.WriteLine($"position += {c.Value.Length};"); - break; - case PropertyKeyPart { FormatType: FormatType.Guid } p: - string formatProvider = invariantFormatting ? InvariantCulture : "null"; - writer.StartBlock(); - writer.WriteLine($"if (!((ISpanFormattable)state.{p.Property.Name}).TryFormat(destination[position..], out int {GetCharsWritten(p.Property)}, \"{p.Format ?? "d"}\", {formatProvider}))"); - writer.WriteLine("\tthrow new FormatException();\n"); - writer.WriteLine($"position += {GetCharsWritten(p.Property)};"); - writer.EndBlock(); - break; - case PropertyKeyPart { FormatType: FormatType.Enum, Property.EnumSpec: not null } p: - writer.StartBlock(); - writer.WriteLine($"if (!{p.Property.EnumSpec.Name}Helper.TryFormat(state.{p.Property.Name}, destination[position..], out int {GetCharsWritten(p.Property)}))"); - writer.WriteLine("\tthrow new FormatException();\n"); - writer.WriteLine($"position += {GetCharsWritten(p.Property)};"); - writer.EndBlock(); - break; - case PropertyKeyPart { FormatType: FormatType.String } p: - writer.WriteLine($"state.{p.Property.Name}.CopyTo(destination[position..]);"); - writer.WriteLine($"position += state.{p.Property.Name}.Length;"); - break; - default: - throw new InvalidOperationException(); - } - - if (i != keyParts.Count - 1) - writer.WriteLine(); - } - - writer.Indentation--; - writer.WriteLine("});"); - } - else - { - string formatString = BuildFormatStringForKeyParts(keyParts); - writer.WriteLine(invariantFormatting - ? $"return string.Create({InvariantCulture}, $\"{formatString}\");" - : $"return $\"{formatString}\";"); - } - - writer.EndBlock(); - writer.WriteLine(); - - return; - - static string GetCharsWritten(PropertySpec p) => $"{p.CamelCaseName}CharsWritten"; - - void WriteRepeatingFormatBody() - { - // Emit empty collection checks for all repeating parts - foreach (var keyPart in keyParts.OfType()) - { - string countExpression = GetRepeatingCountExpression(keyPart.Property); - - writer.WriteLines($""" - if ({countExpression} == 0) - throw new FormatException("Collection must contain at least one item."); - - """); - } - - // Count fixed literal lengths and variable parts for DefaultInterpolatedStringHandler - int fixedLiteralLength = 0; - int formattedCount = 0; - foreach (var keyPart in keyParts) - { - switch (keyPart) - { - case DelimiterKeyPart: - fixedLiteralLength += 1; - break; - case ConstantKeyPart c: - fixedLiteralLength += c.Value.Length; - break; - case PropertyKeyPart: - formattedCount++; - break; - case RepeatingPropertyKeyPart: - // Will be handled dynamically in the loop - break; - } - } - - string formatProvider = invariantFormatting ? InvariantCulture : "null"; - - writer.WriteLines($""" - var handler = new System.Runtime.CompilerServices.DefaultInterpolatedStringHandler({fixedLiteralLength}, {formattedCount}, {formatProvider}); - """); - - foreach (var keyPart in keyParts) - { - switch (keyPart) - { - case DelimiterKeyPart d: - writer.WriteLine($"handler.AppendLiteral(\"{d.Value}\");"); - break; - case ConstantKeyPart c: - writer.WriteLine($"handler.AppendLiteral(\"{c.Value}\");"); - break; - case PropertyKeyPart p: - if (p.Format is not null) - writer.WriteLine($"handler.AppendFormatted({p.Property.Name}, \"{p.Format}\");"); - else - writer.WriteLine($"handler.AppendFormatted({p.Property.Name});"); - break; - case RepeatingPropertyKeyPart rp: - WriteRepeatingPartFormatLoop(rp); - break; - } - } - - writer.WriteLine(); - writer.WriteLine("return handler.ToStringAndClear();"); - } - - void WriteRepeatingPartFormatLoop(RepeatingPropertyKeyPart rp) - { - string countExpression = GetRepeatingCountExpression(rp.Property); - - writer.StartBlock($"for (int i = 0; i < {countExpression}; i++)"); - - writer.WriteLines($""" - if (i > 0) - handler.AppendLiteral("{rp.Separator}"); - - """); - - if (rp.Format is not null) - writer.WriteLine($"handler.AppendFormatted({rp.Property.Name}[i], \"{rp.Format}\");"); - else - writer.WriteLine($"handler.AppendFormatted({rp.Property.Name}[i]);"); - - writer.EndBlock(); - } - } - - private static void WriteDynamicFormatMethodBodyForKeyParts( - SourceWriter writer, string methodDeclaration, IReadOnlyList keyParts, bool invariantFormatting) - { - var repeatingPart = keyParts.OfType().FirstOrDefault(); - - if (repeatingPart is null) - { - WriteDynamicFormatMethodBodyForFixedKeyParts(writer, methodDeclaration, keyParts, invariantFormatting); - return; - } - - // Find the index of the repeating part and count fixed value parts before it - int repeatingKeyPartIndex = keyParts.ToList().IndexOf(repeatingPart); - int fixedPartCount = keyParts.Take(repeatingKeyPartIndex).OfType().Count(); - var fixedKeyParts = keyParts.Take(repeatingKeyPartIndex).ToList(); - - writer.StartBlock(methodDeclaration); - - WriteFixedPartCases(); - WriteRepeatingPartHandler(); - - writer.EndBlock(); // end method - writer.WriteLine(); - - return; - - void WriteFixedPartCases() - { - if (fixedKeyParts.Count == 0) - return; - - writer.StartBlock("switch (throughPartIndex, includeTrailingDelimiter)"); - - for (int i = 0, keyPartIndex = -1; i < fixedKeyParts.Count; i++) - { - var keyPart = fixedKeyParts[i]; - - bool isDelimiter = keyPart is DelimiterKeyPart; - if (!isDelimiter) - keyPartIndex++; - - string switchCase = $"case ({keyPartIndex}, {(isDelimiter ? "true" : "false")}):"; - string formatString = BuildFormatStringForKeyParts(fixedKeyParts.Take(i + 1)); - - writer.WriteLine(invariantFormatting - ? $"{switchCase} return string.Create({InvariantCulture}, $\"{formatString}\");" - : $"{switchCase} return $\"{formatString}\";"); - } - - writer.EndBlock(); - writer.WriteLine(); - } - - void WriteRepeatingPartHandler() - { - string propName = repeatingPart.Property.Name; - char separator = repeatingPart.Separator; - string? format = repeatingPart.Format; - string countExpression = GetRepeatingCountExpression(repeatingPart.Property); - - writer.WriteLines($""" - int fixedPartCount = {fixedPartCount}; - int repeatIndex = throughPartIndex - fixedPartCount; - int repeatCount = Math.Min(repeatIndex + 1, {countExpression}); - if (repeatCount <= 0) - throw new InvalidOperationException("Invalid throughPartIndex for repeating section."); - - """); - - string fixedPrefix = BuildFormatStringForKeyParts(fixedKeyParts); - - writer.WriteLines($$""" - var handler = new System.Runtime.CompilerServices.DefaultInterpolatedStringHandler(0, 0{{(invariantFormatting ? $", {InvariantCulture}" : "")}}); - """); - - if (fixedPrefix.Length > 0) - { - writer.WriteLine(invariantFormatting - ? $"handler.AppendFormatted(string.Create({InvariantCulture}, $\"{fixedPrefix}\"));" - : $"handler.AppendFormatted($\"{fixedPrefix}\");"); - } - - writer.WriteLine(); - - writer.StartBlock("for (int i = 0; i < repeatCount; i++)"); - - writer.StartBlock("if (i > 0)"); - writer.WriteLine($"handler.AppendLiteral(\"{separator}\");"); - writer.EndBlock(); - - writer.WriteLine(); - - if (format is not null) - writer.WriteLine($"handler.AppendFormatted({propName}[i], \"{format}\");"); - else - writer.WriteLine($"handler.AppendFormatted({propName}[i]);"); - - writer.EndBlock(); // end for loop - writer.WriteLine(); - - writer.StartBlock("if (includeTrailingDelimiter)"); - writer.WriteLine($"handler.AppendLiteral(\"{separator}\");"); - writer.EndBlock(); - writer.WriteLine(); - - writer.WriteLine("return handler.ToStringAndClear();"); - } - } - - private static void WriteDynamicFormatMethodBodyForFixedKeyParts( - SourceWriter writer, string methodDeclaration, IReadOnlyList keyParts, bool invariantFormatting) - { - writer.StartBlock(methodDeclaration); - - writer.StartBlock("return (throughPartIndex, includeTrailingDelimiter) switch"); - - for (int i = 0, keyPartIndex = -1; i < keyParts.Count; i++) - { - var keyPart = keyParts[i]; - - bool isDelimiter = keyPart is DelimiterKeyPart; - if (!isDelimiter) - keyPartIndex++; - - string switchCase = $"({keyPartIndex}, {(isDelimiter ? "true" : "false")}) =>"; - string formatString = BuildFormatStringForKeyParts(keyParts.Take(i + 1)); - - writer.WriteLine(invariantFormatting - ? $"{switchCase} string.Create({InvariantCulture}, $\"{formatString}\")," - : $"{switchCase} $\"{formatString}\","); - } - - writer.WriteLine("_ => throw new InvalidOperationException(\"Invalid combination of throughPartIndex and includeTrailingDelimiter provided\")"); - - writer.EndBlock(withSemicolon: true); - - writer.EndBlock(); - writer.WriteLine(); - } - - private static string GetRepeatingCountExpression(PropertySpec property) => - property.CollectionType == CollectionType.ImmutableArray - ? $"{property.Name}.Length" - : $"{property.Name}.Count"; - - private static string BuildFormatStringForKeyParts(IEnumerable keyParts) - { - var builder = new StringBuilder(); - foreach (var keyPart in keyParts) - { - builder.Append(keyPart switch - { - DelimiterKeyPart d => d.Value, - ConstantKeyPart c => c.Value, - PropertyKeyPart p => $"{{{p.Property.Name}{(p.Format is not null ? $":{p.Format}" : string.Empty)}}}", - _ => throw new InvalidOperationException() - }); - } - - return builder.ToString(); - } - - private void AddSource(string hintName, SourceText sourceText) => _context.AddSource(hintName, sourceText); - - private static SourceWriter CreateSourceWriterWithHeader(GenerationSpec generationSpec) - { - var writer = new SourceWriter(); - - writer.WriteLines(""" - // - - #nullable enable annotations - #nullable disable warnings - - // Suppress warnings about [Obsolete] member usage in generated code. - #pragma warning disable CS0612, CS0618 - - using System; - using CompositeKey; - - """); - - if (generationSpec.TargetType.Namespace is not null) - writer.StartBlock($"namespace {generationSpec.TargetType.Namespace}"); - - var nestedTypeDeclarations = generationSpec.TargetType.TypeDeclarations; - Debug.Assert(nestedTypeDeclarations.Count > 0); - - for (int i = nestedTypeDeclarations.Count - 1; i > 0; i--) - writer.StartBlock(nestedTypeDeclarations[i]); - - // Annotate the context class with the GeneratedCodeAttribute - writer.WriteLine($"""[global::System.CodeDom.Compiler.GeneratedCodeAttribute("{AssemblyName}", "{AssemblyVersion}")]"""); - - // Emit the main class declaration - writer.StartBlock($"{nestedTypeDeclarations[0]} : {(generationSpec.Key is PrimaryKeySpec ? "IPrimaryKey" : "ICompositePrimaryKey")}<{generationSpec.TargetType.TypeName}>"); - - return writer; - } - - private static SourceText CompleteSourceFileAndReturnSourceText(SourceWriter writer) - { - while (writer.Indentation > 0) - writer.EndBlock(); - - return writer.ToSourceText(); - } - } -} diff --git a/src/CompositeKey.SourceGeneration/SourceGenerator.Parser.cs b/src/CompositeKey.SourceGeneration/SourceGenerator.Parser.cs deleted file mode 100644 index 3dfea78..0000000 --- a/src/CompositeKey.SourceGeneration/SourceGenerator.Parser.cs +++ /dev/null @@ -1,631 +0,0 @@ -using System.Diagnostics; -using CompositeKey.Analyzers.Common.Diagnostics; -using CompositeKey.Analyzers.Common.Tokenization; -using CompositeKey.Analyzers.Common.Validation; -using CompositeKey.SourceGeneration.Core; -using CompositeKey.SourceGeneration.Core.Extensions; -using CompositeKey.SourceGeneration.Model; -using CompositeKey.SourceGeneration.Model.Key; -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; - -namespace CompositeKey.SourceGeneration; - -public sealed record CompositeKeyAttributeValues(string TemplateString, char? PrimaryKeySeparator, bool InvariantCulture); - -public sealed partial class SourceGenerator -{ - private const string CompositeKeyAttributeFullName = "CompositeKey.CompositeKeyAttribute"; - - private sealed class Parser(KnownTypeSymbols knownTypeSymbols) - { - private const LanguageVersion MinimumSupportedLanguageVersion = LanguageVersion.CSharp11; - - private readonly KnownTypeSymbols _knownTypeSymbols = knownTypeSymbols; - private readonly List _diagnostics = []; - - private Location? _location; - - public ImmutableEquatableArray Diagnostics => _diagnostics.ToImmutableEquatableArray(); - - public void ReportDiagnostic(DiagnosticDescriptor descriptor, Location? location, params object?[]? messageArgs) - { - Debug.Assert(_location != null); - - if (location is null || (location.SourceTree is not null && !_knownTypeSymbols.Compilation.ContainsSyntaxTree(location.SourceTree))) - location = _location; - - _diagnostics.Add(DiagnosticInfo.Create(descriptor, location, messageArgs)); - } - - public GenerationSpec? Parse( - TypeDeclarationSyntax typeDeclarationSyntax, SemanticModel semanticModel, CancellationToken cancellationToken) - { - var targetTypeSymbol = semanticModel.GetDeclaredSymbol(typeDeclarationSyntax, cancellationToken); - Debug.Assert(targetTypeSymbol != null); - - _location = targetTypeSymbol!.Locations.Length > 0 ? targetTypeSymbol.Locations[0] : null; - Debug.Assert(_location is not null); - - var languageVersion = _knownTypeSymbols.Compilation is CSharpCompilation csc ? csc.LanguageVersion : (LanguageVersion?)null; - if (languageVersion is null or < MinimumSupportedLanguageVersion) - { - ReportDiagnostic(DiagnosticDescriptors.UnsupportedLanguageVersion, _location, languageVersion?.ToDisplayString(), MinimumSupportedLanguageVersion.ToDisplayString()); - return null; - } - - // Validate type structure using comprehensive shared validation - var validationResult = TypeValidation.ValidateTypeForCompositeKey( - targetTypeSymbol, - typeDeclarationSyntax, - semanticModel, - _knownTypeSymbols.CompositeKeyConstructorAttributeType, - cancellationToken); - - if (!validationResult.IsSuccess) - { - ReportDiagnostic(validationResult.Descriptor, _location, validationResult.MessageArgs); - return null; - } - - // Use validated data from the validation result (guaranteed non-null due to MemberNotNullWhen on IsSuccess) - var targetTypeDeclarations = validationResult.TargetTypeDeclarations; - var constructor = validationResult.Constructor; - - var compositeKeyAttributeValues = ParseCompositeKeyAttributeValues(targetTypeSymbol); - Debug.Assert(compositeKeyAttributeValues is not null); - - var constructorParameters = ParseConstructorParameters(constructor, out var constructionStrategy, out bool constructorSetsRequiredMembers); - var properties = ParseProperties(targetTypeSymbol); - var propertyInitializers = ParsePropertyInitializers(constructorParameters, properties.Select(p => p.Spec).ToList(), ref constructionStrategy, constructorSetsRequiredMembers); - - var propertiesUsedInKey = new List<(PropertySpec Spec, ITypeSymbol TypeSymbol)>(); - var keyParts = ParseTemplateStringIntoKeyParts(compositeKeyAttributeValues!, properties, propertiesUsedInKey); - if (keyParts is null) - return null; // Should have already reported diagnostics by this point so just return null... - - var primaryDelimiterKeyPart = keyParts.OfType().FirstOrDefault(); - - KeySpec key; - if (primaryDelimiterKeyPart is null) - { - // If we reach this branch then it's just a Primary Key - key = new PrimaryKeySpec(compositeKeyAttributeValues!.InvariantCulture, keyParts.ToImmutableEquatableArray()); - } - else - { - // If we reach this branch then it's a "Composite" Primary Key - var (partitionKeyParts, sortKeyParts) = SplitKeyPartsIntoPartitionAndSortKey(keyParts); - key = new CompositePrimaryKeySpec( - compositeKeyAttributeValues!.InvariantCulture, - keyParts.ToImmutableEquatableArray(), - partitionKeyParts.ToImmutableEquatableArray(), - primaryDelimiterKeyPart, - sortKeyParts.ToImmutableEquatableArray()); - } - - return new GenerationSpec( - new TargetTypeSpec( - new TypeRef(targetTypeSymbol), - targetTypeSymbol.ContainingNamespace is { IsGlobalNamespace: false } ns ? ns.ToDisplayString() : null, - targetTypeDeclarations.ToImmutableEquatableArray(), - propertiesUsedInKey.Select(p => p.Spec).ToImmutableEquatableArray(), - constructorParameters.ToImmutableEquatableArray(), - (propertyInitializers?.Where(pi => propertiesUsedInKey.Any(p => p.Spec.Name == pi.Name))).ToImmutableEquatableArray(), - constructionStrategy), - key); - } - - private static (List PartitionKeyParts, List SortKeyParts) SplitKeyPartsIntoPartitionAndSortKey(List keyParts) - { - int indexOfPrimaryKeyDelimiter = keyParts.FindIndex(kp => kp is PrimaryDelimiterKeyPart); - Debug.Assert(indexOfPrimaryKeyDelimiter != -1); - - return ( - keyParts.Take(indexOfPrimaryKeyDelimiter).ToList(), - keyParts.Skip(indexOfPrimaryKeyDelimiter + 1).ToList()); - } - - private List? ParseTemplateStringIntoKeyParts( - CompositeKeyAttributeValues compositeKeyAttributeValues, - List<(PropertySpec Spec, ITypeSymbol TypeSymbol)> properties, - List<(PropertySpec Spec, ITypeSymbol TypeSymbol)> propertiesUsedInKey) - { - (string templateString, char? primaryKeySeparator, _) = compositeKeyAttributeValues; - - var (tokenizationSuccessful, templateTokens) = TemplateValidation.TokenizeTemplateString(templateString, primaryKeySeparator); - if (!tokenizationSuccessful) - { - ReportDiagnostic(DiagnosticDescriptors.EmptyOrInvalidTemplateString, _location, templateString); - return null; - } - - var separatorValidation = TemplateValidation.ValidatePrimaryKeySeparator(templateString, primaryKeySeparator, templateTokens); - if (!separatorValidation.IsSuccess) - { - ReportDiagnostic(separatorValidation.Descriptor, _location, separatorValidation.MessageArgs); - return null; - } - - if (!TemplateValidation.HasValidTemplateStructure(templateTokens)) - { - ReportDiagnostic(DiagnosticDescriptors.EmptyOrInvalidTemplateString, _location, templateString); - return null; - } - - if (primaryKeySeparator.HasValue && !TemplateValidation.ValidatePartitionAndSortKeyStructure(templateTokens, out _)) - { - ReportDiagnostic(DiagnosticDescriptors.EmptyOrInvalidTemplateString, _location, templateString); - return null; - } - - List keyParts = []; - foreach (var templateToken in templateTokens) - { - KeyPart? keyPart = templateToken switch - { - PrimaryDelimiterTemplateToken pd => new PrimaryDelimiterKeyPart(pd.Value) { LengthRequired = 1 }, - DelimiterTemplateToken d => new DelimiterKeyPart(d.Value) { LengthRequired = 1 }, - PropertyTemplateToken p => ToPropertyKeyPart(p), - RepeatingPropertyTemplateToken rp => ToRepeatingPropertyKeyPart(rp), - ConstantTemplateToken c => new ConstantKeyPart(c.Value) { LengthRequired = c.Value.Length }, - _ => null - }; - - if (keyPart is null) - { - ReportDiagnostic(DiagnosticDescriptors.EmptyOrInvalidTemplateString, _location, templateString); - return null; - } - - keyParts.Add(keyPart); - } - - // Validate: repeating type used without repeating syntax -> COMPOSITE0010 - foreach (var keyPart in keyParts) - { - if (keyPart is PropertyKeyPart pkp && pkp.Property.CollectionType != CollectionType.None) - { - ReportDiagnostic(DiagnosticDescriptors.RepeatingTypeMustUseRepeatingSyntax, _location, pkp.Property.Name); - return null; - } - } - - // Validate: repeating property must be the last value part in its key section -> COMPOSITE0011 - var valueParts = keyParts.Where(kp => kp is ValueKeyPart).ToList(); - if (valueParts.Count > 0 && valueParts[^1] is not RepeatingPropertyKeyPart) - { - // Only report if there's a repeating part that isn't last - if (valueParts.Any(kp => kp is RepeatingPropertyKeyPart)) - { - var repeatingPart = valueParts.First(kp => kp is RepeatingPropertyKeyPart) as RepeatingPropertyKeyPart; - ReportDiagnostic(DiagnosticDescriptors.RepeatingPropertyMustBeLastPart, _location, repeatingPart!.Property.Name); - return null; - } - } - - // For composite keys, also validate repeating position within each section - if (keyParts.Any(kp => kp is PrimaryDelimiterKeyPart)) - { - int delimiterIndex = keyParts.FindIndex(kp => kp is PrimaryDelimiterKeyPart); - - var partitionValueParts = keyParts.Take(delimiterIndex).Where(kp => kp is ValueKeyPart).ToList(); - if (partitionValueParts.Count > 0 && partitionValueParts.Any(kp => kp is RepeatingPropertyKeyPart) && partitionValueParts[^1] is not RepeatingPropertyKeyPart) - { - var repeatingPart = partitionValueParts.First(kp => kp is RepeatingPropertyKeyPart) as RepeatingPropertyKeyPart; - ReportDiagnostic(DiagnosticDescriptors.RepeatingPropertyMustBeLastPart, _location, repeatingPart!.Property.Name); - return null; - } - - var sortValueParts = keyParts.Skip(delimiterIndex + 1).Where(kp => kp is ValueKeyPart).ToList(); - if (sortValueParts.Count > 0 && sortValueParts.Any(kp => kp is RepeatingPropertyKeyPart) && sortValueParts[^1] is not RepeatingPropertyKeyPart) - { - var repeatingPart = sortValueParts.First(kp => kp is RepeatingPropertyKeyPart) as RepeatingPropertyKeyPart; - ReportDiagnostic(DiagnosticDescriptors.RepeatingPropertyMustBeLastPart, _location, repeatingPart!.Property.Name); - return null; - } - } - - return keyParts; - - PropertyKeyPart? ToPropertyKeyPart(PropertyTemplateToken templateToken) - { - var availableProperties = properties - .Select(p => new TemplateValidation.PropertyInfo(p.Spec.Name, p.Spec.HasGetter, p.Spec.HasSetter)) - .ToList(); - - var propertyValidation = TemplateValidation.ValidatePropertyReferences([templateToken], availableProperties); - if (!propertyValidation.IsSuccess) - { - ReportDiagnostic(propertyValidation.Descriptor, _location, propertyValidation.MessageArgs); - return null; - } - - var property = properties.First(p => p.Spec.Name == templateToken.Name); - - propertiesUsedInKey.Add(property); - var (propertySpec, typeSymbol) = property; - - // Repeating type properties must use repeating syntax - if (propertySpec.CollectionType != CollectionType.None) - { - ReportDiagnostic(DiagnosticDescriptors.RepeatingTypeMustUseRepeatingSyntax, _location, propertySpec.Name); - return null; - } - - var interfaces = typeSymbol.AllInterfaces; - bool isSpanParsable = interfaces.Any(i => i.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat).StartsWith("global::System.ISpanParsable")); - bool isSpanFormattable = interfaces.Any(i => i.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat).Equals("global::System.ISpanFormattable")); - - var typeInfo = new PropertyValidation.PropertyTypeInfo( - TypeName: propertySpec.Type.FullyQualifiedName, - IsGuid: SymbolEqualityComparer.Default.Equals(typeSymbol, _knownTypeSymbols.GuidType), - IsString: SymbolEqualityComparer.Default.Equals(typeSymbol, _knownTypeSymbols.StringType), - IsEnum: typeSymbol.TypeKind == TypeKind.Enum, - IsSpanParsable: isSpanParsable, - IsSpanFormattable: isSpanFormattable); - - var formatValidation = PropertyValidation.ValidatePropertyFormat( - propertySpec.Name, - typeInfo, - templateToken.Format); - - if (!formatValidation.IsSuccess) - { - ReportDiagnostic(formatValidation.Descriptor, _location, formatValidation.MessageArgs); - return null; - } - - var typeCompatibility = PropertyValidation.ValidatePropertyTypeCompatibility( - propertySpec.Name, - typeInfo); - - if (!typeCompatibility.IsSuccess) - { - throw new NotSupportedException($"Unsupported property of type '{propertySpec.Type.FullyQualifiedName}'"); - } - - var lengthInfo = PropertyValidation.GetFormattedLength(typeInfo, templateToken.Format); - int lengthRequired = lengthInfo?.length ?? 1; - bool exactLengthRequirement = lengthInfo?.isExact ?? false; - - ParseType parseType; - FormatType formatType; - string? format = templateToken.Format; - - if (typeInfo.IsGuid) - { - parseType = ParseType.Guid; - formatType = FormatType.Guid; - format = templateToken.Format?.ToLowerInvariant() ?? "d"; - } - else if (typeInfo.IsString) - { - parseType = ParseType.String; - formatType = FormatType.String; - format = null; - } - else if (typeInfo.IsEnum) - { - parseType = ParseType.Enum; - formatType = FormatType.Enum; - format = templateToken.Format?.ToLowerInvariant() ?? "g"; - } - else - { - parseType = ParseType.SpanParsable; - formatType = FormatType.SpanFormattable; - } - - return new PropertyKeyPart(propertySpec, format, parseType, formatType) - { - LengthRequired = lengthRequired, - ExactLengthRequirement = exactLengthRequirement - }; - } - - RepeatingPropertyKeyPart? ToRepeatingPropertyKeyPart(RepeatingPropertyTemplateToken templateToken) - { - var availableProperties = properties - .Select(p => new TemplateValidation.PropertyInfo(p.Spec.Name, p.Spec.HasGetter, p.Spec.HasSetter)) - .ToList(); - - var propertyValidation = TemplateValidation.ValidatePropertyReferences([templateToken], availableProperties); - if (!propertyValidation.IsSuccess) - { - ReportDiagnostic(propertyValidation.Descriptor, _location, propertyValidation.MessageArgs); - return null; - } - - var property = properties.First(p => p.Spec.Name == templateToken.Name); - var (propertySpec, typeSymbol) = property; - - // Validate that the property is a collection type - if (propertySpec.CollectionType == CollectionType.None) - { - ReportDiagnostic(DiagnosticDescriptors.RepeatingPropertyMustUseCollectionType, _location, propertySpec.Name); - return null; - } - - // Extract inner type from the collection - var namedTypeSymbol = (INamedTypeSymbol)typeSymbol; - var innerTypeSymbol = namedTypeSymbol.TypeArguments[0]; - - var innerInterfaces = innerTypeSymbol.AllInterfaces; - bool isSpanParsable = innerInterfaces.Any(i => i.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat).StartsWith("global::System.ISpanParsable")); - bool isSpanFormattable = innerInterfaces.Any(i => i.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat).Equals("global::System.ISpanFormattable")); - - var innerTypeInfo = new PropertyValidation.PropertyTypeInfo( - TypeName: innerTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), - IsGuid: SymbolEqualityComparer.Default.Equals(innerTypeSymbol, _knownTypeSymbols.GuidType), - IsString: SymbolEqualityComparer.Default.Equals(innerTypeSymbol, _knownTypeSymbols.StringType), - IsEnum: innerTypeSymbol.TypeKind == TypeKind.Enum, - IsSpanParsable: isSpanParsable, - IsSpanFormattable: isSpanFormattable); - - var formatValidation = PropertyValidation.ValidatePropertyFormat( - propertySpec.Name, - innerTypeInfo, - templateToken.Format); - - if (!formatValidation.IsSuccess) - { - ReportDiagnostic(formatValidation.Descriptor, _location, formatValidation.MessageArgs); - return null; - } - - var typeCompatibility = PropertyValidation.ValidatePropertyTypeCompatibility( - propertySpec.Name, - innerTypeInfo); - - if (!typeCompatibility.IsSuccess) - { - throw new NotSupportedException($"Unsupported inner type '{innerTypeInfo.TypeName}' for repeating property '{propertySpec.Name}'"); - } - - propertiesUsedInKey.Add(property); - - ParseType innerParseType; - FormatType innerFormatType; - string? format = templateToken.Format; - - if (innerTypeInfo.IsGuid) - { - innerParseType = ParseType.Guid; - innerFormatType = FormatType.Guid; - format = templateToken.Format?.ToLowerInvariant() ?? "d"; - } - else if (innerTypeInfo.IsString) - { - innerParseType = ParseType.String; - innerFormatType = FormatType.String; - format = null; - } - else if (innerTypeInfo.IsEnum) - { - innerParseType = ParseType.Enum; - innerFormatType = FormatType.Enum; - format = templateToken.Format?.ToLowerInvariant() ?? "g"; - } - else - { - innerParseType = ParseType.SpanParsable; - innerFormatType = FormatType.SpanFormattable; - } - - return new RepeatingPropertyKeyPart(propertySpec, templateToken.Separator, format, innerParseType, innerFormatType, new TypeRef(innerTypeSymbol)) - { - LengthRequired = 1, - ExactLengthRequirement = false - }; - } - } - - private static List? ParsePropertyInitializers( - ConstructorParameterSpec[] constructorParameters, - List properties, - ref ConstructionStrategy constructionStrategy, - bool constructorSetsRequiredParameters) - { - if (properties is []) - return []; - - HashSet? propertyInitializerNames = null; - List? propertyInitializers = null; - int parameterCount = constructorParameters.Length; - - foreach (var property in properties) - { - if (!property.HasSetter) - continue; - - if ((property.IsRequired || constructorSetsRequiredParameters) && !property.IsInitOnlySetter) - continue; - - if (!(propertyInitializerNames ??= []).Add(property.Name)) - continue; - - var matchingConstructorParameter = constructorParameters.FirstOrDefault(cp => cp.Name.Equals(property.Name, StringComparison.OrdinalIgnoreCase)); - if (!property.IsRequired && matchingConstructorParameter is not null) - continue; - - constructionStrategy = ConstructionStrategy.ParameterizedConstructor; - (propertyInitializers ??= []).Add(new PropertyInitializerSpec( - property.Type, - property.Name, - property.CamelCaseName, - matchingConstructorParameter?.ParameterIndex ?? parameterCount++, - matchingConstructorParameter is not null)); - } - - return propertyInitializers; - } - - private List<(PropertySpec Spec, ITypeSymbol TypeSymbol)> ParseProperties(INamedTypeSymbol typeSymbol) - { - List<(PropertySpec Spec, ITypeSymbol TypeSymbol)> properties = []; - foreach (var propertySymbol in typeSymbol.GetMembers().OfType()) - { - if (propertySymbol is { IsImplicitlyDeclared: true, Name: "EqualityContract" }) - continue; - - if (propertySymbol.IsStatic || propertySymbol.Parameters.Length > 0) - continue; - - // Detect collection types - var collectionType = CollectionType.None; - ITypeSymbol effectiveTypeSymbol = propertySymbol.Type; - - if (propertySymbol.Type is INamedTypeSymbol namedType) - { - var originalDefinition = namedType.OriginalDefinition; - if (SymbolEqualityComparer.Default.Equals(originalDefinition, _knownTypeSymbols.ListType)) - collectionType = CollectionType.List; - else if (SymbolEqualityComparer.Default.Equals(originalDefinition, _knownTypeSymbols.ReadOnlyListType)) - collectionType = CollectionType.IReadOnlyList; - else if (SymbolEqualityComparer.Default.Equals(originalDefinition, _knownTypeSymbols.ImmutableArrayType)) - collectionType = CollectionType.ImmutableArray; - } - - // For collection types, extract inner type for EnumSpec - EnumSpec? enumSpec = null; - if (collectionType != CollectionType.None) - { - var innerType = ((INamedTypeSymbol)propertySymbol.Type).TypeArguments[0]; - if (innerType is INamedTypeSymbol { TypeKind: TypeKind.Enum } innerEnumType) - enumSpec = ExtractEnumDefinition(innerEnumType); - } - else - { - if (propertySymbol.Type is INamedTypeSymbol { TypeKind: TypeKind.Enum } enumType) - enumSpec = ExtractEnumDefinition(enumType); - } - - var propertySpec = new PropertySpec( - new TypeRef(propertySymbol.Type), - propertySymbol.Name, - propertySymbol.Name.FirstToLowerInvariant(), - propertySymbol.IsRequired, - propertySymbol.GetMethod is not null, - propertySymbol.SetMethod is not null, - propertySymbol.SetMethod is { IsInitOnly: true }, - enumSpec, - collectionType); - - properties.Add((propertySpec, propertySymbol.Type)); - } - - return properties; - - EnumSpec ExtractEnumDefinition(INamedTypeSymbol enumSymbol) - { - string name = enumSymbol.Name; - string fullyQualifiedName = enumSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - string underlyingType = enumSymbol.EnumUnderlyingType?.ToString() ?? "int"; - - List members = []; - foreach (var enumMember in enumSymbol.GetMembers()) - { - if (enumMember is not IFieldSymbol { ConstantValue: not null } fieldSymbol) - continue; - - members.Add(new EnumSpec.Member(fieldSymbol.Name, fieldSymbol.ConstantValue)); - } - - bool isSequentialFromZero = true; - for (int i = 0; i < members.Count; i++) - { - if (Convert.ToUInt64(members[i].Value) == (uint)i) - continue; - - isSequentialFromZero = false; - break; - } - - return new EnumSpec( - name, - fullyQualifiedName, - underlyingType, - members.ToImmutableEquatableArray(), - isSequentialFromZero); - } - } - - private ConstructorParameterSpec[] ParseConstructorParameters( - IMethodSymbol constructor, out ConstructionStrategy constructionStrategy, out bool constructorSetsRequiredMembers) - { - constructorSetsRequiredMembers = constructor.GetAttributes().Any(a => - SymbolEqualityComparer.Default.Equals(a.AttributeClass, _knownTypeSymbols.SetsRequiredMembersAttributeType)); - - int parameterCount = constructor.Parameters.Length; - if (parameterCount == 0) - { - constructionStrategy = ConstructionStrategy.ParameterlessConstructor; - return []; - } - - constructionStrategy = ConstructionStrategy.ParameterizedConstructor; - - var constructorParameters = new ConstructorParameterSpec[parameterCount]; - for (int i = 0; i < parameterCount; i++) - { - var parameterSymbol = constructor.Parameters[i]; - constructorParameters[i] = new ConstructorParameterSpec(new TypeRef(parameterSymbol.Type), parameterSymbol.Name, parameterSymbol.Name.FirstToLowerInvariant(), i); - } - - return constructorParameters; - } - - - private CompositeKeyAttributeValues? ParseCompositeKeyAttributeValues(INamedTypeSymbol targetTypeSymbol) - { - Debug.Assert(_knownTypeSymbols.CompositeKeyAttributeType is not null); - - CompositeKeyAttributeValues? attributeValues = null; - - foreach (var attributeData in targetTypeSymbol.GetAttributes()) - { - Debug.Assert(attributeValues is null, $"There should only ever be one {nameof(CompositeKeyAttribute)} per type"); - - if (SymbolEqualityComparer.Default.Equals(attributeData.AttributeClass, _knownTypeSymbols.CompositeKeyAttributeType)) - attributeValues = TryExtractAttributeValues(attributeData); - } - - return attributeValues; - - static CompositeKeyAttributeValues? TryExtractAttributeValues(AttributeData attributeData) - { - Debug.Assert(attributeData.ConstructorArguments.Length is 1); - - string? templateString = (string?)attributeData.ConstructorArguments[0].Value; - if (templateString is null) - return null; - - char? primaryKeySeparator = null; - bool? invariantCulture = null; - foreach (var namedArgument in attributeData.NamedArguments) - { - (string? key, var value) = (namedArgument.Key, namedArgument.Value); - - switch (key) - { - case nameof(CompositeKeyAttribute.PrimaryKeySeparator): - primaryKeySeparator = (char?)value.Value; - break; - - case nameof(CompositeKeyAttribute.InvariantCulture): - invariantCulture = (bool?)value.Value; - break; - - default: - throw new InvalidOperationException(); - } - } - - return new CompositeKeyAttributeValues(templateString, primaryKeySeparator, invariantCulture ?? true); - } - } - - } -} diff --git a/src/CompositeKey.SourceGeneration/SourceGenerator.cs b/src/CompositeKey.SourceGeneration/SourceGenerator.cs index 6813f7c..8f76c9b 100644 --- a/src/CompositeKey.SourceGeneration/SourceGenerator.cs +++ b/src/CompositeKey.SourceGeneration/SourceGenerator.cs @@ -7,8 +7,10 @@ namespace CompositeKey.SourceGeneration; [Generator] -public sealed partial class SourceGenerator : IIncrementalGenerator +public sealed class SourceGenerator : IIncrementalGenerator { + private const string CompositeKeyAttributeFullName = "CompositeKey.CompositeKeyAttribute"; + public const string GenerationSpecTrackingName = nameof(GenerationSpec); public void Initialize(IncrementalGeneratorInitializationContext context)