From 6e5292224a6dbd51aa3d7a68f3f9be20b25f8315 Mon Sep 17 00:00:00 2001 From: Dhruba Basu <7675102+dhrubabasu@users.noreply.github.com> Date: Tue, 28 Nov 2023 07:55:27 -0800 Subject: [PATCH 1/3] `vms/platformvm`: Move `toEngine` channel to mempool (#2333) --- vms/platformvm/block/builder/builder.go | 24 ++++---------- vms/platformvm/block/builder/helpers_test.go | 3 +- vms/platformvm/block/executor/helpers_test.go | 8 +---- vms/platformvm/block/executor/rejector.go | 2 ++ .../block/executor/rejector_test.go | 2 ++ vms/platformvm/network/network.go | 2 ++ vms/platformvm/network/network_test.go | 1 + vms/platformvm/txs/mempool/mempool.go | 33 ++++++++++++------- vms/platformvm/txs/mempool/mempool_test.go | 18 +++------- vms/platformvm/txs/mempool/mock_mempool.go | 12 +++++++ vms/platformvm/vm.go | 5 +-- 11 files changed, 55 insertions(+), 55 deletions(-) diff --git a/vms/platformvm/block/builder/builder.go b/vms/platformvm/block/builder/builder.go index 8923857d579e..b8476e19b032 100644 --- a/vms/platformvm/block/builder/builder.go +++ b/vms/platformvm/block/builder/builder.go @@ -13,7 +13,6 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow/consensus/snowman" - "github.com/ava-labs/avalanchego/snow/engine/common" "github.com/ava-labs/avalanchego/utils/timer" "github.com/ava-labs/avalanchego/utils/timer/mockable" "github.com/ava-labs/avalanchego/utils/units" @@ -40,7 +39,11 @@ var ( type Builder interface { mempool.Mempool - mempool.BlockTimer + + // ResetBlockTimer schedules a timer to notify the consensus engine once + // there is a block ready to be built. If a block is ready to be built when + // this function is called, the engine will be notified directly. + ResetBlockTimer() // BuildBlock is called on timer clock to attempt to create // next block @@ -58,9 +61,6 @@ type builder struct { txExecutorBackend *txexecutor.Backend blkManager blockexecutor.Manager - // channel to send messages to the consensus engine - toEngine chan<- common.Message - // This timer goes off when it is time for the next validator to add/leave // the validator set. When it goes off ResetTimer() is called, potentially // triggering creation of a new block. @@ -72,14 +72,12 @@ func New( txBuilder txbuilder.Builder, txExecutorBackend *txexecutor.Backend, blkManager blockexecutor.Manager, - toEngine chan<- common.Message, ) Builder { builder := &builder{ Mempool: mempool, txBuilder: txBuilder, txExecutorBackend: txExecutorBackend, blkManager: blkManager, - toEngine: toEngine, } builder.timer = timer.NewTimer(builder.setNextBuildBlockTime) @@ -192,7 +190,7 @@ func (b *builder) setNextBuildBlockTime() { if _, err := b.buildBlock(); err == nil { // We can build a block now - b.notifyBlockReady() + b.Mempool.RequestBuildBlock(true /*=emptyBlockPermitted*/) return } @@ -229,16 +227,6 @@ func (b *builder) setNextBuildBlockTime() { b.timer.SetTimeoutIn(waitTime) } -// notifyBlockReady tells the consensus engine that a new block is ready to be -// created -func (b *builder) notifyBlockReady() { - select { - case b.toEngine <- common.PendingTxs: - default: - b.txExecutorBackend.Ctx.Log.Debug("dropping message to consensus engine") - } -} - // [timestamp] is min(max(now, parent timestamp), next staker change time) func buildBlock( builder *builder, diff --git a/vms/platformvm/block/builder/helpers_test.go b/vms/platformvm/block/builder/helpers_test.go index 84778add2864..de37d08ff0dd 100644 --- a/vms/platformvm/block/builder/helpers_test.go +++ b/vms/platformvm/block/builder/helpers_test.go @@ -169,7 +169,7 @@ func newEnvironment(t *testing.T) *environment { metrics, err := metrics.New("", registerer) require.NoError(err) - res.mempool, err = mempool.New("mempool", registerer, res) + res.mempool, err = mempool.New("mempool", registerer, nil) require.NoError(err) res.blkManager = blockexecutor.NewManager( @@ -193,7 +193,6 @@ func newEnvironment(t *testing.T) *environment { res.txBuilder, &res.backend, res.blkManager, - nil, // toEngine, ) res.blkManager.SetPreference(genesisID) diff --git a/vms/platformvm/block/executor/helpers_test.go b/vms/platformvm/block/executor/helpers_test.go index ff0aa13a2ea1..778d9b203181 100644 --- a/vms/platformvm/block/executor/helpers_test.go +++ b/vms/platformvm/block/executor/helpers_test.go @@ -63,8 +63,6 @@ const ( ) var ( - _ mempool.BlockTimer = (*environment)(nil) - defaultMinStakingDuration = 24 * time.Hour defaultMaxStakingDuration = 365 * 24 * time.Hour defaultGenesisTime = time.Date(1997, 1, 1, 0, 0, 0, 0, time.UTC) @@ -131,10 +129,6 @@ type environment struct { backend *executor.Backend } -func (*environment) ResetBlockTimer() { - // dummy call, do nothing for now -} - func newEnvironment(t *testing.T, ctrl *gomock.Controller) *environment { res := &environment{ isBootstrapped: &utils.Atomic[bool]{}, @@ -199,7 +193,7 @@ func newEnvironment(t *testing.T, ctrl *gomock.Controller) *environment { metrics := metrics.Noop var err error - res.mempool, err = mempool.New("mempool", registerer, res) + res.mempool, err = mempool.New("mempool", registerer, nil) if err != nil { panic(fmt.Errorf("failed to create mempool: %w", err)) } diff --git a/vms/platformvm/block/executor/rejector.go b/vms/platformvm/block/executor/rejector.go index daa6939f05cd..cfc64b050be4 100644 --- a/vms/platformvm/block/executor/rejector.go +++ b/vms/platformvm/block/executor/rejector.go @@ -82,5 +82,7 @@ func (r *rejector) rejectBlock(b block.Block, blockType string) error { } } + r.Mempool.RequestBuildBlock(false) + return nil } diff --git a/vms/platformvm/block/executor/rejector_test.go b/vms/platformvm/block/executor/rejector_test.go index 1e1e5768618d..3ccd9c0d66b1 100644 --- a/vms/platformvm/block/executor/rejector_test.go +++ b/vms/platformvm/block/executor/rejector_test.go @@ -142,6 +142,8 @@ func TestRejectBlock(t *testing.T) { mempool.EXPECT().Add(tx).Return(nil).Times(1) } + mempool.EXPECT().RequestBuildBlock(false).Times(1) + require.NoError(tt.rejectFunc(rejector, blk)) // Make sure block and its parent are removed from the state map. require.NotContains(rejector.blkIDToState, blk.ID()) diff --git a/vms/platformvm/network/network.go b/vms/platformvm/network/network.go index 0bbfc4f86eaf..5f4945093d60 100644 --- a/vms/platformvm/network/network.go +++ b/vms/platformvm/network/network.go @@ -181,6 +181,8 @@ func (n *network) issueTx(tx *txs.Tx) error { return err } + n.mempool.RequestBuildBlock(false) + return nil } diff --git a/vms/platformvm/network/network_test.go b/vms/platformvm/network/network_test.go index 8c17bb0491b5..000cbda7e195 100644 --- a/vms/platformvm/network/network_test.go +++ b/vms/platformvm/network/network_test.go @@ -284,6 +284,7 @@ func TestNetworkIssueTx(t *testing.T) { mempool := mempool.NewMockMempool(ctrl) mempool.EXPECT().Has(gomock.Any()).Return(false) mempool.EXPECT().Add(gomock.Any()).Return(nil) + mempool.EXPECT().RequestBuildBlock(false) return mempool }, managerFunc: func(ctrl *gomock.Controller) executor.Manager { diff --git a/vms/platformvm/txs/mempool/mempool.go b/vms/platformvm/txs/mempool/mempool.go index 675ec3c5c763..ce0d6a96f071 100644 --- a/vms/platformvm/txs/mempool/mempool.go +++ b/vms/platformvm/txs/mempool/mempool.go @@ -12,6 +12,7 @@ import ( "github.com/ava-labs/avalanchego/cache" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow/engine/common" "github.com/ava-labs/avalanchego/utils/linkedhashmap" "github.com/ava-labs/avalanchego/utils/set" "github.com/ava-labs/avalanchego/utils/units" @@ -43,13 +44,6 @@ var ( errCantIssueRewardValidatorTx = errors.New("can not issue a reward validator tx") ) -type BlockTimer interface { - // ResetBlockTimer schedules a timer to notify the consensus engine once - // there is a block ready to be built. If a block is ready to be built when - // this function is called, the engine will be notified directly. - ResetBlockTimer() -} - type Mempool interface { // we may want to be able to stop valid transactions // from entering the mempool, e.g. during blocks creation @@ -75,6 +69,13 @@ type Mempool interface { // TODO: Remove once [StartTime] field is ignored in staker txs DropExpiredStakerTxs(minStartTime time.Time) []ids.ID + // RequestBuildBlock notifies the consensus engine that a block should be + // built. If [emptyBlockPermitted] is true, the notification will be sent + // regardless of whether there are no transactions in the mempool. If not, + // a notification will only be sent if there is at least one transaction in + // the mempool. + RequestBuildBlock(emptyBlockPermitted bool) + // Note: dropped txs are added to droppedTxIDs but are not evicted from // unissued decision/staker txs. This allows previously dropped txs to be // possibly reissued. @@ -100,13 +101,13 @@ type mempool struct { consumedUTXOs set.Set[ids.ID] - blkTimer BlockTimer + toEngine chan<- common.Message } func New( namespace string, registerer prometheus.Registerer, - blkTimer BlockTimer, + toEngine chan<- common.Message, ) (Mempool, error) { bytesAvailableMetric := prometheus.NewGauge(prometheus.GaugeOpts{ Namespace: namespace, @@ -137,7 +138,7 @@ func New( droppedTxIDs: &cache.LRU[ids.ID, error]{Size: droppedTxIDsCacheSize}, consumedUTXOs: set.NewSet[ids.ID](initialConsumedUTXOsSize), dropIncoming: false, // enable tx adding by default - blkTimer: blkTimer, + toEngine: toEngine, }, nil } @@ -202,7 +203,6 @@ func (m *mempool) Add(tx *txs.Tx) error { // An explicitly added tx must not be marked as dropped. m.droppedTxIDs.Evict(txID) - m.blkTimer.ResetBlockTimer() return nil } @@ -259,6 +259,17 @@ func (m *mempool) GetDropReason(txID ids.ID) error { return err } +func (m *mempool) RequestBuildBlock(emptyBlockPermitted bool) { + if !emptyBlockPermitted && !m.HasTxs() { + return + } + + select { + case m.toEngine <- common.PendingTxs: + default: + } +} + // Drops all [txs.Staker] transactions whose [StartTime] is before // [minStartTime] from [mempool]. The dropped tx ids are returned. // diff --git a/vms/platformvm/txs/mempool/mempool_test.go b/vms/platformvm/txs/mempool/mempool_test.go index a56ae4702155..1d92132ebbcd 100644 --- a/vms/platformvm/txs/mempool/mempool_test.go +++ b/vms/platformvm/txs/mempool/mempool_test.go @@ -20,15 +20,7 @@ import ( "github.com/ava-labs/avalanchego/vms/secp256k1fx" ) -var ( - _ BlockTimer = (*noopBlkTimer)(nil) - - preFundedKeys = secp256k1.TestKeys() -) - -type noopBlkTimer struct{} - -func (*noopBlkTimer) ResetBlockTimer() {} +var preFundedKeys = secp256k1.TestKeys() // shows that valid tx is not added to mempool if this would exceed its maximum // size @@ -36,7 +28,7 @@ func TestBlockBuilderMaxMempoolSizeHandling(t *testing.T) { require := require.New(t) registerer := prometheus.NewRegistry() - mpool, err := New("mempool", registerer, &noopBlkTimer{}) + mpool, err := New("mempool", registerer, nil) require.NoError(err) decisionTxs, err := createTestDecisionTxs(1) @@ -60,7 +52,7 @@ func TestDecisionTxsInMempool(t *testing.T) { require := require.New(t) registerer := prometheus.NewRegistry() - mpool, err := New("mempool", registerer, &noopBlkTimer{}) + mpool, err := New("mempool", registerer, nil) require.NoError(err) decisionTxs, err := createTestDecisionTxs(2) @@ -112,7 +104,7 @@ func TestProposalTxsInMempool(t *testing.T) { require := require.New(t) registerer := prometheus.NewRegistry() - mpool, err := New("mempool", registerer, &noopBlkTimer{}) + mpool, err := New("mempool", registerer, nil) require.NoError(err) // The proposal txs are ordered by decreasing start time. This means after @@ -245,7 +237,7 @@ func TestDropExpiredStakerTxs(t *testing.T) { require := require.New(t) registerer := prometheus.NewRegistry() - mempool, err := New("mempool", registerer, &noopBlkTimer{}) + mempool, err := New("mempool", registerer, nil) require.NoError(err) tx1, err := generateAddValidatorTx(10, 20) diff --git a/vms/platformvm/txs/mempool/mock_mempool.go b/vms/platformvm/txs/mempool/mock_mempool.go index 8f8c90eb2d07..edc134a42ddf 100644 --- a/vms/platformvm/txs/mempool/mock_mempool.go +++ b/vms/platformvm/txs/mempool/mock_mempool.go @@ -184,3 +184,15 @@ func (mr *MockMempoolMockRecorder) Remove(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockMempool)(nil).Remove), arg0) } + +// RequestBuildBlock mocks base method. +func (m *MockMempool) RequestBuildBlock(arg0 bool) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RequestBuildBlock", arg0) +} + +// RequestBuildBlock indicates an expected call of RequestBuildBlock. +func (mr *MockMempoolMockRecorder) RequestBuildBlock(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestBuildBlock", reflect.TypeOf((*MockMempool)(nil).RequestBuildBlock), arg0) +} diff --git a/vms/platformvm/vm.go b/vms/platformvm/vm.go index c312e4044e8b..d9898b873137 100644 --- a/vms/platformvm/vm.go +++ b/vms/platformvm/vm.go @@ -177,9 +177,7 @@ func (vm *VM) Initialize( Bootstrapped: &vm.bootstrapped, } - // Note: There is a circular dependency between the mempool and block - // builder which is broken by passing in the vm. - mempool, err := mempool.New("mempool", registerer, vm) + mempool, err := mempool.New("mempool", registerer, toEngine) if err != nil { return fmt.Errorf("failed to create mempool: %w", err) } @@ -203,7 +201,6 @@ func (vm *VM) Initialize( vm.txBuilder, txExecutorBackend, vm.manager, - toEngine, ) // Create all of the chains that the database says exist From d7e7ff5cc702d108701d325c38895c2642127124 Mon Sep 17 00:00:00 2001 From: Dhruba Basu <7675102+dhrubabasu@users.noreply.github.com> Date: Tue, 28 Nov 2023 10:10:49 -0800 Subject: [PATCH 2/3] `vms/avm`: Rename `states` pkg to `state` (#2381) --- scripts/mocks.mockgen.txt | 2 +- vms/avm/block/builder/builder.go | 12 +- vms/avm/block/builder/builder_test.go | 18 +-- vms/avm/block/executor/block.go | 4 +- vms/avm/block/executor/block_test.go | 42 +++---- vms/avm/block/executor/manager.go | 14 +-- vms/avm/block/executor/manager_test.go | 22 ++-- vms/avm/block/executor/mock_manager.go | 6 +- vms/avm/service_test.go | 18 +-- vms/avm/{states => state}/diff.go | 2 +- .../mock_states.go => state/mock_state.go} | 6 +- vms/avm/{states => state}/state.go | 2 +- vms/avm/{states => state}/state_test.go | 2 +- vms/avm/{states => state}/versions.go | 2 +- vms/avm/txs/executor/executor.go | 4 +- vms/avm/txs/executor/executor_test.go | 8 +- vms/avm/txs/executor/semantic_verifier.go | 4 +- .../txs/executor/semantic_verifier_test.go | 106 +++++++++--------- vms/avm/vm.go | 6 +- 19 files changed, 140 insertions(+), 140 deletions(-) rename vms/avm/{states => state}/diff.go (99%) rename vms/avm/{states/mock_states.go => state/mock_state.go} (99%) rename vms/avm/{states => state}/state.go (99%) rename vms/avm/{states => state}/state_test.go (99%) rename vms/avm/{states => state}/versions.go (95%) diff --git a/scripts/mocks.mockgen.txt b/scripts/mocks.mockgen.txt index ba55a16a4b71..76add90f3f7f 100644 --- a/scripts/mocks.mockgen.txt +++ b/scripts/mocks.mockgen.txt @@ -27,7 +27,7 @@ github.com/ava-labs/avalanchego/utils/logging=Logger=utils/logging/mock_logger.g github.com/ava-labs/avalanchego/utils/resource=User=utils/resource/mock_user.go github.com/ava-labs/avalanchego/vms/avm/block=Block=vms/avm/block/mock_block.go github.com/ava-labs/avalanchego/vms/avm/metrics=Metrics=vms/avm/metrics/mock_metrics.go -github.com/ava-labs/avalanchego/vms/avm/states=Chain,State,Diff=vms/avm/states/mock_states.go +github.com/ava-labs/avalanchego/vms/avm/state=Chain,State,Diff=vms/avm/state/mock_state.go github.com/ava-labs/avalanchego/vms/avm/txs/mempool=Mempool=vms/avm/txs/mempool/mock_mempool.go github.com/ava-labs/avalanchego/vms/components/avax=TransferableIn=vms/components/avax/mock_transferable_in.go github.com/ava-labs/avalanchego/vms/components/verify=Verifiable=vms/components/verify/mock_verifiable.go diff --git a/vms/avm/block/builder/builder.go b/vms/avm/block/builder/builder.go index 77bbf38a6a35..a3129d797808 100644 --- a/vms/avm/block/builder/builder.go +++ b/vms/avm/block/builder/builder.go @@ -13,7 +13,7 @@ import ( "github.com/ava-labs/avalanchego/utils/timer/mockable" "github.com/ava-labs/avalanchego/utils/units" "github.com/ava-labs/avalanchego/vms/avm/block" - "github.com/ava-labs/avalanchego/vms/avm/states" + "github.com/ava-labs/avalanchego/vms/avm/state" "github.com/ava-labs/avalanchego/vms/avm/txs" "github.com/ava-labs/avalanchego/vms/avm/txs/mempool" @@ -82,7 +82,7 @@ func (b *builder) BuildBlock(context.Context) (snowman.Block, error) { nextTimestamp = preferredTimestamp } - stateDiff, err := states.NewDiff(preferredID, b.manager) + stateDiff, err := state.NewDiff(preferredID, b.manager) if err != nil { return nil, err } @@ -168,15 +168,15 @@ func (b *builder) BuildBlock(context.Context) (snowman.Block, error) { } type stateGetter struct { - state states.Chain + state state.Chain } -func (s stateGetter) GetState(ids.ID) (states.Chain, bool) { +func (s stateGetter) GetState(ids.ID) (state.Chain, bool) { return s.state, true } -func wrapState(parentState states.Chain) (states.Diff, error) { - return states.NewDiff(ids.Empty, stateGetter{ +func wrapState(parentState state.Chain) (state.Diff, error) { + return state.NewDiff(ids.Empty, stateGetter{ state: parentState, }) } diff --git a/vms/avm/block/builder/builder_test.go b/vms/avm/block/builder/builder_test.go index fdab9d6cf064..7faeddbe71e6 100644 --- a/vms/avm/block/builder/builder_test.go +++ b/vms/avm/block/builder/builder_test.go @@ -29,7 +29,7 @@ import ( "github.com/ava-labs/avalanchego/vms/avm/block" "github.com/ava-labs/avalanchego/vms/avm/fxs" "github.com/ava-labs/avalanchego/vms/avm/metrics" - "github.com/ava-labs/avalanchego/vms/avm/states" + "github.com/ava-labs/avalanchego/vms/avm/state" "github.com/ava-labs/avalanchego/vms/avm/txs" "github.com/ava-labs/avalanchego/vms/avm/txs/mempool" "github.com/ava-labs/avalanchego/vms/components/avax" @@ -108,7 +108,7 @@ func TestBuilderBuildBlock(t *testing.T) { mempool, ) }, - expectedErr: states.ErrMissingParentState, + expectedErr: state.ErrMissingParentState, }, { name: "tx fails semantic verification", @@ -120,7 +120,7 @@ func TestBuilderBuildBlock(t *testing.T) { preferredBlock.EXPECT().Height().Return(preferredHeight) preferredBlock.EXPECT().Timestamp().Return(preferredTimestamp) - preferredState := states.NewMockChain(ctrl) + preferredState := state.NewMockChain(ctrl) preferredState.EXPECT().GetLastAccepted().Return(preferredID) preferredState.EXPECT().GetTimestamp().Return(preferredTimestamp) @@ -164,7 +164,7 @@ func TestBuilderBuildBlock(t *testing.T) { preferredBlock.EXPECT().Height().Return(preferredHeight) preferredBlock.EXPECT().Timestamp().Return(preferredTimestamp) - preferredState := states.NewMockChain(ctrl) + preferredState := state.NewMockChain(ctrl) preferredState.EXPECT().GetLastAccepted().Return(preferredID) preferredState.EXPECT().GetTimestamp().Return(preferredTimestamp) @@ -209,7 +209,7 @@ func TestBuilderBuildBlock(t *testing.T) { preferredBlock.EXPECT().Height().Return(preferredHeight) preferredBlock.EXPECT().Timestamp().Return(preferredTimestamp) - preferredState := states.NewMockChain(ctrl) + preferredState := state.NewMockChain(ctrl) preferredState.EXPECT().GetLastAccepted().Return(preferredID) preferredState.EXPECT().GetTimestamp().Return(preferredTimestamp) @@ -255,7 +255,7 @@ func TestBuilderBuildBlock(t *testing.T) { preferredBlock.EXPECT().Height().Return(preferredHeight) preferredBlock.EXPECT().Timestamp().Return(preferredTimestamp) - preferredState := states.NewMockChain(ctrl) + preferredState := state.NewMockChain(ctrl) preferredState.EXPECT().GetLastAccepted().Return(preferredID) preferredState.EXPECT().GetTimestamp().Return(preferredTimestamp) @@ -353,7 +353,7 @@ func TestBuilderBuildBlock(t *testing.T) { clock := &mockable.Clock{} clock.Set(preferredTimestamp.Add(-2 * time.Second)) - preferredState := states.NewMockChain(ctrl) + preferredState := state.NewMockChain(ctrl) preferredState.EXPECT().GetLastAccepted().Return(preferredID) preferredState.EXPECT().GetTimestamp().Return(preferredTimestamp) @@ -427,7 +427,7 @@ func TestBuilderBuildBlock(t *testing.T) { clock := &mockable.Clock{} clock.Set(now) - preferredState := states.NewMockChain(ctrl) + preferredState := state.NewMockChain(ctrl) preferredState.EXPECT().GetLastAccepted().Return(preferredID) preferredState.EXPECT().GetTimestamp().Return(preferredTimestamp) @@ -526,7 +526,7 @@ func TestBlockBuilderAddLocalTx(t *testing.T) { baseDB := versiondb.New(memdb.New()) - state, err := states.New(baseDB, parser, registerer, trackChecksums) + state, err := state.New(baseDB, parser, registerer, trackChecksums) require.NoError(err) clk := &mockable.Clock{} diff --git a/vms/avm/block/executor/block.go b/vms/avm/block/executor/block.go index 418ca0b539ca..5e643ad4ecc0 100644 --- a/vms/avm/block/executor/block.go +++ b/vms/avm/block/executor/block.go @@ -17,7 +17,7 @@ import ( "github.com/ava-labs/avalanchego/snow/choices" "github.com/ava-labs/avalanchego/snow/consensus/snowman" "github.com/ava-labs/avalanchego/vms/avm/block" - "github.com/ava-labs/avalanchego/vms/avm/states" + "github.com/ava-labs/avalanchego/vms/avm/state" "github.com/ava-labs/avalanchego/vms/avm/txs/executor" ) @@ -106,7 +106,7 @@ func (b *Block) Verify(context.Context) error { ) } - stateDiff, err := states.NewDiff(parentID, b.manager) + stateDiff, err := state.NewDiff(parentID, b.manager) if err != nil { return err } diff --git a/vms/avm/block/executor/block_test.go b/vms/avm/block/executor/block_test.go index 9d7f291a8f60..da965884ae58 100644 --- a/vms/avm/block/executor/block_test.go +++ b/vms/avm/block/executor/block_test.go @@ -24,7 +24,7 @@ import ( "github.com/ava-labs/avalanchego/utils/timer/mockable" "github.com/ava-labs/avalanchego/vms/avm/block" "github.com/ava-labs/avalanchego/vms/avm/metrics" - "github.com/ava-labs/avalanchego/vms/avm/states" + "github.com/ava-labs/avalanchego/vms/avm/state" "github.com/ava-labs/avalanchego/vms/avm/txs" "github.com/ava-labs/avalanchego/vms/avm/txs/executor" "github.com/ava-labs/avalanchego/vms/avm/txs/mempool" @@ -153,7 +153,7 @@ func TestBlockVerify(t *testing.T) { parentID := ids.GenerateTestID() mockBlock.EXPECT().Parent().Return(parentID).AnyTimes() - mockState := states.NewMockState(ctrl) + mockState := state.NewMockState(ctrl) mockState.EXPECT().GetBlock(parentID).Return(nil, errTest) return &Block{ Block: mockBlock, @@ -186,7 +186,7 @@ func TestBlockVerify(t *testing.T) { parentID := ids.GenerateTestID() mockBlock.EXPECT().Parent().Return(parentID).AnyTimes() - mockState := states.NewMockState(ctrl) + mockState := state.NewMockState(ctrl) mockParentBlock := block.NewMockBlock(ctrl) mockParentBlock.EXPECT().Height().Return(blockHeight) // Should be blockHeight - 1 mockState.EXPECT().GetBlock(parentID).Return(mockParentBlock, nil) @@ -226,7 +226,7 @@ func TestBlockVerify(t *testing.T) { mockParentBlock := block.NewMockBlock(ctrl) mockParentBlock.EXPECT().Height().Return(blockHeight - 1) - mockParentState := states.NewMockDiff(ctrl) + mockParentState := state.NewMockDiff(ctrl) mockParentState.EXPECT().GetLastAccepted().Return(parentID) mockParentState.EXPECT().GetTimestamp().Return(blockTimestamp.Add(1)) @@ -271,7 +271,7 @@ func TestBlockVerify(t *testing.T) { mockParentBlock := block.NewMockBlock(ctrl) mockParentBlock.EXPECT().Height().Return(blockHeight - 1) - mockParentState := states.NewMockDiff(ctrl) + mockParentState := state.NewMockDiff(ctrl) mockParentState.EXPECT().GetLastAccepted().Return(parentID) mockParentState.EXPECT().GetTimestamp().Return(blockTimestamp) @@ -321,7 +321,7 @@ func TestBlockVerify(t *testing.T) { mockParentBlock := block.NewMockBlock(ctrl) mockParentBlock.EXPECT().Height().Return(blockHeight - 1) - mockParentState := states.NewMockDiff(ctrl) + mockParentState := state.NewMockDiff(ctrl) mockParentState.EXPECT().GetLastAccepted().Return(parentID) mockParentState.EXPECT().GetTimestamp().Return(blockTimestamp) @@ -399,7 +399,7 @@ func TestBlockVerify(t *testing.T) { mockParentBlock := block.NewMockBlock(ctrl) mockParentBlock.EXPECT().Height().Return(blockHeight - 1) - mockParentState := states.NewMockDiff(ctrl) + mockParentState := state.NewMockDiff(ctrl) mockParentState.EXPECT().GetLastAccepted().Return(parentID) mockParentState.EXPECT().GetTimestamp().Return(blockTimestamp) @@ -461,7 +461,7 @@ func TestBlockVerify(t *testing.T) { mockParentBlock := block.NewMockBlock(ctrl) mockParentBlock.EXPECT().Height().Return(blockHeight - 1) - mockParentState := states.NewMockDiff(ctrl) + mockParentState := state.NewMockDiff(ctrl) mockParentState.EXPECT().GetLastAccepted().Return(parentID) mockParentState.EXPECT().GetTimestamp().Return(blockTimestamp) @@ -509,7 +509,7 @@ func TestBlockVerify(t *testing.T) { mockParentBlock := block.NewMockBlock(ctrl) mockParentBlock.EXPECT().Height().Return(blockHeight - 1) - mockParentState := states.NewMockDiff(ctrl) + mockParentState := state.NewMockDiff(ctrl) mockParentState.EXPECT().GetLastAccepted().Return(parentID) mockParentState.EXPECT().GetTimestamp().Return(blockTimestamp) @@ -616,11 +616,11 @@ func TestBlockAccept(t *testing.T) { mempool := mempool.NewMockMempool(ctrl) mempool.EXPECT().Remove(gomock.Any()).AnyTimes() - mockManagerState := states.NewMockState(ctrl) + mockManagerState := state.NewMockState(ctrl) mockManagerState.EXPECT().CommitBatch().Return(nil, errTest) mockManagerState.EXPECT().Abort() - mockOnAcceptState := states.NewMockDiff(ctrl) + mockOnAcceptState := state.NewMockDiff(ctrl) mockOnAcceptState.EXPECT().Apply(mockManagerState) return &Block{ @@ -654,7 +654,7 @@ func TestBlockAccept(t *testing.T) { mempool := mempool.NewMockMempool(ctrl) mempool.EXPECT().Remove(gomock.Any()).AnyTimes() - mockManagerState := states.NewMockState(ctrl) + mockManagerState := state.NewMockState(ctrl) // Note the returned batch is nil but not used // because we mock the call to shared memory mockManagerState.EXPECT().CommitBatch().Return(nil, nil) @@ -663,7 +663,7 @@ func TestBlockAccept(t *testing.T) { mockSharedMemory := atomic.NewMockSharedMemory(ctrl) mockSharedMemory.EXPECT().Apply(gomock.Any(), gomock.Any()).Return(errTest) - mockOnAcceptState := states.NewMockDiff(ctrl) + mockOnAcceptState := state.NewMockDiff(ctrl) mockOnAcceptState.EXPECT().Apply(mockManagerState) return &Block{ @@ -698,7 +698,7 @@ func TestBlockAccept(t *testing.T) { mempool := mempool.NewMockMempool(ctrl) mempool.EXPECT().Remove(gomock.Any()).AnyTimes() - mockManagerState := states.NewMockState(ctrl) + mockManagerState := state.NewMockState(ctrl) // Note the returned batch is nil but not used // because we mock the call to shared memory mockManagerState.EXPECT().CommitBatch().Return(nil, nil) @@ -707,7 +707,7 @@ func TestBlockAccept(t *testing.T) { mockSharedMemory := atomic.NewMockSharedMemory(ctrl) mockSharedMemory.EXPECT().Apply(gomock.Any(), gomock.Any()).Return(nil) - mockOnAcceptState := states.NewMockDiff(ctrl) + mockOnAcceptState := state.NewMockDiff(ctrl) mockOnAcceptState.EXPECT().Apply(mockManagerState) metrics := metrics.NewMockMetrics(ctrl) @@ -748,7 +748,7 @@ func TestBlockAccept(t *testing.T) { mempool := mempool.NewMockMempool(ctrl) mempool.EXPECT().Remove(gomock.Any()).AnyTimes() - mockManagerState := states.NewMockState(ctrl) + mockManagerState := state.NewMockState(ctrl) // Note the returned batch is nil but not used // because we mock the call to shared memory mockManagerState.EXPECT().CommitBatch().Return(nil, nil) @@ -758,7 +758,7 @@ func TestBlockAccept(t *testing.T) { mockSharedMemory := atomic.NewMockSharedMemory(ctrl) mockSharedMemory.EXPECT().Apply(gomock.Any(), gomock.Any()).Return(nil) - mockOnAcceptState := states.NewMockDiff(ctrl) + mockOnAcceptState := state.NewMockDiff(ctrl) mockOnAcceptState.EXPECT().Apply(mockManagerState) metrics := metrics.NewMockMetrics(ctrl) @@ -859,7 +859,7 @@ func TestBlockReject(t *testing.T) { mempool.EXPECT().Add(validTx).Return(nil) // Only add the one that passes verification preferredID := ids.GenerateTestID() - mockPreferredState := states.NewMockDiff(ctrl) + mockPreferredState := state.NewMockDiff(ctrl) mockPreferredState.EXPECT().GetLastAccepted().Return(ids.GenerateTestID()).AnyTimes() mockPreferredState.EXPECT().GetTimestamp().Return(time.Now()).AnyTimes() @@ -918,7 +918,7 @@ func TestBlockReject(t *testing.T) { mempool.EXPECT().Add(tx2).Return(nil) preferredID := ids.GenerateTestID() - mockPreferredState := states.NewMockDiff(ctrl) + mockPreferredState := state.NewMockDiff(ctrl) mockPreferredState.EXPECT().GetLastAccepted().Return(ids.GenerateTestID()).AnyTimes() mockPreferredState.EXPECT().GetTimestamp().Return(time.Now()).AnyTimes() @@ -1014,7 +1014,7 @@ func TestBlockStatus(t *testing.T) { mockBlock := block.NewMockBlock(ctrl) mockBlock.EXPECT().ID().Return(blockID).AnyTimes() - mockState := states.NewMockState(ctrl) + mockState := state.NewMockState(ctrl) mockState.EXPECT().GetBlock(blockID).Return(nil, nil) return &Block{ @@ -1034,7 +1034,7 @@ func TestBlockStatus(t *testing.T) { mockBlock := block.NewMockBlock(ctrl) mockBlock.EXPECT().ID().Return(blockID).AnyTimes() - mockState := states.NewMockState(ctrl) + mockState := state.NewMockState(ctrl) mockState.EXPECT().GetBlock(blockID).Return(nil, database.ErrNotFound) return &Block{ diff --git a/vms/avm/block/executor/manager.go b/vms/avm/block/executor/manager.go index dd9b8bfab400..48eea701bbd9 100644 --- a/vms/avm/block/executor/manager.go +++ b/vms/avm/block/executor/manager.go @@ -13,7 +13,7 @@ import ( "github.com/ava-labs/avalanchego/utils/timer/mockable" "github.com/ava-labs/avalanchego/vms/avm/block" "github.com/ava-labs/avalanchego/vms/avm/metrics" - "github.com/ava-labs/avalanchego/vms/avm/states" + "github.com/ava-labs/avalanchego/vms/avm/state" "github.com/ava-labs/avalanchego/vms/avm/txs" "github.com/ava-labs/avalanchego/vms/avm/txs/executor" "github.com/ava-labs/avalanchego/vms/avm/txs/mempool" @@ -27,7 +27,7 @@ var ( ) type Manager interface { - states.Versions + state.Versions // Returns the ID of the most recently accepted block. LastAccepted() ids.ID @@ -51,7 +51,7 @@ type Manager interface { func NewManager( mempool mempool.Mempool, metrics metrics.Metrics, - state states.State, + state state.State, backend *executor.Backend, clk *mockable.Clock, onAccept func(*txs.Tx) error, @@ -72,7 +72,7 @@ func NewManager( type manager struct { backend *executor.Backend - state states.State + state state.State metrics metrics.Metrics mempool mempool.Mempool clk *mockable.Clock @@ -93,12 +93,12 @@ type manager struct { type blockState struct { statelessBlock block.Block - onAcceptState states.Diff + onAcceptState state.Diff importedInputs set.Set[ids.ID] atomicRequests map[ids.ID]*atomic.Requests } -func (m *manager) GetState(blkID ids.ID) (states.Chain, bool) { +func (m *manager) GetState(blkID ids.ID) (state.Chain, bool) { // If the block is in the map, it is processing. if state, ok := m.blkIDToState[blkID]; ok { return state.onAcceptState, true @@ -155,7 +155,7 @@ func (m *manager) VerifyTx(tx *txs.Tx) error { return err } - stateDiff, err := states.NewDiff(m.preferred, m) + stateDiff, err := state.NewDiff(m.preferred, m) if err != nil { return err } diff --git a/vms/avm/block/executor/manager_test.go b/vms/avm/block/executor/manager_test.go index c21201417add..904154bf7030 100644 --- a/vms/avm/block/executor/manager_test.go +++ b/vms/avm/block/executor/manager_test.go @@ -15,7 +15,7 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils/set" "github.com/ava-labs/avalanchego/vms/avm/block" - "github.com/ava-labs/avalanchego/vms/avm/states" + "github.com/ava-labs/avalanchego/vms/avm/state" "github.com/ava-labs/avalanchego/vms/avm/txs" "github.com/ava-labs/avalanchego/vms/avm/txs/executor" ) @@ -31,7 +31,7 @@ func TestManagerGetStatelessBlock(t *testing.T) { require := require.New(t) ctrl := gomock.NewController(t) - state := states.NewMockState(ctrl) + state := state.NewMockState(ctrl) m := &manager{ state: state, blkIDToState: map[ids.ID]*blockState{}, @@ -73,16 +73,16 @@ func TestManagerGetState(t *testing.T) { require := require.New(t) ctrl := gomock.NewController(t) - state := states.NewMockState(ctrl) + s := state.NewMockState(ctrl) m := &manager{ - state: state, + state: s, blkIDToState: map[ids.ID]*blockState{}, lastAccepted: ids.GenerateTestID(), } // Case: Block is in memory { - diff := states.NewMockDiff(ctrl) + diff := state.NewMockDiff(ctrl) blkID := ids.GenerateTestID() m.blkIDToState[blkID] = &blockState{ onAcceptState: diff, @@ -97,14 +97,14 @@ func TestManagerGetState(t *testing.T) { blkID := ids.GenerateTestID() gotState, ok := m.GetState(blkID) require.False(ok) - require.Equal(state, gotState) + require.Equal(s, gotState) } // Case: Block isn't in memory; block is last accepted { gotState, ok := m.GetState(m.lastAccepted) require.True(ok) - require.Equal(state, gotState) + require.Equal(s, gotState) } } @@ -164,7 +164,7 @@ func TestManagerVerifyTx(t *testing.T) { preferred := ids.GenerateTestID() // These values don't matter for this test - state := states.NewMockState(ctrl) + state := state.NewMockState(ctrl) state.EXPECT().GetLastAccepted().Return(preferred) state.EXPECT().GetTimestamp().Return(time.Time{}) @@ -197,7 +197,7 @@ func TestManagerVerifyTx(t *testing.T) { preferred := ids.GenerateTestID() // These values don't matter for this test - state := states.NewMockState(ctrl) + state := state.NewMockState(ctrl) state.EXPECT().GetLastAccepted().Return(preferred) state.EXPECT().GetTimestamp().Return(time.Time{}) @@ -237,7 +237,7 @@ func TestManagerVerifyTx(t *testing.T) { preferred.EXPECT().Parent().Return(lastAcceptedID).AnyTimes() // These values don't matter for this test - diffState := states.NewMockDiff(ctrl) + diffState := state.NewMockDiff(ctrl) diffState.EXPECT().GetLastAccepted().Return(preferredID) diffState.EXPECT().GetTimestamp().Return(time.Time{}) @@ -276,7 +276,7 @@ func TestManagerVerifyTx(t *testing.T) { preferred := ids.GenerateTestID() // These values don't matter for this test - state := states.NewMockState(ctrl) + state := state.NewMockState(ctrl) state.EXPECT().GetLastAccepted().Return(preferred) state.EXPECT().GetTimestamp().Return(time.Time{}) diff --git a/vms/avm/block/executor/mock_manager.go b/vms/avm/block/executor/mock_manager.go index b3560f2e8afa..5e27089b19fa 100644 --- a/vms/avm/block/executor/mock_manager.go +++ b/vms/avm/block/executor/mock_manager.go @@ -14,7 +14,7 @@ import ( snowman "github.com/ava-labs/avalanchego/snow/consensus/snowman" set "github.com/ava-labs/avalanchego/utils/set" block "github.com/ava-labs/avalanchego/vms/avm/block" - states "github.com/ava-labs/avalanchego/vms/avm/states" + state "github.com/ava-labs/avalanchego/vms/avm/state" txs "github.com/ava-labs/avalanchego/vms/avm/txs" gomock "go.uber.org/mock/gomock" ) @@ -58,10 +58,10 @@ func (mr *MockManagerMockRecorder) GetBlock(arg0 interface{}) *gomock.Call { } // GetState mocks base method. -func (m *MockManager) GetState(arg0 ids.ID) (states.Chain, bool) { +func (m *MockManager) GetState(arg0 ids.ID) (state.Chain, bool) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetState", arg0) - ret0, _ := ret[0].(states.Chain) + ret0, _ := ret[0].(state.Chain) ret1, _ := ret[1].(bool) return ret0, ret1 } diff --git a/vms/avm/service_test.go b/vms/avm/service_test.go index 67a92a663879..19cacb158d13 100644 --- a/vms/avm/service_test.go +++ b/vms/avm/service_test.go @@ -36,7 +36,7 @@ import ( "github.com/ava-labs/avalanchego/vms/avm/block" "github.com/ava-labs/avalanchego/vms/avm/block/executor" "github.com/ava-labs/avalanchego/vms/avm/config" - "github.com/ava-labs/avalanchego/vms/avm/states" + "github.com/ava-labs/avalanchego/vms/avm/state" "github.com/ava-labs/avalanchego/vms/avm/txs" "github.com/ava-labs/avalanchego/vms/components/avax" "github.com/ava-labs/avalanchego/vms/components/index" @@ -2266,7 +2266,7 @@ func TestServiceGetBlockByHeight(t *testing.T) { { name: "block height not found", serviceAndExpectedBlockFunc: func(_ *testing.T, ctrl *gomock.Controller) (*Service, interface{}) { - state := states.NewMockState(ctrl) + state := state.NewMockState(ctrl) state.EXPECT().GetBlockIDAtHeight(blockHeight).Return(ids.Empty, database.ErrNotFound) manager := executor.NewMockManager(ctrl) @@ -2286,7 +2286,7 @@ func TestServiceGetBlockByHeight(t *testing.T) { { name: "block not found", serviceAndExpectedBlockFunc: func(_ *testing.T, ctrl *gomock.Controller) (*Service, interface{}) { - state := states.NewMockState(ctrl) + state := state.NewMockState(ctrl) state.EXPECT().GetBlockIDAtHeight(blockHeight).Return(blockID, nil) manager := executor.NewMockManager(ctrl) @@ -2311,7 +2311,7 @@ func TestServiceGetBlockByHeight(t *testing.T) { block.EXPECT().InitCtx(gomock.Any()) block.EXPECT().Txs().Return(nil) - state := states.NewMockState(ctrl) + state := state.NewMockState(ctrl) state.EXPECT().GetBlockIDAtHeight(blockHeight).Return(blockID, nil) manager := executor.NewMockManager(ctrl) @@ -2336,7 +2336,7 @@ func TestServiceGetBlockByHeight(t *testing.T) { blockBytes := []byte("hi mom") block.EXPECT().Bytes().Return(blockBytes) - state := states.NewMockState(ctrl) + state := state.NewMockState(ctrl) state.EXPECT().GetBlockIDAtHeight(blockHeight).Return(blockID, nil) expected, err := formatting.Encode(formatting.Hex, blockBytes) @@ -2364,7 +2364,7 @@ func TestServiceGetBlockByHeight(t *testing.T) { blockBytes := []byte("hi mom") block.EXPECT().Bytes().Return(blockBytes) - state := states.NewMockState(ctrl) + state := state.NewMockState(ctrl) state.EXPECT().GetBlockIDAtHeight(blockHeight).Return(blockID, nil) expected, err := formatting.Encode(formatting.HexC, blockBytes) @@ -2392,7 +2392,7 @@ func TestServiceGetBlockByHeight(t *testing.T) { blockBytes := []byte("hi mom") block.EXPECT().Bytes().Return(blockBytes) - state := states.NewMockState(ctrl) + state := state.NewMockState(ctrl) state.EXPECT().GetBlockIDAtHeight(blockHeight).Return(blockID, nil) expected, err := formatting.Encode(formatting.HexNC, blockBytes) @@ -2470,7 +2470,7 @@ func TestServiceGetHeight(t *testing.T) { { name: "block not found", serviceFunc: func(ctrl *gomock.Controller) *Service { - state := states.NewMockState(ctrl) + state := state.NewMockState(ctrl) state.EXPECT().GetLastAccepted().Return(blockID) manager := executor.NewMockManager(ctrl) @@ -2490,7 +2490,7 @@ func TestServiceGetHeight(t *testing.T) { { name: "happy path", serviceFunc: func(ctrl *gomock.Controller) *Service { - state := states.NewMockState(ctrl) + state := state.NewMockState(ctrl) state.EXPECT().GetLastAccepted().Return(blockID) block := block.NewMockBlock(ctrl) diff --git a/vms/avm/states/diff.go b/vms/avm/state/diff.go similarity index 99% rename from vms/avm/states/diff.go rename to vms/avm/state/diff.go index 2ca6d58cd5ee..1d53fa37da3e 100644 --- a/vms/avm/states/diff.go +++ b/vms/avm/state/diff.go @@ -1,7 +1,7 @@ // Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. // See the file LICENSE for licensing terms. -package states +package state import ( "errors" diff --git a/vms/avm/states/mock_states.go b/vms/avm/state/mock_state.go similarity index 99% rename from vms/avm/states/mock_states.go rename to vms/avm/state/mock_state.go index 007b8622042e..3bb615283ce6 100644 --- a/vms/avm/states/mock_states.go +++ b/vms/avm/state/mock_state.go @@ -2,10 +2,10 @@ // See the file LICENSE for licensing terms. // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/ava-labs/avalanchego/vms/avm/states (interfaces: Chain,State,Diff) +// Source: github.com/ava-labs/avalanchego/vms/avm/state (interfaces: Chain,State,Diff) -// Package states is a generated GoMock package. -package states +// Package state is a generated GoMock package. +package state import ( reflect "reflect" diff --git a/vms/avm/states/state.go b/vms/avm/state/state.go similarity index 99% rename from vms/avm/states/state.go rename to vms/avm/state/state.go index 1167cdb37dce..e290f093aa22 100644 --- a/vms/avm/states/state.go +++ b/vms/avm/state/state.go @@ -1,7 +1,7 @@ // Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. // See the file LICENSE for licensing terms. -package states +package state import ( "bytes" diff --git a/vms/avm/states/state_test.go b/vms/avm/state/state_test.go similarity index 99% rename from vms/avm/states/state_test.go rename to vms/avm/state/state_test.go index b64fa3aa7933..c97836aee794 100644 --- a/vms/avm/states/state_test.go +++ b/vms/avm/state/state_test.go @@ -1,7 +1,7 @@ // Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. // See the file LICENSE for licensing terms. -package states +package state import ( "testing" diff --git a/vms/avm/states/versions.go b/vms/avm/state/versions.go similarity index 95% rename from vms/avm/states/versions.go rename to vms/avm/state/versions.go index 409c47becfff..da84182bb683 100644 --- a/vms/avm/states/versions.go +++ b/vms/avm/state/versions.go @@ -1,7 +1,7 @@ // Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. // See the file LICENSE for licensing terms. -package states +package state import "github.com/ava-labs/avalanchego/ids" diff --git a/vms/avm/txs/executor/executor.go b/vms/avm/txs/executor/executor.go index 040b1d9c816f..6a5991cade04 100644 --- a/vms/avm/txs/executor/executor.go +++ b/vms/avm/txs/executor/executor.go @@ -10,7 +10,7 @@ import ( "github.com/ava-labs/avalanchego/codec" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils/set" - "github.com/ava-labs/avalanchego/vms/avm/states" + "github.com/ava-labs/avalanchego/vms/avm/state" "github.com/ava-labs/avalanchego/vms/avm/txs" "github.com/ava-labs/avalanchego/vms/components/avax" ) @@ -19,7 +19,7 @@ var _ txs.Visitor = (*Executor)(nil) type Executor struct { Codec codec.Manager - State states.Chain // state will be modified + State state.Chain // state will be modified Tx *txs.Tx Inputs set.Set[ids.ID] // imported inputs AtomicRequests map[ids.ID]*atomic.Requests // may be nil diff --git a/vms/avm/txs/executor/executor_test.go b/vms/avm/txs/executor/executor_test.go index 042ae39a9048..81d301ad6bae 100644 --- a/vms/avm/txs/executor/executor_test.go +++ b/vms/avm/txs/executor/executor_test.go @@ -19,7 +19,7 @@ import ( "github.com/ava-labs/avalanchego/utils/units" "github.com/ava-labs/avalanchego/vms/avm/block" "github.com/ava-labs/avalanchego/vms/avm/fxs" - "github.com/ava-labs/avalanchego/vms/avm/states" + "github.com/ava-labs/avalanchego/vms/avm/state" "github.com/ava-labs/avalanchego/vms/avm/txs" "github.com/ava-labs/avalanchego/vms/components/avax" "github.com/ava-labs/avalanchego/vms/components/verify" @@ -44,7 +44,7 @@ func TestBaseTxExecutor(t *testing.T) { db := memdb.New() vdb := versiondb.New(db) registerer := prometheus.NewRegistry() - state, err := states.New(vdb, parser, registerer, trackChecksums) + state, err := state.New(vdb, parser, registerer, trackChecksums) require.NoError(err) utxoID := avax.UTXOID{ @@ -149,7 +149,7 @@ func TestCreateAssetTxExecutor(t *testing.T) { db := memdb.New() vdb := versiondb.New(db) registerer := prometheus.NewRegistry() - state, err := states.New(vdb, parser, registerer, trackChecksums) + state, err := state.New(vdb, parser, registerer, trackChecksums) require.NoError(err) utxoID := avax.UTXOID{ @@ -292,7 +292,7 @@ func TestOperationTxExecutor(t *testing.T) { db := memdb.New() vdb := versiondb.New(db) registerer := prometheus.NewRegistry() - state, err := states.New(vdb, parser, registerer, trackChecksums) + state, err := state.New(vdb, parser, registerer, trackChecksums) require.NoError(err) outputOwners := secp256k1fx.OutputOwners{ diff --git a/vms/avm/txs/executor/semantic_verifier.go b/vms/avm/txs/executor/semantic_verifier.go index 0a8d59083255..1fd94fb0e131 100644 --- a/vms/avm/txs/executor/semantic_verifier.go +++ b/vms/avm/txs/executor/semantic_verifier.go @@ -9,7 +9,7 @@ import ( "reflect" "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/vms/avm/states" + "github.com/ava-labs/avalanchego/vms/avm/state" "github.com/ava-labs/avalanchego/vms/avm/txs" "github.com/ava-labs/avalanchego/vms/components/avax" "github.com/ava-labs/avalanchego/vms/components/verify" @@ -26,7 +26,7 @@ var ( type SemanticVerifier struct { *Backend - State states.ReadOnlyChain + State state.ReadOnlyChain Tx *txs.Tx } diff --git a/vms/avm/txs/executor/semantic_verifier_test.go b/vms/avm/txs/executor/semantic_verifier_test.go index 72638762c39b..7c659f7ec227 100644 --- a/vms/avm/txs/executor/semantic_verifier_test.go +++ b/vms/avm/txs/executor/semantic_verifier_test.go @@ -22,7 +22,7 @@ import ( "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/timer/mockable" "github.com/ava-labs/avalanchego/vms/avm/fxs" - "github.com/ava-labs/avalanchego/vms/avm/states" + "github.com/ava-labs/avalanchego/vms/avm/state" "github.com/ava-labs/avalanchego/vms/avm/txs" "github.com/ava-labs/avalanchego/vms/components/avax" "github.com/ava-labs/avalanchego/vms/components/verify" @@ -117,14 +117,14 @@ func TestSemanticVerifierBaseTx(t *testing.T) { tests := []struct { name string - stateFunc func(*gomock.Controller) states.Chain + stateFunc func(*gomock.Controller) state.Chain txFunc func(*require.Assertions) *txs.Tx err error }{ { name: "valid", - stateFunc: func(ctrl *gomock.Controller) states.Chain { - state := states.NewMockChain(ctrl) + stateFunc: func(ctrl *gomock.Controller) state.Chain { + state := state.NewMockChain(ctrl) state.EXPECT().GetUTXO(utxoID.InputID()).Return(&utxo, nil) state.EXPECT().GetTx(asset.ID).Return(&createAssetTx, nil) @@ -147,8 +147,8 @@ func TestSemanticVerifierBaseTx(t *testing.T) { }, { name: "assetID mismatch", - stateFunc: func(ctrl *gomock.Controller) states.Chain { - state := states.NewMockChain(ctrl) + stateFunc: func(ctrl *gomock.Controller) state.Chain { + state := state.NewMockChain(ctrl) utxo := utxo utxo.Asset.ID = ids.GenerateTestID() @@ -173,8 +173,8 @@ func TestSemanticVerifierBaseTx(t *testing.T) { }, { name: "not allowed input feature extension", - stateFunc: func(ctrl *gomock.Controller) states.Chain { - state := states.NewMockChain(ctrl) + stateFunc: func(ctrl *gomock.Controller) state.Chain { + state := state.NewMockChain(ctrl) unsignedCreateAssetTx := unsignedCreateAssetTx unsignedCreateAssetTx.States = nil @@ -204,8 +204,8 @@ func TestSemanticVerifierBaseTx(t *testing.T) { }, { name: "invalid signature", - stateFunc: func(ctrl *gomock.Controller) states.Chain { - state := states.NewMockChain(ctrl) + stateFunc: func(ctrl *gomock.Controller) state.Chain { + state := state.NewMockChain(ctrl) state.EXPECT().GetUTXO(utxoID.InputID()).Return(&utxo, nil) state.EXPECT().GetTx(asset.ID).Return(&createAssetTx, nil) @@ -228,8 +228,8 @@ func TestSemanticVerifierBaseTx(t *testing.T) { }, { name: "missing UTXO", - stateFunc: func(ctrl *gomock.Controller) states.Chain { - state := states.NewMockChain(ctrl) + stateFunc: func(ctrl *gomock.Controller) state.Chain { + state := state.NewMockChain(ctrl) state.EXPECT().GetUTXO(utxoID.InputID()).Return(nil, database.ErrNotFound) @@ -251,8 +251,8 @@ func TestSemanticVerifierBaseTx(t *testing.T) { }, { name: "invalid UTXO amount", - stateFunc: func(ctrl *gomock.Controller) states.Chain { - state := states.NewMockChain(ctrl) + stateFunc: func(ctrl *gomock.Controller) state.Chain { + state := state.NewMockChain(ctrl) output := output output.Amt-- @@ -281,8 +281,8 @@ func TestSemanticVerifierBaseTx(t *testing.T) { }, { name: "not allowed output feature extension", - stateFunc: func(ctrl *gomock.Controller) states.Chain { - state := states.NewMockChain(ctrl) + stateFunc: func(ctrl *gomock.Controller) state.Chain { + state := state.NewMockChain(ctrl) unsignedCreateAssetTx := unsignedCreateAssetTx unsignedCreateAssetTx.States = nil @@ -317,8 +317,8 @@ func TestSemanticVerifierBaseTx(t *testing.T) { }, { name: "unknown asset", - stateFunc: func(ctrl *gomock.Controller) states.Chain { - state := states.NewMockChain(ctrl) + stateFunc: func(ctrl *gomock.Controller) state.Chain { + state := state.NewMockChain(ctrl) state.EXPECT().GetUTXO(utxoID.InputID()).Return(&utxo, nil) state.EXPECT().GetTx(asset.ID).Return(nil, database.ErrNotFound) @@ -341,8 +341,8 @@ func TestSemanticVerifierBaseTx(t *testing.T) { }, { name: "not an asset", - stateFunc: func(ctrl *gomock.Controller) states.Chain { - state := states.NewMockChain(ctrl) + stateFunc: func(ctrl *gomock.Controller) state.Chain { + state := state.NewMockChain(ctrl) tx := txs.Tx{ Unsigned: &baseTx, @@ -483,14 +483,14 @@ func TestSemanticVerifierExportTx(t *testing.T) { tests := []struct { name string - stateFunc func(*gomock.Controller) states.Chain + stateFunc func(*gomock.Controller) state.Chain txFunc func(*require.Assertions) *txs.Tx err error }{ { name: "valid", - stateFunc: func(ctrl *gomock.Controller) states.Chain { - state := states.NewMockChain(ctrl) + stateFunc: func(ctrl *gomock.Controller) state.Chain { + state := state.NewMockChain(ctrl) state.EXPECT().GetUTXO(utxoID.InputID()).Return(&utxo, nil) state.EXPECT().GetTx(asset.ID).Return(&createAssetTx, nil) @@ -513,8 +513,8 @@ func TestSemanticVerifierExportTx(t *testing.T) { }, { name: "assetID mismatch", - stateFunc: func(ctrl *gomock.Controller) states.Chain { - state := states.NewMockChain(ctrl) + stateFunc: func(ctrl *gomock.Controller) state.Chain { + state := state.NewMockChain(ctrl) utxo := utxo utxo.Asset.ID = ids.GenerateTestID() @@ -539,8 +539,8 @@ func TestSemanticVerifierExportTx(t *testing.T) { }, { name: "not allowed input feature extension", - stateFunc: func(ctrl *gomock.Controller) states.Chain { - state := states.NewMockChain(ctrl) + stateFunc: func(ctrl *gomock.Controller) state.Chain { + state := state.NewMockChain(ctrl) unsignedCreateAssetTx := unsignedCreateAssetTx unsignedCreateAssetTx.States = nil @@ -570,8 +570,8 @@ func TestSemanticVerifierExportTx(t *testing.T) { }, { name: "invalid signature", - stateFunc: func(ctrl *gomock.Controller) states.Chain { - state := states.NewMockChain(ctrl) + stateFunc: func(ctrl *gomock.Controller) state.Chain { + state := state.NewMockChain(ctrl) state.EXPECT().GetUTXO(utxoID.InputID()).Return(&utxo, nil) state.EXPECT().GetTx(asset.ID).Return(&createAssetTx, nil) @@ -594,8 +594,8 @@ func TestSemanticVerifierExportTx(t *testing.T) { }, { name: "missing UTXO", - stateFunc: func(ctrl *gomock.Controller) states.Chain { - state := states.NewMockChain(ctrl) + stateFunc: func(ctrl *gomock.Controller) state.Chain { + state := state.NewMockChain(ctrl) state.EXPECT().GetUTXO(utxoID.InputID()).Return(nil, database.ErrNotFound) @@ -617,8 +617,8 @@ func TestSemanticVerifierExportTx(t *testing.T) { }, { name: "invalid UTXO amount", - stateFunc: func(ctrl *gomock.Controller) states.Chain { - state := states.NewMockChain(ctrl) + stateFunc: func(ctrl *gomock.Controller) state.Chain { + state := state.NewMockChain(ctrl) output := output output.Amt-- @@ -647,8 +647,8 @@ func TestSemanticVerifierExportTx(t *testing.T) { }, { name: "not allowed output feature extension", - stateFunc: func(ctrl *gomock.Controller) states.Chain { - state := states.NewMockChain(ctrl) + stateFunc: func(ctrl *gomock.Controller) state.Chain { + state := state.NewMockChain(ctrl) unsignedCreateAssetTx := unsignedCreateAssetTx unsignedCreateAssetTx.States = nil @@ -683,8 +683,8 @@ func TestSemanticVerifierExportTx(t *testing.T) { }, { name: "unknown asset", - stateFunc: func(ctrl *gomock.Controller) states.Chain { - state := states.NewMockChain(ctrl) + stateFunc: func(ctrl *gomock.Controller) state.Chain { + state := state.NewMockChain(ctrl) state.EXPECT().GetUTXO(utxoID.InputID()).Return(&utxo, nil) state.EXPECT().GetTx(asset.ID).Return(nil, database.ErrNotFound) @@ -707,8 +707,8 @@ func TestSemanticVerifierExportTx(t *testing.T) { }, { name: "not an asset", - stateFunc: func(ctrl *gomock.Controller) states.Chain { - state := states.NewMockChain(ctrl) + stateFunc: func(ctrl *gomock.Controller) state.Chain { + state := state.NewMockChain(ctrl) tx := txs.Tx{ Unsigned: &baseTx, @@ -849,7 +849,7 @@ func TestSemanticVerifierExportTxDifferentSubnet(t *testing.T) { Unsigned: &unsignedCreateAssetTx, } - state := states.NewMockChain(ctrl) + state := state.NewMockChain(ctrl) state.EXPECT().GetUTXO(utxoID.InputID()).Return(&utxo, nil) state.EXPECT().GetTx(asset.ID).Return(&createAssetTx, nil) @@ -999,14 +999,14 @@ func TestSemanticVerifierImportTx(t *testing.T) { } tests := []struct { name string - stateFunc func(*gomock.Controller) states.Chain + stateFunc func(*gomock.Controller) state.Chain txFunc func(*require.Assertions) *txs.Tx expectedErr error }{ { name: "valid", - stateFunc: func(ctrl *gomock.Controller) states.Chain { - state := states.NewMockChain(ctrl) + stateFunc: func(ctrl *gomock.Controller) state.Chain { + state := state.NewMockChain(ctrl) state.EXPECT().GetUTXO(utxoID.InputID()).Return(&utxo, nil).AnyTimes() state.EXPECT().GetTx(asset.ID).Return(&createAssetTx, nil).AnyTimes() return state @@ -1018,8 +1018,8 @@ func TestSemanticVerifierImportTx(t *testing.T) { }, { name: "not allowed input feature extension", - stateFunc: func(ctrl *gomock.Controller) states.Chain { - state := states.NewMockChain(ctrl) + stateFunc: func(ctrl *gomock.Controller) state.Chain { + state := state.NewMockChain(ctrl) unsignedCreateAssetTx := unsignedCreateAssetTx unsignedCreateAssetTx.States = nil createAssetTx := txs.Tx{ @@ -1036,8 +1036,8 @@ func TestSemanticVerifierImportTx(t *testing.T) { }, { name: "invalid signature", - stateFunc: func(ctrl *gomock.Controller) states.Chain { - state := states.NewMockChain(ctrl) + stateFunc: func(ctrl *gomock.Controller) state.Chain { + state := state.NewMockChain(ctrl) state.EXPECT().GetUTXO(utxoID.InputID()).Return(&utxo, nil).AnyTimes() state.EXPECT().GetTx(asset.ID).Return(&createAssetTx, nil).AnyTimes() return state @@ -1058,8 +1058,8 @@ func TestSemanticVerifierImportTx(t *testing.T) { }, { name: "not allowed output feature extension", - stateFunc: func(ctrl *gomock.Controller) states.Chain { - state := states.NewMockChain(ctrl) + stateFunc: func(ctrl *gomock.Controller) state.Chain { + state := state.NewMockChain(ctrl) unsignedCreateAssetTx := unsignedCreateAssetTx unsignedCreateAssetTx.States = nil createAssetTx := txs.Tx{ @@ -1087,8 +1087,8 @@ func TestSemanticVerifierImportTx(t *testing.T) { }, { name: "unknown asset", - stateFunc: func(ctrl *gomock.Controller) states.Chain { - state := states.NewMockChain(ctrl) + stateFunc: func(ctrl *gomock.Controller) state.Chain { + state := state.NewMockChain(ctrl) state.EXPECT().GetUTXO(utxoID.InputID()).Return(&utxo, nil).AnyTimes() state.EXPECT().GetTx(asset.ID).Return(nil, database.ErrNotFound) return state @@ -1100,8 +1100,8 @@ func TestSemanticVerifierImportTx(t *testing.T) { }, { name: "not an asset", - stateFunc: func(ctrl *gomock.Controller) states.Chain { - state := states.NewMockChain(ctrl) + stateFunc: func(ctrl *gomock.Controller) state.Chain { + state := state.NewMockChain(ctrl) tx := txs.Tx{ Unsigned: &baseTx, } diff --git a/vms/avm/vm.go b/vms/avm/vm.go index cae4514ff3a6..36049befd07c 100644 --- a/vms/avm/vm.go +++ b/vms/avm/vm.go @@ -38,7 +38,7 @@ import ( "github.com/ava-labs/avalanchego/vms/avm/config" "github.com/ava-labs/avalanchego/vms/avm/metrics" "github.com/ava-labs/avalanchego/vms/avm/network" - "github.com/ava-labs/avalanchego/vms/avm/states" + "github.com/ava-labs/avalanchego/vms/avm/state" "github.com/ava-labs/avalanchego/vms/avm/txs" "github.com/ava-labs/avalanchego/vms/avm/txs/mempool" "github.com/ava-labs/avalanchego/vms/avm/utxo" @@ -91,7 +91,7 @@ type VM struct { appSender common.AppSender // State management - state states.State + state state.State // Set to true once this VM is marked as `Bootstrapped` by the engine bootstrapped bool @@ -220,7 +220,7 @@ func (vm *VM) Initialize( vm.AtomicUTXOManager = avax.NewAtomicUTXOManager(ctx.SharedMemory, codec) vm.Spender = utxo.NewSpender(&vm.clock, codec) - state, err := states.New( + state, err := state.New( vm.db, vm.parser, vm.registerer, From be422a0179bb920d20d5e4f547f07cc23671e3df Mon Sep 17 00:00:00 2001 From: Stephen Buttolph Date: Tue, 28 Nov 2023 13:58:45 -0500 Subject: [PATCH 3/3] Implement generic bimap (#2383) --- codec/hierarchycodec/codec.go | 26 +- codec/linearcodec/codec.go | 22 +- .../avalanche/bootstrap/bootstrapper.go | 40 ++- snow/engine/common/fetcher.go | 14 - snow/engine/common/request.go | 11 + snow/engine/common/requests.go | 110 ------ snow/engine/common/requests_test.go | 69 ---- snow/engine/snowman/bootstrap/bootstrapper.go | 37 +- .../snowman/bootstrap/bootstrapper_test.go | 3 +- snow/engine/snowman/transitive.go | 34 +- utils/bimap/bimap.go | 102 ++++++ utils/bimap/bimap_test.go | 328 ++++++++++++++++++ 12 files changed, 545 insertions(+), 251 deletions(-) delete mode 100644 snow/engine/common/fetcher.go create mode 100644 snow/engine/common/request.go delete mode 100644 snow/engine/common/requests.go delete mode 100644 snow/engine/common/requests_test.go create mode 100644 utils/bimap/bimap.go create mode 100644 utils/bimap/bimap_test.go diff --git a/codec/hierarchycodec/codec.go b/codec/hierarchycodec/codec.go index d1d03d879275..1b82380bc576 100644 --- a/codec/hierarchycodec/codec.go +++ b/codec/hierarchycodec/codec.go @@ -10,6 +10,7 @@ import ( "github.com/ava-labs/avalanchego/codec" "github.com/ava-labs/avalanchego/codec/reflectcodec" + "github.com/ava-labs/avalanchego/utils/bimap" "github.com/ava-labs/avalanchego/utils/wrappers" ) @@ -42,20 +43,18 @@ type typeID struct { type hierarchyCodec struct { codec.Codec - lock sync.RWMutex - currentGroupID uint16 - nextTypeID uint16 - typeIDToType map[typeID]reflect.Type - typeToTypeID map[reflect.Type]typeID + lock sync.RWMutex + currentGroupID uint16 + nextTypeID uint16 + registeredTypes *bimap.BiMap[typeID, reflect.Type] } // New returns a new, concurrency-safe codec func New(tagNames []string, maxSliceLen uint32) Codec { hCodec := &hierarchyCodec{ - currentGroupID: 0, - nextTypeID: 0, - typeIDToType: map[typeID]reflect.Type{}, - typeToTypeID: map[reflect.Type]typeID{}, + currentGroupID: 0, + nextTypeID: 0, + registeredTypes: bimap.New[typeID, reflect.Type](), } hCodec.Codec = reflectcodec.New(hCodec, tagNames, maxSliceLen) return hCodec @@ -88,7 +87,7 @@ func (c *hierarchyCodec) RegisterType(val interface{}) error { defer c.lock.Unlock() valType := reflect.TypeOf(val) - if _, exists := c.typeToTypeID[valType]; exists { + if c.registeredTypes.HasValue(valType) { return fmt.Errorf("%w: %v", codec.ErrDuplicateType, valType) } @@ -98,8 +97,7 @@ func (c *hierarchyCodec) RegisterType(val interface{}) error { } c.nextTypeID++ - c.typeIDToType[valTypeID] = valType - c.typeToTypeID[valType] = valTypeID + c.registeredTypes.Put(valTypeID, valType) return nil } @@ -112,7 +110,7 @@ func (c *hierarchyCodec) PackPrefix(p *wrappers.Packer, valueType reflect.Type) c.lock.RLock() defer c.lock.RUnlock() - typeID, ok := c.typeToTypeID[valueType] // Get the type ID of the value being marshaled + typeID, ok := c.registeredTypes.GetKey(valueType) // Get the type ID of the value being marshaled if !ok { return fmt.Errorf("can't marshal unregistered type %q", valueType) } @@ -136,7 +134,7 @@ func (c *hierarchyCodec) UnpackPrefix(p *wrappers.Packer, valueType reflect.Type typeID: typeIDShort, } // Get a type that implements the interface - implementingType, ok := c.typeIDToType[t] + implementingType, ok := c.registeredTypes.GetValue(t) if !ok { return reflect.Value{}, fmt.Errorf("couldn't unmarshal interface: unknown type ID %+v", t) } diff --git a/codec/linearcodec/codec.go b/codec/linearcodec/codec.go index 677c331b0366..07097aee79eb 100644 --- a/codec/linearcodec/codec.go +++ b/codec/linearcodec/codec.go @@ -10,6 +10,7 @@ import ( "github.com/ava-labs/avalanchego/codec" "github.com/ava-labs/avalanchego/codec/reflectcodec" + "github.com/ava-labs/avalanchego/utils/bimap" "github.com/ava-labs/avalanchego/utils/wrappers" ) @@ -36,19 +37,17 @@ type Codec interface { type linearCodec struct { codec.Codec - lock sync.RWMutex - nextTypeID uint32 - typeIDToType map[uint32]reflect.Type - typeToTypeID map[reflect.Type]uint32 + lock sync.RWMutex + nextTypeID uint32 + registeredTypes *bimap.BiMap[uint32, reflect.Type] } // New returns a new, concurrency-safe codec; it allow to specify // both tagNames and maxSlicelenght func New(tagNames []string, maxSliceLen uint32) Codec { hCodec := &linearCodec{ - nextTypeID: 0, - typeIDToType: map[uint32]reflect.Type{}, - typeToTypeID: map[reflect.Type]uint32{}, + nextTypeID: 0, + registeredTypes: bimap.New[uint32, reflect.Type](), } hCodec.Codec = reflectcodec.New(hCodec, tagNames, maxSliceLen) return hCodec @@ -78,12 +77,11 @@ func (c *linearCodec) RegisterType(val interface{}) error { defer c.lock.Unlock() valType := reflect.TypeOf(val) - if _, exists := c.typeToTypeID[valType]; exists { + if c.registeredTypes.HasValue(valType) { return fmt.Errorf("%w: %v", codec.ErrDuplicateType, valType) } - c.typeIDToType[c.nextTypeID] = valType - c.typeToTypeID[valType] = c.nextTypeID + c.registeredTypes.Put(c.nextTypeID, valType) c.nextTypeID++ return nil } @@ -97,7 +95,7 @@ func (c *linearCodec) PackPrefix(p *wrappers.Packer, valueType reflect.Type) err c.lock.RLock() defer c.lock.RUnlock() - typeID, ok := c.typeToTypeID[valueType] // Get the type ID of the value being marshaled + typeID, ok := c.registeredTypes.GetKey(valueType) // Get the type ID of the value being marshaled if !ok { return fmt.Errorf("can't marshal unregistered type %q", valueType) } @@ -114,7 +112,7 @@ func (c *linearCodec) UnpackPrefix(p *wrappers.Packer, valueType reflect.Type) ( return reflect.Value{}, fmt.Errorf("couldn't unmarshal interface: %w", p.Err) } // Get a type that implements the interface - implementingType, ok := c.typeIDToType[typeID] + implementingType, ok := c.registeredTypes.GetValue(typeID) if !ok { return reflect.Value{}, fmt.Errorf("couldn't unmarshal interface: unknown type ID %d", typeID) } diff --git a/snow/engine/avalanche/bootstrap/bootstrapper.go b/snow/engine/avalanche/bootstrap/bootstrapper.go index 59f421158fa2..967d65711abc 100644 --- a/snow/engine/avalanche/bootstrap/bootstrapper.go +++ b/snow/engine/avalanche/bootstrap/bootstrapper.go @@ -16,6 +16,7 @@ import ( "github.com/ava-labs/avalanchego/snow/choices" "github.com/ava-labs/avalanchego/snow/consensus/avalanche" "github.com/ava-labs/avalanchego/snow/engine/common" + "github.com/ava-labs/avalanchego/utils/bimap" "github.com/ava-labs/avalanchego/utils/heap" "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/set" @@ -57,10 +58,10 @@ func New( ChitsHandler: common.NewNoOpChitsHandler(config.Ctx.Log), AppHandler: config.VM, + outstandingRequests: bimap.New[common.Request, ids.ID](), + processedCache: &cache.LRU[ids.ID, struct{}]{Size: cacheSize}, - Fetcher: common.Fetcher{ - OnFinished: onFinished, - }, + onFinished: onFinished, } return b, b.metrics.Initialize("bs", config.Ctx.AvalancheRegisterer) } @@ -81,9 +82,11 @@ type bootstrapper struct { common.ChitsHandler common.AppHandler - common.Fetcher metrics + // tracks which validators were asked for which containers in which requests + outstandingRequests *bimap.BiMap[common.Request, ids.ID] + // IDs of vertices that we will send a GetAncestors request for once we are // not at the max number of outstanding requests needToFetch set.Set[ids.ID] @@ -93,6 +96,9 @@ type bootstrapper struct { // Tracks the last requestID that was used in a request requestID uint32 + + // Called when bootstrapping is done on a specific chain + onFinished func(ctx context.Context, lastReqID uint32) error } func (b *bootstrapper) Context() *snow.ConsensusContext { @@ -137,7 +143,10 @@ func (b *bootstrapper) Ancestors(ctx context.Context, nodeID ids.NodeID, request vtxs = vtxs[:b.Config.AncestorsMaxContainersReceived] } - requestedVtxID, requested := b.OutstandingRequests.Remove(nodeID, requestID) + requestedVtxID, requested := b.outstandingRequests.DeleteKey(common.Request{ + NodeID: nodeID, + RequestID: requestID, + }) vtx, err := b.Manager.ParseVtx(ctx, vtxs[0]) // first vertex should be the one we requested in GetAncestors request if err != nil { if !requested { @@ -177,7 +186,7 @@ func (b *bootstrapper) Ancestors(ctx context.Context, nodeID ids.NodeID, request ) return b.fetch(ctx, requestedVtxID) } - if !requested && !b.OutstandingRequests.Contains(vtxID) && !b.needToFetch.Contains(vtxID) { + if !requested && !b.outstandingRequests.HasValue(vtxID) && !b.needToFetch.Contains(vtxID) { b.Ctx.Log.Debug("received un-needed vertex", zap.Stringer("nodeID", nodeID), zap.Uint32("requestID", requestID), @@ -244,7 +253,10 @@ func (b *bootstrapper) Ancestors(ctx context.Context, nodeID ids.NodeID, request } func (b *bootstrapper) GetAncestorsFailed(ctx context.Context, nodeID ids.NodeID, requestID uint32) error { - vtxID, ok := b.OutstandingRequests.Remove(nodeID, requestID) + vtxID, ok := b.outstandingRequests.DeleteKey(common.Request{ + NodeID: nodeID, + RequestID: requestID, + }) if !ok { b.Ctx.Log.Debug("skipping GetAncestorsFailed call", zap.String("reason", "no matching outstanding request"), @@ -388,12 +400,12 @@ func (b *bootstrapper) HealthCheck(ctx context.Context) (interface{}, error) { // to fetch or we are at the maximum number of outstanding requests. func (b *bootstrapper) fetch(ctx context.Context, vtxIDs ...ids.ID) error { b.needToFetch.Add(vtxIDs...) - for b.needToFetch.Len() > 0 && b.OutstandingRequests.Len() < maxOutstandingGetAncestorsRequests { + for b.needToFetch.Len() > 0 && b.outstandingRequests.Len() < maxOutstandingGetAncestorsRequests { vtxID := b.needToFetch.CappedList(1)[0] b.needToFetch.Remove(vtxID) // Make sure we haven't already requested this vertex - if b.OutstandingRequests.Contains(vtxID) { + if b.outstandingRequests.HasValue(vtxID) { continue } @@ -409,7 +421,13 @@ func (b *bootstrapper) fetch(ctx context.Context, vtxIDs ...ids.ID) error { validatorID := validatorIDs[0] b.requestID++ - b.OutstandingRequests.Add(validatorID, b.requestID, vtxID) + b.outstandingRequests.Put( + common.Request{ + NodeID: validatorID, + RequestID: b.requestID, + }, + vtxID, + ) b.Config.Sender.SendGetAncestors(ctx, validatorID, b.requestID, vtxID) // request vertex and ancestors } return b.checkFinish(ctx) @@ -606,7 +624,7 @@ func (b *bootstrapper) checkFinish(ctx context.Context) error { } b.processedCache.Flush() - return b.OnFinished(ctx, b.requestID) + return b.onFinished(ctx, b.requestID) } // A vertex is less than another vertex if it is unknown. Ties are broken by diff --git a/snow/engine/common/fetcher.go b/snow/engine/common/fetcher.go deleted file mode 100644 index 9e90da3d325b..000000000000 --- a/snow/engine/common/fetcher.go +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package common - -import "context" - -type Fetcher struct { - // tracks which validators were asked for which containers in which requests - OutstandingRequests Requests - - // Called when bootstrapping is done on a specific chain - OnFinished func(ctx context.Context, lastReqID uint32) error -} diff --git a/snow/engine/common/request.go b/snow/engine/common/request.go new file mode 100644 index 000000000000..d677a485c8f4 --- /dev/null +++ b/snow/engine/common/request.go @@ -0,0 +1,11 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package common + +import "github.com/ava-labs/avalanchego/ids" + +type Request struct { + NodeID ids.NodeID + RequestID uint32 +} diff --git a/snow/engine/common/requests.go b/snow/engine/common/requests.go deleted file mode 100644 index ce66585e590d..000000000000 --- a/snow/engine/common/requests.go +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package common - -import ( - "fmt" - "strings" - - "github.com/ava-labs/avalanchego/ids" -) - -const ( - minRequestsSize = 32 -) - -type req struct { - vdr ids.NodeID - id uint32 -} - -// Requests tracks pending container messages from a peer. -type Requests struct { - reqsToID map[ids.NodeID]map[uint32]ids.ID - idToReq map[ids.ID]req -} - -// Add a request. Assumes that requestIDs are unique. Assumes that containerIDs -// are only in one request at a time. -func (r *Requests) Add(vdr ids.NodeID, requestID uint32, containerID ids.ID) { - if r.reqsToID == nil { - r.reqsToID = make(map[ids.NodeID]map[uint32]ids.ID, minRequestsSize) - } - vdrReqs, ok := r.reqsToID[vdr] - if !ok { - vdrReqs = make(map[uint32]ids.ID) - r.reqsToID[vdr] = vdrReqs - } - vdrReqs[requestID] = containerID - - if r.idToReq == nil { - r.idToReq = make(map[ids.ID]req, minRequestsSize) - } - r.idToReq[containerID] = req{ - vdr: vdr, - id: requestID, - } -} - -// Get the containerID the request is expecting and if the request exists. -func (r *Requests) Get(vdr ids.NodeID, requestID uint32) (ids.ID, bool) { - containerID, ok := r.reqsToID[vdr][requestID] - return containerID, ok -} - -// Remove attempts to abandon a requestID sent to a validator. If the request is -// currently outstanding, the requested ID will be returned along with true. If -// the request isn't currently outstanding, false will be returned. -func (r *Requests) Remove(vdr ids.NodeID, requestID uint32) (ids.ID, bool) { - vdrReqs := r.reqsToID[vdr] - containerID, ok := vdrReqs[requestID] - if !ok { - return ids.ID{}, false - } - - if len(vdrReqs) == 1 { - delete(r.reqsToID, vdr) - } else { - delete(vdrReqs, requestID) - } - - delete(r.idToReq, containerID) - return containerID, true -} - -// RemoveAny outstanding requests for the container ID. True is returned if the -// container ID had an outstanding request. -func (r *Requests) RemoveAny(containerID ids.ID) bool { - req, ok := r.idToReq[containerID] - if !ok { - return false - } - - r.Remove(req.vdr, req.id) - return true -} - -// Len returns the total number of outstanding requests. -func (r *Requests) Len() int { - return len(r.idToReq) -} - -// Contains returns true if there is an outstanding request for the container -// ID. -func (r *Requests) Contains(containerID ids.ID) bool { - _, ok := r.idToReq[containerID] - return ok -} - -func (r Requests) String() string { - sb := strings.Builder{} - sb.WriteString(fmt.Sprintf("Requests: (Num Validators = %d)", len(r.reqsToID))) - for vdr, reqs := range r.reqsToID { - sb.WriteString(fmt.Sprintf("\n VDR[%s]: (Outstanding Requests %d)", vdr, len(reqs))) - for reqID, containerID := range reqs { - sb.WriteString(fmt.Sprintf("\n Request[%d]: %s", reqID, containerID)) - } - } - return sb.String() -} diff --git a/snow/engine/common/requests_test.go b/snow/engine/common/requests_test.go deleted file mode 100644 index 73a98e4ccb94..000000000000 --- a/snow/engine/common/requests_test.go +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package common - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "github.com/ava-labs/avalanchego/ids" -) - -func TestRequests(t *testing.T) { - require := require.New(t) - - req := Requests{} - - require.Empty(req) - - _, removed := req.Remove(ids.EmptyNodeID, 0) - require.False(removed) - - require.False(req.RemoveAny(ids.Empty)) - require.False(req.Contains(ids.Empty)) - - req.Add(ids.EmptyNodeID, 0, ids.Empty) - require.Equal(1, req.Len()) - - _, removed = req.Remove(ids.EmptyNodeID, 1) - require.False(removed) - - _, removed = req.Remove(ids.BuildTestNodeID([]byte{0x01}), 0) - require.False(removed) - - require.True(req.Contains(ids.Empty)) - require.Equal(1, req.Len()) - - req.Add(ids.EmptyNodeID, 10, ids.Empty.Prefix(0)) - require.Equal(2, req.Len()) - - _, removed = req.Remove(ids.EmptyNodeID, 1) - require.False(removed) - - _, removed = req.Remove(ids.BuildTestNodeID([]byte{0x01}), 0) - require.False(removed) - - require.True(req.Contains(ids.Empty)) - require.Equal(2, req.Len()) - - removedID, removed := req.Remove(ids.EmptyNodeID, 0) - require.True(removed) - require.Equal(ids.Empty, removedID) - - removedID, removed = req.Remove(ids.EmptyNodeID, 10) - require.True(removed) - require.Equal(ids.Empty.Prefix(0), removedID) - - require.Zero(req.Len()) - - req.Add(ids.EmptyNodeID, 0, ids.Empty) - require.Equal(1, req.Len()) - - require.True(req.RemoveAny(ids.Empty)) - require.Zero(req.Len()) - - require.False(req.RemoveAny(ids.Empty)) - require.Zero(req.Len()) -} diff --git a/snow/engine/snowman/bootstrap/bootstrapper.go b/snow/engine/snowman/bootstrap/bootstrapper.go index b575229954fb..0e2c7e0dab16 100644 --- a/snow/engine/snowman/bootstrap/bootstrapper.go +++ b/snow/engine/snowman/bootstrap/bootstrapper.go @@ -21,6 +21,7 @@ import ( "github.com/ava-labs/avalanchego/snow/consensus/snowman/bootstrapper" "github.com/ava-labs/avalanchego/snow/engine/common" "github.com/ava-labs/avalanchego/snow/engine/snowman/block" + "github.com/ava-labs/avalanchego/utils/bimap" "github.com/ava-labs/avalanchego/utils/set" "github.com/ava-labs/avalanchego/utils/timer" "github.com/ava-labs/avalanchego/version" @@ -91,7 +92,8 @@ type Bootstrapper struct { // Time that startSyncing was last called startTime time.Time - common.Fetcher + // tracks which validators were asked for which containers in which requests + outstandingRequests *bimap.BiMap[common.Request, ids.ID] // number of state transitions executed executedStateTransitions int @@ -112,6 +114,9 @@ type Bootstrapper struct { // bootstrappedOnce ensures that the [Bootstrapped] callback is only invoked // once, even if bootstrapping is retried. bootstrappedOnce sync.Once + + // Called when bootstrapping is done on a specific chain + onFinished func(ctx context.Context, lastReqID uint32) error } func New(config Config, onFinished func(ctx context.Context, lastReqID uint32) error) (*Bootstrapper, error) { @@ -129,10 +134,10 @@ func New(config Config, onFinished func(ctx context.Context, lastReqID uint32) e minority: bootstrapper.Noop, majority: bootstrapper.Noop, - Fetcher: common.Fetcher{ - OnFinished: onFinished, - }, + outstandingRequests: bimap.New[common.Request, ids.ID](), + executedStateTransitions: math.MaxInt, + onFinished: onFinished, }, err } @@ -425,7 +430,7 @@ func (b *Bootstrapper) startSyncing(ctx context.Context, acceptedContainerIDs [] // Get block [blkID] and its ancestors from a validator func (b *Bootstrapper) fetch(ctx context.Context, blkID ids.ID) error { // Make sure we haven't already requested this block - if b.OutstandingRequests.Contains(blkID) { + if b.outstandingRequests.HasValue(blkID) { return nil } @@ -444,7 +449,13 @@ func (b *Bootstrapper) fetch(ctx context.Context, blkID ids.ID) error { b.requestID++ - b.OutstandingRequests.Add(validatorID, b.requestID, blkID) + b.outstandingRequests.Put( + common.Request{ + NodeID: validatorID, + RequestID: b.requestID, + }, + blkID, + ) b.Config.Sender.SendGetAncestors(ctx, validatorID, b.requestID, blkID) // request block and ancestors return nil } @@ -453,7 +464,10 @@ func (b *Bootstrapper) fetch(ctx context.Context, blkID ids.ID) error { // response to a GetAncestors message to [nodeID] with request ID [requestID] func (b *Bootstrapper) Ancestors(ctx context.Context, nodeID ids.NodeID, requestID uint32, blks [][]byte) error { // Make sure this is in response to a request we made - wantedBlkID, ok := b.OutstandingRequests.Remove(nodeID, requestID) + wantedBlkID, ok := b.outstandingRequests.DeleteKey(common.Request{ + NodeID: nodeID, + RequestID: requestID, + }) if !ok { // this message isn't in response to a request we made b.Ctx.Log.Debug("received unexpected Ancestors", zap.Stringer("nodeID", nodeID), @@ -522,7 +536,10 @@ func (b *Bootstrapper) Ancestors(ctx context.Context, nodeID ids.NodeID, request } func (b *Bootstrapper) GetAncestorsFailed(ctx context.Context, nodeID ids.NodeID, requestID uint32) error { - blkID, ok := b.OutstandingRequests.Remove(nodeID, requestID) + blkID, ok := b.outstandingRequests.DeleteKey(common.Request{ + NodeID: nodeID, + RequestID: requestID, + }) if !ok { b.Ctx.Log.Debug("unexpectedly called GetAncestorsFailed", zap.Stringer("nodeID", nodeID), @@ -745,7 +762,7 @@ func (b *Bootstrapper) tryStartExecuting(ctx context.Context) error { return nil } b.fetchETA.Set(0) - return b.OnFinished(ctx, b.requestID) + return b.onFinished(ctx, b.requestID) } func (b *Bootstrapper) Timeout(ctx context.Context) error { @@ -758,7 +775,7 @@ func (b *Bootstrapper) Timeout(ctx context.Context) error { return b.restartBootstrapping(ctx) } b.fetchETA.Set(0) - return b.OnFinished(ctx, b.requestID) + return b.onFinished(ctx, b.requestID) } func (b *Bootstrapper) restartBootstrapping(ctx context.Context) error { diff --git a/snow/engine/snowman/bootstrap/bootstrapper_test.go b/snow/engine/snowman/bootstrap/bootstrapper_test.go index 6cf69d797ff5..83cbca730ba5 100644 --- a/snow/engine/snowman/bootstrap/bootstrapper_test.go +++ b/snow/engine/snowman/bootstrap/bootstrapper_test.go @@ -1131,7 +1131,8 @@ func TestRestartBootstrapping(t *testing.T) { require.Contains(requestIDs, blkID1) // Remove request, so we can restart bootstrapping via startSyncing - require.True(bs.OutstandingRequests.RemoveAny(blkID1)) + _, removed := bs.outstandingRequests.DeleteValue(blkID1) + require.True(removed) requestIDs = map[ids.ID]uint32{} require.NoError(bs.startSyncing(context.Background(), []ids.ID{blkID4})) diff --git a/snow/engine/snowman/transitive.go b/snow/engine/snowman/transitive.go index 8e2b98dc5a38..f0ce42ecf912 100644 --- a/snow/engine/snowman/transitive.go +++ b/snow/engine/snowman/transitive.go @@ -23,6 +23,7 @@ import ( "github.com/ava-labs/avalanchego/snow/event" "github.com/ava-labs/avalanchego/snow/validators" "github.com/ava-labs/avalanchego/utils/bag" + "github.com/ava-labs/avalanchego/utils/bimap" "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/math" @@ -64,7 +65,7 @@ type Transitive struct { polls poll.Set // blocks that have we have sent get requests for but haven't yet received - blkReqs common.Requests + blkReqs *bimap.BiMap[common.Request, ids.ID] // blocks that are queued to be issued to consensus once missing dependencies are fetched // Block ID --> Block @@ -140,6 +141,7 @@ func newTransitive(config Config) (*Transitive, error) { nonVerifiedCache: nonVerifiedCache, acceptedFrontiers: acceptedFrontiers, polls: polls, + blkReqs: bimap.New[common.Request, ids.ID](), } return t, t.metrics.Initialize("", config.Ctx.Registerer) @@ -169,7 +171,10 @@ func (t *Transitive) Put(ctx context.Context, nodeID ids.NodeID, requestID uint3 } actualBlkID := blk.ID() - expectedBlkID, ok := t.blkReqs.Get(nodeID, requestID) + expectedBlkID, ok := t.blkReqs.GetValue(common.Request{ + NodeID: nodeID, + RequestID: requestID, + }) // If the provided block is not the requested block, we need to explicitly // mark the request as failed to avoid having a dangling dependency. if ok && actualBlkID != expectedBlkID { @@ -202,7 +207,10 @@ func (t *Transitive) Put(ctx context.Context, nodeID ids.NodeID, requestID uint3 func (t *Transitive) GetFailed(ctx context.Context, nodeID ids.NodeID, requestID uint32) error { // We don't assume that this function is called after a failed Get message. // Check to see if we have an outstanding request and also get what the request was for if it exists. - blkID, ok := t.blkReqs.Remove(nodeID, requestID) + blkID, ok := t.blkReqs.DeleteKey(common.Request{ + NodeID: nodeID, + RequestID: requestID, + }) if !ok { t.Ctx.Log.Debug("unexpected GetFailed", zap.Stringer("nodeID", nodeID), @@ -658,7 +666,7 @@ func (t *Transitive) issueFrom(ctx context.Context, nodeID ids.NodeID, blk snowm } // Remove any outstanding requests for this block - t.blkReqs.RemoveAny(blkID) + t.blkReqs.DeleteValue(blkID) issued := t.Consensus.Decided(blk) || t.Consensus.Processing(blkID) if issued { @@ -702,7 +710,7 @@ func (t *Transitive) issueWithAncestors(ctx context.Context, blk snowman.Block) // There's an outstanding request for this block. // We can just wait for that request to succeed or fail. - if t.blkReqs.Contains(blkID) { + if t.blkReqs.HasValue(blkID) { return false, nil } @@ -731,7 +739,7 @@ func (t *Transitive) issue(ctx context.Context, blk snowman.Block, push bool) er t.pending[blkID] = blk // Remove any outstanding requests for this block - t.blkReqs.RemoveAny(blkID) + t.blkReqs.DeleteValue(blkID) // Will add [blk] to consensus once its ancestors have been i := &issuer{ @@ -762,12 +770,18 @@ func (t *Transitive) issue(ctx context.Context, blk snowman.Block, push bool) er // Request that [vdr] send us block [blkID] func (t *Transitive) sendRequest(ctx context.Context, nodeID ids.NodeID, blkID ids.ID) { // There is already an outstanding request for this block - if t.blkReqs.Contains(blkID) { + if t.blkReqs.HasValue(blkID) { return } t.RequestID++ - t.blkReqs.Add(nodeID, t.RequestID, blkID) + t.blkReqs.Put( + common.Request{ + NodeID: nodeID, + RequestID: t.RequestID, + }, + blkID, + ) t.Ctx.Log.Verbo("sending Get request", zap.Stringer("nodeID", nodeID), zap.Uint32("requestID", t.RequestID), @@ -917,13 +931,13 @@ func (t *Transitive) deliver(ctx context.Context, blk snowman.Block, push bool) t.removeFromPending(blk) t.blocked.Fulfill(ctx, blkID) - t.blkReqs.RemoveAny(blkID) + t.blkReqs.DeleteValue(blkID) } for _, blk := range dropped { blkID := blk.ID() t.removeFromPending(blk) t.blocked.Abandon(ctx, blkID) - t.blkReqs.RemoveAny(blkID) + t.blkReqs.DeleteValue(blkID) } // If we should issue multiple queries at the same time, we need to repoll diff --git a/utils/bimap/bimap.go b/utils/bimap/bimap.go new file mode 100644 index 000000000000..28d20750bace --- /dev/null +++ b/utils/bimap/bimap.go @@ -0,0 +1,102 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package bimap + +import "github.com/ava-labs/avalanchego/utils" + +type Entry[K, V any] struct { + Key K + Value V +} + +// BiMap is a bi-directional map. +type BiMap[K, V comparable] struct { + keyToValue map[K]V + valueToKey map[V]K +} + +// New creates a new empty bimap. +func New[K, V comparable]() *BiMap[K, V] { + return &BiMap[K, V]{ + keyToValue: make(map[K]V), + valueToKey: make(map[V]K), + } +} + +// Put the key value pair into the map. If either [key] or [val] was previously +// in the map, the previous entries will be removed and returned. +// +// Note: Unlike normal maps, it's possible that Put removes 0, 1, or 2 existing +// entries to ensure that mappings are one-to-one. +func (m *BiMap[K, V]) Put(key K, val V) []Entry[K, V] { + var removed []Entry[K, V] + oldVal, oldValDeleted := m.DeleteKey(key) + if oldValDeleted { + removed = append(removed, Entry[K, V]{ + Key: key, + Value: oldVal, + }) + } + oldKey, oldKeyDeleted := m.DeleteValue(val) + if oldKeyDeleted { + removed = append(removed, Entry[K, V]{ + Key: oldKey, + Value: val, + }) + } + m.keyToValue[key] = val + m.valueToKey[val] = key + return removed +} + +// GetKey that maps to the provided value. +func (m *BiMap[K, V]) GetKey(val V) (K, bool) { + key, ok := m.valueToKey[val] + return key, ok +} + +// GetValue that is mapped to the provided key. +func (m *BiMap[K, V]) GetValue(key K) (V, bool) { + val, ok := m.keyToValue[key] + return val, ok +} + +// HasKey returns true if [key] is in the map. +func (m *BiMap[K, _]) HasKey(key K) bool { + _, ok := m.keyToValue[key] + return ok +} + +// HasValue returns true if [val] is in the map. +func (m *BiMap[_, V]) HasValue(val V) bool { + _, ok := m.valueToKey[val] + return ok +} + +// DeleteKey removes [key] from the map and returns the value it mapped to. +func (m *BiMap[K, V]) DeleteKey(key K) (V, bool) { + val, ok := m.keyToValue[key] + if !ok { + return utils.Zero[V](), false + } + delete(m.keyToValue, key) + delete(m.valueToKey, val) + return val, true +} + +// DeleteValue removes [val] from the map and returns the key that mapped to it. +func (m *BiMap[K, V]) DeleteValue(val V) (K, bool) { + key, ok := m.valueToKey[val] + if !ok { + return utils.Zero[K](), false + } + delete(m.keyToValue, key) + delete(m.valueToKey, val) + return key, true +} + +// Len return the number of entries in this map. +func (m *BiMap[K, V]) Len() int { + return len(m.keyToValue) +} diff --git a/utils/bimap/bimap_test.go b/utils/bimap/bimap_test.go new file mode 100644 index 000000000000..9914578c6070 --- /dev/null +++ b/utils/bimap/bimap_test.go @@ -0,0 +1,328 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package bimap + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBiMapPut(t *testing.T) { + tests := []struct { + name string + state *BiMap[int, int] + key int + value int + expectedRemoved []Entry[int, int] + expectedState *BiMap[int, int] + }{ + { + name: "none removed", + state: New[int, int](), + key: 1, + value: 2, + expectedRemoved: nil, + expectedState: &BiMap[int, int]{ + keyToValue: map[int]int{ + 1: 2, + }, + valueToKey: map[int]int{ + 2: 1, + }, + }, + }, + { + name: "key removed", + state: &BiMap[int, int]{ + keyToValue: map[int]int{ + 1: 2, + }, + valueToKey: map[int]int{ + 2: 1, + }, + }, + key: 1, + value: 3, + expectedRemoved: []Entry[int, int]{ + { + Key: 1, + Value: 2, + }, + }, + expectedState: &BiMap[int, int]{ + keyToValue: map[int]int{ + 1: 3, + }, + valueToKey: map[int]int{ + 3: 1, + }, + }, + }, + { + name: "value removed", + state: &BiMap[int, int]{ + keyToValue: map[int]int{ + 1: 2, + }, + valueToKey: map[int]int{ + 2: 1, + }, + }, + key: 3, + value: 2, + expectedRemoved: []Entry[int, int]{ + { + Key: 1, + Value: 2, + }, + }, + expectedState: &BiMap[int, int]{ + keyToValue: map[int]int{ + 3: 2, + }, + valueToKey: map[int]int{ + 2: 3, + }, + }, + }, + { + name: "key and value removed", + state: &BiMap[int, int]{ + keyToValue: map[int]int{ + 1: 2, + 3: 4, + }, + valueToKey: map[int]int{ + 2: 1, + 4: 3, + }, + }, + key: 1, + value: 4, + expectedRemoved: []Entry[int, int]{ + { + Key: 1, + Value: 2, + }, + { + Key: 3, + Value: 4, + }, + }, + expectedState: &BiMap[int, int]{ + keyToValue: map[int]int{ + 1: 4, + }, + valueToKey: map[int]int{ + 4: 1, + }, + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require := require.New(t) + + removed := test.state.Put(test.key, test.value) + require.Equal(test.expectedRemoved, removed) + require.Equal(test.expectedState, test.state) + }) + } +} + +func TestBiMapHasValueAndGetKey(t *testing.T) { + m := New[int, int]() + require.Empty(t, m.Put(1, 2)) + + tests := []struct { + name string + value int + expectedKey int + expectedExists bool + }{ + { + name: "fetch unknown", + value: 3, + expectedKey: 0, + expectedExists: false, + }, + { + name: "fetch known value", + value: 2, + expectedKey: 1, + expectedExists: true, + }, + { + name: "fetch known key", + value: 1, + expectedKey: 0, + expectedExists: false, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require := require.New(t) + + exists := m.HasValue(test.value) + require.Equal(test.expectedExists, exists) + + key, exists := m.GetKey(test.value) + require.Equal(test.expectedKey, key) + require.Equal(test.expectedExists, exists) + }) + } +} + +func TestBiMapHasKeyAndGetValue(t *testing.T) { + m := New[int, int]() + require.Empty(t, m.Put(1, 2)) + + tests := []struct { + name string + key int + expectedValue int + expectedExists bool + }{ + { + name: "fetch unknown", + key: 3, + expectedValue: 0, + expectedExists: false, + }, + { + name: "fetch known key", + key: 1, + expectedValue: 2, + expectedExists: true, + }, + { + name: "fetch known value", + key: 2, + expectedValue: 0, + expectedExists: false, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require := require.New(t) + + exists := m.HasKey(test.key) + require.Equal(test.expectedExists, exists) + + value, exists := m.GetValue(test.key) + require.Equal(test.expectedValue, value) + require.Equal(test.expectedExists, exists) + }) + } +} + +func TestBiMapDeleteKey(t *testing.T) { + tests := []struct { + name string + state *BiMap[int, int] + key int + expectedValue int + expectedRemoved bool + expectedState *BiMap[int, int] + }{ + { + name: "none removed", + state: New[int, int](), + key: 1, + expectedValue: 0, + expectedRemoved: false, + expectedState: New[int, int](), + }, + { + name: "key removed", + state: &BiMap[int, int]{ + keyToValue: map[int]int{ + 1: 2, + }, + valueToKey: map[int]int{ + 2: 1, + }, + }, + key: 1, + expectedValue: 2, + expectedRemoved: true, + expectedState: New[int, int](), + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require := require.New(t) + + value, removed := test.state.DeleteKey(test.key) + require.Equal(test.expectedValue, value) + require.Equal(test.expectedRemoved, removed) + require.Equal(test.expectedState, test.state) + }) + } +} + +func TestBiMapDeleteValue(t *testing.T) { + tests := []struct { + name string + state *BiMap[int, int] + value int + expectedKey int + expectedRemoved bool + expectedState *BiMap[int, int] + }{ + { + name: "none removed", + state: New[int, int](), + value: 1, + expectedKey: 0, + expectedRemoved: false, + expectedState: New[int, int](), + }, + { + name: "key removed", + state: &BiMap[int, int]{ + keyToValue: map[int]int{ + 1: 2, + }, + valueToKey: map[int]int{ + 2: 1, + }, + }, + value: 2, + expectedKey: 1, + expectedRemoved: true, + expectedState: New[int, int](), + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require := require.New(t) + + key, removed := test.state.DeleteValue(test.value) + require.Equal(test.expectedKey, key) + require.Equal(test.expectedRemoved, removed) + require.Equal(test.expectedState, test.state) + }) + } +} + +func TestBiMapLen(t *testing.T) { + require := require.New(t) + + m := New[int, int]() + require.Zero(m.Len()) + + m.Put(1, 2) + require.Equal(1, m.Len()) + + m.Put(2, 3) + require.Equal(2, m.Len()) + + m.Put(1, 3) + require.Equal(1, m.Len()) + + m.DeleteKey(1) + require.Zero(m.Len()) +}