From 1586e7df4107aa1a1e053db8f30affca19ffd192 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=BCdiger=20zu=20Dohna?= Date: Fri, 29 Oct 2021 16:42:38 +0200 Subject: [PATCH] #521: implement federation data fetcher and type resolver --- .../smallrye/graphql/bootstrap/Bootstrap.java | 52 ++++++----- .../bootstrap/FederationDataFetcher.java | 91 +++++++++++++++++++ 2 files changed, 119 insertions(+), 24 deletions(-) create mode 100644 server/implementation/src/main/java/io/smallrye/graphql/bootstrap/FederationDataFetcher.java diff --git a/server/implementation/src/main/java/io/smallrye/graphql/bootstrap/Bootstrap.java b/server/implementation/src/main/java/io/smallrye/graphql/bootstrap/Bootstrap.java index 65f34fcfa..de559433d 100644 --- a/server/implementation/src/main/java/io/smallrye/graphql/bootstrap/Bootstrap.java +++ b/server/implementation/src/main/java/io/smallrye/graphql/bootstrap/Bootstrap.java @@ -18,8 +18,6 @@ import java.util.Map.Entry; import java.util.Optional; import java.util.Set; -import java.util.function.Function; -import java.util.stream.Collectors; import java.util.stream.Stream; import javax.json.Json; @@ -28,8 +26,6 @@ import javax.json.bind.Jsonb; import com.apollographql.federation.graphqljava.Federation; -import com.apollographql.federation.graphqljava._Entity; - import graphql.introspection.Introspection.DirectiveLocation; import graphql.schema.DataFetcher; import graphql.schema.FieldCoordinates; @@ -144,10 +140,10 @@ private void verifyInjectionIsAvailable() { // This crazy stream operation basically collects all class names where we need to verify that // it belongs to an injectable bean Stream.of( - schema.getQueries().stream().map(Operation::getClassName), - schema.getMutations().stream().map(Operation::getClassName), - schema.getGroupedQueries().values().stream().flatMap(Collection::stream).map(Operation::getClassName), - schema.getGroupedMutations().values().stream().flatMap(Collection::stream).map(Operation::getClassName)) + schema.getQueries().stream().map(Operation::getClassName), + schema.getMutations().stream().map(Operation::getClassName), + schema.getGroupedQueries().values().stream().flatMap(Collection::stream).map(Operation::getClassName), + schema.getGroupedMutations().values().stream().flatMap(Collection::stream).map(Operation::getClassName)) .flatMap(stream -> stream) .distinct().forEach(beanClassName -> { // verify that the bean is injectable @@ -191,10 +187,9 @@ private void generateGraphQLSchema() { JsonInputRegistry.override(overrides); if (Config.get().isFederationEnabled()) { - this.graphQLSchema = Federation.transform(schemaBuilder.build()) - .fetchEntities(env -> env.>> getArgument(_Entity.argumentName).stream() - .map(fetchEntities()) - .collect(Collectors.toList())) + GraphQLSchema rawSchema = schemaBuilder.build(); + this.graphQLSchema = Federation.transform(rawSchema) + .fetchEntities(new FederationDataFetcher(rawSchema.getQueryType(), rawSchema.getCodeRegistry())) .resolveEntityType(fetchEntityType()) .build(); } else { @@ -202,12 +197,18 @@ private void generateGraphQLSchema() { } } - private Function, ?> fetchEntities() { - return reference -> null; // TODO federation: implement fetcher - } - private TypeResolver fetchEntityType() { - return env -> null; // TODO federation: implement type resolver + return env -> { + Object src = env.getObject(); + if (src == null) { + return null; + } + GraphQLObjectType result = env.getSchema().getObjectType(src.getClass().getSimpleName()); // TODO respect @Name, etc. + if (result == null) { + throw new RuntimeException("can't resolve federated entity type " + src.getClass().getName()); + } + return result; + }; } private void createGraphQLDirectiveTypes() { @@ -295,7 +296,7 @@ private void addSubscriptions(GraphQLSchema.Builder schemaBuilder) { } private void addRootObject(GraphQLObjectType.Builder rootBuilder, Set operations, - String rootName) { + String rootName) { for (Operation operation : operations) { operation = eventEmitter.fireCreateOperation(operation); @@ -306,7 +307,7 @@ private void addRootObject(GraphQLObjectType.Builder rootBuilder, Set } private void addGroupedRootObject(GraphQLObjectType.Builder rootBuilder, - Map> operationMap, String rootName) { + Map> operationMap, String rootName) { Set>> operationsSet = operationMap.entrySet(); for (Map.Entry> operationsEntry : operationsSet) { @@ -606,7 +607,7 @@ private GraphQLDirective createGraphQLDirectiveFrom(DirectiveInstance directiveI } private GraphQLFieldDefinition createGraphQLFieldDefinitionFromBatchOperation(String operationTypeName, - Operation operation) { + Operation operation) { // Fields GraphQLFieldDefinition.Builder fieldBuilder = GraphQLFieldDefinition.newFieldDefinition() .name(operation.getName()) @@ -790,7 +791,8 @@ private GraphQLInputType createGraphQLInputType(Field field) { graphQLInputType = GraphQLNonNull.nonNull(graphQLInputType); } // Collection depth - do { + do + { if (wrapper.isCollectionOrArrayOrMap()) { graphQLInputType = list(graphQLInputType); wrapper = wrapper.getWrapper(); @@ -820,7 +822,8 @@ private GraphQLOutputType createGraphQLOutputType(Field field, boolean isBatch) graphQLOutputType = GraphQLNonNull.nonNull(graphQLOutputType); } // Collection depth - do { + do + { if (wrapper.isCollectionOrArrayOrMap()) { graphQLOutputType = list(graphQLOutputType); wrapper = wrapper.getWrapper(); @@ -920,7 +923,8 @@ private GraphQLArgument createGraphQLArgument(Argument argument) { graphQLInputType = GraphQLNonNull.nonNull(graphQLInputType); } // Collection depth - do { + do + { if (wrapper.isCollectionOrArrayOrMap()) { graphQLInputType = list(graphQLInputType); wrapper = wrapper.getWrapper(); @@ -992,7 +996,7 @@ private Object sanitizeDefaultValue(Field field) { private boolean isJsonString(String string) { if (string != null && !string.isEmpty() && (string.contains("{") || string.contains("["))) { try (StringReader stringReader = new StringReader(string); - JsonReader jsonReader = jsonReaderFactory.createReader(stringReader)) { + JsonReader jsonReader = jsonReaderFactory.createReader(stringReader)) { jsonReader.readValue(); return true; diff --git a/server/implementation/src/main/java/io/smallrye/graphql/bootstrap/FederationDataFetcher.java b/server/implementation/src/main/java/io/smallrye/graphql/bootstrap/FederationDataFetcher.java new file mode 100644 index 000000000..adfc7d18b --- /dev/null +++ b/server/implementation/src/main/java/io/smallrye/graphql/bootstrap/FederationDataFetcher.java @@ -0,0 +1,91 @@ +package io.smallrye.graphql.bootstrap; + +import static java.util.stream.Collectors.toList; +import static java.util.stream.Collectors.toSet; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import com.apollographql.federation.graphqljava._Entity; + +import graphql.schema.DataFetcher; +import graphql.schema.DataFetchingEnvironment; +import graphql.schema.DelegatingDataFetchingEnvironment; +import graphql.schema.GraphQLArgument; +import graphql.schema.GraphQLCodeRegistry; +import graphql.schema.GraphQLFieldDefinition; +import graphql.schema.GraphQLObjectType; +import graphql.schema.GraphQLOutputType; + +class FederationDataFetcher implements DataFetcher> { + + private final GraphQLObjectType queryType; + private final GraphQLCodeRegistry codeRegistry; + + public FederationDataFetcher(GraphQLObjectType queryType, GraphQLCodeRegistry codeRegistry) { + this.queryType = queryType; + this.codeRegistry = codeRegistry; + } + + @Override + public List get(DataFetchingEnvironment environment) throws Exception { + return environment.>> getArgument(_Entity.argumentName).stream() + .map(representations -> fetchEntities(environment, representations)) + .collect(toList()); + } + + private Object fetchEntities(DataFetchingEnvironment env, Map representations) { + Map requestedArgs = new HashMap<>(representations); + requestedArgs.remove("__typename"); + String typename = (String) representations.get("__typename"); + for (GraphQLFieldDefinition field : queryType.getFields()) { + if (matchesReturnType(field, typename) && matchesArguments(requestedArgs, field)) { + return execute(field, env, requestedArgs); + } + } + throw new RuntimeException("no query found for " + typename + " by " + requestedArgs.keySet()); + } + + private boolean matchesReturnType(GraphQLFieldDefinition field, String typename) { + GraphQLOutputType returnType = field.getType(); + return returnType instanceof GraphQLObjectType && ((GraphQLObjectType) returnType).getName().equals(typename); + } + + private boolean matchesArguments(Map requestedArguments, GraphQLFieldDefinition field) { + Set argumentNames = field.getArguments().stream().map(GraphQLArgument::getName).collect(toSet()); + return argumentNames.equals(requestedArguments.keySet()); + } + + private Object execute(GraphQLFieldDefinition field, DataFetchingEnvironment env, Map requestedArgs) { + DataFetcher dataFetcher = codeRegistry.getDataFetcher(queryType, field); + DataFetchingEnvironment argsEnv = new DelegatingDataFetchingEnvironment(env) { + @Override + public Map getArguments() { + return requestedArgs; + } + + @Override + public boolean containsArgument(String name) { + return requestedArgs.containsKey(name); + } + + @Override + public T getArgument(String name) { + //noinspection unchecked + return (T) requestedArgs.get(name); + } + + @Override + public T getArgumentOrDefault(String name, T defaultValue) { + return containsArgument(name) ? getArgument(name) : defaultValue; + } + }; + try { + return dataFetcher.get(argsEnv); + } catch (Exception e) { + throw new RuntimeException("can't fetch data from " + field, e); + } + } +}