diff --git a/beacon-chain/p2p/encoder/ssz.go b/beacon-chain/p2p/encoder/ssz.go index f8872d9cc682..eb328e6680b8 100644 --- a/beacon-chain/p2p/encoder/ssz.go +++ b/beacon-chain/p2p/encoder/ssz.go @@ -1,6 +1,7 @@ package encoder import ( + "bytes" "fmt" "io" @@ -18,14 +19,7 @@ type SszNetworkEncoder struct { } func (e SszNetworkEncoder) doEncode(msg interface{}) ([]byte, error) { - b, err := ssz.Marshal(msg) - if err != nil { - return nil, err - } - if e.UseSnappyCompression { - b = snappy.Encode(nil /*dst*/, b) - } - return b, nil + return ssz.Marshal(msg) } // Encode the proto message to the io.Writer. @@ -33,11 +27,13 @@ func (e SszNetworkEncoder) Encode(w io.Writer, msg interface{}) (int, error) { if msg == nil { return 0, nil } - b, err := e.doEncode(msg) if err != nil { return 0, err } + if e.UseSnappyCompression { + return writeSnappyBuffer(w, b) + } return w.Write(b) } @@ -51,7 +47,14 @@ func (e SszNetworkEncoder) EncodeWithLength(w io.Writer, msg interface{}) (int, if err != nil { return 0, err } - b = append(proto.EncodeVarint(uint64(len(b))), b...) + // write varint first + _, err = w.Write(proto.EncodeVarint(uint64(len(b)))) + if err != nil { + return 0, err + } + if e.UseSnappyCompression { + return writeSnappyBuffer(w, b) + } return w.Write(b) } @@ -68,21 +71,34 @@ func (e SszNetworkEncoder) EncodeWithMaxLength(w io.Writer, msg interface{}, max if uint64(len(b)) > maxSize { return 0, fmt.Errorf("size of encoded message is %d which is larger than the provided max limit of %d", len(b), maxSize) } - b = append(proto.EncodeVarint(uint64(len(b))), b...) + // write varint first + _, err = w.Write(proto.EncodeVarint(uint64(len(b)))) + if err != nil { + return 0, err + } + if e.UseSnappyCompression { + return writeSnappyBuffer(w, b) + } return w.Write(b) } +func (e SszNetworkEncoder) doDecode(b []byte, to interface{}) error { + return ssz.Unmarshal(b, to) +} + // Decode the bytes to the protobuf message provided. func (e SszNetworkEncoder) Decode(b []byte, to interface{}) error { if e.UseSnappyCompression { - var err error - b, err = snappy.Decode(nil /*dst*/, b) + newBuffer := bytes.NewBuffer(b) + r := snappy.NewReader(newBuffer) + newObj := make([]byte, len(b)) + numOfBytes, err := r.Read(newObj) if err != nil { return err } + return e.doDecode(newObj[:numOfBytes], to) } - - return ssz.Unmarshal(b, to) + return e.doDecode(b, to) } // DecodeWithLength the bytes from io.Reader to the protobuf message provided. @@ -91,12 +107,15 @@ func (e SszNetworkEncoder) DecodeWithLength(r io.Reader, to interface{}) error { if err != nil { return err } - b := make([]byte, msgLen) - _, err = r.Read(b) + if e.UseSnappyCompression { + r = snappy.NewReader(r) + } + b := make([]byte, e.MaxLength(int(msgLen))) + numOfBytes, err := r.Read(b) if err != nil { return err } - return e.Decode(b, to) + return e.doDecode(b[:numOfBytes], to) } // DecodeWithMaxLength the bytes from io.Reader to the protobuf message provided. @@ -106,15 +125,18 @@ func (e SszNetworkEncoder) DecodeWithMaxLength(r io.Reader, to interface{}, maxS if err != nil { return err } + if e.UseSnappyCompression { + r = snappy.NewReader(r) + } if msgLen > maxSize { return fmt.Errorf("size of decoded message is %d which is larger than the provided max limit of %d", msgLen, maxSize) } - b := make([]byte, msgLen) - _, err = r.Read(b) + b := make([]byte, e.MaxLength(int(msgLen))) + numOfBytes, err := r.Read(b) if err != nil { return err } - return e.Decode(b, to) + return e.doDecode(b[:numOfBytes], to) } // ProtocolSuffix returns the appropriate suffix for protocol IDs. @@ -124,3 +146,19 @@ func (e SszNetworkEncoder) ProtocolSuffix() string { } return "/ssz" } + +// MaxLength specifies the maximum possible length of an encoded +// chunk of data. +func (e SszNetworkEncoder) MaxLength(length int) int { + if e.UseSnappyCompression { + return snappy.MaxEncodedLen(length) + } + return length +} + +// Writes a bytes value through a snappy buffered writer. +func writeSnappyBuffer(w io.Writer, b []byte) (int, error) { + bufWriter := snappy.NewBufferedWriter(w) + defer bufWriter.Close() + return bufWriter.Write(b) +}