From a5f3278b8282ee3c92347fba2524415c3c86a51f Mon Sep 17 00:00:00 2001 From: Paul Bellamy Date: Mon, 20 Mar 2023 16:52:41 +0000 Subject: [PATCH] Soroban xdr value overhaul ScVal.Equals fixes (#4815) * Fix a bug in generated xdr for boolean unmarshalling * Output more helpful message when scval test fails * Fix equality for xdr ScVal overhaul * use bytes.equal --- xdr/scval.go | 183 +++++++++++++++++++++++-------------------- xdr/scval_test.go | 2 +- xdr/xdr_generated.go | 2 +- 3 files changed, 99 insertions(+), 88 deletions(-) diff --git a/xdr/scval.go b/xdr/scval.go index 7bf31e2f1d..6fa9c878f8 100644 --- a/xdr/scval.go +++ b/xdr/scval.go @@ -1,84 +1,20 @@ package xdr -import "bytes" +import ( + "bytes" +) -func (s Int128Parts) Equals(o Int128Parts) bool { - return s.Lo == o.Lo && s.Hi == o.Hi -} - -func (s ScContractCode) Equals(o ScContractCode) bool { +func (s ScContractExecutable) Equals(o ScContractExecutable) bool { if s.Type != o.Type { return false } switch s.Type { - case ScContractCodeTypeSccontractCodeToken: + case ScContractExecutableTypeSccontractExecutableToken: return true - case ScContractCodeTypeSccontractCodeWasmRef: + case ScContractExecutableTypeSccontractExecutableWasmRef: return s.MustWasmId().Equals(o.MustWasmId()) default: - panic("unknown ScContractCode type: " + s.Type.String()) - } -} - -func (s *ScObject) Equals(o *ScObject) bool { - if (s == nil) != (o == nil) { - return false - } - if s == nil { - return true - } - if s.Type != o.Type { - return false - } - - switch s.Type { - case ScObjectTypeScoI64: - return s.MustI64() == o.MustI64() - case ScObjectTypeScoContractCode: - return s.MustContractCode().Equals(o.MustContractCode()) - case ScObjectTypeScoU128: - return s.MustU128().Equals(o.MustU128()) - case ScObjectTypeScoI128: - return s.MustI128().Equals(o.MustI128()) - case ScObjectTypeScoBytes: - return bytes.Equal(s.MustBin(), o.MustBin()) - case ScObjectTypeScoMap: - myMap := s.MustMap() - otherMap := o.MustMap() - if len(myMap) != len(otherMap) { - return false - } - for i := range myMap { - if !myMap[i].Key.Equals(otherMap[i].Key) || - !myMap[i].Val.Equals(otherMap[i].Val) { - return false - } - } - return true - case ScObjectTypeScoU64: - return s.MustU64() == o.MustU64() - case ScObjectTypeScoVec: - myVec := s.MustVec() - otherVec := o.MustVec() - if len(myVec) != len(otherVec) { - return false - } - for i := range myVec { - if !myVec[i].Equals(otherVec[i]) { - return false - } - } - return true - case ScObjectTypeScoAddress: - myAddr := s.MustAddress() - otherAddr := o.MustAddress() - return myAddr.Equals(otherAddr) - case ScObjectTypeScoNonceKey: - myAddr := s.MustNonceAddress() - otherAddr := o.MustNonceAddress() - return myAddr.Equals(otherAddr) - default: - panic("unknown ScObject type: " + s.Type.String()) + panic("unknown ScContractExecutable type: " + s.Type.String()) } } @@ -119,27 +55,59 @@ func (s ScVal) Equals(o ScVal) bool { } switch s.Type { - case ScValTypeScvObject: - return s.MustObj().Equals(o.MustObj()) - case ScValTypeScvBitset: - return s.MustBits() == o.MustBits() - case ScValTypeScvStatic: - return s.MustIc() == o.MustIc() + case ScValTypeScvBool: + return s.MustB() == o.MustB() + case ScValTypeScvVoid: + return true case ScValTypeScvStatus: - return s.MustStatus().Equals(o.MustStatus()) - case ScValTypeScvSymbol: - return s.MustSym() == o.MustSym() - case ScValTypeScvI32: - return s.MustI32() == o.MustI32() + return s.MustError().Equals(o.MustError()) case ScValTypeScvU32: return s.MustU32() == o.MustU32() - case ScValTypeScvU63: - return s.MustU63() == o.MustU63() + case ScValTypeScvI32: + return s.MustI32() == o.MustI32() + case ScValTypeScvU64: + return s.MustU64() == o.MustU64() + case ScValTypeScvI64: + return s.MustI64() == o.MustI64() + case ScValTypeScvTimepoint: + return s.MustTimepoint() == o.MustTimepoint() + case ScValTypeScvDuration: + return s.MustDuration() == o.MustDuration() + case ScValTypeScvU128: + return s.MustU128() == o.MustU128() + case ScValTypeScvI128: + return s.MustI128() == o.MustI128() + case ScValTypeScvU256: + return s.MustU256() == o.MustU256() + case ScValTypeScvI256: + return s.MustI256() == o.MustI256() + case ScValTypeScvBytes: + return s.MustBytes().Equals(o.MustBytes()) + case ScValTypeScvString: + return s.MustStr() == o.MustStr() + case ScValTypeScvSymbol: + return s.MustSym() == o.MustSym() + case ScValTypeScvVec: + return s.MustVec().Equals(o.MustVec()) + case ScValTypeScvMap: + return s.MustMap().Equals(o.MustMap()) + case ScValTypeScvContractExecutable: + return s.MustExec().Equals(o.MustExec()) + case ScValTypeScvAddress: + return s.MustAddress().Equals(o.MustAddress()) + case ScValTypeScvLedgerKeyContractExecutable: + return true + case ScValTypeScvLedgerKeyNonce: + return s.MustNonceKey().Equals(o.MustNonceKey()) default: panic("unknown ScVal type: " + s.Type.String()) } } +func (s ScBytes) Equals(o ScBytes) bool { + return bytes.Equal([]byte(s), []byte(o)) +} + func (s ScAddress) Equals(o ScAddress) bool { if s.Type != o.Type { return false @@ -158,6 +126,49 @@ func (s ScAddress) Equals(o ScAddress) bool { // IsBool returns true if the given ScVal is a boolean func (s ScVal) IsBool() bool { - ic, ok := s.GetIc() - return ok && (ic == ScStaticScsTrue || ic == ScStaticScsFalse) + return s.Type == ScValTypeScvBool +} + +func (s *ScVec) Equals(o *ScVec) bool { + if s == nil && o == nil { + return true + } + if s == nil || o == nil { + return false + } + if len(*s) != len(*o) { + return false + } + for i := range *s { + if !(*s)[i].Equals((*o)[i]) { + return false + } + } + return true +} + +func (s *ScMap) Equals(o *ScMap) bool { + if s == nil && o == nil { + return true + } + if s == nil || o == nil { + return false + } + if len(*s) != len(*o) { + return false + } + for i, entry := range *s { + if !entry.Equals((*o)[i]) { + return false + } + } + return true +} + +func (s ScMapEntry) Equals(o ScMapEntry) bool { + return s.Key.Equals(o.Key) && s.Val.Equals(o.Val) +} + +func (s ScNonceKey) Equals(o ScNonceKey) bool { + return s.NonceAddress.Equals(o.NonceAddress) } diff --git a/xdr/scval_test.go b/xdr/scval_test.go index f237c42884..812a12942b 100644 --- a/xdr/scval_test.go +++ b/xdr/scval_test.go @@ -23,6 +23,6 @@ func TestScValEqualsCoverage(t *testing.T) { clonedScVal := ScVal{} assert.NoError(t, gxdr.Convert(shape, &clonedScVal)) - assert.True(t, scVal.Equals(clonedScVal)) + assert.True(t, scVal.Equals(clonedScVal), "scVal: %#v, clonedScVal: %#v", scVal, clonedScVal) } } diff --git a/xdr/xdr_generated.go b/xdr/xdr_generated.go index 8249f8f12f..e0963e5fcd 100644 --- a/xdr/xdr_generated.go +++ b/xdr/xdr_generated.go @@ -51046,7 +51046,7 @@ func (u *ScVal) DecodeFrom(d *xdr.Decoder) (int, error) { switch ScValType(u.Type) { case ScValTypeScvBool: u.B = new(bool) - nTmp, err = d.Decode((*u.B)) + nTmp, err = d.Decode(u.B) n += nTmp if err != nil { return n, fmt.Errorf("decoding Bool: %s", err)