diff --git a/packages/assets-controllers/CHANGELOG.md b/packages/assets-controllers/CHANGELOG.md index da35c77070..2c875368ce 100644 --- a/packages/assets-controllers/CHANGELOG.md +++ b/packages/assets-controllers/CHANGELOG.md @@ -5,6 +5,15 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Added +- `TokenListController` now exports a `TokenListControllerMessenger` type ([#3609](https://github.com/MetaMask/core/pull/3609)). +- `TokenDetectionController` exports types `TokenDetectionControllerMessenger`, `TokenDetectionControllerActions`, `TokenDetectionControllerGetStateAction`, `TokenDetectionControllerEvents`, `TokenDetectionControllerStateChangeEvent` ([#3609](https://github.com/MetaMask/core/pull/3609)). +- Add `enable` and `disable` methods to `TokenDetectionController`, which control whether the controller is able to make polling requests or all of its network calls are blocked. ([#3609](https://github.com/MetaMask/core/pull/3609)). + - Note that if the controller is initiated without the `disabled` constructor option set to `false`, the `enable` method will need to be called before the controller can make polling requests in response to subscribed events. + +### Changed +- **BREAKING:** `TokenDetectionController` is upgraded to extend `BaseControllerV2` and `StaticIntervalPollingController` ([#3609](https://github.com/MetaMask/core/pull/3609)). + - The constructor now expects an options object as its only argument, with required properties `messenger`, `networkClientId`, required callbacks `onPreferencesStateChange`, `getBalancesInSingleCall`, `addDetectedTokens`, `getTokenState`, `getPreferencesState`, and optional properties `disabled`, `interval`, `selectedAddress`. ## [22.0.0] ### Changed diff --git a/packages/assets-controllers/jest.config.js b/packages/assets-controllers/jest.config.js index 13bec0c96f..6d35a70f3a 100644 --- a/packages/assets-controllers/jest.config.js +++ b/packages/assets-controllers/jest.config.js @@ -17,10 +17,10 @@ module.exports = merge(baseConfig, { // An object that configures minimum threshold enforcement for coverage results coverageThreshold: { global: { - branches: 88.2, - functions: 95.95, - lines: 96.25, - statements: 96.5, + branches: 88.36, + functions: 97.08, + lines: 97.23, + statements: 97.28, }, }, diff --git a/packages/assets-controllers/src/TokenDetectionController.test.ts b/packages/assets-controllers/src/TokenDetectionController.test.ts index 6c73d4ba26..0bdb82b31e 100644 --- a/packages/assets-controllers/src/TokenDetectionController.test.ts +++ b/packages/assets-controllers/src/TokenDetectionController.test.ts @@ -1,3 +1,4 @@ +import type { AddApprovalRequest } from '@metamask/approval-controller'; import { ControllerMessenger } from '@metamask/base-controller'; import { ChainId, @@ -8,7 +9,6 @@ import { } from '@metamask/controller-utils'; import { defaultState as defaultNetworkState } from '@metamask/network-controller'; import type { - NetworkControllerStateChangeEvent, NetworkState, ProviderConfig, } from '@metamask/network-controller'; @@ -22,20 +22,23 @@ import { advanceTime } from '../../../tests/helpers'; import type { AssetsContractController } from './AssetsContractController'; import { formatAggregatorNames, - isTokenDetectionSupportedForNetwork, SupportedTokenDetectionNetworks, } from './assetsUtil'; import { TOKEN_END_POINT_API } from './token-service'; -import { TokenDetectionController } from './TokenDetectionController'; -import { TokenListController } from './TokenListController'; import type { - GetTokenListState, - TokenListStateChange, - TokenListToken, -} from './TokenListController'; + AllowedActions, + AllowedEvents, + TokenDetectionControllerMessenger, +} from './TokenDetectionController'; +import { + TokenDetectionController, + controllerName, +} from './TokenDetectionController'; +import { TokenListController } from './TokenListController'; +import type { TokenListToken } from './TokenListController'; import type { Token } from './TokenRatesController'; -import { TokensController } from './TokensController'; import type { TokensControllerMessenger } from './TokensController'; +import { TokensController } from './TokensController'; const DEFAULT_INTERVAL = 180000; @@ -96,17 +99,26 @@ const sampleTokenB: Token = { }; type MainControllerMessenger = ControllerMessenger< - GetTokenListState, - TokenListStateChange | NetworkControllerStateChangeEvent + AllowedActions | AddApprovalRequest, + AllowedEvents >; -const getControllerMessenger = (): MainControllerMessenger => { +/** + * Returns a new `MainControllerMessenger` instance that can be used to create restricted messengers. + * @returns The new `MainControllerMessenger` instance. + */ +function getControllerMessenger(): MainControllerMessenger { return new ControllerMessenger(); -}; +} -const setupTokenListController = ( +/** + * Sets up a `TokenListController` and its restricted messenger. + * @param controllerMessenger - The main controller messenger. + * @returns An object containing the TokenListController and its restricted messenger. + */ +function setupTokenListController( controllerMessenger: MainControllerMessenger, -) => { +) { const tokenListMessenger = controllerMessenger.getRestricted({ name: 'TokenListController', allowedActions: [], @@ -120,7 +132,29 @@ const setupTokenListController = ( }); return { tokenList, tokenListMessenger }; -}; +} + +/** + * Builds a messenger that `TokenDetectionController` can use to communicate with other controllers. + * @param controllerMessenger - The main controller messenger. + * @returns The restricted messenger. + */ +function buildTokenDetectionControllerMessenger( + controllerMessenger: MainControllerMessenger = getControllerMessenger(), +): TokenDetectionControllerMessenger { + return controllerMessenger.getRestricted({ + name: controllerName, + allowedActions: [ + 'NetworkController:getNetworkConfigurationByNetworkClientId', + 'TokenListController:getState', + ], + allowedEvents: [ + 'NetworkController:stateChange', + 'NetworkController:networkDidChange', + 'TokenListController:stateChange', + ], + }); +} describe('TokenDetectionController', () => { let tokenDetection: TokenDetectionController; @@ -134,6 +168,7 @@ describe('TokenDetectionController', () => { >; const onNetworkDidChangeListeners: ((state: NetworkState) => void)[] = []; + const getNetworkConfigurationByNetworkClientIdHandler = jest.fn(); const changeNetwork = (providerConfig: ProviderConfig) => { onNetworkDidChangeListeners.forEach((listener) => { listener({ @@ -141,9 +176,35 @@ describe('TokenDetectionController', () => { providerConfig, }); }); + + controllerMessenger.publish('NetworkController:networkDidChange', { + ...defaultNetworkState, + providerConfig, + selectedNetworkClientId: providerConfig.type, + }); + + getNetworkConfigurationByNetworkClientIdHandler.mockReturnValue({ + chainId: providerConfig.chainId, + }); + + controllerMessenger.unregisterActionHandler( + 'NetworkController:getNetworkConfigurationByNetworkClientId', + ); + controllerMessenger.registerActionHandler( + 'NetworkController:getNetworkConfigurationByNetworkClientId', + getNetworkConfigurationByNetworkClientIdHandler, + ); + }; + + const goerli = { + chainId: ChainId.goerli, + id: 'goerli', + type: NetworkType.goerli, + ticker: NetworksTicker.goerli, }; const mainnet = { chainId: ChainId.mainnet, + id: 'mainnet', type: NetworkType.mainnet, ticker: NetworksTicker.mainnet, }; @@ -174,6 +235,18 @@ describe('TokenDetectionController', () => { // eslint-disable-next-line @typescript-eslint/no-explicit-any .callsFake(() => null as any); + controllerMessenger.publish('NetworkController:networkDidChange', { + ...defaultNetworkState, + providerConfig: mainnet, + selectedNetworkClientId: NetworkType.mainnet, + }); + controllerMessenger.registerActionHandler( + `NetworkController:getNetworkConfigurationByNetworkClientId`, + getNetworkConfigurationByNetworkClientIdHandler.mockReturnValue({ + chainId: ChainId.mainnet, + }), + ); + tokensController = new TokensController({ chainId: ChainId.mainnet, onPreferencesStateChange: (listener) => preferences.subscribe(listener), @@ -192,30 +265,15 @@ describe('TokenDetectionController', () => { getBalancesInSingleCall = sinon.stub(); tokenDetection = new TokenDetectionController({ + networkClientId: NetworkType.mainnet, onPreferencesStateChange: (listener) => preferences.subscribe(listener), - onNetworkStateChange: (listener) => - onNetworkDidChangeListeners.push(listener), - onTokenListStateChange: (listener) => - tokenListSetup.tokenListMessenger.subscribe( - `TokenListController:stateChange`, - listener, - ), getBalancesInSingleCall: getBalancesInSingleCall as unknown as AssetsContractController['getBalancesInSingleCall'], addDetectedTokens: tokensController.addDetectedTokens.bind(tokensController), getTokensState: () => tokensController.state, - getTokenListState: () => tokenList.state, - getNetworkState: () => defaultNetworkState, getPreferencesState: () => preferences.state, - getNetworkClientById: jest.fn().mockReturnValueOnce({ - configuration: { - chainId: ChainId.mainnet, - }, - provider: {}, - blockTracker: {}, - destroy: jest.fn(), - }), + messenger: buildTokenDetectionControllerMessenger(controllerMessenger), }); sinon @@ -227,28 +285,12 @@ describe('TokenDetectionController', () => { sinon.restore(); tokenDetection.stop(); tokenList.destroy(); - controllerMessenger.clearEventSubscriptions( - 'NetworkController:stateChange', - ); - }); - - it('should set default config', () => { - expect(tokenDetection.config).toStrictEqual({ - interval: DEFAULT_INTERVAL, - selectedAddress: '', - disabled: true, - chainId: ChainId.mainnet, - isDetectionEnabledForNetwork: true, - isDetectionEnabledFromPreferences: true, - }); }); it('should poll and detect tokens on interval while on supported networks', async () => { await new Promise(async (resolve) => { const mockTokens = sinon.stub(tokenDetection, 'detectTokens'); - tokenDetection.configure({ - interval: 10, - }); + tokenDetection.setIntervalLength(10); await tokenDetection.start(); expect(mockTokens.calledOnce).toBe(true); @@ -259,31 +301,8 @@ describe('TokenDetectionController', () => { }); }); - it('should detect supported networks correctly', () => { - tokenDetection.configure({ - chainId: SupportedTokenDetectionNetworks.mainnet, - }); - - expect( - isTokenDetectionSupportedForNetwork(tokenDetection.config.chainId), - ).toBe(true); - tokenDetection.configure({ chainId: SupportedTokenDetectionNetworks.bsc }); - expect( - isTokenDetectionSupportedForNetwork(tokenDetection.config.chainId), - ).toBe(true); - tokenDetection.configure({ chainId: ChainId.goerli }); - expect( - isTokenDetectionSupportedForNetwork(tokenDetection.config.chainId), - ).toBe(false); - }); - it('should not autodetect while not on supported networks', async () => { - tokenDetection.configure({ - selectedAddress: '0x1', - chainId: ChainId.goerli, - isDetectionEnabledForNetwork: false, - }); - + changeNetwork(goerli); getBalancesInSingleCall.resolves({ [sampleTokenA.address]: new BN(1), }); @@ -302,14 +321,13 @@ describe('TokenDetectionController', () => { expect(tokensController.state.detectedTokens).toStrictEqual([sampleTokenA]); }); - it('should detect tokens correctly on the Aurora network', async () => { - const auroraMainnet = { - chainId: ChainId.aurora, - type: NetworkType.mainnet, - ticker: 'Aurora ETH', - }; - preferences.update({ selectedAddress: '0x1' }); - changeNetwork(auroraMainnet); + it('should detect tokens correctly on the Polygon network', async () => { + preferences.update({ selectedAddress: '0x2' }); + changeNetwork({ + chainId: SupportedTokenDetectionNetworks.polygon, + type: NetworkType.rpc, + ticker: NetworksTicker.rpc, + }); getBalancesInSingleCall.resolves({ [sampleTokenA.address]: new BN(1), @@ -438,168 +456,137 @@ describe('TokenDetectionController', () => { expect(tokensController.state.detectedTokens).toStrictEqual([sampleTokenA]); }); - it('should not call getBalancesInSingleCall after stopping polling, and then switching between networks that support token detection', async () => { - const polygonDecimalChainId = '137'; - nock(TOKEN_END_POINT_API) - .get(getTokensPath(toHex(polygonDecimalChainId))) - .reply(200, sampleTokenList); - - const stub = sinon.stub(); - const getBalancesInSingleCallMock = sinon.stub(); - // TODO: Replace `any` with type - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let networkStateChangeListener: (state: any) => void; - const onNetworkStateChange = sinon.stub().callsFake((listener) => { - networkStateChangeListener = listener; + describe('getBalancesInSingleCall', () => { + let stub: sinon.SinonStub; + let getBalancesInSingleCallMock: sinon.SinonStub< + Parameters, + ReturnType + >; + beforeEach(async () => { + stub = sinon.stub(); + getBalancesInSingleCallMock = sinon.stub(); + + controllerMessenger = getControllerMessenger(); + controllerMessenger.publish('NetworkController:networkDidChange', { + ...defaultNetworkState, + providerConfig: mainnet, + selectedNetworkClientId: NetworkType.mainnet, + }); + controllerMessenger.registerActionHandler( + `NetworkController:getNetworkConfigurationByNetworkClientId`, + getNetworkConfigurationByNetworkClientIdHandler.mockReturnValue({ + chainId: ChainId.mainnet, + }), + ); + + const tokenListSetup = setupTokenListController(controllerMessenger); + tokenList = tokenListSetup.tokenList; + await tokenList.start(); }); - tokenDetection = new TokenDetectionController( - { - onTokenListStateChange: stub, + it('should not be called after stopping polling, and then switching between networks that support token detection', async () => { + const polygonDecimalChainId = '137'; + nock(TOKEN_END_POINT_API) + .get(getTokensPath(toHex(polygonDecimalChainId))) + .reply(200, sampleTokenList); + + tokenDetection = new TokenDetectionController({ + networkClientId: NetworkType.mainnet, + selectedAddress: '0x1', onPreferencesStateChange: stub, - onNetworkStateChange, getBalancesInSingleCall: getBalancesInSingleCallMock, - addDetectedTokens: stub, + addDetectedTokens: + tokensController.addDetectedTokens.bind(tokensController), getTokensState: () => tokensController.state, - getTokenListState: () => tokenList.state, - getNetworkState: () => defaultNetworkState, getPreferencesState: () => preferences.state, - getNetworkClientById: jest.fn(), - }, - { - disabled: false, - isDetectionEnabledForNetwork: true, - isDetectionEnabledFromPreferences: true, - selectedAddress: '0x1', - chainId: ChainId.mainnet, - }, - ); - - await tokenDetection.start(); + messenger: buildTokenDetectionControllerMessenger(controllerMessenger), + }); + await tokenDetection.start(); - expect(getBalancesInSingleCallMock.called).toBe(true); - getBalancesInSingleCallMock.reset(); + expect(getBalancesInSingleCallMock.called).toBe(true); + getBalancesInSingleCallMock.reset(); - tokenDetection.stop(); - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - await networkStateChangeListener!({ - providerConfig: { chainId: toHex(polygonDecimalChainId) }, + tokenDetection.stop(); + changeNetwork({ + chainId: toHex(polygonDecimalChainId), + type: NetworkType.rpc, + ticker: 'MATIC', + }); + expect(getBalancesInSingleCallMock.called).toBe(false); }); - expect(getBalancesInSingleCallMock.called).toBe(false); - }); - - it('should not call getBalancesInSingleCall if onTokenListStateChange is called with an empty token list', async () => { - const stub = sinon.stub(); - const getBalancesInSingleCallMock = sinon.stub(); - // TODO: Replace `any` with type - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let tokenListStateChangeListener: (state: any) => void; - const onTokenListStateChange = sinon.stub().callsFake((listener) => { - tokenListStateChangeListener = listener; - }); - tokenDetection = new TokenDetectionController( - { - onTokenListStateChange, + it('should not be called if TokenListController is updated to have an empty token list', async () => { + tokenDetection = new TokenDetectionController({ + networkClientId: NetworkType.mainnet, onPreferencesStateChange: stub, - onNetworkStateChange: stub, getBalancesInSingleCall: getBalancesInSingleCallMock, - addDetectedTokens: stub, - getTokensState: stub, - getTokenListState: stub, - getNetworkState: () => defaultNetworkState, + addDetectedTokens: + tokensController.addDetectedTokens.bind(tokensController), + getTokensState: () => tokensController.state, getPreferencesState: () => preferences.state, - getNetworkClientById: jest.fn(), - }, - { - disabled: false, - isDetectionEnabledForNetwork: true, - isDetectionEnabledFromPreferences: true, - }, - ); + messenger: buildTokenDetectionControllerMessenger(controllerMessenger), + }); - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - await tokenListStateChangeListener!({ tokenList: {} }); + tokenList.clearingTokenListData(); + expect(getBalancesInSingleCallMock.called).toBe(false); + }); - expect(getBalancesInSingleCallMock.called).toBe(false); - }); + it('should be called if onPreferencesStateChange is called with useTokenDetection being true and selectedAddress is changed', async () => { + // TODO: Replace `any` with type + // eslint-disable-next-line @typescript-eslint/no-explicit-any + let preferencesStateChangeListener: (state: any) => void; + const onPreferencesStateChange = sinon.stub().callsFake((listener) => { + preferencesStateChangeListener = listener; + }); - it('should call getBalancesInSingleCall if onPreferencesStateChange is called with useTokenDetection being true and is changed', async () => { - const stub = sinon.stub(); - const getBalancesInSingleCallMock = sinon.stub(); - // TODO: Replace `any` with type - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let preferencesStateChangeListener: (state: any) => void; - const onPreferencesStateChange = sinon.stub().callsFake((listener) => { - preferencesStateChangeListener = listener; - }); - tokenDetection = new TokenDetectionController( - { + tokenDetection = new TokenDetectionController({ + disabled: false, + networkClientId: NetworkType.mainnet, + selectedAddress: '0x1', onPreferencesStateChange, - onTokenListStateChange: stub, - onNetworkStateChange: stub, getBalancesInSingleCall: getBalancesInSingleCallMock, - addDetectedTokens: stub, + addDetectedTokens: + tokensController.addDetectedTokens.bind(tokensController), getTokensState: () => tokensController.state, - getTokenListState: () => tokenList.state, - getNetworkState: () => defaultNetworkState, getPreferencesState: () => preferences.state, - getNetworkClientById: jest.fn(), - }, - { - disabled: false, - isDetectionEnabledForNetwork: true, - isDetectionEnabledFromPreferences: false, - selectedAddress: '0x1', - }, - ); + messenger: buildTokenDetectionControllerMessenger(controllerMessenger), + }); - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - await preferencesStateChangeListener!({ - selectedAddress: '0x1', - useTokenDetection: true, + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + preferencesStateChangeListener!({ + selectedAddress: '0x2', + useTokenDetection: true, + }); + expect(getBalancesInSingleCallMock.calledOnce).toBe(true); }); - expect(getBalancesInSingleCallMock.called).toBe(true); - }); - - it('should call getBalancesInSingleCall if onNetworkStateChange is called with a chainId that supports token detection and is changed', async () => { - const stub = sinon.stub(); - const getBalancesInSingleCallMock = sinon.stub(); - // TODO: Replace `any` with type - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let networkStateChangeListener: (state: any) => void; - const onNetworkStateChange = sinon.stub().callsFake((listener) => { - networkStateChangeListener = listener; - }); - tokenDetection = new TokenDetectionController( - { - onNetworkStateChange, - onTokenListStateChange: stub, + it('should be called if network is changed to a chainId that supports token detection', async () => { + tokenDetection = new TokenDetectionController({ + disabled: false, + networkClientId: 'polygon', + selectedAddress: '0x1', onPreferencesStateChange: stub, getBalancesInSingleCall: getBalancesInSingleCallMock, - addDetectedTokens: stub, + addDetectedTokens: + tokensController.addDetectedTokens.bind(tokensController), getTokensState: () => tokensController.state, - getTokenListState: () => tokenList.state, - getNetworkState: () => defaultNetworkState, getPreferencesState: () => preferences.state, - getNetworkClientById: jest.fn(), - }, - { - disabled: false, - isDetectionEnabledFromPreferences: true, - chainId: SupportedTokenDetectionNetworks.polygon, - isDetectionEnabledForNetwork: true, - selectedAddress: '0x1', - }, - ); + messenger: buildTokenDetectionControllerMessenger(controllerMessenger), + }); - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - await networkStateChangeListener!({ - providerConfig: { chainId: ChainId.mainnet }, + controllerMessenger.unregisterActionHandler( + 'NetworkController:getNetworkConfigurationByNetworkClientId', + ); + controllerMessenger.registerActionHandler( + 'NetworkController:getNetworkConfigurationByNetworkClientId', + getNetworkConfigurationByNetworkClientIdHandler.mockReturnValue({ + chainId: SupportedTokenDetectionNetworks.polygon, + }), + ); + + changeNetwork(mainnet); + expect(getBalancesInSingleCallMock.calledOnce).toBe(true); }); - - expect(getBalancesInSingleCallMock.called).toBe(true); }); describe('startPollingByNetworkClientId', () => { @@ -618,6 +605,7 @@ describe('TokenDetectionController', () => { .mockImplementation(() => { return Promise.resolve(); }); + tokenDetection.enable(); tokenDetection.startPollingByNetworkClientId('mainnet', { address: '0x1', }); @@ -650,15 +638,13 @@ describe('TokenDetectionController', () => { describe('detectTokens', () => { it('should detect and add tokens by networkClientId correctly', async () => { - const selectedAddress = '0x1'; - tokenDetection.configure({ - disabled: false, - }); + const selectedAddress = '0x2'; getBalancesInSingleCall.resolves({ [sampleTokenA.address]: new BN(1), }); + tokenDetection.enable(); await tokenDetection.detectTokens({ - networkClientId: 'mainnet', + networkClientId: NetworkType.mainnet, accountAddress: selectedAddress, }); const tokens = diff --git a/packages/assets-controllers/src/TokenDetectionController.ts b/packages/assets-controllers/src/TokenDetectionController.ts index 30a3249f18..19c4d90871 100644 --- a/packages/assets-controllers/src/TokenDetectionController.ts +++ b/packages/assets-controllers/src/TokenDetectionController.ts @@ -1,242 +1,284 @@ -import type { BaseConfig, BaseState } from '@metamask/base-controller'; +import type { + RestrictedControllerMessenger, + ControllerGetStateAction, + ControllerStateChangeEvent, +} from '@metamask/base-controller'; import { safelyExecute, toChecksumHexAddress, } from '@metamask/controller-utils'; import type { NetworkClientId, - NetworkController, - NetworkState, + NetworkControllerNetworkDidChangeEvent, + NetworkControllerStateChangeEvent, + NetworkControllerGetNetworkConfigurationByNetworkClientId, } from '@metamask/network-controller'; -import { StaticIntervalPollingControllerV1 } from '@metamask/polling-controller'; +import { StaticIntervalPollingController } from '@metamask/polling-controller'; import type { PreferencesState } from '@metamask/preferences-controller'; import type { Hex } from '@metamask/utils'; import type { AssetsContractController } from './AssetsContractController'; import { isTokenDetectionSupportedForNetwork } from './assetsUtil'; -import type { TokenListState } from './TokenListController'; +import type { + GetTokenListState, + TokenListStateChange, +} from './TokenListController'; import type { Token } from './TokenRatesController'; import type { TokensController, TokensState } from './TokensController'; const DEFAULT_INTERVAL = 180000; +export const controllerName = 'TokenDetectionController'; + +export type TokenDetectionState = Record; + +export type TokenDetectionControllerGetStateAction = ControllerGetStateAction< + typeof controllerName, + TokenDetectionState +>; + +export type TokenDetectionControllerActions = + TokenDetectionControllerGetStateAction; + +export type AllowedActions = + | NetworkControllerGetNetworkConfigurationByNetworkClientId + | GetTokenListState; + +export type TokenDetectionControllerStateChangeEvent = + ControllerStateChangeEvent; + +export type TokenDetectionControllerEvents = + TokenDetectionControllerStateChangeEvent; + +export type AllowedEvents = + | NetworkControllerStateChangeEvent + | NetworkControllerNetworkDidChangeEvent + | TokenListStateChange; + +export type TokenDetectionControllerMessenger = RestrictedControllerMessenger< + typeof controllerName, + TokenDetectionControllerActions | AllowedActions, + TokenDetectionControllerEvents | AllowedEvents, + AllowedActions['type'], + AllowedEvents['type'] +>; + /** - * @type TokenDetectionConfig - * - * TokenDetection configuration - * @property interval - Polling interval used to fetch new token rates - * @property selectedAddress - Vault selected address + * Controller that passively polls on a set interval for Tokens auto detection + * @property intervalId - Polling interval used to fetch new token rates * @property chainId - The chain ID of the current network + * @property selectedAddress - Vault selected address + * @property networkClientId - The network client ID of the current selected network + * @property disabled - Boolean to track if network requests are blocked * @property isDetectionEnabledFromPreferences - Boolean to track if detection is enabled from PreferencesController * @property isDetectionEnabledForNetwork - Boolean to track if detected is enabled for current network */ -// This interface was created before this ESLint rule was added. -// Convert to a `type` in a future major version. -// eslint-disable-next-line @typescript-eslint/consistent-type-definitions -export interface TokenDetectionConfig extends BaseConfig { - interval: number; - selectedAddress: string; - chainId: Hex; - isDetectionEnabledFromPreferences: boolean; - isDetectionEnabledForNetwork: boolean; -} - -/** - * Controller that passively polls on a set interval for Tokens auto detection - */ -export class TokenDetectionController extends StaticIntervalPollingControllerV1< - TokenDetectionConfig, - BaseState +export class TokenDetectionController extends StaticIntervalPollingController< + typeof controllerName, + TokenDetectionState, + TokenDetectionControllerMessenger > { - private intervalId?: ReturnType; + #intervalId?: ReturnType; - /** - * Name of this controller used during composition - */ - override name = 'TokenDetectionController'; + #chainId: Hex; + + #selectedAddress: string; + + #networkClientId: NetworkClientId; + + #disabled: boolean; - private readonly getBalancesInSingleCall: AssetsContractController['getBalancesInSingleCall']; + #isDetectionEnabledFromPreferences: boolean; - private readonly addDetectedTokens: TokensController['addDetectedTokens']; + #isDetectionEnabledForNetwork: boolean; - private readonly getTokensState: () => TokensState; + readonly #addDetectedTokens: TokensController['addDetectedTokens']; - private readonly getTokenListState: () => TokenListState; + readonly #getBalancesInSingleCall: AssetsContractController['getBalancesInSingleCall']; - private readonly getNetworkClientById: NetworkController['getNetworkClientById']; + readonly #getTokensState: () => TokensState; /** * Creates a TokenDetectionController instance. * * @param options - The controller options. + * @param options.messenger - The controller messaging system. + * @param options.disabled - If set to true, all network requests are blocked. + * @param options.interval - Polling interval used to fetch new token rates + * @param options.networkClientId - The selected network client ID of the current network + * @param options.selectedAddress - Vault selected address * @param options.onPreferencesStateChange - Allows subscribing to preferences controller state changes. - * @param options.onNetworkStateChange - Allows subscribing to network controller state changes. - * @param options.onTokenListStateChange - Allows subscribing to token list controller state changes. - * @param options.getBalancesInSingleCall - Gets the balances of a list of tokens for the given address. * @param options.addDetectedTokens - Add a list of detected tokens. - * @param options.getTokenListState - Gets the current state of the TokenList controller. + * @param options.getBalancesInSingleCall - Gets the balances of a list of tokens for the given address. * @param options.getTokensState - Gets the current state of the Tokens controller. - * @param options.getNetworkState - Gets the state of the network controller. * @param options.getPreferencesState - Gets the state of the preferences controller. - * @param options.getNetworkClientById - Gets the network client by ID. - * @param config - Initial options used to configure this controller. - * @param state - Initial state to set on this controller. */ - constructor( - { - onPreferencesStateChange, - onNetworkStateChange, - onTokenListStateChange, - getBalancesInSingleCall, - addDetectedTokens, - getTokenListState, - getTokensState, - getNetworkState, - getPreferencesState, - getNetworkClientById, - }: { - onPreferencesStateChange: ( - listener: (preferencesState: PreferencesState) => void, - ) => void; - onNetworkStateChange: ( - listener: (networkState: NetworkState) => void, - ) => void; - onTokenListStateChange: ( - listener: (tokenListState: TokenListState) => void, - ) => void; - getBalancesInSingleCall: AssetsContractController['getBalancesInSingleCall']; - addDetectedTokens: TokensController['addDetectedTokens']; - getTokenListState: () => TokenListState; - getTokensState: () => TokensState; - getNetworkState: () => NetworkState; - getPreferencesState: () => PreferencesState; - getNetworkClientById: NetworkController['getNetworkClientById']; - }, - config?: Partial, - state?: Partial, - ) { - const { - providerConfig: { chainId: defaultChainId }, - } = getNetworkState(); + constructor({ + networkClientId, + selectedAddress = '', + interval = DEFAULT_INTERVAL, + disabled = true, + onPreferencesStateChange, + getBalancesInSingleCall, + addDetectedTokens, + getPreferencesState, + getTokensState, + messenger, + }: { + networkClientId: NetworkClientId; + selectedAddress?: string; + interval?: number; + disabled?: boolean; + onPreferencesStateChange: ( + listener: (preferencesState: PreferencesState) => void, + ) => void; + addDetectedTokens: TokensController['addDetectedTokens']; + getBalancesInSingleCall: AssetsContractController['getBalancesInSingleCall']; + getTokensState: () => TokensState; + getPreferencesState: () => PreferencesState; + messenger: TokenDetectionControllerMessenger; + }) { const { useTokenDetection: defaultUseTokenDetection } = getPreferencesState(); - super(config, state); - this.defaultConfig = { - interval: DEFAULT_INTERVAL, - selectedAddress: '', - disabled: true, - chainId: defaultChainId, - isDetectionEnabledFromPreferences: defaultUseTokenDetection, - isDetectionEnabledForNetwork: - isTokenDetectionSupportedForNetwork(defaultChainId), - ...config, - }; - - this.initialize(); - this.setIntervalLength(this.config.interval); - this.getTokensState = getTokensState; - this.getTokenListState = getTokenListState; - this.addDetectedTokens = addDetectedTokens; - this.getBalancesInSingleCall = getBalancesInSingleCall; - this.getNetworkClientById = getNetworkClientById; - - onTokenListStateChange(({ tokenList }) => { - const hasTokens = Object.keys(tokenList).length; - - if (hasTokens) { - this.detectTokens(); - } + super({ + name: controllerName, + messenger, + state: {}, + metadata: {}, }); - onPreferencesStateChange(({ selectedAddress, useTokenDetection }) => { - const { - selectedAddress: currentSelectedAddress, - isDetectionEnabledFromPreferences, - } = this.config; - const isSelectedAddressChanged = - selectedAddress !== currentSelectedAddress; - const isDetectionChangedFromPreferences = - isDetectionEnabledFromPreferences !== useTokenDetection; - - this.configure({ - isDetectionEnabledFromPreferences: useTokenDetection, - selectedAddress, - }); + this.#disabled = disabled; + this.setIntervalLength(interval); - if ( - useTokenDetection && - (isSelectedAddressChanged || isDetectionChangedFromPreferences) - ) { - this.detectTokens(); - } - }); + this.#networkClientId = networkClientId; + this.#selectedAddress = selectedAddress; + this.#chainId = this.#getCorrectChainId(networkClientId); - onNetworkStateChange(({ providerConfig: { chainId } }) => { - const { chainId: currentChainId } = this.config; - const isDetectionEnabledForNetwork = - isTokenDetectionSupportedForNetwork(chainId); - const isChainIdChanged = currentChainId !== chainId; + this.#isDetectionEnabledFromPreferences = defaultUseTokenDetection; + this.#isDetectionEnabledForNetwork = isTokenDetectionSupportedForNetwork( + this.#chainId, + ); - this.configure({ - chainId, - isDetectionEnabledForNetwork, - }); + this.#addDetectedTokens = addDetectedTokens; + this.#getBalancesInSingleCall = getBalancesInSingleCall; + this.#getTokensState = getTokensState; - if (isDetectionEnabledForNetwork && isChainIdChanged) { - this.detectTokens(); - } - }); + this.messagingSystem.subscribe( + 'TokenListController:stateChange', + async ({ tokenList }) => { + const hasTokens = Object.keys(tokenList).length; + + if (hasTokens) { + await this.detectTokens(); + } + }, + ); + + onPreferencesStateChange( + async ({ selectedAddress: newSelectedAddress, useTokenDetection }) => { + const isSelectedAddressChanged = + this.#selectedAddress !== newSelectedAddress; + const isDetectionChangedFromPreferences = + this.#isDetectionEnabledFromPreferences !== useTokenDetection; + + this.#selectedAddress = newSelectedAddress; + this.#isDetectionEnabledFromPreferences = useTokenDetection; + + if ( + useTokenDetection && + (isSelectedAddressChanged || isDetectionChangedFromPreferences) + ) { + await this.detectTokens(); + } + }, + ); + + this.messagingSystem.subscribe( + 'NetworkController:networkDidChange', + async ({ selectedNetworkClientId }) => { + this.#networkClientId = selectedNetworkClientId; + const newChainId = this.#getCorrectChainId(selectedNetworkClientId); + const isChainIdChanged = this.#chainId !== newChainId; + this.#chainId = newChainId; + + this.#isDetectionEnabledForNetwork = + isTokenDetectionSupportedForNetwork(newChainId); + + if (this.#isDetectionEnabledForNetwork && isChainIdChanged) { + await this.detectTokens(); + } + }, + ); + } + + /** + * Allows controller to make active and passive polling requests + */ + enable() { + this.#disabled = false; + } + + /** + * Blocks controller from making network calls + */ + disable() { + this.#disabled = true; } /** * Start polling for detected tokens. */ async start() { - this.configure({ disabled: false }); - await this.startPolling(); + this.enable(); + await this.#startPolling(); } /** * Stop polling for detected tokens. */ stop() { - this.configure({ disabled: true }); - this.stopPolling(); + this.disable(); + this.#stopPolling(); } - private stopPolling() { - if (this.intervalId) { - clearInterval(this.intervalId); + #stopPolling() { + if (this.#intervalId) { + clearInterval(this.#intervalId); } } /** * Starts a new polling interval. - * - * @param interval - An interval on which to poll. */ - private async startPolling(interval?: number): Promise { - interval && this.configure({ interval }, false, false); - this.stopPolling(); + async #startPolling(): Promise { + if (this.#disabled) { + return; + } + this.#stopPolling(); await this.detectTokens(); - this.intervalId = setInterval(async () => { + this.#intervalId = setInterval(async () => { await this.detectTokens(); - }, this.config.interval); + }, this.getIntervalLength()); } - private getCorrectChainId(networkClientId?: NetworkClientId) { - if (networkClientId) { - return this.getNetworkClientById(networkClientId).configuration.chainId; - } - return this.config.chainId; + #getCorrectChainId(networkClientId?: NetworkClientId) { + const { chainId } = + this.messagingSystem.call( + 'NetworkController:getNetworkConfigurationByNetworkClientId', + networkClientId ?? this.#networkClientId, + ) ?? {}; + return chainId ?? this.#chainId; } - _executePoll( + async _executePoll( networkClientId: string, options: { address: string }, ): Promise { - return this.detectTokens({ + if (this.#disabled) { + return; + } + await this.detectTokens({ networkClientId, accountAddress: options.address, }); @@ -249,31 +291,30 @@ export class TokenDetectionController extends StaticIntervalPollingControllerV1< * @param options.networkClientId - The ID of the network client to use. * @param options.accountAddress - The account address to use. */ - async detectTokens(options?: { + async detectTokens({ + networkClientId, + accountAddress, + }: { networkClientId?: NetworkClientId; accountAddress?: string; - }) { - const { networkClientId, accountAddress } = options || {}; - const { - disabled, - isDetectionEnabledForNetwork, - isDetectionEnabledFromPreferences, - } = this.config; + } = {}): Promise { if ( - disabled || - !isDetectionEnabledForNetwork || - !isDetectionEnabledFromPreferences + this.#disabled || + !this.#isDetectionEnabledForNetwork || + !this.#isDetectionEnabledFromPreferences ) { return; } - const { tokens } = this.getTokensState(); - const selectedAddress = accountAddress || this.config.selectedAddress; - const chainId = this.getCorrectChainId(networkClientId); + const { tokens } = this.#getTokensState(); + const selectedAddress = accountAddress || this.#selectedAddress; + const chainId = this.#getCorrectChainId(networkClientId); const tokensAddresses = tokens.map( /* istanbul ignore next*/ (token) => token.address.toLowerCase(), ); - const { tokenList } = this.getTokenListState(); + const { tokenList } = this.messagingSystem.call( + 'TokenListController:getState', + ); const tokensToDetect: string[] = []; for (const address of Object.keys(tokenList)) { if (!tokensAddresses.includes(address)) { @@ -298,7 +339,7 @@ export class TokenDetectionController extends StaticIntervalPollingControllerV1< } await safelyExecute(async () => { - const balances = await this.getBalancesInSingleCall( + const balances = await this.#getBalancesInSingleCall( selectedAddress, tokensSlice, ); @@ -306,7 +347,7 @@ export class TokenDetectionController extends StaticIntervalPollingControllerV1< for (const tokenAddress of Object.keys(balances)) { let ignored; /* istanbul ignore else */ - const { ignoredTokens } = this.getTokensState(); + const { ignoredTokens } = this.#getTokensState(); if (ignoredTokens.length) { ignored = ignoredTokens.find( (ignoredTokenAddress) => @@ -316,7 +357,7 @@ export class TokenDetectionController extends StaticIntervalPollingControllerV1< const caseInsensitiveTokenKey = Object.keys(tokenList).find( (i) => i.toLowerCase() === tokenAddress.toLowerCase(), - ) || ''; + ) ?? ''; if (ignored === undefined) { const { decimals, symbol, aggregators, iconUrl, name } = @@ -334,7 +375,7 @@ export class TokenDetectionController extends StaticIntervalPollingControllerV1< } if (tokensToAdd.length) { - await this.addDetectedTokens(tokensToAdd, { + await this.#addDetectedTokens(tokensToAdd, { selectedAddress, chainId, }); diff --git a/packages/assets-controllers/src/TokenListController.ts b/packages/assets-controllers/src/TokenListController.ts index ba89c200c9..7b519aeb16 100644 --- a/packages/assets-controllers/src/TokenListController.ts +++ b/packages/assets-controllers/src/TokenListController.ts @@ -68,12 +68,14 @@ export type TokenListControllerActions = GetTokenListState; type AllowedActions = NetworkControllerGetNetworkClientByIdAction; -type TokenListMessenger = RestrictedControllerMessenger< +type AllowedEvents = NetworkControllerStateChangeEvent; + +export type TokenListControllerMessenger = RestrictedControllerMessenger< typeof name, TokenListControllerActions | AllowedActions, - TokenListControllerEvents | NetworkControllerStateChangeEvent, + TokenListControllerEvents | AllowedEvents, AllowedActions['type'], - (TokenListControllerEvents | NetworkControllerStateChangeEvent)['type'] + AllowedEvents['type'] >; const metadata = { @@ -94,7 +96,7 @@ const defaultState: TokenListState = { export class TokenListController extends StaticIntervalPollingController< typeof name, TokenListState, - TokenListMessenger + TokenListControllerMessenger > { private readonly mutex = new Mutex(); @@ -136,7 +138,7 @@ export class TokenListController extends StaticIntervalPollingController< ) => void; interval?: number; cacheRefreshThreshold?: number; - messenger: TokenListMessenger; + messenger: TokenListControllerMessenger; state?: Partial; }) { super({ diff --git a/packages/assets-controllers/src/index.ts b/packages/assets-controllers/src/index.ts index 7cf4296247..fc1a068472 100644 --- a/packages/assets-controllers/src/index.ts +++ b/packages/assets-controllers/src/index.ts @@ -4,7 +4,14 @@ export * from './CurrencyRateController'; export * from './NftController'; export * from './NftDetectionController'; export * from './TokenBalancesController'; -export * from './TokenDetectionController'; +export type { + TokenDetectionControllerMessenger, + TokenDetectionControllerActions, + TokenDetectionControllerGetStateAction, + TokenDetectionControllerEvents, + TokenDetectionControllerStateChangeEvent, +} from './TokenDetectionController'; +export { TokenDetectionController } from './TokenDetectionController'; export * from './TokenListController'; export * from './TokenRatesController'; export * from './TokensController';