Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test(avm): Negative unit tests for AVM CAST opcode #5907

Merged
merged 1 commit into from
Apr 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 196 additions & 16 deletions barretenberg/cpp/src/barretenberg/vm/tests/avm_cast.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ using namespace bb::avm_trace;
using namespace testing;

class AvmCastTests : public ::testing::Test {
public:
AvmTraceBuilder trace_builder;

protected:
AvmTraceBuilder trace_builder;
std::vector<Row> trace;
size_t main_idx;
size_t alu_idx;
size_t mem_idx_c;

// TODO(640): The Standard Honk on Grumpkin test suite fails unless the SRS is initialised for every test.
void SetUp() override { srs::init_crs_factory("../srs_db/ignition"); };
Expand All @@ -25,6 +26,28 @@ class AvmCastTests : public ::testing::Test {
trace_builder.op_cast(0, src_address, dst_address, dst_tag);
trace_builder.return_op(0, 0, 0);
trace = trace_builder.finalize();
gen_indices();
}

void gen_indices()
{
auto row =
std::ranges::find_if(trace.begin(), trace.end(), [](Row r) { return r.avm_main_sel_op_cast == FF(1); });
ASSERT_TRUE(row != trace.end());
main_idx = static_cast<size_t>(row - trace.begin());

// Find the corresponding Alu trace row
auto clk = row->avm_main_clk;
auto alu_row = std::ranges::find_if(trace.begin(), trace.end(), [clk](Row r) { return r.avm_alu_clk == clk; });
ASSERT_TRUE(alu_row != trace.end());
alu_idx = static_cast<size_t>(alu_row - trace.begin());

// Mem entry output ic write operation
auto mem_row_c = std::ranges::find_if(trace.begin(), trace.end(), [clk](Row r) {
return r.avm_mem_clk == clk && r.avm_mem_sub_clk == AvmMemTraceBuilder::SUB_CLK_STORE_C;
});
ASSERT_TRUE(mem_row_c != trace.end());
mem_idx_c = static_cast<size_t>(mem_row_c - trace.begin());
}

void validate_cast_trace(FF const& a,
Expand All @@ -36,11 +59,8 @@ class AvmCastTests : public ::testing::Test {

)
{
auto row =
std::ranges::find_if(trace.begin(), trace.end(), [](Row r) { return r.avm_main_sel_op_cast == FF(1); });
ASSERT_TRUE(row != trace.end());

EXPECT_THAT(*row,
auto const& row = trace.at(main_idx);
EXPECT_THAT(row,
AllOf(Field("sel_op_cast", &Row::avm_main_sel_op_cast, 1),
Field("ia", &Row::avm_main_ia, a),
Field("ib", &Row::avm_main_ib, 0),
Expand All @@ -59,12 +79,8 @@ class AvmCastTests : public ::testing::Test {
Field("sel_rng_8", &Row::avm_main_sel_rng_8, 1),
Field("sel_rng_16", &Row::avm_main_sel_rng_16, 1)));

// Find the corresponding Alu trace row
auto clk = row->avm_main_clk;
auto alu_row = std::ranges::find_if(trace.begin(), trace.end(), [clk](Row r) { return r.avm_alu_clk == clk; });
ASSERT_TRUE(alu_row != trace.end());

EXPECT_THAT(*alu_row,
auto const& alu_row = trace.at(alu_idx);
EXPECT_THAT(alu_row,
AllOf(Field("op_cast", &Row::avm_alu_op_cast, 1),
Field("alu_ia", &Row::avm_alu_ia, a),
Field("alu_ib", &Row::avm_alu_ib, 0),
Expand All @@ -81,15 +97,17 @@ class AvmCastTests : public ::testing::Test {
Field("alu_sel", &Row::avm_alu_alu_sel, 1)));

// Check that there is a second ALU row
auto alu_row_next = alu_row + 1;
auto alu_row_next = trace.at(alu_idx + 1);
EXPECT_THAT(
*alu_row_next,
alu_row_next,
AllOf(Field("op_cast", &Row::avm_alu_op_cast, 0), Field("op_cast_prev", &Row::avm_alu_op_cast_prev, 1)));

validate_trace(std::move(trace));
}
};

class AvmCastNegativeTests : public AvmCastTests {};

TEST_F(AvmCastTests, basicU8ToU16)
{
gen_trace(237, 0, 1, AvmMemoryTag::U8, AvmMemoryTag::U16);
Expand Down Expand Up @@ -135,6 +153,7 @@ TEST_F(AvmCastTests, truncationFFToU16ModMinus1)
trace_builder.op_cast(0, 0, 1, AvmMemoryTag::U16);
trace_builder.return_op(0, 0, 0);
trace = trace_builder.finalize();
gen_indices();

validate_cast_trace(FF::modulus - 1, 0, 0, 1, AvmMemoryTag::FF, AvmMemoryTag::U16);
}
Expand All @@ -145,6 +164,7 @@ TEST_F(AvmCastTests, truncationFFToU16ModMinus2)
trace_builder.op_cast(0, 0, 1, AvmMemoryTag::U16);
trace_builder.return_op(0, 0, 0);
trace = trace_builder.finalize();
gen_indices();

validate_cast_trace(FF::modulus_minus_two, UINT16_MAX, 0, 1, AvmMemoryTag::FF, AvmMemoryTag::U16);
}
Expand All @@ -168,6 +188,7 @@ TEST_F(AvmCastTests, indirectAddrTruncationU64ToU8)
trace_builder.op_cast(3, 0, 1, AvmMemoryTag::U8);
trace_builder.return_op(0, 0, 0);
trace = trace_builder.finalize();
gen_indices();

validate_cast_trace(256'000'000'203UL, 203, 10, 11, AvmMemoryTag::U64, AvmMemoryTag::U8);
}
Expand Down Expand Up @@ -205,4 +226,163 @@ TEST_F(AvmCastTests, indirectAddrWrongResolutionU64ToU8)
validate_trace(std::move(trace));
}

TEST_F(AvmCastNegativeTests, nonTruncatedOutputMainIc)
{
gen_trace(300, 0, 1, AvmMemoryTag::U16, AvmMemoryTag::U8);
ASSERT_EQ(trace.at(main_idx).avm_main_ic, 44);

// Replace the output in main trace with the non-truncated value
trace.at(main_idx).avm_main_ic = 300;

// Adapt the memory trace entry
trace.at(mem_idx_c).avm_mem_val = 300;

EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "PERM_MAIN_ALU");
}

TEST_F(AvmCastNegativeTests, wrongOutputMainIc)
{
gen_trace(151515, 0, 1, AvmMemoryTag::U32, AvmMemoryTag::FF);
ASSERT_EQ(trace.at(main_idx).avm_main_ic, 151515);

// Replace the output in main trace with a wrong value
trace.at(main_idx).avm_main_ic = 151516;

// Adapt the memory trace entry
trace.at(mem_idx_c).avm_mem_val = 151516;

EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "PERM_MAIN_ALU");
}

TEST_F(AvmCastNegativeTests, wrongOutputAluIc)
{
gen_trace(6582736, 0, 1, AvmMemoryTag::U128, AvmMemoryTag::U16);
ASSERT_EQ(trace.at(alu_idx).avm_alu_ic, 29136);

// Replace output in ALU, MAIN, and MEM trace
trace.at(alu_idx).avm_alu_ic = 33;
trace.at(main_idx).avm_main_ic = 33;
trace.at(mem_idx_c).avm_mem_val = 33;

EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "ALU_OP_CAST");
}

TEST_F(AvmCastNegativeTests, wrongLimbDecompositionInput)
{
trace_builder.calldata_copy(0, 0, 1, 0, { FF(FF::modulus_minus_two) });
trace_builder.op_cast(0, 0, 1, AvmMemoryTag::U16);
trace_builder.return_op(0, 0, 0);
trace = trace_builder.finalize();
gen_indices();

trace.at(alu_idx).avm_alu_a_lo -= 23;

EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "INPUT_DECOMP_1");
}

TEST_F(AvmCastNegativeTests, wrongPSubALo)
{
gen_trace(12345, 0, 1, AvmMemoryTag::U32, AvmMemoryTag::U16);
ASSERT_EQ(trace.at(alu_idx).avm_alu_ic, 12345);

trace.at(alu_idx).avm_alu_p_sub_a_lo += 3;

EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "SUB_LO_1");
}

TEST_F(AvmCastNegativeTests, wrongPSubAHi)
{
trace_builder.calldata_copy(0, 0, 1, 0, { FF(FF::modulus_minus_two - 987) });
trace_builder.op_cast(0, 0, 1, AvmMemoryTag::U16);
trace_builder.return_op(0, 0, 0);
trace = trace_builder.finalize();
gen_indices();

trace.at(alu_idx).avm_alu_p_sub_a_hi += 3;

EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "SUB_HI_1");
}

TEST_F(AvmCastNegativeTests, disableRangecheck)
{
gen_trace(123, 23, 43, AvmMemoryTag::U8, AvmMemoryTag::U8);

trace.at(alu_idx).avm_alu_rng_chk_lookup_selector = 0;
EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "RNG_CHK_LOOKUP_SELECTOR");
}

TEST_F(AvmCastNegativeTests, disableRangecheckSub)
{
gen_trace(123, 23, 43, AvmMemoryTag::U8, AvmMemoryTag::U8);

trace.at(alu_idx + 1).avm_alu_rng_chk_lookup_selector = 0;
EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "RNG_CHK_LOOKUP_SELECTOR");
}

TEST_F(AvmCastNegativeTests, wrongRangeCheckDecompositionLo)
{
gen_trace(987344323, 23, 43, AvmMemoryTag::FF, AvmMemoryTag::U128);
ASSERT_EQ(trace.at(alu_idx).avm_alu_ic, 987344323);

trace.at(alu_idx).avm_alu_u16_r0 = 5555;
EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "LOWER_CMP_RNG_CHK");
}

TEST_F(AvmCastNegativeTests, wrongRangeCheckDecompositionHi)
{
trace_builder.calldata_copy(0, 0, 1, 0, { FF(FF::modulus_minus_two - 987) });
trace_builder.op_cast(0, 0, 1, AvmMemoryTag::U16);
trace_builder.return_op(0, 0, 0);
trace = trace_builder.finalize();
gen_indices();

trace.at(alu_idx).avm_alu_u16_r9 = 5555;
EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "UPPER_CMP_RNG_CHK");
}

TEST_F(AvmCastNegativeTests, outOfRangeU8Registers)
{
gen_trace(987344323, 23, 43, AvmMemoryTag::FF, AvmMemoryTag::U128);
ASSERT_EQ(trace.at(alu_idx).avm_alu_ic, 987344323);

trace.at(alu_idx).avm_alu_u8_r0 += 256;
trace.at(alu_idx).avm_alu_u8_r1 -= 1; // Adjust so that the decomposition is correct.

EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "Lookup LOOKUP_U8_0");
}

TEST_F(AvmCastNegativeTests, outOfRangeU16Registers)
{
gen_trace(987344323, 23, 43, AvmMemoryTag::FF, AvmMemoryTag::U128);
ASSERT_EQ(trace.at(alu_idx).avm_alu_ic, 987344323);

trace.at(alu_idx).avm_alu_u16_r0 += 65536;
trace.at(alu_idx).avm_alu_u16_r1 -= 1; // Adjust so that the decomposition is correct.

EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "Lookup LOOKUP_U16_0");
}

TEST_F(AvmCastNegativeTests, wrongCopySubLoForRangeCheck)
{
gen_trace(987344323, 23, 43, AvmMemoryTag::U64, AvmMemoryTag::U128);
ASSERT_EQ(trace.at(alu_idx).avm_alu_ic, 987344323);

ASSERT_EQ(trace.at(alu_idx + 1).avm_alu_a_lo, trace.at(alu_idx).avm_alu_p_sub_a_lo);
trace.at(alu_idx + 1).avm_alu_a_lo -= 1;
EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "OP_CAST_RNG_CHECK_P_SUB_A_LOW");
}

TEST_F(AvmCastNegativeTests, wrongCopySubHiForRangeCheck)
{
trace_builder.calldata_copy(0, 0, 1, 0, { FF(FF::modulus_minus_two - 972836) });
trace_builder.op_cast(0, 0, 1, AvmMemoryTag::U128);
trace_builder.return_op(0, 0, 0);
trace = trace_builder.finalize();
gen_indices();

ASSERT_EQ(trace.at(alu_idx + 1).avm_alu_a_hi, trace.at(alu_idx).avm_alu_p_sub_a_hi);
trace.at(alu_idx + 1).avm_alu_a_hi += 2;
EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(trace)), "OP_CAST_RNG_CHECK_P_SUB_A_HIGH");
}

} // namespace tests_avm
Loading