diff --git a/pkg/streamable/streamable_test.go b/pkg/streamable/streamable_test.go index de0bb76..ac411d6 100644 --- a/pkg/streamable/streamable_test.go +++ b/pkg/streamable/streamable_test.go @@ -227,8 +227,10 @@ func TestUnmarshal_ResponseBlock(t *testing.T) { encodedBytes, err := hex.DecodeString(hexStr) assert.NoError(t, err) - handshake := &protocols.RespondBlock{} + respondBlock := &protocols.RespondBlock{} - err = streamable.Unmarshal(encodedBytes, handshake) + err = streamable.Unmarshal(encodedBytes, respondBlock) assert.NoError(t, err) + + assert.Equal(t, 5_109_110, int(respondBlock.Block.RewardChainBlock.Height)) } diff --git a/pkg/types/program.go b/pkg/types/program.go index 9ca879b..9c14ad8 100644 --- a/pkg/types/program.go +++ b/pkg/types/program.go @@ -1,18 +1,25 @@ package types import ( + "bytes" + "encoding/binary" "encoding/json" "errors" - "fmt" + "io" ) // SerializedProgram An opaque representation of a clvm program. It has a more limited interface than a full SExp // https://github.com/Chia-Network/chia-blockchain/blob/main/chia/types/blockchain_format/program.py#L232 type SerializedProgram Bytes -const MAX_SINGLE_BYTE byte = 0x7f -const BACK_REFERENCE byte = 0xfe -const CONS_BOX_MARKER byte = 0xff +// MaxSingleByte Max single byte +const MaxSingleByte byte = 0x7f + +// BackReference back referencee marker +const BackReference byte = 0xfe + +// ConsBoxMarker cons box marker +const ConsBoxMarker byte = 0xff const ( badEncErr = "bad encoding" @@ -37,68 +44,76 @@ func (g *SerializedProgram) UnmarshalJSON(data []byte) error { return nil } +// SerializedLengthFromBytesTrusted returns the length func SerializedLengthFromBytesTrusted(b []byte) (uint64, error) { + reader := bytes.NewReader(b) var opsCounter uint64 = 1 - var position uint64 = 0 - start := len(b) for opsCounter > 0 { opsCounter-- - if len(b) == 0 { - return 0, errors.New("unexpected end of input") + + var currentByte byte + err := binary.Read(reader, binary.BigEndian, ¤tByte) + if err != nil { + if err == io.EOF { + return 0, errors.New("unexpected end of input") + } + return 0, err } - currentByte := b[0] - b = b[1:] - position++ - if currentByte == CONS_BOX_MARKER { + if currentByte == ConsBoxMarker { opsCounter += 2 - } else if currentByte == BACK_REFERENCE { - if len(b) == 0 { + } else if currentByte == BackReference { + var firstByte byte + err = binary.Read(reader, binary.BigEndian, &firstByte) + if err != nil { return 0, errors.New("unexpected end of input") } - firstByte := b[0] - b = b[1:] - position++ - if firstByte > MAX_SINGLE_BYTE { - _, length, err := decodeSize(b, firstByte) + if firstByte > MaxSingleByte { + pathSize, err := decodeSize(reader, firstByte) if err != nil { return 0, err } - b = b[length:] - position += length + _, err = reader.Seek(int64(pathSize), io.SeekCurrent) + if err != nil { + return 0, errors.New("bad encoding") + } } - } else if currentByte == 0x80 || currentByte <= MAX_SINGLE_BYTE { - // This one byte we just read was the whole atom. - // or the special case of NIL + } else if currentByte == 0x80 || currentByte <= MaxSingleByte { + // This one byte we just read was the whole atom or the special case of NIL. } else { - _, length, err := decodeSize(b, currentByte) + blobSize, err := decodeSize(reader, currentByte) if err != nil { return 0, err } - b = b[length:] - position += length + _, err = reader.Seek(int64(blobSize), io.SeekCurrent) + if err != nil { + return 0, errors.New("bad encoding") + } } - } - fmt.Println("read bytes", start, start-len(b), position) - - return position, nil + position, err := reader.Seek(0, io.SeekCurrent) + if err != nil { + return 0, err + } + return uint64(position), nil } -func decodeSize(input []byte, initialB byte) (byte, uint64, error) { +func decodeSize(reader *bytes.Reader, initialB byte) (uint64, error) { + _, length, err := decodeSizeWithOffset(reader, initialB) + return length, err +} +func decodeSizeWithOffset(reader *bytes.Reader, initialB byte) (uint64, uint64, error) { bitMask := byte(0x80) if (initialB & bitMask) == 0 { return 0, 0, errors.New(internalErr) } - var atomStartOffset byte - + var atomStartOffset uint64 = 0 b := initialB - for (b & bitMask) != 0 { atomStartOffset++ b &= 0xff ^ bitMask @@ -109,7 +124,11 @@ func decodeSize(input []byte, initialB byte) (byte, uint64, error) { sizeBlob[0] = b if atomStartOffset > 1 { - copy(sizeBlob[1:], input) + // We need to read atomStartOffset-1 more bytes + _, err := io.ReadFull(reader, sizeBlob[1:]) + if err != nil { + return 0, 0, err + } } var atomSize uint64 = 0 diff --git a/pkg/types/program_test.go b/pkg/types/program_test.go index 31e49a4..9904859 100644 --- a/pkg/types/program_test.go +++ b/pkg/types/program_test.go @@ -1,6 +1,7 @@ package types import ( + "bytes" "encoding/hex" "github.com/stretchr/testify/assert" "testing" @@ -49,6 +50,20 @@ func TestSerializedLengthFromBytesTrusted(t *testing.T) { assert.NoError(t, err) assert.Equal(t, len(encodedBytes), int(length)) + hexStr = "900cecb8f27d268c2ac73fe5b520db3813" + encodedBytes, err = hex.DecodeString(hexStr) + assert.NoError(t, err) + length, err = SerializedLengthFromBytesTrusted(encodedBytes) + assert.NoError(t, err) + assert.Equal(t, len(encodedBytes), int(length)) + + hexStr = "c059697066733a2f2f62616679626569687478796637737462356b78787473756d77326e6f34326766736871676a687837646d696e706776627468697433616a70336f792f5065706542656172732d25323028333437292e706e67" + encodedBytes, err = hex.DecodeString(hexStr) + assert.NoError(t, err) + length, err = SerializedLengthFromBytesTrusted(encodedBytes) + assert.NoError(t, err) + assert.Equal(t, len(encodedBytes), int(length)) + length, err = SerializedLengthFromBytesTrusted([]byte{0x7f, 0x00, 0x00, 0x00}) assert.NoError(t, err) assert.Equal(t, 1, int(length)) @@ -77,27 +92,27 @@ func TestSerializedLengthFromBytesTrusted(t *testing.T) { func TestDecodeSize(t *testing.T) { - _, length, err := decodeSize([]byte{}, 0x80|0x20) + length, err := decodeSize(bytes.NewReader([]byte{}), 0x80|0x20) assert.NoError(t, err) assert.Equal(t, 32, int(length)) - _, length, err = decodeSize([]byte{0xaa}, 0b11001111) + length, err = decodeSize(bytes.NewReader([]byte{0xaa}), 0b11001111) assert.NoError(t, err) assert.Equal(t, 4010, int(length)) - _, length, err = decodeSize([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 0b11111110) + _, err = decodeSize(bytes.NewReader([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}), 0b11111110) assert.Error(t, err, "bad encoding") - _, length, err = decodeSize([]byte{0x4, 0, 0, 0, 0}, 0b11111100) + _, err = decodeSize(bytes.NewReader([]byte{0x4, 0, 0, 0, 0}), 0b11111100) assert.Error(t, err, "bad encoding") - _, length, err = decodeSize([]byte{0x3, 0xff, 0xff, 0xff, 0xff}, 0b11111100) + length, err = decodeSize(bytes.NewReader([]byte{0x3, 0xff, 0xff, 0xff, 0xff}), 0b11111100) assert.NoError(t, err) assert.Equal(t, 17179869183, int(length)) - _, _, err = decodeSize([]byte{0xff, 0xfe}, 0b11111100) + _, err = decodeSize(bytes.NewReader([]byte{0xff, 0xfe}), 0b11111100) assert.Error(t, err, "bad encoding") - _, _, err = decodeSize([]byte{0x4, 0, 0, 0}, 0b11111100) + _, err = decodeSize(bytes.NewReader([]byte{0x4, 0, 0, 0}), 0b11111100) assert.Error(t, err, "bad encoding") }