diff --git a/rfq/oracle.go b/rfq/oracle.go index 50ce9c522..d2423b553 100644 --- a/rfq/oracle.go +++ b/rfq/oracle.go @@ -221,11 +221,14 @@ func (r *RpcPriceOracle) QueryAskPrice(ctx context.Context, return nil, fmt.Errorf("asset ID is nil") } - // Construct query request. - var subjectAssetId []byte - copy(subjectAssetId, assetId[:]) + var ( + subjectAssetId = make([]byte, 32) + paymentAssetId = make([]byte, 32) + ) - paymentAssetId := make([]byte, 32) + // The payment asset ID is BTC, so we leave it at all zeroes. We only + // set the subject asset ID. + copy(subjectAssetId, assetId[:]) // Construct the RPC rate tick hint. var rateTickHint *oraclerpc.RateTick @@ -303,11 +306,14 @@ func (r *RpcPriceOracle) QueryBidPrice(ctx context.Context, assetId *asset.ID, return nil, fmt.Errorf("asset ID is nil") } - // Construct query request. - var subjectAssetId []byte - copy(subjectAssetId, assetId[:]) + var ( + subjectAssetId = make([]byte, 32) + paymentAssetId = make([]byte, 32) + ) - paymentAssetId := make([]byte, 32) + // The payment asset ID is BTC, so we leave it at all zeroes. We only + // set the subject asset ID. + copy(subjectAssetId, assetId[:]) req := &oraclerpc.QueryRateTickRequest{ TransactionType: oraclerpc.TransactionType_PURCHASE, diff --git a/rfq/oracle_test.go b/rfq/oracle_test.go index 58a500642..3f1637798 100644 --- a/rfq/oracle_test.go +++ b/rfq/oracle_test.go @@ -1,6 +1,7 @@ package rfq import ( + "bytes" "context" "fmt" "net" @@ -43,6 +44,11 @@ func (p *mockRpcPriceOracleServer) QueryRateTick(_ context.Context, ExpiryTimestamp: uint64(expiry), } + err := validateRateTickRequest(req) + if err != nil { + return nil, err + } + // If a rate tick hint is provided, return it as the rate tick. if req.RateTickHint != nil { rateTick.Rate = req.RateTickHint.Rate @@ -58,6 +64,32 @@ func (p *mockRpcPriceOracleServer) QueryRateTick(_ context.Context, }, nil } +// validateRateTickRequest validates the given rate tick request. +func validateRateTickRequest(req *priceoraclerpc.QueryRateTickRequest) error { + var zeroAssetID [32]byte + if req.SubjectAsset == nil { + return fmt.Errorf("subject asset must be specified") + } + if len(req.SubjectAsset.GetAssetId()) != 32 { + return fmt.Errorf("invalid subject asset ID length") + } + if bytes.Equal(req.SubjectAsset.GetAssetId(), zeroAssetID[:]) { + return fmt.Errorf("subject asset ID must NOT be all zero") + } + + if req.PaymentAsset == nil { + return fmt.Errorf("payment asset must be specified") + } + if len(req.PaymentAsset.GetAssetId()) != 32 { + return fmt.Errorf("invalid payment asset ID length") + } + if !bytes.Equal(req.PaymentAsset.GetAssetId(), zeroAssetID[:]) { + return fmt.Errorf("payment asset ID must be all zero") + } + + return nil +} + // startBackendRPC starts the given RPC server and blocks until the server is // shut down. func startBackendRPC(grpcServer *grpc.Server) error { @@ -95,7 +127,7 @@ func runQueryAskPriceTest(t *testing.T, tc *testCaseQueryAskPrice) { defer backendService.Stop() // Wait for the server to start. - time.Sleep(2 * time.Second) + time.Sleep(200 * time.Millisecond) // Create a new RPC price oracle client and connect to the mock service. serviceAddr := fmt.Sprintf("rfqrpc://%s", testServiceAddress)