diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/test.unit.yml similarity index 71% rename from .github/workflows/build-and-test.yml rename to .github/workflows/test.unit.yml index 2794efb30c80..714a966a3b8c 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/test.unit.yml @@ -1,4 +1,4 @@ -name: Build + Unit Tests +name: Unit Tests on: push: @@ -9,6 +9,11 @@ on: merge_group: types: [checks_requested] +# Cancel ongoing workflow runs if a new one is started +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: run_build_unit_tests: name: build_unit_test @@ -25,6 +30,4 @@ jobs: - name: fuzz_test shell: bash if: matrix.os == 'ubuntu-22.04' # Only run on Ubuntu 22.04 - run: | # Run each fuzz test 15 seconds - cd ${{ github.workspace }} - ./scripts/build_fuzz.sh 15 + run: ./scripts/build_fuzz.sh 15 # Run each fuzz test 15 seconds diff --git a/api/info/service.go b/api/info/service.go index a520756b63bb..9cd6b26b0bf8 100644 --- a/api/info/service.go +++ b/api/info/service.go @@ -27,7 +27,6 @@ import ( "github.com/ava-labs/avalanchego/network" "github.com/ava-labs/avalanchego/network/peer" "github.com/ava-labs/avalanchego/snow/networking/benchlist" - "github.com/ava-labs/avalanchego/snow/validators" "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/ips" "github.com/ava-labs/avalanchego/utils/json" @@ -47,7 +46,6 @@ type Info struct { networking network.Network chainManager chains.Manager vmManager vms.Manager - validators validators.Set benchlist benchlist.Manager } @@ -78,7 +76,6 @@ func NewService( vmManager vms.Manager, myIP ips.DynamicIPPort, network network.Network, - validators validators.Set, benchlist benchlist.Manager, ) (http.Handler, error) { server := rpc.NewServer() @@ -93,7 +90,6 @@ func NewService( vmManager: vmManager, myIP: myIP, networking: network, - validators: validators, benchlist: benchlist, }, "info", diff --git a/api/keystore/codec.go b/api/keystore/codec.go index df6c18ae9521..ebb196ccbfff 100644 --- a/api/keystore/codec.go +++ b/api/keystore/codec.go @@ -11,7 +11,7 @@ import ( const ( maxPackerSize = 1 * units.GiB // max size, in bytes, of something being marshalled by Marshal() - maxSliceLength = 256 * 1024 + maxSliceLength = linearcodec.DefaultMaxSliceLength codecVersion = 0 ) diff --git a/chains/manager.go b/chains/manager.go index 5a769953225a..2234e82aadcc 100644 --- a/chains/manager.go +++ b/chains/manager.go @@ -148,7 +148,7 @@ type ChainParameters struct { // The IDs of the feature extensions this chain is running. FxIDs []ids.ID // Invariant: Only used when [ID] is the P-chain ID. - CustomBeacons validators.Set + CustomBeacons validators.Manager } type chain struct { @@ -156,7 +156,7 @@ type chain struct { Context *snow.ConsensusContext VM common.VM Handler handler.Handler - Beacons validators.Set + Beacons validators.Manager } // ChainConfig is configuration settings for the current execution. @@ -531,24 +531,13 @@ func (m *manager) buildChain(chainParams ChainParameters, sb subnets.Subnet) (*c } } - var vdrs validators.Set // Validators validating this blockchain - var hasValidators bool - if m.SybilProtectionEnabled { - vdrs, hasValidators = m.Validators.Get(chainParams.SubnetID) - } else { // Sybil protection is disabled. Every peer validates every subnet. - vdrs, hasValidators = m.Validators.Get(constants.PrimaryNetworkID) - } - if !hasValidators { - return nil, fmt.Errorf("couldn't get validator set of subnet with ID %s. The subnet may not exist", chainParams.SubnetID) - } - var chain *chain switch vm := vm.(type) { case vertex.LinearizableVMWithEngine: chain, err = m.createAvalancheChain( ctx, chainParams.GenesisData, - vdrs, + m.Validators, vm, fxs, sb, @@ -557,7 +546,7 @@ func (m *manager) buildChain(chainParams ChainParameters, sb subnets.Subnet) (*c return nil, fmt.Errorf("error while creating new avalanche vm %w", err) } case block.ChainVM: - beacons := vdrs + beacons := m.Validators if chainParams.ID == constants.PlatformChainID { beacons = chainParams.CustomBeacons } @@ -565,7 +554,7 @@ func (m *manager) buildChain(chainParams ChainParameters, sb subnets.Subnet) (*c chain, err = m.createSnowmanChain( ctx, chainParams.GenesisData, - vdrs, + m.Validators, beacons, vm, fxs, @@ -594,7 +583,7 @@ func (m *manager) AddRegistrant(r Registrant) { func (m *manager) createAvalancheChain( ctx *snow.ConsensusContext, genesisData []byte, - vdrs validators.Set, + vdrs validators.Manager, vm vertex.LinearizableVMWithEngine, fxs []*common.Fx, sb subnets.Subnet, @@ -816,7 +805,10 @@ func (m *manager) createAvalancheChain( appSender: snowmanMessageSender, } - bootstrapWeight := vdrs.Weight() + bootstrapWeight, err := vdrs.TotalWeight(ctx.SubnetID) + if err != nil { + return nil, fmt.Errorf("error while fetching weight for subnet %s: %w", ctx.SubnetID, err) + } consensusParams := sb.Config().ConsensusParameters sampleK := consensusParams.K @@ -828,7 +820,7 @@ func (m *manager) createAvalancheChain( if err != nil { return nil, fmt.Errorf("error creating peer tracker: %w", err) } - vdrs.RegisterCallbackListener(connectedValidators) + vdrs.RegisterCallbackListener(ctx.SubnetID, connectedValidators) // Asynchronously passes messages from the network to the consensus engine h, err := handler.New( @@ -848,7 +840,7 @@ func (m *manager) createAvalancheChain( connectedBeacons := tracker.NewPeers() startupTracker := tracker.NewStartup(connectedBeacons, (3*bootstrapWeight+3)/4) - vdrs.RegisterCallbackListener(startupTracker) + vdrs.RegisterCallbackListener(ctx.SubnetID, startupTracker) snowmanCommonCfg := common.Config{ Ctx: ctx, @@ -998,8 +990,8 @@ func (m *manager) createAvalancheChain( func (m *manager) createSnowmanChain( ctx *snow.ConsensusContext, genesisData []byte, - vdrs validators.Set, - beacons validators.Set, + vdrs validators.Manager, + beacons validators.Manager, vm block.ChainVM, fxs []*common.Fx, sb subnets.Subnet, @@ -1164,7 +1156,10 @@ func (m *manager) createSnowmanChain( return nil, err } - bootstrapWeight := beacons.Weight() + bootstrapWeight, err := beacons.TotalWeight(ctx.SubnetID) + if err != nil { + return nil, fmt.Errorf("error while fetching weight for subnet %s: %w", ctx.SubnetID, err) + } consensusParams := sb.Config().ConsensusParameters sampleK := consensusParams.K @@ -1176,7 +1171,7 @@ func (m *manager) createSnowmanChain( if err != nil { return nil, fmt.Errorf("error creating peer tracker: %w", err) } - vdrs.RegisterCallbackListener(connectedValidators) + vdrs.RegisterCallbackListener(ctx.SubnetID, connectedValidators) // Asynchronously passes messages from the network to the consensus engine h, err := handler.New( @@ -1196,7 +1191,7 @@ func (m *manager) createSnowmanChain( connectedBeacons := tracker.NewPeers() startupTracker := tracker.NewStartup(connectedBeacons, (3*bootstrapWeight+3)/4) - beacons.RegisterCallbackListener(startupTracker) + beacons.RegisterCallbackListener(ctx.SubnetID, startupTracker) commonCfg := common.Config{ Ctx: ctx, @@ -1358,7 +1353,7 @@ func (m *manager) registerBootstrappedHealthChecks() error { if !m.IsBootstrapped(constants.PlatformChainID) { return "node is currently bootstrapping", nil } - if !validators.Contains(m.Validators, constants.PrimaryNetworkID, m.NodeID) { + if _, ok := m.Validators.GetValidator(constants.PrimaryNetworkID, m.NodeID); !ok { return "node is not a primary network validator", nil } diff --git a/codec/linearcodec/camino_codec.go b/codec/linearcodec/camino_codec.go index eb1c1996f861..be7e4dd55f9b 100644 --- a/codec/linearcodec/camino_codec.go +++ b/codec/linearcodec/camino_codec.go @@ -50,7 +50,7 @@ func NewCamino(tagNames []string, maxSliceLen uint32) CaminoCodec { // NewDefault is a convenience constructor; it returns a new codec with reasonable default values func NewCaminoDefault() CaminoCodec { - return NewCamino([]string{reflectcodec.DefaultTagName}, defaultMaxSliceLength) + return NewCamino([]string{reflectcodec.DefaultTagName}, DefaultMaxSliceLength) } // NewCustomMaxLength is a convenience constructor; it returns a new codec with custom max length and default tags diff --git a/codec/linearcodec/camino_codec_test.go b/codec/linearcodec/camino_codec_test.go index d6eb5172e98c..fa6fa54d052a 100644 --- a/codec/linearcodec/camino_codec_test.go +++ b/codec/linearcodec/camino_codec_test.go @@ -18,7 +18,7 @@ func TestVectorsCamino(t *testing.T) { func TestMultipleTagsCamino(t *testing.T) { for _, test := range codec.MultipleTagsTests { - c := NewCamino([]string{"tag1", "tag2"}, defaultMaxSliceLength) + c := NewCamino([]string{"tag1", "tag2"}, DefaultMaxSliceLength) test(c, t) } } diff --git a/codec/linearcodec/codec.go b/codec/linearcodec/codec.go index d488b3a1d4c0..677c331b0366 100644 --- a/codec/linearcodec/codec.go +++ b/codec/linearcodec/codec.go @@ -15,7 +15,7 @@ import ( const ( // default max length of a slice being marshalled by Marshal(). Should be <= math.MaxUint32. - defaultMaxSliceLength = 256 * 1024 + DefaultMaxSliceLength = 256 * 1024 ) var ( @@ -56,7 +56,7 @@ func New(tagNames []string, maxSliceLen uint32) Codec { // NewDefault is a convenience constructor; it returns a new codec with reasonable default values func NewDefault() Codec { - return New([]string{reflectcodec.DefaultTagName}, defaultMaxSliceLength) + return New([]string{reflectcodec.DefaultTagName}, DefaultMaxSliceLength) } // NewCustomMaxLength is a convenience constructor; it returns a new codec with custom max length and default tags diff --git a/codec/linearcodec/codec_test.go b/codec/linearcodec/codec_test.go index 920789da572a..db8a4e720dd6 100644 --- a/codec/linearcodec/codec_test.go +++ b/codec/linearcodec/codec_test.go @@ -18,7 +18,7 @@ func TestVectors(t *testing.T) { func TestMultipleTags(t *testing.T) { for _, test := range codec.MultipleTagsTests { - c := New([]string{"tag1", "tag2"}, defaultMaxSliceLength) + c := New([]string{"tag1", "tag2"}, DefaultMaxSliceLength) test(c, t) } } diff --git a/config/config.go b/config/config.go index c25a93d78250..fe23832c5cd4 100644 --- a/config/config.go +++ b/config/config.go @@ -69,12 +69,20 @@ const ( chainUpgradeFileName = "upgrade" subnetConfigFileExt = ".json" ipResolutionTimeout = 30 * time.Second + + ipcDeprecationMsg = "IPC API is deprecated" + keystoreDeprecationMsg = "keystore API is deprecated" ) var ( // Deprecated key --> deprecation message (i.e. which key replaces it) // TODO: deprecate "BootstrapIDsKey" and "BootstrapIPsKey" - deprecatedKeys = map[string]string{} + deprecatedKeys = map[string]string{ + IpcAPIEnabledKey: ipcDeprecationMsg, + IpcsChainIDsKey: ipcDeprecationMsg, + IpcsPathKey: ipcDeprecationMsg, + KeystoreAPIEnabledKey: keystoreDeprecationMsg, + } errSybilProtectionDisabledStakerWeights = errors.New("sybil protection disabled weights must be positive") errSybilProtectionDisabledOnPublicNetwork = errors.New("sybil protection disabled on public network") diff --git a/database/iterator.go b/database/iterator.go index 3cfd075cc9d3..c83ceac49639 100644 --- a/database/iterator.go +++ b/database/iterator.go @@ -34,10 +34,12 @@ type Iterator interface { // Key returns the key of the current key/value pair, or nil if done. // If the database is closed, must still report the current contents. + // Behavior is undefined after Release is called. Key() []byte // Value returns the value of the current key/value pair, or nil if done. // If the database is closed, must still report the current contents. + // Behavior is undefined after Release is called. Value() []byte // Release releases associated resources. Release should always succeed and diff --git a/database/pebble/batch.go b/database/pebble/batch.go new file mode 100644 index 000000000000..b6c9d283b64d --- /dev/null +++ b/database/pebble/batch.go @@ -0,0 +1,111 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package pebble + +import ( + "fmt" + + "github.com/cockroachdb/pebble" + + "github.com/ava-labs/avalanchego/database" +) + +var _ database.Batch = (*batch)(nil) + +// Not safe for concurrent use. +type batch struct { + batch *pebble.Batch + db *Database + size int + + // True iff [batch] has been written to the database + // since the last time [Reset] was called. + written bool +} + +func (db *Database) NewBatch() database.Batch { + return &batch{ + db: db, + batch: db.pebbleDB.NewBatch(), + } +} + +func (b *batch) Put(key, value []byte) error { + b.size += len(key) + len(value) + pebbleByteOverHead + return b.batch.Set(key, value, pebble.Sync) +} + +func (b *batch) Delete(key []byte) error { + b.size += len(key) + pebbleByteOverHead + return b.batch.Delete(key, pebble.Sync) +} + +func (b *batch) Size() int { + return b.size +} + +// Assumes [b.db.lock] is not held. +func (b *batch) Write() error { + b.db.lock.RLock() + defer b.db.lock.RUnlock() + + // Committing to a closed database makes pebble panic + // so make sure [b.db] isn't closed. + if b.db.closed { + return database.ErrClosed + } + + if !b.written { + // This batch has not been written to the database yet. + if err := updateError(b.batch.Commit(pebble.Sync)); err != nil { + return err + } + b.written = true + return nil + } + + // pebble doesn't support writing a batch twice so we have to clone + // [b] and commit the clone. + batchClone := b.db.pebbleDB.NewBatch() + + // Copy the batch. + if err := batchClone.Apply(b.batch, nil); err != nil { + return err + } + + // Commit the new batch. + return updateError(batchClone.Commit(pebble.Sync)) +} + +func (b *batch) Reset() { + b.batch.Reset() + b.written = false + b.size = 0 +} + +func (b *batch) Replay(w database.KeyValueWriterDeleter) error { + reader := b.batch.Reader() + for { + kind, k, v, ok := reader.Next() + if !ok { + return nil + } + switch kind { + case pebble.InternalKeyKindSet: + if err := w.Put(k, v); err != nil { + return err + } + case pebble.InternalKeyKindDelete: + if err := w.Delete(k); err != nil { + return err + } + default: + return fmt.Errorf("%w: %v", errInvalidOperation, kind) + } + } +} + +func (b *batch) Inner() database.Batch { + return b +} diff --git a/database/pebble/batch_test.go b/database/pebble/batch_test.go new file mode 100644 index 000000000000..a84134708956 --- /dev/null +++ b/database/pebble/batch_test.go @@ -0,0 +1,47 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package pebble + +import ( + "testing" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/stretchr/testify/require" + + "github.com/ava-labs/avalanchego/utils/logging" +) + +// Note: TestInterface tests other batch functionality. +func TestBatch(t *testing.T) { + require := require.New(t) + dirName := t.TempDir() + + db, err := New(dirName, DefaultConfigBytes, logging.NoLog{}, "", prometheus.NewRegistry()) + require.NoError(err) + + batchIntf := db.NewBatch() + batch, ok := batchIntf.(*batch) + require.True(ok) + + require.False(batch.written) + + key1, value1 := []byte("key1"), []byte("value1") + require.NoError(batch.Put(key1, value1)) + require.Equal(len(key1)+len(value1)+pebbleByteOverHead, batch.Size()) + + require.NoError(batch.Write()) + + require.True(batch.written) + + got, err := db.Get(key1) + require.NoError(err) + require.Equal(value1, got) + + batch.Reset() + require.False(batch.written) + require.Zero(batch.Size()) + + require.NoError(db.Close()) +} diff --git a/database/pebble/db.go b/database/pebble/db.go new file mode 100644 index 000000000000..13ab1db04977 --- /dev/null +++ b/database/pebble/db.go @@ -0,0 +1,292 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package pebble + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "sync" + + "github.com/cockroachdb/pebble" + + "github.com/prometheus/client_golang/prometheus" + + "golang.org/x/exp/slices" + + "go.uber.org/zap" + + "github.com/ava-labs/avalanchego/database" + "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/avalanchego/utils/set" + "github.com/ava-labs/avalanchego/utils/units" +) + +// pebbleByteOverHead is the number of bytes of constant overhead that +// should be added to a batch size per operation. +const pebbleByteOverHead = 8 + +var ( + _ database.Database = (*Database)(nil) + + errInvalidOperation = errors.New("invalid operation") + + defaultCacheSize = 512 * units.MiB + DefaultConfig = Config{ + CacheSize: defaultCacheSize, + BytesPerSync: 512 * units.KiB, + WALBytesPerSync: 0, // Default to no background syncing. + MemTableStopWritesThreshold: 8, + MemTableSize: defaultCacheSize / 4, + MaxOpenFiles: 4096, + MaxConcurrentCompactions: 1, + } + + DefaultConfigBytes []byte +) + +func init() { + var err error + DefaultConfigBytes, err = json.Marshal(DefaultConfig) + if err != nil { + panic(err) + } +} + +type Database struct { + lock sync.RWMutex + pebbleDB *pebble.DB + closed bool + openIterators set.Set[*iter] +} + +type Config struct { + CacheSize int `json:"cacheSize"` + BytesPerSync int `json:"bytesPerSync"` + WALBytesPerSync int `json:"walBytesPerSync"` // 0 means no background syncing + MemTableStopWritesThreshold int `json:"memTableStopWritesThreshold"` + MemTableSize int `json:"memTableSize"` + MaxOpenFiles int `json:"maxOpenFiles"` + MaxConcurrentCompactions int `json:"maxConcurrentCompactions"` +} + +// TODO: Add metrics +func New(file string, configBytes []byte, log logging.Logger, _ string, _ prometheus.Registerer) (*Database, error) { + var cfg Config + if err := json.Unmarshal(configBytes, &cfg); err != nil { + return nil, err + } + + opts := &pebble.Options{ + Cache: pebble.NewCache(int64(cfg.CacheSize)), + BytesPerSync: cfg.BytesPerSync, + Comparer: pebble.DefaultComparer, + WALBytesPerSync: cfg.WALBytesPerSync, + MemTableStopWritesThreshold: cfg.MemTableStopWritesThreshold, + MemTableSize: cfg.MemTableSize, + MaxOpenFiles: cfg.MaxOpenFiles, + MaxConcurrentCompactions: func() int { return cfg.MaxConcurrentCompactions }, + } + opts.Experimental.ReadSamplingMultiplier = -1 // Disable seek compaction + + log.Info( + "opening pebble", + zap.Reflect("config", cfg), + ) + + db, err := pebble.Open(file, opts) + return &Database{ + pebbleDB: db, + openIterators: set.Set[*iter]{}, + }, err +} + +func (db *Database) Close() error { + db.lock.Lock() + defer db.lock.Unlock() + + if db.closed { + return database.ErrClosed + } + + db.closed = true + + for iter := range db.openIterators { + iter.lock.Lock() + iter.release() + iter.lock.Unlock() + } + db.openIterators.Clear() + + return updateError(db.pebbleDB.Close()) +} + +func (db *Database) HealthCheck(_ context.Context) (interface{}, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + if db.closed { + return nil, database.ErrClosed + } + return nil, nil +} + +func (db *Database) Has(key []byte) (bool, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + if db.closed { + return false, database.ErrClosed + } + + _, closer, err := db.pebbleDB.Get(key) + if err == pebble.ErrNotFound { + return false, nil + } + if err != nil { + return false, updateError(err) + } + return true, closer.Close() +} + +func (db *Database) Get(key []byte) ([]byte, error) { + db.lock.RLock() + defer db.lock.RUnlock() + + if db.closed { + return nil, database.ErrClosed + } + + data, closer, err := db.pebbleDB.Get(key) + if err != nil { + return nil, updateError(err) + } + return slices.Clone(data), closer.Close() +} + +func (db *Database) Put(key []byte, value []byte) error { + db.lock.RLock() + defer db.lock.RUnlock() + + if db.closed { + return database.ErrClosed + } + + return updateError(db.pebbleDB.Set(key, value, pebble.Sync)) +} + +func (db *Database) Delete(key []byte) error { + db.lock.RLock() + defer db.lock.RUnlock() + + if db.closed { + return database.ErrClosed + } + + return updateError(db.pebbleDB.Delete(key, pebble.Sync)) +} + +func (db *Database) Compact(start []byte, end []byte) error { + db.lock.RLock() + defer db.lock.RUnlock() + + if db.closed { + return database.ErrClosed + } + + if end == nil { + // The database.Database spec treats a nil [limit] as a key after all keys + // but pebble treats a nil [limit] as a key before all keys in Compact. + // Use the greatest key in the database as the [limit] to get the desired behavior. + it := db.pebbleDB.NewIter(&pebble.IterOptions{}) + + if !it.Last() { + // The database is empty. + return it.Close() + } + + end = it.Key() + if err := it.Close(); err != nil { + return err + } + } + + if pebble.DefaultComparer.Compare(start, end) >= 1 { + // pebble requires [start] < [end] + return nil + } + + return updateError(db.pebbleDB.Compact(start, end, true /* parallelize */)) +} + +func (db *Database) NewIterator() database.Iterator { + return db.NewIteratorWithStartAndPrefix(nil, nil) +} + +func (db *Database) NewIteratorWithStart(start []byte) database.Iterator { + return db.NewIteratorWithStartAndPrefix(start, nil) +} + +func (db *Database) NewIteratorWithPrefix(prefix []byte) database.Iterator { + return db.NewIteratorWithStartAndPrefix(nil, prefix) +} + +func (db *Database) NewIteratorWithStartAndPrefix(start, prefix []byte) database.Iterator { + db.lock.Lock() + defer db.lock.Unlock() + + if db.closed { + return &iter{ + db: db, + closed: true, + err: database.ErrClosed, + } + } + + iter := &iter{ + db: db, + iter: db.pebbleDB.NewIter(keyRange(start, prefix)), + } + db.openIterators.Add(iter) + return iter +} + +// Converts a pebble-specific error to its Avalanche equivalent, if applicable. +func updateError(err error) error { + switch err { + case pebble.ErrClosed: + return database.ErrClosed + case pebble.ErrNotFound: + return database.ErrNotFound + default: + return err + } +} + +func keyRange(start, prefix []byte) *pebble.IterOptions { + opt := &pebble.IterOptions{ + LowerBound: prefix, + UpperBound: prefixToUpperBound(prefix), + } + if bytes.Compare(start, prefix) == 1 { + opt.LowerBound = start + } + return opt +} + +// Returns an upper bound that stops after all keys with the given [prefix]. +// Assumes the Database uses bytes.Compare for key comparison and not a custom +// comparer. +func prefixToUpperBound(prefix []byte) []byte { + for i := len(prefix) - 1; i >= 0; i-- { + if prefix[i] != 0xFF { + upperBound := make([]byte, i+1) + copy(upperBound, prefix) + upperBound[i]++ + return upperBound + } + } + return nil +} diff --git a/database/pebble/db_test.go b/database/pebble/db_test.go new file mode 100644 index 000000000000..c72a9d687c88 --- /dev/null +++ b/database/pebble/db_test.go @@ -0,0 +1,147 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package pebble + +import ( + "testing" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/stretchr/testify/require" + + "github.com/ava-labs/avalanchego/database" + "github.com/ava-labs/avalanchego/utils/logging" +) + +func newDB(t testing.TB) *Database { + folder := t.TempDir() + db, err := New(folder, DefaultConfigBytes, logging.NoLog{}, "pebble", prometheus.NewRegistry()) + require.NoError(t, err) + return db +} + +func TestInterface(t *testing.T) { + for _, test := range database.Tests { + db := newDB(t) + test(t, db) + _ = db.Close() + } +} + +func FuzzKeyValue(f *testing.F) { + db := newDB(f) + database.FuzzKeyValue(f, db) + _ = db.Close() +} + +func FuzzNewIteratorWithPrefix(f *testing.F) { + db := newDB(f) + database.FuzzNewIteratorWithPrefix(f, db) + _ = db.Close() +} + +func BenchmarkInterface(b *testing.B) { + for _, size := range database.BenchmarkSizes { + keys, values := database.SetupBenchmark(b, size[0], size[1], size[2]) + for _, bench := range database.Benchmarks { + db := newDB(b) + bench(b, db, "pebble", keys, values) + _ = db.Close() + } + } +} + +func TestKeyRange(t *testing.T) { + require := require.New(t) + + type test struct { + start []byte + prefix []byte + expectedLower []byte + expectedUpper []byte + } + + tests := []test{ + { + start: nil, + prefix: nil, + expectedLower: nil, + expectedUpper: nil, + }, + { + start: nil, + prefix: []byte{}, + expectedLower: []byte{}, + expectedUpper: nil, + }, + { + start: nil, + prefix: []byte{0x00}, + expectedLower: []byte{0x00}, + expectedUpper: []byte{0x01}, + }, + { + start: []byte{0x00, 0x02}, + prefix: []byte{0x00}, + expectedLower: []byte{0x00, 0x02}, + expectedUpper: []byte{0x01}, + }, + { + start: []byte{0x01}, + prefix: []byte{0x00}, + expectedLower: []byte{0x01}, + expectedUpper: []byte{0x01}, + }, + { + start: nil, + prefix: []byte{0x01}, + expectedLower: []byte{0x01}, + expectedUpper: []byte{0x02}, + }, + { + start: nil, + prefix: []byte{0xFF}, + expectedLower: []byte{0xFF}, + expectedUpper: nil, + }, + { + start: []byte{0x00}, + prefix: []byte{0xFF}, + expectedLower: []byte{0xFF}, + expectedUpper: nil, + }, + { + start: nil, + prefix: []byte{0x01, 0x02}, + expectedLower: []byte{0x01, 0x02}, + expectedUpper: []byte{0x01, 0x03}, + }, + { + start: []byte{0x01, 0x02}, + prefix: []byte{0x01, 0x02}, + expectedLower: []byte{0x01, 0x02}, + expectedUpper: []byte{0x01, 0x03}, + }, + { + start: []byte{0x01, 0x02, 0x05}, + prefix: []byte{0x01, 0x02}, + expectedLower: []byte{0x01, 0x02, 0x05}, + expectedUpper: []byte{0x01, 0x03}, + }, + { + start: nil, + prefix: []byte{0x01, 0x02, 0xFF}, + expectedLower: []byte{0x01, 0x02, 0xFF}, + expectedUpper: []byte{0x01, 0x03}, + }, + } + + for _, tt := range tests { + t.Run(string(tt.start)+" "+string(tt.prefix), func(t *testing.T) { + bounds := keyRange(tt.start, tt.prefix) + require.Equal(tt.expectedLower, bounds.LowerBound) + require.Equal(tt.expectedUpper, bounds.UpperBound) + }) + } +} diff --git a/database/pebble/iterator.go b/database/pebble/iterator.go new file mode 100644 index 000000000000..115c122e30f4 --- /dev/null +++ b/database/pebble/iterator.go @@ -0,0 +1,133 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package pebble + +import ( + "errors" + "fmt" + "sync" + + "github.com/cockroachdb/pebble" + + "golang.org/x/exp/slices" + + "github.com/ava-labs/avalanchego/database" +) + +var ( + _ database.Iterator = (*iter)(nil) + + errCouldntGetValue = errors.New("couldnt get iterator value") +) + +type iter struct { + // [lock] ensures that only one goroutine can access [iter] at a time. + // Note that [Database.Close] calls [iter.Release] so we need [lock] to ensure + // that the user and [Database.Close] don't execute [iter.Release] concurrently. + // Invariant: [Database.lock] is never grabbed while holding [lock]. + lock sync.Mutex + + db *Database + iter *pebble.Iterator + + initialized bool + closed bool + err error + + hasNext bool + nextKey []byte + nextVal []byte +} + +// Must not be called with [db.lock] held. +func (it *iter) Next() bool { + it.lock.Lock() + defer it.lock.Unlock() + + switch { + case it.err != nil: + it.hasNext = false + return false + case it.closed: + it.hasNext = false + it.err = database.ErrClosed + return false + case !it.initialized: + it.hasNext = it.iter.First() + it.initialized = true + default: + it.hasNext = it.iter.Next() + } + + if !it.hasNext { + return false + } + + it.nextKey = it.iter.Key() + + var err error + it.nextVal, err = it.iter.ValueAndErr() + if err != nil { + it.hasNext = false + it.err = fmt.Errorf("%w: %w", errCouldntGetValue, err) + return false + } + + return true +} + +func (it *iter) Error() error { + it.lock.Lock() + defer it.lock.Unlock() + + if it.err != nil || it.closed { + return it.err + } + return updateError(it.iter.Error()) +} + +func (it *iter) Key() []byte { + it.lock.Lock() + defer it.lock.Unlock() + + if !it.hasNext { + return nil + } + return slices.Clone(it.nextKey) +} + +func (it *iter) Value() []byte { + it.lock.Lock() + defer it.lock.Unlock() + + if !it.hasNext { + return nil + } + return slices.Clone(it.nextVal) +} + +func (it *iter) Release() { + it.db.lock.Lock() + defer it.db.lock.Unlock() + + it.lock.Lock() + defer it.lock.Unlock() + + it.release() +} + +// Assumes [it.lock] and [it.db.lock] are held. +func (it *iter) release() { + if it.closed { + return + } + + // Remove the iterator from the list of open iterators. + it.db.openIterators.Remove(it) + + it.closed = true + if err := it.iter.Close(); err != nil { + it.err = updateError(err) + } +} diff --git a/database/test_database.go b/database/test_database.go index 69fb1d2b7948..2e68f53341b8 100644 --- a/database/test_database.go +++ b/database/test_database.go @@ -933,7 +933,15 @@ func TestCompactNoPanic(t *testing.T, db Database) { require.NoError(db.Put(key2, value2)) require.NoError(db.Put(key3, value3)) + // Test compacting with nil bounds require.NoError(db.Compact(nil, nil)) + + // Test compacting when start > end + require.NoError(db.Compact([]byte{2}, []byte{1})) + + // Test compacting when start > largest key + require.NoError(db.Compact([]byte{255}, nil)) + require.NoError(db.Close()) err := db.Compact(nil, nil) require.ErrorIs(err, ErrClosed) diff --git a/go.mod b/go.mod index 5b44d0529009..6a3fba220d34 100644 --- a/go.mod +++ b/go.mod @@ -11,9 +11,10 @@ require ( github.com/DataDog/zstd v1.5.2 github.com/Microsoft/go-winio v0.5.2 github.com/NYTimes/gziphandler v1.1.1 - github.com/ava-labs/coreth v0.12.6-rc.2 + github.com/ava-labs/coreth v0.12.7-rc.1 github.com/ava-labs/ledger-avalanche/go v0.0.0-20230105152938-00a24d05a8c7 github.com/btcsuite/btcd/btcutil v1.1.3 + github.com/cockroachdb/pebble v0.0.0-20230209160836-829675f94811 github.com/decred/dcrd/dcrec/secp256k1/v4 v4.1.0 github.com/ethereum/go-ethereum v1.12.0 github.com/golang-jwt/jwt/v4 v4.3.0 @@ -58,7 +59,7 @@ require ( go.uber.org/mock v0.2.0 go.uber.org/zap v1.24.0 golang.org/x/crypto v0.14.0 - golang.org/x/exp v0.0.0-20230206171751-46f607a40771 + golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df golang.org/x/net v0.17.0 golang.org/x/sync v0.3.0 golang.org/x/term v0.13.0 @@ -71,6 +72,7 @@ require ( ) require ( + github.com/BurntSushi/toml v1.2.1 // indirect github.com/VictoriaMetrics/fastcache v1.10.0 // indirect github.com/benbjohnson/clock v1.3.0 // indirect github.com/beorn7/perks v1.0.1 // indirect @@ -79,7 +81,6 @@ require ( github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/cockroachdb/errors v1.9.1 // indirect github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b // indirect - github.com/cockroachdb/pebble v0.0.0-20230209160836-829675f94811 // indirect github.com/cockroachdb/redact v1.1.3 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect @@ -87,6 +88,7 @@ require ( github.com/dlclark/regexp2 v1.7.0 // indirect github.com/dop251/goja v0.0.0-20230605162241-28ee0ee714f3 // indirect github.com/fjl/memsize v0.0.0-20190710130421-bcb5799ab5e5 // indirect + github.com/frankban/quicktest v1.14.4 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/gballet/go-libpcsclite v0.0.0-20191108122812-4678299bea08 // indirect github.com/getsentry/sentry-go v0.18.0 // indirect @@ -127,7 +129,7 @@ require ( github.com/prometheus/procfs v0.9.0 // indirect github.com/richardlehane/mscfb v1.0.4 // indirect github.com/richardlehane/msoleps v1.0.3 // indirect - github.com/rogpeppe/go-internal v1.9.0 // indirect + github.com/rogpeppe/go-internal v1.10.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/sanity-io/litter v1.5.1 // indirect github.com/spf13/afero v1.8.2 // indirect @@ -159,4 +161,4 @@ require ( replace github.com/ava-labs/avalanche-ledger-go => github.com/chain4travel/camino-ledger-go v0.0.13-c4t -replace github.com/ava-labs/coreth => github.com/chain4travel/caminoethvm v1.1.13-rc0 +replace github.com/ava-labs/coreth => github.com/chain4travel/caminoethvm v1.1.14-rc0 diff --git a/go.sum b/go.sum index 35fbdeffbc3d..c1aab72e97ec 100644 --- a/go.sum +++ b/go.sum @@ -38,7 +38,8 @@ cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3f dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/AndreasBriese/bbloom v0.0.0-20190306092124-e2d15f34fcf9/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/BurntSushi/toml v1.2.0 h1:Rt8g24XnyGTyglgET/PRUNlrUeu9F5L+7FilkXfZgs0= +github.com/BurntSushi/toml v1.2.1 h1:9F2/+DoOYIOksmaJFPw1tGFy1eDnIJXg+UHjuD8lTak= +github.com/BurntSushi/toml v1.2.1/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/CloudyKit/fastprinter v0.0.0-20200109182630-33d98a066a53/go.mod h1:+3IMCy2vIlbG1XG/0ggNQv0SvxCAIpPM5b1nCz56Xno= github.com/CloudyKit/jet/v3 v3.0.0/go.mod h1:HKQPgSJmdK8hdoAbKUUWajkHyHo4RaU5rMdUywE7VMo= @@ -103,8 +104,8 @@ github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/chain4travel/caminoethvm v1.1.13-rc0 h1:G4dMGmzo+U2syA/noQ27HljaDx1zjMqPwmpkPXtJjIU= -github.com/chain4travel/caminoethvm v1.1.13-rc0/go.mod h1:fYkXHddsMAlknJoAs8t1ISdyGvtlnWXfqHJztJt43ik= +github.com/chain4travel/caminoethvm v1.1.14-rc0 h1:MqCeMYzPSiaJnlaapLL2RYxcnZXh34g7P0WlaUdy50E= +github.com/chain4travel/caminoethvm v1.1.14-rc0/go.mod h1:vaF3LIgjGW1povg11ZAsnY+iWMuBrhsNWz5JBtpOu80= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/logex v1.2.0/go.mod h1:9+9sk7u7pGNWYMkh0hdiL++6OeibzJccyQU4p4MedaY= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= @@ -186,7 +187,8 @@ github.com/fasthttp-contrib/websocket v0.0.0-20160511215533-1f3b11f56072/go.mod github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= github.com/fjl/memsize v0.0.0-20190710130421-bcb5799ab5e5 h1:FtmdgXiUlNeRsoNMFlKLDt+S+6hbjVMEW6RGQ7aUf7c= github.com/fjl/memsize v0.0.0-20190710130421-bcb5799ab5e5/go.mod h1:VvhXpOYNQvB+uIk2RvXzuaQtkQJzzIx6lSBe1xv7hi0= -github.com/frankban/quicktest v1.14.3 h1:FJKSZTDHjyhriyC81FLQ0LY93eSai0ZyR/ZIkd3ZUKE= +github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0XL9UY= +github.com/frankban/quicktest v1.14.4/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU= @@ -529,8 +531,9 @@ github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6L github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o= -github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/rs/cors v1.7.0 h1:+88SsELBHx5r+hZ8TCkggzSstaWNbDvThkVK8H6f9ik= github.com/rs/cors v1.7.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= @@ -711,8 +714,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= -golang.org/x/exp v0.0.0-20230206171751-46f607a40771 h1:xP7rWLUr1e1n2xkK5YB4LI0hPEy3LJC6Wk+D4pGlOJg= -golang.org/x/exp v0.0.0-20230206171751-46f607a40771/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df h1:UA2aFVmmsIlefxMk29Dp2juaUSth8Pyn3Tq5Y5mJGME= +golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.11.0 h1:ds2RoQvBvYTiJkwpSFDwCcDFNX7DqjL2WsUgTNk0Ooo= diff --git a/network/config.go b/network/config.go index 54b6e56b5220..1ca1addc0117 100644 --- a/network/config.go +++ b/network/config.go @@ -134,8 +134,8 @@ type Config struct { TLSKey crypto.Signer `json:"-"` // TrackedSubnets of the node. - TrackedSubnets set.Set[ids.ID] `json:"-"` - Beacons validators.Set `json:"-"` + TrackedSubnets set.Set[ids.ID] `json:"-"` + Beacons validators.Manager `json:"-"` // Validators are the current validators in the Avalanche network Validators validators.Manager `json:"-"` diff --git a/network/example_test.go b/network/example_test.go index ca059a391d71..8f20900a7e8d 100644 --- a/network/example_test.go +++ b/network/example_test.go @@ -58,11 +58,11 @@ func (t *testExternalHandler) Disconnected(nodeID ids.NodeID) { ) } -type testAggressiveValidatorSet struct { - validators.Set +type testAggressiveValidatorManager struct { + validators.Manager } -func (*testAggressiveValidatorSet) Contains(ids.NodeID) bool { +func (*testAggressiveValidatorManager) Contains(ids.ID, ids.NodeID) bool { return true } @@ -78,8 +78,8 @@ func ExampleNewTestNetwork() { // Needs to be periodically updated by the caller to have the latest // validator set - validators := &testAggressiveValidatorSet{ - Set: validators.NewSet(), + validators := &testAggressiveValidatorManager{ + Manager: validators.NewManager(), } // If we want to be able to communicate with non-primary network subnets, we diff --git a/network/network.go b/network/network.go index 25690a3fdfe0..69f2d5a369b1 100644 --- a/network/network.go +++ b/network/network.go @@ -41,7 +41,6 @@ import ( "github.com/ava-labs/avalanchego/proto/pb/p2p" "github.com/ava-labs/avalanchego/snow/networking/router" "github.com/ava-labs/avalanchego/snow/networking/sender" - "github.com/ava-labs/avalanchego/snow/validators" "github.com/ava-labs/avalanchego/subnets" "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/ips" @@ -65,12 +64,10 @@ var ( _ sender.ExternalSender = (*network)(nil) _ Network = (*network)(nil) - errMissingPrimaryValidators = errors.New("missing primary validator set") - errNotValidator = errors.New("node is not a validator") - errNotTracked = errors.New("subnet is not tracked") - errSubnetNotExist = errors.New("subnet does not exist") - errExpectedProxy = errors.New("expected proxy") - errExpectedTCPProtocol = errors.New("expected TCP protocol") + errNotValidator = errors.New("node is not a validator") + errNotTracked = errors.New("subnet is not tracked") + errExpectedProxy = errors.New("expected proxy") + errExpectedTCPProtocol = errors.New("expected TCP protocol") ) // Network defines the functionality of the networking library. @@ -128,6 +125,14 @@ type UptimeResult struct { WeightedAveragePercentage float64 } +// To avoid potential deadlocks, we maintain that locks must be grabbed in the +// following order: +// +// 1. peersLock +// 2. manuallyTrackedIDsLock +// +// If a higher lock (e.g. manuallyTrackedIDsLock) is held when trying to grab a +// lower lock (e.g. peersLock) a deadlock could occur. type network struct { config *Config peerConfig *peer.Config @@ -167,11 +172,14 @@ type network struct { // connect to. An entry is added to this set when we first start attempting // to connect to the peer. An entry is deleted from this set once we have // finished the handshake. - trackedIPs map[ids.NodeID]*trackedIP - manuallyTrackedIDs set.Set[ids.NodeID] - connectingPeers peer.Set - connectedPeers peer.Set - closing bool + trackedIPs map[ids.NodeID]*trackedIP + connectingPeers peer.Set + connectedPeers peer.Set + closing bool + + // Tracks special peers that the network should always track + manuallyTrackedIDsLock sync.RWMutex + manuallyTrackedIDs set.Set[ids.NodeID] // router is notified about all peer [Connected] and [Disconnected] events // as well as all non-handshake peer messages. @@ -199,11 +207,6 @@ func NewNetwork( dialer dialer.Dialer, router router.ExternalHandler, ) (Network, error) { - primaryNetworkValidators, ok := config.Validators.Get(constants.PrimaryNetworkID) - if !ok { - return nil, errMissingPrimaryValidators - } - if config.ProxyEnabled { // Wrap the listener to process the proxy header. listener = &proxyproto.Listener{ @@ -230,7 +233,7 @@ func NewNetwork( log, config.Namespace, metricsRegisterer, - primaryNetworkValidators, + config.Validators, config.ThrottlerConfig.InboundMsgThrottlerConfig, config.ResourceTracker, config.CPUTargeter, @@ -244,7 +247,7 @@ func NewNetwork( log, config.Namespace, metricsRegisterer, - primaryNetworkValidators, + config.Validators, config.ThrottlerConfig.OutboundMsgThrottlerConfig, ) if err != nil { @@ -477,9 +480,11 @@ func (n *network) Connected(nodeID ids.NodeID) { // of peers, then it should only connect if this node is a validator, or the // peer is a validator/beacon. func (n *network) AllowConnection(nodeID ids.NodeID) bool { - return !n.config.RequireValidatorToConnect || - validators.Contains(n.config.Validators, constants.PrimaryNetworkID, n.config.MyNodeID) || - n.WantsConnection(nodeID) + if !n.config.RequireValidatorToConnect { + return true + } + _, isValidator := n.config.Validators.GetValidator(constants.PrimaryNetworkID, n.config.MyNodeID) + return isValidator || n.WantsConnection(nodeID) } func (n *network) Track(peerID ids.NodeID, claimedIPPorts []*ips.ClaimedIPPort) ([]*p2p.PeerAck, error) { @@ -810,23 +815,24 @@ func (n *network) Dispatch() error { } func (n *network) WantsConnection(nodeID ids.NodeID) bool { - n.peersLock.RLock() - defer n.peersLock.RUnlock() + if _, ok := n.config.Validators.GetValidator(constants.PrimaryNetworkID, nodeID); ok { + return true + } - return n.wantsConnection(nodeID) -} + n.manuallyTrackedIDsLock.RLock() + defer n.manuallyTrackedIDsLock.RUnlock() -func (n *network) wantsConnection(nodeID ids.NodeID) bool { - return validators.Contains(n.config.Validators, constants.PrimaryNetworkID, nodeID) || - n.manuallyTrackedIDs.Contains(nodeID) + return n.manuallyTrackedIDs.Contains(nodeID) } func (n *network) ManuallyTrack(nodeID ids.NodeID, ip ips.IPPort) { + n.manuallyTrackedIDsLock.Lock() + n.manuallyTrackedIDs.Add(nodeID) + n.manuallyTrackedIDsLock.Unlock() + n.peersLock.Lock() defer n.peersLock.Unlock() - n.manuallyTrackedIDs.Add(nodeID) - _, connected := n.connectedPeers.GetByID(nodeID) if connected { // If I'm currently connected to [nodeID] then they will have told me @@ -872,7 +878,7 @@ func (n *network) getPeers( continue } - isValidator := validators.Contains(n.config.Validators, subnetID, nodeID) + _, isValidator := n.config.Validators.GetValidator(subnetID, nodeID) // check if the peer is allowed to connect to the subnet if !allower.IsAllowed(nodeID, isValidator) { continue @@ -891,14 +897,9 @@ func (n *network) samplePeers( numPeersToSample int, allower subnets.Allower, ) []peer.Peer { - subnetValidators, ok := n.config.Validators.Get(subnetID) - if !ok { - return nil - } - // If there are fewer validators than [numValidatorsToSample], then only // sample [numValidatorsToSample] validators. - subnetValidatorsLen := subnetValidators.Len() + subnetValidatorsLen := n.config.Validators.Count(subnetID) if subnetValidatorsLen < numValidatorsToSample { numValidatorsToSample = subnetValidatorsLen } @@ -916,7 +917,7 @@ func (n *network) samplePeers( } peerID := p.ID() - isValidator := subnetValidators.Contains(peerID) + _, isValidator := n.config.Validators.GetValidator(subnetID, peerID) // check if the peer is allowed to connect to the subnet if !allower.IsAllowed(peerID, isValidator) { return false @@ -972,7 +973,7 @@ func (n *network) disconnectedFromConnecting(nodeID ids.NodeID) { // The peer that is disconnecting from us didn't finish the handshake tracked, ok := n.trackedIPs[nodeID] if ok { - if n.wantsConnection(nodeID) { + if n.WantsConnection(nodeID) { tracked := tracked.trackNewIP(tracked.ip) n.trackedIPs[nodeID] = tracked n.dial(nodeID, tracked) @@ -995,7 +996,7 @@ func (n *network) disconnectedFromConnected(peer peer.Peer, nodeID ids.NodeID) { n.connectedPeers.Remove(nodeID) // The peer that is disconnecting from us finished the handshake - if n.wantsConnection(nodeID) { + if n.WantsConnection(nodeID) { prevIP := n.peerIPs[nodeID] tracked := newTrackedIP(prevIP.IPPort) n.trackedIPs[nodeID] = tracked @@ -1058,7 +1059,7 @@ func (n *network) authenticateIPs(ips []*ips.ClaimedIPPort) ([]*ipAuth, error) { func (n *network) peerIPStatus(nodeID ids.NodeID, ip *ips.ClaimedIPPort) (*ips.ClaimedIPPort, bool, bool, bool) { prevIP, previouslyTracked := n.peerIPs[nodeID] shouldUpdateOurIP := previouslyTracked && prevIP.Timestamp < ip.Timestamp - shouldDial := !previouslyTracked && n.wantsConnection(nodeID) + shouldDial := !previouslyTracked && n.WantsConnection(nodeID) return prevIP, previouslyTracked, shouldUpdateOurIP, shouldDial } @@ -1104,7 +1105,7 @@ func (n *network) dial(nodeID ids.NodeID, ip *trackedIP) { // trackedIPs and this goroutine. This prevents a memory leak when // the tracked nodeID leaves the validator set and is never able to // be connected to. - if !n.wantsConnection(nodeID) { + if !n.WantsConnection(nodeID) { // Typically [n.trackedIPs[nodeID]] will already equal [ip], but // the reference to [ip] is refreshed to avoid any potential // race conditions before removing the entry. @@ -1360,18 +1361,18 @@ func (n *network) NodeUptime(subnetID ids.ID) (UptimeResult, error) { return UptimeResult{}, errNotTracked } - validators, ok := n.config.Validators.Get(subnetID) - if !ok { - return UptimeResult{}, errSubnetNotExist - } - - myStake := validators.GetWeight(n.config.MyNodeID) + myStake := n.config.Validators.GetWeight(subnetID, n.config.MyNodeID) if myStake == 0 { return UptimeResult{}, errNotValidator } + totalWeightInt, err := n.config.Validators.TotalWeight(subnetID) + if err != nil { + return UptimeResult{}, fmt.Errorf("error while fetching weight for subnet %s: %w", subnetID, err) + } + var ( - totalWeight = float64(validators.Weight()) + totalWeight = float64(totalWeightInt) totalWeightedPercent = 100 * float64(myStake) rewardingStake = float64(myStake) ) @@ -1383,7 +1384,7 @@ func (n *network) NodeUptime(subnetID ids.ID) (UptimeResult, error) { peer, _ := n.connectedPeers.GetByIndex(i) nodeID := peer.ID() - weight := validators.GetWeight(nodeID) + weight := n.config.Validators.GetWeight(subnetID, nodeID) if weight == 0 { // this is not a validator skip it. continue diff --git a/network/network_test.go b/network/network_test.go index 80f20846bab0..65d4809101fb 100644 --- a/network/network_test.go +++ b/network/network_test.go @@ -135,12 +135,13 @@ func init() { func newDefaultTargeter(t tracker.Tracker) tracker.Targeter { return tracker.NewTargeter( + logging.NoLog{}, &tracker.TargeterConfig{ VdrAlloc: 10, MaxNonVdrUsage: 10, MaxNonVdrNodeUsage: 10, }, - validators.NewSet(), + validators.NewManager(), t, ) } @@ -223,18 +224,15 @@ func newFullyConnectedTestNetwork(t *testing.T, handlers []router.InboundHandler GossipTracker: g, } - beacons := validators.NewSet() - require.NoError(beacons.Add(nodeIDs[0], nil, ids.GenerateTestID(), 1)) + beacons := validators.NewManager() + require.NoError(beacons.AddStaker(constants.PrimaryNetworkID, nodeIDs[0], nil, ids.GenerateTestID(), 1)) - primaryVdrs := validators.NewSet() - primaryVdrs.RegisterCallbackListener(&gossipTrackerCallback) + vdrs := validators.NewManager() + vdrs.RegisterCallbackListener(constants.PrimaryNetworkID, &gossipTrackerCallback) for _, nodeID := range nodeIDs { - require.NoError(primaryVdrs.Add(nodeID, nil, ids.GenerateTestID(), 1)) + require.NoError(vdrs.AddStaker(constants.PrimaryNetworkID, nodeID, nil, ids.GenerateTestID(), 1)) } - vdrs := validators.NewManager() - _ = vdrs.Add(constants.PrimaryNetworkID, primaryVdrs) - config := config config.GossipTracker = g @@ -405,7 +403,7 @@ func TestTrackVerifiesSignatures(t *testing.T) { network := networks[0] nodeID, tlsCert, _ := getTLS(t, 1) - require.NoError(validators.Add(network.config.Validators, constants.PrimaryNetworkID, nodeID, nil, ids.Empty, 1)) + require.NoError(network.config.Validators.AddStaker(constants.PrimaryNetworkID, nodeID, nil, ids.Empty, 1)) _, err := network.Track(ids.EmptyNodeID, []*ips.ClaimedIPPort{{ Cert: staking.CertificateFromX509(tlsCert.Leaf), @@ -448,18 +446,15 @@ func TestTrackDoesNotDialPrivateIPs(t *testing.T) { GossipTracker: g, } - beacons := validators.NewSet() - require.NoError(beacons.Add(nodeIDs[0], nil, ids.GenerateTestID(), 1)) + beacons := validators.NewManager() + require.NoError(beacons.AddStaker(constants.PrimaryNetworkID, nodeIDs[0], nil, ids.GenerateTestID(), 1)) - primaryVdrs := validators.NewSet() - primaryVdrs.RegisterCallbackListener(&gossipTrackerCallback) + vdrs := validators.NewManager() + vdrs.RegisterCallbackListener(constants.PrimaryNetworkID, &gossipTrackerCallback) for _, nodeID := range nodeIDs { - require.NoError(primaryVdrs.Add(nodeID, nil, ids.GenerateTestID(), 1)) + require.NoError(vdrs.AddStaker(constants.PrimaryNetworkID, nodeID, nil, ids.GenerateTestID(), 1)) } - vdrs := validators.NewManager() - _ = vdrs.Add(constants.PrimaryNetworkID, primaryVdrs) - config := config config.GossipTracker = g @@ -527,9 +522,9 @@ func TestDialDeletesNonValidators(t *testing.T) { dialer, listeners, nodeIDs, configs := newTestNetwork(t, 2) - primaryVdrs := validators.NewSet() + vdrs := validators.NewManager() for _, nodeID := range nodeIDs { - require.NoError(primaryVdrs.Add(nodeID, nil, ids.GenerateTestID(), 1)) + require.NoError(vdrs.AddStaker(constants.PrimaryNetworkID, nodeID, nil, ids.GenerateTestID(), 1)) } networks := make([]Network, len(configs)) @@ -546,13 +541,10 @@ func TestDialDeletesNonValidators(t *testing.T) { GossipTracker: g, } - beacons := validators.NewSet() - require.NoError(beacons.Add(nodeIDs[0], nil, ids.GenerateTestID(), 1)) - - primaryVdrs.RegisterCallbackListener(&gossipTrackerCallback) + beacons := validators.NewManager() + require.NoError(beacons.AddStaker(constants.PrimaryNetworkID, nodeIDs[0], nil, ids.GenerateTestID(), 1)) - vdrs := validators.NewManager() - _ = vdrs.Add(constants.PrimaryNetworkID, primaryVdrs) + vdrs.RegisterCallbackListener(constants.PrimaryNetworkID, &gossipTrackerCallback) config := config @@ -613,7 +605,7 @@ func TestDialDeletesNonValidators(t *testing.T) { time.Sleep(50 * time.Millisecond) network := networks[1].(*network) - require.NoError(primaryVdrs.RemoveWeight(nodeIDs[0], 1)) + require.NoError(vdrs.RemoveWeight(constants.PrimaryNetworkID, nodeIDs[0], 1)) require.Eventually( func() bool { network.peersLock.RLock() @@ -657,8 +649,8 @@ func TestDialContext(t *testing.T) { } ) - network.manuallyTrackedIDs.Add(neverDialedNodeID) - network.manuallyTrackedIDs.Add(dialedNodeID) + network.ManuallyTrack(neverDialedNodeID, neverDialedIP.ip) + network.ManuallyTrack(dialedNodeID, dialedIP.ip) // Sanity check that when a non-cancelled context is given, // we actually dial the peer. @@ -691,3 +683,88 @@ func TestDialContext(t *testing.T) { network.StartClose() wg.Wait() } + +func TestAllowConnectionAsAValidator(t *testing.T) { + require := require.New(t) + + dialer, listeners, nodeIDs, configs := newTestNetwork(t, 2) + + networks := make([]Network, len(configs)) + for i, config := range configs { + msgCreator := newMessageCreator(t) + registry := prometheus.NewRegistry() + + g, err := peer.NewGossipTracker(registry, "foobar") + require.NoError(err) + + log := logging.NoLog{} + gossipTrackerCallback := peer.GossipTrackerCallback{ + Log: log, + GossipTracker: g, + } + + beacons := validators.NewManager() + require.NoError(beacons.AddStaker(constants.PrimaryNetworkID, nodeIDs[0], nil, ids.GenerateTestID(), 1)) + + vdrs := validators.NewManager() + vdrs.RegisterCallbackListener(constants.PrimaryNetworkID, &gossipTrackerCallback) + require.NoError(vdrs.AddStaker(constants.PrimaryNetworkID, nodeIDs[0], nil, ids.GenerateTestID(), 1)) + + config := config + + config.GossipTracker = g + config.Beacons = beacons + config.Validators = vdrs + config.RequireValidatorToConnect = true + + net, err := NewNetwork( + config, + msgCreator, + registry, + log, + listeners[i], + dialer, + &testHandler{ + InboundHandler: nil, + ConnectedF: nil, + DisconnectedF: nil, + }, + ) + require.NoError(err) + networks[i] = net + } + + wg := sync.WaitGroup{} + wg.Add(len(networks)) + for i, net := range networks { + if i != 0 { + config := configs[0] + net.ManuallyTrack(config.MyNodeID, config.MyIPPort.IPPort()) + } + + go func(net Network) { + defer wg.Done() + + require.NoError(net.Dispatch()) + }(net) + } + + network := networks[1].(*network) + require.Eventually( + func() bool { + network.peersLock.RLock() + defer network.peersLock.RUnlock() + + nodeID := nodeIDs[0] + _, contains := network.connectedPeers.GetByID(nodeID) + return contains + }, + 10*time.Second, + 50*time.Millisecond, + ) + + for _, net := range networks { + net.StartClose() + } + wg.Wait() +} diff --git a/network/peer/config.go b/network/peer/config.go index 2ad13a19a0be..b4fd03db2166 100644 --- a/network/peer/config.go +++ b/network/peer/config.go @@ -34,7 +34,7 @@ type Config struct { Router router.InboundHandler VersionCompatibility version.Compatibility MySubnets set.Set[ids.ID] - Beacons validators.Set + Beacons validators.Manager NetworkID uint32 PingFrequency time.Duration PongTimeout time.Duration diff --git a/network/peer/peer.go b/network/peer/peer.go index 6ff3c9b1b6ee..503f97262882 100644 --- a/network/peer/peer.go +++ b/network/peer/peer.go @@ -853,7 +853,7 @@ func (p *peer) handleVersion(msg *p2p.Version) { p.Metrics.ClockSkew.Observe(clockDifference) if clockDifference > p.MaxClockDifference.Seconds() { - if p.Beacons.Contains(p.id) { + if _, ok := p.Beacons.GetValidator(constants.PrimaryNetworkID, p.id); ok { p.Log.Warn("beacon reports out of sync time", zap.Stringer("nodeID", p.id), zap.Uint64("peerTime", msg.MyTime), @@ -882,7 +882,7 @@ func (p *peer) handleVersion(msg *p2p.Version) { p.version = peerVersion if p.VersionCompatibility.Version().Before(peerVersion) { - if p.Beacons.Contains(p.id) { + if _, ok := p.Beacons.GetValidator(constants.PrimaryNetworkID, p.id); ok { p.Log.Info("beacon attempting to connect with newer version. You may want to update your client", zap.Stringer("nodeID", p.id), zap.Stringer("beaconVersion", peerVersion), diff --git a/network/peer/peer_test.go b/network/peer/peer_test.go index 4702e781f433..3e43dab3b2c1 100644 --- a/network/peer/peer_test.go +++ b/network/peer/peer_test.go @@ -114,7 +114,7 @@ func makeRawTestPeers(t *testing.T, trackedSubnets set.Set[ids.ID]) (*rawTestPee VersionCompatibility: version.GetCompatibility(constants.LocalID), MySubnets: trackedSubnets, UptimeCalculator: uptime.NoOpCalculator, - Beacons: validators.NewSet(), + Beacons: validators.NewManager(), NetworkID: constants.LocalID, PingFrequency: constants.DefaultPingFrequency, PongTimeout: constants.DefaultPingPongTimeout, diff --git a/network/peer/test_peer.go b/network/peer/test_peer.go index 5bcb1bd57f81..62717e27dca1 100644 --- a/network/peer/test_peer.go +++ b/network/peer/test_peer.go @@ -115,7 +115,7 @@ func StartTestPeer( Router: router, VersionCompatibility: version.GetCompatibility(networkID), MySubnets: set.Set[ids.ID]{}, - Beacons: validators.NewSet(), + Beacons: validators.NewManager(), NetworkID: networkID, PingFrequency: constants.DefaultPingFrequency, PongTimeout: constants.DefaultPingPongTimeout, diff --git a/network/test_network.go b/network/test_network.go index 937496003fd2..d8795e14e044 100644 --- a/network/test_network.go +++ b/network/test_network.go @@ -18,6 +18,7 @@ import ( "github.com/ava-labs/avalanchego/network/dialer" "github.com/ava-labs/avalanchego/network/peer" "github.com/ava-labs/avalanchego/network/throttling" + "github.com/ava-labs/avalanchego/snow" "github.com/ava-labs/avalanchego/snow/networking/router" "github.com/ava-labs/avalanchego/snow/networking/tracker" "github.com/ava-labs/avalanchego/snow/uptime" @@ -73,7 +74,7 @@ func (*noopListener) Addr() net.Addr { func NewTestNetwork( log logging.Logger, networkID uint32, - currentValidators validators.Set, + currentValidators validators.Manager, trackedSubnets set.Set[ids.ID], router router.ExternalHandler, ) (Network, error) { @@ -186,10 +187,9 @@ func NewTestNetwork( networkConfig.TLSConfig = tlsConfig networkConfig.TLSKey = tlsCert.PrivateKey.(crypto.Signer) - validatorManager := validators.NewManager() - beacons := validators.NewSet() - networkConfig.Validators = validatorManager - networkConfig.Validators.Add(constants.PrimaryNetworkID, currentValidators) + ctx := snow.DefaultConsensusContextTest() + beacons := validators.NewManager() + networkConfig.Validators = currentValidators networkConfig.Beacons = beacons // This never actually does anything because we never initialize the P-chain networkConfig.UptimeCalculator = uptime.NoOpCalculator @@ -207,6 +207,7 @@ func NewTestNetwork( return nil, err } networkConfig.CPUTargeter = tracker.NewTargeter( + ctx.Log, &tracker.TargeterConfig{ VdrAlloc: float64(runtime.NumCPU()), MaxNonVdrUsage: .8 * float64(runtime.NumCPU()), @@ -216,6 +217,7 @@ func NewTestNetwork( networkConfig.ResourceTracker.CPUTracker(), ) networkConfig.DiskTargeter = tracker.NewTargeter( + ctx.Log, &tracker.TargeterConfig{ VdrAlloc: 1000 * units.GiB, MaxNonVdrUsage: 1000 * units.GiB, diff --git a/network/throttling/common.go b/network/throttling/common.go index c2a92db31c57..9350fb4f684c 100644 --- a/network/throttling/common.go +++ b/network/throttling/common.go @@ -22,8 +22,7 @@ type MsgByteThrottlerConfig struct { type commonMsgThrottler struct { log logging.Logger lock sync.Mutex - // Primary network validator set - vdrs validators.Set + vdrs validators.Manager // Max number of bytes that can be taken from the // at-large byte allocation by a given node. nodeMaxAtLargeBytes uint64 diff --git a/network/throttling/inbound_msg_byte_throttler.go b/network/throttling/inbound_msg_byte_throttler.go index 66efa79ba0c2..659d9f398309 100644 --- a/network/throttling/inbound_msg_byte_throttler.go +++ b/network/throttling/inbound_msg_byte_throttler.go @@ -13,6 +13,7 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow/validators" + "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/linkedhashmap" "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/math" @@ -26,7 +27,7 @@ func newInboundMsgByteThrottler( log logging.Logger, namespace string, registerer prometheus.Registerer, - vdrs validators.Set, + vdrs validators.Manager, config MsgByteThrottlerConfig, ) (*inboundMsgByteThrottler, error) { t := &inboundMsgByteThrottler{ @@ -96,7 +97,7 @@ func (t *inboundMsgByteThrottler) Acquire(ctx context.Context, msgSize uint64, n t.lock.Lock() - // If there is already a message waiting, log the error but continue + // If there is already a message waiting, log the error and return if existingID, exists := t.nodeToWaitingMsgID[nodeID]; exists { t.log.Error("node already waiting on message", zap.Stringer("nodeID", nodeID), @@ -131,9 +132,16 @@ func (t *inboundMsgByteThrottler) Acquire(ctx context.Context, msgSize uint64, n // Take as many bytes as we can from [nodeID]'s validator allocation. // Calculate [nodeID]'s validator allocation size based on its weight vdrAllocationSize := uint64(0) - weight := t.vdrs.GetWeight(nodeID) + weight := t.vdrs.GetWeight(constants.PrimaryNetworkID, nodeID) if weight != 0 { - vdrAllocationSize = uint64(float64(t.maxVdrBytes) * float64(weight) / float64(t.vdrs.Weight())) + totalWeight, err := t.vdrs.TotalWeight(constants.PrimaryNetworkID) + if err != nil { + t.log.Error("couldn't get total weight of primary network", + zap.Error(err), + ) + } else { + vdrAllocationSize = uint64(float64(t.maxVdrBytes) * float64(weight) / float64(totalWeight)) + } } vdrBytesAlreadyUsed := t.nodeToVdrBytesUsed[nodeID] // [vdrBytesAllowed] is the number of bytes this node diff --git a/network/throttling/inbound_msg_byte_throttler_test.go b/network/throttling/inbound_msg_byte_throttler_test.go index fa21f7baf387..e71f0abba238 100644 --- a/network/throttling/inbound_msg_byte_throttler_test.go +++ b/network/throttling/inbound_msg_byte_throttler_test.go @@ -14,6 +14,7 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow/validators" + "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/logging" ) @@ -24,9 +25,9 @@ func TestInboundMsgByteThrottlerCancelContextDeadlock(t *testing.T) { AtLargeAllocSize: 1, NodeMaxAtLargeBytes: 1, } - vdrs := validators.NewSet() + vdrs := validators.NewManager() vdr := ids.GenerateTestNodeID() - require.NoError(vdrs.Add(vdr, nil, ids.Empty, 1)) + require.NoError(vdrs.AddStaker(constants.PrimaryNetworkID, vdr, nil, ids.Empty, 1)) throttler, err := newInboundMsgByteThrottler( logging.NoLog{}, @@ -52,11 +53,11 @@ func TestInboundMsgByteThrottlerCancelContext(t *testing.T) { AtLargeAllocSize: 512, NodeMaxAtLargeBytes: 1024, } - vdrs := validators.NewSet() + vdrs := validators.NewManager() vdr1ID := ids.GenerateTestNodeID() vdr2ID := ids.GenerateTestNodeID() - require.NoError(vdrs.Add(vdr1ID, nil, ids.Empty, 1)) - require.NoError(vdrs.Add(vdr2ID, nil, ids.Empty, 1)) + require.NoError(vdrs.AddStaker(constants.PrimaryNetworkID, vdr1ID, nil, ids.Empty, 1)) + require.NoError(vdrs.AddStaker(constants.PrimaryNetworkID, vdr2ID, nil, ids.Empty, 1)) throttler, err := newInboundMsgByteThrottler( logging.NoLog{}, @@ -110,11 +111,11 @@ func TestInboundMsgByteThrottler(t *testing.T) { AtLargeAllocSize: 1024, NodeMaxAtLargeBytes: 1024, } - vdrs := validators.NewSet() + vdrs := validators.NewManager() vdr1ID := ids.GenerateTestNodeID() vdr2ID := ids.GenerateTestNodeID() - require.NoError(vdrs.Add(vdr1ID, nil, ids.Empty, 1)) - require.NoError(vdrs.Add(vdr2ID, nil, ids.Empty, 1)) + require.NoError(vdrs.AddStaker(constants.PrimaryNetworkID, vdr1ID, nil, ids.Empty, 1)) + require.NoError(vdrs.AddStaker(constants.PrimaryNetworkID, vdr2ID, nil, ids.Empty, 1)) throttler, err := newInboundMsgByteThrottler( logging.NoLog{}, @@ -328,9 +329,9 @@ func TestSybilMsgThrottlerMaxNonVdr(t *testing.T) { AtLargeAllocSize: 100, NodeMaxAtLargeBytes: 10, } - vdrs := validators.NewSet() + vdrs := validators.NewManager() vdr1ID := ids.GenerateTestNodeID() - require.NoError(vdrs.Add(vdr1ID, nil, ids.Empty, 1)) + require.NoError(vdrs.AddStaker(constants.PrimaryNetworkID, vdr1ID, nil, ids.Empty, 1)) throttler, err := newInboundMsgByteThrottler( logging.NoLog{}, "", @@ -375,9 +376,9 @@ func TestMsgThrottlerNextMsg(t *testing.T) { AtLargeAllocSize: 1024, NodeMaxAtLargeBytes: 1024, } - vdrs := validators.NewSet() + vdrs := validators.NewManager() vdr1ID := ids.GenerateTestNodeID() - require.NoError(vdrs.Add(vdr1ID, nil, ids.Empty, 1)) + require.NoError(vdrs.AddStaker(constants.PrimaryNetworkID, vdr1ID, nil, ids.Empty, 1)) nonVdrNodeID := ids.GenerateTestNodeID() maxVdrBytes := config.VdrAllocSize + config.AtLargeAllocSize diff --git a/network/throttling/inbound_msg_throttler.go b/network/throttling/inbound_msg_throttler.go index b76a7a345ada..3d79f640ae1a 100644 --- a/network/throttling/inbound_msg_throttler.go +++ b/network/throttling/inbound_msg_throttler.go @@ -56,7 +56,7 @@ func NewInboundMsgThrottler( log logging.Logger, namespace string, registerer prometheus.Registerer, - vdrs validators.Set, + vdrs validators.Manager, throttlerConfig InboundMsgThrottlerConfig, resourceTracker tracker.ResourceTracker, cpuTargeter tracker.Targeter, diff --git a/network/throttling/outbound_msg_throttler.go b/network/throttling/outbound_msg_throttler.go index 8b46cb2c00fc..62e8821660bf 100644 --- a/network/throttling/outbound_msg_throttler.go +++ b/network/throttling/outbound_msg_throttler.go @@ -6,9 +6,12 @@ package throttling import ( "github.com/prometheus/client_golang/prometheus" + "go.uber.org/zap" + "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/message" "github.com/ava-labs/avalanchego/snow/validators" + "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/math" "github.com/ava-labs/avalanchego/utils/wrappers" @@ -43,7 +46,7 @@ func NewSybilOutboundMsgThrottler( log logging.Logger, namespace string, registerer prometheus.Registerer, - vdrs validators.Set, + vdrs validators.Manager, config MsgByteThrottlerConfig, ) (OutboundMsgThrottler, error) { t := &outboundMsgThrottler{ @@ -85,9 +88,16 @@ func (t *outboundMsgThrottler) Acquire(msg message.OutboundMessage, nodeID ids.N // Take as many bytes as we can from [nodeID]'s validator allocation. // Calculate [nodeID]'s validator allocation size based on its weight vdrAllocationSize := uint64(0) - weight := t.vdrs.GetWeight(nodeID) + weight := t.vdrs.GetWeight(constants.PrimaryNetworkID, nodeID) if weight != 0 { - vdrAllocationSize = uint64(float64(t.maxVdrBytes) * float64(weight) / float64(t.vdrs.Weight())) + totalWeight, err := t.vdrs.TotalWeight(constants.PrimaryNetworkID) + if err != nil { + t.log.Error("Failed to get total weight of primary network validators", + zap.Error(err), + ) + } else { + vdrAllocationSize = uint64(float64(t.maxVdrBytes) * float64(weight) / float64(totalWeight)) + } } vdrBytesAlreadyUsed := t.nodeToVdrBytesUsed[nodeID] // [vdrBytesAllowed] is the number of bytes this node diff --git a/network/throttling/outbound_msg_throttler_test.go b/network/throttling/outbound_msg_throttler_test.go index a17cb3b974bc..09d8b6f272ef 100644 --- a/network/throttling/outbound_msg_throttler_test.go +++ b/network/throttling/outbound_msg_throttler_test.go @@ -15,6 +15,7 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/message" "github.com/ava-labs/avalanchego/snow/validators" + "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/logging" ) @@ -26,11 +27,11 @@ func TestSybilOutboundMsgThrottler(t *testing.T) { AtLargeAllocSize: 1024, NodeMaxAtLargeBytes: 1024, } - vdrs := validators.NewSet() + vdrs := validators.NewManager() vdr1ID := ids.GenerateTestNodeID() vdr2ID := ids.GenerateTestNodeID() - require.NoError(vdrs.Add(vdr1ID, nil, ids.Empty, 1)) - require.NoError(vdrs.Add(vdr2ID, nil, ids.Empty, 1)) + require.NoError(vdrs.AddStaker(constants.PrimaryNetworkID, vdr1ID, nil, ids.Empty, 1)) + require.NoError(vdrs.AddStaker(constants.PrimaryNetworkID, vdr2ID, nil, ids.Empty, 1)) throttlerIntf, err := NewSybilOutboundMsgThrottler( logging.NoLog{}, "", @@ -170,9 +171,9 @@ func TestSybilOutboundMsgThrottlerMaxNonVdr(t *testing.T) { AtLargeAllocSize: 100, NodeMaxAtLargeBytes: 10, } - vdrs := validators.NewSet() + vdrs := validators.NewManager() vdr1ID := ids.GenerateTestNodeID() - require.NoError(vdrs.Add(vdr1ID, nil, ids.Empty, 1)) + require.NoError(vdrs.AddStaker(constants.PrimaryNetworkID, vdr1ID, nil, ids.Empty, 1)) throttlerIntf, err := NewSybilOutboundMsgThrottler( logging.NoLog{}, "", @@ -217,9 +218,9 @@ func TestBypassThrottling(t *testing.T) { AtLargeAllocSize: 100, NodeMaxAtLargeBytes: 10, } - vdrs := validators.NewSet() + vdrs := validators.NewManager() vdr1ID := ids.GenerateTestNodeID() - require.NoError(vdrs.Add(vdr1ID, nil, ids.Empty, 1)) + require.NoError(vdrs.AddStaker(constants.PrimaryNetworkID, vdr1ID, nil, ids.Empty, 1)) throttlerIntf, err := NewSybilOutboundMsgThrottler( logging.NoLog{}, "", diff --git a/node/beacon_manager.go b/node/beacon_manager.go index 3e19824195c7..af088f3b4845 100644 --- a/node/beacon_manager.go +++ b/node/beacon_manager.go @@ -19,14 +19,15 @@ var _ router.Router = (*beaconManager)(nil) type beaconManager struct { router.Router timer *timer.Timer - beacons validators.Set + beacons validators.Manager requiredConns int64 numConns int64 } func (b *beaconManager) Connected(nodeID ids.NodeID, nodeVersion *version.Application, subnetID ids.ID) { - if constants.PrimaryNetworkID == subnetID && - b.beacons.Contains(nodeID) && + _, isBeacon := b.beacons.GetValidator(constants.PrimaryNetworkID, nodeID) + if isBeacon && + constants.PrimaryNetworkID == subnetID && atomic.AddInt64(&b.numConns, 1) >= b.requiredConns { b.timer.Cancel() } @@ -34,7 +35,7 @@ func (b *beaconManager) Connected(nodeID ids.NodeID, nodeVersion *version.Applic } func (b *beaconManager) Disconnected(nodeID ids.NodeID) { - if b.beacons.Contains(nodeID) { + if _, isBeacon := b.beacons.GetValidator(constants.PrimaryNetworkID, nodeID); isBeacon { atomic.AddInt64(&b.numConns, -1) } b.Router.Disconnected(nodeID) diff --git a/node/beacon_manager_test.go b/node/beacon_manager_test.go index 50347fb43072..82be435e92f1 100644 --- a/node/beacon_manager_test.go +++ b/node/beacon_manager_test.go @@ -27,11 +27,11 @@ func TestBeaconManager_DataRace(t *testing.T) { require := require.New(t) validatorIDs := make([]ids.NodeID, 0, numValidators) - validatorSet := validators.NewSet() + validatorSet := validators.NewManager() for i := 0; i < numValidators; i++ { nodeID := ids.GenerateTestNodeID() - require.NoError(validatorSet.Add(nodeID, nil, ids.Empty, 1)) + require.NoError(validatorSet.AddStaker(constants.PrimaryNetworkID, nodeID, nil, ids.Empty, 1)) validatorIDs = append(validatorIDs, nodeID) } diff --git a/node/insecure_validator_manager.go b/node/insecure_validator_manager.go index a171c52287fd..bd69529619dc 100644 --- a/node/insecure_validator_manager.go +++ b/node/insecure_validator_manager.go @@ -4,16 +4,20 @@ package node import ( + "go.uber.org/zap" + "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow/networking/router" "github.com/ava-labs/avalanchego/snow/validators" "github.com/ava-labs/avalanchego/utils/constants" + "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/version" ) type insecureValidatorManager struct { router.Router - vdrs validators.Set + log logging.Logger + vdrs validators.Manager weight uint64 } @@ -25,10 +29,14 @@ func (i *insecureValidatorManager) Connected(vdrID ids.NodeID, nodeVersion *vers dummyTxID := ids.Empty copy(dummyTxID[:], vdrID[:]) - // Add will only error here if the total weight of the set would go over - // [math.MaxUint64]. In this case, we will just not mark this new peer - // as a validator. - _ = i.vdrs.Add(vdrID, nil, dummyTxID, i.weight) + err := i.vdrs.AddStaker(constants.PrimaryNetworkID, vdrID, nil, dummyTxID, i.weight) + if err != nil { + i.log.Error("failed to add validator", + zap.Stringer("nodeID", vdrID), + zap.Stringer("subnetID", constants.PrimaryNetworkID), + zap.Error(err), + ) + } } i.Router.Connected(vdrID, nodeVersion, subnetID) } @@ -36,6 +44,13 @@ func (i *insecureValidatorManager) Connected(vdrID ids.NodeID, nodeVersion *vers func (i *insecureValidatorManager) Disconnected(vdrID ids.NodeID) { // RemoveWeight will only error here if there was an error reported during // Add. - _ = i.vdrs.RemoveWeight(vdrID, i.weight) + err := i.vdrs.RemoveWeight(constants.PrimaryNetworkID, vdrID, i.weight) + if err != nil { + i.log.Error("failed to remove weight", + zap.Stringer("nodeID", vdrID), + zap.Stringer("subnetID", constants.PrimaryNetworkID), + zap.Error(err), + ) + } i.Router.Disconnected(vdrID) } diff --git a/node/node.go b/node/node.go index 68511c592ae3..89c9f6d85abd 100644 --- a/node/node.go +++ b/node/node.go @@ -99,7 +99,9 @@ import ( ) var ( - genesisHashKey = []byte("genesisID") + genesisHashKey = []byte("genesisID") + ungracefulShutdown = []byte("ungracefulShutdown") + indexerDBPrefix = []byte{0x00} errInvalidTLSKey = errors.New("invalid TLS key") @@ -138,6 +140,9 @@ type Node struct { // Build and parse messages, for both network layer and chain manager msgCreator message.Creator + // Manages network timeouts + timeoutManager timeout.Manager + // Manages creation of blockchains and routing messages to them chainManager chains.Manager @@ -167,7 +172,7 @@ type Node struct { tlsKeyLogWriterCloser io.WriteCloser // this node's initial connections to the network - bootstrappers validators.Set + bootstrappers validators.Manager // current validators of the network vdrs validators.Manager @@ -229,8 +234,8 @@ type Node struct { */ // Initialize the networking layer. -// Assumes [n.CPUTracker] and [n.CPUTargeter] have been initialized. -func (n *Node) initNetworking(primaryNetVdrs validators.Set) error { +// Assumes [n.vdrs], [n.CPUTracker], and [n.CPUTargeter] have been initialized. +func (n *Node) initNetworking() error { currentIPPort := n.Config.IPPort.IPPort() // Providing either loopback address - `::1` for ipv6 and `127.0.0.1` for ipv4 - as the listen @@ -297,7 +302,6 @@ func (n *Node) initNetworking(primaryNetVdrs validators.Set) error { // Configure benchlist n.Config.BenchlistConfig.Validators = n.vdrs n.Config.BenchlistConfig.Benchable = n.Config.ConsensusRouter - n.Config.BenchlistConfig.SybilProtectionEnabled = n.Config.SybilProtectionEnabled n.benchlistManager = benchlist.NewManager(&n.Config.BenchlistConfig) n.uptimeCalculator = uptime.NewLockedCalculator() @@ -310,7 +314,8 @@ func (n *Node) initNetworking(primaryNetVdrs validators.Set) error { dummyTxID := ids.Empty copy(dummyTxID[:], n.ID[:]) - err := primaryNetVdrs.Add( + err := n.vdrs.AddStaker( + constants.PrimaryNetworkID, n.ID, bls.PublicFromSecretKey(n.Config.StakingSigningKey), dummyTxID, @@ -321,13 +326,14 @@ func (n *Node) initNetworking(primaryNetVdrs validators.Set) error { } consensusRouter = &insecureValidatorManager{ + log: n.Log, Router: consensusRouter, - vdrs: primaryNetVdrs, + vdrs: n.vdrs, weight: n.Config.SybilProtectionDisabledWeight, } } - numBootstrappers := n.bootstrappers.Len() + numBootstrappers := n.bootstrappers.Count(constants.PrimaryNetworkID) requiredConns := (3*numBootstrappers + 3) / 4 if requiredConns > 0 { @@ -362,7 +368,7 @@ func (n *Node) initNetworking(primaryNetVdrs validators.Set) error { } // keep gossip tracker synchronized with the validator set - primaryNetVdrs.RegisterCallbackListener(&peer.GossipTrackerCallback{ + n.vdrs.RegisterCallbackListener(constants.PrimaryNetworkID, &peer.GossipTrackerCallback{ Log: n.Log, GossipTracker: gossipTracker, }) @@ -562,17 +568,34 @@ func (n *Node) initDatabase() error { if genesisHash != expectedGenesisHash { return fmt.Errorf("db contains invalid genesis hash. DB Genesis: %s Generated Genesis: %s", genesisHash, expectedGenesisHash) } + + ok, err := n.DB.Has(ungracefulShutdown) + if err != nil { + return fmt.Errorf("failed to read ungraceful shutdown key: %w", err) + } + + if ok { + n.Log.Warn("detected previous ungraceful shutdown") + } + + if err := n.DB.Put(ungracefulShutdown, nil); err != nil { + return fmt.Errorf( + "failed to write ungraceful shutdown key at: %w", + err, + ) + } + return nil } // Set the node IDs of the peers this node should first connect to func (n *Node) initBootstrappers() error { - n.bootstrappers = validators.NewSet() + n.bootstrappers = validators.NewManager() for _, bootstrapper := range n.Config.Bootstrappers { // Note: The beacon connection manager will treat all beaconIDs as // equal. // Invariant: We never use the TxID or BLS keys populated here. - if err := n.bootstrappers.Add(bootstrapper.ID, nil, ids.Empty, 1); err != nil { + if err := n.bootstrappers.AddStaker(constants.PrimaryNetworkID, bootstrapper.ID, nil, ids.Empty, 1); err != nil { return err } } @@ -778,8 +801,7 @@ func (n *Node) initChainManager(avaxAssetID ids.ID) error { cChainID, ) - // Manages network timeouts - timeoutManager, err := timeout.NewManager( + n.timeoutManager, err = timeout.NewManager( &n.Config.AdaptiveTimeoutConfig, n.benchlistManager, "requests", @@ -788,13 +810,13 @@ func (n *Node) initChainManager(avaxAssetID ids.ID) error { if err != nil { return err } - go n.Log.RecoverAndPanic(timeoutManager.Dispatch) + go n.Log.RecoverAndPanic(n.timeoutManager.Dispatch) // Routes incoming messages from peers to the appropriate chain err = n.Config.ConsensusRouter.Initialize( n.ID, n.Log, - timeoutManager, + n.timeoutManager, n.Config.ConsensusShutdownTimeout, criticalChains, n.Config.SybilProtectionEnabled, @@ -833,7 +855,7 @@ func (n *Node) initChainManager(avaxAssetID ids.ID) error { XChainID: xChainID, CChainID: cChainID, CriticalChains: criticalChains, - TimeoutManager: timeoutManager, + TimeoutManager: n.timeoutManager, Health: n.health, RetryBootstrap: n.Config.RetryBootstrap, RetryBootstrapWarnFrequency: n.Config.RetryBootstrapWarnFrequency, @@ -872,8 +894,6 @@ func (n *Node) initVMs() error { // allows the node's validator sets to be determined by network connections. if !n.Config.SybilProtectionEnabled { vdrs = validators.NewManager() - primaryVdrs := validators.NewSet() - _ = vdrs.Add(constants.PrimaryNetworkID, primaryVdrs) } vmRegisterer := registry.NewVMRegisterer(registry.VMRegistererConfig{ @@ -1088,7 +1108,6 @@ func (n *Node) initInfoAPI() error { n.Log.Info("initializing info API") - primaryValidators, _ := n.vdrs.Get(constants.PrimaryNetworkID) service, err := info.NewService( info.Parameters{ Version: version.CurrentApp, @@ -1114,7 +1133,6 @@ func (n *Node) initInfoAPI() error { n.VMManager, n.Config.NetworkConfig.MyIPPort, n.Net, - primaryValidators, n.benchlistManager, ) if err != nil { @@ -1286,14 +1304,6 @@ func (n *Node) initAPIAliases(genesisBytes []byte) error { return nil } -// Initializes [n.vdrs] and returns the Primary Network validator set. -func (n *Node) initVdrs() validators.Set { - n.vdrs = validators.NewManager() - vdrSet := validators.NewSet() - _ = n.vdrs.Add(constants.PrimaryNetworkID, vdrSet) - return vdrSet -} - // Initialize [n.resourceManager]. func (n *Node) initResourceManager(reg prometheus.Registerer) error { resourceManager, err := resource.NewManager( @@ -1318,11 +1328,11 @@ func (n *Node) initResourceManager(reg prometheus.Registerer) error { // Assumes [n.resourceTracker] is already initialized. func (n *Node) initCPUTargeter( config *tracker.TargeterConfig, - vdrs validators.Set, ) { n.cpuTargeter = tracker.NewTargeter( + n.Log, config, - vdrs, + n.vdrs, n.resourceTracker.CPUTracker(), ) } @@ -1331,11 +1341,11 @@ func (n *Node) initCPUTargeter( // Assumes [n.resourceTracker] is already initialized. func (n *Node) initDiskTargeter( config *tracker.TargeterConfig, - vdrs validators.Set, ) { n.diskTargeter = tracker.NewTargeter( + n.Log, config, - vdrs, + n.vdrs, n.resourceTracker.DiskTracker(), ) } @@ -1430,13 +1440,16 @@ func (n *Node) Initialize( return fmt.Errorf("problem initializing message creator: %w", err) } - primaryNetVdrs := n.initVdrs() + n.vdrs = validators.NewManager() + if !n.Config.SybilProtectionEnabled { + n.vdrs = newOverriddenManager(constants.PrimaryNetworkID, n.vdrs) + } if err := n.initResourceManager(n.MetricsRegisterer); err != nil { return fmt.Errorf("problem initializing resource manager: %w", err) } - n.initCPUTargeter(&config.CPUTargeterConfig, primaryNetVdrs) - n.initDiskTargeter(&config.DiskTargeterConfig, primaryNetVdrs) - if err := n.initNetworking(primaryNetVdrs); err != nil { // Set up networking layer. + n.initCPUTargeter(&config.CPUTargeterConfig) + n.initDiskTargeter(&config.DiskTargeterConfig) + if err := n.initNetworking(); err != nil { // Set up networking layer. return fmt.Errorf("problem initializing networking: %w", err) } @@ -1532,6 +1545,7 @@ func (n *Node) shutdown() { ) } } + n.timeoutManager.Stop() if n.chainManager != nil { n.chainManager.Shutdown() } @@ -1557,6 +1571,13 @@ func (n *Node) shutdown() { n.runtimeManager.Stop(context.TODO()) if n.DBManager != nil { + if err := n.DB.Delete(ungracefulShutdown); err != nil { + n.Log.Error( + "failed to delete ungraceful shutdown key", + zap.Error(err), + ) + } + if err := n.DBManager.Close(); err != nil { n.Log.Warn("error during DB shutdown", zap.Error(err), diff --git a/node/overridden_manager.go b/node/overridden_manager.go new file mode 100644 index 000000000000..91d8c198a4c3 --- /dev/null +++ b/node/overridden_manager.go @@ -0,0 +1,85 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package node + +import ( + "fmt" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow/validators" + "github.com/ava-labs/avalanchego/utils/crypto/bls" + "github.com/ava-labs/avalanchego/utils/set" +) + +var _ validators.Manager = (*overriddenManager)(nil) + +// newOverriddenManager returns a Manager that overrides of all calls to the +// underlying Manager to only operate on the validators in [subnetID]. +func newOverriddenManager(subnetID ids.ID, manager validators.Manager) *overriddenManager { + return &overriddenManager{ + subnetID: subnetID, + manager: manager, + } +} + +// overriddenManager is a wrapper around a Manager that overrides of all calls +// to the underlying Manager to only operate on the validators in [subnetID]. +// subnetID here is typically the primary network ID, as it has the superset of +// all subnet validators. +type overriddenManager struct { + manager validators.Manager + subnetID ids.ID +} + +func (o *overriddenManager) AddStaker(_ ids.ID, nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) error { + return o.manager.AddStaker(o.subnetID, nodeID, pk, txID, weight) +} + +func (o *overriddenManager) AddWeight(_ ids.ID, nodeID ids.NodeID, weight uint64) error { + return o.manager.AddWeight(o.subnetID, nodeID, weight) +} + +func (o *overriddenManager) GetWeight(_ ids.ID, nodeID ids.NodeID) uint64 { + return o.manager.GetWeight(o.subnetID, nodeID) +} + +func (o *overriddenManager) GetValidator(_ ids.ID, nodeID ids.NodeID) (*validators.Validator, bool) { + return o.manager.GetValidator(o.subnetID, nodeID) +} + +func (o *overriddenManager) SubsetWeight(_ ids.ID, nodeIDs set.Set[ids.NodeID]) (uint64, error) { + return o.manager.SubsetWeight(o.subnetID, nodeIDs) +} + +func (o *overriddenManager) RemoveWeight(_ ids.ID, nodeID ids.NodeID, weight uint64) error { + return o.manager.RemoveWeight(o.subnetID, nodeID, weight) +} + +func (o *overriddenManager) Count(ids.ID) int { + return o.manager.Count(o.subnetID) +} + +func (o *overriddenManager) TotalWeight(ids.ID) (uint64, error) { + return o.manager.TotalWeight(o.subnetID) +} + +func (o *overriddenManager) Sample(_ ids.ID, size int) ([]ids.NodeID, error) { + return o.manager.Sample(o.subnetID, size) +} + +func (o *overriddenManager) GetMap(ids.ID) map[ids.NodeID]*validators.GetValidatorOutput { + return o.manager.GetMap(o.subnetID) +} + +func (o *overriddenManager) RegisterCallbackListener(_ ids.ID, listener validators.SetCallbackListener) { + o.manager.RegisterCallbackListener(o.subnetID, listener) +} + +func (o *overriddenManager) String() string { + return fmt.Sprintf("Overridden Validator Manager (SubnetID = %s): %s", o.subnetID, o.manager) +} + +func (o *overriddenManager) GetValidatorIDs(ids.ID) []ids.NodeID { + return o.manager.GetValidatorIDs(o.subnetID) +} diff --git a/node/overridden_manager_test.go b/node/overridden_manager_test.go new file mode 100644 index 000000000000..79f03579a5d0 --- /dev/null +++ b/node/overridden_manager_test.go @@ -0,0 +1,75 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package node + +import ( + "math" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow/validators" +) + +func TestOverriddenManager(t *testing.T) { + require := require.New(t) + + nodeID0 := ids.GenerateTestNodeID() + nodeID1 := ids.GenerateTestNodeID() + subnetID0 := ids.GenerateTestID() + subnetID1 := ids.GenerateTestID() + + m := validators.NewManager() + require.NoError(m.AddStaker(subnetID0, nodeID0, nil, ids.Empty, 1)) + require.NoError(m.AddStaker(subnetID1, nodeID1, nil, ids.Empty, 1)) + + om := newOverriddenManager(subnetID0, m) + _, ok := om.GetValidator(subnetID0, nodeID0) + require.True(ok) + _, ok = om.GetValidator(subnetID0, nodeID1) + require.False(ok) + _, ok = om.GetValidator(subnetID1, nodeID0) + require.True(ok) + _, ok = om.GetValidator(subnetID1, nodeID1) + require.False(ok) + + require.NoError(om.RemoveWeight(subnetID1, nodeID0, 1)) + _, ok = om.GetValidator(subnetID0, nodeID0) + require.False(ok) + _, ok = om.GetValidator(subnetID0, nodeID1) + require.False(ok) + _, ok = om.GetValidator(subnetID1, nodeID0) + require.False(ok) + _, ok = om.GetValidator(subnetID1, nodeID1) + require.False(ok) +} + +func TestOverriddenString(t *testing.T) { + require := require.New(t) + + nodeID0 := ids.EmptyNodeID + nodeID1, err := ids.NodeIDFromString("NodeID-QLbz7JHiBTspS962RLKV8GndWFwdYhk6V") + require.NoError(err) + + subnetID0, err := ids.FromString("TtF4d2QWbk5vzQGTEPrN48x6vwgAoAmKQ9cbp79inpQmcRKES") + require.NoError(err) + subnetID1, err := ids.FromString("2mcwQKiD8VEspmMJpL1dc7okQQ5dDVAWeCBZ7FWBFAbxpv3t7w") + require.NoError(err) + + m := validators.NewManager() + require.NoError(m.AddStaker(subnetID0, nodeID0, nil, ids.Empty, 1)) + require.NoError(m.AddStaker(subnetID0, nodeID1, nil, ids.Empty, math.MaxInt64-1)) + require.NoError(m.AddStaker(subnetID1, nodeID1, nil, ids.Empty, 1)) + + om := newOverriddenManager(subnetID0, m) + expected := "Overridden Validator Manager (SubnetID = TtF4d2QWbk5vzQGTEPrN48x6vwgAoAmKQ9cbp79inpQmcRKES): Validator Manager: (Size = 2)\n" + + " Subnet[TtF4d2QWbk5vzQGTEPrN48x6vwgAoAmKQ9cbp79inpQmcRKES]: Validator Set: (Size = 2, Weight = 9223372036854775807)\n" + + " Validator[0]: NodeID-111111111111111111116DBWJs, 1\n" + + " Validator[1]: NodeID-QLbz7JHiBTspS962RLKV8GndWFwdYhk6V, 9223372036854775806\n" + + " Subnet[2mcwQKiD8VEspmMJpL1dc7okQQ5dDVAWeCBZ7FWBFAbxpv3t7w]: Validator Set: (Size = 1, Weight = 1)\n" + + " Validator[0]: NodeID-QLbz7JHiBTspS962RLKV8GndWFwdYhk6V, 1" + result := om.String() + require.Equal(expected, result) +} diff --git a/proto/p2p/p2p.proto b/proto/p2p/p2p.proto index 89019523284e..96f674f2fa74 100644 --- a/proto/p2p/p2p.proto +++ b/proto/p2p/p2p.proto @@ -61,297 +61,356 @@ message Message { } } -// Message that a node sends to its peers in order to periodically check -// responsivness and report the local node's uptime measurements of the peer. +// Ping reports a peer's perceived uptime percentage. // -// On receiving a "ping", the peer should respond with a "pong". +// Peers should respond to Ping with a Pong. message Ping { - // uptime is the primary network uptime percentage. + // Uptime percentage on the primary network [0, 100] uint32 uptime = 1; - // subnet_uptimes contains subnet uptime percentages. + // Uptime percentage on subnets repeated SubnetUptime subnet_uptimes = 2; } -// Contains subnet id and the related observed subnet uptime of the message -// receiver (remote peer). +// SubnetUptime is a descriptor for a peer's perceived uptime on a subnet. message SubnetUptime { + // Subnet the peer is validating bytes subnet_id = 1; + // Uptime percentage on the subnet [0, 100] uint32 uptime = 2; } -// Contains the uptime percentage of the message receiver (remote peer) -// from the sender's point of view, in response to "ping" message. -// Uptimes are expected to be provided as integers ranging in [0, 100]. +// Pong is sent in response to a Ping with the perceived uptime of the +// peer. message Pong { - // Deprecated: remove all these fields in the future, but keep the message. - // uptime is the primary network uptime percentage. + // Deprecated: uptime is now sent in Ping + // Uptime percentage on the primary network [0, 100] uint32 uptime = 1; - // subnet_uptimes contains subnet uptime percentages. + // Deprecated: uptime is now sent in Ping + // Uptime percentage on subnets repeated SubnetUptime subnet_uptimes = 2; } -// The first outbound message that the local node sends to its remote peer -// when the connection is established. In order for the local node to be -// tracked as a valid peer by the remote peer, the fields must be valid. -// For instance, the network ID must be matched and timestamp should be in-sync. -// Otherwise, the remote peer closes the connection. -// ref. "avalanchego/network/peer#handleVersion" -// ref. https://pkg.go.dev/github.com/ava-labs/avalanchego/network#Network "Dispatch" +// Version is the first outbound message sent to a peer when a connection is +// established to start the p2p handshake. +// +// Peers must respond to a Version message with a PeerList message to allow the +// peer to connect to other peers in the network. +// +// Peers should drop connections to peers with incompatible versions. message Version { + // Network the peer is running on (e.g local, testnet, mainnet) uint32 network_id = 1; + // Unix timestamp when this Version message was created uint64 my_time = 2; + // IP address of the peer bytes ip_addr = 3; + // IP port of the peer uint32 ip_port = 4; + // Avalanche client version string my_version = 5; + // Timestamp of the IP uint64 my_version_time = 6; + // Signature of the peer IP port pair at a provided timestamp bytes sig = 7; + // Subnets the peer is tracking repeated bytes tracked_subnets = 8; } -// ref. https://pkg.go.dev/github.com/ava-labs/avalanchego/utils/ips#ClaimedIPPort +// ClaimedIpPort contains metadata needed to connect to a peer message ClaimedIpPort { + // X509 certificate of the peer bytes x509_certificate = 1; + // IP address of the peer bytes ip_addr = 2; + // IP port of the peer uint32 ip_port = 3; + // Timestamp of the IP address + port pair uint64 timestamp = 4; + // Signature of the IP port pair at a provided timestamp bytes signature = 5; + // P-Chain transaction that added this peer to the validator set bytes tx_id = 6; } -// Message that contains a list of peer information (IP, certs, etc.) -// in response to "version" message, and sent periodically to a set of -// validators. -// ref. "avalanchego/network/network#Dispatch.runtTimers" +// PeerList contains network-level metadata for a set of validators. +// +// PeerList must be sent in response to an inbound Version message from a +// remote peer a peer wants to connect to. Once a PeerList is received after +// a version message, the p2p handshake is complete and the connection is +// established. + +// Peers should periodically send PeerList messages to allow peers to +// discover each other. // -// On receiving "peer_list", the engine starts/updates the tracking information -// of the remote peer. +// PeerListAck should be sent in response to a PeerList. message PeerList { repeated ClaimedIpPort claimed_ip_ports = 1; } -// "peer_ack" is sent in response to a "peer_list" message. The "tx_id" should -// correspond to a "tx_id" in the "peer_list" message. The sender should set -// "timestamp" to be the latest known timestamp of a signed IP corresponding to -// the nodeID of "tx_id". -// -// Upon receipt, the "tx_id" and "timestamp" will determine if the receiptent -// can forgo future gossip of the node's IP to the sender of this message. +// PeerAck acknowledges that a gossiped peer in a PeerList message will be +// tracked by the remote peer. message PeerAck { + // P-Chain transaction that added the acknowledged peer to the validator + // set bytes tx_id = 1; + // Timestamp of the signed ip of the peer uint64 timestamp = 2; } -// Message that responds to a peer_list message containing the AddValidatorTxIDs -// from the peer_list message that we currently have in our validator set. +// PeerListAck is sent in response to PeerList to acknowledge the subset of +// peers that the peer will attempt to connect to. message PeerListAck { reserved 1; // deprecated; used to be tx_ids - repeated PeerAck peer_acks = 2; } +// GetStateSummaryFrontier requests a peer's most recently accepted state +// summary message GetStateSummaryFrontier { + // Chain being requested from bytes chain_id = 1; + // Unique identifier for this request uint32 request_id = 2; + // Timeout (ns) for this request uint64 deadline = 3; } +// StateSummaryFrontier is sent in response to a GetStateSummaryFrontier request message StateSummaryFrontier { + // Chain being responded from bytes chain_id = 1; + // Request id of the original GetStateSummaryFrontier request uint32 request_id = 2; + // The requested state summary bytes summary = 3; } +// GetAcceptedStateSummary requests a set of state summaries at a set of +// block heights message GetAcceptedStateSummary { + // Chain bein requested from bytes chain_id = 1; + // Unique identifier for this request uint32 request_id = 2; + // Timeout (ns) for this request uint64 deadline = 3; + // Heights being requested repeated uint64 heights = 4; } +// AcceptedStateSummary is sent in response to GetAcceptedStateSummary message AcceptedStateSummary { + // Chain being responded from bytes chain_id = 1; + // Request id of the original GetAcceptedStateSummary request uint32 request_id = 2; + // State summary ids repeated bytes summary_ids = 3; } +// The consensus engine that should be used when handling a consensus request. enum EngineType { ENGINE_TYPE_UNSPECIFIED = 0; + // Only the X-Chain uses avalanche consensus ENGINE_TYPE_AVALANCHE = 1; ENGINE_TYPE_SNOWMAN = 2; } -// Message to request for the accepted frontier of the "remote" peer. -// For instance, the accepted frontier of X-chain DAG is the set of -// accepted vertices that do not have any accepted descendants (i.e., frontier). +// GetAcceptedFrontier requests the accepted frontier from a peer. // -// During bootstrap, the local node sends out "get_accepted_frontier" to validators -// (see "avalanchego/snow/engine/common/bootstrapper.Startup"). -// And the expected response is "accepted_frontier". -// -// See "snow/engine/common/bootstrapper.go#AcceptedFrontier". +// Peers should respond to GetAcceptedFrontier with AcceptedFrontier. message GetAcceptedFrontier { + // Chain being requested from bytes chain_id = 1; + // Unique identifier for this request uint32 request_id = 2; + // Timeout (ns) for this request uint64 deadline = 3; + // Consensus type the remote peer should use to handle this message EngineType engine_type = 4; } -// Message that contains the list of accepted frontier in response to -// "get_accepted_frontier". For instance, on receiving "get_accepted_frontier", -// the X-chain engine responds with the accepted frontier of X-chain DAG. +// AcceptedFrontier contains the remote peer's last accepted frontier. // -// See "snow/engine/common/bootstrapper.go#AcceptedFrontier". +// AcceptedFrontier is sent in response to GetAcceptedFrontier. message AcceptedFrontier { reserved 4; // Until Cortina upgrade is activated - + // Chain being responded from bytes chain_id = 1; + // Request id of the original GetAcceptedFrontier request uint32 request_id = 2; + // The id of the last accepted frontier bytes container_id = 3; } -// Message to request for the accepted blocks/vertices of the "remote" peer. -// The local node sends out this message during bootstrap, following "get_accepted_frontier". -// Basically, sending the list of the accepted frontier and expects the response of -// the accepted IDs from the remote peer. +// GetAccepted sends a request with the sender's accepted frontier to a remote +// peer. // -// See "avalanchego/snow/engine/common/bootstrapper.Startup" and "sendGetAccepted". -// See "snow/engine/common/bootstrapper.go#AcceptedFrontier". +// Peers should respond to GetAccepted with an Accepted message. message GetAccepted { + // Chain being requested from bytes chain_id = 1; + // Unique identifier for this message uint32 request_id = 2; + // Timeout (ns) for this request uint64 deadline = 3; + // The sender's accepted frontier repeated bytes container_ids = 4; + // Consensus type to handle this message EngineType engine_type = 5; } -// Message that contains the list of accepted block/vertex IDs in response to -// "get_accepted". For instance, on receiving "get_accepted" that contains -// the sender's accepted frontier IDs, the X-chain engine responds only with -// the accepted vertex IDs of the X-chain DAG. -// -// See "snow/engine/avalanche#GetAccepted" and "SendAccepted". -// See "snow/engine/common/bootstrapper.go#Accepted". +// Accepted is sent in response to GetAccepted. The sending peer responds with +// a subset of container ids from the GetAccepted request that the sending peer +// has accepted. message Accepted { reserved 4; // Until Cortina upgrade is activated - + // Chain being responded from bytes chain_id = 1; + // Request id of the original GetAccepted request uint32 request_id = 2; + // Subset of container ids from the GetAccepted request that the sender has + // accepted repeated bytes container_ids = 3; } -// Message that requests for the ancestors (parents) of the specified container ID. -// The engine bootstrapper sends this message to fetch all accepted containers -// in its transitive path. +// GetAncestors requests the ancestors for a given container. // -// On receiving "get_ancestors", it responds with the ancestors' container bytes -// in "ancestors" message. +// The remote peer should respond with an Ancestors message. message GetAncestors { + // Chain being requested from bytes chain_id = 1; + // Unique identifier for this request uint32 request_id = 2; + // Timeout (ns) for this request uint64 deadline = 3; + // Container for which ancestors are being requested bytes container_id = 4; + // Consensus type to handle this message EngineType engine_type = 5; } -// Message that contains the container bytes of the ancestors -// in response to "get_ancestors". +// Ancestors is sent in response to GetAncestors. // -// On receiving "ancestors", the engine parses the containers and queues them -// to be accepted once we've received the entire chain history. +// Ancestors contains a contiguous ancestry of containers for the requested +// container in order of increasing block height. message Ancestors { reserved 4; // Until Cortina upgrade is activated - + // Chain being responded from bytes chain_id = 1; + // Request id of the original GetAncestors request uint32 request_id = 2; + // Ancestry for the requested container repeated bytes containers = 3; } -// Message that requests for the container data. +// Get requests a container from a remote peer. // -// On receiving "get", the engine looks up the container from the storage. -// If the container is found, it sends out the container data in "put" message. +// Remote peers should respond with a Put message if they have the container. message Get { + // Chain being requested from bytes chain_id = 1; + // Unique identifier for this request uint32 request_id = 2; + // Timeout (ns) for this request uint64 deadline = 3; + // Container being requested bytes container_id = 4; + // Consensus type to handle this message EngineType engine_type = 5; } -// Message that contains the container ID and its bytes in response to "get". -// -// On receiving "put", the engine parses the container and tries to issue it to consensus. +// Put is sent in response to Get with the requested block. message Put { + // Chain being responded from bytes chain_id = 1; + // Request id of the original Get request uint32 request_id = 2; + // Requested container bytes container = 3; + // Consensus type to handle this message EngineType engine_type = 4; } -// Message that contains a preferred container ID and its container bytes -// in order to query other peers for their preferences of the container. -// For example, when a new container is issued, the engine sends out -// "push_query" and "pull_query" queries to ask other peers their preferences. -// See "avalanchego/snow/engine/common#SendMixedQuery". +// PushQuery requests the preferences of a remote peer given a container. // -// On receiving the "push_query", the engine parses the incoming container -// and tries to issue the container and all of its parents to the consensus, -// and calls "pull_query" handler to send "chits" for voting. +// Remote peers should respond to a PushQuery with a Chits message message PushQuery { + // Chain being requested from bytes chain_id = 1; + // Unique identifier for this request uint32 request_id = 2; + // Timeout (ns) for this request uint64 deadline = 3; + // Container being gossiped bytes container = 4; + // Consensus type to handle this message EngineType engine_type = 5; + // Requesting peer's last accepted height uint64 requested_height = 6; } -// Message that contains a preferred container ID to query other peers -// for their preferences of the container. -// For example, when a new container is issued, the engine sends out -// "push_query" and "pull_query" queries to ask other peers their preferences. -// See "avalanchego/snow/engine/common#SendMixedQuery". +// PullQuery requests the preferences of a remote peer given a container id. +// +// Remote peers should respond to a PullQuery with a Chits message message PullQuery { + // Chain being requested from bytes chain_id = 1; + // Unique identifier for this request uint32 request_id = 2; + // Timeout (ns) for this request uint64 deadline = 3; + // Container id being gossiped bytes container_id = 4; + // Consensus type to handle this message EngineType engine_type = 5; + // Requesting peer's last accepted height uint64 requested_height = 6; } -// Message that contains the votes/preferences of the node. It is sent in -// response to a "push_query" or "pull_query" request. -// -// Upon receiving "chits", the engine will attempt to issue the preferred block -// into consensus. If the referenced block is not locally available, the engine -// will respond with a "get" message to fetch the missing block from the remote -// peer. +// Chits contains the preferences of a peer in response to a PushQuery or +// PullQuery message. message Chits { + // Chain being responded from bytes chain_id = 1; + // Request id of the original PushQuery/PullQuery request uint32 request_id = 2; - // Represents the current preferred block. + // Currently preferred block bytes preferred_id = 3; - // Represents the last accepted block. + // Last accepted block bytes accepted_id = 4; - // Represents the current preferred block at the requested height. + // Currently preferred block at the requested height bytes preferred_id_at_height = 5; } +// AppRequest is a VM-defined request. +// +// Remote peers must respond to AppRequest with corresponding AppResponse message AppRequest { + // Chain being requested from bytes chain_id = 1; + // Unique identifier for this request uint32 request_id = 2; + // Timeout (ns) for this request uint64 deadline = 3; + // Request body bytes app_bytes = 4; } +// AppResponse is a VM-defined response sent in response to AppRequest message AppResponse { + // Chain being responded from bytes chain_id = 1; + // Request id of the original AppRequest uint32 request_id = 2; + // Response body bytes app_bytes = 3; } +// AppGossip is a VM-defined message message AppGossip { + // Chain the message is for bytes chain_id = 1; + // Message body bytes app_bytes = 2; } diff --git a/proto/pb/p2p/p2p.pb.go b/proto/pb/p2p/p2p.pb.go index fba9670e2643..2a2bf8bb3137 100644 --- a/proto/pb/p2p/p2p.pb.go +++ b/proto/pb/p2p/p2p.pb.go @@ -20,12 +20,14 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) +// The consensus engine that should be used when handling a consensus request. type EngineType int32 const ( EngineType_ENGINE_TYPE_UNSPECIFIED EngineType = 0 - EngineType_ENGINE_TYPE_AVALANCHE EngineType = 1 - EngineType_ENGINE_TYPE_SNOWMAN EngineType = 2 + // Only the X-Chain uses avalanche consensus + EngineType_ENGINE_TYPE_AVALANCHE EngineType = 1 + EngineType_ENGINE_TYPE_SNOWMAN EngineType = 2 ) // Enum value maps for EngineType. @@ -489,18 +491,17 @@ func (*Message_AppGossip) isMessage_Message() {} func (*Message_PeerListAck) isMessage_Message() {} -// Message that a node sends to its peers in order to periodically check -// responsivness and report the local node's uptime measurements of the peer. +// Ping reports a peer's perceived uptime percentage. // -// On receiving a "ping", the peer should respond with a "pong". +// Peers should respond to Ping with a Pong. type Ping struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - // uptime is the primary network uptime percentage. + // Uptime percentage on the primary network [0, 100] Uptime uint32 `protobuf:"varint,1,opt,name=uptime,proto3" json:"uptime,omitempty"` - // subnet_uptimes contains subnet uptime percentages. + // Uptime percentage on subnets SubnetUptimes []*SubnetUptime `protobuf:"bytes,2,rep,name=subnet_uptimes,json=subnetUptimes,proto3" json:"subnet_uptimes,omitempty"` } @@ -550,15 +551,16 @@ func (x *Ping) GetSubnetUptimes() []*SubnetUptime { return nil } -// Contains subnet id and the related observed subnet uptime of the message -// receiver (remote peer). +// SubnetUptime is a descriptor for a peer's perceived uptime on a subnet. type SubnetUptime struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields + // Subnet the peer is validating SubnetId []byte `protobuf:"bytes,1,opt,name=subnet_id,json=subnetId,proto3" json:"subnet_id,omitempty"` - Uptime uint32 `protobuf:"varint,2,opt,name=uptime,proto3" json:"uptime,omitempty"` + // Uptime percentage on the subnet [0, 100] + Uptime uint32 `protobuf:"varint,2,opt,name=uptime,proto3" json:"uptime,omitempty"` } func (x *SubnetUptime) Reset() { @@ -607,18 +609,18 @@ func (x *SubnetUptime) GetUptime() uint32 { return 0 } -// Contains the uptime percentage of the message receiver (remote peer) -// from the sender's point of view, in response to "ping" message. -// Uptimes are expected to be provided as integers ranging in [0, 100]. +// Pong is sent in response to a Ping with the perceived uptime of the +// peer. type Pong struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - // Deprecated: remove all these fields in the future, but keep the message. - // uptime is the primary network uptime percentage. + // Deprecated: uptime is now sent in Ping + // Uptime percentage on the primary network [0, 100] Uptime uint32 `protobuf:"varint,1,opt,name=uptime,proto3" json:"uptime,omitempty"` - // subnet_uptimes contains subnet uptime percentages. + // Deprecated: uptime is now sent in Ping + // Uptime percentage on subnets SubnetUptimes []*SubnetUptime `protobuf:"bytes,2,rep,name=subnet_uptimes,json=subnetUptimes,proto3" json:"subnet_uptimes,omitempty"` } @@ -668,25 +670,33 @@ func (x *Pong) GetSubnetUptimes() []*SubnetUptime { return nil } -// The first outbound message that the local node sends to its remote peer -// when the connection is established. In order for the local node to be -// tracked as a valid peer by the remote peer, the fields must be valid. -// For instance, the network ID must be matched and timestamp should be in-sync. -// Otherwise, the remote peer closes the connection. -// ref. "avalanchego/network/peer#handleVersion" -// ref. https://pkg.go.dev/github.com/ava-labs/avalanchego/network#Network "Dispatch" +// Version is the first outbound message sent to a peer when a connection is +// established to start the p2p handshake. +// +// Peers must respond to a Version message with a PeerList message to allow the +// peer to connect to other peers in the network. +// +// Peers should drop connections to peers with incompatible versions. type Version struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - NetworkId uint32 `protobuf:"varint,1,opt,name=network_id,json=networkId,proto3" json:"network_id,omitempty"` - MyTime uint64 `protobuf:"varint,2,opt,name=my_time,json=myTime,proto3" json:"my_time,omitempty"` - IpAddr []byte `protobuf:"bytes,3,opt,name=ip_addr,json=ipAddr,proto3" json:"ip_addr,omitempty"` - IpPort uint32 `protobuf:"varint,4,opt,name=ip_port,json=ipPort,proto3" json:"ip_port,omitempty"` - MyVersion string `protobuf:"bytes,5,opt,name=my_version,json=myVersion,proto3" json:"my_version,omitempty"` - MyVersionTime uint64 `protobuf:"varint,6,opt,name=my_version_time,json=myVersionTime,proto3" json:"my_version_time,omitempty"` - Sig []byte `protobuf:"bytes,7,opt,name=sig,proto3" json:"sig,omitempty"` + // Network the peer is running on (e.g local, testnet, mainnet) + NetworkId uint32 `protobuf:"varint,1,opt,name=network_id,json=networkId,proto3" json:"network_id,omitempty"` + // Unix timestamp when this Version message was created + MyTime uint64 `protobuf:"varint,2,opt,name=my_time,json=myTime,proto3" json:"my_time,omitempty"` + // IP address of the peer + IpAddr []byte `protobuf:"bytes,3,opt,name=ip_addr,json=ipAddr,proto3" json:"ip_addr,omitempty"` + // IP port of the peer + IpPort uint32 `protobuf:"varint,4,opt,name=ip_port,json=ipPort,proto3" json:"ip_port,omitempty"` + // Avalanche client version + MyVersion string `protobuf:"bytes,5,opt,name=my_version,json=myVersion,proto3" json:"my_version,omitempty"` + // Timestamp of the IP + MyVersionTime uint64 `protobuf:"varint,6,opt,name=my_version_time,json=myVersionTime,proto3" json:"my_version_time,omitempty"` + // Signature of the peer IP port pair at a provided timestamp + Sig []byte `protobuf:"bytes,7,opt,name=sig,proto3" json:"sig,omitempty"` + // Subnets the peer is tracking TrackedSubnets [][]byte `protobuf:"bytes,8,rep,name=tracked_subnets,json=trackedSubnets,proto3" json:"tracked_subnets,omitempty"` } @@ -778,18 +788,24 @@ func (x *Version) GetTrackedSubnets() [][]byte { return nil } -// ref. https://pkg.go.dev/github.com/ava-labs/avalanchego/utils/ips#ClaimedIPPort +// ClaimedIpPort contains metadata needed to connect to a peer type ClaimedIpPort struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields + // X509 certificate of the peer X509Certificate []byte `protobuf:"bytes,1,opt,name=x509_certificate,json=x509Certificate,proto3" json:"x509_certificate,omitempty"` - IpAddr []byte `protobuf:"bytes,2,opt,name=ip_addr,json=ipAddr,proto3" json:"ip_addr,omitempty"` - IpPort uint32 `protobuf:"varint,3,opt,name=ip_port,json=ipPort,proto3" json:"ip_port,omitempty"` - Timestamp uint64 `protobuf:"varint,4,opt,name=timestamp,proto3" json:"timestamp,omitempty"` - Signature []byte `protobuf:"bytes,5,opt,name=signature,proto3" json:"signature,omitempty"` - TxId []byte `protobuf:"bytes,6,opt,name=tx_id,json=txId,proto3" json:"tx_id,omitempty"` + // IP address of the peer + IpAddr []byte `protobuf:"bytes,2,opt,name=ip_addr,json=ipAddr,proto3" json:"ip_addr,omitempty"` + // IP port of the peer + IpPort uint32 `protobuf:"varint,3,opt,name=ip_port,json=ipPort,proto3" json:"ip_port,omitempty"` + // Timestamp of the IP address + port pair + Timestamp uint64 `protobuf:"varint,4,opt,name=timestamp,proto3" json:"timestamp,omitempty"` + // Signature of the IP port pair at a provided timestamp + Signature []byte `protobuf:"bytes,5,opt,name=signature,proto3" json:"signature,omitempty"` + // P-Chain transaction that added this peer to the validator set + TxId []byte `protobuf:"bytes,6,opt,name=tx_id,json=txId,proto3" json:"tx_id,omitempty"` } func (x *ClaimedIpPort) Reset() { @@ -866,13 +882,10 @@ func (x *ClaimedIpPort) GetTxId() []byte { return nil } -// Message that contains a list of peer information (IP, certs, etc.) -// in response to "version" message, and sent periodically to a set of -// validators. -// ref. "avalanchego/network/network#Dispatch.runtTimers" +// Peers should periodically send PeerList messages to allow peers to +// discover each other. // -// On receiving "peer_list", the engine starts/updates the tracking information -// of the remote peer. +// PeerListAck should be sent in response to a PeerList. type PeerList struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -920,19 +933,17 @@ func (x *PeerList) GetClaimedIpPorts() []*ClaimedIpPort { return nil } -// "peer_ack" is sent in response to a "peer_list" message. The "tx_id" should -// correspond to a "tx_id" in the "peer_list" message. The sender should set -// "timestamp" to be the latest known timestamp of a signed IP corresponding to -// the nodeID of "tx_id". -// -// Upon receipt, the "tx_id" and "timestamp" will determine if the receiptent -// can forgo future gossip of the node's IP to the sender of this message. +// PeerAck acknowledges that a gossiped peer in a PeerList message will be +// tracked by the remote peer. type PeerAck struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - TxId []byte `protobuf:"bytes,1,opt,name=tx_id,json=txId,proto3" json:"tx_id,omitempty"` + // P-Chain transaction that added the acknowledged peer to the validator + // set + TxId []byte `protobuf:"bytes,1,opt,name=tx_id,json=txId,proto3" json:"tx_id,omitempty"` + // Timestamp of the signed ip of the peer Timestamp uint64 `protobuf:"varint,2,opt,name=timestamp,proto3" json:"timestamp,omitempty"` } @@ -982,8 +993,8 @@ func (x *PeerAck) GetTimestamp() uint64 { return 0 } -// Message that responds to a peer_list message containing the AddValidatorTxIDs -// from the peer_list message that we currently have in our validator set. +// PeerListAck is sent in response to PeerList to acknowledge the subset of +// peers that the peer will attempt to connect to. type PeerListAck struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -1031,14 +1042,19 @@ func (x *PeerListAck) GetPeerAcks() []*PeerAck { return nil } +// GetStateSummaryFrontier requests a peer's most recently accepted state +// summary type GetStateSummaryFrontier struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` + // Chain being requested from + ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` + // Unique identifier for this request RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` - Deadline uint64 `protobuf:"varint,3,opt,name=deadline,proto3" json:"deadline,omitempty"` + // Timeout (ns) for this request + Deadline uint64 `protobuf:"varint,3,opt,name=deadline,proto3" json:"deadline,omitempty"` } func (x *GetStateSummaryFrontier) Reset() { @@ -1094,14 +1110,18 @@ func (x *GetStateSummaryFrontier) GetDeadline() uint64 { return 0 } +// StateSummaryFrontier is sent in response to a GetStateSummaryFrontier request type StateSummaryFrontier struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` + // Chain being responded from + ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` + // Request id of the original GetStateSummaryFrontier request RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` - Summary []byte `protobuf:"bytes,3,opt,name=summary,proto3" json:"summary,omitempty"` + // The requested state summary + Summary []byte `protobuf:"bytes,3,opt,name=summary,proto3" json:"summary,omitempty"` } func (x *StateSummaryFrontier) Reset() { @@ -1157,15 +1177,21 @@ func (x *StateSummaryFrontier) GetSummary() []byte { return nil } +// GetAcceptedStateSummary requests a set of state summaries at a set of +// block heights type GetAcceptedStateSummary struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` - RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` - Deadline uint64 `protobuf:"varint,3,opt,name=deadline,proto3" json:"deadline,omitempty"` - Heights []uint64 `protobuf:"varint,4,rep,packed,name=heights,proto3" json:"heights,omitempty"` + // Chain bein requested from + ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` + // Unique identifier for this request + RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + // Timeout (ns) for this request + Deadline uint64 `protobuf:"varint,3,opt,name=deadline,proto3" json:"deadline,omitempty"` + // Heights being requested + Heights []uint64 `protobuf:"varint,4,rep,packed,name=heights,proto3" json:"heights,omitempty"` } func (x *GetAcceptedStateSummary) Reset() { @@ -1228,13 +1254,17 @@ func (x *GetAcceptedStateSummary) GetHeights() []uint64 { return nil } +// AcceptedStateSummary is sent in response to GetAcceptedStateSummary type AcceptedStateSummary struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` - RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + // Chain being responded from + ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` + // Request id of the original GetAcceptedStateSummary request + RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + // State summary ids SummaryIds [][]byte `protobuf:"bytes,3,rep,name=summary_ids,json=summaryIds,proto3" json:"summary_ids,omitempty"` } @@ -1291,23 +1321,21 @@ func (x *AcceptedStateSummary) GetSummaryIds() [][]byte { return nil } -// Message to request for the accepted frontier of the "remote" peer. -// For instance, the accepted frontier of X-chain DAG is the set of -// accepted vertices that do not have any accepted descendants (i.e., frontier). +// GetAcceptedFrontier requests the accepted frontier from a peer. // -// During bootstrap, the local node sends out "get_accepted_frontier" to validators -// (see "avalanchego/snow/engine/common/bootstrapper.Startup"). -// And the expected response is "accepted_frontier". -// -// See "snow/engine/common/bootstrapper.go#AcceptedFrontier". +// Peers should respond to GetAcceptedFrontier with AcceptedFrontier. type GetAcceptedFrontier struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` - RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` - Deadline uint64 `protobuf:"varint,3,opt,name=deadline,proto3" json:"deadline,omitempty"` + // Chain being requested from + ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` + // Unique identifier for this request + RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + // Timeout (ns) for this request + Deadline uint64 `protobuf:"varint,3,opt,name=deadline,proto3" json:"deadline,omitempty"` + // Consensus type the remote peer should use to handle this message EngineType EngineType `protobuf:"varint,4,opt,name=engine_type,json=engineType,proto3,enum=p2p.EngineType" json:"engine_type,omitempty"` } @@ -1371,18 +1399,19 @@ func (x *GetAcceptedFrontier) GetEngineType() EngineType { return EngineType_ENGINE_TYPE_UNSPECIFIED } -// Message that contains the list of accepted frontier in response to -// "get_accepted_frontier". For instance, on receiving "get_accepted_frontier", -// the X-chain engine responds with the accepted frontier of X-chain DAG. +// AcceptedFrontier contains the remote peer's last accepted frontier. // -// See "snow/engine/common/bootstrapper.go#AcceptedFrontier". +// AcceptedFrontier is sent in response to GetAcceptedFrontier. type AcceptedFrontier struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` - RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + // Chain being responded from + ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` + // Request id of the original GetAcceptedFrontier request + RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + // The id of the last accepted frontier ContainerId []byte `protobuf:"bytes,3,opt,name=container_id,json=containerId,proto3" json:"container_id,omitempty"` } @@ -1439,23 +1468,25 @@ func (x *AcceptedFrontier) GetContainerId() []byte { return nil } -// Message to request for the accepted blocks/vertices of the "remote" peer. -// The local node sends out this message during bootstrap, following "get_accepted_frontier". -// Basically, sending the list of the accepted frontier and expects the response of -// the accepted IDs from the remote peer. +// GetAccepted sends a request with the sender's accepted frontier to a remote +// peer. // -// See "avalanchego/snow/engine/common/bootstrapper.Startup" and "sendGetAccepted". -// See "snow/engine/common/bootstrapper.go#AcceptedFrontier". +// Peers should respond to GetAccepted with an Accepted message. type GetAccepted struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` - RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` - Deadline uint64 `protobuf:"varint,3,opt,name=deadline,proto3" json:"deadline,omitempty"` - ContainerIds [][]byte `protobuf:"bytes,4,rep,name=container_ids,json=containerIds,proto3" json:"container_ids,omitempty"` - EngineType EngineType `protobuf:"varint,5,opt,name=engine_type,json=engineType,proto3,enum=p2p.EngineType" json:"engine_type,omitempty"` + // Chain being requested from + ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` + // Unique identifier for this message + RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + // Timeout (ns) for this request + Deadline uint64 `protobuf:"varint,3,opt,name=deadline,proto3" json:"deadline,omitempty"` + // The sender's accepted frontier + ContainerIds [][]byte `protobuf:"bytes,4,rep,name=container_ids,json=containerIds,proto3" json:"container_ids,omitempty"` + // Consensus type to handle this message + EngineType EngineType `protobuf:"varint,5,opt,name=engine_type,json=engineType,proto3,enum=p2p.EngineType" json:"engine_type,omitempty"` } func (x *GetAccepted) Reset() { @@ -1525,20 +1556,20 @@ func (x *GetAccepted) GetEngineType() EngineType { return EngineType_ENGINE_TYPE_UNSPECIFIED } -// Message that contains the list of accepted block/vertex IDs in response to -// "get_accepted". For instance, on receiving "get_accepted" that contains -// the sender's accepted frontier IDs, the X-chain engine responds only with -// the accepted vertex IDs of the X-chain DAG. -// -// See "snow/engine/avalanche#GetAccepted" and "SendAccepted". -// See "snow/engine/common/bootstrapper.go#Accepted". +// Accepted is sent in response to GetAccepted. The sending peer responds with +// a subset of container ids from the GetAccepted request that the sending peer +// has accepted. type Accepted struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` - RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + // Chain being responded from + ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` + // Request id of the original GetAccepted request + RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + // Subset of container ids from the GetAccepted request that the sender has + // accepted ContainerIds [][]byte `protobuf:"bytes,3,rep,name=container_ids,json=containerIds,proto3" json:"container_ids,omitempty"` } @@ -1595,22 +1626,24 @@ func (x *Accepted) GetContainerIds() [][]byte { return nil } -// Message that requests for the ancestors (parents) of the specified container ID. -// The engine bootstrapper sends this message to fetch all accepted containers -// in its transitive path. +// GetAncestors requests the ancestors for a given container. // -// On receiving "get_ancestors", it responds with the ancestors' container bytes -// in "ancestors" message. +// The remote peer should respond with an Ancestors message. type GetAncestors struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` - RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` - Deadline uint64 `protobuf:"varint,3,opt,name=deadline,proto3" json:"deadline,omitempty"` - ContainerId []byte `protobuf:"bytes,4,opt,name=container_id,json=containerId,proto3" json:"container_id,omitempty"` - EngineType EngineType `protobuf:"varint,5,opt,name=engine_type,json=engineType,proto3,enum=p2p.EngineType" json:"engine_type,omitempty"` + // Chain being requested from + ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` + // Unique identifier for this request + RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + // Timeout (ns) for this request + Deadline uint64 `protobuf:"varint,3,opt,name=deadline,proto3" json:"deadline,omitempty"` + // Container for which ancestors are being requested + ContainerId []byte `protobuf:"bytes,4,opt,name=container_id,json=containerId,proto3" json:"container_id,omitempty"` + // Consensus type to handle this message + EngineType EngineType `protobuf:"varint,5,opt,name=engine_type,json=engineType,proto3,enum=p2p.EngineType" json:"engine_type,omitempty"` } func (x *GetAncestors) Reset() { @@ -1680,18 +1713,20 @@ func (x *GetAncestors) GetEngineType() EngineType { return EngineType_ENGINE_TYPE_UNSPECIFIED } -// Message that contains the container bytes of the ancestors -// in response to "get_ancestors". +// Ancestors is sent in response to GetAncestors. // -// On receiving "ancestors", the engine parses the containers and queues them -// to be accepted once we've received the entire chain history. +// Ancestors contains a contiguous ancestry of containers for the requested +// container in order of increasing block height. type Ancestors struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` - RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + // Chain being responded from + ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` + // Request id of the original GetAncestors request + RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + // Ancestry for the requested container Containers [][]byte `protobuf:"bytes,3,rep,name=containers,proto3" json:"containers,omitempty"` } @@ -1748,20 +1783,24 @@ func (x *Ancestors) GetContainers() [][]byte { return nil } -// Message that requests for the container data. +// Get requests a container from a remote peer. // -// On receiving "get", the engine looks up the container from the storage. -// If the container is found, it sends out the container data in "put" message. +// Remote peers should respond with a Put message if they have the container. type Get struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` - RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` - Deadline uint64 `protobuf:"varint,3,opt,name=deadline,proto3" json:"deadline,omitempty"` - ContainerId []byte `protobuf:"bytes,4,opt,name=container_id,json=containerId,proto3" json:"container_id,omitempty"` - EngineType EngineType `protobuf:"varint,5,opt,name=engine_type,json=engineType,proto3,enum=p2p.EngineType" json:"engine_type,omitempty"` + // Chain being requested from + ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` + // Unique identifier for this request + RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + // Timeout (ns) for this request + Deadline uint64 `protobuf:"varint,3,opt,name=deadline,proto3" json:"deadline,omitempty"` + // Container being requested + ContainerId []byte `protobuf:"bytes,4,opt,name=container_id,json=containerId,proto3" json:"container_id,omitempty"` + // Consensus type to handle this message + EngineType EngineType `protobuf:"varint,5,opt,name=engine_type,json=engineType,proto3,enum=p2p.EngineType" json:"engine_type,omitempty"` } func (x *Get) Reset() { @@ -1831,17 +1870,19 @@ func (x *Get) GetEngineType() EngineType { return EngineType_ENGINE_TYPE_UNSPECIFIED } -// Message that contains the container ID and its bytes in response to "get". -// -// On receiving "put", the engine parses the container and tries to issue it to consensus. +// Put is sent in response to Get with the requested block. type Put struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` - RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` - Container []byte `protobuf:"bytes,3,opt,name=container,proto3" json:"container,omitempty"` + // Chain being responded from + ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` + // Request id of the original Get request + RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + // Requested container + Container []byte `protobuf:"bytes,3,opt,name=container,proto3" json:"container,omitempty"` + // Consensus type to handle this message EngineType EngineType `protobuf:"varint,4,opt,name=engine_type,json=engineType,proto3,enum=p2p.EngineType" json:"engine_type,omitempty"` } @@ -1905,26 +1946,26 @@ func (x *Put) GetEngineType() EngineType { return EngineType_ENGINE_TYPE_UNSPECIFIED } -// Message that contains a preferred container ID and its container bytes -// in order to query other peers for their preferences of the container. -// For example, when a new container is issued, the engine sends out -// "push_query" and "pull_query" queries to ask other peers their preferences. -// See "avalanchego/snow/engine/common#SendMixedQuery". +// PushQuery requests the preferences of a remote peer given a container. // -// On receiving the "push_query", the engine parses the incoming container -// and tries to issue the container and all of its parents to the consensus, -// and calls "pull_query" handler to send "chits" for voting. +// Remote peers should respond to a PushQuery with a Chits message type PushQuery struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` - RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` - Deadline uint64 `protobuf:"varint,3,opt,name=deadline,proto3" json:"deadline,omitempty"` - Container []byte `protobuf:"bytes,4,opt,name=container,proto3" json:"container,omitempty"` - EngineType EngineType `protobuf:"varint,5,opt,name=engine_type,json=engineType,proto3,enum=p2p.EngineType" json:"engine_type,omitempty"` - RequestedHeight uint64 `protobuf:"varint,6,opt,name=requested_height,json=requestedHeight,proto3" json:"requested_height,omitempty"` + // Chain being requested from + ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` + // Unique identifier for this request + RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + // Timeout (ns) for this request + Deadline uint64 `protobuf:"varint,3,opt,name=deadline,proto3" json:"deadline,omitempty"` + // Container being gossiped + Container []byte `protobuf:"bytes,4,opt,name=container,proto3" json:"container,omitempty"` + // Consensus type to handle this message + EngineType EngineType `protobuf:"varint,5,opt,name=engine_type,json=engineType,proto3,enum=p2p.EngineType" json:"engine_type,omitempty"` + // Requesting peer's last accepted height + RequestedHeight uint64 `protobuf:"varint,6,opt,name=requested_height,json=requestedHeight,proto3" json:"requested_height,omitempty"` } func (x *PushQuery) Reset() { @@ -2001,22 +2042,26 @@ func (x *PushQuery) GetRequestedHeight() uint64 { return 0 } -// Message that contains a preferred container ID to query other peers -// for their preferences of the container. -// For example, when a new container is issued, the engine sends out -// "push_query" and "pull_query" queries to ask other peers their preferences. -// See "avalanchego/snow/engine/common#SendMixedQuery". +// PullQuery requests the preferences of a remote peer given a container id. +// +// Remote peers should respond to a PullQuery with a Chits message type PullQuery struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` - RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` - Deadline uint64 `protobuf:"varint,3,opt,name=deadline,proto3" json:"deadline,omitempty"` - ContainerId []byte `protobuf:"bytes,4,opt,name=container_id,json=containerId,proto3" json:"container_id,omitempty"` - EngineType EngineType `protobuf:"varint,5,opt,name=engine_type,json=engineType,proto3,enum=p2p.EngineType" json:"engine_type,omitempty"` - RequestedHeight uint64 `protobuf:"varint,6,opt,name=requested_height,json=requestedHeight,proto3" json:"requested_height,omitempty"` + // Chain being requested from + ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` + // Unique identifier for this request + RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + // Timeout (ns) for this request + Deadline uint64 `protobuf:"varint,3,opt,name=deadline,proto3" json:"deadline,omitempty"` + // Container id being gossiped + ContainerId []byte `protobuf:"bytes,4,opt,name=container_id,json=containerId,proto3" json:"container_id,omitempty"` + // Consensus type to handle this message + EngineType EngineType `protobuf:"varint,5,opt,name=engine_type,json=engineType,proto3,enum=p2p.EngineType" json:"engine_type,omitempty"` + // Requesting peer's last accepted height + RequestedHeight uint64 `protobuf:"varint,6,opt,name=requested_height,json=requestedHeight,proto3" json:"requested_height,omitempty"` } func (x *PullQuery) Reset() { @@ -2093,25 +2138,22 @@ func (x *PullQuery) GetRequestedHeight() uint64 { return 0 } -// Message that contains the votes/preferences of the node. It is sent in -// response to a "push_query" or "pull_query" request. -// -// Upon receiving "chits", the engine will attempt to issue the preferred block -// into consensus. If the referenced block is not locally available, the engine -// will respond with a "get" message to fetch the missing block from the remote -// peer. +// Chits contains the preferences of a peer in response to a PushQuery or +// PullQuery message. type Chits struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` + // Chain being responded from + ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` + // Request id of the original PushQuery/PullQuery request RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` - // Represents the current preferred block. + // Currently preferred block PreferredId []byte `protobuf:"bytes,3,opt,name=preferred_id,json=preferredId,proto3" json:"preferred_id,omitempty"` - // Represents the last accepted block. + // Last accepted block AcceptedId []byte `protobuf:"bytes,4,opt,name=accepted_id,json=acceptedId,proto3" json:"accepted_id,omitempty"` - // Represents the current preferred block at the requested height. + // Currently preferred block at the requested height PreferredIdAtHeight []byte `protobuf:"bytes,5,opt,name=preferred_id_at_height,json=preferredIdAtHeight,proto3" json:"preferred_id_at_height,omitempty"` } @@ -2182,15 +2224,22 @@ func (x *Chits) GetPreferredIdAtHeight() []byte { return nil } +// AppRequest is a VM-defined request. +// +// Remote peers must respond to AppRequest with corresponding AppResponse type AppRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` + // Chain being requested from + ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` + // Unique identifier for this request RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` - Deadline uint64 `protobuf:"varint,3,opt,name=deadline,proto3" json:"deadline,omitempty"` - AppBytes []byte `protobuf:"bytes,4,opt,name=app_bytes,json=appBytes,proto3" json:"app_bytes,omitempty"` + // Timeout (ns) for this request + Deadline uint64 `protobuf:"varint,3,opt,name=deadline,proto3" json:"deadline,omitempty"` + // Request body + AppBytes []byte `protobuf:"bytes,4,opt,name=app_bytes,json=appBytes,proto3" json:"app_bytes,omitempty"` } func (x *AppRequest) Reset() { @@ -2253,14 +2302,18 @@ func (x *AppRequest) GetAppBytes() []byte { return nil } +// AppResponse is a VM-defined response sent in response to AppRequest type AppResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` + // Chain being responded from + ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` + // Request id of the original AppRequest RequestId uint32 `protobuf:"varint,2,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` - AppBytes []byte `protobuf:"bytes,3,opt,name=app_bytes,json=appBytes,proto3" json:"app_bytes,omitempty"` + // Response body + AppBytes []byte `protobuf:"bytes,3,opt,name=app_bytes,json=appBytes,proto3" json:"app_bytes,omitempty"` } func (x *AppResponse) Reset() { @@ -2316,12 +2369,15 @@ func (x *AppResponse) GetAppBytes() []byte { return nil } +// AppGossip is a VM-defined message type AppGossip struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` + // Chain the message is for + ChainId []byte `protobuf:"bytes,1,opt,name=chain_id,json=chainId,proto3" json:"chain_id,omitempty"` + // Message body AppBytes []byte `protobuf:"bytes,2,opt,name=app_bytes,json=appBytes,proto3" json:"app_bytes,omitempty"` } diff --git a/proto/pb/sync/sync.pb.go b/proto/pb/sync/sync.pb.go index dc0368bbf70f..92cd3d88351e 100644 --- a/proto/pb/sync/sync.pb.go +++ b/proto/pb/sync/sync.pb.go @@ -1197,7 +1197,7 @@ type ProofNode struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Key *Path `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` + Key *Key `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` ValueOrHash *MaybeBytes `protobuf:"bytes,2,opt,name=value_or_hash,json=valueOrHash,proto3" json:"value_or_hash,omitempty"` Children map[uint32][]byte `protobuf:"bytes,3,rep,name=children,proto3" json:"children,omitempty" protobuf_key:"varint,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` } @@ -1234,7 +1234,7 @@ func (*ProofNode) Descriptor() ([]byte, []int) { return file_sync_sync_proto_rawDescGZIP(), []int{18} } -func (x *ProofNode) GetKey() *Path { +func (x *ProofNode) GetKey() *Key { if x != nil { return x.Key } @@ -1310,7 +1310,7 @@ func (x *KeyChange) GetValue() *MaybeBytes { return nil } -type Path struct { +type Key struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields @@ -1319,8 +1319,8 @@ type Path struct { Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` } -func (x *Path) Reset() { - *x = Path{} +func (x *Key) Reset() { + *x = Key{} if protoimpl.UnsafeEnabled { mi := &file_sync_sync_proto_msgTypes[20] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1328,13 +1328,13 @@ func (x *Path) Reset() { } } -func (x *Path) String() string { +func (x *Key) String() string { return protoimpl.X.MessageStringOf(x) } -func (*Path) ProtoMessage() {} +func (*Key) ProtoMessage() {} -func (x *Path) ProtoReflect() protoreflect.Message { +func (x *Key) ProtoReflect() protoreflect.Message { mi := &file_sync_sync_proto_msgTypes[20] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1346,19 +1346,19 @@ func (x *Path) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use Path.ProtoReflect.Descriptor instead. -func (*Path) Descriptor() ([]byte, []int) { +// Deprecated: Use Key.ProtoReflect.Descriptor instead. +func (*Key) Descriptor() ([]byte, []int) { return file_sync_sync_proto_rawDescGZIP(), []int{20} } -func (x *Path) GetLength() uint64 { +func (x *Key) GetLength() uint64 { if x != nil { return x.Length } return 0 } -func (x *Path) GetValue() []byte { +func (x *Key) GetValue() []byte { if x != nil { return x.Value } @@ -1638,72 +1638,72 @@ var file_sync_sync_proto_rawDesc = []byte{ 0x12, 0x2d, 0x0a, 0x0a, 0x6b, 0x65, 0x79, 0x5f, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x73, 0x79, 0x6e, 0x63, 0x2e, 0x4b, 0x65, 0x79, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x52, 0x09, 0x6b, 0x65, 0x79, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x22, - 0xd7, 0x01, 0x0a, 0x09, 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x4e, 0x6f, 0x64, 0x65, 0x12, 0x1c, 0x0a, - 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0a, 0x2e, 0x73, 0x79, 0x6e, - 0x63, 0x2e, 0x50, 0x61, 0x74, 0x68, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x34, 0x0a, 0x0d, 0x76, - 0x61, 0x6c, 0x75, 0x65, 0x5f, 0x6f, 0x72, 0x5f, 0x68, 0x61, 0x73, 0x68, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x73, 0x79, 0x6e, 0x63, 0x2e, 0x4d, 0x61, 0x79, 0x62, 0x65, 0x42, - 0x79, 0x74, 0x65, 0x73, 0x52, 0x0b, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x4f, 0x72, 0x48, 0x61, 0x73, - 0x68, 0x12, 0x39, 0x0a, 0x08, 0x63, 0x68, 0x69, 0x6c, 0x64, 0x72, 0x65, 0x6e, 0x18, 0x03, 0x20, - 0x03, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x73, 0x79, 0x6e, 0x63, 0x2e, 0x50, 0x72, 0x6f, 0x6f, 0x66, - 0x4e, 0x6f, 0x64, 0x65, 0x2e, 0x43, 0x68, 0x69, 0x6c, 0x64, 0x72, 0x65, 0x6e, 0x45, 0x6e, 0x74, - 0x72, 0x79, 0x52, 0x08, 0x63, 0x68, 0x69, 0x6c, 0x64, 0x72, 0x65, 0x6e, 0x1a, 0x3b, 0x0a, 0x0d, - 0x43, 0x68, 0x69, 0x6c, 0x64, 0x72, 0x65, 0x6e, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, - 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, - 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, - 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x45, 0x0a, 0x09, 0x4b, 0x65, 0x79, - 0x43, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x0c, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x26, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, - 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x73, 0x79, 0x6e, 0x63, 0x2e, 0x4d, - 0x61, 0x79, 0x62, 0x65, 0x42, 0x79, 0x74, 0x65, 0x73, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, - 0x22, 0x34, 0x0a, 0x04, 0x50, 0x61, 0x74, 0x68, 0x12, 0x16, 0x0a, 0x06, 0x6c, 0x65, 0x6e, 0x67, - 0x74, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, - 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, - 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x22, 0x41, 0x0a, 0x0a, 0x4d, 0x61, 0x79, 0x62, 0x65, 0x42, - 0x79, 0x74, 0x65, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x0c, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x69, 0x73, - 0x5f, 0x6e, 0x6f, 0x74, 0x68, 0x69, 0x6e, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, - 0x69, 0x73, 0x4e, 0x6f, 0x74, 0x68, 0x69, 0x6e, 0x67, 0x22, 0x32, 0x0a, 0x08, 0x4b, 0x65, 0x79, - 0x56, 0x61, 0x6c, 0x75, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x0c, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x32, 0x8a, 0x04, - 0x0a, 0x02, 0x44, 0x42, 0x12, 0x44, 0x0a, 0x0d, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x72, 0x6b, 0x6c, - 0x65, 0x52, 0x6f, 0x6f, 0x74, 0x12, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1b, 0x2e, - 0x73, 0x79, 0x6e, 0x63, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x72, 0x6b, 0x6c, 0x65, 0x52, 0x6f, - 0x6f, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x39, 0x0a, 0x08, 0x47, 0x65, - 0x74, 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x12, 0x15, 0x2e, 0x73, 0x79, 0x6e, 0x63, 0x2e, 0x47, 0x65, - 0x74, 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, - 0x73, 0x79, 0x6e, 0x63, 0x2e, 0x47, 0x65, 0x74, 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4b, 0x0a, 0x0e, 0x47, 0x65, 0x74, 0x43, 0x68, 0x61, 0x6e, - 0x67, 0x65, 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x12, 0x1b, 0x2e, 0x73, 0x79, 0x6e, 0x63, 0x2e, 0x47, - 0x65, 0x74, 0x43, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x73, 0x79, 0x6e, 0x63, 0x2e, 0x47, 0x65, 0x74, 0x43, - 0x68, 0x61, 0x6e, 0x67, 0x65, 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x12, 0x54, 0x0a, 0x11, 0x56, 0x65, 0x72, 0x69, 0x66, 0x79, 0x43, 0x68, 0x61, 0x6e, - 0x67, 0x65, 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x12, 0x1e, 0x2e, 0x73, 0x79, 0x6e, 0x63, 0x2e, 0x56, - 0x65, 0x72, 0x69, 0x66, 0x79, 0x43, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x50, 0x72, 0x6f, 0x6f, 0x66, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1f, 0x2e, 0x73, 0x79, 0x6e, 0x63, 0x2e, 0x56, - 0x65, 0x72, 0x69, 0x66, 0x79, 0x43, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x50, 0x72, 0x6f, 0x6f, 0x66, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4b, 0x0a, 0x11, 0x43, 0x6f, 0x6d, 0x6d, - 0x69, 0x74, 0x43, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x12, 0x1e, 0x2e, - 0x73, 0x79, 0x6e, 0x63, 0x2e, 0x43, 0x6f, 0x6d, 0x6d, 0x69, 0x74, 0x43, 0x68, 0x61, 0x6e, 0x67, - 0x65, 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, - 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, - 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x48, 0x0a, 0x0d, 0x47, 0x65, 0x74, 0x52, 0x61, 0x6e, 0x67, - 0x65, 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x12, 0x1a, 0x2e, 0x73, 0x79, 0x6e, 0x63, 0x2e, 0x47, 0x65, - 0x74, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x73, 0x79, 0x6e, 0x63, 0x2e, 0x47, 0x65, 0x74, 0x52, 0x61, 0x6e, - 0x67, 0x65, 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, - 0x49, 0x0a, 0x10, 0x43, 0x6f, 0x6d, 0x6d, 0x69, 0x74, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x50, 0x72, - 0x6f, 0x6f, 0x66, 0x12, 0x1d, 0x2e, 0x73, 0x79, 0x6e, 0x63, 0x2e, 0x43, 0x6f, 0x6d, 0x6d, 0x69, - 0x74, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x42, 0x2f, 0x5a, 0x2d, 0x67, 0x69, - 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x61, 0x76, 0x61, 0x2d, 0x6c, 0x61, 0x62, - 0x73, 0x2f, 0x61, 0x76, 0x61, 0x6c, 0x61, 0x6e, 0x63, 0x68, 0x65, 0x67, 0x6f, 0x2f, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x2f, 0x70, 0x62, 0x2f, 0x73, 0x79, 0x6e, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x33, + 0xd6, 0x01, 0x0a, 0x09, 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x4e, 0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, + 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x09, 0x2e, 0x73, 0x79, 0x6e, + 0x63, 0x2e, 0x4b, 0x65, 0x79, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x34, 0x0a, 0x0d, 0x76, 0x61, + 0x6c, 0x75, 0x65, 0x5f, 0x6f, 0x72, 0x5f, 0x68, 0x61, 0x73, 0x68, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x10, 0x2e, 0x73, 0x79, 0x6e, 0x63, 0x2e, 0x4d, 0x61, 0x79, 0x62, 0x65, 0x42, 0x79, + 0x74, 0x65, 0x73, 0x52, 0x0b, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x4f, 0x72, 0x48, 0x61, 0x73, 0x68, + 0x12, 0x39, 0x0a, 0x08, 0x63, 0x68, 0x69, 0x6c, 0x64, 0x72, 0x65, 0x6e, 0x18, 0x03, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x73, 0x79, 0x6e, 0x63, 0x2e, 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x4e, + 0x6f, 0x64, 0x65, 0x2e, 0x43, 0x68, 0x69, 0x6c, 0x64, 0x72, 0x65, 0x6e, 0x45, 0x6e, 0x74, 0x72, + 0x79, 0x52, 0x08, 0x63, 0x68, 0x69, 0x6c, 0x64, 0x72, 0x65, 0x6e, 0x1a, 0x3b, 0x0a, 0x0d, 0x43, + 0x68, 0x69, 0x6c, 0x64, 0x72, 0x65, 0x6e, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, + 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, + 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x76, + 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x45, 0x0a, 0x09, 0x4b, 0x65, 0x79, 0x43, + 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x0c, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x26, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x73, 0x79, 0x6e, 0x63, 0x2e, 0x4d, 0x61, + 0x79, 0x62, 0x65, 0x42, 0x79, 0x74, 0x65, 0x73, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x22, + 0x33, 0x0a, 0x03, 0x4b, 0x65, 0x79, 0x12, 0x16, 0x0a, 0x06, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x12, 0x14, + 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x76, + 0x61, 0x6c, 0x75, 0x65, 0x22, 0x41, 0x0a, 0x0a, 0x4d, 0x61, 0x79, 0x62, 0x65, 0x42, 0x79, 0x74, + 0x65, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x0c, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x69, 0x73, 0x5f, 0x6e, + 0x6f, 0x74, 0x68, 0x69, 0x6e, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, + 0x4e, 0x6f, 0x74, 0x68, 0x69, 0x6e, 0x67, 0x22, 0x32, 0x0a, 0x08, 0x4b, 0x65, 0x79, 0x56, 0x61, + 0x6c, 0x75, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, + 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x32, 0x8a, 0x04, 0x0a, 0x02, + 0x44, 0x42, 0x12, 0x44, 0x0a, 0x0d, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x72, 0x6b, 0x6c, 0x65, 0x52, + 0x6f, 0x6f, 0x74, 0x12, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1b, 0x2e, 0x73, 0x79, + 0x6e, 0x63, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x72, 0x6b, 0x6c, 0x65, 0x52, 0x6f, 0x6f, 0x74, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x39, 0x0a, 0x08, 0x47, 0x65, 0x74, 0x50, + 0x72, 0x6f, 0x6f, 0x66, 0x12, 0x15, 0x2e, 0x73, 0x79, 0x6e, 0x63, 0x2e, 0x47, 0x65, 0x74, 0x50, + 0x72, 0x6f, 0x6f, 0x66, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x73, 0x79, + 0x6e, 0x63, 0x2e, 0x47, 0x65, 0x74, 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x12, 0x4b, 0x0a, 0x0e, 0x47, 0x65, 0x74, 0x43, 0x68, 0x61, 0x6e, 0x67, 0x65, + 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x12, 0x1b, 0x2e, 0x73, 0x79, 0x6e, 0x63, 0x2e, 0x47, 0x65, 0x74, + 0x43, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x73, 0x79, 0x6e, 0x63, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x68, 0x61, + 0x6e, 0x67, 0x65, 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x12, 0x54, 0x0a, 0x11, 0x56, 0x65, 0x72, 0x69, 0x66, 0x79, 0x43, 0x68, 0x61, 0x6e, 0x67, 0x65, + 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x12, 0x1e, 0x2e, 0x73, 0x79, 0x6e, 0x63, 0x2e, 0x56, 0x65, 0x72, + 0x69, 0x66, 0x79, 0x43, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1f, 0x2e, 0x73, 0x79, 0x6e, 0x63, 0x2e, 0x56, 0x65, 0x72, + 0x69, 0x66, 0x79, 0x43, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4b, 0x0a, 0x11, 0x43, 0x6f, 0x6d, 0x6d, 0x69, 0x74, + 0x43, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x12, 0x1e, 0x2e, 0x73, 0x79, + 0x6e, 0x63, 0x2e, 0x43, 0x6f, 0x6d, 0x6d, 0x69, 0x74, 0x43, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x50, + 0x72, 0x6f, 0x6f, 0x66, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x67, 0x6f, + 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, + 0x70, 0x74, 0x79, 0x12, 0x48, 0x0a, 0x0d, 0x47, 0x65, 0x74, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x50, + 0x72, 0x6f, 0x6f, 0x66, 0x12, 0x1a, 0x2e, 0x73, 0x79, 0x6e, 0x63, 0x2e, 0x47, 0x65, 0x74, 0x52, + 0x61, 0x6e, 0x67, 0x65, 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x1b, 0x2e, 0x73, 0x79, 0x6e, 0x63, 0x2e, 0x47, 0x65, 0x74, 0x52, 0x61, 0x6e, 0x67, 0x65, + 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x49, 0x0a, + 0x10, 0x43, 0x6f, 0x6d, 0x6d, 0x69, 0x74, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x50, 0x72, 0x6f, 0x6f, + 0x66, 0x12, 0x1d, 0x2e, 0x73, 0x79, 0x6e, 0x63, 0x2e, 0x43, 0x6f, 0x6d, 0x6d, 0x69, 0x74, 0x52, + 0x61, 0x6e, 0x67, 0x65, 0x50, 0x72, 0x6f, 0x6f, 0x66, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, + 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x42, 0x2f, 0x5a, 0x2d, 0x67, 0x69, 0x74, 0x68, + 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x61, 0x76, 0x61, 0x2d, 0x6c, 0x61, 0x62, 0x73, 0x2f, + 0x61, 0x76, 0x61, 0x6c, 0x61, 0x6e, 0x63, 0x68, 0x65, 0x67, 0x6f, 0x2f, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x2f, 0x70, 0x62, 0x2f, 0x73, 0x79, 0x6e, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x33, } var ( @@ -1740,7 +1740,7 @@ var file_sync_sync_proto_goTypes = []interface{}{ (*RangeProof)(nil), // 17: sync.RangeProof (*ProofNode)(nil), // 18: sync.ProofNode (*KeyChange)(nil), // 19: sync.KeyChange - (*Path)(nil), // 20: sync.Path + (*Key)(nil), // 20: sync.Key (*MaybeBytes)(nil), // 21: sync.MaybeBytes (*KeyValue)(nil), // 22: sync.KeyValue nil, // 23: sync.ProofNode.ChildrenEntry @@ -1777,7 +1777,7 @@ var file_sync_sync_proto_depIdxs = []int32{ 18, // 27: sync.RangeProof.start_proof:type_name -> sync.ProofNode 18, // 28: sync.RangeProof.end_proof:type_name -> sync.ProofNode 22, // 29: sync.RangeProof.key_values:type_name -> sync.KeyValue - 20, // 30: sync.ProofNode.key:type_name -> sync.Path + 20, // 30: sync.ProofNode.key:type_name -> sync.Key 21, // 31: sync.ProofNode.value_or_hash:type_name -> sync.MaybeBytes 23, // 32: sync.ProofNode.children:type_name -> sync.ProofNode.ChildrenEntry 21, // 33: sync.KeyChange.value:type_name -> sync.MaybeBytes @@ -2049,7 +2049,7 @@ func file_sync_sync_proto_init() { } } file_sync_sync_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Path); i { + switch v := v.(*Key); i { case 0: return &v.state case 1: diff --git a/proto/sync/sync.proto b/proto/sync/sync.proto index 32ae2f8a4849..4c4c6f434722 100644 --- a/proto/sync/sync.proto +++ b/proto/sync/sync.proto @@ -139,7 +139,7 @@ message RangeProof { } message ProofNode { - Path key = 1; + Key key = 1; MaybeBytes value_or_hash = 2; map children = 3; } @@ -149,7 +149,7 @@ message KeyChange { MaybeBytes value = 2; } -message Path { +message Key { uint64 length = 1; bytes value = 2; } diff --git a/scripts/camino_mocks.mockgen.txt b/scripts/camino_mocks.mockgen.txt index c3b0616eba0c..287a9722c852 100644 --- a/scripts/camino_mocks.mockgen.txt +++ b/scripts/camino_mocks.mockgen.txt @@ -1,5 +1,7 @@ // TODO @evlekht -// add this to 'scripts/mocks.mockgen.txt' when mockgen will be able to process this files correctly (some generics issues) + +// add this to 'scripts/mocks.mockgen.txt' when mockgen will +// be able to process this files correctly (some generics issues) github.com/ava-labs/avalanchego/cache=Cacher=cache/mock_cacher.go github.com/ava-labs/avalanchego/vms/components/avax=AtomicUTXOManager=vms/components/avax/mock_atomic_utxos.go @@ -8,9 +10,13 @@ github.com/ava-labs/avalanchego/vms/platformvm/state=Diff=vms/platformvm/state/m github.com/ava-labs/avalanchego/vms/platformvm/state=State=vms/platformvm/state/mock_state.go -// avax also have their own mocks excluded from list, though there is no comment about not forgetting to add them back: +// avax also have their own mocks excluded from list, +// though there is no comment about not forgetting +// to add them back or why they were excluded: + github.com/ava-labs/avalanchego/snow/networking/router=Router=snow/networking/router/mock_router.go github.com/ava-labs/avalanchego/snow/networking/sender=ExternalSender=snow/networking/sender/mock_external_sender.go github.com/ava-labs/avalanchego/snow/validators=Set=snow/validators/mock_set.go +github.com/ava-labs/avalanchego/snow/validators=Manager=snow/validators/mock_manager.go github.com/ava-labs/avalanchego/vms/platformvm/txs=Staker=vms/platformvm/txs/mock_staker.go github.com/ava-labs/avalanchego/vms/platformvm/txs=UnsignedTx=vms/platformvm/txs/mock_unsigned_tx.go \ No newline at end of file diff --git a/scripts/constants.sh b/scripts/constants.sh index 19fa2048c7cf..867d8085739f 100755 --- a/scripts/constants.sh +++ b/scripts/constants.sh @@ -41,7 +41,7 @@ fi # # We use "export" here instead of just setting a bash variable because we need # to pass this flag to all child processes spawned by the shell. -export CGO_CFLAGS="-O -D__BLST_PORTABLE__" +export CGO_CFLAGS="-O2 -D__BLST_PORTABLE__" # While CGO_ENABLED doesn't need to be explicitly set, it produces a much more # clear error due to the default value change in go1.20. export CGO_ENABLED=1 diff --git a/scripts/mock.gen.sh b/scripts/mock.gen.sh index ddf92d116adb..95a801caf2a6 100755 --- a/scripts/mock.gen.sh +++ b/scripts/mock.gen.sh @@ -14,6 +14,8 @@ then go install -v go.uber.org/mock/mockgen@v0.2.0 fi +source ./scripts/constants.sh + # tuples of (source interface import path, comma-separated interface names, output file path) input="scripts/mocks.mockgen.txt" while IFS= read -r line diff --git a/scripts/mocks.mockgen.txt b/scripts/mocks.mockgen.txt index 2b39e0db3c78..bc0a18495971 100644 --- a/scripts/mocks.mockgen.txt +++ b/scripts/mocks.mockgen.txt @@ -18,7 +18,6 @@ github.com/ava-labs/avalanchego/snow/networking/timeout=Manager=snow/networking/ github.com/ava-labs/avalanchego/snow/networking/tracker=Targeter=snow/networking/tracker/mock_targeter.go github.com/ava-labs/avalanchego/snow/networking/tracker=Tracker=snow/networking/tracker/mock_resource_tracker.go github.com/ava-labs/avalanchego/snow/uptime=Calculator=snow/uptime/mock_calculator.go -github.com/ava-labs/avalanchego/snow/validators=Manager=snow/validators/mock_manager.go github.com/ava-labs/avalanchego/snow/validators=State=snow/validators/mock_state.go github.com/ava-labs/avalanchego/snow/validators=SubnetConnector=snow/validators/mock_subnet_connector.go github.com/ava-labs/avalanchego/utils/crypto/keychain=Ledger=utils/crypto/keychain/mock_ledger.go diff --git a/snow/engine/avalanche/bootstrap/bootstrapper.go b/snow/engine/avalanche/bootstrap/bootstrapper.go index 2de048c49009..0f8a2484a4e2 100644 --- a/snow/engine/avalanche/bootstrap/bootstrapper.go +++ b/snow/engine/avalanche/bootstrap/bootstrapper.go @@ -15,8 +15,8 @@ import ( "github.com/ava-labs/avalanchego/snow" "github.com/ava-labs/avalanchego/snow/choices" "github.com/ava-labs/avalanchego/snow/consensus/avalanche" - "github.com/ava-labs/avalanchego/snow/engine/avalanche/vertex" "github.com/ava-labs/avalanchego/snow/engine/common" + "github.com/ava-labs/avalanchego/utils/heap" "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/set" "github.com/ava-labs/avalanchego/version" @@ -89,7 +89,10 @@ type bootstrapper struct { processedCache *cache.LRU[ids.ID, struct{}] } -func (b *bootstrapper) Clear() error { +func (b *bootstrapper) Clear(context.Context) error { + b.Ctx.Lock.Lock() + defer b.Ctx.Lock.Unlock() + if err := b.VtxBlocked.Clear(); err != nil { return err } @@ -283,6 +286,10 @@ func (*bootstrapper) Gossip(context.Context) error { func (b *bootstrapper) Shutdown(ctx context.Context) error { b.Ctx.Log.Info("shutting down bootstrapper") + + b.Ctx.Lock.Lock() + defer b.Ctx.Lock.Unlock() + return b.VM.Shutdown(ctx) } @@ -364,6 +371,9 @@ func (b *bootstrapper) Start(ctx context.Context, startReqID uint32) error { } func (b *bootstrapper) HealthCheck(ctx context.Context) (interface{}, error) { + b.Ctx.Lock.Lock() + defer b.Ctx.Lock.Unlock() + vmIntf, vmErr := b.VM.HealthCheck(ctx) intf := map[string]interface{}{ "consensus": struct{}{}, @@ -395,7 +405,7 @@ func (b *bootstrapper) fetch(ctx context.Context, vtxIDs ...ids.ID) error { continue } - validatorIDs, err := b.Config.Beacons.Sample(1) // validator to send request to + validatorIDs, err := b.Config.Beacons.Sample(b.Ctx.SubnetID, 1) // validator to send request to if err != nil { return fmt.Errorf("dropping request for %s as there are no validators", vtxID) } @@ -410,14 +420,15 @@ func (b *bootstrapper) fetch(ctx context.Context, vtxIDs ...ids.ID) error { // Process the vertices in [vtxs]. func (b *bootstrapper) process(ctx context.Context, vtxs ...avalanche.Vertex) error { - // Vertices that we need to process. Store them in a heap for deduplication - // and so we always process vertices further down in the DAG first. This helps - // to reduce the number of repeated DAG traversals. - toProcess := vertex.NewHeap() + // Vertices that we need to process prioritized by vertices that are unknown + // or the furthest down the DAG. Unknown vertices are prioritized to ensure + // that once we have made it below a certain height in DAG traversal we do + // not need to reset and repeat DAG traversals. + toProcess := heap.NewMap[ids.ID, avalanche.Vertex](vertexLess) for _, vtx := range vtxs { vtxID := vtx.ID() if _, ok := b.processedCache.Get(vtxID); !ok { // only process a vertex if we haven't already - toProcess.Push(vtx) + _, _ = toProcess.Push(vtxID, vtx) } else { b.VtxBlocked.RemoveMissingID(vtxID) } @@ -426,13 +437,15 @@ func (b *bootstrapper) process(ctx context.Context, vtxs ...avalanche.Vertex) er vtxHeightSet := set.Set[ids.ID]{} prevHeight := uint64(0) - for toProcess.Len() > 0 { // While there are unprocessed vertices + for { if b.Halted() { return nil } - vtx := toProcess.Pop() // Get an unknown vertex or one furthest down the DAG - vtxID := vtx.ID() + vtxID, vtx, ok := toProcess.Pop() + if !ok { + break + } switch vtx.Status() { case choices.Unknown: @@ -504,7 +517,7 @@ func (b *bootstrapper) process(ctx context.Context, vtxs ...avalanche.Vertex) er parentID := parent.ID() if _, ok := b.processedCache.Get(parentID); !ok { // But only if we haven't processed the parent if !vtxHeightSet.Contains(parentID) { - toProcess.Push(parent) + toProcess.Push(parentID, parent) } } } @@ -626,3 +639,26 @@ func (b *bootstrapper) checkFinish(ctx context.Context) error { b.processedCache.Flush() return b.OnFinished(ctx, b.Config.SharedCfg.RequestID) } + +// A vertex is less than another vertex if it is unknown. Ties are broken by +// prioritizing vertices that have a greater height. +func vertexLess(i, j avalanche.Vertex) bool { + if !i.Status().Fetched() { + return true + } + if !j.Status().Fetched() { + return false + } + + // Treat errors on retrieving the height as if the vertex is not fetched + heightI, errI := i.Height() + if errI != nil { + return true + } + heightJ, errJ := j.Height() + if errJ != nil { + return false + } + + return heightI > heightJ +} diff --git a/snow/engine/avalanche/bootstrap/bootstrapper_test.go b/snow/engine/avalanche/bootstrap/bootstrapper_test.go index 988d42549fa3..d5bd51a9f78c 100644 --- a/snow/engine/avalanche/bootstrap/bootstrapper_test.go +++ b/snow/engine/avalanche/bootstrap/bootstrapper_test.go @@ -25,6 +25,7 @@ import ( "github.com/ava-labs/avalanchego/snow/engine/common/queue" "github.com/ava-labs/avalanchego/snow/engine/common/tracker" "github.com/ava-labs/avalanchego/snow/validators" + "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/set" ) @@ -53,7 +54,7 @@ func newConfig(t *testing.T) (Config, ids.NodeID, *common.SenderTest, *vertex.Te ctx := snow.DefaultConsensusContextTest() - peers := validators.NewSet() + vdrs := validators.NewManager() db := memdb.New() sender := &common.SenderTest{T: t} manager := vertex.NewTestManager(t) @@ -78,7 +79,7 @@ func newConfig(t *testing.T) (Config, ids.NodeID, *common.SenderTest, *vertex.Te sender.CantSendGetAcceptedFrontier = false peer := ids.GenerateTestNodeID() - require.NoError(peers.Add(peer, nil, ids.Empty, 1)) + require.NoError(vdrs.AddStaker(constants.PrimaryNetworkID, peer, nil, ids.Empty, 1)) vtxBlocker, err := queue.NewWithMissing(prefixdb.New([]byte("vtx"), db), "vtx", ctx.AvalancheRegisterer) require.NoError(err) @@ -87,14 +88,16 @@ func newConfig(t *testing.T) (Config, ids.NodeID, *common.SenderTest, *vertex.Te require.NoError(err) peerTracker := tracker.NewPeers() - startupTracker := tracker.NewStartup(peerTracker, peers.Weight()/2+1) - peers.RegisterCallbackListener(startupTracker) + totalWeight, err := vdrs.TotalWeight(constants.PrimaryNetworkID) + require.NoError(err) + startupTracker := tracker.NewStartup(peerTracker, totalWeight/2+1) + vdrs.RegisterCallbackListener(constants.PrimaryNetworkID, startupTracker) commonConfig := common.Config{ Ctx: ctx, - Beacons: peers, - SampleK: peers.Len(), - Alpha: peers.Weight()/2 + 1, + Beacons: vdrs, + SampleK: vdrs.Count(constants.PrimaryNetworkID), + Alpha: totalWeight/2 + 1, StartupTracker: startupTracker, Sender: sender, BootstrapTracker: bootstrapTracker, diff --git a/snow/engine/avalanche/getter/getter_test.go b/snow/engine/avalanche/getter/getter_test.go index abf3381cf511..93694ed5bba0 100644 --- a/snow/engine/avalanche/getter/getter_test.go +++ b/snow/engine/avalanche/getter/getter_test.go @@ -17,14 +17,15 @@ import ( "github.com/ava-labs/avalanchego/snow/engine/avalanche/vertex" "github.com/ava-labs/avalanchego/snow/engine/common" "github.com/ava-labs/avalanchego/snow/validators" + "github.com/ava-labs/avalanchego/utils/constants" ) var errUnknownVertex = errors.New("unknown vertex") func testSetup(t *testing.T) (*vertex.TestManager, *common.SenderTest, common.Config) { - peers := validators.NewSet() + vdrs := validators.NewManager() peer := ids.GenerateTestNodeID() - require.NoError(t, peers.Add(peer, nil, ids.Empty, 1)) + require.NoError(t, vdrs.AddStaker(constants.PrimaryNetworkID, peer, nil, ids.Empty, 1)) sender := &common.SenderTest{T: t} sender.Default(true) @@ -41,11 +42,14 @@ func testSetup(t *testing.T) (*vertex.TestManager, *common.SenderTest, common.Co }, } + totalWeight, err := vdrs.TotalWeight(constants.PrimaryNetworkID) + require.NoError(t, err) + commonConfig := common.Config{ Ctx: snow.DefaultConsensusContextTest(), - Beacons: peers, - SampleK: peers.Len(), - Alpha: peers.Weight()/2 + 1, + Beacons: vdrs, + SampleK: vdrs.Count(constants.PrimaryNetworkID), + Alpha: totalWeight/2 + 1, Sender: sender, BootstrapTracker: bootstrapTracker, Timer: &common.TimerTest{}, diff --git a/snow/engine/avalanche/vertex/heap.go b/snow/engine/avalanche/vertex/heap.go deleted file mode 100644 index fa9a0a83d920..000000000000 --- a/snow/engine/avalanche/vertex/heap.go +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package vertex - -import ( - "container/heap" - - "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/snow/consensus/avalanche" - "github.com/ava-labs/avalanchego/utils/set" -) - -var ( - _ Heap = (*maxHeightVertexHeap)(nil) - _ heap.Interface = (*priorityQueue)(nil) -) - -type priorityQueue []avalanche.Vertex - -func (pq priorityQueue) Len() int { - return len(pq) -} - -// Returns true if the vertex at index i has greater height than the vertex at -// index j. -func (pq priorityQueue) Less(i, j int) bool { - statusI := pq[i].Status() - statusJ := pq[j].Status() - - // Put unknown vertices at the front of the heap to ensure once we have made - // it below a certain height in DAG traversal we do not need to reset - if !statusI.Fetched() { - return true - } - if !statusJ.Fetched() { - return false - } - - // Treat errors on retrieving the height as if the vertex is not fetched - heightI, errI := pq[i].Height() - if errI != nil { - return true - } - heightJ, errJ := pq[j].Height() - if errJ != nil { - return false - } - return heightI > heightJ -} - -func (pq priorityQueue) Swap(i, j int) { - pq[i], pq[j] = pq[j], pq[i] -} - -// Push adds an item to this priority queue. x must have type *vertexItem -func (pq *priorityQueue) Push(x interface{}) { - item := x.(avalanche.Vertex) - *pq = append(*pq, item) -} - -// Pop returns the last item in this priorityQueue -func (pq *priorityQueue) Pop() interface{} { - old := *pq - n := len(old) - item := old[n-1] - old[n-1] = nil - *pq = old[0 : n-1] - return item -} - -// Heap defines the functionality of a heap of vertices with unique VertexIDs -// ordered by height -type Heap interface { - // Empty the heap. - Clear() - - // Add the provided vertex to the heap. Vertices are de-duplicated, returns - // true if the vertex was added, false if it was dropped. - Push(avalanche.Vertex) bool - - // Remove the top vertex. Assumes that there is at least one element. - Pop() avalanche.Vertex - - // Returns if a vertex with the provided ID is currently in the heap. - Contains(ids.ID) bool - - // Returns the number of vertices in the heap. - Len() int -} - -// NewHeap returns an empty Heap -func NewHeap() Heap { - return &maxHeightVertexHeap{} -} - -type maxHeightVertexHeap struct { - heap priorityQueue - elementIDs set.Set[ids.ID] -} - -func (vh *maxHeightVertexHeap) Clear() { - vh.heap = priorityQueue{} - vh.elementIDs.Clear() -} - -// Push adds an element to this heap. Returns true if the element was added. -// Returns false if it was already in the heap. -func (vh *maxHeightVertexHeap) Push(vtx avalanche.Vertex) bool { - vtxID := vtx.ID() - if vh.elementIDs.Contains(vtxID) { - return false - } - - vh.elementIDs.Add(vtxID) - heap.Push(&vh.heap, vtx) - return true -} - -// If there are any vertices in this heap with status Unknown, removes one such -// vertex and returns it. Otherwise, removes and returns the vertex in this heap -// with the greatest height. -func (vh *maxHeightVertexHeap) Pop() avalanche.Vertex { - vtx := heap.Pop(&vh.heap).(avalanche.Vertex) - vh.elementIDs.Remove(vtx.ID()) - return vtx -} - -func (vh *maxHeightVertexHeap) Len() int { - return vh.heap.Len() -} - -func (vh *maxHeightVertexHeap) Contains(vtxID ids.ID) bool { - return vh.elementIDs.Contains(vtxID) -} diff --git a/snow/engine/avalanche/vertex/heap_test.go b/snow/engine/avalanche/vertex/heap_test.go deleted file mode 100644 index 1301f3fda96b..000000000000 --- a/snow/engine/avalanche/vertex/heap_test.go +++ /dev/null @@ -1,138 +0,0 @@ -// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package vertex - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/snow/choices" - "github.com/ava-labs/avalanchego/snow/consensus/avalanche" -) - -// This example inserts several ints into an IntHeap, checks the minimum, -// and removes them in order of priority. -func TestUniqueVertexHeapReturnsOrdered(t *testing.T) { - require := require.New(t) - - h := NewHeap() - - vtx0 := &avalanche.TestVertex{ - TestDecidable: choices.TestDecidable{ - IDV: ids.GenerateTestID(), - StatusV: choices.Processing, - }, - HeightV: 0, - } - vtx1 := &avalanche.TestVertex{ - TestDecidable: choices.TestDecidable{ - IDV: ids.GenerateTestID(), - StatusV: choices.Processing, - }, - HeightV: 1, - } - vtx2 := &avalanche.TestVertex{ - TestDecidable: choices.TestDecidable{ - IDV: ids.GenerateTestID(), - StatusV: choices.Processing, - }, - HeightV: 1, - } - vtx3 := &avalanche.TestVertex{ - TestDecidable: choices.TestDecidable{ - IDV: ids.GenerateTestID(), - StatusV: choices.Processing, - }, - HeightV: 3, - } - vtx4 := &avalanche.TestVertex{ - TestDecidable: choices.TestDecidable{ - IDV: ids.GenerateTestID(), - StatusV: choices.Unknown, - }, - HeightV: 0, - } - - vts := []avalanche.Vertex{vtx0, vtx1, vtx2, vtx3, vtx4} - - for _, vtx := range vts { - h.Push(vtx) - } - - vtxZ := h.Pop() - require.Equal(vtx4.ID(), vtxZ.ID()) - - vtxA := h.Pop() - height, err := vtxA.Height() - require.NoError(err) - require.Equal(uint64(3), height) - require.Equal(vtx3.ID(), vtxA.ID()) - - vtxB := h.Pop() - height, err = vtxB.Height() - require.NoError(err) - require.Equal(uint64(1), height) - require.Contains([]ids.ID{vtx1.ID(), vtx2.ID()}, vtxB.ID()) - - vtxC := h.Pop() - height, err = vtxC.Height() - require.NoError(err) - require.Equal(uint64(1), height) - require.Contains([]ids.ID{vtx1.ID(), vtx2.ID()}, vtxC.ID()) - - require.NotEqual(vtxB.ID(), vtxC.ID()) - - vtxD := h.Pop() - height, err = vtxD.Height() - require.NoError(err) - require.Zero(height) - require.Equal(vtx0.ID(), vtxD.ID()) - - require.Zero(h.Len()) -} - -func TestUniqueVertexHeapRemainsUnique(t *testing.T) { - require := require.New(t) - - h := NewHeap() - - vtx0 := &avalanche.TestVertex{ - TestDecidable: choices.TestDecidable{ - IDV: ids.GenerateTestID(), - StatusV: choices.Processing, - }, - HeightV: 0, - } - vtx1 := &avalanche.TestVertex{ - TestDecidable: choices.TestDecidable{ - IDV: ids.GenerateTestID(), - StatusV: choices.Processing, - }, - HeightV: 1, - } - - sharedID := ids.GenerateTestID() - vtx2 := &avalanche.TestVertex{ - TestDecidable: choices.TestDecidable{ - IDV: sharedID, - StatusV: choices.Processing, - }, - HeightV: 1, - } - vtx3 := &avalanche.TestVertex{ - TestDecidable: choices.TestDecidable{ - IDV: sharedID, - StatusV: choices.Processing, - }, - HeightV: 2, - } - - require.True(h.Push(vtx0)) - require.True(h.Push(vtx1)) - require.True(h.Push(vtx2)) - require.False(h.Push(vtx3)) - require.Equal(3, h.Len()) -} diff --git a/snow/engine/common/bootstrapable.go b/snow/engine/common/bootstrapable.go index f18b3295eae9..a4abcc59a880 100644 --- a/snow/engine/common/bootstrapable.go +++ b/snow/engine/common/bootstrapable.go @@ -21,5 +21,5 @@ type Bootstrapable interface { ForceAccepted(ctx context.Context, acceptedContainerIDs []ids.ID) error // Clear removes all containers to be processed upon bootstrapping - Clear() error + Clear(ctx context.Context) error } diff --git a/snow/engine/common/bootstrapper.go b/snow/engine/common/bootstrapper.go index 38543f228c74..aca219130478 100644 --- a/snow/engine/common/bootstrapper.go +++ b/snow/engine/common/bootstrapper.go @@ -5,6 +5,7 @@ package common import ( "context" + "fmt" "math" "go.uber.org/zap" @@ -46,7 +47,7 @@ type bootstrapper struct { Halter // Holds the beacons that were sampled for the accepted frontier - sampledBeacons validators.Set + sampledBeacons validators.Manager // IDs of validators we should request an accepted frontier from pendingSendAcceptedFrontier set.Set[ids.NodeID] // IDs of validators we requested an accepted frontier from but haven't @@ -143,16 +144,27 @@ func (b *bootstrapper) markAcceptedFrontierReceived(ctx context.Context, nodeID // Ask each bootstrap validator to filter the list of containers that we were // told are on the accepted frontier such that the list only contains containers // they think are accepted. - newAlpha := float64(b.sampledBeacons.Weight()*b.Alpha) / float64(b.Beacons.Weight()) + totalSampledWeight, err := b.sampledBeacons.TotalWeight(b.Ctx.SubnetID) + if err != nil { + return fmt.Errorf("failed to get total weight of sampled beacons for subnet %s: %w", b.Ctx.SubnetID, err) + } + beaconsTotalWeight, err := b.Beacons.TotalWeight(b.Ctx.SubnetID) + if err != nil { + return fmt.Errorf("failed to get total weight of beacons for subnet %s: %w", b.Ctx.SubnetID, err) + } + newAlpha := float64(totalSampledWeight*b.Alpha) / float64(beaconsTotalWeight) - failedBeaconWeight := b.Beacons.SubsetWeight(b.failedAcceptedFrontier) + failedBeaconWeight, err := b.Beacons.SubsetWeight(b.Ctx.SubnetID, b.failedAcceptedFrontier) + if err != nil { + return fmt.Errorf("failed to get total weight of failed beacons: %w", err) + } // fail the bootstrap if the weight is not enough to bootstrap - if float64(b.sampledBeacons.Weight())-newAlpha < float64(failedBeaconWeight) { + if float64(totalSampledWeight)-newAlpha < float64(failedBeaconWeight) { if b.Config.RetryBootstrap { b.Ctx.Log.Debug("restarting bootstrap", zap.String("reason", "not enough frontiers received"), - zap.Int("numBeacons", b.Beacons.Len()), + zap.Int("numBeacons", b.Beacons.Count(b.Ctx.SubnetID)), zap.Int("numFailedBootstrappers", b.failedAcceptedFrontier.Len()), zap.Int("numBootstrapAttemps", b.bootstrapAttempts), ) @@ -192,7 +204,7 @@ func (b *bootstrapper) Accepted(ctx context.Context, nodeID ids.NodeID, requestI // Mark that we received a response from [nodeID] b.pendingReceiveAccepted.Remove(nodeID) - weight := b.Beacons.GetWeight(nodeID) + weight := b.Beacons.GetWeight(b.Ctx.SubnetID, nodeID) for _, containerID := range containerIDs { previousWeight := b.acceptedVotes[containerID] newWeight, err := safemath.Add64(weight, previousWeight) @@ -226,18 +238,25 @@ func (b *bootstrapper) Accepted(ctx context.Context, nodeID ids.NodeID, requestI // if we don't have enough weight for the bootstrap to be accepted then // retry or fail the bootstrap size := len(accepted) - if size == 0 && b.Beacons.Len() > 0 { + if size == 0 && b.Beacons.Count(b.Ctx.SubnetID) > 0 { // if we had too many timeouts when asking for validator votes, we // should restart bootstrap hoping for the network problems to go away; // otherwise, we received enough (>= b.Alpha) responses, but no frontier // was supported by a majority of validators (i.e. votes are split // between minorities supporting different frontiers). - failedBeaconWeight := b.Beacons.SubsetWeight(b.failedAccepted) - votingStakes := b.Beacons.Weight() - failedBeaconWeight + beaconTotalWeight, err := b.Beacons.TotalWeight(b.Ctx.SubnetID) + if err != nil { + return fmt.Errorf("failed to get total weight of beacons for subnet %s: %w", b.Ctx.SubnetID, err) + } + failedBeaconWeight, err := b.Beacons.SubsetWeight(b.Ctx.SubnetID, b.failedAccepted) + if err != nil { + return fmt.Errorf("failed to get total weight of failed beacons for subnet %s: %w", b.Ctx.SubnetID, err) + } + votingStakes := beaconTotalWeight - failedBeaconWeight if b.Config.RetryBootstrap && votingStakes < b.Alpha { b.Ctx.Log.Debug("restarting bootstrap", zap.String("reason", "not enough votes received"), - zap.Int("numBeacons", b.Beacons.Len()), + zap.Int("numBeacons", b.Beacons.Count(b.Ctx.SubnetID)), zap.Int("numFailedBootstrappers", b.failedAccepted.Len()), zap.Int("numBootstrapAttempts", b.bootstrapAttempts), ) @@ -277,19 +296,19 @@ func (b *bootstrapper) GetAcceptedFailed(ctx context.Context, nodeID ids.NodeID, } func (b *bootstrapper) Startup(ctx context.Context) error { - beaconIDs, err := b.Beacons.Sample(b.Config.SampleK) + beaconIDs, err := b.Beacons.Sample(b.Ctx.SubnetID, b.Config.SampleK) if err != nil { return err } - b.sampledBeacons = validators.NewSet() + b.sampledBeacons = validators.NewManager() b.pendingSendAcceptedFrontier.Clear() for _, nodeID := range beaconIDs { - if !b.sampledBeacons.Contains(nodeID) { + if _, ok := b.sampledBeacons.GetValidator(b.Ctx.SubnetID, nodeID); !ok { // Invariant: We never use the TxID or BLS keys populated here. - err = b.sampledBeacons.Add(nodeID, nil, ids.Empty, 1) + err = b.sampledBeacons.AddStaker(b.Ctx.SubnetID, nodeID, nil, ids.Empty, 1) } else { - err = b.sampledBeacons.AddWeight(nodeID, 1) + err = b.sampledBeacons.AddWeight(b.Ctx.SubnetID, nodeID, 1) } if err != nil { return err @@ -302,7 +321,7 @@ func (b *bootstrapper) Startup(ctx context.Context) error { b.acceptedFrontierSet.Clear() b.pendingSendAccepted.Clear() - for nodeID := range b.Beacons.Map() { + for _, nodeID := range b.Beacons.GetValidatorIDs(b.Ctx.SubnetID) { b.pendingSendAccepted.Add(nodeID) } diff --git a/snow/engine/common/config.go b/snow/engine/common/config.go index 57507e7eae86..05eb3602f876 100644 --- a/snow/engine/common/config.go +++ b/snow/engine/common/config.go @@ -15,7 +15,7 @@ import ( // engine type Config struct { Ctx *snow.ConsensusContext - Beacons validators.Set + Beacons validators.Manager SampleK int Alpha uint64 diff --git a/snow/engine/common/test_bootstrapable.go b/snow/engine/common/test_bootstrapable.go index e10d3692308b..625070616377 100644 --- a/snow/engine/common/test_bootstrapable.go +++ b/snow/engine/common/test_bootstrapable.go @@ -26,7 +26,7 @@ type BootstrapableTest struct { CantForceAccepted, CantClear bool - ClearF func() error + ClearF func(ctx context.Context) error ForceAcceptedF func(ctx context.Context, acceptedContainerIDs []ids.ID) error } @@ -35,9 +35,9 @@ func (b *BootstrapableTest) Default(cant bool) { b.CantForceAccepted = cant } -func (b *BootstrapableTest) Clear() error { +func (b *BootstrapableTest) Clear(ctx context.Context) error { if b.ClearF != nil { - return b.ClearF() + return b.ClearF(ctx) } if b.CantClear && b.T != nil { require.FailNow(b.T, errClear.Error()) diff --git a/snow/engine/common/test_config.go b/snow/engine/common/test_config.go index ceca80f2768d..d39e6078fd01 100644 --- a/snow/engine/common/test_config.go +++ b/snow/engine/common/test_config.go @@ -8,6 +8,7 @@ import ( "github.com/ava-labs/avalanchego/snow" "github.com/ava-labs/avalanchego/snow/engine/common/tracker" "github.com/ava-labs/avalanchego/snow/validators" + "github.com/ava-labs/avalanchego/utils/constants" ) // DefaultConfigTest returns a test configuration @@ -22,11 +23,11 @@ func DefaultConfigTest() Config { }, } - beacons := validators.NewSet() + beacons := validators.NewManager() connectedPeers := tracker.NewPeers() startupTracker := tracker.NewStartup(connectedPeers, 0) - beacons.RegisterCallbackListener(startupTracker) + beacons.RegisterCallbackListener(constants.PrimaryNetworkID, startupTracker) return Config{ Ctx: snow.DefaultConsensusContextTest(), diff --git a/snow/engine/common/traced_bootstrapable_engine.go b/snow/engine/common/traced_bootstrapable_engine.go index c2379799055a..ba7a0d89228d 100644 --- a/snow/engine/common/traced_bootstrapable_engine.go +++ b/snow/engine/common/traced_bootstrapable_engine.go @@ -38,6 +38,9 @@ func (e *tracedBootstrapableEngine) ForceAccepted(ctx context.Context, acceptedC return e.bootstrapableEngine.ForceAccepted(ctx, acceptedContainerIDs) } -func (e *tracedBootstrapableEngine) Clear() error { - return e.bootstrapableEngine.Clear() +func (e *tracedBootstrapableEngine) Clear(ctx context.Context) error { + ctx, span := e.tracer.Start(ctx, "tracedBootstrapableEngine.Clear") + defer span.End() + + return e.bootstrapableEngine.Clear(ctx) } diff --git a/snow/engine/snowman/bootstrap/bootstrapper.go b/snow/engine/snowman/bootstrap/bootstrapper.go index 8f8336022b25..f6725aa00ba5 100644 --- a/snow/engine/snowman/bootstrap/bootstrapper.go +++ b/snow/engine/snowman/bootstrap/bootstrapper.go @@ -251,7 +251,7 @@ func (b *bootstrapper) Connected(ctx context.Context, nodeID ids.NodeID, nodeVer return err } // Ensure fetchFrom reflects proper validator list - if b.Beacons.Contains(nodeID) { + if _, ok := b.Beacons.GetValidator(b.Ctx.SubnetID, nodeID); ok { b.fetchFrom.Add(nodeID) } @@ -295,6 +295,10 @@ func (*bootstrapper) Gossip(context.Context) error { func (b *bootstrapper) Shutdown(ctx context.Context) error { b.Ctx.Log.Info("shutting down bootstrapper") + + b.Ctx.Lock.Lock() + defer b.Ctx.Lock.Unlock() + return b.VM.Shutdown(ctx) } @@ -311,6 +315,9 @@ func (b *bootstrapper) Notify(_ context.Context, msg common.Message) error { } func (b *bootstrapper) HealthCheck(ctx context.Context) (interface{}, error) { + b.Ctx.Lock.Lock() + defer b.Ctx.Lock.Unlock() + vmIntf, vmErr := b.VM.HealthCheck(ctx) intf := map[string]interface{}{ "consensus": struct{}{}, @@ -405,7 +412,10 @@ func (b *bootstrapper) markUnavailable(nodeID ids.NodeID) { } } -func (b *bootstrapper) Clear() error { +func (b *bootstrapper) Clear(context.Context) error { + b.Ctx.Lock.Lock() + defer b.Ctx.Lock.Unlock() + if err := b.Config.Blocked.Clear(); err != nil { return err } diff --git a/snow/engine/snowman/bootstrap/bootstrapper_test.go b/snow/engine/snowman/bootstrap/bootstrapper_test.go index 07ff91646dc3..620d85b0ba80 100644 --- a/snow/engine/snowman/bootstrap/bootstrapper_test.go +++ b/snow/engine/snowman/bootstrap/bootstrapper_test.go @@ -39,7 +39,7 @@ func newConfig(t *testing.T) (Config, ids.NodeID, *common.SenderTest, *block.Tes ctx := snow.DefaultConsensusContextTest() - peers := validators.NewSet() + vdrs := validators.NewManager() sender := &common.SenderTest{} vm := &block.TestVM{} @@ -64,19 +64,21 @@ func newConfig(t *testing.T) (Config, ids.NodeID, *common.SenderTest, *block.Tes sender.CantSendGetAcceptedFrontier = false peer := ids.GenerateTestNodeID() - require.NoError(peers.Add(peer, nil, ids.Empty, 1)) + require.NoError(vdrs.AddStaker(ctx.SubnetID, peer, nil, ids.Empty, 1)) peerTracker := tracker.NewPeers() - startupTracker := tracker.NewStartup(peerTracker, peers.Weight()/2+1) - peers.RegisterCallbackListener(startupTracker) + totalWeight, err := vdrs.TotalWeight(ctx.SubnetID) + require.NoError(err) + startupTracker := tracker.NewStartup(peerTracker, totalWeight/2+1) + vdrs.RegisterCallbackListener(ctx.SubnetID, startupTracker) require.NoError(startupTracker.Connected(context.Background(), peer, version.CurrentApp)) commonConfig := common.Config{ Ctx: ctx, - Beacons: peers, - SampleK: peers.Len(), - Alpha: peers.Weight()/2 + 1, + Beacons: vdrs, + SampleK: vdrs.Count(ctx.SubnetID), + Alpha: totalWeight/2 + 1, StartupTracker: startupTracker, Sender: sender, BootstrapTracker: bootstrapTracker, @@ -108,19 +110,19 @@ func TestBootstrapperStartsOnlyIfEnoughStakeIsConnected(t *testing.T) { sender.Default(true) vm.Default(true) - + ctx := snow.DefaultConsensusContextTest() // create boostrapper configuration - peers := validators.NewSet() + peers := validators.NewManager() sampleK := 2 alpha := uint64(10) startupAlpha := alpha peerTracker := tracker.NewPeers() startupTracker := tracker.NewStartup(peerTracker, startupAlpha) - peers.RegisterCallbackListener(startupTracker) + peers.RegisterCallbackListener(ctx.SubnetID, startupTracker) commonCfg := common.Config{ - Ctx: snow.DefaultConsensusContextTest(), + Ctx: ctx, Beacons: peers, SampleK: sampleK, Alpha: alpha, @@ -191,7 +193,7 @@ func TestBootstrapperStartsOnlyIfEnoughStakeIsConnected(t *testing.T) { // attempt starting bootstrapper with not enough stake connected. Bootstrapper should stall. vdr0 := ids.GenerateTestNodeID() - require.NoError(peers.Add(vdr0, nil, ids.Empty, startupAlpha/2)) + require.NoError(peers.AddStaker(commonCfg.Ctx.SubnetID, vdr0, nil, ids.Empty, startupAlpha/2)) require.NoError(bs.Connected(context.Background(), vdr0, version.CurrentApp)) require.NoError(bs.Start(context.Background(), 0)) @@ -199,7 +201,7 @@ func TestBootstrapperStartsOnlyIfEnoughStakeIsConnected(t *testing.T) { // finally attempt starting bootstrapper with enough stake connected. Frontiers should be requested. vdr := ids.GenerateTestNodeID() - require.NoError(peers.Add(vdr, nil, ids.Empty, startupAlpha)) + require.NoError(peers.AddStaker(commonCfg.Ctx.SubnetID, vdr, nil, ids.Empty, startupAlpha)) require.NoError(bs.Connected(context.Background(), vdr, version.CurrentApp)) require.True(frontierRequested) } @@ -1323,7 +1325,7 @@ func TestBootstrapNoParseOnNew(t *testing.T) { require := require.New(t) ctx := snow.DefaultConsensusContextTest() - peers := validators.NewSet() + peers := validators.NewManager() sender := &common.SenderTest{} vm := &block.TestVM{} @@ -1348,18 +1350,20 @@ func TestBootstrapNoParseOnNew(t *testing.T) { sender.CantSendGetAcceptedFrontier = false peer := ids.GenerateTestNodeID() - require.NoError(peers.Add(peer, nil, ids.Empty, 1)) + require.NoError(peers.AddStaker(ctx.SubnetID, peer, nil, ids.Empty, 1)) peerTracker := tracker.NewPeers() - startupTracker := tracker.NewStartup(peerTracker, peers.Weight()/2+1) - peers.RegisterCallbackListener(startupTracker) + totalWeight, err := peers.TotalWeight(ctx.SubnetID) + require.NoError(err) + startupTracker := tracker.NewStartup(peerTracker, totalWeight/2+1) + peers.RegisterCallbackListener(ctx.SubnetID, startupTracker) require.NoError(startupTracker.Connected(context.Background(), peer, version.CurrentApp)) commonConfig := common.Config{ Ctx: ctx, Beacons: peers, - SampleK: peers.Len(), - Alpha: peers.Weight()/2 + 1, + SampleK: peers.Count(ctx.SubnetID), + Alpha: totalWeight/2 + 1, StartupTracker: startupTracker, Sender: sender, BootstrapTracker: bootstrapTracker, diff --git a/snow/engine/snowman/config.go b/snow/engine/snowman/config.go index 07d9609e854e..ed63af2f4936 100644 --- a/snow/engine/snowman/config.go +++ b/snow/engine/snowman/config.go @@ -19,7 +19,7 @@ type Config struct { Ctx *snow.ConsensusContext VM block.ChainVM Sender common.Sender - Validators validators.Set + Validators validators.Manager Params snowball.Parameters Consensus snowman.Consensus PartialSync bool diff --git a/snow/engine/snowman/config_test.go b/snow/engine/snowman/config_test.go index c0b31298891a..54d9536a4884 100644 --- a/snow/engine/snowman/config_test.go +++ b/snow/engine/snowman/config_test.go @@ -16,7 +16,7 @@ func DefaultConfigs() Config { return Config{ Ctx: commonCfg.Ctx, Sender: commonCfg.Sender, - Validators: validators.NewSet(), + Validators: validators.NewManager(), VM: &block.TestVM{}, Params: snowball.Parameters{ K: 1, diff --git a/snow/engine/snowman/getter/getter_test.go b/snow/engine/snowman/getter/getter_test.go index b2891d36e387..35a0e11f9ebb 100644 --- a/snow/engine/snowman/getter/getter_test.go +++ b/snow/engine/snowman/getter/getter_test.go @@ -35,7 +35,7 @@ func testSetup( ) (StateSyncEnabledMock, *common.SenderTest, common.Config) { ctx := snow.DefaultConsensusContextTest() - peers := validators.NewSet() + peers := validators.NewManager() sender := &common.SenderTest{} vm := StateSyncEnabledMock{ TestVM: &block.TestVM{}, @@ -60,13 +60,15 @@ func testSetup( sender.CantSendGetAcceptedFrontier = false peer := ids.GenerateTestNodeID() - require.NoError(t, peers.Add(peer, nil, ids.Empty, 1)) + require.NoError(t, peers.AddStaker(ctx.SubnetID, peer, nil, ids.Empty, 1)) + totalWeight, err := peers.TotalWeight(ctx.SubnetID) + require.NoError(t, err) commonConfig := common.Config{ Ctx: ctx, Beacons: peers, - SampleK: peers.Len(), - Alpha: peers.Weight()/2 + 1, + SampleK: peers.Count(ctx.SubnetID), + Alpha: totalWeight/2 + 1, Sender: sender, BootstrapTracker: bootstrapTracker, Timer: &common.TimerTest{}, diff --git a/snow/engine/snowman/syncer/config.go b/snow/engine/snowman/syncer/config.go index 7b2d59f549ed..4e10d412f38a 100644 --- a/snow/engine/snowman/syncer/config.go +++ b/snow/engine/snowman/syncer/config.go @@ -4,6 +4,8 @@ package syncer import ( + "fmt" + "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow/engine/common" "github.com/ava-labs/avalanchego/snow/engine/snowman/block" @@ -25,7 +27,7 @@ type Config struct { // StateSyncBeacons are the nodes that will be used to sample and vote over // state summaries. - StateSyncBeacons validators.Set + StateSyncBeacons validators.Manager VM block.ChainVM } @@ -47,14 +49,17 @@ func NewConfig( // If the user has manually provided state syncer IDs, then override the // state sync beacons to them. if len(stateSyncerIDs) != 0 { - stateSyncBeacons = validators.NewSet() + stateSyncBeacons = validators.NewManager() for _, peerID := range stateSyncerIDs { // Invariant: We never use the TxID or BLS keys populated here. - if err := stateSyncBeacons.Add(peerID, nil, ids.Empty, 1); err != nil { + if err := stateSyncBeacons.AddStaker(commonCfg.Ctx.SubnetID, peerID, nil, ids.Empty, 1); err != nil { return Config{}, err } } - stateSyncingWeight := stateSyncBeacons.Weight() + stateSyncingWeight, err := stateSyncBeacons.TotalWeight(commonCfg.Ctx.SubnetID) + if err != nil { + return Config{}, fmt.Errorf("failed to calculate total weight of state sync beacons for subnet %s: %w", commonCfg.Ctx.SubnetID, err) + } if uint64(syncSampleK) > stateSyncingWeight { syncSampleK = int(stateSyncingWeight) } diff --git a/snow/engine/snowman/syncer/state_syncer.go b/snow/engine/snowman/syncer/state_syncer.go index 264b5a2b8d96..87e6d1786173 100644 --- a/snow/engine/snowman/syncer/state_syncer.go +++ b/snow/engine/snowman/syncer/state_syncer.go @@ -59,7 +59,7 @@ type stateSyncer struct { // Holds the beacons that were sampled for the accepted frontier // Won't be consumed as seeders are reached out. Used to rescale // alpha for frontiers - frontierSeeders validators.Set + frontierSeeders validators.Manager // IDs of validators we should request state summary frontier from. // Will be consumed seeders are reached out for frontier. targetSeeders set.Set[ids.NodeID] @@ -191,10 +191,21 @@ func (ss *stateSyncer) receivedStateSummaryFrontier(ctx context.Context) error { // If we got too many timeouts, we restart state syncing hoping that network // problems will go away and we can collect a qualified frontier. // We assume the frontier is qualified after an alpha proportion of frontier seeders have responded - frontierAlpha := float64(ss.frontierSeeders.Weight()*ss.Alpha) / float64(ss.StateSyncBeacons.Weight()) - failedBeaconWeight := ss.StateSyncBeacons.SubsetWeight(ss.failedSeeders) + frontiersTotalWeight, err := ss.frontierSeeders.TotalWeight(ss.Ctx.SubnetID) + if err != nil { + return fmt.Errorf("failed to get total weight of frontier seeders for subnet %s: %w", ss.Ctx.SubnetID, err) + } + beaconsTotalWeight, err := ss.StateSyncBeacons.TotalWeight(ss.Ctx.SubnetID) + if err != nil { + return fmt.Errorf("failed to get total weight of state sync beacons for subnet %s: %w", ss.Ctx.SubnetID, err) + } + frontierAlpha := float64(frontiersTotalWeight*ss.Alpha) / float64(beaconsTotalWeight) + failedBeaconWeight, err := ss.StateSyncBeacons.SubsetWeight(ss.Ctx.SubnetID, ss.failedSeeders) + if err != nil { + return fmt.Errorf("failed to get total weight of failed beacons: %w", err) + } - frontierStake := ss.frontierSeeders.Weight() - failedBeaconWeight + frontierStake := frontiersTotalWeight - failedBeaconWeight if float64(frontierStake) < frontierAlpha { ss.Ctx.Log.Debug("didn't receive enough frontiers", zap.Int("numFailedValidators", ss.failedSeeders.Len()), @@ -233,9 +244,10 @@ func (ss *stateSyncer) AcceptedStateSummary(ctx context.Context, nodeID ids.Node // Mark that we received a response from [nodeID] ss.pendingVoters.Remove(nodeID) - nodeWeight := ss.StateSyncBeacons.GetWeight(nodeID) + nodeWeight := ss.StateSyncBeacons.GetWeight(ss.Ctx.SubnetID, nodeID) ss.Ctx.Log.Debug("adding weight to summaries", zap.Stringer("nodeID", nodeID), + zap.Stringer("subnetID", ss.Ctx.SubnetID), zap.Stringers("summaryIDs", summaryIDs), zap.Uint64("nodeWeight", nodeWeight), ) @@ -299,18 +311,25 @@ func (ss *stateSyncer) AcceptedStateSummary(ctx context.Context, nodeID ids.Node size := len(ss.weightedSummaries) if size == 0 { // retry the state sync if the weight is not enough to state sync - failedBeaconWeight := ss.StateSyncBeacons.SubsetWeight(ss.failedVoters) + failedVotersWeight, err := ss.StateSyncBeacons.SubsetWeight(ss.Ctx.SubnetID, ss.failedVoters) + if err != nil { + return fmt.Errorf("failed to get total weight of failed voters: %w", err) + } // if we had too many timeouts when asking for validator votes, we should restart // state sync hoping for the network problems to go away; otherwise, we received // enough (>= ss.Alpha) responses, but no state summary was supported by a majority // of validators (i.e. votes are split between minorities supporting different state // summaries), so there is no point in retrying state sync; we should move ahead to bootstrapping - votingStakes := ss.StateSyncBeacons.Weight() - failedBeaconWeight + beaconsTotalWeight, err := ss.StateSyncBeacons.TotalWeight(ss.Ctx.SubnetID) + if err != nil { + return fmt.Errorf("failed to get total weight of state sync beacons for subnet %s: %w", ss.Ctx.SubnetID, err) + } + votingStakes := beaconsTotalWeight - failedVotersWeight if ss.Config.RetryBootstrap && votingStakes < ss.Alpha { ss.Ctx.Log.Debug("restarting state sync", zap.String("reason", "not enough votes received"), - zap.Int("numBeacons", ss.StateSyncBeacons.Len()), + zap.Int("numBeacons", ss.StateSyncBeacons.Count(ss.Ctx.SubnetID)), zap.Int("numFailedSyncers", ss.failedVoters.Len()), zap.Int("numAttempts", ss.attempts), ) @@ -445,18 +464,18 @@ func (ss *stateSyncer) startup(ctx context.Context) error { ss.failedVoters.Clear() // sample K beacons to retrieve frontier from - beaconIDs, err := ss.StateSyncBeacons.Sample(ss.Config.SampleK) + beaconIDs, err := ss.StateSyncBeacons.Sample(ss.Ctx.SubnetID, ss.Config.SampleK) if err != nil { return err } - ss.frontierSeeders = validators.NewSet() + ss.frontierSeeders = validators.NewManager() for _, nodeID := range beaconIDs { - if !ss.frontierSeeders.Contains(nodeID) { + if _, ok := ss.frontierSeeders.GetValidator(ss.Ctx.SubnetID, nodeID); !ok { // Invariant: We never use the TxID or BLS keys populated here. - err = ss.frontierSeeders.Add(nodeID, nil, ids.Empty, 1) + err = ss.frontierSeeders.AddStaker(ss.Ctx.SubnetID, nodeID, nil, ids.Empty, 1) } else { - err = ss.frontierSeeders.AddWeight(nodeID, 1) + err = ss.frontierSeeders.AddWeight(ss.Ctx.SubnetID, nodeID, 1) } if err != nil { return err @@ -465,7 +484,7 @@ func (ss *stateSyncer) startup(ctx context.Context) error { } // list all beacons, to reach them for voting on frontier - for nodeID := range ss.StateSyncBeacons.Map() { + for _, nodeID := range ss.StateSyncBeacons.GetValidatorIDs(ss.Ctx.SubnetID) { ss.targetVoters.Add(nodeID) } @@ -590,6 +609,10 @@ func (*stateSyncer) Gossip(context.Context) error { func (ss *stateSyncer) Shutdown(ctx context.Context) error { ss.Config.Ctx.Log.Info("shutting down state syncer") + + ss.Ctx.Lock.Lock() + defer ss.Ctx.Lock.Unlock() + return ss.VM.Shutdown(ctx) } @@ -600,6 +623,9 @@ func (*stateSyncer) Timeout(context.Context) error { } func (ss *stateSyncer) HealthCheck(ctx context.Context) (interface{}, error) { + ss.Ctx.Lock.Lock() + defer ss.Ctx.Lock.Unlock() + vmIntf, vmErr := ss.VM.HealthCheck(ctx) intf := map[string]interface{}{ "consensus": struct{}{}, @@ -618,5 +644,8 @@ func (ss *stateSyncer) IsEnabled(ctx context.Context) (bool, error) { return false, nil } + ss.Ctx.Lock.Lock() + defer ss.Ctx.Lock.Unlock() + return ss.stateSyncVM.StateSyncEnabled(ctx) } diff --git a/snow/engine/snowman/syncer/state_syncer_test.go b/snow/engine/snowman/syncer/state_syncer_test.go index 26401e625385..47a00e744471 100644 --- a/snow/engine/snowman/syncer/state_syncer_test.go +++ b/snow/engine/snowman/syncer/state_syncer_test.go @@ -97,19 +97,20 @@ func TestStateSyncerIsEnabledIfVMSupportsStateSyncing(t *testing.T) { func TestStateSyncingStartsOnlyIfEnoughStakeIsConnected(t *testing.T) { require := require.New(t) - - vdrs := buildTestPeers(t) - alpha := vdrs.Weight() + ctx := snow.DefaultConsensusContextTest() + vdrs := buildTestPeers(t, ctx.SubnetID) + alpha, err := vdrs.TotalWeight(ctx.SubnetID) + require.NoError(err) startupAlpha := alpha peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - vdrs.RegisterCallbackListener(startup) + vdrs.RegisterCallbackListener(ctx.SubnetID, startup) commonCfg := common.Config{ - Ctx: snow.DefaultConsensusContextTest(), + Ctx: ctx, Beacons: vdrs, - SampleK: vdrs.Len(), + SampleK: vdrs.Count(ctx.SubnetID), Alpha: alpha, StartupTracker: startup, } @@ -126,7 +127,7 @@ func TestStateSyncingStartsOnlyIfEnoughStakeIsConnected(t *testing.T) { // attempt starting bootstrapper with not enough stake connected. Bootstrapper should stall. vdr0 := ids.GenerateTestNodeID() - require.NoError(vdrs.Add(vdr0, nil, ids.Empty, startupAlpha/2)) + require.NoError(vdrs.AddStaker(ctx.SubnetID, vdr0, nil, ids.Empty, startupAlpha/2)) require.NoError(syncer.Connected(context.Background(), vdr0, version.CurrentApp)) require.False(commonCfg.StartupTracker.ShouldStart()) @@ -135,7 +136,7 @@ func TestStateSyncingStartsOnlyIfEnoughStakeIsConnected(t *testing.T) { // finally attempt starting bootstrapper with enough stake connected. Frontiers should be requested. vdr := ids.GenerateTestNodeID() - require.NoError(vdrs.Add(vdr, nil, ids.Empty, startupAlpha)) + require.NoError(vdrs.AddStaker(ctx.SubnetID, vdr, nil, ids.Empty, startupAlpha)) require.NoError(syncer.Connected(context.Background(), vdr, version.CurrentApp)) require.True(commonCfg.StartupTracker.ShouldStart()) @@ -145,19 +146,21 @@ func TestStateSyncingStartsOnlyIfEnoughStakeIsConnected(t *testing.T) { func TestStateSyncLocalSummaryIsIncludedAmongFrontiersIfAvailable(t *testing.T) { require := require.New(t) - - vdrs := buildTestPeers(t) - startupAlpha := (3*vdrs.Weight() + 3) / 4 + ctx := snow.DefaultConsensusContextTest() + vdrs := buildTestPeers(t, ctx.SubnetID) + totalWeight, err := vdrs.TotalWeight(ctx.SubnetID) + require.NoError(err) + startupAlpha := (3*totalWeight + 3) / 4 peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - vdrs.RegisterCallbackListener(startup) + vdrs.RegisterCallbackListener(ctx.SubnetID, startup) commonCfg := common.Config{ - Ctx: snow.DefaultConsensusContextTest(), + Ctx: ctx, Beacons: vdrs, - SampleK: vdrs.Len(), - Alpha: (vdrs.Weight() + 1) / 2, + SampleK: vdrs.Count(ctx.SubnetID), + Alpha: (totalWeight + 1) / 2, StartupTracker: startup, } syncer, fullVM, _ := buildTestsObjects(t, &commonCfg) @@ -174,7 +177,7 @@ func TestStateSyncLocalSummaryIsIncludedAmongFrontiersIfAvailable(t *testing.T) } // Connect enough stake to start syncer - for nodeID := range vdrs.Map() { + for _, nodeID := range vdrs.GetValidatorIDs(ctx.SubnetID) { require.NoError(syncer.Connected(context.Background(), nodeID, version.CurrentApp)) } @@ -187,19 +190,21 @@ func TestStateSyncLocalSummaryIsIncludedAmongFrontiersIfAvailable(t *testing.T) func TestStateSyncNotFoundOngoingSummaryIsNotIncludedAmongFrontiers(t *testing.T) { require := require.New(t) - - vdrs := buildTestPeers(t) - startupAlpha := (3*vdrs.Weight() + 3) / 4 + ctx := snow.DefaultConsensusContextTest() + vdrs := buildTestPeers(t, ctx.SubnetID) + totalWeight, err := vdrs.TotalWeight(ctx.SubnetID) + require.NoError(err) + startupAlpha := (3*totalWeight + 3) / 4 peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - vdrs.RegisterCallbackListener(startup) + vdrs.RegisterCallbackListener(ctx.SubnetID, startup) commonCfg := common.Config{ - Ctx: snow.DefaultConsensusContextTest(), + Ctx: ctx, Beacons: vdrs, - SampleK: vdrs.Len(), - Alpha: (vdrs.Weight() + 1) / 2, + SampleK: vdrs.Count(ctx.SubnetID), + Alpha: (totalWeight + 1) / 2, StartupTracker: startup, } syncer, fullVM, _ := buildTestsObjects(t, &commonCfg) @@ -211,7 +216,7 @@ func TestStateSyncNotFoundOngoingSummaryIsNotIncludedAmongFrontiers(t *testing.T } // Connect enough stake to start syncer - for nodeID := range vdrs.Map() { + for _, nodeID := range vdrs.GetValidatorIDs(ctx.SubnetID) { require.NoError(syncer.Connected(context.Background(), nodeID, version.CurrentApp)) } @@ -222,18 +227,21 @@ func TestStateSyncNotFoundOngoingSummaryIsNotIncludedAmongFrontiers(t *testing.T func TestBeaconsAreReachedForFrontiersUponStartup(t *testing.T) { require := require.New(t) - vdrs := buildTestPeers(t) - startupAlpha := (3*vdrs.Weight() + 3) / 4 + ctx := snow.DefaultConsensusContextTest() + vdrs := buildTestPeers(t, ctx.SubnetID) + totalWeight, err := vdrs.TotalWeight(ctx.SubnetID) + require.NoError(err) + startupAlpha := (3*totalWeight + 3) / 4 peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - vdrs.RegisterCallbackListener(startup) + vdrs.RegisterCallbackListener(ctx.SubnetID, startup) commonCfg := common.Config{ - Ctx: snow.DefaultConsensusContextTest(), + Ctx: ctx, Beacons: vdrs, - SampleK: vdrs.Len(), - Alpha: (vdrs.Weight() + 1) / 2, + SampleK: vdrs.Count(ctx.SubnetID), + Alpha: (totalWeight + 1) / 2, StartupTracker: startup, } syncer, _, sender := buildTestsObjects(t, &commonCfg) @@ -246,12 +254,12 @@ func TestBeaconsAreReachedForFrontiersUponStartup(t *testing.T) { } // Connect enough stake to start syncer - for nodeID := range vdrs.Map() { + for _, nodeID := range vdrs.GetValidatorIDs(ctx.SubnetID) { require.NoError(syncer.Connected(context.Background(), nodeID, version.CurrentApp)) } // check that vdrs are reached out for frontiers - require.Len(contactedFrontiersProviders, safemath.Min(vdrs.Len(), common.MaxOutstandingBroadcastRequests)) + require.Len(contactedFrontiersProviders, safemath.Min(vdrs.Count(ctx.SubnetID), common.MaxOutstandingBroadcastRequests)) for beaconID := range contactedFrontiersProviders { // check that beacon is duly marked as reached out require.Contains(syncer.pendingSeeders, beaconID) @@ -264,18 +272,21 @@ func TestBeaconsAreReachedForFrontiersUponStartup(t *testing.T) { func TestUnRequestedStateSummaryFrontiersAreDropped(t *testing.T) { require := require.New(t) - vdrs := buildTestPeers(t) - startupAlpha := (3*vdrs.Weight() + 3) / 4 + ctx := snow.DefaultConsensusContextTest() + vdrs := buildTestPeers(t, ctx.SubnetID) + totalWeight, err := vdrs.TotalWeight(ctx.SubnetID) + require.NoError(err) + startupAlpha := (3*totalWeight + 3) / 4 peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - vdrs.RegisterCallbackListener(startup) + vdrs.RegisterCallbackListener(ctx.SubnetID, startup) commonCfg := common.Config{ - Ctx: snow.DefaultConsensusContextTest(), + Ctx: ctx, Beacons: vdrs, - SampleK: vdrs.Len(), - Alpha: (vdrs.Weight() + 1) / 2, + SampleK: vdrs.Count(ctx.SubnetID), + Alpha: (totalWeight + 1) / 2, StartupTracker: startup, } syncer, fullVM, sender := buildTestsObjects(t, &commonCfg) @@ -290,7 +301,7 @@ func TestUnRequestedStateSummaryFrontiersAreDropped(t *testing.T) { } // Connect enough stake to start syncer - for nodeID := range vdrs.Map() { + for _, nodeID := range vdrs.GetValidatorIDs(ctx.SubnetID) { require.NoError(syncer.Connected(context.Background(), nodeID, version.CurrentApp)) } @@ -351,24 +362,27 @@ func TestUnRequestedStateSummaryFrontiersAreDropped(t *testing.T) { // other listed vdrs are reached for data require.True( len(contactedFrontiersProviders) > initiallyReachedOutBeaconsSize || - len(contactedFrontiersProviders) == vdrs.Len()) + len(contactedFrontiersProviders) == vdrs.Count(ctx.SubnetID)) } func TestMalformedStateSummaryFrontiersAreDropped(t *testing.T) { require := require.New(t) - vdrs := buildTestPeers(t) - startupAlpha := (3*vdrs.Weight() + 3) / 4 + ctx := snow.DefaultConsensusContextTest() + vdrs := buildTestPeers(t, ctx.SubnetID) + totalWeight, err := vdrs.TotalWeight(ctx.SubnetID) + require.NoError(err) + startupAlpha := (3*totalWeight + 3) / 4 peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - vdrs.RegisterCallbackListener(startup) + vdrs.RegisterCallbackListener(ctx.SubnetID, startup) commonCfg := common.Config{ - Ctx: snow.DefaultConsensusContextTest(), + Ctx: ctx, Beacons: vdrs, - SampleK: vdrs.Len(), - Alpha: (vdrs.Weight() + 1) / 2, + SampleK: vdrs.Count(ctx.SubnetID), + Alpha: (totalWeight + 1) / 2, StartupTracker: startup, } syncer, fullVM, sender := buildTestsObjects(t, &commonCfg) @@ -383,7 +397,7 @@ func TestMalformedStateSummaryFrontiersAreDropped(t *testing.T) { } // Connect enough stake to start syncer - for nodeID := range vdrs.Map() { + for _, nodeID := range vdrs.GetValidatorIDs(ctx.SubnetID) { require.NoError(syncer.Connected(context.Background(), nodeID, version.CurrentApp)) } @@ -423,24 +437,27 @@ func TestMalformedStateSummaryFrontiersAreDropped(t *testing.T) { // are reached for data require.True( len(contactedFrontiersProviders) > initiallyReachedOutBeaconsSize || - len(contactedFrontiersProviders) == vdrs.Len()) + len(contactedFrontiersProviders) == vdrs.Count(ctx.SubnetID)) } func TestLateResponsesFromUnresponsiveFrontiersAreNotRecorded(t *testing.T) { require := require.New(t) - vdrs := buildTestPeers(t) - startupAlpha := (3*vdrs.Weight() + 3) / 4 + ctx := snow.DefaultConsensusContextTest() + vdrs := buildTestPeers(t, ctx.SubnetID) + totalWeight, err := vdrs.TotalWeight(ctx.SubnetID) + require.NoError(err) + startupAlpha := (3*totalWeight + 3) / 4 peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - vdrs.RegisterCallbackListener(startup) + vdrs.RegisterCallbackListener(ctx.SubnetID, startup) commonCfg := common.Config{ - Ctx: snow.DefaultConsensusContextTest(), + Ctx: ctx, Beacons: vdrs, - SampleK: vdrs.Len(), - Alpha: (vdrs.Weight() + 1) / 2, + SampleK: vdrs.Count(ctx.SubnetID), + Alpha: (totalWeight + 1) / 2, StartupTracker: startup, } syncer, fullVM, sender := buildTestsObjects(t, &commonCfg) @@ -455,7 +472,7 @@ func TestLateResponsesFromUnresponsiveFrontiersAreNotRecorded(t *testing.T) { } // Connect enough stake to start syncer - for nodeID := range vdrs.Map() { + for _, nodeID := range vdrs.GetValidatorIDs(ctx.SubnetID) { require.NoError(syncer.Connected(context.Background(), nodeID, version.CurrentApp)) } @@ -488,7 +505,7 @@ func TestLateResponsesFromUnresponsiveFrontiersAreNotRecorded(t *testing.T) { // are reached for data require.True( len(contactedFrontiersProviders) > initiallyReachedOutBeaconsSize || - len(contactedFrontiersProviders) == vdrs.Len()) + len(contactedFrontiersProviders) == vdrs.Count(ctx.SubnetID)) // mock VM to simulate a valid but late summary is returned fullVM.CantParseStateSummary = true @@ -515,18 +532,21 @@ func TestLateResponsesFromUnresponsiveFrontiersAreNotRecorded(t *testing.T) { func TestStateSyncIsRestartedIfTooManyFrontierSeedersTimeout(t *testing.T) { require := require.New(t) - vdrs := buildTestPeers(t) - startupAlpha := (3*vdrs.Weight() + 3) / 4 + ctx := snow.DefaultConsensusContextTest() + vdrs := buildTestPeers(t, ctx.SubnetID) + totalWeight, err := vdrs.TotalWeight(ctx.SubnetID) + require.NoError(err) + startupAlpha := (3*totalWeight + 3) / 4 peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - vdrs.RegisterCallbackListener(startup) + vdrs.RegisterCallbackListener(ctx.SubnetID, startup) commonCfg := common.Config{ Ctx: snow.DefaultConsensusContextTest(), Beacons: vdrs, - SampleK: vdrs.Len(), - Alpha: (vdrs.Weight() + 1) / 2, + SampleK: vdrs.Count(ctx.SubnetID), + Alpha: (totalWeight + 1) / 2, StartupTracker: startup, RetryBootstrap: true, RetryBootstrapWarnFrequency: 1, @@ -568,7 +588,7 @@ func TestStateSyncIsRestartedIfTooManyFrontierSeedersTimeout(t *testing.T) { } // Connect enough stake to start syncer - for nodeID := range vdrs.Map() { + for _, nodeID := range vdrs.GetValidatorIDs(ctx.SubnetID) { require.NoError(syncer.Connected(context.Background(), nodeID, version.CurrentApp)) } require.NotEmpty(syncer.pendingSeeders) @@ -609,18 +629,21 @@ func TestStateSyncIsRestartedIfTooManyFrontierSeedersTimeout(t *testing.T) { func TestVoteRequestsAreSentAsAllFrontierBeaconsResponded(t *testing.T) { require := require.New(t) - vdrs := buildTestPeers(t) - startupAlpha := (3*vdrs.Weight() + 3) / 4 + ctx := snow.DefaultConsensusContextTest() + vdrs := buildTestPeers(t, ctx.SubnetID) + totalWeight, err := vdrs.TotalWeight(ctx.SubnetID) + require.NoError(err) + startupAlpha := (3*totalWeight + 3) / 4 peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - vdrs.RegisterCallbackListener(startup) + vdrs.RegisterCallbackListener(ctx.SubnetID, startup) commonCfg := common.Config{ - Ctx: snow.DefaultConsensusContextTest(), + Ctx: ctx, Beacons: vdrs, - SampleK: vdrs.Len(), - Alpha: (vdrs.Weight() + 1) / 2, + SampleK: vdrs.Count(ctx.SubnetID), + Alpha: (totalWeight + 1) / 2, StartupTracker: startup, } syncer, fullVM, sender := buildTestsObjects(t, &commonCfg) @@ -654,7 +677,7 @@ func TestVoteRequestsAreSentAsAllFrontierBeaconsResponded(t *testing.T) { } // Connect enough stake to start syncer - for nodeID := range vdrs.Map() { + for _, nodeID := range vdrs.GetValidatorIDs(ctx.SubnetID) { require.NoError(syncer.Connected(context.Background(), nodeID, version.CurrentApp)) } require.NotEmpty(syncer.pendingSeeders) @@ -683,18 +706,21 @@ func TestVoteRequestsAreSentAsAllFrontierBeaconsResponded(t *testing.T) { func TestUnRequestedVotesAreDropped(t *testing.T) { require := require.New(t) - vdrs := buildTestPeers(t) - startupAlpha := (3*vdrs.Weight() + 3) / 4 + ctx := snow.DefaultConsensusContextTest() + vdrs := buildTestPeers(t, ctx.SubnetID) + totalWeight, err := vdrs.TotalWeight(ctx.SubnetID) + require.NoError(err) + startupAlpha := (3*totalWeight + 3) / 4 peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - vdrs.RegisterCallbackListener(startup) + vdrs.RegisterCallbackListener(ctx.SubnetID, startup) commonCfg := common.Config{ - Ctx: snow.DefaultConsensusContextTest(), + Ctx: ctx, Beacons: vdrs, - SampleK: vdrs.Len(), - Alpha: (vdrs.Weight() + 1) / 2, + SampleK: vdrs.Count(ctx.SubnetID), + Alpha: (totalWeight + 1) / 2, StartupTracker: startup, } syncer, fullVM, sender := buildTestsObjects(t, &commonCfg) @@ -727,7 +753,7 @@ func TestUnRequestedVotesAreDropped(t *testing.T) { } // Connect enough stake to start syncer - for nodeID := range vdrs.Map() { + for _, nodeID := range vdrs.GetValidatorIDs(ctx.SubnetID) { require.NoError(syncer.Connected(context.Background(), nodeID, version.CurrentApp)) } require.NotEmpty(syncer.pendingSeeders) @@ -791,30 +817,33 @@ func TestUnRequestedVotesAreDropped(t *testing.T) { // responsiveBeacon not pending anymore require.NotContains(syncer.pendingSeeders, responsiveVoterID) - voterWeight := vdrs.GetWeight(responsiveVoterID) + voterWeight := vdrs.GetWeight(ctx.SubnetID, responsiveVoterID) require.Equal(voterWeight, syncer.weightedSummaries[summaryID].weight) // other listed voters are reached out require.True( len(contactedVoters) > initiallyContactedVotersSize || - len(contactedVoters) == vdrs.Len()) + len(contactedVoters) == vdrs.Count(ctx.SubnetID)) } func TestVotesForUnknownSummariesAreDropped(t *testing.T) { require := require.New(t) - vdrs := buildTestPeers(t) - startupAlpha := (3*vdrs.Weight() + 3) / 4 + ctx := snow.DefaultConsensusContextTest() + vdrs := buildTestPeers(t, ctx.SubnetID) + totalWeight, err := vdrs.TotalWeight(ctx.SubnetID) + require.NoError(err) + startupAlpha := (3*totalWeight + 3) / 4 peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - vdrs.RegisterCallbackListener(startup) + vdrs.RegisterCallbackListener(ctx.SubnetID, startup) commonCfg := common.Config{ - Ctx: snow.DefaultConsensusContextTest(), + Ctx: ctx, Beacons: vdrs, - SampleK: vdrs.Len(), - Alpha: (vdrs.Weight() + 1) / 2, + SampleK: vdrs.Count(ctx.SubnetID), + Alpha: (totalWeight + 1) / 2, StartupTracker: startup, } syncer, fullVM, sender := buildTestsObjects(t, &commonCfg) @@ -847,7 +876,7 @@ func TestVotesForUnknownSummariesAreDropped(t *testing.T) { } // Connect enough stake to start syncer - for nodeID := range vdrs.Map() { + for _, nodeID := range vdrs.GetValidatorIDs(ctx.SubnetID) { require.NoError(syncer.Connected(context.Background(), nodeID, version.CurrentApp)) } require.NotEmpty(syncer.pendingSeeders) @@ -903,24 +932,27 @@ func TestVotesForUnknownSummariesAreDropped(t *testing.T) { // on unknown summary require.True( len(contactedVoters) > initiallyContactedVotersSize || - len(contactedVoters) == vdrs.Len()) + len(contactedVoters) == vdrs.Count(ctx.SubnetID)) } func TestStateSummaryIsPassedToVMAsMajorityOfVotesIsCastedForIt(t *testing.T) { require := require.New(t) - vdrs := buildTestPeers(t) - startupAlpha := (3*vdrs.Weight() + 3) / 4 + ctx := snow.DefaultConsensusContextTest() + vdrs := buildTestPeers(t, ctx.SubnetID) + totalWeight, err := vdrs.TotalWeight(ctx.SubnetID) + require.NoError(err) + startupAlpha := (3*totalWeight + 3) / 4 peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - vdrs.RegisterCallbackListener(startup) + vdrs.RegisterCallbackListener(ctx.SubnetID, startup) commonCfg := common.Config{ - Ctx: snow.DefaultConsensusContextTest(), + Ctx: ctx, Beacons: vdrs, - SampleK: vdrs.Len(), - Alpha: (vdrs.Weight() + 1) / 2, + SampleK: vdrs.Count(ctx.SubnetID), + Alpha: (totalWeight + 1) / 2, StartupTracker: startup, } syncer, fullVM, sender := buildTestsObjects(t, &commonCfg) @@ -969,7 +1001,7 @@ func TestStateSummaryIsPassedToVMAsMajorityOfVotesIsCastedForIt(t *testing.T) { } // Connect enough stake to start syncer - for nodeID := range vdrs.Map() { + for _, nodeID := range vdrs.GetValidatorIDs(ctx.SubnetID) { require.NoError(syncer.Connected(context.Background(), nodeID, version.CurrentApp)) } require.NotEmpty(syncer.pendingSeeders) @@ -1028,7 +1060,7 @@ func TestStateSummaryIsPassedToVMAsMajorityOfVotesIsCastedForIt(t *testing.T) { reqID, []ids.ID{summaryID, minoritySummaryID}, )) - cumulatedWeight += vdrs.GetWeight(voterID) + cumulatedWeight += vdrs.GetWeight(ctx.SubnetID, voterID) case cumulatedWeight < commonCfg.Alpha: require.NoError(syncer.AcceptedStateSummary( @@ -1037,7 +1069,7 @@ func TestStateSummaryIsPassedToVMAsMajorityOfVotesIsCastedForIt(t *testing.T) { reqID, []ids.ID{summaryID}, )) - cumulatedWeight += vdrs.GetWeight(voterID) + cumulatedWeight += vdrs.GetWeight(ctx.SubnetID, voterID) default: require.NoError(syncer.GetAcceptedStateSummaryFailed( @@ -1056,18 +1088,21 @@ func TestStateSummaryIsPassedToVMAsMajorityOfVotesIsCastedForIt(t *testing.T) { func TestVotingIsRestartedIfMajorityIsNotReachedDueToTimeouts(t *testing.T) { require := require.New(t) - vdrs := buildTestPeers(t) - startupAlpha := (3*vdrs.Weight() + 3) / 4 + ctx := snow.DefaultConsensusContextTest() + vdrs := buildTestPeers(t, ctx.SubnetID) + totalWeight, err := vdrs.TotalWeight(ctx.SubnetID) + require.NoError(err) + startupAlpha := (3*totalWeight + 3) / 4 peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - vdrs.RegisterCallbackListener(startup) + vdrs.RegisterCallbackListener(ctx.SubnetID, startup) commonCfg := common.Config{ Ctx: snow.DefaultConsensusContextTest(), Beacons: vdrs, - SampleK: vdrs.Len(), - Alpha: (vdrs.Weight() + 1) / 2, + SampleK: vdrs.Count(ctx.SubnetID), + Alpha: (totalWeight + 1) / 2, StartupTracker: startup, RetryBootstrap: true, // this sets RetryStateSyncing too RetryBootstrapWarnFrequency: 1, // this sets RetrySyncingWarnFrequency too @@ -1104,7 +1139,7 @@ func TestVotingIsRestartedIfMajorityIsNotReachedDueToTimeouts(t *testing.T) { } // Connect enough stake to start syncer - for nodeID := range vdrs.Map() { + for _, nodeID := range vdrs.GetValidatorIDs(ctx.SubnetID) { require.NoError(syncer.Connected(context.Background(), nodeID, version.CurrentApp)) } require.NotEmpty(syncer.pendingSeeders) @@ -1144,7 +1179,7 @@ func TestVotingIsRestartedIfMajorityIsNotReachedDueToTimeouts(t *testing.T) { voterID, reqID, )) - timedOutWeight += vdrs.GetWeight(voterID) + timedOutWeight += vdrs.GetWeight(ctx.SubnetID, voterID) } else { require.NoError(syncer.AcceptedStateSummary( context.Background(), @@ -1166,18 +1201,21 @@ func TestVotingIsRestartedIfMajorityIsNotReachedDueToTimeouts(t *testing.T) { func TestStateSyncIsStoppedIfEnoughVotesAreCastedWithNoClearMajority(t *testing.T) { require := require.New(t) - vdrs := buildTestPeers(t) - startupAlpha := (3*vdrs.Weight() + 3) / 4 + ctx := snow.DefaultConsensusContextTest() + vdrs := buildTestPeers(t, ctx.SubnetID) + totalWeight, err := vdrs.TotalWeight(ctx.SubnetID) + require.NoError(err) + startupAlpha := (3*totalWeight + 3) / 4 peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - vdrs.RegisterCallbackListener(startup) + vdrs.RegisterCallbackListener(ctx.SubnetID, startup) commonCfg := common.Config{ - Ctx: snow.DefaultConsensusContextTest(), + Ctx: ctx, Beacons: vdrs, - SampleK: vdrs.Len(), - Alpha: (vdrs.Weight() + 1) / 2, + SampleK: vdrs.Count(ctx.SubnetID), + Alpha: (totalWeight + 1) / 2, StartupTracker: startup, } syncer, fullVM, sender := buildTestsObjects(t, &commonCfg) @@ -1226,7 +1264,7 @@ func TestStateSyncIsStoppedIfEnoughVotesAreCastedWithNoClearMajority(t *testing. } // Connect enough stake to start syncer - for nodeID := range vdrs.Map() { + for _, nodeID := range vdrs.GetValidatorIDs(ctx.SubnetID) { require.NoError(syncer.Connected(context.Background(), nodeID, version.CurrentApp)) } require.NotEmpty(syncer.pendingSeeders) @@ -1292,7 +1330,7 @@ func TestStateSyncIsStoppedIfEnoughVotesAreCastedWithNoClearMajority(t *testing. reqID, []ids.ID{minoritySummary1.ID(), minoritySummary2.ID()}, )) - votingWeightStake += vdrs.GetWeight(voterID) + votingWeightStake += vdrs.GetWeight(ctx.SubnetID, voterID) default: require.NoError(syncer.AcceptedStateSummary( @@ -1301,7 +1339,7 @@ func TestStateSyncIsStoppedIfEnoughVotesAreCastedWithNoClearMajority(t *testing. reqID, []ids.ID{{'u', 'n', 'k', 'n', 'o', 'w', 'n', 'I', 'D'}}, )) - votingWeightStake += vdrs.GetWeight(voterID) + votingWeightStake += vdrs.GetWeight(ctx.SubnetID, voterID) } } @@ -1314,18 +1352,21 @@ func TestStateSyncIsStoppedIfEnoughVotesAreCastedWithNoClearMajority(t *testing. func TestStateSyncIsDoneOnceVMNotifies(t *testing.T) { require := require.New(t) - vdrs := buildTestPeers(t) - startupAlpha := (3*vdrs.Weight() + 3) / 4 + ctx := snow.DefaultConsensusContextTest() + vdrs := buildTestPeers(t, ctx.SubnetID) + totalWeight, err := vdrs.TotalWeight(ctx.SubnetID) + require.NoError(err) + startupAlpha := (3*totalWeight + 3) / 4 peers := tracker.NewPeers() startup := tracker.NewStartup(peers, startupAlpha) - vdrs.RegisterCallbackListener(startup) + vdrs.RegisterCallbackListener(ctx.SubnetID, startup) commonCfg := common.Config{ Ctx: snow.DefaultConsensusContextTest(), Beacons: vdrs, - SampleK: vdrs.Len(), - Alpha: (vdrs.Weight() + 1) / 2, + SampleK: vdrs.Count(ctx.SubnetID), + Alpha: (totalWeight + 1) / 2, StartupTracker: startup, RetryBootstrap: true, // this sets RetryStateSyncing too RetryBootstrapWarnFrequency: 1, // this sets RetrySyncingWarnFrequency too diff --git a/snow/engine/snowman/syncer/utils_test.go b/snow/engine/snowman/syncer/utils_test.go index a39a81a2b8c6..f83a3006aaa1 100644 --- a/snow/engine/snowman/syncer/utils_test.go +++ b/snow/engine/snowman/syncer/utils_test.go @@ -54,13 +54,13 @@ type fullVM struct { *block.TestStateSyncableVM } -func buildTestPeers(t *testing.T) validators.Set { +func buildTestPeers(t *testing.T, subnetID ids.ID) validators.Manager { // we consider more than common.MaxOutstandingBroadcastRequests peers // so to test the effect of cap on number of requests sent out - vdrs := validators.NewSet() + vdrs := validators.NewManager() for idx := 0; idx < 2*common.MaxOutstandingBroadcastRequests; idx++ { beaconID := ids.GenerateTestNodeID() - require.NoError(t, vdrs.Add(beaconID, nil, ids.Empty, 1)) + require.NoError(t, vdrs.AddStaker(subnetID, beaconID, nil, ids.Empty, 1)) } return vdrs } diff --git a/snow/engine/snowman/transitive.go b/snow/engine/snowman/transitive.go index d61e36128006..803c03237c96 100644 --- a/snow/engine/snowman/transitive.go +++ b/snow/engine/snowman/transitive.go @@ -110,7 +110,7 @@ func newTransitive(config Config) (*Transitive, error) { } acceptedFrontiers := tracker.NewAccepted() - config.Validators.RegisterCallbackListener(acceptedFrontiers) + config.Validators.RegisterCallbackListener(config.Ctx.SubnetID, acceptedFrontiers) factory := poll.NewEarlyTermNoTraversalFactory( config.Params.AlphaPreference, @@ -367,6 +367,10 @@ func (*Transitive) Halt(context.Context) {} func (t *Transitive) Shutdown(ctx context.Context) error { t.Ctx.Log.Info("shutting down consensus engine") + + t.Ctx.Lock.Lock() + defer t.Ctx.Lock.Unlock() + return t.VM.Shutdown(ctx) } @@ -453,6 +457,9 @@ func (t *Transitive) Start(ctx context.Context, startReqID uint32) error { } func (t *Transitive) HealthCheck(ctx context.Context) (interface{}, error) { + t.Ctx.Lock.Lock() + defer t.Ctx.Lock.Unlock() + consensusIntf, consensusErr := t.Consensus.HealthCheck(ctx) vmIntf, vmErr := t.VM.HealthCheck(ctx) intf := map[string]interface{}{ @@ -784,7 +791,7 @@ func (t *Transitive) sendQuery( zap.Stringer("validators", t.Validators), ) - vdrIDs, err := t.Validators.Sample(t.Params.K) + vdrIDs, err := t.Validators.Sample(t.Ctx.SubnetID, t.Params.K) if err != nil { t.Ctx.Log.Error("dropped query for block", zap.String("reason", "insufficient number of validators"), @@ -958,9 +965,12 @@ func (t *Transitive) addToNonVerifieds(blk snowman.Block) { // addUnverifiedBlockToConsensus returns whether the block was added and an // error if one occurred while adding it to consensus. func (t *Transitive) addUnverifiedBlockToConsensus(ctx context.Context, blk snowman.Block) (bool, error) { + blkID := blk.ID() + // make sure this block is valid if err := blk.Verify(ctx); err != nil { t.Ctx.Log.Debug("block verification failed", + zap.Stringer("blkID", blkID), zap.Error(err), ) @@ -969,7 +979,6 @@ func (t *Transitive) addUnverifiedBlockToConsensus(ctx context.Context, blk snow return false, nil } - blkID := blk.ID() t.nonVerifieds.Remove(blkID) t.nonVerifiedCache.Evict(blkID) t.metrics.numNonVerifieds.Set(float64(t.nonVerifieds.Len())) diff --git a/snow/engine/snowman/transitive_test.go b/snow/engine/snowman/transitive_test.go index bb229b5cb530..8993a4e90f9b 100644 --- a/snow/engine/snowman/transitive_test.go +++ b/snow/engine/snowman/transitive_test.go @@ -32,14 +32,14 @@ var ( Genesis = ids.GenerateTestID() ) -func setup(t *testing.T, commonCfg common.Config, engCfg Config) (ids.NodeID, validators.Set, *common.SenderTest, *block.TestVM, *Transitive, snowman.Block) { +func setup(t *testing.T, commonCfg common.Config, engCfg Config) (ids.NodeID, validators.Manager, *common.SenderTest, *block.TestVM, *Transitive, snowman.Block) { require := require.New(t) - vals := validators.NewSet() + vals := validators.NewManager() engCfg.Validators = vals vdr := ids.GenerateTestNodeID() - require.NoError(vals.Add(vdr, nil, ids.Empty, 1)) + require.NoError(vals.AddStaker(commonCfg.Ctx.SubnetID, vdr, nil, ids.Empty, 1)) sender := &common.SenderTest{T: t} engCfg.Sender = sender @@ -86,7 +86,7 @@ func setup(t *testing.T, commonCfg common.Config, engCfg Config) (ids.NodeID, va return vdr, vals, sender, vm, te, gBlk } -func setupDefaultConfig(t *testing.T) (ids.NodeID, validators.Set, *common.SenderTest, *block.TestVM, *Transitive, snowman.Block) { +func setupDefaultConfig(t *testing.T) (ids.NodeID, validators.Manager, *common.SenderTest, *block.TestVM, *Transitive, snowman.Block) { commonCfg := common.DefaultConfigTest() engCfg := DefaultConfigs() return setup(t, commonCfg, engCfg) @@ -331,16 +331,16 @@ func TestEngineMultipleQuery(t *testing.T) { MaxItemProcessingTime: 1, } - vals := validators.NewSet() + vals := validators.NewManager() engCfg.Validators = vals vdr0 := ids.GenerateTestNodeID() vdr1 := ids.GenerateTestNodeID() vdr2 := ids.GenerateTestNodeID() - require.NoError(vals.Add(vdr0, nil, ids.Empty, 1)) - require.NoError(vals.Add(vdr1, nil, ids.Empty, 1)) - require.NoError(vals.Add(vdr2, nil, ids.Empty, 1)) + require.NoError(vals.AddStaker(engCfg.Ctx.SubnetID, vdr0, nil, ids.Empty, 1)) + require.NoError(vals.AddStaker(engCfg.Ctx.SubnetID, vdr1, nil, ids.Empty, 1)) + require.NoError(vals.AddStaker(engCfg.Ctx.SubnetID, vdr2, nil, ids.Empty, 1)) sender := &common.SenderTest{T: t} engCfg.Sender = sender @@ -726,16 +726,16 @@ func TestVoteCanceling(t *testing.T) { MaxItemProcessingTime: 1, } - vals := validators.NewSet() + vals := validators.NewManager() engCfg.Validators = vals vdr0 := ids.GenerateTestNodeID() vdr1 := ids.GenerateTestNodeID() vdr2 := ids.GenerateTestNodeID() - require.NoError(vals.Add(vdr0, nil, ids.Empty, 1)) - require.NoError(vals.Add(vdr1, nil, ids.Empty, 1)) - require.NoError(vals.Add(vdr2, nil, ids.Empty, 1)) + require.NoError(vals.AddStaker(engCfg.Ctx.SubnetID, vdr0, nil, ids.Empty, 1)) + require.NoError(vals.AddStaker(engCfg.Ctx.SubnetID, vdr1, nil, ids.Empty, 1)) + require.NoError(vals.AddStaker(engCfg.Ctx.SubnetID, vdr2, nil, ids.Empty, 1)) sender := &common.SenderTest{T: t} engCfg.Sender = sender @@ -1313,7 +1313,7 @@ func TestEngineInvalidBlockIgnoredFromUnexpectedPeer(t *testing.T) { vdr, vdrs, sender, vm, te, gBlk := setupDefaultConfig(t) secondVdr := ids.GenerateTestNodeID() - require.NoError(vdrs.Add(secondVdr, nil, ids.Empty, 1)) + require.NoError(vdrs.AddStaker(te.Ctx.SubnetID, secondVdr, nil, ids.Empty, 1)) sender.Default(true) @@ -1499,11 +1499,11 @@ func TestEngineAggressivePolling(t *testing.T) { engCfg := DefaultConfigs() engCfg.Params.ConcurrentRepolls = 2 - vals := validators.NewSet() + vals := validators.NewManager() engCfg.Validators = vals vdr := ids.GenerateTestNodeID() - require.NoError(vals.Add(vdr, nil, ids.Empty, 1)) + require.NoError(vals.AddStaker(engCfg.Ctx.SubnetID, vdr, nil, ids.Empty, 1)) sender := &common.SenderTest{T: t} engCfg.Sender = sender @@ -1597,14 +1597,14 @@ func TestEngineDoubleChit(t *testing.T) { MaxItemProcessingTime: 1, } - vals := validators.NewSet() + vals := validators.NewManager() engCfg.Validators = vals vdr0 := ids.GenerateTestNodeID() vdr1 := ids.GenerateTestNodeID() - require.NoError(vals.Add(vdr0, nil, ids.Empty, 1)) - require.NoError(vals.Add(vdr1, nil, ids.Empty, 1)) + require.NoError(vals.AddStaker(engCfg.Ctx.SubnetID, vdr0, nil, ids.Empty, 1)) + require.NoError(vals.AddStaker(engCfg.Ctx.SubnetID, vdr1, nil, ids.Empty, 1)) sender := &common.SenderTest{T: t} engCfg.Sender = sender @@ -1694,11 +1694,11 @@ func TestEngineBuildBlockLimit(t *testing.T) { engCfg.Params.AlphaConfidence = 1 engCfg.Params.OptimalProcessing = 1 - vals := validators.NewSet() + vals := validators.NewManager() engCfg.Validators = vals vdr := ids.GenerateTestNodeID() - require.NoError(vals.Add(vdr, nil, ids.Empty, 1)) + require.NoError(vals.AddStaker(engCfg.Ctx.SubnetID, vdr, nil, ids.Empty, 1)) sender := &common.SenderTest{T: t} engCfg.Sender = sender @@ -2723,11 +2723,11 @@ func TestEngineApplyAcceptedFrontierInQueryFailed(t *testing.T) { MaxItemProcessingTime: 1, } - vals := validators.NewSet() + vals := validators.NewManager() engCfg.Validators = vals vdr := ids.GenerateTestNodeID() - require.NoError(vals.Add(vdr, nil, ids.Empty, 1)) + require.NoError(vals.AddStaker(engCfg.Ctx.SubnetID, vdr, nil, ids.Empty, 1)) sender := &common.SenderTest{T: t} engCfg.Sender = sender diff --git a/snow/networking/benchlist/benchlist.go b/snow/networking/benchlist/benchlist.go index 4f7cb20ab244..394899a1f37a 100644 --- a/snow/networking/benchlist/benchlist.go +++ b/snow/networking/benchlist/benchlist.go @@ -4,19 +4,17 @@ package benchlist import ( - "container/heap" "fmt" "math/rand" "sync" "time" - "github.com/prometheus/client_golang/prometheus" - "go.uber.org/zap" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow" "github.com/ava-labs/avalanchego/snow/validators" - "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/avalanchego/utils/heap" "github.com/ava-labs/avalanchego/utils/set" "github.com/ava-labs/avalanchego/utils/timer" "github.com/ava-labs/avalanchego/utils/timer/mockable" @@ -24,8 +22,6 @@ import ( safemath "github.com/ava-labs/avalanchego/utils/math" ) -var _ heap.Interface = (*benchedQueue)(nil) - // If a peer consistently does not respond to queries, it will // increase latencies on the network whenever that peer is polled. // If we cannot terminate the poll early, then the poll will wait @@ -45,46 +41,6 @@ type Benchlist interface { IsBenched(nodeID ids.NodeID) bool } -// Data about a validator who is benched -type benchData struct { - benchedUntil time.Time - nodeID ids.NodeID - index int -} - -// Each element is a benched validator -type benchedQueue []*benchData - -func (bq benchedQueue) Len() int { - return len(bq) -} - -func (bq benchedQueue) Less(i, j int) bool { - return bq[i].benchedUntil.Before(bq[j].benchedUntil) -} - -func (bq benchedQueue) Swap(i, j int) { - bq[i], bq[j] = bq[j], bq[i] - bq[i].index = i - bq[j].index = j -} - -// Push adds an item to this queue. x must have type *benchData -func (bq *benchedQueue) Push(x interface{}) { - item := x.(*benchData) - item.index = len(*bq) - *bq = append(*bq, item) -} - -// Pop returns the validator that should leave the bench next -func (bq *benchedQueue) Pop() interface{} { - n := len(*bq) - item := (*bq)[n-1] - (*bq)[n-1] = nil // make sure the item is freed from memory - *bq = (*bq)[:n-1] - return item -} - type failureStreak struct { // Time of first consecutive timeout firstFailure time.Time @@ -94,9 +50,8 @@ type failureStreak struct { type benchlist struct { lock sync.RWMutex - // This is the benchlist for chain [chainID] - chainID ids.ID - log logging.Logger + // Context of the chain this is the benchlist for + ctx *snow.ConsensusContext metrics metrics // Fires when the next validator should leave the bench @@ -110,7 +65,7 @@ type benchlist struct { benchable Benchable // Validator set of the network - vdrs validators.Set + vdrs validators.Manager // Validator ID --> Consecutive failure information // [streaklock] must be held when touching [failureStreaks] @@ -120,9 +75,8 @@ type benchlist struct { // IDs of validators that are currently benched benchlistSet set.Set[ids.NodeID] - // Min heap containing benched validators and their endtimes - // Pop() returns the next validator to leave - benchedQueue benchedQueue + // Min heap of benched validators ordered by when they can be unbenched + benchedHeap heap.Map[ids.NodeID, time.Time] // A validator will be benched if [threshold] messages in a row // to them time out and the first of those messages was more than @@ -140,25 +94,23 @@ type benchlist struct { // NewBenchlist returns a new Benchlist func NewBenchlist( - chainID ids.ID, - log logging.Logger, + ctx *snow.ConsensusContext, benchable Benchable, - validators validators.Set, + validators validators.Manager, threshold int, minimumFailingDuration, duration time.Duration, maxPortion float64, - registerer prometheus.Registerer, ) (Benchlist, error) { if maxPortion < 0 || maxPortion >= 1 { return nil, fmt.Errorf("max portion of benched stake must be in [0,1) but got %f", maxPortion) } benchlist := &benchlist{ - chainID: chainID, - log: log, + ctx: ctx, failureStreaks: make(map[ids.NodeID]failureStreak), benchlistSet: set.Set[ids.NodeID]{}, benchable: benchable, + benchedHeap: heap.NewMap[ids.NodeID, time.Time](time.Time.Before), vdrs: validators, threshold: threshold, minimumFailingDuration: minimumFailingDuration, @@ -167,7 +119,7 @@ func NewBenchlist( } benchlist.timer = timer.NewTimer(benchlist.update) go benchlist.timer.Dispatch() - return benchlist, benchlist.metrics.Initialize(registerer) + return benchlist, benchlist.metrics.Initialize(ctx.Registerer) } // Update removes benched validators whose time on the bench is over @@ -177,60 +129,59 @@ func (b *benchlist) update() { now := b.clock.Time() for { - // [next] is nil when no more validators should - // leave the bench at this time - next := b.nextToLeave(now) - if next == nil { + if !b.canUnbench(now) { break } - b.remove(next) + b.remove() } // Set next time update will be called b.setNextLeaveTime() } -// Remove [validator] from the benchlist +// Removes the next node from the benchlist // Assumes [b.lock] is held -func (b *benchlist) remove(node *benchData) { - // Update state - id := node.nodeID - b.log.Debug("removing node from benchlist", - zap.Stringer("nodeID", id), +func (b *benchlist) remove() { + nodeID, _, _ := b.benchedHeap.Pop() + b.ctx.Log.Debug("removing node from benchlist", + zap.Stringer("nodeID", nodeID), ) - heap.Remove(&b.benchedQueue, node.index) - b.benchlistSet.Remove(id) - b.benchable.Unbenched(b.chainID, id) + b.benchlistSet.Remove(nodeID) + b.benchable.Unbenched(b.ctx.ChainID, nodeID) // Update metrics - b.metrics.numBenched.Set(float64(b.benchedQueue.Len())) - benchedStake := b.vdrs.SubsetWeight(b.benchlistSet) + b.metrics.numBenched.Set(float64(b.benchedHeap.Len())) + benchedStake, err := b.vdrs.SubsetWeight(b.ctx.SubnetID, b.benchlistSet) + if err != nil { + b.ctx.Log.Error("error calculating benched stake", + zap.Stringer("subnetID", b.ctx.SubnetID), + zap.Error(err), + ) + return + } b.metrics.weightBenched.Set(float64(benchedStake)) } -// Returns the next validator that should leave -// the bench at time [now]. nil if no validator should. +// Returns if a validator should leave the bench at time [now]. +// False if no validator should. // Assumes [b.lock] is held -func (b *benchlist) nextToLeave(now time.Time) *benchData { - if b.benchedQueue.Len() == 0 { - return nil +func (b *benchlist) canUnbench(now time.Time) bool { + _, next, ok := b.benchedHeap.Peek() + if !ok { + return false } - next := b.benchedQueue[0] - if now.Before(next.benchedUntil) { - return nil - } - return next + return now.After(next) } // Set [b.timer] to fire when the next validator should leave the bench // Assumes [b.lock] is held func (b *benchlist) setNextLeaveTime() { - if b.benchedQueue.Len() == 0 { + _, next, ok := b.benchedHeap.Peek() + if !ok { b.timer.Cancel() return } now := b.clock.Time() - next := b.benchedQueue[0] - nextLeave := next.benchedUntil.Sub(now) + nextLeave := next.Sub(now) b.timer.SetTimeoutIn(nextLeave) } @@ -290,28 +241,44 @@ func (b *benchlist) RegisterFailure(nodeID ids.NodeID) { // Assumes [b.lock] is held // Assumes [nodeID] is not already benched func (b *benchlist) bench(nodeID ids.NodeID) { - validatorStake := b.vdrs.GetWeight(nodeID) + validatorStake := b.vdrs.GetWeight(b.ctx.SubnetID, nodeID) if validatorStake == 0 { // We might want to bench a non-validator because they don't respond to // my Get requests, but we choose to only bench validators. return } - benchedStake := b.vdrs.SubsetWeight(b.benchlistSet) + benchedStake, err := b.vdrs.SubsetWeight(b.ctx.SubnetID, b.benchlistSet) + if err != nil { + b.ctx.Log.Error("error calculating benched stake", + zap.Stringer("subnetID", b.ctx.SubnetID), + zap.Error(err), + ) + return + } + newBenchedStake, err := safemath.Add64(benchedStake, validatorStake) if err != nil { // This should never happen - b.log.Error("overflow calculating new benched stake", + b.ctx.Log.Error("overflow calculating new benched stake", zap.Stringer("nodeID", nodeID), ) return } - totalStake := b.vdrs.Weight() + totalStake, err := b.vdrs.TotalWeight(b.ctx.SubnetID) + if err != nil { + b.ctx.Log.Error("error calculating total stake", + zap.Stringer("subnetID", b.ctx.SubnetID), + zap.Error(err), + ) + return + } + maxBenchedStake := float64(totalStake) * b.maxPortion if float64(newBenchedStake) > maxBenchedStake { - b.log.Debug("not benching node", + b.ctx.Log.Debug("not benching node", zap.String("reason", "benched stake would exceed max"), zap.Stringer("nodeID", nodeID), zap.Float64("benchedStake", float64(newBenchedStake)), @@ -330,17 +297,14 @@ func (b *benchlist) bench(nodeID ids.NodeID) { // Add to benchlist times with randomized delay b.benchlistSet.Add(nodeID) - b.benchable.Benched(b.chainID, nodeID) + b.benchable.Benched(b.ctx.ChainID, nodeID) b.streaklock.Lock() delete(b.failureStreaks, nodeID) b.streaklock.Unlock() - heap.Push( - &b.benchedQueue, - &benchData{nodeID: nodeID, benchedUntil: benchedUntil}, - ) - b.log.Debug("benching validator after consecutive failed queries", + b.benchedHeap.Push(nodeID, benchedUntil) + b.ctx.Log.Debug("benching validator after consecutive failed queries", zap.Stringer("nodeID", nodeID), zap.Duration("benchDuration", benchedUntil.Sub(now)), zap.Int("numFailedQueries", b.threshold), @@ -350,6 +314,6 @@ func (b *benchlist) bench(nodeID ids.NodeID) { b.setNextLeaveTime() // Update metrics - b.metrics.numBenched.Set(float64(b.benchedQueue.Len())) + b.metrics.numBenched.Set(float64(b.benchedHeap.Len())) b.metrics.weightBenched.Set(float64(newBenchedStake)) } diff --git a/snow/networking/benchlist/benchlist_test.go b/snow/networking/benchlist/benchlist_test.go index a33abd943c59..75df4f454292 100644 --- a/snow/networking/benchlist/benchlist_test.go +++ b/snow/networking/benchlist/benchlist_test.go @@ -7,13 +7,11 @@ import ( "testing" "time" - "github.com/prometheus/client_golang/prometheus" - "github.com/stretchr/testify/require" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow" "github.com/ava-labs/avalanchego/snow/validators" - "github.com/ava-labs/avalanchego/utils/logging" ) var minimumFailingDuration = 5 * time.Minute @@ -22,18 +20,19 @@ var minimumFailingDuration = 5 * time.Minute func TestBenchlistAdd(t *testing.T) { require := require.New(t) - vdrs := validators.NewSet() + ctx := snow.DefaultConsensusContextTest() + vdrs := validators.NewManager() vdrID0 := ids.GenerateTestNodeID() vdrID1 := ids.GenerateTestNodeID() vdrID2 := ids.GenerateTestNodeID() vdrID3 := ids.GenerateTestNodeID() vdrID4 := ids.GenerateTestNodeID() - require.NoError(vdrs.Add(vdrID0, nil, ids.Empty, 50)) - require.NoError(vdrs.Add(vdrID1, nil, ids.Empty, 50)) - require.NoError(vdrs.Add(vdrID2, nil, ids.Empty, 50)) - require.NoError(vdrs.Add(vdrID3, nil, ids.Empty, 50)) - require.NoError(vdrs.Add(vdrID4, nil, ids.Empty, 50)) + require.NoError(vdrs.AddStaker(ctx.SubnetID, vdrID0, nil, ids.Empty, 50)) + require.NoError(vdrs.AddStaker(ctx.SubnetID, vdrID1, nil, ids.Empty, 50)) + require.NoError(vdrs.AddStaker(ctx.SubnetID, vdrID2, nil, ids.Empty, 50)) + require.NoError(vdrs.AddStaker(ctx.SubnetID, vdrID3, nil, ids.Empty, 50)) + require.NoError(vdrs.AddStaker(ctx.SubnetID, vdrID4, nil, ids.Empty, 50)) benchable := &TestBenchable{T: t} benchable.Default(true) @@ -42,15 +41,13 @@ func TestBenchlistAdd(t *testing.T) { duration := time.Minute maxPortion := 0.5 benchIntf, err := NewBenchlist( - ids.Empty, - logging.NoLog{}, + ctx, benchable, vdrs, threshold, minimumFailingDuration, duration, maxPortion, - prometheus.NewRegistry(), ) require.NoError(err) b := benchIntf.(*benchlist) @@ -66,7 +63,7 @@ func TestBenchlistAdd(t *testing.T) { require.False(b.isBenched(vdrID3)) require.False(b.isBenched(vdrID4)) require.Empty(b.failureStreaks) - require.Empty(b.benchedQueue) + require.Zero(b.benchedHeap.Len()) require.Empty(b.benchlistSet) b.lock.Unlock() @@ -77,7 +74,7 @@ func TestBenchlistAdd(t *testing.T) { // Still shouldn't be benched due to not enough consecutive failure require.False(b.isBenched(vdrID0)) - require.Empty(b.benchedQueue) + require.Zero(b.benchedHeap.Len()) require.Empty(b.benchlistSet) require.Len(b.failureStreaks, 1) fs := b.failureStreaks[vdrID0] @@ -91,7 +88,7 @@ func TestBenchlistAdd(t *testing.T) { // has passed since the first failure b.lock.Lock() require.False(b.isBenched(vdrID0)) - require.Empty(b.benchedQueue) + require.Zero(b.benchedHeap.Len()) require.Empty(b.benchlistSet) b.lock.Unlock() @@ -112,13 +109,14 @@ func TestBenchlistAdd(t *testing.T) { // Now this validator should be benched b.lock.Lock() require.True(b.isBenched(vdrID0)) - require.Equal(b.benchedQueue.Len(), 1) + require.Equal(b.benchedHeap.Len(), 1) require.Equal(b.benchlistSet.Len(), 1) - next := b.benchedQueue[0] - require.Equal(vdrID0, next.nodeID) - require.False(next.benchedUntil.After(now.Add(duration))) - require.False(next.benchedUntil.Before(now.Add(duration / 2))) + nodeID, benchedUntil, ok := b.benchedHeap.Peek() + require.True(ok) + require.Equal(vdrID0, nodeID) + require.False(benchedUntil.After(now.Add(duration))) + require.False(benchedUntil.Before(now.Add(duration / 2))) require.Empty(b.failureStreaks) require.True(benched) benchable.BenchedF = nil @@ -137,7 +135,7 @@ func TestBenchlistAdd(t *testing.T) { b.lock.Lock() require.True(b.isBenched(vdrID0)) require.False(b.isBenched(vdrID1)) - require.Equal(b.benchedQueue.Len(), 1) + require.Equal(b.benchedHeap.Len(), 1) require.Equal(b.benchlistSet.Len(), 1) require.Empty(b.failureStreaks) b.lock.Unlock() @@ -155,7 +153,8 @@ func TestBenchlistAdd(t *testing.T) { func TestBenchlistMaxStake(t *testing.T) { require := require.New(t) - vdrs := validators.NewSet() + ctx := snow.DefaultConsensusContextTest() + vdrs := validators.NewManager() vdrID0 := ids.GenerateTestNodeID() vdrID1 := ids.GenerateTestNodeID() vdrID2 := ids.GenerateTestNodeID() @@ -163,26 +162,24 @@ func TestBenchlistMaxStake(t *testing.T) { vdrID4 := ids.GenerateTestNodeID() // Total weight is 5100 - require.NoError(vdrs.Add(vdrID0, nil, ids.Empty, 1000)) - require.NoError(vdrs.Add(vdrID1, nil, ids.Empty, 1000)) - require.NoError(vdrs.Add(vdrID2, nil, ids.Empty, 1000)) - require.NoError(vdrs.Add(vdrID3, nil, ids.Empty, 2000)) - require.NoError(vdrs.Add(vdrID4, nil, ids.Empty, 100)) + require.NoError(vdrs.AddStaker(ctx.SubnetID, vdrID0, nil, ids.Empty, 1000)) + require.NoError(vdrs.AddStaker(ctx.SubnetID, vdrID1, nil, ids.Empty, 1000)) + require.NoError(vdrs.AddStaker(ctx.SubnetID, vdrID2, nil, ids.Empty, 1000)) + require.NoError(vdrs.AddStaker(ctx.SubnetID, vdrID3, nil, ids.Empty, 2000)) + require.NoError(vdrs.AddStaker(ctx.SubnetID, vdrID4, nil, ids.Empty, 100)) threshold := 3 duration := 1 * time.Hour // Shouldn't bench more than 2550 (5100/2) maxPortion := 0.5 benchIntf, err := NewBenchlist( - ids.Empty, - logging.NoLog{}, + ctx, &TestBenchable{T: t}, vdrs, threshold, minimumFailingDuration, duration, maxPortion, - prometheus.NewRegistry(), ) require.NoError(err) b := benchIntf.(*benchlist) @@ -215,7 +212,7 @@ func TestBenchlistMaxStake(t *testing.T) { require.True(b.isBenched(vdrID0)) require.True(b.isBenched(vdrID1)) require.False(b.isBenched(vdrID2)) - require.Equal(b.benchedQueue.Len(), 2) + require.Equal(b.benchedHeap.Len(), 2) require.Equal(b.benchlistSet.Len(), 2) require.Len(b.failureStreaks, 1) fs := b.failureStreaks[vdrID2] @@ -242,7 +239,7 @@ func TestBenchlistMaxStake(t *testing.T) { require.True(b.isBenched(vdrID0)) require.True(b.isBenched(vdrID1)) require.True(b.isBenched(vdrID4)) - require.Equal(3, b.benchedQueue.Len()) + require.Equal(3, b.benchedHeap.Len()) require.Equal(3, b.benchlistSet.Len()) require.Contains(b.benchlistSet, vdrID0) require.Contains(b.benchlistSet, vdrID1) @@ -261,19 +258,10 @@ func TestBenchlistMaxStake(t *testing.T) { require.True(b.isBenched(vdrID1)) require.True(b.isBenched(vdrID4)) require.False(b.isBenched(vdrID2)) - require.Equal(3, b.benchedQueue.Len()) + require.Equal(3, b.benchedHeap.Len()) require.Equal(3, b.benchlistSet.Len()) require.Len(b.failureStreaks, 1) require.Contains(b.failureStreaks, vdrID2) - - // Ensure the benched queue root has the min end time - minEndTime := b.benchedQueue[0].benchedUntil - benchedIDs := []ids.NodeID{vdrID0, vdrID1, vdrID4} - for _, benchedVdr := range b.benchedQueue { - require.Contains(benchedIDs, benchedVdr.nodeID) - require.False(benchedVdr.benchedUntil.Before(minEndTime)) - } - b.lock.Unlock() } @@ -281,7 +269,8 @@ func TestBenchlistMaxStake(t *testing.T) { func TestBenchlistRemove(t *testing.T) { require := require.New(t) - vdrs := validators.NewSet() + ctx := snow.DefaultConsensusContextTest() + vdrs := validators.NewManager() vdrID0 := ids.GenerateTestNodeID() vdrID1 := ids.GenerateTestNodeID() vdrID2 := ids.GenerateTestNodeID() @@ -289,11 +278,11 @@ func TestBenchlistRemove(t *testing.T) { vdrID4 := ids.GenerateTestNodeID() // Total weight is 5000 - require.NoError(vdrs.Add(vdrID0, nil, ids.Empty, 1000)) - require.NoError(vdrs.Add(vdrID1, nil, ids.Empty, 1000)) - require.NoError(vdrs.Add(vdrID2, nil, ids.Empty, 1000)) - require.NoError(vdrs.Add(vdrID3, nil, ids.Empty, 1000)) - require.NoError(vdrs.Add(vdrID4, nil, ids.Empty, 1000)) + require.NoError(vdrs.AddStaker(ctx.SubnetID, vdrID0, nil, ids.Empty, 1000)) + require.NoError(vdrs.AddStaker(ctx.SubnetID, vdrID1, nil, ids.Empty, 1000)) + require.NoError(vdrs.AddStaker(ctx.SubnetID, vdrID2, nil, ids.Empty, 1000)) + require.NoError(vdrs.AddStaker(ctx.SubnetID, vdrID3, nil, ids.Empty, 1000)) + require.NoError(vdrs.AddStaker(ctx.SubnetID, vdrID4, nil, ids.Empty, 1000)) count := 0 benchable := &TestBenchable{ @@ -308,15 +297,13 @@ func TestBenchlistRemove(t *testing.T) { duration := 2 * time.Second maxPortion := 0.76 // can bench 3 of the 5 validators benchIntf, err := NewBenchlist( - ids.Empty, - logging.NoLog{}, + ctx, benchable, vdrs, threshold, minimumFailingDuration, duration, maxPortion, - prometheus.NewRegistry(), ) require.NoError(err) b := benchIntf.(*benchlist) @@ -348,18 +335,10 @@ func TestBenchlistRemove(t *testing.T) { require.True(b.isBenched(vdrID0)) require.True(b.isBenched(vdrID1)) require.True(b.isBenched(vdrID2)) - require.Equal(3, b.benchedQueue.Len()) + require.Equal(3, b.benchedHeap.Len()) require.Equal(3, b.benchlistSet.Len()) require.Empty(b.failureStreaks) - // Ensure the benched queue root has the min end time - minEndTime := b.benchedQueue[0].benchedUntil - benchedIDs := []ids.NodeID{vdrID0, vdrID1, vdrID2} - for _, benchedVdr := range b.benchedQueue { - require.Contains(benchedIDs, benchedVdr.nodeID) - require.False(benchedVdr.benchedUntil.Before(minEndTime)) - } - // Set the benchlist's clock past when all validators should be unbenched // so that when its timer fires, it can remove them b.clock.Set(b.clock.Time().Add(duration)) diff --git a/snow/networking/benchlist/manager.go b/snow/networking/benchlist/manager.go index da5b0836c6dc..7a42e8245267 100644 --- a/snow/networking/benchlist/manager.go +++ b/snow/networking/benchlist/manager.go @@ -4,21 +4,15 @@ package benchlist import ( - "errors" "sync" "time" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow" "github.com/ava-labs/avalanchego/snow/validators" - "github.com/ava-labs/avalanchego/utils/constants" ) -var ( - errUnknownValidators = errors.New("unknown validator set for provided chain") - - _ Manager = (*manager)(nil) -) +var _ Manager = (*manager)(nil) // Manager provides an interface for a benchlist to register whether // queries have been successful or unsuccessful and place validators with @@ -47,7 +41,6 @@ type Manager interface { type Config struct { Benchable Benchable `json:"-"` Validators validators.Manager `json:"-"` - SybilProtectionEnabled bool `json:"-"` Threshold int `json:"threshold"` MinimumFailingDuration time.Duration `json:"minimumFailingDuration"` Duration time.Duration `json:"duration"` @@ -115,30 +108,14 @@ func (m *manager) RegisterChain(ctx *snow.ConsensusContext) error { return nil } - var ( - vdrs validators.Set - ok bool - ) - if m.config.SybilProtectionEnabled { - vdrs, ok = m.config.Validators.Get(ctx.SubnetID) - } else { - // If sybil protection is disabled, everyone validates every chain - vdrs, ok = m.config.Validators.Get(constants.PrimaryNetworkID) - } - if !ok { - return errUnknownValidators - } - benchlist, err := NewBenchlist( - ctx.ChainID, - ctx.Log, + ctx, m.config.Benchable, - vdrs, + m.config.Validators, m.config.Threshold, m.config.MinimumFailingDuration, m.config.Duration, m.config.MaxPortion, - ctx.Registerer, ) if err != nil { return err diff --git a/snow/networking/handler/handler.go b/snow/networking/handler/handler.go index e908d76a57b5..1a9a1d89b6ae 100644 --- a/snow/networking/handler/handler.go +++ b/snow/networking/handler/handler.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "sync" + "sync/atomic" "time" "github.com/prometheus/client_golang/prometheus" @@ -84,10 +85,9 @@ type handler struct { clock mockable.Clock ctx *snow.ConsensusContext - // The validator set that validates this chain // TODO: consider using peerTracker instead of validators // since peerTracker is already tracking validators - validators validators.Set + validators validators.Manager // Receives messages from the VM msgFromVMChan <-chan common.Message preemptTimeouts chan struct{} @@ -116,7 +116,7 @@ type handler struct { startClosingTime time.Time totalClosingTime time.Duration closingChan chan struct{} - numDispatchersClosed int + numDispatchersClosed atomic.Uint32 // Closed when this handler and [engine] are done shutting down closed chan struct{} @@ -132,7 +132,7 @@ type handler struct { // [engine] must be initialized before initializing this handler func New( ctx *snow.ConsensusContext, - validators validators.Set, + validators validators.Manager, msgFromVMChan <-chan common.Message, gossipFrequency time.Duration, threadPoolSize int, @@ -164,11 +164,11 @@ func New( return nil, fmt.Errorf("initializing handler metrics errored with: %w", err) } cpuTracker := resourceTracker.CPUTracker() - h.syncMessageQueue, err = NewMessageQueue(h.ctx.Log, h.validators, cpuTracker, "handler", h.ctx.Registerer, message.SynchronousOps) + h.syncMessageQueue, err = NewMessageQueue(h.ctx, h.validators, cpuTracker, "handler", message.SynchronousOps) if err != nil { return nil, fmt.Errorf("initializing sync message queue errored with: %w", err) } - h.asyncMessageQueue, err = NewMessageQueue(h.ctx.Log, h.validators, cpuTracker, "handler_async", h.ctx.Registerer, message.AsynchronousOps) + h.asyncMessageQueue, err = NewMessageQueue(h.ctx, h.validators, cpuTracker, "handler_async", message.AsynchronousOps) if err != nil { return nil, fmt.Errorf("initializing async message queue errored with: %w", err) } @@ -180,7 +180,8 @@ func (h *handler) Context() *snow.ConsensusContext { } func (h *handler) ShouldHandle(nodeID ids.NodeID) bool { - return h.subnet.IsAllowed(nodeID, h.validators.Contains(nodeID)) + _, ok := h.validators.GetValidator(h.ctx.SubnetID, nodeID) + return h.subnet.IsAllowed(nodeID, ok) } func (h *handler) SetEngineManager(engineManager *EngineManager) { @@ -215,27 +216,27 @@ func (h *handler) selectStartingGear(ctx context.Context) (common.Engine, error) } // drop bootstrap state from previous runs before starting state sync - return engines.StateSyncer, engines.Bootstrapper.Clear() + return engines.StateSyncer, engines.Bootstrapper.Clear(ctx) } func (h *handler) Start(ctx context.Context, recoverPanic bool) { - h.ctx.Lock.Lock() - defer h.ctx.Lock.Unlock() - gear, err := h.selectStartingGear(ctx) if err != nil { h.ctx.Log.Error("chain failed to select starting gear", zap.Error(err), ) - h.shutdown(ctx) + h.shutdown(ctx, h.clock.Time()) return } - if err := gear.Start(ctx, 0); err != nil { + h.ctx.Lock.Lock() + err = gear.Start(ctx, 0) + h.ctx.Lock.Unlock() + if err != nil { h.ctx.Log.Error("chain failed to start", zap.Error(err), ) - h.shutdown(ctx) + h.shutdown(ctx, h.clock.Time()) return } @@ -326,7 +327,7 @@ func (h *handler) Stop(ctx context.Context) { state := h.ctx.State.Get() bootstrapper, ok := h.engineManager.Get(state.Type).Get(snow.Bootstrapping) if !ok { - h.ctx.Log.Error("bootstrapping engine doesn't exists", + h.ctx.Log.Error("bootstrapping engine doesn't exist", zap.Stringer("type", state.Type), ) return @@ -998,35 +999,27 @@ func (h *handler) popUnexpiredMsg( } } +// Invariant: if closeDispatcher is called, Stop has already been called. func (h *handler) closeDispatcher(ctx context.Context) { - h.ctx.Lock.Lock() - defer h.ctx.Lock.Unlock() - - h.numDispatchersClosed++ - if h.numDispatchersClosed < numDispatchersToClose { + if h.numDispatchersClosed.Add(1) < numDispatchersToClose { return } - h.shutdown(ctx) + h.shutdown(ctx, h.startClosingTime) } -// Note: shutdown is only called after all message dispatchers have exited. -func (h *handler) shutdown(ctx context.Context) { +// Note: shutdown is only called after all message dispatchers have exited or if +// no message dispatchers ever started. +func (h *handler) shutdown(ctx context.Context, startClosingTime time.Time) { defer func() { if h.onStopped != nil { go h.onStopped() } - h.totalClosingTime = h.clock.Time().Sub(h.startClosingTime) + h.totalClosingTime = h.clock.Time().Sub(startClosingTime) close(h.closed) }() - // shutdown may be called during Start, so we populate the start closing - // time here in case Stop was never called. - if h.startClosingTime.IsZero() { - h.startClosingTime = h.clock.Time() - } - state := h.ctx.State.Get() engine, ok := h.engineManager.Get(state.Type).Get(state.State) if !ok { diff --git a/snow/networking/handler/handler_test.go b/snow/networking/handler/handler_test.go index be76c7077a46..c28da4bc8b71 100644 --- a/snow/networking/handler/handler_test.go +++ b/snow/networking/handler/handler_test.go @@ -41,9 +41,9 @@ func TestHandlerDropsTimedOutMessages(t *testing.T) { ctx := snow.DefaultConsensusContextTest() - vdrs := validators.NewSet() + vdrs := validators.NewManager() vdr0 := ids.GenerateTestNodeID() - require.NoError(vdrs.Add(vdr0, nil, ids.Empty, 1)) + require.NoError(vdrs.AddStaker(ctx.SubnetID, vdr0, nil, ids.Empty, 1)) resourceTracker, err := tracker.NewResourceTracker( prometheus.NewRegistry(), @@ -139,8 +139,8 @@ func TestHandlerClosesOnError(t *testing.T) { closed := make(chan struct{}, 1) ctx := snow.DefaultConsensusContextTest() - vdrs := validators.NewSet() - require.NoError(vdrs.Add(ids.GenerateTestNodeID(), nil, ids.Empty, 1)) + vdrs := validators.NewManager() + require.NoError(vdrs.AddStaker(ctx.SubnetID, ids.GenerateTestNodeID(), nil, ids.Empty, 1)) resourceTracker, err := tracker.NewResourceTracker( prometheus.NewRegistry(), @@ -232,8 +232,8 @@ func TestHandlerDropsGossipDuringBootstrapping(t *testing.T) { closed := make(chan struct{}, 1) ctx := snow.DefaultConsensusContextTest() - vdrs := validators.NewSet() - require.NoError(vdrs.Add(ids.GenerateTestNodeID(), nil, ids.Empty, 1)) + vdrs := validators.NewManager() + require.NoError(vdrs.AddStaker(ctx.SubnetID, ids.GenerateTestNodeID(), nil, ids.Empty, 1)) resourceTracker, err := tracker.NewResourceTracker( prometheus.NewRegistry(), @@ -313,8 +313,8 @@ func TestHandlerDispatchInternal(t *testing.T) { ctx := snow.DefaultConsensusContextTest() msgFromVMChan := make(chan common.Message) - vdrs := validators.NewSet() - require.NoError(vdrs.Add(ids.GenerateTestNodeID(), nil, ids.Empty, 1)) + vdrs := validators.NewManager() + require.NoError(vdrs.AddStaker(ctx.SubnetID, ids.GenerateTestNodeID(), nil, ids.Empty, 1)) resourceTracker, err := tracker.NewResourceTracker( prometheus.NewRegistry(), @@ -384,8 +384,8 @@ func TestHandlerSubnetConnector(t *testing.T) { require := require.New(t) ctx := snow.DefaultConsensusContextTest() - vdrs := validators.NewSet() - require.NoError(vdrs.Add(ids.GenerateTestNodeID(), nil, ids.Empty, 1)) + vdrs := validators.NewManager() + require.NoError(vdrs.AddStaker(ctx.SubnetID, ids.GenerateTestNodeID(), nil, ids.Empty, 1)) resourceTracker, err := tracker.NewResourceTracker( prometheus.NewRegistry(), @@ -562,8 +562,8 @@ func TestDynamicEngineTypeDispatch(t *testing.T) { messageReceived := make(chan struct{}) ctx := snow.DefaultConsensusContextTest() - vdrs := validators.NewSet() - require.NoError(vdrs.Add(ids.GenerateTestNodeID(), nil, ids.Empty, 1)) + vdrs := validators.NewManager() + require.NoError(vdrs.AddStaker(ctx.SubnetID, ids.GenerateTestNodeID(), nil, ids.Empty, 1)) resourceTracker, err := tracker.NewResourceTracker( prometheus.NewRegistry(), @@ -633,3 +633,41 @@ func TestDynamicEngineTypeDispatch(t *testing.T) { }) } } + +func TestHandlerStartError(t *testing.T) { + require := require.New(t) + + ctx := snow.DefaultConsensusContextTest() + resourceTracker, err := tracker.NewResourceTracker( + prometheus.NewRegistry(), + resource.NoUsage, + meter.ContinuousFactory{}, + time.Second, + ) + require.NoError(err) + + handler, err := New( + ctx, + validators.NewManager(), + nil, + time.Second, + testThreadPoolSize, + resourceTracker, + nil, + subnets.New(ctx.NodeID, subnets.Config{}), + commontracker.NewPeers(), + ) + require.NoError(err) + + // Starting a handler with an unprovided engine should immediately cause the + // handler to shutdown. + handler.SetEngineManager(&EngineManager{}) + ctx.State.Set(snow.EngineState{ + Type: p2p.EngineType_ENGINE_TYPE_SNOWMAN, + State: snow.Initializing, + }) + handler.Start(context.Background(), false) + + _, err = handler.AwaitStopped(context.Background()) + require.NoError(err) +} diff --git a/snow/networking/handler/health.go b/snow/networking/handler/health.go index 31d6fe11fcec..b68ead089639 100644 --- a/snow/networking/handler/health.go +++ b/snow/networking/handler/health.go @@ -12,9 +12,6 @@ import ( var ErrNotConnectedEnoughStake = errors.New("not connected to enough stake") func (h *handler) HealthCheck(ctx context.Context) (interface{}, error) { - h.ctx.Lock.Lock() - defer h.ctx.Lock.Unlock() - state := h.ctx.State.Get() engine, ok := h.engineManager.Get(state.Type).Get(state.State) if !ok { diff --git a/snow/networking/handler/health_test.go b/snow/networking/handler/health_test.go index ba89790ce02a..9767859a4abf 100644 --- a/snow/networking/handler/health_test.go +++ b/snow/networking/handler/health_test.go @@ -49,7 +49,7 @@ func TestHealthCheckSubnet(t *testing.T) { ctx := snow.DefaultConsensusContextTest() - vdrs := validators.NewSet() + vdrs := validators.NewManager() resourceTracker, err := tracker.NewResourceTracker( prometheus.NewRegistry(), @@ -60,7 +60,7 @@ func TestHealthCheckSubnet(t *testing.T) { require.NoError(err) peerTracker := commontracker.NewPeers() - vdrs.RegisterCallbackListener(peerTracker) + vdrs.RegisterCallbackListener(ctx.SubnetID, peerTracker) sb := subnets.New( ctx.NodeID, @@ -121,7 +121,7 @@ func TestHealthCheckSubnet(t *testing.T) { vdrID := ids.GenerateTestNodeID() vdrIDs.Add(vdrID) - require.NoError(vdrs.Add(vdrID, nil, ids.Empty, 100)) + require.NoError(vdrs.AddStaker(ctx.SubnetID, vdrID, nil, ids.Empty, 100)) } for index, nodeID := range vdrIDs.List() { diff --git a/snow/networking/handler/message_queue.go b/snow/networking/handler/message_queue.go index 1dfeea5eca50..6fe4137b940e 100644 --- a/snow/networking/handler/message_queue.go +++ b/snow/networking/handler/message_queue.go @@ -7,16 +7,14 @@ import ( "context" "sync" - "github.com/prometheus/client_golang/prometheus" - "go.uber.org/zap" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/message" "github.com/ava-labs/avalanchego/proto/pb/p2p" + "github.com/ava-labs/avalanchego/snow" "github.com/ava-labs/avalanchego/snow/networking/tracker" "github.com/ava-labs/avalanchego/snow/validators" - "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/timer/mockable" ) @@ -60,9 +58,9 @@ type messageQueue struct { clock mockable.Clock metrics messageQueueMetrics - log logging.Logger + ctx *snow.ConsensusContext // Validator set for the chain associated with this - vdrs validators.Set + vdrs validators.Manager // Tracks CPU utilization of each node cpuTracker tracker.Tracker @@ -75,21 +73,20 @@ type messageQueue struct { } func NewMessageQueue( - log logging.Logger, - vdrs validators.Set, + ctx *snow.ConsensusContext, + vdrs validators.Manager, cpuTracker tracker.Tracker, metricsNamespace string, - metricsRegisterer prometheus.Registerer, ops []message.Op, ) (MessageQueue, error) { m := &messageQueue{ - log: log, + ctx: ctx, vdrs: vdrs, cpuTracker: cpuTracker, cond: sync.NewCond(&sync.Mutex{}), nodeToUnprocessedMsgs: make(map[ids.NodeID]int), } - return m, m.metrics.initialize(metricsNamespace, metricsRegisterer, ops) + return m, m.metrics.initialize(metricsNamespace, ctx.Registerer, ops) } func (m *messageQueue) Push(ctx context.Context, msg Message) { @@ -137,7 +134,7 @@ func (m *messageQueue) Pop() (context.Context, Message, bool) { i := 0 for { if i == n { - m.log.Debug("canPop is false for all unprocessed messages", + m.ctx.Log.Debug("canPop is false for all unprocessed messages", zap.Int("numMessages", n), ) } @@ -218,14 +215,26 @@ func (m *messageQueue) canPop(msg message.InboundMessage) bool { // the number of nodes with unprocessed messages. baseMaxCPU := 1 / float64(len(m.nodeToUnprocessedMsgs)) nodeID := msg.NodeID() - weight := m.vdrs.GetWeight(nodeID) - // The sum of validator weights should never be 0, but handle - // that case for completeness here to avoid divide by 0. - portionWeight := float64(0) - totalVdrsWeight := m.vdrs.Weight() - if totalVdrsWeight != 0 { + weight := m.vdrs.GetWeight(m.ctx.SubnetID, nodeID) + + var portionWeight float64 + if totalVdrsWeight, err := m.vdrs.TotalWeight(m.ctx.SubnetID); err != nil { + // The sum of validator weights should never overflow, but if they do, + // we treat portionWeight as 0. + m.ctx.Log.Error("failed to get total weight of validators", + zap.Stringer("subnetID", m.ctx.SubnetID), + zap.Error(err), + ) + } else if totalVdrsWeight == 0 { + // The sum of validator weights should never be 0, but handle that case + // for completeness here to avoid divide by 0. + m.ctx.Log.Warn("validator set is empty", + zap.Stringer("subnetID", m.ctx.SubnetID), + ) + } else { portionWeight = float64(weight) / float64(totalVdrsWeight) } + // Validators are allowed to use more CPU. More weight --> more CPU use allowed. recentCPUUsage := m.cpuTracker.Usage(nodeID, m.clock.Time()) maxCPU := baseMaxCPU + (1.0-baseMaxCPU)*portionWeight diff --git a/snow/networking/handler/message_queue_test.go b/snow/networking/handler/message_queue_test.go index 352733dbb139..1eabfd96c410 100644 --- a/snow/networking/handler/message_queue_test.go +++ b/snow/networking/handler/message_queue_test.go @@ -8,8 +8,6 @@ import ( "testing" "time" - "github.com/prometheus/client_golang/prometheus" - "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" @@ -17,9 +15,9 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/message" "github.com/ava-labs/avalanchego/proto/pb/p2p" + "github.com/ava-labs/avalanchego/snow" "github.com/ava-labs/avalanchego/snow/networking/tracker" "github.com/ava-labs/avalanchego/snow/validators" - "github.com/ava-labs/avalanchego/utils/logging" ) const engineType = p2p.EngineType_ENGINE_TYPE_SNOWMAN @@ -28,11 +26,12 @@ func TestQueue(t *testing.T) { ctrl := gomock.NewController(t) require := require.New(t) cpuTracker := tracker.NewMockTracker(ctrl) - vdrs := validators.NewSet() + ctx := snow.DefaultConsensusContextTest() + vdrs := validators.NewManager() vdr1ID, vdr2ID := ids.GenerateTestNodeID(), ids.GenerateTestNodeID() - require.NoError(vdrs.Add(vdr1ID, nil, ids.Empty, 1)) - require.NoError(vdrs.Add(vdr2ID, nil, ids.Empty, 1)) - mIntf, err := NewMessageQueue(logging.NoLog{}, vdrs, cpuTracker, "", prometheus.NewRegistry(), message.SynchronousOps) + require.NoError(vdrs.AddStaker(ctx.SubnetID, vdr1ID, nil, ids.Empty, 1)) + require.NoError(vdrs.AddStaker(ctx.SubnetID, vdr2ID, nil, ids.Empty, 1)) + mIntf, err := NewMessageQueue(ctx, vdrs, cpuTracker, "", message.SynchronousOps) require.NoError(err) u := mIntf.(*messageQueue) currentTime := time.Now() diff --git a/snow/networking/router/chain_router_test.go b/snow/networking/router/chain_router_test.go index 7854ff8e0b8a..d4f71828799f 100644 --- a/snow/networking/router/chain_router_test.go +++ b/snow/networking/router/chain_router_test.go @@ -46,8 +46,9 @@ const ( func TestShutdown(t *testing.T) { require := require.New(t) - vdrs := validators.NewSet() - require.NoError(vdrs.Add(ids.GenerateTestNodeID(), nil, ids.Empty, 1)) + chainCtx := snow.DefaultConsensusContextTest() + vdrs := validators.NewManager() + require.NoError(vdrs.AddStaker(chainCtx.SubnetID, ids.GenerateTestNodeID(), nil, ids.Empty, 1)) benchlist := benchlist.NewNoBenchlist() tm, err := timeout.NewManager( &timer.AdaptiveTimeoutConfig{ @@ -62,7 +63,9 @@ func TestShutdown(t *testing.T) { prometheus.NewRegistry(), ) require.NoError(err) + go tm.Dispatch() + defer tm.Stop() chainRouter := ChainRouter{} require.NoError(chainRouter.Initialize( @@ -81,7 +84,6 @@ func TestShutdown(t *testing.T) { shutdownCalled := make(chan struct{}, 1) - chainCtx := snow.DefaultConsensusContextTest() resourceTracker, err := tracker.NewResourceTracker( prometheus.NewRegistry(), resource.NoUsage, @@ -182,9 +184,10 @@ func TestShutdown(t *testing.T) { func TestShutdownTimesOut(t *testing.T) { require := require.New(t) + ctx := snow.DefaultConsensusContextTest() nodeID := ids.EmptyNodeID - vdrs := validators.NewSet() - require.NoError(vdrs.Add(ids.GenerateTestNodeID(), nil, ids.Empty, 1)) + vdrs := validators.NewManager() + require.NoError(vdrs.AddStaker(ctx.SubnetID, ids.GenerateTestNodeID(), nil, ids.Empty, 1)) benchlist := benchlist.NewNoBenchlist() metrics := prometheus.NewRegistry() // Ensure that the Ancestors request does not timeout @@ -201,7 +204,9 @@ func TestShutdownTimesOut(t *testing.T) { metrics, ) require.NoError(err) + go tm.Dispatch() + defer tm.Stop() chainRouter := ChainRouter{} @@ -219,7 +224,6 @@ func TestShutdownTimesOut(t *testing.T) { metrics, )) - ctx := snow.DefaultConsensusContextTest() resourceTracker, err := tracker.NewResourceTracker( prometheus.NewRegistry(), resource.NoUsage, @@ -325,6 +329,7 @@ func TestShutdownTimesOut(t *testing.T) { // Ensure that a timeout fires if we don't get a response to a request func TestRouterTimeout(t *testing.T) { require := require.New(t) + // Create a timeout manager maxTimeout := 25 * time.Millisecond tm, err := timeout.NewManager( @@ -340,7 +345,9 @@ func TestRouterTimeout(t *testing.T) { prometheus.NewRegistry(), ) require.NoError(err) + go tm.Dispatch() + defer tm.Stop() // Create a router chainRouter := ChainRouter{} @@ -357,13 +364,17 @@ func TestRouterTimeout(t *testing.T) { "", prometheus.NewRegistry(), )) + defer chainRouter.Shutdown(context.Background()) // Create bootstrapper, engine and handler var ( - calledGetStateSummaryFrontierFailed, calledGetAcceptedStateSummaryFailed, - calledGetAcceptedFrontierFailed, calledGetAcceptedFailed, + calledGetStateSummaryFrontierFailed, + calledGetAcceptedStateSummaryFailed, + calledGetAcceptedFrontierFailed, + calledGetAcceptedFailed, calledGetAncestorsFailed, - calledGetFailed, calledQueryFailed, + calledGetFailed, + calledQueryFailed, calledAppRequestFailed, calledCrossChainAppRequestFailed bool @@ -371,8 +382,8 @@ func TestRouterTimeout(t *testing.T) { ) ctx := snow.DefaultConsensusContextTest() - vdrs := validators.NewSet() - require.NoError(vdrs.Add(ids.GenerateTestNodeID(), nil, ids.Empty, 1)) + vdrs := validators.NewManager() + require.NoError(vdrs.AddStaker(ctx.SubnetID, ids.GenerateTestNodeID(), nil, ids.Empty, 1)) resourceTracker, err := tracker.NewResourceTracker( prometheus.NewRegistry(), @@ -412,6 +423,7 @@ func TestRouterTimeout(t *testing.T) { return nil } bootstrapper.HaltF = func(context.Context) {} + bootstrapper.ShutdownF = func(ctx context.Context) error { return nil } bootstrapper.GetStateSummaryFrontierFailedF = func(context.Context, ids.NodeID, uint32) error { defer wg.Done() @@ -694,7 +706,9 @@ func TestRouterHonorsRequestedEngine(t *testing.T) { prometheus.NewRegistry(), ) require.NoError(err) + go tm.Dispatch() + defer tm.Stop() // Create a router chainRouter := ChainRouter{} @@ -711,12 +725,15 @@ func TestRouterHonorsRequestedEngine(t *testing.T) { "", prometheus.NewRegistry(), )) + defer chainRouter.Shutdown(context.Background()) h := handler.NewMockHandler(ctrl) ctx := snow.DefaultConsensusContextTest() h.EXPECT().Context().Return(ctx).AnyTimes() h.EXPECT().SetOnStopped(gomock.Any()).AnyTimes() + h.EXPECT().Stop(gomock.Any()).AnyTimes() + h.EXPECT().AwaitStopped(gomock.Any()).AnyTimes() h.EXPECT().Push(gomock.Any(), gomock.Any()).Times(1) chainRouter.AddChain(context.Background(), h) @@ -822,7 +839,9 @@ func TestRouterClearTimeouts(t *testing.T) { prometheus.NewRegistry(), ) require.NoError(err) + go tm.Dispatch() + defer tm.Stop() // Create a router chainRouter := ChainRouter{} @@ -839,11 +858,12 @@ func TestRouterClearTimeouts(t *testing.T) { "", prometheus.NewRegistry(), )) + defer chainRouter.Shutdown(context.Background()) // Create bootstrapper, engine and handler ctx := snow.DefaultConsensusContextTest() - vdrs := validators.NewSet() - require.NoError(vdrs.Add(ids.GenerateTestNodeID(), nil, ids.Empty, 1)) + vdrs := validators.NewManager() + require.NoError(vdrs.AddStaker(ctx.SubnetID, ids.GenerateTestNodeID(), nil, ids.Empty, 1)) resourceTracker, err := tracker.NewResourceTracker( prometheus.NewRegistry(), @@ -1111,7 +1131,9 @@ func TestValidatorOnlyMessageDrops(t *testing.T) { prometheus.NewRegistry(), ) require.NoError(err) + go tm.Dispatch() + defer tm.Stop() // Create a router chainRouter := ChainRouter{} @@ -1128,6 +1150,7 @@ func TestValidatorOnlyMessageDrops(t *testing.T) { "", prometheus.NewRegistry(), )) + defer chainRouter.Shutdown(context.Background()) // Create bootstrapper, engine and handler calledF := false @@ -1135,9 +1158,9 @@ func TestValidatorOnlyMessageDrops(t *testing.T) { ctx := snow.DefaultConsensusContextTest() sb := subnets.New(ctx.NodeID, subnets.Config{ValidatorOnly: true}) - vdrs := validators.NewSet() + vdrs := validators.NewManager() vID := ids.GenerateTestNodeID() - require.NoError(vdrs.Add(vID, nil, ids.Empty, 1)) + require.NoError(vdrs.AddStaker(ctx.SubnetID, vID, nil, ids.Empty, 1)) resourceTracker, err := tracker.NewResourceTracker( prometheus.NewRegistry(), resource.NoUsage, @@ -1261,7 +1284,9 @@ func TestRouterCrossChainMessages(t *testing.T) { prometheus.NewRegistry(), ) require.NoError(err) + go tm.Dispatch() + defer tm.Stop() // Create chain router nodeID := ids.GenerateTestNodeID() @@ -1279,18 +1304,19 @@ func TestRouterCrossChainMessages(t *testing.T) { "", prometheus.NewRegistry(), )) + defer chainRouter.Shutdown(context.Background()) - // Set up validators - vdrs := validators.NewSet() - require.NoError(vdrs.Add(ids.GenerateTestNodeID(), nil, ids.Empty, 1)) - - // Create bootstrapper, engine and handler requester := snow.DefaultConsensusContextTest() requester.ChainID = ids.GenerateTestID() requester.Registerer = prometheus.NewRegistry() requester.Metrics = metrics.NewOptionalGatherer() requester.Executing.Set(false) + // Set up validators + vdrs := validators.NewManager() + require.NoError(vdrs.AddStaker(requester.SubnetID, ids.GenerateTestNodeID(), nil, ids.Empty, 1)) + + // Create bootstrapper, engine and handler resourceTracker, err := tracker.NewResourceTracker( prometheus.NewRegistry(), resource.NoUsage, @@ -1311,6 +1337,18 @@ func TestRouterCrossChainMessages(t *testing.T) { commontracker.NewPeers(), ) require.NoError(err) + requesterHandler.SetEngineManager(&handler.EngineManager{ + Avalanche: &handler.Engine{ + StateSyncer: nil, + Bootstrapper: &common.BootstrapperTest{}, + Consensus: &common.EngineTest{}, + }, + Snowman: &handler.Engine{ + StateSyncer: nil, + Bootstrapper: &common.BootstrapperTest{}, + Consensus: &common.EngineTest{}, + }, + }) responder := snow.DefaultConsensusContextTest() responder.ChainID = ids.GenerateTestID() @@ -1330,6 +1368,18 @@ func TestRouterCrossChainMessages(t *testing.T) { commontracker.NewPeers(), ) require.NoError(err) + responderHandler.SetEngineManager(&handler.EngineManager{ + Avalanche: &handler.Engine{ + StateSyncer: nil, + Bootstrapper: &common.BootstrapperTest{}, + Consensus: &common.EngineTest{}, + }, + Snowman: &handler.Engine{ + StateSyncer: nil, + Bootstrapper: &common.BootstrapperTest{}, + Consensus: &common.EngineTest{}, + }, + }) // assumed bootstrapping is done responder.State.Set(snow.EngineState{ @@ -1408,7 +1458,9 @@ func TestConnectedSubnet(t *testing.T) { prometheus.NewRegistry(), ) require.NoError(err) + go tm.Dispatch() + defer tm.Stop() // Create chain router myNodeID := ids.GenerateTestNodeID() @@ -1526,7 +1578,9 @@ func TestValidatorOnlyAllowedNodeMessageDrops(t *testing.T) { prometheus.NewRegistry(), ) require.NoError(err) + go tm.Dispatch() + defer tm.Stop() // Create a router chainRouter := ChainRouter{} @@ -1543,6 +1597,7 @@ func TestValidatorOnlyAllowedNodeMessageDrops(t *testing.T) { "", prometheus.NewRegistry(), )) + defer chainRouter.Shutdown(context.Background()) // Create bootstrapper, engine and handler calledF := false @@ -1553,9 +1608,9 @@ func TestValidatorOnlyAllowedNodeMessageDrops(t *testing.T) { allowedSet := set.Of(allowedID) sb := subnets.New(ctx.NodeID, subnets.Config{ValidatorOnly: true, AllowedNodes: allowedSet}) - vdrs := validators.NewSet() + vdrs := validators.NewManager() vID := ids.GenerateTestNodeID() - require.NoError(vdrs.Add(vID, nil, ids.Empty, 1)) + require.NoError(vdrs.AddStaker(ctx.SubnetID, vID, nil, ids.Empty, 1)) resourceTracker, err := tracker.NewResourceTracker( prometheus.NewRegistry(), diff --git a/snow/networking/router/main_test.go b/snow/networking/router/main_test.go new file mode 100644 index 000000000000..afc1dddb173e --- /dev/null +++ b/snow/networking/router/main_test.go @@ -0,0 +1,14 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package router + +import ( + "testing" + + "go.uber.org/goleak" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} diff --git a/snow/networking/sender/sender_test.go b/snow/networking/sender/sender_test.go index a22322d55c6c..6355834dcf78 100644 --- a/snow/networking/sender/sender_test.go +++ b/snow/networking/sender/sender_test.go @@ -53,8 +53,9 @@ var defaultSubnetConfig = subnets.Config{ func TestTimeout(t *testing.T) { require := require.New(t) - vdrs := validators.NewSet() - require.NoError(vdrs.Add(ids.GenerateTestNodeID(), nil, ids.Empty, 1)) + ctx := snow.DefaultConsensusContextTest() + vdrs := validators.NewManager() + require.NoError(vdrs.AddStaker(ctx.SubnetID, ids.GenerateTestNodeID(), nil, ids.Empty, 1)) benchlist := benchlist.NewNoBenchlist() tm, err := timeout.NewManager( &timer.AdaptiveTimeoutConfig{ @@ -97,7 +98,6 @@ func TestTimeout(t *testing.T) { prometheus.NewRegistry(), )) - ctx := snow.DefaultConsensusContextTest() externalSender := &ExternalSenderTest{TB: t} externalSender.Default(false) @@ -310,8 +310,9 @@ func TestTimeout(t *testing.T) { func TestReliableMessages(t *testing.T) { require := require.New(t) - vdrs := validators.NewSet() - require.NoError(vdrs.Add(ids.NodeID{1}, nil, ids.Empty, 1)) + ctx := snow.DefaultConsensusContextTest() + vdrs := validators.NewManager() + require.NoError(vdrs.AddStaker(ctx.SubnetID, ids.NodeID{1}, nil, ids.Empty, 1)) benchlist := benchlist.NewNoBenchlist() tm, err := timeout.NewManager( &timer.AdaptiveTimeoutConfig{ @@ -355,8 +356,6 @@ func TestReliableMessages(t *testing.T) { prometheus.NewRegistry(), )) - ctx := snow.DefaultConsensusContextTest() - externalSender := &ExternalSenderTest{TB: t} externalSender.Default(false) @@ -460,8 +459,9 @@ func TestReliableMessagesToMyself(t *testing.T) { require := require.New(t) benchlist := benchlist.NewNoBenchlist() - vdrs := validators.NewSet() - require.NoError(vdrs.Add(ids.GenerateTestNodeID(), nil, ids.Empty, 1)) + ctx := snow.DefaultConsensusContextTest() + vdrs := validators.NewManager() + require.NoError(vdrs.AddStaker(ctx.SubnetID, ids.GenerateTestNodeID(), nil, ids.Empty, 1)) tm, err := timeout.NewManager( &timer.AdaptiveTimeoutConfig{ InitialTimeout: 10 * time.Millisecond, @@ -504,8 +504,6 @@ func TestReliableMessagesToMyself(t *testing.T) { prometheus.NewRegistry(), )) - ctx := snow.DefaultConsensusContextTest() - externalSender := &ExternalSenderTest{TB: t} externalSender.Default(false) diff --git a/snow/networking/timeout/main_test.go b/snow/networking/timeout/main_test.go new file mode 100644 index 000000000000..f3bee130e58b --- /dev/null +++ b/snow/networking/timeout/main_test.go @@ -0,0 +1,14 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package timeout + +import ( + "testing" + + "go.uber.org/goleak" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} diff --git a/snow/networking/timeout/manager.go b/snow/networking/timeout/manager.go index 3c2739d6daa5..d94c34a1f663 100644 --- a/snow/networking/timeout/manager.go +++ b/snow/networking/timeout/manager.go @@ -5,6 +5,7 @@ package timeout import ( "fmt" + "sync" "time" "github.com/prometheus/client_golang/prometheus" @@ -62,6 +63,9 @@ type Manager interface { // Mark that we no longer expect a response to this request we sent. // Does not modify the timeout. RemoveRequest(requestID ids.RequestID) + + // Stops the manager. + Stop() } func NewManager( @@ -88,6 +92,7 @@ type manager struct { tm timer.AdaptiveTimeoutManager benchlistMgr benchlist.Manager metrics metrics + stopOnce sync.Once } func (m *manager) Dispatch() { @@ -156,3 +161,9 @@ func (m *manager) RemoveRequest(requestID ids.RequestID) { func (m *manager) RegisterRequestToUnreachableValidator() { m.tm.ObserveLatency(m.TimeoutDuration()) } + +func (m *manager) Stop() { + m.stopOnce.Do(func() { + m.tm.Stop() + }) +} diff --git a/snow/networking/timeout/manager_test.go b/snow/networking/timeout/manager_test.go index d84afbec95ab..ce412150b1b6 100644 --- a/snow/networking/timeout/manager_test.go +++ b/snow/networking/timeout/manager_test.go @@ -33,6 +33,7 @@ func TestManagerFire(t *testing.T) { ) require.NoError(t, err) go manager.Dispatch() + defer manager.Stop() wg := sync.WaitGroup{} wg.Add(1) diff --git a/snow/networking/timeout/mock_manager.go b/snow/networking/timeout/mock_manager.go index 0a1a281b2df9..5a1bda7cb0b6 100644 --- a/snow/networking/timeout/mock_manager.go +++ b/snow/networking/timeout/mock_manager.go @@ -128,6 +128,18 @@ func (mr *MockManagerMockRecorder) RemoveRequest(arg0 interface{}) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveRequest", reflect.TypeOf((*MockManager)(nil).RemoveRequest), arg0) } +// Stop mocks base method. +func (m *MockManager) Stop() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Stop") +} + +// Stop indicates an expected call of Stop. +func (mr *MockManagerMockRecorder) Stop() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockManager)(nil).Stop)) +} + // TimeoutDuration mocks base method. func (m *MockManager) TimeoutDuration() time.Duration { m.ctrl.T.Helper() diff --git a/snow/networking/tracker/targeter.go b/snow/networking/tracker/targeter.go index 216bb9ec1e13..4c69ab9508c1 100644 --- a/snow/networking/tracker/targeter.go +++ b/snow/networking/tracker/targeter.go @@ -6,8 +6,12 @@ package tracker import ( "math" + "go.uber.org/zap" + "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow/validators" + "github.com/ava-labs/avalanchego/utils/constants" + "github.com/ava-labs/avalanchego/utils/logging" ) var _ Targeter = (*targeter)(nil) @@ -32,11 +36,13 @@ type TargeterConfig struct { } func NewTargeter( + logger logging.Logger, config *TargeterConfig, - vdrs validators.Set, + vdrs validators.Manager, tracker Tracker, ) Targeter { return &targeter{ + log: logger, vdrs: vdrs, tracker: tracker, vdrAlloc: config.VdrAlloc, @@ -46,7 +52,8 @@ func NewTargeter( } type targeter struct { - vdrs validators.Set + vdrs validators.Manager + log logging.Logger tracker Tracker vdrAlloc float64 maxNonVdrUsage float64 @@ -60,7 +67,19 @@ func (t *targeter) TargetUsage(nodeID ids.NodeID) float64 { baseAlloc = math.Min(baseAlloc, t.maxNonVdrNodeUsage) // This node gets a stake-weighted portion of the validator allocation. - weight := t.vdrs.GetWeight(nodeID) - vdrAlloc := t.vdrAlloc * float64(weight) / float64(t.vdrs.Weight()) + weight := t.vdrs.GetWeight(constants.PrimaryNetworkID, nodeID) + if weight == 0 { + return baseAlloc + } + + totalWeight, err := t.vdrs.TotalWeight(constants.PrimaryNetworkID) + if err != nil { + t.log.Error("couldn't get total weight of primary network", + zap.Error(err), + ) + return baseAlloc + } + + vdrAlloc := t.vdrAlloc * float64(weight) / float64(totalWeight) return vdrAlloc + baseAlloc } diff --git a/snow/networking/tracker/targeter_test.go b/snow/networking/tracker/targeter_test.go index 72f18cf6dd87..23096adbed28 100644 --- a/snow/networking/tracker/targeter_test.go +++ b/snow/networking/tracker/targeter_test.go @@ -12,6 +12,8 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow/validators" + "github.com/ava-labs/avalanchego/utils/constants" + "github.com/ava-labs/avalanchego/utils/logging" ) // Assert fields are set correctly. @@ -24,10 +26,11 @@ func TestNewTargeter(t *testing.T) { MaxNonVdrUsage: 10, MaxNonVdrNodeUsage: 10, } - vdrs := validators.NewSet() + vdrs := validators.NewManager() tracker := NewMockTracker(ctrl) targeterIntf := NewTargeter( + logging.NoLog{}, config, vdrs, tracker, @@ -47,9 +50,9 @@ func TestTarget(t *testing.T) { vdrWeight := uint64(1) totalVdrWeight := uint64(10) nonVdr := ids.NodeID{2} - vdrs := validators.NewSet() - require.NoError(t, vdrs.Add(vdr, nil, ids.Empty, 1)) - require.NoError(t, vdrs.Add(ids.GenerateTestNodeID(), nil, ids.Empty, totalVdrWeight-vdrWeight)) + vdrs := validators.NewManager() + require.NoError(t, vdrs.AddStaker(constants.PrimaryNetworkID, vdr, nil, ids.Empty, 1)) + require.NoError(t, vdrs.AddStaker(constants.PrimaryNetworkID, ids.GenerateTestNodeID(), nil, ids.Empty, totalVdrWeight-vdrWeight)) tracker := NewMockTracker(ctrl) config := &TargeterConfig{ @@ -59,6 +62,7 @@ func TestTarget(t *testing.T) { } targeter := NewTargeter( + logging.NoLog{}, config, vdrs, tracker, diff --git a/snow/validators/manager.go b/snow/validators/manager.go index a58d86c71d0f..c42ea779d96b 100644 --- a/snow/validators/manager.go +++ b/snow/validators/manager.go @@ -14,33 +14,89 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils" "github.com/ava-labs/avalanchego/utils/crypto/bls" + "github.com/ava-labs/avalanchego/utils/set" ) var ( _ Manager = (*manager)(nil) + ErrZeroWeight = errors.New("weight must be non-zero") ErrMissingValidators = errors.New("missing validators") ) +type SetCallbackListener interface { + OnValidatorAdded(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) + OnValidatorRemoved(nodeID ids.NodeID, weight uint64) + OnValidatorWeightChanged(nodeID ids.NodeID, oldWeight, newWeight uint64) +} + // Manager holds the validator set of each subnet type Manager interface { fmt.Stringer - // Add a subnet's validator set to the manager. - // - // If the subnet had previously registered a validator set, false will be - // returned and the manager will not be modified. - Add(subnetID ids.ID, set Set) bool + // Add a new staker to the subnet. + // Returns an error if: + // - [weight] is 0 + // - [nodeID] is already in the validator set + // If an error is returned, the set will be unmodified. + AddStaker(subnetID ids.ID, nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) error + + // AddWeight to an existing staker to the subnet. + // Returns an error if: + // - [weight] is 0 + // - [nodeID] is not already in the validator set + // If an error is returned, the set will be unmodified. + // AddWeight can result in a total weight that overflows uint64. + // In this case no error will be returned for this call. + // However, the next TotalWeight call will return an error. + AddWeight(subnetID ids.ID, nodeID ids.NodeID, weight uint64) error + + // GetWeight retrieves the validator weight from the subnet. + GetWeight(subnetID ids.ID, nodeID ids.NodeID) uint64 + + // GetValidator returns the validator tied to the specified ID in subnet. + // If the validator doesn't exist, returns false. + GetValidator(subnetID ids.ID, nodeID ids.NodeID) (*Validator, bool) + + // GetValidatoIDs returns the validator IDs in the subnet. + GetValidatorIDs(subnetID ids.ID) []ids.NodeID + + // SubsetWeight returns the sum of the weights of the validators in the subnet. + // Returns err if subset weight overflows uint64. + SubsetWeight(subnetID ids.ID, validatorIDs set.Set[ids.NodeID]) (uint64, error) + + // RemoveWeight from a staker in the subnet. If the staker's weight becomes 0, the staker + // will be removed from the subnet set. + // Returns an error if: + // - [weight] is 0 + // - [nodeID] is not already in the subnet set + // - the weight of the validator would become negative + // If an error is returned, the set will be unmodified. + RemoveWeight(subnetID ids.ID, nodeID ids.NodeID, weight uint64) error + + // Count returns the number of validators currently in the subnet. + Count(subnetID ids.ID) int + + // TotalWeight returns the cumulative weight of all validators in the subnet. + // Returns err if total weight overflows uint64. + TotalWeight(subnetID ids.ID) (uint64, error) + + // Sample returns a collection of validatorIDs in the subnet, potentially with duplicates. + // If sampling the requested size isn't possible, an error will be returned. + Sample(subnetID ids.ID, size int) ([]ids.NodeID, error) - // Get returns the validator set for the given subnet - // Returns false if the subnet doesn't exist - Get(ids.ID) (Set, bool) + // Map of the validators in this subnet + GetMap(subnetID ids.ID) map[ids.NodeID]*GetValidatorOutput + + // When a validator's weight changes, or a validator is added/removed, + // this listener is called. + RegisterCallbackListener(subnetID ids.ID, listener SetCallbackListener) } // NewManager returns a new, empty manager func NewManager() Manager { return &manager{ - subnetToVdrs: make(map[ids.ID]Set), + subnetToVdrs: make(map[ids.ID]*vdrSet), } } @@ -49,27 +105,176 @@ type manager struct { // Key: Subnet ID // Value: The validators that validate the subnet - subnetToVdrs map[ids.ID]Set + subnetToVdrs map[ids.ID]*vdrSet } -func (m *manager) Add(subnetID ids.ID, set Set) bool { +func (m *manager) AddStaker(subnetID ids.ID, nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) error { + if weight == 0 { + return ErrZeroWeight + } + m.lock.Lock() defer m.lock.Unlock() - if _, exists := m.subnetToVdrs[subnetID]; exists { - return false + set, exists := m.subnetToVdrs[subnetID] + if !exists { + set = newSet() + m.subnetToVdrs[subnetID] = set } - m.subnetToVdrs[subnetID] = set - return true + return set.Add(nodeID, pk, txID, weight) } -func (m *manager) Get(subnetID ids.ID) (Set, bool) { +func (m *manager) AddWeight(subnetID ids.ID, nodeID ids.NodeID, weight uint64) error { + if weight == 0 { + return ErrZeroWeight + } + + // We do not need to grab a write lock here because we never modify the + // subnetToVdrs map. However, we must hold the read lock during the entirity + // of this function to ensure that errors are returned consistently. + // + // Consider the case that: + // AddStaker(subnetID, nodeID, 1) + // go func() { + // AddWeight(subnetID, nodeID, 1) + // } + // go func() { + // RemoveWeight(subnetID, nodeID, 1) + // } + // + // In this case, after both goroutines have finished, either AddWeight + // should have errored, or the weight of the node should equal 1. It would + // be unexpected to not have received an error from AddWeight but for the + // node to no longer be tracked as a validator. m.lock.RLock() defer m.lock.RUnlock() - vdrs, ok := m.subnetToVdrs[subnetID] - return vdrs, ok + set, exists := m.subnetToVdrs[subnetID] + if !exists { + return errMissingValidator + } + + return set.AddWeight(nodeID, weight) +} + +func (m *manager) GetWeight(subnetID ids.ID, nodeID ids.NodeID) uint64 { + m.lock.RLock() + set, exists := m.subnetToVdrs[subnetID] + m.lock.RUnlock() + if !exists { + return 0 + } + + return set.GetWeight(nodeID) +} + +func (m *manager) GetValidator(subnetID ids.ID, nodeID ids.NodeID) (*Validator, bool) { + m.lock.RLock() + set, exists := m.subnetToVdrs[subnetID] + m.lock.RUnlock() + if !exists { + return nil, false + } + + return set.Get(nodeID) +} + +func (m *manager) SubsetWeight(subnetID ids.ID, validatorIDs set.Set[ids.NodeID]) (uint64, error) { + m.lock.RLock() + set, exists := m.subnetToVdrs[subnetID] + m.lock.RUnlock() + if !exists { + return 0, nil + } + + return set.SubsetWeight(validatorIDs) +} + +func (m *manager) RemoveWeight(subnetID ids.ID, nodeID ids.NodeID, weight uint64) error { + if weight == 0 { + return ErrZeroWeight + } + + m.lock.Lock() + defer m.lock.Unlock() + + set, exists := m.subnetToVdrs[subnetID] + if !exists { + return errMissingValidator + } + + if err := set.RemoveWeight(nodeID, weight); err != nil { + return err + } + // If this was the last validator in the subnet and no callback listeners + // are registered, remove the subnet + if set.Len() == 0 && !set.HasCallbackRegistered() { + delete(m.subnetToVdrs, subnetID) + } + + return nil +} + +func (m *manager) Count(subnetID ids.ID) int { + m.lock.RLock() + set, exists := m.subnetToVdrs[subnetID] + m.lock.RUnlock() + if !exists { + return 0 + } + + return set.Len() +} + +func (m *manager) TotalWeight(subnetID ids.ID) (uint64, error) { + m.lock.RLock() + set, exists := m.subnetToVdrs[subnetID] + m.lock.RUnlock() + if !exists { + return 0, nil + } + + return set.TotalWeight() +} + +func (m *manager) Sample(subnetID ids.ID, size int) ([]ids.NodeID, error) { + if size == 0 { + return nil, nil + } + + m.lock.RLock() + set, exists := m.subnetToVdrs[subnetID] + m.lock.RUnlock() + if !exists { + return nil, ErrMissingValidators + } + + return set.Sample(size) +} + +func (m *manager) GetMap(subnetID ids.ID) map[ids.NodeID]*GetValidatorOutput { + m.lock.RLock() + set, exists := m.subnetToVdrs[subnetID] + m.lock.RUnlock() + if !exists { + return make(map[ids.NodeID]*GetValidatorOutput) + } + + return set.Map() +} + +func (m *manager) RegisterCallbackListener(subnetID ids.ID, listener SetCallbackListener) { + m.lock.Lock() + defer m.lock.Unlock() + + set, exists := m.subnetToVdrs[subnetID] + if !exists { + set = newSet() + m.subnetToVdrs[subnetID] = set + } + + set.RegisterCallbackListener(listener) } func (m *manager) String() string { @@ -96,62 +301,13 @@ func (m *manager) String() string { return sb.String() } -// Add is a helper that fetches the validator set of [subnetID] from [m] and -// adds [nodeID] to the validator set. -// Returns an error if: -// - [subnetID] does not have a registered validator set in [m] -// - adding [nodeID] to the validator set returns an error -func Add(m Manager, subnetID ids.ID, nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) error { - vdrs, ok := m.Get(subnetID) - if !ok { - return fmt.Errorf("%w: %s", ErrMissingValidators, subnetID) - } - return vdrs.Add(nodeID, pk, txID, weight) -} - -// AddWeight is a helper that fetches the validator set of [subnetID] from [m] -// and adds [weight] to [nodeID] in the validator set. -// Returns an error if: -// - [subnetID] does not have a registered validator set in [m] -// - adding [weight] to [nodeID] in the validator set returns an error -func AddWeight(m Manager, subnetID ids.ID, nodeID ids.NodeID, weight uint64) error { - vdrs, ok := m.Get(subnetID) - if !ok { - return fmt.Errorf("%w: %s", ErrMissingValidators, subnetID) - } - return vdrs.AddWeight(nodeID, weight) -} - -// RemoveWeight is a helper that fetches the validator set of [subnetID] from -// [m] and removes [weight] from [nodeID] in the validator set. -// Returns an error if: -// - [subnetID] does not have a registered validator set in [m] -// - removing [weight] from [nodeID] in the validator set returns an error -func RemoveWeight(m Manager, subnetID ids.ID, nodeID ids.NodeID, weight uint64) error { - vdrs, ok := m.Get(subnetID) - if !ok { - return fmt.Errorf("%w: %s", ErrMissingValidators, subnetID) - } - return vdrs.RemoveWeight(nodeID, weight) -} - -// Contains is a helper that fetches the validator set of [subnetID] from [m] -// and returns if the validator set contains [nodeID]. If [m] does not contain a -// validator set for [subnetID], false is returned. -func Contains(m Manager, subnetID ids.ID, nodeID ids.NodeID) bool { - vdrs, ok := m.Get(subnetID) - if !ok { - return false - } - return vdrs.Contains(nodeID) -} - -func NodeIDs(m Manager, subnetID ids.ID) ([]ids.NodeID, error) { - vdrs, exist := m.Get(subnetID) +func (m *manager) GetValidatorIDs(subnetID ids.ID) []ids.NodeID { + m.lock.RLock() + vdrs, exist := m.subnetToVdrs[subnetID] + m.lock.RUnlock() if !exist { - return nil, fmt.Errorf("%w: %s", ErrMissingValidators, subnetID) + return nil } - vdrsMap := vdrs.Map() - return maps.Keys(vdrsMap), nil + return vdrs.GetValidatorIDs() } diff --git a/snow/validators/manager_test.go b/snow/validators/manager_test.go index e18ad7a5ddd1..01a84201f91d 100644 --- a/snow/validators/manager_test.go +++ b/snow/validators/manager_test.go @@ -4,98 +4,549 @@ package validators import ( + "math" "testing" "github.com/stretchr/testify/require" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils/crypto/bls" + "github.com/ava-labs/avalanchego/utils/sampler" + "github.com/ava-labs/avalanchego/utils/set" + + safemath "github.com/ava-labs/avalanchego/utils/math" ) -func TestAdd(t *testing.T) { +func TestAddZeroWeight(t *testing.T) { + require := require.New(t) + + m := NewManager().(*manager) + err := m.AddStaker(ids.GenerateTestID(), ids.GenerateTestNodeID(), nil, ids.Empty, 0) + require.ErrorIs(err, ErrZeroWeight) + require.Empty(m.subnetToVdrs) +} + +func TestAddDuplicate(t *testing.T) { require := require.New(t) m := NewManager() + subnetID := ids.GenerateTestID() + + nodeID := ids.GenerateTestNodeID() + require.NoError(m.AddStaker(subnetID, nodeID, nil, ids.Empty, 1)) + + err := m.AddStaker(subnetID, nodeID, nil, ids.Empty, 1) + require.ErrorIs(err, errDuplicateValidator) +} +func TestAddOverflow(t *testing.T) { + require := require.New(t) + + m := NewManager() subnetID := ids.GenerateTestID() + nodeID1 := ids.GenerateTestNodeID() + nodeID2 := ids.GenerateTestNodeID() + require.NoError(m.AddStaker(subnetID, nodeID1, nil, ids.Empty, 1)) + + require.NoError(m.AddStaker(subnetID, nodeID2, nil, ids.Empty, math.MaxUint64)) + + _, err := m.TotalWeight(subnetID) + require.ErrorIs(err, errTotalWeightNotUint64) + + set := set.Of(nodeID1, nodeID2) + _, err = m.SubsetWeight(subnetID, set) + require.ErrorIs(err, safemath.ErrOverflow) +} + +func TestAddWeightZeroWeight(t *testing.T) { + require := require.New(t) + + m := NewManager() + subnetID := ids.GenerateTestID() + + nodeID := ids.GenerateTestNodeID() + require.NoError(m.AddStaker(subnetID, nodeID, nil, ids.Empty, 1)) + + err := m.AddWeight(subnetID, nodeID, 0) + require.ErrorIs(err, ErrZeroWeight) +} + +func TestAddWeightOverflow(t *testing.T) { + require := require.New(t) + + m := NewManager() + subnetID := ids.GenerateTestID() + require.NoError(m.AddStaker(subnetID, ids.GenerateTestNodeID(), nil, ids.Empty, 1)) + nodeID := ids.GenerateTestNodeID() + require.NoError(m.AddStaker(subnetID, nodeID, nil, ids.Empty, 1)) + + require.NoError(m.AddWeight(subnetID, nodeID, math.MaxUint64-1)) + + _, err := m.TotalWeight(subnetID) + require.ErrorIs(err, errTotalWeightNotUint64) +} + +func TestGetWeight(t *testing.T) { + require := require.New(t) - err := Add(m, subnetID, nodeID, nil, ids.Empty, 1) - require.ErrorIs(err, ErrMissingValidators) + m := NewManager() + subnetID := ids.GenerateTestID() - s := NewSet() - m.Add(subnetID, s) + nodeID := ids.GenerateTestNodeID() + require.Zero(m.GetWeight(subnetID, nodeID)) - require.NoError(Add(m, subnetID, nodeID, nil, ids.Empty, 1)) + require.NoError(m.AddStaker(subnetID, nodeID, nil, ids.Empty, 1)) - require.Equal(uint64(1), s.Weight()) + totalWeight, err := m.TotalWeight(subnetID) + require.NoError(err) + require.Equal(uint64(1), totalWeight) } -func TestAddWeight(t *testing.T) { +func TestSubsetWeight(t *testing.T) { require := require.New(t) + nodeID0 := ids.GenerateTestNodeID() + nodeID1 := ids.GenerateTestNodeID() + nodeID2 := ids.GenerateTestNodeID() + + weight0 := uint64(93) + weight1 := uint64(123) + weight2 := uint64(810) + + subset := set.Of(nodeID0, nodeID1) + m := NewManager() + subnetID := ids.GenerateTestID() + + require.NoError(m.AddStaker(subnetID, nodeID0, nil, ids.Empty, weight0)) + require.NoError(m.AddStaker(subnetID, nodeID1, nil, ids.Empty, weight1)) + require.NoError(m.AddStaker(subnetID, nodeID2, nil, ids.Empty, weight2)) + + expectedWeight := weight0 + weight1 + subsetWeight, err := m.SubsetWeight(subnetID, subset) + require.NoError(err) + require.Equal(expectedWeight, subsetWeight) +} +func TestRemoveWeightZeroWeight(t *testing.T) { + require := require.New(t) + + m := NewManager() subnetID := ids.GenerateTestID() nodeID := ids.GenerateTestNodeID() + require.NoError(m.AddStaker(subnetID, nodeID, nil, ids.Empty, 1)) + + err := m.RemoveWeight(subnetID, nodeID, 0) + require.ErrorIs(err, ErrZeroWeight) +} - err := AddWeight(m, subnetID, nodeID, 1) - require.ErrorIs(err, ErrMissingValidators) +func TestRemoveWeightMissingValidator(t *testing.T) { + require := require.New(t) + + m := NewManager() + subnetID := ids.GenerateTestID() - s := NewSet() - m.Add(subnetID, s) + require.NoError(m.AddStaker(subnetID, ids.GenerateTestNodeID(), nil, ids.Empty, 1)) - err = AddWeight(m, subnetID, nodeID, 1) + err := m.RemoveWeight(subnetID, ids.GenerateTestNodeID(), 1) require.ErrorIs(err, errMissingValidator) +} + +func TestRemoveWeightUnderflow(t *testing.T) { + require := require.New(t) - require.NoError(Add(m, subnetID, nodeID, nil, ids.Empty, 1)) + m := NewManager() + subnetID := ids.GenerateTestID() - require.NoError(AddWeight(m, subnetID, nodeID, 1)) + require.NoError(m.AddStaker(subnetID, ids.GenerateTestNodeID(), nil, ids.Empty, 1)) - require.Equal(uint64(2), s.Weight()) + nodeID := ids.GenerateTestNodeID() + require.NoError(m.AddStaker(subnetID, nodeID, nil, ids.Empty, 1)) + + err := m.RemoveWeight(subnetID, nodeID, 2) + require.ErrorIs(err, safemath.ErrUnderflow) + + totalWeight, err := m.TotalWeight(subnetID) + require.NoError(err) + require.Equal(uint64(2), totalWeight) } -func TestRemoveWeight(t *testing.T) { +func TestGet(t *testing.T) { require := require.New(t) m := NewManager() - subnetID := ids.GenerateTestID() + nodeID := ids.GenerateTestNodeID() + _, ok := m.GetValidator(subnetID, nodeID) + require.False(ok) + + sk, err := bls.NewSecretKey() + require.NoError(err) + + pk := bls.PublicFromSecretKey(sk) + require.NoError(m.AddStaker(subnetID, nodeID, pk, ids.Empty, 1)) + + vdr0, ok := m.GetValidator(subnetID, nodeID) + require.True(ok) + require.Equal(nodeID, vdr0.NodeID) + require.Equal(pk, vdr0.PublicKey) + require.Equal(uint64(1), vdr0.Weight) + + require.NoError(m.AddWeight(subnetID, nodeID, 1)) + + vdr1, ok := m.GetValidator(subnetID, nodeID) + require.True(ok) + require.Equal(nodeID, vdr0.NodeID) + require.Equal(pk, vdr0.PublicKey) + require.Equal(uint64(1), vdr0.Weight) + require.Equal(nodeID, vdr1.NodeID) + require.Equal(pk, vdr1.PublicKey) + require.Equal(uint64(2), vdr1.Weight) + + require.NoError(m.RemoveWeight(subnetID, nodeID, 2)) + _, ok = m.GetValidator(subnetID, nodeID) + require.False(ok) +} + +func TestLen(t *testing.T) { + require := require.New(t) + + m := NewManager() + subnetID := ids.GenerateTestID() + + len := m.Count(subnetID) + require.Zero(len) + + nodeID0 := ids.GenerateTestNodeID() + require.NoError(m.AddStaker(subnetID, nodeID0, nil, ids.Empty, 1)) - err := RemoveWeight(m, subnetID, nodeID, 1) - require.ErrorIs(err, ErrMissingValidators) + len = m.Count(subnetID) + require.Equal(1, len) - s := NewSet() - m.Add(subnetID, s) + nodeID1 := ids.GenerateTestNodeID() + require.NoError(m.AddStaker(subnetID, nodeID1, nil, ids.Empty, 1)) - require.NoError(Add(m, subnetID, nodeID, nil, ids.Empty, 2)) + len = m.Count(subnetID) + require.Equal(2, len) - require.NoError(RemoveWeight(m, subnetID, nodeID, 1)) + require.NoError(m.RemoveWeight(subnetID, nodeID1, 1)) - require.Equal(uint64(1), s.Weight()) + len = m.Count(subnetID) + require.Equal(1, len) - require.NoError(RemoveWeight(m, subnetID, nodeID, 1)) + require.NoError(m.RemoveWeight(subnetID, nodeID0, 1)) - require.Zero(s.Weight()) + len = m.Count(subnetID) + require.Zero(len) } -func TestContains(t *testing.T) { +func TestGetMap(t *testing.T) { require := require.New(t) m := NewManager() + subnetID := ids.GenerateTestID() + + mp := m.GetMap(subnetID) + require.Empty(mp) + + sk, err := bls.NewSecretKey() + require.NoError(err) + + pk := bls.PublicFromSecretKey(sk) + nodeID0 := ids.GenerateTestNodeID() + require.NoError(m.AddStaker(subnetID, nodeID0, pk, ids.Empty, 2)) + + mp = m.GetMap(subnetID) + require.Len(mp, 1) + require.Contains(mp, nodeID0) + + node0 := mp[nodeID0] + require.Equal(nodeID0, node0.NodeID) + require.Equal(pk, node0.PublicKey) + require.Equal(uint64(2), node0.Weight) + + nodeID1 := ids.GenerateTestNodeID() + require.NoError(m.AddStaker(subnetID, nodeID1, nil, ids.Empty, 1)) + + mp = m.GetMap(subnetID) + require.Len(mp, 2) + require.Contains(mp, nodeID0) + require.Contains(mp, nodeID1) + + node0 = mp[nodeID0] + require.Equal(nodeID0, node0.NodeID) + require.Equal(pk, node0.PublicKey) + require.Equal(uint64(2), node0.Weight) + + node1 := mp[nodeID1] + require.Equal(nodeID1, node1.NodeID) + require.Nil(node1.PublicKey) + require.Equal(uint64(1), node1.Weight) + require.NoError(m.RemoveWeight(subnetID, nodeID0, 1)) + require.Equal(nodeID0, node0.NodeID) + require.Equal(pk, node0.PublicKey) + require.Equal(uint64(2), node0.Weight) + + mp = m.GetMap(subnetID) + require.Len(mp, 2) + require.Contains(mp, nodeID0) + require.Contains(mp, nodeID1) + + node0 = mp[nodeID0] + require.Equal(nodeID0, node0.NodeID) + require.Equal(pk, node0.PublicKey) + require.Equal(uint64(1), node0.Weight) + + node1 = mp[nodeID1] + require.Equal(nodeID1, node1.NodeID) + require.Nil(node1.PublicKey) + require.Equal(uint64(1), node1.Weight) + + require.NoError(m.RemoveWeight(subnetID, nodeID0, 1)) + + mp = m.GetMap(subnetID) + require.Len(mp, 1) + require.Contains(mp, nodeID1) + + node1 = mp[nodeID1] + require.Equal(nodeID1, node1.NodeID) + require.Nil(node1.PublicKey) + require.Equal(uint64(1), node1.Weight) + + require.NoError(m.RemoveWeight(subnetID, nodeID1, 1)) + + require.Empty(m.GetMap(subnetID)) +} + +func TestWeight(t *testing.T) { + require := require.New(t) + + vdr0 := ids.NodeID{1} + weight0 := uint64(93) + vdr1 := ids.NodeID{2} + weight1 := uint64(123) + + m := NewManager() subnetID := ids.GenerateTestID() - nodeID := ids.GenerateTestNodeID() + require.NoError(m.AddStaker(subnetID, vdr0, nil, ids.Empty, weight0)) + + require.NoError(m.AddStaker(subnetID, vdr1, nil, ids.Empty, weight1)) + + setWeight, err := m.TotalWeight(subnetID) + require.NoError(err) + expectedWeight := weight0 + weight1 + require.Equal(expectedWeight, setWeight) +} + +func TestSample(t *testing.T) { + require := require.New(t) + + m := NewManager() + subnetID := ids.GenerateTestID() + + sampled, err := m.Sample(subnetID, 0) + require.NoError(err) + require.Empty(sampled) + + sk, err := bls.NewSecretKey() + require.NoError(err) + + nodeID0 := ids.GenerateTestNodeID() + pk := bls.PublicFromSecretKey(sk) + require.NoError(m.AddStaker(subnetID, nodeID0, pk, ids.Empty, 1)) + + sampled, err = m.Sample(subnetID, 1) + require.NoError(err) + require.Equal([]ids.NodeID{nodeID0}, sampled) - require.False(Contains(m, subnetID, nodeID)) + _, err = m.Sample(subnetID, 2) + require.ErrorIs(err, sampler.ErrOutOfRange) - s := NewSet() - m.Add(subnetID, s) - require.False(Contains(m, subnetID, nodeID)) + nodeID1 := ids.GenerateTestNodeID() + require.NoError(m.AddStaker(subnetID, nodeID1, nil, ids.Empty, math.MaxInt64-1)) - require.NoError(Add(m, subnetID, nodeID, nil, ids.Empty, 1)) - require.True(Contains(m, subnetID, nodeID)) + sampled, err = m.Sample(subnetID, 1) + require.NoError(err) + require.Equal([]ids.NodeID{nodeID1}, sampled) - require.NoError(RemoveWeight(m, subnetID, nodeID, 1)) - require.False(Contains(m, subnetID, nodeID)) + sampled, err = m.Sample(subnetID, 2) + require.NoError(err) + require.Equal([]ids.NodeID{nodeID1, nodeID1}, sampled) + + sampled, err = m.Sample(subnetID, 3) + require.NoError(err) + require.Equal([]ids.NodeID{nodeID1, nodeID1, nodeID1}, sampled) +} + +func TestString(t *testing.T) { + require := require.New(t) + + nodeID0 := ids.EmptyNodeID + nodeID1, err := ids.NodeIDFromString("NodeID-QLbz7JHiBTspS962RLKV8GndWFwdYhk6V") + require.NoError(err) + + subnetID0, err := ids.FromString("TtF4d2QWbk5vzQGTEPrN48x6vwgAoAmKQ9cbp79inpQmcRKES") + require.NoError(err) + subnetID1, err := ids.FromString("2mcwQKiD8VEspmMJpL1dc7okQQ5dDVAWeCBZ7FWBFAbxpv3t7w") + require.NoError(err) + + m := NewManager() + require.NoError(m.AddStaker(subnetID0, nodeID0, nil, ids.Empty, 1)) + require.NoError(m.AddStaker(subnetID0, nodeID1, nil, ids.Empty, math.MaxInt64-1)) + require.NoError(m.AddStaker(subnetID1, nodeID1, nil, ids.Empty, 1)) + + expected := "Validator Manager: (Size = 2)\n" + + " Subnet[TtF4d2QWbk5vzQGTEPrN48x6vwgAoAmKQ9cbp79inpQmcRKES]: Validator Set: (Size = 2, Weight = 9223372036854775807)\n" + + " Validator[0]: NodeID-111111111111111111116DBWJs, 1\n" + + " Validator[1]: NodeID-QLbz7JHiBTspS962RLKV8GndWFwdYhk6V, 9223372036854775806\n" + + " Subnet[2mcwQKiD8VEspmMJpL1dc7okQQ5dDVAWeCBZ7FWBFAbxpv3t7w]: Validator Set: (Size = 1, Weight = 1)\n" + + " Validator[0]: NodeID-QLbz7JHiBTspS962RLKV8GndWFwdYhk6V, 1" + result := m.String() + require.Equal(expected, result) +} + +func TestAddCallback(t *testing.T) { + require := require.New(t) + + nodeID0 := ids.NodeID{1} + sk0, err := bls.NewSecretKey() + require.NoError(err) + pk0 := bls.PublicFromSecretKey(sk0) + txID0 := ids.GenerateTestID() + weight0 := uint64(1) + + m := NewManager() + subnetID := ids.GenerateTestID() + callCount := 0 + m.RegisterCallbackListener(subnetID, &callbackListener{ + t: t, + onAdd: func(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) { + require.Equal(nodeID0, nodeID) + require.Equal(pk0, pk) + require.Equal(txID0, txID) + require.Equal(weight0, weight) + callCount++ + }, + }) + require.NoError(m.AddStaker(subnetID, nodeID0, pk0, txID0, weight0)) + // setup another subnetID + subnetID2 := ids.GenerateTestID() + require.NoError(m.AddStaker(subnetID2, nodeID0, nil, txID0, weight0)) + // should not be called for subnetID2 + require.Equal(1, callCount) +} + +func TestAddWeightCallback(t *testing.T) { + require := require.New(t) + + nodeID0 := ids.NodeID{1} + txID0 := ids.GenerateTestID() + weight0 := uint64(1) + weight1 := uint64(93) + + m := NewManager() + subnetID := ids.GenerateTestID() + require.NoError(m.AddStaker(subnetID, nodeID0, nil, txID0, weight0)) + + callCount := 0 + m.RegisterCallbackListener(subnetID, &callbackListener{ + t: t, + onAdd: func(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) { + require.Equal(nodeID0, nodeID) + require.Nil(pk) + require.Equal(txID0, txID) + require.Equal(weight0, weight) + callCount++ + }, + onWeight: func(nodeID ids.NodeID, oldWeight, newWeight uint64) { + require.Equal(nodeID0, nodeID) + require.Equal(weight0, oldWeight) + require.Equal(weight0+weight1, newWeight) + callCount++ + }, + }) + require.NoError(m.AddWeight(subnetID, nodeID0, weight1)) + // setup another subnetID + subnetID2 := ids.GenerateTestID() + require.NoError(m.AddStaker(subnetID2, nodeID0, nil, txID0, weight0)) + require.NoError(m.AddWeight(subnetID2, nodeID0, weight1)) + // should not be called for subnetID2 + require.Equal(2, callCount) +} + +func TestRemoveWeightCallback(t *testing.T) { + require := require.New(t) + + nodeID0 := ids.NodeID{1} + txID0 := ids.GenerateTestID() + weight0 := uint64(93) + weight1 := uint64(92) + + m := NewManager() + subnetID := ids.GenerateTestID() + require.NoError(m.AddStaker(subnetID, nodeID0, nil, txID0, weight0)) + + callCount := 0 + m.RegisterCallbackListener(subnetID, &callbackListener{ + t: t, + onAdd: func(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) { + require.Equal(nodeID0, nodeID) + require.Nil(pk) + require.Equal(txID0, txID) + require.Equal(weight0, weight) + callCount++ + }, + onWeight: func(nodeID ids.NodeID, oldWeight, newWeight uint64) { + require.Equal(nodeID0, nodeID) + require.Equal(weight0, oldWeight) + require.Equal(weight0-weight1, newWeight) + callCount++ + }, + }) + require.NoError(m.RemoveWeight(subnetID, nodeID0, weight1)) + // setup another subnetID + subnetID2 := ids.GenerateTestID() + require.NoError(m.AddStaker(subnetID2, nodeID0, nil, txID0, weight0)) + require.NoError(m.RemoveWeight(subnetID2, nodeID0, weight1)) + // should not be called for subnetID2 + require.Equal(2, callCount) +} + +func TestValidatorRemovedCallback(t *testing.T) { + require := require.New(t) + + nodeID0 := ids.NodeID{1} + txID0 := ids.GenerateTestID() + weight0 := uint64(93) + + m := NewManager() + subnetID := ids.GenerateTestID() + require.NoError(m.AddStaker(subnetID, nodeID0, nil, txID0, weight0)) + + callCount := 0 + m.RegisterCallbackListener(subnetID, &callbackListener{ + t: t, + onAdd: func(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) { + require.Equal(nodeID0, nodeID) + require.Nil(pk) + require.Equal(txID0, txID) + require.Equal(weight0, weight) + callCount++ + }, + onRemoved: func(nodeID ids.NodeID, weight uint64) { + require.Equal(nodeID0, nodeID) + require.Equal(weight0, weight) + callCount++ + }, + }) + require.NoError(m.RemoveWeight(subnetID, nodeID0, weight0)) + // setup another subnetID + subnetID2 := ids.GenerateTestID() + require.NoError(m.AddStaker(subnetID2, nodeID0, nil, txID0, weight0)) + require.NoError(m.AddWeight(subnetID2, nodeID0, weight0)) + // should not be called for subnetID2 + require.Equal(2, callCount) } diff --git a/snow/validators/mock_manager.go b/snow/validators/mock_manager.go index 347a1a6d5f67..2b99245710fb 100644 --- a/snow/validators/mock_manager.go +++ b/snow/validators/mock_manager.go @@ -1,6 +1,7 @@ // Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. // See the file LICENSE for licensing terms. +// Do not include this in mocks.mockgen.txt as bls package won't be available. // Code generated by MockGen. DO NOT EDIT. // Source: github.com/ava-labs/avalanchego/snow/validators (interfaces: Manager) @@ -11,6 +12,8 @@ import ( reflect "reflect" ids "github.com/ava-labs/avalanchego/ids" + bls "github.com/ava-labs/avalanchego/utils/crypto/bls" + set "github.com/ava-labs/avalanchego/utils/set" gomock "go.uber.org/mock/gomock" ) @@ -37,33 +40,159 @@ func (m *MockManager) EXPECT() *MockManagerMockRecorder { return m.recorder } -// Add mocks base method. -func (m *MockManager) Add(arg0 ids.ID, arg1 Set) bool { +// AddStaker mocks base method. +func (m *MockManager) AddStaker(arg0 ids.ID, arg1 ids.NodeID, arg2 *bls.PublicKey, arg3 ids.ID, arg4 uint64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Add", arg0, arg1) + ret := m.ctrl.Call(m, "AddStaker", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddStaker indicates an expected call of AddStaker. +func (mr *MockManagerMockRecorder) AddStaker(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddStaker", reflect.TypeOf((*MockManager)(nil).AddStaker), arg0, arg1, arg2, arg3, arg4) +} + +// AddWeight mocks base method. +func (m *MockManager) AddWeight(arg0 ids.ID, arg1 ids.NodeID, arg2 uint64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddWeight", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddWeight indicates an expected call of AddWeight. +func (mr *MockManagerMockRecorder) AddWeight(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddWeight", reflect.TypeOf((*MockManager)(nil).AddWeight), arg0, arg1, arg2) +} + +// Contains mocks base method. +func (m *MockManager) Contains(arg0 ids.ID, arg1 ids.NodeID) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Contains", arg0, arg1) ret0, _ := ret[0].(bool) return ret0 } -// Add indicates an expected call of Add. -func (mr *MockManagerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call { +// Contains indicates an expected call of Contains. +func (mr *MockManagerMockRecorder) Contains(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Contains", reflect.TypeOf((*MockManager)(nil).Contains), arg0, arg1) +} + +// GetMap mocks base method. +func (m *MockManager) GetMap(arg0 ids.ID) map[ids.NodeID]*GetValidatorOutput { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMap", arg0) + ret0, _ := ret[0].(map[ids.NodeID]*GetValidatorOutput) + return ret0 +} + +// GetMap indicates an expected call of GetMap. +func (mr *MockManagerMockRecorder) GetMap(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockManager)(nil).Add), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMap", reflect.TypeOf((*MockManager)(nil).GetMap), arg0) } -// Get mocks base method. -func (m *MockManager) Get(arg0 ids.ID) (Set, bool) { +// GetValidator mocks base method. +func (m *MockManager) GetValidator(arg0 ids.ID, arg1 ids.NodeID) (*Validator, bool) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get", arg0) - ret0, _ := ret[0].(Set) + ret := m.ctrl.Call(m, "GetValidator", arg0, arg1) + ret0, _ := ret[0].(*Validator) ret1, _ := ret[1].(bool) return ret0, ret1 } -// Get indicates an expected call of Get. -func (mr *MockManagerMockRecorder) Get(arg0 interface{}) *gomock.Call { +// GetValidator indicates an expected call of GetValidator. +func (mr *MockManagerMockRecorder) GetValidator(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidator", reflect.TypeOf((*MockManager)(nil).GetValidator), arg0, arg1) +} + +// GetValidatorIDs mocks base method. +func (m *MockManager) GetValidatorIDs(arg0 ids.ID) ([]ids.NodeID, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetValidatorIDs", arg0) + ret0, _ := ret[0].([]ids.NodeID) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetValidatorIDs indicates an expected call of GetValidatorIDs. +func (mr *MockManagerMockRecorder) GetValidatorIDs(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatorIDs", reflect.TypeOf((*MockManager)(nil).GetValidatorIDs), arg0) +} + +// GetWeight mocks base method. +func (m *MockManager) GetWeight(arg0 ids.ID, arg1 ids.NodeID) uint64 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetWeight", arg0, arg1) + ret0, _ := ret[0].(uint64) + return ret0 +} + +// GetWeight indicates an expected call of GetWeight. +func (mr *MockManagerMockRecorder) GetWeight(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWeight", reflect.TypeOf((*MockManager)(nil).GetWeight), arg0, arg1) +} + +// Len mocks base method. +func (m *MockManager) Len(arg0 ids.ID) int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Len", arg0) + ret0, _ := ret[0].(int) + return ret0 +} + +// Len indicates an expected call of Len. +func (mr *MockManagerMockRecorder) Len(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Len", reflect.TypeOf((*MockManager)(nil).Len), arg0) +} + +// RegisterCallbackListener mocks base method. +func (m *MockManager) RegisterCallbackListener(arg0 ids.ID, arg1 SetCallbackListener) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RegisterCallbackListener", arg0, arg1) +} + +// RegisterCallbackListener indicates an expected call of RegisterCallbackListener. +func (mr *MockManagerMockRecorder) RegisterCallbackListener(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterCallbackListener", reflect.TypeOf((*MockManager)(nil).RegisterCallbackListener), arg0, arg1) +} + +// RemoveWeight mocks base method. +func (m *MockManager) RemoveWeight(arg0 ids.ID, arg1 ids.NodeID, arg2 uint64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoveWeight", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// RemoveWeight indicates an expected call of RemoveWeight. +func (mr *MockManagerMockRecorder) RemoveWeight(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveWeight", reflect.TypeOf((*MockManager)(nil).RemoveWeight), arg0, arg1, arg2) +} + +// Sample mocks base method. +func (m *MockManager) Sample(arg0 ids.ID, arg1 int) ([]ids.NodeID, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Sample", arg0, arg1) + ret0, _ := ret[0].([]ids.NodeID) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Sample indicates an expected call of Sample. +func (mr *MockManagerMockRecorder) Sample(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockManager)(nil).Get), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Sample", reflect.TypeOf((*MockManager)(nil).Sample), arg0, arg1) } // String mocks base method. @@ -79,3 +208,33 @@ func (mr *MockManagerMockRecorder) String() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "String", reflect.TypeOf((*MockManager)(nil).String)) } + +// SubsetWeight mocks base method. +func (m *MockManager) SubsetWeight(arg0 ids.ID, arg1 set.Set[ids.NodeID]) (uint64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SubsetWeight", arg0, arg1) + ret0, _ := ret[0].(uint64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SubsetWeight indicates an expected call of SubsetWeight. +func (mr *MockManagerMockRecorder) SubsetWeight(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubsetWeight", reflect.TypeOf((*MockManager)(nil).SubsetWeight), arg0, arg1) +} + +// TotalWeight mocks base method. +func (m *MockManager) TotalWeight(arg0 ids.ID) (uint64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TotalWeight", arg0) + ret0, _ := ret[0].(uint64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// TotalWeight indicates an expected call of TotalWeight. +func (mr *MockManagerMockRecorder) TotalWeight(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TotalWeight", reflect.TypeOf((*MockManager)(nil).TotalWeight), arg0) +} diff --git a/snow/validators/mock_set.go b/snow/validators/mock_set.go deleted file mode 100644 index a1667f0a5ca5..000000000000 --- a/snow/validators/mock_set.go +++ /dev/null @@ -1,236 +0,0 @@ -// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/ava-labs/avalanchego/snow/validators (interfaces: Set) - -// Package validators is a generated GoMock package. -package validators - -import ( - reflect "reflect" - - ids "github.com/ava-labs/avalanchego/ids" - bls "github.com/ava-labs/avalanchego/utils/crypto/bls" - set "github.com/ava-labs/avalanchego/utils/set" - gomock "go.uber.org/mock/gomock" -) - -// MockSet is a mock of Set interface. -type MockSet struct { - ctrl *gomock.Controller - recorder *MockSetMockRecorder -} - -// MockSetMockRecorder is the mock recorder for MockSet. -type MockSetMockRecorder struct { - mock *MockSet -} - -// NewMockSet creates a new mock instance. -func NewMockSet(ctrl *gomock.Controller) *MockSet { - mock := &MockSet{ctrl: ctrl} - mock.recorder = &MockSetMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockSet) EXPECT() *MockSetMockRecorder { - return m.recorder -} - -// Add mocks base method. -func (m *MockSet) Add(arg0 ids.NodeID, arg1 *bls.PublicKey, arg2 ids.ID, arg3 uint64) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Add", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// Add indicates an expected call of Add. -func (mr *MockSetMockRecorder) Add(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockSet)(nil).Add), arg0, arg1, arg2, arg3) -} - -// AddWeight mocks base method. -func (m *MockSet) AddWeight(arg0 ids.NodeID, arg1 uint64) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddWeight", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// AddWeight indicates an expected call of AddWeight. -func (mr *MockSetMockRecorder) AddWeight(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddWeight", reflect.TypeOf((*MockSet)(nil).AddWeight), arg0, arg1) -} - -// Contains mocks base method. -func (m *MockSet) Contains(arg0 ids.NodeID) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Contains", arg0) - ret0, _ := ret[0].(bool) - return ret0 -} - -// Contains indicates an expected call of Contains. -func (mr *MockSetMockRecorder) Contains(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Contains", reflect.TypeOf((*MockSet)(nil).Contains), arg0) -} - -// Get mocks base method. -func (m *MockSet) Get(arg0 ids.NodeID) (*Validator, bool) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get", arg0) - ret0, _ := ret[0].(*Validator) - ret1, _ := ret[1].(bool) - return ret0, ret1 -} - -// Get indicates an expected call of Get. -func (mr *MockSetMockRecorder) Get(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockSet)(nil).Get), arg0) -} - -// GetWeight mocks base method. -func (m *MockSet) GetWeight(arg0 ids.NodeID) uint64 { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetWeight", arg0) - ret0, _ := ret[0].(uint64) - return ret0 -} - -// GetWeight indicates an expected call of GetWeight. -func (mr *MockSetMockRecorder) GetWeight(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWeight", reflect.TypeOf((*MockSet)(nil).GetWeight), arg0) -} - -// Len mocks base method. -func (m *MockSet) Len() int { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Len") - ret0, _ := ret[0].(int) - return ret0 -} - -// Len indicates an expected call of Len. -func (mr *MockSetMockRecorder) Len() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Len", reflect.TypeOf((*MockSet)(nil).Len)) -} - -// Map mocks base method. -func (m *MockSet) Map() map[ids.NodeID]*GetValidatorOutput { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Map") - ret0, _ := ret[0].(map[ids.NodeID]*GetValidatorOutput) - return ret0 -} - -// Map indicates an expected call of Map. -func (mr *MockSetMockRecorder) Map() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Map", reflect.TypeOf((*MockSet)(nil).Map)) -} - -// PrefixedString mocks base method. -func (m *MockSet) PrefixedString(arg0 string) string { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PrefixedString", arg0) - ret0, _ := ret[0].(string) - return ret0 -} - -// PrefixedString indicates an expected call of PrefixedString. -func (mr *MockSetMockRecorder) PrefixedString(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PrefixedString", reflect.TypeOf((*MockSet)(nil).PrefixedString), arg0) -} - -// RegisterCallbackListener mocks base method. -func (m *MockSet) RegisterCallbackListener(arg0 SetCallbackListener) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "RegisterCallbackListener", arg0) -} - -// RegisterCallbackListener indicates an expected call of RegisterCallbackListener. -func (mr *MockSetMockRecorder) RegisterCallbackListener(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterCallbackListener", reflect.TypeOf((*MockSet)(nil).RegisterCallbackListener), arg0) -} - -// RemoveWeight mocks base method. -func (m *MockSet) RemoveWeight(arg0 ids.NodeID, arg1 uint64) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RemoveWeight", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// RemoveWeight indicates an expected call of RemoveWeight. -func (mr *MockSetMockRecorder) RemoveWeight(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveWeight", reflect.TypeOf((*MockSet)(nil).RemoveWeight), arg0, arg1) -} - -// Sample mocks base method. -func (m *MockSet) Sample(arg0 int) ([]ids.NodeID, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Sample", arg0) - ret0, _ := ret[0].([]ids.NodeID) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Sample indicates an expected call of Sample. -func (mr *MockSetMockRecorder) Sample(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Sample", reflect.TypeOf((*MockSet)(nil).Sample), arg0) -} - -// String mocks base method. -func (m *MockSet) String() string { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "String") - ret0, _ := ret[0].(string) - return ret0 -} - -// String indicates an expected call of String. -func (mr *MockSetMockRecorder) String() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "String", reflect.TypeOf((*MockSet)(nil).String)) -} - -// SubsetWeight mocks base method. -func (m *MockSet) SubsetWeight(arg0 set.Set[ids.NodeID]) uint64 { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SubsetWeight", arg0) - ret0, _ := ret[0].(uint64) - return ret0 -} - -// SubsetWeight indicates an expected call of SubsetWeight. -func (mr *MockSetMockRecorder) SubsetWeight(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubsetWeight", reflect.TypeOf((*MockSet)(nil).SubsetWeight), arg0) -} - -// Weight mocks base method. -func (m *MockSet) Weight() uint64 { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Weight") - ret0, _ := ret[0].(uint64) - return ret0 -} - -// Weight indicates an expected call of Weight. -func (mr *MockSetMockRecorder) Weight() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Weight", reflect.TypeOf((*MockSet)(nil).Weight)) -} diff --git a/snow/validators/set.go b/snow/validators/set.go index fc868783afc2..dfa294a70bbe 100644 --- a/snow/validators/set.go +++ b/snow/validators/set.go @@ -6,6 +6,7 @@ package validators import ( "errors" "fmt" + "math/big" "strings" "sync" @@ -18,92 +19,17 @@ import ( ) var ( - _ Set = (*vdrSet)(nil) - - errZeroWeight = errors.New("weight must be non-zero") - errDuplicateValidator = errors.New("duplicate validator") - errMissingValidator = errors.New("missing validator") + errDuplicateValidator = errors.New("duplicate validator") + errMissingValidator = errors.New("missing validator") + errTotalWeightNotUint64 = errors.New("total weight is not a uint64") ) -// Set of validators that can be sampled -type Set interface { - formatting.PrefixedStringer - - // Add a new staker to the set. - // Returns an error if: - // - [weight] is 0 - // - [nodeID] is already in the validator set - // - the total weight of the validator set would overflow uint64 - // If an error is returned, the set will be unmodified. - Add(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) error - - // AddWeight to an existing staker. - // Returns an error if: - // - [weight] is 0 - // - [nodeID] is not already in the validator set - // - the total weight of the validator set would overflow uint64 - // If an error is returned, the set will be unmodified. - AddWeight(nodeID ids.NodeID, weight uint64) error - - // GetWeight retrieves the validator weight from the set. - GetWeight(ids.NodeID) uint64 - - // Get returns the validator tied to the specified ID. - Get(ids.NodeID) (*Validator, bool) - - // SubsetWeight returns the sum of the weights of the validators. - SubsetWeight(set.Set[ids.NodeID]) uint64 - - // RemoveWeight from a staker. If the staker's weight becomes 0, the staker - // will be removed from the validator set. - // Returns an error if: - // - [weight] is 0 - // - [nodeID] is not already in the validator set - // - the weight of the validator would become negative - // If an error is returned, the set will be unmodified. - RemoveWeight(nodeID ids.NodeID, weight uint64) error - - // Contains returns true if there is a validator with the specified ID - // currently in the set. - Contains(ids.NodeID) bool - - // Len returns the number of validators currently in the set. - Len() int - - // Map of the validators in this set - Map() map[ids.NodeID]*GetValidatorOutput - - // Weight returns the cumulative weight of all validators in the set. - Weight() uint64 - - // Sample returns a collection of validatorIDs, potentially with duplicates. - // If sampling the requested size isn't possible, an error will be returned. - Sample(size int) ([]ids.NodeID, error) - - // When a validator's weight changes, or a validator is added/removed, - // this listener is called. - RegisterCallbackListener(SetCallbackListener) -} - -type SetCallbackListener interface { - OnValidatorAdded(validatorID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) - OnValidatorRemoved(validatorID ids.NodeID, weight uint64) - OnValidatorWeightChanged(validatorID ids.NodeID, oldWeight, newWeight uint64) -} - -// NewSet returns a new, empty set of validators. -func NewSet() Set { +// newSet returns a new, empty set of validators. +func newSet() *vdrSet { return &vdrSet{ - vdrs: make(map[ids.NodeID]*Validator), - sampler: sampler.NewWeightedWithoutReplacement(), - } -} - -// NewBestSet returns a new, empty set of validators. -func NewBestSet(expectedSampleSize int) Set { - return &vdrSet{ - vdrs: make(map[ids.NodeID]*Validator), - sampler: sampler.NewBestWeightedWithoutReplacement(expectedSampleSize), + vdrs: make(map[ids.NodeID]*Validator), + sampler: sampler.NewWeightedWithoutReplacement(), + totalWeight: new(big.Int), } } @@ -112,7 +38,7 @@ type vdrSet struct { vdrs map[ids.NodeID]*Validator vdrSlice []*Validator weights []uint64 - totalWeight uint64 + totalWeight *big.Int samplerInitialized bool sampler sampler.WeightedWithoutReplacement @@ -121,10 +47,6 @@ type vdrSet struct { } func (s *vdrSet) Add(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) error { - if weight == 0 { - return errZeroWeight - } - s.lock.Lock() defer s.lock.Unlock() @@ -137,13 +59,6 @@ func (s *vdrSet) add(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight u return errDuplicateValidator } - // We first calculate the new total weight of the set, as this guarantees - // that none of the following operations can overflow. - newTotalWeight, err := math.Add64(s.totalWeight, weight) - if err != nil { - return err - } - vdr := &Validator{ NodeID: nodeID, PublicKey: pk, @@ -154,7 +69,7 @@ func (s *vdrSet) add(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight u s.vdrs[nodeID] = vdr s.vdrSlice = append(s.vdrSlice, vdr) s.weights = append(s.weights, weight) - s.totalWeight = newTotalWeight + s.totalWeight.Add(s.totalWeight, new(big.Int).SetUint64(weight)) s.samplerInitialized = false s.callValidatorAddedCallbacks(nodeID, pk, txID, weight) @@ -162,10 +77,6 @@ func (s *vdrSet) add(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight u } func (s *vdrSet) AddWeight(nodeID ids.NodeID, weight uint64) error { - if weight == 0 { - return errZeroWeight - } - s.lock.Lock() defer s.lock.Unlock() @@ -178,17 +89,14 @@ func (s *vdrSet) addWeight(nodeID ids.NodeID, weight uint64) error { return errMissingValidator } - // We first calculate the new total weight of the set, as this guarantees - // that none of the following operations can overflow. - newTotalWeight, err := math.Add64(s.totalWeight, weight) + oldWeight := vdr.Weight + newWeight, err := math.Add64(oldWeight, weight) if err != nil { return err } - - oldWeight := vdr.Weight - vdr.Weight += weight - s.weights[vdr.index] += weight - s.totalWeight = newTotalWeight + vdr.Weight = newWeight + s.weights[vdr.index] = newWeight + s.totalWeight.Add(s.totalWeight, new(big.Int).SetUint64(weight)) s.samplerInitialized = false s.callWeightChangeCallbacks(nodeID, oldWeight, vdr.Weight) @@ -209,28 +117,28 @@ func (s *vdrSet) getWeight(nodeID ids.NodeID) uint64 { return 0 } -func (s *vdrSet) SubsetWeight(subset set.Set[ids.NodeID]) uint64 { +func (s *vdrSet) SubsetWeight(subset set.Set[ids.NodeID]) (uint64, error) { s.lock.RLock() defer s.lock.RUnlock() return s.subsetWeight(subset) } -func (s *vdrSet) subsetWeight(subset set.Set[ids.NodeID]) uint64 { - var totalWeight uint64 +func (s *vdrSet) subsetWeight(subset set.Set[ids.NodeID]) (uint64, error) { + var ( + totalWeight uint64 + err error + ) for nodeID := range subset { - // Because [totalWeight] will be <= [s.totalWeight], we are guaranteed - // this will not overflow. - totalWeight += s.getWeight(nodeID) + totalWeight, err = math.Add64(totalWeight, s.getWeight(nodeID)) + if err != nil { + return 0, err + } } - return totalWeight + return totalWeight, nil } func (s *vdrSet) RemoveWeight(nodeID ids.NodeID, weight uint64) error { - if weight == 0 { - return errZeroWeight - } - s.lock.Lock() defer s.lock.Unlock() @@ -274,7 +182,7 @@ func (s *vdrSet) removeWeight(nodeID ids.NodeID, weight uint64) error { s.callWeightChangeCallbacks(nodeID, oldWeight, newWeight) } - s.totalWeight -= weight + s.totalWeight.Sub(s.totalWeight, new(big.Int).SetUint64(weight)) s.samplerInitialized = false return nil } @@ -295,27 +203,22 @@ func (s *vdrSet) get(nodeID ids.NodeID) (*Validator, bool) { return &copiedVdr, true } -func (s *vdrSet) Contains(nodeID ids.NodeID) bool { +func (s *vdrSet) Len() int { s.lock.RLock() defer s.lock.RUnlock() - return s.contains(nodeID) + return s.len() } -func (s *vdrSet) contains(nodeID ids.NodeID) bool { - _, contains := s.vdrs[nodeID] - return contains +func (s *vdrSet) len() int { + return len(s.vdrSlice) } -func (s *vdrSet) Len() int { +func (s *vdrSet) HasCallbackRegistered() bool { s.lock.RLock() defer s.lock.RUnlock() - return s.len() -} - -func (s *vdrSet) len() int { - return len(s.vdrSlice) + return len(s.callbackListeners) > 0 } func (s *vdrSet) Map() map[ids.NodeID]*GetValidatorOutput { @@ -334,10 +237,6 @@ func (s *vdrSet) Map() map[ids.NodeID]*GetValidatorOutput { } func (s *vdrSet) Sample(size int) ([]ids.NodeID, error) { - if size == 0 { - return nil, nil - } - s.lock.Lock() defer s.lock.Unlock() @@ -364,11 +263,15 @@ func (s *vdrSet) sample(size int) ([]ids.NodeID, error) { return list, nil } -func (s *vdrSet) Weight() uint64 { +func (s *vdrSet) TotalWeight() (uint64, error) { s.lock.RLock() defer s.lock.RUnlock() - return s.totalWeight + if !s.totalWeight.IsUint64() { + return 0, fmt.Errorf("%w, total weight: %s", errTotalWeightNotUint64, s.totalWeight) + } + + return s.totalWeight.Uint64(), nil } func (s *vdrSet) String() string { @@ -432,3 +335,14 @@ func (s *vdrSet) callValidatorRemovedCallbacks(node ids.NodeID, weight uint64) { callbackListener.OnValidatorRemoved(node, weight) } } + +func (s *vdrSet) GetValidatorIDs() []ids.NodeID { + s.lock.RLock() + defer s.lock.RUnlock() + + list := make([]ids.NodeID, len(s.vdrSlice)) + for i, vdr := range s.vdrSlice { + list[i] = vdr.NodeID + } + return list +} diff --git a/snow/validators/set_test.go b/snow/validators/set_test.go index 91dbb18b0238..99651e7930e0 100644 --- a/snow/validators/set_test.go +++ b/snow/validators/set_test.go @@ -17,18 +17,10 @@ import ( safemath "github.com/ava-labs/avalanchego/utils/math" ) -func TestSetAddZeroWeight(t *testing.T) { - require := require.New(t) - - s := NewSet() - err := s.Add(ids.GenerateTestNodeID(), nil, ids.Empty, 0) - require.ErrorIs(err, errZeroWeight) -} - func TestSetAddDuplicate(t *testing.T) { require := require.New(t) - s := NewSet() + s := newSet() nodeID := ids.GenerateTestNodeID() require.NoError(s.Add(nodeID, nil, ids.Empty, 1)) @@ -40,47 +32,35 @@ func TestSetAddDuplicate(t *testing.T) { func TestSetAddOverflow(t *testing.T) { require := require.New(t) - s := NewSet() + s := newSet() require.NoError(s.Add(ids.GenerateTestNodeID(), nil, ids.Empty, 1)) - err := s.Add(ids.GenerateTestNodeID(), nil, ids.Empty, math.MaxUint64) - require.ErrorIs(err, safemath.ErrOverflow) - - require.Equal(uint64(1), s.Weight()) -} - -func TestSetAddWeightZeroWeight(t *testing.T) { - require := require.New(t) - - s := NewSet() - - nodeID := ids.GenerateTestNodeID() - require.NoError(s.Add(nodeID, nil, ids.Empty, 1)) + require.NoError(s.Add(ids.GenerateTestNodeID(), nil, ids.Empty, math.MaxUint64)) - err := s.AddWeight(nodeID, 0) - require.ErrorIs(err, errZeroWeight) + _, err := s.TotalWeight() + require.ErrorIs(err, errTotalWeightNotUint64) } func TestSetAddWeightOverflow(t *testing.T) { require := require.New(t) - s := NewSet() + s := newSet() require.NoError(s.Add(ids.GenerateTestNodeID(), nil, ids.Empty, 1)) nodeID := ids.GenerateTestNodeID() require.NoError(s.Add(nodeID, nil, ids.Empty, 1)) - err := s.AddWeight(nodeID, math.MaxUint64-1) - require.ErrorIs(err, safemath.ErrOverflow) + require.NoError(s.AddWeight(nodeID, math.MaxUint64-1)) - require.Equal(uint64(2), s.Weight()) + _, err := s.TotalWeight() + require.ErrorIs(err, errTotalWeightNotUint64) } func TestSetGetWeight(t *testing.T) { require := require.New(t) - s := NewSet() + s := newSet() nodeID := ids.GenerateTestNodeID() require.Zero(s.GetWeight(nodeID)) @@ -103,33 +83,22 @@ func TestSetSubsetWeight(t *testing.T) { subset := set.Of(nodeID0, nodeID1) - s := NewSet() + s := newSet() require.NoError(s.Add(nodeID0, nil, ids.Empty, weight0)) require.NoError(s.Add(nodeID1, nil, ids.Empty, weight1)) require.NoError(s.Add(nodeID2, nil, ids.Empty, weight2)) expectedWeight := weight0 + weight1 - subsetWeight := s.SubsetWeight(subset) + subsetWeight, err := s.SubsetWeight(subset) + require.NoError(err) require.Equal(expectedWeight, subsetWeight) } -func TestSetRemoveWeightZeroWeight(t *testing.T) { - require := require.New(t) - - s := NewSet() - - nodeID := ids.GenerateTestNodeID() - require.NoError(s.Add(nodeID, nil, ids.Empty, 1)) - - err := s.RemoveWeight(nodeID, 0) - require.ErrorIs(err, errZeroWeight) -} - func TestSetRemoveWeightMissingValidator(t *testing.T) { require := require.New(t) - s := NewSet() + s := newSet() require.NoError(s.Add(ids.GenerateTestNodeID(), nil, ids.Empty, 1)) @@ -140,7 +109,7 @@ func TestSetRemoveWeightMissingValidator(t *testing.T) { func TestSetRemoveWeightUnderflow(t *testing.T) { require := require.New(t) - s := NewSet() + s := newSet() require.NoError(s.Add(ids.GenerateTestNodeID(), nil, ids.Empty, 1)) @@ -150,13 +119,15 @@ func TestSetRemoveWeightUnderflow(t *testing.T) { err := s.RemoveWeight(nodeID, 2) require.ErrorIs(err, safemath.ErrUnderflow) - require.Equal(uint64(2), s.Weight()) + totalWeight, err := s.TotalWeight() + require.NoError(err) + require.Equal(uint64(2), totalWeight) } func TestSetGet(t *testing.T) { require := require.New(t) - s := NewSet() + s := newSet() nodeID := ids.GenerateTestNodeID() _, ok := s.Get(nodeID) @@ -184,27 +155,16 @@ func TestSetGet(t *testing.T) { require.Equal(nodeID, vdr1.NodeID) require.Equal(pk, vdr1.PublicKey) require.Equal(uint64(2), vdr1.Weight) -} - -func TestSetContains(t *testing.T) { - require := require.New(t) - - s := NewSet() - - nodeID := ids.GenerateTestNodeID() - require.False(s.Contains(nodeID)) - require.NoError(s.Add(nodeID, nil, ids.Empty, 1)) - - require.True(s.Contains(nodeID)) - require.NoError(s.RemoveWeight(nodeID, 1)) - require.False(s.Contains(nodeID)) + require.NoError(s.RemoveWeight(nodeID, 2)) + _, ok = s.Get(nodeID) + require.False(ok) } func TestSetLen(t *testing.T) { require := require.New(t) - s := NewSet() + s := newSet() len := s.Len() require.Zero(len) @@ -235,7 +195,7 @@ func TestSetLen(t *testing.T) { func TestSetMap(t *testing.T) { require := require.New(t) - s := NewSet() + s := newSet() m := s.Map() require.Empty(m) @@ -318,12 +278,13 @@ func TestSetWeight(t *testing.T) { vdr1 := ids.NodeID{2} weight1 := uint64(123) - s := NewSet() + s := newSet() require.NoError(s.Add(vdr0, nil, ids.Empty, weight0)) require.NoError(s.Add(vdr1, nil, ids.Empty, weight1)) - setWeight := s.Weight() + setWeight, err := s.TotalWeight() + require.NoError(err) expectedWeight := weight0 + weight1 require.Equal(expectedWeight, setWeight) } @@ -331,7 +292,7 @@ func TestSetWeight(t *testing.T) { func TestSetSample(t *testing.T) { require := require.New(t) - s := NewSet() + s := newSet() sampled, err := s.Sample(0) require.NoError(err) @@ -376,7 +337,7 @@ func TestSetString(t *testing.T) { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, } - s := NewSet() + s := newSet() require.NoError(s.Add(nodeID0, nil, ids.Empty, 1)) require.NoError(s.Add(nodeID1, nil, ids.Empty, math.MaxInt64-1)) @@ -431,8 +392,9 @@ func TestSetAddCallback(t *testing.T) { txID0 := ids.GenerateTestID() weight0 := uint64(1) - s := NewSet() + s := newSet() callCount := 0 + require.False(s.HasCallbackRegistered()) s.RegisterCallbackListener(&callbackListener{ t: t, onAdd: func(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) { @@ -443,6 +405,7 @@ func TestSetAddCallback(t *testing.T) { callCount++ }, }) + require.True(s.HasCallbackRegistered()) require.NoError(s.Add(nodeID0, pk0, txID0, weight0)) require.Equal(1, callCount) } @@ -455,10 +418,11 @@ func TestSetAddWeightCallback(t *testing.T) { weight0 := uint64(1) weight1 := uint64(93) - s := NewSet() + s := newSet() require.NoError(s.Add(nodeID0, nil, txID0, weight0)) callCount := 0 + require.False(s.HasCallbackRegistered()) s.RegisterCallbackListener(&callbackListener{ t: t, onAdd: func(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) { @@ -475,6 +439,7 @@ func TestSetAddWeightCallback(t *testing.T) { callCount++ }, }) + require.True(s.HasCallbackRegistered()) require.NoError(s.AddWeight(nodeID0, weight1)) require.Equal(2, callCount) } @@ -487,10 +452,11 @@ func TestSetRemoveWeightCallback(t *testing.T) { weight0 := uint64(93) weight1 := uint64(92) - s := NewSet() + s := newSet() require.NoError(s.Add(nodeID0, nil, txID0, weight0)) callCount := 0 + require.False(s.HasCallbackRegistered()) s.RegisterCallbackListener(&callbackListener{ t: t, onAdd: func(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) { @@ -507,6 +473,7 @@ func TestSetRemoveWeightCallback(t *testing.T) { callCount++ }, }) + require.True(s.HasCallbackRegistered()) require.NoError(s.RemoveWeight(nodeID0, weight1)) require.Equal(2, callCount) } @@ -518,10 +485,11 @@ func TestSetValidatorRemovedCallback(t *testing.T) { txID0 := ids.GenerateTestID() weight0 := uint64(93) - s := NewSet() + s := newSet() require.NoError(s.Add(nodeID0, nil, txID0, weight0)) callCount := 0 + require.False(s.HasCallbackRegistered()) s.RegisterCallbackListener(&callbackListener{ t: t, onAdd: func(nodeID ids.NodeID, pk *bls.PublicKey, txID ids.ID, weight uint64) { @@ -537,6 +505,7 @@ func TestSetValidatorRemovedCallback(t *testing.T) { callCount++ }, }) + require.True(s.HasCallbackRegistered()) require.NoError(s.RemoveWeight(nodeID0, weight0)) require.Equal(2, callCount) } diff --git a/staking/large_rsa_key.cert b/staking/large_rsa_key.cert new file mode 100644 index 000000000000..45e60a6b7991 Binary files /dev/null and b/staking/large_rsa_key.cert differ diff --git a/staking/large_rsa_key.sig b/staking/large_rsa_key.sig new file mode 100644 index 000000000000..61000a9903cf Binary files /dev/null and b/staking/large_rsa_key.sig differ diff --git a/staking/verify_test.go b/staking/verify_test.go new file mode 100644 index 000000000000..e7cee91c1b43 --- /dev/null +++ b/staking/verify_test.go @@ -0,0 +1,30 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package staking + +import ( + "testing" + + _ "embed" + + "github.com/stretchr/testify/require" +) + +var ( + //go:embed large_rsa_key.cert + largeRSAKeyCert []byte + //go:embed large_rsa_key.sig + largeRSAKeySig []byte +) + +func TestCheckSignatureLargePublicKey(t *testing.T) { + require := require.New(t) + + cert, err := ParseCertificate(largeRSAKeyCert) + require.NoError(err) + + msg := []byte("TODO: put something clever") + err = CheckSignature(cert, msg, largeRSAKeySig) + require.ErrorIs(err, ErrInvalidRSAPublicKey) +} diff --git a/utils/heap/map.go b/utils/heap/map.go new file mode 100644 index 000000000000..dbe06c06446e --- /dev/null +++ b/utils/heap/map.go @@ -0,0 +1,132 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package heap + +import ( + "container/heap" + + "github.com/ava-labs/avalanchego/utils" +) + +var _ heap.Interface = (*indexedQueue[int, int])(nil) + +func MapValues[K comparable, V any](m Map[K, V]) []V { + result := make([]V, 0, m.Len()) + for _, e := range m.queue.entries { + result = append(result, e.v) + } + return result +} + +// NewMap returns a heap without duplicates ordered by its values +func NewMap[K comparable, V any](less func(a, b V) bool) Map[K, V] { + return Map[K, V]{ + queue: &indexedQueue[K, V]{ + queue: queue[entry[K, V]]{ + less: func(a, b entry[K, V]) bool { + return less(a.v, b.v) + }, + }, + index: make(map[K]int), + }, + } +} + +type Map[K comparable, V any] struct { + queue *indexedQueue[K, V] +} + +// Push returns the evicted previous value if present +func (m *Map[K, V]) Push(k K, v V) (V, bool) { + if i, ok := m.queue.index[k]; ok { + prev := m.queue.entries[i] + m.queue.entries[i].v = v + heap.Fix(m.queue, i) + return prev.v, true + } + + heap.Push(m.queue, entry[K, V]{k: k, v: v}) + return utils.Zero[V](), false +} + +func (m *Map[K, V]) Pop() (K, V, bool) { + if m.Len() == 0 { + return utils.Zero[K](), utils.Zero[V](), false + } + + popped := heap.Pop(m.queue).(entry[K, V]) + return popped.k, popped.v, true +} + +func (m *Map[K, V]) Peek() (K, V, bool) { + if m.Len() == 0 { + return utils.Zero[K](), utils.Zero[V](), false + } + + entry := m.queue.entries[0] + return entry.k, entry.v, true +} + +func (m *Map[K, V]) Len() int { + return m.queue.Len() +} + +func (m *Map[K, V]) Remove(k K) (V, bool) { + if i, ok := m.queue.index[k]; ok { + removed := heap.Remove(m.queue, i).(entry[K, V]) + return removed.v, true + } + return utils.Zero[V](), false +} + +func (m *Map[K, V]) Contains(k K) bool { + _, ok := m.queue.index[k] + return ok +} + +func (m *Map[K, V]) Get(k K) (V, bool) { + if i, ok := m.queue.index[k]; ok { + got := m.queue.entries[i] + return got.v, true + } + return utils.Zero[V](), false +} + +func (m *Map[K, V]) Fix(k K) { + if i, ok := m.queue.index[k]; ok { + heap.Fix(m.queue, i) + } +} + +type indexedQueue[K comparable, V any] struct { + queue[entry[K, V]] + index map[K]int +} + +func (h *indexedQueue[K, V]) Swap(i, j int) { + h.entries[i], h.entries[j] = h.entries[j], h.entries[i] + h.index[h.entries[i].k], h.index[h.entries[j].k] = i, j +} + +func (h *indexedQueue[K, V]) Push(x any) { + entry := x.(entry[K, V]) + h.entries = append(h.entries, entry) + h.index[entry.k] = len(h.index) +} + +func (h *indexedQueue[K, V]) Pop() any { + end := len(h.entries) - 1 + + popped := h.entries[end] + h.entries[end] = entry[K, V]{} + h.entries = h.entries[:end] + + delete(h.index, popped.k) + return popped +} + +type entry[K any, V any] struct { + k K + v V +} diff --git a/utils/heap/map_test.go b/utils/heap/map_test.go new file mode 100644 index 000000000000..cc774a5a50df --- /dev/null +++ b/utils/heap/map_test.go @@ -0,0 +1,96 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package heap + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMap(t *testing.T) { + tests := []struct { + name string + setup func(h Map[string, int]) + expected []entry[string, int] + }{ + { + name: "only push", + setup: func(h Map[string, int]) { + h.Push("a", 1) + h.Push("b", 2) + h.Push("c", 3) + }, + expected: []entry[string, int]{ + {k: "a", v: 1}, + {k: "b", v: 2}, + {k: "c", v: 3}, + }, + }, + { + name: "out of order pushes", + setup: func(h Map[string, int]) { + h.Push("a", 1) + h.Push("e", 5) + h.Push("b", 2) + h.Push("d", 4) + h.Push("c", 3) + }, + expected: []entry[string, int]{ + {"a", 1}, + {"b", 2}, + {"c", 3}, + {"d", 4}, + {"e", 5}, + }, + }, + { + name: "push and pop", + setup: func(m Map[string, int]) { + m.Push("a", 1) + m.Push("e", 5) + m.Push("b", 2) + m.Push("d", 4) + m.Push("c", 3) + m.Pop() + m.Pop() + m.Pop() + }, + expected: []entry[string, int]{ + {"d", 4}, + {"e", 5}, + }, + }, + { + name: "duplicate key is overridden", + setup: func(h Map[string, int]) { + h.Push("a", 1) + h.Push("a", 2) + }, + expected: []entry[string, int]{ + {k: "a", v: 2}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + h := NewMap[string, int](func(a, b int) bool { + return a < b + }) + + tt.setup(h) + + require.Equal(len(tt.expected), h.Len()) + for _, expected := range tt.expected { + k, v, ok := h.Pop() + require.True(ok) + require.Equal(expected.k, k) + require.Equal(expected.v, v) + } + }) + } +} diff --git a/utils/heap/set.go b/utils/heap/set.go new file mode 100644 index 000000000000..15fab421b278 --- /dev/null +++ b/utils/heap/set.go @@ -0,0 +1,48 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package heap + +// NewSet returns a heap without duplicates ordered by its values +func NewSet[T comparable](less func(a, b T) bool) Set[T] { + return Set[T]{ + set: NewMap[T, T](less), + } +} + +type Set[T comparable] struct { + set Map[T, T] +} + +// Push returns if the entry was added +func (s Set[T]) Push(t T) bool { + _, hadValue := s.set.Push(t, t) + return !hadValue +} + +func (s Set[T]) Pop() (T, bool) { + pop, _, ok := s.set.Pop() + return pop, ok +} + +func (s Set[T]) Peek() (T, bool) { + peek, _, ok := s.set.Peek() + return peek, ok +} + +func (s Set[T]) Len() int { + return s.set.Len() +} + +func (s Set[T]) Remove(t T) bool { + _, existed := s.set.Remove(t) + return existed +} + +func (s Set[T]) Fix(t T) { + s.set.Fix(t) +} + +func (s Set[T]) Contains(t T) bool { + return s.set.Contains(t) +} diff --git a/utils/heap/set_test.go b/utils/heap/set_test.go new file mode 100644 index 000000000000..fd93f996d5ff --- /dev/null +++ b/utils/heap/set_test.go @@ -0,0 +1,72 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package heap + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSet(t *testing.T) { + tests := []struct { + name string + setup func(h Set[int]) + expected []int + }{ + { + name: "only push", + setup: func(h Set[int]) { + h.Push(1) + h.Push(2) + h.Push(3) + }, + expected: []int{1, 2, 3}, + }, + { + name: "out of order pushes", + setup: func(h Set[int]) { + h.Push(1) + h.Push(5) + h.Push(2) + h.Push(4) + h.Push(3) + }, + expected: []int{1, 2, 3, 4, 5}, + }, + { + name: "push and pop", + setup: func(h Set[int]) { + h.Push(1) + h.Push(5) + h.Push(2) + h.Push(4) + h.Push(3) + h.Pop() + h.Pop() + h.Pop() + }, + expected: []int{4, 5}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + h := NewSet[int](func(a, b int) bool { + return a < b + }) + + tt.setup(h) + + require.Equal(len(tt.expected), h.Len()) + for _, expected := range tt.expected { + got, ok := h.Pop() + require.True(ok) + require.Equal(expected, got) + } + }) + } +} diff --git a/utils/math/averager_heap.go b/utils/math/averager_heap.go index b09393b48803..070593f0eeb8 100644 --- a/utils/math/averager_heap.go +++ b/utils/math/averager_heap.go @@ -4,16 +4,13 @@ package math import ( - "container/heap" - "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils/heap" ) -var ( - _ AveragerHeap = averagerHeap{} - _ heap.Interface = (*averagerHeapBackend)(nil) -) +var _ AveragerHeap = (*averagerHeap)(nil) +// TODO replace this interface with utils/heap // AveragerHeap maintains a heap of the averagers. type AveragerHeap interface { // Add the average to the heap. If [nodeID] is already in the heap, the @@ -33,113 +30,36 @@ type AveragerHeap interface { Len() int } -type averagerHeapEntry struct { - nodeID ids.NodeID - averager Averager - index int -} - -type averagerHeapBackend struct { - isMaxHeap bool - nodeIDToEntry map[ids.NodeID]*averagerHeapEntry - entries []*averagerHeapEntry -} - type averagerHeap struct { - b *averagerHeapBackend -} - -// NewMinAveragerHeap returns a new empty min heap. The returned heap is not -// thread safe. -func NewMinAveragerHeap() AveragerHeap { - return averagerHeap{b: &averagerHeapBackend{ - nodeIDToEntry: make(map[ids.NodeID]*averagerHeapEntry), - }} + heap heap.Map[ids.NodeID, Averager] } // NewMaxAveragerHeap returns a new empty max heap. The returned heap is not // thread safe. func NewMaxAveragerHeap() AveragerHeap { - return averagerHeap{b: &averagerHeapBackend{ - isMaxHeap: true, - nodeIDToEntry: make(map[ids.NodeID]*averagerHeapEntry), - }} + return averagerHeap{ + heap: heap.NewMap[ids.NodeID, Averager](func(a, b Averager) bool { + return a.Read() > b.Read() + }), + } } func (h averagerHeap) Add(nodeID ids.NodeID, averager Averager) (Averager, bool) { - if e, exists := h.b.nodeIDToEntry[nodeID]; exists { - oldAverager := e.averager - e.averager = averager - heap.Fix(h.b, e.index) - return oldAverager, true - } - - heap.Push(h.b, &averagerHeapEntry{ - nodeID: nodeID, - averager: averager, - }) - return nil, false + return h.heap.Push(nodeID, averager) } func (h averagerHeap) Remove(nodeID ids.NodeID) (Averager, bool) { - e, exists := h.b.nodeIDToEntry[nodeID] - if !exists { - return nil, false - } - heap.Remove(h.b, e.index) - return e.averager, true + return h.heap.Remove(nodeID) } func (h averagerHeap) Pop() (ids.NodeID, Averager, bool) { - if len(h.b.entries) == 0 { - return ids.EmptyNodeID, nil, false - } - e := h.b.entries[0] - heap.Pop(h.b) - return e.nodeID, e.averager, true + return h.heap.Pop() } func (h averagerHeap) Peek() (ids.NodeID, Averager, bool) { - if len(h.b.entries) == 0 { - return ids.EmptyNodeID, nil, false - } - e := h.b.entries[0] - return e.nodeID, e.averager, true + return h.heap.Peek() } func (h averagerHeap) Len() int { - return len(h.b.entries) -} - -func (h *averagerHeapBackend) Len() int { - return len(h.entries) -} - -func (h *averagerHeapBackend) Less(i, j int) bool { - if h.isMaxHeap { - return h.entries[i].averager.Read() > h.entries[j].averager.Read() - } - return h.entries[i].averager.Read() < h.entries[j].averager.Read() -} - -func (h *averagerHeapBackend) Swap(i, j int) { - h.entries[i], h.entries[j] = h.entries[j], h.entries[i] - h.entries[i].index = i - h.entries[j].index = j -} - -func (h *averagerHeapBackend) Push(x interface{}) { - e := x.(*averagerHeapEntry) - e.index = len(h.entries) - h.nodeIDToEntry[e.nodeID] = e - h.entries = append(h.entries, e) -} - -func (h *averagerHeapBackend) Pop() interface{} { - newLen := len(h.entries) - 1 - e := h.entries[newLen] - h.entries[newLen] = nil - delete(h.nodeIDToEntry, e.nodeID) - h.entries = h.entries[:newLen] - return e + return h.heap.Len() } diff --git a/utils/math/averager_heap_test.go b/utils/math/averager_heap_test.go index a979612952c4..0586eb77947e 100644 --- a/utils/math/averager_heap_test.go +++ b/utils/math/averager_heap_test.go @@ -20,19 +20,13 @@ func TestAveragerHeap(t *testing.T) { n2 := ids.GenerateTestNodeID() tests := []struct { - h AveragerHeap - a []Averager + name string + h AveragerHeap + a []Averager }{ { - h: NewMinAveragerHeap(), - a: []Averager{ - NewAverager(0, time.Second, time.Now()), - NewAverager(1, time.Second, time.Now()), - NewAverager(2, time.Second, time.Now()), - }, - }, - { - h: NewMaxAveragerHeap(), + name: "max heap", + h: NewMaxAveragerHeap(), a: []Averager{ NewAverager(0, time.Second, time.Now()), NewAverager(-1, time.Second, time.Now()), @@ -42,67 +36,69 @@ func TestAveragerHeap(t *testing.T) { } for _, test := range tests { - _, _, ok := test.h.Pop() - require.False(ok) + t.Run(test.name, func(t *testing.T) { + _, _, ok := test.h.Pop() + require.False(ok) - _, _, ok = test.h.Peek() - require.False(ok) + _, _, ok = test.h.Peek() + require.False(ok) - l := test.h.Len() - require.Zero(l) + l := test.h.Len() + require.Zero(l) - _, ok = test.h.Add(n1, test.a[1]) - require.False(ok) + _, ok = test.h.Add(n1, test.a[1]) + require.False(ok) - n, a, ok := test.h.Peek() - require.True(ok) - require.Equal(n1, n) - require.Equal(test.a[1], a) + n, a, ok := test.h.Peek() + require.True(ok) + require.Equal(n1, n) + require.Equal(test.a[1], a) - l = test.h.Len() - require.Equal(1, l) + l = test.h.Len() + require.Equal(1, l) - a, ok = test.h.Add(n1, test.a[1]) - require.True(ok) - require.Equal(test.a[1], a) + a, ok = test.h.Add(n1, test.a[1]) + require.True(ok) + require.Equal(test.a[1], a) - l = test.h.Len() - require.Equal(1, l) + l = test.h.Len() + require.Equal(1, l) - _, ok = test.h.Add(n0, test.a[0]) - require.False(ok) + _, ok = test.h.Add(n0, test.a[0]) + require.False(ok) - _, ok = test.h.Add(n2, test.a[2]) - require.False(ok) + _, ok = test.h.Add(n2, test.a[2]) + require.False(ok) - n, a, ok = test.h.Pop() - require.True(ok) - require.Equal(n0, n) - require.Equal(test.a[0], a) + n, a, ok = test.h.Pop() + require.True(ok) + require.Equal(n0, n) + require.Equal(test.a[0], a) - l = test.h.Len() - require.Equal(2, l) + l = test.h.Len() + require.Equal(2, l) - a, ok = test.h.Remove(n1) - require.True(ok) - require.Equal(test.a[1], a) + a, ok = test.h.Remove(n1) + require.True(ok) + require.Equal(test.a[1], a) - l = test.h.Len() - require.Equal(1, l) + l = test.h.Len() + require.Equal(1, l) - _, ok = test.h.Remove(n1) - require.False(ok) + _, ok = test.h.Remove(n1) + require.False(ok) - l = test.h.Len() - require.Equal(1, l) + l = test.h.Len() + require.Equal(1, l) - a, ok = test.h.Add(n2, test.a[0]) - require.True(ok) - require.Equal(test.a[2], a) + a, ok = test.h.Add(n2, test.a[0]) + require.True(ok) + require.Equal(test.a[2], a) - n, a, ok = test.h.Pop() - require.True(ok) - require.Equal(n2, n) - require.Equal(test.a[0], a) + n, a, ok = test.h.Pop() + require.True(ok) + require.Equal(n2, n) + require.Equal(test.a[0], a) + }) } } diff --git a/utils/timer/adaptive_timeout_manager.go b/utils/timer/adaptive_timeout_manager.go index 0a0a299cd1da..95b284a48c5f 100644 --- a/utils/timer/adaptive_timeout_manager.go +++ b/utils/timer/adaptive_timeout_manager.go @@ -4,7 +4,6 @@ package timer import ( - "container/heap" "errors" "fmt" "sync" @@ -13,6 +12,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils/heap" "github.com/ava-labs/avalanchego/utils/math" "github.com/ava-labs/avalanchego/utils/timer/mockable" "github.com/ava-labs/avalanchego/utils/wrappers" @@ -24,12 +24,10 @@ var ( errInitialTimeoutBelowMinimum = errors.New("initial timeout cannot be less than minimum timeout") errTooSmallTimeoutCoefficient = errors.New("timeout coefficient must be >= 1") - _ heap.Interface = (*timeoutQueue)(nil) _ AdaptiveTimeoutManager = (*adaptiveTimeoutManager)(nil) ) type adaptiveTimeout struct { - index int // Index in the wait queue id ids.RequestID // Unique ID of this timeout handler func() // Function to execute if timed out duration time.Duration // How long this timeout was set for @@ -37,38 +35,6 @@ type adaptiveTimeout struct { measureLatency bool // Whether this request should impact latency } -type timeoutQueue []*adaptiveTimeout - -func (tq timeoutQueue) Len() int { - return len(tq) -} - -func (tq timeoutQueue) Less(i, j int) bool { - return tq[i].deadline.Before(tq[j].deadline) -} - -func (tq timeoutQueue) Swap(i, j int) { - tq[i], tq[j] = tq[j], tq[i] - tq[i].index = i - tq[j].index = j -} - -// Push adds an item to this priority queue. x must have type *adaptiveTimeout -func (tq *timeoutQueue) Push(x interface{}) { - item := x.(*adaptiveTimeout) - item.index = len(*tq) - *tq = append(*tq, item) -} - -// Pop returns the next item in this queue -func (tq *timeoutQueue) Pop() interface{} { - n := len(*tq) - item := (*tq)[n-1] - (*tq)[n-1] = nil // make sure the item is freed from memory - *tq = (*tq)[:n-1] - return item -} - // AdaptiveTimeoutConfig contains the parameters provided to the // adaptive timeout manager. type AdaptiveTimeoutConfig struct { @@ -120,8 +86,7 @@ type adaptiveTimeoutManager struct { minimumTimeout time.Duration maximumTimeout time.Duration currentTimeout time.Duration // Amount of time before a timeout - timeoutMap map[ids.RequestID]*adaptiveTimeout - timeoutQueue timeoutQueue + timeoutHeap heap.Map[ids.RequestID, *adaptiveTimeout] timer *Timer // Timer that will fire to clear the timeouts } @@ -166,7 +131,9 @@ func NewAdaptiveTimeoutManager( maximumTimeout: config.MaximumTimeout, currentTimeout: config.InitialTimeout, timeoutCoefficient: config.TimeoutCoefficient, - timeoutMap: make(map[ids.RequestID]*adaptiveTimeout), + timeoutHeap: heap.NewMap[ids.RequestID, *adaptiveTimeout](func(a, b *adaptiveTimeout) bool { + return a.deadline.Before(b.deadline) + }), } tm.timer = NewTimer(tm.timeout) tm.averager = math.NewAverager(float64(config.InitialTimeout), config.TimeoutHalflife, tm.clock.Time()) @@ -215,9 +182,8 @@ func (tm *adaptiveTimeoutManager) put(id ids.RequestID, measureLatency bool, han deadline: now.Add(tm.currentTimeout), measureLatency: measureLatency, } - tm.timeoutMap[id] = timeout - tm.numPendingTimeouts.Set(float64(len(tm.timeoutMap))) - heap.Push(&tm.timeoutQueue, timeout) + tm.timeoutHeap.Push(id, timeout) + tm.numPendingTimeouts.Set(float64(tm.timeoutHeap.Len())) tm.setNextTimeoutTime() } @@ -231,24 +197,18 @@ func (tm *adaptiveTimeoutManager) Remove(id ids.RequestID) { // Assumes [tm.lock] is held func (tm *adaptiveTimeoutManager) remove(id ids.RequestID, now time.Time) { - timeout, exists := tm.timeoutMap[id] + // Observe the response time to update average network response time. + timeout, exists := tm.timeoutHeap.Remove(id) if !exists { return } - // Observe the response time to update average network response time. if timeout.measureLatency { timeoutRegisteredAt := timeout.deadline.Add(-1 * timeout.duration) latency := now.Sub(timeoutRegisteredAt) tm.observeLatencyAndUpdateTimeout(latency, now) } - - // Remove the timeout from the map - delete(tm.timeoutMap, id) - tm.numPendingTimeouts.Set(float64(len(tm.timeoutMap))) - - // Remove the timeout from the queue - heap.Remove(&tm.timeoutQueue, timeout.index) + tm.numPendingTimeouts.Set(float64(tm.timeoutHeap.Len())) } // Assumes [tm.lock] is not held. @@ -300,11 +260,10 @@ func (tm *adaptiveTimeoutManager) observeLatencyAndUpdateTimeout(latency time.Du // returns nil. // Assumes [tm.lock] is held func (tm *adaptiveTimeoutManager) getNextTimeoutHandler(now time.Time) func() { - if tm.timeoutQueue.Len() == 0 { + _, nextTimeout, ok := tm.timeoutHeap.Peek() + if !ok { return nil } - - nextTimeout := tm.timeoutQueue[0] if nextTimeout.deadline.After(now) { return nil } @@ -315,14 +274,14 @@ func (tm *adaptiveTimeoutManager) getNextTimeoutHandler(now time.Time) func() { // Calculate the time of the next timeout and set // the timer to fire at that time. func (tm *adaptiveTimeoutManager) setNextTimeoutTime() { - if tm.timeoutQueue.Len() == 0 { + _, nextTimeout, ok := tm.timeoutHeap.Peek() + if !ok { // There are no pending timeouts tm.timer.Cancel() return } now := tm.clock.Time() - nextTimeout := tm.timeoutQueue[0] timeToNextTimeout := nextTimeout.deadline.Sub(now) tm.timer.SetTimeoutIn(timeToNextTimeout) } diff --git a/utils/wrappers/errors.go b/utils/wrappers/errors.go index 1f0f19846ffe..641734da16c0 100644 --- a/utils/wrappers/errors.go +++ b/utils/wrappers/errors.go @@ -3,10 +3,6 @@ package wrappers -import "strings" - -var _ error = (*aggregate)(nil) - type Errs struct{ Err error } func (errs *Errs) Errored() bool { @@ -23,28 +19,3 @@ func (errs *Errs) Add(errors ...error) { } } } - -// NewAggregate returns an aggregate error from a list of errors -func NewAggregate(errs []error) error { - err := &aggregate{errs} - if len(err.Errors()) == 0 { - return nil - } - return err -} - -type aggregate struct{ errs []error } - -// Error returns the slice of errors with comma separated messsages wrapped in brackets -// [ error string 0 ], [ error string 1 ] ... -func (a *aggregate) Error() string { - errString := make([]string, len(a.errs)) - for i, err := range a.errs { - errString[i] = "[" + err.Error() + "]" - } - return strings.Join(errString, ",") -} - -func (a *aggregate) Errors() []error { - return a.errs -} diff --git a/version/compatibility.json b/version/compatibility.json index 2c43f7ef06d9..65bf1d393f2b 100644 --- a/version/compatibility.json +++ b/version/compatibility.json @@ -1,5 +1,6 @@ { "29": [ + "v1.1.14", "v1.1.13" ], "28": [ diff --git a/version/constants.go b/version/constants.go index 3365ef8b0a49..859359224cee 100644 --- a/version/constants.go +++ b/version/constants.go @@ -35,7 +35,7 @@ var ( Current = &Semantic{ Major: 1, Minor: 1, - Patch: 13, + Patch: 14, } CurrentApp = &Application{ Major: Current.Major, diff --git a/vms/platformvm/block/builder/helpers_test.go b/vms/platformvm/block/builder/helpers_test.go index 551b6799be16..58c71418b6e2 100644 --- a/vms/platformvm/block/builder/helpers_test.go +++ b/vms/platformvm/block/builder/helpers_test.go @@ -282,13 +282,10 @@ func defaultCtx(db database.Database) (*snow.Context, *mutableSharedMemory) { } func defaultConfig() *config.Config { - vdrs := validators.NewManager() - primaryVdrs := validators.NewSet() - _ = vdrs.Add(constants.PrimaryNetworkID, primaryVdrs) return &config.Config{ Chains: chains.TestManager, UptimeLockedCalculator: uptime.NewLockedCalculator(), - Validators: vdrs, + Validators: validators.NewManager(), TxFee: defaultTxFee, CreateSubnetTxFee: 100 * defaultTxFee, CreateBlockchainTxFee: 100 * defaultTxFee, @@ -409,11 +406,10 @@ func buildGenesisTest(t *testing.T, ctx *snow.Context) []byte { } func shutdownEnvironment(env *environment) error { + env.Builder.Shutdown() + if env.isBootstrapped.Get() { - validatorIDs, err := validators.NodeIDs(env.config.Validators, constants.PrimaryNetworkID) - if err != nil { - return err - } + validatorIDs := env.config.Validators.GetValidatorIDs(constants.PrimaryNetworkID) if err := env.uptimes.StopTracking(validatorIDs, constants.PrimaryNetworkID); err != nil { return err diff --git a/vms/platformvm/block/builder/main_test.go b/vms/platformvm/block/builder/main_test.go new file mode 100644 index 000000000000..01135c523738 --- /dev/null +++ b/vms/platformvm/block/builder/main_test.go @@ -0,0 +1,14 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package builder + +import ( + "testing" + + "go.uber.org/goleak" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} diff --git a/vms/platformvm/block/codec.go b/vms/platformvm/block/codec.go index f2a004599132..def29219d08e 100644 --- a/vms/platformvm/block/codec.go +++ b/vms/platformvm/block/codec.go @@ -46,6 +46,7 @@ func init() { RegisterApricotBlockTypes(c), txs.RegisterUnsignedTxsTypes(c), RegisterBanffBlockTypes(c), + txs.RegisterDUnsignedTxsTypes(c), ) } errs.Add( diff --git a/vms/platformvm/block/executor/helpers_test.go b/vms/platformvm/block/executor/helpers_test.go index 0d36d4ceb739..499d82669ad9 100644 --- a/vms/platformvm/block/executor/helpers_test.go +++ b/vms/platformvm/block/executor/helpers_test.go @@ -323,13 +323,10 @@ func defaultCtx(db database.Database) *snow.Context { } func defaultConfig() *config.Config { - vdrs := validators.NewManager() - primaryVdrs := validators.NewSet() - _ = vdrs.Add(constants.PrimaryNetworkID, primaryVdrs) return &config.Config{ Chains: chains.TestManager, UptimeLockedCalculator: uptime.NewLockedCalculator(), - Validators: vdrs, + Validators: validators.NewManager(), TxFee: defaultTxFee, CreateSubnetTxFee: 100 * defaultTxFee, CreateBlockchainTxFee: 100 * defaultTxFee, @@ -463,10 +460,7 @@ func shutdownEnvironment(t *environment) error { } if t.isBootstrapped.Get() { - validatorIDs, err := validators.NodeIDs(t.config.Validators, constants.PrimaryNetworkID) - if err != nil { - return err - } + validatorIDs := t.config.Validators.GetValidatorIDs(constants.PrimaryNetworkID) if err := t.uptimes.StopTracking(validatorIDs, constants.PrimaryNetworkID); err != nil { return err diff --git a/vms/platformvm/block/executor/proposal_block_test.go b/vms/platformvm/block/executor/proposal_block_test.go index 497319ab28fa..aba5a0d15730 100644 --- a/vms/platformvm/block/executor/proposal_block_test.go +++ b/vms/platformvm/block/executor/proposal_block_test.go @@ -26,7 +26,6 @@ import ( "github.com/ava-labs/avalanchego/database" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow/consensus/snowman" - "github.com/ava-labs/avalanchego/snow/validators" "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/crypto/secp256k1" "github.com/ava-labs/avalanchego/vms/components/avax" @@ -592,7 +591,6 @@ func TestBanffProposalBlockUpdateStakers(t *testing.T) { subnetID := testSubnet1.ID() env.config.TrackedSubnets.Add(subnetID) - env.config.Validators.Add(subnetID, validators.NewSet()) for _, staker := range test.stakers { tx, err := env.txBuilder.NewAddValidatorTx( @@ -709,20 +707,24 @@ func TestBanffProposalBlockUpdateStakers(t *testing.T) { case pending: _, err := env.state.GetPendingValidator(constants.PrimaryNetworkID, stakerNodeID) require.NoError(err) - require.False(validators.Contains(env.config.Validators, constants.PrimaryNetworkID, stakerNodeID)) + _, ok := env.config.Validators.GetValidator(constants.PrimaryNetworkID, stakerNodeID) + require.False(ok) case current: _, err := env.state.GetCurrentValidator(constants.PrimaryNetworkID, stakerNodeID) require.NoError(err) - require.True(validators.Contains(env.config.Validators, constants.PrimaryNetworkID, stakerNodeID)) + _, ok := env.config.Validators.GetValidator(constants.PrimaryNetworkID, stakerNodeID) + require.True(ok) } } for stakerNodeID, status := range test.expectedSubnetStakers { switch status { case pending: - require.False(validators.Contains(env.config.Validators, subnetID, stakerNodeID)) + _, ok := env.config.Validators.GetValidator(subnetID, stakerNodeID) + require.False(ok) case current: - require.True(validators.Contains(env.config.Validators, subnetID, stakerNodeID)) + _, ok := env.config.Validators.GetValidator(subnetID, stakerNodeID) + require.True(ok) } } }) @@ -739,7 +741,6 @@ func TestBanffProposalBlockRemoveSubnetValidator(t *testing.T) { subnetID := testSubnet1.ID() env.config.TrackedSubnets.Add(subnetID) - env.config.Validators.Add(subnetID, validators.NewSet()) // Add a subnet validator to the staker set subnetValidatorNodeID := ids.NodeID(preFundedKeys[0].PublicKey().Address()) @@ -861,8 +862,10 @@ func TestBanffProposalBlockRemoveSubnetValidator(t *testing.T) { // Check VM Validators are removed successfully require.NoError(propBlk.Accept(context.Background())) require.NoError(commitBlk.Accept(context.Background())) - require.False(validators.Contains(env.config.Validators, subnetID, subnetVdr2NodeID)) - require.False(validators.Contains(env.config.Validators, subnetID, subnetValidatorNodeID)) + _, ok := env.config.Validators.GetValidator(subnetID, subnetVdr2NodeID) + require.False(ok) + _, ok = env.config.Validators.GetValidator(subnetID, subnetValidatorNodeID) + require.False(ok) } func TestBanffProposalBlockTrackedSubnet(t *testing.T) { @@ -878,7 +881,6 @@ func TestBanffProposalBlockTrackedSubnet(t *testing.T) { subnetID := testSubnet1.ID() if tracked { env.config.TrackedSubnets.Add(subnetID) - env.config.Validators.Add(subnetID, validators.NewSet()) } // Add a subnet validator to the staker set @@ -966,7 +968,8 @@ func TestBanffProposalBlockTrackedSubnet(t *testing.T) { require.NoError(propBlk.Accept(context.Background())) require.NoError(commitBlk.Accept(context.Background())) - require.Equal(tracked, validators.Contains(env.config.Validators, subnetID, subnetValidatorNodeID)) + _, ok := env.config.Validators.GetValidator(subnetID, subnetValidatorNodeID) + require.Equal(tracked, ok) }) } } @@ -1054,9 +1057,7 @@ func TestBanffProposalBlockDelegatorStakerWeight(t *testing.T) { require.NoError(commitBlk.Accept(context.Background())) // Test validator weight before delegation - primarySet, ok := env.config.Validators.Get(constants.PrimaryNetworkID) - require.True(ok) - vdrWeight := primarySet.GetWeight(nodeID) + vdrWeight := env.config.Validators.GetWeight(constants.PrimaryNetworkID, nodeID) require.Equal(env.config.MinValidatorStake, vdrWeight) // Add delegator @@ -1148,7 +1149,7 @@ func TestBanffProposalBlockDelegatorStakerWeight(t *testing.T) { require.NoError(commitBlk.Accept(context.Background())) // Test validator weight after delegation - vdrWeight = primarySet.GetWeight(nodeID) + vdrWeight = env.config.Validators.GetWeight(constants.PrimaryNetworkID, nodeID) require.Equal(env.config.MinDelegatorStake+env.config.MinValidatorStake, vdrWeight) } @@ -1238,9 +1239,7 @@ func TestBanffProposalBlockDelegatorStakers(t *testing.T) { require.NoError(commitBlk.Accept(context.Background())) // Test validator weight before delegation - primarySet, ok := env.config.Validators.Get(constants.PrimaryNetworkID) - require.True(ok) - vdrWeight := primarySet.GetWeight(nodeID) + vdrWeight := env.config.Validators.GetWeight(constants.PrimaryNetworkID, nodeID) require.Equal(env.config.MinValidatorStake, vdrWeight) // Add delegator @@ -1330,6 +1329,6 @@ func TestBanffProposalBlockDelegatorStakers(t *testing.T) { require.NoError(commitBlk.Accept(context.Background())) // Test validator weight after delegation - vdrWeight = primarySet.GetWeight(nodeID) + vdrWeight = env.config.Validators.GetWeight(constants.PrimaryNetworkID, nodeID) require.Equal(env.config.MinDelegatorStake+env.config.MinValidatorStake, vdrWeight) } diff --git a/vms/platformvm/block/executor/standard_block_test.go b/vms/platformvm/block/executor/standard_block_test.go index 0665622f45df..50b614b2bd73 100644 --- a/vms/platformvm/block/executor/standard_block_test.go +++ b/vms/platformvm/block/executor/standard_block_test.go @@ -25,7 +25,6 @@ import ( "github.com/ava-labs/avalanchego/database" "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/snow/validators" "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/crypto/secp256k1" "github.com/ava-labs/avalanchego/vms/components/avax" @@ -370,7 +369,8 @@ func TestBanffStandardBlockUpdatePrimaryNetworkStakers(t *testing.T) { // Test VM validators require.NoError(block.Accept(context.Background())) - require.True(validators.Contains(env.config.Validators, constants.PrimaryNetworkID, nodeID)) + _, ok := env.config.Validators.GetValidator(constants.PrimaryNetworkID, nodeID) + require.True(ok) } // Ensure semantic verification updates the current and pending staker sets correctly. @@ -519,7 +519,6 @@ func TestBanffStandardBlockUpdateStakers(t *testing.T) { subnetID := testSubnet1.ID() env.config.TrackedSubnets.Add(subnetID) - env.config.Validators.Add(subnetID, validators.NewSet()) for _, staker := range test.stakers { _, err := addPendingValidator( @@ -584,20 +583,24 @@ func TestBanffStandardBlockUpdateStakers(t *testing.T) { case pending: _, err := env.state.GetPendingValidator(constants.PrimaryNetworkID, stakerNodeID) require.NoError(err) - require.False(validators.Contains(env.config.Validators, constants.PrimaryNetworkID, stakerNodeID)) + _, ok := env.config.Validators.GetValidator(constants.PrimaryNetworkID, stakerNodeID) + require.False(ok) case current: _, err := env.state.GetCurrentValidator(constants.PrimaryNetworkID, stakerNodeID) require.NoError(err) - require.True(validators.Contains(env.config.Validators, constants.PrimaryNetworkID, stakerNodeID)) + _, ok := env.config.Validators.GetValidator(constants.PrimaryNetworkID, stakerNodeID) + require.True(ok) } } for stakerNodeID, status := range test.expectedSubnetStakers { switch status { case pending: - require.False(validators.Contains(env.config.Validators, subnetID, stakerNodeID)) + _, ok := env.config.Validators.GetValidator(subnetID, stakerNodeID) + require.False(ok) case current: - require.True(validators.Contains(env.config.Validators, subnetID, stakerNodeID)) + _, ok := env.config.Validators.GetValidator(subnetID, stakerNodeID) + require.True(ok) } } }) @@ -618,7 +621,6 @@ func TestBanffStandardBlockRemoveSubnetValidator(t *testing.T) { subnetID := testSubnet1.ID() env.config.TrackedSubnets.Add(subnetID) - env.config.Validators.Add(subnetID, validators.NewSet()) // Add a subnet validator to the staker set subnetValidatorNodeID := ids.NodeID(preFundedKeys[0].PublicKey().Address()) @@ -699,8 +701,10 @@ func TestBanffStandardBlockRemoveSubnetValidator(t *testing.T) { // Check VM Validators are removed successfully require.NoError(block.Accept(context.Background())) - require.False(validators.Contains(env.config.Validators, subnetID, subnetVdr2NodeID)) - require.False(validators.Contains(env.config.Validators, subnetID, subnetValidatorNodeID)) + _, ok := env.config.Validators.GetValidator(subnetID, subnetVdr2NodeID) + require.False(ok) + _, ok = env.config.Validators.GetValidator(subnetID, subnetValidatorNodeID) + require.False(ok) } func TestBanffStandardBlockTrackedSubnet(t *testing.T) { @@ -716,7 +720,6 @@ func TestBanffStandardBlockTrackedSubnet(t *testing.T) { subnetID := testSubnet1.ID() if tracked { env.config.TrackedSubnets.Add(subnetID) - env.config.Validators.Add(subnetID, validators.NewSet()) } // Add a subnet validator to the staker set @@ -764,7 +767,8 @@ func TestBanffStandardBlockTrackedSubnet(t *testing.T) { // update staker set require.NoError(block.Verify(context.Background())) require.NoError(block.Accept(context.Background())) - require.Equal(tracked, validators.Contains(env.config.Validators, subnetID, subnetValidatorNodeID)) + _, ok := env.config.Validators.GetValidator(subnetID, subnetValidatorNodeID) + require.Equal(tracked, ok) }) } } @@ -810,9 +814,7 @@ func TestBanffStandardBlockDelegatorStakerWeight(t *testing.T) { require.NoError(env.state.Commit()) // Test validator weight before delegation - primarySet, ok := env.config.Validators.Get(constants.PrimaryNetworkID) - require.True(ok) - vdrWeight := primarySet.GetWeight(nodeID) + vdrWeight := env.config.Validators.GetWeight(constants.PrimaryNetworkID, nodeID) require.Equal(env.config.MinValidatorStake, vdrWeight) // Add delegator @@ -862,6 +864,6 @@ func TestBanffStandardBlockDelegatorStakerWeight(t *testing.T) { require.NoError(env.state.Commit()) // Test validator weight after delegation - vdrWeight = primarySet.GetWeight(nodeID) + vdrWeight = env.config.Validators.GetWeight(constants.PrimaryNetworkID, nodeID) require.Equal(env.config.MinDelegatorStake+env.config.MinValidatorStake, vdrWeight) } diff --git a/vms/platformvm/block/serialization_test.go b/vms/platformvm/block/serialization_test.go new file mode 100644 index 000000000000..031e527be25f --- /dev/null +++ b/vms/platformvm/block/serialization_test.go @@ -0,0 +1,115 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package block + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ava-labs/avalanchego/vms/platformvm/txs" +) + +func TestBanffBlockSerialization(t *testing.T) { + type test struct { + block BanffBlock + bytes []byte + } + + tests := []test{ + { + block: &BanffProposalBlock{ + ApricotProposalBlock: ApricotProposalBlock{ + Tx: &txs.Tx{ + Unsigned: &txs.AdvanceTimeTx{}, + }, + }, + }, + bytes: []byte{ + // Codec version + 0x00, 0x00, + // Type ID + 0x00, 0x00, 0x00, 0x1d, + // Rest + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x13, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + }, + }, + { + block: &BanffCommitBlock{ + ApricotCommitBlock: ApricotCommitBlock{}, + }, + bytes: []byte{ + // Codec version + 0x00, 0x00, + // Type ID + 0x00, 0x00, 0x00, 0x1f, + // Rest + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + }, + { + block: &BanffAbortBlock{ + ApricotAbortBlock: ApricotAbortBlock{}, + }, + bytes: []byte{ + // Codec version + 0x00, 0x00, + // Type ID + 0x00, 0x00, 0x00, 0x1e, + // Rest + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + }, + { + block: &BanffStandardBlock{ + ApricotStandardBlock: ApricotStandardBlock{ + Transactions: []*txs.Tx{}, + }, + }, + bytes: []byte{ + // Codec version + 0x00, 0x00, + // Type ID + 0x00, 0x00, 0x00, 0x20, + // Rest + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + }, + }, + } + + for _, test := range tests { + testName := fmt.Sprintf("%T", test.block) + t.Run(testName, func(t *testing.T) { + require := require.New(t) + + require.NoError(initialize(test.block)) + require.Equal(test.bytes, test.block.Bytes()) + }) + } +} diff --git a/vms/platformvm/camino_helpers_test.go b/vms/platformvm/camino_helpers_test.go index 36cc92833040..f3a544bdaaeb 100644 --- a/vms/platformvm/camino_helpers_test.go +++ b/vms/platformvm/camino_helpers_test.go @@ -5,7 +5,6 @@ package platformvm import ( "context" - "fmt" "testing" "time" @@ -28,7 +27,6 @@ import ( "github.com/ava-labs/avalanchego/utils/json" "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/nodeid" - "github.com/ava-labs/avalanchego/utils/timer/mockable" "github.com/ava-labs/avalanchego/utils/units" "github.com/ava-labs/avalanchego/version" "github.com/ava-labs/avalanchego/vms/components/avax" @@ -38,7 +36,6 @@ import ( "github.com/ava-labs/avalanchego/vms/platformvm/config" "github.com/ava-labs/avalanchego/vms/platformvm/genesis" "github.com/ava-labs/avalanchego/vms/platformvm/locked" - "github.com/ava-labs/avalanchego/vms/platformvm/reward" "github.com/ava-labs/avalanchego/vms/secp256k1fx" ) @@ -50,8 +47,22 @@ var ( localStakingPath = "../../staking/local/" caminoPreFundedKeys = secp256k1.TestKeys() _, caminoPreFundedNodeIDs = nodeid.LoadLocalCaminoNodeKeysAndIDs(localStakingPath) + defaultStartTime = banffForkTime.Add(time.Second) + + testAddressID ids.ShortID ) +func init() { + _, _, testAddressBytes, err := address.Parse(testAddress) + if err != nil { + panic(err) + } + testAddressID, err = ids.ToShortID(testAddressBytes) + if err != nil { + panic(err) + } +} + func defaultCaminoService(t *testing.T, camino api.Camino, utxos []api.UTXO) *CaminoService { vm := newCaminoVM(t, camino, utxos, nil) @@ -69,14 +80,15 @@ func defaultCaminoService(t *testing.T, camino api.Camino, utxos []api.UTXO) *Ca } func newCaminoVM(t *testing.T, genesisConfig api.Camino, genesisUTXOs []api.UTXO, startTime *time.Time) *VM { - vm := &VM{Config: defaultCaminoConfig(true)} + require := require.New(t) + + vm := &VM{Config: defaultCaminoConfig()} baseDBManager := manager.NewMemDB(version.Semantic1_0_0) chainDBManager := baseDBManager.NewPrefixDBManager([]byte{0}) atomicDB := prefixdb.New([]byte{1}, baseDBManager.Current().Database) if startTime == nil { - defaultStartTime := banffForkTime.Add(time.Second) startTime = &defaultStartTime } vm.clock.Set(*startTime) @@ -91,26 +103,33 @@ func newCaminoVM(t *testing.T, genesisConfig api.Camino, genesisUTXOs []api.UTXO ctx.Lock.Lock() defer ctx.Lock.Unlock() - _, genesisBytes := newCaminoGenesisWithUTXOs(genesisConfig, genesisUTXOs, startTime) + _, genesisBytes := newCaminoGenesisWithUTXOs(t, genesisConfig, genesisUTXOs, startTime) + // _, genesisBytes := defaultGenesis(t) appSender := &common.SenderTest{} appSender.CantSendAppGossip = true appSender.SendAppGossipF = func(context.Context, []byte) error { return nil } - if err := vm.Initialize(context.TODO(), ctx, chainDBManager, genesisBytes, nil, nil, msgChan, nil, appSender); err != nil { - panic(err) - } - if err := vm.SetState(context.TODO(), snow.NormalOp); err != nil { - panic(err) - } + require.NoError(vm.Initialize( + context.Background(), + ctx, + chainDBManager, + genesisBytes, + nil, + nil, + msgChan, + nil, + appSender, + )) + + require.NoError(vm.SetState(context.Background(), snow.NormalOp)) // Create a subnet and store it in testSubnet1 // Note: following Banff activation, block acceptance will move // chain time ahead - var err error - testSubnet1, err = vm.txBuilder.NewCreateSubnetTx( - 2, // threshold; 2 sigs from keys[0], keys[1], keys[2] needed to add validator to this subnet + testSubnet1, err := vm.txBuilder.NewCreateSubnetTx( + 2, // threshold; 2 sigs from control keys needed to add validator to this subnet []ids.ShortID{ // control keys caminoPreFundedKeys[0].PublicKey().Address(), caminoPreFundedKeys[1].PublicKey().Address(), @@ -119,53 +138,37 @@ func newCaminoVM(t *testing.T, genesisConfig api.Camino, genesisUTXOs []api.UTXO []*secp256k1.PrivateKey{caminoPreFundedKeys[0]}, caminoPreFundedKeys[0].PublicKey().Address(), ) - if err != nil { - panic(err) - } else if err := vm.Builder.AddUnverifiedTx(testSubnet1); err != nil { - panic(err) - } else if blk, err := vm.Builder.BuildBlock(context.TODO()); err != nil { - panic(err) - } else if err := blk.Verify(context.TODO()); err != nil { - panic(err) - } else if err := blk.Accept(context.TODO()); err != nil { - panic(err) - } else if err := vm.SetPreference(context.TODO(), vm.manager.LastAccepted()); err != nil { - panic(err) - } + require.NoError(err) + require.NoError(vm.Builder.AddUnverifiedTx(testSubnet1)) + blk, err := vm.Builder.BuildBlock(context.Background()) + require.NoError(err) + require.NoError(blk.Verify(context.Background())) + require.NoError(blk.Accept(context.Background())) + require.NoError(vm.SetPreference(context.Background(), vm.manager.LastAccepted())) return vm + // return vm, baseDBManager.Current().Database, msm } -func defaultCaminoConfig(postBanff bool) config.Config { //nolint:unparam - banffTime := mockable.MaxTime - if postBanff { - banffTime = defaultValidateEndTime.Add(-2 * time.Second) - } - - vdrs := validators.NewManager() - primaryVdrs := validators.NewSet() - _ = vdrs.Add(constants.PrimaryNetworkID, primaryVdrs) +func defaultCaminoConfig() config.Config { return config.Config{ Chains: chains.TestManager, UptimeLockedCalculator: uptime.NewLockedCalculator(), - Validators: vdrs, + SybilProtectionEnabled: true, + Validators: validators.NewManager(), TxFee: defaultTxFee, CreateSubnetTxFee: 100 * defaultTxFee, + TransformSubnetTxFee: 100 * defaultTxFee, CreateBlockchainTxFee: 100 * defaultTxFee, MinValidatorStake: defaultCaminoValidatorWeight, MaxValidatorStake: defaultCaminoValidatorWeight, MinDelegatorStake: 1 * units.MilliAvax, MinStakeDuration: defaultMinStakingDuration, MaxStakeDuration: defaultMaxStakingDuration, - RewardConfig: reward.Config{ - MaxConsumptionRate: .12 * reward.PercentDenominator, - MinConsumptionRate: .10 * reward.PercentDenominator, - MintingPeriod: 365 * 24 * time.Hour, - SupplyCap: 720 * units.MegaAvax, - }, - ApricotPhase3Time: defaultValidateEndTime, - ApricotPhase5Time: defaultValidateEndTime, - BanffTime: banffTime, + RewardConfig: defaultRewardConfig, + ApricotPhase3Time: defaultValidateEndTime, + ApricotPhase5Time: defaultValidateEndTime, + BanffTime: banffForkTime, CaminoConfig: caminoconfig.Config{ DACProposalBondAmount: 100 * units.Avax, }, @@ -175,7 +178,9 @@ func defaultCaminoConfig(postBanff bool) config.Config { //nolint:unparam // Returns: // 1) The genesis state // 2) The byte representation of the default genesis for tests -func newCaminoGenesisWithUTXOs(caminoGenesisConfig api.Camino, genesisUTXOs []api.UTXO, starttime *time.Time) (*api.BuildGenesisArgs, []byte) { +func newCaminoGenesisWithUTXOs(t *testing.T, caminoGenesisConfig api.Camino, genesisUTXOs []api.UTXO, starttime *time.Time) (*api.BuildGenesisArgs, []byte) { + require := require.New(t) + if starttime == nil { starttime = &defaultValidateStartTime } @@ -186,9 +191,7 @@ func newCaminoGenesisWithUTXOs(caminoGenesisConfig api.Camino, genesisUTXOs []ap genesisValidators := make([]api.PermissionlessValidator, len(caminoPreFundedKeys)) for i, key := range caminoPreFundedKeys { addr, err := address.FormatBech32(constants.UnitTestHRP, key.PublicKey().Address().Bytes()) - if err != nil { - panic(err) - } + require.NoError(err) genesisValidators[i] = api.PermissionlessValidator{ Staker: api.Staker{ StartTime: json.Uint64(starttime.Unix()), @@ -226,14 +229,10 @@ func newCaminoGenesisWithUTXOs(caminoGenesisConfig api.Camino, genesisUTXOs []ap buildGenesisResponse := api.BuildGenesisReply{} platformvmSS := api.StaticService{} - if err := platformvmSS.BuildGenesis(nil, &buildGenesisArgs, &buildGenesisResponse); err != nil { - panic(fmt.Errorf("problem while building platform chain's genesis state: %w", err)) - } + require.NoError(platformvmSS.BuildGenesis(nil, &buildGenesisArgs, &buildGenesisResponse)) genesisBytes, err := formatting.Decode(buildGenesisResponse.Encoding, buildGenesisResponse.Bytes) - if err != nil { - panic(err) - } + require.NoError(err) return &buildGenesisArgs, genesisBytes } @@ -250,10 +249,13 @@ func generateKeyAndOwner(t *testing.T) (*secp256k1.PrivateKey, ids.ShortID, secp } } -func stopService(t *testing.T, service *CaminoService) { - service.vm.ctx.Lock.Lock() - require.NoError(t, service.vm.Shutdown(context.TODO())) - service.vm.ctx.Lock.Unlock() +func stopVM(t *testing.T, vm *VM, doLock bool) { + t.Helper() + if doLock { + vm.ctx.Lock.Lock() + } + require.NoError(t, vm.Shutdown(context.TODO())) + vm.ctx.Lock.Unlock() } func generateTestUTXO(txID ids.ID, assetID ids.ID, amount uint64, outputOwners secp256k1fx.OutputOwners, depositTxID, bondTxID ids.ID) *avax.UTXO { diff --git a/vms/platformvm/camino_service_test.go b/vms/platformvm/camino_service_test.go index 067035f6ba45..2b347254714d 100644 --- a/vms/platformvm/camino_service_test.go +++ b/vms/platformvm/camino_service_test.go @@ -101,7 +101,7 @@ func TestGetCaminoBalance(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { service := defaultCaminoService(t, tt.camino, tt.genesisUTXOs) - defer stopService(t, service) + defer stopVM(t, service.vm, true) request := GetBalanceRequest{ Addresses: []string{ @@ -173,24 +173,17 @@ func TestGetCaminoBalance(t *testing.T) { } func TestCaminoService_GetAllDepositOffers(t *testing.T) { - type fields struct { - Service CaminoService - } type args struct { depositOffersArgs *GetAllDepositOffersArgs response *GetAllDepositOffersReply } tests := map[string]struct { - fields fields args args want []*APIDepositOffer wantErr error - prepare func(service CaminoService) + prepare func(service *CaminoService) }{ "OK": { - fields: fields{ - Service: *defaultCaminoService(t, api.Camino{}, []api.UTXO{}), - }, args: args{ depositOffersArgs: &GetAllDepositOffersArgs{ Timestamp: 50, @@ -214,7 +207,7 @@ func TestCaminoService_GetAllDepositOffers(t *testing.T) { End: 100, }, }, - prepare: func(service CaminoService) { + prepare: func(service *CaminoService) { service.vm.ctx.Lock.Lock() service.vm.state.SetDepositOffer(&deposit.Offer{ ID: ids.ID{0}, @@ -247,8 +240,12 @@ func TestCaminoService_GetAllDepositOffers(t *testing.T) { } for name, tt := range tests { t.Run(name, func(t *testing.T) { - tt.prepare(tt.fields.Service) - err := tt.fields.Service.GetAllDepositOffers(nil, tt.args.depositOffersArgs, tt.args.response) + s := defaultCaminoService(t, api.Camino{}, []api.UTXO{}) + defer stopVM(t, s.vm, true) + + tt.prepare(s) + + err := s.GetAllDepositOffers(nil, tt.args.depositOffersArgs, tt.args.response) require.ErrorIs(t, err, tt.wantErr) require.ElementsMatch(t, tt.want, tt.args.response.DepositOffers) }) @@ -256,12 +253,7 @@ func TestCaminoService_GetAllDepositOffers(t *testing.T) { } func TestGetKeystoreKeys(t *testing.T) { - s, _ := defaultService(t) userPass := json_api.UserPass{Username: testUsername, Password: testPassword} - // Insert testAddress into keystore - defaultAddress(t, s) - _, _, testAddressBytes, _ := address.Parse(testAddress) - testAddressID, _ := ids.ToShortID(testAddressBytes) tests := map[string]struct { from json_api.JSONFromAddrs @@ -290,8 +282,11 @@ func TestGetKeystoreKeys(t *testing.T) { } for name, tt := range tests { t.Run(name, func(t *testing.T) { + s, _ := defaultService(t) + defaultAddress(t, s) // Insert testAddress into keystore s.vm.ctx.Lock.Lock() - defer s.vm.ctx.Lock.Unlock() + defer stopVM(t, s.vm, false) + keys, err := s.getKeystoreKeys(&userPass, &tt.from) //nolint:gosec require.ErrorIs(t, err, tt.expectedError) @@ -308,6 +303,7 @@ func TestGetKeystoreKeys(t *testing.T) { func TestGetFakeKeys(t *testing.T) { s, _ := defaultService(t) + defer stopVM(t, s.vm, true) _, _, testAddressBytes, _ := address.Parse(testAddress) testAddressID, _ := ids.ToShortID(testAddressBytes) @@ -364,6 +360,7 @@ func TestSpend(t *testing.T) { Message: "", }}, ) + defer stopVM(t, service.vm, true) spendArgs := SpendArgs{ JSONFromAddrs: json_api.JSONFromAddrs{ diff --git a/vms/platformvm/camino_vm_test.go b/vms/platformvm/camino_vm_test.go index 05ac75da182e..6e186a3c3b74 100644 --- a/vms/platformvm/camino_vm_test.go +++ b/vms/platformvm/camino_vm_test.go @@ -12,6 +12,7 @@ import ( "github.com/ava-labs/avalanchego/database" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow/consensus/snowman" "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/crypto/secp256k1" "github.com/ava-labs/avalanchego/utils/formatting/address" @@ -22,6 +23,8 @@ import ( as "github.com/ava-labs/avalanchego/vms/platformvm/addrstate" "github.com/ava-labs/avalanchego/vms/platformvm/api" "github.com/ava-labs/avalanchego/vms/platformvm/block" + "github.com/ava-labs/avalanchego/vms/platformvm/block/builder" + blockexecutor "github.com/ava-labs/avalanchego/vms/platformvm/block/executor" "github.com/ava-labs/avalanchego/vms/platformvm/dac" "github.com/ava-labs/avalanchego/vms/platformvm/deposit" "github.com/ava-labs/avalanchego/vms/platformvm/genesis" @@ -31,12 +34,8 @@ import ( "github.com/ava-labs/avalanchego/vms/platformvm/status" "github.com/ava-labs/avalanchego/vms/platformvm/treasury" "github.com/ava-labs/avalanchego/vms/platformvm/txs" - "github.com/ava-labs/avalanchego/vms/secp256k1fx" - - smcon "github.com/ava-labs/avalanchego/snow/consensus/snowman" - "github.com/ava-labs/avalanchego/vms/platformvm/block/builder" - blockexecutor "github.com/ava-labs/avalanchego/vms/platformvm/block/executor" txexecutor "github.com/ava-labs/avalanchego/vms/platformvm/txs/executor" + "github.com/ava-labs/avalanchego/vms/secp256k1fx" ) func TestRemoveDeferredValidator(t *testing.T) { @@ -71,10 +70,7 @@ func TestRemoveDeferredValidator(t *testing.T) { vm := newCaminoVM(t, caminoGenesisConf, genesisUTXOs, nil) vm.ctx.Lock.Lock() - defer func() { - require.NoError(vm.Shutdown(context.Background())) - vm.ctx.Lock.Unlock() - }() + defer stopVM(t, vm, false) utxo := generateTestUTXO(ids.GenerateTestID(), avaxAssetID, defaultBalance, *outputOwners, ids.Empty, ids.Empty) vm.state.AddUTXO(utxo) @@ -174,7 +170,7 @@ func TestRemoveDeferredValidator(t *testing.T) { require.NoError(err) // Assert preferences are correct - oracleBlk := blk.(smcon.OracleBlock) + oracleBlk := blk.(snowman.OracleBlock) options, err := oracleBlk.Options(context.Background()) require.NoError(err) @@ -253,10 +249,7 @@ func TestRemoveReactivatedValidator(t *testing.T) { vm := newCaminoVM(t, caminoGenesisConf, genesisUTXOs, nil) vm.ctx.Lock.Lock() - defer func() { - require.NoError(vm.Shutdown(context.Background())) - vm.ctx.Lock.Unlock() - }() + defer stopVM(t, vm, false) utxo := generateTestUTXO(ids.GenerateTestID(), avaxAssetID, defaultBalance, *outputOwners, ids.Empty, ids.Empty) vm.state.AddUTXO(utxo) @@ -370,7 +363,7 @@ func TestRemoveReactivatedValidator(t *testing.T) { require.NoError(err) // Assert preferences are correct - oracleBlk := blk.(smcon.OracleBlock) + oracleBlk := blk.(snowman.OracleBlock) options, err := oracleBlk.Options(context.Background()) require.NoError(err) @@ -440,7 +433,7 @@ func TestDepositsAutoUnlock(t *testing.T) { Address: depositOwnerAddrBech32, }}, nil) vm.ctx.Lock.Lock() - defer func() { require.NoError(vm.Shutdown(context.Background())) }() //nolint:lint + defer stopVM(t, vm, false) // Add deposit depositTx, err := vm.txBuilder.NewDepositTx( @@ -498,7 +491,7 @@ func TestProposals(t *testing.T) { caminoPreFundedKey0AddrStr, err := address.FormatBech32(constants.UnitTestHRP, caminoPreFundedKeys[0].Address().Bytes()) require.NoError(t, err) - defaultConfig := defaultCaminoConfig(true) + defaultConfig := defaultCaminoConfig() proposalBondAmount := defaultConfig.CaminoConfig.DACProposalBondAmount newFee := (defaultTxFee + 7) * 10 @@ -579,7 +572,7 @@ func TestProposals(t *testing.T) { }, }, &defaultConfig.BanffTime) vm.ctx.Lock.Lock() - defer func() { require.NoError(vm.Shutdown(context.Background())) }() //nolint:lint + defer stopVM(t, vm, false) checkBalance(t, vm.state, proposerAddr, balance, // total 0, 0, 0, balance, // unlocked @@ -701,7 +694,7 @@ func TestAdminProposals(t *testing.T) { applicantAddr := proposerAddr - defaultConfig := defaultCaminoConfig(true) + defaultConfig := defaultCaminoConfig() proposalBondAmount := defaultConfig.CaminoConfig.DACProposalBondAmount balance := proposalBondAmount + defaultTxFee @@ -721,7 +714,7 @@ func TestAdminProposals(t *testing.T) { }, }, &defaultConfig.BanffTime) vm.ctx.Lock.Lock() - defer func() { require.NoError(vm.Shutdown(context.Background())) }() //nolint:lint + defer stopVM(t, vm, false) checkBalance(t, vm.state, proposerAddr, balance, // total 0, 0, 0, balance, // unlocked @@ -817,7 +810,7 @@ func TestExcludeMemberProposals(t *testing.T) { fundsKeyAddrStr, err := address.FormatBech32(constants.UnitTestHRP, fundsKey.Address().Bytes()) require.NoError(t, err) - defaultConfig := defaultCaminoConfig(true) + defaultConfig := defaultCaminoConfig() fee := defaultConfig.TxFee addValidatorFee := defaultConfig.AddPrimaryNetworkValidatorFee proposalBondAmount := defaultConfig.CaminoConfig.DACProposalBondAmount @@ -912,7 +905,7 @@ func TestExcludeMemberProposals(t *testing.T) { InitialAdmin: rootAdminKey.Address(), }, []api.UTXO{{Amount: json.Uint64(initialBalance - defaultCaminoValidatorWeight), Address: fundsKeyAddrStr}}, &defaultConfig.BanffTime) vm.ctx.Lock.Lock() - defer func() { require.NoError(vm.Shutdown(context.Background())) }() //nolint:lint + defer stopVM(t, vm, false) height, err := vm.GetCurrentHeight(context.Background()) require.NoError(err) require.Equal(expectedHeight, height) diff --git a/vms/platformvm/main_test.go b/vms/platformvm/main_test.go new file mode 100644 index 000000000000..88a571cfa5cb --- /dev/null +++ b/vms/platformvm/main_test.go @@ -0,0 +1,14 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package platformvm + +import ( + "testing" + + "go.uber.org/goleak" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} diff --git a/vms/platformvm/metrics/tx_metrics.go b/vms/platformvm/metrics/tx_metrics.go index 118f1156e677..17d6a090957b 100644 --- a/vms/platformvm/metrics/tx_metrics.go +++ b/vms/platformvm/metrics/tx_metrics.go @@ -27,7 +27,8 @@ type txMetrics struct { numRemoveSubnetValidatorTxs, numTransformSubnetTxs, numAddPermissionlessValidatorTxs, - numAddPermissionlessDelegatorTxs prometheus.Counter + numAddPermissionlessDelegatorTxs, + numTransferSubnetOwnershipTxs prometheus.Counter } func newTxMetrics( @@ -49,6 +50,7 @@ func newTxMetrics( numTransformSubnetTxs: newTxMetric(namespace, "transform_subnet", registerer, &errs), numAddPermissionlessValidatorTxs: newTxMetric(namespace, "add_permissionless_validator", registerer, &errs), numAddPermissionlessDelegatorTxs: newTxMetric(namespace, "add_permissionless_delegator", registerer, &errs), + numTransferSubnetOwnershipTxs: newTxMetric(namespace, "transfer_subnet_ownership", registerer, &errs), } return m, errs.Err } @@ -132,3 +134,8 @@ func (m *txMetrics) AddPermissionlessDelegatorTx(*txs.AddPermissionlessDelegator m.numAddPermissionlessDelegatorTxs.Inc() return nil } + +func (m *txMetrics) TransferSubnetOwnershipTx(*txs.TransferSubnetOwnershipTx) error { + m.numTransferSubnetOwnershipTxs.Inc() + return nil +} diff --git a/vms/platformvm/service.go b/vms/platformvm/service.go index dfbac9791984..be49c0e26fc6 100644 --- a/vms/platformvm/service.go +++ b/vms/platformvm/service.go @@ -1173,17 +1173,9 @@ func (s *Service) SampleValidators(_ *http.Request, args *SampleValidatorsArgs, zap.Uint16("size", uint16(args.Size)), ) - validators, ok := s.vm.Validators.Get(args.SubnetID) - if !ok { - return fmt.Errorf( - "couldn't get validators of subnet %q. Is it being validated?", - args.SubnetID, - ) - } - - sample, err := validators.Sample(int(args.Size)) + sample, err := s.vm.Validators.Sample(args.SubnetID, int(args.Size)) if err != nil { - return fmt.Errorf("sampling errored with %w", err) + return fmt.Errorf("sampling %s errored with %w", args.SubnetID, err) } if sample == nil { @@ -1894,12 +1886,8 @@ func (s *Service) nodeValidates(blockchainID ids.ID) bool { return false } - validators, ok := s.vm.Validators.Get(chain.SubnetID) - if !ok { - return false - } - - return validators.Contains(s.vm.ctx.NodeID) + _, isValidator := s.vm.Validators.GetValidator(chain.SubnetID, s.vm.ctx.NodeID) + return isValidator } func (s *Service) chainExists(ctx context.Context, blockID ids.ID, chainID ids.ID) (bool, error) { @@ -2396,11 +2384,11 @@ func (s *Service) GetTotalStake(_ *http.Request, args *GetTotalStakeArgs, reply zap.String("method", "getTotalStake"), ) - vdrs, ok := s.vm.Validators.Get(args.SubnetID) - if !ok { - return errMissingValidatorSet + totalWeight, err := s.vm.Validators.TotalWeight(args.SubnetID) + if err != nil { + return fmt.Errorf("couldn't get total weight: %w", err) } - weight := json.Uint64(vdrs.Weight()) + weight := json.Uint64(totalWeight) reply.Weight = weight reply.Stake = weight return nil diff --git a/vms/platformvm/service_test.go b/vms/platformvm/service_test.go index a0d94d6058a1..bbb238f15701 100644 --- a/vms/platformvm/service_test.go +++ b/vms/platformvm/service_test.go @@ -777,6 +777,11 @@ func TestGetBlock(t *testing.T) { require := require.New(t) service, _ := defaultService(t) service.vm.ctx.Lock.Lock() + defer func() { + service.vm.ctx.Lock.Lock() + require.NoError(service.vm.Shutdown(context.Background())) + service.vm.ctx.Lock.Unlock() + }() service.vm.Config.CreateAssetTxFee = 100 * defaultTxFee diff --git a/vms/platformvm/state/camino_helpers_test.go b/vms/platformvm/state/camino_helpers_test.go index a4e1f36891fc..a4afce530793 100644 --- a/vms/platformvm/state/camino_helpers_test.go +++ b/vms/platformvm/state/camino_helpers_test.go @@ -16,7 +16,6 @@ import ( "github.com/ava-labs/avalanchego/snow" "github.com/ava-labs/avalanchego/snow/validators" "github.com/ava-labs/avalanchego/utils" - "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/units" "github.com/ava-labs/avalanchego/vms/components/avax" "github.com/ava-labs/avalanchego/vms/platformvm/config" @@ -55,14 +54,12 @@ func generateBaseTx(assetID ids.ID, amount uint64, outputOwners secp256k1fx.Outp } func newEmptyState(t *testing.T) *state { - vdrs := validators.NewManager() - _ = vdrs.Add(constants.PrimaryNetworkID, validators.NewSet()) execCfg, _ := config.GetExecutionConfig(nil) newState, err := newState( memdb.New(), metrics.Noop, &config.Config{ - Validators: vdrs, + Validators: validators.NewManager(), }, execCfg, &snow.Context{}, diff --git a/vms/platformvm/state/metadata_codec.go b/vms/platformvm/state/metadata_codec.go new file mode 100644 index 000000000000..6240bbd879ca --- /dev/null +++ b/vms/platformvm/state/metadata_codec.go @@ -0,0 +1,28 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package state + +import ( + "math" + + "github.com/ava-labs/avalanchego/codec" + "github.com/ava-labs/avalanchego/codec/linearcodec" +) + +const ( + v0tag = "v0" + v0 = uint16(0) +) + +var metadataCodec codec.Manager + +func init() { + c := linearcodec.New([]string{v0tag}, math.MaxInt32) + metadataCodec = codec.NewManager(math.MaxInt32) + + err := metadataCodec.RegisterCodec(v0, c) + if err != nil { + panic(err) + } +} diff --git a/vms/platformvm/state/metadata_delegator.go b/vms/platformvm/state/metadata_delegator.go new file mode 100644 index 000000000000..04e7ef6a8795 --- /dev/null +++ b/vms/platformvm/state/metadata_delegator.go @@ -0,0 +1,25 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package state + +import ( + "github.com/ava-labs/avalanchego/database" + "github.com/ava-labs/avalanchego/ids" +) + +type delegatorMetadata struct { + PotentialReward uint64 + + txID ids.ID +} + +func parseDelegatorMetadata(bytes []byte, metadata *delegatorMetadata) error { + var err error + metadata.PotentialReward, err = database.ParseUInt64(bytes) + return err +} + +func writeDelegatorMetadata(db database.KeyValueWriter, metadata *delegatorMetadata) error { + return database.PutUInt64(db, metadata.txID[:], metadata.PotentialReward) +} diff --git a/vms/platformvm/state/validator_metadata.go b/vms/platformvm/state/metadata_validator.go similarity index 90% rename from vms/platformvm/state/validator_metadata.go rename to vms/platformvm/state/metadata_validator.go index a14b9331f0c1..6b839ccad801 100644 --- a/vms/platformvm/state/validator_metadata.go +++ b/vms/platformvm/state/metadata_validator.go @@ -11,8 +11,6 @@ import ( "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/set" "github.com/ava-labs/avalanchego/utils/wrappers" - "github.com/ava-labs/avalanchego/vms/platformvm/genesis" - "github.com/ava-labs/avalanchego/vms/platformvm/txs" ) // preDelegateeRewardSize is the size of codec marshalling @@ -24,16 +22,16 @@ const preDelegateeRewardSize = wrappers.ShortLen + 3*wrappers.LongLen var _ validatorState = (*metadata)(nil) type preDelegateeRewardMetadata struct { - UpDuration time.Duration `serialize:"true"` - LastUpdated uint64 `serialize:"true"` // Unix time in seconds - PotentialReward uint64 `serialize:"true"` + UpDuration time.Duration `v0:"true"` + LastUpdated uint64 `v0:"true"` // Unix time in seconds + PotentialReward uint64 `v0:"true"` } type validatorMetadata struct { - UpDuration time.Duration `serialize:"true"` - LastUpdated uint64 `serialize:"true"` // Unix time in seconds - PotentialReward uint64 `serialize:"true"` - PotentialDelegateeReward uint64 `serialize:"true"` + UpDuration time.Duration `v0:"true"` + LastUpdated uint64 `v0:"true"` // Unix time in seconds + PotentialReward uint64 `v0:"true"` + PotentialDelegateeReward uint64 `v0:"true"` txID ids.ID lastUpdated time.Time @@ -60,7 +58,7 @@ func parseValidatorMetadata(bytes []byte, metadata *validatorMetadata) error { // potential reward and uptime was stored but potential delegatee reward // was not tmp := preDelegateeRewardMetadata{} - if _, err := txs.Codec.Unmarshal(bytes, &tmp); err != nil { + if _, err := metadataCodec.Unmarshal(bytes, &tmp); err != nil { return err } @@ -69,7 +67,7 @@ func parseValidatorMetadata(bytes []byte, metadata *validatorMetadata) error { metadata.PotentialReward = tmp.PotentialReward default: // everything was stored - if _, err := txs.Codec.Unmarshal(bytes, metadata); err != nil { + if _, err := metadataCodec.Unmarshal(bytes, metadata); err != nil { return err } } @@ -238,7 +236,7 @@ func (m *metadata) WriteValidatorMetadata( metadata := m.metadata[vdrID][subnetID] metadata.LastUpdated = uint64(metadata.lastUpdated.Unix()) - metadataBytes, err := genesis.Codec.Marshal(txs.Version, metadata) + metadataBytes, err := metadataCodec.Marshal(v0, metadata) if err != nil { return err } diff --git a/vms/platformvm/state/validator_metadata_test.go b/vms/platformvm/state/metadata_validator_test.go similarity index 100% rename from vms/platformvm/state/validator_metadata_test.go rename to vms/platformvm/state/metadata_validator_test.go diff --git a/vms/platformvm/state/mock_state.go b/vms/platformvm/state/mock_state.go index 58b97dcb7967..de76e4529999 100644 --- a/vms/platformvm/state/mock_state.go +++ b/vms/platformvm/state/mock_state.go @@ -163,6 +163,20 @@ func (mr *MockStateMockRecorder) AddUTXO(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUTXO", reflect.TypeOf((*MockState)(nil).AddUTXO), arg0) } +// ApplyCurrentValidators mocks base method. +func (m *MockState) ApplyCurrentValidators(arg0 ids.ID, arg1 validators.Manager) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ApplyCurrentValidators", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// ApplyCurrentValidators indicates an expected call of ApplyCurrentValidators. +func (mr *MockStateMockRecorder) ApplyCurrentValidators(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplyCurrentValidators", reflect.TypeOf((*MockState)(nil).ApplyCurrentValidators), arg0, arg1) +} + // ApplyValidatorPublicKeyDiffs mocks base method. func (m *MockState) ApplyValidatorPublicKeyDiffs(arg0 context.Context, arg1 map[ids.NodeID]*validators.GetValidatorOutput, arg2, arg3 uint64) error { m.ctrl.T.Helper() @@ -1303,20 +1317,6 @@ func (mr *MockStateMockRecorder) UTXOIDs(arg0, arg1, arg2 interface{}) *gomock.C return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UTXOIDs", reflect.TypeOf((*MockState)(nil).UTXOIDs), arg0, arg1, arg2) } -// ValidatorSet mocks base method. -func (m *MockState) ValidatorSet(arg0 ids.ID, arg1 validators.Set) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ValidatorSet", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// ValidatorSet indicates an expected call of ValidatorSet. -func (mr *MockStateMockRecorder) ValidatorSet(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidatorSet", reflect.TypeOf((*MockState)(nil).ValidatorSet), arg0, arg1) -} - // AddDeposit mocks base method. func (m *MockState) AddDeposit(arg0 ids.ID, arg1 *deposit.Deposit) { m.ctrl.T.Helper() diff --git a/vms/platformvm/state/state.go b/vms/platformvm/state/state.go index 9e95210de38a..0b08a096ce8b 100644 --- a/vms/platformvm/state/state.go +++ b/vms/platformvm/state/state.go @@ -68,9 +68,7 @@ const ( var ( _ State = (*state)(nil) - errMissingValidatorSet = errors.New("missing validator set") errValidatorSetAlreadyPopulated = errors.New("validator set already populated") - errDuplicateValidatorSet = errors.New("duplicate validator set") errIsNotSubnet = errors.New("is not a subnet") blockIDPrefix = []byte("blockID") @@ -153,9 +151,9 @@ type State interface { GetBlockIDAtHeight(height uint64) (ids.ID, error) - // ValidatorSet adds all the validators and delegators of [subnetID] into - // [vdrs]. - ValidatorSet(subnetID ids.ID, vdrs validators.Set) error + // ApplyCurrentValidators adds all the current validators and delegators of + // [subnetID] into [vdrs]. + ApplyCurrentValidators(subnetID ids.ID, vdrs validators.Manager) error // ApplyValidatorWeightDiffs iterates from [startHeight] towards the genesis // block until it has applied all of the diffs up to and including @@ -1161,17 +1159,17 @@ func (s *state) SetCurrentSupply(subnetID ids.ID, cs uint64) { } } -func (s *state) ValidatorSet(subnetID ids.ID, vdrs validators.Set) error { +func (s *state) ApplyCurrentValidators(subnetID ids.ID, vdrs validators.Manager) error { for nodeID, validator := range s.currentStakers.validators[subnetID] { staker := validator.validator - if err := vdrs.Add(nodeID, staker.PublicKey, staker.TxID, staker.Weight); err != nil { + if err := vdrs.AddStaker(subnetID, nodeID, staker.PublicKey, staker.TxID, staker.Weight); err != nil { return err } delegatorIterator := NewTreeIterator(validator.delegators) for delegatorIterator.Next() { staker := delegatorIterator.Value() - if err := vdrs.AddWeight(nodeID, staker.Weight); err != nil { + if err := vdrs.AddWeight(subnetID, nodeID, staker.Weight); err != nil { delegatorIterator.Release() return err } @@ -1591,8 +1589,10 @@ func (s *state) loadCurrentValidators() error { return err } - potentialRewardBytes := delegatorIt.Value() - potentialReward, err := database.ParseUInt64(potentialRewardBytes) + metadata := &delegatorMetadata{ + txID: txID, + } + err = parseDelegatorMetadata(delegatorIt.Value(), metadata) if err != nil { return err } @@ -1602,7 +1602,7 @@ func (s *state) loadCurrentValidators() error { return fmt.Errorf("expected tx type txs.Staker but got %T", tx.Unsigned) } - staker, err := NewCurrentStaker(txID, stakerTx, potentialReward) + staker, err := NewCurrentStaker(txID, stakerTx, metadata.PotentialReward) if err != nil { return err } @@ -1716,38 +1716,37 @@ func (s *state) loadPendingValidators() error { // Invariant: initValidatorSets requires loadCurrentValidators to have already // been called. func (s *state) initValidatorSets() error { - primaryValidators, ok := s.cfg.Validators.Get(constants.PrimaryNetworkID) - if !ok { - return errMissingValidatorSet - } - if primaryValidators.Len() != 0 { + if s.cfg.Validators.Count(constants.PrimaryNetworkID) != 0 { // Enforce the invariant that the validator set is empty here. return errValidatorSetAlreadyPopulated } - err := s.ValidatorSet(constants.PrimaryNetworkID, primaryValidators) + err := s.ApplyCurrentValidators(constants.PrimaryNetworkID, s.cfg.Validators) if err != nil { return err } vl := validators.NewLogger(s.ctx.Log, s.bootstrapped, constants.PrimaryNetworkID, s.ctx.NodeID) - primaryValidators.RegisterCallbackListener(vl) + s.cfg.Validators.RegisterCallbackListener(constants.PrimaryNetworkID, vl) - s.metrics.SetLocalStake(primaryValidators.GetWeight(s.ctx.NodeID)) - s.metrics.SetTotalStake(primaryValidators.Weight()) + s.metrics.SetLocalStake(s.cfg.Validators.GetWeight(constants.PrimaryNetworkID, s.ctx.NodeID)) + totalWeight, err := s.cfg.Validators.TotalWeight(constants.PrimaryNetworkID) + if err != nil { + return fmt.Errorf("failed to get total weight of primary network validators: %w", err) + } + s.metrics.SetTotalStake(totalWeight) for subnetID := range s.cfg.TrackedSubnets { - subnetValidators := validators.NewSet() - err := s.ValidatorSet(subnetID, subnetValidators) + if s.cfg.Validators.Count(subnetID) != 0 { + // Enforce the invariant that the validator set is empty here. + return errValidatorSetAlreadyPopulated + } + err := s.ApplyCurrentValidators(subnetID, s.cfg.Validators) if err != nil { return err } - if !s.cfg.Validators.Add(subnetID, subnetValidators) { - return fmt.Errorf("%w: %s", errDuplicateValidatorSet, subnetID) - } - vl := validators.NewLogger(s.ctx.Log, s.bootstrapped, subnetID, s.ctx.NodeID) - subnetValidators.RegisterCallbackListener(vl) + s.cfg.Validators.RegisterCallbackListener(subnetID, vl) } return nil } @@ -2062,7 +2061,7 @@ func (s *state) writeCurrentStakers(updateValidators bool, height uint64) error PotentialDelegateeReward: 0, } - metadataBytes, err := block.GenesisCodec.Marshal(block.Version, metadata) + metadataBytes, err := metadataCodec.Marshal(v0, metadata) if err != nil { return fmt.Errorf("failed to serialize current validator: %w", err) } @@ -2153,12 +2152,11 @@ func (s *state) writeCurrentStakers(updateValidators bool, height uint64) error } if weightDiff.Decrease { - err = validators.RemoveWeight(s.cfg.Validators, subnetID, nodeID, weightDiff.Amount) + err = s.cfg.Validators.RemoveWeight(subnetID, nodeID, weightDiff.Amount) } else { if validatorDiff.validatorStatus == added { staker := validatorDiff.validator - err = validators.Add( - s.cfg.Validators, + err = s.cfg.Validators.AddStaker( subnetID, nodeID, staker.PublicKey, @@ -2166,7 +2164,7 @@ func (s *state) writeCurrentStakers(updateValidators bool, height uint64) error weightDiff.Amount, ) } else { - err = validators.AddWeight(s.cfg.Validators, subnetID, nodeID, weightDiff.Amount) + err = s.cfg.Validators.AddWeight(subnetID, nodeID, weightDiff.Amount) } } if err != nil { @@ -2181,12 +2179,14 @@ func (s *state) writeCurrentStakers(updateValidators bool, height uint64) error if !updateValidators { return nil } - primaryValidators, ok := s.cfg.Validators.Get(constants.PrimaryNetworkID) - if !ok { - return nil + + totalWeight, err := s.cfg.Validators.TotalWeight(constants.PrimaryNetworkID) + if err != nil { + return fmt.Errorf("failed to get total weight of primary network: %w", err) } - s.metrics.SetLocalStake(primaryValidators.GetWeight(s.ctx.NodeID)) - s.metrics.SetTotalStake(primaryValidators.Weight()) + + s.metrics.SetLocalStake(s.cfg.Validators.GetWeight(constants.PrimaryNetworkID, s.ctx.NodeID)) + s.metrics.SetTotalStake(totalWeight) return nil } @@ -2204,7 +2204,11 @@ func writeCurrentDelegatorDiff( return fmt.Errorf("failed to increase node weight diff: %w", err) } - if err := database.PutUInt64(currentDelegatorList, staker.TxID[:], staker.PotentialReward); err != nil { + metadata := &delegatorMetadata{ + txID: staker.TxID, + PotentialReward: staker.PotentialReward, + } + if err := writeDelegatorMetadata(currentDelegatorList, metadata); err != nil { return fmt.Errorf("failed to write current delegator to list: %w", err) } } diff --git a/vms/platformvm/state/state_test.go b/vms/platformvm/state/state_test.go index d76505d0db85..5a29619c1beb 100644 --- a/vms/platformvm/state/state_test.go +++ b/vms/platformvm/state/state_test.go @@ -162,16 +162,12 @@ func newUninitializedState(require *require.Assertions) (State, database.Databas } func newStateFromDB(require *require.Assertions, db database.Database) State { - vdrs := validators.NewManager() - primaryVdrs := validators.NewSet() - _ = vdrs.Add(constants.PrimaryNetworkID, primaryVdrs) - execCfg, _ := config.GetExecutionConfig(nil) state, err := newState( db, metrics.Noop, &config.Config{ - Validators: vdrs, + Validators: validators.NewManager(), }, execCfg, &snow.Context{}, diff --git a/vms/platformvm/txs/builder/builder.go b/vms/platformvm/txs/builder/builder.go index 279202087619..3f13ec2ecad9 100644 --- a/vms/platformvm/txs/builder/builder.go +++ b/vms/platformvm/txs/builder/builder.go @@ -159,6 +159,19 @@ type ProposalTxBuilder interface { changeAddr ids.ShortID, ) (*txs.Tx, error) + // Creates a transaction that transfers ownership of [subnetID] + // threshold: [threshold] of [ownerAddrs] needed to manage this subnet + // ownerAddrs: control addresses for the new subnet + // keys: keys to use for modifying the subnet + // changeAddr: address to send change to, if there is any + NewTransferSubnetOwnershipTx( + subnetID ids.ID, + threshold uint32, + ownerAddrs []ids.ShortID, + keys []*secp256k1.PrivateKey, + changeAddr ids.ShortID, + ) (*txs.Tx, error) + // newAdvanceTimeTx creates a new tx that, if it is accepted and followed by a // Commit block, will set the chain's timestamp to [timestamp]. NewAdvanceTimeTx(timestamp time.Time) (*txs.Tx, error) @@ -609,3 +622,42 @@ func (b *builder) NewRewardValidatorTx(txID ids.ID) (*txs.Tx, error) { return tx, tx.SyntacticVerify(b.ctx) } + +func (b *builder) NewTransferSubnetOwnershipTx( + subnetID ids.ID, + threshold uint32, + ownerAddrs []ids.ShortID, + keys []*secp256k1.PrivateKey, + changeAddr ids.ShortID, +) (*txs.Tx, error) { + ins, outs, _, signers, err := b.Spend(b.state, keys, 0, b.cfg.TxFee, changeAddr) + if err != nil { + return nil, fmt.Errorf("couldn't generate tx inputs/outputs: %w", err) + } + + subnetAuth, subnetSigners, err := b.Authorize(b.state, subnetID, keys) + if err != nil { + return nil, fmt.Errorf("couldn't authorize tx's subnet restrictions: %w", err) + } + signers = append(signers, subnetSigners) + + utx := &txs.TransferSubnetOwnershipTx{ + BaseTx: txs.BaseTx{BaseTx: avax.BaseTx{ + NetworkID: b.ctx.NetworkID, + BlockchainID: b.ctx.ChainID, + Ins: ins, + Outs: outs, + }}, + Subnet: subnetID, + SubnetAuth: subnetAuth, + Owner: &secp256k1fx.OutputOwners{ + Threshold: threshold, + Addrs: ownerAddrs, + }, + } + tx, err := txs.NewSigned(utx, txs.Codec, signers) + if err != nil { + return nil, err + } + return tx, tx.SyntacticVerify(b.ctx) +} diff --git a/vms/platformvm/txs/builder/camino_helpers_test.go b/vms/platformvm/txs/builder/camino_helpers_test.go index 5e9449c9e608..473ea80f8776 100644 --- a/vms/platformvm/txs/builder/camino_helpers_test.go +++ b/vms/platformvm/txs/builder/camino_helpers_test.go @@ -374,14 +374,10 @@ func defaultCaminoConfig(postBanff bool) config.Config { banffTime = defaultValidateEndTime.Add(-2 * time.Second) } - vdrs := validators.NewManager() - primaryVdrs := validators.NewSet() - _ = vdrs.Add(constants.PrimaryNetworkID, primaryVdrs) - return config.Config{ Chains: chains.TestManager, UptimeLockedCalculator: uptime.NewLockedCalculator(), - Validators: vdrs, + Validators: validators.NewManager(), TxFee: defaultTxFee, CreateSubnetTxFee: 100 * defaultTxFee, CreateBlockchainTxFee: 100 * defaultTxFee, @@ -476,10 +472,7 @@ func buildCaminoGenesisTest(ctx *snow.Context, caminoGenesisConf api.Camino) []b func shutdownCaminoEnvironment(env *caminoEnvironment) error { if env.isBootstrapped.Get() { - validatorIDs, err := validators.NodeIDs(env.config.Validators, constants.PrimaryNetworkID) - if err != nil { - return err - } + validatorIDs := env.config.Validators.GetValidatorIDs(constants.PrimaryNetworkID) if err := env.uptimes.StopTracking(validatorIDs, constants.PrimaryNetworkID); err != nil { return err diff --git a/vms/platformvm/txs/builder/mock_builder.go b/vms/platformvm/txs/builder/mock_builder.go index 20fd52e4abd4..79291afb7cd7 100644 --- a/vms/platformvm/txs/builder/mock_builder.go +++ b/vms/platformvm/txs/builder/mock_builder.go @@ -189,3 +189,18 @@ func (mr *MockBuilderMockRecorder) NewRewardValidatorTx(arg0 interface{}) *gomoc mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewRewardValidatorTx", reflect.TypeOf((*MockBuilder)(nil).NewRewardValidatorTx), arg0) } + +// NewTransferSubnetOwnershipTx mocks base method. +func (m *MockBuilder) NewTransferSubnetOwnershipTx(arg0 ids.ID, arg1 uint32, arg2 []ids.ShortID, arg3 []*secp256k1.PrivateKey, arg4 ids.ShortID) (*txs.Tx, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewTransferSubnetOwnershipTx", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(*txs.Tx) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// NewTransferSubnetOwnershipTx indicates an expected call of NewTransferSubnetOwnershipTx. +func (mr *MockBuilderMockRecorder) NewTransferSubnetOwnershipTx(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewTransferSubnetOwnershipTx", reflect.TypeOf((*MockBuilder)(nil).NewTransferSubnetOwnershipTx), arg0, arg1, arg2, arg3, arg4) +} diff --git a/vms/platformvm/txs/codec.go b/vms/platformvm/txs/codec.go index 59dff5bb7c31..5efe9d7e382f 100644 --- a/vms/platformvm/txs/codec.go +++ b/vms/platformvm/txs/codec.go @@ -54,6 +54,10 @@ func init() { c.SkipRegistrations(5) errs.Add(RegisterUnsignedTxsTypes(c)) + + c.SkipRegistrations(4) + + errs.Add(RegisterDUnsignedTxsTypes(c)) } errs.Add( Codec.RegisterCodec(Version, c), @@ -140,3 +144,7 @@ func RegisterUnsignedTxsTypes(targetCodec linearcodec.CaminoCodec) error { ) return errs.Err } + +func RegisterDUnsignedTxsTypes(targetCodec linearcodec.Codec) error { + return targetCodec.RegisterType(&TransferSubnetOwnershipTx{}) +} diff --git a/vms/platformvm/txs/executor/advance_time_test.go b/vms/platformvm/txs/executor/advance_time_test.go index bfe6d58fa085..9bf5aafed7ac 100644 --- a/vms/platformvm/txs/executor/advance_time_test.go +++ b/vms/platformvm/txs/executor/advance_time_test.go @@ -12,7 +12,6 @@ import ( "github.com/ava-labs/avalanchego/database" "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/snow/validators" "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/crypto/secp256k1" "github.com/ava-labs/avalanchego/vms/platformvm/reward" @@ -77,7 +76,8 @@ func TestAdvanceTimeTxUpdatePrimaryNetworkStakers(t *testing.T) { env.state.SetHeight(dummyHeight) require.NoError(env.state.Commit()) - require.True(validators.Contains(env.config.Validators, constants.PrimaryNetworkID, nodeID)) + _, ok := env.config.Validators.GetValidator(constants.PrimaryNetworkID, nodeID) + require.True(ok) } // Ensure semantic verification fails when proposed timestamp is at or before current timestamp @@ -356,7 +356,6 @@ func TestAdvanceTimeTxUpdateStakers(t *testing.T) { subnetID := testSubnet1.ID() env.config.TrackedSubnets.Add(subnetID) - env.config.Validators.Add(subnetID, validators.NewSet()) for _, staker := range test.stakers { _, err := addPendingValidator( @@ -422,20 +421,24 @@ func TestAdvanceTimeTxUpdateStakers(t *testing.T) { case pending: _, err := env.state.GetPendingValidator(constants.PrimaryNetworkID, stakerNodeID) require.NoError(err) - require.False(validators.Contains(env.config.Validators, constants.PrimaryNetworkID, stakerNodeID)) + _, ok := env.config.Validators.GetValidator(constants.PrimaryNetworkID, stakerNodeID) + require.False(ok) case current: _, err := env.state.GetCurrentValidator(constants.PrimaryNetworkID, stakerNodeID) require.NoError(err) - require.True(validators.Contains(env.config.Validators, constants.PrimaryNetworkID, stakerNodeID)) + _, ok := env.config.Validators.GetValidator(constants.PrimaryNetworkID, stakerNodeID) + require.True(ok) } } for stakerNodeID, status := range test.expectedSubnetStakers { switch status { case pending: - require.False(validators.Contains(env.config.Validators, subnetID, stakerNodeID)) + _, ok := env.config.Validators.GetValidator(subnetID, stakerNodeID) + require.False(ok) case current: - require.True(validators.Contains(env.config.Validators, subnetID, stakerNodeID)) + _, ok := env.config.Validators.GetValidator(subnetID, stakerNodeID) + require.True(ok) } } }) @@ -456,7 +459,6 @@ func TestAdvanceTimeTxRemoveSubnetValidator(t *testing.T) { subnetID := testSubnet1.ID() env.config.TrackedSubnets.Add(subnetID) - env.config.Validators.Add(subnetID, validators.NewSet()) dummyHeight := uint64(1) // Add a subnet validator to the staker set @@ -542,8 +544,10 @@ func TestAdvanceTimeTxRemoveSubnetValidator(t *testing.T) { env.state.SetHeight(dummyHeight) require.NoError(env.state.Commit()) - require.False(validators.Contains(env.config.Validators, subnetID, subnetVdr2NodeID)) - require.False(validators.Contains(env.config.Validators, subnetID, subnetValidatorNodeID)) + _, ok := env.config.Validators.GetValidator(subnetID, subnetVdr2NodeID) + require.False(ok) + _, ok = env.config.Validators.GetValidator(subnetID, subnetValidatorNodeID) + require.False(ok) } func TestTrackedSubnet(t *testing.T) { @@ -560,7 +564,6 @@ func TestTrackedSubnet(t *testing.T) { subnetID := testSubnet1.ID() if tracked { env.config.TrackedSubnets.Add(subnetID) - env.config.Validators.Add(subnetID, validators.NewSet()) } // Add a subnet validator to the staker set @@ -613,7 +616,8 @@ func TestTrackedSubnet(t *testing.T) { env.state.SetHeight(dummyHeight) require.NoError(env.state.Commit()) - require.Equal(tracked, validators.Contains(env.config.Validators, subnetID, ids.NodeID(subnetValidatorNodeID))) + _, ok := env.config.Validators.GetValidator(subnetID, ids.NodeID(subnetValidatorNodeID)) + require.Equal(tracked, ok) }) } } @@ -664,9 +668,7 @@ func TestAdvanceTimeTxDelegatorStakerWeight(t *testing.T) { require.NoError(env.state.Commit()) // Test validator weight before delegation - primarySet, ok := env.config.Validators.Get(constants.PrimaryNetworkID) - require.True(ok) - vdrWeight := primarySet.GetWeight(nodeID) + vdrWeight := env.config.Validators.GetWeight(constants.PrimaryNetworkID, nodeID) require.Equal(env.config.MinValidatorStake, vdrWeight) // Add delegator @@ -723,7 +725,7 @@ func TestAdvanceTimeTxDelegatorStakerWeight(t *testing.T) { require.NoError(env.state.Commit()) // Test validator weight after delegation - vdrWeight = primarySet.GetWeight(nodeID) + vdrWeight = env.config.Validators.GetWeight(constants.PrimaryNetworkID, nodeID) require.Equal(env.config.MinDelegatorStake+env.config.MinValidatorStake, vdrWeight) } @@ -767,9 +769,7 @@ func TestAdvanceTimeTxDelegatorStakers(t *testing.T) { require.NoError(env.state.Commit()) // Test validator weight before delegation - primarySet, ok := env.config.Validators.Get(constants.PrimaryNetworkID) - require.True(ok) - vdrWeight := primarySet.GetWeight(nodeID) + vdrWeight := env.config.Validators.GetWeight(constants.PrimaryNetworkID, nodeID) require.Equal(env.config.MinValidatorStake, vdrWeight) // Add delegator @@ -821,7 +821,7 @@ func TestAdvanceTimeTxDelegatorStakers(t *testing.T) { require.NoError(env.state.Commit()) // Test validator weight after delegation - vdrWeight = primarySet.GetWeight(nodeID) + vdrWeight = env.config.Validators.GetWeight(constants.PrimaryNetworkID, nodeID) require.Equal(env.config.MinDelegatorStake+env.config.MinValidatorStake, vdrWeight) } diff --git a/vms/platformvm/txs/executor/atomic_tx_executor.go b/vms/platformvm/txs/executor/atomic_tx_executor.go index 43bf448f3eb9..b9c8ba5f7b7f 100644 --- a/vms/platformvm/txs/executor/atomic_tx_executor.go +++ b/vms/platformvm/txs/executor/atomic_tx_executor.go @@ -74,6 +74,10 @@ func (*AtomicTxExecutor) TransformSubnetTx(*txs.TransformSubnetTx) error { return ErrWrongTxType } +func (*AtomicTxExecutor) TransferSubnetOwnershipTx(*txs.TransferSubnetOwnershipTx) error { + return ErrWrongTxType +} + func (*AtomicTxExecutor) AddPermissionlessValidatorTx(*txs.AddPermissionlessValidatorTx) error { return ErrWrongTxType } diff --git a/vms/platformvm/txs/executor/camino_advance_time_test.go b/vms/platformvm/txs/executor/camino_advance_time_test.go index 57fffc706fcb..c3d408a987d3 100644 --- a/vms/platformvm/txs/executor/camino_advance_time_test.go +++ b/vms/platformvm/txs/executor/camino_advance_time_test.go @@ -11,7 +11,6 @@ import ( "github.com/ava-labs/avalanchego/database" "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/snow/validators" "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/crypto/secp256k1" "github.com/ava-labs/avalanchego/utils/nodeid" @@ -208,7 +207,6 @@ func TestDeferredStakers(t *testing.T) { subnetID := testSubnet1.ID() env.config.TrackedSubnets.Add(subnetID) - env.config.Validators.Add(subnetID, validators.NewSet()) for _, staker := range test.stakers { _, err := addCaminoPendingValidator( @@ -282,11 +280,13 @@ func TestDeferredStakers(t *testing.T) { case pending: _, err := env.state.GetPendingValidator(constants.PrimaryNetworkID, stakerNodeID) require.NoError(err) - require.False(validators.Contains(env.config.Validators, constants.PrimaryNetworkID, stakerNodeID)) + _, ok := env.config.Validators.GetValidator(constants.PrimaryNetworkID, stakerNodeID) + require.False(ok) case current: _, err := env.state.GetCurrentValidator(constants.PrimaryNetworkID, stakerNodeID) require.NoError(err) - require.True(validators.Contains(env.config.Validators, constants.PrimaryNetworkID, stakerNodeID)) + _, ok := env.config.Validators.GetValidator(constants.PrimaryNetworkID, stakerNodeID) + require.True(ok) case expired: _, err := env.state.GetCurrentValidator(constants.PrimaryNetworkID, stakerNodeID) require.ErrorIs(err, database.ErrNotFound) @@ -300,11 +300,13 @@ func TestDeferredStakers(t *testing.T) { case pending: _, err := env.state.GetPendingValidator(subnetID, stakerNodeID) require.NoError(err) - require.False(validators.Contains(env.config.Validators, subnetID, stakerNodeID)) + _, ok := env.config.Validators.GetValidator(subnetID, stakerNodeID) + require.False(ok) case current: _, err := env.state.GetCurrentValidator(subnetID, stakerNodeID) require.NoError(err) - require.True(validators.Contains(env.config.Validators, subnetID, stakerNodeID)) + _, ok := env.config.Validators.GetValidator(subnetID, stakerNodeID) + require.True(ok) case expired: _, err := env.state.GetCurrentValidator(subnetID, stakerNodeID) require.ErrorIs(err, database.ErrNotFound) diff --git a/vms/platformvm/txs/executor/camino_helpers_test.go b/vms/platformvm/txs/executor/camino_helpers_test.go index 7b77e656ed4f..48bfddcffc06 100644 --- a/vms/platformvm/txs/executor/camino_helpers_test.go +++ b/vms/platformvm/txs/executor/camino_helpers_test.go @@ -247,14 +247,10 @@ func defaultCaminoConfig(postBanff bool) config.Config { banffTime = defaultValidateEndTime.Add(-2 * time.Second) } - vdrs := validators.NewManager() - primaryVdrs := validators.NewSet() - _ = vdrs.Add(constants.PrimaryNetworkID, primaryVdrs) - return config.Config{ Chains: chains.TestManager, UptimeLockedCalculator: uptime.NewLockedCalculator(), - Validators: vdrs, + Validators: validators.NewManager(), TxFee: defaultTxFee, CreateSubnetTxFee: 100 * defaultTxFee, CreateBlockchainTxFee: 100 * defaultTxFee, @@ -597,10 +593,7 @@ func generateMsigAliasAndKeys(t *testing.T, threshold, addrsCount uint32, sorted func shutdownCaminoEnvironment(env *caminoEnvironment) error { if env.isBootstrapped.Get() { - validatorIDs, err := validators.NodeIDs(env.config.Validators, constants.PrimaryNetworkID) - if err != nil { - return err - } + validatorIDs := env.config.Validators.GetValidatorIDs(constants.PrimaryNetworkID) if err := env.uptimes.StopTracking(validatorIDs, constants.PrimaryNetworkID); err != nil { return err diff --git a/vms/platformvm/txs/executor/helpers_test.go b/vms/platformvm/txs/executor/helpers_test.go index d3d842ee8129..74a5bb40764f 100644 --- a/vms/platformvm/txs/executor/helpers_test.go +++ b/vms/platformvm/txs/executor/helpers_test.go @@ -287,13 +287,10 @@ func defaultConfig(postBanff, postCortina bool) config.Config { cortinaTime = defaultValidateStartTime.Add(-2 * time.Second) } - vdrs := validators.NewManager() - primaryVdrs := validators.NewSet() - _ = vdrs.Add(constants.PrimaryNetworkID, primaryVdrs) return config.Config{ Chains: chains.TestManager, UptimeLockedCalculator: uptime.NewLockedCalculator(), - Validators: vdrs, + Validators: validators.NewManager(), TxFee: defaultTxFee, CreateSubnetTxFee: 100 * defaultTxFee, CreateBlockchainTxFee: 100 * defaultTxFee, @@ -428,23 +425,14 @@ func buildGenesisTest(ctx *snow.Context) []byte { func shutdownEnvironment(env *environment) error { if env.isBootstrapped.Get() { - validatorIDs, err := validators.NodeIDs(env.config.Validators, constants.PrimaryNetworkID) - if err != nil { - return err - } + validatorIDs := env.config.Validators.GetValidatorIDs(constants.PrimaryNetworkID) if err := env.uptimes.StopTracking(validatorIDs, constants.PrimaryNetworkID); err != nil { return err } for subnetID := range env.config.TrackedSubnets { - validatorIDs, err := validators.NodeIDs(env.config.Validators, subnetID) - if errors.Is(err, validators.ErrMissingValidators) { - return nil - } - if err != nil { - return err - } + validatorIDs := env.config.Validators.GetValidatorIDs(subnetID) if err := env.uptimes.StopTracking(validatorIDs, subnetID); err != nil { return err diff --git a/vms/platformvm/txs/executor/proposal_tx_executor.go b/vms/platformvm/txs/executor/proposal_tx_executor.go index babb6bbdb921..f182a979c79c 100644 --- a/vms/platformvm/txs/executor/proposal_tx_executor.go +++ b/vms/platformvm/txs/executor/proposal_tx_executor.go @@ -107,6 +107,10 @@ func (*ProposalTxExecutor) AddPermissionlessDelegatorTx(*txs.AddPermissionlessDe return ErrWrongTxType } +func (*ProposalTxExecutor) TransferSubnetOwnershipTx(*txs.TransferSubnetOwnershipTx) error { + return ErrWrongTxType +} + func (e *ProposalTxExecutor) AddValidatorTx(tx *txs.AddValidatorTx) error { // AddValidatorTx is a proposal transaction until the Banff fork // activation. Following the activation, AddValidatorTxs must be issued into diff --git a/vms/platformvm/txs/executor/reward_validator_test.go b/vms/platformvm/txs/executor/reward_validator_test.go index 8ccab002cc91..5871b9eef531 100644 --- a/vms/platformvm/txs/executor/reward_validator_test.go +++ b/vms/platformvm/txs/executor/reward_validator_test.go @@ -287,10 +287,7 @@ func TestRewardDelegatorTxExecuteOnCommitPreDelegateeDeferral(t *testing.T) { require.NoError(env.state.Commit()) // test validator stake - vdrSet, ok := env.config.Validators.Get(constants.PrimaryNetworkID) - require.True(ok) - - stake := vdrSet.GetWeight(vdrNodeID) + stake := env.config.Validators.GetWeight(constants.PrimaryNetworkID, vdrNodeID) require.Equal(env.config.MinValidatorStake+env.config.MinDelegatorStake, stake) tx, err := env.txBuilder.NewRewardValidatorTx(delTx.ID()) @@ -342,7 +339,8 @@ func TestRewardDelegatorTxExecuteOnCommitPreDelegateeDeferral(t *testing.T) { require.Less(vdrReward, delReward, "the delegator's reward should be greater than the delegatee's because the delegatee's share is 25%") require.Equal(expectedReward, delReward+vdrReward, "expected total reward to be %d but is %d", expectedReward, delReward+vdrReward) - require.Equal(env.config.MinValidatorStake, vdrSet.GetWeight(vdrNodeID)) + stake = env.config.Validators.GetWeight(constants.PrimaryNetworkID, vdrNodeID) + require.Equal(env.config.MinValidatorStake, stake) } func TestRewardDelegatorTxExecuteOnCommitPostDelegateeDeferral(t *testing.T) { @@ -419,10 +417,7 @@ func TestRewardDelegatorTxExecuteOnCommitPostDelegateeDeferral(t *testing.T) { require.NoError(err) // test validator stake - vdrSet, ok := env.config.Validators.Get(constants.PrimaryNetworkID) - require.True(ok) - - stake := vdrSet.GetWeight(vdrNodeID) + stake := env.config.Validators.GetWeight(constants.PrimaryNetworkID, vdrNodeID) require.Equal(env.config.MinValidatorStake+env.config.MinDelegatorStake, stake) tx, err := env.txBuilder.NewRewardValidatorTx(delTx.ID()) diff --git a/vms/platformvm/txs/executor/staker_tx_verification.go b/vms/platformvm/txs/executor/staker_tx_verification.go index e48b5e80cf73..95908614c5e3 100644 --- a/vms/platformvm/txs/executor/staker_tx_verification.go +++ b/vms/platformvm/txs/executor/staker_tx_verification.go @@ -47,6 +47,7 @@ var ( ErrDuplicateValidator = errors.New("duplicate validator") ErrDelegateToPermissionedValidator = errors.New("delegation to permissioned validator") ErrWrongStakedAssetID = errors.New("incorrect staked assetID") + ErrDUpgradeNotActive = errors.New("attempting to use a D-upgrade feature prior to activation") ) // verifySubnetValidatorPrimaryNetworkRequirements verifies the primary @@ -729,3 +730,50 @@ func verifyAddPermissionlessDelegatorTx( return nil } + +// Returns an error if the given tx is invalid. +// The transaction is valid if: +// * [sTx]'s creds authorize it to spend the stated inputs. +// * [sTx]'s creds authorize it to transfer ownership of [tx.Subnet]. +// * The flow checker passes. +func verifyTransferSubnetOwnershipTx( + backend *Backend, + chainState state.Chain, + sTx *txs.Tx, + tx *txs.TransferSubnetOwnershipTx, +) error { + if !backend.Config.IsDActivated(chainState.GetTimestamp()) { + return ErrDUpgradeNotActive + } + + // Verify the tx is well-formed + if err := sTx.SyntacticVerify(backend.Ctx); err != nil { + return err + } + + if !backend.Bootstrapped.Get() { + // Not bootstrapped yet -- don't need to do full verification. + return nil + } + + baseTxCreds, err := verifySubnetAuthorization(backend, chainState, sTx, tx.Subnet, tx.SubnetAuth) + if err != nil { + return err + } + + // Verify the flowcheck + if err := backend.FlowChecker.VerifySpend( + tx, + chainState, + tx.Ins, + tx.Outs, + baseTxCreds, + map[ids.ID]uint64{ + backend.Ctx.AVAXAssetID: backend.Config.TxFee, + }, + ); err != nil { + return fmt.Errorf("%w: %w", ErrFlowCheckFailed, err) + } + + return nil +} diff --git a/vms/platformvm/txs/executor/standard_tx_executor.go b/vms/platformvm/txs/executor/standard_tx_executor.go index 560f263ef10c..d8fab705515d 100644 --- a/vms/platformvm/txs/executor/standard_tx_executor.go +++ b/vms/platformvm/txs/executor/standard_tx_executor.go @@ -511,3 +511,27 @@ func (e *StandardTxExecutor) AddPermissionlessDelegatorTx(tx *txs.AddPermissionl return nil } + +// Verifies a [*txs.TransferSubnetOwnershipTx] and, if it passes, executes it on +// [e.State]. For verification rules, see [verifyTransferSubnetOwnershipTx]. +// This transaction will result in the ownership of [tx.Subnet] being transferred +// to [tx.Owner]. +func (e *StandardTxExecutor) TransferSubnetOwnershipTx(tx *txs.TransferSubnetOwnershipTx) error { + err := verifyTransferSubnetOwnershipTx( + e.Backend, + e.State, + e.Tx, + tx, + ) + if err != nil { + return err + } + + e.State.SetSubnetOwner(tx.Subnet, tx.Owner) + + txID := e.Tx.ID() + avax.Consume(e.State, tx.Ins) + avax.Produce(e.State, txID, tx.Outs) + + return nil +} diff --git a/vms/platformvm/txs/executor/tx_mempool_verifier.go b/vms/platformvm/txs/executor/tx_mempool_verifier.go index 2aed1bafbac2..83a17b615803 100644 --- a/vms/platformvm/txs/executor/tx_mempool_verifier.go +++ b/vms/platformvm/txs/executor/tx_mempool_verifier.go @@ -84,6 +84,10 @@ func (v *MempoolTxVerifier) AddPermissionlessDelegatorTx(tx *txs.AddPermissionle return v.standardTx(tx) } +func (v *MempoolTxVerifier) TransferSubnetOwnershipTx(tx *txs.TransferSubnetOwnershipTx) error { + return v.standardTx(tx) +} + func (v *MempoolTxVerifier) standardTx(tx txs.UnsignedTx) error { baseState, err := v.standardBaseState() if err != nil { diff --git a/vms/platformvm/txs/mempool/issuer.go b/vms/platformvm/txs/mempool/issuer.go index aa5e5c707a6c..e24afb5282da 100644 --- a/vms/platformvm/txs/mempool/issuer.go +++ b/vms/platformvm/txs/mempool/issuer.go @@ -74,6 +74,11 @@ func (i *issuer) TransformSubnetTx(*txs.TransformSubnetTx) error { return nil } +func (i *issuer) TransferSubnetOwnershipTx(*txs.TransferSubnetOwnershipTx) error { + i.m.addDecisionTx(i.tx) + return nil +} + func (i *issuer) AddPermissionlessValidatorTx(*txs.AddPermissionlessValidatorTx) error { i.m.addStakerTx(i.tx) return nil diff --git a/vms/platformvm/txs/mempool/remover.go b/vms/platformvm/txs/mempool/remover.go index fcdeca380679..e418cf46c342 100644 --- a/vms/platformvm/txs/mempool/remover.go +++ b/vms/platformvm/txs/mempool/remover.go @@ -57,6 +57,11 @@ func (r *remover) TransformSubnetTx(*txs.TransformSubnetTx) error { return nil } +func (r *remover) TransferSubnetOwnershipTx(*txs.TransferSubnetOwnershipTx) error { + r.m.removeDecisionTxs([]*txs.Tx{r.tx}) + return nil +} + func (r *remover) AddPermissionlessValidatorTx(*txs.AddPermissionlessValidatorTx) error { r.m.removeStakerTx(r.tx) return nil diff --git a/vms/platformvm/txs/transfer_subnet_ownership_tx.go b/vms/platformvm/txs/transfer_subnet_ownership_tx.go new file mode 100644 index 000000000000..78dbf28b48b4 --- /dev/null +++ b/vms/platformvm/txs/transfer_subnet_ownership_tx.go @@ -0,0 +1,65 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package txs + +import ( + "errors" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow" + "github.com/ava-labs/avalanchego/utils/constants" + "github.com/ava-labs/avalanchego/vms/components/verify" + "github.com/ava-labs/avalanchego/vms/platformvm/fx" +) + +var ( + _ UnsignedTx = (*TransferSubnetOwnershipTx)(nil) + + ErrTransferPermissionlessSubnet = errors.New("cannot transfer ownership of a permissionless subnet") +) + +type TransferSubnetOwnershipTx struct { + // Metadata, inputs and outputs + BaseTx `serialize:"true"` + // ID of the subnet this tx is modifying + Subnet ids.ID `serialize:"true" json:"subnetID"` + // Proves that the issuer has the right to remove the node from the subnet. + SubnetAuth verify.Verifiable `serialize:"true" json:"subnetAuthorization"` + // Who is now authorized to manage this subnet + Owner fx.Owner `serialize:"true" json:"newOwner"` +} + +// InitCtx sets the FxID fields in the inputs and outputs of this +// [TransferSubnetOwnershipTx]. Also sets the [ctx] to the given [vm.ctx] so +// that the addresses can be json marshalled into human readable format +func (tx *TransferSubnetOwnershipTx) InitCtx(ctx *snow.Context) { + tx.BaseTx.InitCtx(ctx) + tx.Owner.InitCtx(ctx) +} + +func (tx *TransferSubnetOwnershipTx) SyntacticVerify(ctx *snow.Context) error { + switch { + case tx == nil: + return ErrNilTx + case tx.SyntacticallyVerified: + // already passed syntactic verification + return nil + case tx.Subnet == constants.PrimaryNetworkID: + return ErrTransferPermissionlessSubnet + } + + if err := tx.BaseTx.SyntacticVerify(ctx); err != nil { + return err + } + if err := verify.All(tx.SubnetAuth, tx.Owner); err != nil { + return err + } + + tx.SyntacticallyVerified = true + return nil +} + +func (tx *TransferSubnetOwnershipTx) Visit(visitor Visitor) error { + return visitor.TransferSubnetOwnershipTx(tx) +} diff --git a/vms/platformvm/txs/transfer_subnet_ownership_tx_test.go b/vms/platformvm/txs/transfer_subnet_ownership_tx_test.go new file mode 100644 index 000000000000..7e6f5835a283 --- /dev/null +++ b/vms/platformvm/txs/transfer_subnet_ownership_tx_test.go @@ -0,0 +1,669 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package txs + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" + + "go.uber.org/mock/gomock" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow" + "github.com/ava-labs/avalanchego/utils" + "github.com/ava-labs/avalanchego/utils/constants" + "github.com/ava-labs/avalanchego/utils/units" + "github.com/ava-labs/avalanchego/vms/components/avax" + "github.com/ava-labs/avalanchego/vms/components/verify" + "github.com/ava-labs/avalanchego/vms/platformvm/fx" + "github.com/ava-labs/avalanchego/vms/platformvm/stakeable" + "github.com/ava-labs/avalanchego/vms/secp256k1fx" + "github.com/ava-labs/avalanchego/vms/types" +) + +func TestTransferSubnetOwnershipTxSerialization(t *testing.T) { + require := require.New(t) + + addr := ids.ShortID{ + 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, + 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, + 0x44, 0x55, 0x66, 0x77, + } + + avaxAssetID, err := ids.FromString("FvwEAhmxKfeiG8SnEvq42hc6whRyY3EFYAvebMqDNDGCgxN5Z") + require.NoError(err) + + customAssetID := ids.ID{ + 0x99, 0x77, 0x55, 0x77, 0x11, 0x33, 0x55, 0x31, + 0x99, 0x77, 0x55, 0x77, 0x11, 0x33, 0x55, 0x31, + 0x99, 0x77, 0x55, 0x77, 0x11, 0x33, 0x55, 0x31, + 0x99, 0x77, 0x55, 0x77, 0x11, 0x33, 0x55, 0x31, + } + + txID := ids.ID{ + 0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88, + 0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88, + 0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88, + 0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88, + } + subnetID := ids.ID{ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, + 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, + 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, + } + + simpleTransferSubnetOwnershipTx := &TransferSubnetOwnershipTx{ + BaseTx: BaseTx{ + BaseTx: avax.BaseTx{ + NetworkID: constants.MainnetID, + BlockchainID: constants.PlatformChainID, + Outs: []*avax.TransferableOutput{}, + Ins: []*avax.TransferableInput{ + { + UTXOID: avax.UTXOID{ + TxID: txID, + OutputIndex: 1, + }, + Asset: avax.Asset{ + ID: avaxAssetID, + }, + In: &secp256k1fx.TransferInput{ + Amt: units.MilliAvax, + Input: secp256k1fx.Input{ + SigIndices: []uint32{5}, + }, + }, + }, + }, + Memo: types.JSONByteSlice{}, + }, + }, + Subnet: subnetID, + SubnetAuth: &secp256k1fx.Input{ + SigIndices: []uint32{3}, + }, + Owner: &secp256k1fx.OutputOwners{ + Locktime: 0, + Threshold: 1, + Addrs: []ids.ShortID{ + addr, + }, + }, + } + require.NoError(simpleTransferSubnetOwnershipTx.SyntacticVerify(&snow.Context{ + NetworkID: 1, + ChainID: constants.PlatformChainID, + AVAXAssetID: avaxAssetID, + })) + + expectedUnsignedSimpleTransferSubnetOwnershipTxBytes := []byte{ + // Codec version + 0x00, 0x00, + // RemoveSubnetValidatorTx Type ID + 0x00, 0x00, 0x00, 0x21, + // Mainnet network ID + 0x00, 0x00, 0x00, 0x01, + // P-chain blockchain ID + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // Number of outputs + 0x00, 0x00, 0x00, 0x00, + // Number of inputs + 0x00, 0x00, 0x00, 0x01, + // Inputs[0] + // TxID + 0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88, + 0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88, + 0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88, + 0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88, + // Tx output index + 0x00, 0x00, 0x00, 0x01, + // Mainnet AVAX assetID + 0x21, 0xe6, 0x73, 0x17, 0xcb, 0xc4, 0xbe, 0x2a, + 0xeb, 0x00, 0x67, 0x7a, 0xd6, 0x46, 0x27, 0x78, + 0xa8, 0xf5, 0x22, 0x74, 0xb9, 0xd6, 0x05, 0xdf, + 0x25, 0x91, 0xb2, 0x30, 0x27, 0xa8, 0x7d, 0xff, + // secp256k1fx transfer input type ID + 0x00, 0x00, 0x00, 0x05, + // input amount = 1 MilliAvax + 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x42, 0x40, + // number of signatures needed in input + 0x00, 0x00, 0x00, 0x01, + // index of signer + 0x00, 0x00, 0x00, 0x05, + // length of memo + 0x00, 0x00, 0x00, 0x00, + // subnetID to modify + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, + 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, + 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, + // secp256k1fx authorization type ID + 0x00, 0x00, 0x00, 0x0a, + // number of signatures needed in authorization + 0x00, 0x00, 0x00, 0x01, + // index of signer + 0x00, 0x00, 0x00, 0x03, + // secp256k1fx output owners type ID + 0x00, 0x00, 0x00, 0x0b, + // locktime + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // threshold + 0x00, 0x00, 0x00, 0x01, + // number of addrs + 0x00, 0x00, 0x00, 0x01, + // Addrs[0] + 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, + 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, + 0x44, 0x55, 0x66, 0x77, + } + var unsignedSimpleTransferSubnetOwnershipTx UnsignedTx = simpleTransferSubnetOwnershipTx + unsignedSimpleTransferSubnetOwnershipTxBytes, err := Codec.Marshal(Version, &unsignedSimpleTransferSubnetOwnershipTx) + require.NoError(err) + require.Equal(expectedUnsignedSimpleTransferSubnetOwnershipTxBytes, unsignedSimpleTransferSubnetOwnershipTxBytes) + + complexTransferSubnetOwnershipTx := &TransferSubnetOwnershipTx{ + BaseTx: BaseTx{ + BaseTx: avax.BaseTx{ + NetworkID: constants.MainnetID, + BlockchainID: constants.PlatformChainID, + Outs: []*avax.TransferableOutput{ + { + Asset: avax.Asset{ + ID: avaxAssetID, + }, + Out: &stakeable.LockOut{ + Locktime: 87654321, + TransferableOut: &secp256k1fx.TransferOutput{ + Amt: 1, + OutputOwners: secp256k1fx.OutputOwners{ + Locktime: 12345678, + Threshold: 0, + Addrs: []ids.ShortID{}, + }, + }, + }, + }, + { + Asset: avax.Asset{ + ID: customAssetID, + }, + Out: &stakeable.LockOut{ + Locktime: 876543210, + TransferableOut: &secp256k1fx.TransferOutput{ + Amt: 0xffffffffffffffff, + OutputOwners: secp256k1fx.OutputOwners{ + Locktime: 0, + Threshold: 1, + Addrs: []ids.ShortID{ + addr, + }, + }, + }, + }, + }, + }, + Ins: []*avax.TransferableInput{ + { + UTXOID: avax.UTXOID{ + TxID: txID, + OutputIndex: 1, + }, + Asset: avax.Asset{ + ID: avaxAssetID, + }, + In: &secp256k1fx.TransferInput{ + Amt: units.Avax, + Input: secp256k1fx.Input{ + SigIndices: []uint32{2, 5}, + }, + }, + }, + { + UTXOID: avax.UTXOID{ + TxID: txID, + OutputIndex: 2, + }, + Asset: avax.Asset{ + ID: customAssetID, + }, + In: &stakeable.LockIn{ + Locktime: 876543210, + TransferableIn: &secp256k1fx.TransferInput{ + Amt: 0xefffffffffffffff, + Input: secp256k1fx.Input{ + SigIndices: []uint32{0}, + }, + }, + }, + }, + { + UTXOID: avax.UTXOID{ + TxID: txID, + OutputIndex: 3, + }, + Asset: avax.Asset{ + ID: customAssetID, + }, + In: &secp256k1fx.TransferInput{ + Amt: 0x1000000000000000, + Input: secp256k1fx.Input{ + SigIndices: []uint32{}, + }, + }, + }, + }, + Memo: types.JSONByteSlice("😅\nwell that's\x01\x23\x45!"), + }, + }, + Subnet: subnetID, + SubnetAuth: &secp256k1fx.Input{ + SigIndices: []uint32{}, + }, + Owner: &secp256k1fx.OutputOwners{ + Locktime: 876543210, + Threshold: 1, + Addrs: []ids.ShortID{ + addr, + }, + }, + } + avax.SortTransferableOutputs(complexTransferSubnetOwnershipTx.Outs, Codec) + utils.Sort(complexTransferSubnetOwnershipTx.Ins) + require.NoError(simpleTransferSubnetOwnershipTx.SyntacticVerify(&snow.Context{ + NetworkID: 1, + ChainID: constants.PlatformChainID, + AVAXAssetID: avaxAssetID, + })) + + expectedUnsignedComplexTransferSubnetOwnershipTxBytes := []byte{ + // Codec version + 0x00, 0x00, + // TransferSubnetOwnershipTx Type ID + 0x00, 0x00, 0x00, 0x21, + // Mainnet network ID + 0x00, 0x00, 0x00, 0x01, + // P-chain blockchain ID + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // Number of outputs + 0x00, 0x00, 0x00, 0x02, + // Outputs[0] + // Mainnet AVAX assetID + 0x21, 0xe6, 0x73, 0x17, 0xcb, 0xc4, 0xbe, 0x2a, + 0xeb, 0x00, 0x67, 0x7a, 0xd6, 0x46, 0x27, 0x78, + 0xa8, 0xf5, 0x22, 0x74, 0xb9, 0xd6, 0x05, 0xdf, + 0x25, 0x91, 0xb2, 0x30, 0x27, 0xa8, 0x7d, 0xff, + // Stakeable locked output type ID + 0x00, 0x00, 0x00, 0x16, + // Locktime + 0x00, 0x00, 0x00, 0x00, 0x05, 0x39, 0x7f, 0xb1, + // secp256k1fx transfer output type ID + 0x00, 0x00, 0x00, 0x07, + // amount + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + // secp256k1fx output locktime + 0x00, 0x00, 0x00, 0x00, 0x00, 0xbc, 0x61, 0x4e, + // threshold + 0x00, 0x00, 0x00, 0x00, + // number of addresses + 0x00, 0x00, 0x00, 0x00, + // Outputs[1] + // custom asset ID + 0x99, 0x77, 0x55, 0x77, 0x11, 0x33, 0x55, 0x31, + 0x99, 0x77, 0x55, 0x77, 0x11, 0x33, 0x55, 0x31, + 0x99, 0x77, 0x55, 0x77, 0x11, 0x33, 0x55, 0x31, + 0x99, 0x77, 0x55, 0x77, 0x11, 0x33, 0x55, 0x31, + // Stakeable locked output type ID + 0x00, 0x00, 0x00, 0x16, + // Locktime + 0x00, 0x00, 0x00, 0x00, 0x34, 0x3e, 0xfc, 0xea, + // secp256k1fx transfer output type ID + 0x00, 0x00, 0x00, 0x07, + // amount + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + // secp256k1fx output locktime + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // threshold + 0x00, 0x00, 0x00, 0x01, + // number of addresses + 0x00, 0x00, 0x00, 0x01, + // address[0] + 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, + 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, + 0x44, 0x55, 0x66, 0x77, + // number of inputs + 0x00, 0x00, 0x00, 0x03, + // inputs[0] + // TxID + 0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88, + 0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88, + 0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88, + 0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88, + // Tx output index + 0x00, 0x00, 0x00, 0x01, + // Mainnet AVAX assetID + 0x21, 0xe6, 0x73, 0x17, 0xcb, 0xc4, 0xbe, 0x2a, + 0xeb, 0x00, 0x67, 0x7a, 0xd6, 0x46, 0x27, 0x78, + 0xa8, 0xf5, 0x22, 0x74, 0xb9, 0xd6, 0x05, 0xdf, + 0x25, 0x91, 0xb2, 0x30, 0x27, 0xa8, 0x7d, 0xff, + // secp256k1fx transfer input type ID + 0x00, 0x00, 0x00, 0x05, + // input amount = 1 Avax + 0x00, 0x00, 0x00, 0x00, 0x3b, 0x9a, 0xca, 0x00, + // number of signatures needed in input + 0x00, 0x00, 0x00, 0x02, + // index of first signer + 0x00, 0x00, 0x00, 0x02, + // index of second signer + 0x00, 0x00, 0x00, 0x05, + // inputs[1] + // TxID + 0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88, + 0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88, + 0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88, + 0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88, + // Tx output index + 0x00, 0x00, 0x00, 0x02, + // Custom asset ID + 0x99, 0x77, 0x55, 0x77, 0x11, 0x33, 0x55, 0x31, + 0x99, 0x77, 0x55, 0x77, 0x11, 0x33, 0x55, 0x31, + 0x99, 0x77, 0x55, 0x77, 0x11, 0x33, 0x55, 0x31, + 0x99, 0x77, 0x55, 0x77, 0x11, 0x33, 0x55, 0x31, + // Stakeable locked input type ID + 0x00, 0x00, 0x00, 0x15, + // Locktime + 0x00, 0x00, 0x00, 0x00, 0x34, 0x3e, 0xfc, 0xea, + // secp256k1fx transfer input type ID + 0x00, 0x00, 0x00, 0x05, + // input amount + 0xef, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + // number of signatures needed in input + 0x00, 0x00, 0x00, 0x01, + // index of signer + 0x00, 0x00, 0x00, 0x00, + // inputs[2] + // TxID + 0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88, + 0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88, + 0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88, + 0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99, 0x88, + // Tx output index + 0x00, 0x00, 0x00, 0x03, + // custom asset ID + 0x99, 0x77, 0x55, 0x77, 0x11, 0x33, 0x55, 0x31, + 0x99, 0x77, 0x55, 0x77, 0x11, 0x33, 0x55, 0x31, + 0x99, 0x77, 0x55, 0x77, 0x11, 0x33, 0x55, 0x31, + 0x99, 0x77, 0x55, 0x77, 0x11, 0x33, 0x55, 0x31, + // secp256k1fx transfer input type ID + 0x00, 0x00, 0x00, 0x05, + // input amount + 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // number of signatures needed in input + 0x00, 0x00, 0x00, 0x00, + // length of memo + 0x00, 0x00, 0x00, 0x14, + // memo + 0xf0, 0x9f, 0x98, 0x85, 0x0a, 0x77, 0x65, 0x6c, + 0x6c, 0x20, 0x74, 0x68, 0x61, 0x74, 0x27, 0x73, + 0x01, 0x23, 0x45, 0x21, + // subnetID to modify + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, + 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, + 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, + // secp256k1fx authorization type ID + 0x00, 0x00, 0x00, 0x0a, + // number of signatures needed in authorization + 0x00, 0x00, 0x00, 0x00, + // secp256k1fx output owners type ID + 0x00, 0x00, 0x00, 0x0b, + // locktime + 0x00, 0x00, 0x00, 0x00, 0x34, 0x3e, 0xfc, 0xea, + // threshold + 0x00, 0x00, 0x00, 0x01, + // number of addrs + 0x00, 0x00, 0x00, 0x01, + // Addrs[0] + 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, + 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, + 0x44, 0x55, 0x66, 0x77, + } + var unsignedComplexTransferSubnetOwnershipTx UnsignedTx = complexTransferSubnetOwnershipTx + unsignedComplexTransferSubnetOwnershipTxBytes, err := Codec.Marshal(Version, &unsignedComplexTransferSubnetOwnershipTx) + require.NoError(err) + require.Equal(expectedUnsignedComplexTransferSubnetOwnershipTxBytes, unsignedComplexTransferSubnetOwnershipTxBytes) + + aliaser := ids.NewAliaser() + require.NoError(aliaser.Alias(constants.PlatformChainID, "P")) + + unsignedComplexTransferSubnetOwnershipTx.InitCtx(&snow.Context{ + NetworkID: 1, + ChainID: constants.PlatformChainID, + AVAXAssetID: avaxAssetID, + BCLookup: aliaser, + }) + + unsignedComplexTransferSubnetOwnershipTxJSONBytes, err := json.MarshalIndent(unsignedComplexTransferSubnetOwnershipTx, "", "\t") + require.NoError(err) + require.Equal(`{ + "networkID": 1, + "blockchainID": "11111111111111111111111111111111LpoYY", + "outputs": [ + { + "assetID": "FvwEAhmxKfeiG8SnEvq42hc6whRyY3EFYAvebMqDNDGCgxN5Z", + "fxID": "spdxUxVJQbX85MGxMHbKw1sHxMnSqJ3QBzDyDYEP3h6TLuxqQ", + "output": { + "locktime": 87654321, + "output": { + "addresses": [], + "amount": 1, + "locktime": 12345678, + "threshold": 0 + } + } + }, + { + "assetID": "2Ab62uWwJw1T6VvmKD36ufsiuGZuX1pGykXAvPX1LtjTRHxwcc", + "fxID": "spdxUxVJQbX85MGxMHbKw1sHxMnSqJ3QBzDyDYEP3h6TLuxqQ", + "output": { + "locktime": 876543210, + "output": { + "addresses": [ + "P-avax1g32kvaugnx4tk3z4vemc3xd2hdz92enh972wxr" + ], + "amount": 18446744073709551615, + "locktime": 0, + "threshold": 1 + } + } + } + ], + "inputs": [ + { + "txID": "2wiU5PnFTjTmoAXGZutHAsPF36qGGyLHYHj9G1Aucfmb3JFFGN", + "outputIndex": 1, + "assetID": "FvwEAhmxKfeiG8SnEvq42hc6whRyY3EFYAvebMqDNDGCgxN5Z", + "fxID": "spdxUxVJQbX85MGxMHbKw1sHxMnSqJ3QBzDyDYEP3h6TLuxqQ", + "input": { + "amount": 1000000000, + "signatureIndices": [ + 2, + 5 + ] + } + }, + { + "txID": "2wiU5PnFTjTmoAXGZutHAsPF36qGGyLHYHj9G1Aucfmb3JFFGN", + "outputIndex": 2, + "assetID": "2Ab62uWwJw1T6VvmKD36ufsiuGZuX1pGykXAvPX1LtjTRHxwcc", + "fxID": "spdxUxVJQbX85MGxMHbKw1sHxMnSqJ3QBzDyDYEP3h6TLuxqQ", + "input": { + "locktime": 876543210, + "input": { + "amount": 17293822569102704639, + "signatureIndices": [ + 0 + ] + } + } + }, + { + "txID": "2wiU5PnFTjTmoAXGZutHAsPF36qGGyLHYHj9G1Aucfmb3JFFGN", + "outputIndex": 3, + "assetID": "2Ab62uWwJw1T6VvmKD36ufsiuGZuX1pGykXAvPX1LtjTRHxwcc", + "fxID": "spdxUxVJQbX85MGxMHbKw1sHxMnSqJ3QBzDyDYEP3h6TLuxqQ", + "input": { + "amount": 1152921504606846976, + "signatureIndices": [] + } + } + ], + "memo": "0xf09f98850a77656c6c2074686174277301234521", + "subnetID": "SkB92YpWm4UpburLz9tEKZw2i67H3FF6YkjaU4BkFUDTG9Xm", + "subnetAuthorization": { + "signatureIndices": [] + }, + "newOwner": { + "addresses": [ + "P-avax1g32kvaugnx4tk3z4vemc3xd2hdz92enh972wxr" + ], + "locktime": 876543210, + "threshold": 1 + } +}`, string(unsignedComplexTransferSubnetOwnershipTxJSONBytes)) +} + +func TestTransferSubnetOwnershipTxSyntacticVerify(t *testing.T) { + type test struct { + name string + txFunc func(*gomock.Controller) *TransferSubnetOwnershipTx + expectedErr error + } + + var ( + networkID = uint32(1337) + chainID = ids.GenerateTestID() + ) + + ctx := &snow.Context{ + ChainID: chainID, + NetworkID: networkID, + } + + // A BaseTx that already passed syntactic verification. + verifiedBaseTx := BaseTx{ + SyntacticallyVerified: true, + } + // Sanity check. + require.NoError(t, verifiedBaseTx.SyntacticVerify(ctx)) + + // A BaseTx that passes syntactic verification. + validBaseTx := BaseTx{ + BaseTx: avax.BaseTx{ + NetworkID: networkID, + BlockchainID: chainID, + }, + } + // Sanity check. + require.NoError(t, validBaseTx.SyntacticVerify(ctx)) + // Make sure we're not caching the verification result. + require.False(t, validBaseTx.SyntacticallyVerified) + + // A BaseTx that fails syntactic verification. + invalidBaseTx := BaseTx{} + + tests := []test{ + { + name: "nil tx", + txFunc: func(*gomock.Controller) *TransferSubnetOwnershipTx { + return nil + }, + expectedErr: ErrNilTx, + }, + { + name: "already verified", + txFunc: func(*gomock.Controller) *TransferSubnetOwnershipTx { + return &TransferSubnetOwnershipTx{BaseTx: verifiedBaseTx} + }, + expectedErr: nil, + }, + { + name: "invalid BaseTx", + txFunc: func(*gomock.Controller) *TransferSubnetOwnershipTx { + return &TransferSubnetOwnershipTx{ + // Set subnetID so we don't error on that check. + Subnet: ids.GenerateTestID(), + BaseTx: invalidBaseTx, + } + }, + expectedErr: avax.ErrWrongNetworkID, + }, + { + name: "invalid subnetID", + txFunc: func(*gomock.Controller) *TransferSubnetOwnershipTx { + return &TransferSubnetOwnershipTx{ + BaseTx: validBaseTx, + Subnet: constants.PrimaryNetworkID, + } + }, + expectedErr: ErrTransferPermissionlessSubnet, + }, + { + name: "invalid subnetAuth", + txFunc: func(ctrl *gomock.Controller) *TransferSubnetOwnershipTx { + // This SubnetAuth fails verification. + invalidSubnetAuth := verify.NewMockVerifiable(ctrl) + invalidSubnetAuth.EXPECT().Verify().Return(errInvalidSubnetAuth) + return &TransferSubnetOwnershipTx{ + // Set subnetID so we don't error on that check. + Subnet: ids.GenerateTestID(), + BaseTx: validBaseTx, + SubnetAuth: invalidSubnetAuth, + } + }, + expectedErr: errInvalidSubnetAuth, + }, + { + name: "passes verification", + txFunc: func(ctrl *gomock.Controller) *TransferSubnetOwnershipTx { + // This SubnetAuth passes verification. + validSubnetAuth := verify.NewMockVerifiable(ctrl) + validSubnetAuth.EXPECT().Verify().Return(nil) + mockOwner := fx.NewMockOwner(ctrl) + mockOwner.EXPECT().Verify().Return(nil) + return &TransferSubnetOwnershipTx{ + // Set subnetID so we don't error on that check. + Subnet: ids.GenerateTestID(), + BaseTx: validBaseTx, + SubnetAuth: validSubnetAuth, + Owner: mockOwner, + } + }, + expectedErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + ctrl := gomock.NewController(t) + + tx := tt.txFunc(ctrl) + err := tx.SyntacticVerify(ctx) + require.ErrorIs(err, tt.expectedErr) + if tt.expectedErr != nil { + return + } + require.True(tx.SyntacticallyVerified) + }) + } +} diff --git a/vms/platformvm/txs/txheap/by_age.go b/vms/platformvm/txs/txheap/by_age.go index a445822dd6e6..be888c437a0f 100644 --- a/vms/platformvm/txs/txheap/by_age.go +++ b/vms/platformvm/txs/txheap/by_age.go @@ -3,18 +3,15 @@ package txheap -var _ Heap = (*byAge)(nil) - -type byAge struct { - txHeap -} +import ( + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils/heap" +) func NewByAge() Heap { - h := &byAge{} - h.initialize(h) - return h -} - -func (h *byAge) Less(i, j int) bool { - return h.txs[i].age < h.txs[j].age + return &txHeap{ + heap: heap.NewMap[ids.ID, heapTx](func(a, b heapTx) bool { + return a.age < b.age + }), + } } diff --git a/vms/platformvm/txs/txheap/by_end_time.go b/vms/platformvm/txs/txheap/by_end_time.go index 2b0cbd8d3817..ba144448919d 100644 --- a/vms/platformvm/txs/txheap/by_end_time.go +++ b/vms/platformvm/txs/txheap/by_end_time.go @@ -6,6 +6,8 @@ package txheap import ( "time" + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils/heap" "github.com/ava-labs/avalanchego/vms/platformvm/txs" ) @@ -16,15 +18,15 @@ type byEndTime struct { } func NewByEndTime() TimedHeap { - h := &byEndTime{} - h.initialize(h) - return h -} - -func (h *byEndTime) Less(i, j int) bool { - iTime := h.txs[i].tx.Unsigned.(txs.Staker).EndTime() - jTime := h.txs[j].tx.Unsigned.(txs.Staker).EndTime() - return iTime.Before(jTime) + return &byEndTime{ + txHeap: txHeap{ + heap: heap.NewMap[ids.ID, heapTx](func(a, b heapTx) bool { + aTime := a.tx.Unsigned.(txs.Staker).EndTime() + bTime := b.tx.Unsigned.(txs.Staker).EndTime() + return aTime.Before(bTime) + }), + }, + } } func (h *byEndTime) Timestamp() time.Time { diff --git a/vms/platformvm/txs/txheap/by_end_time_test.go b/vms/platformvm/txs/txheap/by_end_time_test.go index 33ddc3cc3d1a..8ea152d27e02 100644 --- a/vms/platformvm/txs/txheap/by_end_time_test.go +++ b/vms/platformvm/txs/txheap/by_end_time_test.go @@ -14,7 +14,7 @@ import ( "github.com/ava-labs/avalanchego/vms/secp256k1fx" ) -func TestByStopTime(t *testing.T) { +func TestByEndTime(t *testing.T) { require := require.New(t) txHeap := NewByEndTime() diff --git a/vms/platformvm/txs/txheap/by_start_time.go b/vms/platformvm/txs/txheap/by_start_time.go index 31834cf0603d..f19c28d76436 100644 --- a/vms/platformvm/txs/txheap/by_start_time.go +++ b/vms/platformvm/txs/txheap/by_start_time.go @@ -6,6 +6,8 @@ package txheap import ( "time" + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils/heap" "github.com/ava-labs/avalanchego/vms/platformvm/txs" ) @@ -22,15 +24,15 @@ type byStartTime struct { } func NewByStartTime() TimedHeap { - h := &byStartTime{} - h.initialize(h) - return h -} - -func (h *byStartTime) Less(i, j int) bool { - iTime := h.txs[i].tx.Unsigned.(txs.Staker).StartTime() - jTime := h.txs[j].tx.Unsigned.(txs.Staker).StartTime() - return iTime.Before(jTime) + return &byStartTime{ + txHeap: txHeap{ + heap: heap.NewMap[ids.ID, heapTx](func(a, b heapTx) bool { + aTime := a.tx.Unsigned.(txs.Staker).StartTime() + bTime := b.tx.Unsigned.(txs.Staker).StartTime() + return aTime.Before(bTime) + }), + }, + } } func (h *byStartTime) Timestamp() time.Time { diff --git a/vms/platformvm/txs/txheap/heap.go b/vms/platformvm/txs/txheap/heap.go index 4b6ba68614cb..3727bb891d92 100644 --- a/vms/platformvm/txs/txheap/heap.go +++ b/vms/platformvm/txs/txheap/heap.go @@ -4,14 +4,11 @@ package txheap import ( - "container/heap" - "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils/heap" "github.com/ava-labs/avalanchego/vms/platformvm/txs" ) -var _ Heap = (*txHeap)(nil) - type Heap interface { Add(tx *txs.Tx) Get(txID ids.ID) *txs.Tx @@ -23,107 +20,57 @@ type Heap interface { } type heapTx struct { - tx *txs.Tx - index int - age int + tx *txs.Tx + age int } type txHeap struct { - self heap.Interface - - txIDToIndex map[ids.ID]int - txs []*heapTx - currentAge int -} - -func (h *txHeap) initialize(self heap.Interface) { - h.self = self - h.txIDToIndex = make(map[ids.ID]int) + heap heap.Map[ids.ID, heapTx] + currentAge int } func (h *txHeap) Add(tx *txs.Tx) { - heap.Push(h.self, tx) + txID := tx.ID() + if h.heap.Contains(txID) { + return + } + htx := heapTx{ + tx: tx, + age: h.currentAge, + } + h.currentAge++ + h.heap.Push(txID, htx) } func (h *txHeap) Get(txID ids.ID) *txs.Tx { - index, exists := h.txIDToIndex[txID] - if !exists { - return nil - } - return h.txs[index].tx + got, _ := h.heap.Get(txID) + return got.tx } func (h *txHeap) List() []*txs.Tx { - res := make([]*txs.Tx, 0, len(h.txs)) - for _, tx := range h.txs { + heapTxs := heap.MapValues(h.heap) + res := make([]*txs.Tx, 0, len(heapTxs)) + for _, tx := range heapTxs { res = append(res, tx.tx) } return res } func (h *txHeap) Remove(txID ids.ID) *txs.Tx { - index, exists := h.txIDToIndex[txID] - if !exists { - return nil - } - return heap.Remove(h.self, index).(*txs.Tx) + removed, _ := h.heap.Remove(txID) + return removed.tx } func (h *txHeap) Peek() *txs.Tx { - return h.txs[0].tx + _, peeked, _ := h.heap.Peek() + return peeked.tx } func (h *txHeap) RemoveTop() *txs.Tx { - return heap.Pop(h.self).(*txs.Tx) + _, popped, _ := h.heap.Pop() + return popped.tx } func (h *txHeap) Len() int { - return len(h.txs) -} - -func (h *txHeap) Swap(i, j int) { - // The follow "i"s and "j"s are intentionally swapped to perform the actual - // swap - iTx := h.txs[j] - jTx := h.txs[i] - - iTx.index = i - jTx.index = j - h.txs[i] = iTx - h.txs[j] = jTx - - iTxID := iTx.tx.ID() - jTxID := jTx.tx.ID() - h.txIDToIndex[iTxID] = i - h.txIDToIndex[jTxID] = j -} - -func (h *txHeap) Push(x interface{}) { - tx := x.(*txs.Tx) - - txID := tx.ID() - _, exists := h.txIDToIndex[txID] - if exists { - return - } - htx := &heapTx{ - tx: tx, - index: len(h.txs), - age: h.currentAge, - } - h.currentAge++ - h.txIDToIndex[txID] = htx.index - h.txs = append(h.txs, htx) -} - -func (h *txHeap) Pop() interface{} { - newLen := len(h.txs) - 1 - htx := h.txs[newLen] - h.txs[newLen] = nil - h.txs = h.txs[:newLen] - - tx := htx.tx - txID := tx.ID() - delete(h.txIDToIndex, txID) - return tx + return h.heap.Len() } diff --git a/vms/platformvm/txs/visitor.go b/vms/platformvm/txs/visitor.go index e0b4b3a8b09b..41f12d74d1d8 100644 --- a/vms/platformvm/txs/visitor.go +++ b/vms/platformvm/txs/visitor.go @@ -28,6 +28,7 @@ type Visitor interface { TransformSubnetTx(*TransformSubnetTx) error AddPermissionlessValidatorTx(*AddPermissionlessValidatorTx) error AddPermissionlessDelegatorTx(*AddPermissionlessDelegatorTx) error + TransferSubnetOwnershipTx(*TransferSubnetOwnershipTx) error CaminoVisitor } diff --git a/vms/platformvm/utxo/camino_helpers_test.go b/vms/platformvm/utxo/camino_helpers_test.go index e169785f869a..6801a939453b 100644 --- a/vms/platformvm/utxo/camino_helpers_test.go +++ b/vms/platformvm/utxo/camino_helpers_test.go @@ -57,13 +57,10 @@ var ( ) func defaultConfig() *config.Config { - vdrs := validators.NewManager() - primaryVdrs := validators.NewSet() - _ = vdrs.Add(constants.PrimaryNetworkID, primaryVdrs) return &config.Config{ Chains: chains.TestManager, UptimeLockedCalculator: uptime.NewLockedCalculator(), - Validators: vdrs, + Validators: validators.NewManager(), TxFee: defaultTxFee, CreateSubnetTxFee: 100 * defaultTxFee, CreateBlockchainTxFee: 100 * defaultTxFee, diff --git a/vms/platformvm/validator_set_property_test.go b/vms/platformvm/validator_set_property_test.go index bfcdea3c0683..d984a1a986b0 100644 --- a/vms/platformvm/validator_set_property_test.go +++ b/vms/platformvm/validator_set_property_test.go @@ -725,16 +725,12 @@ func TestTimestampListGenerator(t *testing.T) { // add a single validator at the end of times, // to make sure it won't pollute our tests func buildVM(t *testing.T) (*VM, ids.ID, error) { - vdrs := validators.NewManager() - primaryVdrs := validators.NewSet() - _ = vdrs.Add(constants.PrimaryNetworkID, primaryVdrs) - forkTime := defaultGenesisTime vm := &VM{Config: config.Config{ Chains: chains.TestManager, UptimeLockedCalculator: uptime.NewLockedCalculator(), SybilProtectionEnabled: true, - Validators: vdrs, + Validators: validators.NewManager(), TxFee: defaultTxFee, CreateSubnetTxFee: 100 * defaultTxFee, TransformSubnetTxFee: 100 * defaultTxFee, diff --git a/vms/platformvm/validators/manager.go b/vms/platformvm/validators/manager.go index 4cc7336fc479..fb7c314c90a7 100644 --- a/vms/platformvm/validators/manager.go +++ b/vms/platformvm/validators/manager.go @@ -5,12 +5,9 @@ package validators import ( "context" - "errors" "fmt" "time" - "go.uber.org/zap" - "github.com/ava-labs/avalanchego/cache" "github.com/ava-labs/avalanchego/database" "github.com/ava-labs/avalanchego/ids" @@ -33,11 +30,7 @@ const ( recentlyAcceptedWindowTTL = 2 * time.Minute ) -var ( - _ validators.State = (*manager)(nil) - - ErrMissingValidatorSet = errors.New("missing validator set") -) +var _ validators.State = (*manager)(nil) // Manager adds the ability to introduce newly accepted blocks IDs to the State // interface. @@ -55,9 +48,9 @@ type State interface { GetLastAccepted() ids.ID GetStatelessBlock(blockID ids.ID) (block.Block, error) - // ValidatorSet adds all the validators and delegators of [subnetID] into - // [vdrs]. - ValidatorSet(subnetID ids.ID, vdrs validators.Set) error + // ApplyCurrentValidators adds all the current validators and delegators of + // [subnetID] into [vdrs]. + ApplyCurrentValidators(subnetID ids.ID, vdrs validators.Manager) error // ApplyValidatorWeightDiffs iterates from [startHeight] towards the genesis // block until it has applied all of the diffs up to and including @@ -291,17 +284,9 @@ func (m *manager) makePrimaryNetworkValidatorSet( func (m *manager) getCurrentPrimaryValidatorSet( ctx context.Context, ) (map[ids.NodeID]*validators.GetValidatorOutput, uint64, error) { - currentValidators, ok := m.cfg.Validators.Get(constants.PrimaryNetworkID) - if !ok { - // This should never happen - m.log.Error(ErrMissingValidatorSet.Error(), - zap.Stringer("subnetID", constants.PrimaryNetworkID), - ) - return nil, 0, ErrMissingValidatorSet - } - + primaryMap := m.cfg.Validators.GetMap(constants.PrimaryNetworkID) currentHeight, err := m.getCurrentHeight(ctx) - return currentValidators.Map(), currentHeight, err + return primaryMap, currentHeight, err } func (m *manager) makeSubnetValidatorSet( @@ -361,28 +346,25 @@ func (m *manager) getCurrentValidatorSets( ctx context.Context, subnetID ids.ID, ) (map[ids.NodeID]*validators.GetValidatorOutput, map[ids.NodeID]*validators.GetValidatorOutput, uint64, error) { - currentSubnetValidators, ok := m.cfg.Validators.Get(subnetID) - if !ok { - // TODO: Require that the current validator set for all subnets is - // included in the validator manager. - currentSubnetValidators = validators.NewSet() - err := m.state.ValidatorSet(subnetID, currentSubnetValidators) - if err != nil { + subnetManager := m.cfg.Validators + if subnetManager.Count(subnetID) == 0 { + // If this subnet isn't tracked, there will not be any registered + // validators. To calculate the current validators we need to first + // fetch them from state. We generate a new manager as we don't want to + // modify that long-lived reference. + // + // TODO: remove this once all subnets are included in the validator + // manager. + subnetManager = validators.NewManager() + if err := m.state.ApplyCurrentValidators(subnetID, subnetManager); err != nil { return nil, nil, 0, err } } - currentPrimaryValidators, ok := m.cfg.Validators.Get(constants.PrimaryNetworkID) - if !ok { - // This should never happen - m.log.Error(ErrMissingValidatorSet.Error(), - zap.Stringer("subnetID", constants.PrimaryNetworkID), - ) - return nil, nil, 0, ErrMissingValidatorSet - } - + subnetMap := subnetManager.GetMap(subnetID) + primaryMap := m.cfg.Validators.GetMap(constants.PrimaryNetworkID) currentHeight, err := m.getCurrentHeight(ctx) - return currentSubnetValidators.Map(), currentPrimaryValidators.Map(), currentHeight, err + return subnetMap, primaryMap, currentHeight, err } func (m *manager) GetSubnetID(_ context.Context, chainID ids.ID) (ids.ID, error) { diff --git a/vms/platformvm/validators/manager_benchmark_test.go b/vms/platformvm/validators/manager_benchmark_test.go index e8202cffcc77..54d0e264e63e 100644 --- a/vms/platformvm/validators/manager_benchmark_test.go +++ b/vms/platformvm/validators/manager_benchmark_test.go @@ -102,7 +102,6 @@ func BenchmarkGetValidatorSet(b *testing.B) { require.NoError(err) vdrs := validators.NewManager() - vdrs.Add(constants.PrimaryNetworkID, validators.NewSet()) execConfig, err := config.GetExecutionConfig(nil) require.NoError(err) diff --git a/vms/platformvm/vm.go b/vms/platformvm/vm.go index c9e8e397afa2..761d4c4ef2ee 100644 --- a/vms/platformvm/vm.go +++ b/vms/platformvm/vm.go @@ -15,7 +15,6 @@ package platformvm import ( "context" - "errors" "fmt" "net/http" @@ -68,8 +67,6 @@ var ( _ secp256k1fx.VM = (*VM)(nil) _ validators.State = (*VM)(nil) _ validators.SubnetConnector = (*VM)(nil) - - errMissingValidatorSet = errors.New("missing validator set") ) type VM struct { @@ -165,7 +162,7 @@ func (vm *VM) Initialize( validatorManager := pvalidators.NewManager(chainCtx.Log, vm.Config, vm.state, vm.metrics, &vm.clock) vm.State = validatorManager vm.atomicUtxosManager = avax.NewAtomicUTXOManager(chainCtx.SharedMemory, txs.Codec) - camCfg, _ := vm.state.CaminoConfig() + camCfg, _ := vm.state.CaminoConfig() // should never error utxoHandler := utxo.NewCaminoHandler(vm.ctx, &vm.clock, vm.fx, camCfg != nil && camCfg.LockModeBondDeposit) vm.uptimeManager = uptime.NewManager(vm.state, &vm.clock) vm.UptimeLockedCalculator.SetCalculator(&vm.bootstrapped, &chainCtx.Lock, vm.uptimeManager) @@ -318,19 +315,15 @@ func (vm *VM) onNormalOperationsStarted() error { return err } - primaryVdrIDs, err := validators.NodeIDs(vm.Validators, constants.PrimaryNetworkID) - if err != nil { - return err - } + primaryVdrIDs := vm.Validators.GetValidatorIDs(constants.PrimaryNetworkID) + if err := vm.uptimeManager.StartTracking(primaryVdrIDs, constants.PrimaryNetworkID); err != nil { return err } for subnetID := range vm.TrackedSubnets { - vdrIDs, err := validators.NodeIDs(vm.Validators, subnetID) - if err != nil { - return err - } + vdrIDs := vm.Validators.GetValidatorIDs(subnetID) + if err := vm.uptimeManager.StartTracking(vdrIDs, subnetID); err != nil { return err } @@ -365,19 +358,13 @@ func (vm *VM) Shutdown(context.Context) error { vm.Builder.Shutdown() if vm.bootstrapped.Get() { - primaryVdrIDs, err := validators.NodeIDs(vm.Validators, constants.PrimaryNetworkID) - if err != nil { - return err - } + primaryVdrIDs := vm.Validators.GetValidatorIDs(constants.PrimaryNetworkID) if err := vm.uptimeManager.StopTracking(primaryVdrIDs, constants.PrimaryNetworkID); err != nil { return err } for subnetID := range vm.TrackedSubnets { - vdrIDs, err := validators.NodeIDs(vm.Validators, subnetID) - if err != nil { - return err - } + vdrIDs := vm.Validators.GetValidatorIDs(subnetID) if err := vm.uptimeManager.StopTracking(vdrIDs, subnetID); err != nil { return err } diff --git a/vms/platformvm/vm_regression_test.go b/vms/platformvm/vm_regression_test.go index 1d88b993f831..9427f95ffb41 100644 --- a/vms/platformvm/vm_regression_test.go +++ b/vms/platformvm/vm_regression_test.go @@ -346,12 +346,9 @@ func TestUnverifiedParentPanicRegression(t *testing.T) { baseDBManager := manager.NewMemDB(version.Semantic1_0_0) atomicDB := prefixdb.New([]byte{1}, baseDBManager.Current().Database) - vdrs := validators.NewManager() - primaryVdrs := validators.NewSet() - _ = vdrs.Add(constants.PrimaryNetworkID, primaryVdrs) vm := &VM{Config: config.Config{ Chains: chains.TestManager, - Validators: vdrs, + Validators: validators.NewManager(), UptimeLockedCalculator: uptime.NewLockedCalculator(), MinStakeDuration: defaultMinStakingDuration, MaxStakeDuration: defaultMaxStakingDuration, @@ -653,7 +650,6 @@ func TestRejectedStateRegressionInvalidValidatorTimestamp(t *testing.T) { // Force a reload of the state from the database. vm.Config.Validators = validators.NewManager() - vm.Config.Validators.Add(constants.PrimaryNetworkID, validators.NewSet()) execCfg, _ := config.GetExecutionConfig(nil) newState, err := state.New( vm.dbManager.Current().Database, @@ -963,7 +959,6 @@ func TestRejectedStateRegressionInvalidValidatorReward(t *testing.T) { // Force a reload of the state from the database. vm.Config.Validators = validators.NewManager() - vm.Config.Validators.Add(constants.PrimaryNetworkID, validators.NewSet()) execCfg, _ := config.GetExecutionConfig(nil) newState, err := state.New( vm.dbManager.Current().Database, @@ -1404,10 +1399,7 @@ func TestRemovePermissionedValidatorDuringPendingToCurrentTransitionTracked(t *t require.NoError(vm.SetPreference(context.Background(), vm.manager.LastAccepted())) vm.TrackedSubnets.Add(createSubnetTx.ID()) - subnetValidators := validators.NewSet() - require.NoError(vm.state.ValidatorSet(createSubnetTx.ID(), subnetValidators)) - - require.True(vm.Validators.Add(createSubnetTx.ID(), subnetValidators)) + require.NoError(vm.state.ApplyCurrentValidators(createSubnetTx.ID(), vm.Validators)) addSubnetValidatorTx, err := vm.txBuilder.NewAddSubnetValidatorTx( defaultMaxValidatorStake, diff --git a/vms/platformvm/vm_test.go b/vms/platformvm/vm_test.go index 4b1a8ef01dfc..2db4146d05df 100644 --- a/vms/platformvm/vm_test.go +++ b/vms/platformvm/vm_test.go @@ -298,14 +298,11 @@ func BuildGenesisTestWithArgs(t *testing.T, args *api.BuildGenesisArgs) (*api.Bu func defaultVM(t *testing.T) (*VM, database.Database, *mutableSharedMemory) { require := require.New(t) - vdrs := validators.NewManager() - primaryVdrs := validators.NewSet() - _ = vdrs.Add(constants.PrimaryNetworkID, primaryVdrs) vm := &VM{Config: config.Config{ Chains: chains.TestManager, UptimeLockedCalculator: uptime.NewLockedCalculator(), SybilProtectionEnabled: true, - Validators: vdrs, + Validators: validators.NewManager(), TxFee: defaultTxFee, CreateSubnetTxFee: 100 * defaultTxFee, TransformSubnetTxFee: 100 * defaultTxFee, @@ -424,13 +421,12 @@ func TestGenesis(t *testing.T) { } // Ensure current validator set of primary network is correct - vdrSet, ok := vm.Validators.Get(constants.PrimaryNetworkID) - require.True(ok) - require.Len(genesisState.Validators, vdrSet.Len()) + require.Len(genesisState.Validators, vm.Validators.Count(constants.PrimaryNetworkID)) for _, key := range keys { nodeID := ids.NodeID(key.PublicKey().Address()) - require.True(vdrSet.Contains(nodeID)) + _, ok := vm.Validators.GetValidator(constants.PrimaryNetworkID, nodeID) + require.True(ok) } // Ensure the new subnet we created exists @@ -1182,12 +1178,9 @@ func TestRestartFullyAccepted(t *testing.T) { db := manager.NewMemDB(version.Semantic1_0_0) firstDB := db.NewPrefixDBManager([]byte{}) - firstVdrs := validators.NewManager() - firstPrimaryVdrs := validators.NewSet() - _ = firstVdrs.Add(constants.PrimaryNetworkID, firstPrimaryVdrs) firstVM := &VM{Config: config.Config{ Chains: chains.TestManager, - Validators: firstVdrs, + Validators: validators.NewManager(), UptimeLockedCalculator: uptime.NewLockedCalculator(), MinStakeDuration: defaultMinStakingDuration, MaxStakeDuration: defaultMaxStakingDuration, @@ -1270,12 +1263,9 @@ func TestRestartFullyAccepted(t *testing.T) { require.NoError(firstVM.Shutdown(context.Background())) firstCtx.Lock.Unlock() - secondVdrs := validators.NewManager() - secondPrimaryVdrs := validators.NewSet() - _ = secondVdrs.Add(constants.PrimaryNetworkID, secondPrimaryVdrs) secondVM := &VM{Config: config.Config{ Chains: chains.TestManager, - Validators: secondVdrs, + Validators: validators.NewManager(), UptimeLockedCalculator: uptime.NewLockedCalculator(), MinStakeDuration: defaultMinStakingDuration, MaxStakeDuration: defaultMaxStakingDuration, @@ -1324,12 +1314,9 @@ func TestBootstrapPartiallyAccepted(t *testing.T) { blocked, err := queue.NewWithMissing(bootstrappingDB, "", prometheus.NewRegistry()) require.NoError(err) - vdrs := validators.NewManager() - primaryVdrs := validators.NewSet() - _ = vdrs.Add(constants.PrimaryNetworkID, primaryVdrs) vm := &VM{Config: config.Config{ Chains: chains.TestManager, - Validators: vdrs, + Validators: validators.NewManager(), UptimeLockedCalculator: uptime.NewLockedCalculator(), MinStakeDuration: defaultMinStakingDuration, MaxStakeDuration: defaultMaxStakingDuration, @@ -1406,8 +1393,8 @@ func TestBootstrapPartiallyAccepted(t *testing.T) { advanceTimeBlkBytes := advanceTimeBlk.Bytes() peerID := ids.NodeID{1, 2, 3, 4, 5, 4, 3, 2, 1} - beacons := validators.NewSet() - require.NoError(beacons.Add(peerID, nil, ids.Empty, 1)) + beacons := validators.NewManager() + require.NoError(beacons.AddStaker(ctx.SubnetID, peerID, nil, ids.Empty, 1)) benchlist := benchlist.NewNoBenchlist() timeoutManager, err := timeout.NewManager( @@ -1425,6 +1412,7 @@ func TestBootstrapPartiallyAccepted(t *testing.T) { require.NoError(err) go timeoutManager.Dispatch() + defer timeoutManager.Stop() chainRouter := &router.ChainRouter{} @@ -1492,17 +1480,19 @@ func TestBootstrapPartiallyAccepted(t *testing.T) { } peers := tracker.NewPeers() - startup := tracker.NewStartup(peers, (beacons.Weight()+1)/2) - beacons.RegisterCallbackListener(startup) + totalWeight, err := beacons.TotalWeight(ctx.SubnetID) + require.NoError(err) + startup := tracker.NewStartup(peers, (totalWeight+1)/2) + beacons.RegisterCallbackListener(ctx.SubnetID, startup) // The engine handles consensus consensus := &smcon.Topological{} commonCfg := common.Config{ Ctx: consensusCtx, Beacons: beacons, - SampleK: beacons.Len(), + SampleK: beacons.Count(ctx.SubnetID), StartupTracker: startup, - Alpha: (beacons.Weight() + 1) / 2, + Alpha: (totalWeight + 1) / 2, Sender: sender, BootstrapTracker: bootstrapTracker, AncestorsMaxContainersSent: 2000, @@ -1645,12 +1635,9 @@ func TestUnverifiedParent(t *testing.T) { _, genesisBytes := defaultGenesis(t) dbManager := manager.NewMemDB(version.Semantic1_0_0) - vdrs := validators.NewManager() - primaryVdrs := validators.NewSet() - _ = vdrs.Add(constants.PrimaryNetworkID, primaryVdrs) vm := &VM{Config: config.Config{ Chains: chains.TestManager, - Validators: vdrs, + Validators: validators.NewManager(), UptimeLockedCalculator: uptime.NewLockedCalculator(), MinStakeDuration: defaultMinStakingDuration, MaxStakeDuration: defaultMaxStakingDuration, @@ -1806,16 +1793,12 @@ func TestUptimeDisallowedWithRestart(t *testing.T) { db := manager.NewMemDB(version.Semantic1_0_0) firstDB := db.NewPrefixDBManager([]byte{}) - firstVdrs := validators.NewManager() - firstPrimaryVdrs := validators.NewSet() - _ = firstVdrs.Add(constants.PrimaryNetworkID, firstPrimaryVdrs) - const firstUptimePercentage = 20 // 20% firstVM := &VM{Config: config.Config{ Chains: chains.TestManager, UptimePercentage: firstUptimePercentage / 100., RewardConfig: defaultRewardConfig, - Validators: firstVdrs, + Validators: validators.NewManager(), UptimeLockedCalculator: uptime.NewLockedCalculator(), BanffTime: banffForkTime, }} @@ -1854,15 +1837,11 @@ func TestUptimeDisallowedWithRestart(t *testing.T) { // Restart the VM with a larger uptime requirement secondDB := db.NewPrefixDBManager([]byte{}) - secondVdrs := validators.NewManager() - secondPrimaryVdrs := validators.NewSet() - _ = secondVdrs.Add(constants.PrimaryNetworkID, secondPrimaryVdrs) - const secondUptimePercentage = 21 // 21% > firstUptimePercentage, so uptime for reward is not met now secondVM := &VM{Config: config.Config{ Chains: chains.TestManager, UptimePercentage: secondUptimePercentage / 100., - Validators: secondVdrs, + Validators: validators.NewManager(), UptimeLockedCalculator: uptime.NewLockedCalculator(), BanffTime: banffForkTime, }} @@ -1951,14 +1930,11 @@ func TestUptimeDisallowedAfterNeverConnecting(t *testing.T) { _, genesisBytes := defaultGenesis(t) db := manager.NewMemDB(version.Semantic1_0_0) - vdrs := validators.NewManager() - primaryVdrs := validators.NewSet() - _ = vdrs.Add(constants.PrimaryNetworkID, primaryVdrs) vm := &VM{Config: config.Config{ Chains: chains.TestManager, UptimePercentage: .2, RewardConfig: defaultRewardConfig, - Validators: vdrs, + Validators: validators.NewManager(), UptimeLockedCalculator: uptime.NewLockedCalculator(), BanffTime: banffForkTime, }} @@ -2142,3 +2118,78 @@ func TestRemovePermissionedValidatorDuringAddPending(t *testing.T) { _, err = vm.state.GetPendingValidator(createSubnetTx.ID(), ids.NodeID(id)) require.ErrorIs(err, database.ErrNotFound) } + +func TestTransferSubnetOwnershipTx(t *testing.T) { + require := require.New(t) + vm, _, _ := defaultVM(t) + vm.ctx.Lock.Lock() + defer func() { + require.NoError(vm.Shutdown(context.Background())) + vm.ctx.Lock.Unlock() + }() + + // Create a subnet + createSubnetTx, err := vm.txBuilder.NewCreateSubnetTx( + 1, + []ids.ShortID{keys[0].PublicKey().Address()}, + []*secp256k1.PrivateKey{keys[0]}, + keys[0].Address(), + ) + require.NoError(err) + subnetID := createSubnetTx.ID() + + require.NoError(vm.Builder.AddUnverifiedTx(createSubnetTx)) + createSubnetBlock, err := vm.Builder.BuildBlock(context.Background()) + require.NoError(err) + + createSubnetRawBlock := createSubnetBlock.(*blockexecutor.Block).Block + require.IsType(&block.BanffStandardBlock{}, createSubnetRawBlock) + require.Contains(createSubnetRawBlock.Txs(), createSubnetTx) + + require.NoError(createSubnetBlock.Verify(context.Background())) + require.NoError(createSubnetBlock.Accept(context.Background())) + require.NoError(vm.SetPreference(context.Background(), vm.manager.LastAccepted())) + + subnetOwner, err := vm.state.GetSubnetOwner(subnetID) + require.NoError(err) + expectedOwner := &secp256k1fx.OutputOwners{ + Locktime: 0, + Threshold: 1, + Addrs: []ids.ShortID{ + keys[0].PublicKey().Address(), + }, + } + require.Equal(expectedOwner, subnetOwner) + + transferSubnetOwnershipTx, err := vm.txBuilder.NewTransferSubnetOwnershipTx( + subnetID, + 1, + []ids.ShortID{keys[1].PublicKey().Address()}, + []*secp256k1.PrivateKey{keys[0]}, + ids.ShortEmpty, // change addr + ) + require.NoError(err) + + require.NoError(vm.Builder.AddUnverifiedTx(transferSubnetOwnershipTx)) + transferSubnetOwnershipBlock, err := vm.Builder.BuildBlock(context.Background()) + require.NoError(err) + + transferSubnetOwnershipRawBlock := transferSubnetOwnershipBlock.(*blockexecutor.Block).Block + require.IsType(&block.BanffStandardBlock{}, transferSubnetOwnershipRawBlock) + require.Contains(transferSubnetOwnershipRawBlock.Txs(), transferSubnetOwnershipTx) + + require.NoError(transferSubnetOwnershipBlock.Verify(context.Background())) + require.NoError(transferSubnetOwnershipBlock.Accept(context.Background())) + require.NoError(vm.SetPreference(context.Background(), vm.manager.LastAccepted())) + + subnetOwner, err = vm.state.GetSubnetOwner(subnetID) + require.NoError(err) + expectedOwner = &secp256k1fx.OutputOwners{ + Locktime: 0, + Threshold: 1, + Addrs: []ids.ShortID{ + keys[1].PublicKey().Address(), + }, + } + require.Equal(expectedOwner, subnetOwner) +} diff --git a/vms/proposervm/block.go b/vms/proposervm/block.go index 3035bc321e80..dcd5e63a213f 100644 --- a/vms/proposervm/block.go +++ b/vms/proposervm/block.go @@ -145,7 +145,11 @@ func (p *postForkCommonComponents) Verify( return err } if childPChainHeight > currentPChainHeight { - return errPChainHeightNotReached + return fmt.Errorf("%w: %d > %d", + errPChainHeightNotReached, + childPChainHeight, + currentPChainHeight, + ) } childHeight := child.Height() diff --git a/vms/proposervm/main_test.go b/vms/proposervm/main_test.go index 263fabf0b665..913e29613f1c 100644 --- a/vms/proposervm/main_test.go +++ b/vms/proposervm/main_test.go @@ -9,8 +9,6 @@ import ( "go.uber.org/goleak" ) -// TestMain uses goleak to verify tests in this package do not leak unexpected -// goroutines. func TestMain(m *testing.M) { goleak.VerifyTestMain(m) } diff --git a/vms/proposervm/pre_fork_block.go b/vms/proposervm/pre_fork_block.go index 0b11ef8e4716..ed665e473910 100644 --- a/vms/proposervm/pre_fork_block.go +++ b/vms/proposervm/pre_fork_block.go @@ -5,6 +5,7 @@ package proposervm import ( "context" + "fmt" "time" "go.uber.org/zap" @@ -128,7 +129,11 @@ func (b *preForkBlock) verifyPostForkChild(ctx context.Context, child *postForkB return err } if childPChainHeight > currentPChainHeight { - return errPChainHeightNotReached + return fmt.Errorf("%w: %d > %d", + errPChainHeightNotReached, + childPChainHeight, + currentPChainHeight, + ) } if childPChainHeight < b.vm.minimumPChainHeight { return errPChainHeightTooLow diff --git a/wallet/chain/p/backend_visitor.go b/wallet/chain/p/backend_visitor.go index 9830d87ade05..da2fc591ecd5 100644 --- a/wallet/chain/p/backend_visitor.go +++ b/wallet/chain/p/backend_visitor.go @@ -53,6 +53,11 @@ func (b *backendVisitor) RemoveSubnetValidatorTx(tx *txs.RemoveSubnetValidatorTx return b.baseTx(&tx.BaseTx) } +func (b *backendVisitor) TransferSubnetOwnershipTx(tx *txs.TransferSubnetOwnershipTx) error { + // TODO: Correctly track subnet owners in [getSubnetSigners] + return b.baseTx(&tx.BaseTx) +} + func (b *backendVisitor) ImportTx(tx *txs.ImportTx) error { err := b.b.removeUTXOs( b.ctx, diff --git a/wallet/chain/p/signer_visitor.go b/wallet/chain/p/signer_visitor.go index 52269ee69081..6df1687400ac 100644 --- a/wallet/chain/p/signer_visitor.go +++ b/wallet/chain/p/signer_visitor.go @@ -135,6 +135,19 @@ func (s *signerVisitor) RemoveSubnetValidatorTx(tx *txs.RemoveSubnetValidatorTx) return sign(s.tx, true, txSigners) } +func (s *signerVisitor) TransferSubnetOwnershipTx(tx *txs.TransferSubnetOwnershipTx) error { + txSigners, err := s.getSigners(constants.PlatformChainID, tx.Ins) + if err != nil { + return err + } + subnetAuthSigners, err := s.getSubnetSigners(tx.Subnet, tx.SubnetAuth) + if err != nil { + return err + } + txSigners = append(txSigners, subnetAuthSigners) + return sign(s.tx, true, txSigners) +} + func (s *signerVisitor) TransformSubnetTx(tx *txs.TransformSubnetTx) error { txSigners, err := s.getSigners(constants.PlatformChainID, tx.Ins) if err != nil { diff --git a/x/merkledb/README.md b/x/merkledb/README.md index 270382bf07d5..467a60e19b08 100644 --- a/x/merkledb/README.md +++ b/x/merkledb/README.md @@ -21,9 +21,9 @@ To reduce the depth of nodes in the trie, a `Merkle Node` utilizes path compress | Merkle Node | | | | ID: 0x0131 | an id representing the current node, derived from the node's value and all children ids -| Key: 0x91 | prefix of the key path, representing the location of the node in the trie -| Value: 0x00 | the value, if one exists, that is stored at the key path (pathPrefix + compressedPath) -| Children: | a map of children node ids for any nodes in the trie that have this node's key path as a prefix +| Key: 0x91 | prefix of the key, representing the location of the node in the trie +| Value: 0x00 | the value, if one exists, that is stored at the key (keyPrefix + compressedKey) +| Children: | a map of children node ids for any nodes in the trie that have this node's key as a prefix | 0: [:0x00542F] | child 0 represents a node with key 0x910 with ID 0x00542F | 1: [0x432:0xA0561C] | child 1 represents a node with key 0x911432 with ID 0xA0561C | ... | @@ -52,9 +52,9 @@ The node serialization format is as follows: +----------------------------------------------------+ | Child index (varint) | +----------------------------------------------------+ -| Child compressed path length (varint) | +| Child compressed key length (varint) | +----------------------------------------------------+ -| Child compressed path (variable length bytes) | +| Child compressed key (variable length bytes) | +----------------------------------------------------+ | Child ID (32 bytes) | +----------------------------------------------------+ @@ -62,9 +62,9 @@ The node serialization format is as follows: +----------------------------------------------------+ | Child index (varint) | +----------------------------------------------------+ -| Child compressed path length (varint) | +| Child compressed key length (varint) | +----------------------------------------------------+ -| Child compressed path (variable length bytes) | +| Child compressed key (variable length bytes) | +----------------------------------------------------+ | Child ID (32 bytes) | +----------------------------------------------------+ @@ -80,8 +80,8 @@ Where: * `Value` is the value, if it exists (i.e. if `Value existince flag` is `1`.) Otherwise not serialized. * `Number of children` is the number of children this node has. * `Child index` is the index of a child node within the list of the node's children. -* `Child compressed path length` is the length of the child node's compressed path. -* `Child compressed path` is the child node's compressed path. +* `Child compressed key length` is the length of the child node's compressed key. +* `Child compressed key` is the child node's compressed key. * `Child ID` is the child node's ID. * `Child has value` indicates if that child has a value. @@ -91,9 +91,9 @@ For each child of the node, we have an additional: +----------------------------------------------------+ | Child index (varint) | +----------------------------------------------------+ -| Child compressed path length (varint) | +| Child compressed key length (varint) | +----------------------------------------------------+ -| Child compressed path (variable length bytes) | +| Child compressed key (variable length bytes) | +----------------------------------------------------+ | Child ID (32 bytes) | +----------------------------------------------------+ @@ -114,8 +114,8 @@ Its byte representation (in hex) is: `0x01020204000210579EB3718A7E437D2DDCE931AC The node's key is empty (its the root) and has value `0x02`. It has two children. -The first is at child index `0`, has compressed path `0x01` and ID (in hex) `0x579eb3718a7e437d2ddce931ac7cc05a0bc695a9c2084f5df12fb96ad0fa3266`. -The second is at child index `14`, has compressed path `0x0F0F0F` and ID (in hex) `0x9845893c4f9d92c4e097fcf2589bc9d6882b1f18d1c2fc91d7df1d3fcbdb4238`. +The first is at child index `0`, has compressed key `0x01` and ID (in hex) `0x579eb3718a7e437d2ddce931ac7cc05a0bc695a9c2084f5df12fb96ad0fa3266`. +The second is at child index `14`, has compressed key `0x0F0F0F` and ID (in hex) `0x9845893c4f9d92c4e097fcf2589bc9d6882b1f18d1c2fc91d7df1d3fcbdb4238`. ``` +--------------------------------------------------------------------+ @@ -134,10 +134,10 @@ The second is at child index `14`, has compressed path `0x0F0F0F` and ID (in hex | Child index (varint) | | 0x00 | +--------------------------------------------------------------------+ -| Child compressed path length (varint) | +| Child compressed key length (varint) | | 0x02 | +--------------------------------------------------------------------+ -| Child compressed path (variable length bytes) | +| Child compressed key (variable length bytes) | | 0x10 | +--------------------------------------------------------------------+ | Child ID (32 bytes) | @@ -146,10 +146,10 @@ The second is at child index `14`, has compressed path `0x0F0F0F` and ID (in hex | Child index (varint) | | 0x0E | +--------------------------------------------------------------------+ -| Child compressed path length (varint) | +| Child compressed key length (varint) | | 0x06 | +--------------------------------------------------------------------+ -| Child compressed path (variable length bytes) | +| Child compressed key (variable length bytes) | | 0xFFF0 | +--------------------------------------------------------------------+ | Child ID (32 bytes) | @@ -204,7 +204,7 @@ Where: Note that, as with the node serialization format, the `Child index` values aren't necessarily sequential, but they are unique and strictly increasing. Also like the node serialization format, there can be up to 16 blocks of children data. -However, note that child compressed paths are not included in the node ID calculation. +However, note that child compressed keys are not included in the node ID calculation. Once this is encoded, we `sha256` hash the resulting bytes to get the node's ID. @@ -227,7 +227,7 @@ By splitting the nodes up by value, it allows better key/value iteration and a m ### Single node type -A `Merkle Node` holds the IDs of its children, its value, as well as any path extension. This simplifies some logic and allows all of the data about a node to be loaded in a single database read. This trades off a small amount of storage efficiency (some fields may be `nil` but are still stored for every node). +A `Merkle Node` holds the IDs of its children, its value, as well as any key extension. This simplifies some logic and allows all of the data about a node to be loaded in a single database read. This trades off a small amount of storage efficiency (some fields may be `nil` but are still stored for every node). ### Validity diff --git a/x/merkledb/codec.go b/x/merkledb/codec.go index 58004aba5088..6420baac56e9 100644 --- a/x/merkledb/codec.go +++ b/x/merkledb/codec.go @@ -21,16 +21,16 @@ const ( falseByte = 0 minVarIntLen = 1 minMaybeByteSliceLen = boolLen - minPathLen = minVarIntLen + minKeyLen = minVarIntLen minByteSliceLen = minVarIntLen minDBNodeLen = minMaybeByteSliceLen + minVarIntLen - minChildLen = minVarIntLen + minPathLen + ids.IDLen + boolLen + minChildLen = minVarIntLen + minKeyLen + ids.IDLen + boolLen - estimatedKeyLen = 64 - estimatedValueLen = 64 - estimatedCompressedPathLen = 8 - // Child index, child compressed path, child ID, child has value - estimatedNodeChildLen = minVarIntLen + estimatedCompressedPathLen + ids.IDLen + boolLen + estimatedKeyLen = 64 + estimatedValueLen = 64 + estimatedCompressedKeyLen = 8 + // Child index, child compressed key, child ID, child has value + estimatedNodeChildLen = minVarIntLen + estimatedCompressedKeyLen + ids.IDLen + boolLen // Child index, child ID hashValuesChildLen = minVarIntLen + ids.IDLen ) @@ -45,7 +45,7 @@ var ( errChildIndexTooLarge = errors.New("invalid child index. Must be less than branching factor") errLeadingZeroes = errors.New("varint has leading zeroes") errInvalidBool = errors.New("decoded bool is neither true nor false") - errNonZeroPathPadding = errors.New("path partial byte should be padded with 0s") + errNonZeroKeyPadding = errors.New("key partial byte should be padded with 0s") errExtraSpace = errors.New("trailing buffer space") errIntOverflow = errors.New("value overflows int") ) @@ -102,7 +102,7 @@ func (c *codecImpl) encodeDBNode(n *dbNode, branchFactor BranchFactor) []byte { for index := 0; BranchFactor(index) < branchFactor; index++ { if entry, ok := n.children[byte(index)]; ok { c.encodeUint(buf, uint64(index)) - c.encodePath(buf, entry.compressedPath) + c.encodeKey(buf, entry.compressedKey) _, _ = buf.Write(entry.id[:]) c.encodeBool(buf, entry.hasValue) } @@ -128,7 +128,7 @@ func (c *codecImpl) encodeHashValues(hv *hashValues) []byte { } } c.encodeMaybeByteSlice(buf, hv.Value) - c.encodePath(buf, hv.Key) + c.encodeKey(buf, hv.Key) return buf.Bytes() } @@ -168,7 +168,7 @@ func (c *codecImpl) decodeDBNode(b []byte, n *dbNode, branchFactor BranchFactor) } previousChild = index - compressedPath, err := c.decodePath(src, branchFactor) + compressedKey, err := c.decodeKey(src, branchFactor) if err != nil { return err } @@ -181,9 +181,9 @@ func (c *codecImpl) decodeDBNode(b []byte, n *dbNode, branchFactor BranchFactor) return err } n.children[byte(index)] = child{ - compressedPath: compressedPath, - id: childID, - hasValue: hasValue, + compressedKey: compressedKey, + id: childID, + hasValue: hasValue, } } if src.Len() != 0 { @@ -326,43 +326,43 @@ func (*codecImpl) decodeID(src *bytes.Reader) (ids.ID, error) { return id, err } -func (c *codecImpl) encodePath(dst *bytes.Buffer, p Path) { - c.encodeUint(dst, uint64(p.tokensLength)) - _, _ = dst.Write(p.Bytes()) +func (c *codecImpl) encodeKey(dst *bytes.Buffer, key Key) { + c.encodeUint(dst, uint64(key.tokenLength)) + _, _ = dst.Write(key.Bytes()) } -func (c *codecImpl) decodePath(src *bytes.Reader, branchFactor BranchFactor) (Path, error) { - if minPathLen > src.Len() { - return Path{}, io.ErrUnexpectedEOF +func (c *codecImpl) decodeKey(src *bytes.Reader, branchFactor BranchFactor) (Key, error) { + if minKeyLen > src.Len() { + return Key{}, io.ErrUnexpectedEOF } length, err := c.decodeUint(src) if err != nil { - return Path{}, err + return Key{}, err } if length > math.MaxInt { - return Path{}, errIntOverflow + return Key{}, errIntOverflow } - result := emptyPath(branchFactor) - result.tokensLength = int(length) - pathBytesLen := result.bytesNeeded(result.tokensLength) - if pathBytesLen > src.Len() { - return Path{}, io.ErrUnexpectedEOF + result := emptyKey(branchFactor) + result.tokenLength = int(length) + keyBytesLen := result.bytesNeeded(result.tokenLength) + if keyBytesLen > src.Len() { + return Key{}, io.ErrUnexpectedEOF } - buffer := make([]byte, pathBytesLen) + buffer := make([]byte, keyBytesLen) if _, err := io.ReadFull(src, buffer); err != nil { if err == io.EOF { err = io.ErrUnexpectedEOF } - return Path{}, err + return Key{}, err } if result.hasPartialByte() { // Confirm that the padding bits in the partial byte are 0. // We want to only look at the bits to the right of the last token, which is at index length-1. // Generate a mask with (8-bitsToShift) 0s followed by bitsToShift 1s. - paddingMask := byte(0xFF >> (8 - result.bitsToShift(result.tokensLength-1))) - if buffer[pathBytesLen-1]&paddingMask != 0 { - return Path{}, errNonZeroPathPadding + paddingMask := byte(0xFF >> (8 - result.bitsToShift(result.tokenLength-1))) + if buffer[keyBytesLen-1]&paddingMask != 0 { + return Key{}, errNonZeroKeyPadding } } result.value = string(buffer) diff --git a/x/merkledb/codec_test.go b/x/merkledb/codec_test.go index 24bc6b061524..a948ea520b8c 100644 --- a/x/merkledb/codec_test.go +++ b/x/merkledb/codec_test.go @@ -73,7 +73,7 @@ func FuzzCodecInt(f *testing.F) { ) } -func FuzzCodecPath(f *testing.F) { +func FuzzCodecKey(f *testing.F) { f.Fuzz( func( t *testing.T, @@ -84,7 +84,7 @@ func FuzzCodecPath(f *testing.F) { codec := codec.(*codecImpl) reader := bytes.NewReader(b) startLen := reader.Len() - got, err := codec.decodePath(reader, branchFactor) + got, err := codec.decodeKey(reader, branchFactor) if err != nil { t.SkipNow() } @@ -93,7 +93,7 @@ func FuzzCodecPath(f *testing.F) { // Encoding [got] should be the same as [b]. var buf bytes.Buffer - codec.encodePath(&buf, got) + codec.encodeKey(&buf, got) bufBytes := buf.Bytes() require.Len(bufBytes, numRead) require.Equal(b[:numRead], bufBytes) @@ -155,12 +155,12 @@ func FuzzCodecDBNodeDeterministic(f *testing.F) { var childID ids.ID _, _ = r.Read(childID[:]) // #nosec G404 - childPathBytes := make([]byte, r.Intn(32)) // #nosec G404 - _, _ = r.Read(childPathBytes) // #nosec G404 + childKeyBytes := make([]byte, r.Intn(32)) // #nosec G404 + _, _ = r.Read(childKeyBytes) // #nosec G404 children[byte(i)] = child{ - compressedPath: NewPath(childPathBytes, branchFactor), - id: childID, + compressedKey: ToKey(childKeyBytes, branchFactor), + id: childID, } } node := dbNode{ @@ -225,14 +225,14 @@ func FuzzEncodeHashValues(f *testing.F) { children := map[byte]child{} numChildren := r.Intn(int(branchFactor)) // #nosec G404 for i := 0; i < numChildren; i++ { - compressedPathLen := r.Intn(32) // #nosec G404 - compressedPathBytes := make([]byte, compressedPathLen) - _, _ = r.Read(compressedPathBytes) // #nosec G404 + compressedKeyLen := r.Intn(32) // #nosec G404 + compressedKeyBytes := make([]byte, compressedKeyLen) + _, _ = r.Read(compressedKeyBytes) // #nosec G404 children[byte(i)] = child{ - compressedPath: NewPath(compressedPathBytes, branchFactor), - id: ids.GenerateTestID(), - hasValue: r.Intn(2) == 1, // #nosec G404 + compressedKey: ToKey(compressedKeyBytes, branchFactor), + id: ids.GenerateTestID(), + hasValue: r.Intn(2) == 1, // #nosec G404 } } @@ -250,7 +250,7 @@ func FuzzEncodeHashValues(f *testing.F) { hv := &hashValues{ Children: children, Value: value, - Key: NewPath(key, branchFactor), + Key: ToKey(key, branchFactor), } // Serialize the *hashValues with both codecs @@ -264,9 +264,9 @@ func FuzzEncodeHashValues(f *testing.F) { ) } -func TestCodecDecodePathLengthOverflowRegression(t *testing.T) { +func TestCodecDecodeKeyLengthOverflowRegression(t *testing.T) { codec := codec.(*codecImpl) bytes := bytes.NewReader(binary.AppendUvarint(nil, math.MaxInt)) - _, err := codec.decodePath(bytes, BranchFactor16) + _, err := codec.decodeKey(bytes, BranchFactor16) require.ErrorIs(t, err, io.ErrUnexpectedEOF) } diff --git a/x/merkledb/db.go b/x/merkledb/db.go index e5a5e1170dc8..247bdfbd5822 100644 --- a/x/merkledb/db.go +++ b/x/merkledb/db.go @@ -204,8 +204,8 @@ type merkleDB struct { // [calculateNodeIDsHelper] at any given time. calculateNodeIDsSema *semaphore.Weighted - newPath func(p []byte) Path - rootPath Path + toKey func(p []byte) Key + rootKey Key } // New returns a new merkle database. @@ -232,8 +232,8 @@ func newDatabase( return nil, err } - newPath := func(b []byte) Path { - return NewPath(b, config.BranchFactor) + toKey := func(b []byte) Key { + return ToKey(b, config.BranchFactor) } // Share a sync.Pool of []byte between the intermediateNodeDB and valueNodeDB @@ -248,13 +248,13 @@ func newDatabase( baseDB: db, valueNodeDB: newValueNodeDB(db, bufferPool, metrics, int(config.ValueNodeCacheSize), config.BranchFactor), intermediateNodeDB: newIntermediateNodeDB(db, bufferPool, metrics, int(config.IntermediateNodeCacheSize), int(config.EvictionBatchSize)), - history: newTrieHistory(int(config.HistoryLength), newPath), + history: newTrieHistory(int(config.HistoryLength), toKey), debugTracer: getTracerIfEnabled(config.TraceLevel, DebugTrace, config.Tracer), infoTracer: getTracerIfEnabled(config.TraceLevel, InfoTrace, config.Tracer), childViews: make([]*trieView, 0, defaultPreallocationSize), calculateNodeIDsSema: semaphore.NewWeighted(int64(rootGenConcurrency)), - newPath: newPath, - rootPath: newPath(rootKey), + toKey: toKey, + rootKey: toKey(rootKey), } root, err := trieDB.initializeRootIfNeeded() @@ -265,8 +265,8 @@ func newDatabase( // add current root to history (has no changes) trieDB.history.record(&changeSummary{ rootID: root, - values: map[Path]*change[maybe.Maybe[[]byte]]{}, - nodes: map[Path]*change[*node]{}, + values: map[Key]*change[maybe.Maybe[[]byte]]{}, + nodes: map[Key]*change[*node]{}, }) shutdownType, err := trieDB.baseDB.Get(cleanShutdownKey) @@ -292,7 +292,7 @@ func newDatabase( // Deletes every intermediate node and rebuilds them by re-adding every key/value. // TODO: make this more efficient by only clearing out the stale portions of the trie. func (db *merkleDB) rebuild(ctx context.Context, cacheSize int) error { - db.root = newNode(nil, db.rootPath) + db.root = newNode(nil, db.rootKey) // Delete intermediate nodes. if err := database.ClearPrefix(db.baseDB, intermediateNodePrefix, rebuildIntermediateDeletionWriteSize); err != nil { @@ -473,8 +473,8 @@ func (db *merkleDB) PrefetchPath(key []byte) error { return db.prefetchPath(tempView, key) } -func (db *merkleDB) prefetchPath(view *trieView, key []byte) error { - pathToKey, err := view.getPathTo(db.newPath(key)) +func (db *merkleDB) prefetchPath(view *trieView, keyBytes []byte) error { + pathToKey, err := view.getPathTo(db.toKey(keyBytes)) if err != nil { return err } @@ -508,7 +508,7 @@ func (db *merkleDB) GetValues(ctx context.Context, keys [][]byte) ([][]byte, []e values := make([][]byte, len(keys)) errors := make([]error, len(keys)) for i, key := range keys { - values[i], errors[i] = db.getValueCopy(db.newPath(key)) + values[i], errors[i] = db.getValueCopy(db.toKey(key)) } return values, errors } @@ -522,13 +522,13 @@ func (db *merkleDB) GetValue(ctx context.Context, key []byte) ([]byte, error) { db.lock.RLock() defer db.lock.RUnlock() - return db.getValueCopy(db.newPath(key)) + return db.getValueCopy(db.toKey(key)) } // getValueCopy returns a copy of the value for the given [key]. // Returns database.ErrNotFound if it doesn't exist. // Assumes [db.lock] is read locked. -func (db *merkleDB) getValueCopy(key Path) ([]byte, error) { +func (db *merkleDB) getValueCopy(key Key) ([]byte, error) { val, err := db.getValueWithoutLock(key) if err != nil { return nil, err @@ -539,7 +539,7 @@ func (db *merkleDB) getValueCopy(key Path) ([]byte, error) { // getValue returns the value for the given [key]. // Returns database.ErrNotFound if it doesn't exist. // Assumes [db.lock] isn't held. -func (db *merkleDB) getValue(key Path) ([]byte, error) { +func (db *merkleDB) getValue(key Key) ([]byte, error) { db.lock.RLock() defer db.lock.RUnlock() @@ -549,7 +549,7 @@ func (db *merkleDB) getValue(key Path) ([]byte, error) { // getValueWithoutLock returns the value for the given [key]. // Returns database.ErrNotFound if it doesn't exist. // Assumes [db.lock] is read locked. -func (db *merkleDB) getValueWithoutLock(key Path) ([]byte, error) { +func (db *merkleDB) getValueWithoutLock(key Key) ([]byte, error) { if db.closed { return nil, database.ErrClosed } @@ -731,7 +731,7 @@ func (db *merkleDB) GetChangeProof( commonNodeIndex := 0 for ; commonNodeIndex < len(result.StartProof) && commonNodeIndex < len(result.EndProof) && - result.StartProof[commonNodeIndex].KeyPath == result.EndProof[commonNodeIndex].KeyPath; commonNodeIndex++ { + result.StartProof[commonNodeIndex].Key == result.EndProof[commonNodeIndex].Key; commonNodeIndex++ { } result.StartProof = result.StartProof[commonNodeIndex:] } @@ -788,7 +788,7 @@ func (db *merkleDB) Has(k []byte) (bool, error) { return false, database.ErrClosed } - _, err := db.getValueWithoutLock(db.newPath(k)) + _, err := db.getValueWithoutLock(db.toKey(k)) if err == database.ErrNotFound { return false, nil } @@ -926,7 +926,7 @@ func (db *merkleDB) commitChanges(ctx context.Context, trieToCommit *trieView) e return nil } - rootChange, ok := changes.nodes[db.rootPath] + rootChange, ok := changes.nodes[db.rootKey] if !ok { return errNoNewRoot } @@ -1025,32 +1025,32 @@ func (db *merkleDB) VerifyChangeProof( return err } - smallestPath := maybe.Bind(start, db.newPath) + smallestKey := maybe.Bind(start, db.toKey) // Make sure the start proof, if given, is well-formed. - if err := verifyProofPath(proof.StartProof, smallestPath); err != nil { + if err := verifyProofPath(proof.StartProof, smallestKey); err != nil { return err } // Find the greatest key in [proof.KeyChanges] // Note that [proof.EndProof] is a proof for this key. - // [largestPath] is also used when we add children of proof nodes to [trie] below. - largestPath := maybe.Bind(end, db.newPath) + // [largestKey] is also used when we add children of proof nodes to [trie] below. + largestKey := maybe.Bind(end, db.toKey) if len(proof.KeyChanges) > 0 { // If [proof] has key-value pairs, we should insert children // greater than [end] to ancestors of the node containing [end] // so that we get the expected root ID. - largestPath = maybe.Some(db.newPath(proof.KeyChanges[len(proof.KeyChanges)-1].Key)) + largestKey = maybe.Some(db.toKey(proof.KeyChanges[len(proof.KeyChanges)-1].Key)) } // Make sure the end proof, if given, is well-formed. - if err := verifyProofPath(proof.EndProof, largestPath); err != nil { + if err := verifyProofPath(proof.EndProof, largestKey); err != nil { return err } - keyValues := make(map[Path]maybe.Maybe[[]byte], len(proof.KeyChanges)) + keyValues := make(map[Key]maybe.Maybe[[]byte], len(proof.KeyChanges)) for _, keyValue := range proof.KeyChanges { - keyValues[db.newPath(keyValue.Key)] = keyValue.Value + keyValues[db.toKey(keyValue.Key)] = keyValue.Value } // want to prevent commit writes to DB, but not prevent DB reads @@ -1065,8 +1065,8 @@ func (db *merkleDB) VerifyChangeProof( ctx, db, proof.StartProof, - smallestPath, - largestPath, + smallestKey, + largestKey, keyValues, ); err != nil { return err @@ -1076,8 +1076,8 @@ func (db *merkleDB) VerifyChangeProof( ctx, db, proof.EndProof, - smallestPath, - largestPath, + smallestKey, + largestKey, keyValues, ); err != nil { return err @@ -1106,16 +1106,16 @@ func (db *merkleDB) VerifyChangeProof( if err := addPathInfo( view, proof.StartProof, - smallestPath, - largestPath, + smallestKey, + largestKey, ); err != nil { return err } if err := addPathInfo( view, proof.EndProof, - smallestPath, - largestPath, + smallestKey, + largestKey, ); err != nil { return err } @@ -1154,9 +1154,9 @@ func (db *merkleDB) initializeRootIfNeeded() (ids.ID, error) { // not sure if the root exists or had a value or not // check under both prefixes var err error - db.root, err = db.intermediateNodeDB.Get(db.rootPath) + db.root, err = db.intermediateNodeDB.Get(db.rootKey) if err == database.ErrNotFound { - db.root, err = db.valueNodeDB.Get(db.rootPath) + db.root, err = db.valueNodeDB.Get(db.rootKey) } if err == nil { // Root already exists, so calculate its id @@ -1168,12 +1168,12 @@ func (db *merkleDB) initializeRootIfNeeded() (ids.ID, error) { } // Root doesn't exist; make a new one. - db.root = newNode(nil, db.rootPath) + db.root = newNode(nil, db.rootKey) // update its ID db.root.calculateID(db.metrics) - if err := db.intermediateNodeDB.Put(db.rootPath, db.root); err != nil { + if err := db.intermediateNodeDB.Put(db.rootKey, db.root); err != nil { return ids.Empty, err } @@ -1233,7 +1233,7 @@ func (db *merkleDB) getKeysNotInSet(start, end maybe.Maybe[[]byte], keySet set.S // This copy may be edited by the caller without affecting the database state. // Returns database.ErrNotFound if the node doesn't exist. // Assumes [db.lock] isn't held. -func (db *merkleDB) getEditableNode(key Path, hasValue bool) (*node, error) { +func (db *merkleDB) getEditableNode(key Key, hasValue bool) (*node, error) { db.lock.RLock() defer db.lock.RUnlock() @@ -1249,11 +1249,11 @@ func (db *merkleDB) getEditableNode(key Path, hasValue bool) (*node, error) { // Editing the returned node affects the database state. // Returns database.ErrNotFound if the node doesn't exist. // Assumes [db.lock] is read locked. -func (db *merkleDB) getNode(key Path, hasValue bool) (*node, error) { +func (db *merkleDB) getNode(key Key, hasValue bool) (*node, error) { switch { case db.closed: return nil, database.ErrClosed - case key == db.rootPath: + case key == db.rootKey: return db.root, nil case hasValue: return db.valueNodeDB.Get(key) @@ -1289,11 +1289,11 @@ func getBufferFromPool(bufferPool *sync.Pool, size int) []byte { return buffer } -// cacheEntrySize returns a rough approximation of the memory consumed by storing the path and node -func cacheEntrySize(p Path, n *node) int { +// cacheEntrySize returns a rough approximation of the memory consumed by storing the key and node +func cacheEntrySize(key Key, n *node) int { if n == nil { - return len(p.Bytes()) + return len(key.Bytes()) } // nodes cache their bytes representation so the total memory consumed is roughly twice that - return len(p.Bytes()) + 2*len(n.bytes()) + return len(key.Bytes()) + 2*len(n.bytes()) } diff --git a/x/merkledb/db_test.go b/x/merkledb/db_test.go index a2e1cdc30f71..d4f09803cdaf 100644 --- a/x/merkledb/db_test.go +++ b/x/merkledb/db_test.go @@ -63,7 +63,7 @@ func Test_MerkleDB_Get_Safety(t *testing.T) { val, err := db.Get(keyBytes) require.NoError(err) - n, err := db.getNode(NewPath(keyBytes, BranchFactor16), true) + n, err := db.getNode(ToKey(keyBytes, BranchFactor16), true) require.NoError(err) // node's value shouldn't be affected by the edit @@ -861,10 +861,10 @@ func runRandDBTest(require *require.Assertions, r *rand.Rand, rt randTest, bf Br ) var ( - values = make(map[Path][]byte) // tracks content of the trie + values = make(map[Key][]byte) // tracks content of the trie currentBatch = db.NewBatch() - uncommittedKeyValues = make(map[Path][]byte) - uncommittedDeletes = set.Set[Path]{} + uncommittedKeyValues = make(map[Key][]byte) + uncommittedDeletes = set.Set[Key]{} pastRoots = []ids.ID{} ) @@ -877,13 +877,13 @@ func runRandDBTest(require *require.Assertions, r *rand.Rand, rt randTest, bf Br case opUpdate: require.NoError(currentBatch.Put(step.key, step.value)) - uncommittedKeyValues[NewPath(step.key, bf)] = step.value - uncommittedDeletes.Remove(NewPath(step.key, bf)) + uncommittedKeyValues[ToKey(step.key, bf)] = step.value + uncommittedDeletes.Remove(ToKey(step.key, bf)) case opDelete: require.NoError(currentBatch.Delete(step.key)) - uncommittedDeletes.Add(NewPath(step.key, bf)) - delete(uncommittedKeyValues, NewPath(step.key, bf)) + uncommittedDeletes.Add(ToKey(step.key, bf)) + delete(uncommittedKeyValues, ToKey(step.key, bf)) case opGenerateRangeProof: root, err := db.GetMerkleRoot(context.Background()) require.NoError(err) @@ -984,7 +984,7 @@ func runRandDBTest(require *require.Assertions, r *rand.Rand, rt randTest, bf Br require.ErrorIs(err, database.ErrNotFound) } - want := values[NewPath(step.key, bf)] + want := values[ToKey(step.key, bf)] require.True(bytes.Equal(want, v)) // Use bytes.Equal so nil treated equal to []byte{} trieValue, err := getNodeValueWithBranchFactor(db, string(step.key), bf) diff --git a/x/merkledb/helpers_test.go b/x/merkledb/helpers_test.go index 5b3bb5508146..3cd84ce11e7c 100644 --- a/x/merkledb/helpers_test.go +++ b/x/merkledb/helpers_test.go @@ -52,7 +52,7 @@ func writeBasicBatch(t *testing.T, db *merkleDB) { func newRandomProofNode(r *rand.Rand) ProofNode { key := make([]byte, r.Intn(32)) // #nosec G404 _, _ = r.Read(key) // #nosec G404 - serializedKey := NewPath(key, BranchFactor16) + serializedKey := ToKey(key, BranchFactor16) val := make([]byte, r.Intn(64)) // #nosec G404 _, _ = r.Read(val) // #nosec G404 @@ -83,7 +83,7 @@ func newRandomProofNode(r *rand.Rand) ProofNode { } return ProofNode{ - KeyPath: serializedKey, + Key: serializedKey, ValueOrHash: valueOrHash, Children: children, } diff --git a/x/merkledb/history.go b/x/merkledb/history.go index 65975f034694..c82fbb1e5f78 100644 --- a/x/merkledb/history.go +++ b/x/merkledb/history.go @@ -33,7 +33,7 @@ type trieHistory struct { // Each change is tagged with this monotonic increasing number. nextInsertNumber uint64 - newPath func([]byte) Path + toKey func([]byte) Key } // Tracks the beginning and ending state of a value. @@ -54,23 +54,23 @@ type changeSummaryAndInsertNumber struct { // Tracks all of the node and value changes that resulted in the rootID. type changeSummary struct { rootID ids.ID - nodes map[Path]*change[*node] - values map[Path]*change[maybe.Maybe[[]byte]] + nodes map[Key]*change[*node] + values map[Key]*change[maybe.Maybe[[]byte]] } func newChangeSummary(estimatedSize int) *changeSummary { return &changeSummary{ - nodes: make(map[Path]*change[*node], estimatedSize), - values: make(map[Path]*change[maybe.Maybe[[]byte]], estimatedSize), + nodes: make(map[Key]*change[*node], estimatedSize), + values: make(map[Key]*change[maybe.Maybe[[]byte]], estimatedSize), } } -func newTrieHistory(maxHistoryLookback int, newPath func([]byte) Path) *trieHistory { +func newTrieHistory(maxHistoryLookback int, toKey func([]byte) Key) *trieHistory { return &trieHistory{ maxHistoryLen: maxHistoryLookback, history: buffer.NewUnboundedDeque[*changeSummaryAndInsertNumber](maxHistoryLookback), lastChanges: make(map[ids.ID]*changeSummaryAndInsertNumber), - newPath: newPath, + toKey: toKey, } } @@ -156,10 +156,10 @@ func (th *trieHistory) getValueChanges( var ( // Keep track of changed keys so the largest can be removed // in order to stay within the [maxLength] limit if necessary. - changedKeys = set.Set[Path]{} + changedKeys = set.Set[Key]{} - startPath = maybe.Bind(start, th.newPath) - endPath = maybe.Bind(end, th.newPath) + startKey = maybe.Bind(start, th.toKey) + endKey = maybe.Bind(end, th.toKey) // For each element in the history in the range between [startRoot]'s // last appearance (exclusive) and [endRoot]'s last appearance (inclusive), @@ -183,8 +183,8 @@ func (th *trieHistory) getValueChanges( // Add the changes from this commit to [combinedChanges]. for key, valueChange := range changes.values { // The key is outside the range [start, end]. - if (startPath.HasValue() && key.Less(startPath.Value())) || - (end.HasValue() && key.Greater(endPath.Value())) { + if (startKey.HasValue() && key.Less(startKey.Value())) || + (end.HasValue() && key.Greater(endKey.Value())) { continue } @@ -237,8 +237,8 @@ func (th *trieHistory) getChangesToGetToRoot(rootID ids.ID, start maybe.Maybe[[] } var ( - startPath = maybe.Bind(start, th.newPath) - endPath = maybe.Bind(end, th.newPath) + startKey = maybe.Bind(start, th.toKey) + endKey = maybe.Bind(end, th.toKey) combinedChanges = newChangeSummary(defaultPreallocationSize) mostRecentChangeInsertNumber = th.nextInsertNumber - 1 mostRecentChangeIndex = th.history.Len() - 1 @@ -259,8 +259,8 @@ func (th *trieHistory) getChangesToGetToRoot(rootID ids.ID, start maybe.Maybe[[] } for key, valueChange := range changes.values { - if (startPath.IsNothing() || !key.Less(startPath.Value())) && - (endPath.IsNothing() || !key.Greater(endPath.Value())) { + if (startKey.IsNothing() || !key.Less(startKey.Value())) && + (endKey.IsNothing() || !key.Greater(endKey.Value())) { if existing, ok := combinedChanges.values[key]; ok { existing.after = valueChange.before } else { diff --git a/x/merkledb/history_test.go b/x/merkledb/history_test.go index 94627aef0d90..1261c92b22df 100644 --- a/x/merkledb/history_test.go +++ b/x/merkledb/history_test.go @@ -312,10 +312,10 @@ func Test_History_Values_Lookup_Over_Queue_Break(t *testing.T) { // changes should still be collectable even though the history has had to loop due to hitting max size changes, err := db.history.getValueChanges(startRoot, endRoot, maybe.Nothing[[]byte](), maybe.Nothing[[]byte](), 10) require.NoError(err) - require.Contains(changes.values, NewPath([]byte("key1"), BranchFactor16)) - require.Equal([]byte("value1"), changes.values[NewPath([]byte("key1"), BranchFactor16)].after.Value()) - require.Contains(changes.values, NewPath([]byte("key2"), BranchFactor16)) - require.Equal([]byte("value3"), changes.values[NewPath([]byte("key2"), BranchFactor16)].after.Value()) + require.Contains(changes.values, ToKey([]byte("key1"), BranchFactor16)) + require.Equal([]byte("value1"), changes.values[ToKey([]byte("key1"), BranchFactor16)].after.Value()) + require.Contains(changes.values, ToKey([]byte("key2"), BranchFactor16)) + require.Equal([]byte("value3"), changes.values[ToKey([]byte("key2"), BranchFactor16)].after.Value()) } func Test_History_RepeatedRoot(t *testing.T) { @@ -572,8 +572,8 @@ func TestHistoryRecord(t *testing.T) { require := require.New(t) maxHistoryLen := 3 - th := newTrieHistory(maxHistoryLen, func(bytes []byte) Path { - return NewPath(bytes, BranchFactor16) + th := newTrieHistory(maxHistoryLen, func(bytes []byte) Key { + return ToKey(bytes, BranchFactor16) }) changes := []*changeSummary{} @@ -647,22 +647,22 @@ func TestHistoryRecord(t *testing.T) { func TestHistoryGetChangesToRoot(t *testing.T) { maxHistoryLen := 3 - history := newTrieHistory(maxHistoryLen, func(bytes []byte) Path { - return NewPath(bytes, BranchFactor16) + history := newTrieHistory(maxHistoryLen, func(bytes []byte) Key { + return ToKey(bytes, BranchFactor16) }) changes := []*changeSummary{} for i := 0; i < maxHistoryLen; i++ { // Fill the history changes = append(changes, &changeSummary{ rootID: ids.GenerateTestID(), - nodes: map[Path]*change[*node]{ - history.newPath([]byte{byte(i)}): { + nodes: map[Key]*change[*node]{ + history.toKey([]byte{byte(i)}): { before: &node{id: ids.GenerateTestID()}, after: &node{id: ids.GenerateTestID()}, }, }, - values: map[Path]*change[maybe.Maybe[[]byte]]{ - history.newPath([]byte{byte(i)}): { + values: map[Key]*change[maybe.Maybe[[]byte]]{ + history.toKey([]byte{byte(i)}): { before: maybe.Some([]byte{byte(i)}), after: maybe.Some([]byte{byte(i + 1)}), }, @@ -701,7 +701,7 @@ func TestHistoryGetChangesToRoot(t *testing.T) { require.Len(got.nodes, 1) require.Len(got.values, 1) reversedChanges := changes[maxHistoryLen-1] - removedKey := history.newPath([]byte{byte(maxHistoryLen - 1)}) + removedKey := history.toKey([]byte{byte(maxHistoryLen - 1)}) require.Equal(reversedChanges.nodes[removedKey].before, got.nodes[removedKey].after) require.Equal(reversedChanges.values[removedKey].before, got.values[removedKey].after) require.Equal(reversedChanges.values[removedKey].after, got.values[removedKey].before) @@ -714,12 +714,12 @@ func TestHistoryGetChangesToRoot(t *testing.T) { require.Len(got.nodes, 2) require.Len(got.values, 2) reversedChanges1 := changes[maxHistoryLen-1] - removedKey1 := history.newPath([]byte{byte(maxHistoryLen - 1)}) + removedKey1 := history.toKey([]byte{byte(maxHistoryLen - 1)}) require.Equal(reversedChanges1.nodes[removedKey1].before, got.nodes[removedKey1].after) require.Equal(reversedChanges1.values[removedKey1].before, got.values[removedKey1].after) require.Equal(reversedChanges1.values[removedKey1].after, got.values[removedKey1].before) reversedChanges2 := changes[maxHistoryLen-2] - removedKey2 := history.newPath([]byte{byte(maxHistoryLen - 2)}) + removedKey2 := history.toKey([]byte{byte(maxHistoryLen - 2)}) require.Equal(reversedChanges2.nodes[removedKey2].before, got.nodes[removedKey2].after) require.Equal(reversedChanges2.values[removedKey2].before, got.values[removedKey2].after) require.Equal(reversedChanges2.values[removedKey2].after, got.values[removedKey2].before) @@ -733,12 +733,12 @@ func TestHistoryGetChangesToRoot(t *testing.T) { require.Len(got.nodes, 2) require.Len(got.values, 1) reversedChanges1 := changes[maxHistoryLen-1] - removedKey1 := history.newPath([]byte{byte(maxHistoryLen - 1)}) + removedKey1 := history.toKey([]byte{byte(maxHistoryLen - 1)}) require.Equal(reversedChanges1.nodes[removedKey1].before, got.nodes[removedKey1].after) require.Equal(reversedChanges1.values[removedKey1].before, got.values[removedKey1].after) require.Equal(reversedChanges1.values[removedKey1].after, got.values[removedKey1].before) reversedChanges2 := changes[maxHistoryLen-2] - removedKey2 := history.newPath([]byte{byte(maxHistoryLen - 2)}) + removedKey2 := history.toKey([]byte{byte(maxHistoryLen - 2)}) require.Equal(reversedChanges2.nodes[removedKey2].before, got.nodes[removedKey2].after) }, }, @@ -750,10 +750,10 @@ func TestHistoryGetChangesToRoot(t *testing.T) { require.Len(got.nodes, 2) require.Len(got.values, 1) reversedChanges1 := changes[maxHistoryLen-1] - removedKey1 := history.newPath([]byte{byte(maxHistoryLen - 1)}) + removedKey1 := history.toKey([]byte{byte(maxHistoryLen - 1)}) require.Equal(reversedChanges1.nodes[removedKey1].before, got.nodes[removedKey1].after) reversedChanges2 := changes[maxHistoryLen-2] - removedKey2 := history.newPath([]byte{byte(maxHistoryLen - 2)}) + removedKey2 := history.toKey([]byte{byte(maxHistoryLen - 2)}) require.Equal(reversedChanges2.nodes[removedKey2].before, got.nodes[removedKey2].after) require.Equal(reversedChanges2.values[removedKey2].before, got.values[removedKey2].after) require.Equal(reversedChanges2.values[removedKey2].after, got.values[removedKey2].before) diff --git a/x/merkledb/intermediate_node_db.go b/x/merkledb/intermediate_node_db.go index ba65f6313a71..e146b943d6c2 100644 --- a/x/merkledb/intermediate_node_db.go +++ b/x/merkledb/intermediate_node_db.go @@ -27,7 +27,7 @@ type intermediateNodeDB struct { // from the cache, which will call [OnEviction]. // A non-nil error returned from Put is considered fatal. // Keys in [nodeCache] aren't prefixed with [intermediateNodePrefix]. - nodeCache onEvictCache[Path, *node] + nodeCache onEvictCache[Key, *node] // the number of bytes to evict during an eviction batch evictionBatchSize int metrics merkleMetrics @@ -55,7 +55,7 @@ func newIntermediateNodeDB( } // A non-nil error is considered fatal and closes [db.baseDB]. -func (db *intermediateNodeDB) onEviction(key Path, n *node) error { +func (db *intermediateNodeDB) onEviction(key Key, n *node) error { writeBatch := db.baseDB.NewBatch() totalSize := cacheEntrySize(key, n) @@ -88,7 +88,7 @@ func (db *intermediateNodeDB) onEviction(key Path, n *node) error { return nil } -func (db *intermediateNodeDB) addToBatch(b database.Batch, key Path, n *node) error { +func (db *intermediateNodeDB) addToBatch(b database.Batch, key Key, n *node) error { dbKey := db.constructDBKey(key) defer db.bufferPool.Put(dbKey) db.metrics.DatabaseNodeWrite() @@ -98,7 +98,7 @@ func (db *intermediateNodeDB) addToBatch(b database.Batch, key Path, n *node) er return b.Put(dbKey, n.bytes()) } -func (db *intermediateNodeDB) Get(key Path) (*node, error) { +func (db *intermediateNodeDB) Get(key Key) (*node, error) { if cachedValue, isCached := db.nodeCache.Get(key); isCached { db.metrics.IntermediateNodeCacheHit() if cachedValue == nil { @@ -120,10 +120,10 @@ func (db *intermediateNodeDB) Get(key Path) (*node, error) { } // constructDBKey returns a key that can be used in [db.baseDB]. -// We need to be able to differentiate between two paths of equal +// We need to be able to differentiate between two keys of equal // byte length but different token length, so we add padding to differentiate. // Additionally, we add a prefix indicating it is part of the intermediateNodeDB. -func (db *intermediateNodeDB) constructDBKey(key Path) []byte { +func (db *intermediateNodeDB) constructDBKey(key Key) []byte { if key.branchFactor == BranchFactor256 { // For BranchFactor256, no padding is needed since byte length == token length return addPrefixToKey(db.bufferPool, intermediateNodePrefix, key.Bytes()) @@ -132,7 +132,7 @@ func (db *intermediateNodeDB) constructDBKey(key Path) []byte { return addPrefixToKey(db.bufferPool, intermediateNodePrefix, key.Append(1).Bytes()) } -func (db *intermediateNodeDB) Put(key Path, n *node) error { +func (db *intermediateNodeDB) Put(key Key, n *node) error { return db.nodeCache.Put(key, n) } @@ -140,6 +140,6 @@ func (db *intermediateNodeDB) Flush() error { return db.nodeCache.Flush() } -func (db *intermediateNodeDB) Delete(key Path) error { +func (db *intermediateNodeDB) Delete(key Key) error { return db.nodeCache.Put(key, nil) } diff --git a/x/merkledb/intermediate_node_db_test.go b/x/merkledb/intermediate_node_db_test.go index b8122753cd70..3d40aa7f8a05 100644 --- a/x/merkledb/intermediate_node_db_test.go +++ b/x/merkledb/intermediate_node_db_test.go @@ -23,7 +23,7 @@ import ( func Test_IntermediateNodeDB(t *testing.T) { require := require.New(t) - n := newNode(nil, NewPath([]byte{0x00}, BranchFactor16)) + n := newNode(nil, ToKey([]byte{0x00}, BranchFactor16)) n.setValue(maybe.Some([]byte{byte(0x02)})) nodeSize := cacheEntrySize(n.key, n) @@ -42,7 +42,7 @@ func Test_IntermediateNodeDB(t *testing.T) { ) // Put a key-node pair - node1Key := NewPath([]byte{0x01}, BranchFactor16) + node1Key := ToKey([]byte{0x01}, BranchFactor16) node1 := newNode(nil, node1Key) node1.setValue(maybe.Some([]byte{byte(0x01)})) require.NoError(db.Put(node1Key, node1)) @@ -73,8 +73,8 @@ func Test_IntermediateNodeDB(t *testing.T) { expectedSize := 0 added := 0 for { - key := NewPath([]byte{byte(added)}, BranchFactor16) - node := newNode(nil, emptyPath(BranchFactor16)) + key := ToKey([]byte{byte(added)}, BranchFactor16) + node := newNode(nil, emptyKey(BranchFactor16)) node.setValue(maybe.Some([]byte{byte(added)})) newExpectedSize := expectedSize + cacheEntrySize(key, node) if newExpectedSize > cacheSize { @@ -93,8 +93,8 @@ func Test_IntermediateNodeDB(t *testing.T) { // Put one more element in the cache, which should trigger an eviction // of all but 2 elements. 2 elements remain rather than 1 element because of // the added key prefix increasing the size tracked by the batch. - key := NewPath([]byte{byte(added)}, BranchFactor16) - node := newNode(nil, emptyPath(BranchFactor16)) + key := ToKey([]byte{byte(added)}, BranchFactor16) + node := newNode(nil, emptyKey(BranchFactor16)) node.setValue(maybe.Some([]byte{byte(added)})) require.NoError(db.Put(key, node)) @@ -102,7 +102,7 @@ func Test_IntermediateNodeDB(t *testing.T) { require.Equal(1, db.nodeCache.fifo.Len()) gotKey, _, ok := db.nodeCache.fifo.Oldest() require.True(ok) - require.Equal(NewPath([]byte{byte(added)}, BranchFactor16), gotKey) + require.Equal(ToKey([]byte{byte(added)}, BranchFactor16), gotKey) // Get a node from the base database // Use an early key that has been evicted from the cache @@ -150,8 +150,8 @@ func FuzzIntermediateNodeDBConstructDBKey(f *testing.F) { ) { require := require.New(t) for _, branchFactor := range branchFactors { - p := NewPath(key, branchFactor) - if p.tokensLength <= int(tokenLength) { + p := ToKey(key, branchFactor) + if p.tokenLength <= int(tokenLength) { t.SkipNow() } p = p.Take(int(tokenLength)) @@ -190,7 +190,7 @@ func Test_IntermediateNodeDB_ConstructDBKey_DirtyBuffer(t *testing.T) { ) db.bufferPool.Put([]byte{0xFF, 0xFF, 0xFF}) - constructedKey := db.constructDBKey(NewPath([]byte{}, BranchFactor16)) + constructedKey := db.constructDBKey(ToKey([]byte{}, BranchFactor16)) require.Len(constructedKey, 2) require.Equal(intermediateNodePrefix, constructedKey[:len(intermediateNodePrefix)]) require.Equal(byte(16), constructedKey[len(constructedKey)-1]) @@ -201,7 +201,7 @@ func Test_IntermediateNodeDB_ConstructDBKey_DirtyBuffer(t *testing.T) { }, } db.bufferPool.Put([]byte{0xFF, 0xFF, 0xFF}) - p := NewPath([]byte{0xF0}, BranchFactor16).Take(1) + p := ToKey([]byte{0xF0}, BranchFactor16).Take(1) constructedKey = db.constructDBKey(p) require.Len(constructedKey, 2) require.Equal(intermediateNodePrefix, constructedKey[:len(intermediateNodePrefix)]) diff --git a/x/merkledb/path.go b/x/merkledb/key.go similarity index 59% rename from x/merkledb/path.go rename to x/merkledb/key.go index 20fbb599e3c2..a2b6a3da065e 100644 --- a/x/merkledb/path.go +++ b/x/merkledb/key.go @@ -13,7 +13,7 @@ import ( var ( errInvalidBranchFactor = errors.New("invalid branch factor") - branchFactorToPathConfig = map[BranchFactor]pathConfig{ + branchFactorToTokenConfig = map[BranchFactor]tokenConfig{ BranchFactor2: { branchFactor: BranchFactor2, tokenBitSize: 1, @@ -51,71 +51,71 @@ const ( ) func (f BranchFactor) Valid() error { - if _, ok := branchFactorToPathConfig[f]; ok { + if _, ok := branchFactorToTokenConfig[f]; ok { return nil } return fmt.Errorf("%w: %d", errInvalidBranchFactor, f) } -type pathConfig struct { +type tokenConfig struct { branchFactor BranchFactor tokensPerByte int tokenBitSize byte singleTokenMask byte } -type Path struct { - tokensLength int - value string - pathConfig +type Key struct { + tokenLength int + value string + tokenConfig } -func emptyPath(bf BranchFactor) Path { - return Path{ - pathConfig: branchFactorToPathConfig[bf], +func emptyKey(bf BranchFactor) Key { + return Key{ + tokenConfig: branchFactorToTokenConfig[bf], } } -// NewPath returns [p] as a new path with the given [branchFactor]. +// ToKey returns [keyBytes] as a new key with the given [branchFactor]. // Assumes [branchFactor] is valid. -func NewPath(p []byte, branchFactor BranchFactor) Path { - pConfig := branchFactorToPathConfig[branchFactor] - return Path{ - value: byteSliceToString(p), - pathConfig: pConfig, - tokensLength: len(p) * pConfig.tokensPerByte, +func ToKey(keyBytes []byte, branchFactor BranchFactor) Key { + tc := branchFactorToTokenConfig[branchFactor] + return Key{ + value: byteSliceToString(keyBytes), + tokenConfig: tc, + tokenLength: len(keyBytes) * tc.tokensPerByte, } } -// TokensLength returns the number of tokens in [p]. -func (p Path) TokensLength() int { - return p.tokensLength +// TokensLength returns the number of tokens in [k]. +func (k Key) TokensLength() int { + return k.tokenLength } -// hasPartialByte returns true iff the path fits into a non-whole number of bytes -func (p Path) hasPartialByte() bool { - return p.tokensLength%p.tokensPerByte > 0 +// hasPartialByte returns true iff the key fits into a non-whole number of bytes +func (k Key) hasPartialByte() bool { + return k.tokenLength%k.tokensPerByte > 0 } -// HasPrefix returns true iff [prefix] is a prefix of [p] or equal to it. -func (p Path) HasPrefix(prefix Path) bool { - // [prefix] must be shorter than [p] to be a prefix. - if p.tokensLength < prefix.tokensLength { +// HasPrefix returns true iff [prefix] is a prefix of [k] or equal to it. +func (k Key) HasPrefix(prefix Key) bool { + // [prefix] must be shorter than [k] to be a prefix. + if k.tokenLength < prefix.tokenLength { return false } // The number of tokens in the last byte of [prefix], or zero // if [prefix] fits into a whole number of bytes. - remainderTokensCount := prefix.tokensLength % p.tokensPerByte + remainderTokensCount := prefix.tokenLength % k.tokensPerByte if remainderTokensCount == 0 { - return strings.HasPrefix(p.value, prefix.value) + return strings.HasPrefix(k.value, prefix.value) } // check that the tokens in the partially filled final byte of [prefix] are - // equal to the tokens in the final byte of [p]. - remainderBitsMask := byte(0xFF << (8 - remainderTokensCount*int(p.tokenBitSize))) + // equal to the tokens in the final byte of [k]. + remainderBitsMask := byte(0xFF << (8 - remainderTokensCount*int(k.tokenBitSize))) prefixRemainderTokens := prefix.value[len(prefix.value)-1] & remainderBitsMask - remainderTokens := p.value[len(prefix.value)-1] & remainderBitsMask + remainderTokens := k.value[len(prefix.value)-1] & remainderBitsMask if prefixRemainderTokens != remainderTokens { return false @@ -125,60 +125,46 @@ func (p Path) HasPrefix(prefix Path) bool { // If len(prefix.value) == 0 were true, [remainderTokens] would be 0 so we // would have returned above. prefixWithoutPartialByte := prefix.value[:len(prefix.value)-1] - return strings.HasPrefix(p.value, prefixWithoutPartialByte) + return strings.HasPrefix(k.value, prefixWithoutPartialByte) } -// iteratedHasPrefix checks if the provided prefix path is a prefix of the current path after having skipped [skipTokens] tokens first -// this has better performance than constructing the actual path via Skip() then calling HasPrefix because it avoids the []byte allocation -func (p Path) iteratedHasPrefix(skipTokens int, prefix Path) bool { - if p.tokensLength-skipTokens < prefix.tokensLength { - return false - } - for i := 0; i < prefix.tokensLength; i++ { - if p.Token(skipTokens+i) != prefix.Token(i) { - return false - } - } - return true -} - -// HasStrictPrefix returns true iff [prefix] is a prefix of [p] +// HasStrictPrefix returns true iff [prefix] is a prefix of [k] // but is not equal to it. -func (p Path) HasStrictPrefix(prefix Path) bool { - return p != prefix && p.HasPrefix(prefix) +func (k Key) HasStrictPrefix(prefix Key) bool { + return k != prefix && k.HasPrefix(prefix) } // Token returns the token at the specified index, -func (p Path) Token(index int) byte { - // Find the index in [p.value] of the byte containing the token at [index]. - storageByteIndex := index / p.tokensPerByte - storageByte := p.value[storageByteIndex] +func (k Key) Token(index int) byte { + // Find the index in [k.value] of the byte containing the token at [index]. + storageByteIndex := index / k.tokensPerByte + storageByte := k.value[storageByteIndex] // Shift the byte right to get the token to the rightmost position. - storageByte >>= p.bitsToShift(index) + storageByte >>= k.bitsToShift(index) // Apply a mask to remove any other tokens in the byte. - return storageByte & p.singleTokenMask + return storageByte & k.singleTokenMask } // Append returns a new Path that equals the current // Path with [token] appended to the end. -func (p Path) Append(token byte) Path { - buffer := make([]byte, p.bytesNeeded(p.tokensLength+1)) - p.appendIntoBuffer(buffer, token) - return Path{ - value: byteSliceToString(buffer), - tokensLength: p.tokensLength + 1, - pathConfig: p.pathConfig, +func (k Key) Append(token byte) Key { + buffer := make([]byte, k.bytesNeeded(k.tokenLength+1)) + k.appendIntoBuffer(buffer, token) + return Key{ + value: byteSliceToString(buffer), + tokenLength: k.tokenLength + 1, + tokenConfig: k.tokenConfig, } } -// Greater returns true if current Path is greater than other Path -func (p Path) Greater(other Path) bool { - return p.value > other.value || (p.value == other.value && p.tokensLength > other.tokensLength) +// Greater returns true if current Key is greater than other Key +func (k Key) Greater(other Key) bool { + return k.value > other.value || (k.value == other.value && k.tokenLength > other.tokenLength) } -// Less returns true if current Path is less than other Path -func (p Path) Less(other Path) bool { - return p.value < other.value || (p.value == other.value && p.tokensLength < other.tokensLength) +// Less returns true if current Key is less than other Key +func (k Key) Less(other Key) bool { + return k.value < other.value || (k.value == other.value && k.tokenLength < other.tokenLength) } // bitsToShift returns the number of bits to right shift a token @@ -198,15 +184,15 @@ func (p Path) Less(other Path) bool { // * Token at index 1 (0b0010) needs to be shifted by 0 bits // * Token at index 2 (0b0011) needs to be shifted by 4 bits // * Token at index 3 (0b0100) needs to be shifted by 0 bits -func (p Path) bitsToShift(index int) byte { +func (k Key) bitsToShift(index int) byte { // [tokenIndex] is the index of the token in the byte. // For example, if the branch factor is 16, then each byte contains 2 tokens. // The first is at index 0, and the second is at index 1, by this definition. - tokenIndex := index % p.tokensPerByte + tokenIndex := index % k.tokensPerByte // The bit within the byte that the token starts at. - startBitIndex := p.tokenBitSize * byte(tokenIndex) + startBitIndex := k.tokenBitSize * byte(tokenIndex) // The bit within the byte that the token ends at. - endBitIndex := startBitIndex + p.tokenBitSize - 1 + endBitIndex := startBitIndex + k.tokenBitSize - 1 // We want to right shift until [endBitIndex] is at the last index, so return // the distance from the end of the byte to the end of the token. // Note that 7 is the index of the last bit in a byte. @@ -218,14 +204,62 @@ func (p Path) bitsToShift(index int) byte { // // Invariant: [tokens] is a non-negative, but otherwise untrusted, input and // this method must never overflow. -func (p Path) bytesNeeded(tokens int) int { - size := tokens / p.tokensPerByte - if tokens%p.tokensPerByte != 0 { +func (k Key) bytesNeeded(tokens int) int { + size := tokens / k.tokensPerByte + if tokens%k.tokensPerByte != 0 { size++ } return size } +func (k Key) AppendExtend(token byte, extensionKey Key) Key { + appendBytes := k.bytesNeeded(k.tokenLength + 1) + totalLength := k.tokenLength + 1 + extensionKey.tokenLength + buffer := make([]byte, k.bytesNeeded(totalLength)) + k.appendIntoBuffer(buffer[:appendBytes], token) + + // the extension path will be shifted based on the number of tokens in the partial byte + tokenRemainder := (k.tokenLength + 1) % k.tokensPerByte + result := Key{ + value: byteSliceToString(buffer), + tokenLength: totalLength, + tokenConfig: k.tokenConfig, + } + + extensionBuffer := buffer[appendBytes-1:] + if extensionKey.tokenLength == 0 { + return result + } + + // If the existing value fits into a whole number of bytes, + // the extension path can be copied directly into the buffer. + if tokenRemainder == 0 { + copy(extensionBuffer[1:], extensionKey.value) + return result + } + + // The existing path doesn't fit into a whole number of bytes. + // Figure out how many bits to shift. + shift := extensionKey.bitsToShift(tokenRemainder - 1) + // Fill the partial byte with the first [shift] bits of the extension path + extensionBuffer[0] |= extensionKey.value[0] >> (8 - shift) + + // copy the rest of the extension path bytes into the buffer, + // shifted byte shift bits + shiftCopy(extensionBuffer[1:], extensionKey.value, shift) + + return result +} + +func (k Key) appendIntoBuffer(buffer []byte, token byte) { + copy(buffer, k.value) + + // Shift [token] to the left such that it's at the correct + // index within its storage byte, then OR it with its storage + // byte to write the token into the byte. + buffer[len(buffer)-1] |= token << k.bitsToShift(k.tokenLength) +} + // Treats [src] as a bit array and copies it into [dst] shifted by [shift] bits. // For example, if [src] is [0b0000_0001, 0b0000_0010] and [shift] is 4, // we copy [0b0001_0000, 0b0010_0000] into [dst]. @@ -244,28 +278,28 @@ func shiftCopy(dst []byte, src string, shift byte) { } } -// Skip returns a new Path that contains the last -// p.length-tokensToSkip tokens of [p]. -func (p Path) Skip(tokensToSkip int) Path { - if p.tokensLength == tokensToSkip { - return emptyPath(p.branchFactor) +// Skip returns a new Key that contains the last +// k.length-tokensToSkip tokens of [k]. +func (k Key) Skip(tokensToSkip int) Key { + if k.tokenLength == tokensToSkip { + return emptyKey(k.branchFactor) } - result := Path{ - value: p.value[tokensToSkip/p.tokensPerByte:], - tokensLength: p.tokensLength - tokensToSkip, - pathConfig: p.pathConfig, + result := Key{ + value: k.value[tokensToSkip/k.tokensPerByte:], + tokenLength: k.tokenLength - tokensToSkip, + tokenConfig: k.tokenConfig, } // if the tokens to skip is a whole number of bytes, - // the remaining bytes exactly equals the new path. - if tokensToSkip%p.tokensPerByte == 0 { + // the remaining bytes exactly equals the new key. + if tokensToSkip%k.tokensPerByte == 0 { return result } // tokensToSkip does not remove a whole number of bytes. // copy the remaining shifted bytes into a new buffer. - buffer := make([]byte, p.bytesNeeded(result.tokensLength)) - bitsSkipped := tokensToSkip * int(p.tokenBitSize) + buffer := make([]byte, k.bytesNeeded(result.tokenLength)) + bitsSkipped := tokensToSkip * int(k.tokenBitSize) bitsRemovedFromFirstRemainingByte := byte(bitsSkipped % 8) shiftCopy(buffer, result.value, bitsRemovedFromFirstRemainingByte) @@ -273,90 +307,56 @@ func (p Path) Skip(tokensToSkip int) Path { return result } -func (p Path) AppendExtend(token byte, extensionPath Path) Path { - appendBytes := p.bytesNeeded(p.tokensLength + 1) - totalLength := p.tokensLength + 1 + extensionPath.tokensLength - buffer := make([]byte, p.bytesNeeded(totalLength)) - p.appendIntoBuffer(buffer[:appendBytes], token) - - // the extension path will be shifted based on the number of tokens in the partial byte - tokenRemainder := (p.tokensLength + 1) % p.tokensPerByte - result := Path{ - value: byteSliceToString(buffer), - tokensLength: totalLength, - pathConfig: p.pathConfig, - } - - extensionBuffer := buffer[appendBytes-1:] - if extensionPath.tokensLength == 0 { - return result - } - - // If the existing value fits into a whole number of bytes, - // the extension path can be copied directly into the buffer. - if tokenRemainder == 0 { - copy(extensionBuffer[1:], extensionPath.value) - return result - } - - // The existing path doesn't fit into a whole number of bytes. - // Figure out how many bits to shift. - shift := extensionPath.bitsToShift(tokenRemainder - 1) - // Fill the partial byte with the first [shift] bits of the extension path - extensionBuffer[0] |= extensionPath.value[0] >> (8 - shift) - - // copy the rest of the extension path bytes into the buffer, - // shifted byte shift bits - shiftCopy(extensionBuffer[1:], extensionPath.value, shift) - - return result -} - -func (p Path) appendIntoBuffer(buffer []byte, token byte) { - copy(buffer, p.value) - - // Shift [token] to the left such that it's at the correct - // index within its storage byte, then OR it with its storage - // byte to write the token into the byte. - buffer[len(buffer)-1] |= token << p.bitsToShift(p.tokensLength) -} - -// Take returns a new Path that contains the first tokensToTake tokens of the current Path -func (p Path) Take(tokensToTake int) Path { - if p.tokensLength == tokensToTake { - return p +// Take returns a new Key that contains the first tokensToTake tokens of the current Key +func (k Key) Take(tokensToTake int) Key { + if k.tokenLength <= tokensToTake { + return k } - result := Path{ - tokensLength: tokensToTake, - pathConfig: p.pathConfig, + result := Key{ + tokenLength: tokensToTake, + tokenConfig: k.tokenConfig, } if !result.hasPartialByte() { - result.value = p.value[:tokensToTake/p.tokensPerByte] + result.value = k.value[:tokensToTake/k.tokensPerByte] return result } // We need to zero out some bits of the last byte so a simple slice will not work // Create a new []byte to store the altered value - buffer := make([]byte, p.bytesNeeded(tokensToTake)) - copy(buffer, p.value) + buffer := make([]byte, k.bytesNeeded(tokensToTake)) + copy(buffer, k.value) // We want to zero out everything to the right of the last token, which is at index [tokensToTake] - 1 // Mask will be (8-bitsToShift) number of 1's followed by (bitsToShift) number of 0's - mask := byte(0xFF << p.bitsToShift(tokensToTake-1)) + mask := byte(0xFF << k.bitsToShift(tokensToTake-1)) buffer[len(buffer)-1] &= mask result.value = byteSliceToString(buffer) return result } -// Bytes returns the raw bytes of the Path +// Bytes returns the raw bytes of the Key // Invariant: The returned value must not be modified. -func (p Path) Bytes() []byte { +func (k Key) Bytes() []byte { // avoid copying during the conversion // "safe" because we never edit the value, only used as DB key - return stringToByteSlice(p.value) + return stringToByteSlice(k.value) +} + +// iteratedHasPrefix checks if the provided prefix path is a prefix of the current path after having skipped [skipTokens] tokens first +// this has better performance than constructing the actual path via Skip() then calling HasPrefix because it avoids the []byte allocation +func (k Key) iteratedHasPrefix(skipTokens int, prefix Key) bool { + if k.tokenLength-skipTokens < prefix.tokenLength { + return false + } + for i := 0; i < prefix.tokenLength; i++ { + if k.Token(skipTokens+i) != prefix.Token(i) { + return false + } + } + return true } // byteSliceToString converts the []byte to a string diff --git a/x/merkledb/path_test.go b/x/merkledb/key_test.go similarity index 56% rename from x/merkledb/path_test.go rename to x/merkledb/key_test.go index a03de5c2f94a..e56ee1a98050 100644 --- a/x/merkledb/path_test.go +++ b/x/merkledb/key_test.go @@ -22,41 +22,41 @@ func TestHasPartialByte(t *testing.T) { t.Run(fmt.Sprint(branchFactor), func(t *testing.T) { require := require.New(t) - path := emptyPath(branchFactor) - require.False(path.hasPartialByte()) + key := emptyKey(branchFactor) + require.False(key.hasPartialByte()) if branchFactor == BranchFactor256 { // Tokens are an entire byte so // there is never a partial byte. - path = path.Append(0) - require.False(path.hasPartialByte()) - path = path.Append(0) - require.False(path.hasPartialByte()) + key = key.Append(0) + require.False(key.hasPartialByte()) + key = key.Append(0) + require.False(key.hasPartialByte()) return } // Fill all but the last token of the first byte. - for i := 0; i < path.tokensPerByte-1; i++ { - path = path.Append(0) - require.True(path.hasPartialByte()) + for i := 0; i < key.tokensPerByte-1; i++ { + key = key.Append(0) + require.True(key.hasPartialByte()) } // Fill the last token of the first byte. - path = path.Append(0) - require.False(path.hasPartialByte()) + key = key.Append(0) + require.False(key.hasPartialByte()) // Fill the first token of the second byte. - path = path.Append(0) - require.True(path.hasPartialByte()) + key = key.Append(0) + require.True(key.hasPartialByte()) }) } } -func Test_Path_Has_Prefix(t *testing.T) { +func Test_Key_Has_Prefix(t *testing.T) { type test struct { name string - pathA func(bf BranchFactor) Path - pathB func(bf BranchFactor) Path + keyA func(bf BranchFactor) Key + keyB func(bf BranchFactor) Key isStrictPrefix bool isPrefix bool } @@ -64,43 +64,43 @@ func Test_Path_Has_Prefix(t *testing.T) { key := "Key" keyLength := map[BranchFactor]int{} for _, branchFactor := range branchFactors { - config := branchFactorToPathConfig[branchFactor] + config := branchFactorToTokenConfig[branchFactor] keyLength[branchFactor] = len(key) * config.tokensPerByte } tests := []test{ { name: "equal keys", - pathA: func(bf BranchFactor) Path { return NewPath([]byte(key), bf) }, - pathB: func(bf BranchFactor) Path { return NewPath([]byte(key), bf) }, + keyA: func(bf BranchFactor) Key { return ToKey([]byte(key), bf) }, + keyB: func(bf BranchFactor) Key { return ToKey([]byte(key), bf) }, isPrefix: true, isStrictPrefix: false, }, { name: "one key has one fewer token", - pathA: func(bf BranchFactor) Path { return NewPath([]byte(key), bf) }, - pathB: func(bf BranchFactor) Path { return NewPath([]byte(key), bf).Take(keyLength[bf] - 1) }, + keyA: func(bf BranchFactor) Key { return ToKey([]byte(key), bf) }, + keyB: func(bf BranchFactor) Key { return ToKey([]byte(key), bf).Take(keyLength[bf] - 1) }, isPrefix: true, isStrictPrefix: true, }, { name: "equal keys, both have one fewer token", - pathA: func(bf BranchFactor) Path { return NewPath([]byte(key), bf).Take(keyLength[bf] - 1) }, - pathB: func(bf BranchFactor) Path { return NewPath([]byte(key), bf).Take(keyLength[bf] - 1) }, + keyA: func(bf BranchFactor) Key { return ToKey([]byte(key), bf).Take(keyLength[bf] - 1) }, + keyB: func(bf BranchFactor) Key { return ToKey([]byte(key), bf).Take(keyLength[bf] - 1) }, isPrefix: true, isStrictPrefix: false, }, { name: "different keys", - pathA: func(bf BranchFactor) Path { return NewPath([]byte{0xF7}, bf) }, - pathB: func(bf BranchFactor) Path { return NewPath([]byte{0xF0}, bf) }, + keyA: func(bf BranchFactor) Key { return ToKey([]byte{0xF7}, bf) }, + keyB: func(bf BranchFactor) Key { return ToKey([]byte{0xF0}, bf) }, isPrefix: false, isStrictPrefix: false, }, { name: "same bytes, different lengths", - pathA: func(bf BranchFactor) Path { return NewPath([]byte{0x10, 0x00}, bf).Take(1) }, - pathB: func(bf BranchFactor) Path { return NewPath([]byte{0x10, 0x00}, bf).Take(2) }, + keyA: func(bf BranchFactor) Key { return ToKey([]byte{0x10, 0x00}, bf).Take(1) }, + keyB: func(bf BranchFactor) Key { return ToKey([]byte{0x10, 0x00}, bf).Take(2) }, isPrefix: false, isStrictPrefix: false, }, @@ -110,76 +110,80 @@ func Test_Path_Has_Prefix(t *testing.T) { for _, bf := range branchFactors { t.Run(tt.name+" bf "+fmt.Sprint(bf), func(t *testing.T) { require := require.New(t) - pathA := tt.pathA(bf) - pathB := tt.pathB(bf) + keyA := tt.keyA(bf) + keyB := tt.keyB(bf) - require.Equal(tt.isPrefix, pathA.HasPrefix(pathB)) - require.Equal(tt.isPrefix, pathA.iteratedHasPrefix(0, pathB)) - require.Equal(tt.isStrictPrefix, pathA.HasStrictPrefix(pathB)) + require.Equal(tt.isPrefix, keyA.HasPrefix(keyB)) + require.Equal(tt.isPrefix, keyA.iteratedHasPrefix(0, keyB)) + require.Equal(tt.isStrictPrefix, keyA.HasStrictPrefix(keyB)) }) } } } -func Test_Path_Skip(t *testing.T) { +func Test_Key_Skip(t *testing.T) { require := require.New(t) for _, bf := range branchFactors { - empty := emptyPath(bf) - require.Equal(NewPath([]byte{0}, bf).Skip(empty.tokensPerByte), empty) + empty := emptyKey(bf) + require.Equal(ToKey([]byte{0}, bf).Skip(empty.tokensPerByte), empty) if bf == BranchFactor256 { continue } - shortPath := NewPath([]byte{0b0101_0101}, bf) - longPath := NewPath([]byte{0b0101_0101, 0b0101_0101}, bf) - for i := 0; i < shortPath.tokensPerByte; i++ { - shift := byte(i) * shortPath.tokenBitSize - skipPath := shortPath.Skip(i) - require.Equal(byte(0b0101_0101<>(8-shift)), skipPath.value[0]) - require.Equal(byte(0b0101_0101<>(8-shift)), skipKey.value[0]) + require.Equal(byte(0b0101_0101<>shift)< 0 { - path1 = path1.Take(path1.tokensLength - 1) + key1 := ToKey(first, branchFactor) + if forceFirstOdd && key1.tokenLength > 0 { + key1 = key1.Take(key1.tokenLength - 1) } - path2 := NewPath(second, branchFactor) - if forceSecondOdd && path2.tokensLength > 0 { - path2 = path2.Take(path2.tokensLength - 1) + key2 := ToKey(second, branchFactor) + if forceSecondOdd && key2.tokenLength > 0 { + key2 = key2.Take(key2.tokenLength - 1) } token = byte(int(token) % int(branchFactor)) - extendedP := path1.AppendExtend(token, path2) - require.Equal(path1.tokensLength+path2.tokensLength+1, extendedP.tokensLength) - for i := 0; i < path1.tokensLength; i++ { - require.Equal(path1.Token(i), extendedP.Token(i)) + extendedP := key1.AppendExtend(token, key2) + require.Equal(key1.tokenLength+key2.tokenLength+1, extendedP.tokenLength) + for i := 0; i < key1.tokenLength; i++ { + require.Equal(key1.Token(i), extendedP.Token(i)) } - require.Equal(token, extendedP.Token(path1.tokensLength)) - for i := 0; i < path2.tokensLength; i++ { - require.Equal(path2.Token(i), extendedP.Token(i+1+path1.tokensLength)) + require.Equal(token, extendedP.Token(key1.tokenLength)) + for i := 0; i < key2.tokenLength; i++ { + require.Equal(key2.Token(i), extendedP.Token(i+1+key1.tokenLength)) } } }) } -func FuzzPathSkip(f *testing.F) { +func FuzzKeySkip(f *testing.F) { f.Fuzz(func( t *testing.T, first []byte, @@ -506,20 +510,20 @@ func FuzzPathSkip(f *testing.F) { ) { require := require.New(t) for _, branchFactor := range branchFactors { - path1 := NewPath(first, branchFactor) - if int(tokensToSkip) >= path1.tokensLength { + key1 := ToKey(first, branchFactor) + if int(tokensToSkip) >= key1.tokenLength { t.SkipNow() } - path2 := path1.Skip(int(tokensToSkip)) - require.Equal(path1.tokensLength-int(tokensToSkip), path2.tokensLength) - for i := 0; i < path2.tokensLength; i++ { - require.Equal(path1.Token(int(tokensToSkip)+i), path2.Token(i)) + key2 := key1.Skip(int(tokensToSkip)) + require.Equal(key1.tokenLength-int(tokensToSkip), key2.tokenLength) + for i := 0; i < key2.tokenLength; i++ { + require.Equal(key1.Token(int(tokensToSkip)+i), key2.Token(i)) } } }) } -func FuzzPathTake(f *testing.F) { +func FuzzKeyTake(f *testing.F) { f.Fuzz(func( t *testing.T, first []byte, @@ -527,15 +531,15 @@ func FuzzPathTake(f *testing.F) { ) { require := require.New(t) for _, branchFactor := range branchFactors { - path1 := NewPath(first, branchFactor) - if int(tokensToTake) >= path1.tokensLength { + key1 := ToKey(first, branchFactor) + if int(tokensToTake) >= key1.tokenLength { t.SkipNow() } - path2 := path1.Take(int(tokensToTake)) - require.Equal(int(tokensToTake), path2.tokensLength) + key2 := key1.Take(int(tokensToTake)) + require.Equal(int(tokensToTake), key2.tokenLength) - for i := 0; i < path2.tokensLength; i++ { - require.Equal(path1.Token(i), path2.Token(i)) + for i := 0; i < key2.tokenLength; i++ { + require.Equal(key1.Token(i), key2.Token(i)) } } }) diff --git a/x/merkledb/mock_db.go b/x/merkledb/mock_db.go index f2d8eae8a69b..f7e35883c177 100644 --- a/x/merkledb/mock_db.go +++ b/x/merkledb/mock_db.go @@ -402,7 +402,7 @@ func (mr *MockMerkleDBMockRecorder) VerifyChangeProof(arg0, arg1, arg2, arg3, ar } // getEditableNode mocks base method. -func (m *MockMerkleDB) getEditableNode(arg0 Path, arg1 bool) (*node, error) { +func (m *MockMerkleDB) getEditableNode(arg0 Key, arg1 bool) (*node, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "getEditableNode", arg0, arg1) ret0, _ := ret[0].(*node) @@ -417,7 +417,7 @@ func (mr *MockMerkleDBMockRecorder) getEditableNode(arg0, arg1 interface{}) *gom } // getValue mocks base method. -func (m *MockMerkleDB) getValue(arg0 Path) ([]byte, error) { +func (m *MockMerkleDB) getValue(arg0 Key) ([]byte, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "getValue", arg0) ret0, _ := ret[0].([]byte) diff --git a/x/merkledb/node.go b/x/merkledb/node.go index f51ca92ae653..f78e4636c815 100644 --- a/x/merkledb/node.go +++ b/x/merkledb/node.go @@ -18,7 +18,7 @@ const HashLength = 32 type hashValues struct { Children map[byte]child Value maybe.Maybe[[]byte] - Key Path + Key Key } // Representation of a node stored in the database. @@ -28,23 +28,23 @@ type dbNode struct { } type child struct { - compressedPath Path - id ids.ID - hasValue bool + compressedKey Key + id ids.ID + hasValue bool } // node holds additional information on top of the dbNode that makes calculations easier to do type node struct { dbNode id ids.ID - key Path + key Key nodeBytes []byte valueDigest maybe.Maybe[[]byte] } // Returns a new node with the given [key] and no value. // If [parent] isn't nil, the new node is added as a child of [parent]. -func newNode(parent *node, key Path) *node { +func newNode(parent *node, key Key) *node { newNode := &node{ dbNode: dbNode{ children: make(map[byte]child, key.branchFactor), @@ -58,7 +58,7 @@ func newNode(parent *node, key Path) *node { } // Parse [nodeBytes] to a node and set its key to [key]. -func parseNode(key Path, nodeBytes []byte) (*node, error) { +func parseNode(key Key, nodeBytes []byte) (*node, error) { n := dbNode{} if err := codec.decodeDBNode(nodeBytes, &n, key.branchFactor); err != nil { return nil, err @@ -129,11 +129,11 @@ func (n *node) setValueDigest() { // That is, [n.key] is a prefix of [child.key]. func (n *node) addChild(childNode *node) { n.setChildEntry( - childNode.key.Token(n.key.tokensLength), + childNode.key.Token(n.key.tokenLength), child{ - compressedPath: childNode.key.Skip(n.key.tokensLength + 1), - id: childNode.id, - hasValue: childNode.hasValue(), + compressedKey: childNode.key.Skip(n.key.tokenLength + 1), + id: childNode.id, + hasValue: childNode.hasValue(), }, ) } @@ -147,7 +147,7 @@ func (n *node) setChildEntry(index byte, childEntry child) { // Removes [child] from [n]'s children. func (n *node) removeChild(child *node) { n.onNodeChanged() - delete(n.children, child.key.Token(n.key.tokensLength)) + delete(n.children, child.key.Token(n.key.tokenLength)) } // clone Returns a copy of [n]. @@ -170,7 +170,7 @@ func (n *node) clone() *node { // Returns the ProofNode representation of this node. func (n *node) asProofNode() ProofNode { pn := ProofNode{ - KeyPath: n.key, + Key: n.key, Children: make(map[byte]ids.ID, len(n.children)), ValueOrHash: maybe.Bind(n.valueDigest, slices.Clone[[]byte]), } diff --git a/x/merkledb/node_test.go b/x/merkledb/node_test.go index 1639d49bebb9..9632b7c7dacb 100644 --- a/x/merkledb/node_test.go +++ b/x/merkledb/node_test.go @@ -13,11 +13,11 @@ import ( ) func Test_Node_Marshal(t *testing.T) { - root := newNode(nil, emptyPath(BranchFactor16)) + root := newNode(nil, emptyKey(BranchFactor16)) require.NotNil(t, root) - fullpath := NewPath([]byte("key"), BranchFactor16) - childNode := newNode(root, fullpath) + fullKey := ToKey([]byte("key"), BranchFactor16) + childNode := newNode(root, fullKey) childNode.setValue(maybe.Some([]byte("value"))) require.NotNil(t, childNode) @@ -25,31 +25,31 @@ func Test_Node_Marshal(t *testing.T) { root.addChild(childNode) data := root.bytes() - rootParsed, err := parseNode(NewPath([]byte(""), BranchFactor16), data) + rootParsed, err := parseNode(ToKey([]byte(""), BranchFactor16), data) require.NoError(t, err) require.Len(t, rootParsed.children, 1) - rootIndex := getSingleChildPath(root).Token(root.key.tokensLength) - parsedIndex := getSingleChildPath(rootParsed).Token(rootParsed.key.tokensLength) + rootIndex := getSingleChildKey(root).Token(root.key.tokenLength) + parsedIndex := getSingleChildKey(rootParsed).Token(rootParsed.key.tokenLength) rootChildEntry := root.children[rootIndex] parseChildEntry := rootParsed.children[parsedIndex] require.Equal(t, rootChildEntry.id, parseChildEntry.id) } func Test_Node_Marshal_Errors(t *testing.T) { - root := newNode(nil, emptyPath(BranchFactor16)) + root := newNode(nil, emptyKey(BranchFactor16)) require.NotNil(t, root) - fullpath := NewPath([]byte{255}, BranchFactor16) - childNode1 := newNode(root, fullpath) + fullKey := ToKey([]byte{255}, BranchFactor16) + childNode1 := newNode(root, fullKey) childNode1.setValue(maybe.Some([]byte("value1"))) require.NotNil(t, childNode1) childNode1.calculateID(&mockMetrics{}) root.addChild(childNode1) - fullpath = NewPath([]byte{237}, BranchFactor16) - childNode2 := newNode(root, fullpath) + fullKey = ToKey([]byte{237}, BranchFactor16) + childNode2 := newNode(root, fullKey) childNode2.setValue(maybe.Some([]byte("value2"))) require.NotNil(t, childNode2) @@ -60,7 +60,7 @@ func Test_Node_Marshal_Errors(t *testing.T) { for i := 1; i < len(data); i++ { broken := data[:i] - _, err := parseNode(NewPath([]byte(""), BranchFactor16), broken) + _, err := parseNode(ToKey([]byte(""), BranchFactor16), broken) require.ErrorIs(t, err, io.ErrUnexpectedEOF) } } diff --git a/x/merkledb/proof.go b/x/merkledb/proof.go index 23cc7a3532ec..63ea34542c9b 100644 --- a/x/merkledb/proof.go +++ b/x/merkledb/proof.go @@ -45,8 +45,8 @@ var ( ErrInvalidChildIndex = errors.New("child index must be less than branch factor") ErrNilProofNode = errors.New("proof node is nil") ErrNilValueOrHash = errors.New("proof node's valueOrHash field is nil") - ErrNilPath = errors.New("path is nil") - ErrInvalidPathLength = errors.New("path length doesn't match bytes length, check specified branchFactor") + ErrNilKey = errors.New("key is nil") + ErrInvalidKeyLength = errors.New("key length doesn't match bytes length, check specified branchFactor") ErrNilRangeProof = errors.New("range proof is nil") ErrNilChangeProof = errors.New("change proof is nil") ErrNilMaybeBytes = errors.New("maybe bytes is nil") @@ -57,7 +57,7 @@ var ( ) type ProofNode struct { - KeyPath Path + Key Key // Nothing if this is an intermediate node. // The value in this node if its length < [HashLen]. // The hash of the value in this node otherwise. @@ -65,12 +65,12 @@ type ProofNode struct { Children map[byte]ids.ID } -// Assumes [node.Key.KeyPath.length] <= math.MaxUint64. +// Assumes [node.Key.Key.length] <= math.MaxUint64. func (node *ProofNode) ToProto() *pb.ProofNode { pbNode := &pb.ProofNode{ - Key: &pb.Path{ - Length: uint64(node.KeyPath.tokensLength), - Value: node.KeyPath.Bytes(), + Key: &pb.Key{ + Length: uint64(node.Key.tokenLength), + Value: node.Key.Bytes(), }, ValueOrHash: &pb.MaybeBytes{ Value: node.ValueOrHash.Value(), @@ -96,12 +96,12 @@ func (node *ProofNode) UnmarshalProto(pbNode *pb.ProofNode, bf BranchFactor) err case pbNode.ValueOrHash.IsNothing && len(pbNode.ValueOrHash.Value) != 0: return ErrInvalidMaybe case pbNode.Key == nil: - return ErrNilPath + return ErrNilKey } - node.KeyPath = NewPath(pbNode.Key.Value, bf).Take(int(pbNode.Key.Length)) + node.Key = ToKey(pbNode.Key.Value, bf).Take(int(pbNode.Key.Length)) - if len(node.KeyPath.value) != node.KeyPath.bytesNeeded(node.KeyPath.tokensLength) { - return ErrInvalidPathLength + if len(pbNode.Key.Value) != node.Key.bytesNeeded(node.Key.tokenLength) { + return ErrInvalidKeyLength } node.Children = make(map[byte]ids.ID, len(pbNode.Children)) @@ -130,7 +130,7 @@ type Proof struct { // Must always be non-empty (i.e. have the root node). Path []ProofNode // This is a proof that [key] exists/doesn't exist. - Key Path + Key Key // Nothing if [Key] isn't in the trie. // Otherwise the value corresponding to [Key]. @@ -156,8 +156,8 @@ func (proof *Proof) Verify(ctx context.Context, expectedRootID ids.ID) error { // then the value of the last proof node must match [proof.Value]. // Note partial byte length keys can never match the [proof.Key] since it's bytes, // and thus has a whole number of bytes - if !lastNode.KeyPath.hasPartialByte() && - proof.Key == lastNode.KeyPath && + if !lastNode.Key.hasPartialByte() && + proof.Key == lastNode.Key && !valueOrHashMatches(proof.Value, lastNode.ValueOrHash) { return ErrProofValueDoesntMatch } @@ -166,7 +166,7 @@ func (proof *Proof) Verify(ctx context.Context, expectedRootID ids.ID) error { // then this is an exclusion proof and should prove that [proof.Key] isn't in the trie. // Note length not evenly divisible into bytes can never match the [proof.Key] since it's bytes, // and thus an exact number of bytes. - if (lastNode.KeyPath.hasPartialByte() || proof.Key != lastNode.KeyPath) && + if (lastNode.Key.hasPartialByte() || proof.Key != lastNode.Key) && proof.Value.HasValue() { return ErrProofValueDoesntMatch } @@ -180,7 +180,7 @@ func (proof *Proof) Verify(ctx context.Context, expectedRootID ids.ID) error { // Insert all proof nodes. // [provenPath] is the path that we are proving exists, or the path // that is where the path we are proving doesn't exist should be. - provenPath := maybe.Some(proof.Path[len(proof.Path)-1].KeyPath) + provenPath := maybe.Some(proof.Path[len(proof.Path)-1].Key) if err = addPathInfo(view, proof.Path, provenPath, provenPath); err != nil { return err @@ -225,7 +225,7 @@ func (proof *Proof) UnmarshalProto(pbProof *pb.Proof, bf BranchFactor) error { return ErrInvalidMaybe } - proof.Key = NewPath(pbProof.Key, bf) + proof.Key = ToKey(pbProof.Key, bf) if !pbProof.Value.IsNothing { proof.Value = maybe.Some(pbProof.Value.Value) @@ -304,10 +304,10 @@ func (proof *RangeProof) Verify( // determine branch factor based on proof paths var branchFactor BranchFactor if len(proof.StartProof) > 0 { - branchFactor = proof.StartProof[0].KeyPath.branchFactor + branchFactor = proof.StartProof[0].Key.branchFactor } else { // safe because invariants prevent both start proof and end proof from being empty at the same time - branchFactor = proof.EndProof[0].KeyPath.branchFactor + branchFactor = proof.EndProof[0].Key.branchFactor } // Make sure the key-value pairs are sorted and in [start, end]. @@ -322,24 +322,24 @@ func (proof *RangeProof) Verify( // If [largestProvenPath] is Nothing, [proof] should // provide and prove all keys > [smallestProvenPath]. // If both are Nothing, [proof] should prove the entire trie. - smallestProvenPath := maybe.Bind(start, func(b []byte) Path { - return NewPath(b, branchFactor) + smallestProvenPath := maybe.Bind(start, func(b []byte) Key { + return ToKey(b, branchFactor) }) - largestProvenPath := maybe.Bind(end, func(b []byte) Path { - return NewPath(b, branchFactor) + largestProvenPath := maybe.Bind(end, func(b []byte) Key { + return ToKey(b, branchFactor) }) if len(proof.KeyValues) > 0 { // If [proof] has key-value pairs, we should insert children // greater than [largestProvenPath] to ancestors of the node containing // [largestProvenPath] so that we get the expected root ID. - largestProvenPath = maybe.Some(NewPath(proof.KeyValues[len(proof.KeyValues)-1].Key, branchFactor)) + largestProvenPath = maybe.Some(ToKey(proof.KeyValues[len(proof.KeyValues)-1].Key, branchFactor)) } // The key-value pairs (allegedly) proven by [proof]. - keyValues := make(map[Path][]byte, len(proof.KeyValues)) + keyValues := make(map[Key][]byte, len(proof.KeyValues)) for _, keyValue := range proof.KeyValues { - keyValues[NewPath(keyValue.Key, branchFactor)] = keyValue.Value + keyValues[ToKey(keyValue.Key, branchFactor)] = keyValue.Value } // Ensure that the start proof is valid and contains values that @@ -476,11 +476,11 @@ func (proof *RangeProof) UnmarshalProto(pbProof *pb.RangeProof, bf BranchFactor) // Verify that all non-intermediate nodes in [proof] which have keys // in [[start], [end]] have the value given for that key in [keysValues]. -func verifyAllRangeProofKeyValuesPresent(proof []ProofNode, start maybe.Maybe[Path], end maybe.Maybe[Path], keysValues map[Path][]byte) error { +func verifyAllRangeProofKeyValuesPresent(proof []ProofNode, start maybe.Maybe[Key], end maybe.Maybe[Key], keysValues map[Key][]byte) error { for i := 0; i < len(proof); i++ { var ( node = proof[i] - nodePath = node.KeyPath + nodePath = node.Key ) // Skip keys that cannot have a value (enforced by [verifyProofPath]). @@ -645,14 +645,14 @@ func verifyAllChangeProofKeyValuesPresent( ctx context.Context, db MerkleDB, proof []ProofNode, - start maybe.Maybe[Path], - end maybe.Maybe[Path], - keysValues map[Path]maybe.Maybe[[]byte], + start maybe.Maybe[Key], + end maybe.Maybe[Key], + keysValues map[Key]maybe.Maybe[[]byte], ) error { for i := 0; i < len(proof); i++ { var ( node = proof[i] - nodePath = node.KeyPath + nodePath = node.Key ) // Check the value of any node with a key that is within the range. @@ -745,17 +745,17 @@ func verifyKeyValues(kvs []KeyValue, start maybe.Maybe[[]byte], end maybe.Maybe[ // since all keys with values are written in complete bytes([]byte). // - Each key in [proof] is a strict prefix of the following key. // - Each key in [proof] is a strict prefix of [keyBytes], except possibly the last. -// - If the last element in [proof] is [keyPath], this is an inclusion proof. +// - If the last element in [proof] is [Key], this is an inclusion proof. // Otherwise, this is an exclusion proof and [keyBytes] must not be in [proof]. -func verifyProofPath(proof []ProofNode, keyPath maybe.Maybe[Path]) error { +func verifyProofPath(proof []ProofNode, key maybe.Maybe[Key]) error { if len(proof) == 0 { return nil } // loop over all but the last node since it will not have the prefix in exclusion proofs for i := 0; i < len(proof)-1; i++ { - nodeKey := proof[i].KeyPath - if keyPath.HasValue() && nodeKey.branchFactor != keyPath.Value().branchFactor { + nodeKey := proof[i].Key + if key.HasValue() && nodeKey.branchFactor != key.Value().branchFactor { return ErrInconsistentBranchFactor } @@ -766,12 +766,12 @@ func verifyProofPath(proof []ProofNode, keyPath maybe.Maybe[Path]) error { } // each node should have a key that has the proven key as a prefix - if keyPath.HasValue() && !keyPath.Value().HasStrictPrefix(nodeKey) { + if key.HasValue() && !key.Value().HasStrictPrefix(nodeKey) { return ErrProofNodeNotForKey } // each node should have a key that has a matching BranchFactor and is a prefix of the next node's key - nextKey := proof[i+1].KeyPath + nextKey := proof[i+1].Key if nextKey.branchFactor != nodeKey.branchFactor { return ErrInconsistentBranchFactor } @@ -783,7 +783,7 @@ func verifyProofPath(proof []ProofNode, keyPath maybe.Maybe[Path]) error { // check the last node for a value since the above loop doesn't check the last node if len(proof) > 0 { lastNode := proof[len(proof)-1] - if lastNode.KeyPath.hasPartialByte() && !lastNode.ValueOrHash.IsNothing() { + if lastNode.Key.hasPartialByte() && !lastNode.ValueOrHash.IsNothing() { return ErrPartialByteLengthWithValue } } @@ -823,8 +823,8 @@ func valueOrHashMatches(value maybe.Maybe[[]byte], valueOrHash maybe.Maybe[[]byt func addPathInfo( t *trieView, proofPath []ProofNode, - insertChildrenLessThan maybe.Maybe[Path], - insertChildrenGreaterThan maybe.Maybe[Path], + insertChildrenLessThan maybe.Maybe[Key], + insertChildrenGreaterThan maybe.Maybe[Key], ) error { var ( shouldInsertLeftChildren = insertChildrenLessThan.HasValue() @@ -833,15 +833,15 @@ func addPathInfo( for i := len(proofPath) - 1; i >= 0; i-- { proofNode := proofPath[i] - keyPath := proofNode.KeyPath + key := proofNode.Key - if keyPath.hasPartialByte() && !proofNode.ValueOrHash.IsNothing() { + if key.hasPartialByte() && !proofNode.ValueOrHash.IsNothing() { return ErrPartialByteLengthWithValue } // load the node associated with the key or create a new one // pass nothing because we are going to overwrite the value digest below - n, err := t.insert(keyPath, maybe.Nothing[[]byte]()) + n, err := t.insert(key, maybe.Nothing[[]byte]()) if err != nil { return err } @@ -857,12 +857,12 @@ func addPathInfo( // Add [proofNode]'s children which are outside the range // [insertChildrenLessThan, insertChildrenGreaterThan]. - compressedPath := emptyPath(keyPath.branchFactor) + compressedPath := emptyKey(key.branchFactor) for index, childID := range proofNode.Children { if existingChild, ok := n.children[index]; ok { - compressedPath = existingChild.compressedPath + compressedPath = existingChild.compressedKey } - childPath := keyPath.AppendExtend(index, compressedPath) + childPath := key.AppendExtend(index, compressedPath) if (shouldInsertLeftChildren && childPath.Less(insertChildrenLessThan.Value())) || (shouldInsertRightChildren && childPath.Greater(insertChildrenGreaterThan.Value())) { // We didn't set the other values on the child entry, but it doesn't matter. @@ -870,8 +870,8 @@ func addPathInfo( n.setChildEntry( index, child{ - id: childID, - compressedPath: compressedPath, + id: childID, + compressedKey: compressedPath, }) } } diff --git a/x/merkledb/proof_test.go b/x/merkledb/proof_test.go index b288fe74bb4f..bf9d9da18996 100644 --- a/x/merkledb/proof_test.go +++ b/x/merkledb/proof_test.go @@ -60,7 +60,7 @@ func Test_Proof_Verify_Bad_Data(t *testing.T) { expectedErr: nil, }, { - name: "odd length key path with value", + name: "odd length key with value", malform: func(proof *Proof) { proof.Path[1].ValueOrHash = maybe.Some([]byte{1, 2}) }, @@ -185,7 +185,7 @@ func Test_RangeProof_Verify_Bad_Data(t *testing.T) { expectedErr: ErrProofValueDoesntMatch, }, { - name: "EndProof: odd length key path with value", + name: "EndProof: odd length key with value", malform: func(proof *RangeProof) { proof.EndProof[1].ValueOrHash = maybe.Some([]byte{1, 2}) }, @@ -271,10 +271,10 @@ func Test_Proof(t *testing.T) { require.Len(proof.Path, 3) - require.Equal(NewPath([]byte("key1"), BranchFactor16), proof.Path[2].KeyPath) + require.Equal(ToKey([]byte("key1"), BranchFactor16), proof.Path[2].Key) require.Equal(maybe.Some([]byte("value1")), proof.Path[2].ValueOrHash) - require.Equal(NewPath([]byte{}, BranchFactor16), proof.Path[0].KeyPath) + require.Equal(ToKey([]byte{}, BranchFactor16), proof.Path[0].Key) require.True(proof.Path[0].ValueOrHash.IsNothing()) expectedRootID, err := trie.GetMerkleRoot(context.Background()) @@ -357,7 +357,7 @@ func Test_RangeProof_Syntactic_Verify(t *testing.T) { {Key: []byte{1}, Value: []byte{1}}, {Key: []byte{0}, Value: []byte{0}}, }, - EndProof: []ProofNode{{KeyPath: emptyPath(BranchFactor16)}}, + EndProof: []ProofNode{{Key: emptyKey(BranchFactor16)}}, }, expectedErr: ErrNonIncreasingValues, }, @@ -369,7 +369,7 @@ func Test_RangeProof_Syntactic_Verify(t *testing.T) { KeyValues: []KeyValue{ {Key: []byte{0}, Value: []byte{0}}, }, - EndProof: []ProofNode{{KeyPath: emptyPath(BranchFactor16)}}, + EndProof: []ProofNode{{Key: emptyKey(BranchFactor16)}}, }, expectedErr: ErrStateFromOutsideOfRange, }, @@ -381,7 +381,7 @@ func Test_RangeProof_Syntactic_Verify(t *testing.T) { KeyValues: []KeyValue{ {Key: []byte{2}, Value: []byte{0}}, }, - EndProof: []ProofNode{{KeyPath: emptyPath(BranchFactor16)}}, + EndProof: []ProofNode{{Key: emptyKey(BranchFactor16)}}, }, expectedErr: ErrStateFromOutsideOfRange, }, @@ -395,13 +395,13 @@ func Test_RangeProof_Syntactic_Verify(t *testing.T) { }, StartProof: []ProofNode{ { - KeyPath: NewPath([]byte{2}, BranchFactor16), + Key: ToKey([]byte{2}, BranchFactor16), }, { - KeyPath: NewPath([]byte{1}, BranchFactor16), + Key: ToKey([]byte{1}, BranchFactor16), }, }, - EndProof: []ProofNode{{KeyPath: emptyPath(BranchFactor16)}}, + EndProof: []ProofNode{{Key: emptyKey(BranchFactor16)}}, }, expectedErr: ErrProofNodeNotForKey, }, @@ -415,16 +415,16 @@ func Test_RangeProof_Syntactic_Verify(t *testing.T) { }, StartProof: []ProofNode{ { - KeyPath: NewPath([]byte{1}, BranchFactor16), + Key: ToKey([]byte{1}, BranchFactor16), }, { - KeyPath: NewPath([]byte{1, 2, 3}, BranchFactor16), // Not a prefix of [1, 2] + Key: ToKey([]byte{1, 2, 3}, BranchFactor16), // Not a prefix of [1, 2] }, { - KeyPath: NewPath([]byte{1, 2, 3, 4}, BranchFactor16), + Key: ToKey([]byte{1, 2, 3, 4}, BranchFactor16), }, }, - EndProof: []ProofNode{{KeyPath: emptyPath(BranchFactor16)}}, + EndProof: []ProofNode{{Key: emptyKey(BranchFactor16)}}, }, expectedErr: ErrProofNodeNotForKey, }, @@ -438,10 +438,10 @@ func Test_RangeProof_Syntactic_Verify(t *testing.T) { }, EndProof: []ProofNode{ { - KeyPath: NewPath([]byte{2}, BranchFactor16), + Key: ToKey([]byte{2}, BranchFactor16), }, { - KeyPath: NewPath([]byte{1}, BranchFactor16), + Key: ToKey([]byte{1}, BranchFactor16), }, }, }, @@ -454,18 +454,18 @@ func Test_RangeProof_Syntactic_Verify(t *testing.T) { proof: &RangeProof{ StartProof: []ProofNode{ { - KeyPath: NewPath([]byte{1}, BranchFactor16), + Key: ToKey([]byte{1}, BranchFactor16), }, { - KeyPath: NewPath([]byte{1, 2}, BranchFactor16), + Key: ToKey([]byte{1, 2}, BranchFactor16), }, }, EndProof: []ProofNode{ { - KeyPath: NewPath([]byte{1}, BranchFactor4), + Key: ToKey([]byte{1}, BranchFactor4), }, { - KeyPath: NewPath([]byte{1, 2}, BranchFactor4), + Key: ToKey([]byte{1, 2}, BranchFactor4), }, }, }, @@ -481,13 +481,13 @@ func Test_RangeProof_Syntactic_Verify(t *testing.T) { }, EndProof: []ProofNode{ { - KeyPath: NewPath([]byte{1}, BranchFactor16), + Key: ToKey([]byte{1}, BranchFactor16), }, { - KeyPath: NewPath([]byte{1, 2, 3}, BranchFactor16), // Not a prefix of [1, 2] + Key: ToKey([]byte{1, 2, 3}, BranchFactor16), // Not a prefix of [1, 2] }, { - KeyPath: NewPath([]byte{1, 2, 3, 4}, BranchFactor16), + Key: ToKey([]byte{1, 2, 3, 4}, BranchFactor16), }, }, }, @@ -523,12 +523,12 @@ func Test_RangeProof(t *testing.T) { require.Equal([]byte{2}, proof.KeyValues[1].Value) require.Equal([]byte{3}, proof.KeyValues[2].Value) - require.Nil(proof.EndProof[0].KeyPath.Bytes()) - require.Equal([]byte{0}, proof.EndProof[1].KeyPath.Bytes()) - require.Equal([]byte{3}, proof.EndProof[2].KeyPath.Bytes()) + require.Nil(proof.EndProof[0].Key.Bytes()) + require.Equal([]byte{0}, proof.EndProof[1].Key.Bytes()) + require.Equal([]byte{3}, proof.EndProof[2].Key.Bytes()) // only a single node here since others are duplicates in endproof - require.Equal([]byte{1}, proof.StartProof[0].KeyPath.Bytes()) + require.Equal([]byte{1}, proof.StartProof[0].Key.Bytes()) require.NoError(proof.Verify( context.Background(), @@ -578,9 +578,9 @@ func Test_RangeProof_NilStart(t *testing.T) { require.Equal([]byte("value1"), proof.KeyValues[0].Value) require.Equal([]byte("value2"), proof.KeyValues[1].Value) - require.Equal(NewPath([]byte("key2"), BranchFactor16), proof.EndProof[2].KeyPath, BranchFactor16) - require.Equal(NewPath([]byte("key2"), BranchFactor16).Take(7), proof.EndProof[1].KeyPath) - require.Equal(NewPath([]byte(""), BranchFactor16), proof.EndProof[0].KeyPath, BranchFactor16) + require.Equal(ToKey([]byte("key2"), BranchFactor16), proof.EndProof[2].Key, BranchFactor16) + require.Equal(ToKey([]byte("key2"), BranchFactor16).Take(7), proof.EndProof[1].Key) + require.Equal(ToKey([]byte(""), BranchFactor16), proof.EndProof[0].Key, BranchFactor16) require.NoError(proof.Verify( context.Background(), @@ -610,11 +610,11 @@ func Test_RangeProof_NilEnd(t *testing.T) { require.Equal([]byte{1}, proof.KeyValues[0].Value) require.Equal([]byte{2}, proof.KeyValues[1].Value) - require.Equal([]byte{1}, proof.StartProof[0].KeyPath.Bytes()) + require.Equal([]byte{1}, proof.StartProof[0].Key.Bytes()) - require.Nil(proof.EndProof[0].KeyPath.Bytes()) - require.Equal([]byte{0}, proof.EndProof[1].KeyPath.Bytes()) - require.Equal([]byte{2}, proof.EndProof[2].KeyPath.Bytes()) + require.Nil(proof.EndProof[0].Key.Bytes()) + require.Equal([]byte{0}, proof.EndProof[1].Key.Bytes()) + require.Equal([]byte{2}, proof.EndProof[2].Key.Bytes()) require.NoError(proof.Verify( context.Background(), @@ -652,11 +652,11 @@ func Test_RangeProof_EmptyValues(t *testing.T) { require.Empty(proof.KeyValues[2].Value) require.Len(proof.StartProof, 1) - require.Equal(NewPath([]byte("key1"), BranchFactor16), proof.StartProof[0].KeyPath, BranchFactor16) + require.Equal(ToKey([]byte("key1"), BranchFactor16), proof.StartProof[0].Key, BranchFactor16) require.Len(proof.EndProof, 3) - require.Equal(NewPath([]byte("key2"), BranchFactor16), proof.EndProof[2].KeyPath, BranchFactor16) - require.Equal(NewPath([]byte{}, BranchFactor16), proof.EndProof[0].KeyPath, BranchFactor16) + require.Equal(ToKey([]byte("key2"), BranchFactor16), proof.EndProof[2].Key, BranchFactor16) + require.Equal(ToKey([]byte{}, BranchFactor16), proof.EndProof[0].Key, BranchFactor16) require.NoError(proof.Verify( context.Background(), @@ -942,8 +942,8 @@ func Test_ChangeProof_Syntactic_Verify(t *testing.T) { name: "start proof node has wrong prefix", proof: &ChangeProof{ StartProof: []ProofNode{ - {KeyPath: NewPath([]byte{2}, BranchFactor16)}, - {KeyPath: NewPath([]byte{2, 3}, BranchFactor16)}, + {Key: ToKey([]byte{2}, BranchFactor16)}, + {Key: ToKey([]byte{2, 3}, BranchFactor16)}, }, }, start: maybe.Some([]byte{1, 2, 3}), @@ -954,8 +954,8 @@ func Test_ChangeProof_Syntactic_Verify(t *testing.T) { name: "start proof non-increasing", proof: &ChangeProof{ StartProof: []ProofNode{ - {KeyPath: NewPath([]byte{1}, BranchFactor16)}, - {KeyPath: NewPath([]byte{2, 3}, BranchFactor16)}, + {Key: ToKey([]byte{1}, BranchFactor16)}, + {Key: ToKey([]byte{2, 3}, BranchFactor16)}, }, }, start: maybe.Some([]byte{1, 2, 3}), @@ -969,8 +969,8 @@ func Test_ChangeProof_Syntactic_Verify(t *testing.T) { {Key: []byte{1, 2}, Value: maybe.Some([]byte{0})}, }, EndProof: []ProofNode{ - {KeyPath: NewPath([]byte{2}, BranchFactor16)}, - {KeyPath: NewPath([]byte{2, 3}, BranchFactor16)}, + {Key: ToKey([]byte{2}, BranchFactor16)}, + {Key: ToKey([]byte{2, 3}, BranchFactor16)}, }, }, start: maybe.Nothing[[]byte](), @@ -984,8 +984,8 @@ func Test_ChangeProof_Syntactic_Verify(t *testing.T) { {Key: []byte{1, 2, 3}}, }, EndProof: []ProofNode{ - {KeyPath: NewPath([]byte{1}, BranchFactor16)}, - {KeyPath: NewPath([]byte{2, 3}, BranchFactor16)}, + {Key: ToKey([]byte{1}, BranchFactor16)}, + {Key: ToKey([]byte{2, 3}, BranchFactor16)}, }, }, start: maybe.Nothing[[]byte](), @@ -1087,7 +1087,7 @@ func TestVerifyProofPath(t *testing.T) { type test struct { name string path []ProofNode - proofKey maybe.Maybe[Path] + proofKey maybe.Maybe[Key] expectedErr error } @@ -1095,124 +1095,124 @@ func TestVerifyProofPath(t *testing.T) { { name: "empty", path: nil, - proofKey: maybe.Nothing[Path](), + proofKey: maybe.Nothing[Key](), expectedErr: nil, }, { name: "1 element", - path: []ProofNode{{KeyPath: NewPath([]byte{1}, BranchFactor16)}}, - proofKey: maybe.Nothing[Path](), + path: []ProofNode{{Key: ToKey([]byte{1}, BranchFactor16)}}, + proofKey: maybe.Nothing[Key](), expectedErr: nil, }, { name: "non-increasing keys", path: []ProofNode{ - {KeyPath: NewPath([]byte{1}, BranchFactor16)}, - {KeyPath: NewPath([]byte{1, 2}, BranchFactor16)}, - {KeyPath: NewPath([]byte{1, 3}, BranchFactor16)}, + {Key: ToKey([]byte{1}, BranchFactor16)}, + {Key: ToKey([]byte{1, 2}, BranchFactor16)}, + {Key: ToKey([]byte{1, 3}, BranchFactor16)}, }, - proofKey: maybe.Some(NewPath([]byte{1, 2, 3}, BranchFactor16)), + proofKey: maybe.Some(ToKey([]byte{1, 2, 3}, BranchFactor16)), expectedErr: ErrNonIncreasingProofNodes, }, { name: "invalid key", path: []ProofNode{ - {KeyPath: NewPath([]byte{1}, BranchFactor16)}, - {KeyPath: NewPath([]byte{1, 2}, BranchFactor16)}, - {KeyPath: NewPath([]byte{1, 2, 4}, BranchFactor16)}, - {KeyPath: NewPath([]byte{1, 2, 3}, BranchFactor16)}, + {Key: ToKey([]byte{1}, BranchFactor16)}, + {Key: ToKey([]byte{1, 2}, BranchFactor16)}, + {Key: ToKey([]byte{1, 2, 4}, BranchFactor16)}, + {Key: ToKey([]byte{1, 2, 3}, BranchFactor16)}, }, - proofKey: maybe.Some(NewPath([]byte{1, 2, 3}, BranchFactor16)), + proofKey: maybe.Some(ToKey([]byte{1, 2, 3}, BranchFactor16)), expectedErr: ErrProofNodeNotForKey, }, { name: "extra node inclusion proof", path: []ProofNode{ - {KeyPath: NewPath([]byte{1}, BranchFactor16)}, - {KeyPath: NewPath([]byte{1, 2}, BranchFactor16)}, - {KeyPath: NewPath([]byte{1, 2, 3}, BranchFactor16)}, + {Key: ToKey([]byte{1}, BranchFactor16)}, + {Key: ToKey([]byte{1, 2}, BranchFactor16)}, + {Key: ToKey([]byte{1, 2, 3}, BranchFactor16)}, }, - proofKey: maybe.Some(NewPath([]byte{1, 2}, BranchFactor16)), + proofKey: maybe.Some(ToKey([]byte{1, 2}, BranchFactor16)), expectedErr: ErrProofNodeNotForKey, }, { name: "extra node exclusion proof", path: []ProofNode{ - {KeyPath: NewPath([]byte{1}, BranchFactor16)}, - {KeyPath: NewPath([]byte{1, 3}, BranchFactor16)}, - {KeyPath: NewPath([]byte{1, 3, 4}, BranchFactor16)}, + {Key: ToKey([]byte{1}, BranchFactor16)}, + {Key: ToKey([]byte{1, 3}, BranchFactor16)}, + {Key: ToKey([]byte{1, 3, 4}, BranchFactor16)}, }, - proofKey: maybe.Some(NewPath([]byte{1, 2}, BranchFactor16)), + proofKey: maybe.Some(ToKey([]byte{1, 2}, BranchFactor16)), expectedErr: ErrProofNodeNotForKey, }, { name: "happy path exclusion proof", path: []ProofNode{ - {KeyPath: NewPath([]byte{1}, BranchFactor16)}, - {KeyPath: NewPath([]byte{1, 2}, BranchFactor16)}, - {KeyPath: NewPath([]byte{1, 2, 4}, BranchFactor16)}, + {Key: ToKey([]byte{1}, BranchFactor16)}, + {Key: ToKey([]byte{1, 2}, BranchFactor16)}, + {Key: ToKey([]byte{1, 2, 4}, BranchFactor16)}, }, - proofKey: maybe.Some(NewPath([]byte{1, 2, 3}, BranchFactor16)), + proofKey: maybe.Some(ToKey([]byte{1, 2, 3}, BranchFactor16)), expectedErr: nil, }, { name: "happy path inclusion proof", path: []ProofNode{ - {KeyPath: NewPath([]byte{1}, BranchFactor16)}, - {KeyPath: NewPath([]byte{1, 2}, BranchFactor16)}, - {KeyPath: NewPath([]byte{1, 2, 3}, BranchFactor16)}, + {Key: ToKey([]byte{1}, BranchFactor16)}, + {Key: ToKey([]byte{1, 2}, BranchFactor16)}, + {Key: ToKey([]byte{1, 2, 3}, BranchFactor16)}, }, - proofKey: maybe.Some(NewPath([]byte{1, 2, 3}, BranchFactor16)), + proofKey: maybe.Some(ToKey([]byte{1, 2, 3}, BranchFactor16)), expectedErr: nil, }, { name: "repeat nodes", path: []ProofNode{ - {KeyPath: NewPath([]byte{1}, BranchFactor16)}, - {KeyPath: NewPath([]byte{1}, BranchFactor16)}, - {KeyPath: NewPath([]byte{1, 2}, BranchFactor16)}, - {KeyPath: NewPath([]byte{1, 2, 3}, BranchFactor16)}, + {Key: ToKey([]byte{1}, BranchFactor16)}, + {Key: ToKey([]byte{1}, BranchFactor16)}, + {Key: ToKey([]byte{1, 2}, BranchFactor16)}, + {Key: ToKey([]byte{1, 2, 3}, BranchFactor16)}, }, - proofKey: maybe.Some(NewPath([]byte{1, 2, 3}, BranchFactor16)), + proofKey: maybe.Some(ToKey([]byte{1, 2, 3}, BranchFactor16)), expectedErr: ErrNonIncreasingProofNodes, }, { name: "repeat nodes 2", path: []ProofNode{ - {KeyPath: NewPath([]byte{1}, BranchFactor16)}, - {KeyPath: NewPath([]byte{1, 2}, BranchFactor16)}, - {KeyPath: NewPath([]byte{1, 2}, BranchFactor16)}, - {KeyPath: NewPath([]byte{1, 2, 3}, BranchFactor16)}, + {Key: ToKey([]byte{1}, BranchFactor16)}, + {Key: ToKey([]byte{1, 2}, BranchFactor16)}, + {Key: ToKey([]byte{1, 2}, BranchFactor16)}, + {Key: ToKey([]byte{1, 2, 3}, BranchFactor16)}, }, - proofKey: maybe.Some(NewPath([]byte{1, 2, 3}, BranchFactor16)), + proofKey: maybe.Some(ToKey([]byte{1, 2, 3}, BranchFactor16)), expectedErr: ErrNonIncreasingProofNodes, }, { name: "repeat nodes 3", path: []ProofNode{ - {KeyPath: NewPath([]byte{1}, BranchFactor16)}, - {KeyPath: NewPath([]byte{1, 2}, BranchFactor16)}, - {KeyPath: NewPath([]byte{1, 2, 3}, BranchFactor16)}, - {KeyPath: NewPath([]byte{1, 2, 3}, BranchFactor16)}, + {Key: ToKey([]byte{1}, BranchFactor16)}, + {Key: ToKey([]byte{1, 2}, BranchFactor16)}, + {Key: ToKey([]byte{1, 2, 3}, BranchFactor16)}, + {Key: ToKey([]byte{1, 2, 3}, BranchFactor16)}, }, - proofKey: maybe.Some(NewPath([]byte{1, 2, 3}, BranchFactor16)), + proofKey: maybe.Some(ToKey([]byte{1, 2, 3}, BranchFactor16)), expectedErr: ErrProofNodeNotForKey, }, { name: "oddLength key with value", path: []ProofNode{ - {KeyPath: NewPath([]byte{1}, BranchFactor16)}, - {KeyPath: NewPath([]byte{1, 2}, BranchFactor16)}, + {Key: ToKey([]byte{1}, BranchFactor16)}, + {Key: ToKey([]byte{1, 2}, BranchFactor16)}, { - KeyPath: Path{ - value: string([]byte{1, 2, 240}), - tokensLength: 5, - pathConfig: branchFactorToPathConfig[BranchFactor16], + Key: Key{ + value: string([]byte{1, 2, 240}), + tokenLength: 5, + tokenConfig: branchFactorToTokenConfig[BranchFactor16], }, ValueOrHash: maybe.Some([]byte{1}), }, }, - proofKey: maybe.Some(NewPath([]byte{1, 2, 3}, BranchFactor16)), + proofKey: maybe.Some(ToKey([]byte{1, 2, 3}, BranchFactor16)), expectedErr: ErrPartialByteLengthWithValue, }, } @@ -1314,7 +1314,7 @@ func TestProofNodeUnmarshalProtoMissingFields(t *testing.T) { protoNode.Key = nil return protoNode }, - expectedErr: ErrNilPath, + expectedErr: ErrNilKey, }, } @@ -1575,7 +1575,7 @@ func FuzzProofProtoMarshalUnmarshal(f *testing.F) { } proof := Proof{ - Key: NewPath(key, BranchFactor16), + Key: ToKey(key, BranchFactor16), Value: value, Path: proofPath, } @@ -1698,13 +1698,13 @@ func FuzzRangeProofInvariants(f *testing.F) { // Make sure the start proof doesn't contain any nodes // that are in the end proof. - endProofKeys := set.Set[Path]{} + endProofKeys := set.Set[Key]{} for _, node := range rangeProof.EndProof { - endProofKeys.Add(node.KeyPath) + endProofKeys.Add(node.Key) } for _, node := range rangeProof.StartProof { - require.NotContains(endProofKeys, node.KeyPath) + require.NotContains(endProofKeys, node.Key) } // Make sure the EndProof invariant is maintained @@ -1713,7 +1713,7 @@ func FuzzRangeProofInvariants(f *testing.F) { if len(rangeProof.KeyValues) == 0 { if len(rangeProof.StartProof) == 0 { require.Len(rangeProof.EndProof, 1) // Just the root - require.Empty(rangeProof.EndProof[0].KeyPath.Bytes()) + require.Empty(rangeProof.EndProof[0].Key.Bytes()) } else { require.Empty(rangeProof.EndProof) } @@ -1732,7 +1732,7 @@ func FuzzRangeProofInvariants(f *testing.F) { proof := Proof{ Path: rangeProof.EndProof, - Key: NewPath(endBytes, BranchFactor16), + Key: ToKey(endBytes, BranchFactor16), Value: value, } @@ -1747,7 +1747,7 @@ func FuzzRangeProofInvariants(f *testing.F) { // EndProof should be a proof for largest key-value. proof := Proof{ Path: rangeProof.EndProof, - Key: NewPath(greatestKV.Key, BranchFactor16), + Key: ToKey(greatestKV.Key, BranchFactor16), Value: maybe.Some(greatestKV.Value), } diff --git a/x/merkledb/trie.go b/x/merkledb/trie.go index 998cd34f9ed2..d4b01d2de29a 100644 --- a/x/merkledb/trie.go +++ b/x/merkledb/trie.go @@ -19,7 +19,7 @@ type MerkleRootGetter interface { type ProofGetter interface { // GetProof generates a proof of the value associated with a particular key, // or a proof of its absence from the trie - GetProof(ctx context.Context, bytesPath []byte) (*Proof, error) + GetProof(ctx context.Context, keyBytes []byte) (*Proof, error) } type ReadOnlyTrie interface { @@ -36,11 +36,11 @@ type ReadOnlyTrie interface { // get the value associated with the key in path form // database.ErrNotFound if the key is not present - getValue(key Path) ([]byte, error) + getValue(key Key) ([]byte, error) // get an editable copy of the node with the given key path // hasValue indicates which db to look in (value or intermediate) - getEditableNode(key Path, hasValue bool) (*node, error) + getEditableNode(key Key, hasValue bool) (*node, error) // GetRangeProof returns a proof of up to [maxLength] key-value pairs with // keys in range [start, end]. diff --git a/x/merkledb/trie_test.go b/x/merkledb/trie_test.go index d13680b0b064..bd666b44ea9e 100644 --- a/x/merkledb/trie_test.go +++ b/x/merkledb/trie_test.go @@ -27,7 +27,7 @@ func getNodeValueWithBranchFactor(t ReadOnlyTrie, key string, bf BranchFactor) ( if err := asTrieView.calculateNodeIDs(context.Background()); err != nil { return nil, err } - path := NewPath([]byte(key), bf) + path := ToKey([]byte(key), bf) nodePath, err := asTrieView.getPathTo(path) if err != nil { return nil, err @@ -44,7 +44,7 @@ func getNodeValueWithBranchFactor(t ReadOnlyTrie, key string, bf BranchFactor) ( if err != nil { return nil, err } - path := NewPath([]byte(key), bf) + path := ToKey([]byte(key), bf) nodePath, err := view.(*trieView).getPathTo(path) if err != nil { return nil, err @@ -127,7 +127,7 @@ func TestTrieViewGetPathTo(t *testing.T) { require.IsType(&trieView{}, trieIntf) trie := trieIntf.(*trieView) - nodePath, err := trie.getPathTo(NewPath(nil, BranchFactor16)) + nodePath, err := trie.getPathTo(ToKey(nil, BranchFactor16)) require.NoError(err) // Just the root @@ -149,13 +149,13 @@ func TestTrieViewGetPathTo(t *testing.T) { trie = trieIntf.(*trieView) require.NoError(trie.calculateNodeIDs(context.Background())) - nodePath, err = trie.getPathTo(NewPath(key1, BranchFactor16)) + nodePath, err = trie.getPathTo(ToKey(key1, BranchFactor16)) require.NoError(err) // Root and 1 value require.Len(nodePath, 2) require.Equal(trie.root, nodePath[0]) - require.Equal(NewPath(key1, BranchFactor16), nodePath[1].key) + require.Equal(ToKey(key1, BranchFactor16), nodePath[1].key) // Insert another key which is a child of the first key2 := []byte{0, 1} @@ -172,12 +172,12 @@ func TestTrieViewGetPathTo(t *testing.T) { trie = trieIntf.(*trieView) require.NoError(trie.calculateNodeIDs(context.Background())) - nodePath, err = trie.getPathTo(NewPath(key2, BranchFactor16)) + nodePath, err = trie.getPathTo(ToKey(key2, BranchFactor16)) require.NoError(err) require.Len(nodePath, 3) require.Equal(trie.root, nodePath[0]) - require.Equal(NewPath(key1, BranchFactor16), nodePath[1].key) - require.Equal(NewPath(key2, BranchFactor16), nodePath[2].key) + require.Equal(ToKey(key1, BranchFactor16), nodePath[1].key) + require.Equal(ToKey(key2, BranchFactor16), nodePath[2].key) // Insert a key which shares no prefix with the others key3 := []byte{255} @@ -194,32 +194,32 @@ func TestTrieViewGetPathTo(t *testing.T) { trie = trieIntf.(*trieView) require.NoError(trie.calculateNodeIDs(context.Background())) - nodePath, err = trie.getPathTo(NewPath(key3, BranchFactor16)) + nodePath, err = trie.getPathTo(ToKey(key3, BranchFactor16)) require.NoError(err) require.Len(nodePath, 2) require.Equal(trie.root, nodePath[0]) - require.Equal(NewPath(key3, BranchFactor16), nodePath[1].key) + require.Equal(ToKey(key3, BranchFactor16), nodePath[1].key) // Other key path not affected - nodePath, err = trie.getPathTo(NewPath(key2, BranchFactor16)) + nodePath, err = trie.getPathTo(ToKey(key2, BranchFactor16)) require.NoError(err) require.Len(nodePath, 3) require.Equal(trie.root, nodePath[0]) - require.Equal(NewPath(key1, BranchFactor16), nodePath[1].key) - require.Equal(NewPath(key2, BranchFactor16), nodePath[2].key) + require.Equal(ToKey(key1, BranchFactor16), nodePath[1].key) + require.Equal(ToKey(key2, BranchFactor16), nodePath[2].key) // Gets closest node when key doesn't exist key4 := []byte{0, 1, 2} - nodePath, err = trie.getPathTo(NewPath(key4, BranchFactor16)) + nodePath, err = trie.getPathTo(ToKey(key4, BranchFactor16)) require.NoError(err) require.Len(nodePath, 3) require.Equal(trie.root, nodePath[0]) - require.Equal(NewPath(key1, BranchFactor16), nodePath[1].key) - require.Equal(NewPath(key2, BranchFactor16), nodePath[2].key) + require.Equal(ToKey(key1, BranchFactor16), nodePath[1].key) + require.Equal(ToKey(key2, BranchFactor16), nodePath[2].key) // Gets just root when key doesn't exist and no key shares a prefix key5 := []byte{128} - nodePath, err = trie.getPathTo(NewPath(key5, BranchFactor16)) + nodePath, err = trie.getPathTo(ToKey(key5, BranchFactor16)) require.NoError(err) require.Len(nodePath, 1) require.Equal(trie.root, nodePath[0]) @@ -304,7 +304,7 @@ func Test_Trie_WriteToDB(t *testing.T) { rawBytes, err := dbTrie.baseDB.Get(prefixedKey) require.NoError(err) - node, err := parseNode(NewPath(key, BranchFactor16), rawBytes) + node, err := parseNode(ToKey(key, BranchFactor16), rawBytes) require.NoError(err) require.Equal([]byte("value"), node.value.Value()) } @@ -603,7 +603,7 @@ func Test_Trie_HashCountOnBranch(t *testing.T) { // Make sure the branch node with the common prefix was created. // Note it's only created on call to GetMerkleRoot, not in NewView. - _, err = view2.getEditableNode(NewPath(keyPrefix, BranchFactor16), false) + _, err = view2.getEditableNode(ToKey(keyPrefix, BranchFactor16), false) require.NoError(err) // only hashes the new branch node, the new child node, and root @@ -744,7 +744,7 @@ func Test_Trie_ChainDeletion(t *testing.T) { require.NoError(err) require.NoError(newTrie.(*trieView).calculateNodeIDs(context.Background())) - root, err := newTrie.getEditableNode(emptyPath(BranchFactor16), false) + root, err := newTrie.getEditableNode(emptyKey(BranchFactor16), false) require.NoError(err) require.Len(root.children, 1) @@ -761,7 +761,7 @@ func Test_Trie_ChainDeletion(t *testing.T) { ) require.NoError(err) require.NoError(newTrie.(*trieView).calculateNodeIDs(context.Background())) - root, err = newTrie.getEditableNode(emptyPath(BranchFactor16), false) + root, err = newTrie.getEditableNode(emptyKey(BranchFactor16), false) require.NoError(err) // since all values have been deleted, the nodes should have been cleaned up require.Empty(root.children) @@ -826,15 +826,15 @@ func Test_Trie_NodeCollapse(t *testing.T) { require.NoError(err) require.NoError(trie.(*trieView).calculateNodeIDs(context.Background())) - root, err := trie.getEditableNode(emptyPath(BranchFactor16), false) + root, err := trie.getEditableNode(emptyKey(BranchFactor16), false) require.NoError(err) require.Len(root.children, 1) - root, err = trie.getEditableNode(emptyPath(BranchFactor16), false) + root, err = trie.getEditableNode(emptyKey(BranchFactor16), false) require.NoError(err) require.Len(root.children, 1) - firstNode, err := trie.getEditableNode(getSingleChildPath(root), true) + firstNode, err := trie.getEditableNode(getSingleChildKey(root), true) require.NoError(err) require.Len(firstNode.children, 1) @@ -852,11 +852,11 @@ func Test_Trie_NodeCollapse(t *testing.T) { require.NoError(err) require.NoError(trie.(*trieView).calculateNodeIDs(context.Background())) - root, err = trie.getEditableNode(emptyPath(BranchFactor16), false) + root, err = trie.getEditableNode(emptyKey(BranchFactor16), false) require.NoError(err) require.Len(root.children, 1) - firstNode, err = trie.getEditableNode(getSingleChildPath(root), true) + firstNode, err = trie.getEditableNode(getSingleChildKey(root), true) require.NoError(err) require.Len(firstNode.children, 2) } @@ -1199,11 +1199,11 @@ func Test_Trie_ConcurrentNewViewAndCommit(t *testing.T) { // Returns the path of the only child of this node. // Assumes this node has exactly one child. -func getSingleChildPath(n *node) Path { +func getSingleChildKey(n *node) Key { for index, entry := range n.children { - return n.key.AppendExtend(index, entry.compressedPath) + return n.key.AppendExtend(index, entry.compressedKey) } - return Path{} + return Key{} } func TestTrieCommitToDB(t *testing.T) { diff --git a/x/merkledb/trieview.go b/x/merkledb/trieview.go index a92970866e14..d83c50901096 100644 --- a/x/merkledb/trieview.go +++ b/x/merkledb/trieview.go @@ -145,7 +145,7 @@ func newTrieView( parentTrie TrieView, changes ViewChanges, ) (*trieView, error) { - root, err := parentTrie.getEditableNode(db.rootPath, false /* hasValue */) + root, err := parentTrie.getEditableNode(db.rootKey, false /* hasValue */) if err != nil { if err == database.ErrNotFound { return nil, ErrNoValidRoot @@ -173,7 +173,7 @@ func newTrieView( newVal = maybe.Some(slices.Clone(op.Value)) } } - if err := newView.recordValueChange(db.newPath(key), newVal); err != nil { + if err := newView.recordValueChange(db.toKey(key), newVal); err != nil { return nil, err } } @@ -181,7 +181,7 @@ func newTrieView( if !changes.ConsumeBytes { val = maybe.Bind(val, slices.Clone[[]byte]) } - if err := newView.recordValueChange(db.newPath(stringToByteSlice(key)), val); err != nil { + if err := newView.recordValueChange(db.toKey(stringToByteSlice(key)), val); err != nil { return nil, err } } @@ -197,7 +197,7 @@ func newHistoricalTrieView( return nil, ErrNoValidRoot } - passedRootChange, ok := changes.nodes[db.rootPath] + passedRootChange, ok := changes.nodes[db.rootKey] if !ok { return nil, ErrNoValidRoot } @@ -269,7 +269,7 @@ func (t *trieView) calculateNodeIDsHelper(n *node) { ) for childIndex, child := range n.children { - childPath := n.key.AppendExtend(childIndex, child.compressedPath) + childPath := n.key.AppendExtend(childIndex, child.compressedKey) childNodeChange, ok := t.changes.nodes[childPath] if !ok { // This child wasn't changed. @@ -302,13 +302,13 @@ func (t *trieView) calculateNodeIDsHelper(n *node) { wg.Wait() close(updatedChildren) - keyLength := n.key.tokensLength + keyLength := n.key.tokenLength for updatedChild := range updatedChildren { index := updatedChild.key.Token(keyLength) n.setChildEntry(index, child{ - compressedPath: n.children[index].compressedPath, - id: updatedChild.id, - hasValue: updatedChild.hasValue(), + compressedKey: n.children[index].compressedKey, + id: updatedChild.id, + hasValue: updatedChild.hasValue(), }) } @@ -334,7 +334,7 @@ func (t *trieView) getProof(ctx context.Context, key []byte) (*Proof, error) { defer span.End() proof := &Proof{ - Key: t.db.newPath(key), + Key: t.db.toKey(key), } proofPath, err := t.getPathTo(proof.Key) @@ -359,7 +359,7 @@ func (t *trieView) getProof(ctx context.Context, key []byte) (*Proof, error) { // There is no node with the given [key]. // If there is a child at the index where the node would be // if it existed, include that child in the proof. - nextIndex := proof.Key.Token(closestNode.key.tokensLength) + nextIndex := proof.Key.Token(closestNode.key.tokenLength) child, ok := closestNode.children[nextIndex] if !ok { return proof, nil @@ -367,7 +367,7 @@ func (t *trieView) getProof(ctx context.Context, key []byte) (*Proof, error) { childNode, err := t.getNodeWithID( child.id, - closestNode.key.AppendExtend(nextIndex, child.compressedPath), + closestNode.key.AppendExtend(nextIndex, child.compressedKey), child.hasValue, ) if err != nil { @@ -454,7 +454,7 @@ func (t *trieView) GetRangeProof( i := 0 for ; i < len(result.StartProof) && i < len(result.EndProof) && - result.StartProof[i].KeyPath == result.EndProof[i].KeyPath; i++ { + result.StartProof[i].Key == result.EndProof[i].Key; i++ { } result.StartProof = result.StartProof[i:] } @@ -561,7 +561,7 @@ func (t *trieView) GetValues(ctx context.Context, keys [][]byte) ([][]byte, []er valueErrors := make([]error, len(keys)) for i, key := range keys { - results[i], valueErrors[i] = t.getValueCopy(t.db.newPath(key)) + results[i], valueErrors[i] = t.getValueCopy(t.db.toKey(key)) } return results, valueErrors } @@ -572,12 +572,12 @@ func (t *trieView) GetValue(ctx context.Context, key []byte) ([]byte, error) { _, span := t.db.debugTracer.Start(ctx, "MerkleDB.trieview.GetValue") defer span.End() - return t.getValueCopy(t.db.newPath(key)) + return t.getValueCopy(t.db.toKey(key)) } // getValueCopy returns a copy of the value for the given [key]. // Returns database.ErrNotFound if it doesn't exist. -func (t *trieView) getValueCopy(key Path) ([]byte, error) { +func (t *trieView) getValueCopy(key Key) ([]byte, error) { val, err := t.getValue(key) if err != nil { return nil, err @@ -585,7 +585,7 @@ func (t *trieView) getValueCopy(key Path) ([]byte, error) { return slices.Clone(val), nil } -func (t *trieView) getValue(key Path) ([]byte, error) { +func (t *trieView) getValue(key Key) ([]byte, error) { if t.isInvalid() { return nil, ErrInvalid } @@ -614,7 +614,7 @@ func (t *trieView) getValue(key Path) ([]byte, error) { } // Must not be called after [calculateNodeIDs] has returned. -func (t *trieView) remove(key Path) error { +func (t *trieView) remove(key Key) error { if t.nodesAlreadyCalculated.Get() { return ErrNodesAlreadyCalculated } @@ -688,13 +688,13 @@ func (t *trieView) compressNodePath(parent, node *node) error { var ( childEntry child - childPath Path + childPath Key ) // There is only one child, but we don't know the index. // "Cycle" over the key/values to find the only child. // Note this iteration once because len(node.children) == 1. for index, entry := range node.children { - childPath = node.key.AppendExtend(index, entry.compressedPath) + childPath = node.key.AppendExtend(index, entry.compressedKey) childEntry = entry } @@ -752,7 +752,7 @@ func (t *trieView) deleteEmptyNodes(nodePath []*node) error { // given [key], if it's in the trie, or the node with the largest prefix of // the [key] if it isn't in the trie. // Always returns at least the root node. -func (t *trieView) getPathTo(key Path) ([]*node, error) { +func (t *trieView) getPathTo(key Key) ([]*node, error) { var ( // all node paths start at the root currentNode = t.root @@ -761,20 +761,20 @@ func (t *trieView) getPathTo(key Path) ([]*node, error) { ) // while the entire path hasn't been matched - for matchedPathIndex < key.tokensLength { + for matchedPathIndex < key.tokenLength { // confirm that a child exists and grab its ID before attempting to load it nextChildEntry, hasChild := currentNode.children[key.Token(matchedPathIndex)] // the current token for the child entry has now been handled, so increment the matchedPathIndex matchedPathIndex += 1 - if !hasChild || !key.iteratedHasPrefix(matchedPathIndex, nextChildEntry.compressedPath) { + if !hasChild || !key.iteratedHasPrefix(matchedPathIndex, nextChildEntry.compressedKey) { // there was no child along the path or the child that was there doesn't match the remaining path return nodes, nil } // the compressed path of the entry there matched the path, so increment the matched index - matchedPathIndex += nextChildEntry.compressedPath.tokensLength + matchedPathIndex += nextChildEntry.compressedKey.tokenLength // grab the next node along the path var err error @@ -789,9 +789,9 @@ func (t *trieView) getPathTo(key Path) ([]*node, error) { return nodes, nil } -func getLengthOfCommonPrefix(first, second Path, secondOffset int) int { +func getLengthOfCommonPrefix(first, second Key, secondOffset int) int { commonIndex := 0 - for first.tokensLength > commonIndex && second.tokensLength > (commonIndex+secondOffset) && first.Token(commonIndex) == second.Token(commonIndex+secondOffset) { + for first.tokenLength > commonIndex && second.tokenLength > (commonIndex+secondOffset) && first.Token(commonIndex) == second.Token(commonIndex+secondOffset) { commonIndex++ } return commonIndex @@ -799,7 +799,7 @@ func getLengthOfCommonPrefix(first, second Path, secondOffset int) int { // Get a copy of the node matching the passed key from the trie. // Used by views to get nodes from their ancestors. -func (t *trieView) getEditableNode(key Path, hadValue bool) (*node, error) { +func (t *trieView) getEditableNode(key Key, hadValue bool) (*node, error) { if t.isInvalid() { return nil, ErrInvalid } @@ -822,7 +822,7 @@ func (t *trieView) getEditableNode(key Path, hadValue bool) (*node, error) { // insert a key/value pair into the correct node of the trie. // Must not be called after [calculateNodeIDs] has returned. func (t *trieView) insert( - key Path, + key Key, value maybe.Maybe[[]byte], ) (*node, error) { if t.nodesAlreadyCalculated.Get() { @@ -852,7 +852,7 @@ func (t *trieView) insert( return closestNode, nil } - closestNodeKeyLength := closestNode.key.tokensLength + closestNodeKeyLength := closestNode.key.tokenLength // A node with the exact key doesn't exist so determine the portion of the // key that hasn't been matched yet @@ -878,11 +878,11 @@ func (t *trieView) insert( // find how many tokens are common between the existing child's compressed path and // the current key(offset by the closest node's key), // then move all the common tokens into the branch node - commonPrefixLength := getLengthOfCommonPrefix(existingChildEntry.compressedPath, key, closestNodeKeyLength+1) + commonPrefixLength := getLengthOfCommonPrefix(existingChildEntry.compressedKey, key, closestNodeKeyLength+1) // If the length of the existing child's compressed path is less than or equal to the branch node's key that implies that the existing child's key matched the key to be inserted. // Since it matched the key to be inserted, it should have been the last node returned by GetPathTo - if existingChildEntry.compressedPath.tokensLength <= commonPrefixLength { + if existingChildEntry.compressedKey.tokenLength <= commonPrefixLength { return nil, ErrGetPathToFailure } @@ -892,7 +892,7 @@ func (t *trieView) insert( ) nodeWithValue := branchNode - if key.tokensLength == branchNode.key.tokensLength { + if key.tokenLength == branchNode.key.tokenLength { // the branch node has exactly the key to be inserted as its key, so set the value on the branch node branchNode.setValue(value) } else { @@ -911,11 +911,11 @@ func (t *trieView) insert( // add the existing child onto the branch node branchNode.setChildEntry( - existingChildEntry.compressedPath.Token(commonPrefixLength), + existingChildEntry.compressedKey.Token(commonPrefixLength), child{ - compressedPath: existingChildEntry.compressedPath.Skip(commonPrefixLength + 1), - id: existingChildEntry.id, - hasValue: existingChildEntry.hasValue, + compressedKey: existingChildEntry.compressedKey.Skip(commonPrefixLength + 1), + id: existingChildEntry.id, + hasValue: existingChildEntry.hasValue, }) return nodeWithValue, t.recordNewNode(branchNode) @@ -937,7 +937,7 @@ func (t *trieView) recordNodeChange(after *node) error { // Must not be called after [calculateNodeIDs] has returned. func (t *trieView) recordNodeDeleted(after *node) error { // don't delete the root. - if after.key.tokensLength == 0 { + if after.key.tokenLength == 0 { return t.recordKeyChange(after.key, after, after.hasValue(), false /* newNode */) } return t.recordKeyChange(after.key, nil, after.hasValue(), false /* newNode */) @@ -946,7 +946,7 @@ func (t *trieView) recordNodeDeleted(after *node) error { // Records that the node associated with the given key has been changed. // If it is an existing node, record what its value was before it was changed. // Must not be called after [calculateNodeIDs] has returned. -func (t *trieView) recordKeyChange(key Path, after *node, hadValue bool, newNode bool) error { +func (t *trieView) recordKeyChange(key Key, after *node, hadValue bool, newNode bool) error { if t.nodesAlreadyCalculated.Get() { return ErrNodesAlreadyCalculated } @@ -978,7 +978,7 @@ func (t *trieView) recordKeyChange(key Path, after *node, hadValue bool, newNode // Doesn't actually change the trie data structure. // That's deferred until we call [calculateNodeIDs]. // Must not be called after [calculateNodeIDs] has returned. -func (t *trieView) recordValueChange(key Path, value maybe.Maybe[[]byte]) error { +func (t *trieView) recordValueChange(key Key, value maybe.Maybe[[]byte]) error { if t.nodesAlreadyCalculated.Get() { return ErrNodesAlreadyCalculated } @@ -1013,7 +1013,7 @@ func (t *trieView) recordValueChange(key Path, value maybe.Maybe[[]byte]) error // sets the node's ID to [id]. // If the node is loaded from the baseDB, [hasValue] determines which database the node is stored in. // Returns database.ErrNotFound if the node doesn't exist. -func (t *trieView) getNodeWithID(id ids.ID, key Path, hasValue bool) (*node, error) { +func (t *trieView) getNodeWithID(id ids.ID, key Key, hasValue bool) (*node, error) { // check for the key within the changed nodes if nodeChange, isChanged := t.changes.nodes[key]; isChanged { t.db.metrics.ViewNodeCacheHit() diff --git a/x/merkledb/value_node_db.go b/x/merkledb/value_node_db.go index b8b5788f27bb..8f168560d7fa 100644 --- a/x/merkledb/value_node_db.go +++ b/x/merkledb/value_node_db.go @@ -24,7 +24,7 @@ type valueNodeDB struct { // If a value is nil, the corresponding key isn't in the trie. // Paths in [nodeCache] aren't prefixed with [valueNodePrefix]. - nodeCache cache.Cacher[Path, *node] + nodeCache cache.Cacher[Key, *node] metrics merkleMetrics closed utils.Atomic[bool] @@ -66,11 +66,11 @@ func (db *valueNodeDB) Close() { func (db *valueNodeDB) NewBatch() *valueNodeBatch { return &valueNodeBatch{ db: db, - ops: make(map[Path]*node, defaultBufferLength), + ops: make(map[Key]*node, defaultBufferLength), } } -func (db *valueNodeDB) Get(key Path) (*node, error) { +func (db *valueNodeDB) Get(key Key) (*node, error) { if cachedValue, isCached := db.nodeCache.Get(key); isCached { db.metrics.ValueNodeCacheHit() if cachedValue == nil { @@ -95,14 +95,14 @@ func (db *valueNodeDB) Get(key Path) (*node, error) { // Batch of database operations type valueNodeBatch struct { db *valueNodeDB - ops map[Path]*node + ops map[Key]*node } -func (b *valueNodeBatch) Put(key Path, value *node) { +func (b *valueNodeBatch) Put(key Key, value *node) { b.ops[key] = value } -func (b *valueNodeBatch) Delete(key Path) { +func (b *valueNodeBatch) Delete(key Key) { b.ops[key] = nil } @@ -170,7 +170,7 @@ func (i *iterator) Next() bool { i.db.metrics.DatabaseNodeRead() key := i.nodeIter.Key() key = key[valueNodePrefixLen:] - n, err := parseNode(NewPath(key, i.db.branchFactor), i.nodeIter.Value()) + n, err := parseNode(ToKey(key, i.db.branchFactor), i.nodeIter.Value()) if err != nil { i.err = err return false diff --git a/x/merkledb/value_node_db_test.go b/x/merkledb/value_node_db_test.go index 46d9385ebda7..910c6e1e9d6b 100644 --- a/x/merkledb/value_node_db_test.go +++ b/x/merkledb/value_node_db_test.go @@ -32,7 +32,7 @@ func TestValueNodeDB(t *testing.T) { ) // Getting a key that doesn't exist should return an error. - key := NewPath([]byte{0x01}, BranchFactor16) + key := ToKey([]byte{0x01}, BranchFactor16) _, err := db.Get(key) require.ErrorIs(err, database.ErrNotFound) @@ -129,7 +129,7 @@ func TestValueNodeDBIterator(t *testing.T) { // Put key-node pairs. for i := 0; i < cacheSize; i++ { - key := NewPath([]byte{byte(i)}, BranchFactor16) + key := ToKey([]byte{byte(i)}, BranchFactor16) node := &node{ dbNode: dbNode{ value: maybe.Some([]byte{byte(i)}), @@ -167,7 +167,7 @@ func TestValueNodeDBIterator(t *testing.T) { it.Release() // Put key-node pairs with a common prefix. - key := NewPath([]byte{0xFF, 0x00}, BranchFactor16) + key := ToKey([]byte{0xFF, 0x00}, BranchFactor16) n := &node{ dbNode: dbNode{ value: maybe.Some([]byte{0xFF, 0x00}), @@ -178,7 +178,7 @@ func TestValueNodeDBIterator(t *testing.T) { batch.Put(key, n) require.NoError(batch.Write()) - key = NewPath([]byte{0xFF, 0x01}, BranchFactor16) + key = ToKey([]byte{0xFF, 0x01}, BranchFactor16) n = &node{ dbNode: dbNode{ value: maybe.Some([]byte{0xFF, 0x01}), diff --git a/x/merkledb/view_iterator.go b/x/merkledb/view_iterator.go index 7207e2920faa..263aa409e882 100644 --- a/x/merkledb/view_iterator.go +++ b/x/merkledb/view_iterator.go @@ -25,17 +25,17 @@ func (t *trieView) NewIteratorWithPrefix(prefix []byte) database.Iterator { func (t *trieView) NewIteratorWithStartAndPrefix(start, prefix []byte) database.Iterator { var ( - changes = make([]KeyChange, 0, len(t.changes.values)) - startPath = t.db.newPath(start) - prefixPath = t.db.newPath(prefix) + changes = make([]KeyChange, 0, len(t.changes.values)) + startKey = t.db.toKey(start) + prefixKey = t.db.toKey(prefix) ) - for path, change := range t.changes.values { - if len(start) > 0 && startPath.Greater(path) || !path.HasPrefix(prefixPath) { + for key, change := range t.changes.values { + if len(start) > 0 && startKey.Greater(key) || !key.HasPrefix(prefixKey) { continue } changes = append(changes, KeyChange{ - Key: path.Bytes(), + Key: key.Bytes(), Value: change.after, }) } diff --git a/x/sync/manager.go b/x/sync/manager.go index 96b104fcfdd0..0a13a89eb32b 100644 --- a/x/sync/manager.go +++ b/x/sync/manager.go @@ -404,16 +404,16 @@ func (m *Manager) findNextKey( // and traversing them from the longest key to the shortest key. // For each node in these proofs, compare if the children of that node exist // or have the same ID in the other proof. - proofKeyPath := merkledb.NewPath(lastReceivedKey, m.branchFactor) + proofKeyPath := merkledb.ToKey(lastReceivedKey, m.branchFactor) // If the received proof is an exclusion proof, the last node may be for a // key that is after the [lastReceivedKey]. // If the last received node's key is after the [lastReceivedKey], it can // be removed to obtain a valid proof for a prefix of the [lastReceivedKey]. - if !proofKeyPath.HasPrefix(endProof[len(endProof)-1].KeyPath) { + if !proofKeyPath.HasPrefix(endProof[len(endProof)-1].Key) { endProof = endProof[:len(endProof)-1] // update the proofKeyPath to be for the prefix - proofKeyPath = endProof[len(endProof)-1].KeyPath + proofKeyPath = endProof[len(endProof)-1].Key } // get a proof for the same key as the received proof from the local db @@ -425,7 +425,7 @@ func (m *Manager) findNextKey( // The local proof may also be an exclusion proof with an extra node. // Remove this extra node if it exists to get a proof of the same key as the received proof - if !proofKeyPath.HasPrefix(localProofNodes[len(localProofNodes)-1].KeyPath) { + if !proofKeyPath.HasPrefix(localProofNodes[len(localProofNodes)-1].Key) { localProofNodes = localProofNodes[:len(localProofNodes)-1] } @@ -447,7 +447,7 @@ func (m *Manager) findNextKey( // select the deepest proof node from the two proofs switch { - case receivedProofNode.KeyPath.TokensLength() > localProofNode.KeyPath.TokensLength(): + case receivedProofNode.Key.TokensLength() > localProofNode.Key.TokensLength(): // there was a branch node in the received proof that isn't in the local proof // see if the received proof node has children not present in the local proof deepestNode = &receivedProofNode @@ -455,7 +455,7 @@ func (m *Manager) findNextKey( // we have dealt with this received node, so move on to the next received node receivedProofNodeIndex-- - case localProofNode.KeyPath.TokensLength() > receivedProofNode.KeyPath.TokensLength(): + case localProofNode.Key.TokensLength() > receivedProofNode.Key.TokensLength(): // there was a branch node in the local proof that isn't in the received proof // see if the local proof node has children not present in the received proof deepestNode = &localProofNode @@ -489,13 +489,13 @@ func (m *Manager) findNextKey( // node's children have keys larger than [proofKeyPath]. // Any child with a token greater than the [proofKeyPath]'s token at that // index will have a larger key. - if deepestNode.KeyPath.TokensLength() < proofKeyPath.TokensLength() { - startingChildToken = proofKeyPath.Token(deepestNode.KeyPath.TokensLength()) + 1 + if deepestNode.Key.TokensLength() < proofKeyPath.TokensLength() { + startingChildToken = proofKeyPath.Token(deepestNode.Key.TokensLength()) + 1 } // determine if there are any differences in the children for the deepest unhandled node of the two proofs if childIndex, hasDifference := findChildDifference(deepestNode, deepestNodeFromOtherProof, startingChildToken, m.branchFactor); hasDifference { - nextKey = maybe.Some(deepestNode.KeyPath.Append(childIndex).Bytes()) + nextKey = maybe.Some(deepestNode.Key.Append(childIndex).Bytes()) break } } diff --git a/x/sync/peer_tracker.go b/x/sync/peer_tracker.go index e1d471cc40ab..7c105f3363af 100644 --- a/x/sync/peer_tracker.go +++ b/x/sync/peer_tracker.go @@ -14,6 +14,7 @@ import ( "go.uber.org/zap" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils/heap" "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/set" "github.com/ava-labs/avalanchego/utils/wrappers" @@ -55,7 +56,7 @@ type peerTracker struct { // Peers that we're connected to that responded to the last request they were sent. responsivePeers set.Set[ids.NodeID] // Max heap that contains the average bandwidth of peers. - bandwidthHeap safemath.AveragerHeap + bandwidthHeap heap.Map[ids.NodeID, safemath.Averager] averageBandwidth safemath.Averager log logging.Logger numTrackedPeers prometheus.Gauge @@ -69,10 +70,12 @@ func newPeerTracker( registerer prometheus.Registerer, ) (*peerTracker, error) { t := &peerTracker{ - peers: make(map[ids.NodeID]*peerInfo), - trackedPeers: make(set.Set[ids.NodeID]), - responsivePeers: make(set.Set[ids.NodeID]), - bandwidthHeap: safemath.NewMaxAveragerHeap(), + peers: make(map[ids.NodeID]*peerInfo), + trackedPeers: make(set.Set[ids.NodeID]), + responsivePeers: make(set.Set[ids.NodeID]), + bandwidthHeap: heap.NewMap[ids.NodeID, safemath.Averager](func(a, b safemath.Averager) bool { + return a.Read() > b.Read() + }), averageBandwidth: safemath.NewAverager(0, bandwidthHalflife, time.Now()), log: log, numTrackedPeers: prometheus.NewGauge( @@ -212,7 +215,7 @@ func (p *peerTracker) TrackBandwidth(nodeID ids.NodeID, bandwidth float64) { } else { peer.bandwidth.Observe(bandwidth, now) } - p.bandwidthHeap.Add(nodeID, peer.bandwidth) + p.bandwidthHeap.Push(nodeID, peer.bandwidth) if bandwidth == 0 { p.responsivePeers.Remove(nodeID) diff --git a/x/sync/sync_test.go b/x/sync/sync_test.go index e3a363d2d4b6..9a6a5de2dba9 100644 --- a/x/sync/sync_test.go +++ b/x/sync/sync_test.go @@ -668,7 +668,7 @@ func TestFindNextKeyRandom(t *testing.T) { require.NoError(err) type keyAndID struct { - key merkledb.Path + key merkledb.Key id ids.ID } @@ -677,7 +677,7 @@ func TestFindNextKeyRandom(t *testing.T) { for _, node := range remoteProof.EndProof { for childIdx, childID := range node.Children { remoteKeyIDs = append(remoteKeyIDs, keyAndID{ - key: node.KeyPath.Append(childIdx), + key: node.Key.Append(childIdx), id: childID, }) } @@ -688,7 +688,7 @@ func TestFindNextKeyRandom(t *testing.T) { for _, node := range localProof.Path { for childIdx, childID := range node.Children { localKeyIDs = append(localKeyIDs, keyAndID{ - key: node.KeyPath.Append(childIdx), + key: node.Key.Append(childIdx), id: childID, }) } @@ -731,7 +731,7 @@ func TestFindNextKeyRandom(t *testing.T) { // Find smallest difference between the set of key/ID pairs proven by // the remote/local proofs for key/ID pairs after the last received key. var ( - smallestDiffKey merkledb.Path + smallestDiffKey merkledb.Key foundDiff bool ) for i := 0; i < len(remoteKeyIDs) && i < len(localKeyIDs); i++ { diff --git a/x/sync/workheap.go b/x/sync/workheap.go index 36b27e229296..76d438c92d17 100644 --- a/x/sync/workheap.go +++ b/x/sync/workheap.go @@ -5,23 +5,14 @@ package sync import ( "bytes" - "container/heap" + "github.com/ava-labs/avalanchego/utils/heap" "github.com/ava-labs/avalanchego/utils/math" "github.com/ava-labs/avalanchego/utils/maybe" "github.com/google/btree" ) -var _ heap.Interface = (*innerHeap)(nil) - -type heapItem struct { - workItem *workItem - heapIndex int -} - -type innerHeap []*heapItem - // A priority queue of syncWorkItems. // Note that work item ranges never overlap. // Supports range merging and priority updating. @@ -29,20 +20,23 @@ type innerHeap []*heapItem type workHeap struct { // Max heap of items by priority. // i.e. heap.Pop returns highest priority item. - innerHeap innerHeap + innerHeap heap.Set[*workItem] // The heap items sorted by range start. // A Nothing start is considered to be the smallest. - sortedItems *btree.BTreeG[*heapItem] + sortedItems *btree.BTreeG[*workItem] closed bool } func newWorkHeap() *workHeap { return &workHeap{ + innerHeap: heap.NewSet[*workItem](func(a, b *workItem) bool { + return a.priority > b.priority + }), sortedItems: btree.NewG( 2, - func(a, b *heapItem) bool { - aNothing := a.workItem.start.IsNothing() - bNothing := b.workItem.start.IsNothing() + func(a, b *workItem) bool { + aNothing := a.start.IsNothing() + bNothing := b.start.IsNothing() if aNothing { // [a] is Nothing, so if [b] is Nothing, they're equal. // Otherwise, [b] is greater. @@ -53,7 +47,7 @@ func newWorkHeap() *workHeap { return false } // [a] and [b] both contain values. Compare the values. - return bytes.Compare(a.workItem.start.Value(), b.workItem.start.Value()) < 0 + return bytes.Compare(a.start.Value(), b.start.Value()) < 0 }, ), } @@ -70,10 +64,8 @@ func (wh *workHeap) Insert(item *workItem) { return } - wrappedItem := &heapItem{workItem: item} - - heap.Push(&wh.innerHeap, wrappedItem) - wh.sortedItems.ReplaceOrInsert(wrappedItem) + wh.innerHeap.Push(item) + wh.sortedItems.ReplaceOrInsert(item) } // Pops and returns a work item from the heap. @@ -82,9 +74,9 @@ func (wh *workHeap) GetWork() *workItem { if wh.closed || wh.Len() == 0 { return nil } - item := heap.Pop(&wh.innerHeap).(*heapItem) + item, _ := wh.innerHeap.Pop() wh.sortedItems.Delete(item) - return item.workItem + return item } // Insert the item into the heap, merging it with existing items @@ -99,25 +91,23 @@ func (wh *workHeap) MergeInsert(item *workItem) { return } - var mergedBefore, mergedAfter *heapItem - searchItem := &heapItem{ - workItem: &workItem{ - start: item.start, - }, + var mergedBefore, mergedAfter *workItem + searchItem := &workItem{ + start: item.start, } // Find the item with the greatest start range which is less than [item.start]. // Note that the iterator function will run at most once, since it always returns false. wh.sortedItems.DescendLessOrEqual( searchItem, - func(beforeItem *heapItem) bool { - if item.localRootID == beforeItem.workItem.localRootID && - maybe.Equal(item.start, beforeItem.workItem.end, bytes.Equal) { + func(beforeItem *workItem) bool { + if item.localRootID == beforeItem.localRootID && + maybe.Equal(item.start, beforeItem.end, bytes.Equal) { // [beforeItem.start, beforeItem.end] and [item.start, item.end] are // merged into [beforeItem.start, item.end] - beforeItem.workItem.end = item.end - beforeItem.workItem.priority = math.Max(item.priority, beforeItem.workItem.priority) - heap.Fix(&wh.innerHeap, beforeItem.heapIndex) + beforeItem.end = item.end + beforeItem.priority = math.Max(item.priority, beforeItem.priority) + wh.innerHeap.Fix(beforeItem) mergedBefore = beforeItem } return false @@ -127,14 +117,14 @@ func (wh *workHeap) MergeInsert(item *workItem) { // Note that the iterator function will run at most once, since it always returns false. wh.sortedItems.AscendGreaterOrEqual( searchItem, - func(afterItem *heapItem) bool { - if item.localRootID == afterItem.workItem.localRootID && - maybe.Equal(item.end, afterItem.workItem.start, bytes.Equal) { + func(afterItem *workItem) bool { + if item.localRootID == afterItem.localRootID && + maybe.Equal(item.end, afterItem.start, bytes.Equal) { // [item.start, item.end] and [afterItem.start, afterItem.end] are merged into // [item.start, afterItem.end]. - afterItem.workItem.start = item.start - afterItem.workItem.priority = math.Max(item.priority, afterItem.workItem.priority) - heap.Fix(&wh.innerHeap, afterItem.heapIndex) + afterItem.start = item.start + afterItem.priority = math.Max(item.priority, afterItem.priority) + wh.innerHeap.Fix(afterItem) mergedAfter = afterItem } return false @@ -144,12 +134,12 @@ func (wh *workHeap) MergeInsert(item *workItem) { // we can combine the before item with the after item if mergedBefore != nil && mergedAfter != nil { // combine the two ranges - mergedBefore.workItem.end = mergedAfter.workItem.end + mergedBefore.end = mergedAfter.end // remove the second range since it is now covered by the first wh.remove(mergedAfter) // update the priority - mergedBefore.workItem.priority = math.Max(mergedBefore.workItem.priority, mergedAfter.workItem.priority) - heap.Fix(&wh.innerHeap, mergedBefore.heapIndex) + mergedBefore.priority = math.Max(mergedBefore.priority, mergedAfter.priority) + wh.innerHeap.Fix(mergedBefore) } // nothing was merged, so add new item to the heap @@ -160,43 +150,11 @@ func (wh *workHeap) MergeInsert(item *workItem) { } // Deletes [item] from the heap. -func (wh *workHeap) remove(item *heapItem) { - heap.Remove(&wh.innerHeap, item.heapIndex) - +func (wh *workHeap) remove(item *workItem) { + wh.innerHeap.Remove(item) wh.sortedItems.Delete(item) } func (wh *workHeap) Len() int { return wh.innerHeap.Len() } - -// below this line are the implementations required for heap.Interface - -func (h innerHeap) Len() int { - return len(h) -} - -func (h innerHeap) Less(i int, j int) bool { - return h[i].workItem.priority > h[j].workItem.priority -} - -func (h innerHeap) Swap(i int, j int) { - h[i], h[j] = h[j], h[i] - h[i].heapIndex = i - h[j].heapIndex = j -} - -func (h *innerHeap) Pop() interface{} { - old := *h - n := len(old) - item := old[n-1] - old[n-1] = nil - *h = old[0 : n-1] - return item -} - -func (h *innerHeap) Push(x interface{}) { - item := x.(*heapItem) - item.heapIndex = len(*h) - *h = append(*h, item) -} diff --git a/x/sync/workheap_test.go b/x/sync/workheap_test.go index 7f50468a1fbd..0a3262a9310f 100644 --- a/x/sync/workheap_test.go +++ b/x/sync/workheap_test.go @@ -17,102 +17,6 @@ import ( "github.com/ava-labs/avalanchego/utils/maybe" ) -// Tests heap.Interface methods Push, Pop, Swap, Len, Less. -func Test_WorkHeap_InnerHeap(t *testing.T) { - require := require.New(t) - - lowPriorityItem := &heapItem{ - workItem: &workItem{ - start: maybe.Some([]byte{1}), - end: maybe.Some([]byte{2}), - priority: lowPriority, - localRootID: ids.GenerateTestID(), - }, - } - - mediumPriorityItem := &heapItem{ - workItem: &workItem{ - start: maybe.Some([]byte{3}), - end: maybe.Some([]byte{4}), - priority: medPriority, - localRootID: ids.GenerateTestID(), - }, - } - - highPriorityItem := &heapItem{ - workItem: &workItem{ - start: maybe.Some([]byte{5}), - end: maybe.Some([]byte{6}), - priority: highPriority, - localRootID: ids.GenerateTestID(), - }, - } - - h := innerHeap{} - require.Zero(h.Len()) - - // Note we're calling Push and Pop on the heap directly, - // not using heap.Push and heap.Pop. - h.Push(lowPriorityItem) - // Heap has [lowPriorityItem] - require.Equal(1, h.Len()) - require.Equal(lowPriorityItem, h[0]) - - got := h.Pop() - // Heap has [] - require.Equal(lowPriorityItem, got) - require.Zero(h.Len()) - - h.Push(lowPriorityItem) - h.Push(mediumPriorityItem) - // Heap has [lowPriorityItem, mediumPriorityItem] - require.Equal(2, h.Len()) - require.Equal(lowPriorityItem, h[0]) - require.Equal(mediumPriorityItem, h[1]) - - got = h.Pop() - // Heap has [lowPriorityItem] - require.Equal(mediumPriorityItem, got) - require.Equal(1, h.Len()) - - got = h.Pop() - // Heap has [] - require.Equal(lowPriorityItem, got) - require.Zero(h.Len()) - - h.Push(mediumPriorityItem) - h.Push(lowPriorityItem) - h.Push(highPriorityItem) - // Heap has [mediumPriorityItem, lowPriorityItem, highPriorityItem] - require.Equal(mediumPriorityItem, h[0]) - require.Equal(lowPriorityItem, h[1]) - require.Equal(highPriorityItem, h[2]) - - h.Swap(0, 1) - // Heap has [lowPriorityItem, mediumPriorityItem, highPriorityItem] - require.Equal(lowPriorityItem, h[0]) - require.Equal(mediumPriorityItem, h[1]) - require.Equal(highPriorityItem, h[2]) - - h.Swap(1, 2) - // Heap has [lowPriorityItem, highPriorityItem, mediumPriorityItem] - require.Equal(lowPriorityItem, h[0]) - require.Equal(highPriorityItem, h[1]) - require.Equal(mediumPriorityItem, h[2]) - - h.Swap(0, 2) - // Heap has [mediumPriorityItem, highPriorityItem, lowPriorityItem] - require.Equal(mediumPriorityItem, h[0]) - require.Equal(highPriorityItem, h[1]) - require.Equal(lowPriorityItem, h[2]) - require.False(h.Less(0, 1)) - require.True(h.Less(1, 0)) - require.True(h.Less(1, 2)) - require.False(h.Less(2, 1)) - require.True(h.Less(0, 2)) - require.False(h.Less(2, 0)) -} - // Tests Insert and GetWork func Test_WorkHeap_Insert_GetWork(t *testing.T) { require := require.New(t) @@ -144,8 +48,8 @@ func Test_WorkHeap_Insert_GetWork(t *testing.T) { // Ensure [sortedItems] is in right order. got := []*workItem{} h.sortedItems.Ascend( - func(i *heapItem) bool { - got = append(got, i.workItem) + func(i *workItem) bool { + got = append(got, i) return true }, ) @@ -195,40 +99,42 @@ func Test_WorkHeap_remove(t *testing.T) { h.Insert(lowPriorityItem) - wrappedLowPriorityItem := h.innerHeap[0] + wrappedLowPriorityItem, ok := h.innerHeap.Peek() + require.True(ok) h.remove(wrappedLowPriorityItem) require.Zero(h.Len()) - require.Empty(h.innerHeap) require.Zero(h.sortedItems.Len()) h.Insert(lowPriorityItem) h.Insert(mediumPriorityItem) h.Insert(highPriorityItem) - wrappedhighPriorityItem := h.innerHeap[0] - require.Equal(highPriorityItem, wrappedhighPriorityItem.workItem) + wrappedhighPriorityItem, ok := h.innerHeap.Peek() + require.True(ok) + require.Equal(highPriorityItem, wrappedhighPriorityItem) h.remove(wrappedhighPriorityItem) require.Equal(2, h.Len()) - require.Len(h.innerHeap, 2) require.Equal(2, h.sortedItems.Len()) - require.Zero(h.innerHeap[0].heapIndex) - require.Equal(mediumPriorityItem, h.innerHeap[0].workItem) + got, ok := h.innerHeap.Peek() + require.True(ok) + require.Equal(mediumPriorityItem, got) - wrappedMediumPriorityItem := h.innerHeap[0] - require.Equal(mediumPriorityItem, wrappedMediumPriorityItem.workItem) + wrappedMediumPriorityItem, ok := h.innerHeap.Peek() + require.True(ok) + require.Equal(mediumPriorityItem, wrappedMediumPriorityItem) h.remove(wrappedMediumPriorityItem) require.Equal(1, h.Len()) - require.Len(h.innerHeap, 1) require.Equal(1, h.sortedItems.Len()) - require.Zero(h.innerHeap[0].heapIndex) - require.Equal(lowPriorityItem, h.innerHeap[0].workItem) + got, ok = h.innerHeap.Peek() + require.True(ok) + require.Equal(lowPriorityItem, got) - wrappedLowPriorityItem = h.innerHeap[0] - require.Equal(lowPriorityItem, wrappedLowPriorityItem.workItem) + wrappedLowPriorityItem, ok = h.innerHeap.Peek() + require.True(ok) + require.Equal(lowPriorityItem, wrappedLowPriorityItem) h.remove(wrappedLowPriorityItem) require.Zero(h.Len()) - require.Empty(h.innerHeap) require.Zero(h.sortedItems.Len()) } @@ -367,13 +273,11 @@ func TestWorkHeapMergeInsertRandom(t *testing.T) { start = maybe.Nothing[[]byte]() } // Make sure end is updated - got, ok := h.sortedItems.Get(&heapItem{ - workItem: &workItem{ - start: start, - }, + got, ok := h.sortedItems.Get(&workItem{ + start: start, }) require.True(ok) - require.Equal(newEnd, got.workItem.end.Value()) + require.Equal(newEnd, got.end.Value()) } } @@ -397,13 +301,11 @@ func TestWorkHeapMergeInsertRandom(t *testing.T) { require.Equal(len(ranges), h.Len()) // Make sure start is updated - got, ok := h.sortedItems.Get(&heapItem{ - workItem: &workItem{ - start: newStart, - }, + got, ok := h.sortedItems.Get(&workItem{ + start: newStart, }) require.True(ok) - require.Equal(newStartBytes, got.workItem.start.Value()) + require.Equal(newStartBytes, got.start.Value()) } } }