diff --git a/avm-transpiler/src/transpile.rs b/avm-transpiler/src/transpile.rs index a1ae1b716443..ace92a10337c 100644 --- a/avm-transpiler/src/transpile.rs +++ b/avm-transpiler/src/transpile.rs @@ -863,6 +863,7 @@ fn handle_black_box_function(avm_instrs: &mut Vec, operation: &B let num_points = points.size.0; let scalars_offset = scalars.pointer.0; // Output array is fixed to 3 + assert!(outputs.size == &3u32, "Output array size must be equal to 3"); let outputs_offset = outputs.pointer.0; avm_instrs.push(AvmInstruction { opcode: AvmOpcode::MSM, diff --git a/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_deserialization.cpp b/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_deserialization.cpp index 6b10ab40afc1..891d2af56957 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_deserialization.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_deserialization.cpp @@ -158,6 +158,8 @@ const std::unordered_map> OPCODE_WIRE_FORMAT = OperandType::UINT32, // rhs.y OperandType::UINT32, // rhs.is_infinite OperandType::UINT32 } }, // dst_offset + { OpCode::MSM, + { OperandType::INDIRECT, OperandType::UINT32, OperandType::UINT32, OperandType::UINT32, OperandType::UINT32 } }, // Gadget - Conversion { OpCode::TORADIXLE, { OperandType::INDIRECT, OperandType::UINT32, OperandType::UINT32, OperandType::UINT32, OperandType::UINT32 } }, diff --git a/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_opcode.hpp b/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_opcode.hpp index 83de32e65682..e3ced1a03e7a 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_opcode.hpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_opcode.hpp @@ -105,6 +105,7 @@ enum class OpCode : uint8_t { SHA256, PEDERSEN, ECADD, + MSM, // Conversions TORADIXLE, // Future Gadgets -- pending changes in noir diff --git a/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr b/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr index c3d827652f1d..d870e8564f88 100644 --- a/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr +++ b/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr @@ -148,7 +148,8 @@ contract AvmTest { fn variable_base_msm() -> [Field; 3] { let g = EmbeddedCurvePoint { x: 1, y: 17631683881184975370165255887551781615748388533673675138860, is_infinite: false }; let scalar = EmbeddedCurveScalar { lo: 3, hi: 0 }; - let triple_g = multi_scalar_mul([g], [scalar]); + let scalar2 = EmbeddedCurveScalar { lo: 20, hi: 0 }; + let triple_g = multi_scalar_mul([g, g], [scalar, scalar2]); triple_g } diff --git a/yarn-project/foundation/src/fields/point.ts b/yarn-project/foundation/src/fields/point.ts index 490c1fe2f48e..bf12faf3a130 100644 --- a/yarn-project/foundation/src/fields/point.ts +++ b/yarn-project/foundation/src/fields/point.ts @@ -139,14 +139,18 @@ export class Point { /** * Check if this is point at infinity. + * Check this is consistent with how bb is encoding the point at infinity */ - isInfPoint() { - // Check this - return this.x.isZero(); + public get inf() { + return this.x == Fr.ZERO; + } + public toFieldsWithInf() { + return [this.x, this.y, new Fr(this.inf)]; } isOnGrumpkin() { - if (this.isInfPoint()) { + // TODO: Check this against how bb handles curve check and infinity point check + if (this.inf) { return true; } diff --git a/yarn-project/simulator/src/avm/avm_simulator.test.ts b/yarn-project/simulator/src/avm/avm_simulator.test.ts index 86d1960577b4..1614305b0bea 100644 --- a/yarn-project/simulator/src/avm/avm_simulator.test.ts +++ b/yarn-project/simulator/src/avm/avm_simulator.test.ts @@ -117,7 +117,9 @@ describe('AVM simulator: transpiled Noir contracts', () => { expect(results.reverted).toBe(false); const grumpkin = new Grumpkin(); const g3 = grumpkin.mul(grumpkin.generator(), new Fq(3)); - expect(results.output).toEqual([g3.x, g3.y, Fr.ZERO]); + const g20 = grumpkin.mul(grumpkin.generator(), new Fq(20)); + const expectedResult = grumpkin.add(g3, g20); + expect(results.output).toEqual([expectedResult.x, expectedResult.y, Fr.ZERO]); }); describe('U128 addition and overflows', () => { diff --git a/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.test.ts b/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.test.ts index 83a9b79ca311..a0d5825c6cc1 100644 --- a/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.test.ts +++ b/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.test.ts @@ -2,141 +2,129 @@ import { Fq, Fr } from '@aztec/circuits.js'; import { Grumpkin } from '@aztec/circuits.js/barretenberg'; import { type AvmContext } from '../avm_context.js'; -import { Field, Uint8, Uint32 } from '../avm_memory_types.js'; +import { Field, type MemoryValue, Uint8, Uint32 } from '../avm_memory_types.js'; import { initContext } from '../fixtures/index.js'; import { MultiScalarMul } from './multi_scalar_mul.js'; describe('MultiScalarMul Opcode', () => { - let context: AvmContext; - - beforeEach(async () => { - context = initContext(); - }); - it('Should (de)serialize correctly', () => { - const buf = Buffer.from([ - MultiScalarMul.opcode, // opcode - 7, // indirect - ...Buffer.from('12345678', 'hex'), // pointsOffset - ...Buffer.from('23456789', 'hex'), // scalars Offset - ...Buffer.from('3456789a', 'hex'), // outputOffset - ...Buffer.from('456789ab', 'hex'), // pointsLengthOffset - ]); - const inst = new MultiScalarMul( + let context: AvmContext; + + beforeEach(async () => { + context = initContext(); + }); + it('Should (de)serialize correctly', () => { + const buf = Buffer.from([ + MultiScalarMul.opcode, // opcode + 7, // indirect + ...Buffer.from('12345678', 'hex'), // pointsOffset + ...Buffer.from('23456789', 'hex'), // scalars Offset + ...Buffer.from('3456789a', 'hex'), // outputOffset + ...Buffer.from('456789ab', 'hex'), // pointsLengthOffset + ]); + const inst = new MultiScalarMul( /*indirect=*/ 7, /*pointsOffset=*/ 0x12345678, /*scalarsOffset=*/ 0x23456789, /*outputOffset=*/ 0x3456789a, /*pointsLengthOffset=*/ 0x456789ab, - ); - - expect(MultiScalarMul.deserialize(buf)).toEqual(inst); - expect(inst.serialize()).toEqual(buf); - }); - - it('Should perform msm correctly - direct', async () => { - const indirect = 0; - const grumpkin = new Grumpkin(); - // We need to ensure points are actually on curve, so we just use the generator - // In future we could use a random point, for now we create an array of [G, 2G, 3G] - const points = Array.from({ length: 3 }, (_, i) => grumpkin.mul(grumpkin.generator(), new Fq(i + 1))); - - // Pick some big scalars to test the edge cases - const scalars = [new Fq(Fq.MODULUS - 1n), new Fq(Fq.MODULUS - 2n), new Fq(1n)]; - const pointsReadLength = points.length * 3; // multiplied by 3 since we will store them as triplet in avm memory - const scalarsLength = scalars.length * 2; // multiplied by 2 since we will store them as lo and hi limbs in avm memory - // Transform the points and scalars into the format that we will write to memory - // We just store the x and y coordinates here, and handle the infinities when we write to memory - const storedPoints: Field[] = points.flatMap(p => [new Field(p.x), new Field(p.y)]); - const storedScalars: Field[] = scalars.flatMap(s => [new Field(s.low), new Field(s.high)]); - - const pointsOffset = 0; - // Store points...awkwardly (This would be simpler if ts handled the infinities in the Point struct) - // Points are stored as [x1, y1, inf1, x2, y2, inf2, ...] where the types are [Field, Field, Uint8, Field, Field, Uint8, ...] - for (let i = 0; i < points.length; i++) { - const flatPointsOffset = pointsOffset + 2 * i; // 2 since we only have x and y - const memoryOffset = pointsOffset + 3 * i; // 3 since we store x, y, inf - context.machineState.memory.set(memoryOffset, storedPoints[flatPointsOffset]); - context.machineState.memory.set(memoryOffset + 1, storedPoints[flatPointsOffset + 1]); - context.machineState.memory.set(memoryOffset + 2, new Uint8(points[i].isInfPoint() ? 1 : 0)); - } - // Store scalars - const scalarsOffset = pointsOffset + pointsReadLength; - context.machineState.memory.setSlice(scalarsOffset, storedScalars); - // Store length of points to read - const pointsLengthOffset = scalarsOffset + scalarsLength; - context.machineState.memory.set(pointsLengthOffset, new Uint32(pointsReadLength)); - const outputOffset = pointsLengthOffset + 1; - - await new MultiScalarMul(indirect, pointsOffset, scalarsOffset, outputOffset, pointsLengthOffset).execute(context); - - const result = context.machineState.memory.getSlice(outputOffset, 3); - - // We write it out explicitly here - let expectedResult = grumpkin.mul(points[0], scalars[0]); - expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[1], scalars[1])); - expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[2], scalars[2])); - - expect(result.map(r => r.toFr())).toEqual([expectedResult.x, expectedResult.y, new Fr(0n)]); - }); - - it('Should perform msm correctly - indirect', async () => { - const indirect = 7; - const grumpkin = new Grumpkin(); - // We need to ensure points are actually on curve, so we just use the generator - // In future we could use a random point, for now we create an array of [G, 2G, 3G] - const points = Array.from({ length: 3 }, (_, i) => grumpkin.mul(grumpkin.generator(), new Fq(i + 1))); - - // Pick some big scalars to test the edge cases - const scalars = [new Fq(Fq.MODULUS - 1n), new Fq(Fq.MODULUS - 2n), new Fq(1n)]; - const pointsReadLength = points.length * 3; // multiplied by 3 since we will store them as triplet in avm memory - const scalarsLength = scalars.length * 2; // multiplied by 2 since we will store them as lo and hi limbs in avm memory - // Transform the points and scalars into the format that we will write to memory - // We just store the x and y coordinates here, and handle the infinities when we write to memory - const storedPoints: Field[] = points.flatMap(p => [new Field(p.x), new Field(p.y)]); - const storedScalars: Field[] = scalars.flatMap(s => [new Field(s.low), new Field(s.high)]); - - const pointsOffset = 0; - // Store points...awkwardly (This would be simpler if ts handled the infinities in the Point struct) - // Points are stored as [x1, y1, inf1, x2, y2, inf2, ...] where the types are [Field, Field, Uint8, Field, Field, Uint8, ...] - for (let i = 0; i < points.length; i++) { - const flatPointsOffset = pointsOffset + 2 * i; // 2 since we only have x and y - const memoryOffset = pointsOffset + 3 * i; // 3 since we store x, y, inf - context.machineState.memory.set(memoryOffset, storedPoints[flatPointsOffset]); - context.machineState.memory.set(memoryOffset + 1, storedPoints[flatPointsOffset + 1]); - context.machineState.memory.set(memoryOffset + 2, new Uint8(points[i].isInfPoint() ? 1 : 0)); - } - // Store scalars - const scalarsOffset = pointsOffset + pointsReadLength; - context.machineState.memory.setSlice(scalarsOffset, storedScalars); - // Store length of points to read - const pointsLengthOffset = scalarsOffset + scalarsLength; - context.machineState.memory.set(pointsLengthOffset, new Uint32(pointsReadLength)); - const outputOffset = pointsLengthOffset + 1; - - // Set up the indirect pointers - const pointsIndirectOffset = outputOffset + 3; /* 3 since the output is a triplet */ - const scalarsIndirectOffset = pointsIndirectOffset + 1; - const outputIndirectOffset = scalarsIndirectOffset + 1; - - context.machineState.memory.set(pointsIndirectOffset, new Uint32(pointsOffset)); - context.machineState.memory.set(scalarsIndirectOffset, new Uint32(scalarsOffset)); - context.machineState.memory.set(outputIndirectOffset, new Uint32(outputOffset)); - - await new MultiScalarMul( - indirect, - pointsIndirectOffset, - scalarsIndirectOffset, - outputIndirectOffset, - pointsLengthOffset, - ).execute(context); - - const result = context.machineState.memory.getSlice(outputOffset, 3); - - // We write it out explicitly here - let expectedResult = grumpkin.mul(points[0], scalars[0]); - expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[1], scalars[1])); - expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[2], scalars[2])); - - expect(result.map(r => r.toFr())).toEqual([expectedResult.x, expectedResult.y, new Fr(0n)]); - }); + ); + + expect(MultiScalarMul.deserialize(buf)).toEqual(inst); + expect(inst.serialize()).toEqual(buf); + }); + + it('Should perform msm correctly - direct', async () => { + const indirect = 0; + const grumpkin = new Grumpkin(); + // We need to ensure points are actually on curve, so we just use the generator + // In future we could use a random point, for now we create an array of [G, 2G, 3G] + const points = Array.from({ length: 3 }, (_, i) => grumpkin.mul(grumpkin.generator(), new Fq(i + 1))); + + // Pick some big scalars to test the edge cases + const scalars = [new Fq(Fq.MODULUS - 1n), new Fq(Fq.MODULUS - 2n), new Fq(1n)]; + const pointsReadLength = points.length * 3; // multiplied by 3 since we will store them as triplet in avm memory + const scalarsLength = scalars.length * 2; // multiplied by 2 since we will store them as lo and hi limbs in avm memory + // Transform the points and scalars into the format that we will write to memory + // We just store the x and y coordinates here, and handle the infinities when we write to memory + const storedScalars: Field[] = scalars.flatMap(s => [new Field(s.low), new Field(s.high)]); + // Points are stored as [x1, y1, inf1, x2, y2, inf2, ...] where the types are [Field, Field, Uint8, Field, Field, Uint8, ...] + const storedPoints: MemoryValue[] = points + .map(p => p.toFieldsWithInf()) + .flatMap(([x, y, inf]) => [new Field(x), new Field(y), new Uint8(inf.toNumber())]); + const pointsOffset = 0; + context.machineState.memory.setSlice(pointsOffset, storedPoints); + // Store scalars + const scalarsOffset = pointsOffset + pointsReadLength; + context.machineState.memory.setSlice(scalarsOffset, storedScalars); + // Store length of points to read + const pointsLengthOffset = scalarsOffset + scalarsLength; + context.machineState.memory.set(pointsLengthOffset, new Uint32(pointsReadLength)); + const outputOffset = pointsLengthOffset + 1; + + await new MultiScalarMul(indirect, pointsOffset, scalarsOffset, outputOffset, pointsLengthOffset).execute(context); + + const result = context.machineState.memory.getSlice(outputOffset, 3).map(r => r.toFr()); + + // We write it out explicitly here + let expectedResult = grumpkin.mul(points[0], scalars[0]); + expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[1], scalars[1])); + expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[2], scalars[2])); + + expect(result).toEqual([expectedResult.x, expectedResult.y, new Fr(0n)]); + }); + + it('Should perform msm correctly - indirect', async () => { + const indirect = 7; + const grumpkin = new Grumpkin(); + // We need to ensure points are actually on curve, so we just use the generator + // In future we could use a random point, for now we create an array of [G, 2G, 3G] + const points = Array.from({ length: 3 }, (_, i) => grumpkin.mul(grumpkin.generator(), new Fq(i + 1))); + + // Pick some big scalars to test the edge cases + const scalars = [new Fq(Fq.MODULUS - 1n), new Fq(Fq.MODULUS - 2n), new Fq(1n)]; + const pointsReadLength = points.length * 3; // multiplied by 3 since we will store them as triplet in avm memory + const scalarsLength = scalars.length * 2; // multiplied by 2 since we will store them as lo and hi limbs in avm memory + // Transform the points and scalars into the format that we will write to memory + // We just store the x and y coordinates here, and handle the infinities when we write to memory + const storedScalars: Field[] = scalars.flatMap(s => [new Field(s.low), new Field(s.high)]); + // Points are stored as [x1, y1, inf1, x2, y2, inf2, ...] where the types are [Field, Field, Uint8, Field, Field, Uint8, ...] + const storedPoints: MemoryValue[] = points + .map(p => p.toFieldsWithInf()) + .flatMap(([x, y, inf]) => [new Field(x), new Field(y), new Uint8(inf.toNumber())]); + const pointsOffset = 0; + context.machineState.memory.setSlice(pointsOffset, storedPoints); + // Store scalars + const scalarsOffset = pointsOffset + pointsReadLength; + context.machineState.memory.setSlice(scalarsOffset, storedScalars); + // Store length of points to read + const pointsLengthOffset = scalarsOffset + scalarsLength; + context.machineState.memory.set(pointsLengthOffset, new Uint32(pointsReadLength)); + const outputOffset = pointsLengthOffset + 1; + + // Set up the indirect pointers + const pointsIndirectOffset = outputOffset + 3; /* 3 since the output is a triplet */ + const scalarsIndirectOffset = pointsIndirectOffset + 1; + const outputIndirectOffset = scalarsIndirectOffset + 1; + + context.machineState.memory.set(pointsIndirectOffset, new Uint32(pointsOffset)); + context.machineState.memory.set(scalarsIndirectOffset, new Uint32(scalarsOffset)); + context.machineState.memory.set(outputIndirectOffset, new Uint32(outputOffset)); + + await new MultiScalarMul( + indirect, + pointsIndirectOffset, + scalarsIndirectOffset, + outputIndirectOffset, + pointsLengthOffset, + ).execute(context); + + const result = context.machineState.memory.getSlice(outputOffset, 3).map(r => r.toFr()); + + // We write it out explicitly here + let expectedResult = grumpkin.mul(points[0], scalars[0]); + expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[1], scalars[1])); + expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[2], scalars[2])); + + expect(result).toEqual([expectedResult.x, expectedResult.y, new Fr(0n)]); + }); }); diff --git a/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.ts b/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.ts index 70a370b231b2..29b6c106e656 100644 --- a/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.ts +++ b/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.ts @@ -1,6 +1,8 @@ import { Fq, Fr, Point } from '@aztec/circuits.js'; import { Grumpkin } from '@aztec/circuits.js/barretenberg'; +import { strict as assert } from 'assert'; + import { type AvmContext } from '../avm_context.js'; import { Field, TypeTag } from '../avm_memory_types.js'; import { InstructionExecutionError } from '../errors.js'; @@ -9,108 +11,104 @@ import { Addressing } from './addressing_mode.js'; import { Instruction } from './instruction.js'; export class MultiScalarMul extends Instruction { - static type: string = 'MultiScalarMul'; - static readonly opcode: Opcode = Opcode.MSM; + static type: string = 'MultiScalarMul'; + static readonly opcode: Opcode = Opcode.MSM; - // Informs (de)serialization. See Instruction.deserialize. - static readonly wireFormat: OperandType[] = [ - OperandType.UINT8 /* opcode */, - OperandType.UINT8 /* indirect */, - OperandType.UINT32 /* points vector offset */, - OperandType.UINT32 /* scalars vector offset */, - OperandType.UINT32 /* output offset (fixed triplet)*/, - OperandType.UINT32 /* points length offset */, - ]; + // Informs (de)serialization. See Instruction.deserialize. + static readonly wireFormat: OperandType[] = [ + OperandType.UINT8 /* opcode */, + OperandType.UINT8 /* indirect */, + OperandType.UINT32 /* points vector offset */, + OperandType.UINT32 /* scalars vector offset */, + OperandType.UINT32 /* output offset (fixed triplet) */, + OperandType.UINT32 /* points length offset */, + ]; - constructor( - private indirect: number, - private pointsOffset: number, - private scalarsOffset: number, - private outputOffset: number, - private pointsLengthOffset: number, - ) { - super(); - } + constructor( + private indirect: number, + private pointsOffset: number, + private scalarsOffset: number, + private outputOffset: number, + private pointsLengthOffset: number, + ) { + super(); + } - public async execute(context: AvmContext): Promise { - const memory = context.machineState.memory.track(this.type); - // Resolve indirects - const [pointsOffset, scalarsOffset, outputOffset] = Addressing.fromWire(this.indirect).resolve( - [this.pointsOffset, this.scalarsOffset, this.outputOffset], - memory, - ); + public async execute(context: AvmContext): Promise { + const memory = context.machineState.memory.track(this.type); + // Resolve indirects + const [pointsOffset, scalarsOffset, outputOffset] = Addressing.fromWire(this.indirect).resolve( + [this.pointsOffset, this.scalarsOffset, this.outputOffset], + memory, + ); - // Length of the points vector should be U32 - memory.checkTag(TypeTag.UINT32, this.pointsLengthOffset); + // Length of the points vector should be U32 + memory.checkTag(TypeTag.UINT32, this.pointsLengthOffset); + // Get the size of the unrolled (x, y , inf) points vector + const pointsReadLength = memory.get(this.pointsLengthOffset).toNumber(); + assert(pointsReadLength % 3 === 0, 'Points vector offset should be a multiple of 3'); + // Divide by 3 since each point is represented as a triplet to get the number of points + const numPoints = pointsReadLength / 3; + // The tag for each triplet will be (Field, Field, Uint8) + for (let i = 0; i < numPoints; i++) { + const offset = pointsOffset + i * 3; + // Check (Field, Field) + memory.checkTagsRange(TypeTag.FIELD, offset, 2); + // Check Uint8 (inf flag) + memory.checkTag(TypeTag.UINT8, offset + 2); + } + // Get the unrolled (x, y, inf) representing the points + const pointsVector = memory.getSlice(pointsOffset, pointsReadLength); - // Get the size of the unrolled (x, y , inf) points vector - // TODO: Do we need to assert that the length is a multiple of 3 (x, y, inf)? - const pointsReadLength = memory.get(this.pointsLengthOffset).toNumber(); - // Divide by 3 since each point is represented as a triplet to get the number of points - const numPoints = pointsReadLength / 3; - // The tag for each triplet will be (Field, Field, Uint8) - for (let i = 0; i < numPoints; i++) { - const offset = pointsOffset + i * 3; - // Check (Field, Field) - memory.checkTagsRange(TypeTag.FIELD, offset, 2); - // Check Uint8 (inf flag) - memory.checkTag(TypeTag.UINT8, offset + 2); - } - // Get the unrolled (x, y, inf) representing the points - const pointsVector = memory.getSlice(pointsOffset, pointsReadLength); + // The size of the scalars vector is twice the NUMBER of points because of the scalar limb decomposition + const scalarReadLength = numPoints * 2; + // Consume gas prior to performing work + const memoryOperations = { + reads: 1 + pointsReadLength + scalarReadLength /* points and scalars */, + writes: 3 /* output triplet */, + indirect: this.indirect, + }; + context.machineState.consumeGas(this.gasCost(memoryOperations)); + // Get the unrolled scalar (lo & hi) representing the scalars + const scalarsVector = memory.getSlice(scalarsOffset, scalarReadLength); + memory.checkTagsRange(TypeTag.FIELD, scalarsOffset, scalarReadLength); - // The size of the scalars vector is twice the NUMBER of points because of the scalar limb decomposition - const scalarReadLength = numPoints * 2; - // Get the unrolled scalar (lo & hi) representing the scalars - const scalarsVector = memory.getSlice(scalarsOffset, scalarReadLength); - memory.checkTagsRange(TypeTag.FIELD, scalarsOffset, scalarReadLength); + // Now we need to reconstruct the points and scalars into something we can operate on. + const grumpkinPoints: Point[] = []; + for (let i = 0; i < numPoints; i++) { + const p: Point = new Point(pointsVector[3 * i].toFr(), pointsVector[3 * i + 1].toFr()); + // Include this later when we have a standard for representing infinity + // const isInf = pointsVector[i + 2].toBoolean(); - // Now we need to reconstruct the points and scalars into something we can operate on. - const grumpkinPoints: Point[] = []; - for (let i = 0; i < numPoints; i++) { - const p: Point = new Point(pointsVector[3 * i].toFr(), pointsVector[3 * i + 1].toFr()); - // Include this later when we have a standard for representing infinity - // const isInf = pointsVector[i + 2].toBoolean(); + if (!p.isOnGrumpkin()) { + throw new InstructionExecutionError(`Point ${p.toString()} is not on the curve.`); + } + grumpkinPoints.push(p); + } + // The scalars are read from memory as Fr elements, which are limbs of Fq elements + // So we need to reconstruct them before performing the scalar multiplications + const scalarFqVector: Fq[] = []; + for (let i = 0; i < numPoints; i++) { + const scalarLo = scalarsVector[2 * i].toFr(); + const scalarHi = scalarsVector[2 * i + 1].toFr(); + const fqScalar = Fq.fromHighLow(scalarHi, scalarLo); + scalarFqVector.push(fqScalar); + } + // TODO: Is there an efficient MSM implementation in ts that we can replace this by? + const grumpkin = new Grumpkin(); + // Zip the points and scalars into pairs + const [firstBaseScalarPair, ...rest]: Array<[Point, Fq]> = grumpkinPoints.map((p, idx) => [p, scalarFqVector[idx]]); + // Fold the points and scalars into a single point + // We have to ensure get the first point, since the identity element (point at infinity) isn't quite working in ts + const outputPoint = rest.reduce( + (acc, curr) => grumpkin.add(acc, grumpkin.mul(curr[0], curr[1])), + grumpkin.mul(firstBaseScalarPair[0], firstBaseScalarPair[1]), + ); + const output = outputPoint.toFieldsWithInf().map(f => new Field(f)); - if (!p.isOnGrumpkin()) { - throw new InstructionExecutionError(`Point ${p.toString()} is not on the curve.`); - } - grumpkinPoints.push(p); - } - // The scalars are read from memory as Fr elements, which are limbs of Fq elements - // So we need to reconstruct them before performing the scalar multiplications - const scalarFqVector: Fq[] = []; - for (let i = 0; i < numPoints; i++) { - const scalarLo = scalarsVector[2 * i].toFr(); - const scalarHi = scalarsVector[2 * i + 1].toFr(); - const fqScalar = Fq.fromHighLow(scalarHi, scalarLo); - scalarFqVector.push(fqScalar); - } - // TODO: Is there an efficient MSM implementation in ts that we can replace this by? - const grumpkin = new Grumpkin(); - // Zip the points and scalars into pairs - const [firstBaseScalarPair, ...rest]: Array<[Point, Fq]> = grumpkinPoints.map((p, idx) => [p, scalarFqVector[idx]]); - // Fold the points and scalars into a single point - // We have to ensure get the first point, since the identity element (point at infinity) isn't quite working in ts - const outputPoint = rest.reduce( - (acc, curr) => grumpkin.add(acc, grumpkin.mul(curr[0], curr[1])), - grumpkin.mul(firstBaseScalarPair[0], firstBaseScalarPair[1]), - ); - // TODO: Check the Infinity flag here - const output: Fr[] = [outputPoint.x, outputPoint.y, outputPoint.isInfPoint() ? Fr.ONE : Fr.ZERO]; - - memory.setSlice( - outputOffset, - output.map(word => new Field(word)), - ); + memory.setSlice(outputOffset, output); - const memoryOperations = { - reads: 1 + pointsReadLength + scalarReadLength /* points and scalars */, - writes: 3 /* output triplet */, - indirect: this.indirect, - }; - context.machineState.consumeGas(this.gasCost(memoryOperations)); - memory.assert(memoryOperations); - context.machineState.incrementPc(); - } + memory.assert(memoryOperations); + context.machineState.incrementPc(); + } }