diff --git a/graphql/execution/executor.py b/graphql/execution/executor.py index 472f95b9..c339455b 100644 --- a/graphql/execution/executor.py +++ b/graphql/execution/executor.py @@ -364,6 +364,9 @@ def resolve_field( executor = exe_context.executor result = resolve_or_error(resolve_fn_middleware, source, info, args, executor) + if result is Undefined: + return Undefined + return complete_value_catching_error( exe_context, return_type, field_asts, info, field_path, result ) diff --git a/graphql/execution/values.py b/graphql/execution/values.py index b4304cb4..be272532 100644 --- a/graphql/execution/values.py +++ b/graphql/execution/values.py @@ -20,6 +20,7 @@ from ..utils.is_valid_value import is_valid_value from ..utils.type_from_ast import type_from_ast from ..utils.value_from_ast import value_from_ast +from ..utils.undefined import Undefined # Necessary for static type checking if False: # flake8: noqa @@ -56,10 +57,11 @@ def get_variable_values( [def_ast], ) elif value is None: - if def_ast.default_value is not None: - values[var_name] = value_from_ast( - def_ast.default_value, var_type - ) # type: ignore + if def_ast.default_value is None: + values[var_name] = None + elif def_ast.default_value is not Undefined: + values[var_name] = value_from_ast(def_ast.default_value, var_type) + if isinstance(var_type, GraphQLNonNull): raise GraphQLError( 'Variable "${var_name}" of required type "{var_type}" was not provided.'.format( @@ -109,7 +111,7 @@ def get_argument_values( arg_type = arg_def.type arg_ast = arg_ast_map.get(name) if name not in arg_ast_map: - if arg_def.default_value is not None: + if arg_def.default_value is not Undefined: result[arg_def.out_name or name] = arg_def.default_value continue elif isinstance(arg_type, GraphQLNonNull): @@ -123,7 +125,7 @@ def get_argument_values( variable_name = arg_ast.value.name.value # type: ignore if variables and variable_name in variables: result[arg_def.out_name or name] = variables[variable_name] - elif arg_def.default_value is not None: + elif arg_def.default_value is not Undefined: result[arg_def.out_name or name] = arg_def.default_value elif isinstance(arg_type, GraphQLNonNull): raise GraphQLError( @@ -137,7 +139,7 @@ def get_argument_values( else: value = value_from_ast(arg_ast.value, arg_type, variables) # type: ignore if value is None: - if arg_def.default_value is not None: + if arg_def.default_value is not Undefined: value = arg_def.default_value result[arg_def.out_name or name] = value else: @@ -172,7 +174,7 @@ def coerce_value(type, value): obj = {} for field_name, field in fields.items(): if field_name not in value: - if field.default_value is not None: + if field.default_value is not Undefined: field_value = field.default_value obj[field.out_name or field_name] = field_value else: diff --git a/graphql/language/ast.py b/graphql/language/ast.py index 0c40d440..7a11e7bd 100644 --- a/graphql/language/ast.py +++ b/graphql/language/ast.py @@ -604,6 +604,48 @@ def __hash__(self): return id(self) +class NullValue(Value): + __slots__ = ("loc", "value") + _fields = ("value",) + + def __init__(self, value=None, loc=None): + self.value = None + self.loc = loc + + def __eq__(self, other): + return isinstance(other, NullValue) + + def __repr__(self): + return "NullValue" + + def __copy__(self): + return type(self)(self.value, self.loc) + + def __hash__(self): + return id(self) + + +class UndefinedValue(Value): + __slots__ = ("loc", "value") + _fields = ("value",) + + def __init__(self, value=None, loc=None): + self.value = None + self.loc = loc + + def __eq__(self, other): + return isinstance(other, UndefinedValue) + + def __repr__(self): + return "UndefinedValue" + + def __copy__(self): + return type(self)(self.value, self.loc) + + def __hash__(self): + return id(self) + + class EnumValue(Value): __slots__ = ("loc", "value") _fields = ("value",) diff --git a/graphql/language/parser.py b/graphql/language/parser.py index 959a630e..c0f77c20 100644 --- a/graphql/language/parser.py +++ b/graphql/language/parser.py @@ -4,6 +4,7 @@ from ..error import GraphQLSyntaxError from .lexer import Lexer, TokenKind, get_token_desc, get_token_kind_desc from .source import Source +from ..utils.undefined import Undefined # Necessary for static type checking if False: # flake8: noqa @@ -65,6 +66,12 @@ def parse(source, **kwargs): def parse_value(source, **kwargs): + if source is None: + return ast.NullValue() + + if source is Undefined: + return ast.UndefinedValue() + options = {"no_location": False, "no_source": False} options.update(kwargs) source_obj = source @@ -338,7 +345,7 @@ def parse_variable_definition(parser): type=expect(parser, TokenKind.COLON) and parse_type(parser), default_value=parse_value_literal(parser, True) if skip(parser, TokenKind.EQUALS) - else None, + else Undefined, loc=loc(parser, start), ) @@ -493,18 +500,21 @@ def parse_value_literal(parser, is_const): ) elif token.kind == TokenKind.NAME: + advance(parser) if token.value in ("true", "false"): - advance(parser) return ast.BooleanValue( value=token.value == "true", loc=loc(parser, token.start) ) - if token.value != "null": - advance(parser) - return ast.EnumValue( - value=token.value, loc=loc(parser, token.start) # type: ignore + if token.value == "null": + return ast.NullValue( + loc=loc(parser, token.start) # type: ignore ) + return ast.EnumValue( # type: ignore + value=token.value, loc=loc(parser, token.start) + ) + elif token.kind == TokenKind.DOLLAR: if not is_const: return parse_variable(parser) @@ -754,7 +764,7 @@ def parse_input_value_def(parser): type=expect(parser, TokenKind.COLON) and parse_type(parser), # type: ignore default_value=parse_const_value(parser) if skip(parser, TokenKind.EQUALS) - else None, + else Undefined, directives=parse_directives(parser), loc=loc(parser, start), ) diff --git a/graphql/language/printer.py b/graphql/language/printer.py index 6dcd81de..cebed944 100644 --- a/graphql/language/printer.py +++ b/graphql/language/printer.py @@ -1,6 +1,7 @@ import json from .visitor import Visitor, visit +from ..utils.undefined import Undefined # Necessary for static type checking if False: # flake8: noqa @@ -45,7 +46,7 @@ def leave_OperationDefinition(self, node, *args): def leave_VariableDefinition(self, node, *args): # type: (Any, *Any) -> str - return node.variable + ": " + node.type + wrap(" = ", node.default_value) + return node.variable + ": " + node.type + wrap(" = ", node.default_value, is_default_value=True) def leave_SelectionSet(self, node, *args): # type: (Any, *Any) -> str @@ -111,6 +112,12 @@ def leave_BooleanValue(self, node, *args): # type: (Any, *Any) -> str return json.dumps(node.value) + def leave_NullValue(self, node, *args): + return "null" + + def leave_UndefinedValue(self, node, *args): + return Undefined + def leave_EnumValue(self, node, *args): # type: (Any, *Any) -> str return node.value @@ -192,7 +199,7 @@ def leave_InputValueDefinition(self, node, *args): node.name + ": " + node.type - + wrap(" = ", node.default_value) + + wrap(" = ", node.default_value, is_default_value=True) + wrap(" ", join(node.directives, " ")) ) @@ -232,13 +239,14 @@ def leave_EnumValueDefinition(self, node, *args): def leave_InputObjectTypeDefinition(self, node, *args): # type: (Any, *Any) -> str - return ( + s = ( "input " + node.name + wrap(" ", join(node.directives, " ")) + " " + block(node.fields) ) + return s def leave_TypeExtensionDefinition(self, node, *args): # type: (Any, *Any) -> str @@ -268,8 +276,14 @@ def block(_list): return "{}" -def wrap(start, maybe_str, end=""): - # type: (str, Optional[str], str) -> str +def wrap(start, maybe_str, end="", is_default_value=False): + # type: (str, Optional[str], str, bool) -> str + if is_default_value: + if maybe_str is Undefined: + return "" + s = "null" if maybe_str is None else maybe_str + return start + s + end + if maybe_str: return start + maybe_str + end return "" diff --git a/graphql/language/tests/fixtures.py b/graphql/language/tests/fixtures.py index b16653c4..74b2ccda 100644 --- a/graphql/language/tests/fixtures.py +++ b/graphql/language/tests/fixtures.py @@ -53,7 +53,7 @@ } { - unnamed(truthy: true, falsey: false), + unnamed(truthy: true, falsey: false, nullish: null), query } """ diff --git a/graphql/language/tests/test_parser.py b/graphql/language/tests/test_parser.py index 6c9d5f47..e6ea183f 100644 --- a/graphql/language/tests/test_parser.py +++ b/graphql/language/tests/test_parser.py @@ -101,12 +101,72 @@ def test_does_not_accept_fragments_spread_of_on(): assert "Syntax Error GraphQL (1:9) Expected Name, found }" in excinfo.value.message -def test_does_not_allow_null_value(): +def test_allows_null_value(): # type: () -> None - with raises(GraphQLSyntaxError) as excinfo: - parse("{ fieldWithNullableStringInput(input: null) }") + parse("{ fieldWithNullableStringInput(input: null) }") + + +def test_parses_null_value_to_null(): + result = parse('{ fieldWithObjectInput(input: {a: null, b: null, c: "C", d: null}) }') + values = result.definitions[0].selection_set.selections[0].arguments[0].value.fields + expected = ( + (u"a", ast.NullValue()), + (u"b", ast.NullValue()), + (u"c", ast.StringValue(value=u"C")), + (u"d", ast.NullValue()), + ) + for name_value, actual in zip(expected, values): + assert name_value == (actual.name.value, actual.value) + + +def test_parses_null_value_in_list(): + result = parse('{ fieldWithObjectInput(input: {b: ["A", null, "C"], c: "C"}) }') + assert result == ast.Document( + definitions=[ + ast.OperationDefinition( + operation="query", name=None, variable_definitions=None, directives=[], + selection_set=ast.SelectionSet( + selections=[ + ast.Field( + alias=None, + name=ast.Name(value=u"fieldWithObjectInput"), + directives=[], + selection_set=None, + arguments=[ + ast.Argument( + name=ast.Name(value=u"input"), + value=ast.ObjectValue( + fields=[ + ast.ObjectField( + name=ast.Name(value=u"b"), + value=ast.ListValue( + values=[ + ast.StringValue(value=u"A"), + ast.NullValue(), + ast.StringValue(value=u"C"), + ], + ), + ), + ast.ObjectField( + name=ast.Name(value=u"c"), + value=ast.StringValue(value=u"C"), + ), + ] + ), + ), + ], + ), + ], + ), + ), + ], + ) + - assert 'Syntax Error GraphQL (1:39) Unexpected Name "null"' in excinfo.value.message +def test_null_as_name(): + result = parse('{ thingy(null: "stringcheese") }') + assert result.definitions[0].selection_set.selections[0].name.value == "thingy" + assert result.definitions[0].selection_set.selections[0].arguments[0].name.value == "null" def test_parses_multi_byte_characters(): @@ -158,6 +218,7 @@ def tesst_allows_non_keywords_anywhere_a_name_is_allowed(): "subscription", "true", "false", + "null", ] query_template = """ diff --git a/graphql/language/tests/test_printer.py b/graphql/language/tests/test_printer.py index 470991df..e1684935 100644 --- a/graphql/language/tests/test_printer.py +++ b/graphql/language/tests/test_printer.py @@ -85,6 +85,14 @@ def test_correctly_prints_mutation_with_artifacts(): ) +def test_correctly_prints_null(): + query_ast_shorthanded = parse('{ thingy(null: "wow", name: null) }') + assert print_ast(query_ast_shorthanded) == """{ + thingy(null: "wow", name: null) +} +""" + + def test_prints_kitchen_sink(): # type: () -> None ast = parse(KITCHEN_SINK) @@ -138,7 +146,7 @@ def test_prints_kitchen_sink(): } { - unnamed(truthy: true, falsey: false) + unnamed(truthy: true, falsey: false, nullish: null) query } """ diff --git a/graphql/language/tests/test_schema_parser.py b/graphql/language/tests/test_schema_parser.py index d5059a5e..583ac341 100644 --- a/graphql/language/tests/test_schema_parser.py +++ b/graphql/language/tests/test_schema_parser.py @@ -4,6 +4,7 @@ from graphql.error import GraphQLSyntaxError from graphql.language import ast from graphql.language.parser import Loc +from graphql.utils.undefined import Undefined from typing import Callable @@ -294,7 +295,7 @@ def test_parses_simple_field_with_arg(): name=ast.Name(value="Boolean", loc=loc(28, 35)), loc=loc(28, 35), ), - default_value=None, + default_value=Undefined, directives=[], loc=loc(22, 35), ) @@ -391,7 +392,7 @@ def test_parses_simple_field_with_list_arg(): ), loc=loc(30, 38), ), - default_value=None, + default_value=Undefined, directives=[], loc=loc(22, 38), ) @@ -436,7 +437,7 @@ def test_parses_simple_field_with_two_args(): name=ast.Name(value="Boolean", loc=loc(30, 37)), loc=loc(30, 37), ), - default_value=None, + default_value=Undefined, directives=[], loc=loc(22, 37), ), @@ -446,7 +447,7 @@ def test_parses_simple_field_with_two_args(): name=ast.Name(value="Int", loc=loc(47, 50)), loc=loc(47, 50), ), - default_value=None, + default_value=Undefined, directives=[], loc=loc(39, 50), ), @@ -554,7 +555,7 @@ def test_parses_simple_input_object(): name=ast.Name(value="String", loc=loc(24, 30)), loc=loc(24, 30), ), - default_value=None, + default_value=Undefined, directives=[], loc=loc(17, 30), ) diff --git a/graphql/language/tests/test_visitor.py b/graphql/language/tests/test_visitor.py index 1392fe3e..5819aa5a 100644 --- a/graphql/language/tests/test_visitor.py +++ b/graphql/language/tests/test_visitor.py @@ -600,6 +600,12 @@ def leave(self, node, key, parent, *args): ["enter", "BooleanValue", "value", "Argument"], ["leave", "BooleanValue", "value", "Argument"], ["leave", "Argument", 1, None], + ["enter", "Argument", 2, None], + ["enter", "Name", "name", "Argument"], + ["leave", "Name", "name", "Argument"], + ["enter", "NullValue", "value", "Argument"], + ["leave", "NullValue", "value", "Argument"], + ["leave", "Argument", 2, None], ["leave", "Field", 0, None], ["enter", "Field", 1, None], ["enter", "Name", "name", "Field"], diff --git a/graphql/language/visitor.py b/graphql/language/visitor.py index a3c3ff93..49256795 100644 --- a/graphql/language/visitor.py +++ b/graphql/language/visitor.py @@ -4,6 +4,7 @@ from . import ast from .visitor_meta import QUERY_DOCUMENT_KEYS, VisitorMeta +from ..utils.undefined import Undefined # Necessary for static type checking if False: # flake8: noqa @@ -107,6 +108,9 @@ def visit(root, visitor, key_map=None): key = None node = new_root # type: ignore + if node is Undefined: + continue + if node is REMOVE or node is None: continue @@ -124,6 +128,9 @@ def visit(root, visitor, key_map=None): else: result = enter(node, key, parent, path, ancestors) + if result is Undefined: + continue + if result is BREAK: break diff --git a/graphql/language/visitor_meta.py b/graphql/language/visitor_meta.py index 37372c48..cdd53d28 100644 --- a/graphql/language/visitor_meta.py +++ b/graphql/language/visitor_meta.py @@ -21,6 +21,7 @@ ast.FloatValue: (), ast.StringValue: (), ast.BooleanValue: (), + ast.NullValue: (), ast.EnumValue: (), ast.ListValue: ("values",), ast.ObjectValue: ("fields",), diff --git a/graphql/type/definition.py b/graphql/type/definition.py index 8e5d12f4..551cdc02 100644 --- a/graphql/type/definition.py +++ b/graphql/type/definition.py @@ -314,7 +314,7 @@ class GraphQLArgument(object): def __init__( self, type_, # type: Union[GraphQLInputObjectType, GraphQLNonNull, GraphQLList, GraphQLScalarType] - default_value=None, # type: Optional[Any] + default_value=Undefined, # type: Optional[Any] description=None, # type: Optional[Any] out_name=None, # type: Optional[str] ): @@ -666,7 +666,7 @@ class GraphQLInputObjectField(object): def __init__( self, type_, # type: Union[GraphQLInputObjectType, GraphQLScalarType] - default_value=None, # type: Optional[Any] + default_value=Undefined, # type: Optional[Any] description=None, # type: Optional[Any] out_name=None, # type: str ): diff --git a/graphql/type/introspection.py b/graphql/type/introspection.py index e392b82d..c4a30fed 100644 --- a/graphql/type/introspection.py +++ b/graphql/type/introspection.py @@ -2,6 +2,7 @@ from ..language.printer import print_ast from ..utils.ast_from_value import ast_from_value +from ..utils.undefined import Undefined from .definition import ( GraphQLArgument, GraphQLEnumType, @@ -519,6 +520,15 @@ def input_fields(type, info): ), ) + +def _resolve_default_value(input_value, *_): + if input_value.default_value is Undefined: + return Undefined + if input_value.default_value is None: + return None + return print_ast(ast_from_value(input_value.default_value, input_value)) + + __InputValue = GraphQLObjectType( "__InputValue", description="Arguments provided to Fields or Directives and the input fields of an " @@ -533,9 +543,7 @@ def input_fields(type, info): "defaultValue", GraphQLField( type_=GraphQLString, - resolver=lambda input_val, *_: None - if input_val.default_value is None - else print_ast(ast_from_value(input_val.default_value, input_val)), + resolver=_resolve_default_value, ), ), ] diff --git a/graphql/type/tests/test_introspection.py b/graphql/type/tests/test_introspection.py index 8085189b..d81a006a 100644 --- a/graphql/type/tests/test_introspection.py +++ b/graphql/type/tests/test_introspection.py @@ -854,7 +854,6 @@ def test_introspects_on_input_object(): "name": None, "ofType": {"kind": "SCALAR", "name": "String", "ofType": None}, }, - "defaultValue": None, }, ], } in result.data["__schema"]["types"] diff --git a/graphql/utils/ast_from_value.py b/graphql/utils/ast_from_value.py index 4e3b4d94..9251aa7d 100644 --- a/graphql/utils/ast_from_value.py +++ b/graphql/utils/ast_from_value.py @@ -12,14 +12,18 @@ ) from ..type.scalars import GraphQLFloat from .assert_valid_name import COMPILED_NAME_PATTERN +from ..utils.undefined import Undefined def ast_from_value(value, type=None): if isinstance(type, GraphQLNonNull): return ast_from_value(value, type.of_type) + if value is Undefined: + return ast.UndefinedValue() + if value is None: - return None + return ast.NullValue() if isinstance(value, list): item_type = type.of_type if isinstance(type, GraphQLList) else None diff --git a/graphql/utils/build_client_schema.py b/graphql/utils/build_client_schema.py index 152abaaf..389e1e7a 100644 --- a/graphql/utils/build_client_schema.py +++ b/graphql/utils/build_client_schema.py @@ -35,6 +35,7 @@ __TypeKind, ) from .value_from_ast import value_from_ast +from ..utils.undefined import Undefined def _false(*_): @@ -237,10 +238,7 @@ def build_field_def_map(type_introspection): ) def build_default_value(f): - default_value = f.get("defaultValue") - if default_value is None: - return None - + default_value = f.get("defaultValue", Undefined) return value_from_ast(parse_value(default_value), get_input_type(f["type"])) def build_input_value_def_map(input_value_introspection, argument_type): diff --git a/graphql/utils/is_valid_literal_value.py b/graphql/utils/is_valid_literal_value.py index 1b65d871..3ac2be10 100644 --- a/graphql/utils/is_valid_literal_value.py +++ b/graphql/utils/is_valid_literal_value.py @@ -20,15 +20,15 @@ def is_valid_literal_value(type, value_ast): # type: (Union[GraphQLInputObjectType, GraphQLScalarType, GraphQLNonNull, GraphQLList], Any) -> List if isinstance(type, GraphQLNonNull): of_type = type.of_type - if not value_ast: - return [u'Expected "{}", found null.'.format(type)] + if not value_ast or isinstance(value_ast, ast.NullValue): + return [u'Expected type "{}", found null.'.format(type)] return is_valid_literal_value(of_type, value_ast) # type: ignore if not value_ast: return _empty_list - if isinstance(value_ast, ast.Variable): + if isinstance(value_ast, (ast.Variable, ast.NullValue)): return _empty_list if isinstance(type, GraphQLList): diff --git a/graphql/utils/schema_printer.py b/graphql/utils/schema_printer.py index fd19ae42..61d6efbd 100644 --- a/graphql/utils/schema_printer.py +++ b/graphql/utils/schema_printer.py @@ -9,6 +9,7 @@ ) from ..type.directives import DEFAULT_DEPRECATION_REASON from .ast_from_value import ast_from_value +from ..language.ast import UndefinedValue # Necessary for static type checking @@ -201,10 +202,11 @@ def _print_args(field_or_directives): def _print_input_value(name, arg): # type: (str, GraphQLArgument) -> str - if arg.default_value is not None: - default_value = " = " + print_ast(ast_from_value(arg.default_value, arg.type)) - else: + _ast = ast_from_value(arg.default_value, arg.type) + if isinstance(_ast, UndefinedValue): default_value = "" + else: + default_value = " = " + print_ast(_ast) return "{}: {}{}".format(name, arg.type, default_value) diff --git a/graphql/utils/tests/test_ast_from_value.py b/graphql/utils/tests/test_ast_from_value.py index b25f1ed1..f82c2f57 100644 --- a/graphql/utils/tests/test_ast_from_value.py +++ b/graphql/utils/tests/test_ast_from_value.py @@ -41,6 +41,10 @@ def test_it_converts_string_values_to_asts(): assert ast_from_value("123") == ast.StringValue("123") +def test_it_converts_null_values_to_asts(): + assert ast_from_value(None) == ast.NullValue() + + my_enum = GraphQLEnumType( "MyEnum", {"HELLO": GraphQLEnumValue(1), "GOODBYE": GraphQLEnumValue(2)} ) diff --git a/graphql/utils/tests/test_ast_to_code.py b/graphql/utils/tests/test_ast_to_code.py index 29d62d46..2b39eb28 100644 --- a/graphql/utils/tests/test_ast_to_code.py +++ b/graphql/utils/tests/test_ast_to_code.py @@ -4,6 +4,7 @@ from graphql.utils.ast_to_code import ast_to_code from ...language.tests import fixtures +from ...utils.undefined import Undefined def test_ast_to_code_using_kitchen_sink(): @@ -14,5 +15,5 @@ def test_ast_to_code_using_kitchen_sink(): def loc(start, end): return Loc(start, end, source) - parsed_code_ast = eval(code_ast, {}, {"ast": ast, "loc": loc}) + parsed_code_ast = eval(code_ast, {}, {"ast": ast, "loc": loc, "Undefined": Undefined}) assert doc == parsed_code_ast diff --git a/graphql/utils/tests/test_build_client_schema.py b/graphql/utils/tests/test_build_client_schema.py index a6f73078..3fc8cbd2 100644 --- a/graphql/utils/tests/test_build_client_schema.py +++ b/graphql/utils/tests/test_build_client_schema.py @@ -73,6 +73,11 @@ def test_builds_a_simple_schema_with_both_operation_types(): GraphQLString, description="Set the string field", args={"value": GraphQLArgument(GraphQLString)}, + ), + "setStringDefault": GraphQLField( + GraphQLString, + description="Set the string field", + args={"default_value": GraphQLArgument(GraphQLString, default_value=None)}, ) }, ) @@ -459,6 +464,15 @@ def test_builds_a_schema_with_field_arguments_with_default_values(): }, ), ), + ( + "defaultNullInt", + GraphQLField( + GraphQLString, + args={ + "intArg": GraphQLArgument(GraphQLInt, default_value=None) + }, + ), + ), ( "defaultList", GraphQLField( diff --git a/graphql/utils/tests/test_schema_printer.py b/graphql/utils/tests/test_schema_printer.py index 9b7b001f..4f5b2ae1 100644 --- a/graphql/utils/tests/test_schema_printer.py +++ b/graphql/utils/tests/test_schema_printer.py @@ -184,6 +184,22 @@ def test_prints_string_field_with_int_arg_with_default(): ) +def test_prints_string_field_with_int_arg_with_default_null(): + output = print_single_field_schema(GraphQLField( + type=GraphQLString, + args={"argOne": GraphQLArgument(GraphQLInt, default_value=None)} + )) + assert output == """ +schema { + query: Root +} + +type Root { + singleField(argOne: Int = null): String +} +""" + + def test_prints_string_field_with_non_null_int_arg(): output = print_single_field_schema( GraphQLField( @@ -496,6 +512,72 @@ def test_prints_input_type(): ) +def test_prints_input_type_with_default(): + InputType = GraphQLInputObjectType( + name="InputType", + fields={ + "int": GraphQLInputObjectField(GraphQLInt, default_value=2) + } + ) + + Root = GraphQLObjectType( + name="Root", + fields={ + "str": GraphQLField(GraphQLString, args={"argOne": GraphQLArgument(InputType)}) + } + ) + + Schema = GraphQLSchema(Root) + output = print_for_test(Schema) + + assert output == """ +schema { + query: Root +} + +input InputType { + int: Int = 2 +} + +type Root { + str(argOne: InputType): String +} +""" + + +def test_prints_input_type_with_default_null(): + InputType = GraphQLInputObjectType( + name="InputType", + fields={ + "int": GraphQLInputObjectField(GraphQLInt, default_value=None) + } + ) + + Root = GraphQLObjectType( + name="Root", + fields={ + "str": GraphQLField(GraphQLString, args={"argOne": GraphQLArgument(InputType)}) + } + ) + + Schema = GraphQLSchema(Root) + output = print_for_test(Schema) + + assert output == """ +schema { + query: Root +} + +input InputType { + int: Int = null +} + +type Root { + str(argOne: InputType): String +} +""" + + def test_prints_custom_scalar(): OddType = GraphQLScalarType( name="Odd", serialize=lambda v: v if v % 2 == 1 else None diff --git a/graphql/utils/undefined.py b/graphql/utils/undefined.py index db8f80a9..213165f4 100644 --- a/graphql/utils/undefined.py +++ b/graphql/utils/undefined.py @@ -1,6 +1,9 @@ class _Undefined(object): """A representation of an Undefined value distinct from a None value""" + def __eq__(self, other): + return isinstance(other, _Undefined) + def __bool__(self): # type: () -> bool return False diff --git a/graphql/utils/value_from_ast.py b/graphql/utils/value_from_ast.py index 07bb4d00..9dfd467b 100644 --- a/graphql/utils/value_from_ast.py +++ b/graphql/utils/value_from_ast.py @@ -1,5 +1,6 @@ from ..language import ast from ..pyutils.ordereddict import OrderedDict +from ..utils.undefined import Undefined from ..type import ( GraphQLEnumType, GraphQLInputObjectType, @@ -24,7 +25,10 @@ def value_from_ast(value_ast, type, variables=None): # We're assuming that this query has been validated and the value used here is of the correct type. return value_from_ast(value_ast, type.of_type, variables) - if value_ast is None: + if value_ast is Undefined: + return value_ast + + if isinstance(value_ast, ast.NullValue): return None if isinstance(value_ast, ast.Variable): @@ -60,7 +64,7 @@ def value_from_ast(value_ast, type, variables=None): obj_items = [] for field_name, field in fields.items(): if field_name not in field_asts: - if field.default_value is not None: + if field.default_value is not Undefined: # We use out_name as the output name for the # dict if exists obj_items.append( diff --git a/graphql/validation/rules/unique_variable_names.py b/graphql/validation/rules/unique_variable_names.py index 6d7bfcbf..6b398c2b 100644 --- a/graphql/validation/rules/unique_variable_names.py +++ b/graphql/validation/rules/unique_variable_names.py @@ -1,3 +1,4 @@ +from six import text_type from ...error import GraphQLError from .base import ValidationRule @@ -41,4 +42,5 @@ def enter_VariableDefinition(self, node, key, parent, path, ancestors): @staticmethod def duplicate_variable_message(operation_name): - return 'There can be only one variable named "{}".'.format(operation_name) + s = 'There can be only one variable named "{}".'.format(operation_name) + return text_type(s) diff --git a/graphql/validation/rules/variables_are_input_types.py b/graphql/validation/rules/variables_are_input_types.py index 49397ad9..b27eeb31 100644 --- a/graphql/validation/rules/variables_are_input_types.py +++ b/graphql/validation/rules/variables_are_input_types.py @@ -1,3 +1,4 @@ +from six import text_type from ...error import GraphQLError from ...language.printer import print_ast from ...type.definition import is_input_type @@ -21,6 +22,7 @@ def enter_VariableDefinition(self, node, key, parent, path, ancestors): @staticmethod def non_input_type_on_variable_message(variable_name, type_name): - return 'Variable "${}" cannot be non-input type "{}".'.format( + s = 'Variable "${}" cannot be non-input type "{}".'.format( variable_name, type_name ) + return text_type(s) diff --git a/graphql/validation/rules/variables_in_allowed_position.py b/graphql/validation/rules/variables_in_allowed_position.py index 8a0209d1..c61d8d29 100644 --- a/graphql/validation/rules/variables_in_allowed_position.py +++ b/graphql/validation/rules/variables_in_allowed_position.py @@ -1,3 +1,4 @@ +from six import text_type from ...error import GraphQLError from ...type.definition import GraphQLNonNull from ...utils.type_comparators import is_type_sub_type_of @@ -84,6 +85,7 @@ def effective_type(var_type, var_def): @staticmethod def bad_var_pos_message(var_name, var_type, expected_type): - return 'Variable "{}" of type "{}" used in position expecting type "{}".'.format( + s = 'Variable "{}" of type "{}" used in position expecting type "{}".'.format( var_name, var_type, expected_type ) + return text_type(s) diff --git a/graphql/validation/tests/test_arguments_of_correct_type.py b/graphql/validation/tests/test_arguments_of_correct_type.py index db165669..84d621c8 100644 --- a/graphql/validation/tests/test_arguments_of_correct_type.py +++ b/graphql/validation/tests/test_arguments_of_correct_type.py @@ -114,6 +114,15 @@ def test_good_enum_value(self): """, ) + def test_null_nullable_int_value(self): + expect_passes_rule(ArgumentsOfCorrectType, """ + { + complicatedArgs { + intArgField(intArg: null) + } + } + """) + # noinspection PyMethodMayBeStatic class TestInvalidStringValues(object): @@ -237,6 +246,17 @@ def test_float_into_int(self): [bad_value("intArg", "Int", "3.333", 4, 37)], ) + def test_null_into_non_null_int(self): + expect_fails_rule(ArgumentsOfCorrectType, """ + { + complicatedArgs { + nonNullIntArgField(nonNullIntArg: null) + } + } + """, [ + bad_value("nonNullIntArg", "Int!", "null", 4, 51) + ]) + # noinspection PyMethodMayBeStatic class TestInvalidFloatValues(object): @@ -784,7 +804,7 @@ def test_partial_object_missing_required(self): "{intField: 4}", 4, 45, - ['In field "requiredField": Expected "Boolean!", found null.'], + ['In field "requiredField": Expected type "Boolean!", found null.'], ) ], ) diff --git a/graphql/validation/tests/test_default_values_of_correct_type.py b/graphql/validation/tests/test_default_values_of_correct_type.py index 9ea0b439..a8acbdcc 100644 --- a/graphql/validation/tests/test_default_values_of_correct_type.py +++ b/graphql/validation/tests/test_default_values_of_correct_type.py @@ -54,7 +54,9 @@ def test_variables_with_valid_default_values(): query WithDefaultValues( $a: Int = 1, $b: String = "ok", - $c: ComplexInput = { requiredField: true, intField: 3 } + $c: ComplexInput = { requiredField: true, intField: 3 }, + $d: ComplexInputl = null, + $d: [Int] = null ) { dog { name } } @@ -84,7 +86,8 @@ def test_variables_with_invalid_default_values(): query InvalidDefaultValues( $a: Int = "one", $b: String = 4, - $c: ComplexInput = "notverycomplex" + $c: ComplexInput = "notverycomplex", + $d: Int! = null ) { dog { name } } @@ -100,6 +103,8 @@ def test_variables_with_invalid_default_values(): 28, ['Expected "ComplexInput", found not an object.'], ), + bad_value("d", "Int!", "null", 6, 20), + default_for_non_null_arg("d", "Int!", "Int", 6, 20) ], ) @@ -119,7 +124,7 @@ def test_variables_missing_required_field(): "{intField: 3}", 2, 51, - ['In field "requiredField": Expected "Boolean!", found null.'], + ['In field "requiredField": Expected type "Boolean!", found null.'], ) ], ) diff --git a/graphql/validation/tests/test_variables_in_allowed_position.py b/graphql/validation/tests/test_variables_in_allowed_position.py index ffcb0ffb..2059f84c 100644 --- a/graphql/validation/tests/test_variables_in_allowed_position.py +++ b/graphql/validation/tests/test_variables_in_allowed_position.py @@ -219,7 +219,7 @@ def test_int_non_null_int(): "message": VariablesInAllowedPosition.bad_var_pos_message( "intArg", "Int", "Int!" ), - "locations": [SourceLocation(4, 45), SourceLocation(2, 19)], + "locations": [SourceLocation(2, 19), SourceLocation(4, 45)], } ], ) @@ -318,7 +318,7 @@ def test_string_string_fail(): ) -def test_boolean_non_null_boolean_in_directive(): +def test_boolean_null_boolean_in_directive(): expect_fails_rule( VariablesInAllowedPosition, """ diff --git a/graphql/validation/tests/utils.py b/graphql/validation/tests/utils.py index 7bec2763..3899036d 100644 --- a/graphql/validation/tests/utils.py +++ b/graphql/validation/tests/utils.py @@ -1,3 +1,5 @@ +from copy import deepcopy +from six import text_type from graphql.error import format_error from graphql.language.parser import parse from graphql.type import ( @@ -278,6 +280,12 @@ def sort_lists(value): return value +def format_message(error): + error = deepcopy(error) + error["message"] = text_type(error["message"]) + return error + + def expect_invalid(schema, rules, query, expected_errors, sort_list=True): errors = validate(schema, parse(query), rules) assert errors, "Should not validate" @@ -286,13 +294,17 @@ def expect_invalid(schema, rules, query, expected_errors, sort_list=True): {"line": loc.line, "column": loc.column} for loc in error["locations"] ] + errors = list(map(format_error, errors)) + msg = ("\nexpected errors: %s" + "\n got errors: %s" % (expected_errors, errors)) if sort_list: - assert sort_lists(list(map(format_error, errors))) == sort_lists( - expected_errors - ) + sorted_errors = sort_lists(list(map(format_error, errors))) + expected_errors = map(format_message, expected_errors) + sorted_expected = sort_lists(list(map(format_error, expected_errors))) + assert sorted_errors == sorted_expected, msg else: - assert list(map(format_error, errors)) == expected_errors + assert errors == expected_errors, msg def expect_passes_rule(rule, query):