Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Symbolically propagate assertions inside a basic block #266

Closed
wants to merge 11 commits into from
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
strategy:
matrix:
configurations: [Debug, Release]
runs-on: ubuntu-20.04
runs-on: ubuntu-21.04
env:
# Configuration type to build. For documentation on how build matrices work, see
# https://docs.github.com/actions/learn-github-actions/managing-complex-workflows#using-a-build-matrix
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR
"${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang")
find_package(yaml-cpp REQUIRED)
find_package(Boost REQUIRED)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD 20)
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC")
find_program(NUGET nuget)
if (NOT NUGET)
Expand Down
32 changes: 26 additions & 6 deletions src/asm_cfg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,8 @@ static cfg_t to_nondet(const cfg_t& cfg) {
basic_block_t& newbb = res.insert(this_label);

for (const auto& ins : bb) {
if (!std::holds_alternative<Jmp>(ins)) {
newbb.insert(ins);
}
// We can avoid inserting Jmp to newbb, but then blocks will end with an assertion instead of an instruction
newbb.insert(ins);
}

for (const label_t& prev_label : bb.prev_blocks_set()) {
Expand All @@ -147,6 +146,7 @@ static cfg_t to_nondet(const cfg_t& cfg) {
for (auto const& [next_label, cond1] : jumps) {
label_t jump_label = label_t::make_jump(mid_label, next_label);
basic_block_t& jump_bb = res.insert(jump_label);
jump_bb.insert<Assert>(); // maintain parity - alternating assert/cmd
jump_bb.insert<Assume>(cond1);
newbb >> jump_bb;
jump_bb >> res.insert(next_label);
Expand Down Expand Up @@ -180,9 +180,9 @@ static std::string instype(Instruction ins) {
return "load_store";
} else if (std::holds_alternative<Packet>(ins)) {
return "packet_access";
} else if (std::holds_alternative<Mov>(ins)) {
return "assign";
} else if (std::holds_alternative<Bin>(ins)) {
if (std::get<Bin>(ins).op == Bin::Op::MOV)
return "assign";
return "arith";
} else if (std::holds_alternative<Un>(ins)) {
return "arith";
Expand Down Expand Up @@ -237,12 +237,27 @@ std::map<std::string, int> collect_stats(const cfg_t& cfg) {
return res;
}

static int count_assertions(const cfg_t& cfg) {
int total_assertions = 0;
for (auto& [label, bb] : cfg) {
for (const auto& ins : bb) {
if (std::holds_alternative<Assert>(ins)) {
total_assertions += std::get<Assert>(ins).csts.size();
}
}
}
return total_assertions;
}

cfg_t prepare_cfg(const InstructionSeq& prog, const program_info& info, bool simplify, bool must_have_exit) {
// Convert the instruction sequence to a deterministic control-flow graph.
cfg_t det_cfg = instruction_seq_to_cfg(prog, must_have_exit);

// Annotate the CFG by adding in assertions before every memory instruction.
explicate_assertions(det_cfg, info);
for (auto& [label, bb] : det_cfg) {
explicate_assertions(bb, info);
}
int total_assertions = count_assertions(det_cfg);

// Translate conditional jumps to non-deterministic jumps.
cfg_t cfg = to_nondet(det_cfg);
Expand All @@ -257,5 +272,10 @@ cfg_t prepare_cfg(const InstructionSeq& prog, const program_info& info, bool sim
cfg.simplify();
}

for (auto& [label, bb] : cfg) {
propagate_assertions_backwards(bb);
}
// std::cout << "Total assertions: " << total_assertions << "\n";
// std::cout << "Resolved: " << count_assertions(cfg) << "\n";
return cfg;
}
21 changes: 20 additions & 1 deletion src/asm_marshal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ static uint8_t op(Bin::Op op) {
case Op::RSH: return 0x7;
case Op::MOD: return 0x9;
case Op::XOR: return 0xa;
case Op::MOV: return 0xb;
case Op::ARSH: return 0xc;
}
assert(false);
Expand Down Expand Up @@ -83,6 +82,26 @@ struct MarshalVisitor {

vector<ebpf_inst> operator()(LoadMapFd const& b) { return makeLddw(b.dst, true, b.mapfd, 0); }

vector<ebpf_inst> operator()(Mov const& b) {
if (b.lddw) {
assert(std::holds_alternative<Imm>(b.v));
auto [imm, next_imm] = split(std::get<Imm>(b.v).v);
return makeLddw(b.dst, false, imm, next_imm);
}

ebpf_inst res{.opcode = static_cast<uint8_t>((b.is64 ? INST_CLS_ALU64 : INST_CLS_ALU) | (0xb << 4)),
.dst = b.dst.v,
.src = 0,
.offset = 0,
.imm = 0};
std::visit(overloaded{[&](Reg right) {
res.opcode |= INST_SRC_REG;
res.src = right.v;
},
[&](Imm right) { res.imm = static_cast<int32_t>(right.v); }},
b.v);
return {res};
}
vector<ebpf_inst> operator()(Bin const& b) {
if (b.lddw) {
assert(std::holds_alternative<Imm>(b.v));
Expand Down
46 changes: 32 additions & 14 deletions src/asm_ostream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ std::ostream& operator<<(std::ostream& os, ArgPair arg) {
std::ostream& operator<<(std::ostream& os, Bin::Op op) {
using Op = Bin::Op;
switch (op) {
case Op::MOV: return os;
case Op::ADD: return os << "+";
case Op::SUB: return os << "-";
case Op::MUL: return os << "*";
Expand Down Expand Up @@ -154,7 +153,7 @@ std::ostream& operator<<(std::ostream& os, ZeroOffset const& a) {
return os << crab::variable_t::reg(crab::data_kind_t::offsets, a.reg.v) << " == 0";
}

std::ostream& operator<<(std::ostream& os, Comparable const& a) {
std::ostream& operator<<(std::ostream& os, SameType const& a) {
return os << crab::variable_t::reg(crab::data_kind_t::types, a.r1.v) << " == "
<< crab::variable_t::reg(crab::data_kind_t::types, a.r2.v);
}
Expand Down Expand Up @@ -185,6 +184,14 @@ struct InstructionPrinterVisitor {

void operator()(LoadMapFd const& b) { os_ << b.dst << " = map_fd " << b.mapfd; }

void operator()(Mov const& b) {
os_ << b.dst << " " << "= " << b.v;
if (b.lddw)
os_ << " ll";
if (!b.is64)
os_ << " & 0xFFFFFFFF";
}

void operator()(Bin const& b) {
os_ << b.dst << " " << b.op << "= " << b.v;
if (b.lddw)
Expand Down Expand Up @@ -306,7 +313,9 @@ struct InstructionPrinterVisitor {
}

void operator()(Assert const& a) {
os_ << "assert " << a.cst;
os_ << "assert all ";
for (const auto& cst: a.csts)
os_ << cst << " & ";
}
};

Expand Down Expand Up @@ -358,8 +367,7 @@ void print(const InstructionSeq& insts, std::ostream& out, std::optional<const l
auto pc_of_label = get_labels(insts);
pc_t pc = 0;
InstructionPrinterVisitor visitor{out};
for (const LabeledInstruction& labeled_inst : insts) {
const auto& [label, ins] = labeled_inst;
for (const auto& [label, ins] : insts) {
if (!label_to_print.has_value() || (label == label_to_print)) {
if (label.isjump()) {
out << "\n";
Expand All @@ -376,6 +384,10 @@ void print(const InstructionSeq& insts, std::ostream& out, std::optional<const l
throw std::runtime_error(string("Cannot find label ") + to_string(jmp.target));
pc_t target_pc = pc_of_label.at(jmp.target);
visitor(jmp, target_pc - (int)pc - 1);
} else if (std::holds_alternative<Assert>(ins)) {
for (const auto& cst : std::get<Assert>(ins).csts) {
out << "assert " << cst << "\n";
}
} else {
std::visit(visitor, ins);
}
Expand All @@ -387,13 +399,13 @@ void print(const InstructionSeq& insts, std::ostream& out, std::optional<const l

std::ostream& operator<<(std::ostream& o, const EbpfMapDescriptor& desc) {
return o << "("
<< "original_fd = " << desc.original_fd << ", "
<< "inner_map_fd = " << desc.inner_map_fd << ", "
<< "type = " << desc.type << ", "
<< "max_entries = " << desc.max_entries << ", "
<< "value_size = " << desc.value_size << ", "
<< "key_size = " << desc.key_size <<
")";
<< "original_fd = " << desc.original_fd << ", "
<< "inner_map_fd = " << desc.inner_map_fd << ", "
<< "type = " << desc.type << ", "
<< "max_entries = " << desc.max_entries << ", "
<< "value_size = " << desc.value_size << ", "
<< "key_size = " << desc.key_size <<
")";
}

void print_map_descriptors(const std::vector<EbpfMapDescriptor>& descriptors, std::ostream& o) {
Expand Down Expand Up @@ -432,8 +444,14 @@ void print_dot(const cfg_t& cfg, const std::string& outfile) {

std::ostream& operator<<(std::ostream& o, const basic_block_t& bb) {
o << bb.label() << ":\n";
for (auto const& s : bb) {
o << " " << s << ";\n";
for (auto const& ins : bb) {
if (std::holds_alternative<Assert>(ins)) {
for (const auto& cst : std::get<Assert>(ins).csts) {
o << " assert " << cst << ";\n";
}
} else if (!std::holds_alternative<Jmp>(ins)) {
o << " " << ins << ";\n";
}
}
auto [it, et] = bb.next_blocks();
if (it != et) {
Expand Down
8 changes: 7 additions & 1 deletion src/asm_parse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ using std::regex_match;
#define WRAPPED_LABEL "\\s*" LABEL "\\s*"

static const std::map<std::string, Bin::Op> str_to_binop = {
{"", Bin::Op::MOV}, {"+", Bin::Op::ADD}, {"-", Bin::Op::SUB}, {"*", Bin::Op::MUL},
{"+", Bin::Op::ADD}, {"-", Bin::Op::SUB}, {"*", Bin::Op::MUL},
{"/", Bin::Op::DIV}, {"%", Bin::Op::MOD}, {"|", Bin::Op::OR}, {"&", Bin::Op::AND},
{"<<", Bin::Op::LSH}, {">>", Bin::Op::RSH}, {">>>", Bin::Op::ARSH}, {"^", Bin::Op::XOR},
};
Expand Down Expand Up @@ -101,13 +101,19 @@ Instruction parse_instruction(const std::string& line, const std::map<std::strin
int func = boost::lexical_cast<int>(m[1]);
return Call{.func = func};
}
if (regex_match(text, m, regex(REG ASSIGN REG))) {
return Mov{.dst = reg(m[1]), .v = reg(m[2]), .is64 = true, .lddw = false};
}
if (regex_match(text, m, regex(REG OPASSIGN REG))) {
return Bin{.op = str_to_binop.at(m[2]), .dst = reg(m[1]), .v = reg(m[3]), .is64 = true, .lddw = false};
}
if (regex_match(text, m, regex(REG ASSIGN UNOP REG))) {
if (m[1] != m[3]) throw std::invalid_argument(std::string("Invalid unary operation: ") + text);
return Un{.op = str_to_unop.at(m[2]), .dst = reg(m[1])};
}
if (regex_match(text, m, regex(REG ASSIGN IMM LONGLONG))) {
return Mov{.dst = reg(m[1]), .v = imm(m[2]), .is64 = true, .lddw = !m[3].str().empty()};
}
if (regex_match(text, m, regex(REG OPASSIGN IMM LONGLONG))) {
return Bin{
.op = str_to_binop.at(m[2]), .dst = reg(m[1]), .v = imm(m[3]), .is64 = true, .lddw = !m[4].str().empty()};
Expand Down
Loading