Skip to content

Commit

Permalink
Reduce encoding and decoding allocations
Browse files Browse the repository at this point in the history
I reduced the allocation count by

1. Adding a scratch buffer to the Encoder/decoder. The scratch
   buffer is used as a temporary buffer. Before that, temporary
   buffers were allocated (in the heap since they escaped the stack)
   for basic type encoding/decoding (e.g. `DecodeInt()` `EncodeInt()`

2. Making `DecodeFixedOpaque()` decode in-place instead of
   allocating the result. Apart from reducing allocations,
   this removes the need of copying (as shown by `decodeFixedArray()`.

The pre-existing (and admitedly very limited) benchmarks show a sizeable improvement:

Before:

```
goos: darwin
goarch: amd64
pkg: github.com/stellar/go-xdr/xdr3
cpu: Intel(R) Core(TM) i7-1068NG7 CPU @ 2.30GHz
BenchmarkUnmarshal
BenchmarkUnmarshal-8   	    3000	      1039 ns/op	  15.40 MB/s	     152 B/op	      11 allocs/op
BenchmarkMarshal
BenchmarkMarshal-8     	    3000	       806.7 ns/op	  19.83 MB/s	     104 B/op	      10 allocs/op
PASS
```

After:
```
goos: darwin
goarch: amd64
pkg: github.com/stellar/go-xdr/xdr3
cpu: Intel(R) Core(TM) i7-1068NG7 CPU @ 2.30GHz
BenchmarkUnmarshal
BenchmarkUnmarshal-8   	    3000	       679.6 ns/op	  23.54 MB/s	     144 B/op	       8 allocs/op
BenchmarkMarshal
BenchmarkMarshal-8     	    3000	       609.1 ns/op	  26.27 MB/s	     104 B/op	       8 allocs/op
PASS
```

Both decoding and encoding both go down to 8 allocations from 11 (decoding) and 10 (encoding) allocations per operation.

More context at stellar/go#4022 (comment)
  • Loading branch information
2opremio committed Nov 2, 2021
1 parent b95df30 commit e50eac4
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 112 deletions.
138 changes: 69 additions & 69 deletions xdr3/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ func Unmarshal(r io.Reader, v interface{}) (int, error) {
// necessary in complex scenarios where automatic reflection-based decoding
// won't work.
type Decoder struct {
r io.Reader
// used to minimize heap allocations during decoding
scratchBuf [8]byte
r io.Reader
}

// DecodeInt treats the next 4 bytes as an XDR encoded integer and returns the
Expand All @@ -102,16 +104,15 @@ type Decoder struct {
// RFC Section 4.1 - Integer
// 32-bit big-endian signed integer in range [-2147483648, 2147483647]
func (d *Decoder) DecodeInt() (int32, int, error) {
var buf [4]byte
n, err := io.ReadFull(d.r, buf[:])
n, err := io.ReadFull(d.r, d.scratchBuf[:4])
if err != nil {
msg := fmt.Sprintf(errIODecode, err.Error(), 4)
err := unmarshalError("DecodeInt", ErrIO, msg, buf[:n], err)
err := unmarshalError("DecodeInt", ErrIO, msg, d.scratchBuf[:n], err)
return 0, n, err
}

rv := int32(buf[3]) | int32(buf[2])<<8 |
int32(buf[1])<<16 | int32(buf[0])<<24
rv := int32(d.scratchBuf[3]) | int32(d.scratchBuf[2])<<8 |
int32(d.scratchBuf[1])<<16 | int32(d.scratchBuf[0])<<24
return rv, n, nil
}

Expand All @@ -124,16 +125,15 @@ func (d *Decoder) DecodeInt() (int32, int, error) {
// RFC Section 4.2 - Unsigned Integer
// 32-bit big-endian unsigned integer in range [0, 4294967295]
func (d *Decoder) DecodeUint() (uint32, int, error) {
var buf [4]byte
n, err := io.ReadFull(d.r, buf[:])
n, err := io.ReadFull(d.r, d.scratchBuf[:4])
if err != nil {
msg := fmt.Sprintf(errIODecode, err.Error(), 4)
err := unmarshalError("DecodeUint", ErrIO, msg, buf[:n], err)
err := unmarshalError("DecodeUint", ErrIO, msg, d.scratchBuf[:n], err)
return 0, n, err
}

rv := uint32(buf[3]) | uint32(buf[2])<<8 |
uint32(buf[1])<<16 | uint32(buf[0])<<24
rv := uint32(d.scratchBuf[3]) | uint32(d.scratchBuf[2])<<8 |
uint32(d.scratchBuf[1])<<16 | uint32(d.scratchBuf[0])<<24
return rv, n, nil
}

Expand Down Expand Up @@ -197,18 +197,17 @@ func (d *Decoder) DecodeBool() (bool, int, error) {
// RFC Section 4.5 - Hyper Integer
// 64-bit big-endian signed integer in range [-9223372036854775808, 9223372036854775807]
func (d *Decoder) DecodeHyper() (int64, int, error) {
var buf [8]byte
n, err := io.ReadFull(d.r, buf[:])
n, err := io.ReadFull(d.r, d.scratchBuf[:8])
if err != nil {
msg := fmt.Sprintf(errIODecode, err.Error(), 8)
err := unmarshalError("DecodeHyper", ErrIO, msg, buf[:n], err)
err := unmarshalError("DecodeHyper", ErrIO, msg, d.scratchBuf[:n], err)
return 0, n, err
}

rv := int64(buf[7]) | int64(buf[6])<<8 |
int64(buf[5])<<16 | int64(buf[4])<<24 |
int64(buf[3])<<32 | int64(buf[2])<<40 |
int64(buf[1])<<48 | int64(buf[0])<<56
rv := int64(d.scratchBuf[7]) | int64(d.scratchBuf[6])<<8 |
int64(d.scratchBuf[5])<<16 | int64(d.scratchBuf[4])<<24 |
int64(d.scratchBuf[3])<<32 | int64(d.scratchBuf[2])<<40 |
int64(d.scratchBuf[1])<<48 | int64(d.scratchBuf[0])<<56
return rv, n, err
}

Expand All @@ -222,18 +221,17 @@ func (d *Decoder) DecodeHyper() (int64, int, error) {
// RFC Section 4.5 - Unsigned Hyper Integer
// 64-bit big-endian unsigned integer in range [0, 18446744073709551615]
func (d *Decoder) DecodeUhyper() (uint64, int, error) {
var buf [8]byte
n, err := io.ReadFull(d.r, buf[:])
n, err := io.ReadFull(d.r, d.scratchBuf[:8])
if err != nil {
msg := fmt.Sprintf(errIODecode, err.Error(), 8)
err := unmarshalError("DecodeUhyper", ErrIO, msg, buf[:n], err)
err := unmarshalError("DecodeUhyper", ErrIO, msg, d.scratchBuf[:n], err)
return 0, n, err
}

rv := uint64(buf[7]) | uint64(buf[6])<<8 |
uint64(buf[5])<<16 | uint64(buf[4])<<24 |
uint64(buf[3])<<32 | uint64(buf[2])<<40 |
uint64(buf[1])<<48 | uint64(buf[0])<<56
rv := uint64(d.scratchBuf[7]) | uint64(d.scratchBuf[6])<<8 |
uint64(d.scratchBuf[5])<<16 | uint64(d.scratchBuf[4])<<24 |
uint64(d.scratchBuf[3])<<32 | uint64(d.scratchBuf[2])<<40 |
uint64(d.scratchBuf[1])<<48 | uint64(d.scratchBuf[0])<<56
return rv, n, nil
}

Expand All @@ -246,16 +244,15 @@ func (d *Decoder) DecodeUhyper() (uint64, int, error) {
// RFC Section 4.6 - Floating Point
// 32-bit single-precision IEEE 754 floating point
func (d *Decoder) DecodeFloat() (float32, int, error) {
var buf [4]byte
n, err := io.ReadFull(d.r, buf[:])
n, err := io.ReadFull(d.r, d.scratchBuf[:4])
if err != nil {
msg := fmt.Sprintf(errIODecode, err.Error(), 4)
err := unmarshalError("DecodeFloat", ErrIO, msg, buf[:n], err)
err := unmarshalError("DecodeFloat", ErrIO, msg, d.scratchBuf[:n], err)
return 0, n, err
}

val := uint32(buf[3]) | uint32(buf[2])<<8 |
uint32(buf[1])<<16 | uint32(buf[0])<<24
val := uint32(d.scratchBuf[3]) | uint32(d.scratchBuf[2])<<8 |
uint32(d.scratchBuf[1])<<16 | uint32(d.scratchBuf[0])<<24
return math.Float32frombits(val), n, nil
}

Expand All @@ -269,18 +266,17 @@ func (d *Decoder) DecodeFloat() (float32, int, error) {
// RFC Section 4.7 - Double-Precision Floating Point
// 64-bit double-precision IEEE 754 floating point
func (d *Decoder) DecodeDouble() (float64, int, error) {
var buf [8]byte
n, err := io.ReadFull(d.r, buf[:])
n, err := io.ReadFull(d.r, d.scratchBuf[:8])
if err != nil {
msg := fmt.Sprintf(errIODecode, err.Error(), 8)
err := unmarshalError("DecodeDouble", ErrIO, msg, buf[:n], err)
err := unmarshalError("DecodeDouble", ErrIO, msg, d.scratchBuf[:n], err)
return 0, n, err
}

val := uint64(buf[7]) | uint64(buf[6])<<8 |
uint64(buf[5])<<16 | uint64(buf[4])<<24 |
uint64(buf[3])<<32 | uint64(buf[2])<<40 |
uint64(buf[1])<<48 | uint64(buf[0])<<56
val := uint64(d.scratchBuf[7]) | uint64(d.scratchBuf[6])<<8 |
uint64(d.scratchBuf[5])<<16 | uint64(d.scratchBuf[4])<<24 |
uint64(d.scratchBuf[3])<<32 | uint64(d.scratchBuf[2])<<40 |
uint64(d.scratchBuf[1])<<48 | uint64(d.scratchBuf[0])<<56
return math.Float64frombits(val), n, nil
}

Expand All @@ -299,45 +295,51 @@ func (d *Decoder) DecodeDouble() (float64, int, error) {
// Reference:
// RFC Section 4.9 - Fixed-Length Opaque Data
// Fixed-length uninterpreted data zero-padded to a multiple of four
func (d *Decoder) DecodeFixedOpaque(size int32) ([]byte, int, error) {
func (d *Decoder) DecodeFixedOpaque(out []byte) (int, error) {
size := len(out)
// Nothing to do if size is 0.
if size == 0 {
return nil, 0, nil
return 0, nil
}

pad := (4 - (size % 4)) % 4
paddedSize := size + pad
if uint(paddedSize) > uint(maxInt32) {
err := unmarshalError("DecodeFixedOpaque", ErrOverflow,
errMaxSlice, paddedSize, nil)
return nil, 0, err
return 0, err
}

buf := make([]byte, paddedSize)
n, err := io.ReadFull(d.r, buf)
n, err := io.ReadFull(d.r, out)
if err != nil {
msg := fmt.Sprintf(errIODecode, err.Error(), paddedSize)
err := unmarshalError("DecodeFixedOpaque", ErrIO, msg, buf[:n],
msg := fmt.Sprintf(errIODecode, err.Error(), size)
err := unmarshalError("DecodeFixedOpaque", ErrIO, msg, out[:n],
err)
return nil, n, err
}

if !d.checkPadding(buf, size) {
msg := "non-zero padding"
err := unmarshalError("DecodeFixedOpaque", ErrIO, msg, buf[:n], nil)
return nil, n, err
return n, err
}

return buf[0:size], n, nil
}

func (d *Decoder) checkPadding(buf []byte, size int32) bool {
for _, pad := range buf[size:] {
if pad != 0x00 {
return false
if pad > 0 {
// the maximum value of pad is 3, so the scratch buffer should be enough
padding := d.scratchBuf[:pad]
n2, err := io.ReadFull(d.r, padding)
if err != nil {
msg := fmt.Sprintf(errIODecode, err.Error(), pad)
err := unmarshalError("DecodeFixedOpaque", ErrIO, msg, out[:n],
err)
return n, err
}
n += n2
// check all the padding bytes to be zero
for _, p := range padding {
if p != 0x00 {
msg := "non-zero padding"
err := unmarshalError("DecodeFixedOpaque", ErrIO, msg, padding[:n2], nil)
return n, err
}
}
}
return true

return n, nil
}

// DecodeOpaque treats the next bytes as variable length XDR encoded opaque
Expand Down Expand Up @@ -365,8 +367,8 @@ func (d *Decoder) DecodeOpaque(maxSize int) ([]byte, int, error) {
dataLen, nil)
return nil, n, err
}

rv, n2, err := d.DecodeFixedOpaque(int32(dataLen))
rv := make([]byte, dataLen)
n2, err := d.DecodeFixedOpaque(rv)
n += n2
if err != nil {
return nil, n, err
Expand Down Expand Up @@ -404,7 +406,8 @@ func (d *Decoder) DecodeString(maxSize int) (string, int, error) {
return "", n, err
}

opaque, n2, err := d.DecodeFixedOpaque(int32(dataLen))
opaque := make([]byte, dataLen)
n2, err := d.DecodeFixedOpaque(opaque)
n += n2
if err != nil {
return "", n, err
Expand All @@ -428,12 +431,8 @@ func (d *Decoder) decodeFixedArray(v reflect.Value, ignoreOpaque bool) (int, err
// Treat [#]byte (byte is alias for uint8) as opaque data unless
// ignored.
if !ignoreOpaque && v.Type().Elem().Kind() == reflect.Uint8 {
data, n, err := d.DecodeFixedOpaque(int32(v.Len()))
if err != nil {
return n, err
}
reflect.Copy(v, reflect.ValueOf(data))
return n, nil
dest := v.Slice(0, v.Len()).Bytes()
return d.DecodeFixedOpaque(dest)
}

// Decode each array element.
Expand Down Expand Up @@ -489,7 +488,8 @@ func (d *Decoder) decodeArray(v reflect.Value, ignoreOpaque bool, maxSize int) (

// Treat []byte (byte is alias for uint8) as opaque data unless ignored.
if !ignoreOpaque && v.Type().Elem().Kind() == reflect.Uint8 {
data, n2, err := d.DecodeFixedOpaque(int32(sliceLen))
data := make([]byte, sliceLen)
n2, err := d.DecodeFixedOpaque(data)
n += n2
if err != nil {
return n, err
Expand Down
7 changes: 5 additions & 2 deletions xdr3/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,9 @@ func TestDecoder(t *testing.T) {
rv, n, err = dec.DecodeEnum(validEnums)
case fDecodeFixedOpaque:
want := test.wantVal.([]byte)
rv, n, err = dec.DecodeFixedOpaque(int32(len(want)))
buf := make([]byte, len(want))
n, err = dec.DecodeFixedOpaque(buf)
rv = buf
case fDecodeFloat:
rv, n, err = dec.DecodeFloat()
case fDecodeHyper:
Expand Down Expand Up @@ -1032,7 +1034,8 @@ func TestPaddedReads(t *testing.T) {

// opaque
dec := NewDecoder(bytes.NewReader([]byte{0x0, 0x0, 0x1, 0x1}))
_, _, err := dec.DecodeFixedOpaque(3)
out := make([]byte, 3, 3)
_, err := dec.DecodeFixedOpaque(out)
if err == nil {
t.Error("expected error when unmarshaling opaque with non-zero padding byte, got none")
}
Expand Down
Loading

0 comments on commit e50eac4

Please sign in to comment.