diff --git a/exchanges/asset/asset.go b/exchanges/asset/asset.go index aa03ed4e16c..a125eb1120c 100644 --- a/exchanges/asset/asset.go +++ b/exchanges/asset/asset.go @@ -22,37 +22,33 @@ type Items []Item // Supported Assets const ( - Empty Item = 0 - Spot Item = 1 << (iota - 1) + Empty Item = iota + Spot Margin CrossMargin MarginFunding Index Binary + Futures // All Futures must come after this PerpetualContract PerpetualSwap - Futures DeliveryFutures UpsideProfitContract DownsideProfitContract CoinMarginedFutures USDTMarginedFutures USDCMarginedFutures - Options - OptionCombo FutureCombo LinearContract // Derivatives with a linear Base (e.g. USDT or USDC) - All // Must come immediately after all valid assets + Options // All Options must come after this + OptionCombo + All // Must come immediately after all valid assets ) const ( - optionsFlag = OptionCombo | Options - futuresFlag = PerpetualContract | PerpetualSwap | Futures | DeliveryFutures | UpsideProfitContract | DownsideProfitContract | CoinMarginedFutures | USDTMarginedFutures | USDCMarginedFutures | LinearContract | FutureCombo - supportedFlag = Spot | Margin | CrossMargin | MarginFunding | Index | Binary | PerpetualContract | PerpetualSwap | Futures | DeliveryFutures | UpsideProfitContract | DownsideProfitContract | CoinMarginedFutures | USDTMarginedFutures | USDCMarginedFutures | Options | LinearContract | OptionCombo | FutureCombo - spot = "spot" margin = "margin" - crossMargin = "cross_margin" // for Gateio exchange + crossMargin = "cross_margin" marginFunding = "marginfunding" index = "index" binary = "binary" @@ -160,7 +156,17 @@ func (a Items) JoinToString(separator string) string { // IsValid returns whether or not the supplied asset type is valid or not func (a Item) IsValid() bool { - return a != Empty && supportedFlag&a == a + return a > Empty && a < All +} + +// IsFutures checks if the asset type is a futures contract based asset +func (a Item) IsFutures() bool { + return a >= Futures && a < Options +} + +// IsOptions checks if the asset type is options contract based asset +func (a Item) IsOptions() bool { + return a >= Options && a < All } // UnmarshalJSON conforms type to the umarshaler interface @@ -242,13 +248,3 @@ func New(input string) (Item, error) { func UseDefault() Item { return Spot } - -// IsFutures checks if the asset type is a futures contract based asset -func (a Item) IsFutures() bool { - return a != Empty && futuresFlag&a == a -} - -// IsOptions checks if the asset type is options contract based asset -func (a Item) IsOptions() bool { - return a != Empty && optionsFlag&a == a -} diff --git a/exchanges/asset/asset_test.go b/exchanges/asset/asset_test.go index f82129fa4d1..66910e585c9 100644 --- a/exchanges/asset/asset_test.go +++ b/exchanges/asset/asset_test.go @@ -3,9 +3,11 @@ package asset import ( "encoding/json" "errors" + "slices" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestString(t *testing.T) { @@ -53,12 +55,37 @@ func TestJoinToString(t *testing.T) { func TestIsValid(t *testing.T) { t.Parallel() - if Item(0).IsValid() { - t.Fatal("TestIsValid returned an unexpected result") + for a := range All { + if a.String() == "" { + require.Falsef(t, a.IsValid(), "IsValid must return false with non-asset value %d", a) + } else { + require.Truef(t, a.IsValid(), "IsValid must return true for %s", a) + } } + require.Falsef(t, All.IsValid(), "IsValid must return false for All") +} - if !Spot.IsValid() { - t.Fatal("TestIsValid returned an unexpected result") +func TestIsFutures(t *testing.T) { + t.Parallel() + valid := []Item{PerpetualContract, PerpetualSwap, Futures, DeliveryFutures, UpsideProfitContract, DownsideProfitContract, CoinMarginedFutures, USDTMarginedFutures, USDCMarginedFutures, FutureCombo, LinearContract} + for a := range All { + if slices.Contains(valid, a) { + require.Truef(t, a.IsFutures(), "IsFutures must return true for %s", a) + } else { + require.Falsef(t, a.IsFutures(), "IsFutures must return false for non-asset value %d (%s)", a, a) + } + } +} + +func TestIsOptions(t *testing.T) { + t.Parallel() + valid := []Item{Options, OptionCombo} + for a := range All { + if slices.Contains(valid, a) { + require.Truef(t, a.IsOptions(), "IsOptions must return true for %s", a) + } else { + require.Falsef(t, a.IsOptions(), "IsOptions must return false for non-asset value %d (%s)", a, a) + } } } @@ -117,50 +144,6 @@ func TestSupported(t *testing.T) { } } -func TestIsFutures(t *testing.T) { - t.Parallel() - for _, a := range []Item{Spot, Margin, MarginFunding, Index, Binary} { - assert.Falsef(t, a.IsFutures(), "%s should return correctly for IsFutures", a) - } - for _, a := range []Item{PerpetualContract, PerpetualSwap, Futures, UpsideProfitContract, DownsideProfitContract, CoinMarginedFutures, USDTMarginedFutures, USDCMarginedFutures, FutureCombo} { - assert.Truef(t, a.IsFutures(), "%s should return correctly for IsFutures", a) - } -} - -func TestIsOptions(t *testing.T) { - t.Parallel() - type scenario struct { - item Item - isOptions bool - } - scenarios := []scenario{ - { - item: Options, - isOptions: true, - }, { - item: OptionCombo, - isOptions: true, - }, - { - item: Futures, - isOptions: false, - }, - { - item: Empty, - isOptions: false, - }, - } - for _, s := range scenarios { - testScenario := s - t.Run(testScenario.item.String(), func(t *testing.T) { - t.Parallel() - if testScenario.item.IsOptions() != testScenario.isOptions { - t.Errorf("expected %v isOptions to be %v", testScenario.item, testScenario.isOptions) - } - }) - } -} - func TestUnmarshalMarshal(t *testing.T) { t.Parallel() data, err := json.Marshal(Item(0)) diff --git a/exchanges/order/order_test.go b/exchanges/order/order_test.go index 04c0258d44a..2f2b8190595 100644 --- a/exchanges/order/order_test.go +++ b/exchanges/order/order_test.go @@ -55,7 +55,7 @@ func TestSubmit_Validate(t *testing.T) { Submit: &Submit{ Exchange: "test", Pair: testPair, - AssetType: 255, + AssetType: asset.All, }, }, // valid pair but invalid asset {