diff --git a/blockchain/fullblocktests/generate.go b/blockchain/fullblocktests/generate.go index 4c551c05e0..963926e781 100644 --- a/blockchain/fullblocktests/generate.go +++ b/blockchain/fullblocktests/generate.go @@ -384,6 +384,8 @@ func additionalCoinbase(amount btcutil.Amount) func(*wire.MsgBlock) { // Increase the first proof-of-work coinbase subsidy by the // provided amount. b.Transactions[0].TxOut[0].Value += int64(amount) + + b.Transactions[0].WipeCache() } } @@ -402,6 +404,8 @@ func additionalSpendFee(fee btcutil.Amount) func(*wire.MsgBlock) { fee)) } b.Transactions[1].TxOut[0].Value -= int64(fee) + + b.Transactions[1].WipeCache() } } @@ -410,6 +414,8 @@ func additionalSpendFee(fee btcutil.Amount) func(*wire.MsgBlock) { func replaceSpendScript(pkScript []byte) func(*wire.MsgBlock) { return func(b *wire.MsgBlock) { b.Transactions[1].TxOut[0].PkScript = pkScript + + b.Transactions[1].WipeCache() } } @@ -418,6 +424,8 @@ func replaceSpendScript(pkScript []byte) func(*wire.MsgBlock) { func replaceCoinbaseSigScript(script []byte) func(*wire.MsgBlock) { return func(b *wire.MsgBlock) { b.Transactions[0].TxIn[0].SignatureScript = script + + b.Transactions[0].WipeCache() } } diff --git a/btcutil/txsort/txsort_test.go b/btcutil/txsort/txsort_test.go index dd2149294e..d196d6d045 100644 --- a/btcutil/txsort/txsort_test.go +++ b/btcutil/txsort/txsort_test.go @@ -114,6 +114,9 @@ func TestSort(t *testing.T) { // Now sort the transaction using the mutable version and ensure // the resulting hash is the expected value. txsort.InPlaceSort(&tx) + + tx.WipeCache() + if got := tx.TxHash().String(); got != test.sortedHash { t.Errorf("SortMutate (%s): sorted hash does not match "+ "expected - got %v, want %v", test.name, got, diff --git a/wire/message_test.go b/wire/message_test.go index 7ba2e0639f..8b0187e4cd 100644 --- a/wire/message_test.go +++ b/wire/message_test.go @@ -137,6 +137,18 @@ func TestMessage(t *testing.T) { spew.Sdump(msg)) continue } + + // Blank out the cached encoding for transactions to ensure the + // deep equality check doesn't fail. + if tx, ok := msg.(*MsgTx); ok { + tx.cachedSeralizedNoWitness = nil + } + if block, ok := msg.(*MsgBlock); ok { + for _, tx := range block.Transactions { + tx.cachedSeralizedNoWitness = nil + } + } + if !reflect.DeepEqual(msg, test.out) { t.Errorf("ReadMessage #%d\n got: %v want: %v", i, spew.Sdump(msg), spew.Sdump(test.out)) @@ -170,6 +182,18 @@ func TestMessage(t *testing.T) { spew.Sdump(msg)) continue } + + // Blank out the cached encoding for transactions to ensure the + // deep equality check doesn't fail. + if tx, ok := msg.(*MsgTx); ok { + tx.cachedSeralizedNoWitness = nil + } + if block, ok := msg.(*MsgBlock); ok { + for _, tx := range block.Transactions { + tx.cachedSeralizedNoWitness = nil + } + } + if !reflect.DeepEqual(msg, test.out) { t.Errorf("ReadMessage #%d\n got: %v want: %v", i, spew.Sdump(msg), spew.Sdump(test.out)) diff --git a/wire/msgtx.go b/wire/msgtx.go index 7705504cc8..49291b422d 100644 --- a/wire/msgtx.go +++ b/wire/msgtx.go @@ -343,6 +343,12 @@ type MsgTx struct { TxIn []*TxIn TxOut []*TxOut LockTime uint32 + + // cachedSeralizedNoWitness is a cached version of the serialization of + // this transaction without witness data. When we decode a transaction, + // we'll write out the non-witness bytes to this so we can quickly + // calculate the TxHash later if needed. + cachedSeralizedNoWitness []byte } // AddTxIn adds a transaction input to the message. @@ -357,13 +363,26 @@ func (msg *MsgTx) AddTxOut(to *TxOut) { // TxHash generates the Hash for the transaction. func (msg *MsgTx) TxHash() chainhash.Hash { - // Encode the transaction and calculate double sha256 on the result. - // Ignore the error returns since the only way the encode could fail - // is being out of memory or due to nil pointers, both of which would - // cause a run-time panic. - buf := bytes.NewBuffer(make([]byte, 0, msg.SerializeSizeStripped())) - _ = msg.SerializeNoWitness(buf) - return chainhash.DoubleHashH(buf.Bytes()) + if msg.cachedSeralizedNoWitness == nil { + // Encode the transaction and calculate double sha256 on the + // result. Ignore the error returns since the only way the + // encode could fail is being out of memory or due to nil + // pointers, both of which would cause a run-time panic. + strippedSize := msg.SerializeSizeStripped() + buf := bytes.NewBuffer(make([]byte, 0, strippedSize)) + _ = msg.SerializeNoWitness(buf) + + msg.cachedSeralizedNoWitness = buf.Bytes() + } + + return chainhash.DoubleHashH(msg.cachedSeralizedNoWitness) +} + +// WipeCache removes the cached serialized bytes of the transaction. This is +// useful to be able to get the correct txid after mutating a transaction's +// state. +func (msg *MsgTx) WipeCache() { + msg.cachedSeralizedNoWitness = nil } // WitnessHash generates the hash of the transaction serialized according to @@ -461,7 +480,14 @@ func (msg *MsgTx) Copy() *MsgTx { // See Deserialize for decoding transactions stored to disk, such as in a // database, as opposed to decoding transactions from the wire. func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error { - version, err := binarySerializer.Uint32(r, littleEndian) + // We'll use a tee reader in order to incrementally cache the raw + // non-witness serialization of this transaction. We'll then later + // cache this value as it allow to compute the TxHash more quickly, as + // we don't need to re-serialize the entire transaction. + var rawTxBuf bytes.Buffer + rawTxTeeReader := io.TeeReader(r, &rawTxBuf) + + version, err := binarySerializer.Uint32(rawTxTeeReader, littleEndian) if err != nil { return err } @@ -472,12 +498,15 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error return err } - // A count of zero (meaning no TxIn's to the uninitiated) means that the - // value is a TxFlagMarker, and hence indicates the presence of a flag. - var flag [1]TxFlag + // A count of zero (meaning no TxIn's to the uninitiated) indicates + // this is a transaction with witness data. Notice that we don't use + // the rawTxTeeReader here, as these are segwit specific bytes. + var ( + flag [1]byte + hasWitneess bool + ) if count == TxFlagMarker && enc == WitnessEncoding { - // The count varint was in fact the flag marker byte. Next, we need to - // read the flag value, which is a single byte. + // Next, we need to read the flag, which is a single byte. if _, err = io.ReadFull(r, flag[:]); err != nil { return err } @@ -495,6 +524,15 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error if err != nil { return err } + + hasWitneess = true + } + + // Write out the actual number of inputs as this won't be the very byte + // series after the versino of segwit transactions. + if WriteVarInt(&rawTxBuf, pver, count); err != nil { + str := fmt.Sprintf("unable to write txin count: %v", err) + return messageError("MsgTx.BtcDecode", str) } // Prevent more input transactions than could possibly fit into a @@ -545,7 +583,7 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error // and needs to be returned to the pool on error. ti := &txIns[i] msg.TxIn[i] = ti - err = readTxIn(r, pver, msg.Version, ti) + err = readTxIn(rawTxTeeReader, pver, msg.Version, ti) if err != nil { returnScriptBuffers() return err @@ -553,7 +591,7 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error totalScriptSize += uint64(len(ti.SignatureScript)) } - count, err = ReadVarInt(r, pver) + count, err = ReadVarInt(rawTxTeeReader, pver) if err != nil { returnScriptBuffers() return err @@ -578,7 +616,7 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error // and needs to be returned to the pool on error. to := &txOuts[i] msg.TxOut[i] = to - err = ReadTxOut(r, pver, msg.Version, to) + err = ReadTxOut(rawTxTeeReader, pver, msg.Version, to) if err != nil { returnScriptBuffers() return err @@ -588,7 +626,7 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error // If the transaction's flag byte isn't 0x00 at this point, then one or // more of its inputs has accompanying witness data. - if flag[0] != 0 && enc == WitnessEncoding { + if hasWitneess && enc == WitnessEncoding { for _, txin := range msg.TxIn { // For each input, the witness is encoded as a stack // with one or more items. Therefore, we first read a @@ -626,7 +664,9 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error } } - msg.LockTime, err = binarySerializer.Uint32(r, littleEndian) + msg.LockTime, err = binarySerializer.Uint32( + rawTxTeeReader, littleEndian, + ) if err != nil { returnScriptBuffers() return err @@ -700,6 +740,11 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error scriptPool.Return(pkScript) } + // Now that we've decoded the entire transaction without any issues, + // we'll cache the non-witness serialization so we can more quickly + // calculate the TxHash in the future. + msg.cachedSeralizedNoWitness = rawTxBuf.Bytes() + return nil } @@ -832,6 +877,10 @@ func (msg *MsgTx) Serialize(w io.Writer) error { // Serialize, however even if the source transaction has inputs with witness // data, the old serialization format will still be used. func (msg *MsgTx) SerializeNoWitness(w io.Writer) error { + if msg.cachedSeralizedNoWitness != nil { + w.Write(msg.cachedSeralizedNoWitness) + } + return msg.BtcEncode(w, 0, BaseEncoding) } diff --git a/wire/msgtx_test.go b/wire/msgtx_test.go index 5ec753b62d..77a6986453 100644 --- a/wire/msgtx_test.go +++ b/wire/msgtx_test.go @@ -181,6 +181,13 @@ func TestTxHash(t *testing.T) { t.Errorf("TxHash: wrong hash - got %v, want %v", spew.Sprint(txHash), spew.Sprint(wantHash)) } + + // Compute it again to ensure any cached elements, are valid. + txHash = msgTx.TxHash() + if !txHash.IsEqual(wantHash) { + t.Errorf("TxHash: wrong hash - got %v, want %v", + spew.Sprint(txHash), spew.Sprint(wantHash)) + } } // TestTxSha tests the ability to generate the wtxid, and txid of a transaction @@ -258,6 +265,18 @@ func TestWTxSha(t *testing.T) { t.Errorf("WTxSha: wrong hash - got %v, want %v", spew.Sprint(wtxid), spew.Sprint(wantHashWTxid)) } + + // Compute the values again to ensure any cached elements are valid. + txid = msgTx.TxHash() + if !txid.IsEqual(wantHashTxid) { + t.Errorf("TxSha: wrong hash - got %v, want %v", + spew.Sprint(txid), spew.Sprint(wantHashTxid)) + } + wtxid = msgTx.WitnessHash() + if !wtxid.IsEqual(wantHashWTxid) { + t.Errorf("WTxSha: wrong hash - got %v, want %v", + spew.Sprint(wtxid), spew.Sprint(wantHashWTxid)) + } } // TestTxWire tests the MsgTx wire encode and decode for various numbers @@ -393,6 +412,23 @@ func TestTxWire(t *testing.T) { t.Errorf("BtcDecode #%d error %v", i, err) continue } + + // If this is the base encoding, then ensure that the cached + // serialization properly matches the raw encoding. + if test.enc == BaseEncoding { + if !bytes.Equal( + test.buf, msg.cachedSeralizedNoWitness, + ) { + t.Errorf("BtcdDecode #%d: cached encoding "+ + "is wrong, expected %x got %x", i, + test.buf, + msg.cachedSeralizedNoWitness) + continue + } + } + + msg.cachedSeralizedNoWitness = nil + if !reflect.DeepEqual(&msg, test.out) { t.Errorf("BtcDecode #%d\n got: %s want: %s", i, spew.Sdump(&msg), spew.Sdump(test.out)) @@ -539,6 +575,23 @@ func TestTxSerialize(t *testing.T) { t.Errorf("Deserialize #%d error %v", i, err) continue } + + // Ensure that the raw non-witness encoding matches the cached + // non-witness encoding bytes. + var b bytes.Buffer + if err := tx.SerializeNoWitness(&b); err != nil { + t.Errorf("Deserialize #%d: unable to encode: %v", i, err) + } + if !bytes.Equal(b.Bytes(), tx.cachedSeralizedNoWitness) { + t.Errorf("Deserialize #%d: cached encoding "+ + "is wrong, expected %x got %x", i, + b.Bytes(), + tx.cachedSeralizedNoWitness) + continue + } + + tx.cachedSeralizedNoWitness = nil + if !reflect.DeepEqual(&tx, test.out) { t.Errorf("Deserialize #%d\n got: %s want: %s", i, spew.Sdump(&tx), spew.Sdump(test.out))