Skip to content

Commit

Permalink
Add support for hierarchical record creation from generated factories
Browse files Browse the repository at this point in the history
We now invoke factories automatically whenever they are
found.

A new diagnostic is reported if a `Create(dynamic value)` factory
method is found in the record but it's not accessible to the
generated factory in the current assembly. The diagnostic is
nevertheless a warning since in that case we still generate our
custom factory, but users will not get the benefit of invoking
their custom factory (although this might be by design if it's
just a coincidence in the method name and signature they
chose for an entirely unrelated purpose).

Closes #47
  • Loading branch information
kzu committed Nov 10, 2022
1 parent 59b16a6 commit 47203c0
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 40 deletions.
12 changes: 12 additions & 0 deletions src/Merq.CodeAnalysis/Diagnostics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,16 @@ public static class Diagnostics
DiagnosticSeverity.Error,
isEnabledByDefault: true,
description: "Commands must implement the interface that matches the handler's.");

/// <summary>
/// MERQ006: Factory method is not accessible
/// </summary>
public static DiagnosticDescriptor CreateMethodNotAccessible { get; } = new(
"MERQ006",
"Factory method is not accessible",
"Factory method '{0}.Create' is not accessible within the current compilation to support hierarchical dynamic conversion.",
"Build",
DiagnosticSeverity.Warning,
isEnabledByDefault: true,
description: "In order to support automatic hierarchical dynamic conversion for records, the Create method must be accessible within the compilation.");
}
1 change: 1 addition & 0 deletions src/Merq.CodeAnalysis/Merq.CodeAnalysis.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
<PackageReference Include="Microsoft.CSharp" Version="4.7.0" />
<PackageReference Include="Scriban" Version="5.5.0" Pack="false" IncludeAssets="build" />
<PackageReference Include="Superpower" Version="3.0.0" PrivateAssets="all" />
<PackageReference Include="PolySharp" Version="1.7.1" Pack="false" />
</ItemGroup>

</Project>
19 changes: 16 additions & 3 deletions src/Merq.CodeAnalysis/RecordFactory.sbntxt
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
// <auto-generated />
{{~
func getValue(x)
if x.Factory
ret x.Factory + "(value." + x.Name + ")"
else
ret "value." + x.Name
end
end
~}}
// <auto-generated />
using System;
using Microsoft.CSharp.RuntimeBinder;

Expand All @@ -13,14 +22,18 @@ namespace {{ Namespace }}

try
{
return new {{ Name }}({{ Parameters | array.each @(do; ret "value." + $0; end) | array.join ', ' }}){{~ if !HasProperties ~}};{{~ end }}
{{~ if Factory ~}}
return {{ Factory }}(value);
{{~ else ~}}
return new {{ Name }}({{ Parameters | array.each @getValue | array.join ', ' }}){{~ if !HasProperties ~}};{{~ end }}
{{~ if HasProperties ~}}
{
{{~ for prop in Properties ~}}
{{ prop }} = value.{{ prop }},
{{ prop.Name }} = {{ getValue prop }},
{{~ end ~}}
};
{{~ end ~}}
{{~ end ~}}
}
catch (RuntimeBinderException e)
{
Expand Down
110 changes: 88 additions & 22 deletions src/Merq.CodeAnalysis/RecordFactoryGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Reflection;
Expand All @@ -14,19 +15,24 @@ namespace Merq;
[Generator(LanguageNames.CSharp)]
public class RecordFactoryGenerator : IIncrementalGenerator
{
static readonly SymbolDisplayFormat fullNameFormat = new SymbolDisplayFormat(
static readonly SymbolDisplayFormat fullNameFormat = new(
typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameAndContainingTypesAndNamespaces,
genericsOptions: SymbolDisplayGenericsOptions.IncludeTypeParameters,
miscellaneousOptions: SymbolDisplayMiscellaneousOptions.ExpandNullable);

public void Initialize(IncrementalGeneratorInitializationContext context)
{
var types = context.CompilationProvider.SelectMany((x, c) =>
TypesVisitor.Visit(x.GlobalNamespace, symbol => x.IsSymbolAccessibleWithin(symbol, x.Assembly) && symbol.IsRecord, c));
TypesVisitor.Visit(x.GlobalNamespace, symbol =>
x.IsSymbolAccessibleWithin(symbol, x.Assembly) &&
symbol.IsRecord &&
symbol.ContainingNamespace != null, c));

context.RegisterSourceOutput(types, (ctx, data) =>
context.RegisterSourceOutput(
types.Combine(context.CompilationProvider),
(ctx, data) =>
{
var ctor = data.InstanceConstructors
var ctor = data.Left.InstanceConstructors
.Where(x => x.DeclaredAccessibility == Accessibility.Public || x.DeclaredAccessibility == Accessibility.Internal)
.OrderByDescending(x => x.Parameters.Length).FirstOrDefault();
if (ctor == null)
Expand All @@ -35,40 +41,82 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
using var resource = Assembly.GetExecutingAssembly().GetManifestResourceStream("Merq.RecordFactory.sbntxt");
using var reader = new StreamReader(resource!);
var template = Template.Parse(reader.ReadToEnd());
var compilation = data.Right;
string? GetFactory(ITypeSymbol type)
{
if (type.SpecialType != SpecialType.None)
return null;
var ns = type.ContainingNamespace.Equals(data.Left.ContainingNamespace, SymbolEqualityComparer.Default) ?
"" : $"{type.ContainingNamespace.ToDisplayString(fullNameFormat)}.";
if (FindCreate(type) is IMethodSymbol create &&
compilation.IsSymbolAccessibleWithin(create, compilation.Assembly))
{
// We either had a custom Create factory method, or the type is partial,
// and we'll generate it ourselves.
return ns + type.Name + ".Create";
}
else if (!HasCreate(type) && IsPartial(type) && type.IsRecord)
{
// We'll generate a Create factory method.
return ns + type.Name + ".Create";
}
else if (type.IsRecord)
{
// If the type isn't partial or has a Create method, we will
// generate a factory class for it.
return ns + $"__{type.Name}Factory.Create";
}
return null;
};
// Get properties that can be set and are not named (case insensitive) as ctor parameters
var properties = data.GetMembers().OfType<IPropertySymbol>()
var properties = data.Left.GetMembers().OfType<IPropertySymbol>()
.Where(x => x.SetMethod != null && !ctor.Parameters.Any(y => string.Equals(y.Name, x.Name, StringComparison.OrdinalIgnoreCase)))
.Select(x => x.Name)
.OrderBy(x => x)
.Select(x => new
{
x.Name,
Factory = GetFactory(x.Type)
})
.OrderBy(x => x.Name)
.ToImmutableArray();
var output = template.Render(new
{
Namespace = data.ContainingNamespace.ToDisplayString(fullNameFormat),
Name = data.Name,
Parameters = ctor.Parameters.Select(x => x.Name).ToArray(),
Namespace = data.Left.ContainingNamespace.ToDisplayString(fullNameFormat),
Name = data.Left.Name,
Factory = FindCreate(data.Left) is IMethodSymbol create &&
compilation.IsSymbolAccessibleWithin(create, compilation.Assembly) ? data.Left.Name + ".Create" : null,
Parameters = ctor.Parameters.Select(x => new
{
x.Name,
Factory = GetFactory(x.Type)
}).ToArray(),
HasProperties = !properties.IsDefaultOrEmpty,
Properties = properties,
}, member => member.Name);
ctx.AddSource(data.Name + ".Factory.g", output.Replace("\r\n", "\n").Replace("\n", Environment.NewLine));
ctx.AddSource(data.Left.Name + ".Factory.g", output.Replace("\r\n", "\n").Replace("\n", Environment.NewLine));
if (FindCreate(data.Left) is IMethodSymbol factory &&
!compilation.IsSymbolAccessibleWithin(factory, compilation.Assembly))
{
ctx.ReportDiagnostic(Diagnostic.Create(
Diagnostics.CreateMethodNotAccessible,
factory.Locations.FirstOrDefault(),
data.Left.Name));
}
});

context.RegisterSourceOutput(
// Only generate a partial factory method for partial records
// NOTE: if there are no declaring syntax references, it's because the type is declared
// in another project so we cannot declare a partial.
types.Where(x => x.DeclaringSyntaxReferences.Any() && x.DeclaringSyntaxReferences.All(
r => r.GetSyntax() is RecordDeclarationSyntax c && c.Modifiers.Any(
m => m.IsKind(Microsoft.CodeAnalysis.CSharp.SyntaxKind.PartialKeyword))) &&
// Don't generate duplicate method names. We also don't generate if there's already a Create with
// a single parameter.
!x.GetMembers().OfType<IMethodSymbol>().Where(
x => x.Name == "Create" && x.IsStatic && x.Parameters.Length == 1 &&
(x.Parameters[0].Type.SpecialType == SpecialType.System_Object ||
x.Parameters[0].Type.TypeKind == TypeKind.Dynamic)).Any()),
// Don't generate duplicate method names. We also don't generate if there's already a Create with
// a single parameter.
types.Where(x => IsPartial(x) && !HasCreate(x)),
(ctx, data) =>
{
ctx.AddSource(data.Name + ".Create.g",
Expand All @@ -86,6 +134,24 @@ partial record {{data.Name}}
});
}

/// <summary>
/// Checks if there are declaring syntax references with the 'partial' keyword.
/// Types declared in another project will not have them.
/// </summary>
static bool IsPartial(ITypeSymbol type)
=> type.DeclaringSyntaxReferences.Any() && type.DeclaringSyntaxReferences.All(
r => r.GetSyntax() is RecordDeclarationSyntax c && c.Modifiers.Any(
m => m.IsKind(Microsoft.CodeAnalysis.CSharp.SyntaxKind.PartialKeyword)));

static bool HasCreate(ITypeSymbol type) => FindCreate(type) is IMethodSymbol;

static IMethodSymbol? FindCreate(ITypeSymbol type)
=> type.GetMembers()
.OfType<IMethodSymbol>()
.Where(x => x.Name == "Create" && x.IsStatic && x.Parameters.Length == 1 &&
(x.Parameters[0].Type.SpecialType == SpecialType.System_Object || x.Parameters[0].Type.TypeKind == TypeKind.Dynamic))
.FirstOrDefault();

class TypesVisitor : SymbolVisitor
{
readonly Func<INamedTypeSymbol, bool> shouldInclude;
Expand Down
4 changes: 2 additions & 2 deletions src/Merq.Tests/MessageBusSpec.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

namespace Merq;

partial record Foo
partial record Foo(string Message, string Format)
{
static Foo Create(dynamic value) => new Foo();
internal static Foo Create(dynamic value) => new(value.Message, value.Format);
}

partial record Foo { }
Expand Down
20 changes: 20 additions & 0 deletions src/Merq.Tests/Records.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using System.Collections.Generic;

namespace Merq.Records;

public partial record Point(int X, int Y);

public partial record Line(Point Start, Point End);

public record Buffer(List<Line> Lines)
{
public static Buffer Create(dynamic value)
{
var lines = new List<Line>();
foreach (var line in value.Lines)
{
//lines.Add(Line.Create(line));
}
return new Buffer(lines);
}
}
19 changes: 15 additions & 4 deletions src/Merq.Tests/TemplateTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,29 @@ public record TemplateTests(ITestOutputHelper Output)
Namespace: MyProject.MyNamespace
Name: MyRecord
Parameters:
- Message
- Format
- Name: Message
- Name: Format
Factory: Format.Create
""")]
[InlineData("../../../../Merq.CodeAnalysis/RecordFactory.sbntxt",
"""
Namespace: MyProject.MyNamespace
Name: MyRecord
Factory: MyRecord.Create
""")]
[InlineData("../../../../Merq.CodeAnalysis/RecordFactory.sbntxt",
"""
Namespace: MyProject.MyNamespace
Name: MyRecord
Parameters:
- Message
- Name: Message
- Name: Format
Factory: Format.Create
HasProperties: true
Properties:
- Timestamp
- Name: Timestamp
Factory: Timestamp.Create
- Name: Id
""")]
[Theory]
public void RenderTemplate(string templateFile, string modelYaml)
Expand Down
12 changes: 3 additions & 9 deletions src/Samples/Library2/Events.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,11 @@

public record DuckEvent(string Message);

public record Point(int X, int Y)
{
public static Point Create(dynamic value) => new Point(value.X, value.Y);
}
public partial record Point(int X, int Y);

public record Line(Point Start, Point End)
{
public static Line Create(dynamic value) => new Line(Point.Create(value.Start), Point.Create(value.End));
}
public partial record Line(Point Start, Point End);

public record OnDidDrawLine(Line Line)
{
public static OnDidDrawLine Create(dynamic value) => new OnDidDrawLine(Line.Create(value.Line));
public static OnDidDrawLine Create(dynamic value) => new(Line.Create(value.Line));
}

0 comments on commit 47203c0

Please sign in to comment.