diff --git a/src/codeGenerators/circuit/zokrates/toCircuit.ts b/src/codeGenerators/circuit/zokrates/toCircuit.ts index d040e3efa..d37c971ac 100644 --- a/src/codeGenerators/circuit/zokrates/toCircuit.ts +++ b/src/codeGenerators/circuit/zokrates/toCircuit.ts @@ -264,12 +264,31 @@ function codeGenerator(node: any) { } }); for (let i =0; i node.outsideIf).concat(node.falseBody.filter((node: any) => node.outsideIf)); + let newPreIfStatements = []; + preIfStatements.forEach((node: any) => { + newPreIfStatements.push(cloneDeep(node)); + newPreIfStatements[newPreIfStatements.length - 1].outsideIf = false; + }); + let preIfStatementsString = newPreIfStatements.flatMap(codeGenerator).join('\n'); + if(node.falseBody.length) - return `if (${codeGenerator(node.condition)}) { + return `${comment} + ${preIfStatementsString} + if (${codeGenerator(node.condition)}) { ${node.trueBody.flatMap(codeGenerator).join('\n')} } else { ${node.falseBody.flatMap(codeGenerator).join('\n')} }` else - return `if (${codeGenerator(node.condition)}) { + return `${comment} + ${preIfStatementsString} + if (${codeGenerator(node.condition)}) { ${node.trueBody.flatMap(codeGenerator).join('\n')} }` } diff --git a/src/transformers/visitors/toCircuitVisitor.ts b/src/transformers/visitors/toCircuitVisitor.ts index 7a702ba68..34a95eda7 100644 --- a/src/transformers/visitors/toCircuitVisitor.ts +++ b/src/transformers/visitors/toCircuitVisitor.ts @@ -57,6 +57,31 @@ const incrementNames = (node: any, indicator: any) => { }; +//Finds a statement with the correct ID +const findStatementId = (statements: any, ID: number) => { + let expNode = statements.find((n:any) => n?.id === ID); + let index_expNode = statements.indexOf(expNode); + let location = {index: index_expNode, trueIndex: -1, falseIndex: -1}; + statements.forEach((st:any) => { + if (st.trueBody){ + if (!expNode) { + expNode = st.trueBody.find((n:any) => n?.id === ID); + location.index = statements.indexOf(st); + location.trueIndex = st.trueBody.indexOf(expNode); + } + } + if (st.falseBody){ + if (!expNode) { + expNode = st.falseBody.find((n:any) => n?.id === ID); + location.index = statements.indexOf(st); + location.falseIndex = st.falseBody.indexOf(expNode); + } + } + }); + return {expNode, location}; +}; + + // public variables that interact with the secret also need to be modified within the circuit. const publicVariables = (path: NodePath, state: any, IDnode: any) => { const {parent, node } = path; @@ -91,8 +116,7 @@ const publicVariables = (path: NodePath, state: any, IDnode: any) => { if (path.containerName !== 'indexExpression') { num_modifiers++; } - let expNode = statements.find((n:any) => n?.id === expressionId); - let index_expNode = fnDefNode.node._newASTPointer.body.statements.indexOf(expNode); + let {expNode, location} = findStatementId(statements, expressionId); if (expNode && !expNode.isAccessed) { expNode.isAccessed = true; if((expNode.expression && expNode.expression.leftHandSide && expNode.expression.leftHandSide?.name === node.name) || @@ -109,8 +133,11 @@ const publicVariables = (path: NodePath, state: any, IDnode: any) => { interactsWithSecret: true, isVarDec: true, }); - if (index_expNode !== -1) { - fnDefNode.node._newASTPointer.body.statements.splice(index_expNode + 1, 0, newNode1); + newNode1.outsideIf = true; + if (location.index!== -1) { + if (location.trueIndex !== -1){ fnDefNode.node._newASTPointer.body.statements[location.index].trueBody.splice(location.trueIndex + 1, 0, newNode1); } + else if (location.falseIndex !== -1){ fnDefNode.node._newASTPointer.body.statements[location.index].falseBody.splice(location.falseIndex + 1, 0, newNode1); } + else {fnDefNode.node._newASTPointer.body.statements.splice(location.index + 1, 0, newNode1);} } } } else{ @@ -124,8 +151,11 @@ const publicVariables = (path: NodePath, state: any, IDnode: any) => { expression: InnerNode, interactsWithSecret: true, }); - if (index_expNode !== -1) { - fnDefNode.node._newASTPointer.body.statements.splice(index_expNode + 1, 0, newNode1); + newNode1.outsideIf = true; + if (location.index!== -1) { + if (location.trueIndex !== -1){ fnDefNode.node._newASTPointer.body.statements[location.index].trueBody.splice(location.trueIndex + 1, 0, newNode1); } + else if (location.falseIndex !== -1){ fnDefNode.node._newASTPointer.body.statements[location.index].falseBody.splice(location.falseIndex + 1, 0, newNode1); } + else {fnDefNode.node._newASTPointer.body.statements.splice(location.index + 1, 0, newNode1);} } if (`${modName}` !== `${node.name}_${num_modifiers}` && num_modifiers !==0){ const initInnerNode1 = buildNode('Assignment', { @@ -138,8 +168,11 @@ const publicVariables = (path: NodePath, state: any, IDnode: any) => { interactsWithSecret: true, isVarDec: true, }); - if (index_expNode !== -1) { - fnDefNode.node._newASTPointer.body.statements.splice(index_expNode + 2, 0, newNode2); + newNode2.outsideIf = true; + if (location.index!== -1) { + if (location.trueIndex !== -1){ fnDefNode.node._newASTPointer.body.statements[location.index].trueBody.splice(location.trueIndex + 2, 0, newNode2); } + else if (location.falseIndex !== -1){ fnDefNode.node._newASTPointer.body.statements[location.index].falseBody.splice(location.falseIndex + 2, 0, newNode2); } + else {fnDefNode.node._newASTPointer.body.statements.splice(location.index + 2, 0, newNode2);} } } } @@ -1193,9 +1226,34 @@ const visitor = { IfStatement: { enter(path: NodePath, state: any) { - const { node, parent } = path; - let isIfStatementSecret; - if(node.falseBody?.containsSecret || node.trueBody?.containsSecret || !node.condition?.containsPublic) + const { node, parent, scope } = path; + let isIfStatementSecret: boolean; + let interactsWithSecret = false; + function bodyInteractsWithSecrets(statements) { + statements.forEach((st) => { + if (st.nodeType === 'ExpressionStatement') { + if (st.expression.nodeType === 'UnaryOperation') { + const { operator, subExpression } = st.expression; + if ((operator === '++' || operator === '--') && subExpression.nodeType === 'Identifier') { + const referencedIndicator = scope.getReferencedIndicator(subExpression); + if (referencedIndicator?.interactsWithSecret) { + interactsWithSecret = true; + } + } + } else { + const referencedIndicator = scope.getReferencedIndicator(st.expression.leftHandSide); + if (referencedIndicator?.interactsWithSecret) { + interactsWithSecret = true; + } + } + } + }); + } + if (node.trueBody?.statements) bodyInteractsWithSecrets(node.trueBody?.statements); + if (node.falseBody?.statements) bodyInteractsWithSecrets(node.falseBody?.statements); + + + if(node.falseBody?.containsSecret || node.trueBody?.containsSecret || interactsWithSecret || node.condition?.containsSecret) isIfStatementSecret = true; if(isIfStatementSecret) { if(node.trueBody.statements[0].expression.nodeType === 'FunctionCall') diff --git a/src/transformers/visitors/toOrchestrationVisitor.ts b/src/transformers/visitors/toOrchestrationVisitor.ts index aebf69268..afbb0b2d0 100644 --- a/src/transformers/visitors/toOrchestrationVisitor.ts +++ b/src/transformers/visitors/toOrchestrationVisitor.ts @@ -83,6 +83,33 @@ const collectIncrements = (stateVarIndicator: StateVariableIndicator | MappingKe return { incrementsArray, incrementsString }; }; +//Finds a statement with the correct ID +const findStatementId = (statements: any, ID: number) => { + let expNode = statements.find((n:any) => n?.id === ID); + let index_expNode = statements.indexOf(expNode); + let location = {index: index_expNode, trueIndex: -1, falseIndex: -1, ifNode: null}; + statements.forEach((st:any) => { + if (st.trueBody){ + if (!expNode) { + expNode = st.trueBody.find((n:any) => n?.id === ID); + location.index = statements.indexOf(st); + location.trueIndex = st.trueBody.indexOf(expNode); + location.ifNode = st; + } + } + if (st.falseBody){ + if (!expNode) { + expNode = st.falseBody.find((n:any) => n?.id === ID); + location.index = statements.indexOf(st); + location.falseIndex = st.falseBody.indexOf(expNode); + location.ifNode = st; + } + } + }); + return {expNode, location}; +}; + + // gathers public inputs we need to extract from the contract // i.e. public 'accessed' variables const addPublicInput = (path: NodePath, state: any, IDnode: any) => { @@ -163,7 +190,7 @@ const addPublicInput = (path: NodePath, state: any, IDnode: any) => { modifiedBeforePaths?.forEach((p: NodePath) => { const expressionId = p.getAncestorOfType('ExpressionStatement')?.node?.id; if (expressionId) { - let expNode = statements.find((n:any) => n?.id === expressionId); + let {expNode, location} = findStatementId(statements, expressionId); if (path.containerName !== 'indexExpression') { num_modifiers++; } @@ -172,8 +199,39 @@ const addPublicInput = (path: NodePath, state: any, IDnode: any) => { // we have to go back and mark any editing statements as interactsWithSecret so they show up expNode.interactsWithSecret = true; const moveExpNode = cloneDeep(expNode); - fnDefNode.node._newASTPointer.body.preStatements.push(moveExpNode); - delete statements[statements.indexOf(expNode)]; + // We now move the statement in expNode to preStatements. + //If the statement is within an if statement we need to find the correct if statement in preStatements or create a new one. + let ifPreIndex = null; + if (location.ifNode) { + let {location: locIf } = findStatementId(fnDefNode.node._newASTPointer.body.preStatements, location.ifNode.id); + ifPreIndex = locIf.index; + if (locIf.index !== -1 && location.trueIndex !== -1) fnDefNode.node._newASTPointer.body.preStatements[locIf.index].trueBody.push(moveExpNode); + else if (locIf.index !== -1 && location.falseIndex !== -1) fnDefNode.node._newASTPointer.body.preStatements[locIf.index].falseBody.push(moveExpNode); + else if (!locIf.index || locIf.index === -1 ){ + let newIfNode = cloneDeep(location.ifNode); + newIfNode.inPreStatements = true; + newIfNode.trueBody = []; + newIfNode.falseBody = []; + if (location.trueIndex !== -1) newIfNode.trueBody.push(moveExpNode); + if (location.falseIndex !== -1) newIfNode.falseBody.push(moveExpNode); + fnDefNode.node._newASTPointer.body.preStatements.push(newIfNode); + ifPreIndex = fnDefNode.node._newASTPointer.body.preStatements.length -1; + } + } else{ + fnDefNode.node._newASTPointer.body.preStatements.push(moveExpNode); + } + // We now remove the statement from the statements array. + if (location.index!== -1) { + if (location.trueIndex !== -1){ delete statements[location.index].trueBody[location.trueIndex]; } + else if (location.falseIndex !== -1){ delete statements[location.index].falseBody[location.falseIndex]; } + else { + delete statements[location.index]; + } + } + if ((statements[location.index]?.trueBody && statements[location.index].trueBody.every(element => element === null || element === undefined)) && (statements[location.index]?.falseBody && statements[location.index].falseBody.every(element => element === null || element === undefined))) { + delete statements[location.index]; + } + if( (expNode.expression && expNode.expression.leftHandSide && expNode.expression.leftHandSide?.name === node.name) || (expNode.initialValue && expNode.initialValue.leftHandSide && expNode.initialValue.leftHandSide?.name === node.name) @@ -196,7 +254,11 @@ const addPublicInput = (path: NodePath, state: any, IDnode: any) => { interactsWithSecret: true, isModifiedDeclaration: true, }); - fnDefNode.node._newASTPointer.body.preStatements.push(newNode1); + if (location.ifNode) newNode1.outsideIf = true; + if (location.trueIndex !== -1){ fnDefNode.node._newASTPointer.body.preStatements[ifPreIndex].trueBody.push(newNode1); } + else if (location.falseIndex !== -1){ fnDefNode.node._newASTPointer.body.preStatements[ifPreIndex].falseBody.push(newNode1); } + else {fnDefNode.node._newASTPointer.body.preStatements.push(newNode1);} + } } else{ let name_new = expNode.expression?.initialValue?.leftHandSide?.name || expNode.initialValue?.leftHandSide.name || expNode.expression?.leftHandSide.name; @@ -209,6 +271,9 @@ const addPublicInput = (path: NodePath, state: any, IDnode: any) => { expression: InnerNode, interactsWithSecret: true, }); + if (location.trueIndex !== -1){ fnDefNode.node._newASTPointer.body.preStatements[ifPreIndex].trueBody.push(newNode1); } + else if (location.falseIndex !== -1){ fnDefNode.node._newASTPointer.body.preStatements[ifPreIndex].falseBody.push(newNode1); } + else {fnDefNode.node._newASTPointer.body.preStatements.push(newNode1);} fnDefNode.node._newASTPointer.body.preStatements.push(newNode1); if (`${name_new}` !== `${node.name}_${num_modifiers}` && num_modifiers !==0){ const decInnerNode1 = buildNode('VariableDeclaration', { @@ -228,7 +293,10 @@ const addPublicInput = (path: NodePath, state: any, IDnode: any) => { interactsWithSecret: true, isModifiedDeclaration: true, }); - fnDefNode.node._newASTPointer.body.preStatements.push(newNode2); + if (location.ifNode) newNode2.outsideIf = true; + if (location.trueIndex !== -1){ fnDefNode.node._newASTPointer.body.preStatements[ifPreIndex].trueBody.push(newNode2); } + else if (location.falseIndex !== -1){ fnDefNode.node._newASTPointer.body.preStatements[ifPreIndex].falseBody.push(newNode2); } + else {fnDefNode.node._newASTPointer.body.preStatements.push(newNode2);} } } } @@ -505,6 +573,8 @@ const visitor = { node._newASTPointer.msgSenderParam ??= state.msgSenderParam; node._newASTPointer.msgValueParam ??= state.msgValueParam; + + if(node.containsPublic && !scope.modifiesSecretState()){ interface PublicParam { name: string; @@ -848,6 +918,38 @@ const visitor = { const newFunctionDefinitionNode = node._newASTPointer; + // In If Statements we might have non-secret statements editing variables that later interact with a secret variable. + // We therefore have statements of the form b_6 = b so that b_6 can be used later. + // We need to add the final such statement from the false body to after the if statement so that e.g. b_6 can be used even if the true body is executed. + let nodesToAdd = []; + newFunctionDefinitionNode.body.preStatements.forEach((n: any, index: number) => { + if (n.nodeType === 'IfStatement'){ + let finalName; + let originalName; + n.falseBody.forEach((falseNode: any) => { + if (falseNode.outsideIf){ + finalName = falseNode.initialValue.leftHandSide.name; + originalName = falseNode.initialValue.rightHandSide.name; + } + }); + if (finalName && originalName){ + const InnerNode = buildNode('Assignment', { + leftHandSide: buildNode('Identifier', { name: finalName, subType: 'generalNumber' }), + operator: '=', + rightHandSide: buildNode('Identifier', { name: originalName, subType: 'generalNumber' }) + }); + const finalIfNode = buildNode('ExpressionStatement', { + expression: InnerNode, + interactsWithSecret: true, + }); + nodesToAdd.push({node: finalIfNode, index: index+1}); + } + } + }); + for (let i = nodesToAdd.length - 1; i >= 0; i--) { + newFunctionDefinitionNode.body.preStatements.splice(nodesToAdd[i].index, 0, nodesToAdd[i].node); + } + // this adds other values we need in the circuit for (const param of node._newASTPointer.parameters.parameters) { if (param.isPrivate || param.isSecret || param.interactsWithSecret) { @@ -1746,13 +1848,40 @@ const visitor = { IfStatement: { enter(path: NodePath , state: any) { - const { node, parent, } = path; - if(!node.containsSecret) { + const { node, parent, scope } = path; + let isIfStatementSecret; + let interactsWithSecret = false; + function bodyInteractsWithSecrets(statements) { + statements.forEach((st) => { + if (st.nodeType === 'ExpressionStatement') { + if (st.expression.nodeType === 'UnaryOperation') { + const { operator, subExpression } = st.expression; + if ((operator === '++' || operator === '--') && subExpression.nodeType === 'Identifier') { + const referencedIndicator = scope.getReferencedIndicator(subExpression); + if (referencedIndicator?.interactsWithSecret) { + interactsWithSecret = true; + } + } + } else { + const referencedIndicator = scope.getReferencedIndicator(st.expression.leftHandSide); + if (referencedIndicator?.interactsWithSecret) { + interactsWithSecret = true; + } + } + } + }); + } + if (node.trueBody?.statements) bodyInteractsWithSecrets(node.trueBody?.statements); + if (node.falseBody?.statements) bodyInteractsWithSecrets(node.falseBody?.statements); + if(node.falseBody?.containsSecret || node.trueBody?.containsSecret || interactsWithSecret || node.condition?.containsSecret) + isIfStatementSecret = true; + if(!isIfStatementSecret) { state.skipSubNodes = true; return; } const newNode = buildNode(node.nodeType); newNode.interactsWithSecret = true; + newNode.id = node.id; node._newASTPointer = newNode; parent._newASTPointer.push(newNode); }, diff --git a/src/types/orchestration-types.ts b/src/types/orchestration-types.ts index 88b4d72a4..35c44ae88 100644 --- a/src/types/orchestration-types.ts +++ b/src/types/orchestration-types.ts @@ -183,8 +183,9 @@ export default function buildNode(nodeType: string, fields: any = {}): any { } } case 'IfStatement': { - const { condition = {} , trueBody= [] , falseBody= [] } = fields; + const { condition = {} , trueBody= [] , falseBody= [], oldASTId, } = fields; return { + id: oldASTId, nodeType, condition, trueBody, diff --git a/test/contracts/If-Statement6.zol b/test/contracts/If-Statement6.zol new file mode 100644 index 000000000..22e1051fc --- /dev/null +++ b/test/contracts/If-Statement6.zol @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: CC0 + +pragma solidity ^0.8.0; + +contract MyContract { + + secret uint256 private a; + uint256 public b; + uint256 public c; + secret address private admin; + address private pubAdmin; + + + constructor() { + admin = msg.sender; + pubAdmin = msg.sender; + } + + function add(uint256 value) public { + if(b < 5) { + b += 1; + b++; + } else{ + b = b+2; + b++; + } + a += b + value; + } + + + function add1(uint256 value) public { + if(b < 5) { + b += 1; + b++; + } + a += value; + } + + + function add2(uint256 value) public { + if(b < 5) { + b += 1; + c += 1; + } + a += b + value; + } + + + + +} diff --git a/test/contracts/If-Statement7.zol b/test/contracts/If-Statement7.zol new file mode 100644 index 000000000..eaecc4259 --- /dev/null +++ b/test/contracts/If-Statement7.zol @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: CC0 + +pragma solidity ^0.8.0; + +contract MyContract { + + secret uint256 private a; + uint256 public b; + uint256 public c; + secret address private admin; + address private pubAdmin; + + + constructor() { + admin = msg.sender; + pubAdmin = msg.sender; + } + + + + function add3(uint256 value) public { + if(b < 5) { + b += 1; + b++; + } + a += b + value; + } + + function add4(uint256 value) public { + if(a < b) { + revert("a less than b"); + } + a+= value; + b += 20; + } + + + + +} diff --git a/test/error-checks/_Variable.zol b/test/error-checks/_Variable.zol new file mode 100644 index 000000000..0c610a011 --- /dev/null +++ b/test/error-checks/_Variable.zol @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: CC0 + +pragma solidity ^0.8.0; + +contract Assign { + + secret uint256 private a; + function add(secret uint256 _value) public { + unknown a += _value; + } + + function remove(secret uint256 value) public { + a -= value; + } +}