diff --git a/aws/protocol/restjson/decoder_util.go b/aws/protocol/restjson/decoder_util.go index 3193d7c8142..ffaa0fbc3d2 100644 --- a/aws/protocol/restjson/decoder_util.go +++ b/aws/protocol/restjson/decoder_util.go @@ -2,12 +2,10 @@ package restjson import ( "encoding/json" - "fmt" "io" "strings" "github.com/awslabs/smithy-go" - smithyjson "github.com/awslabs/smithy-go/json" ) // GetErrorInfo util looks for code, __type, and message members in the @@ -15,55 +13,37 @@ import ( // returns the value of member if it is available. This function is useful to // identify the error code, msg in a REST JSON error response. func GetErrorInfo(decoder *json.Decoder) (errorType string, message string, err error) { - startToken, err := decoder.Token() - if err == io.EOF { - return "", "", nil + var errInfo struct { + Code string + Type string `json:"__type"` + Message string } - if err != nil { - return "", "", err - } - - if t, ok := startToken.(json.Delim); !ok || t.String() != "{" { - return "", "", fmt.Errorf("expected start token to be {") - } - - for decoder.More() { - var target *string - t, err := decoder.Token() - if err != nil { - return "", "", err - } - switch st := t.(string); { - case strings.EqualFold(st, "code"): - fallthrough - case strings.EqualFold(st, "__type"): - target = &errorType - case strings.EqualFold(st, "message"): - target = &message - default: - smithyjson.DiscardUnknownField(decoder) - continue - } - - v, err := decoder.Token() - if err != nil { - return errorType, message, err + err = decoder.Decode(&errInfo) + if err != nil { + if err == io.EOF { + return errorType, message, nil } - *target = v.(string) + return errorType, message, err } - endToken, err := decoder.Token() - if err != nil { - return "", "", err + // assign error type + if len(errInfo.Code) != 0 { + errorType = errInfo.Code + } else if len(errInfo.Type) != 0 { + errorType = errInfo.Type } - if t, ok := endToken.(json.Delim); !ok || t.String() != "}" { - return "", "", fmt.Errorf("expected end token to be }") + // assign error message + if len(errInfo.Message) != 0 { + message = errInfo.Message } // sanitize error - errorType = SanitizeErrorCode(errorType) + if len(errorType) != 0 { + errorType = SanitizeErrorCode(errorType) + } + return errorType, message, nil } diff --git a/aws/protocol/restjson/decoder_util_test.go b/aws/protocol/restjson/decoder_util_test.go new file mode 100644 index 00000000000..48e1555af4a --- /dev/null +++ b/aws/protocol/restjson/decoder_util_test.go @@ -0,0 +1,83 @@ +package restjson + +import ( + "bytes" + "encoding/json" + "io" + "strings" + "testing" +) + +func TestGetErrorInfo(t *testing.T) { + cases := map[string]struct { + errorResponse []byte + expectedErrorType string + expectedErrorMsg string + expectedDeserializationError string + }{ + "error with code": { + errorResponse: []byte(`{"code": "errorCode", "message": "message for errorCode"}`), + expectedErrorType: "errorCode", + expectedErrorMsg: "message for errorCode", + }, + "error with type": { + errorResponse: []byte(`{"__type": "errorCode", "message": "message for errorCode"}`), + expectedErrorType: "errorCode", + expectedErrorMsg: "message for errorCode", + }, + + "error with only message": { + errorResponse: []byte(`{"message": "message for errorCode"}`), + expectedErrorMsg: "message for errorCode", + }, + + "error with only code": { + errorResponse: []byte(`{"code": "errorCode"}`), + expectedErrorType: "errorCode", + }, + + "empty": { + errorResponse: []byte(``), + }, + + "unknownField": { + errorResponse: []byte(`{"xyz":"abc", "code": "errorCode"}`), + expectedErrorType: "errorCode", + }, + + "unexpectedEOF": { + errorResponse: []byte(`{"xyz":"abc"`), + expectedDeserializationError: io.ErrUnexpectedEOF.Error(), + }, + + "caseless compare": { + errorResponse: []byte(`{"Code": "errorCode", "Message": "errorMessage", "xyz": "abc"}`), + expectedErrorType: "errorCode", + expectedErrorMsg: "errorMessage", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + decoder := json.NewDecoder(bytes.NewReader(c.errorResponse)) + actualType, actualMsg, err := GetErrorInfo(decoder) + if err != nil { + if len(c.expectedDeserializationError) == 0 { + t.Fatalf("expected no error, got %v", err.Error()) + } + + if e, a := c.expectedDeserializationError, err.Error(); !strings.Contains(a, e) { + t.Fatalf("expected error to be %v, got %v", e, a) + } + } + + if e, a := c.expectedErrorType, actualType; !strings.EqualFold(e, a) { + t.Fatalf("expected error type to be %v, got %v", e, a) + } + + if e, a := c.expectedErrorMsg, actualMsg; !strings.EqualFold(e, a) { + t.Fatalf("expected error message to be %v, got %v", e, a) + } + }) + } +}