diff --git a/currency/manager.go b/currency/manager.go index 271b0228d49..d57c125600a 100644 --- a/currency/manager.go +++ b/currency/manager.go @@ -33,8 +33,8 @@ var ( // GetAssetTypes returns a list of stored asset types func (p *PairsManager) GetAssetTypes(enabled bool) asset.Items { - p.Mutex.RLock() - defer p.Mutex.RUnlock() + p.mutex.RLock() + defer p.mutex.RUnlock() assetTypes := make(asset.Items, 0, len(p.Pairs)) for k, ps := range p.Pairs { if enabled && (ps.AssetEnabled == nil || !*ps.AssetEnabled) { @@ -51,8 +51,8 @@ func (p *PairsManager) Get(a asset.Item) (*PairStore, error) { return nil, fmt.Errorf("%s %w", a, asset.ErrNotSupported) } - p.Mutex.RLock() - defer p.Mutex.RUnlock() + p.mutex.RLock() + defer p.mutex.RUnlock() c, ok := p.Pairs[a] if !ok { return nil, @@ -70,20 +70,20 @@ func (p *PairsManager) Store(a asset.Item, ps *PairStore) error { if err != nil { return err } - p.Mutex.Lock() + p.mutex.Lock() if p.Pairs == nil { p.Pairs = make(map[asset.Item]*PairStore) } p.Pairs[a] = cpy - p.Mutex.Unlock() + p.mutex.Unlock() return nil } // Delete deletes a map entry based on the supplied asset type func (p *PairsManager) Delete(a asset.Item) { - p.Mutex.Lock() + p.mutex.Lock() delete(p.Pairs, a) - p.Mutex.Unlock() + p.mutex.Unlock() } // GetPairs gets a list of stored pairs based on the asset type and whether @@ -93,8 +93,8 @@ func (p *PairsManager) GetPairs(a asset.Item, enabled bool) (Pairs, error) { return nil, fmt.Errorf("%s %w", a, asset.ErrNotSupported) } - p.Mutex.RLock() - defer p.Mutex.RUnlock() + p.mutex.RLock() + defer p.mutex.RUnlock() pairStore, ok := p.Pairs[a] if !ok { return nil, nil @@ -134,8 +134,8 @@ func (p *PairsManager) StoreFormat(a asset.Item, pFmt *PairFormat, config bool) cpy := *pFmt - p.Mutex.Lock() - defer p.Mutex.Unlock() + p.mutex.Lock() + defer p.mutex.Unlock() if p.Pairs == nil { p.Pairs = make(map[asset.Item]*PairStore) @@ -167,8 +167,8 @@ func (p *PairsManager) StorePairs(a asset.Item, pairs Pairs, enabled bool) error cpy := make(Pairs, len(pairs)) copy(cpy, pairs) - p.Mutex.Lock() - defer p.Mutex.Unlock() + p.mutex.Lock() + defer p.mutex.Unlock() if p.Pairs == nil { p.Pairs = make(map[asset.Item]*PairStore) @@ -197,8 +197,8 @@ func (p *PairsManager) EnsureOnePairEnabled() (Pair, asset.Item, error) { if p == nil { return EMPTYPAIR, asset.Empty, common.ErrNilPointer } - p.Mutex.Lock() - defer p.Mutex.Unlock() + p.mutex.Lock() + defer p.mutex.Unlock() for _, v := range p.Pairs { if v.AssetEnabled == nil || !*v.AssetEnabled || @@ -235,8 +235,8 @@ func (p *PairsManager) DisablePair(a asset.Item, pair Pair) error { return ErrCurrencyPairEmpty } - p.Mutex.Lock() - defer p.Mutex.Unlock() + p.mutex.Lock() + defer p.mutex.Unlock() pairStore, err := p.getPairStoreRequiresLock(a) if err != nil { @@ -262,8 +262,8 @@ func (p *PairsManager) EnablePair(a asset.Item, pair Pair) error { return ErrCurrencyPairEmpty } - p.Mutex.Lock() - defer p.Mutex.Unlock() + p.mutex.Lock() + defer p.mutex.Unlock() pairStore, err := p.getPairStoreRequiresLock(a) if err != nil { @@ -292,8 +292,8 @@ func (p *PairsManager) IsAssetPairEnabled(a asset.Item, pair Pair) error { return ErrCurrencyPairEmpty } - p.Mutex.RLock() - defer p.Mutex.RUnlock() + p.mutex.RLock() + defer p.mutex.RUnlock() pairStore, err := p.getPairStoreRequiresLock(a) if err != nil { @@ -319,8 +319,8 @@ func (p *PairsManager) IsAssetEnabled(a asset.Item) error { return fmt.Errorf("%s %w", a, asset.ErrNotSupported) } - p.Mutex.RLock() - defer p.Mutex.RUnlock() + p.mutex.RLock() + defer p.mutex.RUnlock() pairStore, err := p.getPairStoreRequiresLock(a) if err != nil { @@ -343,8 +343,8 @@ func (p *PairsManager) SetAssetEnabled(a asset.Item, enabled bool) error { return fmt.Errorf("%s %w", a, asset.ErrNotSupported) } - p.Mutex.Lock() - defer p.Mutex.Unlock() + p.mutex.Lock() + defer p.mutex.Unlock() pairStore, err := p.getPairStoreRequiresLock(a) if err != nil { @@ -366,6 +366,37 @@ func (p *PairsManager) SetAssetEnabled(a asset.Item, enabled bool) error { return nil } +// Load sets the pair manager from a seed without copying mutexes +func (p *PairsManager) Load(seed *PairsManager) error { + if seed == nil { + return fmt.Errorf("%w PairsManager", common.ErrNilPointer) + } + p.mutex.Lock() + defer p.mutex.Unlock() + seed.mutex.RLock() + defer seed.mutex.RUnlock() + + var pN PairsManager + j, err := json.Marshal(seed) + if err != nil { + return err + } + err = json.Unmarshal(j, &pN) + if err != nil { + return err + } + p.BypassConfigFormatUpgrades = pN.BypassConfigFormatUpgrades + if pN.UseGlobalFormat { + p.UseGlobalFormat = pN.UseGlobalFormat + p.RequestFormat = pN.RequestFormat + p.ConfigFormat = pN.ConfigFormat + } + p.LastUpdated = pN.LastUpdated + p.Pairs = pN.Pairs + + return nil +} + func (p *PairsManager) getPairStoreRequiresLock(a asset.Item) (*PairStore, error) { if p.Pairs == nil { return nil, errors.New("pair manager not initialised") diff --git a/currency/manager_test.go b/currency/manager_test.go index 319292137d8..d9232ac248e 100644 --- a/currency/manager_test.go +++ b/currency/manager_test.go @@ -5,6 +5,8 @@ import ( "errors" "testing" + "github.com/stretchr/testify/assert" + "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/common/convert" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" ) @@ -672,3 +674,38 @@ func TestEnsureOnePairEnabled(t *testing.T) { t.Errorf("received: '%v' but expected: '%v'", err, ErrCurrencyPairsEmpty) } } + +func TestLoad(t *testing.T) { + t.Parallel() + base := PairsManager{} + fmt1 := &PairFormat{Uppercase: true} + fmt2 := &PairFormat{Uppercase: true, Delimiter: DashDelimiter} + p := NewPair(BTC, USDT) + tt := int64(1337) + seed := PairsManager{ + LastUpdated: tt, + UseGlobalFormat: true, + ConfigFormat: fmt1, + RequestFormat: fmt2, + Pairs: map[asset.Item]*PairStore{ + asset.Futures: { + AssetEnabled: convert.BoolPtr(true), + Available: []Pair{p}, + }, + asset.Options: { + AssetEnabled: convert.BoolPtr(false), + Available: []Pair{}, + }, + }, + } + + assert.ErrorIs(t, base.Load(nil), common.ErrNilPointer, "Load nil should error") + if assert.NoError(t, base.Load(&seed), "Loading from seed should not error") { + assert.True(t, *base.Pairs[asset.Futures].AssetEnabled, "Futures AssetEnabled should be true") + assert.True(t, base.Pairs[asset.Futures].Available.Contains(p, true), "Futures Available Pairs should contain BTCUSDT") + assert.False(t, *base.Pairs[asset.Options].AssetEnabled, "Options AssetEnabled should be false") + assert.Equal(t, tt, base.LastUpdated, "Last Updated should be correct") + assert.Equal(t, fmt1.Uppercase, base.ConfigFormat.Uppercase, "ConfigFormat Uppercase should be correct") + assert.Equal(t, fmt2.Delimiter, base.RequestFormat.Delimiter, "RequestFormat Delimiter should be correct") + } +} diff --git a/currency/manager_types.go b/currency/manager_types.go index d9a14a44d24..2b94d7cad0b 100644 --- a/currency/manager_types.go +++ b/currency/manager_types.go @@ -14,7 +14,7 @@ type PairsManager struct { UseGlobalFormat bool `json:"useGlobalFormat,omitempty"` LastUpdated int64 `json:"lastUpdated,omitempty"` Pairs FullStore `json:"pairs"` - Mutex sync.RWMutex `json:"-"` + mutex sync.RWMutex `json:"-"` } // FullStore holds all supported asset types with the enabled and available diff --git a/exchanges/sharedtestvalues/sharedtestvalues.go b/exchanges/sharedtestvalues/sharedtestvalues.go index 7691ae3e9de..c398f9314e7 100644 --- a/exchanges/sharedtestvalues/sharedtestvalues.go +++ b/exchanges/sharedtestvalues/sharedtestvalues.go @@ -157,12 +157,7 @@ func TestFixtureToDataHandler(t *testing.T, seed, e exchange.IBotExchange, fixtu b := e.GetBase() seedBase := seed.GetBase() - seedBase.CurrencyPairs.Mutex.RLock() - b.CurrencyPairs.RequestFormat = seedBase.CurrencyPairs.RequestFormat - b.CurrencyPairs.ConfigFormat = seedBase.CurrencyPairs.ConfigFormat - b.CurrencyPairs.UseGlobalFormat = seedBase.CurrencyPairs.UseGlobalFormat - b.CurrencyPairs.Pairs = seedBase.CurrencyPairs.Pairs - seedBase.CurrencyPairs.Mutex.RUnlock() + assert.NoError(t, b.CurrencyPairs.Load(&seedBase.CurrencyPairs), "Loading currency pairs should not error") b.Name = "fixture" b.Websocket = &stream.Websocket{