diff --git a/operators/validate_byte_range.go b/operators/validate_byte_range.go index a8bbc9742..9ab306010 100644 --- a/operators/validate_byte_range.go +++ b/operators/validate_byte_range.go @@ -13,13 +13,8 @@ import ( "github.com/corazawaf/coraza/v3/rules" ) -type byteRange struct { - start byte - end byte -} - type validateByteRange struct { - data []byteRange + validBytes [256]bool // array, not slice, so don't pass as-is to functions } var _ rules.Operator = (*validateByteRange)(nil) @@ -28,75 +23,65 @@ func newValidateByteRange(options rules.OperatorOptions) (rules.Operator, error) data := options.Arguments if data == "" { - return &validateByteRange{}, nil + return &unconditionalMatch{}, nil } - var ranges []byteRange - var err error + var validBytes [256]bool for _, br := range strings.Split(data, ",") { br = strings.TrimSpace(br) - spl := strings.SplitN(br, "-", 2) + start, end, ok := strings.Cut(br, "-") - var start, end uint64 - if len(spl) == 1 { - start, err = strconv.ParseUint(spl[0], 10, 8) - if err != nil { + if !ok { + if b, err := strconv.Atoi(start); err != nil { return nil, err - } - if ranges, err = addRange(ranges, start, start); err != nil { + } else if err := validateByte(b); err != nil { return nil, err + } else { + validBytes[b] = true } continue } - start, err = strconv.ParseUint(spl[0], 10, 8) + s, err := strconv.Atoi(start) if err != nil { return nil, err } - end, err = strconv.ParseUint(spl[1], 10, 8) + if err := validateByte(s); err != nil { + return nil, err + } + e, err := strconv.Atoi(end) if err != nil { return nil, err } - if ranges, err = addRange(ranges, start, end); err != nil { + if err := validateByte(e); err != nil { return nil, err } + for i := s; i <= e; i++ { + validBytes[i] = true + } } - return &validateByteRange{data: ranges}, nil + return &validateByteRange{validBytes: validBytes}, nil } -func (o *validateByteRange) Evaluate(tx rules.TransactionState, data string) bool { - lenData := len(o.data) - if lenData == 0 { - return true +func validateByte(b int) error { + if b < 0 || b > 255 { + return fmt.Errorf("invalid byte %d", b) } + return nil +} + +func (o *validateByteRange) Evaluate(tx rules.TransactionState, data string) bool { if data == "" { return false } // we must iterate each byte from input and check if it is in the range // if every byte is within the range we return false - matched := 0 for i := 0; i < len(data); i++ { c := data[i] - for _, r := range o.data { - if c >= r.start && c <= r.end { - matched++ - break - } + if !o.validBytes[c] { + return true } } - return len(data) != matched -} - -func addRange(ranges []byteRange, start uint64, end uint64) ([]byteRange, error) { - if start > 255 { - return nil, fmt.Errorf("invalid start byte %d", start) - } - if end > 255 { - return nil, fmt.Errorf("invalid end byte %d", end) - } - return append(ranges, byteRange{ - start: byte(start), - end: byte(end), - }), nil + return false } func init() {