Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use get_if instead of holds_alternative to avoid repetition #703

Merged
merged 1 commit into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 37 additions & 40 deletions src/asm_cfg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

#include "asm_syntax.hpp"
#include "crab/cfg.hpp"
#include "crab_utils/debug.hpp"

using std::optional;
using std::set;
Expand All @@ -20,19 +19,21 @@ using std::to_string;
using std::vector;

static optional<label_t> get_jump(Instruction ins) {
if (std::holds_alternative<Jmp>(ins)) {
return std::get<Jmp>(ins).target;
if (const auto pins = std::get_if<Jmp>(&ins)) {
return pins->target;
}
return {};
}

static bool has_fall(Instruction ins) {
static bool has_fall(const Instruction& ins) {
elazarg marked this conversation as resolved.
Show resolved Hide resolved
if (std::holds_alternative<Exit>(ins)) {
return false;
}

if (std::holds_alternative<Jmp>(ins) && !std::get<Jmp>(ins).cond) {
return false;
if (const auto pins = std::get_if<Jmp>(&ins)) {
if (!pins->cond) {
return false;
}
}

return true;
Expand All @@ -51,8 +52,8 @@ static void add_cfg_nodes(cfg_t& cfg, const label_t& caller_label, const label_t
basic_block_t& caller_node = cfg.get_node(caller_label);
std::string stack_frame_prefix = to_string(caller_label);
for (auto& inst : caller_node) {
if (std::holds_alternative<CallLocal>(inst)) {
std::get<CallLocal>(inst).stack_frame_prefix = stack_frame_prefix;
if (const auto pcall = std::get_if<CallLocal>(&inst)) {
pcall->stack_frame_prefix = stack_frame_prefix;
}
}

Expand All @@ -72,10 +73,10 @@ static void add_cfg_nodes(cfg_t& cfg, const label_t& caller_label, const label_t
const label_t label(macro_label.from, macro_label.to, stack_frame_prefix);
auto& bb = cfg.insert(label);
for (auto inst : cfg.get_node(macro_label)) {
if (std::holds_alternative<Exit>(inst)) {
std::get<Exit>(inst).stack_frame_prefix = label.stack_frame_prefix;
} else if (std::holds_alternative<Call>(inst)) {
std::get<Call>(inst).stack_frame_prefix = label.stack_frame_prefix;
if (const auto pins = std::get_if<Exit>(&inst)) {
pins->stack_frame_prefix = label.stack_frame_prefix;
} else if (const auto pins = std::get_if<Call>(&inst)) {
pins->stack_frame_prefix = label.stack_frame_prefix;
}
bb.insert(inst);
}
Expand Down Expand Up @@ -121,18 +122,18 @@ static void add_cfg_nodes(cfg_t& cfg, const label_t& caller_label, const label_t
for (auto& macro_label : seen_labels) {
for (const label_t label(macro_label.from, macro_label.to, caller_label_str);
const auto& inst : cfg.get_node(label)) {
if (std::holds_alternative<CallLocal>(inst)) {
if (const auto pins = std::get_if<CallLocal>(&inst)) {
elazarg marked this conversation as resolved.
Show resolved Hide resolved
if (stack_frame_depth >= MAX_CALL_STACK_FRAMES) {
throw std::runtime_error{"too many call stack frames"};
}
add_cfg_nodes(cfg, label, std::get<CallLocal>(inst).target);
add_cfg_nodes(cfg, label, pins->target);
}
}
}
}

/// Convert an instruction sequence to a control-flow graph (CFG).
static cfg_t instruction_seq_to_cfg(const InstructionSeq& insts, bool must_have_exit) {
static cfg_t instruction_seq_to_cfg(const InstructionSeq& insts, const bool must_have_exit) {
cfg_t cfg;
std::optional<label_t> falling_from = {};
bool first = true;
Expand All @@ -159,8 +160,7 @@ static cfg_t instruction_seq_to_cfg(const InstructionSeq& insts, bool must_have_
if (has_fall(inst)) {
falling_from = label;
}
auto jump_target = get_jump(inst);
if (jump_target) {
if (auto jump_target = get_jump(inst)) {
bb >> cfg.insert(*jump_target);
}

Expand All @@ -180,16 +180,16 @@ static cfg_t instruction_seq_to_cfg(const InstructionSeq& insts, bool must_have_
// we only add new nodes that are actually reachable, based on the
// results of the first pass.
for (auto& [label, inst, _] : insts) {
if (std::holds_alternative<CallLocal>(inst)) {
add_cfg_nodes(cfg, label, std::get<CallLocal>(inst).target);
if (const auto pins = std::get_if<CallLocal>(&inst)) {
add_cfg_nodes(cfg, label, pins->target);
}
}

return cfg;
}

/// Get the inverse of a given comparison operation.
static Condition::Op reverse(Condition::Op op) {
static Condition::Op reverse(const Condition::Op op) {
switch (op) {
case Condition::Op::EQ: return Condition::Op::NE;
case Condition::Op::NE: return Condition::Op::EQ;
Expand All @@ -214,7 +214,7 @@ static Condition::Op reverse(Condition::Op op) {
}

/// Get the inverse of a given comparison condition.
static Condition reverse(Condition cond) {
static Condition reverse(const Condition& cond) {
return {.op = reverse(cond.op), .left = cond.left, .right = cond.right, .is64 = cond.is64};
}

Expand Down Expand Up @@ -274,31 +274,30 @@ static cfg_t to_nondet(const cfg_t& cfg) {
return res;
}

/// Get the type of a given instruction.
/// Get the type of given instruction.
/// Most of these type names are also statistics header labels.
static std::string instype(Instruction ins) {
if (std::holds_alternative<Call>(ins)) {
auto call = std::get<Call>(ins);
if (call.is_map_lookup) {
if (const auto pcall = std::get_if<Call>(&ins)) {
if (pcall->is_map_lookup) {
return "call_1";
}
if (call.pairs.empty()) {
if (std::all_of(call.singles.begin(), call.singles.end(),
[](ArgSingle kr) { return kr.kind == ArgSingle::Kind::ANYTHING; })) {
if (pcall->pairs.empty()) {
if (std::ranges::all_of(pcall->singles,
[](const ArgSingle kr) { return kr.kind == ArgSingle::Kind::ANYTHING; })) {
elazarg marked this conversation as resolved.
Show resolved Hide resolved
return "call_nomem";
}
}
return "call_mem";
} else if (std::holds_alternative<Callx>(ins)) {
return "callx";
} else if (std::holds_alternative<Mem>(ins)) {
return std::get<Mem>(ins).is_load ? "load" : "store";
} else if (const auto pimm = std::get_if<Mem>(&ins)) {
return pimm->is_load ? "load" : "store";
} else if (std::holds_alternative<Atomic>(ins)) {
return "load_store";
} else if (std::holds_alternative<Packet>(ins)) {
return "packet_access";
} else if (std::holds_alternative<Bin>(ins)) {
switch (std::get<Bin>(ins).op) {
} else if (const auto pins = std::get_if<Bin>(&ins)) {
switch (pins->op) {
case Bin::Op::MOV:
case Bin::Op::MOVSX8:
case Bin::Op::MOVSX16:
Expand Down Expand Up @@ -334,20 +333,18 @@ std::map<std::string, int> collect_stats(const cfg_t& cfg) {
basic_block_t const& bb = cfg.get_node(this_label);

for (Instruction ins : bb) {
if (std::holds_alternative<LoadMapFd>(ins)) {
if (std::get<LoadMapFd>(ins).mapfd == -1) {
if (const auto pins = std::get_if<LoadMapFd>(&ins)) {
if (pins->mapfd == -1) {
res["map_in_map"] = 1;
}
}
if (std::holds_alternative<Call>(ins)) {
auto call = std::get<Call>(ins);
if (call.reallocate_packet) {
if (const auto pins = std::get_if<Call>(&ins)) {
if (pins->reallocate_packet) {
res["reallocate"] = 1;
}
}
if (std::holds_alternative<Bin>(ins)) {
auto const& bin = std::get<Bin>(ins);
res[bin.is64 ? "arith64" : "arith32"]++;
if (const auto pins = std::get_if<Bin>(&ins)) {
res[pins->is64 ? "arith64" : "arith32"]++;
}
res[instype(ins)]++;
}
Expand Down
20 changes: 10 additions & 10 deletions src/asm_marshal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,9 @@ struct MarshalVisitor {

vector<ebpf_inst> operator()(Bin const& b) {
if (b.lddw) {
assert(std::holds_alternative<Imm>(b.v));
auto [imm, next_imm] = split(std::get<Imm>(b.v).v);
const auto pimm = std::get_if<Imm>(&b.v);
assert(pimm != nullptr);
auto [imm, next_imm] = split(pimm->v);
return makeLddw(b.dst, false, imm, next_imm);
}

Expand Down Expand Up @@ -248,9 +249,9 @@ struct MarshalVisitor {
} else {
res.opcode |= INST_CLS_ST;
res.dst = access.basereg.v;
if (std::holds_alternative<Reg>(b.value)) {
if (const auto preg = std::get_if<Reg>(&b.value)) {
res.opcode |= 0x1;
res.src = std::get<Reg>(b.value).v;
res.src = preg->v;
elazarg marked this conversation as resolved.
Show resolved Hide resolved
} else {
res.opcode |= 0x0;
res.imm = static_cast<int32_t>(std::get<Imm>(b.value).v);
Expand Down Expand Up @@ -308,9 +309,9 @@ vector<ebpf_inst> marshal(const vector<Instruction>& insts) {
return res;
}

static int size(Instruction inst) {
if (std::holds_alternative<Bin>(inst)) {
if (std::get<Bin>(inst).lddw) {
static int size(const Instruction& inst) {
if (const auto pins = std::get_if<Bin>(&inst)) {
if (pins->lddw) {
return 2;
}
}
Expand All @@ -336,9 +337,8 @@ vector<ebpf_inst> marshal(const InstructionSeq& insts) {
pc_t pc = 0;
for (auto [label, ins, _] : insts) {
(void)label; // unused
if (std::holds_alternative<Jmp>(ins)) {
Jmp& jmp = std::get<Jmp>(ins);
jmp.target = label_t(pc_of_label.at(jmp.target));
if (const auto pins = std::get_if<Jmp>(&ins)) {
pins->target = label_t(pc_of_label.at(pins->target));
elazarg marked this conversation as resolved.
Show resolved Hide resolved
}
for (auto e : marshal(ins, pc)) {
pc++;
Expand Down
43 changes: 20 additions & 23 deletions src/asm_ostream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ std::ostream& operator<<(std::ostream& os, ValidAccess const& a) {
os << a.offset;
}

if (a.width == (Value)Imm{0}) {
if (a.width == static_cast<Value>(Imm{0})) {
// a.width == 0, meaning we only care it's an in-bound pointer,
// so it can be compared with another pointer to the same region.
os << ") for comparison/subtraction";
Expand All @@ -138,12 +138,12 @@ std::ostream& operator<<(std::ostream& os, ValidAccess const& a) {
static crab::variable_t typereg(const Reg& r) { return crab::variable_t::reg(crab::data_kind_t::types, r.v); }

std::ostream& operator<<(std::ostream& os, ValidSize const& a) {
auto op = a.can_be_zero ? " >= " : " > ";
const auto op = a.can_be_zero ? " >= " : " > ";
return os << a.reg << ".value" << op << 0;
}

std::ostream& operator<<(std::ostream& os, ValidCall const& a) {
EbpfHelperPrototype proto = global_program_info->platform->get_helper_prototype(a.func);
const EbpfHelperPrototype proto = global_program_info->platform->get_helper_prototype(a.func);
return os << "valid call(" << proto.name << ")";
}

Expand Down Expand Up @@ -224,8 +224,7 @@ struct InstructionPrinterVisitor {
os_ << "r0 = " << call.name << ":" << call.func << "(";
for (uint8_t r = 1; r <= 5; r++) {
// Look for a singleton.
std::vector<ArgSingle>::const_iterator single =
std::find_if(call.singles.begin(), call.singles.end(), [r](ArgSingle arg) { return arg.reg.v == r; });
auto single = std::ranges::find_if(call.singles, [r](ArgSingle arg) { return arg.reg.v == r; });
if (single != call.singles.end()) {
if (r > 1) {
os_ << ", ";
Expand All @@ -235,8 +234,7 @@ struct InstructionPrinterVisitor {
}

// Look for the start of a pair.
std::vector<ArgPair>::const_iterator pair =
std::find_if(call.pairs.begin(), call.pairs.end(), [r](ArgPair arg) { return arg.mem.v == r; });
auto pair = std::ranges::find_if(call.pairs, [r](ArgPair arg) { return arg.mem.v == r; });
if (pair != call.pairs.end()) {
if (r > 1) {
os_ << ", ";
Expand Down Expand Up @@ -270,8 +268,8 @@ struct InstructionPrinterVisitor {
}

void operator()(Jmp const& b, int offset) {
string sign = offset > 0 ? "+" : "";
string target = sign + std::to_string(offset) + " <" + to_string(b.target) + ">";
const string sign = offset > 0 ? "+" : "";
const string target = sign + std::to_string(offset) + " <" + to_string(b.target) + ">";

if (b.cond) {
os_ << "if ";
Expand All @@ -284,7 +282,7 @@ struct InstructionPrinterVisitor {
void operator()(Packet const& b) {
/* Direct packet access, R0 = *(uint *) (skb->data + imm32) */
/* Indirect packet access, R0 = *(uint *) (skb->data + src_reg + imm32) */
string s = size(b.width);
const string s = size(b.width);
os_ << "r0 = ";
os_ << "*(" << s << " *)skb[";
if (b.regoffset) {
Expand All @@ -300,7 +298,7 @@ struct InstructionPrinterVisitor {
}

void print(Deref const& access) {
string sign = access.offset < 0 ? " - " : " + ";
const string sign = access.offset < 0 ? " - " : " + ";
int offset = std::abs(access.offset); // what about INT_MIN?
os_ << "*(" << size(access.width) << " *)";
os_ << "(" << access.basereg << sign << offset << ")";
Expand Down Expand Up @@ -379,9 +377,9 @@ string to_string(AssertionConstraint const& constraint) {
return str.str();
}

int size(Instruction inst) {
if (std::holds_alternative<Bin>(inst)) {
if (std::get<Bin>(inst).lddw) {
int size(const Instruction& inst) {
if (const auto bin = std::get_if<Bin>(&inst)) {
if (bin->lddw) {
return 2;
}
}
Expand All @@ -401,9 +399,9 @@ auto get_labels(const InstructionSeq& insts) {
return pc_of_label;
}

void print(const InstructionSeq& insts, std::ostream& out, std::optional<const label_t> label_to_print,
void print(const InstructionSeq& insts, std::ostream& out, const std::optional<const label_t>& label_to_print,
bool print_line_info) {
auto pc_of_label = get_labels(insts);
const auto pc_of_label = get_labels(insts);
pc_t pc = 0;
std::string previous_source;
InstructionPrinterVisitor visitor{out};
Expand All @@ -427,13 +425,12 @@ void print(const InstructionSeq& insts, std::ostream& out, std::optional<const l
} else {
out << std::setw(8) << pc << ":\t";
}
if (std::holds_alternative<Jmp>(ins)) {
auto const& jmp = std::get<Jmp>(ins);
if (pc_of_label.count(jmp.target) == 0) {
throw std::runtime_error(string("Cannot find label ") + to_string(jmp.target));
if (const auto jmp = std::get_if<Jmp>(&ins)) {
if (!pc_of_label.contains(jmp->target)) {
throw std::runtime_error(string("Cannot find label ") + to_string(jmp->target));
}
pc_t target_pc = pc_of_label.at(jmp.target);
visitor(jmp, target_pc - (int)pc - 1);
const pc_t target_pc = pc_of_label.at(jmp->target);
visitor(*jmp, target_pc - static_cast<int>(pc) - 1);
} else {
std::visit(visitor, ins);
}
Expand Down Expand Up @@ -498,7 +495,7 @@ std::ostream& operator<<(std::ostream& o, const basic_block_t& bb) {
if (it != et) {
o << " "
<< "goto ";
for (; it != et;) {
while (it != et) {
o << *it;
++it;
if (it == et) {
Expand Down
6 changes: 3 additions & 3 deletions src/asm_ostream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ inline std::function<int32_t(label_t)> label_to_offset32(pc_t pc) {

std::ostream& operator<<(std::ostream& os, const btf_line_info_t& line_info);

void print(const InstructionSeq& insts, std::ostream& out, std::optional<const label_t> label_to_print,
void print(const InstructionSeq& insts, std::ostream& out, const std::optional<const label_t>& label_to_print,
bool print_line_info = false);

std::string to_string(label_t const& label);
Expand All @@ -43,8 +43,8 @@ std::ostream& operator<<(std::ostream& os, Condition::Op op);
inline std::ostream& operator<<(std::ostream& os, Imm imm) { return os << (int64_t)imm.v; }
inline std::ostream& operator<<(std::ostream& os, Reg const& a) { return os << "r" << (int)a.v; }
inline std::ostream& operator<<(std::ostream& os, Value const& a) {
if (std::holds_alternative<Imm>(a)) {
return os << std::get<Imm>(a);
if (auto pa = std::get_if<Imm>(&a)) {
return os << *pa;
elazarg marked this conversation as resolved.
Show resolved Hide resolved
}
return os << std::get<Reg>(a);
}
Expand Down
Loading
Loading