diff --git a/internal/common/values.go b/internal/common/values.go index 2589664cd6..2fda883bae 100644 --- a/internal/common/values.go +++ b/internal/common/values.go @@ -9,7 +9,7 @@ import ( ) type InputValue struct { - Name string + Name lexer.Ident Type Type Default *ValueWithLoc Desc string @@ -19,7 +19,7 @@ type InputValueList []*InputValue func (l InputValueList) Get(name string) *InputValue { for _, v := range l { - if v.Name == name { + if v.Name.Name == name { return v } } @@ -34,7 +34,7 @@ type ValueWithLoc struct { func ParseInputValue(l *lexer.Lexer) *InputValue { p := &InputValue{} p.Desc = l.DescComment() - p.Name = l.ConsumeIdent() + p.Name = l.ConsumeIdentWithLoc() l.ConsumeToken(':') p.Type = ParseType(l) if l.Peek() == '=' { @@ -46,26 +46,34 @@ func ParseInputValue(l *lexer.Lexer) *InputValue { } type Argument struct { - Name string + Name lexer.Ident Value ValueWithLoc } type ArgumentList []Argument -func (l ArgumentList) Get(name string) ValueWithLoc { +func (l ArgumentList) Get(name string) (ValueWithLoc, bool) { for _, arg := range l { - if arg.Name == name { - return arg.Value + if arg.Name.Name == name { + return arg.Value, true } } - return ValueWithLoc{} + return ValueWithLoc{}, false +} + +func (l ArgumentList) MustGet(name string) ValueWithLoc { + value, ok := l.Get(name) + if !ok { + panic("argument not found") + } + return value } func ParseArguments(l *lexer.Lexer) ArgumentList { var args ArgumentList l.ConsumeToken('(') for l.Peek() != ')' { - name := l.ConsumeIdent() + name := l.ConsumeIdentWithLoc() l.ConsumeToken(':') value := ParseValue(l, false) args = append(args, Argument{Name: name, Value: value}) diff --git a/internal/exec/exec.go b/internal/exec/exec.go index 02b8bcf548..efd4fceadd 100644 --- a/internal/exec/exec.go +++ b/internal/exec/exec.go @@ -543,7 +543,7 @@ func (e *objectExec) execField(ctx context.Context, r *request, f *query.Field, case "__type": p := valuePacker{valueType: reflect.TypeOf("")} - v, err := p.pack(r, r.resolveVar(f.Arguments.Get("name").Value)) + v, err := p.pack(r, r.resolveVar(f.Arguments.MustGet("name").Value)) if err != nil { r.addError(errors.Errorf("%s", err)) addResult(f.Alias, nil) @@ -621,8 +621,8 @@ func (e *fieldExec) exec(ctx context.Context, r *request, f *query.Field, resolv if e.argsPacker != nil { args := make(map[string]interface{}) for _, arg := range f.Arguments { - args[arg.Name] = arg.Value.Value - span.SetTag(OpenTracingTagArgsPrefix+arg.Name, arg.Value.Value) + args[arg.Name.Name] = arg.Value.Value + span.SetTag(OpenTracingTagArgsPrefix+arg.Name.Name, arg.Value.Value) } packed, err := e.argsPacker.pack(r, args) if err != nil { @@ -662,7 +662,7 @@ type typeAssertExec struct { func skipByDirective(r *request, d map[string]common.ArgumentList) bool { if args, ok := d["skip"]; ok { p := valuePacker{valueType: reflect.TypeOf(false)} - v, err := p.pack(r, r.resolveVar(args.Get("if").Value)) + v, err := p.pack(r, r.resolveVar(args.MustGet("if").Value)) if err != nil { r.addError(errors.Errorf("%s", err)) } @@ -673,7 +673,7 @@ func skipByDirective(r *request, d map[string]common.ArgumentList) bool { if args, ok := d["include"]; ok { p := valuePacker{valueType: reflect.TypeOf(false)} - v, err := p.pack(r, r.resolveVar(args.Get("if").Value)) + v, err := p.pack(r, r.resolveVar(args.MustGet("if").Value)) if err != nil { r.addError(errors.Errorf("%s", err)) } diff --git a/internal/exec/packer.go b/internal/exec/packer.go index fc4af2adaa..9b558a5e34 100644 --- a/internal/exec/packer.go +++ b/internal/exec/packer.go @@ -120,7 +120,7 @@ func (b *execBuilder) makeStructPacker(values common.InputValueList, typ reflect for _, v := range values { fe := &structPackerField{field: v} - sf, ok := structType.FieldByNameFunc(func(n string) bool { return strings.EqualFold(n, v.Name) }) + sf, ok := structType.FieldByNameFunc(func(n string) bool { return strings.EqualFold(n, v.Name.Name) }) if !ok { return nil, fmt.Errorf("missing argument %q", v.Name) } @@ -171,7 +171,7 @@ func (p *structPacker) pack(r *request, value interface{}) (reflect.Value, error v := reflect.New(p.structType) v.Elem().Set(p.defaultStruct) for _, f := range p.fields { - if value, ok := values[f.field.Name]; ok { + if value, ok := values[f.field.Name.Name]; ok { packed, err := f.fieldPacker.pack(r, r.resolveVar(value)) if err != nil { return reflect.Value{}, err diff --git a/internal/schema/schema.go b/internal/schema/schema.go index 41d3a9eda7..0cd73265ff 100644 --- a/internal/schema/schema.go +++ b/internal/schema/schema.go @@ -278,12 +278,12 @@ func resolveDirectives(s *Schema, directives map[string]common.ArgumentList) err return errors.Errorf("directive %q not found", name) } for _, arg := range args { - if d.Args.Get(arg.Name) == nil { - return errors.Errorf("invalid argument %q for directive %q", arg.Name, name) + if d.Args.Get(arg.Name.Name) == nil { + return errors.Errorf("invalid argument %q for directive %q", arg.Name.Name, name) } } for _, arg := range d.Args { - if args.Get(arg.Name).Value == nil { + if _, ok := args.Get(arg.Name.Name); !ok { args = append(args, common.Argument{Name: arg.Name, Value: *arg.Default}) } } diff --git a/internal/tests/testdata/export.js b/internal/tests/testdata/export.js index 7539f2bdaa..afa0afda78 100644 --- a/internal/tests/testdata/export.js +++ b/internal/tests/testdata/export.js @@ -50,6 +50,7 @@ require('./src/validation/__tests__/ArgumentsOfCorrectType-test.js'); require('./src/validation/__tests__/DefaultValuesOfCorrectType-test.js'); require('./src/validation/__tests__/FieldsOnCorrectType-test.js'); require('./src/validation/__tests__/FragmentsOnCompositeTypes-test.js'); +require('./src/validation/__tests__/KnownArgumentNames-test.js'); let output = JSON.stringify(tests, null, 2) output = output.replace('{stringListField: [\\"one\\", 2], requiredField: true}', '{requiredField: true, stringListField: [\\"one\\", 2]}'); diff --git a/internal/tests/testdata/tests.json b/internal/tests/testdata/tests.json index f437b39b42..e6f61bbebd 100644 --- a/internal/tests/testdata/tests.json +++ b/internal/tests/testdata/tests.json @@ -1240,5 +1240,129 @@ ] } ] + }, + { + "name": "Validate: Known argument names/single arg is known", + "rule": "KnownArgumentNames", + "query": "\n fragment argOnRequiredArg on Dog {\n doesKnowCommand(dogCommand: SIT)\n }\n ", + "errors": [] + }, + { + "name": "Validate: Known argument names/multiple args are known", + "rule": "KnownArgumentNames", + "query": "\n fragment multipleArgs on ComplicatedArgs {\n multipleReqs(req1: 1, req2: 2)\n }\n ", + "errors": [] + }, + { + "name": "Validate: Known argument names/ignores args of unknown fields", + "rule": "KnownArgumentNames", + "query": "\n fragment argOnUnknownField on Dog {\n unknownField(unknownArg: SIT)\n }\n ", + "errors": [] + }, + { + "name": "Validate: Known argument names/multiple args in reverse order are known", + "rule": "KnownArgumentNames", + "query": "\n fragment multipleArgsReverseOrder on ComplicatedArgs {\n multipleReqs(req2: 2, req1: 1)\n }\n ", + "errors": [] + }, + { + "name": "Validate: Known argument names/no args on optional arg", + "rule": "KnownArgumentNames", + "query": "\n fragment noArgOnOptionalArg on Dog {\n isHousetrained\n }\n ", + "errors": [] + }, + { + "name": "Validate: Known argument names/args are known deeply", + "rule": "KnownArgumentNames", + "query": "\n {\n dog {\n doesKnowCommand(dogCommand: SIT)\n }\n human {\n pet {\n ... on Dog {\n doesKnowCommand(dogCommand: SIT)\n }\n }\n }\n }\n ", + "errors": [] + }, + { + "name": "Validate: Known argument names/directive args are known", + "rule": "KnownArgumentNames", + "query": "\n {\n dog @skip(if: true)\n }\n ", + "errors": [] + }, + { + "name": "Validate: Known argument names/undirective args are invalid", + "rule": "KnownArgumentNames", + "query": "\n {\n dog @skip(unless: true)\n }\n ", + "errors": [ + { + "message": "Unknown argument \"unless\" on directive \"@skip\".", + "locations": [ + { + "line": 3, + "column": 19 + } + ] + } + ] + }, + { + "name": "Validate: Known argument names/invalid arg name", + "rule": "KnownArgumentNames", + "query": "\n fragment invalidArgName on Dog {\n doesKnowCommand(unknown: true)\n }\n ", + "errors": [ + { + "message": "Unknown argument \"unknown\" on field \"doesKnowCommand\" of type \"Dog\".", + "locations": [ + { + "line": 3, + "column": 25 + } + ] + } + ] + }, + { + "name": "Validate: Known argument names/unknown args amongst known args", + "rule": "KnownArgumentNames", + "query": "\n fragment oneGoodArgOneInvalidArg on Dog {\n doesKnowCommand(whoknows: 1, dogCommand: SIT, unknown: true)\n }\n ", + "errors": [ + { + "message": "Unknown argument \"whoknows\" on field \"doesKnowCommand\" of type \"Dog\".", + "locations": [ + { + "line": 3, + "column": 25 + } + ] + }, + { + "message": "Unknown argument \"unknown\" on field \"doesKnowCommand\" of type \"Dog\".", + "locations": [ + { + "line": 3, + "column": 55 + } + ] + } + ] + }, + { + "name": "Validate: Known argument names/unknown args deeply", + "rule": "KnownArgumentNames", + "query": "\n {\n dog {\n doesKnowCommand(unknown: true)\n }\n human {\n pet {\n ... on Dog {\n doesKnowCommand(unknown: true)\n }\n }\n }\n }\n ", + "errors": [ + { + "message": "Unknown argument \"unknown\" on field \"doesKnowCommand\" of type \"Dog\".", + "locations": [ + { + "line": 4, + "column": 27 + } + ] + }, + { + "message": "Unknown argument \"unknown\" on field \"doesKnowCommand\" of type \"Dog\".", + "locations": [ + { + "line": 9, + "column": 31 + } + ] + } + ] } ] \ No newline at end of file diff --git a/internal/validation/validation.go b/internal/validation/validation.go index 6b21cd1df1..7c4e674e71 100644 --- a/internal/validation/validation.go +++ b/internal/validation/validation.go @@ -33,11 +33,11 @@ func Validate(s *schema.Schema, q *query.Document) (errs []*errors.QueryError) { } if nn, ok := t.(*common.NonNull); ok { - addErr(&errs, v.Default.Loc, "DefaultValuesOfCorrectType", "Variable %q of type %q is required and will not use the default value. Perhaps you meant to use type %q.", "$"+v.Name, t, nn.OfType) + addErr(&errs, v.Default.Loc, "DefaultValuesOfCorrectType", "Variable %q of type %q is required and will not use the default value. Perhaps you meant to use type %q.", "$"+v.Name.Name, t, nn.OfType) } if ok, reason := validateValue(v.Default.Value, t); !ok { - addErr(&errs, v.Default.Loc, "DefaultValuesOfCorrectType", "Variable %q of type %q has invalid default value %s.\n%s", "$"+v.Name, t, stringify(v.Default.Value), reason) + addErr(&errs, v.Default.Loc, "DefaultValuesOfCorrectType", "Variable %q of type %q has invalid default value %s.\n%s", "$"+v.Name.Name, t, stringify(v.Default.Value), reason) } } } @@ -94,13 +94,14 @@ func validateSelection(s *schema.Schema, sel query.Selection, t common.Type) (er if f != nil { for _, selArg := range sel.Arguments { - arg := f.Args.Get(selArg.Name) + arg := f.Args.Get(selArg.Name.Name) if arg == nil { + addErr(&errs, selArg.Name.Loc, "KnownArgumentNames", "Unknown argument %q on field %q of type %q.", selArg.Name.Name, sel.Name, t) continue } value := selArg.Value if ok, reason := validateValue(value.Value, arg.Type); !ok { - addErr(&errs, value.Loc, "ArgumentsOfCorrectType", "Argument %q has invalid value %s.\n%s", arg.Name, stringify(value.Value), reason) + addErr(&errs, value.Loc, "ArgumentsOfCorrectType", "Argument %q has invalid value %s.\n%s", arg.Name.Name, stringify(value.Value), reason) } } } @@ -158,10 +159,14 @@ func validateDirectives(s *schema.Schema, directives map[string]common.ArgumentL if !ok { continue } - for _, arg := range d.Args { - value := args.Get(arg.Name) - if ok, reason := validateValue(value.Value, arg.Type); !ok { - addErr(&errs, value.Loc, "ArgumentsOfCorrectType", "Argument %q has invalid value %s.\n%s", arg.Name, stringify(value.Value), reason) + for _, arg := range args { + iv := d.Args.Get(arg.Name.Name) + if iv == nil { + addErr(&errs, arg.Name.Loc, "KnownArgumentNames", "Unknown argument %q on directive %q.", arg.Name.Name, "@"+name) + continue + } + if ok, reason := validateValue(arg.Value.Value, iv.Type); !ok { + addErr(&errs, arg.Value.Loc, "ArgumentsOfCorrectType", "Argument %q has invalid value %s.\n%s", arg.Name.Name, stringify(arg.Value.Value), reason) } } } @@ -223,9 +228,9 @@ func validateValue(v interface{}, t common.Type) (bool, string) { } } for _, f := range t.Values { - if _, ok := v[f.Name]; !ok { + if _, ok := v[f.Name.Name]; !ok { if _, ok := f.Type.(*common.NonNull); ok && f.Default == nil { - return false, fmt.Sprintf("In field %q: Expected %q, found null.", f.Name, f.Type) + return false, fmt.Sprintf("In field %q: Expected %q, found null.", f.Name.Name, f.Type) } } } diff --git a/introspection/introspection.go b/introspection/introspection.go index c5e3beb19f..d9c72872cd 100644 --- a/introspection/introspection.go +++ b/introspection/introspection.go @@ -229,7 +229,7 @@ func (r *Field) DeprecationReason() *string { if !ok { return nil } - reason := common.UnmarshalLiteral(args.Get("reason").Value.(*lexer.Literal)).(string) + reason := common.UnmarshalLiteral(args.MustGet("reason").Value.(*lexer.Literal)).(string) return &reason } @@ -238,7 +238,7 @@ type InputValue struct { } func (r *InputValue) Name() string { - return r.value.Name + return r.value.Name.Name } func (r *InputValue) Description() *string { @@ -289,7 +289,7 @@ func (r *EnumValue) DeprecationReason() *string { if !ok { return nil } - reason := common.UnmarshalLiteral(args.Get("reason").Value.(*lexer.Literal)).(string) + reason := common.UnmarshalLiteral(args.MustGet("reason").Value.(*lexer.Literal)).(string) return &reason }