From e50eac4f7b6a1664cff3897e037077fda4535aa5 Mon Sep 17 00:00:00 2001 From: Alfonso Acosta Date: Tue, 2 Nov 2021 21:44:02 +0100 Subject: [PATCH] Reduce encoding and decoding allocations I reduced the allocation count by 1. Adding a scratch buffer to the Encoder/decoder. The scratch buffer is used as a temporary buffer. Before that, temporary buffers were allocated (in the heap since they escaped the stack) for basic type encoding/decoding (e.g. `DecodeInt()` `EncodeInt()` 2. Making `DecodeFixedOpaque()` decode in-place instead of allocating the result. Apart from reducing allocations, this removes the need of copying (as shown by `decodeFixedArray()`. The pre-existing (and admitedly very limited) benchmarks show a sizeable improvement: Before: ``` goos: darwin goarch: amd64 pkg: github.com/stellar/go-xdr/xdr3 cpu: Intel(R) Core(TM) i7-1068NG7 CPU @ 2.30GHz BenchmarkUnmarshal BenchmarkUnmarshal-8 3000 1039 ns/op 15.40 MB/s 152 B/op 11 allocs/op BenchmarkMarshal BenchmarkMarshal-8 3000 806.7 ns/op 19.83 MB/s 104 B/op 10 allocs/op PASS ``` After: ``` goos: darwin goarch: amd64 pkg: github.com/stellar/go-xdr/xdr3 cpu: Intel(R) Core(TM) i7-1068NG7 CPU @ 2.30GHz BenchmarkUnmarshal BenchmarkUnmarshal-8 3000 679.6 ns/op 23.54 MB/s 144 B/op 8 allocs/op BenchmarkMarshal BenchmarkMarshal-8 3000 609.1 ns/op 26.27 MB/s 104 B/op 8 allocs/op PASS ``` Both decoding and encoding both go down to 8 allocations from 11 (decoding) and 10 (encoding) allocations per operation. More context at https://github.com/stellar/go/issues/4022#issuecomment-957929343 --- xdr3/decode.go | 138 +++++++++++++++++++++---------------------- xdr3/decode_test.go | 7 ++- xdr3/encode.go | 76 ++++++++++++------------ xdr3/example_test.go | 4 +- 4 files changed, 113 insertions(+), 112 deletions(-) diff --git a/xdr3/decode.go b/xdr3/decode.go index e3627df..3083c26 100644 --- a/xdr3/decode.go +++ b/xdr3/decode.go @@ -90,7 +90,9 @@ func Unmarshal(r io.Reader, v interface{}) (int, error) { // necessary in complex scenarios where automatic reflection-based decoding // won't work. type Decoder struct { - r io.Reader + // used to minimize heap allocations during decoding + scratchBuf [8]byte + r io.Reader } // DecodeInt treats the next 4 bytes as an XDR encoded integer and returns the @@ -102,16 +104,15 @@ type Decoder struct { // RFC Section 4.1 - Integer // 32-bit big-endian signed integer in range [-2147483648, 2147483647] func (d *Decoder) DecodeInt() (int32, int, error) { - var buf [4]byte - n, err := io.ReadFull(d.r, buf[:]) + n, err := io.ReadFull(d.r, d.scratchBuf[:4]) if err != nil { msg := fmt.Sprintf(errIODecode, err.Error(), 4) - err := unmarshalError("DecodeInt", ErrIO, msg, buf[:n], err) + err := unmarshalError("DecodeInt", ErrIO, msg, d.scratchBuf[:n], err) return 0, n, err } - rv := int32(buf[3]) | int32(buf[2])<<8 | - int32(buf[1])<<16 | int32(buf[0])<<24 + rv := int32(d.scratchBuf[3]) | int32(d.scratchBuf[2])<<8 | + int32(d.scratchBuf[1])<<16 | int32(d.scratchBuf[0])<<24 return rv, n, nil } @@ -124,16 +125,15 @@ func (d *Decoder) DecodeInt() (int32, int, error) { // RFC Section 4.2 - Unsigned Integer // 32-bit big-endian unsigned integer in range [0, 4294967295] func (d *Decoder) DecodeUint() (uint32, int, error) { - var buf [4]byte - n, err := io.ReadFull(d.r, buf[:]) + n, err := io.ReadFull(d.r, d.scratchBuf[:4]) if err != nil { msg := fmt.Sprintf(errIODecode, err.Error(), 4) - err := unmarshalError("DecodeUint", ErrIO, msg, buf[:n], err) + err := unmarshalError("DecodeUint", ErrIO, msg, d.scratchBuf[:n], err) return 0, n, err } - rv := uint32(buf[3]) | uint32(buf[2])<<8 | - uint32(buf[1])<<16 | uint32(buf[0])<<24 + rv := uint32(d.scratchBuf[3]) | uint32(d.scratchBuf[2])<<8 | + uint32(d.scratchBuf[1])<<16 | uint32(d.scratchBuf[0])<<24 return rv, n, nil } @@ -197,18 +197,17 @@ func (d *Decoder) DecodeBool() (bool, int, error) { // RFC Section 4.5 - Hyper Integer // 64-bit big-endian signed integer in range [-9223372036854775808, 9223372036854775807] func (d *Decoder) DecodeHyper() (int64, int, error) { - var buf [8]byte - n, err := io.ReadFull(d.r, buf[:]) + n, err := io.ReadFull(d.r, d.scratchBuf[:8]) if err != nil { msg := fmt.Sprintf(errIODecode, err.Error(), 8) - err := unmarshalError("DecodeHyper", ErrIO, msg, buf[:n], err) + err := unmarshalError("DecodeHyper", ErrIO, msg, d.scratchBuf[:n], err) return 0, n, err } - rv := int64(buf[7]) | int64(buf[6])<<8 | - int64(buf[5])<<16 | int64(buf[4])<<24 | - int64(buf[3])<<32 | int64(buf[2])<<40 | - int64(buf[1])<<48 | int64(buf[0])<<56 + rv := int64(d.scratchBuf[7]) | int64(d.scratchBuf[6])<<8 | + int64(d.scratchBuf[5])<<16 | int64(d.scratchBuf[4])<<24 | + int64(d.scratchBuf[3])<<32 | int64(d.scratchBuf[2])<<40 | + int64(d.scratchBuf[1])<<48 | int64(d.scratchBuf[0])<<56 return rv, n, err } @@ -222,18 +221,17 @@ func (d *Decoder) DecodeHyper() (int64, int, error) { // RFC Section 4.5 - Unsigned Hyper Integer // 64-bit big-endian unsigned integer in range [0, 18446744073709551615] func (d *Decoder) DecodeUhyper() (uint64, int, error) { - var buf [8]byte - n, err := io.ReadFull(d.r, buf[:]) + n, err := io.ReadFull(d.r, d.scratchBuf[:8]) if err != nil { msg := fmt.Sprintf(errIODecode, err.Error(), 8) - err := unmarshalError("DecodeUhyper", ErrIO, msg, buf[:n], err) + err := unmarshalError("DecodeUhyper", ErrIO, msg, d.scratchBuf[:n], err) return 0, n, err } - rv := uint64(buf[7]) | uint64(buf[6])<<8 | - uint64(buf[5])<<16 | uint64(buf[4])<<24 | - uint64(buf[3])<<32 | uint64(buf[2])<<40 | - uint64(buf[1])<<48 | uint64(buf[0])<<56 + rv := uint64(d.scratchBuf[7]) | uint64(d.scratchBuf[6])<<8 | + uint64(d.scratchBuf[5])<<16 | uint64(d.scratchBuf[4])<<24 | + uint64(d.scratchBuf[3])<<32 | uint64(d.scratchBuf[2])<<40 | + uint64(d.scratchBuf[1])<<48 | uint64(d.scratchBuf[0])<<56 return rv, n, nil } @@ -246,16 +244,15 @@ func (d *Decoder) DecodeUhyper() (uint64, int, error) { // RFC Section 4.6 - Floating Point // 32-bit single-precision IEEE 754 floating point func (d *Decoder) DecodeFloat() (float32, int, error) { - var buf [4]byte - n, err := io.ReadFull(d.r, buf[:]) + n, err := io.ReadFull(d.r, d.scratchBuf[:4]) if err != nil { msg := fmt.Sprintf(errIODecode, err.Error(), 4) - err := unmarshalError("DecodeFloat", ErrIO, msg, buf[:n], err) + err := unmarshalError("DecodeFloat", ErrIO, msg, d.scratchBuf[:n], err) return 0, n, err } - val := uint32(buf[3]) | uint32(buf[2])<<8 | - uint32(buf[1])<<16 | uint32(buf[0])<<24 + val := uint32(d.scratchBuf[3]) | uint32(d.scratchBuf[2])<<8 | + uint32(d.scratchBuf[1])<<16 | uint32(d.scratchBuf[0])<<24 return math.Float32frombits(val), n, nil } @@ -269,18 +266,17 @@ func (d *Decoder) DecodeFloat() (float32, int, error) { // RFC Section 4.7 - Double-Precision Floating Point // 64-bit double-precision IEEE 754 floating point func (d *Decoder) DecodeDouble() (float64, int, error) { - var buf [8]byte - n, err := io.ReadFull(d.r, buf[:]) + n, err := io.ReadFull(d.r, d.scratchBuf[:8]) if err != nil { msg := fmt.Sprintf(errIODecode, err.Error(), 8) - err := unmarshalError("DecodeDouble", ErrIO, msg, buf[:n], err) + err := unmarshalError("DecodeDouble", ErrIO, msg, d.scratchBuf[:n], err) return 0, n, err } - val := uint64(buf[7]) | uint64(buf[6])<<8 | - uint64(buf[5])<<16 | uint64(buf[4])<<24 | - uint64(buf[3])<<32 | uint64(buf[2])<<40 | - uint64(buf[1])<<48 | uint64(buf[0])<<56 + val := uint64(d.scratchBuf[7]) | uint64(d.scratchBuf[6])<<8 | + uint64(d.scratchBuf[5])<<16 | uint64(d.scratchBuf[4])<<24 | + uint64(d.scratchBuf[3])<<32 | uint64(d.scratchBuf[2])<<40 | + uint64(d.scratchBuf[1])<<48 | uint64(d.scratchBuf[0])<<56 return math.Float64frombits(val), n, nil } @@ -299,10 +295,11 @@ func (d *Decoder) DecodeDouble() (float64, int, error) { // Reference: // RFC Section 4.9 - Fixed-Length Opaque Data // Fixed-length uninterpreted data zero-padded to a multiple of four -func (d *Decoder) DecodeFixedOpaque(size int32) ([]byte, int, error) { +func (d *Decoder) DecodeFixedOpaque(out []byte) (int, error) { + size := len(out) // Nothing to do if size is 0. if size == 0 { - return nil, 0, nil + return 0, nil } pad := (4 - (size % 4)) % 4 @@ -310,34 +307,39 @@ func (d *Decoder) DecodeFixedOpaque(size int32) ([]byte, int, error) { if uint(paddedSize) > uint(maxInt32) { err := unmarshalError("DecodeFixedOpaque", ErrOverflow, errMaxSlice, paddedSize, nil) - return nil, 0, err + return 0, err } - buf := make([]byte, paddedSize) - n, err := io.ReadFull(d.r, buf) + n, err := io.ReadFull(d.r, out) if err != nil { - msg := fmt.Sprintf(errIODecode, err.Error(), paddedSize) - err := unmarshalError("DecodeFixedOpaque", ErrIO, msg, buf[:n], + msg := fmt.Sprintf(errIODecode, err.Error(), size) + err := unmarshalError("DecodeFixedOpaque", ErrIO, msg, out[:n], err) - return nil, n, err - } - - if !d.checkPadding(buf, size) { - msg := "non-zero padding" - err := unmarshalError("DecodeFixedOpaque", ErrIO, msg, buf[:n], nil) - return nil, n, err + return n, err } - return buf[0:size], n, nil -} - -func (d *Decoder) checkPadding(buf []byte, size int32) bool { - for _, pad := range buf[size:] { - if pad != 0x00 { - return false + if pad > 0 { + // the maximum value of pad is 3, so the scratch buffer should be enough + padding := d.scratchBuf[:pad] + n2, err := io.ReadFull(d.r, padding) + if err != nil { + msg := fmt.Sprintf(errIODecode, err.Error(), pad) + err := unmarshalError("DecodeFixedOpaque", ErrIO, msg, out[:n], + err) + return n, err + } + n += n2 + // check all the padding bytes to be zero + for _, p := range padding { + if p != 0x00 { + msg := "non-zero padding" + err := unmarshalError("DecodeFixedOpaque", ErrIO, msg, padding[:n2], nil) + return n, err + } } } - return true + + return n, nil } // DecodeOpaque treats the next bytes as variable length XDR encoded opaque @@ -365,8 +367,8 @@ func (d *Decoder) DecodeOpaque(maxSize int) ([]byte, int, error) { dataLen, nil) return nil, n, err } - - rv, n2, err := d.DecodeFixedOpaque(int32(dataLen)) + rv := make([]byte, dataLen) + n2, err := d.DecodeFixedOpaque(rv) n += n2 if err != nil { return nil, n, err @@ -404,7 +406,8 @@ func (d *Decoder) DecodeString(maxSize int) (string, int, error) { return "", n, err } - opaque, n2, err := d.DecodeFixedOpaque(int32(dataLen)) + opaque := make([]byte, dataLen) + n2, err := d.DecodeFixedOpaque(opaque) n += n2 if err != nil { return "", n, err @@ -428,12 +431,8 @@ func (d *Decoder) decodeFixedArray(v reflect.Value, ignoreOpaque bool) (int, err // Treat [#]byte (byte is alias for uint8) as opaque data unless // ignored. if !ignoreOpaque && v.Type().Elem().Kind() == reflect.Uint8 { - data, n, err := d.DecodeFixedOpaque(int32(v.Len())) - if err != nil { - return n, err - } - reflect.Copy(v, reflect.ValueOf(data)) - return n, nil + dest := v.Slice(0, v.Len()).Bytes() + return d.DecodeFixedOpaque(dest) } // Decode each array element. @@ -489,7 +488,8 @@ func (d *Decoder) decodeArray(v reflect.Value, ignoreOpaque bool, maxSize int) ( // Treat []byte (byte is alias for uint8) as opaque data unless ignored. if !ignoreOpaque && v.Type().Elem().Kind() == reflect.Uint8 { - data, n2, err := d.DecodeFixedOpaque(int32(sliceLen)) + data := make([]byte, sliceLen) + n2, err := d.DecodeFixedOpaque(data) n += n2 if err != nil { return n, err diff --git a/xdr3/decode_test.go b/xdr3/decode_test.go index bc575f6..06b1310 100644 --- a/xdr3/decode_test.go +++ b/xdr3/decode_test.go @@ -686,7 +686,9 @@ func TestDecoder(t *testing.T) { rv, n, err = dec.DecodeEnum(validEnums) case fDecodeFixedOpaque: want := test.wantVal.([]byte) - rv, n, err = dec.DecodeFixedOpaque(int32(len(want))) + buf := make([]byte, len(want)) + n, err = dec.DecodeFixedOpaque(buf) + rv = buf case fDecodeFloat: rv, n, err = dec.DecodeFloat() case fDecodeHyper: @@ -1032,7 +1034,8 @@ func TestPaddedReads(t *testing.T) { // opaque dec := NewDecoder(bytes.NewReader([]byte{0x0, 0x0, 0x1, 0x1})) - _, _, err := dec.DecodeFixedOpaque(3) + out := make([]byte, 3, 3) + _, err := dec.DecodeFixedOpaque(out) if err == nil { t.Error("expected error when unmarshaling opaque with non-zero padding byte, got none") } diff --git a/xdr3/encode.go b/xdr3/encode.go index b8f78a2..b6d7cdc 100644 --- a/xdr3/encode.go +++ b/xdr3/encode.go @@ -80,7 +80,9 @@ func Marshal(w io.Writer, v interface{}) (int, error) { // An Encoder wraps an io.Writer that will receive the XDR encoded byte stream. // See NewEncoder. type Encoder struct { - w io.Writer + // used to minimize heap allocations during encoding + scratchBuf [8]byte + w io.Writer } // EncodeInt writes the XDR encoded representation of the passed 32-bit signed @@ -93,16 +95,15 @@ type Encoder struct { // RFC Section 4.1 - Integer // 32-bit big-endian signed integer in range [-2147483648, 2147483647] func (enc *Encoder) EncodeInt(v int32) (int, error) { - var b [4]byte - b[0] = byte(v >> 24) - b[1] = byte(v >> 16) - b[2] = byte(v >> 8) - b[3] = byte(v) + enc.scratchBuf[0] = byte(v >> 24) + enc.scratchBuf[1] = byte(v >> 16) + enc.scratchBuf[2] = byte(v >> 8) + enc.scratchBuf[3] = byte(v) - n, err := enc.w.Write(b[:]) + n, err := enc.w.Write(enc.scratchBuf[:4]) if err != nil { msg := fmt.Sprintf(errIOEncode, err.Error(), 4) - err := marshalError("EncodeInt", ErrIO, msg, b[:n], err) + err := marshalError("EncodeInt", ErrIO, msg, enc.scratchBuf[:n], err) return n, err } @@ -120,16 +121,15 @@ func (enc *Encoder) EncodeInt(v int32) (int, error) { // RFC Section 4.2 - Unsigned Integer // 32-bit big-endian unsigned integer in range [0, 4294967295] func (enc *Encoder) EncodeUint(v uint32) (int, error) { - var b [4]byte - b[0] = byte(v >> 24) - b[1] = byte(v >> 16) - b[2] = byte(v >> 8) - b[3] = byte(v) + enc.scratchBuf[0] = byte(v >> 24) + enc.scratchBuf[1] = byte(v >> 16) + enc.scratchBuf[2] = byte(v >> 8) + enc.scratchBuf[3] = byte(v) - n, err := enc.w.Write(b[:]) + n, err := enc.w.Write(enc.scratchBuf[:4]) if err != nil { msg := fmt.Sprintf(errIOEncode, err.Error(), 4) - err := marshalError("EncodeUint", ErrIO, msg, b[:n], err) + err := marshalError("EncodeUint", ErrIO, msg, enc.scratchBuf[:4], err) return n, err } @@ -184,20 +184,19 @@ func (enc *Encoder) EncodeBool(v bool) (int, error) { // RFC Section 4.5 - Hyper Integer // 64-bit big-endian signed integer in range [-9223372036854775808, 9223372036854775807] func (enc *Encoder) EncodeHyper(v int64) (int, error) { - var b [8]byte - b[0] = byte(v >> 56) - b[1] = byte(v >> 48) - b[2] = byte(v >> 40) - b[3] = byte(v >> 32) - b[4] = byte(v >> 24) - b[5] = byte(v >> 16) - b[6] = byte(v >> 8) - b[7] = byte(v) - - n, err := enc.w.Write(b[:]) + enc.scratchBuf[0] = byte(v >> 56) + enc.scratchBuf[1] = byte(v >> 48) + enc.scratchBuf[2] = byte(v >> 40) + enc.scratchBuf[3] = byte(v >> 32) + enc.scratchBuf[4] = byte(v >> 24) + enc.scratchBuf[5] = byte(v >> 16) + enc.scratchBuf[6] = byte(v >> 8) + enc.scratchBuf[7] = byte(v) + + n, err := enc.w.Write(enc.scratchBuf[:8]) if err != nil { msg := fmt.Sprintf(errIOEncode, err.Error(), 8) - err := marshalError("EncodeHyper", ErrIO, msg, b[:n], err) + err := marshalError("EncodeHyper", ErrIO, msg, enc.scratchBuf[:8], err) return n, err } @@ -215,20 +214,19 @@ func (enc *Encoder) EncodeHyper(v int64) (int, error) { // RFC Section 4.5 - Unsigned Hyper Integer // 64-bit big-endian unsigned integer in range [0, 18446744073709551615] func (enc *Encoder) EncodeUhyper(v uint64) (int, error) { - var b [8]byte - b[0] = byte(v >> 56) - b[1] = byte(v >> 48) - b[2] = byte(v >> 40) - b[3] = byte(v >> 32) - b[4] = byte(v >> 24) - b[5] = byte(v >> 16) - b[6] = byte(v >> 8) - b[7] = byte(v) - - n, err := enc.w.Write(b[:]) + enc.scratchBuf[0] = byte(v >> 56) + enc.scratchBuf[1] = byte(v >> 48) + enc.scratchBuf[2] = byte(v >> 40) + enc.scratchBuf[3] = byte(v >> 32) + enc.scratchBuf[4] = byte(v >> 24) + enc.scratchBuf[5] = byte(v >> 16) + enc.scratchBuf[6] = byte(v >> 8) + enc.scratchBuf[7] = byte(v) + + n, err := enc.w.Write(enc.scratchBuf[:8]) if err != nil { msg := fmt.Sprintf(errIOEncode, err.Error(), 8) - err := marshalError("EncodeUhyper", ErrIO, msg, b[:n], err) + err := marshalError("EncodeUhyper", ErrIO, msg, enc.scratchBuf[:n], err) return n, err } diff --git a/xdr3/example_test.go b/xdr3/example_test.go index ce2f52d..53d3029 100644 --- a/xdr3/example_test.go +++ b/xdr3/example_test.go @@ -113,8 +113,8 @@ func ExampleNewDecoder() { // Get a new decoder for manual decoding. dec := xdr.NewDecoder(bytes.NewReader(encodedData)) - signature, _, err := dec.DecodeFixedOpaque(3) - if err != nil { + signature := make([]byte, 3) + if _, err := dec.DecodeFixedOpaque(signature); err != nil { fmt.Println(err) return }