Skip to content

Commit

Permalink
Fixed Tag Composition Tooling (#6783)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelstaib authored Dec 18, 2023
1 parent d272875 commit 8568047
Show file tree
Hide file tree
Showing 13 changed files with 287 additions and 115 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,18 @@ namespace HotChocolate.Fusion.Composition.Features;
/// <summary>
/// Specifies behavior of the @tag directive.
/// </summary>
public sealed class TagDirectiveFeature : IFusionFeature
public sealed class TagDirectiveFeature(
IEnumerable<string>? exclude = null,
bool makeTagsPublic = false)
: IFusionFeature
{
public TagDirectiveFeature(
IEnumerable<string>? exclude = null,
bool makeTagsPublic = false)
{
Excluded = new HashSet<string>(exclude ?? Enumerable.Empty<string>());
MakeTagsPublic = makeTagsPublic;
}

/// <summary>
/// Gets the tags that shall be excluded from the public schema.
/// </summary>
public IReadOnlySet<string> Excluded { get; }
public IReadOnlySet<string> Excluded { get; } = new HashSet<string>(exclude ?? Enumerable.Empty<string>());

/// <summary>
/// Defines if the tag directives should be exported to the public schema.
/// </summary>
public bool MakeTagsPublic { get; }
public bool MakeTagsPublic { get; } = makeTagsPublic;
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ internal FusionGraphComposer(
.Use<MergeQueryAndMutationTypeMiddleware>()
.Use<MergeSubscriptionTypeMiddleware>()
.Use<NodeMiddleware>()
.Use<ApplyExcludeTagMiddleware>()
.Use<ApplyTagDirectiveMiddleware>()
.Use<ApplyExcludeTagMiddleware>()
.Use<RemoveFusionTypesMiddleware>()
.Build();
_logFactory = logFactory;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,17 @@ namespace HotChocolate.Fusion.Composition.Pipeline;

internal sealed class ApplyTagDirectiveMiddleware : IMergeMiddleware
{
public async ValueTask InvokeAsync(CompositionContext context, MergeDelegate next)
public ValueTask InvokeAsync(CompositionContext context, MergeDelegate next)
{
if (context.Features.MakeTagsPublic())
{
Rewrite(context);
}

if (!context.Log.HasErrors)
{
await next(context);
}
Rewrite(context, context.Features.MakeTagsPublic());
return !context.Log.HasErrors
? next(context)
: ValueTask.CompletedTask;
}

private static void Rewrite(
CompositionContext context)
CompositionContext context,
bool makePublic)
{
var needsDirectiveType = false;

Expand Down Expand Up @@ -54,7 +50,7 @@ private static void Rewrite(
}

var tags = new HashSet<string>();
Rewrite(context, tagDirectiveType, tags);
Rewrite(context, tagDirectiveType, tags, makePublic);

if (context.GetTagContext().HasTags && needsDirectiveType)
{
Expand All @@ -65,38 +61,39 @@ private static void Rewrite(
private static void Rewrite(
CompositionContext context,
DirectiveType tagDirectiveType,
HashSet<string> tags)
HashSet<string> tags,
bool makePublic)
{
var tagContext = context.GetTagContext();

ApplyDirectives(tagContext, context.FusionGraph, context.Subgraphs, tagDirectiveType, tags);
ApplyDirectives(tagContext, context.FusionGraph, context.Subgraphs, tagDirectiveType, tags, makePublic);

foreach (var type in context.FusionGraph.Types)
{
switch (type)
{
case ObjectType objectType:
Rewrite(context, tagContext, objectType, tagDirectiveType, tags);
Rewrite(context, tagContext, objectType, tagDirectiveType, tags, makePublic);
break;

case InterfaceType interfaceType:
Rewrite(context, tagContext, interfaceType, tagDirectiveType, tags);
Rewrite(context, tagContext, interfaceType, tagDirectiveType, tags, makePublic);
break;

case UnionType unionType:
Rewrite(context, tagContext, unionType, tagDirectiveType, tags);
Rewrite(context, tagContext, unionType, tagDirectiveType, tags, makePublic);
break;

case InputObjectType inputObjectType:
Rewrite(context, tagContext, inputObjectType, tagDirectiveType, tags);
Rewrite(context, tagContext, inputObjectType, tagDirectiveType, tags, makePublic);
break;

case EnumType enumType:
Rewrite(context, tagContext, enumType, tagDirectiveType, tags);
Rewrite(context, tagContext, enumType, tagDirectiveType, tags, makePublic);
break;

case ScalarType scalarType:
Rewrite(context, tagContext, scalarType, tagDirectiveType, tags);
Rewrite(context, tagContext, scalarType, tagDirectiveType, tags, makePublic);
break;

default:
Expand All @@ -106,7 +103,7 @@ private static void Rewrite(

foreach (var directiveType in context.FusionGraph.DirectiveTypes)
{
Rewrite(context, tagContext, directiveType, tagDirectiveType, tags);
Rewrite(context, tagContext, directiveType, tagDirectiveType, tags, makePublic);
}
}

Expand All @@ -115,15 +112,16 @@ private static void Rewrite(
TagContext tagContext,
ComplexType type,
DirectiveType tagDirectiveType,
HashSet<string> tags)
HashSet<string> tags,
bool makePublic)
{
var coordinate = new SchemaCoordinate(type.Name);

ApplyDirectives(context, tagContext, type, coordinate, tagDirectiveType, tags);
ApplyDirectives(context, tagContext, type, coordinate, tagDirectiveType, tags, makePublic);

foreach (var field in type.Fields)
{
Rewrite(context, tagContext, field, coordinate, tagDirectiveType, tags);
Rewrite(context, tagContext, field, coordinate, tagDirectiveType, tags, makePublic);
}
}

Expand All @@ -132,23 +130,25 @@ private static void Rewrite(
TagContext tagContext,
UnionType type,
DirectiveType tagDirectiveType,
HashSet<string> tags)
=> ApplyDirectives(context, tagContext, type, new SchemaCoordinate(type.Name), tagDirectiveType, tags);
HashSet<string> tags,
bool makePublic)
=> ApplyDirectives(context, tagContext, type, new(type.Name), tagDirectiveType, tags, makePublic);

private static void Rewrite(
CompositionContext context,
TagContext tagContext,
InputObjectType type,
DirectiveType tagDirectiveType,
HashSet<string> tags)
HashSet<string> tags,
bool makePublic)
{
var coordinate = new SchemaCoordinate(type.Name);

ApplyDirectives(context, tagContext, type, coordinate, tagDirectiveType, tags);
ApplyDirectives(context, tagContext, type, coordinate, tagDirectiveType, tags, makePublic);

foreach (var field in type.Fields)
{
Rewrite(context, tagContext, field, coordinate, tagDirectiveType, tags);
Rewrite(context, tagContext, field, coordinate, tagDirectiveType, tags, makePublic);
}
}

Expand All @@ -157,15 +157,16 @@ private static void Rewrite(
TagContext tagContext,
EnumType type,
DirectiveType tagDirectiveType,
HashSet<string> tags)
HashSet<string> tags,
bool makePublic)
{
var coordinate = new SchemaCoordinate(type.Name);

ApplyDirectives(context, tagContext, type, coordinate, tagDirectiveType, tags);
ApplyDirectives(context, tagContext, type, coordinate, tagDirectiveType, tags, makePublic);

foreach (var field in type.Values)
{
Rewrite(context, tagContext, field, coordinate, tagDirectiveType, tags);
Rewrite(context, tagContext, field, coordinate, tagDirectiveType, tags, makePublic);
}
}

Expand All @@ -174,21 +175,23 @@ private static void Rewrite(
TagContext tagContext,
ScalarType type,
DirectiveType tagDirectiveType,
HashSet<string> tags)
=> ApplyDirectives(context, tagContext, type, new SchemaCoordinate(type.Name), tagDirectiveType, tags);
HashSet<string> tags,
bool makePublic)
=> ApplyDirectives(context, tagContext, type, new(type.Name), tagDirectiveType, tags, makePublic);

private static void Rewrite(
CompositionContext context,
TagContext tagContext,
DirectiveType type,
DirectiveType tagDirectiveType,
HashSet<string> tags)
HashSet<string> tags,
bool makePublic)
{
var coordinate = new SchemaCoordinate(type.Name, ofDirective: true);

foreach (var field in type.Arguments)
{
Rewrite(context, tagContext, field, coordinate, tagDirectiveType, tags);
Rewrite(context, tagContext, field, coordinate, tagDirectiveType, tags, makePublic);
}
}

Expand All @@ -198,15 +201,16 @@ private static void Rewrite(
OutputField field,
SchemaCoordinate parent,
DirectiveType tagDirectiveType,
HashSet<string> tags)
HashSet<string> tags,
bool makePublic)
{
var coordinate = new SchemaCoordinate(parent.Name, field.Name);

ApplyDirectives(context, tagContext, field, coordinate, tagDirectiveType, tags);
ApplyDirectives(context, tagContext, field, coordinate, tagDirectiveType, tags, makePublic);

foreach (var argument in field.Arguments)
{
Rewrite(context, tagContext, argument, coordinate, tagDirectiveType, tags);
Rewrite(context, tagContext, argument, coordinate, tagDirectiveType, tags, makePublic);
}
}

Expand All @@ -216,7 +220,8 @@ private static void Rewrite(
InputField field,
SchemaCoordinate parent,
DirectiveType tagDirectiveType,
HashSet<string> tags)
HashSet<string> tags,
bool makePublic)
{
var coordinate = parent switch
{
Expand All @@ -225,7 +230,7 @@ private static void Rewrite(
{ MemberName: not null } => new SchemaCoordinate(parent.Name, parent.MemberName, field.Name),
};

ApplyDirectives(context, tagContext, field, coordinate, tagDirectiveType, tags);
ApplyDirectives(context, tagContext, field, coordinate, tagDirectiveType, tags, makePublic);
}

private static void Rewrite(
Expand All @@ -234,26 +239,29 @@ private static void Rewrite(
EnumValue value,
SchemaCoordinate parent,
DirectiveType tagDirectiveType,
HashSet<string> tags)
HashSet<string> tags,
bool makePublic)
=> ApplyDirectives(
context,
tagContext,
value,
new SchemaCoordinate(parent.Name, value.Name),
tagDirectiveType,
tags);
tags,
makePublic);

private static void ApplyDirectives<T>(
CompositionContext context,
TagContext tagContext,
T merged,
SchemaCoordinate coordinate,
DirectiveType tagDirectiveType,
HashSet<string> tags)
HashSet<string> tags,
bool makePublic)
where T : ITypeSystemMember, IHasDirectives
{
var parts = context.GetSubgraphMembers<T>(coordinate);
ApplyDirectives(tagContext, merged, parts, tagDirectiveType, tags);
ApplyDirectives(tagContext, merged, parts, tagDirectiveType, tags, makePublic);

foreach (var tag in tags)
{
Expand All @@ -266,7 +274,8 @@ private static void ApplyDirectives<T>(
T merged,
IEnumerable<T> parts,
DirectiveType tagDirectiveType,
HashSet<string> tags)
HashSet<string> tags,
bool makePublic)
where T : ITypeSystemMember, IHasDirectives
{
tags.Clear();
Expand All @@ -288,6 +297,11 @@ private static void ApplyDirectives<T>(
value is StringValueNode name &&
tags.Add(name.Value))
{
if (!makePublic)
{
continue;
}

merged.Directives.Add(
new Directive(
tagDirectiveType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ internal sealed class TagContext
private readonly Dictionary<string, HashSet<SchemaCoordinate>> _taggedTypes =
new(StringComparer.Ordinal);

public bool HasTags { get; set; } = false;
public bool HasTags { get; set; }

public void RegisterTagCoordinate(string name, SchemaCoordinate coordinate)
{
Expand All @@ -16,10 +16,10 @@ public void RegisterTagCoordinate(string name, SchemaCoordinate coordinate)
}
else
{
_taggedTypes.Add(name, new HashSet<SchemaCoordinate> { coordinate });
_taggedTypes.Add(name, [coordinate]);
}
}

public IReadOnlySet<SchemaCoordinate> GetTagCoordinates(string name)
=> _taggedTypes.TryGetValue(name, out var coordinates) ? coordinates : _empty;
=> _taggedTypes.GetValueOrDefault(name, _empty);
}
Loading

0 comments on commit 8568047

Please sign in to comment.