Skip to content

Commit

Permalink
Negative unit tests for AVM CAST opcode
Browse files Browse the repository at this point in the history
  • Loading branch information
jeanmon committed Apr 22, 2024
1 parent 2e55713 commit 292e488
Showing 1 changed file with 196 additions and 16 deletions.
212 changes: 196 additions & 16 deletions barretenberg/cpp/src/barretenberg/vm/tests/avm_cast.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ using namespace bb::avm_trace;
using namespace testing;

class AvmCastTests : public ::testing::Test {
public:
AvmTraceBuilder trace_builder;

protected:
AvmTraceBuilder trace_builder;
std::vector<Row> trace;
size_t main_idx;
size_t alu_idx;
size_t mem_idx_c;

// TODO(640): The Standard Honk on Grumpkin test suite fails unless the SRS is initialised for every test.
void SetUp() override { srs::init_crs_factory("../srs_db/ignition"); };
Expand All @@ -25,6 +26,28 @@ class AvmCastTests : public ::testing::Test {
trace_builder.op_cast(0, src_address, dst_address, dst_tag);
trace_builder.return_op(0, 0, 0);
trace = trace_builder.finalize();
gen_indices();
}

void gen_indices()
{
auto row =
std::ranges::find_if(trace.begin(), trace.end(), [](Row r) { return r.avm_main_sel_op_cast == FF(1); });
ASSERT_TRUE(row != trace.end());
main_idx = static_cast<size_t>(row - trace.begin());

// Find the corresponding Alu trace row
auto clk = row->avm_main_clk;
auto alu_row = std::ranges::find_if(trace.begin(), trace.end(), [clk](Row r) { return r.avm_alu_clk == clk; });
ASSERT_TRUE(alu_row != trace.end());
alu_idx = static_cast<size_t>(alu_row - trace.begin());

// Mem entry output ic write operation
auto mem_row_c = std::ranges::find_if(trace.begin(), trace.end(), [clk](Row r) {
return r.avm_mem_clk == clk && r.avm_mem_sub_clk == AvmMemTraceBuilder::SUB_CLK_STORE_C;
});
ASSERT_TRUE(mem_row_c != trace.end());
mem_idx_c = static_cast<size_t>(mem_row_c - trace.begin());
}

void validate_cast_trace(FF const& a,
Expand All @@ -36,11 +59,8 @@ class AvmCastTests : public ::testing::Test {

)
{
auto row =
std::ranges::find_if(trace.begin(), trace.end(), [](Row r) { return r.avm_main_sel_op_cast == FF(1); });
ASSERT_TRUE(row != trace.end());

EXPECT_THAT(*row,
auto const& row = trace.at(main_idx);
EXPECT_THAT(row,
AllOf(Field("sel_op_cast", &Row::avm_main_sel_op_cast, 1),
Field("ia", &Row::avm_main_ia, a),
Field("ib", &Row::avm_main_ib, 0),
Expand All @@ -59,12 +79,8 @@ class AvmCastTests : public ::testing::Test {
Field("sel_rng_8", &Row::avm_main_sel_rng_8, 1),
Field("sel_rng_16", &Row::avm_main_sel_rng_16, 1)));

// Find the corresponding Alu trace row
auto clk = row->avm_main_clk;
auto alu_row = std::ranges::find_if(trace.begin(), trace.end(), [clk](Row r) { return r.avm_alu_clk == clk; });
ASSERT_TRUE(alu_row != trace.end());

EXPECT_THAT(*alu_row,
auto const& alu_row = trace.at(alu_idx);
EXPECT_THAT(alu_row,
AllOf(Field("op_cast", &Row::avm_alu_op_cast, 1),
Field("alu_ia", &Row::avm_alu_ia, a),
Field("alu_ib", &Row::avm_alu_ib, 0),
Expand All @@ -81,15 +97,17 @@ class AvmCastTests : public ::testing::Test {
Field("alu_sel", &Row::avm_alu_alu_sel, 1)));

// Check that there is a second ALU row
auto alu_row_next = alu_row + 1;
auto alu_row_next = trace.at(alu_idx + 1);
EXPECT_THAT(
*alu_row_next,
alu_row_next,
AllOf(Field("op_cast", &Row::avm_alu_op_cast, 0), Field("op_cast_prev", &Row::avm_alu_op_cast_prev, 1)));

validate_trace(std::move(trace));
}
};

class AvmCastNegativeTests : public AvmCastTests {};

TEST_F(AvmCastTests, basicU8ToU16)
{
gen_trace(237, 0, 1, AvmMemoryTag::U8, AvmMemoryTag::U16);
Expand Down Expand Up @@ -135,6 +153,7 @@ TEST_F(AvmCastTests, truncationFFToU16ModMinus1)
trace_builder.op_cast(0, 0, 1, AvmMemoryTag::U16);
trace_builder.return_op(0, 0, 0);
trace = trace_builder.finalize();
gen_indices();

validate_cast_trace(FF::modulus - 1, 0, 0, 1, AvmMemoryTag::FF, AvmMemoryTag::U16);
}
Expand All @@ -145,6 +164,7 @@ TEST_F(AvmCastTests, truncationFFToU16ModMinus2)
trace_builder.op_cast(0, 0, 1, AvmMemoryTag::U16);
trace_builder.return_op(0, 0, 0);
trace = trace_builder.finalize();
gen_indices();

validate_cast_trace(FF::modulus_minus_two, UINT16_MAX, 0, 1, AvmMemoryTag::FF, AvmMemoryTag::U16);
}
Expand All @@ -168,6 +188,7 @@ TEST_F(AvmCastTests, indirectAddrTruncationU64ToU8)
trace_builder.op_cast(3, 0, 1, AvmMemoryTag::U8);
trace_builder.return_op(0, 0, 0);
trace = trace_builder.finalize();
gen_indices();

validate_cast_trace(256'000'000'203UL, 203, 10, 11, AvmMemoryTag::U64, AvmMemoryTag::U8);
}
Expand Down Expand Up @@ -205,4 +226,163 @@ TEST_F(AvmCastTests, indirectAddrWrongResolutionU64ToU8)
validate_trace(std::move(trace));
}

TEST_F(AvmCastNegativeTests, nonTruncatedOutputMainIc)
{
gen_trace(300, 0, 1, AvmMemoryTag::U16, AvmMemoryTag::U8);
ASSERT_EQ(trace.at(main_idx).avm_main_ic, 44);

// Replace the output in main trace with the non-truncated value
trace.at(main_idx).avm_main_ic = 300;

// Adapt the memory trace entry
trace.at(mem_idx_c).avm_mem_val = 300;

EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "PERM_MAIN_ALU");
}

TEST_F(AvmCastNegativeTests, wrongOutputMainIc)
{
gen_trace(151515, 0, 1, AvmMemoryTag::U32, AvmMemoryTag::FF);
ASSERT_EQ(trace.at(main_idx).avm_main_ic, 151515);

// Replace the output in main trace with a wrong value
trace.at(main_idx).avm_main_ic = 151516;

// Adapt the memory trace entry
trace.at(mem_idx_c).avm_mem_val = 151516;

EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "PERM_MAIN_ALU");
}

TEST_F(AvmCastNegativeTests, wrongOutputAluIc)
{
gen_trace(6582736, 0, 1, AvmMemoryTag::U128, AvmMemoryTag::U16);
ASSERT_EQ(trace.at(alu_idx).avm_alu_ic, 29136);

// Replace output in ALU, MAIN, and MEM trace
trace.at(alu_idx).avm_alu_ic = 33;
trace.at(main_idx).avm_main_ic = 33;
trace.at(mem_idx_c).avm_mem_val = 33;

EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "ALU_OP_CAST");
}

TEST_F(AvmCastNegativeTests, wrongLimbDecompositionInput)
{
trace_builder.calldata_copy(0, 0, 1, 0, { FF(FF::modulus_minus_two) });
trace_builder.op_cast(0, 0, 1, AvmMemoryTag::U16);
trace_builder.return_op(0, 0, 0);
trace = trace_builder.finalize();
gen_indices();

trace.at(alu_idx).avm_alu_a_lo -= 23;

EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "INPUT_DECOMP_1");
}

TEST_F(AvmCastNegativeTests, wrongPSubALo)
{
gen_trace(12345, 0, 1, AvmMemoryTag::U32, AvmMemoryTag::U16);
ASSERT_EQ(trace.at(alu_idx).avm_alu_ic, 12345);

trace.at(alu_idx).avm_alu_p_sub_a_lo += 3;

EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "SUB_LO_1");
}

TEST_F(AvmCastNegativeTests, wrongPSubAHi)
{
trace_builder.calldata_copy(0, 0, 1, 0, { FF(FF::modulus_minus_two - 987) });
trace_builder.op_cast(0, 0, 1, AvmMemoryTag::U16);
trace_builder.return_op(0, 0, 0);
trace = trace_builder.finalize();
gen_indices();

trace.at(alu_idx).avm_alu_p_sub_a_hi += 3;

EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "SUB_HI_1");
}

TEST_F(AvmCastNegativeTests, disableRangecheck)
{
gen_trace(123, 23, 43, AvmMemoryTag::U8, AvmMemoryTag::U8);

trace.at(alu_idx).avm_alu_rng_chk_lookup_selector = 0;
EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "RNG_CHK_LOOKUP_SELECTOR");
}

TEST_F(AvmCastNegativeTests, disableRangecheckSub)
{
gen_trace(123, 23, 43, AvmMemoryTag::U8, AvmMemoryTag::U8);

trace.at(alu_idx + 1).avm_alu_rng_chk_lookup_selector = 0;
EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "RNG_CHK_LOOKUP_SELECTOR");
}

TEST_F(AvmCastNegativeTests, wrongRangeCheckDecompositionLo)
{
gen_trace(987344323, 23, 43, AvmMemoryTag::FF, AvmMemoryTag::U128);
ASSERT_EQ(trace.at(alu_idx).avm_alu_ic, 987344323);

trace.at(alu_idx).avm_alu_u16_r0 = 5555;
EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "LOWER_CMP_RNG_CHK");
}

TEST_F(AvmCastNegativeTests, wrongRangeCheckDecompositionHi)
{
trace_builder.calldata_copy(0, 0, 1, 0, { FF(FF::modulus_minus_two - 987) });
trace_builder.op_cast(0, 0, 1, AvmMemoryTag::U16);
trace_builder.return_op(0, 0, 0);
trace = trace_builder.finalize();
gen_indices();

trace.at(alu_idx).avm_alu_u16_r9 = 5555;
EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "UPPER_CMP_RNG_CHK");
}

TEST_F(AvmCastNegativeTests, outOfRangeU8Registers)
{
gen_trace(987344323, 23, 43, AvmMemoryTag::FF, AvmMemoryTag::U128);
ASSERT_EQ(trace.at(alu_idx).avm_alu_ic, 987344323);

trace.at(alu_idx).avm_alu_u8_r0 += 256;
trace.at(alu_idx).avm_alu_u8_r1 -= 1; // Adjust so that the decomposition is correct.

EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "Lookup LOOKUP_U8_0");
}

TEST_F(AvmCastNegativeTests, outOfRangeU16Registers)
{
gen_trace(987344323, 23, 43, AvmMemoryTag::FF, AvmMemoryTag::U128);
ASSERT_EQ(trace.at(alu_idx).avm_alu_ic, 987344323);

trace.at(alu_idx).avm_alu_u16_r0 += 65536;
trace.at(alu_idx).avm_alu_u16_r1 -= 1; // Adjust so that the decomposition is correct.

EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "Lookup LOOKUP_U16_0");
}

TEST_F(AvmCastNegativeTests, wrongCopySubLoForRangeCheck)
{
gen_trace(987344323, 23, 43, AvmMemoryTag::U64, AvmMemoryTag::U128);
ASSERT_EQ(trace.at(alu_idx).avm_alu_ic, 987344323);

ASSERT_EQ(trace.at(alu_idx + 1).avm_alu_a_lo, trace.at(alu_idx).avm_alu_p_sub_a_lo);
trace.at(alu_idx + 1).avm_alu_a_lo -= 1;
EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "OP_CAST_RNG_CHECK_P_SUB_A_LOW");
}

TEST_F(AvmCastNegativeTests, wrongCopySubHiForRangeCheck)
{
trace_builder.calldata_copy(0, 0, 1, 0, { FF(FF::modulus_minus_two - 972836) });
trace_builder.op_cast(0, 0, 1, AvmMemoryTag::U128);
trace_builder.return_op(0, 0, 0);
trace = trace_builder.finalize();
gen_indices();

ASSERT_EQ(trace.at(alu_idx + 1).avm_alu_a_hi, trace.at(alu_idx).avm_alu_p_sub_a_hi);
trace.at(alu_idx + 1).avm_alu_a_hi += 2;
EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "OP_CAST_RNG_CHECK_P_SUB_A_HIGH");
}

} // namespace tests_avm

0 comments on commit 292e488

Please sign in to comment.