From a276ba83cf7e2aa1c0f81454591a794d6efb8c2a Mon Sep 17 00:00:00 2001 From: Arda TANRIKULU Date: Sat, 13 Jul 2024 18:01:41 +0300 Subject: [PATCH] fix(pruneSchema): respect directive definitions in extensions --- .changeset/pretty-foxes-check.md | 5 +++ packages/utils/src/prune.ts | 51 ++++++++++++++++------------ packages/utils/tests/prune.test.ts | 54 +++++++++++++++++++++++++++++- 3 files changed, 88 insertions(+), 22 deletions(-) create mode 100644 .changeset/pretty-foxes-check.md diff --git a/.changeset/pretty-foxes-check.md b/.changeset/pretty-foxes-check.md new file mode 100644 index 00000000000..97514553c3f --- /dev/null +++ b/.changeset/pretty-foxes-check.md @@ -0,0 +1,5 @@ +--- +"@graphql-tools/utils": patch +--- + +Respect directive extensions on \`pruneSchema\` diff --git a/packages/utils/src/prune.ts b/packages/utils/src/prune.ts index 6e69e49e82b..7f32b00cb01 100644 --- a/packages/utils/src/prune.ts +++ b/packages/utils/src/prune.ts @@ -1,5 +1,4 @@ import { - ASTNode, getNamedType, GraphQLFieldMap, GraphQLSchema, @@ -11,6 +10,7 @@ import { isSpecifiedScalarType, isUnionType, } from 'graphql'; +import { DirectableGraphQLObject } from './get-directives.js'; import { getImplementingTypes } from './get-implementing-types.js'; import { MapperKind } from './Interfaces.js'; import { mapSchema } from './mapSchema.js'; @@ -152,12 +152,7 @@ function visitQueue( if (isEnumType(type)) { // Visit enum values directives argument types queue.push( - ...type.getValues().flatMap(value => { - if (value.astNode) { - return getDirectivesArgumentsTypeNames(schema, value.astNode); - } - return []; - }), + ...type.getValues().flatMap(value => getDirectivesArgumentsTypeNames(schema, value)), ); } // Visit interfaces this type is implementing if they haven't been visited yet @@ -180,9 +175,7 @@ function visitQueue( queue.push( ...field.args.flatMap(arg => { const typeNames = [getNamedType(arg.type).name]; - if (arg.astNode) { - typeNames.push(...getDirectivesArgumentsTypeNames(schema, arg.astNode)); - } + typeNames.push(...getDirectivesArgumentsTypeNames(schema, arg)); return typeNames; }), ); @@ -192,9 +185,7 @@ function visitQueue( queue.push(namedType.name); - if (field.astNode) { - queue.push(...getDirectivesArgumentsTypeNames(schema, field.astNode)); - } + queue.push(...getDirectivesArgumentsTypeNames(schema, field)); // Interfaces returned on fields need to be revisited to add their implementations if (isInterfaceType(namedType) && !(namedType.name in revisit)) { @@ -203,9 +194,7 @@ function visitQueue( } } - if (type.astNode) { - queue.push(...getDirectivesArgumentsTypeNames(schema, type.astNode)); - } + queue.push(...getDirectivesArgumentsTypeNames(schema, type)); visited.add(typeName); // Mark as visited (and therefore it is used and should be kept) } @@ -215,10 +204,30 @@ function visitQueue( function getDirectivesArgumentsTypeNames( schema: GraphQLSchema, - astNode: Extract, + directableObj: DirectableGraphQLObject, ) { - return (astNode.directives ?? []).flatMap( - directive => - schema.getDirective(directive.name.value)?.args.map(arg => getNamedType(arg.type).name) ?? [], - ); + const argTypeNames = new Set(); + if (directableObj.astNode?.directives) { + for (const directiveNode of directableObj.astNode.directives) { + const directive = schema.getDirective(directiveNode.name.value); + if (directive?.args) { + for (const arg of directive.args) { + const argType = getNamedType(arg.type); + argTypeNames.add(argType.name); + } + } + } + } + if (directableObj.extensions?.['directives']) { + for (const directiveName in directableObj.extensions['directives']) { + const directive = schema.getDirective(directiveName); + if (directive?.args) { + for (const arg of directive.args) { + const argType = getNamedType(arg.type); + argTypeNames.add(argType.name); + } + } + } + } + return [...argTypeNames]; } diff --git a/packages/utils/tests/prune.test.ts b/packages/utils/tests/prune.test.ts index 09c07eb1da8..312356fb454 100644 --- a/packages/utils/tests/prune.test.ts +++ b/packages/utils/tests/prune.test.ts @@ -1,4 +1,13 @@ -import { buildSchema, GraphQLNamedType } from 'graphql'; +import { + buildSchema, + DirectiveLocation, + GraphQLDirective, + GraphQLEnumType, + GraphQLNamedType, + GraphQLObjectType, + GraphQLSchema, + GraphQLString, +} from 'graphql'; import { PruneSchemaFilter } from '../src/index.js'; import { pruneSchema } from '../src/prune.js'; @@ -487,6 +496,49 @@ describe('pruneSchema', () => { } `); + const result = pruneSchema(schema); + expect(result.getType('DirectiveArg')).toBeDefined(); + }); + test('does not remove type used in argument definition directive argument from extensions', () => { + const enumType = new GraphQLEnumType({ + name: 'DirectiveArg', + values: { + VALUE: { + value: 'VALUE', + }, + }, + }); + const schema = new GraphQLSchema({ + query: new GraphQLObjectType({ + name: 'Query', + fields: { + foo: { + type: GraphQLString, + extensions: { + directives: { + bar: [ + { + arg: 'VALUE', + }, + ], + }, + }, + }, + }, + }), + directives: [ + new GraphQLDirective({ + name: 'bar', + locations: [DirectiveLocation.FIELD], + args: { + arg: { + type: enumType, + }, + }, + }), + ], + }); + const result = pruneSchema(schema); expect(result.getType('DirectiveArg')).toBeDefined(); });