diff --git a/app/scripts/lib/accounts/BalancesController.test.ts b/app/scripts/lib/accounts/BalancesController.test.ts index 6bfe0bfabfc6..036b29f1da76 100644 --- a/app/scripts/lib/accounts/BalancesController.test.ts +++ b/app/scripts/lib/accounts/BalancesController.test.ts @@ -14,7 +14,7 @@ import { defaultState, BalancesControllerMessenger, } from './BalancesController'; -import { Poller } from './Poller'; +import { BalancesTracker } from './BalancesTracker'; const mockBtcAccount = createMockInternalAccount({ address: '', @@ -54,8 +54,14 @@ const setupController = ({ const balancesControllerMessenger: BalancesControllerMessenger = controllerMessenger.getRestricted({ name: 'BalancesController', - allowedActions: ['SnapController:handleRequest'], - allowedEvents: ['AccountsController:stateChange'], + allowedActions: [ + 'SnapController:handleRequest', + 'AccountsController:listMultichainAccounts', + ], + allowedEvents: [ + 'AccountsController:accountAdded', + 'AccountsController:accountRemoved', + ], }); const mockSnapHandleRequest = jest.fn(); @@ -66,20 +72,22 @@ const setupController = ({ ), ); - // TODO: remove when listMultichainAccounts action is available - const mockListMultichainAccounts = jest - .fn() - .mockReturnValue(mocks?.listMultichainAccounts ?? [mockBtcAccount]); + const mockListMultichainAccounts = jest.fn(); + controllerMessenger.registerActionHandler( + 'AccountsController:listMultichainAccounts', + mockListMultichainAccounts.mockReturnValue( + mocks?.listMultichainAccounts ?? [mockBtcAccount], + ), + ); const controller = new BalancesController({ messenger: balancesControllerMessenger, state, - // TODO: remove when listMultichainAccounts action is available - listMultichainAccounts: mockListMultichainAccounts, }); return { controller, + messenger: controllerMessenger, mockSnapHandleRequest, mockListMultichainAccounts, }; @@ -91,19 +99,19 @@ describe('BalancesController', () => { expect(controller.state).toEqual({ balances: {} }); }); - it('starts polling when calling start', async () => { - const spyPoller = jest.spyOn(Poller.prototype, 'start'); + it('starts tracking when calling start', async () => { + const spyTracker = jest.spyOn(BalancesTracker.prototype, 'start'); const { controller } = setupController(); await controller.start(); - expect(spyPoller).toHaveBeenCalledTimes(1); + expect(spyTracker).toHaveBeenCalledTimes(1); }); - it('stops polling when calling stop', async () => { - const spyPoller = jest.spyOn(Poller.prototype, 'stop'); + it('stops tracking when calling stop', async () => { + const spyTracker = jest.spyOn(BalancesTracker.prototype, 'stop'); const { controller } = setupController(); await controller.start(); await controller.stop(); - expect(spyPoller).toHaveBeenCalledTimes(1); + expect(spyTracker).toHaveBeenCalledTimes(1); }); it('update balances when calling updateBalances', async () => { @@ -113,13 +121,49 @@ describe('BalancesController', () => { expect(controller.state).toEqual({ balances: { - [mockBtcAccount.id]: { - 'bip122:000000000933ea01ad0ee984209779ba/slip44:0': { - amount: '0.00000000', - unit: 'BTC', - }, + [mockBtcAccount.id]: mockBalanceResult, + }, + }); + }); + + it('update balances when "AccountsController:accountAdded" is fired', async () => { + const { controller, messenger, mockListMultichainAccounts } = + setupController({ + mocks: { + listMultichainAccounts: [], }, + }); + + controller.start(); + mockListMultichainAccounts.mockReturnValue([mockBtcAccount]); + messenger.publish('AccountsController:accountAdded', mockBtcAccount); + await controller.updateBalances(); + + expect(controller.state).toEqual({ + balances: { + [mockBtcAccount.id]: mockBalanceResult, }, }); }); + + it('update balances when "AccountsController:accountRemoved" is fired', async () => { + const { controller, messenger, mockListMultichainAccounts } = + setupController(); + + controller.start(); + await controller.updateBalances(); + expect(controller.state).toEqual({ + balances: { + [mockBtcAccount.id]: mockBalanceResult, + }, + }); + + messenger.publish('AccountsController:accountRemoved', mockBtcAccount.id); + mockListMultichainAccounts.mockReturnValue([]); + await controller.updateBalances(); + + expect(controller.state).toEqual({ + balances: {}, + }); + }); }); diff --git a/app/scripts/lib/accounts/BalancesController.ts b/app/scripts/lib/accounts/BalancesController.ts index ab1eb8c6cfe6..5e76c4ad5a7f 100644 --- a/app/scripts/lib/accounts/BalancesController.ts +++ b/app/scripts/lib/accounts/BalancesController.ts @@ -19,11 +19,12 @@ import type { SnapId } from '@metamask/snaps-sdk'; import { HandlerType } from '@metamask/snaps-utils'; import type { Draft } from 'immer'; import type { - AccountsControllerChangeEvent, - AccountsControllerState, + AccountsControllerAccountAddedEvent, + AccountsControllerAccountRemovedEvent, + AccountsControllerListMultichainAccountsAction, } from '@metamask/accounts-controller'; import { isBtcMainnetAddress } from '../../../../shared/lib/multichain'; -import { Poller } from './Poller'; +import { BalancesTracker } from './BalancesTracker'; const controllerName = 'BalancesController'; @@ -85,12 +86,16 @@ export type BalancesControllerEvents = BalancesControllerStateChange; /** * Actions that this controller is allowed to call. */ -export type AllowedActions = HandleSnapRequest; +export type AllowedActions = + | HandleSnapRequest + | AccountsControllerListMultichainAccountsAction; /** * Events that this controller is allowed to subscribe. */ -export type AllowedEvents = AccountsControllerChangeEvent; +export type AllowedEvents = + | AccountsControllerAccountAddedEvent + | AccountsControllerAccountRemovedEvent; /** * Messenger type for the BalancesController. @@ -119,7 +124,11 @@ const balancesControllerMetadata = { const BTC_TESTNET_ASSETS = ['bip122:000000000933ea01ad0ee984209779ba/slip44:0']; const BTC_MAINNET_ASSETS = ['bip122:000000000019d6689c085ae165831e93/slip44:0']; -export const BTC_AVG_BLOCK_TIME = 600000; // 10 minutes in milliseconds +const BTC_AVG_BLOCK_TIME = 10 * 60 * 1000; // 10 minutes in milliseconds + +// NOTE: We set an interval of half the average block time to mitigate when our interval +// is de-synchronized with the actual block time. +export const BALANCES_UPDATE_TIME = BTC_AVG_BLOCK_TIME / 2; /** * The BalancesController is responsible for fetching and caching account @@ -130,19 +139,14 @@ export class BalancesController extends BaseController< BalancesControllerState, BalancesControllerMessenger > { - #poller: Poller; - - // TODO: remove once action is implemented - #listMultichainAccounts: () => InternalAccount[]; + #tracker: BalancesTracker; constructor({ messenger, state, - listMultichainAccounts, }: { messenger: BalancesControllerMessenger; state: BalancesControllerState; - listMultichainAccounts: () => InternalAccount[]; }) { super({ messenger, @@ -154,27 +158,50 @@ export class BalancesController extends BaseController< }, }); - this.messagingSystem.subscribe( - 'AccountsController:stateChange', - (newState) => this.#handleOnAccountsControllerChange(newState), + this.#tracker = new BalancesTracker( + async (accountId: string) => await this.#updateBalance(accountId), ); - this.#listMultichainAccounts = listMultichainAccounts; - this.#poller = new Poller(() => this.updateBalances(), BTC_AVG_BLOCK_TIME); + // Register all non-EVM accounts into the tracker + for (const account of this.#listAccounts()) { + if (this.#isNonEvmAccount(account)) { + this.#tracker.track(account.id, BALANCES_UPDATE_TIME); + } + } + + this.messagingSystem.subscribe( + 'AccountsController:accountAdded', + (account) => this.#handleOnAccountAdded(account), + ); + this.messagingSystem.subscribe( + 'AccountsController:accountRemoved', + (account) => this.#handleOnAccountRemoved(account), + ); } /** * Starts the polling process. */ async start(): Promise { - this.#poller.start(); + this.#tracker.start(); } /** * Stops the polling process. */ async stop(): Promise { - this.#poller.stop(); + this.#tracker.stop(); + } + + /** + * Lists the multichain accounts coming from the `AccountsController`. + * + * @returns A list of multichain accounts. + */ + #listMultichainAccounts(): InternalAccount[] { + return this.messagingSystem.call( + 'AccountsController:listMultichainAccounts', + ); } /** @@ -185,50 +212,119 @@ export class BalancesController extends BaseController< * * @returns A list of accounts that we should get balances for. */ - async #listAccounts(): Promise { + #listAccounts(): InternalAccount[] { const accounts = this.#listMultichainAccounts(); return accounts.filter((account) => account.type === BtcAccountType.P2wpkh); } /** - * Updates the balances of all supported accounts. This method doesn't return + * Get a non-EVM account from its ID. + * + * @param accountId - The account ID. + */ + #getAccount(accountId: string): InternalAccount { + const account: InternalAccount = this.#listMultichainAccounts().find( + (multichainAccount) => multichainAccount.id === accountId, + ); + + if (!account) { + throw new Error(`Unknown account: ${accountId}`); + } + if (!this.#isNonEvmAccount(account)) { + throw new Error(`Account is not a non-EVM account: ${accountId}`); + } + return account; + } + + /** + * Updates the balances of one account. This method doesn't return * anything, but it updates the state of the controller. + * + * @param accountId - The account ID. */ - async updateBalances() { - const accounts = await this.#listAccounts(); + async #updateBalance(accountId: string) { + const account = this.#getAccount(accountId); const partialState: BalancesControllerState = { balances: {} }; - for (const account of accounts) { - if (account.metadata.snap) { - partialState.balances[account.id] = await this.#getBalances( - account.id, - account.metadata.snap.id, - isBtcMainnetAddress(account.address) - ? BTC_MAINNET_ASSETS - : BTC_TESTNET_ASSETS, - ); - } - } + partialState.balances[account.id] = await this.#getBalances( + account.id, + account.metadata.snap.id, + isBtcMainnetAddress(account.address) + ? BTC_MAINNET_ASSETS + : BTC_TESTNET_ASSETS, + ); this.update((state: Draft) => ({ ...state, - ...partialState, + balances: { + ...state.balances, + ...partialState.balances, + }, })); } /** - * Handles changes in the accounts state, specifically when new non-EVM accounts are added. + * Updates the balances of one account. This method doesn't return + * anything, but it updates the state of the controller. + * + * @param accountId - The account ID. + */ + async updateBalance(accountId: string) { + await this.#tracker.updateBalance(accountId); + } + + /** + * Updates the balances of all supported accounts. This method doesn't return + * anything, but it updates the state of the controller. + */ + async updateBalances() { + await this.#tracker.updateBalances(); + } + + /** + * Checks for non-EVM accounts. * - * @param newState - The new state of the accounts controller. + * @param account - The new account to be checked. + * @returns True if the account is a non-EVM account, false otherwise. */ - #handleOnAccountsControllerChange(newState: AccountsControllerState) { - // If we have any new non-EVM accounts, we just update non-EVM balances - const newNonEvmAccounts = Object.values( - newState.internalAccounts.accounts, - ).filter((account) => !isEvmAccountType(account.type)); - if (newNonEvmAccounts.length) { - this.updateBalances(); + #isNonEvmAccount(account: InternalAccount): boolean { + return ( + !isEvmAccountType(account.type) && + // Non-EVM accounts are backed by a Snap for now + account.metadata.snap + ); + } + + /** + * Handles changes when a new account has been added. + * + * @param account - The new account being added. + */ + async #handleOnAccountAdded(account: InternalAccount) { + if (!this.#isNonEvmAccount(account)) { + // Nothing to do here for EVM accounts + return; + } + + this.#tracker.track(account.id, BTC_AVG_BLOCK_TIME); + } + + /** + * Handles changes when a new account has been removed. + * + * @param accountId - The account ID being removed. + */ + async #handleOnAccountRemoved(accountId: string) { + if (this.#tracker.isTracked(accountId)) { + this.#tracker.untrack(accountId); + } + + if (accountId in this.state.balances) { + this.update((state: Draft) => { + delete state.balances[accountId]; + return state; + }); } } diff --git a/app/scripts/lib/accounts/BalancesTracker.test.ts b/app/scripts/lib/accounts/BalancesTracker.test.ts new file mode 100644 index 000000000000..27c862aa4460 --- /dev/null +++ b/app/scripts/lib/accounts/BalancesTracker.test.ts @@ -0,0 +1,121 @@ +import { BtcAccountType } from '@metamask/keyring-api'; +import { createMockInternalAccount } from '../../../../test/jest/mocks'; +import { Poller } from './Poller'; +import { BalancesTracker } from './BalancesTracker'; + +const MOCK_TIMESTAMP = 1709983353; + +const mockBtcAccount = createMockInternalAccount({ + address: '', + name: 'Btc Account', + // @ts-expect-error - account type may be btc or eth, mock file is not typed + type: BtcAccountType.P2wpkh, + // @ts-expect-error - snap options is not typed and defaults to undefined + snapOptions: { + id: 'mock-btc-snap', + name: 'mock-btc-snap', + enabled: true, + }, +}); + +function setupTracker() { + const mockUpdateBalance = jest.fn(); + const tracker = new BalancesTracker(mockUpdateBalance); + + return { + tracker, + mockUpdateBalance, + }; +} + +describe('BalancesTracker', () => { + it('starts polling when calling start', async () => { + const { tracker } = setupTracker(); + const spyPoller = jest.spyOn(Poller.prototype, 'start'); + + await tracker.start(); + expect(spyPoller).toHaveBeenCalledTimes(1); + }); + + it('stops polling when calling stop', async () => { + const { tracker } = setupTracker(); + const spyPoller = jest.spyOn(Poller.prototype, 'stop'); + + await tracker.start(); + await tracker.stop(); + expect(spyPoller).toHaveBeenCalledTimes(1); + }); + + it('is not tracking if none accounts have been registered', async () => { + const { tracker, mockUpdateBalance } = setupTracker(); + + await tracker.start(); + await tracker.updateBalances(); + + expect(mockUpdateBalance).not.toHaveBeenCalled(); + }); + + it('tracks account balances', async () => { + const { tracker, mockUpdateBalance } = setupTracker(); + + await tracker.start(); + // We must track account IDs explicitly + tracker.track(mockBtcAccount.id, 0); + // Trigger balances refresh (not waiting for the Poller here) + await tracker.updateBalances(); + + expect(mockUpdateBalance).toHaveBeenCalledWith(mockBtcAccount.id); + }); + + it('untracks account balances', async () => { + const { tracker, mockUpdateBalance } = setupTracker(); + + await tracker.start(); + tracker.track(mockBtcAccount.id, 0); + await tracker.updateBalances(); + expect(mockUpdateBalance).toHaveBeenCalledWith(mockBtcAccount.id); + + tracker.untrack(mockBtcAccount.id); + await tracker.updateBalances(); + expect(mockUpdateBalance).toHaveBeenCalledTimes(1); // No second call after untracking + }); + + it('tracks account after being registered', async () => { + const { tracker } = setupTracker(); + + await tracker.start(); + tracker.track(mockBtcAccount.id, 0); + expect(tracker.isTracked(mockBtcAccount.id)).toBe(true); + }); + + it('does not track account if not registered', async () => { + const { tracker } = setupTracker(); + + await tracker.start(); + expect(tracker.isTracked(mockBtcAccount.id)).toBe(false); + }); + + it('does not refresh balance if they are considered up-to-date', async () => { + const { tracker, mockUpdateBalance } = setupTracker(); + + const blockTime = 10 * 60 * 1000; // 10 minutes in milliseconds. + jest + .spyOn(global.Date, 'now') + .mockImplementation(() => new Date(MOCK_TIMESTAMP).getTime()); + + await tracker.start(); + tracker.track(mockBtcAccount.id, blockTime); + await tracker.updateBalances(); + expect(mockUpdateBalance).toHaveBeenCalledTimes(1); + + await tracker.updateBalances(); + expect(mockUpdateBalance).toHaveBeenCalledTimes(1); // No second call since the balances is already still up-to-date + + jest + .spyOn(global.Date, 'now') + .mockImplementation(() => new Date(MOCK_TIMESTAMP + blockTime).getTime()); + + await tracker.updateBalances(); + expect(mockUpdateBalance).toHaveBeenCalledTimes(2); // Now the balance will update + }); +}); diff --git a/app/scripts/lib/accounts/BalancesTracker.ts b/app/scripts/lib/accounts/BalancesTracker.ts new file mode 100644 index 000000000000..48ecd6f84cca --- /dev/null +++ b/app/scripts/lib/accounts/BalancesTracker.ts @@ -0,0 +1,122 @@ +import { Poller } from './Poller'; + +type BalanceInfo = { + lastUpdated: number; + blockTime: number; +}; + +const BALANCES_TRACKING_INTERVAL = 30 * 1000; // Every 30s in milliseconds. + +export class BalancesTracker { + #poller: Poller; + + #updateBalance: (accountId: string) => Promise; + + #balances: Record = {}; + + constructor(updateBalanceCallback: (accountId: string) => Promise) { + this.#updateBalance = updateBalanceCallback; + + this.#poller = new Poller(() => { + this.updateBalances(); + }, BALANCES_TRACKING_INTERVAL); + } + + /** + * Starts the tracking process. + */ + async start(): Promise { + this.#poller.start(); + } + + /** + * Stops the tracking process. + */ + async stop(): Promise { + this.#poller.stop(); + } + + /** + * Checks if an account ID is being tracked. + * + * @param accountId - The account ID. + * @returns True if the account is being tracker, false otherwise. + */ + isTracked(accountId: string) { + return accountId in this.#balances; + } + + /** + * Asserts that an account ID is being tracked. + * + * @param accountId - The account ID. + * @throws If the account ID is not being tracked. + */ + assertBeingTracked(accountId: string) { + if (!this.isTracked(accountId)) { + throw new Error(`Account is not being tracked: ${accountId}`); + } + } + + /** + * Starts tracking a new account ID. This method has no effect on already tracked + * accounts. + * + * @param accountId - The account ID. + * @param blockTime - The block time (used when refreshing the account balances). + */ + track(accountId: string, blockTime: number) { + // Do not overwrite current info if already being tracked! + if (!this.isTracked(accountId)) { + this.#balances[accountId] = { + lastUpdated: 0, + blockTime, + }; + } + } + + /** + * Stops tracking a tracked account ID. + * + * @param accountId - The account ID. + * @throws If the account ID is not being tracked. + */ + untrack(accountId: string) { + this.assertBeingTracked(accountId); + delete this.#balances[accountId]; + } + + /** + * Update the balances for a tracked account ID. + * + * @param accountId - The account ID. + * @throws If the account ID is not being tracked. + */ + async updateBalance(accountId: string) { + this.assertBeingTracked(accountId); + + // We check if the balance is outdated (by comparing to the block time associated + // with this kind of account). + // + // This might not be super accurate, but we could probably compute this differently + // and try to sync with the "real block time"! + const info = this.#balances[accountId]; + const isOutdated = Date.now() - info.lastUpdated >= info.blockTime; + if (isOutdated) { + await this.#updateBalance(accountId); + this.#balances[accountId].lastUpdated = Date.now(); + } + } + + /** + * Update the balances of all tracked accounts (only if the balances + * is considered outdated). + */ + async updateBalances() { + await Promise.allSettled( + Object.keys(this.#balances).map(async (accountId) => { + await this.updateBalance(accountId); + }), + ); + } +} diff --git a/app/scripts/metamask-controller.js b/app/scripts/metamask-controller.js index 59dfc9d96540..3d739072d7cb 100644 --- a/app/scripts/metamask-controller.js +++ b/app/scripts/metamask-controller.js @@ -876,18 +876,19 @@ export default class MetamaskController extends EventEmitter { const multichainBalancesControllerMessenger = this.controllerMessenger.getRestricted({ name: 'BalancesController', - allowedEvents: ['AccountsController:stateChange'], - allowedActions: ['SnapController:handleRequest'], + allowedEvents: [ + 'AccountsController:accountAdded', + 'AccountsController:accountRemoved', + ], + allowedActions: [ + 'AccountsController:listMultichainAccounts', + 'SnapController:handleRequest', + ], }); this.multichainBalancesController = new MultichainBalancesController({ messenger: multichainBalancesControllerMessenger, - state: {}, - // TODO: remove when listMultichainAccounts action is available - listMultichainAccounts: - this.accountsController.listMultichainAccounts.bind( - this.accountsController, - ), + state: initState.MultichainBalancesController, }); const multichainRatesControllerMessenger = @@ -3749,6 +3750,13 @@ export default class MetamaskController extends EventEmitter { ), setName: this.nameController.setName.bind(this.nameController), + // MultichainBalancesController + multichainUpdateBalance: (accountId) => + this.multichainBalancesController.updateBalance(accountId), + + multichainUpdateBalances: () => + this.multichainBalancesController.updateBalances(), + // Transaction Decode decodeTransactionData: (request) => decodeTransactionData({ diff --git a/app/scripts/metamask-controller.test.js b/app/scripts/metamask-controller.test.js index 1ad12736fe2a..3aeac2866219 100644 --- a/app/scripts/metamask-controller.test.js +++ b/app/scripts/metamask-controller.test.js @@ -41,8 +41,9 @@ import { ETH_EOA_METHODS } from '../../shared/constants/eth-methods'; import { createMockInternalAccount } from '../../test/jest/mocks'; import { BalancesController as MultichainBalancesController, - BTC_AVG_BLOCK_TIME, + BALANCES_UPDATE_TIME as MULTICHAIN_BALANCES_UPDATE_TIME, } from './lib/accounts/BalancesController'; +import { BalancesTracker as MultichainBalancesTracker } from './lib/accounts/BalancesTracker'; import { deferredPromise } from './lib/util'; import MetaMaskController from './metamask-controller'; @@ -2240,12 +2241,31 @@ describe('MetaMaskController', () => { type: BtcAccountType.P2wpkh, methods: [BtcMethod.SendMany], address: 'bc1qar0srrr7xfkvy5l643lydnw9re59gtzzwf5mdq', + // We need to have a "Snap account" account here, since the MultichainBalancesController will + // filter it out otherwise! + metadata: { + name: 'Bitcoin Account', + importTime: Date.now(), + keyring: { + type: KeyringType.snap, + }, + snap: { + id: 'npm:@metamask/bitcoin-wallet-snap', + }, + }, }; let localMetamaskController; + let spyBalancesTrackerUpdateBalance; beforeEach(() => { jest.useFakeTimers(); jest.spyOn(MultichainBalancesController.prototype, 'updateBalances'); + jest + .spyOn(MultichainBalancesController.prototype, 'updateBalance') + .mockResolvedValue(); + spyBalancesTrackerUpdateBalance = jest + .spyOn(MultichainBalancesTracker.prototype, 'updateBalance') + .mockResolvedValue(); localMetamaskController = new MetaMaskController({ showUserConfirmation: noop, encryptor: mockEncryptor, @@ -2284,11 +2304,30 @@ describe('MetaMaskController', () => { }); it('calls updateBalances after the interval has passed', async () => { - jest.advanceTimersByTime(BTC_AVG_BLOCK_TIME); - // 2 calls because 1 is during startup + // 1st call is during startup: + // updatesBalances is going to call updateBalance for the only non-EVM + // account that we have expect( localMetamaskController.multichainBalancesController.updateBalances, - ).toHaveBeenCalledTimes(2); + ).toHaveBeenCalledTimes(1); + expect(spyBalancesTrackerUpdateBalance).toHaveBeenCalledTimes(1); + expect(spyBalancesTrackerUpdateBalance).toHaveBeenCalledWith( + mockNonEvmAccount.id, + ); + + // Wait for "block time", so balances will have to be refreshed + jest.advanceTimersByTime(MULTICHAIN_BALANCES_UPDATE_TIME); + + // Check that we tried to fetch the balances more than once + // NOTE: For now, this method might be called a lot more than just twice, but this + // method has some internal logic to prevent fetching the balance too often if we + // consider the balance to be "up-to-date" + expect( + spyBalancesTrackerUpdateBalance.mock.calls.length, + ).toBeGreaterThan(1); + expect(spyBalancesTrackerUpdateBalance).toHaveBeenLastCalledWith( + mockNonEvmAccount.id, + ); }); }); }); diff --git a/ui/components/multichain/create-btc-account/create-btc-account.test.tsx b/ui/components/multichain/create-btc-account/create-btc-account.test.tsx index 23fbe9226707..6c64bb7d553b 100644 --- a/ui/components/multichain/create-btc-account/create-btc-account.test.tsx +++ b/ui/components/multichain/create-btc-account/create-btc-account.test.tsx @@ -33,12 +33,22 @@ const mockBtcAccount = { methods: [BtcMethod.SendMany], }; const mockBitcoinWalletSnapSend = jest.fn().mockReturnValue(mockBtcAccount); +const mockMultichainUpdateBalance = jest.fn().mockReturnValue({ + [mockBtcAccount.address]: { + [`${MultichainNetworks.BITCOIN_TESTNET}/slip44:0`]: { + amount: '0.00000000', + unit: 'BTC', + }, + }, +}); const mockSetAccountLabel = jest.fn().mockReturnValue({ type: 'TYPE' }); jest.mock('../../../store/actions', () => ({ forceUpdateMetamaskState: jest.fn(), setAccountLabel: (address: string, label: string) => mockSetAccountLabel(address, label), + multichainUpdateBalance: (accountId: string) => + mockMultichainUpdateBalance(accountId), })); jest.mock( @@ -85,6 +95,7 @@ describe('CreateBtcAccount', () => { newAccountName, ), ); + await waitFor(() => expect(mockMultichainUpdateBalance).toHaveBeenCalled()); await waitFor(() => expect(onActionComplete).toHaveBeenCalled()); }); diff --git a/ui/components/multichain/create-btc-account/create-btc-account.tsx b/ui/components/multichain/create-btc-account/create-btc-account.tsx index 29c7b8f345b3..98b4619daeca 100644 --- a/ui/components/multichain/create-btc-account/create-btc-account.tsx +++ b/ui/components/multichain/create-btc-account/create-btc-account.tsx @@ -7,6 +7,7 @@ import { BitcoinWalletSnapSender } from '../../../../app/scripts/lib/snap-keyrin import { setAccountLabel, forceUpdateMetamaskState, + multichainUpdateBalance, } from '../../../store/actions'; type CreateBtcAccountOptions = { @@ -51,7 +52,21 @@ export const CreateBtcAccount = ({ dispatch(setAccountLabel(account.address, name)); } + // This will close up the name dialog await onActionComplete(true); + + // Force update the balances + try { + await multichainUpdateBalance(account.id); + } catch (error) { + // To avoid breaking the flow entirely, we do catch any error that might happens while fetching + // the balance. + // Worst case scenario, the balance will be updated during a future tick of the + // MultichainBalancesTracker! + console.warn( + `Unable to fetch Bitcoin balance: ${(error as Error).message}`, + ); + } }; const getNextAvailableAccountName = async (_accounts: InternalAccount[]) => { diff --git a/ui/store/actions.ts b/ui/store/actions.ts index 7493ee03ba3f..e62b00a5dced 100644 --- a/ui/store/actions.ts +++ b/ui/store/actions.ts @@ -5660,3 +5660,15 @@ export async function decodeTransactionData({ }, ]); } + +export async function multichainUpdateBalance( + accountId: string, +): Promise { + return await submitRequestToBackground('multichainUpdateBalance', [ + accountId, + ]); +} + +export async function multichainUpdateBalances(): Promise { + return await submitRequestToBackground('multichainUpdateBalances', []); +}