Skip to content

Commit

Permalink
smallrye#521: implement federation data fetcher and type resolver
Browse files Browse the repository at this point in the history
  • Loading branch information
t1 authored and jmartisk committed Oct 3, 2022
1 parent 79efa0a commit 1586e7d
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import javax.json.Json;
Expand All @@ -28,8 +26,6 @@
import javax.json.bind.Jsonb;

import com.apollographql.federation.graphqljava.Federation;
import com.apollographql.federation.graphqljava._Entity;

import graphql.introspection.Introspection.DirectiveLocation;
import graphql.schema.DataFetcher;
import graphql.schema.FieldCoordinates;
Expand Down Expand Up @@ -144,10 +140,10 @@ private void verifyInjectionIsAvailable() {
// This crazy stream operation basically collects all class names where we need to verify that
// it belongs to an injectable bean
Stream.of(
schema.getQueries().stream().map(Operation::getClassName),
schema.getMutations().stream().map(Operation::getClassName),
schema.getGroupedQueries().values().stream().flatMap(Collection::stream).map(Operation::getClassName),
schema.getGroupedMutations().values().stream().flatMap(Collection::stream).map(Operation::getClassName))
schema.getQueries().stream().map(Operation::getClassName),
schema.getMutations().stream().map(Operation::getClassName),
schema.getGroupedQueries().values().stream().flatMap(Collection::stream).map(Operation::getClassName),
schema.getGroupedMutations().values().stream().flatMap(Collection::stream).map(Operation::getClassName))
.flatMap(stream -> stream)
.distinct().forEach(beanClassName -> {
// verify that the bean is injectable
Expand Down Expand Up @@ -191,23 +187,28 @@ private void generateGraphQLSchema() {
JsonInputRegistry.override(overrides);

if (Config.get().isFederationEnabled()) {
this.graphQLSchema = Federation.transform(schemaBuilder.build())
.fetchEntities(env -> env.<List<Map<String, Object>>> getArgument(_Entity.argumentName).stream()
.map(fetchEntities())
.collect(Collectors.toList()))
GraphQLSchema rawSchema = schemaBuilder.build();
this.graphQLSchema = Federation.transform(rawSchema)
.fetchEntities(new FederationDataFetcher(rawSchema.getQueryType(), rawSchema.getCodeRegistry()))
.resolveEntityType(fetchEntityType())
.build();
} else {
this.graphQLSchema = schemaBuilder.build();
}
}

private Function<Map<String, Object>, ?> fetchEntities() {
return reference -> null; // TODO federation: implement fetcher
}

private TypeResolver fetchEntityType() {
return env -> null; // TODO federation: implement type resolver
return env -> {
Object src = env.getObject();
if (src == null) {
return null;
}
GraphQLObjectType result = env.getSchema().getObjectType(src.getClass().getSimpleName()); // TODO respect @Name, etc.
if (result == null) {
throw new RuntimeException("can't resolve federated entity type " + src.getClass().getName());
}
return result;
};
}

private void createGraphQLDirectiveTypes() {
Expand Down Expand Up @@ -295,7 +296,7 @@ private void addSubscriptions(GraphQLSchema.Builder schemaBuilder) {
}

private void addRootObject(GraphQLObjectType.Builder rootBuilder, Set<Operation> operations,
String rootName) {
String rootName) {

for (Operation operation : operations) {
operation = eventEmitter.fireCreateOperation(operation);
Expand All @@ -306,7 +307,7 @@ private void addRootObject(GraphQLObjectType.Builder rootBuilder, Set<Operation>
}

private void addGroupedRootObject(GraphQLObjectType.Builder rootBuilder,
Map<Group, Set<Operation>> operationMap, String rootName) {
Map<Group, Set<Operation>> operationMap, String rootName) {
Set<Map.Entry<Group, Set<Operation>>> operationsSet = operationMap.entrySet();

for (Map.Entry<Group, Set<Operation>> operationsEntry : operationsSet) {
Expand Down Expand Up @@ -606,7 +607,7 @@ private GraphQLDirective createGraphQLDirectiveFrom(DirectiveInstance directiveI
}

private GraphQLFieldDefinition createGraphQLFieldDefinitionFromBatchOperation(String operationTypeName,
Operation operation) {
Operation operation) {
// Fields
GraphQLFieldDefinition.Builder fieldBuilder = GraphQLFieldDefinition.newFieldDefinition()
.name(operation.getName())
Expand Down Expand Up @@ -790,7 +791,8 @@ private GraphQLInputType createGraphQLInputType(Field field) {
graphQLInputType = GraphQLNonNull.nonNull(graphQLInputType);
}
// Collection depth
do {
do
{
if (wrapper.isCollectionOrArrayOrMap()) {
graphQLInputType = list(graphQLInputType);
wrapper = wrapper.getWrapper();
Expand Down Expand Up @@ -820,7 +822,8 @@ private GraphQLOutputType createGraphQLOutputType(Field field, boolean isBatch)
graphQLOutputType = GraphQLNonNull.nonNull(graphQLOutputType);
}
// Collection depth
do {
do
{
if (wrapper.isCollectionOrArrayOrMap()) {
graphQLOutputType = list(graphQLOutputType);
wrapper = wrapper.getWrapper();
Expand Down Expand Up @@ -920,7 +923,8 @@ private GraphQLArgument createGraphQLArgument(Argument argument) {
graphQLInputType = GraphQLNonNull.nonNull(graphQLInputType);
}
// Collection depth
do {
do
{
if (wrapper.isCollectionOrArrayOrMap()) {
graphQLInputType = list(graphQLInputType);
wrapper = wrapper.getWrapper();
Expand Down Expand Up @@ -992,7 +996,7 @@ private Object sanitizeDefaultValue(Field field) {
private boolean isJsonString(String string) {
if (string != null && !string.isEmpty() && (string.contains("{") || string.contains("["))) {
try (StringReader stringReader = new StringReader(string);
JsonReader jsonReader = jsonReaderFactory.createReader(stringReader)) {
JsonReader jsonReader = jsonReaderFactory.createReader(stringReader)) {

jsonReader.readValue();
return true;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package io.smallrye.graphql.bootstrap;

import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toSet;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

import com.apollographql.federation.graphqljava._Entity;

import graphql.schema.DataFetcher;
import graphql.schema.DataFetchingEnvironment;
import graphql.schema.DelegatingDataFetchingEnvironment;
import graphql.schema.GraphQLArgument;
import graphql.schema.GraphQLCodeRegistry;
import graphql.schema.GraphQLFieldDefinition;
import graphql.schema.GraphQLObjectType;
import graphql.schema.GraphQLOutputType;

class FederationDataFetcher implements DataFetcher<List<Object>> {

private final GraphQLObjectType queryType;
private final GraphQLCodeRegistry codeRegistry;

public FederationDataFetcher(GraphQLObjectType queryType, GraphQLCodeRegistry codeRegistry) {
this.queryType = queryType;
this.codeRegistry = codeRegistry;
}

@Override
public List<Object> get(DataFetchingEnvironment environment) throws Exception {
return environment.<List<Map<String, Object>>> getArgument(_Entity.argumentName).stream()
.map(representations -> fetchEntities(environment, representations))
.collect(toList());
}

private Object fetchEntities(DataFetchingEnvironment env, Map<String, Object> representations) {
Map<String, Object> requestedArgs = new HashMap<>(representations);
requestedArgs.remove("__typename");
String typename = (String) representations.get("__typename");
for (GraphQLFieldDefinition field : queryType.getFields()) {
if (matchesReturnType(field, typename) && matchesArguments(requestedArgs, field)) {
return execute(field, env, requestedArgs);
}
}
throw new RuntimeException("no query found for " + typename + " by " + requestedArgs.keySet());
}

private boolean matchesReturnType(GraphQLFieldDefinition field, String typename) {
GraphQLOutputType returnType = field.getType();
return returnType instanceof GraphQLObjectType && ((GraphQLObjectType) returnType).getName().equals(typename);
}

private boolean matchesArguments(Map<String, Object> requestedArguments, GraphQLFieldDefinition field) {
Set<String> argumentNames = field.getArguments().stream().map(GraphQLArgument::getName).collect(toSet());
return argumentNames.equals(requestedArguments.keySet());
}

private Object execute(GraphQLFieldDefinition field, DataFetchingEnvironment env, Map<String, Object> requestedArgs) {
DataFetcher<?> dataFetcher = codeRegistry.getDataFetcher(queryType, field);
DataFetchingEnvironment argsEnv = new DelegatingDataFetchingEnvironment(env) {
@Override
public Map<String, Object> getArguments() {
return requestedArgs;
}

@Override
public boolean containsArgument(String name) {
return requestedArgs.containsKey(name);
}

@Override
public <T> T getArgument(String name) {
//noinspection unchecked
return (T) requestedArgs.get(name);
}

@Override
public <T> T getArgumentOrDefault(String name, T defaultValue) {
return containsArgument(name) ? getArgument(name) : defaultValue;
}
};
try {
return dataFetcher.get(argsEnv);
} catch (Exception e) {
throw new RuntimeException("can't fetch data from " + field, e);
}
}
}

0 comments on commit 1586e7d

Please sign in to comment.