Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added cast opcode and cast calldata #4423

Merged
merged 3 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,16 @@ struct BrilligOpcode {
static BinaryIntOp bincodeDeserialize(std::vector<uint8_t>);
};

struct Cast {
Circuit::MemoryAddress destination;
Circuit::MemoryAddress source;
uint32_t bit_size;

friend bool operator==(const Cast&, const Cast&);
std::vector<uint8_t> bincodeSerialize() const;
static Cast bincodeDeserialize(std::vector<uint8_t>);
};

struct JumpIfNot {
Circuit::MemoryAddress condition;
uint64_t location;
Expand Down Expand Up @@ -921,6 +931,7 @@ struct BrilligOpcode {

std::variant<BinaryFieldOp,
BinaryIntOp,
Cast,
JumpIfNot,
JumpIf,
Jump,
Expand Down Expand Up @@ -5077,6 +5088,63 @@ Circuit::BrilligOpcode::BinaryIntOp serde::Deserializable<Circuit::BrilligOpcode

namespace Circuit {

inline bool operator==(const BrilligOpcode::Cast& lhs, const BrilligOpcode::Cast& rhs)
{
if (!(lhs.destination == rhs.destination)) {
return false;
}
if (!(lhs.source == rhs.source)) {
return false;
}
if (!(lhs.bit_size == rhs.bit_size)) {
return false;
}
return true;
}

inline std::vector<uint8_t> BrilligOpcode::Cast::bincodeSerialize() const
{
auto serializer = serde::BincodeSerializer();
serde::Serializable<BrilligOpcode::Cast>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BrilligOpcode::Cast BrilligOpcode::Cast::bincodeDeserialize(std::vector<uint8_t> input)
{
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BrilligOpcode::Cast>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw_or_abort("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::BrilligOpcode::Cast>::serialize(const Circuit::BrilligOpcode::Cast& obj,
Serializer& serializer)
{
serde::Serializable<decltype(obj.destination)>::serialize(obj.destination, serializer);
serde::Serializable<decltype(obj.source)>::serialize(obj.source, serializer);
serde::Serializable<decltype(obj.bit_size)>::serialize(obj.bit_size, serializer);
}

template <>
template <typename Deserializer>
Circuit::BrilligOpcode::Cast serde::Deserializable<Circuit::BrilligOpcode::Cast>::deserialize(
Deserializer& deserializer)
{
Circuit::BrilligOpcode::Cast obj;
obj.destination = serde::Deserializable<decltype(obj.destination)>::deserialize(deserializer);
obj.source = serde::Deserializable<decltype(obj.source)>::deserialize(deserializer);
obj.bit_size = serde::Deserializable<decltype(obj.bit_size)>::deserialize(deserializer);
return obj;
}

namespace Circuit {

inline bool operator==(const BrilligOpcode::JumpIfNot& lhs, const BrilligOpcode::JumpIfNot& rhs)
{
if (!(lhs.condition == rhs.condition)) {
Expand Down
56 changes: 55 additions & 1 deletion noir/acvm-repo/acir/codegen/acir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,16 @@ namespace Circuit {
static BinaryIntOp bincodeDeserialize(std::vector<uint8_t>);
};

struct Cast {
Circuit::MemoryAddress destination;
Circuit::MemoryAddress source;
uint32_t bit_size;

friend bool operator==(const Cast&, const Cast&);
std::vector<uint8_t> bincodeSerialize() const;
static Cast bincodeDeserialize(std::vector<uint8_t>);
};

struct JumpIfNot {
Circuit::MemoryAddress condition;
uint64_t location;
Expand Down Expand Up @@ -874,7 +884,7 @@ namespace Circuit {
static Stop bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<BinaryFieldOp, BinaryIntOp, JumpIfNot, JumpIf, Jump, CalldataCopy, Call, Const, Return, ForeignCall, Mov, Load, Store, BlackBox, Trap, Stop> value;
std::variant<BinaryFieldOp, BinaryIntOp, Cast, JumpIfNot, JumpIf, Jump, CalldataCopy, Call, Const, Return, ForeignCall, Mov, Load, Store, BlackBox, Trap, Stop> value;

friend bool operator==(const BrilligOpcode&, const BrilligOpcode&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -4205,6 +4215,50 @@ Circuit::BrilligOpcode::BinaryIntOp serde::Deserializable<Circuit::BrilligOpcode
return obj;
}

namespace Circuit {

inline bool operator==(const BrilligOpcode::Cast &lhs, const BrilligOpcode::Cast &rhs) {
if (!(lhs.destination == rhs.destination)) { return false; }
if (!(lhs.source == rhs.source)) { return false; }
if (!(lhs.bit_size == rhs.bit_size)) { return false; }
return true;
}

inline std::vector<uint8_t> BrilligOpcode::Cast::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<BrilligOpcode::Cast>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BrilligOpcode::Cast BrilligOpcode::Cast::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BrilligOpcode::Cast>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::BrilligOpcode::Cast>::serialize(const Circuit::BrilligOpcode::Cast &obj, Serializer &serializer) {
serde::Serializable<decltype(obj.destination)>::serialize(obj.destination, serializer);
serde::Serializable<decltype(obj.source)>::serialize(obj.source, serializer);
serde::Serializable<decltype(obj.bit_size)>::serialize(obj.bit_size, serializer);
}

template <>
template <typename Deserializer>
Circuit::BrilligOpcode::Cast serde::Deserializable<Circuit::BrilligOpcode::Cast>::deserialize(Deserializer &deserializer) {
Circuit::BrilligOpcode::Cast obj;
obj.destination = serde::Deserializable<decltype(obj.destination)>::deserialize(deserializer);
obj.source = serde::Deserializable<decltype(obj.source)>::deserialize(deserializer);
obj.bit_size = serde::Deserializable<decltype(obj.bit_size)>::deserialize(deserializer);
return obj;
}

namespace Circuit {

inline bool operator==(const BrilligOpcode::JumpIfNot &lhs, const BrilligOpcode::JumpIfNot &rhs) {
Expand Down
26 changes: 13 additions & 13 deletions noir/acvm-repo/acir/tests/test_program_serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,11 @@ fn simple_brillig_foreign_call() {
let bytes = Circuit::serialize_circuit(&circuit);

let expected_serialization: Vec<u8> = vec![
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 173, 143, 177, 10, 192, 32, 12, 68, 207, 148, 150, 118,
234, 175, 216, 63, 232, 207, 116, 232, 226, 32, 226, 247, 171, 24, 225, 6, 113, 209, 7, 33,
199, 5, 194, 221, 9, 192, 160, 178, 145, 102, 154, 247, 234, 182, 115, 60, 102, 221, 47,
203, 121, 69, 59, 20, 246, 78, 254, 198, 149, 231, 80, 253, 187, 248, 249, 48, 106, 205,
220, 189, 187, 144, 33, 24, 144, 0, 93, 119, 243, 238, 108, 1, 0, 0,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 173, 143, 49, 10, 0, 33, 16, 3, 227, 30, 28, 199, 85,
62, 69, 127, 224, 103, 44, 108, 44, 68, 124, 191, 136, 10, 41, 196, 70, 167, 217, 37, 129,
144, 124, 0, 20, 58, 15, 253, 204, 212, 220, 184, 230, 12, 171, 238, 101, 25, 238, 43, 99,
67, 227, 93, 244, 159, 252, 228, 135, 88, 124, 202, 187, 213, 140, 94, 249, 66, 130, 96,
67, 5, 171, 116, 175, 175, 108, 1, 0, 0,
];

assert_eq!(bytes, expected_serialization)
Expand Down Expand Up @@ -294,14 +294,14 @@ fn complex_brillig_foreign_call() {
let bytes = Circuit::serialize_circuit(&circuit);

let expected_serialization: Vec<u8> = vec![
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 83, 81, 10, 131, 48, 12, 125, 105, 215, 205, 125,
237, 10, 131, 237, 0, 221, 78, 224, 93, 196, 63, 69, 63, 61, 190, 5, 95, 177, 6, 193, 15,
43, 104, 32, 164, 9, 175, 201, 107, 146, 22, 0, 4, 147, 216, 160, 134, 103, 161, 159, 74,
196, 149, 180, 126, 159, 252, 36, 95, 46, 127, 20, 71, 115, 1, 142, 246, 0, 142, 113, 31,
78, 58, 239, 156, 115, 201, 218, 63, 187, 242, 127, 110, 65, 93, 208, 59, 253, 7, 109, 193,
56, 104, 223, 170, 239, 80, 120, 16, 83, 102, 225, 250, 247, 14, 243, 46, 138, 170, 253,
76, 234, 86, 93, 219, 55, 245, 96, 21, 84, 83, 253, 36, 231, 47, 173, 217, 184, 19, 227,
47, 204, 207, 119, 26, 40, 76, 164, 251, 178, 144, 17, 127, 189, 34, 151, 201, 4, 0, 0,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 83, 203, 9, 128, 48, 12, 77, 90, 127, 120, 114, 4,
65, 7, 168, 78, 224, 46, 226, 77, 209, 163, 227, 91, 240, 21, 107, 16, 60, 88, 65, 31, 132,
52, 105, 154, 190, 38, 105, 70, 68, 76, 59, 180, 21, 133, 53, 195, 246, 225, 226, 58, 104,
243, 12, 13, 135, 203, 101, 222, 226, 168, 126, 192, 81, 191, 192, 209, 205, 195, 71, 251,
29, 178, 47, 65, 235, 167, 47, 254, 79, 100, 37, 182, 146, 192, 78, 161, 51, 248, 9, 123,
165, 168, 59, 121, 113, 14, 101, 48, 174, 173, 73, 232, 152, 69, 22, 119, 231, 30, 207,
126, 158, 150, 113, 88, 181, 8, 149, 84, 43, 111, 93, 67, 171, 155, 51, 206, 95, 208, 241,
252, 88, 6, 50, 18, 201, 186, 156, 176, 1, 136, 75, 233, 37, 201, 4, 0, 0,
];

assert_eq!(bytes, expected_serialization)
Expand Down
14 changes: 7 additions & 7 deletions noir/acvm-repo/acvm_js/test/shared/complex_foreign_call.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ import { WitnessMap } from '@noir-lang/acvm_js';

// See `complex_brillig_foreign_call` integration test in `acir/tests/test_program_serialization.rs`.
export const bytecode = Uint8Array.from([
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 83, 81, 10, 131, 48, 12, 125, 105, 215, 205, 125, 237, 10, 131, 237, 0, 221,
78, 224, 93, 196, 63, 69, 63, 61, 190, 5, 95, 177, 6, 193, 15, 43, 104, 32, 164, 9, 175, 201, 107, 146, 22, 0, 4, 147,
216, 160, 134, 103, 161, 159, 74, 196, 149, 180, 126, 159, 252, 36, 95, 46, 127, 20, 71, 115, 1, 142, 246, 0, 142,
113, 31, 78, 58, 239, 156, 115, 201, 218, 63, 187, 242, 127, 110, 65, 93, 208, 59, 253, 7, 109, 193, 56, 104, 223,
170, 239, 80, 120, 16, 83, 102, 225, 250, 247, 14, 243, 46, 138, 170, 253, 76, 234, 86, 93, 219, 55, 245, 96, 21, 84,
83, 253, 36, 231, 47, 173, 217, 184, 19, 227, 47, 204, 207, 119, 26, 40, 76, 164, 251, 178, 144, 17, 127, 189, 34,
151, 201, 4, 0, 0,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 83, 203, 9, 128, 48, 12, 77, 90, 127, 120, 114, 4, 65, 7, 168, 78, 224, 46,
226, 77, 209, 163, 227, 91, 240, 21, 107, 16, 60, 88, 65, 31, 132, 52, 105, 154, 190, 38, 105, 70, 68, 76, 59, 180,
21, 133, 53, 195, 246, 225, 226, 58, 104, 243, 12, 13, 135, 203, 101, 222, 226, 168, 126, 192, 81, 191, 192, 209, 205,
195, 71, 251, 29, 178, 47, 65, 235, 167, 47, 254, 79, 100, 37, 182, 146, 192, 78, 161, 51, 248, 9, 123, 165, 168, 59,
121, 113, 14, 101, 48, 174, 173, 73, 232, 152, 69, 22, 119, 231, 30, 207, 126, 158, 150, 113, 88, 181, 8, 149, 84, 43,
111, 93, 67, 171, 155, 51, 206, 95, 208, 241, 252, 88, 6, 50, 18, 201, 186, 156, 176, 1, 136, 75, 233, 37, 201, 4, 0,
0,
]);
export const initialWitnessMap: WitnessMap = new Map([
[1, '0x0000000000000000000000000000000000000000000000000000000000000001'],
Expand Down
8 changes: 4 additions & 4 deletions noir/acvm-repo/acvm_js/test/shared/foreign_call.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ import { WitnessMap } from '@noir-lang/acvm_js';

// See `simple_brillig_foreign_call` integration test in `acir/tests/test_program_serialization.rs`.
export const bytecode = Uint8Array.from([
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 173, 143, 177, 10, 192, 32, 12, 68, 207, 148, 150, 118, 234, 175, 216, 63, 232,
207, 116, 232, 226, 32, 226, 247, 171, 24, 225, 6, 113, 209, 7, 33, 199, 5, 194, 221, 9, 192, 160, 178, 145, 102, 154,
247, 234, 182, 115, 60, 102, 221, 47, 203, 121, 69, 59, 20, 246, 78, 254, 198, 149, 231, 80, 253, 187, 248, 249, 48,
106, 205, 220, 189, 187, 144, 33, 24, 144, 0, 93, 119, 243, 238, 108, 1, 0, 0,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 173, 143, 49, 10, 0, 33, 16, 3, 227, 30, 28, 199, 85, 62, 69, 127, 224, 103, 44,
108, 44, 68, 124, 191, 136, 10, 41, 196, 70, 167, 217, 37, 129, 144, 124, 0, 20, 58, 15, 253, 204, 212, 220, 184, 230,
12, 171, 238, 101, 25, 238, 43, 99, 67, 227, 93, 244, 159, 252, 228, 135, 88, 124, 202, 187, 213, 140, 94, 249, 66,
130, 96, 67, 5, 171, 116, 175, 175, 108, 1, 0, 0,
]);
export const initialWitnessMap: WitnessMap = new Map([
[1, '0x0000000000000000000000000000000000000000000000000000000000000005'],
Expand Down
5 changes: 5 additions & 0 deletions noir/acvm-repo/brillig/src/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ pub enum BrilligOpcode {
lhs: MemoryAddress,
rhs: MemoryAddress,
},
Cast {
destination: MemoryAddress,
source: MemoryAddress,
bit_size: u32,
},
JumpIfNot {
condition: MemoryAddress,
location: Label,
Expand Down
47 changes: 47 additions & 0 deletions noir/acvm-repo/brillig_vm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> {
self.increment_program_counter()
}
}
Opcode::Cast { destination: destination_address, source: source_address, bit_size } => {
let source_value = self.memory.read(*source_address);
let casted_value = self.cast(*bit_size, source_value);
self.memory.write(*destination_address, casted_value);
self.increment_program_counter()
}
Opcode::Jump { location: destination } => self.set_program_counter(*destination),
Opcode::JumpIf { condition, location: destination } => {
// Check if condition is true
Expand Down Expand Up @@ -406,6 +412,13 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> {
.write(result, FieldElement::from_be_bytes_reduce(&result_value.to_bytes_be()).into());
Ok(())
}

/// Casts a value to a different bit size.
fn cast(&self, bit_size: u32, value: Value) -> Value {
let lhs_big = BigUint::from_bytes_be(&value.to_field().to_be_bytes());
let mask = BigUint::from(2_u32).pow(bit_size) - 1_u32;
FieldElement::from_be_bytes_reduce(&(lhs_big & mask).to_bytes_be()).into()
}
}

pub(crate) struct DummyBlackBoxSolver;
Expand Down Expand Up @@ -603,6 +616,40 @@ mod tests {
assert_eq!(output_value, Value::from(false));
}

#[test]
fn cast_opcode() {
let calldata = vec![Value::from((2_u128.pow(32)) - 1)];

let opcodes = &[
Opcode::CalldataCopy {
destination_address: MemoryAddress::from(0),
size: 1,
offset: 0,
},
Opcode::Cast {
destination: MemoryAddress::from(1),
source: MemoryAddress::from(0),
bit_size: 8,
},
Opcode::Stop { return_data_offset: 1, return_data_size: 1 },
];
let mut vm = VM::new(calldata, opcodes, vec![], &DummyBlackBoxSolver);

let status = vm.process_opcode();
assert_eq!(status, VMStatus::InProgress);

let status = vm.process_opcode();
assert_eq!(status, VMStatus::InProgress);

let status = vm.process_opcode();
assert_eq!(status, VMStatus::Finished { return_data_offset: 1, return_data_size: 1 });

let VM { memory, .. } = vm;

let casted_value = memory.read(MemoryAddress::from(1));
assert_eq!(casted_value, Value::from(2_u128.pow(8) - 1));
}

#[test]
fn mov_opcode() {
let calldata = vec![Value::from(1u128), Value::from(2u128), Value::from(3u128)];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use num_bigint::BigUint;

use super::brillig_black_box::convert_black_box_call;
use super::brillig_block_variables::BlockVariables;
use super::brillig_fn::FunctionContext;
use super::brillig_fn::{get_bit_size_from_ssa_type, FunctionContext};

/// Generate the compilation artifacts for compiling a function into brillig bytecode.
pub(crate) struct BrilligBlock<'block> {
Expand Down Expand Up @@ -85,16 +85,6 @@ impl<'block> BrilligBlock<'block> {
self.convert_ssa_terminator(terminator_instruction, dfg);
}

fn get_bit_size_from_ssa_type(typ: &Type) -> u32 {
match typ {
Type::Numeric(num_type) => match num_type {
NumericType::Signed { bit_size } | NumericType::Unsigned { bit_size } => *bit_size,
NumericType::NativeField => FieldElement::max_num_bits(),
},
_ => unreachable!("ICE bitwise not on a non numeric type"),
}
}

/// Creates a unique global label for a block.
///
/// This uses the current functions's function ID and the block ID
Expand Down Expand Up @@ -322,7 +312,7 @@ impl<'block> BrilligBlock<'block> {
dfg.instruction_results(instruction_id)[0],
dfg,
);
let bit_size = Self::get_bit_size_from_ssa_type(&dfg.type_of_value(*value));
let bit_size = get_bit_size_from_ssa_type(&dfg.type_of_value(*value));
self.brillig_context.not_instruction(condition_register, bit_size, result_register);
}
Instruction::Call { func, arguments } => match &dfg[*func] {
Expand Down Expand Up @@ -535,7 +525,7 @@ impl<'block> BrilligBlock<'block> {
*bit_size,
);
}
Instruction::Cast(value, _) => {
Instruction::Cast(value, typ) => {
let result_ids = dfg.instruction_results(instruction_id);
let destination_register = self.variables.define_register_variable(
self.function_context,
Expand All @@ -544,7 +534,7 @@ impl<'block> BrilligBlock<'block> {
dfg,
);
let source_register = self.convert_ssa_register_value(*value, dfg);
self.convert_cast(destination_register, source_register);
self.convert_cast(destination_register, source_register, typ);
}
Instruction::ArrayGet { array, index } => {
let result_ids = dfg.instruction_results(instruction_id);
Expand Down Expand Up @@ -1124,11 +1114,11 @@ impl<'block> BrilligBlock<'block> {

/// Converts an SSA cast to a sequence of Brillig opcodes.
/// Casting is only necessary when shrinking the bit size of a numeric value.
fn convert_cast(&mut self, destination: MemoryAddress, source: MemoryAddress) {
fn convert_cast(&mut self, destination: MemoryAddress, source: MemoryAddress, typ: &Type) {
// We assume that `source` is a valid `target_type` as it's expected that a truncate instruction was emitted
// to ensure this is the case.

self.brillig_context.mov_instruction(destination, source);
self.brillig_context.cast_instruction(destination, source, get_bit_size_from_ssa_type(typ));
}

/// Converts the Binary instruction into a sequence of Brillig opcodes.
Expand Down Expand Up @@ -1174,7 +1164,7 @@ impl<'block> BrilligBlock<'block> {
self.brillig_context.const_instruction(
register_index,
(*constant).into(),
Self::get_bit_size_from_ssa_type(typ),
get_bit_size_from_ssa_type(typ),
);
new_variable
}
Expand Down
Loading
Loading