diff --git a/api.go b/api.go index 19ac3bf2..a5190e0a 100644 --- a/api.go +++ b/api.go @@ -231,6 +231,10 @@ type API interface { Middlewares() Middlewares } +type InputParamConverter interface { + HumaInputParamConvert([]byte) (any, error) +} + // Format represents a request / response format. It is used to marshal and // unmarshal data. type Format struct { diff --git a/huma.go b/huma.go index 8d0a0a98..ab6f3f2f 100644 --- a/huma.go +++ b/huma.go @@ -868,253 +868,9 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) if value != "" { var pv any - switch p.Type.Kind() { - case reflect.String: - f.SetString(value) - pv = value - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - v, err := strconv.ParseInt(value, 10, 64) - if err != nil { - res.Add(pb, value, "invalid integer") - return - } - f.SetInt(v) - pv = v - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - v, err := strconv.ParseUint(value, 10, 64) - if err != nil { - res.Add(pb, value, "invalid integer") - return - } - f.SetUint(v) - pv = v - case reflect.Float32, reflect.Float64: - v, err := strconv.ParseFloat(value, 64) - if err != nil { - res.Add(pb, value, "invalid float") - return - } - f.SetFloat(v) - pv = v - case reflect.Bool: - v, err := strconv.ParseBool(value) - if err != nil { - res.Add(pb, value, "invalid boolean") - return - } - f.SetBool(v) - pv = v - default: - if f.Type().Kind() == reflect.Slice { - switch f.Type().Elem().Kind() { - - case reflect.String: - values := strings.Split(value, ",") - f.Set(reflect.ValueOf(values)) - pv = values - - case reflect.Int: - values := strings.Split(value, ",") - vs, err := parseArrElement(values, func(s string) (int, error) { - val, err := strconv.ParseInt(s, 10, strconv.IntSize) - if err != nil { - return 0, err - } - return int(val), nil - }) - if err != nil { - res.Add(pb, value, "invalid integer") - return - } - f.Set(reflect.ValueOf(vs)) - pv = vs - - case reflect.Int8: - values := strings.Split(value, ",") - vs, err := parseArrElement(values, func(s string) (int8, error) { - val, err := strconv.ParseInt(s, 10, 8) - if err != nil { - return 0, err - } - return int8(val), nil - }) - if err != nil { - res.Add(pb, value, "invalid integer") - return - } - f.Set(reflect.ValueOf(vs)) - pv = vs - - case reflect.Int16: - values := strings.Split(value, ",") - vs, err := parseArrElement(values, func(s string) (int16, error) { - val, err := strconv.ParseInt(s, 10, 16) - if err != nil { - return 0, err - } - return int16(val), nil - }) - if err != nil { - res.Add(pb, value, "invalid integer") - return - } - f.Set(reflect.ValueOf(vs)) - pv = vs - - case reflect.Int32: - values := strings.Split(value, ",") - vs, err := parseArrElement(values, func(s string) (int32, error) { - val, err := strconv.ParseInt(s, 10, 32) - if err != nil { - return 0, err - } - return int32(val), nil - }) - if err != nil { - res.Add(pb, value, "invalid integer") - return - } - f.Set(reflect.ValueOf(vs)) - pv = vs - - case reflect.Int64: - values := strings.Split(value, ",") - vs, err := parseArrElement(values, func(s string) (int64, error) { - val, err := strconv.ParseInt(s, 10, 64) - if err != nil { - return 0, err - } - return int64(val), nil - }) - if err != nil { - res.Add(pb, value, "invalid integer") - return - } - f.Set(reflect.ValueOf(vs)) - pv = vs - - case reflect.Uint: - values := strings.Split(value, ",") - vs, err := parseArrElement(values, func(s string) (uint, error) { - val, err := strconv.ParseUint(s, 10, strconv.IntSize) - if err != nil { - return 0, err - } - return uint(val), nil - }) - if err != nil { - res.Add(pb, value, "invalid integer") - return - } - f.Set(reflect.ValueOf(vs)) - pv = vs - - case reflect.Uint16: - values := strings.Split(value, ",") - vs, err := parseArrElement(values, func(s string) (uint16, error) { - val, err := strconv.ParseUint(s, 10, 16) - if err != nil { - return 0, err - } - return uint16(val), nil - }) - if err != nil { - res.Add(pb, value, "invalid integer") - return - } - f.Set(reflect.ValueOf(vs)) - pv = vs - - case reflect.Uint32: - values := strings.Split(value, ",") - vs, err := parseArrElement(values, func(s string) (uint32, error) { - val, err := strconv.ParseUint(s, 10, 32) - if err != nil { - return 0, err - } - return uint32(val), nil - }) - if err != nil { - res.Add(pb, value, "invalid integer") - return - } - f.Set(reflect.ValueOf(vs)) - pv = vs - - case reflect.Uint64: - values := strings.Split(value, ",") - vs, err := parseArrElement(values, func(s string) (uint64, error) { - val, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return 0, err - } - return uint64(val), nil - }) - if err != nil { - res.Add(pb, value, "invalid integer") - return - } - f.Set(reflect.ValueOf(vs)) - pv = vs - - case reflect.Float32: - values := strings.Split(value, ",") - vs, err := parseArrElement(values, func(s string) (float32, error) { - val, err := strconv.ParseFloat(s, 32) - if err != nil { - return 0, err - } - return float32(val), nil - }) - if err != nil { - res.Add(pb, value, "invalid floating value") - return - } - f.Set(reflect.ValueOf(vs)) - pv = vs - - case reflect.Float64: - values := strings.Split(value, ",") - vs, err := parseArrElement(values, func(s string) (float64, error) { - val, err := strconv.ParseFloat(s, 64) - if err != nil { - return 0, err - } - return float64(val), nil - }) - if err != nil { - res.Add(pb, value, "invalid floating value") - return - } - f.Set(reflect.ValueOf(vs)) - pv = vs - } - break - } - - // Special case: time.Time - if f.Type() == timeType { - t, err := time.Parse(p.TimeFormat, value) - if err != nil { - res.Add(pb, value, "invalid date/time for format "+p.TimeFormat) - return - } - f.Set(reflect.ValueOf(t)) - pv = value - break - } - - // Last resort: use the `encoding.TextUnmarshaler` interface. - if fn, ok := f.Addr().Interface().(encoding.TextUnmarshaler); ok { - if err := fn.UnmarshalText([]byte(value)); err != nil { - res.Add(pb, value, "invalid value: "+err.Error()) - return - } - pv = value - break - } - - panic("unsupported param type " + p.Type.String()) + pv, ok := getValue(f, p, value, res, pb) + if !ok { + return } if !op.SkipValidateParams { @@ -1583,3 +1339,263 @@ func Patch[I, O any](api API, path string, handler func(context.Context, *I) (*O func Delete[I, O any](api API, path string, handler func(context.Context, *I) (*O, error), operationHandlers ...func(o *Operation)) { convenience(api, http.MethodDelete, path, handler, operationHandlers...) } + +func getValue(f reflect.Value, p *paramFieldInfo, value string, res *ValidateResult, pb *PathBuffer) (any, bool) { + if fn, ok := f.Addr().Interface().(InputParamConverter); ok { + pv, err := fn.HumaInputParamConvert([]byte(value)) + if err != nil { + res.Add(pb, value, "invalid value for "+p.Type.String()) + return nil, false + } + return pv, true + } + + switch p.Type.Kind() { + case reflect.String: + f.SetString(value) + return value, true + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v, err := strconv.ParseInt(value, 10, 64) + if err != nil { + res.Add(pb, value, "invalid integer") + return nil, false + } + f.SetInt(v) + return v, true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v, err := strconv.ParseUint(value, 10, 64) + if err != nil { + res.Add(pb, value, "invalid integer") + return nil, false + } + f.SetUint(v) + return v, true + case reflect.Float32, reflect.Float64: + v, err := strconv.ParseFloat(value, 64) + if err != nil { + res.Add(pb, value, "invalid float") + return nil, false + } + f.SetFloat(v) + return v, true + case reflect.Bool: + v, err := strconv.ParseBool(value) + if err != nil { + res.Add(pb, value, "invalid boolean") + return nil, false + } + f.SetBool(v) + return v, true + default: + if f.Type().Kind() == reflect.Slice { + switch f.Type().Elem().Kind() { + + case reflect.String: + values := strings.Split(value, ",") + f.Set(reflect.ValueOf(values)) + return values, true + + case reflect.Int: + values := strings.Split(value, ",") + vs, err := parseArrElement(values, func(s string) (int, error) { + val, err := strconv.ParseInt(s, 10, strconv.IntSize) + if err != nil { + return 0, err + } + return int(val), nil + }) + if err != nil { + res.Add(pb, value, "invalid integer") + return nil, false + } + f.Set(reflect.ValueOf(vs)) + return vs, true + + case reflect.Int8: + values := strings.Split(value, ",") + vs, err := parseArrElement(values, func(s string) (int8, error) { + val, err := strconv.ParseInt(s, 10, 8) + if err != nil { + return 0, err + } + return int8(val), nil + }) + if err != nil { + res.Add(pb, value, "invalid integer") + return nil, false + } + f.Set(reflect.ValueOf(vs)) + return vs, true + + case reflect.Int16: + values := strings.Split(value, ",") + vs, err := parseArrElement(values, func(s string) (int16, error) { + val, err := strconv.ParseInt(s, 10, 16) + if err != nil { + return 0, err + } + return int16(val), nil + }) + if err != nil { + res.Add(pb, value, "invalid integer") + return nil, false + } + f.Set(reflect.ValueOf(vs)) + return vs, true + + case reflect.Int32: + values := strings.Split(value, ",") + vs, err := parseArrElement(values, func(s string) (int32, error) { + val, err := strconv.ParseInt(s, 10, 32) + if err != nil { + return 0, err + } + return int32(val), nil + }) + if err != nil { + res.Add(pb, value, "invalid integer") + return nil, false + } + f.Set(reflect.ValueOf(vs)) + return vs, true + + case reflect.Int64: + values := strings.Split(value, ",") + vs, err := parseArrElement(values, func(s string) (int64, error) { + val, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return 0, err + } + return int64(val), nil + }) + if err != nil { + res.Add(pb, value, "invalid integer") + return nil, false + } + f.Set(reflect.ValueOf(vs)) + return vs, true + + case reflect.Uint: + values := strings.Split(value, ",") + vs, err := parseArrElement(values, func(s string) (uint, error) { + val, err := strconv.ParseUint(s, 10, strconv.IntSize) + if err != nil { + return 0, err + } + return uint(val), nil + }) + if err != nil { + res.Add(pb, value, "invalid integer") + return nil, false + } + f.Set(reflect.ValueOf(vs)) + return vs, true + + case reflect.Uint16: + values := strings.Split(value, ",") + vs, err := parseArrElement(values, func(s string) (uint16, error) { + val, err := strconv.ParseUint(s, 10, 16) + if err != nil { + return 0, err + } + return uint16(val), nil + }) + if err != nil { + res.Add(pb, value, "invalid integer") + return nil, false + } + f.Set(reflect.ValueOf(vs)) + return vs, true + + case reflect.Uint32: + values := strings.Split(value, ",") + vs, err := parseArrElement(values, func(s string) (uint32, error) { + val, err := strconv.ParseUint(s, 10, 32) + if err != nil { + return 0, err + } + return uint32(val), nil + }) + if err != nil { + res.Add(pb, value, "invalid integer") + return nil, false + } + f.Set(reflect.ValueOf(vs)) + return vs, true + + case reflect.Uint64: + values := strings.Split(value, ",") + vs, err := parseArrElement(values, func(s string) (uint64, error) { + val, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return 0, err + } + return uint64(val), nil + }) + if err != nil { + res.Add(pb, value, "invalid integer") + return nil, false + } + f.Set(reflect.ValueOf(vs)) + return vs, true + + case reflect.Float32: + values := strings.Split(value, ",") + vs, err := parseArrElement(values, func(s string) (float32, error) { + val, err := strconv.ParseFloat(s, 32) + if err != nil { + return 0, err + } + return float32(val), nil + }) + if err != nil { + res.Add(pb, value, "invalid floating value") + return nil, false + } + f.Set(reflect.ValueOf(vs)) + return vs, true + + case reflect.Float64: + values := strings.Split(value, ",") + vs, err := parseArrElement(values, func(s string) (float64, error) { + val, err := strconv.ParseFloat(s, 64) + if err != nil { + return 0, err + } + return float64(val), nil + }) + if err != nil { + res.Add(pb, value, "invalid floating value") + return nil, false + } + f.Set(reflect.ValueOf(vs)) + return vs, true + + default: + return nil, true + } + } + + // Special case: time.Time + if f.Type() == timeType { + t, err := time.Parse(p.TimeFormat, value) + if err != nil { + res.Add(pb, value, "invalid date/time for format "+p.TimeFormat) + return nil, false + } + f.Set(reflect.ValueOf(t)) + return value, true + } + + // Last resort: use the `encoding.TextUnmarshaler` interface. + if fn, ok := f.Addr().Interface().(encoding.TextUnmarshaler); ok { + if err := fn.UnmarshalText([]byte(value)); err != nil { + res.Add(pb, value, "invalid value: "+err.Error()) + return nil, false + } + return value, true + } + + panic("unsupported param type " + p.Type.String()) + } +} diff --git a/huma_test.go b/huma_test.go index 15333c7e..9a15b3f6 100644 --- a/huma_test.go +++ b/huma_test.go @@ -11,6 +11,7 @@ import ( "net/http" "net/http/httptest" "net/http/httputil" + "reflect" "strings" "testing" "time" @@ -1921,3 +1922,58 @@ func TestGenerateFuncsPanicWithDescriptiveMessage(t *testing.T) { }) } + +type Nullable[T any] struct { + Null bool + Value T +} + +func (o *Nullable[T]) HumaInputParamConvert(b []byte) (any, error) { + err := o.UnmarshalText(b) + return o.Value, err +} + +func (o Nullable[T]) Schema(r huma.Registry) *huma.Schema { + return r.Schema(reflect.TypeOf(o.Value), true, "") +} + +func (o *Nullable[T]) UnmarshalText(b []byte) error { + o.Null = true + if len(b) == 0 { + o.Null = true + return nil + } + var temp T // Create a temporary variable of type T + err := json.Unmarshal(b, &temp) + if err != nil { + return err + } + o.Value = temp + o.Null = false + return nil +} + +type TestCustomQueryParamInput struct { + Query Nullable[int] `query:"query"` + ID Nullable[int] `header:"id"` +} + +func TestCustomInputParam(t *testing.T) { + r, app := humatest.New(t, huma.DefaultConfig("Test API", "1.0.0")) + huma.Register(app, huma.Operation{ + OperationID: "test", + Method: http.MethodGet, + Path: "/custom_query_param", + }, func(ctx context.Context, input *TestCustomQueryParamInput) (*struct{}, error) { + assert.Equal(t, Nullable[int]{Null: false, Value: 12}, input.ID) + assert.Equal(t, Nullable[int]{Null: false, Value: 2}, input.Query) + return nil, nil + }) + + req, _ := http.NewRequest(http.MethodGet, "/custom_query_param?query=2", nil) + req.Header.Add("id", "12") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + resp, _ := io.ReadAll(w.Body) + assert.Equal(t, 204, w.Code, string(resp)) +}