Skip to content

Commit

Permalink
Merge pull request #287 from EYBlockchain/lydia/unknownStructs
Browse files Browse the repository at this point in the history
Lydia/unknown structs
  • Loading branch information
lydiagarms authored Jul 2, 2024
2 parents 1b9e233 + f4cbabb commit 02f10e5
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 40 deletions.
11 changes: 6 additions & 5 deletions src/boilerplate/circuit/zokrates/nodes/BoilerplateGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ const collectIncrements = (bpg: BoilerplateGenerator) => {
}
if (inc.nodeType === 'MemberAccess') inc.name ??= `${inc.expression.name}.${inc.memberName}`;
if (!inc.name) inc.name = inc.value;

let modName = inc.modName ? inc.modName : inc.name;
if (incrementsArray.some(existingInc => inc.name === existingInc.name))
continue;
incrementsArray.push({
Expand All @@ -54,9 +54,9 @@ const collectIncrements = (bpg: BoilerplateGenerator) => {
});

if (inc === stateVarIndicator.increments[0]) {
incrementsString += `${inc.name}`;
incrementsString += `${modName}`;
} else {
incrementsString += ` ${inc.precedingOperator} ${inc.name}`;
incrementsString += ` ${inc.precedingOperator} ${modName}`;
}
}
for (const dec of stateVarIndicator.decrements) {
Expand All @@ -72,17 +72,18 @@ const collectIncrements = (bpg: BoilerplateGenerator) => {
if (!dec.name) dec.name = dec.value;
if (incrementsArray.some(existingInc => dec.name === existingInc.name))
continue;
let modName = dec.modName ? dec.modName : dec.name;
incrementsArray.push({
name: dec.name,
precedingOperator: dec.precedingOperator,
});

if (!stateVarIndicator.decrements[1] && !stateVarIndicator.increments[0]) {
incrementsString += `${dec.name}`;
incrementsString += `${modName}`;
} else {
// if we have decrements, this str represents the value we must take away
// => it's a positive value with +'s
incrementsString += ` + ${dec.name}`;
incrementsString += ` + ${modName}`;
}
}
return { incrementsArray, incrementsString };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ class BoilerplateGenerator {
`${structProperties.map(p => newCommitmentValue[p] === '0' ? '' : `assert(${x0}.${p} + ${x1}.${p} >= ${y[p]})`).join('\n')}
// TODO: assert no under/overflows
${typeName} ${x}_newCommitment_value = ${typeName} { ${structProperties.map(p => ` ${p}: (${x0}.${p} + ${x1}.${p}) - ${y[p]}`)} }`
${typeName} ${x}_newCommitment_value = ${typeName} { ${structProperties.map(p => ` ${p}: (${x0}.${p} + ${x1}.${p}) - (${y[p]})`)} }`
);
}
} else {
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/visitors/checks/incrementedVisitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ const literalOneNode = {
const collectIncrements = (increments: any, incrementedIdentifier: any) => {
const { operands, precedingOperator } = increments;
const newIncrements: any[] = [];
const Idname = incrementedIdentifier.name || incrementedIdentifier.expression?.name;
for (const [index, operand] of operands.entries()) {
operand.precedingOperator = precedingOperator[index];
if (
operand.name !== incrementedIdentifier.name &&
operand.baseExpression?.name !== incrementedIdentifier.name &&
operand.name !== Idname &&
operand.baseExpression?.name !== Idname &&
!newIncrements.some(inc => inc.id === operand.id)
)
newIncrements.push(operand);
Expand Down
96 changes: 80 additions & 16 deletions src/transformers/visitors/toCircuitVisitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import cloneDeep from 'lodash.clonedeep';
import { buildNode } from '../../types/zokrates-types.js';
import { TODOError } from '../../error/errors.js';
import { traversePathsFast } from '../../traverse/traverse.js';
import { traversePathsFast, traverseNodesFast } from '../../traverse/traverse.js';
import NodePath from '../../traverse/NodePath.js';
import explode from './explode.js';
import internalCallVisitor from './circuitInternalFunctionCallVisitor.js';
Expand All @@ -14,8 +14,51 @@ import { interactsWithSecretVisitor, internalFunctionCallVisitor, parentnewASTPo




// Adjusts names of indicator.increment so that they match the names of the corresponding indicators i.e. index_1 instead of index
const incrementNames = (node: any, indicator: any) => {
if (node.bpType === 'incrementation'){
let rhsNode = node.addend;
const adjustIncrementsVisitor = (thisNode: any) => {
if (thisNode.nodeType === 'Identifier'){
if (!indicator.increments.some((inc: any) => inc.name === thisNode.name)){
let lastUnderscoreIndex = thisNode.name.lastIndexOf("_");
let origName = thisNode.name.substring(0, lastUnderscoreIndex);
let count =0;
indicator.increments.forEach((inc: any) => {
if (origName === inc.name && !inc.modName && count === 0){
inc.modName = thisNode.name;
count++;
}
});
}
}
}
if (rhsNode) traverseNodesFast(rhsNode, adjustIncrementsVisitor);
} else if (node.bpType === 'decrementation'){
let rhsNode = node.subtrahend;
const adjustDecrementsVisitor = (thisNode: any) => {
if (thisNode.nodeType === 'Identifier'){
if (!indicator.decrements.some((dec: any) => dec.name === thisNode.name)){
let lastUnderscoreIndex = thisNode.name.lastIndexOf("_");
let origName = thisNode.name.substring(0, lastUnderscoreIndex);
let count =0;
indicator.decrements.forEach((dec: any) => {
if (origName === dec.name && !dec.modName && count === 0){
dec.modName = thisNode.name;
count++;
}
});
}
}
}
if (rhsNode) traverseNodesFast(rhsNode, adjustDecrementsVisitor);
}
};


// public variables that interact with the secret also need to be modified within the circuit.
const publicVariablesVisitor = (path: NodePath, state: any, IDnode: any) => {
const publicVariables = (path: NodePath, state: any, IDnode: any) => {
const {parent, node } = path;
// Break if the identifier is a mapping or array.
if ( parent.indexExpression && parent.baseExpression === node ) {
Expand All @@ -35,7 +78,6 @@ const publicVariablesVisitor = (path: NodePath, state: any, IDnode: any) => {
if (!fnDefNode) throw new Error(`Not in a function`);

const modifiedBeforePaths = path.scope.getReferencedIndicator(node, true)?.modifyingPaths?.filter((p: NodePath) => p.node.id < node.id);

const statements = fnDefNode.node._newASTPointer.body.statements;

let num_modifiers=0;
Expand Down Expand Up @@ -131,9 +173,27 @@ const publicVariablesVisitor = (path: NodePath, state: any, IDnode: any) => {
fnDefNode.node._newASTPointer.body.statements.push(endNode);
}
// We no longer need this because index expression nodes are not input.
//if (['Identifier', 'IndexAccess'].includes(node.indexExpression?.nodeType)) publicVariablesVisitor(NodePath.getPath(node.indexExpression), state, null);
//if (['Identifier', 'IndexAccess'].includes(node.indexExpression?.nodeType)) publicVariables(NodePath.getPath(node.indexExpression), state, null);
}

//Visitor for publicVariables
const publicVariablesVisitor = (thisPath: NodePath, thisState: any) => {
const { node } = thisPath;
let { name } = node;
if (!['Identifier', 'IndexAccess'].includes(thisPath.nodeType)) return;
const binding = thisPath.getReferencedBinding(node);
if ( (binding instanceof VariableBinding) && !binding.isSecret &&
binding.stateVariable && thisPath.getAncestorContainedWithin('rightHandSide') ){
} else{
name = thisPath.scope.getIdentifierMappingKeyName(node);
}
const newNode = buildNode(
node.nodeType,
{ name, type: node.typeDescriptions?.typeString },
);
publicVariables(thisPath, thisState, newNode);
};


// below stub will only work with a small subtree - passing a whole AST will always give true!
// useful for subtrees like ExpressionStatements
Expand Down Expand Up @@ -310,13 +370,15 @@ const visitor = {
// a non secret function - we skip it for circuits
state.skipSubNodes = true;
}

},

exit(path: NodePath, state: any) {
const { node, parent, scope } = path;
const { indicators } = scope;
const newFunctionDefinitionNode = node._newASTPointer;


// We need to ensure the correctness of the circuitImport flag for each internal function call. The state may have been updated due to later function calls that modify the same secret state.
let importStatementList: any;
parent._newASTPointer.forEach((file: any) => {
Expand Down Expand Up @@ -776,14 +838,12 @@ let childOfSecret = path.getAncestorOfType('ForStatement')?.containsSecret;
case 'Assignment': {
const { leftHandSide: lhs, rightHandSide: rhs } = expression;
const lhsIndicator = scope.getReferencedIndicator(lhs);

if (!lhsIndicator?.isPartitioned) break;

const rhsPath = NodePath.getPath(rhs);
// We need to _clone_ the path, because we want to temporarily modify some of its properties for this traversal. For future AST transformations, we'll want to revert to the original path.
const tempRHSPath = cloneDeep(rhsPath);
const tempRHSParent = tempRHSPath.parent;

if (isDecremented) {
newNode = buildNode('BoilerplateStatement', {
bpType: 'decrementation',
Expand Down Expand Up @@ -820,8 +880,15 @@ let childOfSecret = path.getAncestorOfType('ForStatement')?.containsSecret;

tempRHSPath.traverse(visitor, { skipPublicInputs: true });
rhsPath.traversePathsFast(publicInputsVisitor, {});
rhsPath.traversePathsFast(publicVariablesVisitor, {});
path.traversePathsFast(p => {
if (p.node.nodeType === 'Identifier' && p.isStruct(p.node)){
addStructDefinition(p);
}
}, state);
state.skipSubNodes = true;
parent._newASTPointer.push(newNode);
incrementNames(newNode, lhsIndicator);
return;
}
default:
Expand Down Expand Up @@ -897,12 +964,11 @@ let childOfSecret = path.getAncestorOfType('ForStatement')?.containsSecret;
const fnDefNode = path.getAncestorOfType('FunctionDefinition');
// We ensure the original variable name is set to the initial value only at the end of the statements.
//E.g index = index_init should only appear at the end of all the modifying statements.
let ind = fnDefNode.node._newASTPointer.body.statements.length - 2;
while (ind >= 0 && fnDefNode.node._newASTPointer.body.statements[ind].expression?.rightHandSide?.name && fnDefNode.node._newASTPointer.body.statements[ind].expression?.rightHandSide?.name.includes("_init")){
let temp = fnDefNode.node._newASTPointer.body.statements[ind+1];
fnDefNode.node._newASTPointer.body.statements[ind+1] = fnDefNode.node._newASTPointer.body.statements[ind];
fnDefNode.node._newASTPointer.body.statements[ind] = temp;
ind--;
for (let i = fnDefNode.node._newASTPointer.body.statements.length - 1; i >= 0; i--) {
if (fnDefNode.node._newASTPointer.body.statements[i].isEndInit) {
let element = fnDefNode.node._newASTPointer.body.statements.splice(i, 1)[0];
fnDefNode.node._newASTPointer.body.statements.push(element);
}
}
}
},
Expand Down Expand Up @@ -1009,7 +1075,6 @@ let childOfSecret = path.getAncestorOfType('ForStatement')?.containsSecret;
declarationType,
});


if (path.isStruct(node)) {
state.structNode = addStructDefinition(path);
newNode.typeName.name = state.structNode.name;
Expand Down Expand Up @@ -1100,7 +1165,7 @@ let childOfSecret = path.getAncestorOfType('ForStatement')?.containsSecret;
{ name, type: node.typeDescriptions?.typeString },
);
if (path.isStruct(node)) addStructDefinition(path);
publicVariablesVisitor(path, state,newNode);
publicVariables(path, state,newNode);
if (path.getAncestorOfType('IfStatement')) node._newASTPointer = newNode;
// no pointer needed, because this is a leaf, so we won't be recursing any further.
// UNLESS we must add and rename if conditionals
Expand Down Expand Up @@ -1276,7 +1341,6 @@ let childOfSecret = path.getAncestorOfType('ForStatement')?.containsSecret;
MemberAccess: {
enter(path: NodePath, state: any) {
const { parent, node } = path;

let newNode: any;

if (path.isMsgSender()) {
Expand Down Expand Up @@ -1311,7 +1375,7 @@ let childOfSecret = path.getAncestorOfType('ForStatement')?.containsSecret;
const newNode = buildNode('IndexAccess');
if (path.isConstantArray(node) && (path.isLocalStackVariable(node) || path.isFunctionParameter(node))) newNode.isConstantArray = true;
// We don't need this because index access expressions always contain identifiers.
//publicVariablesVisitor(path, state,newNode);
//publicVariables(path, state,newNode);
node._newASTPointer = newNode;
parent._newASTPointer[path.containerName] = newNode;
},
Expand Down
30 changes: 16 additions & 14 deletions src/transformers/visitors/toOrchestrationVisitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ const collectIncrements = (stateVarIndicator: StateVariableIndicator | MappingKe
return structIncs;
}
for (const inc of stateVarIndicator.increments || []) {

if (inc.nodeType === 'IndexAccess' || inc.nodeType === 'MemberAccess') inc.name = getIndexAccessName(inc);
if (!inc.name) inc.name = inc.value;
if (incrementsArray.some(existingInc => inc.name === existingInc.name))
// Note: modName defined in circuit
let modName = inc.modName ? inc.modName : inc.name;
if (incrementsArray.some(existingInc => inc.name === existingInc.name ))
continue;
incrementsArray.push({
name: inc.name,
Expand All @@ -48,17 +49,18 @@ const collectIncrements = (stateVarIndicator: StateVariableIndicator | MappingKe

if (inc === stateVarIndicator.increments?.[0]) {
incrementsString += inc.value
? `parseInt(${inc.name}, 10)`
: `parseInt(${inc.name}.integer, 10)`;
? `parseInt(${modName}, 10)`
: `parseInt(${modName}.integer, 10)`;
} else {
incrementsString += inc.value
? ` ${inc.precedingOperator} parseInt(${inc.name}, 10)`
: ` ${inc.precedingOperator} parseInt(${inc.name}.integer, 10)`;
? ` ${inc.precedingOperator} parseInt(${modName}, 10)`
: ` ${inc.precedingOperator} parseInt(${modName}.integer, 10)`;
}
}
for (const dec of stateVarIndicator.decrements || []) {
if (dec.nodeType === 'IndexAccess' || dec.nodeType === 'MemberAccess') dec.name = getIndexAccessName(dec);
if (!dec.name) dec.name = dec.value;
let modName = dec.modName ? dec.modName : dec.name;
if (incrementsArray.some(existingInc => dec.name === existingInc.name))
continue;
incrementsArray.push({
Expand All @@ -68,14 +70,14 @@ const collectIncrements = (stateVarIndicator: StateVariableIndicator | MappingKe

if (!stateVarIndicator.decrements?.[1] && !stateVarIndicator.increments?.[0]) {
incrementsString += dec.value
? `parseInt(${dec.name}, 10)`
: `parseInt(${dec.name}.integer, 10)`;
? `parseInt(${modName}, 10)`
: `parseInt(${modName}.integer, 10)`;
} else {
// if we have decrements, this str represents the value we must take away
// => it's a positive value with +'s
incrementsString += dec.value
? ` + parseInt(${dec.name}, 10)`
: ` + parseInt(${dec.name}.integer, 10)`;
? ` + parseInt(${modName}, 10)`
: ` + parseInt(${modName}.integer, 10)`;
}
}
return { incrementsArray, incrementsString };
Expand Down Expand Up @@ -676,10 +678,10 @@ const visitor = {
state.wholeNullified ??= [];
if (!state.wholeNullified.includes(name)) state.wholeNullified.push(name)
}

let { incrementsArray, incrementsString } = isIncremented
let increments = isIncremented
? collectIncrements(stateVarIndicator)
: { incrementsArray: null, incrementsString: null };
let {incrementsArray, incrementsString} = increments;
if (!incrementsString) incrementsString = null;
if (!incrementsArray) incrementsArray = null;

Expand Down Expand Up @@ -1475,8 +1477,8 @@ const visitor = {
// reset
delete state.interactsWithSecret;
if (node._newASTPointer?.incrementsSecretState && indicator) {
const increments = collectIncrements(indicator).incrementsString;
path.node._newASTPointer.increments = increments;
let increments = collectIncrements(indicator);
path.node._newASTPointer.increments = increments.incrementsString;
} else if (indicator?.isWhole && node._newASTPointer) {
// we add a general number statement after each whole state edit
const tempNode = node._newASTPointer;
Expand Down
4 changes: 2 additions & 2 deletions src/traverse/Scope.ts
Original file line number Diff line number Diff line change
Expand Up @@ -756,10 +756,10 @@ export class Scope {
if (keyBinding?.isModified) {
let i = 0;
// Consider each time the variable (which becomes the mapping's key) is edited throughout the scope:
for (const modifyingPath of keyBinding.modifyingPaths) {
let filteredModifyingPaths = keyBinding.modifyingPaths.filter(modifyingPath => modifyingPath.scope.scopeName === indexAccessPath.scope.scopeName);
for (const modifyingPath of filteredModifyingPaths) {
// we have found the 'current' state (relative to the input node), so we don't need to move any further
if (indexAccessNode.id < modifyingPath.node.id && i === 0) break;

i++;

if (
Expand Down
1 change: 1 addition & 0 deletions src/types/solidity-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ export function getVisitableKeys(nodeType: string): string[] {
case 'ModifierDefinition':
case 'Break':
case 'Continue':
case 'MsgSender':
return [];

// And again, if we haven't recognized the nodeType then we'll throw an
Expand Down
Loading

0 comments on commit 02f10e5

Please sign in to comment.