Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional tag to ignore fields when nil when rlp encoding/decoding #12

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions rlp/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ type Decoder interface {
// rules for the field such that input values of size zero decode as a nil
// pointer. This tag can be useful when decoding recursive types.
//
// type StructWithEmptyOK struct {
// Foo *[20]byte `rlp:"nil"`
// }
// type StructWithEmptyOK struct {
// Foo *[20]byte `rlp:"nil"`
// }
//
// To decode into a slice, the input must be a list and the resulting
// slice will contain the input elements in order. For byte slices,
Expand All @@ -113,8 +113,8 @@ type Decoder interface {
// To decode into an interface value, Decode stores one of these
// in the value:
//
// []interface{}, for RLP lists
// []byte, for RLP strings
// []interface{}, for RLP lists
// []byte, for RLP strings
//
// Non-empty interface types are not supported, nor are booleans,
// signed integers, floating point numbers, maps, channels and
Expand All @@ -124,7 +124,7 @@ type Decoder interface {
// and may be vulnerable to panics cause by huge value sizes. If
// you need an input limit, use
//
// NewStream(r, limit).Decode(val)
// NewStream(r, limit).Decode(val)
func Decode(r io.Reader, val interface{}) error {
// TODO: this could use a Stream from a pool.
return NewStream(r, 0).Decode(val)
Expand Down Expand Up @@ -438,9 +438,16 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) {
if _, err := s.List(); err != nil {
return wrapStreamError(err, typ)
}
for _, f := range fields {
for i, f := range fields {
err := f.info.decoder(s, val.Field(f.index))
if err == EOL {
if f.optional {
// The field is optional, so reaching the end of the list before
// reaching the last field is acceptable. All remaining undecoded
// fields are zeroed.
zeroFields(val, fields[i:])
break
}
return &decodeError{msg: "too few elements", typ: typ}
} else if err != nil {
return addErrorContext(err, "."+typ.Field(f.index).Name)
Expand All @@ -451,6 +458,13 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) {
return dec, nil
}

func zeroFields(structval reflect.Value, fields []field) {
for _, f := range fields {
fv := structval.Field(f.index)
fv.Set(reflect.Zero(fv.Type()))
}
}

// makePtrDecoder creates a decoder that decodes into
// the pointer's element type.
func makePtrDecoder(typ reflect.Type) (decoder, error) {
Expand Down
164 changes: 161 additions & 3 deletions rlp/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,29 @@ var (
)
)

type hasIgnoredField struct {
type optionalFields struct {
A uint
B uint `rlp:"optional"`
C uint `rlp:"optional"`
}
type optionalAndTailField struct {
A uint
B uint `rlp:"optional"`
Tail []uint `rlp:"tail"`
}
type optionalBigIntField struct {
A uint
B *big.Int `rlp:"optional"`
}
type optionalPtrField struct {
A uint
B *[3]byte `rlp:"optional"`
}
type optionalPtrFieldNil struct {
A uint
B *[3]byte `rlp:"optional,nil"`
}
type ignoredField struct {
A uint
B uint `rlp:"-"`
C uint
Expand Down Expand Up @@ -514,8 +536,112 @@ var decodeTests = []decodeTest{
// struct tag "-"
{
input: "C20102",
ptr: new(hasIgnoredField),
value: hasIgnoredField{A: 1, C: 2},
ptr: new(ignoredField),
value: ignoredField{A: 1, C: 2},
},

// struct tag "optional"
{
input: "C101",
ptr: new(optionalFields),
value: optionalFields{1, 0, 0},
},
{
input: "C20102",
ptr: new(optionalFields),
value: optionalFields{1, 2, 0},
},
{
input: "C3010203",
ptr: new(optionalFields),
value: optionalFields{1, 2, 3},
},
{
input: "C401020304",
ptr: new(optionalFields),
error: "rlp: input list has too many elements for rlp.optionalFields",
},
{
input: "C101",
ptr: new(optionalAndTailField),
value: optionalAndTailField{A: 1},
},
{
input: "C20102",
ptr: new(optionalAndTailField),
value: optionalAndTailField{A: 1, B: 2, Tail: []uint{}},
},
{
input: "C401020304",
ptr: new(optionalAndTailField),
value: optionalAndTailField{A: 1, B: 2, Tail: []uint{3, 4}},
},
{
input: "C101",
ptr: new(optionalBigIntField),
value: optionalBigIntField{A: 1, B: nil},
},
{
input: "C20102",
ptr: new(optionalBigIntField),
value: optionalBigIntField{A: 1, B: big.NewInt(2)},
},
{
input: "C101",
ptr: new(optionalPtrField),
value: optionalPtrField{A: 1},
},
{
input: "C20180", // not accepted because "optional" doesn't enable "nil"
ptr: new(optionalPtrField),
error: "rlp: input string too short for [3]uint8, decoding into (rlp.optionalPtrField).B",
},
{
input: "C20102",
ptr: new(optionalPtrField),
error: "rlp: input string too short for [3]uint8, decoding into (rlp.optionalPtrField).B",
},
{
input: "C50183010203",
ptr: new(optionalPtrField),
value: optionalPtrField{A: 1, B: &[3]byte{1, 2, 3}},
},
{
input: "C101",
ptr: new(optionalPtrFieldNil),
value: optionalPtrFieldNil{A: 1},
},
{
input: "C20180", // accepted because "nil" tag allows empty input
ptr: new(optionalPtrFieldNil),
value: optionalPtrFieldNil{A: 1},
},
{
input: "C20102",
ptr: new(optionalPtrFieldNil),
error: "rlp: input string too short for [3]uint8, decoding into (rlp.optionalPtrFieldNil).B",
},

// struct tag "optional" field clearing
{
input: "C101",
ptr: &optionalFields{A: 9, B: 8, C: 7},
value: optionalFields{A: 1, B: 0, C: 0},
},
{
input: "C20102",
ptr: &optionalFields{A: 9, B: 8, C: 7},
value: optionalFields{A: 1, B: 2, C: 0},
},
{
input: "C20102",
ptr: &optionalAndTailField{A: 9, B: 8, Tail: []uint{7, 6, 5}},
value: optionalAndTailField{A: 1, B: 2, Tail: []uint{}},
},
{
input: "C101",
ptr: &optionalPtrField{A: 9, B: &[3]byte{8, 7, 6}},
value: optionalPtrField{A: 1},
},

// RawValue
Expand Down Expand Up @@ -817,3 +943,35 @@ func unhex(str string) []byte {
}
return b
}

// This tests the validity checks for fields with struct tag "optional".
func TestInvalidOptionalField(t *testing.T) {
type (
invalid1 struct {
A uint `rlp:"optional"`
B uint
}
invalid2 struct {
T []uint `rlp:"tail,optional"`
}
invalid3 struct {
T []uint `rlp:"optional,tail"`
}
)
tests := []struct {
v interface{}
err string
}{
{v: new(invalid1), err: `rlp: struct field rlp.invalid1.B needs "optional" tag`},
leszek-vechain marked this conversation as resolved.
Show resolved Hide resolved
{v: new(invalid2), err: `rlp: invalid struct tag "optional" for rlp.invalid2.T (cannot be used with "tail")`},
{v: new(invalid3), err: `rlp: invalid struct tag "tail" for rlp.invalid3.T (cannot be used with "optional")`},
}
for _, test := range tests {
err := DecodeBytes(unhex("C20102"), test.v)
if err == nil {
t.Errorf("no error for %T", test.v)
} else if err.Error() != test.err {
t.Errorf("wrong error for %T: %v", test.v, err.Error())
}
}
}
38 changes: 31 additions & 7 deletions rlp/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -529,15 +529,39 @@ func makeStructWriter(typ reflect.Type) (writer, error) {
if err != nil {
return nil, err
}
writer := func(val reflect.Value, w *encbuf) error {
lh := w.list()
for _, f := range fields {
if err := f.info.writer(val.Field(f.index), w); err != nil {
return err
var writer writer
firstOptionalField := firstOptionalField(fields)
if firstOptionalField == len(fields) {
// This is the writer function for structs without any optional fields.
writer = func(val reflect.Value, w *encbuf) error {
lh := w.list()
for _, f := range fields {
if err := f.info.writer(val.Field(f.index), w); err != nil {
return err
}
}
w.listEnd(lh)
return nil
}
} else {
// If there are any "optional" fields, the writer needs to perform additional
// checks to determine the output list length.
writer = func(val reflect.Value, w *encbuf) error {
lastField := len(fields) - 1
for ; lastField >= firstOptionalField; lastField-- {
if !val.Field(fields[lastField].index).IsZero() {
break
}
}
lh := w.list()
for i := 0; i <= lastField; i++ {
if err := fields[i].info.writer(val.Field(fields[i].index), w); err != nil {
return err
}
}
w.listEnd(lh)
return nil
}
w.listEnd(lh)
return nil
}
return writer, nil
}
Expand Down
16 changes: 15 additions & 1 deletion rlp/encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ var encTests = []encTest{
{val: &tailRaw{A: 1, Tail: []RawValue{unhex("02")}}, output: "C20102"},
{val: &tailRaw{A: 1, Tail: []RawValue{}}, output: "C101"},
{val: &tailRaw{A: 1, Tail: nil}, output: "C101"},
{val: &hasIgnoredField{A: 1, B: 2, C: 3}, output: "C20103"},
{val: &ignoredField{A: 1, B: 2, C: 3}, output: "C20103"},

// nil
{val: (*uint)(nil), output: "80"},
Expand All @@ -232,6 +232,20 @@ var encTests = []encTest{
{val: (*[]struct{ uint })(nil), output: "C0"},
{val: (*interface{})(nil), output: "C0"},

// struct tag "optional"
{val: &optionalFields{}, output: "C180"},
{val: &optionalFields{A: 1}, output: "C101"},
{val: &optionalFields{A: 1, B: 2}, output: "C20102"},
{val: &optionalFields{A: 1, B: 2, C: 3}, output: "C3010203"},
{val: &optionalFields{A: 1, B: 0, C: 3}, output: "C3018003"},
{val: &optionalAndTailField{A: 1}, output: "C101"},
{val: &optionalAndTailField{A: 1, B: 2}, output: "C20102"},
{val: &optionalAndTailField{A: 1, Tail: []uint{5, 6}}, output: "C401800506"},
{val: &optionalAndTailField{A: 1, Tail: []uint{5, 6}}, output: "C401800506"},
{val: &optionalBigIntField{A: 1}, output: "C101"},
{val: &optionalPtrField{A: 1}, output: "C101"},
{val: &optionalPtrFieldNil{A: 1}, output: "C101"},

// interfaces
{val: []io.Reader{reader}, output: "C3C20102"}, // the contained value is a struct

Expand Down
Loading
Loading