From 35a9ecf4f134bb4fe454d8c4cbbf492980f356e1 Mon Sep 17 00:00:00 2001 From: Yamashou <1230124fw@gmail.com> Date: Sat, 21 Dec 2024 21:55:44 +0900 Subject: [PATCH 1/2] new Encoder --- clientv2/client.go | 281 ++++++++++++++++++++------------ clientv2/client_test.go | 180 +++++++++++++++++++- querydocument/query_document.go | 2 +- 3 files changed, 360 insertions(+), 103 deletions(-) diff --git a/clientv2/client.go b/clientv2/client.go index 5be9655..599f10a 100644 --- a/clientv2/client.go +++ b/clientv2/client.go @@ -81,11 +81,11 @@ func UnsafeChainInterceptor(interceptors ...RequestInterceptor) RequestIntercept // Client is the http client wrapper type Client struct { - Client HttpClient - BaseURL string - RequestInterceptor RequestInterceptor - CustomDo RequestInterceptorFunc - ParseDataWhenErrors bool + Client HttpClient + BaseURL string + RequestInterceptor RequestInterceptor + CustomDo RequestInterceptorFunc + ParseDataWhenErrors bool IsUnsafeRequestInterceptor bool } @@ -229,7 +229,7 @@ func (c *Client) Post(ctx context.Context, operationName, query string, respData headers = append(headers, header{key: "Content-Type", value: contentType}) } else { - requestBody, err := MarshalJSON(r) + requestBody, err := MarshalJSON(ctx, r) if err != nil { return fmt.Errorf("encode: %w", err) } @@ -462,7 +462,40 @@ func (c *Client) unmarshal(data []byte, res any) error { return err } -func MarshalJSON(v any) ([]byte, error) { +// contextKey is a type for context keys +type contextKey string + +const ( + // EnableInputJsonOmitemptyTagKey is a context key for EnableInputJsonOmitemptyTag + EnableInputJsonOmitemptyTagKey contextKey = "enable_input_json_omitempty_tag" +) + +// WithEnableInputJsonOmitemptyTag returns a new context with EnableInputJsonOmitemptyTag value +func WithEnableInputJsonOmitemptyTag(ctx context.Context, enable bool) context.Context { + return context.WithValue(ctx, EnableInputJsonOmitemptyTagKey, enable) +} + +// getEnableInputJsonOmitemptyTagFromContext retrieves the EnableInputJsonOmitemptyTag value from context +func getEnableInputJsonOmitemptyTagFromContext(ctx context.Context) bool { + enableClientJsonOmitemptyTag := true + if ctx != nil { + enable, ok := ctx.Value(EnableInputJsonOmitemptyTagKey).(bool) + if ok { + enableClientJsonOmitemptyTag = enable + } + } + return enableClientJsonOmitemptyTag +} + +// WithEnableInputJsonOmitemptyTagInterceptor creates a RequestInterceptor that sets EnableInputJsonOmitemptyTag in context +func WithEnableInputJsonOmitemptyTagInterceptor(enable bool) RequestInterceptor { + return func(ctx context.Context, req *http.Request, gqlInfo *GQLRequestInfo, res any, next RequestInterceptorFunc) error { + newCtx := WithEnableInputJsonOmitemptyTag(ctx, enable) + return next(newCtx, req, gqlInfo, res) + } +} + +func MarshalJSON(ctx context.Context, v any) ([]byte, error) { if v == nil { return []byte("null"), nil } @@ -472,7 +505,11 @@ func MarshalJSON(v any) ([]byte, error) { return []byte("null"), nil } - return encode(val) + encoder := &Encoder{ + EnableInputJsonOmitemptyTag: getEnableInputJsonOmitemptyTagFromContext(ctx), + } + + return encoder.Encode(val) } func checkImplements[I any](v reflect.Value) bool { @@ -482,48 +519,61 @@ func checkImplements[I any](v reflect.Value) bool { return t.Implements(interfaceType) || (t.Kind() == reflect.Ptr && reflect.PointerTo(t).Implements(interfaceType)) } -// encode returns an appropriate encoder function for the provided value. -func encode(v reflect.Value) ([]byte, error) { +// Encoder is a struct for encoding GraphQL requests to JSON +type Encoder struct { + EnableInputJsonOmitemptyTag bool +} + +// fieldInfo holds field information of a struct +type fieldInfo struct { + name string // field name + jsonName string // field name in JSON + omitempty bool // omitempty flag + typ reflect.Type // field type +} + +// Encode encodes any value to JSON +func (e *Encoder) Encode(v reflect.Value) ([]byte, error) { if !v.IsValid() || (v.Kind() == reflect.Ptr && v.IsNil()) { return []byte("null"), nil } if checkImplements[graphql.Marshaler](v) { - return encodeGQLMarshaler(v.Interface()) + return e.encodeGQLMarshaler(v.Interface()) } if checkImplements[json.Marshaler](v) { - return encodeJsonMarshaler(v.Interface()) + return e.encodeJsonMarshaler(v.Interface()) } if checkImplements[encoding.TextMarshaler](v) { - return encodeTextMarshaler(v.Interface()) + return e.encodeTextMarshaler(v.Interface()) } - t := v.Type() // Get the type from the value + t := v.Type() switch t.Kind() { case reflect.Ptr: - return encodePtr(v) + return e.encodePtr(v) case reflect.Struct: - return encodeStruct(v) + return e.encodeStruct(v) case reflect.Map: - return encodeMap(v) + return e.encodeMap(v) case reflect.Slice: - return encodeSlice(v) + return e.encodeSlice(v) case reflect.Array: - return encodeArray(v) + return e.encodeArray(v) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return encodeInt(v) + return e.encodeInt(v) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return encodeUint(v) + return e.encodeUint(v) case reflect.String: - return encodeString(v) + return e.encodeString(v) case reflect.Bool: - return encodeBool(v) + return e.encodeBool(v) case reflect.Float32, reflect.Float64: - return encodeFloat(v) + return e.encodeFloat(v) case reflect.Interface: - return encodeInterface(v) + return e.encodeInterface(v) case reflect.Invalid, reflect.Complex64, reflect.Complex128, reflect.Chan, reflect.Func, reflect.UnsafePointer: panic(fmt.Sprintf("unsupported type: %s", t)) default: @@ -531,7 +581,8 @@ func encode(v reflect.Value) ([]byte, error) { } } -func encodeGQLMarshaler(v any) ([]byte, error) { +// encodeGQLMarshaler encodes a value that implements graphql.Marshaler interface +func (e *Encoder) encodeGQLMarshaler(v any) ([]byte, error) { if v == nil { return []byte("null"), nil } @@ -546,24 +597,24 @@ func encodeGQLMarshaler(v any) ([]byte, error) { return buf.Bytes(), nil } -func encodeJsonMarshaler(v any) ([]byte, error) { +// encodeJsonMarshaler encodes a value that implements json.Marshaler interface +func (e *Encoder) encodeJsonMarshaler(v any) ([]byte, error) { if val, ok := v.(json.Marshaler); ok { return val.MarshalJSON() - } else { - return nil, fmt.Errorf("failed to encode json.Marshaler: %v", v) } + return nil, fmt.Errorf("failed to encode json.Marshaler: %v", v) } -func encodeTextMarshaler(v any) ([]byte, error) { +// encodeTextMarshaler encodes a value that implements encoding.TextMarshaler interface +func (e *Encoder) encodeTextMarshaler(v any) ([]byte, error) { if _, ok := v.(encoding.TextMarshaler); ok { - // json.Marshal uses encoding.TextMarshaler internally if the value implements it. return json.Marshal(v) - } else { - return nil, fmt.Errorf("failed to encode encoding.TextMarshaler: %v", v) } + return nil, fmt.Errorf("failed to encode encoding.TextMarshaler: %v", v) } -func encodeBool(v reflect.Value) ([]byte, error) { +// encodeBool encodes a boolean value +func (e *Encoder) encodeBool(v reflect.Value) ([]byte, error) { boolValue, err := json.Marshal(v.Bool()) if err != nil { return nil, fmt.Errorf("failed to encode bool: %v", v) @@ -571,19 +622,23 @@ func encodeBool(v reflect.Value) ([]byte, error) { return boolValue, nil } -func encodeInt(v reflect.Value) ([]byte, error) { +// encodeInt encodes an integer value +func (e *Encoder) encodeInt(v reflect.Value) ([]byte, error) { return []byte(fmt.Sprintf("%d", v.Int())), nil } -func encodeUint(v reflect.Value) ([]byte, error) { +// encodeUint encodes an unsigned integer value +func (e *Encoder) encodeUint(v reflect.Value) ([]byte, error) { return []byte(fmt.Sprintf("%d", v.Uint())), nil } -func encodeFloat(v reflect.Value) ([]byte, error) { +// encodeFloat encodes a floating-point value +func (e *Encoder) encodeFloat(v reflect.Value) ([]byte, error) { return []byte(fmt.Sprintf("%f", v.Float())), nil } -func encodeString(v reflect.Value) ([]byte, error) { +// encodeString encodes a string value +func (e *Encoder) encodeString(v reflect.Value) ([]byte, error) { stringValue, err := json.Marshal(v.String()) if err != nil { return nil, fmt.Errorf("failed to encode string: %v", v) @@ -591,62 +646,57 @@ func encodeString(v reflect.Value) ([]byte, error) { return stringValue, nil } -type fieldInfo struct { - name string - jsonName string - omitempty bool - typ reflect.Type +// trimQuotes removes double quotes from the beginning and end of a string +func (e *Encoder) trimQuotes(s string) string { + if len(s) > 1 && s[0] == '"' && s[len(s)-1] == '"' { + return s[1 : len(s)-1] + } + return s } -func prepareFields(t reflect.Type) []fieldInfo { - num := t.NumField() - fields := make([]fieldInfo, 0, num) - for i := range num { - f := t.Field(i) - if f.PkgPath != "" && !f.Anonymous { // Skip unexported fields unless they are embedded - continue - } - jsonTag := f.Tag.Get("json") - if jsonTag == "-" { - continue // Skip fields explicitly marked to be ignored - } - - jsonName := f.Name - if jsonTag != "" { - parts := strings.Split(jsonTag, ",") - jsonName = parts[0] // Use the name specified in the JSON tag - } +func (e *Encoder) isSkipOmitemptyField(v reflect.Value, field fieldInfo) bool { + if !e.EnableInputJsonOmitemptyTag { + return false + } - fi := fieldInfo{ - name: f.Name, - jsonName: jsonName, - typ: f.Type, - } + if !v.IsValid() { + return true + } - if strings.Contains(jsonTag, "omitempty") { - fi.omitempty = true - } + if v.Kind() == reflect.Ptr && v.IsNil() { + return true + } - fields = append(fields, fi) + if field.omitempty && v.IsZero() { + return true } - return fields + return false } -func encodeStruct(v reflect.Value) ([]byte, error) { - fields := prepareFields(v.Type()) +// encodeStruct encodes a struct value +func (e *Encoder) encodeStruct(v reflect.Value) ([]byte, error) { + fields := e.prepareFields(v.Type()) result := make(map[string]json.RawMessage) for _, field := range fields { fieldValue := v.FieldByName(field.name) - if !fieldValue.IsValid() || (fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil()) { - continue // Skip invalid or nil pointers to avoid panics + if e.isSkipOmitemptyField(fieldValue, field) { + continue } - if field.omitempty && fieldValue.IsZero() { - continue // Skip nil fields marked with omitempty + // omitemptyが無効な場合、nilスライスは空のスライス[]として扱う + if !e.EnableInputJsonOmitemptyTag && fieldValue.Kind() == reflect.Slice && fieldValue.IsNil() { + result[field.jsonName] = []byte("[]") + continue } - encodedValue, err := encode(fieldValue) + // omitemptyが無効な場合、nilポインタはnullとして扱い、出力に含める + if !e.EnableInputJsonOmitemptyTag && fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { + result[field.jsonName] = []byte("null") + continue + } + + encodedValue, err := e.Encode(fieldValue) if err != nil { return nil, err } @@ -655,26 +705,19 @@ func encodeStruct(v reflect.Value) ([]byte, error) { return json.Marshal(result) } -func trimQuotes(s string) string { - if len(s) > 1 && s[0] == '"' && s[len(s)-1] == '"' { - return s[1 : len(s)-1] - } - - return s -} - -func encodeMap(v reflect.Value) ([]byte, error) { +// encodeMap encodes a map value +func (e *Encoder) encodeMap(v reflect.Value) ([]byte, error) { result := make(map[string]json.RawMessage) for _, key := range v.MapKeys() { - encodedKey, err := encode(key) + encodedKey, err := e.Encode(key) if err != nil { return nil, err } keyStr := string(encodedKey) - keyStr = trimQuotes(keyStr) + keyStr = e.trimQuotes(keyStr) value := v.MapIndex(key) - encodedValue, err := encode(value) + encodedValue, err := e.Encode(value) if err != nil { return nil, err } @@ -683,10 +726,11 @@ func encodeMap(v reflect.Value) ([]byte, error) { return json.Marshal(result) } -func encodeSlice(v reflect.Value) ([]byte, error) { +// encodeSlice encodes a slice value +func (e *Encoder) encodeSlice(v reflect.Value) ([]byte, error) { result := make([]json.RawMessage, v.Len()) for i := range v.Len() { - encodedValue, err := encode(v.Index(i)) + encodedValue, err := e.Encode(v.Index(i)) if err != nil { return nil, err } @@ -695,10 +739,11 @@ func encodeSlice(v reflect.Value) ([]byte, error) { return json.Marshal(result) } -func encodeArray(v reflect.Value) ([]byte, error) { +// encodeArray encodes an array value +func (e *Encoder) encodeArray(v reflect.Value) ([]byte, error) { result := make([]json.RawMessage, v.Len()) for i := range v.Len() { - encodedValue, err := encode(v.Index(i)) + encodedValue, err := e.Encode(v.Index(i)) if err != nil { return nil, err } @@ -707,18 +752,54 @@ func encodeArray(v reflect.Value) ([]byte, error) { return json.Marshal(result) } -func encodePtr(v reflect.Value) ([]byte, error) { +// encodePtr encodes a pointer value +func (e *Encoder) encodePtr(v reflect.Value) ([]byte, error) { if v.IsNil() { return []byte("null"), nil } - - return encode(v.Elem()) + return e.Encode(v.Elem()) } -func encodeInterface(v reflect.Value) ([]byte, error) { +// encodeInterface encodes an interface value +func (e *Encoder) encodeInterface(v reflect.Value) ([]byte, error) { if v.IsNil() { return []byte("null"), nil } - actualValue := v.Elem() - return encode(actualValue) + return e.Encode(v.Elem()) +} + +// prepareFields collects field information from a struct type +func (e *Encoder) prepareFields(t reflect.Type) []fieldInfo { + num := t.NumField() + fields := make([]fieldInfo, 0, num) + for i := range num { + f := t.Field(i) + if f.PkgPath != "" && !f.Anonymous { + continue + } + jsonTag := f.Tag.Get("json") + if jsonTag == "-" { + continue + } + + jsonName := f.Name + if jsonTag != "" { + parts := strings.Split(jsonTag, ",") + jsonName = parts[0] + } + + fi := fieldInfo{ + name: f.Name, + jsonName: jsonName, + typ: f.Type, + } + + if strings.Contains(jsonTag, "omitempty") { + fi.omitempty = true + } + + fields = append(fields, fi) + } + + return fields } diff --git a/clientv2/client_test.go b/clientv2/client_test.go index dbd2971..dbfca09 100644 --- a/clientv2/client_test.go +++ b/clientv2/client_test.go @@ -611,7 +611,7 @@ func TestMarshalJSONValueType(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := MarshalJSON(tt.args.v) + got, err := MarshalJSON(context.Background(), tt.args.v) if (err != nil) != tt.wantErr { t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) @@ -809,7 +809,7 @@ func TestMarshalJSON(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := MarshalJSON(tt.args.v) + got, err := MarshalJSON(context.Background(), tt.args.v) if (err != nil) != tt.wantErr { t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) @@ -958,3 +958,179 @@ func TestUnsafeChainInterceptor(t *testing.T) { } }) } + +func TestEncoder_encodeStruct(t *testing.T) { + type Address struct { + City string `json:"city"` + Country string `json:"country,omitempty"` + Zip *string `json:"zip,omitempty"` + } + + type Person struct { + Name string `json:"name"` + Age int64 `json:"age,omitempty"` + Email *string `json:"email,omitempty"` + Address Address `json:"address"` + Tags []string `json:"tags,omitempty"` + Nickname string `json:"nickname,omitempty"` + Empty string `json:"-"` + unexposed string + } + + zip := "123-4567" + email := "test@example.com" + + tests := []struct { + name string + input Person + enableOmitemptyTag bool + want map[string]interface{} + wantErr bool + }{ + { + name: "all fields filled", + input: Person{ + Name: "John", + Age: 30, + Email: &email, + Address: Address{City: "Tokyo", Country: "Japan", Zip: &zip}, + Tags: []string{"tag1", "tag2"}, + Nickname: "Johnny", + }, + enableOmitemptyTag: true, + want: map[string]any{ + "name": "John", + "age": int64(30), + "email": "test@example.com", + "address": map[string]any{"city": "Tokyo", "country": "Japan", "zip": "123-4567"}, + "tags": []any{"tag1", "tag2"}, + "nickname": "Johnny", + }, + }, + { + name: "omitempty fields with zero values", + input: Person{ + Name: "John", + Address: Address{City: "Tokyo"}, + }, + enableOmitemptyTag: true, + want: map[string]any{ + "name": "John", + "address": map[string]any{"city": "Tokyo"}, + }, + }, + { + name: "omitempty disabled", + input: Person{ + Name: "John", + Address: Address{City: "Tokyo"}, + }, + enableOmitemptyTag: false, + want: map[string]any{ + "name": "John", + "age": int64(0), + "email": nil, + "address": map[string]any{"city": "Tokyo", "country": "", "zip": nil}, + "tags": []any{}, + "nickname": "", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + encoder := &Encoder{ + EnableInputJsonOmitemptyTag: tt.enableOmitemptyTag, + } + + got, err := encoder.encodeStruct(reflect.ValueOf(tt.input)) + if (err != nil) != tt.wantErr { + t.Errorf("encodeStruct() error = %v, wantErr %v", err, tt.wantErr) + return + } + + // 期待値をJSONに変換 + want, err := json.Marshal(tt.want) + if err != nil { + t.Errorf("failed to marshal want: %v", err) + return + } + + // JSONの文字列として比較 + if string(got) != string(want) { + t.Errorf("encodeStruct() = %s, want %s", got, want) + } + }) + } +} + +func TestEncoder_isSkipOmitemptyField(t *testing.T) { + type testStruct struct { + Required string `json:"required"` + Optional string `json:"optional,omitempty"` + Ptr *string `json:"ptr,omitempty"` + } + + str := "test" + tests := []struct { + name string + value reflect.Value + field fieldInfo + enableOmitemptyTag bool + want bool + }{ + { + name: "non-empty value with omitempty", + value: reflect.ValueOf("test"), + field: fieldInfo{omitempty: true}, + enableOmitemptyTag: true, + want: false, + }, + { + name: "empty value with omitempty", + value: reflect.ValueOf(""), + field: fieldInfo{omitempty: true}, + enableOmitemptyTag: true, + want: true, + }, + { + name: "nil pointer with omitempty", + value: reflect.ValueOf((*string)(nil)), + field: fieldInfo{omitempty: true}, + enableOmitemptyTag: true, + want: true, + }, + { + name: "non-nil pointer with omitempty", + value: reflect.ValueOf(&str), + field: fieldInfo{omitempty: true}, + enableOmitemptyTag: true, + want: false, + }, + { + name: "empty value without omitempty", + value: reflect.ValueOf(""), + field: fieldInfo{omitempty: false}, + enableOmitemptyTag: true, + want: false, + }, + { + name: "omitempty tag disabled", + value: reflect.ValueOf(""), + field: fieldInfo{omitempty: true}, + enableOmitemptyTag: false, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + encoder := &Encoder{ + EnableInputJsonOmitemptyTag: tt.enableOmitemptyTag, + } + if got := encoder.isSkipOmitemptyField(tt.value, tt.field); got != tt.want { + t.Errorf("isSkipOmitemptyField() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/querydocument/query_document.go b/querydocument/query_document.go index b1b8bdb..80d90a8 100644 --- a/querydocument/query_document.go +++ b/querydocument/query_document.go @@ -92,7 +92,7 @@ func CollectTypesFromQueryDocuments(schema *ast.Schema, queryDocuments []*ast.Qu return usedTypes } -func collectInputObjectFieldsWithCycle(def *ast.Definition, schema *ast.Schema, usedTypes map[string]bool, processedTypes map[string]bool) { +func collectInputObjectFieldsWithCycle(def *ast.Definition, schema *ast.Schema, usedTypes, processedTypes map[string]bool) { if processedTypes[def.Name] { return // この型は既に完全に処理済み } From 295f9af260e53732b077cf14668cf1942d639f97 Mon Sep 17 00:00:00 2001 From: Yamashou <1230124fw@gmail.com> Date: Sat, 21 Dec 2024 22:00:25 +0900 Subject: [PATCH 2/2] version 0.29.0 --- main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.go b/main.go index 3c720e4..20b7547 100644 --- a/main.go +++ b/main.go @@ -9,7 +9,7 @@ import ( "github.com/urfave/cli/v2" ) -const version = "0.28.2" +const version = "0.29.0" var versionCmd = &cli.Command{ Name: "version",