diff --git a/LinqGen.Generator/CodeGenUtils.cs b/LinqGen.Generator/CodeGenUtils.cs index 3b9b62c..d947979 100644 --- a/LinqGen.Generator/CodeGenUtils.cs +++ b/LinqGen.Generator/CodeGenUtils.cs @@ -1,6 +1,7 @@ // LinqGen.Generator, Maxwell Keonwoo Kang , 2022 using System; +using System.Collections.Immutable; using System.Linq; namespace Cathei.LinqGen.Generator; @@ -96,7 +97,7 @@ symbol.TypeArguments[1] is not INamedTypeSymbol resultSignatureSymbol || } public static bool TryParseStubMethod(IMethodSymbol methodSymbol, - out ITypeSymbol inputElementSymbol, out INamedTypeSymbol[] signatureSymbols) + out ITypeSymbol inputElementSymbol, out IEnumerable signatureSymbols) { if (methodSymbol.ReceiverType is not INamedTypeSymbol receiverTypeSymbol || !TryParseStubInterface(receiverTypeSymbol, out inputElementSymbol, out var receiverSignatureSymbol)) @@ -124,7 +125,7 @@ public static bool TryParseStubMethod(IMethodSymbol methodSymbol, signatureSymbolsList.Add(paramSignatureSymbol); } - signatureSymbols = signatureSymbolsList.ToArray(); + signatureSymbols = signatureSymbolsList; return true; } @@ -782,4 +783,9 @@ public static BlockSyntax AddStatements(this BlockSyntax block, SyntaxList, 2022 +using System; + namespace Cathei.LinqGen.Generator; -public readonly struct EvaluationKey : IEqualityComparer +public readonly struct EvaluationKey : IEquatable { private static readonly SymbolEqualityComparer SymbolComparer = SymbolEqualityComparer.Default; @@ -18,17 +20,18 @@ public EvaluationKey( InputElementSymbol = inputElementSymbol; } - public bool Equals(EvaluationKey x, EvaluationKey y) + public bool Equals(EvaluationKey other) { - return SymbolComparer.Equals(x.SignatureSymbol, y.SignatureSymbol) && - SymbolComparer.Equals(x.MethodSymbol, y.MethodSymbol) && - SymbolComparer.Equals(x.InputElementSymbol, y.InputElementSymbol); + return SymbolComparer.Equals(SignatureSymbol, other.SignatureSymbol) && + SymbolComparer.Equals(MethodSymbol, other.MethodSymbol) && + SymbolComparer.Equals(InputElementSymbol, other.InputElementSymbol); } - public int GetHashCode(EvaluationKey obj) + public override int GetHashCode() { - return SymbolComparer.GetHashCode(obj.SignatureSymbol) ^ - SymbolComparer.GetHashCode(obj.MethodSymbol) ^ - SymbolComparer.GetHashCode(obj.InputElementSymbol); + int hashCode = SymbolComparer.GetHashCode(SignatureSymbol); + hashCode = HashCombine(hashCode, SymbolComparer.GetHashCode(MethodSymbol)); + hashCode = HashCombine(hashCode, SymbolComparer.GetHashCode(InputElementSymbol)); + return hashCode; } } \ No newline at end of file diff --git a/LinqGen.Generator/Instructions/Instruction.cs b/LinqGen.Generator/Instructions/Instruction.cs index e0abb6f..5edcc2b 100644 --- a/LinqGen.Generator/Instructions/Instruction.cs +++ b/LinqGen.Generator/Instructions/Instruction.cs @@ -1,6 +1,7 @@ // LinqGen.Generator, Maxwell Keonwoo Kang , 2022 using System; +using System.Collections.Immutable; using System.Linq; namespace Cathei.LinqGen.Generator; @@ -13,10 +14,10 @@ namespace Cathei.LinqGen.Generator; /// public abstract class Instruction { - public INamedTypeSymbol[]? UpstreamSignatureSymbols { get; } + public ImmutableArray UpstreamSignatureSymbols { get; } public string Id { get; } - protected Instruction(in LinqGenExpression expression, int id) + protected Instruction(in LinqGenExpression expression, uint id) { UpstreamSignatureSymbols = expression.UpstreamSignatureSymbols; Id = Base62.Encode(id); diff --git a/LinqGen.Generator/Instructions/InstructionFactory.cs b/LinqGen.Generator/Instructions/InstructionFactory.cs index 79d1589..7df378a 100644 --- a/LinqGen.Generator/Instructions/InstructionFactory.cs +++ b/LinqGen.Generator/Instructions/InstructionFactory.cs @@ -11,7 +11,7 @@ public static class InstructionFactory /// /// The Instruction instance must be unique per signature (per generic arguments combination). /// - public static Generation? CreateGeneration(StringBuilder logBuilder, in LinqGenExpression expression, int id) + public static Generation? CreateGeneration(in LinqGenExpression expression, int id) { switch (expression.SignatureSymbol!.Name) { @@ -274,7 +274,7 @@ public static class InstructionFactory return null; } - public static Evaluation? CreateEvaluation(StringBuilder logBuilder, in LinqGenExpression expression, int id) + public static Evaluation? CreateEvaluation(in LinqGenExpression expression, int id) { switch (expression.MethodSymbol.Name) { diff --git a/LinqGen.Generator/LinqGen.Generator.csproj b/LinqGen.Generator/LinqGen.Generator.csproj index 0aade77..b80cebc 100644 --- a/LinqGen.Generator/LinqGen.Generator.csproj +++ b/LinqGen.Generator/LinqGen.Generator.csproj @@ -8,11 +8,12 @@ Alloc-free and fast replacement for Linq, with code generation. https://github.com/cathei/LinqGen MIT + true - - + + all runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/LinqGen.Generator/LinqGenExpression.cs b/LinqGen.Generator/LinqGenExpression.cs index b895293..e253a47 100644 --- a/LinqGen.Generator/LinqGenExpression.cs +++ b/LinqGen.Generator/LinqGenExpression.cs @@ -1,6 +1,7 @@ // LinqGen.Generator, Maxwell Keonwoo Kang , 2022 using System; +using System.Collections.Immutable; using System.Linq; namespace Cathei.LinqGen.Generator; @@ -10,10 +11,10 @@ public readonly struct LinqGenExpression public IMethodSymbol MethodSymbol { get; } public INamedTypeSymbol? SignatureSymbol { get; } public ITypeSymbol? InputElementSymbol { get; } - public INamedTypeSymbol[]? UpstreamSignatureSymbols { get; } + public ImmutableArray UpstreamSignatureSymbols { get; } private LinqGenExpression(IMethodSymbol methodSymbol, INamedTypeSymbol? signatureSymbol, - ITypeSymbol? inputElementSymbol, INamedTypeSymbol[]? upstreamSignatureSymbols) + ITypeSymbol? inputElementSymbol, ImmutableArray upstreamSignatureSymbols) { MethodSymbol = methodSymbol; SignatureSymbol = signatureSymbol; @@ -57,7 +58,7 @@ public static bool TryParse(SemanticModel semanticModel, } ITypeSymbol? inputElementSymbol = null; - INamedTypeSymbol[]? upstreamSignatureSymbols = null; + IEnumerable? upstreamSignatureSymbols = null; // this means it takes LinqGen enumerable as input, and upstream type is required if (methodSymbol.ReceiverType is INamedTypeSymbol receiverTypeSymbol && @@ -80,10 +81,11 @@ public static bool TryParse(SemanticModel semanticModel, signatureSymbol = NormalizeSignature(signatureSymbol); if (upstreamSignatureSymbols != null) - upstreamSignatureSymbols = upstreamSignatureSymbols.Select(NormalizeSignature).ToArray(); + upstreamSignatureSymbols = upstreamSignatureSymbols.Select(NormalizeSignature); result = new LinqGenExpression( - methodSymbol, signatureSymbol, inputElementSymbol, upstreamSignatureSymbols); + methodSymbol, signatureSymbol, inputElementSymbol, + upstreamSignatureSymbols?.ToImmutableArray() ?? ImmutableArray.Empty); return true; } @@ -129,7 +131,7 @@ public static bool TryParse(SemanticModel semanticModel, } result = new LinqGenExpression( - methodSymbol, null, inputElementSymbol, new[] { NormalizeSignature(upstreamSignatureSymbol) }); + methodSymbol, null, inputElementSymbol, ImmutableArray.Create(upstreamSignatureSymbol)); return true; } diff --git a/LinqGen.Generator/LinqGenIncrementalGenerator.cs b/LinqGen.Generator/LinqGenIncrementalGenerator.cs index abf52e5..6d39979 100644 --- a/LinqGen.Generator/LinqGenIncrementalGenerator.cs +++ b/LinqGen.Generator/LinqGenIncrementalGenerator.cs @@ -1,9 +1,260 @@ // LinqGen.Generator, Maxwell Keonwoo Kang , 2022 +using System.Collections.Immutable; +using System.Linq; + namespace Cathei.LinqGen.Generator; // Not supported in current Unity version -// [Generator] -// public class LinqGenIncrementalGenerator : IIncrementalGenerator -// { -// } \ No newline at end of file +[Generator] +public class LinqGenIncrementalGenerator : IIncrementalGenerator +{ + public void Initialize(IncrementalGeneratorInitializationContext context) + { + var expressions = context.SyntaxProvider.CreateSyntaxProvider( + static (node, _) => ExpressionPredicate(node), + static (ctx, _) => ExpressionTransform(ctx.SemanticModel, ctx.Node)) + .Where(static x => x != null) + .Select(static (x, _) => x!.Value); + + var generations = expressions + .Where(static x => x.IsCompilingGeneration()) + .Collect() + .Select(static (x, _) => CreateGenerationDictionary(x)); + + var evaluations = expressions + .Where(static x => !x.IsCompilingGeneration()) + .Collect() + .Select(static (x, _) => CreateEvaluationDictionary(x)); + + var downstream = generations.Combine(evaluations) + .Select(static (x, _) => CreateDownstreamDictionary(x.Left, x.Right)); + + var dependencies = generations.Combine(downstream) + .SelectMany(static (x, _) => CreateDependencies(x.Left, x.Right)); + + context.RegisterSourceOutput(dependencies, Render); + } + + public readonly struct LinqGenExpressionDependency + { + public readonly LinqGenExpression Expression; + public readonly ImmutableArray Dependencies; + + /// logic + /// consider all nested upstreams + /// consider direct downstream and its all additional nested upstreams + + public LinqGenExpressionDependency( + in LinqGenExpression expression, + in ImmutableArray dependencies) + { + Expression = expression; + Dependencies = dependencies; + } + } + + private static uint GenerateStableId(in LinqGenExpression expr) + { + return unchecked((uint)SymbolEqualityComparer.Default.GetHashCode(expr.SignatureSymbol)); + } + + private static bool ExpressionPredicate(SyntaxNode node) + { + return node is + InvocationExpressionSyntax { Expression: MemberAccessExpressionSyntax } or + CommonForEachStatementSyntax; + } + + private static LinqGenExpression? ExpressionTransform(SemanticModel model, SyntaxNode node) + { + if (node is InvocationExpressionSyntax invocationSyntax) + { + if (LinqGenExpression.TryParse(model, invocationSyntax, out var expression)) + { + return expression; + } + } + else if (node is CommonForEachStatementSyntax forEachSyntax) + { + if (LinqGenExpression.TryParse(model, forEachSyntax, out var expression)) + { + return expression; + } + } + + return null; + } + + private static ImmutableDictionary CreateGenerationDictionary( + in ImmutableArray expressions) + { + var builder = ImmutableDictionary.CreateBuilder(SymbolEqualityComparer.Default); + + foreach (var expr in expressions) + { + if (builder.ContainsKey(expr.SignatureSymbol!)) + { + // already registered + continue; + } + + builder.Add(expr.SignatureSymbol!, expr); + } + + return builder.ToImmutable(); + } + + private static ImmutableDictionary CreateEvaluationDictionary( + in ImmutableArray expressions) + { + var builder = ImmutableDictionary.CreateBuilder(); + + foreach (var expr in expressions) + { + var key = new EvaluationKey(expr.UpstreamSignatureSymbols[0], expr.MethodSymbol, expr.InputElementSymbol!); + + if (builder.ContainsKey(key)) + { + // already registered + continue; + } + + builder.Add(key, expr); + } + + return builder.ToImmutable(); + } + + private static ImmutableDictionary> CreateDownstreamDictionary( + ImmutableDictionary generations, + ImmutableDictionary evaluations) + { + var downstream = generations.ToDictionary( + static x => x.Key, + static _ => ImmutableArray.CreateBuilder(), + generations.KeyComparer); + + foreach (var generation in generations.Values) + { + if (generation.UpstreamSignatureSymbols.IsEmpty) + continue; + + // only first upstream depends on downstream + if (downstream.TryGetValue(generation.UpstreamSignatureSymbols[0], out var builder)) + builder.Add(generation); + } + + foreach (var evaluation in evaluations.Values) + { + // only first upstream depends on downstream + if (downstream.TryGetValue(evaluation.UpstreamSignatureSymbols[0], out var builder)) + builder.Add(evaluation); + } + + return downstream.ToImmutableDictionary( + x => x.Key, + x => x.Value.ToImmutable(), + downstream.Comparer); + } + + private static IEnumerable CreateDependencies( + ImmutableDictionary generations, + ImmutableDictionary> downstream) + { + var hashSet = new HashSet(); + + foreach (var pair in generations) + { + hashSet.Clear(); + hashSet.Add(pair.Value); + CollectUpwardDependencies(pair.Value, generations, hashSet); + + foreach (var expression in downstream[pair.Key]) + { + hashSet.Add(expression); + CollectUpwardDependencies(expression, generations, hashSet); + } + + yield return new LinqGenExpressionDependency(pair.Value, hashSet.ToImmutableArray()); + } + } + + private static void CollectUpwardDependencies( + in LinqGenExpression current, + ImmutableDictionary generations, + HashSet result) + { + foreach (var symbol in current.UpstreamSignatureSymbols) + { + // failed to find upstream + if (!generations.TryGetValue(symbol, out var upstream)) + continue; + + // result already been added + if (!result.Add(upstream)) + continue; + + // no more upstream to traverse + if (upstream.UpstreamSignatureSymbols.IsEmpty) + continue; + + CollectUpwardDependencies(upstream, generations, result); + } + } + + private static void Render(SourceProductionContext context, LinqGenExpressionDependency dependency) + { + var generations = new Dictionary(SymbolEqualityComparer.Default); + + foreach (var dep in dependency.Dependencies) + { + if (!dep.IsCompilingGeneration()) + continue; + + var generation = InstructionFactory.CreateGeneration(dep, GenerateStableId(dep)); + + // Something went wrong? + if (generation == null) + continue; + + generations.Add(dep.SignatureSymbol!, generation); + } + + foreach (var generation in generations.Values) + { + foreach (var upstreamSymbol in generation.UpstreamSignatureSymbols) + { + if (!generations.TryGetValue(upstreamSymbol, out var upstream)) + continue; + + generation.AddUpstream(upstream); + } + } + + foreach (var dep in dependency.Dependencies) + { + if (dep.IsCompilingGeneration()) + continue; + + var evaluation = InstructionFactory.CreateEvaluation(dep, GenerateStableId(dep)); + + // Something went wrong? + if (evaluation == null) + continue; + + foreach (var upstreamSymbol in evaluation.UpstreamSignatureSymbols) + { + if (!generations.TryGetValue(upstreamSymbol, out var upstream)) + continue; + + evaluation.AddUpstream(upstream); + } + } + + var generationToRender = generations[dependency.Expression.SignatureSymbol!]; + var sourceText = FileTemplate.Render(generationToRender.Render()); + + context.AddSource($"LinqGen.{generationToRender.FileName}.g.cs", sourceText); + } +} \ No newline at end of file diff --git a/LinqGen.Generator/LinqGenSourceGenerator.cs b/LinqGen.Generator/LinqGenSourceGenerator.cs index 8a47276..458e6e0 100644 --- a/LinqGen.Generator/LinqGenSourceGenerator.cs +++ b/LinqGen.Generator/LinqGenSourceGenerator.cs @@ -6,76 +6,76 @@ namespace Cathei.LinqGen.Generator; -[Generator] -public class LinqGenSourceGenerator : ISourceGenerator -{ - public void Initialize(GeneratorInitializationContext context) { } - - public void Execute(GeneratorExecutionContext context) - { - StringBuilder logBuilder = new(); - - logBuilder.AppendLine("/* Started */"); - - try - { - var syntaxReceiver = new LinqGenSyntaxReceiver(logBuilder); - - foreach (var syntaxTree in context.Compilation.SyntaxTrees) - { - var semanticModel = context.Compilation.GetSemanticModel(syntaxTree); - syntaxReceiver.VisitSyntaxTree(semanticModel, syntaxTree); - } - - syntaxReceiver.ResolveHierarchy(); - - var buffer = new List(); - int batch = 1; - int count = 0; - - foreach (var result in syntaxReceiver.Roots.SelectMany(RenderNodeRecursive)) - { - buffer.AddRange(result); - - if (count++ > 10) - { - context.AddSource($"LinqGen.{batch}.cs", FileTemplate.Render(buffer)); - batch++; - count = 0; - buffer.Clear(); - } - } - - // last batch - if (buffer.Count > 0) - context.AddSource($"LinqGen.{batch}.cs", FileTemplate.Render(buffer)); - } - catch (Exception ex) - { - logBuilder.AppendFormat("/* Exception found: {0} */\n", ex); -#if !DEBUG - throw; -#endif - } - finally - { - logBuilder.AppendLine("/* Ended */"); - - context.AddSource("Log.g.cs", logBuilder.ToString()); - } - } - - private IEnumerable> RenderNodeRecursive(Generation generation) - { - yield return generation.Render(); - - if (generation.Downstream != null) - { - foreach (var downstream in generation.Downstream) - { - foreach (var result in RenderNodeRecursive(downstream)) - yield return result; - } - } - } -} \ No newline at end of file +// [Generator] +// public class LinqGenSourceGenerator : ISourceGenerator +// { +// public void Initialize(GeneratorInitializationContext context) { } +// +// public void Execute(GeneratorExecutionContext context) +// { +// StringBuilder logBuilder = new(); +// +// logBuilder.AppendLine("/* Started */"); +// +// try +// { +// var syntaxReceiver = new LinqGenSyntaxReceiver(logBuilder); +// +// foreach (var syntaxTree in context.Compilation.SyntaxTrees) +// { +// var semanticModel = context.Compilation.GetSemanticModel(syntaxTree); +// syntaxReceiver.VisitSyntaxTree(semanticModel, syntaxTree); +// } +// +// syntaxReceiver.ResolveHierarchy(); +// +// var buffer = new List(); +// int batch = 1; +// int count = 0; +// +// foreach (var result in syntaxReceiver.Roots.SelectMany(RenderNodeRecursive)) +// { +// buffer.AddRange(result); +// +// if (count++ > 10) +// { +// context.AddSource($"LinqGen.{batch}.cs", FileTemplate.Render(buffer)); +// batch++; +// count = 0; +// buffer.Clear(); +// } +// } +// +// // last batch +// if (buffer.Count > 0) +// context.AddSource($"LinqGen.{batch}.cs", FileTemplate.Render(buffer)); +// } +// catch (Exception ex) +// { +// logBuilder.AppendFormat("/* Exception found: {0} */\n", ex); +// #if !DEBUG +// throw; +// #endif +// } +// finally +// { +// logBuilder.AppendLine("/* Ended */"); +// +// context.AddSource("Log.g.cs", logBuilder.ToString()); +// } +// } +// +// private IEnumerable> RenderNodeRecursive(Generation generation) +// { +// yield return generation.Render(); +// +// if (generation.Downstream != null) +// { +// foreach (var downstream in generation.Downstream) +// { +// foreach (var result in RenderNodeRecursive(downstream)) +// yield return result; +// } +// } +// } +// } \ No newline at end of file diff --git a/LinqGen.Generator/LinqGenSyntaxReceiver.cs b/LinqGen.Generator/LinqGenSyntaxReceiver.cs index 88ccbaa..e005bdb 100644 --- a/LinqGen.Generator/LinqGenSyntaxReceiver.cs +++ b/LinqGen.Generator/LinqGenSyntaxReceiver.cs @@ -65,7 +65,7 @@ private void AddGeneration(in LinqGenExpression expression) return; } - var generation = InstructionFactory.CreateGeneration(_logBuilder, expression, ++_idCounter); + var generation = InstructionFactory.CreateGeneration(expression, ++_idCounter); if (generation == null) { @@ -84,7 +84,7 @@ private void AddGeneration(in LinqGenExpression expression) private void AddEvaluation(in LinqGenExpression expression) { var key = new EvaluationKey( - expression.UpstreamSignatureSymbols![0], expression.MethodSymbol, expression.InputElementSymbol!); + expression.UpstreamSignatureSymbols[0], expression.MethodSymbol, expression.InputElementSymbol!); if (_evaluations.ContainsKey(key)) { @@ -92,7 +92,7 @@ private void AddEvaluation(in LinqGenExpression expression) return; } - var evaluation = InstructionFactory.CreateEvaluation(_logBuilder, expression, ++_idCounter); + var evaluation = InstructionFactory.CreateEvaluation(expression, ++_idCounter); if (evaluation == null) { @@ -114,7 +114,7 @@ public void ResolveHierarchy() { var upstreamSymbols = generation.UpstreamSignatureSymbols; - if (upstreamSymbols == null) + if (upstreamSymbols.IsDefaultOrEmpty) { Roots.Add(generation); continue;