From aca89a43935047c73d0d293897030e0ef37de855 Mon Sep 17 00:00:00 2001 From: Mark Stacey Date: Thu, 7 Dec 2023 15:12:15 -0330 Subject: [PATCH] fix(assets-controllers): TokenRatesController state consistency (#3624) ## Explanation The TokenRatesController has been updated to ensure that the two sets of exchange rate state are always consistent with each other. Previously if the method `updateExchangeRatesByChainId` was called directly, it would not update the `contractExchangeRates` state even if the equivalent part of `contractExchangeRatesByChainId` was updated by that call. The tests have undergone a substantial refactor to ensure that we cover all cases between both of these update methods. The test cases written previously left out many test cases. We should now have better (but not perfect) coverage, and the tests are now shared between the two methods because they have identical test cases. ## References Fixes #3597 ## Changelog ### `@metamask/assets-controllers` - Fixed: Fixed bug where the `contractExchangeRates` state would sometimes be stale after calling `updateExchangeRatesByChainId` ## Checklist - [x] I've updated the test suite for new or updated code as appropriate - [x] I've updated documentation (JSDoc, Markdown, etc.) for new or updated code as appropriate - [x] I've highlighted breaking changes using the "BREAKING" category above as appropriate --------- Co-authored-by: Elliot Winkler --- .../src/TokenRatesController.test.ts | 764 +++++++++++++----- .../src/TokenRatesController.ts | 34 +- 2 files changed, 575 insertions(+), 223 deletions(-) diff --git a/packages/assets-controllers/src/TokenRatesController.test.ts b/packages/assets-controllers/src/TokenRatesController.test.ts index 433fc402c79..b3cf4203cf0 100644 --- a/packages/assets-controllers/src/TokenRatesController.test.ts +++ b/packages/assets-controllers/src/TokenRatesController.test.ts @@ -1,4 +1,6 @@ import { NetworksTicker, toHex } from '@metamask/controller-utils'; +import type { NetworkState } from '@metamask/network-controller'; +import type { PreferencesState } from '@metamask/preferences-controller'; import type { Hex } from '@metamask/utils'; import nock from 'nock'; import { useFakeTimers } from 'sinon'; @@ -9,9 +11,11 @@ import type { TokenPrice, TokenPricesByTokenContractAddress, } from './token-prices-service/abstract-token-prices-service'; +import type { TokenBalancesState } from './TokenBalancesController'; import { TokenRatesController } from './TokenRatesController'; +import type { TokenRatesConfig } from './TokenRatesController'; +import type { TokensState } from './TokensController'; -const ADDRESS = '0x01'; const defaultSelectedAddress = '0x0000000000000000000000000000000000000001'; describe('TokenRatesController', () => { @@ -1033,257 +1037,601 @@ describe('TokenRatesController', () => { }); }); - describe('updateExchangeRates', () => { - it('should not update exchange rates if legacy polling is disabled', async () => { - const tokenPricesService = buildMockTokenPricesService(); - jest.spyOn(tokenPricesService, 'fetchTokenPrices'); - const controller = new TokenRatesController( - { - chainId: '0x1', - ticker: NetworksTicker.mainnet, - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), - onTokensStateChange: jest.fn(), - onNetworkStateChange: jest.fn(), - getNetworkClientById: jest.fn(), - tokenPricesService, - }, - { - disabled: true, + // The TokenRatesController has two methods for updating exchange rates: + // `updateExchangeRates` and `updateExchangeRatesByChainId`. They are the same + // except in how the inputs are specified. `updateExchangeRates` gets the + // inputs from controller configuration, whereas `updateExchangeRatesByChainId` + // accepts the inputs as parameters. + // + // Here we test both of these methods using the same test cases. The + // differences between them are abstracted away by the helper function + // `callUpdateExchangeRatesMethod`. + describe.each([ + 'updateExchangeRates' as const, + 'updateExchangeRatesByChainId' as const, + ])('%s', (method) => { + it('does not update state when disabled', async () => { + await withController( + { config: { disabled: true } }, + async ({ controller, controllerEvents }) => { + const tokenAddress = '0x0000000000000000000000000000000000000001'; + + await callUpdateExchangeRatesMethod({ + allTokens: { + [toHex(1)]: { + [controller.config.selectedAddress]: [ + { + address: tokenAddress, + decimals: 18, + symbol: 'TST', + aggregators: [], + }, + ], + }, + }, + chainId: toHex(1), + controller, + controllerEvents, + method, + nativeCurrency: 'ETH', + }); + + expect(controller.state.contractExchangeRates).toStrictEqual({}); + expect(controller.state.contractExchangeRatesByChainId).toStrictEqual( + {}, + ); }, ); - - await controller.updateExchangeRates(); - - expect(tokenPricesService.fetchTokenPrices).not.toHaveBeenCalled(); }); - it('should update legacy state after updateExchangeRatesByChainId', async () => { - const controller = new TokenRatesController( - { - chainId: '0x1', - ticker: NetworksTicker.mainnet, - selectedAddress: defaultSelectedAddress, - onPreferencesStateChange: jest.fn(), - onTokensStateChange: jest.fn(), - onNetworkStateChange: jest.fn(), - getNetworkClientById: jest.fn(), - tokenPricesService: buildMockTokenPricesService(), - }, - { + it('does not update state if there are no tokens for the given chain and address', async () => { + await withController(async ({ controller, controllerEvents }) => { + const tokenAddress = '0x0000000000000000000000000000000000000001'; + const differentAccount = '0x1000000000000000000000000000000000000000'; + + await callUpdateExchangeRatesMethod({ allTokens: { - '0x1': { - [defaultSelectedAddress]: [ + // These tokens are for the right chain but wrong account + [toHex(1)]: { + [differentAccount]: [ { - address: '0x123', + address: tokenAddress, decimals: 18, - symbol: 'DAI', + symbol: 'TST', + aggregators: [], + }, + ], + }, + // These tokens are for the right account but wrong chain + [toHex(2)]: { + [controller.config.selectedAddress]: [ + { + address: tokenAddress, + decimals: 18, + symbol: 'TST', aggregators: [], }, - { address: ADDRESS, decimals: 0, symbol: '', aggregators: [] }, ], }, }, - }, - ); - - const updateExchangeRatesByChainIdSpy = jest - .spyOn(controller, 'updateExchangeRatesByChainId') - .mockResolvedValue(); - - // Setting mock state as if updateExchangeRatesByChainId updated it - controller.state.contractExchangeRatesByChainId = { - '0x1': { - [NetworksTicker.mainnet]: { - '0x123': 123, - '0x01': 100, - }, - }, - }; - - await controller.updateExchangeRates(); - - expect(updateExchangeRatesByChainIdSpy).toHaveBeenCalledWith({ - chainId: '0x1', - nativeCurrency: NetworksTicker.mainnet, - tokenContractAddresses: ['0x123', ADDRESS], - }); + chainId: toHex(1), + controller, + controllerEvents, + method, + nativeCurrency: 'ETH', + }); - expect(controller.state.contractExchangeRates).toStrictEqual({ - '0x123': 123, - '0x01': 100, + expect(controller.state.contractExchangeRates).toStrictEqual({}); + expect(controller.state.contractExchangeRatesByChainId).toStrictEqual( + {}, + ); }); }); - }); - describe('updateExchangeRatesByChainId', () => { - it('should not update state if no token contract addresses are provided', async () => { - const controller = new TokenRatesController({ - interval: 100, - chainId: '0x2', - ticker: 'ticker', - selectedAddress: '0xdeadbeef', - onPreferencesStateChange: jest.fn(), - onTokensStateChange: jest.fn(), - onNetworkStateChange: jest.fn(), - getNetworkClientById: jest.fn(), - tokenPricesService: buildMockTokenPricesService(), - }); - - expect(controller.state.contractExchangeRates).toStrictEqual({}); - await controller.updateExchangeRatesByChainId({ - chainId: '0x1', - nativeCurrency: 'ETH', - tokenContractAddresses: [], + it('does not update state if the price update fails', async () => { + const tokenPricesService = buildMockTokenPricesService({ + fetchTokenPrices: jest + .fn() + .mockRejectedValue(new Error('Failed to fetch')), }); - expect(controller.state.contractExchangeRatesByChainId).toStrictEqual({}); - }); - - it('should not update state when disabled', async () => { - const tokenContractAddress = '0x89d24A6b4CcB1B6fAA2625fE562bDD9a23260359'; - const controller = new TokenRatesController( - { - interval: 100, - chainId: '0x2', - ticker: 'ticker', - selectedAddress: '0xdeadbeef', - onPreferencesStateChange: jest.fn(), - onTokensStateChange: jest.fn(), - onNetworkStateChange: jest.fn(), - getNetworkClientById: jest.fn(), - tokenPricesService: buildMockTokenPricesService(), + await withController( + { options: { tokenPricesService } }, + async ({ controller, controllerEvents }) => { + const tokenAddress = '0x0000000000000000000000000000000000000001'; + + await expect( + async () => + await callUpdateExchangeRatesMethod({ + allTokens: { + [toHex(1)]: { + [controller.config.selectedAddress]: [ + { + address: tokenAddress, + decimals: 18, + symbol: 'TST', + aggregators: [], + }, + ], + }, + }, + chainId: toHex(1), + controller, + controllerEvents, + method, + nativeCurrency: 'ETH', + }), + ).rejects.toThrow('Failed to fetch'); + expect(controller.state.contractExchangeRates).toStrictEqual({}); + expect(controller.state.contractExchangeRatesByChainId).toStrictEqual( + {}, + ); }, - { disabled: true }, ); - expect(controller.state.contractExchangeRatesByChainId).toStrictEqual({}); + }); - await controller.updateExchangeRatesByChainId({ - chainId: '0x1', - nativeCurrency: 'ETH', - tokenContractAddresses: [tokenContractAddress], + it('updates all rates', async () => { + const tokenAddresses = [ + '0x0000000000000000000000000000000000000001', + '0x0000000000000000000000000000000000000002', + ]; + const tokenPricesService = buildMockTokenPricesService({ + fetchTokenPrices: jest.fn().mockResolvedValue({ + [tokenAddresses[0]]: { + currency: 'ETH', + tokenContractAddress: tokenAddresses[0], + value: 0.001, + }, + [tokenAddresses[1]]: { + currency: 'ETH', + tokenContractAddress: tokenAddresses[1], + value: 0.002, + }, + }), }); + await withController( + { options: { tokenPricesService } }, + async ({ controller, controllerEvents }) => { + await callUpdateExchangeRatesMethod({ + allTokens: { + [toHex(1)]: { + [controller.config.selectedAddress]: [ + { + address: tokenAddresses[0], + decimals: 18, + symbol: 'TST1', + aggregators: [], + }, + { + address: tokenAddresses[1], + decimals: 18, + symbol: 'TST2', + aggregators: [], + }, + ], + }, + }, + chainId: toHex(1), + controller, + controllerEvents, + method, + nativeCurrency: 'ETH', + }); - expect(controller.state.contractExchangeRatesByChainId).toStrictEqual({}); + expect(controller.state).toMatchInlineSnapshot(` + Object { + "contractExchangeRates": Object { + "0x0000000000000000000000000000000000000001": 0.001, + "0x0000000000000000000000000000000000000002": 0.002, + }, + "contractExchangeRatesByChainId": Object { + "0x1": Object { + "ETH": Object { + "0x0000000000000000000000000000000000000001": 0.001, + "0x0000000000000000000000000000000000000002": 0.002, + }, + }, + }, + } + `); + }, + ); }); - it('should update exchange rates for the given token addresses to undefined when the given chain ID is not supported by the Price API', async () => { - const controller = new TokenRatesController( - { - chainId: '0x2', - ticker: 'ticker', - selectedAddress: '0xdeadbeef', - onPreferencesStateChange: jest.fn(), - onTokensStateChange: jest.fn(), - onNetworkStateChange: jest.fn(), - getNetworkClientById: jest.fn(), - tokenPricesService: buildMockTokenPricesService({ - validateChainIdSupported(chainId: unknown): chainId is Hex { - return chainId !== '0x9999999999'; + if (method === 'updateExchangeRatesByChainId') { + it('updates rates only for a non-selected chain', async () => { + const tokenAddresses = [ + '0x0000000000000000000000000000000000000001', + '0x0000000000000000000000000000000000000002', + ]; + const tokenPricesService = buildMockTokenPricesService({ + fetchTokenPrices: jest.fn().mockResolvedValue({ + [tokenAddresses[0]]: { + currency: 'ETH', + tokenContractAddress: tokenAddresses[0], + value: 0.001, + }, + [tokenAddresses[1]]: { + currency: 'ETH', + tokenContractAddress: tokenAddresses[1], + value: 0.002, }, }), - }, - {}, - { - contractExchangeRatesByChainId: { - '0x9999999999': { - MATIC: { - '0x02': 0.01, - '0x03': 0.02, - '0x04': 0.03, + }); + await withController( + { options: { tokenPricesService } }, + async ({ controller, controllerEvents }) => { + await callUpdateExchangeRatesMethod({ + allTokens: { + [toHex(2)]: { + [controller.config.selectedAddress]: [ + { + address: tokenAddresses[0], + decimals: 18, + symbol: 'TST1', + aggregators: [], + }, + { + address: tokenAddresses[1], + decimals: 18, + symbol: 'TST2', + aggregators: [], + }, + ], + }, }, - }, - }, - }, - ); - - await controller.updateExchangeRatesByChainId({ - chainId: '0x9999999999', - nativeCurrency: 'MATIC', - tokenContractAddresses: ['0x02', '0x03'], - }); - - expect(controller.state.contractExchangeRatesByChainId).toStrictEqual({ - '0x9999999999': { - MATIC: { - '0x02': undefined, - '0x03': undefined, - '0x04': 0.03, + chainId: toHex(2), + controller, + controllerEvents, + method, + nativeCurrency: 'ETH', + setChainAsCurrent: false, + }); + + expect(controller.state).toMatchInlineSnapshot(` + Object { + "contractExchangeRates": Object {}, + "contractExchangeRatesByChainId": Object { + "0x2": Object { + "ETH": Object { + "0x0000000000000000000000000000000000000001": 0.001, + "0x0000000000000000000000000000000000000002": 0.002, + }, + }, + }, + } + `); }, - }, + ); }); - }); + } - it('should update exchange rates when native currency is supported by the Price API', async () => { + it('updates exchange rates when native currency is not supported by the Price API', async () => { + const tokenAddresses = [ + '0x0000000000000000000000000000000000000001', + '0x0000000000000000000000000000000000000002', + ]; const tokenPricesService = buildMockTokenPricesService({ - fetchTokenPrices: fetchTokenPricesWithIncreasingPriceForEachToken, - }); - const controller = new TokenRatesController({ - interval: 100, - chainId: '0x2', - ticker: 'ticker', - selectedAddress: '0xdeadbeef', - onPreferencesStateChange: jest.fn(), - onTokensStateChange: jest.fn(), - onNetworkStateChange: jest.fn(), - getNetworkClientById: jest.fn(), - tokenPricesService, + fetchTokenPrices: jest.fn().mockResolvedValue({ + [tokenAddresses[0]]: { + currency: 'ETH', + tokenContractAddress: tokenAddresses[0], + value: 0.001, + }, + [tokenAddresses[1]]: { + currency: 'ETH', + tokenContractAddress: tokenAddresses[1], + value: 0.002, + }, + }), + validateCurrencySupported: jest.fn().mockReturnValue( + false, + // Cast used because this method has an assertion in the return + // value that I don't know how to type properly with Jest's mock. + ) as unknown as AbstractTokenPricesService['validateCurrencySupported'], }); + nock('https://min-api.cryptocompare.com') + .get('/data/price') + .query({ + fsym: 'ETH', + tsyms: 'UNSUPPORTED', + }) + .reply(200, { UNSUPPORTED: 0.5 }); // .5 eth to 1 matic + + await withController( + { options: { tokenPricesService } }, + async ({ controller, controllerEvents }) => { + await callUpdateExchangeRatesMethod({ + allTokens: { + [toHex(137)]: { + [controller.config.selectedAddress]: [ + { + address: tokenAddresses[0], + decimals: 18, + symbol: 'TST1', + aggregators: [], + }, + { + address: tokenAddresses[1], + decimals: 18, + symbol: 'TST2', + aggregators: [], + }, + ], + }, + }, + chainId: toHex(137), + controller, + controllerEvents, + method, + nativeCurrency: 'UNSUPPORTED', + }); - expect(controller.state.contractExchangeRates).toStrictEqual({}); - await controller.updateExchangeRatesByChainId({ - chainId: '0x1', - nativeCurrency: 'ETH', - tokenContractAddresses: ['0xAAA'], - }); - expect(controller.state.contractExchangeRatesByChainId).toStrictEqual({ - '0x1': { - ETH: { - '0xAAA': 0.001, - }, + // token value in terms of matic should be (token value in eth) * (eth value in matic) + expect(controller.state).toMatchInlineSnapshot(` + Object { + "contractExchangeRates": Object { + "0x0000000000000000000000000000000000000001": 0.0005, + "0x0000000000000000000000000000000000000002": 0.001, + }, + "contractExchangeRatesByChainId": Object { + "0x89": Object { + "UNSUPPORTED": Object { + "0x0000000000000000000000000000000000000001": 0.0005, + "0x0000000000000000000000000000000000000002": 0.001, + }, + }, + }, + } + `); }, - }); + ); }); - it('should update exchange rates when native currency is not supported by the Price API', async () => { - nock('https://min-api.cryptocompare.com') - .get('/data/price?fsym=ETH&tsyms=LOL') - .reply(200, { LOL: 0.5 }); + it('sets rates to undefined when chain is not supported by the Price API', async () => { + const tokenAddresses = [ + '0x0000000000000000000000000000000000000001', + '0x0000000000000000000000000000000000000002', + ]; const tokenPricesService = buildMockTokenPricesService({ - fetchTokenPrices: fetchTokenPricesWithIncreasingPriceForEachToken, - validateCurrencySupported(currency: unknown): currency is string { - return currency !== 'LOL'; - }, - }); - const controller = new TokenRatesController({ - chainId: '0x2', - ticker: 'ticker', - selectedAddress: '0xdeadbeef', - onPreferencesStateChange: jest.fn(), - onTokensStateChange: jest.fn(), - onNetworkStateChange: jest.fn(), - getNetworkClientById: jest.fn(), - tokenPricesService, - }); - - await controller.updateExchangeRatesByChainId({ - chainId: '0x1', - nativeCurrency: 'LOL', - tokenContractAddresses: ['0x02', '0x03'], + fetchTokenPrices: jest.fn().mockResolvedValue({ + [tokenAddresses[0]]: { + currency: 'ETH', + tokenContractAddress: tokenAddresses[0], + value: 0.001, + }, + [tokenAddresses[1]]: { + currency: 'ETH', + tokenContractAddress: tokenAddresses[1], + value: 0.002, + }, + }), + validateChainIdSupported: jest.fn().mockReturnValue( + false, + // Cast used because this method has an assertion in the return + // value that I don't know how to type properly with Jest's mock. + ) as unknown as AbstractTokenPricesService['validateChainIdSupported'], }); + await withController( + { options: { tokenPricesService } }, + async ({ controller, controllerEvents }) => { + await callUpdateExchangeRatesMethod({ + allTokens: { + [toHex(999)]: { + [controller.config.selectedAddress]: [ + { + address: tokenAddresses[0], + decimals: 18, + symbol: 'TST1', + aggregators: [], + }, + { + address: tokenAddresses[1], + decimals: 18, + symbol: 'TST2', + aggregators: [], + }, + ], + }, + }, + chainId: toHex(999), + controller, + controllerEvents, + method, + nativeCurrency: 'TST', + }); - expect(controller.state.contractExchangeRatesByChainId).toStrictEqual({ - '0x1': { - LOL: { - // token price in LOL = (token price in ETH) * (ETH value in LOL) - '0x02': 0.0005, - '0x03': 0.001, - }, + expect(controller.state).toMatchInlineSnapshot(` + Object { + "contractExchangeRates": Object { + "0x0000000000000000000000000000000000000001": undefined, + "0x0000000000000000000000000000000000000002": undefined, + }, + "contractExchangeRatesByChainId": Object { + "0x3e7": Object { + "TST": Object { + "0x0000000000000000000000000000000000000001": undefined, + "0x0000000000000000000000000000000000000002": undefined, + }, + }, + }, + } + `); }, - }); + ); }); }); }); +/** + * A collection of mock external controller events. + */ +type ControllerEvents = { + networkStateChange: (state: NetworkState) => void; + preferencesStateChange: (state: PreferencesState) => void; + tokensStateChange: (state: TokensState) => void; +}; + +/** + * A callback for the `withController` helper function. + * + * @param args - The arguments. + * @param args.controller - The controller that the test helper created. + * @param args.controllerEvents - A collection of methods for dispatching mock + * events from external controllers. + */ +type WithControllerCallback = ({ + controller, + controllerEvents, +}: { + controller: TokenRatesController; + controllerEvents: ControllerEvents; +}) => Promise | ReturnValue; + +type PartialConstructorParameters = { + options?: Partial[0]>; + config?: Partial; + state?: Partial; +}; + +type WithControllerArgs = + | [WithControllerCallback] + | [PartialConstructorParameters, WithControllerCallback]; + +/** + * Builds a controller based on the given options, and calls the given function + * with that controller. + * + * @param args - Either a function, or a set of partial constructor parameters + * plus a function. The function will be called with the built controller and a + * collection of controller event handlers. + * @returns Whatever the callback returns. + */ +async function withController( + ...args: WithControllerArgs +) { + const [{ options, config, state }, testFunction] = + args.length === 2 + ? args + : [{ options: undefined, config: undefined, state: undefined }, args[0]]; + + // explit cast used here because we know the `on____` functions are always + // set in the constructor. + const controllerEvents = {} as ControllerEvents; + + const controllerOptions: ConstructorParameters< + typeof TokenRatesController + >[0] = { + chainId: toHex(1), + getNetworkClientById: jest.fn(), + onNetworkStateChange: (listener) => { + controllerEvents.networkStateChange = listener; + }, + onPreferencesStateChange: (listener) => { + controllerEvents.preferencesStateChange = listener; + }, + onTokensStateChange: (listener) => { + controllerEvents.tokensStateChange = listener; + }, + selectedAddress: defaultSelectedAddress, + ticker: NetworksTicker.mainnet, + tokenPricesService: buildMockTokenPricesService(), + ...options, + }; + + const controller = new TokenRatesController(controllerOptions, config, state); + try { + return await testFunction({ + controller, + controllerEvents, + }); + } finally { + controller.stop(); + } +} + +/** + * Call an "update exchange rates" method with the given parameters. + * + * The TokenRatesController has two methods for updating exchange rates: + * `updateExchangeRates` and `updateExchangeRatesByChainId`. They are the same + * except in how the inputs are specified. `updateExchangeRates` gets the + * inputs from controller configuration, whereas `updateExchangeRatesByChainId` + * accepts the inputs as parameters. + * + * This helper function normalizes between these two functions, so that we can + * test them the same way. + * + * @param args - The arguments. + * @param args.allTokens - The `allTokens` state (from the TokensController) + * @param args.chainId - The chain ID of the chain we want to update the + * exchange rates for. + * @param args.controller - The controller to call the method with. + * @param args.controllerEvents - Controller event handlers, used to + * update controller configuration. + * @param args.method - The "update exchange rates" method to call. + * @param args.nativeCurrency - The symbol for the native currency of the + * network we're getting updated exchange rates for. + * @param args.setChainAsCurrent - When calling `updateExchangeRatesByChainId`, + * this determines whether to set the chain as the globally selected chain. + */ +async function callUpdateExchangeRatesMethod({ + allTokens, + chainId, + controller, + controllerEvents, + method, + nativeCurrency, + setChainAsCurrent = true, +}: { + allTokens: TokenRatesConfig['allTokens']; + chainId: TokenRatesConfig['chainId']; + controller: TokenRatesController; + controllerEvents: ControllerEvents; + method: 'updateExchangeRates' | 'updateExchangeRatesByChainId'; + nativeCurrency: TokenRatesConfig['nativeCurrency']; + setChainAsCurrent?: boolean; +}) { + if (method === 'updateExchangeRates' && !setChainAsCurrent) { + throw new Error( + 'The "setChainAsCurrent" flag cannot be enabled when calling the "updateExchangeRates" method', + ); + } + + if (setChainAsCurrent) { + // We're using controller events here instead of calling `configure` + // because `configure` does not update internal controller state correctly. + // As with many BaseControllerV1-based controllers, runtime config + // modification is allowed by the API but not supported in practice. + controllerEvents.networkStateChange({ + // Note that the state given here is intentionally incomplete because the + // controller only uses these two properties, and the tests are written to + // only consider these two. We want this to break if we start relying on + // more, as we'd need to update the tests accordingly. + // @ts-expect-error Intentionally incomplete state + providerConfig: { chainId, ticker: nativeCurrency }, + }); + // Note that the state given here is intentionally incomplete because the + // controller only uses these two properties, and the tests are written to + // only consider these two. We want this to break if we start relying on + // more, as we'd need to update the tests accordingly. + // @ts-expect-error Intentionally incomplete state + controllerEvents.tokensStateChange({ allDetectedTokens: {}, allTokens }); + } + + if (method === 'updateExchangeRates') { + await controller.updateExchangeRates(); + } else { + const { selectedAddress } = controller.config; + const tokens = allTokens[chainId]?.[selectedAddress] || []; + const tokenContractAddresses = tokens.map((token) => toHex(token.address)); + await controller.updateExchangeRatesByChainId({ + chainId, + nativeCurrency, + tokenContractAddresses, + }); + } +} + /** * Builds a mock token prices service. * diff --git a/packages/assets-controllers/src/TokenRatesController.ts b/packages/assets-controllers/src/TokenRatesController.ts index edb511a777c..7b42fbfd54f 100644 --- a/packages/assets-controllers/src/TokenRatesController.ts +++ b/packages/assets-controllers/src/TokenRatesController.ts @@ -335,11 +335,6 @@ export class TokenRatesController extends PollingControllerV1< nativeCurrency, tokenContractAddresses, }); - - this.update({ - contractExchangeRates: - this.state.contractExchangeRatesByChainId[chainId][nativeCurrency], - }); } /** @@ -369,20 +364,29 @@ export class TokenRatesController extends PollingControllerV1< nativeCurrency, }); + const existingContractExchangeRates = this.state.contractExchangeRates; + const updatedContractExchangeRates = + chainId === this.config.chainId && + nativeCurrency === this.config.nativeCurrency + ? newContractExchangeRates + : existingContractExchangeRates; + const existingContractExchangeRatesForChainId = this.state.contractExchangeRatesByChainId[chainId] ?? {}; - - this.update({ - contractExchangeRatesByChainId: { - ...this.state.contractExchangeRatesByChainId, - [chainId]: { - ...existingContractExchangeRatesForChainId, - [nativeCurrency]: { - ...existingContractExchangeRatesForChainId[nativeCurrency], - ...newContractExchangeRates, - }, + const updatedContractExchangeRatesForChainId = { + ...this.state.contractExchangeRatesByChainId, + [chainId]: { + ...existingContractExchangeRatesForChainId, + [nativeCurrency]: { + ...existingContractExchangeRatesForChainId[nativeCurrency], + ...newContractExchangeRates, }, }, + }; + + this.update({ + contractExchangeRates: updatedContractExchangeRates, + contractExchangeRatesByChainId: updatedContractExchangeRatesForChainId, }); }