diff --git a/.gitignore b/.gitignore index 973ce80b6..2c6ca1ec9 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,5 @@ node_modules/ proving-files/ zapps/ test-zapps/ +truezapps/ +temp-zapps/ \ No newline at end of file diff --git a/src/boilerplate/circuit/zokrates/nodes/BoilerplateGenerator.ts b/src/boilerplate/circuit/zokrates/nodes/BoilerplateGenerator.ts index 184774098..c24bf2bda 100644 --- a/src/boilerplate/circuit/zokrates/nodes/BoilerplateGenerator.ts +++ b/src/boilerplate/circuit/zokrates/nodes/BoilerplateGenerator.ts @@ -164,35 +164,47 @@ class BoilerplateGenerator { }); } + prepareMappingBoilerplate(mappingKeyName: string, mappingKeyIndicator: MappingKey) { + mappingKeyIndicator.isMapping = true; + this.assignIndicators(mappingKeyIndicator); + if(mappingKeyName == 'msg') + mappingKeyName = mappingKeyName+mappingKeyIndicator.keyPath.parent.memberName.replace('sender','Sender').replace('value','Value'); + this.mappingKeyName = mappingKeyName.replace('[', '_').replace(']', ''); + if (this.mappingKeyName.split('.').length > 2) this.mappingKeyName = this.mappingKeyName.replace('.', 'dot'); + + if (mappingKeyIndicator.keyPath.isStruct() && !(mappingKeyIndicator.keyPath.node.nodeType === 'Identifier' && !mappingKeyIndicator.keyPath.node.typeDescriptions.typeString.includes('struct '))) + this.mappingKeyTypeName = mappingKeyIndicator.keyPath.getStructDeclaration().name; + + if (!mappingKeyIndicator.keyPath.isMsg() && + (mappingKeyIndicator.keyPath.node.nodeType === 'Literal'|| mappingKeyIndicator.keyPath.isLocalStackVariable() || !mappingKeyIndicator.keyPath.isSecret || mappingKeyIndicator.keyPath.node.accessedSecretState)) + this.mappingKeyTypeName = 'local'; + + this.mappingName = this.indicators.name; + this.name = //this.name.replaceAll('.', 'dot').replaceAll('[', '_').replaceAll(']', ''); + `${this.mappingName}_${mappingKeyName}`.replaceAll('.', 'dot').replaceAll('[', '_').replaceAll(']', ''); + + if (mappingKeyIndicator.isStruct && mappingKeyIndicator.isParent) { + this.typeName = this.indicators.referencingPaths[0]?.getStructDeclaration()?.name; + this.structProperties = this.indicators.referencingPaths[0]?.getStructDeclaration()?.members.map(m => m.name) + } else if (mappingKeyIndicator.referencingPaths[0]?.node.typeDescriptions.typeString.includes('struct ')) { + // somewhat janky way to include referenced structs not separated by property + this.typeName = mappingKeyIndicator.referencingPaths[0]?.getStructDeclaration()?.name; + } + this.generateBoilerplate(); + } + initialise(indicators: StateVariableIndicator){ this.indicators = indicators; if (indicators.isMapping && indicators.mappingKeys) { for (let [mappingKeyName, mappingKeyIndicator] of Object.entries(indicators.mappingKeys)) { mappingKeyIndicator.isMapping = true; - this.assignIndicators(mappingKeyIndicator); - if(mappingKeyName == 'msg') - mappingKeyName = mappingKeyName+mappingKeyIndicator.keyPath.parent.memberName.replace('sender','Sender').replace('value','Value'); - this.mappingKeyName = mappingKeyName.replace('[', '_').replace(']', ''); - if (this.mappingKeyName.split('.').length > 2) this.mappingKeyName = this.mappingKeyName.replace('.', 'dot'); - - if (mappingKeyIndicator.keyPath.isStruct() && !(mappingKeyIndicator.keyPath.node.nodeType === 'Identifier' && !mappingKeyIndicator.keyPath.node.typeDescriptions.typeString.includes('struct '))) - this.mappingKeyTypeName = mappingKeyIndicator.keyPath.getStructDeclaration().name; - - if (!mappingKeyIndicator.keyPath.isMsg() && - (mappingKeyIndicator.keyPath.node.nodeType === 'Literal'|| mappingKeyIndicator.keyPath.isLocalStackVariable() || !mappingKeyIndicator.keyPath.isSecret)) - this.mappingKeyTypeName = 'local'; - - this.mappingName = this.indicators.name; - this.name = `${this.mappingName}_${mappingKeyName}`.replaceAll('.', 'dot').replace('[', '_').replace(']', ''); - - if (mappingKeyIndicator.isStruct && mappingKeyIndicator.isParent) { - this.typeName = indicators.referencingPaths[0]?.getStructDeclaration()?.name; - this.structProperties = indicators.referencingPaths[0]?.getStructDeclaration()?.members.map(m => m.name) - } else if (mappingKeyIndicator.referencingPaths[0]?.node.typeDescriptions.typeString.includes('struct ')) { - // somewhat janky way to include referenced structs not separated by property - this.typeName = mappingKeyIndicator.referencingPaths[0]?.getStructDeclaration()?.name; + if (mappingKeyIndicator.mappingKeys) { + for (let [outerMappingKeyName, outerMappingKeyIndicator] of Object.entries(mappingKeyIndicator.mappingKeys)) { + this.prepareMappingBoilerplate(mappingKeyName + '_' + outerMappingKeyName, outerMappingKeyIndicator); + } + } else { + this.prepareMappingBoilerplate(mappingKeyName, mappingKeyIndicator); } - this.generateBoilerplate(); } } else { if (indicators instanceof StateVariableIndicator && indicators.structProperties) { @@ -205,7 +217,14 @@ class BoilerplateGenerator { } refresh(mappingKeyName: string) { - const mappingKeyIndicator = this.indicators.mappingKeys[mappingKeyName]; + let mappingKeyIndicator = this.indicators.mappingKeys[mappingKeyName]; + if (mappingKeyName.includes(`/`)) { + // nested mapping + mappingKeyName.split(`/`).forEach(name => { + if (name) mappingKeyIndicator = mappingKeyIndicator ? mappingKeyIndicator.mappingKeys[name] : this.indicators.mappingKeys[name]; + }); + mappingKeyName = mappingKeyName.replace(`/`, `_`); + } this.assignIndicators(mappingKeyIndicator); this.mappingKeyName = mappingKeyName.replace('[', '_').replace(']', ''); if (this.mappingKeyName.split('.').length > 2) this.mappingKeyName.replace('.', 'dot'); @@ -235,8 +254,11 @@ class BoilerplateGenerator { } _addBP = (bpType: string, extraParams?: any) => { + const lastModifiedNodeName = this.thisIndicator.isModified ? this.thisIndicator.modifyingPaths[0].scope.getIdentifierMappingKeyName(this.thisIndicator.modifyingPaths[this.thisIndicator.modifyingPaths.length - 1].node) : this.thisIndicator.node.name; if (this.isPartitioned) { this.newCommitmentValue = collectIncrements(this).incrementsString; + } else if (lastModifiedNodeName !== this.thisIndicator.node.name) { + this.newCommitmentValue = lastModifiedNodeName; } this.bpSections.forEach(bpSection => { this[bpSection] = this[bpSection] @@ -355,7 +377,11 @@ class BoilerplateGenerator { mapping = (bpSection) => ({ mappingName: this.mappingName, - mappingKeyName: bpSection === 'postStatements' ? this.mappingKeyName : bpSection === 'parameters' ? this.mappingKeyName.split('.')[0] : this.mappingKeyName.replace('.', 'dot'), + mappingKeyName: this.thisIndicator?.keyPath?.isNestedMapping() + ? [this.thisIndicator.container.referencedKeyName, this.thisIndicator.referencedKeyName] + : bpSection === 'parameters' + ? this.mappingKeyName.split('.')[0] + : bpSection === 'postStatements' ? this.mappingKeyName : this.mappingKeyName.replace('.', 'dot'), }); /** Partitioned states need boilerplate for an incrementation/decrementation, because it's so weird and different from `a = a - b`. Whole states inherit directly from the AST, so don't need boilerplate here. */ diff --git a/src/boilerplate/circuit/zokrates/raw/BoilerplateGenerator.ts b/src/boilerplate/circuit/zokrates/raw/BoilerplateGenerator.ts index 40f5d02ed..d9c439cf8 100644 --- a/src/boilerplate/circuit/zokrates/raw/BoilerplateGenerator.ts +++ b/src/boilerplate/circuit/zokrates/raw/BoilerplateGenerator.ts @@ -291,7 +291,7 @@ class BoilerplateGenerator { postStatements({ name: x, isWhole, isNullified, newCommitmentValue, structProperties, typeName }): string[] { // if (!isWhole && !newCommitmentValue) throw new Error('PATH'); - const y = isWhole ? x : newCommitmentValue; + const y = isWhole && !newCommitmentValue ? x : newCommitmentValue; const lines: string[] = []; if (!isWhole && isNullified) { // decrement @@ -429,12 +429,13 @@ class BoilerplateGenerator { mapping = { importStatements(): string[] { return [ - `from "./common/hashes/mimc/altbn254/mimc2.zok" import main as mimc2`, + `from "./common/hashes/poseidon/poseidon.zok" import main as poseidon`, ]; }, parameters({ mappingKeyName: k, mappingKeyTypeName: t }): string[] { if (t === 'local') return []; + if (Array.isArray(k)) return k.map(name => `private ${t ? t : 'field'} ${name}`); return [ `private ${t ? t : 'field'} ${k}`, // must be a field, in case we need to do arithmetic on it. ]; @@ -452,7 +453,7 @@ class BoilerplateGenerator { // const x = `${m}_${k}`; return [ ` - field ${x}_stateVarId_field = mimc2([${m}_mappingId, ${k}])`, + field ${x}_stateVarId_field = poseidon([${m}_mappingId, ${k}])`, ]; }, }; diff --git a/src/boilerplate/orchestration/javascript/nodes/boilerplate-generator.ts b/src/boilerplate/orchestration/javascript/nodes/boilerplate-generator.ts index f19bff972..e91fe167b 100644 --- a/src/boilerplate/orchestration/javascript/nodes/boilerplate-generator.ts +++ b/src/boilerplate/orchestration/javascript/nodes/boilerplate-generator.ts @@ -99,8 +99,11 @@ export function buildPrivateStateNode(nodeType: string, fields: any = {}): any { } case 'CalculateCommitment': { const { id, increment, privateStateName, indicator = {} } = fields; + const lastModifiedNodeName = indicator.modifyingPaths[0].scope.getIdentifierMappingKeyName(indicator.modifyingPaths[indicator.modifyingPaths.length - 1].node); + const newCommitmentValue = indicator.isWhole && lastModifiedNodeName !== indicator.node.name ? lastModifiedNodeName : null; return { privateStateName, + newCommitmentValue, stateVarId: id, increment, isWhole: indicator.isWhole, @@ -128,6 +131,7 @@ export function buildPrivateStateNode(nodeType: string, fields: any = {}): any { indicator = {}, } = fields; const structProperties = !indicator.isStruct ? null : indicator.isAccessed ? indicator.referencingPaths[0]?.getStructDeclaration()?.members.map(m => m.name) : Object.keys(indicator.structProperties); + const originalMappingKeyName = id[2] ? [ indicator.container.referencedKeyName, indicator.referencedKeyName] : id[1] ? [ indicator.referencedKeyName ] : []; return { privateStateName, stateVarId: id, @@ -142,6 +146,9 @@ export function buildPrivateStateNode(nodeType: string, fields: any = {}): any { isPartitioned: indicator.isPartitioned, isOwned: indicator.isOwned, mappingOwnershipType: indicator.mappingOwnershipType, + originalMappingKeyName: indicator.isMapping + ? originalMappingKeyName + : null, initialisationRequired: indicator.initialisationRequired, encryptionRequired: indicator.encryptionRequired, owner: indicator.isOwned diff --git a/src/boilerplate/orchestration/javascript/raw/boilerplate-generator.ts b/src/boilerplate/orchestration/javascript/raw/boilerplate-generator.ts index 4bcf7f1ad..d3afa8670 100644 --- a/src/boilerplate/orchestration/javascript/raw/boilerplate-generator.ts +++ b/src/boilerplate/orchestration/javascript/raw/boilerplate-generator.ts @@ -362,7 +362,7 @@ class BoilerplateGenerator { calculateCommitment = { - postStatements({ stateName, stateType, structProperties }): string[] { + postStatements({ stateName, newCommitmentValue, stateType, structProperties }): string[] { // once per state switch (stateType) { case 'increment': @@ -383,9 +383,10 @@ class BoilerplateGenerator { \nlet ${stateName}_2_newCommitment = poseidonHash([BigInt(${stateName}_stateVarId), ${structProperties ? `...${stateName}_change.hex(32).map(v => BigInt(v))` : `BigInt(${stateName}_change.hex(32))`}, BigInt(publicKey.hex(32)), BigInt(${stateName}_2_newSalt.hex(32))],); \n${stateName}_2_newCommitment = generalise(${stateName}_2_newCommitment.hex(32)); // truncate`]; case 'whole': - const value = structProperties ? structProperties.map(p => `BigInt(${stateName}.${p}.hex(32))`) :` BigInt(${stateName}.hex(32))`; + const name = newCommitmentValue || stateName; + const value = structProperties ? structProperties.map(p => `BigInt(${name}.${p}.hex(32))`) :` BigInt(${name}.hex(32))`; return [` - \n ${structProperties ? structProperties.map(p => `\n${stateName}.${p} = ${stateName}.${p} ? ${stateName}.${p} : ${stateName}_prev.${p};`).join('') : ''} + \n ${structProperties ? structProperties.map(p => `\n${name}.${p} = ${name}.${p} ? ${name}.${p} : ${stateName}_prev.${p};`).join('') : ''} \nconst ${stateName}_newSalt = generalise(utils.randomHex(31)); \nlet ${stateName}_newCommitment = poseidonHash([BigInt(${stateName}_stateVarId), ${value}, BigInt(${stateName}_newOwnerPublicKey.hex(32)), BigInt(${stateName}_newSalt.hex(32))],); \n${stateName}_newCommitment = generalise(${stateName}_newCommitment.hex(32)); // truncate`]; diff --git a/src/boilerplate/orchestration/javascript/raw/toOrchestration.ts b/src/boilerplate/orchestration/javascript/raw/toOrchestration.ts index c85f3cc20..731fa18f5 100644 --- a/src/boilerplate/orchestration/javascript/raw/toOrchestration.ts +++ b/src/boilerplate/orchestration/javascript/raw/toOrchestration.ts @@ -2,47 +2,44 @@ import OrchestrationBP from './boilerplate-generator.js'; - +let msgSenderAdded = false; const stateVariableIds = (node: any) => { - const {privateStateName, stateNode} = node; + const { privateStateName, stateNode } = node; const stateVarIds: string[] = []; // state variable ids // if not a mapping, use singular unique id (if mapping, stateVarId is an array) if (!stateNode.stateVarId[1]) { stateVarIds.push( - `\nconst ${privateStateName}_stateVarId = generalise(${stateNode.stateVarId}).hex(32);`, + `\nconst ${ privateStateName }_stateVarId = generalise(${ stateNode.stateVarId }).hex(32);`, ); } else { // if is a mapping... stateVarIds.push( - `\nlet ${privateStateName}_stateVarId = ${stateNode.stateVarId[0]};`, + `\nlet ${ privateStateName }_stateVarId = ${ stateNode.stateVarId[0] };`, ); - // ... and the mapping key is not msg.sender, but is a parameter - if ( - privateStateName.includes(stateNode.stateVarId[1].replaceAll('.', 'dot')) && - stateNode.stateVarId[1] !== 'msg' - ) { - if (+stateNode.stateVarId[1] || stateNode.stateVarId[1] === '0') { - stateVarIds.push( - `\nconst ${privateStateName}_stateVarId_key = generalise(${stateNode.stateVarId[1]});`, - ); - } else { - stateVarIds.push( - `\nconst ${privateStateName}_stateVarId_key = ${stateNode.stateVarId[1]};`, - ); + + let innerArgs: string[] = []; + stateNode.stateVarId.forEach((id, index) => { + if (index !== 0) { + innerArgs.push(`${ +id || id === '0' ? `${ id }` : `${ id }.hex(32)` }`); } - } + }); // ... and the mapping key is msg, and the caller of the fn has the msg key if ( - stateNode.stateVarId[1] === 'msg' && - privateStateName.includes('msg') + stateNode.stateVarId.includes('msg') && + privateStateName.includes('msg') && !msgSenderAdded ) { stateVarIds.push( - `\nconst ${privateStateName}_stateVarId_key = generalise(config.web3.options.defaultAccount); // emulates msg.sender`, + `\nconst msgSender = generalise(config.web3.options.defaultAccount); // emulates msg.sender`, ); + const index = stateNode.stateVarId.indexOf('msg'); + stateNode.stateVarId[index] = `msgSender`; } stateVarIds.push( - `\n${privateStateName}_stateVarId = generalise(utils.mimcHash([generalise(${privateStateName}_stateVarId).bigInt, ${privateStateName}_stateVarId_key.bigInt], 'ALT_BN_254')).hex(32);`, + `\n${ privateStateName }_stateVarId = poseidonHash([ + BigInt(${ privateStateName }_stateVarId), + ${ innerArgs.map(a => `BigInt(${ a })`) } + ]).hex(32);`, ); } return stateVarIds; @@ -89,8 +86,8 @@ export const sendTransactionBoilerplate = (node: any) => { // increment output[3].push(`${privateStateName}_newCommitment.integer`); if (stateNode.encryptionRequired) { - output[4].push(`${privateStateName}_cipherText`); - output[5].push(`${privateStateName}_encKey`); + output[4].push(`${ privateStateName }_cipherText`); + output[5].push(`${ privateStateName }_encKey`); } break; } @@ -131,36 +128,45 @@ export const generateProofBoilerplate = (node: any) => { if (stateNode.encryptionRequired) { stateNode.structProperties ? cipherTextLength.push(stateNode.structProperties.length + 2) : cipherTextLength.push(3); enc[0] ??= []; - enc[0].push(`const ${stateName}_cipherText = res.inputs.slice(START_SLICE, END_SLICE).map(e => generalise(e).integer);`); + enc[0].push(`const ${ stateName }_cipherText = res.inputs.slice(START_SLICE, END_SLICE).map(e => generalise(e).integer);`); enc[1] ??= []; - enc[1].push(`const ${stateName}_encKey = res.inputs.slice(START_SLICE END_SLICE).map(e => generalise(e).integer);`); + enc[1].push(`const ${ stateName }_encKey = res.inputs.slice(START_SLICE END_SLICE).map(e => generalise(e).integer);`); } + + let stateVarIdLines: string[] = []; const parameters: string[] = []; - // we include the state variable key (mapping key) if its not a param (we include params separately) - const msgSenderParamAndMappingKey = stateNode.isMapping && (node.parameters.includes('msgSender') || output.join().includes('_msg_stateVarId_key.integer')) && stateNode.stateVarId[1] === 'msg'; - const msgValueParamAndMappingKey = stateNode.isMapping && (node.parameters.includes('msgValue') || output.join().includes('_msg_stateVarId_key.integer')) && stateNode.stateVarId[1] === 'msg'; - - const constantMappingKey = stateNode.isMapping && (+stateNode.stateVarId[1] || stateNode.stateVarId[1] === '0'); - const stateVarIdLines = - stateNode.isMapping && !node.parameters.includes(stateNode.stateVarId[1]) && !msgSenderParamAndMappingKey && !msgValueParamAndMappingKey && !constantMappingKey - ? [`\n\t\t\t\t\t\t\t\t${stateName}_stateVarId_key.integer,`] - : []; + + if (stateNode.stateVarId[0]) stateNode.stateVarId.forEach((svid, index) => { + if (index !== 0) { + // we include the state variable key (mapping key) if its not a param (we include params separately) + const msgSenderParamAndMappingKey = stateNode.isMapping && (node.parameters.includes('msgSender') || output.join().includes('msgSender.integer')) && svid.includes('msg'); + const msgValueParamAndMappingKey = stateNode.isMapping && (node.parameters.includes('msgValue') || output.join().includes('msgValue.integer')) && svid.includes('msg'); + const constantMappingKey = stateNode.isMapping && (+svid || svid === '0'); + if (!node.parameters.includes(stateNode.originalMappingKeyName[index - 1] || svid) + && !msgSenderParamAndMappingKey + && !msgValueParamAndMappingKey + && !constantMappingKey + ) { + stateVarIdLines.push(`\n\t\t\t\t\t\t\t\t${ stateNode.originalMappingKeyName[index - 1] }.integer,`); + } + } + }); // we add any extra params the circuit needs node.parameters .filter( (para: string) => !privateStateNames.includes(para) && ( - !output.join().includes(`${para}.integer`) && !output.join().includes('msgValue')), + !output.join().includes(`${ para }.integer`) && !output.join().includes('msgValue')), ) ?.forEach((param: string) => { if (param == 'msgSender') { - parameters.unshift(`\t${param}.integer,`); - } + parameters.unshift(`\t${ param }.integer,`); + } else if (param == 'msgValue') { - parameters.unshift(`\t${param},`); + parameters.unshift(`\t${ param },`); } else { - parameters.push(`\t${param}.integer,`); + parameters.push(`\t${ param }.integer,`); } }); @@ -196,12 +202,12 @@ export const generateProofBoilerplate = (node: any) => { stateNode.increment.forEach((inc: any) => { // +inc.name tries to convert into a number - we don't want to add constants here if ( - !output.join().includes(`\t${inc.name}.integer`) && - !parameters.includes(`\t${inc.name}.integer,`) && + !output.join().includes(`\t${ inc.name }.integer`) && + !parameters.includes(`\t${ inc.name }.integer,`) && !privateStateNames.includes(inc.name) && !inc.accessed && !+inc.name ) - output.push(`\n\t\t\t\t\t\t\t\t${inc.name}.integer`); + output.push(`\n\t\t\t\t\t\t\t\t${ inc.name }.integer`); }); output.push( Orchestrationbp.generateProof.parameters({ @@ -228,14 +234,14 @@ export const generateProofBoilerplate = (node: any) => { if (stateNode.structProperties) stateNode.increment = Object.values(stateNode.increment).flat(Infinity); stateNode.increment.forEach((inc: any) => { if ( - !output.join().includes(`\t${inc.name}.integer`) && - !parameters.includes(`\t${inc.name}.integer,`) && !inc.accessed && + !output.join().includes(`\t${ inc.name }.integer`) && + !parameters.includes(`\t${ inc.name }.integer,`) && !inc.accessed && !+inc.name ) - output.push(`\n\t\t\t\t\t\t\t\t${inc.name}.integer`); + output.push(`\n\t\t\t\t\t\t\t\t${ inc.name }.integer`); }); output.push( - Orchestrationbp.generateProof.parameters( { + Orchestrationbp.generateProof.parameters({ stateName, stateType: 'increment', stateVarIds: stateVarIdLines, @@ -256,7 +262,7 @@ export const generateProofBoilerplate = (node: any) => { } // we now want to go backwards and calculate where our cipherText is let start = 0; - for (let i = cipherTextLength.length -1; i >= 0; i--) { + for (let i = cipherTextLength.length - 1; i >= 0; i--) { // extract enc key enc[1][i] = start === 0 ? enc[1][i].replace('END_SLICE', '') : enc[1][i].replace('END_SLICE', ', ' + start); enc[1][i] = enc[1][i].replace('START_SLICE', start - 2); @@ -279,11 +285,11 @@ export const preimageBoilerPlate = (node: any) => { for ([privateStateName, stateNode] of Object.entries(node.privateStates)) { const stateVarIds = stateVariableIds({ privateStateName, stateNode }); const initialiseParams: string[] = []; - const preimageParams:string[] = []; + const preimageParams: string[] = []; if (stateNode.accessedOnly) { output.push( Orchestrationbp.readPreimage.postStatements({ - stateName:privateStateName, + stateName: privateStateName, contractName: node.contractName, stateType: 'whole', mappingName: null, @@ -300,45 +306,45 @@ export const preimageBoilerPlate = (node: any) => { continue; } - initialiseParams.push(`\nlet ${privateStateName}_prev = generalise(0);`); - preimageParams.push(`\t${privateStateName}: 0,`); + initialiseParams.push(`\nlet ${ privateStateName }_prev = generalise(0);`); + preimageParams.push(`\t${ privateStateName }: 0,`); // ownership (PK in commitment) const newOwner = stateNode.isOwned ? stateNode.owner : null; let newOwnerStatment: string; switch (newOwner) { case null: - newOwnerStatment = `_${privateStateName}_newOwnerPublicKey === 0 ? publicKey : ${privateStateName}_newOwnerPublicKey;`; + newOwnerStatment = `_${ privateStateName }_newOwnerPublicKey === 0 ? publicKey : ${ privateStateName }_newOwnerPublicKey;`; break; case 'msg': if (privateStateName.includes('msg')) { newOwnerStatment = `publicKey;`; } else if (stateNode.mappingOwnershipType === 'key') { // the stateVarId[1] is the mapping key - newOwnerStatment = `generalise(await instance.methods.zkpPublicKeys(${stateNode.stateVarId[1]}.hex(20)).call()); // address should be registered`; + newOwnerStatment = `generalise(await instance.methods.zkpPublicKeys(${ stateNode.stateVarId[1] }.hex(20)).call()); // address should be registered`; } else if (stateNode.mappingOwnershipType === 'value') { // TODO test below // if the private state is an address (as here) its still in eth form - we need to convert - newOwnerStatment = `await instance.methods.zkpPublicKeys(${privateStateName}.hex(20)).call(); - \nif (${privateStateName}_newOwnerPublicKey === 0) { + newOwnerStatment = `await instance.methods.zkpPublicKeys(${ privateStateName }.hex(20)).call(); + \nif (${ privateStateName }_newOwnerPublicKey === 0) { console.log('WARNING: Public key for given eth address not found - reverting to your public key'); - ${privateStateName}_newOwnerPublicKey = publicKey; + ${ privateStateName }_newOwnerPublicKey = publicKey; } - \n${privateStateName}_newOwnerPublicKey = generalise(${privateStateName}_newOwnerPublicKey);`; + \n${ privateStateName }_newOwnerPublicKey = generalise(${ privateStateName }_newOwnerPublicKey);`; } else { - newOwnerStatment = `_${privateStateName}_newOwnerPublicKey === 0 ? publicKey : ${privateStateName}_newOwnerPublicKey;`; + newOwnerStatment = `_${ privateStateName }_newOwnerPublicKey === 0 ? publicKey : ${ privateStateName }_newOwnerPublicKey;`; } break; default: // TODO - this is the case where the owner is an admin (state var) // we have to let the user submit the key and check it in the contract if (!stateNode.ownerIsSecret && !stateNode.ownerIsParam) { - newOwnerStatment = `_${privateStateName}_newOwnerPublicKey === 0 ? generalise(await instance.methods.zkpPublicKeys(await instance.methods.${newOwner}().call()).call()) : ${privateStateName}_newOwnerPublicKey;`; + newOwnerStatment = `_${ privateStateName }_newOwnerPublicKey === 0 ? generalise(await instance.methods.zkpPublicKeys(await instance.methods.${ newOwner }().call()).call()) : ${ privateStateName }_newOwnerPublicKey;`; } else if (stateNode.ownerIsParam && newOwner) { - newOwnerStatment = `_${privateStateName}_newOwnerPublicKey === 0 ? ${newOwner} : ${privateStateName}_newOwnerPublicKey;`; + newOwnerStatment = `_${ privateStateName }_newOwnerPublicKey === 0 ? ${ newOwner } : ${ privateStateName }_newOwnerPublicKey;`; } else { // is secret - we just use the users to avoid revealing the secret owner - newOwnerStatment = `_${privateStateName}_newOwnerPublicKey === 0 ? publicKey : ${privateStateName}_newOwnerPublicKey;` + newOwnerStatment = `_${ privateStateName }_newOwnerPublicKey === 0 ? publicKey : ${ privateStateName }_newOwnerPublicKey;` // BELOW reveals the secret owner as we check the public key in the contract // `_${privateStateName}_newOwnerPublicKey === 0 ? generalise(await instance.methods.zkpPublicKeys(${newOwner}.hex(20)).call()) : ${privateStateName}_newOwnerPublicKey;` @@ -377,7 +383,7 @@ export const preimageBoilerPlate = (node: any) => { stateType: 'decrement', mappingName: stateNode.mappingName || privateStateName, mappingKey: stateNode.mappingKey - ? `[${privateStateName}_stateVarId_key.integer]` + ? `[${ privateStateName }_stateVarId_key.integer]` : ``, increment: stateNode.increment, structProperties: stateNode.structProperties, @@ -393,7 +399,7 @@ export const preimageBoilerPlate = (node: any) => { default: // increment output.push( - Orchestrationbp.readPreimage.postStatements({ + Orchestrationbp.readPreimage.postStatements({ stateName: privateStateName, contractName: node.contractName, stateType: 'increment', @@ -422,7 +428,7 @@ export const preimageBoilerPlate = (node: any) => { export const OrchestrationCodeBoilerPlate: any = (node: any) => { const lines: any[] = []; - const params:any[] = []; + const params: any[] = []; const states: string[] = []; const rtnparams: string[] = []; let stateName: string; @@ -430,78 +436,83 @@ export const OrchestrationCodeBoilerPlate: any = (node: any) => { switch (node.nodeType) { case 'Imports': - return { statements: Orchestrationbp.generateProof.import() } + return { statements: Orchestrationbp.generateProof.import() } case 'FunctionDefinition': // the main function + // reset for each new fn + msgSenderAdded = false; if (node.name !== 'cnstrctr') lines.push( `\n\n// Initialisation of variables: - \nconst instance = await getContractInstance('${node.contractName}'); - \nconst contractAddr = await getContractAddress('${node.contractName}'); `, + \nconst instance = await getContractInstance('${ node.contractName }'); + \nconst contractAddr = await getContractAddress('${ node.contractName }'); `, ); - if (node.msgSenderParam) + if (node.msgSenderParam) { lines.push(` \nconst msgSender = generalise(config.web3.options.defaultAccount);`); + msgSenderAdded = true; + } + if (node.msgValueParam) lines.push(` \nconst msgValue = 1;`); - else - lines.push(` - \nconst msgValue = 0;`); + // else + // lines.push(` + // \nconst msgValue = 0;`); node.inputParameters.forEach((param: string) => { - lines.push(`\nconst ${param} = generalise(_${param});`); - params.push(`_${param}`); + lines.push(`\nconst ${ param } = generalise(_${ param });`); + params.push(`_${ param }`); }); node.parameters.modifiedStateVariables.forEach((param: any) => { - states.push(`_${param.name}_newOwnerPublicKey = 0`); + states.push(`_${ param.name }_newOwnerPublicKey = 0`); lines.push( - `\nlet ${param.name}_newOwnerPublicKey = generalise(_${param.name}_newOwnerPublicKey);`, + `\nlet ${ param.name }_newOwnerPublicKey = generalise(_${ param.name }_newOwnerPublicKey);`, ); }); if (node.decrementsSecretState) { node.decrementedSecretStates.forEach((decrementedState: string) => { - states.push(` _${decrementedState}_0_oldCommitment = 0`); - states.push(` _${decrementedState}_1_oldCommitment = 0`); + states.push(` _${ decrementedState }_0_oldCommitment = 0`); + states.push(` _${ decrementedState }_1_oldCommitment = 0`); }); } - node.returnParameters.forEach( (param, index) => { - if(param === 'true') - rtnparams?.push('bool: bool'); - else if(param?.includes('Commitment')) - rtnparams?.push( ` ${param} : ${param}.integer `); - else - rtnparams.push(` ${param} :${param}.integer`); - }); + node.returnParameters.forEach((param, index) => { + if (param === 'true') + rtnparams?.push('bool: bool'); + else if (param?.includes('Commitment')) + rtnparams?.push(` ${ param } : ${ param }.integer `); + else + rtnparams.push(` ${ param } :${ param }.integer`); + }); if (params) params[params.length - 1] += `,`; if (node.name === 'cnstrctr') return { signature: [ - `\nexport default async function ${node.name}(${params} ${states}) {`, + `\nexport default async function ${ node.name }(${ params } ${ states }) {`, `\nprocess.exit(0); \n}`, ], statements: lines, }; - if(rtnparams.length == 0) { - return { - signature: [ - `\nexport default async function ${node.name}(${params} ${states}) {`, - `\n return { tx, encEvent }; + if (rtnparams.length == 0) { + return { + signature: [ + `\nexport default async function ${ node.name }(${ params } ${ states }) {`, + `\n return { tx, encEvent }; \n}`, - ], - statements: lines, - }; - } + ], + statements: lines, + }; + } - if(rtnparams.includes('bool: bool')) { + if (rtnparams.includes('bool: bool')) { return { signature: [ - `\nexport default async function ${node.name}(${params} ${states}) {`, - `\n const bool = true; \n return { tx, encEvent, ${rtnparams} }; + `\nexport default async function ${ node.name }(${ params } ${ states }) {`, + `\n const bool = true; \n return { tx, encEvent, ${ rtnparams } }; \n}`, ], statements: lines, @@ -510,8 +521,8 @@ export const OrchestrationCodeBoilerPlate: any = (node: any) => { return { signature: [ - `\nexport default async function ${node.name}(${params} ${states}) {`, - `\nreturn { tx, encEvent, ${rtnparams} }; + `\nexport default async function ${ node.name }(${ params } ${ states }) {`, + `\nreturn { tx, encEvent, ${ rtnparams } }; \n}`, ], statements: lines, @@ -523,7 +534,7 @@ export const OrchestrationCodeBoilerPlate: any = (node: any) => { switch (stateNode.mappingKey) { case 'msg': // msg.sender => key is _newOwnerPublicKey - mappingKey = `[${stateName}_stateVarId_key.integer]`; + mappingKey = `[${ stateName }_stateVarId_key.integer]`; break; case null: case undefined: @@ -533,17 +544,17 @@ export const OrchestrationCodeBoilerPlate: any = (node: any) => { default: if (+stateNode.mappingKey || stateNode.mappingKey === '0') { // we have a constant number - mappingKey = `[${stateNode.mappingKey}]`; + mappingKey = `[${ stateNode.mappingKey }]`; } else { // any other => a param or accessed var - mappingKey = `[${stateNode.mappingKey}.integer]`; + mappingKey = `[${ stateNode.mappingKey }.integer]`; } } lines.push( - Orchestrationbp.initialisePreimage.preStatements( { + Orchestrationbp.initialisePreimage.preStatements({ stateName, accessedOnly: stateNode.accessedOnly, - stateVarIds: stateVariableIds({ privateStateName: stateName, stateNode}), + stateVarIds: stateVariableIds({ privateStateName: stateName, stateNode }), mappingKey, mappingName: stateNode.mappingName || stateName, structProperties: stateNode.structProperties @@ -558,9 +569,9 @@ export const OrchestrationCodeBoilerPlate: any = (node: any) => { states[0] = node.onChainKeyRegistry ? `true` : `false`; return { statements: [ - `${Orchestrationbp.initialiseKeys.postStatements( - node.contractName, - states[0], + `${ Orchestrationbp.initialiseKeys.postStatements( + node.contractName, + states[0], ) }`, ], }; @@ -568,7 +579,7 @@ export const OrchestrationCodeBoilerPlate: any = (node: any) => { case 'ReadPreimage': lines[0] = preimageBoilerPlate(node); return { - statements: [`${params.join('\n')}`, lines[0].join('\n')], + statements: [`${ params.join('\n') }`, lines[0].join('\n')], }; case 'WritePreimage': @@ -584,7 +595,7 @@ export const OrchestrationCodeBoilerPlate: any = (node: any) => { stateType: 'decrement', mappingName: stateNode.mappingName || stateName, mappingKey: stateNode.mappingKey - ? `${stateName}_stateVarId_key.integer` + ? `${ stateName }_stateVarId_key.integer` : ``, burnedOnly: false, structProperties: stateNode.structProperties, @@ -594,12 +605,12 @@ export const OrchestrationCodeBoilerPlate: any = (node: any) => { case false: default: lines.push( - Orchestrationbp.writePreimage.postStatements({ + Orchestrationbp.writePreimage.postStatements({ stateName, stateType: 'increment', - mappingName:stateNode.mappingName || stateName, + mappingName: stateNode.mappingName || stateName, mappingKey: stateNode.mappingKey - ? `${stateName}_stateVarId_key.integer` + ? `${ stateName }_stateVarId_key.integer` : ``, burnedOnly: false, structProperties: stateNode.structProperties, @@ -611,12 +622,12 @@ export const OrchestrationCodeBoilerPlate: any = (node: any) => { case false: default: lines.push( - Orchestrationbp.writePreimage.postStatements({ + Orchestrationbp.writePreimage.postStatements({ stateName, stateType: 'whole', mappingName: stateNode.mappingName || stateName, mappingKey: stateNode.mappingKey - ? `${stateName}_stateVarId_key.integer` + ? `${ stateName }_stateVarId_key.integer` : ``, burnedOnly: stateNode.burnedOnly, structProperties: stateNode.structProperties, @@ -635,11 +646,11 @@ export const OrchestrationCodeBoilerPlate: any = (node: any) => { for ([stateName, stateNode] of Object.entries(node.privateStates)) { if (node.isConstructor) { lines.push([` - const ${stateName}_index = generalise(0); - const ${stateName}_root = generalise(0); - const ${stateName}_path = generalise(new Array(32).fill(0)).all;\n + const ${ stateName }_index = generalise(0); + const ${ stateName }_root = generalise(0); + const ${ stateName }_path = generalise(new Array(32).fill(0)).all;\n `]); - continue; + continue; } if (stateNode.isPartitioned) { lines.push( @@ -740,8 +751,9 @@ export const OrchestrationCodeBoilerPlate: any = (node: any) => { case undefined: case false: lines.push( - Orchestrationbp.calculateCommitment.postStatements( { + Orchestrationbp.calculateCommitment.postStatements({ stateName, + newCommitmentValue: stateNode.newCommitmentValue, stateType: 'whole', structProperties: stateNode.structProperties, })); @@ -753,8 +765,9 @@ export const OrchestrationCodeBoilerPlate: any = (node: any) => { case true: // decrement lines.push( - Orchestrationbp.calculateCommitment.postStatements( { + Orchestrationbp.calculateCommitment.postStatements({ stateName, + newCommitmentValue: null, stateType: 'decrement', structProperties: stateNode.structProperties, })); @@ -764,8 +777,9 @@ export const OrchestrationCodeBoilerPlate: any = (node: any) => { default: // increment lines.push( - Orchestrationbp.calculateCommitment.postStatements( { + Orchestrationbp.calculateCommitment.postStatements({ stateName, + newCommitmentValue: null, stateType: 'increment', structProperties: stateNode.structProperties, })); @@ -778,17 +792,17 @@ export const OrchestrationCodeBoilerPlate: any = (node: any) => { }; case 'GenerateProof': - [ lines[0], params[0] ] = generateProofBoilerplate(node); + [lines[0], params[0]] = generateProofBoilerplate(node); return { statements: [ `\n\n// Call Zokrates to generate the proof: \nconst allInputs = [`, - `${lines[0]}`, - `\nconst res = await generateProof('${node.circuitName}', allInputs);`, + `${ lines[0] }`, + `\nconst res = await generateProof('${ node.circuitName }', allInputs);`, `\nconst proof = generalise(Object.values(res.proof).flat(Infinity)) .map(coeff => coeff.integer) .flat(Infinity);`, - `${params[0].flat(Infinity).join('\n')}` + `${ params[0].flat(Infinity).join('\n') }` ], }; @@ -796,9 +810,9 @@ export const OrchestrationCodeBoilerPlate: any = (node: any) => { if (node.publicInputs[0]) { node.publicInputs.forEach((input: any) => { if (input.properties) { - lines.push(`[${input.properties.map(p => `${input.name}.${p}.integer`).join(',')}]`) + lines.push(`[${ input.properties.map(p => `${ input.name }.${ p }.integer`).join(',') }]`) } else - lines.push(`${input}.integer`); + lines.push(`${ input }.integer`); }); lines[lines.length - 1] += `, `; } @@ -815,8 +829,6 @@ export const OrchestrationCodeBoilerPlate: any = (node: any) => { if (params[0][3][0]) params[0][3] = `[${params[0][3]}],`; // commitments - array if (params[0][4][0]) params[0][4] = `[${params[0][4]}],`; // cipherText - array of arrays if (params[0][5][0]) params[0][5] = `[${params[0][5]}],`; // cipherText - array of arrays - - if (node.functionName === 'cnstrctr') return { statements: [ `\n\n// Save transaction for the constructor: @@ -828,7 +840,7 @@ export const OrchestrationCodeBoilerPlate: any = (node: any) => { statements: [ `\n\n// Send transaction to the blockchain: \nconst txData = await instance.methods - .${node.functionName}(${lines}${params[0][0]} ${params[0][1]} ${params[0][2]} ${params[0][3]} ${params[0][4]} ${params[0][5]} proof).encodeABI(); + .${ node.functionName }(${ lines }${ params[0][0] } ${ params[0][1] } ${ params[0][2] } ${ params[0][3] } ${ params[0][4] } ${ params[0][5] } proof).encodeABI(); \n let txParams = { from: config.web3.options.defaultAccount, to: contractAddr, diff --git a/src/codeGenerators/orchestration/nodejs/toOrchestration.ts b/src/codeGenerators/orchestration/nodejs/toOrchestration.ts index 3a36b0cd4..67a83e0ba 100644 --- a/src/codeGenerators/orchestration/nodejs/toOrchestration.ts +++ b/src/codeGenerators/orchestration/nodejs/toOrchestration.ts @@ -76,7 +76,7 @@ export default function codeGenerator(node: any, options: any = {}): any { if (!node.interactsWithSecret) return `\n// non-secret line would go here but has been filtered out`; if (node.initialValue?.nodeType === 'Assignment') { - if (node.declarations[0].isAccessed && node.declarations[0].isSecret) { + if (node.declarations[0].isAccessed && node.declarations[0].isSecret && node.declarations[0].declarationType === 'state') { return `${getAccessedValue( node.declarations[0].name, )}\n${codeGenerator(node.initialValue)};`; @@ -134,7 +134,7 @@ export default function codeGenerator(node: any, options: any = {}): any { } return `${codeGenerator(node.leftHandSide, { lhs: true })} ${ node.operator - } ${codeGenerator(node.rightHandSide)}`; + } ${codeGenerator(node.rightHandSide, { rhs: true })}`; case 'BinaryOperation': return `${codeGenerator(node.leftExpression, { lhs: options.condition })} ${ @@ -192,8 +192,10 @@ export default function codeGenerator(node: any, options: any = {}): any { return `${codeGenerator(node.arguments)}`; case 'UnaryOperation': + if (options?.rhs) return `generalise(parseInt(${node.subExpression.name}.integer, 10) ${node.operator.includes('+') ? `+ 1` : `- 1`})`; // ++ or -- on a parseInt() does not work return `generalise(${node.subExpression.name}.integer${node.operator})`; + // return `generalise(parseInt(${node.subExpression.name}.integer, 10) ${node.operator.includes('+') ? ``})`; case 'Literal': return node.value; diff --git a/src/transformers/visitors/checks/interactsWithSecretVisitor.ts b/src/transformers/visitors/checks/interactsWithSecretVisitor.ts index 3bea6a785..31b052ba0 100644 --- a/src/transformers/visitors/checks/interactsWithSecretVisitor.ts +++ b/src/transformers/visitors/checks/interactsWithSecretVisitor.ts @@ -1,4 +1,5 @@ /* eslint-disable no-param-reassign, no-unused-vars */ +import { ZKPError } from '../../../error/errors.js'; import NodePath from '../../../traverse/NodePath.js'; @@ -72,4 +73,16 @@ export default { } }, }, + + MemberAccess: { + exit(path: NodePath) { + const { node } = path; + const { expression, memberName } = node; + if (memberName === 'length' && (path.containsSecret || path.scope.getReferencedIndicator(expression)?.interactsWithSecret)) { + throw new ZKPError(`We can't loop a dynamic number of times when secret states are involved due to the fixed constraint system required. + If the .length here is constant, please use the constant value instead.`, node); + } + + } + }, }; diff --git a/src/transformers/visitors/common.ts b/src/transformers/visitors/common.ts index 65416f3b5..e594e0259 100644 --- a/src/transformers/visitors/common.ts +++ b/src/transformers/visitors/common.ts @@ -56,12 +56,14 @@ export const interactsWithSecretVisitor = (thisPath: NodePath, thisState: any) = }; export const getIndexAccessName = (node: any) => { - if (node.nodeType == 'MemberAccess') return `${node.expression.name}.${node.memberName}`; + if (node.nodeType == 'MemberAccess') return `${node.expression.name || getIndexAccessName(node.expression)}.${node.memberName}`; if (node.nodeType == 'IndexAccess') { + let baseName = node.baseExpression.name; const mappingKeyName = NodePath.getPath(node).scope.getMappingKeyName(node); + if (node.baseExpression.nodeType === 'IndexAccess') baseName = getIndexAccessName(node.baseExpression); if(mappingKeyName == 'msg') - return `${node.baseExpression.name}_${(mappingKeyName).replaceAll('.', 'dot').replace('[', '_').replace(']', '')}${node.indexExpression.memberName.replace('sender','Sender').replace('value','Value')}`; - return `${node.baseExpression.name}_${(mappingKeyName).replaceAll('.', 'dot').replace('[', '_').replace(']', '')}`; + return `${baseName}_${(mappingKeyName).replaceAll('.', 'dot').replace('[', '_').replace(']', '')}${node.indexExpression.memberName.replace('sender','Sender').replace('value','Value')}`; + return `${baseName}_${(mappingKeyName).replaceAll('.', 'dot').replace('[', '_').replace(']', '')}`; } return null; } diff --git a/src/transformers/visitors/orchestrationInternalFunctionCallVisitor.ts b/src/transformers/visitors/orchestrationInternalFunctionCallVisitor.ts index 5507da153..a347c08ac 100644 --- a/src/transformers/visitors/orchestrationInternalFunctionCallVisitor.ts +++ b/src/transformers/visitors/orchestrationInternalFunctionCallVisitor.ts @@ -95,7 +95,7 @@ const internalCallVisitor = { }); for(const [index, oldStateName] of oldStateArray.entries()) { node.initialValue.leftHandSide.name = node.initialValue.leftHandSide.name.replace('_'+oldStateName, '_'+ state.newStateArray[index]); - node.initialValue.rightHandSide.name = node.initialValue.rightHandSide.name.replace(oldStateName, state.newStateArray[index]); + node.initialValue.rightHandSide.name = node.initialValue.rightHandSide?.name?.replace(oldStateName, state.newStateArray[index]); } } if(node.nodeType === 'Assignment'){ diff --git a/src/transformers/visitors/ownership/errorChecksVisitor.ts b/src/transformers/visitors/ownership/errorChecksVisitor.ts index d589b17d4..2ffc55bc3 100644 --- a/src/transformers/visitors/ownership/errorChecksVisitor.ts +++ b/src/transformers/visitors/ownership/errorChecksVisitor.ts @@ -119,15 +119,15 @@ export default { let idInLoopExpression = []; traverseNodesFast(loopExpression, miniIdVisitor, idInLoopExpression); - const miniMappingVisitor = (thisNode: any) => { - if (thisNode.nodeType !== 'IndexAccess') return; - const key = path.getMappingKeyIdentifier(thisNode); - if (!key.referencedDeclaration) return; - if (idInLoopExpression.includes(key.referencedDeclaration)) - throw new ZKPError(`The mapping ${thisNode.baseExpression.name} is being accessed by the loop expression ${key.name}, which means we are editing as many secret states as there are loop iterations. This is not currently supported due to the computation involved.`, thisNode); - }; - - traverseNodesFast(body, miniMappingVisitor); + // const miniMappingVisitor = (thisNode: any) => { + // if (thisNode.nodeType !== 'IndexAccess') return; + // const key = path.getMappingKeyIdentifier(thisNode); + // if (!key.referencedDeclaration) return; + // if (idInLoopExpression.includes(key.referencedDeclaration)) + // throw new ZKPError(`The mapping ${thisNode.baseExpression.name} is being accessed by the loop expression ${key.name}, which means we are editing as many secret states as there are loop iterations. This is not currently supported due to the computation involved.`, thisNode); + // }; + + // traverseNodesFast(body, miniMappingVisitor); if ((condition.containsSecret || initializationExpression.containsSecret || loopExpression.containsSecret) && body.containsPublic) { throw new TODOError(`This For statement edits a public state based on a secret condition, which currently isn't supported.`, path.node); diff --git a/src/transformers/visitors/toCircuitVisitor.ts b/src/transformers/visitors/toCircuitVisitor.ts index c24646bb2..571983b0c 100644 --- a/src/transformers/visitors/toCircuitVisitor.ts +++ b/src/transformers/visitors/toCircuitVisitor.ts @@ -29,7 +29,7 @@ const publicInputsVisitor = (thisPath: NodePath, thisState: any) => { //Check if for-if statements are both together. if(thisPath.getAncestorContainedWithin('condition') && thisPath.getAncestorOfType('IfStatement') && thisPath.getAncestorOfType('ForStatement')){ //Currently We only support if statements inside a for loop no the other way around, so getting the public inputs according to inner if statement - if((thisPath.getAncestorOfType('IfStatement')).getAncestorOfType('ForStatement')) + if((thisPath.getAncestorOfType('IfStatement'))?.getAncestorOfType('ForStatement')) isForCondition = isCondition; } // below: we have a public state variable we need as a public input to the circuit @@ -576,15 +576,29 @@ let childOfSecret = path.getAncestorOfType('ForStatement')?.containsSecret; const tempRHSPath = cloneDeep(rhsPath); const tempRHSParent = tempRHSPath.parent; + let mappingKeyName = ``; + if (path.isNestedMapping(lhs)) { + let topMapping = NodePath.getPath(lhs); + while (topMapping.parentPath.getAncestorOfType('IndexAccess')) { + topMapping = topMapping.parentPath.getAncestorOfType('IndexAccess'); + } + + while (topMapping.node.baseExpression) { + mappingKeyName = scope.getMappingKeyName(topMapping.node) + `${mappingKeyName === `` ? `` : `/` + mappingKeyName}` ; + topMapping = NodePath.getPath(topMapping.node.baseExpression); + } + } else if (lhsIndicator.isMapping) { + mappingKeyName = scope.getMappingKeyName(lhs) || + lhs.indexExpression?.name || + lhs.indexExpression.expression.name; + } if (isDecremented) { newNode = buildNode('BoilerplateStatement', { bpType: 'decrementation', indicators: lhsIndicator, subtrahendId: rhs.id, ...(lhsIndicator.isMapping && { - mappingKeyName: scope.getMappingKeyName(lhs) || - lhs.indexExpression?.name || - lhs.indexExpression.expression.name, + mappingKeyName }), // TODO: tidy this }); tempRHSPath.containerName = 'subtrahend'; // a dangerous bodge that works @@ -596,9 +610,7 @@ let childOfSecret = path.getAncestorOfType('ForStatement')?.containsSecret; indicators: lhsIndicator, addendId: rhs.id, ...(lhsIndicator.isMapping && { - mappingKeyName: scope.getMappingKeyName(lhs) || - lhs.indexExpression?.name || - lhs.indexExpression.expression.name, + mappingKeyName }), // TODO: tidy this }); tempRHSPath.containerName = 'addend'; // a dangerous bodge that works @@ -969,6 +981,9 @@ let childOfSecret = path.getAncestorOfType('ForStatement')?.containsSecret; const { node, parent , parentPath } = path; const { value } = node; + // skip require msg warnings + if (path.getAncestorOfType('FunctionCall')?.node.requireStatementPrivate && node.kind === 'string') return; + if (node.kind !== 'number' && node.kind !== 'bool' && !path.getAncestorOfType('Return')) throw new Error( `Only literals of kind "number" are currently supported. Found literal of kind '${node.kind}'. Please open an issue.`, diff --git a/src/transformers/visitors/toContractVisitor.ts b/src/transformers/visitors/toContractVisitor.ts index 4a38689e4..ff3e7cceb 100644 --- a/src/transformers/visitors/toContractVisitor.ts +++ b/src/transformers/visitors/toContractVisitor.ts @@ -329,7 +329,7 @@ export default { ); if (path.scope.containsSecret) - postStatements.push( + postStatements.unshift( ...buildNode('FunctionBoilerplate', { bpSection: 'postStatements', scope, @@ -569,12 +569,26 @@ export default { const { node, parent } = path; const newNode = buildNode('ExpressionStatement'); node._newASTPointer = newNode; - parentnewASTPointer(parent, path, newNode , parent._newASTPointer[path.containerName]); + // We check whether this statement edits a public state which we need in the circuit + // If so, we need the edit to go AFTER verification + let moveToPost: boolean = false; + path.traversePathsFast(p => { + // console.log(p.node) + // console.log(p.isModification()) + if (p.isModification() && + p.getReferencedNode().stateVariable && + !p.getReferencedBinding().isSecret && + p.scope.getReferencedIndicator(p.node)?.interactsWithSecret) { + moveToPost = true; + } + }, {}); + // the 'statements' check is to ensure we are in a block nodetype + parentnewASTPointer(parent, path, newNode, moveToPost && path.containerName === 'statements' ? parent._newASTPointer.postStatements : parent._newASTPointer[path.containerName]); }, }, EventDefinition: { - enter(path: NodePath , state:any) { + enter(path: NodePath, state:any) { const { node, parent } = path; state.functionName = path.getUniqueFunctionName() const newNode = buildNode('EventDefinition', { diff --git a/src/transformers/visitors/toOrchestrationVisitor.ts b/src/transformers/visitors/toOrchestrationVisitor.ts index 451a5e93e..43acaa5e9 100644 --- a/src/transformers/visitors/toOrchestrationVisitor.ts +++ b/src/transformers/visitors/toOrchestrationVisitor.ts @@ -1,6 +1,6 @@ /* eslint-disable no-param-reassign, no-shadow, no-unused-vars, no-continue */ import NodePath from '../../traverse/NodePath.js'; -import { StateVariableIndicator, FunctionDefinitionIndicator } from '../../traverse/Indicator.js'; +import { StateVariableIndicator, FunctionDefinitionIndicator, LocalVariableIndicator } from '../../traverse/Indicator.js'; import { VariableBinding } from '../../traverse/Binding.js'; import MappingKey from '../../traverse/MappingKey.js'; import cloneDeep from 'lodash.clonedeep'; @@ -156,8 +156,13 @@ const addPublicInput = (path: NodePath, state: any) => { if (expNode) { expNode.interactsWithSecret = true; const moveExpNode = cloneDeep(expNode); - delete statements[statements.indexOf(expNode)]; + const i = statements.indexOf(expNode); + delete statements[i]; fnDefNode.node._newASTPointer.body.preStatements.push(moveExpNode); + const nextNode = cloneDeep(statements[i + 1]); + if (nextNode?.nodeType === 'Assignment' && nextNode.leftHandSide.name === nextNode.rightHandSide.name && nextNode.rightHandSide.subType === 'generalNumber') { + fnDefNode.node._newASTPointer.body.preStatements.push(nextNode); + } } } }); @@ -417,7 +422,15 @@ const visitor = { for (const [, mappingKey] of Object.entries( stateVarIndicator.mappingKeys || {} )) { - allIndicators.push(mappingKey); + if (mappingKey.mappingKeys) { + for (const [, innerMappingKey] of Object.entries( + mappingKey.mappingKeys || {} + )) { + allIndicators.push(innerMappingKey); + } + } else { + allIndicators.push(mappingKey); + } } } else if (stateVarIndicator instanceof StateVariableIndicator) { allIndicators.push(stateVarIndicator); @@ -435,14 +448,21 @@ const visitor = { stateVarIndicator.container?.isAccessed && !stateVarIndicator.container?.isModified; secretModified = stateVarIndicator.container?.isSecret && stateVarIndicator.container?.isModified; - id = [id, scope.getMappingKeyName(stateVarIndicator.keyPath.node) || ``]; - - name = (accessedOnly ? - getIndexAccessName(stateVarIndicator.accessedPaths[stateVarIndicator.accessedPaths.length -1]?.getAncestorOfType('IndexAccess')?.node) : - stateVarIndicator.container?.isModified ? - getIndexAccessName(stateVarIndicator.modifyingPaths[stateVarIndicator.modifyingPaths.length -1].getAncestorOfType('IndexAccess')?.node) : - getIndexAccessName(stateVarIndicator.referencingPaths[stateVarIndicator.referencingPaths.length -1].getAncestorOfType('IndexAccess')?.node)) - || ''; + + if (stateVarIndicator.container instanceof MappingKey && stateVarIndicator.isChild && !stateVarIndicator.container.isParent) { + // probably a better way to discover that we have a nested mapping, but it will do for now + id = [id, stateVarIndicator.container.referencedKeyName, stateVarIndicator.referencedKeyName]; + name = getIndexAccessName(stateVarIndicator.referencingPaths[stateVarIndicator.referencingPaths.length -1].getAncestorOfType('IndexAccess')?.parentPath.getAncestorOfType('IndexAccess')?.node) || ``; + } else { + id = [id, scope.getMappingKeyName(stateVarIndicator.keyPath.node) || ``]; + + name = (accessedOnly ? + getIndexAccessName(stateVarIndicator.accessedPaths[stateVarIndicator.accessedPaths.length -1]?.getAncestorOfType('IndexAccess')?.node) : + stateVarIndicator.container?.isModified ? + getIndexAccessName(stateVarIndicator.modifyingPaths[stateVarIndicator.modifyingPaths.length -1].getAncestorOfType('IndexAccess')?.node) : + getIndexAccessName(stateVarIndicator.referencingPaths[stateVarIndicator.referencingPaths.length -1].getAncestorOfType('IndexAccess')?.node)) + || ''; + } } let { incrementsArray, incrementsString } = isIncremented @@ -451,8 +471,9 @@ const visitor = { if (!incrementsString) incrementsString = null; if (!incrementsArray) incrementsArray = null; - if (accessedOnly || (stateVarIndicator.isWhole && functionIndicator.oldCommitmentAccessRequired)) { - if(stateVarIndicator.isSecret || stateVarIndicator.node.interactsWithSecret) + if (accessedOnly || (stateVarIndicator.isWhole && functionIndicator.oldCommitmentAccessRequired) && stateVarIndicator.isSecret) { + + newNodes.initialisePreimageNode.privateStates[ name ] = buildPrivateStateNode('InitialisePreimage', { @@ -631,34 +652,57 @@ const visitor = { newNodes.InitialiseKeysNode, ); - // OR they are local variable declarations we need for initialising preimage... + // OR they are local variable declarations/ accessed secret states we need for initialising preimage... let localVariableDeclarations: any[] = []; newFunctionDefinitionNode.body.statements.forEach((n, index) => { - if (n.nodeType === 'VariableDeclarationStatement' && n.declarations[0].declarationType === 'localStack') + if (n.nodeType === 'VariableDeclarationStatement' + && (n.declarations[0].declarationType === 'localStack' + || n.declarations[0].isSecret && n.declarations[0].isAccessed)) localVariableDeclarations.push({node: cloneDeep(n), index}); }); - + let toSplice = earliestPublicAccessIndex + 2; if (localVariableDeclarations[0]) { localVariableDeclarations.forEach(n => { const localIndicator = scope.indicators[n.node.declarations[0].id]; const indexExpressionPath = localIndicator.referencingPaths.find(p => p.getAncestorContainedWithin('indexExpression') && p.getAncestorOfType('IndexAccess')?.node.containsSecret ); - if (indexExpressionPath) { + if (indexExpressionPath && !indexExpressionPath.isNestedMapping()) { // we have found a local variable which is used as an indexExpression, so we need it before we get the mapping value // NB if there are multiple, we have just found the first one const varDecComesAfter = scope.getReferencedIndicator( NodePath.getPath(localIndicator.node)?.getCorrespondingRhsNode(), true ); - if (!varDecComesAfter) { + if (!varDecComesAfter || (varDecComesAfter instanceof LocalVariableIndicator && varDecComesAfter.isParam)) { // here, we don't need to worry about defining anything first, so we push this local var to the top newFunctionDefinitionNode.body.preStatements.splice( - earliestPublicAccessIndex + 2, + toSplice, 0, n.node, ); - } else { + toSplice += 1; + if (localIndicator.isSecret && n.node.declarations[0].declarationType === 'state') { + // we shift up the init preimage node if this accessed state is secret + const name = n.node.declarations[0].name; + const initPIndex = newFunctionDefinitionNode.body.preStatements.findIndex( + (nd: any) => + nd.nodeType === 'InitialisePreimage' && + nd.privateStates[name] + ); + const newInitPreimageNode = { nodeType: 'InitialisePreimage', privateStates: {}}; + newInitPreimageNode.privateStates[name] = cloneDeep( + newFunctionDefinitionNode.body.preStatements[initPIndex].privateStates[name] + ); + delete newFunctionDefinitionNode.body.preStatements[initPIndex].privateStates[name]; + newFunctionDefinitionNode.body.preStatements.splice( + earliestPublicAccessIndex + 2, + 0, + newInitPreimageNode, + ); + toSplice +=1; + } + } else if (n.node.declarations[0].declarationType !== 'state') { // now we have to split initPreimage const varDecComesBefore = scope.getReferencedIndicator( indexExpressionPath.getAncestorOfType('IndexAccess').node.baseExpression, true @@ -709,7 +753,15 @@ const visitor = { 0, newInitPreimageNode2, ); - + } + const nextNode = cloneDeep(newFunctionDefinitionNode.body.statements[n.index + 1]); + if (nextNode?.nodeType === 'Assignment' && nextNode.leftHandSide.name === nextNode.rightHandSide.name && nextNode.rightHandSide.subType === 'generalNumber') { + newFunctionDefinitionNode.body.preStatements.splice( + toSplice, + 0, + nextNode, + ); + toSplice += 1; } delete newFunctionDefinitionNode.body.statements[n.index]; } @@ -983,8 +1035,12 @@ const visitor = { path.traversePathsFast(interactsWithSecretVisitor, newState); const { interactsWithSecret } = newState; - let indicator; - let name; + const fnDefNode = path.getAncestorOfType('FunctionDefinition')?.node; + let indicator: any; + let name: string = ``; + let firstInstanceOfNewName: boolean = false; + let thisName: string = ``; + let accessedBeforeModification: boolean = false; // we mark this to grab anything we need from the db / contract state.interactsWithSecret = interactsWithSecret; // ExpressionStatements can contain an Assignment node. @@ -993,7 +1049,7 @@ const visitor = { if (!lhs) lhs = node.expression.subExpression; indicator = scope.getReferencedIndicator(lhs, true); - const name = indicator.isMapping + name = indicator.isMapping ? indicator.name .replace('[', '_') .replace(']', '') @@ -1014,9 +1070,9 @@ const visitor = { // collect all index names const names = indicator.referencingPaths.map((p: NodePath) => ({ name: scope.getIdentifierMappingKeyName(p.node), id: p.node.id })).filter(n => n.id <= lhs.id); - + thisName = scope.getIdentifierMappingKeyName(lhs); // check whether this is the first instance of a new index name - const firstInstanceOfNewName = names.length > 1 && names[names.length - 1].name !== names[names.length - 2].name; + firstInstanceOfNewName = names.findIndex(n => n.name === thisName) === names.length - 1; // check whether this should be a VariableDeclaration const firstEdit = @@ -1039,7 +1095,7 @@ const visitor = { }); // we still need to initialise accessed states if they were accessed _before_ this modification - const accessedBeforeModification = indicator.isAccessed && indicator.accessedPaths[0].node.id < lhs.id && !indicator.accessedPaths[0].isModification(); + accessedBeforeModification = indicator.isAccessed && indicator.accessedPaths[0].node.id < lhs.id && !indicator.accessedPaths[0].isModification(); if (accessedBeforeModification || path.isInSubScope()) accessed = true; @@ -1050,6 +1106,7 @@ const visitor = { name: indicator.isStruct && !indicator.isMapping ? lhs.name : name, isAccessed: accessed, isSecret: indicator.isSecret, + oldASTId: indicator.id, }), ], interactsWithSecret: true, @@ -1059,10 +1116,13 @@ const visitor = { if (accessedBeforeModification || path.isInSubScope()) { // we need to initialise an accessed state // or declare it outside of this subscope e.g. if statement - const fnDefNode = path.getAncestorOfType('FunctionDefinition')?.node; - delete newNode.initialValue; - fnDefNode._newASTPointer.body.statements.unshift(newNode); - } else { + if (!fnDefNode._newASTPointer.body.statements.find(n => n?.nodeType === newNode.nodeType && n.declarations[0]?.name === newNode.declarations[0].name)) { + // prevent duplicates + delete newNode.initialValue; + fnDefNode._newASTPointer.body.statements.unshift(newNode); + } + + } else if (!node.expression.incrementedDeclaration || !indicator.isPartitioned){ node._newASTPointer = newNode; parent._newASTPointer.push(newNode); return; @@ -1087,10 +1147,30 @@ const visitor = { } } if (node.expression.expression?.name !== 'require') { - const newNode = buildNode(node.nodeType, { - interactsWithSecret: interactsWithSecret || indicator?.interactsWithSecret, - oldASTId: node.id, - }); + let newNode; + if (!fnDefNode._newASTPointer.body.statements.find( + n => n.nodeType === 'VariableDeclarationStatement' && (n.declarations[0]?.name === thisName || name?.split('.')[0].length > thisName?.split('.')[0].length ? n.id === node.id : false) + ) && firstInstanceOfNewName) { + newNode = buildNode('VariableDeclarationStatement', { + oldASTId: node.id, + declarations: [ + buildNode('VariableDeclaration', { + name: indicator.isStruct && !indicator.isMapping ? thisName : name, + isSecret: indicator.isSecret, + declarationType: 'local', + isAccessed: accessedBeforeModification, + oldASTId: indicator.id, + }), + ], + interactsWithSecret: true, + }); + if (indicator.isStruct) newNode.declarations[0].isStruct = true; + } else { + newNode = buildNode(node.nodeType, { + interactsWithSecret: interactsWithSecret || indicator?.interactsWithSecret, + oldASTId: node.id, + }); + } node._newASTPointer = newNode; if (Array.isArray(parent._newASTPointer) || (!path.isInSubScope() && Array.isArray(parent._newASTPointer[path.containerName]))) { @@ -1107,29 +1187,26 @@ const visitor = { const { node, scope } = path; const { leftHandSide: lhs } = node.expression; const indicator = scope.getReferencedIndicator(lhs, true); - const name = indicator?.isMapping - ? indicator.name - .replace('[', '_') - .replace(']', '') - .replace('.sender', 'Sender') - .replace('.value', 'Value') - .replace('.', 'dot') - : indicator?.name || lhs?.name; + if (!lhs) return; + const name = indicator?.isMapping // wrong? + ? getIndexAccessName(lhs) + : scope.getIdentifierMappingKeyName(lhs) || indicator?.name || lhs?.name; // reset delete state.interactsWithSecret; if (node._newASTPointer?.incrementsSecretState && indicator) { const increments = collectIncrements(indicator).incrementsString; path.node._newASTPointer.increments = increments; } else if (indicator?.isWhole && node._newASTPointer) { + // we add a general number statement after each whole state edit - if (node._newASTPointer.interactsWithSecret) path.getAncestorOfType('FunctionDefinition')?.node._newASTPointer.body.statements.push( - buildNode('Assignment', { - leftHandSide: buildNode('Identifier', { name }), - operator: '=', - rightHandSide: buildNode('Identifier', { name, subType: 'generalNumber' }) - } - ) - ); + const newNode = buildNode('Assignment', { + leftHandSide: buildNode('Identifier', { name }), + operator: '=', + rightHandSide: buildNode('Identifier', { name, subType: 'generalNumber' }) + }); + newNode.interactsWithSecret = node._newASTPointer.interactsWithSecret; + if (node._newASTPointer.interactsWithSecret) + path.getAncestorOfType('FunctionDefinition')?.node._newASTPointer.body.statements.push(newNode); } if (node._newASTPointer?.interactsWithSecret && path.getAncestorOfType('ForStatement')) { @@ -1265,7 +1342,8 @@ const visitor = { }); newNode.isStruct = true; } - parent._newASTPointer[path.containerName] = newNode; + + if (parent._newASTPointer) parent._newASTPointer[path.containerName] = newNode; state.skipSubNodes = true; // the subnodes are ElementaryTypeNames }, diff --git a/src/traverse/Binding.ts b/src/traverse/Binding.ts index 9777fbbb8..b80e3d38b 100644 --- a/src/traverse/Binding.ts +++ b/src/traverse/Binding.ts @@ -256,7 +256,7 @@ export class VariableBinding extends Binding { // A binding will be updated if (some time after its creation) we encounter an AST node which refers to this binding's variable. // E.g. if we encounter an Identifier node. update(path: NodePath) { - if (this.isMapping) { + if (this.isMapping && path.getAncestorOfType('IndexAccess')) { this.addMappingKey(path).updateProperties(path); } else if (this.isStruct && path.getAncestorOfType('MemberAccess')) { this.addStructProperty(path).updateProperties(path); @@ -365,10 +365,15 @@ export class VariableBinding extends Binding { this.isWholeReason ??= []; this.isWholeReason.push(reason); - if (this.isMapping) { + if (this.isMapping && path.getAncestorOfType('IndexAccess')) { this.addMappingKey(path).isAccessed = true; this.addMappingKey(path).accessedPaths ??= []; this.addMappingKey(path).accessedPaths.push(path); + if (this.addMappingKey(path).mappingKeys) { + this.addMappingKey(path).addMappingKey(path).isAccessed = true; + this.addMappingKey(path).addMappingKey(path).accessedPaths ??= []; + this.addMappingKey(path).addMappingKey(path).accessedPaths.push(path); + } } if (this.isStruct && path.getAncestorOfType('MemberAccess')) { diff --git a/src/traverse/Indicator.ts b/src/traverse/Indicator.ts index 5c815854f..f61e225d5 100644 --- a/src/traverse/Indicator.ts +++ b/src/traverse/Indicator.ts @@ -289,11 +289,11 @@ export class LocalVariableIndicator extends FunctionDefinitionIndicator { if (path.isModification()) { this.addModifyingPath(path); } - if (this.isStruct && path.getAncestorOfType('MemberAccess')) { - this.addStructProperty(path).updateProperties(path); - } else if (this.isMapping) { + if (this.isMapping && path.getAncestorOfType('IndexAccess')) { this.addMappingKey(path).updateProperties(path); - } + } else if (this.isStruct && path.getAncestorOfType('MemberAccess')) { + this.addStructProperty(path).updateProperties(path); + } } updateProperties(path: NodePath) { @@ -476,7 +476,7 @@ export class StateVariableIndicator extends FunctionDefinitionIndicator { // A StateVariableIndicator will be updated if (some time after its creation) we encounter an AST node which refers to this state variable. // E.g. if we encounter an Identifier node. update(path: NodePath) { - if (this.isMapping) { + if (this.isMapping && (path.getAncestorOfType('IndexAccess'))) { this.addMappingKey(path).updateProperties(path); } else if (this.isStruct && path.getAncestorOfType('MemberAccess')) { this.addStructProperty(path).updateProperties(path); @@ -574,6 +574,11 @@ export class StateVariableIndicator extends FunctionDefinitionIndicator { this.addMappingKey(path).isAccessed = true; this.addMappingKey(path).accessedPaths ??= []; this.addMappingKey(path).accessedPaths.push(path); + if (this.addMappingKey(path).mappingKeys) { + this.addMappingKey(path).addMappingKey(path).isAccessed = true; + this.addMappingKey(path).addMappingKey(path).accessedPaths ??= []; + this.addMappingKey(path).addMappingKey(path).accessedPaths.push(path); + } } if (this.isStruct && path.getAncestorOfType('MemberAccess')) { @@ -787,6 +792,12 @@ export class StateVariableIndicator extends FunctionDefinitionIndicator { const mappingKeys: [string, MappingKey][] = Object.entries(this.mappingKeys ? this.mappingKeys : {}); for (const [, mappingKey] of mappingKeys) { mappingKey.newCommitmentsRequired = true; + if (mappingKey.mappingKeys) { + const innerMappingKeys: [string, MappingKey][] = Object.entries(mappingKey.mappingKeys); + for (const [, innerMappingKey] of innerMappingKeys) { + innerMappingKey.newCommitmentsRequired = true; + } + } } } if (this.isStruct) { diff --git a/src/traverse/MappingKey.ts b/src/traverse/MappingKey.ts index 09e2b8fb8..5f5c55084 100644 --- a/src/traverse/MappingKey.ts +++ b/src/traverse/MappingKey.ts @@ -1,5 +1,5 @@ import NodePath from './NodePath.js'; -import { Binding } from './Binding.js'; +import { Binding, VariableBinding } from './Binding.js'; import { StateVariableIndicator } from './Indicator.js'; import logger from '../utils/logger.js'; import { SyntaxUsageError, ZKPError } from '../error/errors.js'; @@ -54,6 +54,7 @@ export default class MappingKey { isParent?: boolean; isChild?: boolean; structProperties?: {[key: string]: any}; + mappingKeys?: {[key: string]: MappingKey}; isKnown?:boolean; isUnknown?:boolean; @@ -140,6 +141,35 @@ export default class MappingKey { this.nullifyingPaths = []; // array of paths of `Identifier` nodes which nullify this binding } + addMappingKey(referencingPath: NodePath): MappingKey { + // we assume that if we're here we have a nested mapping + // input is the index of the inner mapping + const parentMap = referencingPath.getAncestorOfType('IndexAccess')?.parentPath.getAncestorOfType('IndexAccess'); + if (!parentMap) throw new Error('No nested mapping - we have for some reason assumed there was a nested mapping, but havent found it'); + const keyNode = parentMap.getMappingKeyIdentifier(); + const keyPath = NodePath.getPath(keyNode); + if (!keyPath) throw new Error('No keyPath found in pathCache'); + + if (!['Identifier', 'MemberAccess', 'Literal'].includes(keyNode.nodeType)) { + throw new Error( + `A mapping key of nodeType '${keyNode.nodeType}' isn't supported yet. We've only written the code for keys of nodeType Identifier'`, + ); + } + + // naming of the key within mappingKeys: + const keyName = parentMap.scope.getMappingKeyName(parentMap); + + // add this mappingKey if it hasn't yet been added: + this.mappingKeys ??= {}; + const mappingKeyExists = !!this.mappingKeys[keyName]; + if (!mappingKeyExists) + this.mappingKeys[keyName] = new MappingKey(this, keyPath); + + this.mappingKeys[keyName].isChild = true; + + return this.mappingKeys[keyName]; + } + addStructProperty(referencingPath: NodePath): MappingKey { this.isParent = true; this.isStruct = true; @@ -152,10 +182,14 @@ export default class MappingKey { } updateProperties(path: NodePath) { + const parentMap = path.getAncestorOfType('IndexAccess')?.parentPath.getAncestorOfType('IndexAccess'); if (this.isMapping && this.node.typeDescriptions.typeString.includes('struct ') && !this.isChild && path.getAncestorOfType('MemberAccess')) { // in mapping[key].property, the node for .property is actually a parent value, so we need to make sure this isnt already a child of a mappingKey this.addStructProperty(path).updateProperties(path); + } else if (this.isMapping && !this.isChild && parentMap) { + // we have a nested mapping + this.addMappingKey(path).updateProperties(path); } this.addReferencingPath(path); this.isUnknown ??= path.node.isUnknown; @@ -180,6 +214,13 @@ export default class MappingKey { updateEncryption(options?: any) { // no new commitments => nothing to encrypt if (!this.newCommitmentsRequired) return; + if (this.mappingKeys) { + const mappingKeys: [string, MappingKey][] = Object.entries(this.mappingKeys ? this.mappingKeys : {}); + for (const [, mappingKey] of mappingKeys) { + mappingKey.updateEncryption() + } + return; + } // decremented only => no new commitments to encrypt if (this.isPartitioned && this.isDecremented && this.nullificationCount === this.referenceCount) return; // find whether enc for this scope only has been opted in @@ -250,6 +291,13 @@ export default class MappingKey { state.decrements.forEach((dec: any) => { this.decrements.push(dec); }); + + if (this.mappingKeys) { + this.addMappingKey(state.incrementedPath).updateIncrementation( + path, + state, + ); + } } // TODO: move into commonFunctions (because it's the same function as included in the Binding class) @@ -272,11 +320,13 @@ export default class MappingKey { this.isNullified = true; ++this.nullificationCount; this.nullifyingPaths.push(path); + if (this.mappingKeys) this.addMappingKey(path).addNullifyingPath(path); } addBurningPath(path: NodePath) { this.isBurned = true; this.burningPaths.push(path); + if (this.mappingKeys) this.addMappingKey(path).addBurningPath(path); } addSecretInteractingPath(path: NodePath) { @@ -343,7 +393,11 @@ export default class MappingKey { updateFromBinding() { // it's possible we dont know in this fn scope whether a state is whole/owned or not, but the binding (contract scope) will - const container = this.container instanceof Binding ? this.container : this.container.binding; + // const container = this.container instanceof Binding ? this.container : this.container.binding; + let { container } = this; + while (!(container instanceof VariableBinding)) { + container = container.container || container.binding; + } this.isWhole ??= container.isWhole; this.isWholeReason = this.isWhole ? container.isWholeReason @@ -356,5 +410,11 @@ export default class MappingKey { this.owner ??= container.owner; this.mappingOwnershipType = this.owner?.mappingOwnershipType; this.onChainKeyRegistry ??= container.onChainKeyRegistry; + if (this.mappingKeys) { + const mappingKeys: [string, MappingKey][] = Object.entries(this.mappingKeys ? this.mappingKeys : {}); + for (const [, mappingKey] of mappingKeys) { + mappingKey.updateFromBinding(); + } + } } } diff --git a/src/traverse/NodePath.ts b/src/traverse/NodePath.ts index ddef2aed5..b972c5809 100644 --- a/src/traverse/NodePath.ts +++ b/src/traverse/NodePath.ts @@ -834,6 +834,38 @@ export default class NodePath { return this.isMappingDeclaration(node) || this.isMappingIdentifier(node); } + /** + * Checks whether a node is a nested mapping. + * @param {node} node (optional - defaults to this.node) + * @returns {Boolean} + */ + isNestedMapping(node: any = this.node): boolean { + + const path = NodePath.getPath(node) || this; + if (!this.isMapping(node) && !path.getAncestorOfType('IndexAccess')) return false; + /** Nested mappings look like: + baseExpression: { + baseExpression: { + name: 'parent', + }, + indexExpression: { + name: 'innerKey', + }, + nodeType: 'IndexAccess', + }, + indexExpression: { + name: 'outerKey', + } + for parent[innerKey][outerKey] + */ + if (node.nodeType !== 'IndexAccess' && path.getAncestorOfType('IndexAccess')) { + return path.getAncestorOfType('IndexAccess').isNestedMapping(); + } + if (node.baseExpression?.nodeType === 'IndexAccess') return true; + if (path.parentPath.getAncestorOfType('IndexAccess')) return true; + return false; + } + /** * A mapping's key will contain an Identifier node pointing to a previously-declared variable. * @param {Object} - the mapping's index access node. @@ -975,6 +1007,31 @@ export default class NodePath { return memberAccNode && memberAccNode.node.baseExpression?.typeDescriptions?.typeIdentifier.includes('array'); } + /** + * Checks whether a node is of an array type. + * @param {node} node (optional - defaults to this.node) + * @returns {Boolean} + */ + isConstantArray(node: any = this.node): boolean { + if (!this.isArray(node)) return false; + let arrLen; + switch (node.nodeType) { + case 'IndexAccess': + arrLen = node.baseExpression.typeDescriptions.typeString.match(/(?<=\[)(\d+)(?=\])/); + break; + case 'Identifier': + default: + arrLen = node.typeDescriptions.typeString.match(/(?<=\[)(\d+)(?=\])/); + break; + } + if (!arrLen) return false; + for (const match of arrLen) { + // tries to convert to a number + if (+match) return true; + } + return false; + } + /** * Checks whether a node is a VariableDeclaration of a Mapping. * @param {node} node (optional - defaults to this.node) @@ -1009,7 +1066,7 @@ export default class NodePath { return ( !(this.queryAncestors(path => path.containerName === 'indexExpression')) && !this.getAncestorOfType('FunctionCall') && !this.getAncestorContainedWithin('initialValue') && - this.getLhsAncestor(true) && !(this.queryAncestors(path => path.containerName === 'condition') || this.queryAncestors(path => path.containerName === 'initializationExpression') || this.queryAncestors(path => path.containerName === 'loopExpression')) + !this.getAncestorContainedWithin('rightHandSide') && this.getLhsAncestor(true) && !(this.queryAncestors(path => path.containerName === 'condition') || this.queryAncestors(path => path.containerName === 'initializationExpression') || this.queryAncestors(path => path.containerName === 'loopExpression')) ); default: return false; @@ -1037,7 +1094,7 @@ export default class NodePath { id = referencingNode.referencedDeclaration; break; case 'IndexAccess': - id = referencingNode.baseExpression.referencedDeclaration; + id = referencingNode.baseExpression.referencedDeclaration || this.getReferencedDeclarationId(referencingNode.baseExpression); break; case 'MemberAccess': id = referencingNode.expression.referencedDeclaration || this.getReferencedDeclarationId(referencingNode.expression); diff --git a/src/traverse/Scope.ts b/src/traverse/Scope.ts index 101e8247d..a4bbaab58 100644 --- a/src/traverse/Scope.ts +++ b/src/traverse/Scope.ts @@ -427,15 +427,20 @@ export class Scope { : indicator; } + if ((path.isConstantArray(referencingNode) || referencingNode.memberName === 'length') && !NodePath.getPath(referencingNode).getAncestorOfType('IndexAccess')) return indicator; + + // getMappingKeyName requires an indexAccessNode - referencingNode may be a baseExpression or indexExpression contained Identifier - const indexAccessNode = - referencingNode.nodeType === 'IndexAccess' - ? referencingNode - : NodePath.getPath(referencingNode).getAncestorOfType('IndexAccess') - .node; + const indexAccessPath = referencingNode.nodeType === 'IndexAccess' + ? NodePath.getPath(referencingNode) + : NodePath.getPath(referencingNode).getAncestorOfType('IndexAccess'); + const indexAccessNode = indexAccessPath.node; + + if (!indicator.mappingKeys[this.getMappingKeyName(indexAccessPath)] && indexAccessNode.baseExpression.nodeType === 'IndexAccess') + return this.getReferencedIndicator(indexAccessNode.baseExpression, mappingKeyIndicatorOnly); return mappingKeyIndicatorOnly - ? indicator.mappingKeys[this.getMappingKeyName(indexAccessNode)] + ? indicator.mappingKeys[this.getMappingKeyName(indexAccessPath)] : indicator; } @@ -677,11 +682,11 @@ export class Scope { if (refPaths && thisIndex && refPaths[thisIndex]?.key === 'indexExpression') return this.getMappingKeyName(refPaths[thisIndex].getAncestorOfType('IndexAccess')); let { name } = identifierNode; - + // we find the next indexExpression after this identifier for (let i = thisIndex || 0; i < (refPaths?.length || 0); i++) { - if (refPaths?.[i].key !== 'indexExpression' || !thisIndex) continue; - if (refPaths[thisIndex].isModification() && !forceNotModification) { + if (refPaths?.[i]?.key !== 'indexExpression' || !thisIndex) continue; + if (refPaths[thisIndex]?.isModification() && !forceNotModification) { name = this.getMappingKeyName(refPaths[i].getAncestorOfType('IndexAccess')); break; // if this identifier is not a modification, we need the previous indexExpression diff --git a/test/contracts/nested-mapping.zol b/test/contracts/nested-mapping.zol new file mode 100644 index 000000000..fd6414adf --- /dev/null +++ b/test/contracts/nested-mapping.zol @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: CC0 + +pragma solidity ^0.8.0; + +contract Nesting { + + secret mapping(uint256 => uint256[]) public parent; + + function deposit(secret uint256 child, secret uint256 index, uint256 amountDeposit) public { + unknown parent[child][index] += amountDeposit; + + } + + function transfer(secret uint256 child, secret uint256 recipient, secret uint256 index, secret uint256 amount) public { + parent[child][index] -= amount; + unknown parent[recipient][0] += amount; + + } +}