diff --git a/src/boilerplate/circuit/zokrates/nodes/BoilerplateGenerator.ts b/src/boilerplate/circuit/zokrates/nodes/BoilerplateGenerator.ts index d44994037..591274fee 100644 --- a/src/boilerplate/circuit/zokrates/nodes/BoilerplateGenerator.ts +++ b/src/boilerplate/circuit/zokrates/nodes/BoilerplateGenerator.ts @@ -45,7 +45,7 @@ const collectIncrements = (bpg: BoilerplateGenerator) => { } if (inc.nodeType === 'MemberAccess') inc.name ??= `${inc.expression.name}.${inc.memberName}`; if (!inc.name) inc.name = inc.value; - + let modName = inc.modName ? inc.modName : inc.name; if (incrementsArray.some(existingInc => inc.name === existingInc.name)) continue; incrementsArray.push({ @@ -54,9 +54,9 @@ const collectIncrements = (bpg: BoilerplateGenerator) => { }); if (inc === stateVarIndicator.increments[0]) { - incrementsString += `${inc.name}`; + incrementsString += `${modName}`; } else { - incrementsString += ` ${inc.precedingOperator} ${inc.name}`; + incrementsString += ` ${inc.precedingOperator} ${modName}`; } } for (const dec of stateVarIndicator.decrements) { @@ -72,17 +72,18 @@ const collectIncrements = (bpg: BoilerplateGenerator) => { if (!dec.name) dec.name = dec.value; if (incrementsArray.some(existingInc => dec.name === existingInc.name)) continue; + let modName = dec.modName ? dec.modName : dec.name; incrementsArray.push({ name: dec.name, precedingOperator: dec.precedingOperator, }); if (!stateVarIndicator.decrements[1] && !stateVarIndicator.increments[0]) { - incrementsString += `${dec.name}`; + incrementsString += `${modName}`; } else { // if we have decrements, this str represents the value we must take away // => it's a positive value with +'s - incrementsString += ` + ${dec.name}`; + incrementsString += ` + ${modName}`; } } return { incrementsArray, incrementsString }; diff --git a/src/boilerplate/circuit/zokrates/raw/BoilerplateGenerator.ts b/src/boilerplate/circuit/zokrates/raw/BoilerplateGenerator.ts index 247c93efa..30d872680 100644 --- a/src/boilerplate/circuit/zokrates/raw/BoilerplateGenerator.ts +++ b/src/boilerplate/circuit/zokrates/raw/BoilerplateGenerator.ts @@ -342,7 +342,7 @@ class BoilerplateGenerator { `${structProperties.map(p => newCommitmentValue[p] === '0' ? '' : `assert(${x0}.${p} + ${x1}.${p} >= ${y[p]})`).join('\n')} // TODO: assert no under/overflows - ${typeName} ${x}_newCommitment_value = ${typeName} { ${structProperties.map(p => ` ${p}: (${x0}.${p} + ${x1}.${p}) - ${y[p]}`)} }` + ${typeName} ${x}_newCommitment_value = ${typeName} { ${structProperties.map(p => ` ${p}: (${x0}.${p} + ${x1}.${p}) - (${y[p]})`)} }` ); } } else { diff --git a/src/transformers/visitors/checks/incrementedVisitor.ts b/src/transformers/visitors/checks/incrementedVisitor.ts index 4b10c8cff..2c7bff04f 100644 --- a/src/transformers/visitors/checks/incrementedVisitor.ts +++ b/src/transformers/visitors/checks/incrementedVisitor.ts @@ -24,11 +24,12 @@ const literalOneNode = { const collectIncrements = (increments: any, incrementedIdentifier: any) => { const { operands, precedingOperator } = increments; const newIncrements: any[] = []; + const Idname = incrementedIdentifier.name || incrementedIdentifier.expression?.name; for (const [index, operand] of operands.entries()) { operand.precedingOperator = precedingOperator[index]; if ( - operand.name !== incrementedIdentifier.name && - operand.baseExpression?.name !== incrementedIdentifier.name && + operand.name !== Idname && + operand.baseExpression?.name !== Idname && !newIncrements.some(inc => inc.id === operand.id) ) newIncrements.push(operand); diff --git a/src/transformers/visitors/toCircuitVisitor.ts b/src/transformers/visitors/toCircuitVisitor.ts index 012ce94ef..7ebf558af 100644 --- a/src/transformers/visitors/toCircuitVisitor.ts +++ b/src/transformers/visitors/toCircuitVisitor.ts @@ -3,7 +3,7 @@ import cloneDeep from 'lodash.clonedeep'; import { buildNode } from '../../types/zokrates-types.js'; import { TODOError } from '../../error/errors.js'; -import { traversePathsFast } from '../../traverse/traverse.js'; +import { traversePathsFast, traverseNodesFast } from '../../traverse/traverse.js'; import NodePath from '../../traverse/NodePath.js'; import explode from './explode.js'; import internalCallVisitor from './circuitInternalFunctionCallVisitor.js'; @@ -14,8 +14,51 @@ import { interactsWithSecretVisitor, internalFunctionCallVisitor, parentnewASTPo + +// Adjusts names of indicator.increment so that they match the names of the corresponding indicators i.e. index_1 instead of index +const incrementNames = (node: any, indicator: any) => { + if (node.bpType === 'incrementation'){ + let rhsNode = node.addend; + const adjustIncrementsVisitor = (thisNode: any) => { + if (thisNode.nodeType === 'Identifier'){ + if (!indicator.increments.some((inc: any) => inc.name === thisNode.name)){ + let lastUnderscoreIndex = thisNode.name.lastIndexOf("_"); + let origName = thisNode.name.substring(0, lastUnderscoreIndex); + let count =0; + indicator.increments.forEach((inc: any) => { + if (origName === inc.name && !inc.modName && count === 0){ + inc.modName = thisNode.name; + count++; + } + }); + } + } + } + if (rhsNode) traverseNodesFast(rhsNode, adjustIncrementsVisitor); + } else if (node.bpType === 'decrementation'){ + let rhsNode = node.subtrahend; + const adjustDecrementsVisitor = (thisNode: any) => { + if (thisNode.nodeType === 'Identifier'){ + if (!indicator.decrements.some((dec: any) => dec.name === thisNode.name)){ + let lastUnderscoreIndex = thisNode.name.lastIndexOf("_"); + let origName = thisNode.name.substring(0, lastUnderscoreIndex); + let count =0; + indicator.decrements.forEach((dec: any) => { + if (origName === dec.name && !dec.modName && count === 0){ + dec.modName = thisNode.name; + count++; + } + }); + } + } + } + if (rhsNode) traverseNodesFast(rhsNode, adjustDecrementsVisitor); + } +}; + + // public variables that interact with the secret also need to be modified within the circuit. -const publicVariablesVisitor = (path: NodePath, state: any, IDnode: any) => { +const publicVariables = (path: NodePath, state: any, IDnode: any) => { const {parent, node } = path; // Break if the identifier is a mapping or array. if ( parent.indexExpression && parent.baseExpression === node ) { @@ -35,7 +78,6 @@ const publicVariablesVisitor = (path: NodePath, state: any, IDnode: any) => { if (!fnDefNode) throw new Error(`Not in a function`); const modifiedBeforePaths = path.scope.getReferencedIndicator(node, true)?.modifyingPaths?.filter((p: NodePath) => p.node.id < node.id); - const statements = fnDefNode.node._newASTPointer.body.statements; let num_modifiers=0; @@ -131,9 +173,27 @@ const publicVariablesVisitor = (path: NodePath, state: any, IDnode: any) => { fnDefNode.node._newASTPointer.body.statements.push(endNode); } // We no longer need this because index expression nodes are not input. - //if (['Identifier', 'IndexAccess'].includes(node.indexExpression?.nodeType)) publicVariablesVisitor(NodePath.getPath(node.indexExpression), state, null); + //if (['Identifier', 'IndexAccess'].includes(node.indexExpression?.nodeType)) publicVariables(NodePath.getPath(node.indexExpression), state, null); } +//Visitor for publicVariables +const publicVariablesVisitor = (thisPath: NodePath, thisState: any) => { + const { node } = thisPath; + let { name } = node; + if (!['Identifier', 'IndexAccess'].includes(thisPath.nodeType)) return; + const binding = thisPath.getReferencedBinding(node); + if ( (binding instanceof VariableBinding) && !binding.isSecret && + binding.stateVariable && thisPath.getAncestorContainedWithin('rightHandSide') ){ + } else{ + name = thisPath.scope.getIdentifierMappingKeyName(node); + } + const newNode = buildNode( + node.nodeType, + { name, type: node.typeDescriptions?.typeString }, + ); + publicVariables(thisPath, thisState, newNode); +}; + // below stub will only work with a small subtree - passing a whole AST will always give true! // useful for subtrees like ExpressionStatements @@ -310,6 +370,7 @@ const visitor = { // a non secret function - we skip it for circuits state.skipSubNodes = true; } + }, exit(path: NodePath, state: any) { @@ -317,6 +378,7 @@ const visitor = { const { indicators } = scope; const newFunctionDefinitionNode = node._newASTPointer; + // We need to ensure the correctness of the circuitImport flag for each internal function call. The state may have been updated due to later function calls that modify the same secret state. let importStatementList: any; parent._newASTPointer.forEach((file: any) => { @@ -776,14 +838,12 @@ let childOfSecret = path.getAncestorOfType('ForStatement')?.containsSecret; case 'Assignment': { const { leftHandSide: lhs, rightHandSide: rhs } = expression; const lhsIndicator = scope.getReferencedIndicator(lhs); - if (!lhsIndicator?.isPartitioned) break; const rhsPath = NodePath.getPath(rhs); // We need to _clone_ the path, because we want to temporarily modify some of its properties for this traversal. For future AST transformations, we'll want to revert to the original path. const tempRHSPath = cloneDeep(rhsPath); const tempRHSParent = tempRHSPath.parent; - if (isDecremented) { newNode = buildNode('BoilerplateStatement', { bpType: 'decrementation', @@ -820,8 +880,15 @@ let childOfSecret = path.getAncestorOfType('ForStatement')?.containsSecret; tempRHSPath.traverse(visitor, { skipPublicInputs: true }); rhsPath.traversePathsFast(publicInputsVisitor, {}); + rhsPath.traversePathsFast(publicVariablesVisitor, {}); + path.traversePathsFast(p => { + if (p.node.nodeType === 'Identifier' && p.isStruct(p.node)){ + addStructDefinition(p); + } + }, state); state.skipSubNodes = true; parent._newASTPointer.push(newNode); + incrementNames(newNode, lhsIndicator); return; } default: @@ -897,12 +964,11 @@ let childOfSecret = path.getAncestorOfType('ForStatement')?.containsSecret; const fnDefNode = path.getAncestorOfType('FunctionDefinition'); // We ensure the original variable name is set to the initial value only at the end of the statements. //E.g index = index_init should only appear at the end of all the modifying statements. - let ind = fnDefNode.node._newASTPointer.body.statements.length - 2; - while (ind >= 0 && fnDefNode.node._newASTPointer.body.statements[ind].expression?.rightHandSide?.name && fnDefNode.node._newASTPointer.body.statements[ind].expression?.rightHandSide?.name.includes("_init")){ - let temp = fnDefNode.node._newASTPointer.body.statements[ind+1]; - fnDefNode.node._newASTPointer.body.statements[ind+1] = fnDefNode.node._newASTPointer.body.statements[ind]; - fnDefNode.node._newASTPointer.body.statements[ind] = temp; - ind--; + for (let i = fnDefNode.node._newASTPointer.body.statements.length - 1; i >= 0; i--) { + if (fnDefNode.node._newASTPointer.body.statements[i].isEndInit) { + let element = fnDefNode.node._newASTPointer.body.statements.splice(i, 1)[0]; + fnDefNode.node._newASTPointer.body.statements.push(element); + } } } }, @@ -1009,7 +1075,6 @@ let childOfSecret = path.getAncestorOfType('ForStatement')?.containsSecret; declarationType, }); - if (path.isStruct(node)) { state.structNode = addStructDefinition(path); newNode.typeName.name = state.structNode.name; @@ -1100,7 +1165,7 @@ let childOfSecret = path.getAncestorOfType('ForStatement')?.containsSecret; { name, type: node.typeDescriptions?.typeString }, ); if (path.isStruct(node)) addStructDefinition(path); - publicVariablesVisitor(path, state,newNode); + publicVariables(path, state,newNode); if (path.getAncestorOfType('IfStatement')) node._newASTPointer = newNode; // no pointer needed, because this is a leaf, so we won't be recursing any further. // UNLESS we must add and rename if conditionals @@ -1276,7 +1341,6 @@ let childOfSecret = path.getAncestorOfType('ForStatement')?.containsSecret; MemberAccess: { enter(path: NodePath, state: any) { const { parent, node } = path; - let newNode: any; if (path.isMsgSender()) { @@ -1311,7 +1375,7 @@ let childOfSecret = path.getAncestorOfType('ForStatement')?.containsSecret; const newNode = buildNode('IndexAccess'); if (path.isConstantArray(node) && (path.isLocalStackVariable(node) || path.isFunctionParameter(node))) newNode.isConstantArray = true; // We don't need this because index access expressions always contain identifiers. - //publicVariablesVisitor(path, state,newNode); + //publicVariables(path, state,newNode); node._newASTPointer = newNode; parent._newASTPointer[path.containerName] = newNode; }, diff --git a/src/transformers/visitors/toOrchestrationVisitor.ts b/src/transformers/visitors/toOrchestrationVisitor.ts index 21a6afda1..4c950ea59 100644 --- a/src/transformers/visitors/toOrchestrationVisitor.ts +++ b/src/transformers/visitors/toOrchestrationVisitor.ts @@ -35,10 +35,11 @@ const collectIncrements = (stateVarIndicator: StateVariableIndicator | MappingKe return structIncs; } for (const inc of stateVarIndicator.increments || []) { - if (inc.nodeType === 'IndexAccess' || inc.nodeType === 'MemberAccess') inc.name = getIndexAccessName(inc); if (!inc.name) inc.name = inc.value; - if (incrementsArray.some(existingInc => inc.name === existingInc.name)) + // Note: modName defined in circuit + let modName = inc.modName ? inc.modName : inc.name; + if (incrementsArray.some(existingInc => inc.name === existingInc.name )) continue; incrementsArray.push({ name: inc.name, @@ -48,17 +49,18 @@ const collectIncrements = (stateVarIndicator: StateVariableIndicator | MappingKe if (inc === stateVarIndicator.increments?.[0]) { incrementsString += inc.value - ? `parseInt(${inc.name}, 10)` - : `parseInt(${inc.name}.integer, 10)`; + ? `parseInt(${modName}, 10)` + : `parseInt(${modName}.integer, 10)`; } else { incrementsString += inc.value - ? ` ${inc.precedingOperator} parseInt(${inc.name}, 10)` - : ` ${inc.precedingOperator} parseInt(${inc.name}.integer, 10)`; + ? ` ${inc.precedingOperator} parseInt(${modName}, 10)` + : ` ${inc.precedingOperator} parseInt(${modName}.integer, 10)`; } } for (const dec of stateVarIndicator.decrements || []) { if (dec.nodeType === 'IndexAccess' || dec.nodeType === 'MemberAccess') dec.name = getIndexAccessName(dec); if (!dec.name) dec.name = dec.value; + let modName = dec.modName ? dec.modName : dec.name; if (incrementsArray.some(existingInc => dec.name === existingInc.name)) continue; incrementsArray.push({ @@ -68,14 +70,14 @@ const collectIncrements = (stateVarIndicator: StateVariableIndicator | MappingKe if (!stateVarIndicator.decrements?.[1] && !stateVarIndicator.increments?.[0]) { incrementsString += dec.value - ? `parseInt(${dec.name}, 10)` - : `parseInt(${dec.name}.integer, 10)`; + ? `parseInt(${modName}, 10)` + : `parseInt(${modName}.integer, 10)`; } else { // if we have decrements, this str represents the value we must take away // => it's a positive value with +'s incrementsString += dec.value - ? ` + parseInt(${dec.name}, 10)` - : ` + parseInt(${dec.name}.integer, 10)`; + ? ` + parseInt(${modName}, 10)` + : ` + parseInt(${modName}.integer, 10)`; } } return { incrementsArray, incrementsString }; @@ -676,10 +678,10 @@ const visitor = { state.wholeNullified ??= []; if (!state.wholeNullified.includes(name)) state.wholeNullified.push(name) } - - let { incrementsArray, incrementsString } = isIncremented + let increments = isIncremented ? collectIncrements(stateVarIndicator) : { incrementsArray: null, incrementsString: null }; + let {incrementsArray, incrementsString} = increments; if (!incrementsString) incrementsString = null; if (!incrementsArray) incrementsArray = null; @@ -1475,8 +1477,8 @@ const visitor = { // reset delete state.interactsWithSecret; if (node._newASTPointer?.incrementsSecretState && indicator) { - const increments = collectIncrements(indicator).incrementsString; - path.node._newASTPointer.increments = increments; + let increments = collectIncrements(indicator); + path.node._newASTPointer.increments = increments.incrementsString; } else if (indicator?.isWhole && node._newASTPointer) { // we add a general number statement after each whole state edit const tempNode = node._newASTPointer; diff --git a/src/traverse/Scope.ts b/src/traverse/Scope.ts index 6a8e34ea9..52b8e176b 100644 --- a/src/traverse/Scope.ts +++ b/src/traverse/Scope.ts @@ -756,10 +756,10 @@ export class Scope { if (keyBinding?.isModified) { let i = 0; // Consider each time the variable (which becomes the mapping's key) is edited throughout the scope: - for (const modifyingPath of keyBinding.modifyingPaths) { + let filteredModifyingPaths = keyBinding.modifyingPaths.filter(modifyingPath => modifyingPath.scope.scopeName === indexAccessPath.scope.scopeName); + for (const modifyingPath of filteredModifyingPaths) { // we have found the 'current' state (relative to the input node), so we don't need to move any further if (indexAccessNode.id < modifyingPath.node.id && i === 0) break; - i++; if ( diff --git a/src/types/solidity-types.ts b/src/types/solidity-types.ts index 6d53ee690..e7899da6c 100644 --- a/src/types/solidity-types.ts +++ b/src/types/solidity-types.ts @@ -71,6 +71,7 @@ export function getVisitableKeys(nodeType: string): string[] { case 'ModifierDefinition': case 'Break': case 'Continue': + case 'MsgSender': return []; // And again, if we haven't recognized the nodeType then we'll throw an diff --git a/test/contracts/SimpleStruct6.zol b/test/contracts/SimpleStruct6.zol new file mode 100644 index 000000000..9510f84bb --- /dev/null +++ b/test/contracts/SimpleStruct6.zol @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: CC0 + +pragma solidity ^0.8.0; + +contract Assign { + + + struct myStruct { + uint256 prop1; + uint256 prop2; + } + + secret myStruct public w; + secret uint256 public a; + + + uint256 public index; + + + function add(secret uint256 value) public { + unknown a += index; + index++; + unknown w.prop1 += value; + unknown w.prop2 += index + value; + index++; + unknown a += value; + } + + function remove(secret uint256 value) public { + index++; + unknown w.prop1 -= value; + unknown w.prop2 -= value; + unknown a -= index; + index++; + unknown w.prop2 -= index; + unknown a -= value; + } + + + +}