Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow @typePolicy directive on interfaces and unions #4131

Merged
merged 2 commits into from
May 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.apollographql.apollo3.ast

import com.apollographql.apollo3.ast.internal.buffer
import com.apollographql.apollo3.annotations.ApolloDeprecatedSince

/**
* A wrapper around a schema GQLDocument that:
Expand All @@ -11,10 +11,16 @@ import com.apollographql.apollo3.ast.internal.buffer
* - has some helper functions to retrieve a type by name and/or possible types
*
* @param definitions a list of validated and merged definitions
* @param keyFields a Map containing the key fields for each type
*/
class Schema(
class Schema internal constructor(
private val definitions: List<GQLDefinition>,
private val keyFields: Map<String, Set<String>>
) {
@Deprecated("Use toSchema() to get a Schema")
@ApolloDeprecatedSince(ApolloDeprecatedSince.Version.v3_3_1)
constructor(definitions: List<GQLDefinition>): this(definitions, emptyMap())

val typeDefinitions: Map<String, GQLTypeDefinition> = definitions
.filterIsInstance<GQLTypeDefinition>()
.associateBy { it.name }
Expand Down Expand Up @@ -110,71 +116,17 @@ class Schema(
* Returns whether the `typePolicy` directive is present on at least one object in the schema
*/
fun hasTypeWithTypePolicy(): Boolean {
return typeDefinitions.values.filterIsInstance<GQLObjectTypeDefinition>().any { objectType ->
objectType.directives.any { it.name == TYPE_POLICY }
}
val directives = typeDefinitions.values.filterIsInstance<GQLObjectTypeDefinition>().flatMap { it.directives } +
typeDefinitions.values.filterIsInstance<GQLInterfaceTypeDefinition>().flatMap { it.directives } +
typeDefinitions.values.filterIsInstance<GQLUnionTypeDefinition>().flatMap { it.directives }
return directives.any { it.name == TYPE_POLICY }
}

/**
* Returns the key fields for the given type
*
* If this type has one or multiple @[TYPE_POLICY] annotation(s), they are used, else it recurses in implemented interfaces until it
* finds some.
*
* Returns the emptySet if this type has no key fields.
* Get the key fields for an object, interface or union type.
*/
fun keyFields(name: String): Set<String> {
val typeDefinition = typeDefinition(name)
return when (typeDefinition) {
is GQLObjectTypeDefinition -> {
val kf = typeDefinition.directives.toKeyFields()
if (kf != null) {
kf
} else {
val kfs = typeDefinition.implementsInterfaces.map { it to keyFields(it) }.filter { it.second.isNotEmpty() }
if (kfs.isNotEmpty()) {
check(kfs.size == 1) {
val candidates = kfs.map { "${it.first}: ${it.second}" }.joinToString("\n")
"Object '$name' inherits different keys from different interfaces:\n$candidates\nSpecify @$TYPE_POLICY explicitly"
}
}
kfs.singleOrNull()?.second ?: emptySet()
}
}
is GQLInterfaceTypeDefinition -> {
val kf = typeDefinition.directives.toKeyFields()
if (kf != null) {
kf
} else {
val kfs = typeDefinition.implementsInterfaces.map { it to keyFields(it) }.filter { it.second.isNotEmpty() }
if (kfs.isNotEmpty()) {
check(kfs.size == 1) {
val candidates = kfs.map { "${it.first}: ${it.second}" }.joinToString("\n")
"Interface '$name' inherits different keys from different interfaces:\n$candidates\nSpecify @$TYPE_POLICY explicitly"
}
}
kfs.singleOrNull()?.second ?: emptySet()
}
}
is GQLUnionTypeDefinition -> typeDefinition.directives.toKeyFields() ?: emptySet()
else -> error("Type '$name' cannot have key fields")
}
}

/**
* Returns the key Fields or null if there's no directive
*/
private fun List<GQLDirective>.toKeyFields(): Set<String>? {
val directives = filter { it.name == TYPE_POLICY }
if (directives.isEmpty()) {
return null
}
return directives.flatMap {
(it.arguments!!.arguments.first().value as GQLStringValue).value.buffer().parseAsGQLSelections().valueAssertNoErrors().map { gqlSelection ->
// No need to check here, this should be done during validation
(gqlSelection as GQLField).name
}
}.toSet()
return keyFields[name] ?: emptySet()
}

companion object {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import com.apollographql.apollo3.ast.internal.toGQLDocument
import com.apollographql.apollo3.ast.internal.toGQLSelection
import com.apollographql.apollo3.ast.internal.toGQLValue
import com.apollographql.apollo3.ast.internal.validateDocumentAndMergeExtensions
import com.apollographql.apollo3.ast.internal.validateKeyFields
import okio.BufferedSource

/**
Expand Down Expand Up @@ -83,7 +84,8 @@ fun GQLDocument.validateAsSchema(): GQLResult<Schema> {
val schema = if (scope.issues.containsError()) {
null
} else {
Schema(mergedDefinitions)
val keyFields = scope.validateKeyFields(mergedDefinitions)
Schema(mergedDefinitions, keyFields)
}
return GQLResult(schema, scope.issues)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,7 @@ data class GQLInterfaceTypeExtension(
override val sourceLocation: SourceLocation = SourceLocation.UNKNOWN,
override val name: String,
val implementsInterfaces: List<String>,
val directives: List<GQLDirective>,
val fields: List<GQLFieldDefinition>,
) : GQLDefinition, GQLTypeExtension, GQLNamed {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,13 @@ fun GQLTypeDefinition.implementsAbstractType(schema: Schema): Boolean {
schema.typeDefinition(it).isAbstract()
}
}

fun GQLTypeDefinition.canHaveKeyFields(): Boolean {
return when (this) {
is GQLObjectTypeDefinition,
is GQLInterfaceTypeDefinition,
is GQLUnionTypeDefinition
-> true
else -> false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ internal class SchemaValidationScope(document: GQLDocument) : ValidationScope {
override val issues = mutableListOf<Issue>()

val documentDefinitions = document.definitions

/**
* The builtin definitions are required to validate directives amongst other
* things so add them early in the validation proccess.
Expand Down Expand Up @@ -97,7 +98,7 @@ internal fun SchemaValidationScope.validateDocumentAndMergeExtensions(): List<GQ

val schemaDefinition = schemaDefinition ?: syntheticSchemaDefinition()

return mergeExtensions(listOf(schemaDefinition) + allDefinitions.filter { it !is GQLSchemaDefinition } )
return mergeExtensions(listOf(schemaDefinition) + allDefinitions.filter { it !is GQLSchemaDefinition })
}

internal fun SchemaValidationScope.validateRootOperationTypes() {
Expand Down Expand Up @@ -156,7 +157,7 @@ private fun ValidationScope.validateInterfaces() {
}
}

private fun ValidationScope.validateObjects() {
private fun SchemaValidationScope.validateObjects() {
typeDefinitions.values.filterIsInstance<GQLObjectTypeDefinition>().forEach { o ->
if (o.fields.isEmpty()) {
registerIssue("Object must specify one or more fields", o.sourceLocation)
Expand Down Expand Up @@ -185,3 +186,79 @@ private fun SchemaValidationScope.validateNoIntrospectionNames() {
}
}

private fun SchemaValidationScope.keyFields(
typeDefinition: GQLTypeDefinition,
allTypeDefinition: Map<String, GQLTypeDefinition>,
keyFieldsCache: MutableMap<String, Set<String>>,
): Set<String> {
val cached = keyFieldsCache[typeDefinition.name]
if (cached != null) {
return cached
}

val (directives, interfaces) = when (typeDefinition) {
is GQLObjectTypeDefinition -> typeDefinition.directives to typeDefinition.implementsInterfaces
is GQLInterfaceTypeDefinition -> typeDefinition.directives to typeDefinition.implementsInterfaces
is GQLUnionTypeDefinition -> typeDefinition.directives to emptyList()
else -> error("Cannot get directives for $typeDefinition")
}

val interfacesKeyFields = interfaces.map { keyFields(allTypeDefinition[it]!!, allTypeDefinition, keyFieldsCache) }
val distinct = interfacesKeyFields.distinct()
if (distinct.size > 1) {
val extra = interfaces.indices.map {
"${interfaces[it]}: ${interfacesKeyFields[it]}"
}.joinToString("\n")
registerIssue(
message = "Apollo: Type '${typeDefinition.name}' cannot inherit different keys from different interfaces:\n$extra",
sourceLocation = typeDefinition.sourceLocation
)
}

val keyFields = directives.toKeyFields()
val ret = if (keyFields != null) {
if (!distinct.isEmpty()) {
val extra = interfaces.indices.map {
"${interfaces[it]}: ${interfacesKeyFields[it]}"
}.joinToString("\n")
registerIssue(
message = "Type '${typeDefinition.name}' cannot have key fields since it implements the following interfaces which also have key fields: $extra",
sourceLocation = typeDefinition.sourceLocation
)
}
keyFields
} else {
distinct.firstOrNull() ?: emptySet()
}

keyFieldsCache[typeDefinition.name] = ret

return ret
}

private fun List<GQLDirective>.toKeyFields(): Set<String>? {
val directives = filter { it.name == Schema.TYPE_POLICY }
if (directives.isEmpty()) {
return null
}
return directives.flatMap {
(it.arguments!!.arguments.first().value as GQLStringValue).value.buffer().parseAsGQLSelections().valueAssertNoErrors().map { gqlSelection ->
// No need to check here, this should be done during validation
(gqlSelection as GQLField).name
}
}.toSet()
}

/**
* To prevent surprising behaviour, objects that declare key fields that also implement interfaces that declare key fields are an error
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Nitpick) maybe the comment could be moved to the method above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added another comment on the method above and kept this one as it's the entry point of validation.

*
* @see <a href="https://github.com/apollographql/apollo-kotlin/issues/3356#issuecomment-1134381986">Discussion</a>
*/
internal fun SchemaValidationScope.validateKeyFields(mergedDefinitions: List<GQLDefinition>): Map<String, Set<String>> {
val keyFieldsCache = mutableMapOf<String, Set<String>>()
val typeDefinitions = mergedDefinitions.filterIsInstance<GQLTypeDefinition>().filter { it.canHaveKeyFields() }
typeDefinitions.forEach {
keyFields(it, typeDefinitions.associateBy { it.name }, keyFieldsCache)
}
return keyFieldsCache
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ private fun ValidationScope.merge(
): GQLInterfaceTypeDefinition = with(interfaceTypeDefinition) {
return copy(
fields = mergeUniquesOrThrow(fields, extension.fields),
directives = mergeDirectives(directives, extension.directives),
implementsInterfaces = mergeUniqueInterfacesOrThrow(implementsInterfaces, extension.implementsInterfaces, extension.sourceLocation)
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ private class AntlrToGQLScope(val filePath: String?) {
sourceLocation = sourceLocation(start),
name = name().text,
implementsInterfaces = implementsInterfaces().parse(),
directives = directives().parse(),
fields = fieldsDefinition().parse()
)
}
Expand Down
2 changes: 1 addition & 1 deletion apollo-ast/src/main/resources/apollo.graphqls
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ directive @nonnull(fields: String! = "") on OBJECT | FIELD

# Marks fields as key fields. Key fields are used to compute the cache key of an object
# `keyFields` should contain a selection set. Composite fields are not supported yet.
directive @typePolicy(keyFields: String!) on OBJECT
directive @typePolicy(keyFields: String!) on OBJECT | INTERFACE | UNION

# Indicates how to compute a key from a field arguments.
# `keyArgs` should contain a selection set. Composite args are not supported yet.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,22 @@ import com.apollographql.apollo3.ast.GQLOperationDefinition
import com.apollographql.apollo3.ast.checkKeyFields
import com.apollographql.apollo3.ast.parseAsGQLDocument
import com.apollographql.apollo3.ast.transformation.addRequiredFields
import com.apollographql.apollo3.ast.validateAsSchema
import com.apollographql.apollo3.ast.withApolloDefinitions
import com.apollographql.apollo3.compiler.Options.Companion.defaultAddTypename
import com.apollographql.apollo3.compiler.introspection.toSchema
import com.apollographql.apollo3.compiler.introspection.toSchemaGQLDocument
import okio.buffer
import okio.source
import org.junit.Assert.fail
import org.junit.Test
import java.io.File
import kotlin.test.assertContains
import kotlin.test.assertEquals

class KeyFieldsTest {
@Test
fun test() {
fun testAddRequiredFields() {
val schema = File("src/test/kotlin/com/apollographql/apollo3/compiler/keyfields/schema.graphqls").toSchema()

val definitions = File("src/test/kotlin/com/apollographql/apollo3/compiler/keyfields/operations.graphql")
Expand All @@ -30,15 +35,56 @@ class KeyFieldsTest {
val operation = definitions
.filterIsInstance<GQLOperationDefinition>()
.first()
.let {
addRequiredFields(it, defaultAddTypename, schema, fragments)
}

try {
checkKeyFields(operation, schema, emptyMap())
fail("an exception was expected")
} catch (e: Exception) {
assert(e.message?.contains("are not queried") == true)
}

val operationWithKeyFields = addRequiredFields(operation, defaultAddTypename, schema, fragments)
checkKeyFields(operationWithKeyFields, schema, emptyMap())
}

@Test
fun testExtendInterfaceTypePolicyDirective() {
val schema = File("src/test/kotlin/com/apollographql/apollo3/compiler/keyfields/extendsSchema.graphqls").toSchema()
schema.toGQLDocument().validateAsSchema()
assertEquals(setOf("id"), schema.keyFields("Node"))
}

@Test
fun testExtendUnionTypePolicyDirective() {
val schema = File("src/test/kotlin/com/apollographql/apollo3/compiler/keyfields/extendsSchema.graphqls").toSchema()
assertEquals(setOf("x"), schema.keyFields("Foo"))
}

@Test
fun testObjectWithTypePolicyAndInterfaceTypePolicyErrors() {
val doc = File("src/test/kotlin/com/apollographql/apollo3/compiler/keyfields/objectAndInterfaceTypePolicySchema.graphqls")
.toSchemaGQLDocument()
.withApolloDefinitions()
val issue = doc.validateAsSchema().issues.first()
assertContains(issue.message, "Type 'Foo' cannot have key fields since it implements")
assertContains(issue.message, "Node")
assertEquals(11, issue.sourceLocation.line)
}

@Test
fun testObjectInheritingTwoInterfacesWithDifferentKeyFields() {
val doc = File("src/test/kotlin/com/apollographql/apollo3/compiler/keyfields/objectInheritingTwoInterfaces.graphqls")
.toSchemaGQLDocument()
.withApolloDefinitions()
val issue = doc.validateAsSchema().issues.first()
assertEquals(
"""
Apollo: Type 'Book' cannot inherit different keys from different interfaces:
Node: [id]
Product: [upc]
""".trimIndent(),
issue.message
)
assertEquals(13, issue.sourceLocation.line)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
type Query {
node: Node
}

interface Node {
id: String!
}

extend interface Node @typePolicy(keyFields: "id")

type Bar {
id: String!
}

extend type Bar implements Node

union Foo = A | B
type A {
x: String!
}
type B {
x: String!
}

extend union Foo @typePolicy(keyFields: "x")
Loading