From 973ff2f0ad11b78b5dcab1537abe5cb611af8db2 Mon Sep 17 00:00:00 2001 From: Facundo Date: Mon, 29 Jan 2024 10:18:01 +0000 Subject: [PATCH] feat(avm): implement comparator opcodes (#4232) - Added `lt` and `equals` to AVM's memory types - Implemented `eq`, `lt`, `lte` opcodes to work with integral types, fields, and inTag checking - Added tests Refers to #4120. --- .../src/avm/avm_memory_types.ts | 21 +++ .../src/avm/opcodes/comparators.test.ts | 147 ++++++++++++++++++ .../src/avm/opcodes/comparators.ts | 29 ++-- .../src/avm/opcodes/control_flow.test.ts | 6 +- 4 files changed, 190 insertions(+), 13 deletions(-) create mode 100644 yarn-project/acir-simulator/src/avm/opcodes/comparators.test.ts diff --git a/yarn-project/acir-simulator/src/avm/avm_memory_types.ts b/yarn-project/acir-simulator/src/avm/avm_memory_types.ts index 8d6d240c922..9f831bfd21e 100644 --- a/yarn-project/acir-simulator/src/avm/avm_memory_types.ts +++ b/yarn-project/acir-simulator/src/avm/avm_memory_types.ts @@ -8,6 +8,9 @@ export abstract class MemoryValue { public abstract mul(rhs: MemoryValue): MemoryValue; public abstract div(rhs: MemoryValue): MemoryValue; + public abstract equals(rhs: MemoryValue): boolean; + public abstract lt(rhs: MemoryValue): boolean; + // We need this to be able to build an instance of the subclasses. public abstract build(n: bigint): MemoryValue; @@ -92,6 +95,16 @@ abstract class UnsignedInteger extends IntegralValue { return this.build(~this.n & this.bitmask); } + public equals(rhs: UnsignedInteger): boolean { + assert(this.bits == rhs.bits); + return this.n === rhs.n; + } + + public lt(rhs: UnsignedInteger): boolean { + assert(this.bits == rhs.bits); + return this.n < rhs.n; + } + public toBigInt(): bigint { return this.n; } @@ -176,6 +189,14 @@ export class Field extends MemoryValue { return new Field(this.rep.div(rhs.rep)); } + public equals(rhs: Field): boolean { + return this.rep.equals(rhs.rep); + } + + public lt(rhs: Field): boolean { + return this.rep.lt(rhs.rep); + } + public toBigInt(): bigint { return this.rep.toBigInt(); } diff --git a/yarn-project/acir-simulator/src/avm/opcodes/comparators.test.ts b/yarn-project/acir-simulator/src/avm/opcodes/comparators.test.ts new file mode 100644 index 00000000000..ff604dbf2cb --- /dev/null +++ b/yarn-project/acir-simulator/src/avm/opcodes/comparators.test.ts @@ -0,0 +1,147 @@ +import { MockProxy, mock } from 'jest-mock-extended'; + +import { AvmMachineState } from '../avm_machine_state.js'; +import { Field, TypeTag, Uint16, Uint32 } from '../avm_memory_types.js'; +import { initExecutionEnvironment } from '../fixtures/index.js'; +import { AvmJournal } from '../journal/journal.js'; +import { Eq, Lt, Lte } from './comparators.js'; +import { InstructionExecutionError } from './instruction.js'; + +describe('Comparators', () => { + let machineState: AvmMachineState; + let journal: MockProxy; + + beforeEach(async () => { + machineState = new AvmMachineState(initExecutionEnvironment()); + journal = mock(); + }); + + describe('Eq', () => { + it('Works on integral types', async () => { + machineState.memory.setSlice(0, [new Uint32(1), new Uint32(2), new Uint32(3), new Uint32(1)]); + + [ + new Eq(TypeTag.UINT32, /*aOffset=*/ 0, /*bOffset=*/ 1, /*dstOffset=*/ 10), + new Eq(TypeTag.UINT32, /*aOffset=*/ 0, /*bOffset=*/ 2, /*dstOffset=*/ 11), + new Eq(TypeTag.UINT32, /*aOffset=*/ 0, /*bOffset=*/ 3, /*dstOffset=*/ 12), + ].forEach(i => i.execute(machineState, journal)); + + const actual = machineState.memory.getSlice(/*offset=*/ 10, /*size=*/ 4); + expect(actual).toEqual([new Uint32(0), new Uint32(0), new Uint32(1)]); + }); + + it('Works on field elements', async () => { + machineState.memory.setSlice(0, [new Field(1), new Field(2), new Field(3), new Field(1)]); + + [ + new Eq(TypeTag.FIELD, /*aOffset=*/ 0, /*bOffset=*/ 1, /*dstOffset=*/ 10), + new Eq(TypeTag.FIELD, /*aOffset=*/ 0, /*bOffset=*/ 2, /*dstOffset=*/ 11), + new Eq(TypeTag.FIELD, /*aOffset=*/ 0, /*bOffset=*/ 3, /*dstOffset=*/ 12), + ].forEach(i => i.execute(machineState, journal)); + + const actual = machineState.memory.getSlice(/*offset=*/ 10, /*size=*/ 4); + expect(actual).toEqual([new Field(0), new Field(0), new Field(1)]); + }); + + it('InTag is checked', async () => { + machineState.memory.setSlice(0, [new Field(1), new Uint32(2), new Uint16(3)]); + + const ops = [ + new Eq(TypeTag.FIELD, /*aOffset=*/ 0, /*bOffset=*/ 1, /*dstOffset=*/ 10), + new Eq(TypeTag.UINT32, /*aOffset=*/ 0, /*bOffset=*/ 2, /*dstOffset=*/ 10), + new Eq(TypeTag.UINT16, /*aOffset=*/ 1, /*bOffset=*/ 2, /*dstOffset=*/ 10), + new Eq(TypeTag.UINT16, /*aOffset=*/ 1, /*bOffset=*/ 1, /*dstOffset=*/ 10), + ]; + + for (const o of ops) { + await expect(() => o.execute(machineState, journal)).rejects.toThrow(InstructionExecutionError); + } + }); + }); + + describe('Lt', () => { + it('Works on integral types', async () => { + machineState.memory.setSlice(0, [new Uint32(1), new Uint32(2), new Uint32(0)]); + + [ + new Lt(TypeTag.UINT32, /*aOffset=*/ 0, /*bOffset=*/ 0, /*dstOffset=*/ 10), + new Lt(TypeTag.UINT32, /*aOffset=*/ 0, /*bOffset=*/ 1, /*dstOffset=*/ 11), + new Lt(TypeTag.UINT32, /*aOffset=*/ 0, /*bOffset=*/ 2, /*dstOffset=*/ 12), + ].forEach(i => i.execute(machineState, journal)); + + const actual = machineState.memory.getSlice(/*offset=*/ 10, /*size=*/ 4); + expect(actual).toEqual([new Uint32(0), new Uint32(1), new Uint32(0)]); + }); + + it('Works on field elements', async () => { + machineState.memory.setSlice(0, [new Field(1), new Field(2), new Field(0)]); + + [ + new Lt(TypeTag.FIELD, /*aOffset=*/ 0, /*bOffset=*/ 0, /*dstOffset=*/ 10), + new Lt(TypeTag.FIELD, /*aOffset=*/ 0, /*bOffset=*/ 1, /*dstOffset=*/ 11), + new Lt(TypeTag.FIELD, /*aOffset=*/ 0, /*bOffset=*/ 2, /*dstOffset=*/ 12), + ].forEach(i => i.execute(machineState, journal)); + + const actual = machineState.memory.getSlice(/*offset=*/ 10, /*size=*/ 4); + expect(actual).toEqual([new Field(0), new Field(1), new Field(0)]); + }); + + it('InTag is checked', async () => { + machineState.memory.setSlice(0, [new Field(1), new Uint32(2), new Uint16(3)]); + + const ops = [ + new Lt(TypeTag.FIELD, /*aOffset=*/ 0, /*bOffset=*/ 1, /*dstOffset=*/ 10), + new Lt(TypeTag.UINT32, /*aOffset=*/ 0, /*bOffset=*/ 2, /*dstOffset=*/ 10), + new Lt(TypeTag.UINT16, /*aOffset=*/ 1, /*bOffset=*/ 2, /*dstOffset=*/ 10), + new Lt(TypeTag.UINT16, /*aOffset=*/ 1, /*bOffset=*/ 1, /*dstOffset=*/ 10), + ]; + + for (const o of ops) { + await expect(() => o.execute(machineState, journal)).rejects.toThrow(InstructionExecutionError); + } + }); + }); + + describe('Lte', () => { + it('Works on integral types', async () => { + machineState.memory.setSlice(0, [new Uint32(1), new Uint32(2), new Uint32(0)]); + + [ + new Lte(TypeTag.UINT32, /*aOffset=*/ 0, /*bOffset=*/ 0, /*dstOffset=*/ 10), + new Lte(TypeTag.UINT32, /*aOffset=*/ 0, /*bOffset=*/ 1, /*dstOffset=*/ 11), + new Lte(TypeTag.UINT32, /*aOffset=*/ 0, /*bOffset=*/ 2, /*dstOffset=*/ 12), + ].forEach(i => i.execute(machineState, journal)); + + const actual = machineState.memory.getSlice(/*offset=*/ 10, /*size=*/ 4); + expect(actual).toEqual([new Uint32(1), new Uint32(1), new Uint32(0)]); + }); + + it('Works on field elements', async () => { + machineState.memory.setSlice(0, [new Field(1), new Field(2), new Field(0)]); + + [ + new Lte(TypeTag.FIELD, /*aOffset=*/ 0, /*bOffset=*/ 0, /*dstOffset=*/ 10), + new Lte(TypeTag.FIELD, /*aOffset=*/ 0, /*bOffset=*/ 1, /*dstOffset=*/ 11), + new Lte(TypeTag.FIELD, /*aOffset=*/ 0, /*bOffset=*/ 2, /*dstOffset=*/ 12), + ].forEach(i => i.execute(machineState, journal)); + + const actual = machineState.memory.getSlice(/*offset=*/ 10, /*size=*/ 4); + expect(actual).toEqual([new Field(1), new Field(1), new Field(0)]); + }); + + it('InTag is checked', async () => { + machineState.memory.setSlice(0, [new Field(1), new Uint32(2), new Uint16(3)]); + + const ops = [ + new Lte(TypeTag.FIELD, /*aOffset=*/ 0, /*bOffset=*/ 1, /*dstOffset=*/ 10), + new Lte(TypeTag.UINT32, /*aOffset=*/ 0, /*bOffset=*/ 2, /*dstOffset=*/ 10), + new Lte(TypeTag.UINT16, /*aOffset=*/ 1, /*bOffset=*/ 2, /*dstOffset=*/ 10), + new Lte(TypeTag.UINT16, /*aOffset=*/ 1, /*bOffset=*/ 1, /*dstOffset=*/ 10), + ]; + + for (const o of ops) { + await expect(() => o.execute(machineState, journal)).rejects.toThrow(InstructionExecutionError); + } + }); + }); +}); diff --git a/yarn-project/acir-simulator/src/avm/opcodes/comparators.ts b/yarn-project/acir-simulator/src/avm/opcodes/comparators.ts index 8bdb406f9c6..f89d450dbb9 100644 --- a/yarn-project/acir-simulator/src/avm/opcodes/comparators.ts +++ b/yarn-project/acir-simulator/src/avm/opcodes/comparators.ts @@ -1,5 +1,5 @@ import { AvmMachineState } from '../avm_machine_state.js'; -import { Field } from '../avm_memory_types.js'; +import { TypeTag } from '../avm_memory_types.js'; import { AvmJournal } from '../journal/index.js'; import { Instruction } from './instruction.js'; @@ -7,16 +7,19 @@ export class Eq extends Instruction { static type: string = 'EQ'; static numberOfOperands = 3; - constructor(private aOffset: number, private bOffset: number, private destOffset: number) { + constructor(private inTag: TypeTag, private aOffset: number, private bOffset: number, private dstOffset: number) { super(); } async execute(machineState: AvmMachineState, _journal: AvmJournal): Promise { + Instruction.checkTags(machineState, this.inTag, this.aOffset, this.bOffset); + const a = machineState.memory.get(this.aOffset); const b = machineState.memory.get(this.bOffset); - const dest = new Field(a.toBigInt() == b.toBigInt() ? 1 : 0); - machineState.memory.set(this.destOffset, dest); + // Result will be of the same type as 'a'. + const dest = a.build(a.equals(b) ? 1n : 0n); + machineState.memory.set(this.dstOffset, dest); this.incrementPc(machineState); } @@ -26,16 +29,19 @@ export class Lt extends Instruction { static type: string = 'Lt'; static numberOfOperands = 3; - constructor(private aOffset: number, private bOffset: number, private destOffset: number) { + constructor(private inTag: TypeTag, private aOffset: number, private bOffset: number, private dstOffset: number) { super(); } async execute(machineState: AvmMachineState, _journal: AvmJournal): Promise { + Instruction.checkTags(machineState, this.inTag, this.aOffset, this.bOffset); + const a = machineState.memory.get(this.aOffset); const b = machineState.memory.get(this.bOffset); - const dest = new Field(a.toBigInt() < b.toBigInt() ? 1 : 0); - machineState.memory.set(this.destOffset, dest); + // Result will be of the same type as 'a'. + const dest = a.build(a.lt(b) ? 1n : 0n); + machineState.memory.set(this.dstOffset, dest); this.incrementPc(machineState); } @@ -45,16 +51,19 @@ export class Lte extends Instruction { static type: string = 'LTE'; static numberOfOperands = 3; - constructor(private aOffset: number, private bOffset: number, private destOffset: number) { + constructor(private inTag: TypeTag, private aOffset: number, private bOffset: number, private dstOffset: number) { super(); } async execute(machineState: AvmMachineState, _journal: AvmJournal): Promise { + Instruction.checkTags(machineState, this.inTag, this.aOffset, this.bOffset); + const a = machineState.memory.get(this.aOffset); const b = machineState.memory.get(this.bOffset); - const dest = new Field(a.toBigInt() < b.toBigInt() ? 1 : 0); - machineState.memory.set(this.destOffset, dest); + // Result will be of the same type as 'a'. + const dest = a.build(a.equals(b) || a.lt(b) ? 1n : 0n); + machineState.memory.set(this.dstOffset, dest); this.incrementPc(machineState); } diff --git a/yarn-project/acir-simulator/src/avm/opcodes/control_flow.test.ts b/yarn-project/acir-simulator/src/avm/opcodes/control_flow.test.ts index ca4cdd45a5c..973655279e6 100644 --- a/yarn-project/acir-simulator/src/avm/opcodes/control_flow.test.ts +++ b/yarn-project/acir-simulator/src/avm/opcodes/control_flow.test.ts @@ -121,9 +121,9 @@ describe('Control Flow Opcodes', () => { new Add(0, 1, 2), new Sub(0, 1, 2), new Mul(0, 1, 2), - new Lt(0, 1, 2), - new Lte(0, 1, 2), - new Eq(0, 1, 2), + new Lt(TypeTag.UINT16, 0, 1, 2), + new Lte(TypeTag.UINT16, 0, 1, 2), + new Eq(TypeTag.UINT16, 0, 1, 2), new Xor(0, 1, 2, TypeTag.UINT16), new And(0, 1, 2, TypeTag.UINT16), new Or(0, 1, 2, TypeTag.UINT16),