Skip to content

Commit

Permalink
feat: getContractVersion and getCairoVersion by classHash, deploy_acc…
Browse files Browse the repository at this point in the history
…ount and bulk action autodetect
  • Loading branch information
tabaktoni committed Sep 23, 2023
1 parent 85bbe39 commit 54ffca4
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 19 deletions.
5 changes: 3 additions & 2 deletions __tests__/defaultProvider.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ describe('defaultProvider', () => {
});

test('getContractVersion', async () => {
const version = await testProvider.getContractVersion(erc20ContractAddress);
expect(version).toEqual({ cairo: '0', compiler: '0' });
const expected = { cairo: '0', compiler: '0' };
expect(await testProvider.getContractVersion(erc20ContractAddress)).toEqual(expected);
expect(await testProvider.getContractVersion(undefined, erc20ClassHash)).toEqual(expected);
});

describe('getBlock', () => {
Expand Down
33 changes: 23 additions & 10 deletions src/account/default.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,17 @@ export class Account extends Provider implements AccountInterface {

/**
* Async Get cairo version (auto set it, if not set by user)
* @param version CairoVersion
* @param classHash (optional) string - if provided detect cairoVersion from classHash
*/
public async getCairoVersion() {
public async getCairoVersion(classHash?: string) {
if (!this.cairoVersion) {
const { cairo } = await super.getContractVersion(this.address);
this.cairoVersion = cairo;
if (classHash) {
const { cairo } = await super.getContractVersion(undefined, classHash);
this.cairoVersion = cairo;
} else {
const { cairo } = await super.getContractVersion(this.address);
this.cairoVersion = cairo;
}
}
return this.cairoVersion;
}
Expand Down Expand Up @@ -167,7 +172,7 @@ export class Account extends Provider implements AccountInterface {
version,
walletAddress: this.address,
maxFee: ZERO,
cairoVersion: await this.getCairoVersion(),
cairoVersion: undefined, // unused parameter
}
);

Expand Down Expand Up @@ -204,9 +209,9 @@ export class Account extends Provider implements AccountInterface {
nonce,
chainId,
version,
walletAddress: this.address,
walletAddress: this.address, // unused parameter
maxFee: ZERO,
cairoVersion: await this.getCairoVersion(),
cairoVersion: undefined, // unused parameter
}
);

Expand Down Expand Up @@ -354,7 +359,7 @@ export class Account extends Provider implements AccountInterface {
const declareContractTransaction = await this.buildDeclarePayload(declareContractPayload, {
...details,
walletAddress: this.address,
cairoVersion: await this.getCairoVersion(),
cairoVersion: await this.getCairoVersion(), // unused parameter
});

return this.declareContract(declareContractTransaction, details);
Expand Down Expand Up @@ -659,18 +664,26 @@ export class Account extends Provider implements AccountInterface {
const version = versions[0];
const safeNonce = await this.getNonceSafe(nonce);
const chainId = await this.getChainId();
let newAccClassHash: string;

return Promise.all(
([] as Invocations).concat(invocations).map(async (transaction, index: number) => {
const txPayload: any = 'payload' in transaction ? transaction.payload : transaction;
// BULK ACTION FROM NEW ACCOUNT START WITH DEPLOY_ACCOUNT
if (index === 0 && transaction.type === TransactionType.DEPLOY_ACCOUNT) {
newAccClassHash = txPayload.classHash;
}

const signerDetails: InvocationsSignerDetails = {
walletAddress: this.address,
nonce: toBigInt(Number(safeNonce) + index),
maxFee: ZERO,
version,
chainId,
cairoVersion: await this.getCairoVersion(),
cairoVersion: newAccClassHash
? await this.getCairoVersion(newAccClassHash)
: await this.getCairoVersion(),
};
const txPayload: any = 'payload' in transaction ? transaction.payload : transaction;
const common = {
type: transaction.type,
version,
Expand Down
8 changes: 6 additions & 2 deletions src/provider/default.ts
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,11 @@ export class Provider implements ProviderInterface {
return getAddressFromStarkName(this, name, StarknetIdContract);
}

public async getContractVersion(contractAddress: string, options?: getContractVersionOptions) {
return this.provider.getContractVersion(contractAddress, options);
public async getContractVersion(
contractAddress?: string,
classHash?: string,
options?: getContractVersionOptions
) {
return this.provider.getContractVersion(contractAddress, classHash, options);
}
}
6 changes: 4 additions & 2 deletions src/provider/interface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -335,13 +335,15 @@ export abstract class ProviderInterface {

/**
* Gets the contract version from the provided address
* @param contractAddress string
* @param contractAddress required if no classHash
* @param classHash required if no contractAddress
* @param options - getContractVersionOptions
* - (optional) compiler - (default true) extract compiler version using type tactic from abi
* - (optional) blockIdentifier - block identifier
*/
public abstract getContractVersion(
contractAddress: string,
contractAddress?: string,
classHash?: string,
options?: getContractVersionOptions
): Promise<ContractVersion>;
}
11 changes: 10 additions & 1 deletion src/provider/rpc.ts
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,18 @@ export class RpcProvider implements ProviderInterface {

public async getContractVersion(
contractAddress: string,
classHash: string,
{ blockIdentifier = this.blockIdentifier, compiler = true }: getContractVersionOptions
): Promise<ContractVersion> {
const contractClass = await this.getClassAt(contractAddress, blockIdentifier);
let contractClass;
if (contractAddress) {
contractClass = await this.getClassAt(contractAddress, blockIdentifier);
} else if (classHash) {
contractClass = await this.getClass(classHash, blockIdentifier);
} else {
throw Error('getContractVersion require contractAddress or classHash');
}

if (isSierra(contractClass)) {
if (compiler) {
const abiTest = getAbiContractVersion(contractClass.abi);
Expand Down
13 changes: 11 additions & 2 deletions src/provider/sequencer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -346,13 +346,22 @@ export class SequencerProvider implements ProviderInterface {
}

public async getContractVersion(
contractAddress: string,
contractAddress?: string,
classHash?: string,
{ blockIdentifier, compiler }: getContractVersionOptions = {
blockIdentifier: this.blockIdentifier,
compiler: true,
}
): Promise<ContractVersion> {
const contractClass = await this.getClassAt(contractAddress, blockIdentifier);
let contractClass;
if (contractAddress) {
contractClass = await this.getClassAt(contractAddress, blockIdentifier);
} else if (classHash) {
contractClass = await this.getClassByHash(classHash, blockIdentifier);
} else {
throw Error('getContractVersion require contractAddress or classHash');
}

if (isSierra(contractClass)) {
if (compiler) {
const abiTest = getAbiContractVersion(contractClass.abi);
Expand Down

0 comments on commit 54ffca4

Please sign in to comment.