From 0b56833f6eb7557d55725a505a01185bbc9756db Mon Sep 17 00:00:00 2001 From: Janek Rahrt Date: Fri, 19 Aug 2022 03:34:17 +0200 Subject: [PATCH] fix: session account prooving --- __tests__/utils/merkle.test.ts | 1 + src/account/session.ts | 36 +++++++++++++++++++++++------ src/index.ts | 1 + src/utils/merkle.ts | 12 +++++----- src/utils/session.ts | 42 ++++++++++++++++++++++++++-------- src/utils/typedData/index.ts | 7 +++--- 6 files changed, 73 insertions(+), 26 deletions(-) diff --git a/__tests__/utils/merkle.test.ts b/__tests__/utils/merkle.test.ts index 5899ed39a..2c1fb8764 100644 --- a/__tests__/utils/merkle.test.ts +++ b/__tests__/utils/merkle.test.ts @@ -79,6 +79,7 @@ describe('MerkleTree class', () => { const proof = tree.getProof('0x7'); const manualProof = [ + '0x0', // proofs should always be as long as the tree is deep MerkleTree.hash('0x5', '0x6'), MerkleTree.hash(MerkleTree.hash('0x1', '0x2'), MerkleTree.hash('0x3', '0x4')), ]; diff --git a/src/account/session.ts b/src/account/session.ts index 2e1b2bfca..469794577 100644 --- a/src/account/session.ts +++ b/src/account/session.ts @@ -16,7 +16,7 @@ import { import { feeTransactionVersion, transactionVersion } from '../utils/hash'; import { MerkleTree } from '../utils/merkle'; import { BigNumberish, toBN } from '../utils/number'; -import { SignedSession, createMerkleTreeForPolicies } from '../utils/session'; +import { SignedSession, createMerkleTreeForPolicies, preparePolicy } from '../utils/session'; import { compileCalldata, estimatedFeeToMaxFee } from '../utils/stark'; import { fromCallsToExecuteCalldataWithNonce } from '../utils/transaction'; import { Account } from './default'; @@ -39,24 +39,44 @@ export class SessionAccount extends Account implements AccountInterface { assert(signedSession.root === this.merkleTree.root, 'Invalid session'); } - private async sessionToCall(session: SignedSession): Promise { + private async sessionToCall(session: SignedSession, proofs: string[][]): Promise { return { contractAddress: this.address, entrypoint: 'use_plugin', calldata: compileCalldata({ - SESSION_PLUGIN_CLASS_HASH, + classHash: SESSION_PLUGIN_CLASS_HASH, signer: await this.signer.getPubKey(), expires: session.expires.toString(), + root: session.root, + proofLength: proofs[0].length.toString(), + ...proofs.reduce( + (acc, proof, i) => ({ + ...acc, + ...proof.reduce((acc2, path, j) => ({ ...acc2, [`proof${i}:${j}`]: path }), {}), + }), + {} + ), + token1: session.signature[0], token2: session.signature[1], - root: session.root, - proof: [], }), }; } + private proofCalls(calls: Call[]): string[][] { + return calls.map((call) => { + const leaf = preparePolicy({ + contractAddress: call.contractAddress, + selector: call.entrypoint, + }); + return this.merkleTree.getProof(leaf); + }); + } + private async extendCallsBySession(calls: Call[], session: SignedSession): Promise { - return [await this.sessionToCall(session), ...calls]; + const proofs = this.proofCalls(calls); + const pluginCall = await this.sessionToCall(session, proofs); + return [pluginCall, ...calls]; } public async estimateFee( @@ -119,7 +139,9 @@ export class SessionAccount extends Account implements AccountInterface { if (transactionsDetail.maxFee || transactionsDetail.maxFee === 0) { maxFee = transactionsDetail.maxFee; } else { - const { suggestedMaxFee } = await this.estimateFee(transactions, { nonce }); + const { suggestedMaxFee } = await this.estimateFee(Array.isArray(calls) ? calls : [calls], { + nonce, + }); maxFee = suggestedMaxFee.toString(); } diff --git a/src/index.ts b/src/index.ts index 997d043c3..086bfa0b4 100644 --- a/src/index.ts +++ b/src/index.ts @@ -18,6 +18,7 @@ export * as number from './utils/number'; export * as transaction from './utils/transaction'; export * as stark from './utils/stark'; export * as merkle from './utils/merkle'; +export * as session from './utils/session'; export * as ec from './utils/ellipticCurve'; export * as uint256 from './utils/uint256'; export * as shortString from './utils/shortString'; diff --git a/src/utils/merkle.ts b/src/utils/merkle.ts index 04bd7e7ae..1dd288c0c 100644 --- a/src/utils/merkle.ts +++ b/src/utils/merkle.ts @@ -36,15 +36,15 @@ export class MerkleTree { } public getProof(leaf: string, branch = this.leaves, hashPath: string[] = []): string[] { - if (branch.length === 1) { - return hashPath; - } const index = branch.indexOf(leaf); if (index === -1) { throw new Error('leaf not found'); } + if (branch.length === 1) { + return hashPath; + } const isLeft = index % 2 === 0; - const neededBranch = (isLeft ? branch[index + 1] : branch[index - 1]) ?? branch[index]; + const neededBranch = (isLeft ? branch[index + 1] : branch[index - 1]) ?? '0x0'; const newHashPath = [...hashPath, neededBranch]; const currentBranchLevelIndex = this.leaves.length === branch.length @@ -52,11 +52,11 @@ export class MerkleTree { : this.branches.findIndex((b) => b.length === branch.length); const nextBranch = this.branches[currentBranchLevelIndex + 1] ?? [this.root]; return this.getProof( - neededBranch === leaf + neededBranch === '0x0' ? leaf : MerkleTree.hash(isLeft ? leaf : neededBranch, isLeft ? neededBranch : leaf), nextBranch, - neededBranch === leaf ? hashPath : newHashPath + newHashPath ); } } diff --git a/src/utils/session.ts b/src/utils/session.ts index baebc22ce..e4aee5554 100644 --- a/src/utils/session.ts +++ b/src/utils/session.ts @@ -1,8 +1,12 @@ import type { AccountInterface } from '../account'; +import { StarknetChainId } from '../constants'; +import { ProviderInterface } from '../provider'; import { Signature } from '../types'; -import { pedersen } from './hash'; +import { computeHashOnElements } from './hash'; import { MerkleTree } from './merkle'; -import { StarkNetDomain, prepareSelector } from './typedData'; +import { toBN } from './number'; +import { compileCalldata } from './stark'; +import { prepareSelector } from './typedData'; interface Policy { contractAddress: string; @@ -23,8 +27,25 @@ export interface SignedSession extends PreparedSession { signature: Signature; } -function preparePolicy({ contractAddress, selector }: Policy): string { - return pedersen([contractAddress, prepareSelector(selector)]); +export const SESSION_PLUGIN_CLASS_HASH = + '0x1031d8540af9d984d8d8aa5dff598467008c58b6f6147b7f90fda4b6d8db463'; +// H(Policy(contractAddress:felt,selector:selector)) +const POLICY_TYPE_HASH = '0x2f0026e78543f036f33e26a8f5891b88c58dc1e20cbbfaf0bb53274da6fa568'; + +export async function supportsSessions( + address: string, + provider: ProviderInterface +): Promise { + const { result } = await provider.callContract({ + contractAddress: address, + entrypoint: 'is_plugin', + calldata: compileCalldata({ classHash: SESSION_PLUGIN_CLASS_HASH }), + }); + return !toBN(result[0]).isZero(); +} + +export function preparePolicy({ contractAddress, selector }: Policy): string { + return computeHashOnElements([POLICY_TYPE_HASH, contractAddress, prepareSelector(selector)]); } export function createMerkleTreeForPolicies(policies: Policy[]): MerkleTree { @@ -38,8 +59,7 @@ export function prepareSession(session: RequestSession): PreparedSession { export async function createSession( session: RequestSession, - account: AccountInterface, - domain: StarkNetDomain = {} + account: AccountInterface ): Promise { const { expires, key, policies, root } = prepareSession(session); const signature = await account.signMessage({ @@ -52,15 +72,19 @@ export async function createSession( Session: [ { name: 'key', type: 'felt' }, { name: 'expires', type: 'felt' }, - { name: 'root', type: 'merkletree', contains: 'Policy*' }, + { name: 'root', type: 'merkletree', contains: 'Policy' }, ], StarkNetDomain: [ { name: 'name', type: 'felt' }, { name: 'version', type: 'felt' }, - { name: 'chain_id', type: 'felt' }, + { name: 'chainId', type: 'felt' }, ], }, - domain, + domain: { + name: '0x0', + version: '0x0', + chainId: StarknetChainId.TESTNET, + }, message: { key, expires, diff --git a/src/utils/typedData/index.ts b/src/utils/typedData/index.ts index 944f0f1bd..b99540b17 100644 --- a/src/utils/typedData/index.ts +++ b/src/utils/typedData/index.ts @@ -81,6 +81,9 @@ function getMerkleTreeType(types: TypedData['types'], ctx: Context) { if (!isMerkleTree) { throw new Error(`${ctx.key} is not a merkle tree`); } + if (merkleType.contains.endsWith('*')) { + throw new Error(`Merkle tree contain property must not be an array but was given ${ctx.key}`); + } return merkleType.contains; } return 'raw'; @@ -161,10 +164,6 @@ export const encodeValue = ( return ['felt*', computeHashOnElements(data as string[])]; } - if (type === 'raw') { - return ['felt', data as string]; - } - if (type === 'selector') { return ['felt', prepareSelector(data as string)]; }