Skip to content

Commit

Permalink
Merge branch 'main' into gai/cookie-crumble-xunit-v3
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelstaib authored Dec 18, 2024
2 parents ce35817 + c31b5e5 commit 2d9de20
Show file tree
Hide file tree
Showing 23 changed files with 863 additions and 111 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using System.Collections;
using System.Diagnostics.CodeAnalysis;
using System.Collections.Immutable;
using HotChocolate.Fusion.Types;
using HotChocolate.Language;
Expand All @@ -14,7 +16,8 @@ public DocumentNode RewriteDocument(DocumentNode document, string? operationName
var fragmentLookup = CreateFragmentLookup(document);
var context = new Context(operationType, fragmentLookup);

RewriteFields(operation.SelectionSet, context);
CollectSelections(operation.SelectionSet, context);
RewriteSelections(context);

var newSelectionSet = new SelectionSetNode(
null,
Expand All @@ -31,14 +34,38 @@ public DocumentNode RewriteDocument(DocumentNode document, string? operationName
return new DocumentNode(ImmutableArray<IDefinitionNode>.Empty.Add(newOperation));
}

private void RewriteFields(SelectionSetNode selectionSet, Context context)
internal void CollectSelections(SelectionSetNode selectionSet, Context context)
{
foreach (var selection in selectionSet.Selections)
{
switch (selection)
{
case FieldNode field:
RewriteField(field, context);
context.AddField(field);
break;

case InlineFragmentNode inlineFragment:
CollectInlineFragment(inlineFragment, context);
break;

case FragmentSpreadNode fragmentSpread:
CollectFragmentSpread(fragmentSpread, context);
break;
}
}
}

internal void RewriteSelections(Context context)
{
var collectedSelections = context.Selections.ToImmutableArray();
context.Selections.Clear();

foreach (var selection in collectedSelections)
{
switch (selection)
{
case FieldNode field:
MergeField(field.ResponseName(), context);
break;

case InlineFragmentNode inlineFragment:
Expand All @@ -50,6 +77,23 @@ private void RewriteFields(SelectionSetNode selectionSet, Context context)
break;
}
}

void MergeField(string fieldName, Context ctx)
{
foreach (var field in ctx.Fields[fieldName].GroupBy(t => t, t => t, FieldComparer.Instance))
{
var mergedField = field.Key;

if (mergedField.SelectionSet is not null)
{
mergedField = mergedField.WithSelectionSet(
new SelectionSetNode(
field.SelectMany(t => t.SelectionSet!.Selections).ToList()));
}

RewriteField(mergedField, ctx);
}
}
}

private void RewriteField(FieldNode fieldNode, Context context)
Expand All @@ -65,10 +109,11 @@ private void RewriteField(FieldNode fieldNode, Context context)
}
else
{
var field = ((CompositeComplexType)context.Type).Fields[fieldNode.Name.Value];
var field = ((CompositeComplexType)context.Type).Fields[fieldNode.ResponseName()];
var fieldContext = context.Branch(field.Type.NamedType());

RewriteFields(fieldNode.SelectionSet, fieldContext);
CollectSelections(fieldNode.SelectionSet, fieldContext);
RewriteSelections(fieldContext);

var newSelectionSetNode = new SelectionSetNode(
null,
Expand All @@ -89,23 +134,29 @@ private void RewriteField(FieldNode fieldNode, Context context)
}
}

private void RewriteInlineFragment(InlineFragmentNode inlineFragment, Context context)
private void CollectInlineFragment(InlineFragmentNode inlineFragment, Context context)
{
if ((inlineFragment.TypeCondition is null
|| inlineFragment.TypeCondition.Name.Value.Equals(context.Type.Name, StringComparison.Ordinal))
if ((inlineFragment.TypeCondition is null
|| inlineFragment.TypeCondition.Name.Value.Equals(context.Type.Name, StringComparison.Ordinal))
&& inlineFragment.Directives.Count == 0)
{
RewriteFields(inlineFragment.SelectionSet, context);
CollectSelections(inlineFragment.SelectionSet, context);
return;
}

context.AddInlineFragment(inlineFragment);
}

private void RewriteInlineFragment(InlineFragmentNode inlineFragment, Context context)
{
var typeCondition = inlineFragment.TypeCondition is null
? context.Type
: schema.GetType(inlineFragment.TypeCondition.Name.Value);

var inlineFragmentContext = context.Branch(typeCondition);

RewriteFields(inlineFragment.SelectionSet, inlineFragmentContext);
CollectSelections(inlineFragment.SelectionSet, inlineFragmentContext);
RewriteSelections(inlineFragmentContext);

var newSelectionSetNode = new SelectionSetNode(
null,
Expand All @@ -120,7 +171,7 @@ private void RewriteInlineFragment(InlineFragmentNode inlineFragment, Context co
context.Selections.Add(newInlineFragment);
}

private void InlineFragmentDefinition(
private void CollectFragmentSpread(
FragmentSpreadNode fragmentSpread,
Context context)
{
Expand All @@ -130,28 +181,37 @@ private void InlineFragmentDefinition(
if (fragmentSpread.Directives.Count == 0
&& typeCondition.IsAssignableFrom(context.Type))
{
RewriteFields(fragmentDefinition.SelectionSet, context);
CollectSelections(fragmentDefinition.SelectionSet, context);
return;
}
else
{
var fragmentContext = context.Branch(typeCondition);

RewriteFields(fragmentDefinition.SelectionSet, fragmentContext);
context.AddFragmentSpread(fragmentSpread);
}

var selectionSet = new SelectionSetNode(
null,
fragmentContext.Selections.ToImmutable());
private void InlineFragmentDefinition(
FragmentSpreadNode fragmentSpread,
Context context)
{
var fragmentDefinition = context.GetFragmentDefinition(fragmentSpread.Name.Value);
var typeCondition = schema.GetType(fragmentDefinition.TypeCondition.Name.Value);
var fragmentContext = context.Branch(typeCondition);

var inlineFragment = new InlineFragmentNode(
null,
new NamedTypeNode(typeCondition.Name),
RewriteDirectives(fragmentSpread.Directives),
selectionSet);
CollectSelections(fragmentDefinition.SelectionSet, fragmentContext);
RewriteSelections(fragmentContext);

if (context.Visited.Add(inlineFragment))
{
context.Selections.Add(inlineFragment);
}
var selectionSet = new SelectionSetNode(
null,
fragmentContext.Selections.ToImmutable());

var inlineFragment = new InlineFragmentNode(
null,
new NamedTypeNode(typeCondition.Name),
RewriteDirectives(fragmentSpread.Directives),
selectionSet);

if (context.Visited.Add(inlineFragment))
{
context.Selections.Add(inlineFragment);
}
}

Expand All @@ -175,6 +235,7 @@ private IReadOnlyList<DirectiveNode> RewriteDirectives(IReadOnlyList<DirectiveNo
var directive = directives[i];
buffer[i] = new DirectiveNode(directive.Name.Value, RewriteArguments(directive.Arguments));
}

return ImmutableArray.Create(buffer);
}

Expand All @@ -195,6 +256,7 @@ private IReadOnlyList<ArgumentNode> RewriteArguments(IReadOnlyList<ArgumentNode>
{
buffer[i] = arguments[i].WithLocation(null);
}

return ImmutableArray.Create(buffer);
}

Expand Down Expand Up @@ -224,10 +286,103 @@ public readonly ref struct Context(

public HashSet<ISelectionNode> Visited { get; } = new(SyntaxComparer.BySyntax);

public Dictionary<string, List<FieldNode>> Fields { get; } = new(StringComparer.Ordinal);

public FragmentDefinitionNode GetFragmentDefinition(string name)
=> fragments[name];

public void AddField(FieldNode field)
{
var responseName = field.ResponseName();
if (!Fields.TryGetValue(responseName, out var fields))
{
fields = [];
Fields.Add(responseName, fields);
Selections.Add(field);
}

fields.Add(field);
}

public void AddInlineFragment(InlineFragmentNode inlineFragment)
{
Selections.Add(inlineFragment);
}

public void AddFragmentSpread(FragmentSpreadNode fragmentSpread)
{
Selections.Add(fragmentSpread);
}

public Context Branch(ICompositeNamedType type)
=> new(type, fragments);
}

private sealed class FieldComparer : IEqualityComparer<FieldNode>
{
public bool Equals(FieldNode? x, FieldNode? y)
{
if (ReferenceEquals(x, y))
{
return true;
}

if (x is null)
{
return false;
}

if (y is null)
{
return false;
}

return Equals(x.Alias, y.Alias)
&& x.Name.Equals(y.Name)
&& Equals(x.Directives, y.Directives)
&& Equals(x.Arguments, y.Arguments);
}

private bool Equals(IReadOnlyList<ISyntaxNode> a, IReadOnlyList<ISyntaxNode> b)
{
if (a.Count == 0 && b.Count == 0)
{
return true;
}

return a.SequenceEqual(b, SyntaxComparer.BySyntax);
}

public int GetHashCode(FieldNode obj)
{
var hashCode = new HashCode();

if (obj.Alias is not null)
{
hashCode.Add(obj.Alias.Value);
}

hashCode.Add(obj.Name.Value);

for (var i = 0; i < obj.Directives.Count; i++)
{
hashCode.Add(SyntaxComparer.BySyntax.GetHashCode(obj.Directives[i]));
}

for (var i = 0; i < obj.Arguments.Count; i++)
{
hashCode.Add(SyntaxComparer.BySyntax.GetHashCode(obj.Arguments[i]));
}

return hashCode.ToHashCode();
}

public static FieldComparer Instance { get; } = new();
}
}

file static class FileExtensions
{
public static string ResponseName(this FieldNode field)
=> field.Alias?.Value ?? field.Name.Value;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using HotChocolate.Fusion.Types;
using HotChocolate.Language;

namespace HotChocolate.Fusion.Planning;

public class MergeSelectionSetRewriter(CompositeSchema schema)
{
private readonly InlineFragmentOperationRewriter _rewriter = new(schema);

public SelectionSetNode RewriteSelectionSets(
IReadOnlyList<SelectionSetNode> selectionSets,
ICompositeNamedType type)
{
var context = new InlineFragmentOperationRewriter.Context(
type,
new Dictionary<string, FragmentDefinitionNode>());

var merged = new SelectionSetNode(
null,
selectionSets.SelectMany(t => t.Selections).ToList());

_rewriter.CollectSelections(merged, context);
_rewriter.RewriteSelections(context);

return new SelectionSetNode(
null,
context.Selections.ToImmutable());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ public abstract class SelectionPlanNode : PlanNode
{
private List<CompositeDirective>? _directives;
private List<SelectionPlanNode>? _selections;
private List<SelectionSetNode>? _requirements;
private bool? _isConditional;
private string? _skipVariable;
private string? _includeVariable;
Expand Down Expand Up @@ -69,6 +70,12 @@ public IReadOnlyList<CompositeDirective> Directives
public IReadOnlyList<SelectionPlanNode> Selections
=> _selections ?? (IReadOnlyList<SelectionPlanNode>)Array.Empty<SelectionPlanNode>();

/// <summary>
/// Gets the requirements that are needed to execute this selection.
/// </summary>
public IReadOnlyList<SelectionSetNode> RequirementNodes
=> _requirements ?? (IReadOnlyList<SelectionSetNode>)Array.Empty<SelectionSetNode>();

/// <summary>
/// Defines if the selection is conditional.
/// </summary>
Expand Down Expand Up @@ -150,6 +157,12 @@ public void AddDirective(CompositeDirective directive)
public bool RemoveDirective(CompositeDirective directive)
=> _directives?.Remove(directive) == true;

public void AddRequirementNode(SelectionSetNode selectionSet)
{
ArgumentNullException.ThrowIfNull(selectionSet);
(_requirements ??= []).Add(selectionSet);
}

private void InitializeConditions()
{
if (_isConditional.HasValue)
Expand Down
Loading

0 comments on commit 2d9de20

Please sign in to comment.