diff --git a/__tests__/sequencerProvider.test.ts b/__tests__/sequencerProvider.test.ts index e9a97bbf9..bcd906162 100644 --- a/__tests__/sequencerProvider.test.ts +++ b/__tests__/sequencerProvider.test.ts @@ -1,5 +1,6 @@ -import { Contract, Provider, SequencerProvider, stark } from '../src'; -import { toBigInt } from '../src/utils/number'; +import { Contract, GatewayError, HttpError, Provider, SequencerProvider, stark } from '../src'; +import { getSelector } from '../src/utils/hash'; +import { getDecimalString, getHexStringArray, toBigInt } from '../src/utils/number'; import { encodeShortString } from '../src/utils/shortString'; import { compiledErc20, @@ -9,31 +10,41 @@ import { describeIfSequencer, getTestAccount, getTestProvider, + wrongClassHash, } from './fixtures'; describeIfSequencer('SequencerProvider', () => { const sequencerProvider = getTestProvider() as SequencerProvider; const account = getTestAccount(sequencerProvider); + const L1_ADDRESS = '0x8359E4B0152ed5A731162D3c7B0D8D56edB165A0'; + let customSequencerProvider: Provider; let exampleContractAddress: string; + let exampleTransactionHash: string; + let l1l2ContractAddress: string; - describe('Gateway specific methods', () => { - let exampleTransactionHash: string; + beforeAll(async () => { + const { deploy } = await account.declareAndDeploy({ + contract: compiledErc20, + constructorCalldata: [ + encodeShortString('Token'), + encodeShortString('ERC20'), + account.address, + ], + }); - beforeAll(async () => { - const { deploy } = await account.declareAndDeploy({ - contract: compiledErc20, - constructorCalldata: [ - encodeShortString('Token'), - encodeShortString('ERC20'), - account.address, - ], - }); + exampleTransactionHash = deploy.transaction_hash; + exampleContractAddress = deploy.contract_address; + }); - exampleTransactionHash = deploy.transaction_hash; - exampleContractAddress = deploy.contract_address; + beforeAll(async () => { + const { deploy } = await account.declareAndDeploy({ + contract: compiledL1L2, }); + l1l2ContractAddress = deploy.contract_address; + }); + describe('Gateway specific methods', () => { test('getTransactionStatus()', async () => { return expect( sequencerProvider.getTransactionStatus(exampleTransactionHash) @@ -61,16 +72,6 @@ describeIfSequencer('SequencerProvider', () => { }); describe('Test Estimate message fee', () => { - const L1_ADDRESS = '0x8359E4B0152ed5A731162D3c7B0D8D56edB165A0'; - let l1l2ContractAddress: string; - - beforeAll(async () => { - const { deploy } = await account.declareAndDeploy({ - contract: compiledL1L2, - }); - l1l2ContractAddress = deploy.contract_address; - }); - test('estimate message fee', async () => { const estimation = await sequencerProvider.estimateMessageFee( { @@ -92,6 +93,41 @@ describeIfSequencer('SequencerProvider', () => { }); }); + describe('Generic fetch', () => { + test('fetch http error', async () => { + expect(sequencerProvider.fetch('wrong')).rejects.toThrow(HttpError); + }); + + test('fetch gateway error', async () => { + const endpoint = '/feeder_gateway/get_class_by_hash'; + const query = `classHash=${wrongClassHash}`; + expect(sequencerProvider.fetch(`${endpoint}?${query}`)).rejects.toThrow(GatewayError); + }); + + test('fetch GET success', async () => { + const endpoint = '/feeder_gateway/get_transaction_status'; + const query = `transactionHash=${exampleTransactionHash}`; + const result = await sequencerProvider.fetch(`${endpoint}?${query}`); + expect(result).toHaveProperty('tx_status'); + }); + + test('fetch POST success', async () => { + const endpoint = '/feeder_gateway/estimate_message_fee'; + const query = `blockIdentifier=latest`; + const result = await sequencerProvider.fetch(`${endpoint}?${query}`, { + method: 'POST', + body: { + from_address: getDecimalString(L1_ADDRESS), + to_address: l1l2ContractAddress, + entry_point_selector: getSelector('deposit'), + payload: getHexStringArray(['556', '123']), + }, + parseAlwaysAsBigInt: true, + }); + expect(typeof result.overall_fee).toBe('bigint'); + }); + }); + describeIfDevnet('Test calls with Custom Devnet Sequencer Provider', () => { let erc20: Contract; const wallet = stark.randomAddress(); diff --git a/src/provider/sequencer.ts b/src/provider/sequencer.ts index 074bb7e68..86314ee54 100644 --- a/src/provider/sequencer.ts +++ b/src/provider/sequencer.ts @@ -46,9 +46,10 @@ import { ProviderInterface } from './interface'; import { Block, BlockIdentifier } from './utils'; type NetworkName = 'mainnet-alpha' | 'goerli-alpha' | 'goerli-alpha-2'; +type SequencerHttpMethod = 'POST' | 'GET'; export type SequencerProviderOptions = { - headers?: object; + headers?: Record; blockIdentifier?: BlockIdentifier; } & ( | { @@ -84,14 +85,14 @@ export class SequencerProvider implements ProviderInterface { public gatewayUrl: string; - public headers: object | undefined; + public headers?: Record; + + private blockIdentifier: BlockIdentifier; private chainId: StarknetChainId; private responseParser = new SequencerAPIResponseParser(); - private blockIdentifier: BlockIdentifier; - constructor(optionsOrProvider: SequencerProviderOptions = defaultOptions) { if ('network' in optionsOrProvider) { this.baseUrl = SequencerProvider.getNetworkFromName(optionsOrProvider.network); @@ -177,7 +178,7 @@ export class SequencerProvider implements ProviderInterface { return `?${queryString}`; } - private getHeaders(method: 'POST' | 'GET'): object | undefined { + private getHeaders(method: SequencerHttpMethod): Record | undefined { if (method === 'POST') { return { 'Content-Type': 'application/json', @@ -202,43 +203,52 @@ export class SequencerProvider implements ProviderInterface { const baseUrl = this.getFetchUrl(endpoint); const method = this.getFetchMethod(endpoint); const queryString = this.getQueryString(query); - const headers = this.getHeaders(method); const url = urljoin(baseUrl, endpoint, queryString); + return this.fetch(url, { + method, + body: request, + }); + } + + public async fetch( + endpoint: string, + options?: { + method?: SequencerHttpMethod; + body?: any; + parseAlwaysAsBigInt?: boolean; + } + ): Promise { + const url = buildUrl(this.baseUrl, '', endpoint); + const method = options?.method ?? 'GET'; + const headers = this.getHeaders(method); + try { - const res = await fetch(url, { + const response = await fetch(url, { method, - body: stringify(request), - headers: headers as Record, + body: stringify(options?.body), + headers, }); - const textResponse = await res.text(); - if (!res.ok) { - // This will allow user to handle contract errors + const textResponse = await response.text(); + + if (!response.ok) { + // This will allow the user to handle contract errors let responseBody: any; try { responseBody = parse(textResponse); } catch { - // if error parsing fails, return an http error - throw new HttpError(res.statusText, res.status); + throw new HttpError(response.statusText, response.status); } - - const errorCode = responseBody.code || ((responseBody as any)?.status_code as string); // starknet-devnet uses status_code instead of code; They need to fix that - throw new GatewayError(responseBody.message, errorCode); // Caught locally, and re-thrown for the user + throw new GatewayError(responseBody.message, responseBody.code); } - if (endpoint === 'estimate_fee') { - return parseAlwaysAsBig(textResponse); - } - return parse(textResponse) as Sequencer.Endpoints[T]['RESPONSE']; - } catch (err) { - // rethrow custom errors - if (err instanceof GatewayError || err instanceof HttpError) { - throw err; - } - if (err instanceof Error) { - throw Error(`Could not ${method} from endpoint \`${url}\`: ${err.message}`); - } - throw err; + const parseChoice = options?.parseAlwaysAsBigInt ? parseAlwaysAsBig : parse; + return parseChoice(textResponse); + } catch (error) { + if (error instanceof Error && !(error instanceof LibraryError)) + throw Error(`Could not ${method} from endpoint \`${url}\`: ${error.message}`); + + throw error; } }