Skip to content

Commit

Permalink
Merge pull request #520 from smacker/allow-override-errors
Browse files Browse the repository at this point in the history
feat: make validation error messages customizable
  • Loading branch information
danielgtaylor authored Jul 26, 2024
2 parents fd2d72d + 71fac32 commit 4b6a508
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 50 deletions.
3 changes: 3 additions & 0 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,3 +362,6 @@ func Error503ServiceUnavailable(msg string, errs ...error) StatusError {
func Error504GatewayTimeout(msg string, errs ...error) StatusError {
return NewError(http.StatusGatewayTimeout, msg, errs...)
}

// ErrorFormatter is a function that formats an error message
var ErrorFormatter func(format string, a ...any) string = fmt.Sprintf
36 changes: 19 additions & 17 deletions schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"strconv"
"strings"
"time"

"github.com/danielgtaylor/huma/v2/validation"
)

// ErrSchemaInvalid is sent when there is a problem building the schema.
Expand Down Expand Up @@ -202,57 +204,57 @@ func (s *Schema) MarshalJSON() ([]byte, error) {
// PrecomputeMessages tries to precompute as many validation error messages
// as possible so that new strings aren't allocated during request validation.
func (s *Schema) PrecomputeMessages() {
s.msgEnum = "expected value to be one of \"" + strings.Join(mapTo(s.Enum, func(v any) string {
s.msgEnum = ErrorFormatter(validation.MsgExpectedOneOf, strings.Join(mapTo(s.Enum, func(v any) string {
return fmt.Sprintf("%v", v)
}), ", ") + "\""
}), ", "))
if s.Minimum != nil {
s.msgMinimum = fmt.Sprintf("expected number >= %v", *s.Minimum)
s.msgMinimum = ErrorFormatter(validation.MsgExpectedMinimumNumber, *s.Minimum)
}
if s.ExclusiveMinimum != nil {
s.msgExclusiveMinimum = fmt.Sprintf("expected number > %v", *s.ExclusiveMinimum)
s.msgExclusiveMinimum = ErrorFormatter(validation.MsgExpectedExclusiveMinimumNumber, *s.ExclusiveMinimum)
}
if s.Maximum != nil {
s.msgMaximum = fmt.Sprintf("expected number <= %v", *s.Maximum)
s.msgMaximum = ErrorFormatter(validation.MsgExpectedMaximumNumber, *s.Maximum)
}
if s.ExclusiveMaximum != nil {
s.msgExclusiveMaximum = fmt.Sprintf("expected number < %v", *s.ExclusiveMaximum)
s.msgExclusiveMaximum = ErrorFormatter(validation.MsgExpectedExclusiveMaximumNumber, *s.ExclusiveMaximum)
}
if s.MultipleOf != nil {
s.msgMultipleOf = fmt.Sprintf("expected number to be a multiple of %v", *s.MultipleOf)
s.msgMultipleOf = ErrorFormatter(validation.MsgExpectedNumberBeMultipleOf, *s.MultipleOf)
}
if s.MinLength != nil {
s.msgMinLength = fmt.Sprintf("expected length >= %d", *s.MinLength)
s.msgMinLength = ErrorFormatter(validation.MsgExpectedMinLength, *s.MinLength)
}
if s.MaxLength != nil {
s.msgMaxLength = fmt.Sprintf("expected length <= %d", *s.MaxLength)
s.msgMaxLength = ErrorFormatter(validation.MsgExpectedMaxLength, *s.MaxLength)
}
if s.Pattern != "" {
s.patternRe = regexp.MustCompile(s.Pattern)
if s.PatternDescription != "" {
s.msgPattern = "expected string to be " + s.PatternDescription
s.msgPattern = ErrorFormatter(validation.MsgExpectedBePattern, s.PatternDescription)
} else {
s.msgPattern = "expected string to match pattern " + s.Pattern
s.msgPattern = ErrorFormatter(validation.MsgExpectedMatchPattern, s.Pattern)
}
}
if s.MinItems != nil {
s.msgMinItems = fmt.Sprintf("expected array length >= %d", *s.MinItems)
s.msgMinItems = ErrorFormatter(validation.MsgExpectedMinItems, *s.MinItems)
}
if s.MaxItems != nil {
s.msgMaxItems = fmt.Sprintf("expected array length <= %d", *s.MaxItems)
s.msgMaxItems = ErrorFormatter(validation.MsgExpectedMaxItems, *s.MaxItems)
}
if s.MinProperties != nil {
s.msgMinProperties = fmt.Sprintf("expected object with at least %d properties", *s.MinProperties)
s.msgMinProperties = ErrorFormatter(validation.MsgExpectedMinProperties, *s.MinProperties)
}
if s.MaxProperties != nil {
s.msgMaxProperties = fmt.Sprintf("expected object with at most %d properties", *s.MaxProperties)
s.msgMaxProperties = ErrorFormatter(validation.MsgExpectedMaxProperties, *s.MaxProperties)
}

if s.Required != nil {
if s.msgRequired == nil {
s.msgRequired = map[string]string{}
}
for _, name := range s.Required {
s.msgRequired[name] = "expected required property " + name + " to be present"
s.msgRequired[name] = ErrorFormatter(validation.MsgExpectedRequiredProperty, name)
}
}

Expand All @@ -265,7 +267,7 @@ func (s *Schema) PrecomputeMessages() {
if s.msgDependentRequired[name] == nil {
s.msgDependentRequired[name] = map[string]string{}
}
s.msgDependentRequired[name][dependent] = "expected property " + dependent + " to be present when " + name + " is present"
s.msgDependentRequired[name][dependent] = ErrorFormatter(validation.MsgExpectedDependentRequiredProperty, dependent, name)
}
}
}
Expand Down
68 changes: 35 additions & 33 deletions validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"time"
"unicode/utf8"
"unsafe"

"github.com/danielgtaylor/huma/v2/validation"
)

// ValidateMode describes the direction of validation (server -> client or
Expand Down Expand Up @@ -192,69 +194,69 @@ func validateFormat(path *PathBuffer, str string, s *Schema, res *ValidateResult
}
}
if !found {
res.Add(path, str, "expected string to be RFC 3339 date-time")
res.Add(path, str, validation.MsgExpectedRFC3339DateTime)
}
case "date-time-http":
if _, err := time.Parse(time.RFC1123, str); err != nil {
res.Add(path, str, "expected string to be RFC 1123 date-time")
res.Add(path, str, validation.MsgExpectedRFC1123DateTime)
}
case "date":
if _, err := time.Parse("2006-01-02", str); err != nil {
res.Add(path, str, "expected string to be RFC 3339 date")
res.Add(path, str, validation.MsgExpectedRFC3339Date)
}
case "time":
if _, err := time.Parse("15:04:05", str); err != nil {
if _, err := time.Parse("15:04:05Z07:00", str); err != nil {
res.Add(path, str, "expected string to be RFC 3339 time")
res.Add(path, str, validation.MsgExpectedRFC3339Time)
}
}
// TODO: duration
case "email", "idn-email":
if _, err := mail.ParseAddress(str); err != nil {
res.Addf(path, str, "expected string to be RFC 5322 email: %v", err)
res.Add(path, str, ErrorFormatter(validation.MsgExpectedRFC5322Email, err))
}
case "hostname":
if !(rxHostname.MatchString(str) && len(str) < 256) {
res.Add(path, str, "expected string to be RFC 5890 hostname")
res.Add(path, str, validation.MsgExpectedRFC5890Hostname)
}
// TODO: proper idn-hostname support... need to figure out how.
case "ipv4":
if ip := net.ParseIP(str); ip == nil || ip.To4() == nil {
res.Add(path, str, "expected string to be RFC 2673 ipv4")
res.Add(path, str, validation.MsgExpectedRFC2673IPv4)
}
case "ipv6":
if ip := net.ParseIP(str); ip == nil || ip.To16() == nil {
res.Add(path, str, "expected string to be RFC 2373 ipv6")
res.Add(path, str, validation.MsgExpectedRFC2373IPv6)
}
case "uri", "uri-reference", "iri", "iri-reference":
if _, err := url.Parse(str); err != nil {
res.Addf(path, str, "expected string to be RFC 3986 uri: %v", err)
res.Add(path, str, ErrorFormatter(validation.MsgExpectedRFC3986URI, err))
}
// TODO: check if it's actually a reference?
case "uuid":
if err := validateUUID(str); err != nil {
res.Addf(path, str, "expected string to be RFC 4122 uuid: %v", err)
res.Add(path, str, ErrorFormatter(validation.MsgExpectedRFC4122UUID, err))
}
case "uri-template":
u, err := url.Parse(str)
if err != nil {
res.Addf(path, str, "expected string to be RFC 3986 uri: %v", err)
res.Add(path, str, ErrorFormatter(validation.MsgExpectedRFC3986URI, err))
return
}
if !rxURITemplate.MatchString(u.Path) {
res.Add(path, str, "expected string to be RFC 6570 uri-template")
res.Add(path, str, validation.MsgExpectedRFC6570URITemplate)
}
case "json-pointer":
if !rxJSONPointer.MatchString(str) {
res.Add(path, str, "expected string to be RFC 6901 json-pointer")
res.Add(path, str, validation.MsgExpectedRFC6901JSONPointer)
}
case "relative-json-pointer":
if !rxRelJSONPointer.MatchString(str) {
res.Add(path, str, "expected string to be RFC 6901 relative-json-pointer")
res.Add(path, str, validation.MsgExpectedRFC6901RelativeJSONPointer)
}
case "regex":
if _, err := regexp.Compile(str); err != nil {
res.Addf(path, str, "expected string to be regex: %v", err)
res.Add(path, str, ErrorFormatter(validation.MsgExpectedRegexp, err))
}
}
}
Expand Down Expand Up @@ -289,7 +291,7 @@ func validateAnyOf(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v
}

if matches == 0 {
res.Add(path, v, "expected value to match at least one schema but matched none")
res.Add(path, v, validation.MsgExpectedMatchSchema)
}
}

Expand Down Expand Up @@ -333,7 +335,7 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any,
subRes := &ValidateResult{}
Validate(r, s.Not, path, mode, v, subRes)
if len(subRes.Errors) == 0 {
res.Add(path, v, "expected value to not match schema")
res.Add(path, v, validation.MsgExpectedNotMatchSchema)
}
}

Expand All @@ -344,7 +346,7 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any,
switch s.Type {
case TypeBoolean:
if _, ok := v.(bool); !ok {
res.Add(path, v, "expected boolean")
res.Add(path, v, validation.MsgExpectedBoolean)
return
}
case TypeNumber, TypeInteger:
Expand Down Expand Up @@ -376,18 +378,18 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any,
case uint64:
num = float64(v)
default:
res.Add(path, v, "expected number")
res.Add(path, v, validation.MsgExpectedNumber)
return
}

if s.Minimum != nil {
if num < *s.Minimum {
res.Addf(path, v, s.msgMinimum)
res.Add(path, v, s.msgMinimum)
}
}
if s.ExclusiveMinimum != nil {
if num <= *s.ExclusiveMinimum {
res.Addf(path, v, s.msgExclusiveMinimum)
res.Add(path, v, s.msgExclusiveMinimum)
}
}
if s.Maximum != nil {
Expand All @@ -397,12 +399,12 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any,
}
if s.ExclusiveMaximum != nil {
if num >= *s.ExclusiveMaximum {
res.Addf(path, v, s.msgExclusiveMaximum)
res.Add(path, v, s.msgExclusiveMaximum)
}
}
if s.MultipleOf != nil {
if math.Mod(num, *s.MultipleOf) != 0 {
res.Addf(path, v, s.msgMultipleOf)
res.Add(path, v, s.msgMultipleOf)
}
}
case TypeString:
Expand All @@ -411,14 +413,14 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any,
if b, ok := v.([]byte); ok {
str = *(*string)(unsafe.Pointer(&b))
} else {
res.Add(path, v, "expected string")
res.Add(path, v, validation.MsgExpectedString)
return
}
}

if s.MinLength != nil {
if utf8.RuneCountInString(str) < *s.MinLength {
res.Addf(path, str, s.msgMinLength)
res.Add(path, str, s.msgMinLength)
}
}
if s.MaxLength != nil {
Expand All @@ -438,7 +440,7 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any,

if s.ContentEncoding == "base64" {
if !rxBase64.MatchString(str) {
res.Add(path, str, "expected string to be base64 encoded")
res.Add(path, str, validation.MsgExpectedBase64String)
}
}
case TypeArray:
Expand Down Expand Up @@ -471,7 +473,7 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any,
case []float64:
handleArray(r, s, path, mode, res, arr)
default:
res.Add(path, v, "expected array")
res.Add(path, v, validation.MsgExpectedArray)
return
}
case TypeObject:
Expand All @@ -480,7 +482,7 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any,
} else if vv, ok := v.(map[any]any); ok {
handleMapAny(r, s, path, mode, vv, res)
} else {
res.Add(path, v, "expected object")
res.Add(path, v, validation.MsgExpectedObject)
return
}
}
Expand All @@ -502,20 +504,20 @@ func Validate(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, v any,
func handleArray[T any](r Registry, s *Schema, path *PathBuffer, mode ValidateMode, res *ValidateResult, arr []T) {
if s.MinItems != nil {
if len(arr) < *s.MinItems {
res.Addf(path, arr, s.msgMinItems)
res.Add(path, arr, s.msgMinItems)
}
}
if s.MaxItems != nil {
if len(arr) > *s.MaxItems {
res.Addf(path, arr, s.msgMaxItems)
res.Add(path, arr, s.msgMaxItems)
}
}

if s.UniqueItems {
seen := make(map[any]struct{}, len(arr))
for _, item := range arr {
if _, ok := seen[item]; ok {
res.Add(path, arr, "expected array items to be unique")
res.Add(path, arr, validation.MsgExpectedArrayItemsUnique)
}
seen[item] = struct{}{}
}
Expand Down Expand Up @@ -602,7 +604,7 @@ func handleMapString(r Registry, s *Schema, path *PathBuffer, mode ValidateMode,
// No additional properties allowed.
if _, ok := s.Properties[k]; !ok {
path.Push(k)
res.Add(path, m, "unexpected property")
res.Add(path, m, validation.MsgUnexpectedProperty)
path.Pop()
}
}
Expand Down Expand Up @@ -702,7 +704,7 @@ func handleMapAny(r Registry, s *Schema, path *PathBuffer, mode ValidateMode, m
}
if _, ok := s.Properties[kStr]; !ok {
path.Push(kStr)
res.Add(path, m, "unexpected property")
res.Add(path, m, validation.MsgUnexpectedProperty)
path.Pop()
}
}
Expand Down
22 changes: 22 additions & 0 deletions validate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1300,6 +1300,28 @@ func TestValidate(t *testing.T) {
}
}

func TestValidateCustomFormatter(t *testing.T) {
originalFormatter := huma.ErrorFormatter
defer func() {
huma.ErrorFormatter = originalFormatter
}()

huma.ErrorFormatter = func(format string, a ...any) string {
return fmt.Sprintf("custom: %v", a)
}

registry := huma.NewMapRegistry("#/components/schemas/", huma.DefaultSchemaNamer)
s := registry.Schema(reflect.TypeOf(struct {
Value string `json:"value" format:"email"`
}{}), true, "TestInput")
pb := huma.NewPathBuffer([]byte(""), 0)
res := &huma.ValidateResult{}

huma.Validate(registry, s, pb, huma.ModeReadFromServer, map[string]any{"value": "alice"}, res)
assert.Len(t, res.Errors, 1)
assert.Equal(t, "custom: [mail: missing '@' or angle-addr] (value: alice)", res.Errors[0].Error())
}

func ExampleModelValidator() {
// Define a type you want to validate.
type Model struct {
Expand Down
Loading

0 comments on commit 4b6a508

Please sign in to comment.