diff --git a/web3.js/src/connection.ts b/web3.js/src/connection.ts index 7794cbf3e65225..0dd6cd4c458c5f 100644 --- a/web3.js/src/connection.ts +++ b/web3.js/src/connection.ts @@ -88,6 +88,10 @@ type ClientSubscriptionId = number; /** @internal */ type ServerSubscriptionId = number; /** @internal */ type SubscriptionConfigHash = string; /** @internal */ type SubscriptionDisposeFn = () => Promise; +/** @internal */ type SubscriptionStateChangeCallback = ( + nextState: StatefulSubscription['state'], +) => void; +/** @internal */ type SubscriptionStateChangeDisposeFn = () => void; /** * @internal * Every subscription contains the args used to open the subscription with @@ -2715,6 +2719,16 @@ export class Connection { | SubscriptionDisposeFn | undefined; } = {}; + /** @internal */ private _subscriptionHashByClientSubscriptionId: { + [clientSubscriptionId: ClientSubscriptionId]: + | SubscriptionConfigHash + | undefined; + } = {}; + /** @internal */ private _subscriptionStateChangeCallbacksByHash: { + [hash: SubscriptionConfigHash]: + | Set + | undefined; + } = {}; /** @internal */ private _subscriptionCallbacksByServerSubscriptionId: { [serverSubscriptionId: ServerSubscriptionId]: | Set @@ -3372,7 +3386,10 @@ export class Connection { const subscriptionCommitment = commitment || this.commitment; let timeoutId; - let subscriptionId; + let signatureSubscriptionId: number | undefined; + let disposeSignatureSubscriptionStateChangeObserver: + | SubscriptionStateChangeDisposeFn + | undefined; let done = false; const confirmationPromise = new Promise<{ @@ -3380,10 +3397,10 @@ export class Connection { response: RpcResponseAndContext; }>((resolve, reject) => { try { - subscriptionId = this.onSignature( + signatureSubscriptionId = this.onSignature( rawSignature, (result: SignatureResult, context: Context) => { - subscriptionId = undefined; + signatureSubscriptionId = undefined; const response = { context, value: result, @@ -3393,6 +3410,46 @@ export class Connection { }, subscriptionCommitment, ); + const subscriptionSetupPromise = new Promise( + resolveSubscriptionSetup => { + if (signatureSubscriptionId == null) { + resolveSubscriptionSetup(); + } else { + disposeSignatureSubscriptionStateChangeObserver = + this._onSubscriptionStateChange( + signatureSubscriptionId, + nextState => { + if (nextState === 'subscribed') { + resolveSubscriptionSetup(); + } + }, + ); + } + }, + ); + (async () => { + await subscriptionSetupPromise; + if (done) return; + const response = await this.getSignatureStatus(rawSignature); + if (done) return; + if (response == null) { + return; + } + const {context, value} = response; + if (value?.err) { + reject(value.err); + } + if (value) { + done = true; + resolve({ + __type: TransactionStatus.PROCESSED, + response: { + context, + value, + }, + }); + } + })(); } catch (err) { reject(err); } @@ -3465,8 +3522,11 @@ export class Connection { } } finally { clearTimeout(timeoutId); - if (subscriptionId) { - this.removeSignatureListener(subscriptionId); + if (disposeSignatureSubscriptionStateChangeObserver) { + disposeSignatureSubscriptionStateChangeObserver(); + } + if (signatureSubscriptionId) { + this.removeSignatureListener(signatureSubscriptionId); } } return result; @@ -5106,13 +5166,60 @@ export class Connection { Object.entries( this._subscriptionsByHash as Record, ).forEach(([hash, subscription]) => { - this._subscriptionsByHash[hash] = { + this._setSubscription(hash, { ...subscription, state: 'pending', - }; + }); }); } + /** + * @internal + */ + private _setSubscription( + hash: SubscriptionConfigHash, + nextSubscription: Subscription, + ) { + const prevState = this._subscriptionsByHash[hash]?.state; + this._subscriptionsByHash[hash] = nextSubscription; + if (prevState !== nextSubscription.state) { + const stateChangeCallbacks = + this._subscriptionStateChangeCallbacksByHash[hash]; + if (stateChangeCallbacks) { + stateChangeCallbacks.forEach(cb => { + try { + cb(nextSubscription.state); + // eslint-disable-next-line no-empty + } catch {} + }); + } + } + } + + /** + * @internal + */ + private _onSubscriptionStateChange( + clientSubscriptionId: ClientSubscriptionId, + callback: SubscriptionStateChangeCallback, + ): SubscriptionStateChangeDisposeFn { + const hash = + this._subscriptionHashByClientSubscriptionId[clientSubscriptionId]; + if (hash == null) { + return () => {}; + } + const stateChangeCallbacks = (this._subscriptionStateChangeCallbacksByHash[ + hash + ] ||= new Set()); + stateChangeCallbacks.add(callback); + return () => { + stateChangeCallbacks.delete(callback); + if (stateChangeCallbacks.size === 0) { + delete this._subscriptionStateChangeCallbacksByHash[hash]; + } + }; + } + /** * @internal */ @@ -5193,17 +5300,17 @@ export class Connection { await (async () => { const {args, method} = subscription; try { - this._subscriptionsByHash[hash] = { + this._setSubscription(hash, { ...subscription, state: 'subscribing', - }; + }); const serverSubscriptionId: ServerSubscriptionId = (await this._rpcWebSocket.call(method, args)) as number; - this._subscriptionsByHash[hash] = { + this._setSubscription(hash, { ...subscription, serverSubscriptionId, state: 'subscribed', - }; + }); this._subscriptionCallbacksByServerSubscriptionId[ serverSubscriptionId ] = subscription.callbacks; @@ -5220,10 +5327,10 @@ export class Connection { return; } // TODO: Maybe add an 'errored' state or a retry limit? - this._subscriptionsByHash[hash] = { + this._setSubscription(hash, { ...subscription, state: 'pending', - }; + }); await this._updateSubscriptions(); } })(); @@ -5251,10 +5358,14 @@ export class Connection { serverSubscriptionId, ); } else { - this._subscriptionsByHash[hash] = { + this._setSubscription(hash, { + ...subscription, + state: 'unsubscribing', + }); + this._setSubscription(hash, { ...subscription, state: 'unsubscribing', - }; + }); try { await this._rpcWebSocket.call(unsubscribeMethod, [ serverSubscriptionId, @@ -5267,18 +5378,18 @@ export class Connection { return; } // TODO: Maybe add an 'errored' state or a retry limit? - this._subscriptionsByHash[hash] = { + this._setSubscription(hash, { ...subscription, state: 'subscribed', - }; + }); await this._updateSubscriptions(); return; } } - this._subscriptionsByHash[hash] = { + this._setSubscription(hash, { ...subscription, state: 'unsubscribed', - }; + }); await this._updateSubscriptions(); })(); } @@ -5381,12 +5492,14 @@ export class Connection { } else { existingSubscription.callbacks.add(subscriptionConfig.callback); } + this._subscriptionHashByClientSubscriptionId[clientSubscriptionId] = hash; this._subscriptionDisposeFunctionsByClientSubscriptionId[ clientSubscriptionId ] = async () => { delete this._subscriptionDisposeFunctionsByClientSubscriptionId[ clientSubscriptionId ]; + delete this._subscriptionHashByClientSubscriptionId[clientSubscriptionId]; const subscription = this._subscriptionsByHash[hash]; assert( subscription !== undefined, diff --git a/web3.js/test/connection.test.ts b/web3.js/test/connection.test.ts index 052528a24fb77a..e13260540ed397 100644 --- a/web3.js/test/connection.test.ts +++ b/web3.js/test/connection.test.ts @@ -3,7 +3,8 @@ import {Buffer} from 'buffer'; import * as splToken from '@solana/spl-token'; import {expect, use} from 'chai'; import chaiAsPromised from 'chai-as-promised'; -import {useFakeTimers, SinonFakeTimers} from 'sinon'; +import {mock, useFakeTimers, SinonFakeTimers} from 'sinon'; +import sinonChai from 'sinon-chai'; import { Authorized, @@ -67,6 +68,7 @@ import {MessageV0} from '../src/message/v0'; import {encodeData} from '../src/instruction'; use(chaiAsPromised); +use(sinonChai); const verifySignatureStatus = ( status: SignatureStatus | null, @@ -1133,6 +1135,85 @@ describe('Connection', function () { value: {err: null}, }); }); + + it('confirm transaction - does not check the signature status before the signature subscription comes alive', async () => { + const mockSignature = + 'w2Zeq8YkpyB463DttvfzARD7k9ZxGEwbsEw4boEK7jDp3pfoxZbTdLFSsEPhzXhpCcjGi2kHtHFobgX49MMhbWt'; + + await mockRpcMessage({ + method: 'signatureSubscribe', + params: [mockSignature, {commitment: 'finalized'}], + result: {err: null}, + subscriptionEstablishmentPromise: new Promise(() => {}), // Never resolve. + }); + const getSignatureStatusesExpectation = mock(connection) + .expects('getSignatureStatuses') + .never(); + connection.confirmTransaction(mockSignature); + getSignatureStatusesExpectation.verify(); + }); + + it('confirm transaction - checks the signature status once the signature subscription comes alive', async () => { + const mockSignature = + 'w2Zeq8YkpyB463DttvfzARD7k9ZxGEwbsEw4boEK7jDp3pfoxZbTdLFSsEPhzXhpCcjGi2kHtHFobgX49MMhbWt'; + + await mockRpcMessage({ + method: 'signatureSubscribe', + params: [mockSignature, {commitment: 'finalized'}], + result: {err: null}, + }); + const getSignatureStatusesExpectation = mock(connection) + .expects('getSignatureStatuses') + .once(); + + const confirmationPromise = + connection.confirmTransaction(mockSignature); + clock.runAllAsync(); + + await expect(confirmationPromise).to.eventually.deep.equal({ + context: {slot: 11}, + value: {err: null}, + }); + getSignatureStatusesExpectation.verify(); + }); + + // FIXME: This test does not work. + // it('confirm transaction - confirms transaction when signature status check yields confirmation before signature subscription does', async () => { + // const mockSignature = + // 'w2Zeq8YkpyB463DttvfzARD7k9ZxGEwbsEw4boEK7jDp3pfoxZbTdLFSsEPhzXhpCcjGi2kHtHFobgX49MMhbWt'; + + // // Keep the subscription from ever returning data. + // await mockRpcMessage({ + // method: 'signatureSubscribe', + // params: [mockSignature, {commitment: 'finalized'}], + // result: new Promise(() => {}), // Never resolve. + // }); + // clock.runAllAsync(); + + // const confirmationPromise = + // connection.confirmTransaction(mockSignature); + // clock.runAllAsync(); + + // // Return a signature status through the RPC API. + // await mockRpcResponse({ + // method: 'getSignatureStatuses', + // params: [[mockSignature]], + // value: [ + // { + // slot: 0, + // confirmations: 11, + // status: {Ok: null}, + // err: null, + // }, + // ], + // }); + // clock.runAllAsync(); + + // await expect(confirmationPromise).to.eventually.deep.equal({ + // context: {slot: 11}, + // value: {err: null}, + // }); + // }); }); } diff --git a/web3.js/test/mocks/rpc-websockets.ts b/web3.js/test/mocks/rpc-websockets.ts index eae2c4a7c5c777..769aa2bd80f5bc 100644 --- a/web3.js/test/mocks/rpc-websockets.ts +++ b/web3.js/test/mocks/rpc-websockets.ts @@ -7,6 +7,7 @@ import {Connection} from '../../src'; type RpcRequest = { method: string; params?: Array; + subscriptionEstablishmentPromise?: Promise; }; type RpcResponse = { @@ -24,13 +25,15 @@ export const mockRpcMessage = ({ method, params, result, + subscriptionEstablishmentPromise, }: { method: string; params: Array; result: any | Promise; + subscriptionEstablishmentPromise?: Promise; }) => { mockRpcSocket.push([ - {method, params}, + {method, params, subscriptionEstablishmentPromise}, { context: {slot: 11}, value: result, @@ -109,6 +112,10 @@ class MockClient { expect(params).to.eql(mockRequest.params); } + if (mockRequest.subscriptionEstablishmentPromise) { + await mockRequest.subscriptionEstablishmentPromise; + } + let id = ++this.subscriptionCounter; const response = { subscription: id,