diff --git a/config/config.go b/config/config.go index a5c74013..51a1ffe4 100644 --- a/config/config.go +++ b/config/config.go @@ -68,8 +68,10 @@ type SourceBlockchain struct { SupportedDestinations []string `mapstructure:"supported-destinations" json:"supported-destinations"` ProcessHistoricalBlocksFromHeight uint64 `mapstructure:"process-historical-blocks-from-height" json:"process-historical-blocks-from-height"` - // convenience field to access the supported destinations after initialization + // convenience fields to access parsed data after initialization supportedDestinations set.Set[ids.ID] + subnetID ids.ID + blockchainID ids.ID } type DestinationBlockchain struct { @@ -81,6 +83,10 @@ type DestinationBlockchain struct { // Fetched from the chain after startup warpQuorum WarpQuorum + + // convenience fields to access parsed data after initialization + subnetID ids.ID + blockchainID ids.ID } type WarpQuorum struct { @@ -98,9 +104,8 @@ type Config struct { ProcessMissedBlocks bool `mapstructure:"process-missed-blocks" json:"process-missed-blocks"` ManualWarpMessages []*ManualWarpMessage `mapstructure:"manual-warp-messages" json:"manual-warp-messages"` - // convenience fields to access the source subnet and chain IDs after initialization - sourceSubnetIDs []ids.ID - sourceBlockchainIDs []ids.ID + // convenience field to fetch a blockchain's subnet ID + blockchainIDToSubnetID map[ids.ID]ids.ID } func SetDefaultConfigValues(v *viper.Viper) { @@ -194,6 +199,8 @@ func (c *Config) Validate() error { return err } + blockchainIDToSubnetID := make(map[ids.ID]ids.ID) + // Validate the destination chains destinationChains := set.NewSet[string](len(c.DestinationBlockchains)) for _, s := range c.DestinationBlockchains { @@ -204,12 +211,11 @@ func (c *Config) Validate() error { return errors.New("configured destination subnets must have unique chain IDs") } destinationChains.Add(s.BlockchainID) + blockchainIDToSubnetID[s.blockchainID] = s.subnetID } // Validate the source chains and store the source subnet and chain IDs for future use sourceBlockchains := set.NewSet[string](len(c.SourceBlockchains)) - var sourceSubnetIDs []ids.ID - var sourceBlockchainIDs []ids.ID for _, s := range c.SourceBlockchains { // Validate configuration if err := s.Validate(&destinationChains); err != nil { @@ -220,23 +226,9 @@ func (c *Config) Validate() error { return errors.New("configured source subnets must have unique chain IDs") } sourceBlockchains.Add(s.BlockchainID) - - // Save IDs for future use - subnetID, err := ids.FromString(s.SubnetID) - if err != nil { - return fmt.Errorf("invalid subnetID in configuration. error: %w", err) - } - sourceSubnetIDs = append(sourceSubnetIDs, subnetID) - - blockchainID, err := ids.FromString(s.BlockchainID) - if err != nil { - return fmt.Errorf("invalid subnetID in configuration. error: %w", err) - } - sourceBlockchainIDs = append(sourceBlockchainIDs, blockchainID) + blockchainIDToSubnetID[s.blockchainID] = s.subnetID } - - c.sourceSubnetIDs = sourceSubnetIDs - c.sourceBlockchainIDs = sourceBlockchainIDs + c.blockchainIDToSubnetID = blockchainIDToSubnetID // Validate the manual warp messages for i, msg := range c.ManualWarpMessages { @@ -248,6 +240,10 @@ func (c *Config) Validate() error { return nil } +func (c *Config) GetSubnetID(blockchainID ids.ID) ids.ID { + return c.blockchainIDToSubnetID[blockchainID] +} + func (m *ManualWarpMessage) GetUnsignedMessageBytes() []byte { return m.unsignedMessageBytes } @@ -417,6 +413,18 @@ func (s *SourceBlockchain) Validate(destinationBlockchainIDs *set.Set[string]) e } } + // Validate and store the subnet and blockchain IDs for future use + blockchainID, err := ids.FromString(s.BlockchainID) + if err != nil { + return fmt.Errorf("invalid blockchainID in configuration. error: %w", err) + } + s.blockchainID = blockchainID + subnetID, err := ids.FromString(s.SubnetID) + if err != nil { + return fmt.Errorf("invalid subnetID in configuration. error: %w", err) + } + s.subnetID = subnetID + // Validate and store the allowed destinations for future use s.supportedDestinations = set.Set[ids.ID]{} @@ -447,6 +455,14 @@ func (s *SourceBlockchain) Validate(destinationBlockchainIDs *set.Set[string]) e return nil } +func (s *SourceBlockchain) GetSubnetID() ids.ID { + return s.subnetID +} + +func (s *SourceBlockchain) GetBlockchainID() ids.ID { + return s.blockchainID +} + // Validatees the destination subnet configuration func (s *DestinationBlockchain) Validate() error { if _, err := ids.FromString(s.SubnetID); err != nil { @@ -473,9 +489,29 @@ func (s *DestinationBlockchain) Validate() error { return fmt.Errorf("unsupported VM type for source subnet: %s", s.VM) } + // Validate and store the subnet and blockchain IDs for future use + blockchainID, err := ids.FromString(s.BlockchainID) + if err != nil { + return fmt.Errorf("invalid blockchainID in configuration. error: %w", err) + } + s.blockchainID = blockchainID + subnetID, err := ids.FromString(s.SubnetID) + if err != nil { + return fmt.Errorf("invalid subnetID in configuration. error: %w", err) + } + s.subnetID = subnetID + return nil } +func (s *DestinationBlockchain) GetSubnetID() ids.ID { + return s.subnetID +} + +func (s *DestinationBlockchain) GetBlockchainID() ids.ID { + return s.blockchainID +} + func (s *DestinationBlockchain) initializeWarpQuorum() error { blockchainID, err := ids.FromString(s.BlockchainID) if err != nil { @@ -514,11 +550,6 @@ func (s *DestinationBlockchain) GetRelayerAccountInfo() (*ecdsa.PrivateKey, comm // Top-level config getters // -// GetSourceIDs returns the Subnet and Chain IDs of all subnets configured as a source -func (c *Config) GetSourceIDs() ([]ids.ID, []ids.ID) { - return c.sourceSubnetIDs, c.sourceBlockchainIDs -} - func (c *Config) GetWarpQuorum(blockchainID ids.ID) (WarpQuorum, error) { for _, s := range c.DestinationBlockchains { if blockchainID.String() == s.BlockchainID { diff --git a/database/json_file_storage.go b/database/json_file_storage.go index 68ff7d61..67cc5779 100644 --- a/database/json_file_storage.go +++ b/database/json_file_storage.go @@ -12,6 +12,7 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/awm-relayer/config" "github.com/pkg/errors" "go.uber.org/zap" ) @@ -34,7 +35,7 @@ type JSONFileStorage struct { } // NewJSONFileStorage creates a new JSONFileStorage instance -func NewJSONFileStorage(logger logging.Logger, dir string, networks []ids.ID) (*JSONFileStorage, error) { +func NewJSONFileStorage(logger logging.Logger, dir string, sourceBlockchains []*config.SourceBlockchain) (*JSONFileStorage, error) { storage := &JSONFileStorage{ dir: filepath.Clean(dir), mutexes: make(map[ids.ID]*sync.RWMutex), @@ -42,22 +43,24 @@ func NewJSONFileStorage(logger logging.Logger, dir string, networks []ids.ID) (* currentState: make(map[ids.ID]chainState), } - for _, network := range networks { - storage.currentState[network] = make(chainState) - storage.mutexes[network] = &sync.RWMutex{} + for _, sourceBlockchain := range sourceBlockchains { + sourceBlockchainID := sourceBlockchain.GetBlockchainID() + storage.currentState[sourceBlockchainID] = make(chainState) + storage.mutexes[sourceBlockchainID] = &sync.RWMutex{} } _, err := os.Stat(dir) if err == nil { // Directory already exists. // Read the existing storage. - for _, network := range networks { - currentState, fileExists, err := storage.getCurrentState(network) + for _, sourceBlockchain := range sourceBlockchains { + sourceBlockchainID := sourceBlockchain.GetBlockchainID() + currentState, fileExists, err := storage.getCurrentState(sourceBlockchainID) if err != nil { return nil, err } if fileExists { - storage.currentState[network] = currentState + storage.currentState[sourceBlockchainID] = currentState } } return storage, nil diff --git a/database/json_file_storage_test.go b/database/json_file_storage_test.go index f706c89a..268294d8 100644 --- a/database/json_file_storage_test.go +++ b/database/json_file_storage_test.go @@ -13,15 +13,44 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/avalanchego/utils/set" + "github.com/ava-labs/awm-relayer/config" "github.com/stretchr/testify/assert" ) +var validSourceBlockchainConfig = &config.SourceBlockchain{ + RPCEndpoint: "http://test.avax.network/ext/bc/C/rpc", + WSEndpoint: "ws://test.avax.network/ext/bc/C/ws", + BlockchainID: "S4mMqUXe7vHsGiRAma6bv3CKnyaLssyAxmQ2KvFpX1KEvfFCD", + SubnetID: "2TGBXcnwx5PqiXWiqxAKUaNSqDguXNh1mxnp82jui68hxJSZAx", + VM: "evm", + MessageContracts: map[string]config.MessageProtocolConfig{ + "0xd81545385803bCD83bd59f58Ba2d2c0562387F83": { + MessageFormat: config.TELEPORTER.String(), + }, + }, +} + +func populateSourceConfig(blockchainIDs []ids.ID) []*config.SourceBlockchain { + sourceBlockchains := make([]*config.SourceBlockchain, len(blockchainIDs)) + for i, id := range blockchainIDs { + sourceBlockchains[i] = validSourceBlockchainConfig + sourceBlockchains[i].BlockchainID = id.String() + } + destinationsBlockchainIDs := set.NewSet[string](1) // just needs to be non-nil + destinationsBlockchainIDs.Add(ids.GenerateTestID().String()) + sourceBlockchains[0].Validate(&destinationsBlockchainIDs) + return sourceBlockchains +} + // Test that the JSON database can write and read to a single chain concurrently. func TestConcurrentWriteReadSingleChain(t *testing.T) { - networks := []ids.ID{ - ids.GenerateTestID(), - } - jsonStorage := setupJsonStorage(t, networks) + sourceBlockchains := populateSourceConfig( + []ids.ID{ + ids.GenerateTestID(), + }, + ) + jsonStorage := setupJsonStorage(t, sourceBlockchains) // Test writing to the JSON database concurrently. wg := sync.WaitGroup{} @@ -30,16 +59,16 @@ func TestConcurrentWriteReadSingleChain(t *testing.T) { idx := i go func() { defer wg.Done() - testWrite(jsonStorage, networks[0], uint64(idx)) + testWrite(jsonStorage, sourceBlockchains[0].GetBlockchainID(), uint64(idx)) }() } wg.Wait() // Write one final time to ensure that concurrent writes don't cause any issues. finalTargetValue := uint64(11) - testWrite(jsonStorage, networks[0], finalTargetValue) + testWrite(jsonStorage, sourceBlockchains[0].GetBlockchainID(), finalTargetValue) - latestProcessedBlockData, err := jsonStorage.Get(networks[0], []byte(LatestProcessedBlockKey)) + latestProcessedBlockData, err := jsonStorage.Get(sourceBlockchains[0].GetBlockchainID(), []byte(LatestProcessedBlockKey)) if err != nil { t.Fatalf("failed to retrieve from JSON storage. err: %v", err) } @@ -52,12 +81,14 @@ func TestConcurrentWriteReadSingleChain(t *testing.T) { // Test that the JSON database can write and read from multiple chains concurrently. Write to any given chain are not concurrent. func TestConcurrentWriteReadMultipleChains(t *testing.T) { - networks := []ids.ID{ - ids.GenerateTestID(), - ids.GenerateTestID(), - ids.GenerateTestID(), - } - jsonStorage := setupJsonStorage(t, networks) + sourceBlockchains := populateSourceConfig( + []ids.ID{ + ids.GenerateTestID(), + ids.GenerateTestID(), + ids.GenerateTestID(), + }, + ) + jsonStorage := setupJsonStorage(t, sourceBlockchains) // Test writing to the JSON database concurrently. wg := sync.WaitGroup{} @@ -66,19 +97,19 @@ func TestConcurrentWriteReadMultipleChains(t *testing.T) { index := i go func() { defer wg.Done() - testWrite(jsonStorage, networks[index], uint64(index)) + testWrite(jsonStorage, sourceBlockchains[index].GetBlockchainID(), uint64(index)) }() } wg.Wait() // Write one final time to ensure that concurrent writes don't cause any issues. finalTargetValue := uint64(3) - for _, network := range networks { - testWrite(jsonStorage, network, finalTargetValue) + for _, sourceBlockchain := range sourceBlockchains { + testWrite(jsonStorage, sourceBlockchain.GetBlockchainID(), finalTargetValue) } - for i, id := range networks { - latestProcessedBlockData, err := jsonStorage.Get(id, []byte(LatestProcessedBlockKey)) + for i, sourceBlockchain := range sourceBlockchains { + latestProcessedBlockData, err := jsonStorage.Get(sourceBlockchain.GetBlockchainID(), []byte(LatestProcessedBlockKey)) if err != nil { t.Fatalf("failed to retrieve from JSON storage. networkID: %d err: %v", i, err) } @@ -90,7 +121,7 @@ func TestConcurrentWriteReadMultipleChains(t *testing.T) { } } -func setupJsonStorage(t *testing.T, networks []ids.ID) *JSONFileStorage { +func setupJsonStorage(t *testing.T, sourceBlockchains []*config.SourceBlockchain) *JSONFileStorage { logger := logging.NewLogger( "awm-relayer-test", logging.NewWrappedCore( @@ -101,7 +132,7 @@ func setupJsonStorage(t *testing.T, networks []ids.ID) *JSONFileStorage { ) storageDir := t.TempDir() - jsonStorage, err := NewJSONFileStorage(logger, storageDir, networks) + jsonStorage, err := NewJSONFileStorage(logger, storageDir, sourceBlockchains) if err != nil { t.Fatal(err) } diff --git a/main/main.go b/main/main.go index 4a816856..063a0c77 100644 --- a/main/main.go +++ b/main/main.go @@ -12,6 +12,7 @@ import ( "os" "github.com/alexliesenfeld/health" + "github.com/ava-labs/avalanchego/api/info" "github.com/ava-labs/avalanchego/api/metrics" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/message" @@ -74,8 +75,9 @@ func main() { } logger.Info(fmt.Sprintf("Set config options.%s", overwrittenLog)) - // Global P-Chain client used to get subnet validator sets + // Global P-Chain and Info clients used to get subnet validator sets pChainClient := platformvm.NewClient(cfg.PChainAPIURL) + infoClient := info.NewClient(cfg.InfoAPIURL) // Initialize all destination clients logger.Info("Initializing destination clients") @@ -98,7 +100,6 @@ func main() { // Initialize the global app request network logger.Info("Initializing app request network") - sourceSubnetIDs, sourceBlockchainIDs := cfg.GetSourceIDs() // The app request network generates P2P networking logs that are verbose at the info level. // Unless the log level is debug or lower, set the network log level to error to avoid spamming the logs. @@ -106,7 +107,13 @@ func main() { if logLevel <= logging.Debug { networkLogLevel = logLevel } - network, responseChans, err := peers.NewNetwork(networkLogLevel, registerer, sourceSubnetIDs, sourceBlockchainIDs, cfg.InfoAPIURL) + network, responseChans, err := peers.NewNetwork( + networkLogLevel, + registerer, + &cfg, + infoClient, + pChainClient, + ) if err != nil { logger.Error( "Failed to create app request network", @@ -167,7 +174,7 @@ func main() { } // Initialize the database - db, err := database.NewJSONFileStorage(logger, cfg.StorageLocation, sourceBlockchainIDs) + db, err := database.NewJSONFileStorage(logger, cfg.StorageLocation, cfg.SourceBlockchains) if err != nil { logger.Error( "Failed to create database", diff --git a/peers/app_request_network.go b/peers/app_request_network.go index 3fe351dc..db45f7a4 100644 --- a/peers/app_request_network.go +++ b/peers/app_request_network.go @@ -5,8 +5,7 @@ package peers import ( "context" - "fmt" - "math/rand" + "math/big" "os" "sync" "time" @@ -15,10 +14,16 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/message" "github.com/ava-labs/avalanchego/network" - "github.com/ava-labs/avalanchego/snow/validators" + snowVdrs "github.com/ava-labs/avalanchego/snow/validators" + "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/ips" "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/set" + "github.com/ava-labs/avalanchego/vms/platformvm" + "github.com/ava-labs/avalanchego/vms/platformvm/warp" + "github.com/ava-labs/awm-relayer/config" + "github.com/ava-labs/awm-relayer/utils" + "github.com/ava-labs/awm-relayer/validators" "github.com/prometheus/client_golang/prometheus" "go.uber.org/zap" ) @@ -26,25 +31,24 @@ import ( const ( InboundMessageChannelSize = 1000 DefaultAppRequestTimeout = time.Second * 2 - - numInitialTestPeers = 5 ) type AppRequestNetwork struct { - Network network.Network - Handler *RelayerExternalHandler - infoClient info.Client - logger logging.Logger - lock *sync.Mutex + Network network.Network + Handler *RelayerExternalHandler + infoClient info.Client + logger logging.Logger + lock *sync.Mutex + validatorClient *validators.CanonicalValidatorClient } // NewNetwork connects to a peers at the app request level. func NewNetwork( logLevel logging.Level, registerer prometheus.Registerer, - subnetIDs []ids.ID, - blockchainIDs []ids.ID, - infoAPINodeURL string, + cfg *config.Config, + infoClient info.Client, + pChainClient platformvm.Client, ) (*AppRequestNetwork, map[ids.ID]chan message.InboundMessage, error) { logger := logging.NewLogger( "awm-relayer-p2p", @@ -55,13 +59,6 @@ func NewNetwork( ), ) - if infoAPINodeURL == "" { - logger.Error("No InfoAPI node URL provided") - return nil, nil, fmt.Errorf("must provide an Inffo API URL") - } - - // Create the info client - infoClient := info.NewClient(infoAPINodeURL) networkID, err := infoClient.GetNetworkID(context.Background()) if err != nil { logger.Error( @@ -73,15 +70,15 @@ func NewNetwork( // Create the test network for AppRequests var trackedSubnets set.Set[ids.ID] - for _, subnetID := range subnetIDs { - trackedSubnets.Add(subnetID) + for _, sourceBlockchain := range cfg.SourceBlockchains { + trackedSubnets.Add(sourceBlockchain.GetSubnetID()) } // Construct a response chan for each chain. Inbound messages will be routed to the proper channel in the handler responseChans := make(map[ids.ID]chan message.InboundMessage) - for _, blockchainID := range blockchainIDs { + for _, sourceBlockchain := range cfg.SourceBlockchains { responseChan := make(chan message.InboundMessage, InboundMessageChannelSize) - responseChans[blockchainID] = responseChan + responseChans[sourceBlockchain.GetBlockchainID()] = responseChan } responseChansLock := new(sync.RWMutex) @@ -94,7 +91,7 @@ func NewNetwork( return nil, nil, err } - network, err := network.NewTestNetwork(logger, networkID, validators.NewManager(), trackedSubnets, handler) + testNetwork, err := network.NewTestNetwork(logger, networkID, snowVdrs.NewManager(), trackedSubnets, handler) if err != nil { logger.Error( "Failed to create test network", @@ -103,83 +100,38 @@ func NewNetwork( return nil, nil, err } - // We need to initially connect to some nodes in the network before peer - // gossip will enable connecting to all the remaining nodes in the network. - var beaconIPs, beaconIDs []string - - peers, err := infoClient.Peers(context.Background()) - if err != nil { - logger.Error( - "Failed to get peers", - zap.Error(err), - ) - return nil, nil, err - } + validatorClient := validators.NewCanonicalValidatorClient(logger, pChainClient) - // Randomly select peers to connect to until we have numInitialTestPeers - indices := rand.Perm(len(peers)) - for _, index := range indices { - // Do not attempt to connect to private peers - if len(peers[index].PublicIP) == 0 { - continue - } - beaconIPs = append(beaconIPs, peers[index].PublicIP) - beaconIDs = append(beaconIDs, peers[index].ID.String()) - if len(beaconIDs) == numInitialTestPeers { - break - } - } - if len(beaconIPs) == 0 { - logger.Error( - "Failed to find any peers to connect to", - zap.Error(err), - ) - return nil, nil, err - } - if len(beaconIPs) < numInitialTestPeers { - logger.Warn( - "Failed to find a full set of peers to connect to on startup", - zap.Int("connectedPeers", len(beaconIPs)), - zap.Int("expectedConnectedPeers", numInitialTestPeers), - ) + arNetwork := &AppRequestNetwork{ + Network: testNetwork, + Handler: handler, + infoClient: infoClient, + logger: logger, + lock: new(sync.Mutex), + validatorClient: validatorClient, } - for i, beaconIDStr := range beaconIDs { - beaconID, err := ids.NodeIDFromString(beaconIDStr) - if err != nil { - logger.Error( - "Failed to parse beaconID", - zap.String("beaconID", beaconIDStr), - zap.Error(err), - ) - return nil, nil, err - } - - beaconIPStr := beaconIPs[i] - ipPort, err := ips.ToIPPort(beaconIPStr) - if err != nil { - logger.Error( - "Failed to parse beaconIP", - zap.String("beaconIP", beaconIPStr), - zap.Error(err), - ) - return nil, nil, err + // Manually connect to the validators of each of the source subnets. + // We return an error if we are unable to connect to sufficient stake on any of the subnets. + // Sufficient stake is determined by the Warp quora of the configured supported destinations, + // or if the subnet supports all destinations, by the quora of all configured destinations. + for _, sourceBlockchain := range cfg.SourceBlockchains { + if sourceBlockchain.GetSubnetID() == constants.PrimaryNetworkID { + if err := arNetwork.connectToPrimaryNetworkPeers(cfg, sourceBlockchain); err != nil { + return nil, nil, err + } + } else { + if err := arNetwork.connectToNonPrimaryNetworkPeers(cfg, sourceBlockchain); err != nil { + return nil, nil, err + } } - - network.ManuallyTrack(beaconID, ipPort) } go logger.RecoverAndPanic(func() { - network.Dispatch() + testNetwork.Dispatch() }) - return &AppRequestNetwork{ - Network: network, - Handler: handler, - infoClient: infoClient, - logger: logger, - lock: new(sync.Mutex), - }, responseChans, nil + return arNetwork, responseChans, nil } // ConnectPeers connects the network to peers with the given nodeIDs. @@ -257,3 +209,139 @@ func (n *AppRequestNetwork) ConnectPeers(nodeIDs set.Set[ids.NodeID]) set.Set[id return trackedNodes } + +// Helper struct to hold connected validator information +// Warp Validators sharing the same BLS key may consist of multiple nodes, +// so we need to track the node ID to validator index mapping +type ConnectedCanonicalValidators struct { + ConnectedWeight uint64 + TotalValidatorWeight uint64 + ValidatorSet []*warp.Validator + nodeValidatorIndexMap map[ids.NodeID]int +} + +// Returns the Warp Validator and its index in the canonical Validator ordering for a given nodeID +func (c *ConnectedCanonicalValidators) GetValidator(nodeID ids.NodeID) (*warp.Validator, int) { + return c.ValidatorSet[c.nodeValidatorIndexMap[nodeID]], c.nodeValidatorIndexMap[nodeID] +} + +// ConnectToCanonicalValidators connects to the canonical validators of the given subnet and returns the connected +// validator information +func (n *AppRequestNetwork) ConnectToCanonicalValidators(subnetID ids.ID) (*ConnectedCanonicalValidators, error) { + // Get the subnet's current canonical validator set + validatorSet, totalValidatorWeight, err := n.validatorClient.GetCurrentCanonicalValidatorSet(subnetID) + if err != nil { + return nil, err + } + + // We make queries to node IDs, not unique validators as represented by a BLS pubkey, so we need this map to track + // responses from nodes and populate the signatureMap with the corresponding validator signature + // This maps node IDs to the index in the canonical validator set + nodeValidatorIndexMap := make(map[ids.NodeID]int) + for i, vdr := range validatorSet { + for _, node := range vdr.NodeIDs { + nodeValidatorIndexMap[node] = i + } + } + + // Manually connect to all peers in the validator set + // If new peers are connected, AppRequests may fail while the handshake is in progress. + // In that case, AppRequests to those nodes will be retried in the next iteration of the retry loop. + nodeIDs := set.NewSet[ids.NodeID](len(nodeValidatorIndexMap)) + for node := range nodeValidatorIndexMap { + nodeIDs.Add(node) + } + connectedNodes := n.ConnectPeers(nodeIDs) + + // Check if we've connected to a stake threshold of nodes + connectedWeight := uint64(0) + for node := range connectedNodes { + connectedWeight += validatorSet[nodeValidatorIndexMap[node]].Weight + } + return &ConnectedCanonicalValidators{ + ConnectedWeight: connectedWeight, + TotalValidatorWeight: totalValidatorWeight, + ValidatorSet: validatorSet, + nodeValidatorIndexMap: nodeValidatorIndexMap, + }, nil +} + +// Private helpers + +// Connect to the validators of the source blockchain. For each destination blockchain, verify that we have connected to a threshold of stake. +func (n *AppRequestNetwork) connectToNonPrimaryNetworkPeers(cfg *config.Config, sourceBlockchain *config.SourceBlockchain) error { + subnetID := sourceBlockchain.GetSubnetID() + connectedValidators, err := n.ConnectToCanonicalValidators(subnetID) + if err != nil { + n.logger.Error( + "Failed to connect to canonical validators", + zap.String("subnetID", subnetID.String()), + zap.Error(err), + ) + return err + } + for _, destinationBlockchainID := range sourceBlockchain.GetSupportedDestinations().List() { + if ok, quorum, err := n.checkForSufficientConnectedStake(cfg, connectedValidators, destinationBlockchainID); !ok { + n.logger.Error( + "Failed to connect to a threshold of stake", + zap.String("destinationBlockchainID", destinationBlockchainID.String()), + zap.Uint64("connectedWeight", connectedValidators.ConnectedWeight), + zap.Uint64("totalValidatorWeight", connectedValidators.TotalValidatorWeight), + zap.Any("warpQuorum", quorum), + ) + return err + } + } + return nil +} + +// Connect to the validators of the destination blockchains. Verify that we have connected to a threshold of stake for each blockchain. +func (n *AppRequestNetwork) connectToPrimaryNetworkPeers(cfg *config.Config, sourceBlockchain *config.SourceBlockchain) error { + for _, destinationBlockchainID := range sourceBlockchain.GetSupportedDestinations().List() { + subnetID := cfg.GetSubnetID(destinationBlockchainID) + connectedValidators, err := n.ConnectToCanonicalValidators(subnetID) + if err != nil { + n.logger.Error( + "Failed to connect to canonical validators", + zap.String("subnetID", subnetID.String()), + zap.Error(err), + ) + return err + } + + if ok, quorum, err := n.checkForSufficientConnectedStake(cfg, connectedValidators, destinationBlockchainID); !ok { + n.logger.Error( + "Failed to connect to a threshold of stake", + zap.String("destinationBlockchainID", destinationBlockchainID.String()), + zap.Uint64("connectedWeight", connectedValidators.ConnectedWeight), + zap.Uint64("totalValidatorWeight", connectedValidators.TotalValidatorWeight), + zap.Any("warpQuorum", quorum), + ) + return err + } + } + return nil +} + +// Fetch the warp quorum from the config and check if the connected stake exceeds the threshold +func (n *AppRequestNetwork) checkForSufficientConnectedStake( + cfg *config.Config, + connectedValidators *ConnectedCanonicalValidators, + destinationBlockchainID ids.ID, +) (bool, *config.WarpQuorum, error) { + quorum, err := cfg.GetWarpQuorum(destinationBlockchainID) + if err != nil { + n.logger.Error( + "Failed to get warp quorum from config", + zap.String("destinationBlockchainID", destinationBlockchainID.String()), + zap.Error(err), + ) + return false, nil, err + } + return utils.CheckStakeWeightExceedsThreshold( + big.NewInt(0).SetUint64(connectedValidators.ConnectedWeight), + connectedValidators.TotalValidatorWeight, + quorum.QuorumNumerator, + quorum.QuorumDenominator, + ), &quorum, nil +} diff --git a/relayer/message_relayer.go b/relayer/message_relayer.go index dfc47a94..28ced466 100644 --- a/relayer/message_relayer.go +++ b/relayer/message_relayer.go @@ -54,6 +54,7 @@ type messageRelayer struct { relayer *Relayer warpMessage *avalancheWarp.UnsignedMessage destinationBlockchainID ids.ID + signingSubnetID ids.ID warpQuorum config.WarpQuorum } @@ -71,10 +72,19 @@ func newMessageRelayer( ) return nil, err } + var signingSubnet ids.ID + if relayer.sourceSubnetID == constants.PrimaryNetworkID { + // If the message originates from the primary subnet, then we instead "self sign" the message using the validators of the destination subnet. + signingSubnet = relayer.globalConfig.GetSubnetID(destinationBlockchainID) + } else { + // Otherwise, the source subnet signs the message. + signingSubnet = relayer.sourceSubnetID + } return &messageRelayer{ relayer: relayer, warpMessage: warpMessage, destinationBlockchainID: destinationBlockchainID, + signingSubnetID: signingSubnet, warpQuorum: quorum, }, nil } @@ -156,18 +166,6 @@ func (r *messageRelayer) createSignedMessage() (*avalancheWarp.Message, error) { ) return nil, err } - signingSubnetID := r.relayer.sourceSubnetID - if r.relayer.sourceSubnetID == constants.PrimaryNetworkID { - signingSubnetID, err = r.relayer.pChainClient.ValidatedBy(context.Background(), r.destinationBlockchainID) - if err != nil { - r.relayer.logger.Error( - "failed to get validating subnet for destination chain", - zap.String("destinationBlockchainID", r.destinationBlockchainID.String()), - zap.Error(err), - ) - return nil, err - } - } var signedWarpMessageBytes []byte for attempt := 1; attempt <= maxRelayerQueryAttempts; attempt++ { @@ -176,13 +174,13 @@ func (r *messageRelayer) createSignedMessage() (*avalancheWarp.Message, error) { zap.Int("attempt", attempt), zap.String("sourceBlockchainID", r.relayer.sourceBlockchainID.String()), zap.String("destinationBlockchainID", r.destinationBlockchainID.String()), - zap.String("signingSubnetID", signingSubnetID.String()), + zap.String("signingSubnetID", r.signingSubnetID.String()), ) signedWarpMessageBytes, err = warpClient.GetMessageAggregateSignature( context.Background(), r.warpMessage.ID(), r.warpQuorum.QuorumNumerator, - signingSubnetID.String(), + r.signingSubnetID.String(), ) if err == nil { warpMsg, err := avalancheWarp.ParseMessage(signedWarpMessageBytes) @@ -211,7 +209,7 @@ func (r *messageRelayer) createSignedMessage() (*avalancheWarp.Message, error) { zap.Int("attempts", maxRelayerQueryAttempts), zap.String("sourceBlockchainID", r.relayer.sourceBlockchainID.String()), zap.String("destinationBlockchainID", r.destinationBlockchainID.String()), - zap.String("signingSubnetID", signingSubnetID.String()), + zap.String("signingSubnetID", r.signingSubnetID.String()), ) return nil, errFailedToGetAggSig } @@ -219,52 +217,24 @@ func (r *messageRelayer) createSignedMessage() (*avalancheWarp.Message, error) { // createSignedMessageAppRequest collects signatures from nodes by directly querying them via AppRequest, then aggregates the signatures, and constructs the signed warp message. func (r *messageRelayer) createSignedMessageAppRequest(requestID uint32) (*avalancheWarp.Message, error) { r.relayer.logger.Info("Fetching aggregate signature from the source chain validators via AppRequest") - - // Get the current canonical validator set of the source subnet. - validatorSet, totalValidatorWeight, err := r.getCurrentCanonicalValidatorSet() + connectedValidators, err := r.relayer.network.ConnectToCanonicalValidators(r.signingSubnetID) if err != nil { r.relayer.logger.Error( - "Failed to get the canonical subnet validator set", - zap.String("subnetID", r.relayer.sourceSubnetID.String()), + "Failed to connect to canonical validators", zap.Error(err), ) return nil, err } - - // We make queries to node IDs, not unique validators as represented by a BLS pubkey, so we need this map to track - // responses from nodes and populate the signatureMap with the corresponding validator signature - // This maps node IDs to the index in the canonical validator set - nodeValidatorIndexMap := make(map[ids.NodeID]int) - for i, vdr := range validatorSet { - for _, node := range vdr.NodeIDs { - nodeValidatorIndexMap[node] = i - } - } - - // Manually connect to all peers in the validator set - // If new peers are connected, AppRequests may fail while the handshake is in progress. - // In that case, AppRequests to those nodes will be retried in the next iteration of the retry loop. - nodeIDs := set.NewSet[ids.NodeID](len(nodeValidatorIndexMap)) - for node := range nodeValidatorIndexMap { - nodeIDs.Add(node) - } - connectedNodes := r.relayer.network.ConnectPeers(nodeIDs) - - // Check if we've connected to a stake threshold of nodes - connectedWeight := uint64(0) - for node := range connectedNodes { - connectedWeight += validatorSet[nodeValidatorIndexMap[node]].Weight - } if !utils.CheckStakeWeightExceedsThreshold( - big.NewInt(0).SetUint64(connectedWeight), - totalValidatorWeight, + big.NewInt(0).SetUint64(connectedValidators.ConnectedWeight), + connectedValidators.TotalValidatorWeight, r.warpQuorum.QuorumNumerator, r.warpQuorum.QuorumDenominator, ) { r.relayer.logger.Error( "Failed to connect to a threshold of stake", - zap.Uint64("connectedWeight", connectedWeight), - zap.Uint64("totalValidatorWeight", totalValidatorWeight), + zap.Uint64("connectedWeight", connectedValidators.ConnectedWeight), + zap.Uint64("totalValidatorWeight", connectedValidators.TotalValidatorWeight), zap.Any("warpQuorum", r.warpQuorum), ) return nil, errNotEnoughConnectedStake @@ -309,18 +279,18 @@ func (r *messageRelayer) createSignedMessageAppRequest(requestID uint32) (*avala signatureMap := make(map[int]blsSignatureBuf) for attempt := 1; attempt <= maxRelayerQueryAttempts; attempt++ { - responsesExpected := len(validatorSet) - len(signatureMap) + responsesExpected := len(connectedValidators.ValidatorSet) - len(signatureMap) r.relayer.logger.Debug( "Relayer collecting signatures from peers.", zap.Int("attempt", attempt), zap.String("destinationBlockchainID", r.destinationBlockchainID.String()), - zap.Int("validatorSetSize", len(validatorSet)), + zap.Int("validatorSetSize", len(connectedValidators.ValidatorSet)), zap.Int("signatureMapSize", len(signatureMap)), zap.Int("responsesExpected", responsesExpected), ) - vdrSet := set.NewSet[ids.NodeID](len(validatorSet)) - for i, vdr := range validatorSet { + vdrSet := set.NewSet[ids.NodeID](len(connectedValidators.ValidatorSet)) + for i, vdr := range connectedValidators.ValidatorSet { // If we already have the signature for this validator, do not query any of the composite nodes again if _, ok := signatureMap[i]; ok { continue @@ -401,14 +371,14 @@ func (r *messageRelayer) createSignedMessageAppRequest(requestID uint32) (*avala return nil, nil } - validator := validatorSet[nodeValidatorIndexMap[nodeID]] + validator, vdrIndex := connectedValidators.GetValidator(nodeID) signature, valid := r.isValidSignatureResponse(response, validator.PublicKey) if valid { r.relayer.logger.Debug( "Got valid signature response", zap.String("nodeID", nodeID.String()), ) - signatureMap[nodeValidatorIndexMap[nodeID]] = signature + signatureMap[vdrIndex] = signature accumulatedSignatureWeight.Add(accumulatedSignatureWeight, new(big.Int).SetUint64(validator.Weight)) } else { r.relayer.logger.Debug( @@ -421,7 +391,7 @@ func (r *messageRelayer) createSignedMessageAppRequest(requestID uint32) (*avala // As soon as the signatures exceed the stake weight threshold we try to aggregate and send the transaction. if utils.CheckStakeWeightExceedsThreshold( accumulatedSignatureWeight, - totalValidatorWeight, + connectedValidators.TotalValidatorWeight, r.warpQuorum.QuorumNumerator, r.warpQuorum.QuorumDenominator, ) { @@ -483,55 +453,6 @@ func (r *messageRelayer) createSignedMessageAppRequest(requestID uint32) (*avala return nil, errNotEnoughSignatures } -func (r *messageRelayer) getCurrentCanonicalValidatorSet() ([]*avalancheWarp.Validator, uint64, error) { - var ( - signingSubnet ids.ID - err error - ) - if r.relayer.sourceSubnetID == constants.PrimaryNetworkID { - // If the message originates from the primary subnet, then we instead "self sign" the message using the validators of the destination subnet. - signingSubnet, err = r.relayer.pChainClient.ValidatedBy(context.Background(), r.destinationBlockchainID) - if err != nil { - r.relayer.logger.Error( - "Failed to get validating subnet for destination chain", - zap.String("destinationBlockchainID", r.destinationBlockchainID.String()), - zap.Error(err), - ) - return nil, 0, err - } - } else { - // Otherwise, the source subnet signs the message. - signingSubnet = r.relayer.sourceSubnetID - } - - height, err := r.relayer.pChainClient.GetHeight(context.Background()) - if err != nil { - r.relayer.logger.Error( - "Failed to get P-Chain height", - zap.Error(err), - ) - return nil, 0, err - } - - // Get the current canonical validator set of the source subnet. - canonicalSubnetValidators, totalValidatorWeight, err := avalancheWarp.GetCanonicalValidatorSet( - context.Background(), - r.relayer.canonicalValidatorClient, - height, - signingSubnet, - ) - if err != nil { - r.relayer.logger.Error( - "Failed to get the canonical subnet validator set", - zap.String("subnetID", r.relayer.sourceSubnetID.String()), - zap.Error(err), - ) - return nil, 0, err - } - - return canonicalSubnetValidators, totalValidatorWeight, nil -} - // isValidSignatureResponse tries to generate a signature from the peer.AsyncResponse, then verifies the signature against the node's public key. // If we are unable to generate the signature or verify correctly, false will be returned to indicate no valid signature was found in response. func (r *messageRelayer) isValidSignatureResponse( diff --git a/relayer/relayer.go b/relayer/relayer.go index 94c3cb4d..592f9801 100644 --- a/relayer/relayer.go +++ b/relayer/relayer.go @@ -38,25 +38,24 @@ const ( // Relayer handles all messages sent from a given source chain type Relayer struct { - Subscriber vms.Subscriber - pChainClient platformvm.Client - canonicalValidatorClient *CanonicalValidatorClient - currentRequestID uint32 - network *peers.AppRequestNetwork - sourceSubnetID ids.ID - sourceBlockchainID ids.ID - responseChan chan message.InboundMessage - contractMessage vms.ContractMessage - messageManagers map[common.Address]messages.MessageManager - logger logging.Logger - metrics *MessageRelayerMetrics - db database.RelayerDatabase - supportedDestinations set.Set[ids.ID] - rpcEndpoint string - messageCreator message.Creator - catchUpResultChan chan bool - healthStatus *atomic.Bool - globalConfig config.Config + Subscriber vms.Subscriber + pChainClient platformvm.Client + currentRequestID uint32 + network *peers.AppRequestNetwork + sourceSubnetID ids.ID + sourceBlockchainID ids.ID + responseChan chan message.InboundMessage + contractMessage vms.ContractMessage + messageManagers map[common.Address]messages.MessageManager + logger logging.Logger + metrics *MessageRelayerMetrics + db database.RelayerDatabase + supportedDestinations set.Set[ids.ID] + rpcEndpoint string + messageCreator message.Creator + catchUpResultChan chan bool + healthStatus *atomic.Bool + globalConfig config.Config } func NewRelayer( @@ -123,25 +122,24 @@ func NewRelayer( zap.String("blockchainIDHex", blockchainID.Hex()), ) r := Relayer{ - Subscriber: sub, - pChainClient: pChainClient, - canonicalValidatorClient: NewCanonicalValidatorClient(logger, pChainClient), - currentRequestID: rand.Uint32(), // Initialize to a random value to mitigate requestID collision - network: network, - sourceSubnetID: subnetID, - sourceBlockchainID: blockchainID, - responseChan: responseChan, - contractMessage: vms.NewContractMessage(logger, sourceSubnetInfo), - messageManagers: messageManagers, - logger: logger, - metrics: metrics, - db: db, - supportedDestinations: sourceSubnetInfo.GetSupportedDestinations(), - rpcEndpoint: sourceSubnetInfo.RPCEndpoint, - messageCreator: messageCreator, - catchUpResultChan: catchUpResultChan, - healthStatus: relayerHealth, - globalConfig: cfg, + Subscriber: sub, + pChainClient: pChainClient, + currentRequestID: rand.Uint32(), // Initialize to a random value to mitigate requestID collision + network: network, + sourceSubnetID: subnetID, + sourceBlockchainID: blockchainID, + responseChan: responseChan, + contractMessage: vms.NewContractMessage(logger, sourceSubnetInfo), + messageManagers: messageManagers, + logger: logger, + metrics: metrics, + db: db, + supportedDestinations: sourceSubnetInfo.GetSupportedDestinations(), + rpcEndpoint: sourceSubnetInfo.RPCEndpoint, + messageCreator: messageCreator, + catchUpResultChan: catchUpResultChan, + healthStatus: relayerHealth, + globalConfig: cfg, } // Open the subscription. We must do this before processing any missed messages, otherwise we may miss an incoming message diff --git a/tests/basic_relay.go b/tests/basic_relay.go index 8142636d..d19c7fb7 100644 --- a/tests/basic_relay.go +++ b/tests/basic_relay.go @@ -118,7 +118,7 @@ func BasicRelay(network interfaces.LocalNetwork) { logging.JSON.ConsoleEncoder(), ), ) - jsonDB, err := database.NewJSONFileStorage(logger, testUtils.RelayerStorageLocation(), []ids.ID{subnetAInfo.BlockchainID, subnetBInfo.BlockchainID}) + jsonDB, err := database.NewJSONFileStorage(logger, testUtils.RelayerStorageLocation(), relayerConfig.SourceBlockchains) Expect(err).Should(BeNil()) // Modify the JSON database to force the relayer to re-process old blocks diff --git a/relayer/canonical_validator_client.go b/validators/canonical_validator_client.go similarity index 85% rename from relayer/canonical_validator_client.go rename to validators/canonical_validator_client.go index fea26306..ccf44b03 100644 --- a/relayer/canonical_validator_client.go +++ b/validators/canonical_validator_client.go @@ -1,7 +1,7 @@ // Copyright (C) 2023, Ava Labs, Inc. All rights reserved. // See the file LICENSE for licensing terms. -package relayer +package validators import ( "context" @@ -10,6 +10,7 @@ import ( "github.com/ava-labs/avalanchego/snow/validators" "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/vms/platformvm" + avalancheWarp "github.com/ava-labs/avalanchego/vms/platformvm/warp" "go.uber.org/zap" ) @@ -28,6 +29,35 @@ func NewCanonicalValidatorClient(logger logging.Logger, client platformvm.Client } } +func (v *CanonicalValidatorClient) GetCurrentCanonicalValidatorSet(subnetID ids.ID) ([]*avalancheWarp.Validator, uint64, error) { + height, err := v.GetCurrentHeight(context.Background()) + if err != nil { + v.logger.Error( + "Failed to get P-Chain height", + zap.Error(err), + ) + return nil, 0, err + } + + // Get the current canonical validator set of the source subnet. + canonicalSubnetValidators, totalValidatorWeight, err := avalancheWarp.GetCanonicalValidatorSet( + context.Background(), + v, + height, + subnetID, + ) + if err != nil { + v.logger.Error( + "Failed to get the canonical subnet validator set", + zap.String("subnetID", subnetID.String()), + zap.Error(err), + ) + return nil, 0, err + } + + return canonicalSubnetValidators, totalValidatorWeight, nil +} + func (v *CanonicalValidatorClient) GetMinimumHeight(ctx context.Context) (uint64, error) { return v.client.GetHeight(ctx) }