Skip to content

Commit

Permalink
vms/platformvm: Move GetRewardUTXOs, GetSubnets, and `GetChains…
Browse files Browse the repository at this point in the history
…` to `State` interface (#2402)
  • Loading branch information
dhrubabasu authored Nov 30, 2023
1 parent de3b16c commit 9b85141
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 270 deletions.
97 changes: 1 addition & 96 deletions vms/platformvm/state/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,8 @@ type diff struct {
subnetOwners map[ids.ID]fx.Owner
// Subnet ID --> Tx that transforms the subnet
transformedSubnets map[ids.ID]*txs.Tx
cachedSubnets []*txs.Tx

addedChains map[ids.ID][]*txs.Tx
cachedChains map[ids.ID][]*txs.Tx
addedChains map[ids.ID][]*txs.Tx

addedRewardUTXOs map[ids.ID][]*avax.UTXO

Expand Down Expand Up @@ -259,41 +257,8 @@ func (d *diff) GetPendingStakerIterator() (StakerIterator, error) {
return d.pendingStakerDiffs.GetStakerIterator(parentIterator), nil
}

func (d *diff) GetSubnets() ([]*txs.Tx, error) {
if len(d.addedSubnets) == 0 {
parentState, ok := d.stateVersions.GetState(d.parentID)
if !ok {
return nil, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID)
}
return parentState.GetSubnets()
}

if len(d.cachedSubnets) != 0 {
return d.cachedSubnets, nil
}

parentState, ok := d.stateVersions.GetState(d.parentID)
if !ok {
return nil, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID)
}
subnets, err := parentState.GetSubnets()
if err != nil {
return nil, err
}
newSubnets := make([]*txs.Tx, len(subnets)+len(d.addedSubnets))
copy(newSubnets, subnets)
for i, subnet := range d.addedSubnets {
newSubnets[i+len(subnets)] = subnet
}
d.cachedSubnets = newSubnets
return newSubnets, nil
}

func (d *diff) AddSubnet(createSubnetTx *txs.Tx) {
d.addedSubnets = append(d.addedSubnets, createSubnetTx)
if d.cachedSubnets != nil {
d.cachedSubnets = append(d.cachedSubnets, createSubnetTx)
}
}

func (d *diff) GetSubnetOwner(subnetID ids.ID) (fx.Owner, error) {
Expand Down Expand Up @@ -339,48 +304,6 @@ func (d *diff) AddSubnetTransformation(transformSubnetTxIntf *txs.Tx) {
}
}

func (d *diff) GetChains(subnetID ids.ID) ([]*txs.Tx, error) {
addedChains := d.addedChains[subnetID]
if len(addedChains) == 0 {
// No chains have been added to this subnet
parentState, ok := d.stateVersions.GetState(d.parentID)
if !ok {
return nil, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID)
}
return parentState.GetChains(subnetID)
}

// There have been chains added to the requested subnet

if d.cachedChains == nil {
// This is the first time we are going to be caching the subnet chains
d.cachedChains = make(map[ids.ID][]*txs.Tx)
}

cachedChains, cached := d.cachedChains[subnetID]
if cached {
return cachedChains, nil
}

// This chain wasn't cached yet
parentState, ok := d.stateVersions.GetState(d.parentID)
if !ok {
return nil, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID)
}
chains, err := parentState.GetChains(subnetID)
if err != nil {
return nil, err
}

newChains := make([]*txs.Tx, len(chains)+len(addedChains))
copy(newChains, chains)
for i, chain := range addedChains {
newChains[i+len(chains)] = chain
}
d.cachedChains[subnetID] = newChains
return newChains, nil
}

func (d *diff) AddChain(createChainTx *txs.Tx) {
tx := createChainTx.Unsigned.(*txs.CreateChainTx)
if d.addedChains == nil {
Expand All @@ -390,12 +313,6 @@ func (d *diff) AddChain(createChainTx *txs.Tx) {
} else {
d.addedChains[tx.SubnetID] = append(d.addedChains[tx.SubnetID], createChainTx)
}

cachedChains, cached := d.cachedChains[tx.SubnetID]
if !cached {
return
}
d.cachedChains[tx.SubnetID] = append(cachedChains, createChainTx)
}

func (d *diff) GetTx(txID ids.ID) (*txs.Tx, status.Status, error) {
Expand Down Expand Up @@ -425,18 +342,6 @@ func (d *diff) AddTx(tx *txs.Tx, status status.Status) {
}
}

func (d *diff) GetRewardUTXOs(txID ids.ID) ([]*avax.UTXO, error) {
if utxos, exists := d.addedRewardUTXOs[txID]; exists {
return utxos, nil
}

parentState, ok := d.stateVersions.GetState(d.parentID)
if !ok {
return nil, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID)
}
return parentState.GetRewardUTXOs(txID)
}

func (d *diff) AddRewardUTXO(txID ids.ID, utxo *avax.UTXO) {
if d.addedRewardUTXOs == nil {
d.addedRewardUTXOs = make(map[ids.ID][]*avax.UTXO)
Expand Down
164 changes: 83 additions & 81 deletions vms/platformvm/state/diff_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,15 +250,28 @@ func TestDiffSubnet(t *testing.T) {
require := require.New(t)
ctrl := gomock.NewController(t)

state := NewMockState(ctrl)
// Called in NewDiff
state.EXPECT().GetTimestamp().Return(time.Now()).Times(1)
state, _ := newInitializedState(require)

// Initialize parent with one subnet
parentStateCreateSubnetTx := &txs.Tx{
Unsigned: &txs.CreateSubnetTx{
Owner: fx.NewMockOwner(ctrl),
},
}
state.AddSubnet(parentStateCreateSubnetTx)

// Verify parent returns one subnet
subnets, err := state.GetSubnets()
require.NoError(err)
require.Equal([]*txs.Tx{
parentStateCreateSubnetTx,
}, subnets)

states := NewMockVersions(ctrl)
lastAcceptedID := ids.GenerateTestID()
states.EXPECT().GetState(lastAcceptedID).Return(state, true).AnyTimes()

d, err := NewDiff(lastAcceptedID, states)
diff, err := NewDiff(lastAcceptedID, states)
require.NoError(err)

// Put a subnet
Expand All @@ -267,60 +280,67 @@ func TestDiffSubnet(t *testing.T) {
Owner: fx.NewMockOwner(ctrl),
},
}
d.AddSubnet(createSubnetTx)
diff.AddSubnet(createSubnetTx)

// Assert that we get the subnet back
// [state] returns 1 subnet.
parentStateCreateSubnetTx := &txs.Tx{
Unsigned: &txs.CreateSubnetTx{
Owner: fx.NewMockOwner(ctrl),
},
}
state.EXPECT().GetSubnets().Return([]*txs.Tx{parentStateCreateSubnetTx}, nil).Times(1)
gotSubnets, err := d.GetSubnets()
// Apply diff to parent state
require.NoError(diff.Apply(state))

// Verify parent now returns two subnets
subnets, err = state.GetSubnets()
require.NoError(err)
require.Len(gotSubnets, 2)
require.Equal(gotSubnets[0], parentStateCreateSubnetTx)
require.Equal(gotSubnets[1], createSubnetTx)
require.Equal([]*txs.Tx{
parentStateCreateSubnetTx,
createSubnetTx,
}, subnets)
}

func TestDiffChain(t *testing.T) {
require := require.New(t)
ctrl := gomock.NewController(t)

state := NewMockState(ctrl)
// Called in NewDiff
state.EXPECT().GetTimestamp().Return(time.Now()).Times(1)
state, _ := newInitializedState(require)
subnetID := ids.GenerateTestID()

// Initialize parent with one chain
parentStateCreateChainTx := &txs.Tx{
Unsigned: &txs.CreateChainTx{
SubnetID: subnetID,
},
}
state.AddChain(parentStateCreateChainTx)

// Verify parent returns one chain
chains, err := state.GetChains(subnetID)
require.NoError(err)
require.Equal([]*txs.Tx{
parentStateCreateChainTx,
}, chains)

states := NewMockVersions(ctrl)
lastAcceptedID := ids.GenerateTestID()
states.EXPECT().GetState(lastAcceptedID).Return(state, true).AnyTimes()

d, err := NewDiff(lastAcceptedID, states)
diff, err := NewDiff(lastAcceptedID, states)
require.NoError(err)

// Put a chain
subnetID := ids.GenerateTestID()
createChainTx := &txs.Tx{
Unsigned: &txs.CreateChainTx{
SubnetID: subnetID,
SubnetID: subnetID, // note this is the same subnet as [parentStateCreateChainTx]
},
}
d.AddChain(createChainTx)
diff.AddChain(createChainTx)

// Assert that we get the chain back
// [state] returns 1 chain.
parentStateCreateChainTx := &txs.Tx{
Unsigned: &txs.CreateChainTx{
SubnetID: subnetID, // note this is the same subnet as [createChainTx]
},
}
state.EXPECT().GetChains(subnetID).Return([]*txs.Tx{parentStateCreateChainTx}, nil).Times(1)
gotChains, err := d.GetChains(subnetID)
// Apply diff to parent state
require.NoError(diff.Apply(state))

// Verify parent now returns two chains
chains, err = state.GetChains(subnetID)
require.NoError(err)
require.Len(gotChains, 2)
require.Equal(parentStateCreateChainTx, gotChains[0])
require.Equal(createChainTx, gotChains[1])
require.Equal([]*txs.Tx{
parentStateCreateChainTx,
createChainTx,
}, chains)
}

func TestDiffTx(t *testing.T) {
Expand Down Expand Up @@ -377,45 +397,46 @@ func TestDiffRewardUTXO(t *testing.T) {
require := require.New(t)
ctrl := gomock.NewController(t)

state := NewMockState(ctrl)
// Called in NewDiff
state.EXPECT().GetTimestamp().Return(time.Now()).Times(1)
state, _ := newInitializedState(require)

txID := ids.GenerateTestID()

// Initialize parent with one reward UTXO
parentRewardUTXO := &avax.UTXO{
UTXOID: avax.UTXOID{TxID: txID},
}
state.AddRewardUTXO(txID, parentRewardUTXO)

// Verify parent returns the reward UTXO
rewardUTXOs, err := state.GetRewardUTXOs(txID)
require.NoError(err)
require.Equal([]*avax.UTXO{
parentRewardUTXO,
}, rewardUTXOs)

states := NewMockVersions(ctrl)
lastAcceptedID := ids.GenerateTestID()
states.EXPECT().GetState(lastAcceptedID).Return(state, true).AnyTimes()

d, err := NewDiff(lastAcceptedID, states)
diff, err := NewDiff(lastAcceptedID, states)
require.NoError(err)

// Put a reward UTXO
txID := ids.GenerateTestID()
rewardUTXO := &avax.UTXO{
UTXOID: avax.UTXOID{TxID: txID},
}
d.AddRewardUTXO(txID, rewardUTXO)
diff.AddRewardUTXO(txID, rewardUTXO)

{
// Assert that we get the UTXO back
gotRewardUTXOs, err := d.GetRewardUTXOs(txID)
require.NoError(err)
require.Len(gotRewardUTXOs, 1)
require.Equal(rewardUTXO, gotRewardUTXOs[0])
}
// Apply diff to parent state
require.NoError(diff.Apply(state))

{
// Assert that we can get a UTXO from the parent state
// [state] returns 1 UTXO.
txID2 := ids.GenerateTestID()
parentRewardUTXO := &avax.UTXO{
UTXOID: avax.UTXOID{TxID: txID2},
}
state.EXPECT().GetRewardUTXOs(txID2).Return([]*avax.UTXO{parentRewardUTXO}, nil).Times(1)
gotParentRewardUTXOs, err := d.GetRewardUTXOs(txID2)
require.NoError(err)
require.Len(gotParentRewardUTXOs, 1)
require.Equal(parentRewardUTXO, gotParentRewardUTXOs[0])
}
// Verify parent now returns two reward UTXOs
rewardUTXOs, err = state.GetRewardUTXOs(txID)
require.NoError(err)
require.Equal([]*avax.UTXO{
parentRewardUTXO,
rewardUTXO,
}, rewardUTXOs)
}

func TestDiffUTXO(t *testing.T) {
Expand Down Expand Up @@ -496,25 +517,6 @@ func assertChainsEqual(t *testing.T, expected, actual Chain) {
require.NoError(err)

require.Equal(expectedCurrentSupply, actualCurrentSupply)

expectedSubnets, expectedErr := expected.GetSubnets()
actualSubnets, actualErr := actual.GetSubnets()
require.Equal(expectedErr, actualErr)
if expectedErr == nil {
require.Equal(expectedSubnets, actualSubnets)

for _, subnet := range expectedSubnets {
subnetID := subnet.ID()

expectedChains, expectedErr := expected.GetChains(subnetID)
actualChains, actualErr := actual.GetChains(subnetID)
require.Equal(expectedErr, actualErr)
if expectedErr != nil {
continue
}
require.Equal(expectedChains, actualChains)
}
}
}

func TestDiffSubnetOwner(t *testing.T) {
Expand Down
Loading

0 comments on commit 9b85141

Please sign in to comment.