diff --git a/app/util/importAdditionalAccounts.js b/app/util/importAdditionalAccounts.js deleted file mode 100644 index 9b015ce661b..00000000000 --- a/app/util/importAdditionalAccounts.js +++ /dev/null @@ -1,60 +0,0 @@ -import Engine from '../core/Engine'; -import { BNToHex } from '../util/number'; -import Logger from '../util/Logger'; -import ExtendedKeyringTypes from '../../app/constants/keyringTypes'; - -const HD_KEY_TREE_ERROR = 'MetamaskController - No HD Key Tree found'; -const ZERO_BALANCE = '0x0'; -const MAX = 20; - -/** - * Get an account balance from the network. - * @param {string} address - The account address - * @param {EthQuery} ethQuery - The EthQuery instance to use when asking the network - */ -const getBalance = async (address, ethQuery) => - new Promise((resolve, reject) => { - ethQuery.getBalance(address, (error, balance) => { - if (error) { - reject(error); - Logger.error(error); - } else { - const balanceHex = BNToHex(balance); - resolve(balanceHex || ZERO_BALANCE); - } - }); - }); - -/** - * Add additional accounts in the wallet based on balance - */ -export default async () => { - const { KeyringController } = Engine.context; - - const ethQuery = Engine.getGlobalEthQuery(); - let accounts = await KeyringController.getAccounts(); - let lastBalance = await getBalance(accounts[accounts.length - 1], ethQuery); - - const { keyrings } = KeyringController.state; - const filteredKeyrings = keyrings.filter( - (keyring) => keyring.type === ExtendedKeyringTypes.hd, - ); - const primaryKeyring = filteredKeyrings[0]; - if (!primaryKeyring) throw new Error(HD_KEY_TREE_ERROR); - - let i = 0; - // seek out the first zero balance - while (lastBalance !== ZERO_BALANCE) { - if (i === MAX) break; - await KeyringController.addNewAccountWithoutUpdate(primaryKeyring); - accounts = await KeyringController.getAccounts(); - lastBalance = await getBalance(accounts[accounts.length - 1], ethQuery); - i++; - } - - // remove extra zero balance account potentially created from seeking ahead - if (accounts.length > 1 && lastBalance === ZERO_BALANCE) { - await KeyringController.removeAccount(accounts[accounts.length - 1]); - accounts = await KeyringController.getAccounts(); - } -}; diff --git a/app/util/importAdditionalAccounts.test.ts b/app/util/importAdditionalAccounts.test.ts new file mode 100644 index 00000000000..bc72c072edd --- /dev/null +++ b/app/util/importAdditionalAccounts.test.ts @@ -0,0 +1,111 @@ +import importAdditionalAccounts from './importAdditionalAccounts'; +import { BN } from 'ethereumjs-util'; + +const mockKeyring = { + addAccounts: jest.fn(), + removeAccount: jest.fn(), +}; + +const mockEthQuery = { + getBalance: jest.fn(), +}; + +jest.mock('../core/Engine', () => ({ + context: { + KeyringController: { + withKeyring: jest.fn((_keyring, callback) => callback(mockKeyring)), + }, + }, + getGlobalEthQuery: () => mockEthQuery, +})); + +/** + * Set the balance that will be queried for the account + * + * @param balance - The balance to be queried + */ +function setQueriedBalance(balance: BN) { + mockEthQuery.getBalance.mockImplementation((_, callback) => + callback(null, balance), + ); +} + +/** + * Set the balance that will be queried for the account once + * + * @param balance - The balance to be queried + */ +function setQueriedBalanceOnce(balance: BN) { + mockEthQuery.getBalance.mockImplementationOnce((_, callback) => { + callback(null, balance); + }); +} + +describe('importAdditionalAccounts', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + describe('when there is no account with balance', () => { + it('should not add any account', async () => { + setQueriedBalance(new BN(0)); + mockKeyring.addAccounts.mockResolvedValue(['0x1234']); + + await importAdditionalAccounts(); + + expect(mockKeyring.addAccounts).toHaveBeenCalledTimes(1); + expect(mockKeyring.removeAccount).toHaveBeenCalledTimes(1); + expect(mockKeyring.removeAccount).toHaveBeenCalledWith('0x1234'); + }); + }); + + describe('when there is an account with balance', () => { + it('should add 1 account', async () => { + setQueriedBalanceOnce(new BN(1)); + setQueriedBalanceOnce(new BN(0)); + mockKeyring.addAccounts + .mockResolvedValueOnce(['0x1234']) + .mockResolvedValueOnce(['0x5678']); + + await importAdditionalAccounts(); + + expect(mockKeyring.addAccounts).toHaveBeenCalledTimes(2); + expect(mockKeyring.removeAccount).toHaveBeenCalledWith('0x5678'); + }); + }); + + describe('when there are multiple accounts with balance', () => { + it('should add 2 accounts', async () => { + setQueriedBalanceOnce(new BN(1)); + setQueriedBalanceOnce(new BN(2)); + setQueriedBalanceOnce(new BN(0)); + mockKeyring.addAccounts + .mockResolvedValueOnce(['0x1234']) + .mockResolvedValueOnce(['0x5678']) + .mockResolvedValueOnce(['0x9abc']); + + await importAdditionalAccounts(); + + expect(mockKeyring.addAccounts).toHaveBeenCalledTimes(3); + expect(mockKeyring.removeAccount).toHaveBeenCalledWith('0x9abc'); + }); + }); + + describe('when ethQuery.getBalance throws an error', () => { + it('should not remove all the accounts', async () => { + setQueriedBalanceOnce(new BN(1)); + mockEthQuery.getBalance.mockImplementationOnce((_, callback) => + callback(new Error('error')), + ); + mockKeyring.addAccounts + .mockResolvedValueOnce(['0x1234']) + .mockResolvedValueOnce(['0x5678']); + + await importAdditionalAccounts(); + + expect(mockKeyring.addAccounts).toHaveBeenCalledTimes(2); + expect(mockKeyring.removeAccount).toHaveBeenCalledTimes(1); + expect(mockKeyring.removeAccount).toHaveBeenCalledWith('0x5678'); + }); + }); +}); diff --git a/app/util/importAdditionalAccounts.ts b/app/util/importAdditionalAccounts.ts new file mode 100644 index 00000000000..8da7200d9f4 --- /dev/null +++ b/app/util/importAdditionalAccounts.ts @@ -0,0 +1,60 @@ +import Engine from '../core/Engine'; +import { BNToHex } from '../util/number'; +import Logger from '../util/Logger'; +import ExtendedKeyringTypes from '../../app/constants/keyringTypes'; +import type EthQuery from '@metamask/eth-query'; +import type { BN } from 'ethereumjs-util'; +import { Hex } from '@metamask/utils'; + +const ZERO_BALANCE = '0x0'; +const MAX = 20; + +/** + * Get an account balance from the network. + * @param address - The account address + * @param ethQuery - The EthQuery instance to use when asking the network + */ +const getBalance = async (address: string, ethQuery: EthQuery): Promise => + new Promise((resolve, reject) => { + ethQuery.getBalance(address, (error: Error, balance: BN) => { + if (error) { + reject(error); + Logger.error(error); + } else { + const balanceHex = BNToHex(balance); + resolve(balanceHex || ZERO_BALANCE); + } + }); + }); + +/** + * Add additional accounts in the wallet based on balance + */ +export default async () => { + const { KeyringController } = Engine.context; + const ethQuery = Engine.getGlobalEthQuery(); + + await KeyringController.withKeyring( + { type: ExtendedKeyringTypes.hd }, + async (primaryKeyring) => { + for (let i = 0; i < MAX; i++) { + const [newAccount] = await primaryKeyring.addAccounts(1); + + let newAccountBalance = ZERO_BALANCE; + try { + newAccountBalance = await getBalance(newAccount, ethQuery); + } catch (error) { + // Errors are gracefully handled so that `withKeyring` + // will not rollback the primary keyring, and accounts + // created in previous loop iterations will remain in place. + } + + if (newAccountBalance === ZERO_BALANCE) { + // remove extra zero balance account we just added and break the loop + primaryKeyring.removeAccount?.(newAccount); + break; + } + } + }, + ); +};