Skip to content

Commit

Permalink
Match CEL and Go duration literal parsing, while preserving the full …
Browse files Browse the repository at this point in the history
…range of values (#38)

Matches the impl in time.ParseDuration, except uses a big.Int to support
the full range of protobuf Duration values.
  • Loading branch information
Alfus authored Jul 29, 2024
1 parent 77e564e commit 93c0894
Show file tree
Hide file tree
Showing 6 changed files with 255 additions and 107 deletions.
4 changes: 4 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ linters:
- wsl # generous whitespace violates house style
- exhaustive
- exhaustruct
- nonamedreturns
- mnd
- err113
- gochecknoglobals
issues:
exclude:
# Don't ban use of fmt.Errorf to create new errors, but the remaining
Expand Down
263 changes: 197 additions & 66 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"errors"
"fmt"
"math"
"math/big"
"strconv"
"strings"
"time"
Expand All @@ -34,8 +35,6 @@ import (
"gopkg.in/yaml.v3"
)

const atTypeFieldName = "@type"

// Validator is an interface for validating a Protobuf message produced from a given YAML node.
type Validator interface {
// Validate the given message.
Expand All @@ -57,11 +56,6 @@ type UnmarshalOptions struct {
}
}

type protoResolver interface {
protoregistry.MessageTypeResolver
protoregistry.ExtensionTypeResolver
}

// Unmarshal a Protobuf message from the given YAML data.
func Unmarshal(data []byte, message proto.Message) error {
return (UnmarshalOptions{}).Unmarshal(data, message)
Expand All @@ -76,6 +70,53 @@ func (o UnmarshalOptions) Unmarshal(data []byte, message proto.Message) error {
return o.unmarshalNode(&yamlFile, message, data)
}

// ParseDuration parses a duration string into a durationpb.Duration.
//
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
//
// This function supports the full range of durationpb.Duration values, including
// those outside the range of time.Duration.
func ParseDuration(str string) (*durationpb.Duration, error) {
// [-+]?([0-9]*(\.[0-9]*)?[a-z]+)+
neg := false

// Consume [-+]?
if str != "" {
c := str[0]
if c == '-' || c == '+' {
neg = c == '-'
str = str[1:]
}
}
// Special case: if all that is left is "0", this is zero.
if str == "0" {
var empty *durationpb.Duration
return empty, nil
}
if str == "" {
return nil, errors.New("invalid duration")
}
totalNanos := &big.Int{}
var err error
for str != "" {
str, err = parseDurationNext(str, totalNanos)
if err != nil {
return nil, err
}
}
if neg {
totalNanos.Neg(totalNanos)
}
result := &durationpb.Duration{}
quo, rem := totalNanos.QuoRem(totalNanos, nanosPerSecond, &big.Int{})
if !quo.IsInt64() {
return nil, errors.New("invalid duration: out of range")
}
result.Seconds = quo.Int64()
result.Nanos = int32(rem.Int64())
return result, nil
}

func (o UnmarshalOptions) unmarshalNode(node *yaml.Node, message proto.Message, data []byte) error {
if node.Kind == 0 {
return nil
Expand Down Expand Up @@ -121,6 +162,13 @@ func (o UnmarshalOptions) unmarshalNode(node *yaml.Node, message proto.Message,
return nil
}

const atTypeFieldName = "@type"

type protoResolver interface {
protoregistry.MessageTypeResolver
protoregistry.ExtensionTypeResolver
}

type unmarshaler struct {
options UnmarshalOptions
errors []error
Expand Down Expand Up @@ -683,54 +731,6 @@ const (
minTimestampSeconds = -62135596800
)

// Format is decimal seconds with up to 9 fractional digits, followed by an 's'.
func parseDuration(txt string, duration *durationpb.Duration) error {
// Remove trailing s.
txt = strings.TrimSpace(txt)
if len(txt) == 0 || txt[len(txt)-1] != 's' {
return errors.New("missing trailing 's'")
}
value := txt[:len(txt)-1]
isNeg := strings.HasPrefix(value, "-")

// Split into seconds and nanos.
parts := strings.Split(value, ".")
switch len(parts) {
case 1:
// seconds only
seconds, err := strconv.ParseInt(parts[0], 10, 64)
if err != nil {
return err
}
duration.Seconds = seconds
duration.Nanos = 0
case 2:
// seconds and up to 9 digits of fractional seconds
seconds, err := strconv.ParseInt(parts[0], 10, 64)
if err != nil {
return err
}
duration.Seconds = seconds
nanos, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
return err
}
power := 9 - len(parts[1])
if power < 0 {
return errors.New("too many fractional second digits")
}
nanos *= int64(math.Pow10(power))
if isNeg {
duration.Nanos = -int32(nanos)
} else {
duration.Nanos = int32(nanos)
}
default:
return errors.New("invalid duration: too many '.' characters")
}
return nil
}

// Format is RFC3339Nano, limited to the range 0001-01-01T00:00:00Z to
// 9999-12-31T23:59:59Z inclusive.
func parseTimestamp(txt string, timestamp *timestamppb.Timestamp) error {
Expand Down Expand Up @@ -770,19 +770,21 @@ func unmarshalDurationMsg(unm *unmarshaler, node *yaml.Node, message proto.Messa
if node.Kind != yaml.ScalarNode || len(node.Value) == 0 || isNull(node) {
return false
}
duration, ok := message.(*durationpb.Duration)
if !ok {
duration = &durationpb.Duration{}
}
err := parseDuration(node.Value, duration)
duration, err := ParseDuration(node.Value)
if err != nil {
unm.addErrorf(node, "invalid duration: %v", err)
} else if !ok {
// Set the fields dynamically.
return setFieldByName(message, "seconds", protoreflect.ValueOfInt64(duration.GetSeconds())) &&
setFieldByName(message, "nanos", protoreflect.ValueOfInt32(duration.GetNanos()))
unm.addError(node, err)
return true
}
return true

if value, ok := message.(*durationpb.Duration); ok {
value.Seconds = duration.GetSeconds()
value.Nanos = duration.GetNanos()
return true
}

// Set the fields dynamically.
return setFieldByName(message, "seconds", protoreflect.ValueOfInt64(duration.GetSeconds())) &&
setFieldByName(message, "nanos", protoreflect.ValueOfInt32(duration.GetNanos()))
}

func unmarshalTimestampMsg(unm *unmarshaler, node *yaml.Node, message proto.Message) bool {
Expand Down Expand Up @@ -1184,3 +1186,132 @@ func findEntryByKey(cur *yaml.Node, key string) (*yaml.Node, *yaml.Node, bool) {
}
return nil, cur, false
}

// nanosPerSecond is the number of nanoseconds in a second.
var nanosPerSecond = new(big.Int).SetUint64(uint64(time.Second / time.Nanosecond))

// nanosMap is a map of time unit names to their duration in nanoseconds.
var nanosMap = map[string]*big.Int{
"ns": new(big.Int).SetUint64(1), // Identity for nanos.
"us": new(big.Int).SetUint64(uint64(time.Microsecond / time.Nanosecond)),
"µs": new(big.Int).SetUint64(uint64(time.Microsecond / time.Nanosecond)), // U+00B5 = micro symbol
"μs": new(big.Int).SetUint64(uint64(time.Microsecond / time.Nanosecond)), // U+03BC = Greek letter mu
"ms": new(big.Int).SetUint64(uint64(time.Millisecond / time.Nanosecond)),
"s": nanosPerSecond,
"m": new(big.Int).SetUint64(uint64(time.Minute / time.Nanosecond)),
"h": new(big.Int).SetUint64(uint64(time.Hour / time.Nanosecond)),
}

// unitsNames is the (normalized) list of time unit names.
var unitsNames = []string{"h", "m", "s", "ms", "us", "ns"}

// parseDurationNest parses a single segment of the duration string.
func parseDurationNext(str string, totalNanos *big.Int) (string, error) {
// The next character must be [0-9.]
if !(str[0] == '.' || '0' <= str[0] && str[0] <= '9') {
return "", errors.New("invalid duration")
}
var err error
var whole, frac uint64
var pre bool // Whether we have seen a digit before the dot.
whole, str, pre, err = leadingInt(str)
if err != nil {
return "", err
}
var scale *big.Int
var post bool // Whether we have seen a digit after the dot.
if str != "" && str[0] == '.' {
str = str[1:]
frac, scale, str, post = leadingFrac(str)
}
if !pre && !post {
return "", errors.New("invalid duration")
}

end := unitEnd(str)
if end == 0 {
return "", fmt.Errorf("invalid duration: missing unit, expected one of %v", unitsNames)
}
unitName := str[:end]
str = str[end:]
nanosPerUnit, ok := nanosMap[unitName]
if !ok {
return "", fmt.Errorf("invalid duration: unknown unit, expected one of %v", unitsNames)
}

// Convert to nanos and add to total.
// totalNanos += whole * nanosPerUnit + frac * nanosPerUnit / scale
if whole > 0 {
wholeNanos := &big.Int{}
wholeNanos.SetUint64(whole)
wholeNanos.Mul(wholeNanos, nanosPerUnit)
totalNanos.Add(totalNanos, wholeNanos)
}
if frac > 0 {
fracNanos := &big.Int{}
fracNanos.SetUint64(frac)
fracNanos.Mul(fracNanos, nanosPerUnit)
rem := &big.Int{}
fracNanos.QuoRem(fracNanos, scale, rem)
if rem.Uint64() > 0 {
return "", errors.New("invalid duration: fractional nanos")
}
totalNanos.Add(totalNanos, fracNanos)
}
return str, nil
}

func unitEnd(str string) int {
var i int
for ; i < len(str); i++ {
c := str[i]
if c == '.' || c == '-' || '0' <= c && c <= '9' {
return i
}
}
return i
}

func leadingFrac(str string) (result uint64, scale *big.Int, rem string, post bool) {
var i int
scale = big.NewInt(1)
big10 := big.NewInt(10)
var overflow bool
for ; i < len(str); i++ {
chr := str[i]
if chr < '0' || chr > '9' {
break
}
if overflow {
continue
}
if result > (1<<63-1)/10 {
overflow = true
continue
}
temp := result*10 + uint64(chr-'0')
if temp > 1<<63 {
overflow = true
continue
}
result = temp
scale.Mul(scale, big10)
}
return result, scale, str[i:], i > 0
}

func leadingInt(str string) (result uint64, rem string, pre bool, err error) {
var i int
for ; i < len(str); i++ {
c := str[i]
if c < '0' || c > '9' {
break
}
newResult := result*10 + uint64(c-'0')
if newResult < result {
return 0, str, i > 0, errors.New("integer overflow")
}
result = newResult
}
return result, str[i:], i > 0, nil
}
Loading

0 comments on commit 93c0894

Please sign in to comment.