diff --git a/src/asm_cfg.cpp b/src/asm_cfg.cpp index d922c33b3..37afdbf29 100644 --- a/src/asm_cfg.cpp +++ b/src/asm_cfg.cpp @@ -11,7 +11,6 @@ #include "asm_syntax.hpp" #include "crab/cfg.hpp" -#include "crab_utils/debug.hpp" using std::optional; using std::set; @@ -20,19 +19,21 @@ using std::to_string; using std::vector; static optional get_jump(Instruction ins) { - if (std::holds_alternative(ins)) { - return std::get(ins).target; + if (const auto pins = std::get_if(&ins)) { + return pins->target; } return {}; } -static bool has_fall(Instruction ins) { +static bool has_fall(const Instruction& ins) { if (std::holds_alternative(ins)) { return false; } - if (std::holds_alternative(ins) && !std::get(ins).cond) { - return false; + if (const auto pins = std::get_if(&ins)) { + if (!pins->cond) { + return false; + } } return true; @@ -51,8 +52,8 @@ static void add_cfg_nodes(cfg_t& cfg, const label_t& caller_label, const label_t basic_block_t& caller_node = cfg.get_node(caller_label); std::string stack_frame_prefix = to_string(caller_label); for (auto& inst : caller_node) { - if (std::holds_alternative(inst)) { - std::get(inst).stack_frame_prefix = stack_frame_prefix; + if (const auto pcall = std::get_if(&inst)) { + pcall->stack_frame_prefix = stack_frame_prefix; } } @@ -72,10 +73,10 @@ static void add_cfg_nodes(cfg_t& cfg, const label_t& caller_label, const label_t const label_t label(macro_label.from, macro_label.to, stack_frame_prefix); auto& bb = cfg.insert(label); for (auto inst : cfg.get_node(macro_label)) { - if (std::holds_alternative(inst)) { - std::get(inst).stack_frame_prefix = label.stack_frame_prefix; - } else if (std::holds_alternative(inst)) { - std::get(inst).stack_frame_prefix = label.stack_frame_prefix; + if (const auto pins = std::get_if(&inst)) { + pins->stack_frame_prefix = label.stack_frame_prefix; + } else if (const auto pins = std::get_if(&inst)) { + pins->stack_frame_prefix = label.stack_frame_prefix; } bb.insert(inst); } @@ -121,18 +122,18 @@ static void add_cfg_nodes(cfg_t& cfg, const label_t& caller_label, const label_t for (auto& macro_label : seen_labels) { for (const label_t label(macro_label.from, macro_label.to, caller_label_str); const auto& inst : cfg.get_node(label)) { - if (std::holds_alternative(inst)) { + if (const auto pins = std::get_if(&inst)) { if (stack_frame_depth >= MAX_CALL_STACK_FRAMES) { throw std::runtime_error{"too many call stack frames"}; } - add_cfg_nodes(cfg, label, std::get(inst).target); + add_cfg_nodes(cfg, label, pins->target); } } } } /// Convert an instruction sequence to a control-flow graph (CFG). -static cfg_t instruction_seq_to_cfg(const InstructionSeq& insts, bool must_have_exit) { +static cfg_t instruction_seq_to_cfg(const InstructionSeq& insts, const bool must_have_exit) { cfg_t cfg; std::optional falling_from = {}; bool first = true; @@ -159,8 +160,7 @@ static cfg_t instruction_seq_to_cfg(const InstructionSeq& insts, bool must_have_ if (has_fall(inst)) { falling_from = label; } - auto jump_target = get_jump(inst); - if (jump_target) { + if (auto jump_target = get_jump(inst)) { bb >> cfg.insert(*jump_target); } @@ -180,8 +180,8 @@ static cfg_t instruction_seq_to_cfg(const InstructionSeq& insts, bool must_have_ // we only add new nodes that are actually reachable, based on the // results of the first pass. for (auto& [label, inst, _] : insts) { - if (std::holds_alternative(inst)) { - add_cfg_nodes(cfg, label, std::get(inst).target); + if (const auto pins = std::get_if(&inst)) { + add_cfg_nodes(cfg, label, pins->target); } } @@ -189,7 +189,7 @@ static cfg_t instruction_seq_to_cfg(const InstructionSeq& insts, bool must_have_ } /// Get the inverse of a given comparison operation. -static Condition::Op reverse(Condition::Op op) { +static Condition::Op reverse(const Condition::Op op) { switch (op) { case Condition::Op::EQ: return Condition::Op::NE; case Condition::Op::NE: return Condition::Op::EQ; @@ -214,7 +214,7 @@ static Condition::Op reverse(Condition::Op op) { } /// Get the inverse of a given comparison condition. -static Condition reverse(Condition cond) { +static Condition reverse(const Condition& cond) { return {.op = reverse(cond.op), .left = cond.left, .right = cond.right, .is64 = cond.is64}; } @@ -274,31 +274,30 @@ static cfg_t to_nondet(const cfg_t& cfg) { return res; } -/// Get the type of a given instruction. +/// Get the type of given instruction. /// Most of these type names are also statistics header labels. static std::string instype(Instruction ins) { - if (std::holds_alternative(ins)) { - auto call = std::get(ins); - if (call.is_map_lookup) { + if (const auto pcall = std::get_if(&ins)) { + if (pcall->is_map_lookup) { return "call_1"; } - if (call.pairs.empty()) { - if (std::all_of(call.singles.begin(), call.singles.end(), - [](ArgSingle kr) { return kr.kind == ArgSingle::Kind::ANYTHING; })) { + if (pcall->pairs.empty()) { + if (std::ranges::all_of(pcall->singles, + [](const ArgSingle kr) { return kr.kind == ArgSingle::Kind::ANYTHING; })) { return "call_nomem"; } } return "call_mem"; } else if (std::holds_alternative(ins)) { return "callx"; - } else if (std::holds_alternative(ins)) { - return std::get(ins).is_load ? "load" : "store"; + } else if (const auto pimm = std::get_if(&ins)) { + return pimm->is_load ? "load" : "store"; } else if (std::holds_alternative(ins)) { return "load_store"; } else if (std::holds_alternative(ins)) { return "packet_access"; - } else if (std::holds_alternative(ins)) { - switch (std::get(ins).op) { + } else if (const auto pins = std::get_if(&ins)) { + switch (pins->op) { case Bin::Op::MOV: case Bin::Op::MOVSX8: case Bin::Op::MOVSX16: @@ -334,20 +333,18 @@ std::map collect_stats(const cfg_t& cfg) { basic_block_t const& bb = cfg.get_node(this_label); for (Instruction ins : bb) { - if (std::holds_alternative(ins)) { - if (std::get(ins).mapfd == -1) { + if (const auto pins = std::get_if(&ins)) { + if (pins->mapfd == -1) { res["map_in_map"] = 1; } } - if (std::holds_alternative(ins)) { - auto call = std::get(ins); - if (call.reallocate_packet) { + if (const auto pins = std::get_if(&ins)) { + if (pins->reallocate_packet) { res["reallocate"] = 1; } } - if (std::holds_alternative(ins)) { - auto const& bin = std::get(ins); - res[bin.is64 ? "arith64" : "arith32"]++; + if (const auto pins = std::get_if(&ins)) { + res[pins->is64 ? "arith64" : "arith32"]++; } res[instype(ins)]++; } diff --git a/src/asm_marshal.cpp b/src/asm_marshal.cpp index 987b3423f..5e8854474 100644 --- a/src/asm_marshal.cpp +++ b/src/asm_marshal.cpp @@ -109,8 +109,9 @@ struct MarshalVisitor { vector operator()(Bin const& b) { if (b.lddw) { - assert(std::holds_alternative(b.v)); - auto [imm, next_imm] = split(std::get(b.v).v); + const auto pimm = std::get_if(&b.v); + assert(pimm != nullptr); + auto [imm, next_imm] = split(pimm->v); return makeLddw(b.dst, false, imm, next_imm); } @@ -248,9 +249,9 @@ struct MarshalVisitor { } else { res.opcode |= INST_CLS_ST; res.dst = access.basereg.v; - if (std::holds_alternative(b.value)) { + if (const auto preg = std::get_if(&b.value)) { res.opcode |= 0x1; - res.src = std::get(b.value).v; + res.src = preg->v; } else { res.opcode |= 0x0; res.imm = static_cast(std::get(b.value).v); @@ -308,9 +309,9 @@ vector marshal(const vector& insts) { return res; } -static int size(Instruction inst) { - if (std::holds_alternative(inst)) { - if (std::get(inst).lddw) { +static int size(const Instruction& inst) { + if (const auto pins = std::get_if(&inst)) { + if (pins->lddw) { return 2; } } @@ -336,9 +337,8 @@ vector marshal(const InstructionSeq& insts) { pc_t pc = 0; for (auto [label, ins, _] : insts) { (void)label; // unused - if (std::holds_alternative(ins)) { - Jmp& jmp = std::get(ins); - jmp.target = label_t(pc_of_label.at(jmp.target)); + if (const auto pins = std::get_if(&ins)) { + pins->target = label_t(pc_of_label.at(pins->target)); } for (auto e : marshal(ins, pc)) { pc++; diff --git a/src/asm_ostream.cpp b/src/asm_ostream.cpp index 682dba033..e2e9f3169 100644 --- a/src/asm_ostream.cpp +++ b/src/asm_ostream.cpp @@ -120,7 +120,7 @@ std::ostream& operator<<(std::ostream& os, ValidAccess const& a) { os << a.offset; } - if (a.width == (Value)Imm{0}) { + if (a.width == static_cast(Imm{0})) { // a.width == 0, meaning we only care it's an in-bound pointer, // so it can be compared with another pointer to the same region. os << ") for comparison/subtraction"; @@ -138,12 +138,12 @@ std::ostream& operator<<(std::ostream& os, ValidAccess const& a) { static crab::variable_t typereg(const Reg& r) { return crab::variable_t::reg(crab::data_kind_t::types, r.v); } std::ostream& operator<<(std::ostream& os, ValidSize const& a) { - auto op = a.can_be_zero ? " >= " : " > "; + const auto op = a.can_be_zero ? " >= " : " > "; return os << a.reg << ".value" << op << 0; } std::ostream& operator<<(std::ostream& os, ValidCall const& a) { - EbpfHelperPrototype proto = global_program_info->platform->get_helper_prototype(a.func); + const EbpfHelperPrototype proto = global_program_info->platform->get_helper_prototype(a.func); return os << "valid call(" << proto.name << ")"; } @@ -224,8 +224,7 @@ struct InstructionPrinterVisitor { os_ << "r0 = " << call.name << ":" << call.func << "("; for (uint8_t r = 1; r <= 5; r++) { // Look for a singleton. - std::vector::const_iterator single = - std::find_if(call.singles.begin(), call.singles.end(), [r](ArgSingle arg) { return arg.reg.v == r; }); + auto single = std::ranges::find_if(call.singles, [r](ArgSingle arg) { return arg.reg.v == r; }); if (single != call.singles.end()) { if (r > 1) { os_ << ", "; @@ -235,8 +234,7 @@ struct InstructionPrinterVisitor { } // Look for the start of a pair. - std::vector::const_iterator pair = - std::find_if(call.pairs.begin(), call.pairs.end(), [r](ArgPair arg) { return arg.mem.v == r; }); + auto pair = std::ranges::find_if(call.pairs, [r](ArgPair arg) { return arg.mem.v == r; }); if (pair != call.pairs.end()) { if (r > 1) { os_ << ", "; @@ -270,8 +268,8 @@ struct InstructionPrinterVisitor { } void operator()(Jmp const& b, int offset) { - string sign = offset > 0 ? "+" : ""; - string target = sign + std::to_string(offset) + " <" + to_string(b.target) + ">"; + const string sign = offset > 0 ? "+" : ""; + const string target = sign + std::to_string(offset) + " <" + to_string(b.target) + ">"; if (b.cond) { os_ << "if "; @@ -284,7 +282,7 @@ struct InstructionPrinterVisitor { void operator()(Packet const& b) { /* Direct packet access, R0 = *(uint *) (skb->data + imm32) */ /* Indirect packet access, R0 = *(uint *) (skb->data + src_reg + imm32) */ - string s = size(b.width); + const string s = size(b.width); os_ << "r0 = "; os_ << "*(" << s << " *)skb["; if (b.regoffset) { @@ -300,7 +298,7 @@ struct InstructionPrinterVisitor { } void print(Deref const& access) { - string sign = access.offset < 0 ? " - " : " + "; + const string sign = access.offset < 0 ? " - " : " + "; int offset = std::abs(access.offset); // what about INT_MIN? os_ << "*(" << size(access.width) << " *)"; os_ << "(" << access.basereg << sign << offset << ")"; @@ -379,9 +377,9 @@ string to_string(AssertionConstraint const& constraint) { return str.str(); } -int size(Instruction inst) { - if (std::holds_alternative(inst)) { - if (std::get(inst).lddw) { +int size(const Instruction& inst) { + if (const auto bin = std::get_if(&inst)) { + if (bin->lddw) { return 2; } } @@ -401,9 +399,9 @@ auto get_labels(const InstructionSeq& insts) { return pc_of_label; } -void print(const InstructionSeq& insts, std::ostream& out, std::optional label_to_print, +void print(const InstructionSeq& insts, std::ostream& out, const std::optional& label_to_print, bool print_line_info) { - auto pc_of_label = get_labels(insts); + const auto pc_of_label = get_labels(insts); pc_t pc = 0; std::string previous_source; InstructionPrinterVisitor visitor{out}; @@ -427,13 +425,12 @@ void print(const InstructionSeq& insts, std::ostream& out, std::optional(ins)) { - auto const& jmp = std::get(ins); - if (pc_of_label.count(jmp.target) == 0) { - throw std::runtime_error(string("Cannot find label ") + to_string(jmp.target)); + if (const auto jmp = std::get_if(&ins)) { + if (!pc_of_label.contains(jmp->target)) { + throw std::runtime_error(string("Cannot find label ") + to_string(jmp->target)); } - pc_t target_pc = pc_of_label.at(jmp.target); - visitor(jmp, target_pc - (int)pc - 1); + const pc_t target_pc = pc_of_label.at(jmp->target); + visitor(*jmp, target_pc - static_cast(pc) - 1); } else { std::visit(visitor, ins); } @@ -498,7 +495,7 @@ std::ostream& operator<<(std::ostream& o, const basic_block_t& bb) { if (it != et) { o << " " << "goto "; - for (; it != et;) { + while (it != et) { o << *it; ++it; if (it == et) { diff --git a/src/asm_ostream.hpp b/src/asm_ostream.hpp index ae55a6e9c..db8f998a7 100644 --- a/src/asm_ostream.hpp +++ b/src/asm_ostream.hpp @@ -29,7 +29,7 @@ inline std::function label_to_offset32(pc_t pc) { std::ostream& operator<<(std::ostream& os, const btf_line_info_t& line_info); -void print(const InstructionSeq& insts, std::ostream& out, std::optional label_to_print, +void print(const InstructionSeq& insts, std::ostream& out, const std::optional& label_to_print, bool print_line_info = false); std::string to_string(label_t const& label); @@ -43,8 +43,8 @@ std::ostream& operator<<(std::ostream& os, Condition::Op op); inline std::ostream& operator<<(std::ostream& os, Imm imm) { return os << (int64_t)imm.v; } inline std::ostream& operator<<(std::ostream& os, Reg const& a) { return os << "r" << (int)a.v; } inline std::ostream& operator<<(std::ostream& os, Value const& a) { - if (std::holds_alternative(a)) { - return os << std::get(a); + if (auto pa = std::get_if(&a)) { + return os << *pa; } return os << std::get(a); } diff --git a/src/asm_unmarshal.cpp b/src/asm_unmarshal.cpp index dc9457d43..e319b3e3f 100644 --- a/src/asm_unmarshal.cpp +++ b/src/asm_unmarshal.cpp @@ -5,14 +5,13 @@ #include #include -#include "ebpf_vm_isa.hpp" - #include "asm_unmarshal.hpp" +#include "ebpf_vm_isa.hpp" using std::string; using std::vector; -int opcode_to_width(uint8_t opcode) { +int opcode_to_width(const uint8_t opcode) { switch (opcode & INST_SIZE_MASK) { case INST_SIZE_B: return 1; case INST_SIZE_H: return 2; @@ -23,7 +22,7 @@ int opcode_to_width(uint8_t opcode) { return {}; } -uint8_t width_to_opcode(int width) { +uint8_t width_to_opcode(const int width) { switch (width) { case 1: return INST_SIZE_B; case 2: return INST_SIZE_H; @@ -37,21 +36,22 @@ uint8_t width_to_opcode(int width) { template void compare(const string& field, T actual, T expected) { if (actual != expected) { - std::cerr << field << ": (actual) " << std::hex << (int)actual << " != " << (int)expected << " (expected)\n"; + std::cerr << field << ": (actual) " << std::hex << static_cast(actual) + << " != " << static_cast(expected) << " (expected)\n"; } } -static std::string make_opcode_message(const char* msg, uint8_t opcode) { +static std::string make_opcode_message(const char* msg, const uint8_t opcode) { std::ostringstream oss; - oss << msg << " op 0x" << std::hex << (int)opcode; + oss << msg << " op 0x" << std::hex << static_cast(opcode); return oss.str(); } struct InvalidInstruction : std::invalid_argument { size_t pc; - explicit InvalidInstruction(size_t pc, const char* what) : std::invalid_argument{what}, pc{pc} {} - InvalidInstruction(size_t pc, std::string what) : std::invalid_argument{what}, pc{pc} {} - InvalidInstruction(size_t pc, uint8_t opcode) + explicit InvalidInstruction(const size_t pc, const char* what) : std::invalid_argument{what}, pc{pc} {} + InvalidInstruction(const size_t pc, const std::string& what) : std::invalid_argument{what}, pc{pc} {} + InvalidInstruction(const size_t pc, const uint8_t opcode) : std::invalid_argument{make_opcode_message("bad instruction", opcode)}, pc{pc} {} }; @@ -59,7 +59,7 @@ struct UnsupportedMemoryMode : std::invalid_argument { explicit UnsupportedMemoryMode(const char* what) : std::invalid_argument{what} {} }; -static auto getMemIsLoad(uint8_t opcode) -> bool { +static auto getMemIsLoad(const uint8_t opcode) -> bool { switch (opcode & INST_CLS_MASK) { case INST_CLS_LD: case INST_CLS_LDX: return true; @@ -69,7 +69,7 @@ static auto getMemIsLoad(uint8_t opcode) -> bool { return {}; } -static auto getMemWidth(uint8_t opcode) -> int { +static auto getMemWidth(const uint8_t opcode) -> int { switch (opcode & INST_SIZE_MASK) { case INST_SIZE_B: return 1; case INST_SIZE_H: return 2; @@ -89,8 +89,8 @@ static auto getMemWidth(uint8_t opcode) -> int { // return {}; // } -static Instruction shift32(Reg dst, Bin::Op op) { - return (Instruction)Bin{.op = op, .dst = dst, .v = Imm{32}, .is64 = true, .lddw = false}; +static Instruction shift32(const Reg dst, const Bin::Op op) { + return Bin{.op = op, .dst = dst, .v = Imm{32}, .is64 = true, .lddw = false}; } struct Unmarshaller { @@ -102,9 +102,9 @@ struct Unmarshaller { note_next_pc(); } - auto getAluOp(size_t pc, ebpf_inst inst) -> std::variant { + auto getAluOp(const size_t pc, const ebpf_inst inst) -> std::variant { // First handle instructions that support a non-zero offset. - bool is64 = (inst.opcode & INST_CLS_MASK) == INST_CLS_ALU64; + const bool is64 = (inst.opcode & INST_CLS_MASK) == INST_CLS_ALU64; switch (inst.opcode & INST_ALU_OP_MASK) { case INST_ALU_OP_DIV: if (!info.platform->supports_group(is64 ? bpf_conformance_groups_t::divmul64 @@ -226,9 +226,8 @@ struct Unmarshaller { return {}; } - auto getAtomicOp(size_t pc, ebpf_inst inst) -> Atomic::Op { - Atomic::Op op = (Atomic::Op)(inst.imm & ~INST_FETCH); - switch (op) { + static auto getAtomicOp(const size_t pc, const ebpf_inst inst) -> Atomic::Op { + switch (const auto op = static_cast(inst.imm & ~INST_FETCH)) { case Atomic::Op::XCHG: case Atomic::Op::CMPXCHG: if ((inst.imm & INST_FETCH) == 0) { @@ -242,11 +241,11 @@ struct Unmarshaller { throw InvalidInstruction(pc, "unsupported immediate"); } - uint64_t sign_extend(int32_t imm) { return (uint64_t)(int64_t)imm; } + static uint64_t sign_extend(const int32_t imm) { return static_cast(static_cast(imm)); } - uint64_t zero_extend(int32_t imm) { return (uint64_t)(uint32_t)imm; } + static uint64_t zero_extend(const int32_t imm) { return static_cast(static_cast(imm)); } - auto getBinValue(pc_t pc, ebpf_inst inst) -> Value { + static auto getBinValue(const pc_t pc, const ebpf_inst inst) -> Value { if (inst.opcode & INST_SRC_REG) { if (inst.imm != 0) { throw InvalidInstruction(pc, make_opcode_message("nonzero imm for", inst.opcode)); @@ -261,7 +260,7 @@ struct Unmarshaller { } } - static auto getJmpOp(size_t pc, uint8_t opcode) -> Condition::Op { + static auto getJmpOp(const size_t pc, const uint8_t opcode) -> Condition::Op { using Op = Condition::Op; switch ((opcode >> 4) & 0xF) { case 0x0: return {}; // goto @@ -284,17 +283,17 @@ struct Unmarshaller { return {}; } - auto makeMemOp(pc_t pc, ebpf_inst inst) -> Instruction { + auto makeMemOp(const pc_t pc, const ebpf_inst inst) -> Instruction { if (inst.dst > R10_STACK_POINTER || inst.src > R10_STACK_POINTER) { throw InvalidInstruction(pc, "bad register"); } - int width = getMemWidth(inst.opcode); + const int width = getMemWidth(inst.opcode); if (!info.platform->supports_group((width == sizeof(uint64_t)) ? bpf_conformance_groups_t::base64 : bpf_conformance_groups_t::base32)) { throw InvalidInstruction(pc, inst.opcode); } - bool isLD = (inst.opcode & INST_CLS_MASK) == INST_CLS_LD; + const bool isLD = (inst.opcode & INST_CLS_MASK) == INST_CLS_LD; switch (inst.opcode & INST_MODE_MASK) { case INST_MODE_IMM: throw InvalidInstruction(pc, inst.opcode); @@ -332,11 +331,11 @@ struct Unmarshaller { if (isLD) { throw InvalidInstruction(pc, inst.opcode); } - bool isLoad = getMemIsLoad(inst.opcode); + const bool isLoad = getMemIsLoad(inst.opcode); if (isLoad && inst.dst == R10_STACK_POINTER) { throw InvalidInstruction(pc, "cannot modify r10"); } - bool isImm = !(inst.opcode & 1); + const bool isImm = !(inst.opcode & 1); if (isImm && inst.src != 0) { throw InvalidInstruction(pc, inst.opcode); } @@ -345,7 +344,7 @@ struct Unmarshaller { } assert(!(isLoad && isImm)); - uint8_t basereg = isLoad ? inst.src : inst.dst; + const uint8_t basereg = isLoad ? inst.src : inst.dst; if (basereg == R10_STACK_POINTER && (inst.offset + opcode_to_width(inst.opcode) > 0 || inst.offset < -EBPF_STACK_SIZE)) { @@ -358,8 +357,9 @@ struct Unmarshaller { .basereg = Reg{basereg}, .offset = inst.offset, }, - .value = - isLoad ? (Value)Reg{inst.dst} : (isImm ? (Value)Imm{zero_extend(inst.imm)} : (Value)Reg{inst.src}), + .value = isLoad ? static_cast(Reg{inst.dst}) + : (isImm ? static_cast(Imm{zero_extend(inst.imm)}) + : static_cast(Reg{inst.src})), .is_load = isLoad, }; return res; @@ -388,11 +388,10 @@ struct Unmarshaller { }; default: throw InvalidInstruction(pc, inst.opcode); } - return {}; } - auto makeAluOp(size_t pc, ebpf_inst inst) -> Instruction { - bool is64 = (inst.opcode & INST_CLS_MASK) == INST_CLS_ALU64; + auto makeAluOp(const size_t pc, const ebpf_inst inst) -> Instruction { + const bool is64 = (inst.opcode & INST_CLS_MASK) == INST_CLS_ALU64; if (!info.platform->supports_group(is64 ? bpf_conformance_groups_t::base64 : bpf_conformance_groups_t::base32)) { throw InvalidInstruction(pc, inst.opcode); @@ -404,8 +403,8 @@ struct Unmarshaller { throw InvalidInstruction(pc, "bad register"); } return std::visit( - overloaded{[&](Un::Op op) -> Instruction { return Un{.op = op, .dst = Reg{inst.dst}, .is64 = is64}; }, - [&](Bin::Op op) -> Instruction { + overloaded{[&](const Un::Op op) -> Instruction { return Un{.op = op, .dst = Reg{inst.dst}, .is64 = is64}; }, + [&](const Bin::Op op) -> Instruction { Bin res{ .op = op, .dst = Reg{inst.dst}, @@ -414,8 +413,10 @@ struct Unmarshaller { }; if (!thread_local_options.allow_division_by_zero && (op == Bin::Op::UDIV || op == Bin::Op::UMOD)) { - if (std::holds_alternative(res.v) && std::get(res.v).v == 0) { - note("division by zero"); + if (const auto pimm = std::get_if(&res.v)) { + if (pimm->v == 0) { + note("division by zero"); + } } } return res; @@ -423,14 +424,15 @@ struct Unmarshaller { getAluOp(pc, inst)); } - auto makeLddw(ebpf_inst inst, int32_t next_imm, const vector& insts, pc_t pc) -> Instruction { + auto makeLddw(const ebpf_inst inst, const int32_t next_imm, const vector& insts, const pc_t pc) const + -> Instruction { if (!info.platform->supports_group(bpf_conformance_groups_t::base64)) { throw InvalidInstruction{pc, inst.opcode}; } if (pc >= insts.size() - 1) { throw InvalidInstruction(pc, "incomplete lddw"); } - ebpf_inst next = insts[pc + 1]; + const ebpf_inst next = insts[pc + 1]; if (next.opcode != 0 || next.dst != 0 || next.src != 0 || next.offset != 0) { throw InvalidInstruction(pc, "invalid lddw"); } @@ -462,7 +464,7 @@ struct Unmarshaller { }; } - static ArgSingle::Kind toArgSingleKind(ebpf_argument_type_t t) { + static ArgSingle::Kind toArgSingleKind(const ebpf_argument_type_t t) { switch (t) { case EBPF_ARGUMENT_TYPE_ANYTHING: return ArgSingle::Kind::ANYTHING; case EBPF_ARGUMENT_TYPE_PTR_TO_MAP: return ArgSingle::Kind::MAP_FD; @@ -475,7 +477,7 @@ struct Unmarshaller { return {}; } - static ArgPair::Kind toArgPairKind(ebpf_argument_type_t t) { + static ArgPair::Kind toArgPairKind(const ebpf_argument_type_t t) { switch (t) { case EBPF_ARGUMENT_TYPE_PTR_TO_READABLE_MEM_OR_NULL: return ArgPair::Kind::PTR_TO_READABLE_MEM_OR_NULL; case EBPF_ARGUMENT_TYPE_PTR_TO_READABLE_MEM: return ArgPair::Kind::PTR_TO_READABLE_MEM; @@ -485,8 +487,8 @@ struct Unmarshaller { return {}; } - auto makeCall(int32_t imm) const { - EbpfHelperPrototype proto = info.platform->get_helper_prototype(imm); + auto makeCall(const int32_t imm) const { + const EbpfHelperPrototype proto = info.platform->get_helper_prototype(imm); if (proto.return_type == EBPF_RETURN_TYPE_UNSUPPORTED) { throw std::runtime_error(std::string("unsupported function: ") + proto.name); } @@ -495,7 +497,7 @@ struct Unmarshaller { res.name = proto.name; res.reallocate_packet = proto.reallocate_packet; res.is_map_lookup = proto.return_type == EBPF_RETURN_TYPE_PTR_TO_MAP_VALUE_OR_NULL; - std::array args = { + const std::array args = { {EBPF_ARGUMENT_TYPE_DONTCARE, proto.argument_type[0], proto.argument_type[1], proto.argument_type[2], proto.argument_type[3], proto.argument_type[4], EBPF_ARGUMENT_TYPE_DONTCARE}}; for (size_t i = 1; i < args.size() - 1; i++) { @@ -510,7 +512,7 @@ struct Unmarshaller { case EBPF_ARGUMENT_TYPE_PTR_TO_MAP_KEY: case EBPF_ARGUMENT_TYPE_PTR_TO_MAP_VALUE: case EBPF_ARGUMENT_TYPE_PTR_TO_CTX: - res.singles.push_back({toArgSingleKind(args[i]), Reg{(uint8_t)i}}); + res.singles.push_back({toArgSingleKind(args[i]), Reg{static_cast(i)}}); break; case EBPF_ARGUMENT_TYPE_CONST_SIZE: { // Sanity check: This argument should never be seen in isolation. @@ -542,8 +544,9 @@ struct Unmarshaller { "EBPF_ARGUMENT_TYPE_CONST_SIZE_OR_ZERO: ") + proto.name); } - bool can_be_zero = (args[i + 1] == EBPF_ARGUMENT_TYPE_CONST_SIZE_OR_ZERO); - res.pairs.push_back({toArgPairKind(args[i]), Reg{(uint8_t)i}, Reg{(uint8_t)(i + 1)}, can_be_zero}); + const bool can_be_zero = (args[i + 1] == EBPF_ARGUMENT_TYPE_CONST_SIZE_OR_ZERO); + res.pairs.push_back({toArgPairKind(args[i]), Reg{static_cast(i)}, + Reg{static_cast(i + 1)}, can_be_zero}); i++; break; } @@ -552,17 +555,17 @@ struct Unmarshaller { } /// Given a program counter and an offset, get the label of the target instruction. - label_t getJumpTarget(int32_t offset, const vector& insts, pc_t pc) const { - pc_t new_pc = pc + 1 + offset; + static label_t getJumpTarget(const int32_t offset, const vector& insts, const pc_t pc) { + const pc_t new_pc = pc + 1 + offset; if (new_pc >= insts.size()) { throw InvalidInstruction(pc, "jump out of bounds"); } else if (insts[new_pc].opcode == 0) { throw InvalidInstruction(pc, "jump to middle of lddw"); } - return label_t{(int)new_pc}; + return label_t{static_cast(new_pc)}; } - auto makeCallLocal(ebpf_inst inst, const vector& insts, pc_t pc) const { + static auto makeCallLocal(const ebpf_inst inst, const vector& insts, const pc_t pc) { if (inst.opcode & INST_SRC_REG) { throw InvalidInstruction(pc, inst.opcode); } @@ -572,7 +575,7 @@ struct Unmarshaller { return CallLocal{.target = getJumpTarget(inst.imm, insts, pc)}; } - auto makeCallx(ebpf_inst inst, pc_t pc) const { + static auto makeCallx(const ebpf_inst inst, const pc_t pc) { // callx puts the register number in the 'dst' field rather than the 'src' field. if (inst.dst > R10_STACK_POINTER) { throw InvalidInstruction(pc, "bad register"); @@ -585,12 +588,12 @@ struct Unmarshaller { if (inst.imm < 0 || inst.imm > R10_STACK_POINTER) { throw InvalidInstruction(pc, "bad register"); } - return Callx{(uint8_t)inst.imm}; + return Callx{static_cast(inst.imm)}; } return Callx{inst.dst}; } - auto makeJmp(ebpf_inst inst, const vector& insts, pc_t pc) -> Instruction { + auto makeJmp(const ebpf_inst inst, const vector& insts, const pc_t pc) const -> Instruction { switch ((inst.opcode >> 4) & 0xF) { case INST_CALL: if ((inst.opcode & INST_CLS_MASK) != INST_CLS_JMP) { @@ -662,12 +665,12 @@ struct Unmarshaller { } default: { // First validate the opcode, src, and imm. - auto is64 = (inst.opcode & INST_CLS_MASK) == INST_CLS_JMP; + const auto is64 = (inst.opcode & INST_CLS_MASK) == INST_CLS_JMP; if (!info.platform->supports_group(is64 ? bpf_conformance_groups_t::base64 : bpf_conformance_groups_t::base32)) { throw InvalidInstruction(pc, inst.opcode); } - auto op = getJmpOp(pc, inst.opcode); + const auto op = getJmpOp(pc, inst.opcode); if (!(inst.opcode & INST_SRC_REG) && (inst.src != 0)) { throw InvalidInstruction(pc, inst.opcode); } @@ -675,8 +678,8 @@ struct Unmarshaller { throw InvalidInstruction(pc, make_opcode_message("nonzero imm for", inst.opcode)); } - int32_t offset = (inst.opcode == INST_OP_JA32) ? inst.imm : inst.offset; - label_t target = getJumpTarget(offset, insts, pc); + const int32_t offset = (inst.opcode == INST_OP_JA32) ? inst.imm : inst.offset; + const label_t target = getJumpTarget(offset, insts, pc); if (inst.opcode != INST_OP_JA16 && inst.opcode != INST_OP_JA32) { if (inst.dst > R10_STACK_POINTER) { throw InvalidInstruction(pc, "bad register"); @@ -686,13 +689,13 @@ struct Unmarshaller { } } - auto cond = (inst.opcode == INST_OP_JA16 || inst.opcode == INST_OP_JA32) - ? std::optional{} - : Condition{.op = op, - .left = Reg{inst.dst}, - .right = (inst.opcode & INST_SRC_REG) ? (Value)Reg{inst.src} - : Imm{sign_extend(inst.imm)}, - .is64 = ((inst.opcode & INST_CLS_MASK) == INST_CLS_JMP)}; + const auto cond = (inst.opcode == INST_OP_JA16 || inst.opcode == INST_OP_JA32) + ? std::optional{} + : Condition{.op = op, + .left = Reg{inst.dst}, + .right = (inst.opcode & INST_SRC_REG) ? static_cast(Reg{inst.src}) + : Imm{sign_extend(inst.imm)}, + .is64 = ((inst.opcode & INST_CLS_MASK) == INST_CLS_JMP)}; return Jmp{.cond = cond, .target = target}; } } @@ -705,14 +708,14 @@ struct Unmarshaller { throw std::invalid_argument("Zero length programs are not allowed"); } for (size_t pc = 0; pc < insts.size();) { - ebpf_inst inst = insts[pc]; + const ebpf_inst inst = insts[pc]; Instruction new_ins; bool skip_instruction = false; bool fallthrough = true; switch (inst.opcode & INST_CLS_MASK) { case INST_CLS_LD: if (inst.opcode == INST_OP_LDDW_IMM) { - int32_t next_imm = pc < insts.size() - 1 ? insts[pc + 1].imm : 0; + const int32_t next_imm = pc < insts.size() - 1 ? insts[pc + 1].imm : 0; new_ins = makeLddw(inst, next_imm, insts, static_cast(pc)); skip_instruction = true; break; @@ -731,7 +734,7 @@ struct Unmarshaller { if (pc >= insts.size() - 1) { break; } - ebpf_inst next = insts[pc + 1]; + const ebpf_inst next = insts[pc + 1]; auto dst = Reg{inst.dst}; if (new_ins != shift32(dst, Bin::Op::LSH)) { @@ -755,13 +758,13 @@ struct Unmarshaller { case INST_CLS_JMP32: case INST_CLS_JMP: { - new_ins = makeJmp(inst, insts, static_cast(pc)); + new_ins = makeJmp(inst, insts, pc); if (std::holds_alternative(new_ins)) { fallthrough = false; exit_count++; } - if (std::holds_alternative(new_ins)) { - if (!std::get(new_ins).cond) { + if (const auto pjmp = std::get_if(&new_ins)) { + if (!pjmp->cond) { fallthrough = false; } } @@ -810,8 +813,8 @@ std::variant unmarshal(const raw_program& raw_prog) return unmarshal(raw_prog, notes); } -Call make_call(int imm, const ebpf_platform_t& platform) { +Call make_call(const int imm, const ebpf_platform_t& platform) { vector> notes; - program_info info{.platform = &platform}; + const program_info info{.platform = &platform}; return Unmarshaller{notes, info}.makeCall(imm); } diff --git a/src/assertions.cpp b/src/assertions.cpp index 4a53e68ca..5282ce100 100644 --- a/src/assertions.cpp +++ b/src/assertions.cpp @@ -18,9 +18,7 @@ class AssertExtractor { program_info info; std::optional current_label; ///< Pre-simplification label this assert is part of. - static Reg reg(Value v) { return std::get(v); } - - static Imm imm(Value v) { return std::get(v); } + static Imm imm(const Value& v) { return std::get(v); } static vector zero_offset_ctx(Reg reg) { vector res; @@ -129,13 +127,13 @@ class AssertExtractor { } [[nodiscard]] - vector explicate(Condition cond) const { + vector explicate(const Condition& cond) const { if (info.type.is_privileged) { return {}; } vector res; - if (std::holds_alternative(cond.right)) { - if (imm(cond.right).v != 0) { + if (const auto pimm = std::get_if(&cond.right)) { + if (pimm->v != 0) { // no need to check for valid access, it must be a number res.emplace_back(TypeConstraint{cond.left, TypeGroup::number}); } else { @@ -144,33 +142,34 @@ class AssertExtractor { // Anything can be compared to 0 } } else { + const auto reg_right = get(cond.right); res.emplace_back(ValidAccess{cond.left}); - res.emplace_back(ValidAccess{reg(cond.right)}); + res.emplace_back(ValidAccess{reg_right}); if (cond.op != Condition::Op::EQ && cond.op != Condition::Op::NE) { res.emplace_back(TypeConstraint{cond.left, TypeGroup::ptr_or_num}); } - res.emplace_back(Comparable{.r1 = cond.left, .r2 = reg(cond.right), .or_r2_is_number = false}); + res.emplace_back(Comparable{.r1 = cond.left, .r2 = reg_right, .or_r2_is_number = false}); } return res; } - vector operator()(Assume ins) const { return explicate(ins.cond); } + vector operator()(const Assume& ins) const { return explicate(ins.cond); } - vector operator()(Jmp ins) const { + vector operator()(const Jmp& ins) const { if (!ins.cond) { return {}; } return explicate(*ins.cond); } - vector operator()(Mem ins) const { + vector operator()(const Mem& ins) const { vector res; - Reg basereg = ins.access.basereg; + const Reg basereg = ins.access.basereg; Imm width{static_cast(ins.access.width)}; - int offset = ins.access.offset; + const int offset = ins.access.offset; if (basereg.v == R10_STACK_POINTER) { // We know we are accessing the stack. - if (offset < -EBPF_STACK_SIZE || offset + (int)width.v >= 0) { + if (offset < -EBPF_STACK_SIZE || offset + static_cast(width.v) >= 0) { // This assertion will fail res.emplace_back( ValidAccess{basereg, offset, width, false, ins.is_load ? AccessType::read : AccessType::write}); @@ -179,18 +178,20 @@ class AssertExtractor { res.emplace_back(TypeConstraint{basereg, TypeGroup::pointer}); res.emplace_back( ValidAccess{basereg, offset, width, false, ins.is_load ? AccessType::read : AccessType::write}); - if (!info.type.is_privileged && !ins.is_load && std::holds_alternative(ins.value)) { - if (width.v != 8) { - res.emplace_back(TypeConstraint{reg(ins.value), TypeGroup::number}); - } else { - res.emplace_back(ValidStore{ins.access.basereg, reg(ins.value)}); + if (!info.type.is_privileged && !ins.is_load) { + if (const auto preg = std::get_if(&ins.value)) { + if (width.v != 8) { + res.emplace_back(TypeConstraint{*preg, TypeGroup::number}); + } else { + res.emplace_back(ValidStore{ins.access.basereg, *preg}); + } } } } return res; } - vector operator()(Atomic ins) const { + vector operator()(const Atomic& ins) const { vector res; res.emplace_back(TypeConstraint{ins.access.basereg, TypeGroup::pointer}); res.emplace_back( @@ -204,58 +205,54 @@ class AssertExtractor { return res; } - vector operator()(Un ins) { return {Assert{TypeConstraint{ins.dst, TypeGroup::number}}}; } + vector operator()(const Un ins) const { return {Assert{TypeConstraint{ins.dst, TypeGroup::number}}}; } - vector operator()(Bin ins) const { + vector operator()(const Bin& ins) const { switch (ins.op) { case Bin::Op::MOV: return {}; case Bin::Op::MOVSX8: case Bin::Op::MOVSX16: case Bin::Op::MOVSX32: - if (std::holds_alternative(ins.v)) { - auto src = reg(ins.v); - return {Assert{TypeConstraint{src, TypeGroup::number}}}; + if (const auto src = std::get_if(&ins.v)) { + return {Assert{TypeConstraint{*src, TypeGroup::number}}}; } return {}; - case Bin::Op::ADD: - if (std::holds_alternative(ins.v)) { - auto src = reg(ins.v); + case Bin::Op::ADD: { + if (const auto src = std::get_if(&ins.v)) { return {Assert{TypeConstraint{ins.dst, TypeGroup::ptr_or_num}}, - Assert{TypeConstraint{src, TypeGroup::ptr_or_num}}, Assert{Addable{src, ins.dst}}, - Assert{Addable{ins.dst, src}}}; - } else { - return {Assert{TypeConstraint{ins.dst, TypeGroup::ptr_or_num}}}; + Assert{TypeConstraint{*src, TypeGroup::ptr_or_num}}, Assert{Addable{*src, ins.dst}}, + Assert{Addable{ins.dst, *src}}}; } - case Bin::Op::SUB: - if (std::holds_alternative(ins.v)) { + return {Assert{TypeConstraint{ins.dst, TypeGroup::ptr_or_num}}}; + } + case Bin::Op::SUB: { + if (const auto reg = std::get_if(&ins.v)) { vector res; // disallow map-map since same type does not mean same offset // TODO: map identities res.emplace_back(TypeConstraint{ins.dst, TypeGroup::ptr_or_num}); - res.emplace_back(Comparable{.r1 = ins.dst, .r2 = reg(ins.v), .or_r2_is_number = true}); + res.emplace_back(Comparable{.r1 = ins.dst, .r2 = *reg, .or_r2_is_number = true}); return res; - } else { - return {Assert{TypeConstraint{ins.dst, TypeGroup::ptr_or_num}}}; } + return {Assert{TypeConstraint{ins.dst, TypeGroup::ptr_or_num}}}; + } case Bin::Op::UDIV: case Bin::Op::UMOD: case Bin::Op::SDIV: - case Bin::Op::SMOD: - if (std::holds_alternative(ins.v)) { - auto src = reg(ins.v); - bool is_signed = (ins.op == Bin::Op::SDIV || ins.op == Bin::Op::SMOD); - return {Assert{TypeConstraint{ins.dst, TypeGroup::number}}, Assert{ValidDivisor{src, is_signed}}}; - } else { - return {Assert{TypeConstraint{ins.dst, TypeGroup::number}}}; + case Bin::Op::SMOD: { + if (const auto src = std::get_if(&ins.v)) { + const bool is_signed = (ins.op == Bin::Op::SDIV || ins.op == Bin::Op::SMOD); + return {Assert{TypeConstraint{ins.dst, TypeGroup::number}}, Assert{ValidDivisor{*src, is_signed}}}; } + return {Assert{TypeConstraint{ins.dst, TypeGroup::number}}}; + } default: return {Assert{TypeConstraint{ins.dst, TypeGroup::number}}}; } assert(false); - return {}; } }; -vector get_assertions(Instruction ins, const program_info& info, std::optional label) { +vector get_assertions(Instruction ins, const program_info& info, const std::optional& label) { return std::visit(AssertExtractor{info, label}, ins); } diff --git a/src/crab/cfg.hpp b/src/crab/cfg.hpp index d93130855..8ff0792ff 100644 --- a/src/crab/cfg.hpp +++ b/src/crab/cfg.hpp @@ -616,7 +616,7 @@ std::map collect_stats(const cfg_t&); cfg_t prepare_cfg(const InstructionSeq& prog, const program_info& info, bool simplify, bool must_have_exit = true); void explicate_assertions(cfg_t& cfg, const program_info& info); -std::vector get_assertions(Instruction ins, const program_info& info, std::optional label); +std::vector get_assertions(Instruction ins, const program_info& info, const std::optional& label); void print_dot(const cfg_t& cfg, std::ostream& out); void print_dot(const cfg_t& cfg, const std::string& outfile); diff --git a/src/crab/ebpf_domain.cpp b/src/crab/ebpf_domain.cpp index 98b42cfc6..a4534c9e8 100644 --- a/src/crab/ebpf_domain.cpp +++ b/src/crab/ebpf_domain.cpp @@ -1222,8 +1222,8 @@ void ebpf_domain_t::check_access_shared(NumAbsDomain& inv, const linear_expressi void ebpf_domain_t::operator()(const Assume& s) { const Condition cond = s.cond; const auto dst = reg_pack(cond.left); - if (std::holds_alternative(cond.right)) { - const auto src_reg = std::get(cond.right); + if (const auto psrc_reg = std::get_if(&cond.right)) { + const auto src_reg = *psrc_reg; const auto src = reg_pack(src_reg); if (type_inv.same_type(m_inv, cond.left, std::get(cond.right))) { m_inv = type_inv.join_over_types(m_inv, cond.left, [&](NumAbsDomain& inv, const type_encoding_t type) { @@ -1639,7 +1639,14 @@ void ebpf_domain_t::operator()(const ValidMapKeyValue& s) { } }); } - +static std::tuple lb_ub_access_pair(const ValidAccess& s, + const variable_t offset_var) { + using namespace crab::dsl_syntax; + linear_expression_t lb = offset_var + s.offset; + linear_expression_t ub = std::holds_alternative(s.width) ? lb + std::get(s.width).v + : lb + reg_pack(std::get(s.width)).svalue; + return {lb, ub}; +} void ebpf_domain_t::operator()(const ValidAccess& s) { using namespace crab::dsl_syntax; @@ -1650,10 +1657,7 @@ void ebpf_domain_t::operator()(const ValidAccess& s) { m_inv = type_inv.join_over_types(m_inv, s.reg, [&](NumAbsDomain& inv, type_encoding_t type) { switch (type) { case T_PACKET: { - linear_expression_t lb = reg.packet_offset + s.offset; - linear_expression_t ub = std::holds_alternative(s.width) - ? lb + std::get(s.width).v - : lb + reg_pack(std::get(s.width)).svalue; + auto [lb, ub] = lb_ub_access_pair(s, reg.packet_offset); check_access_packet(inv, lb, ub, is_comparison_check ? std::optional{} : variable_t::packet_size()); // if within bounds, it can never be null @@ -1661,10 +1665,7 @@ void ebpf_domain_t::operator()(const ValidAccess& s) { break; } case T_STACK: { - linear_expression_t lb = reg.stack_offset + s.offset; - linear_expression_t ub = std::holds_alternative(s.width) - ? lb + std::get(s.width).v - : lb + reg_pack(std::get(s.width)).svalue; + auto [lb, ub] = lb_ub_access_pair(s, reg.stack_offset); check_access_stack(inv, lb, ub); // if within bounds, it can never be null if (s.access_type == AccessType::read) { @@ -1672,9 +1673,8 @@ void ebpf_domain_t::operator()(const ValidAccess& s) { if (!stack.all_num(inv, lb, ub)) { if (s.offset < 0) { require(inv, linear_constraint_t::false_const(), "Stack content is not numeric"); - } else if (std::holds_alternative(s.width)) { - if (!inv.entail(static_cast(std::get(s.width).v) <= - reg.stack_numeric_size - s.offset)) { + } else if (const auto pimm = std::get_if(&s.width)) { + if (!inv.entail(static_cast(pimm->v) <= reg.stack_numeric_size - s.offset)) { require(inv, linear_constraint_t::false_const(), "Stack content is not numeric"); } } else { @@ -1687,15 +1687,21 @@ void ebpf_domain_t::operator()(const ValidAccess& s) { break; } case T_CTX: { - linear_expression_t lb = reg.ctx_offset + s.offset; - linear_expression_t ub = std::holds_alternative(s.width) - ? lb + std::get(s.width).v - : lb + reg_pack(std::get(s.width)).svalue; + auto [lb, ub] = lb_ub_access_pair(s, reg.ctx_offset); check_access_context(inv, lb, ub); // if within bounds, it can never be null // The context is both readable and writable. break; } + case T_SHARED: { + auto [lb, ub] = lb_ub_access_pair(s, reg.shared_offset); + check_access_shared(inv, lb, ub, reg.shared_region_size); + if (!is_comparison_check && !s.or_null) { + require(inv, reg.svalue > 0, "Possible null access"); + } + // Shared memory is zero-initialized when created so is safe to read and write. + break; + } case T_NUM: if (!is_comparison_check) { if (s.or_null) { @@ -1711,18 +1717,6 @@ void ebpf_domain_t::operator()(const ValidAccess& s) { require(inv, linear_constraint_t::false_const(), "FDs cannot be dereferenced directly"); } break; - case T_SHARED: { - linear_expression_t lb = reg.shared_offset + s.offset; - linear_expression_t ub = std::holds_alternative(s.width) - ? lb + std::get(s.width).v - : lb + reg_pack(std::get(s.width)).svalue; - check_access_shared(inv, lb, ub, reg.shared_region_size); - if (!is_comparison_check && !s.or_null) { - require(inv, reg.svalue > 0, "Possible null access"); - } - // Shared memory is zero-initialized when created so is safe to read and write. - break; - } default: require(inv, linear_constraint_t::false_const(), "Invalid type"); break; } }); @@ -2016,13 +2010,12 @@ void ebpf_domain_t::operator()(const Mem& b) { if (m_inv.is_bottom()) { return; } - if (std::holds_alternative(b.value)) { + if (const auto preg = std::get_if(&b.value)) { if (b.is_load) { - do_load(b, std::get(b.value)); + do_load(b, *preg); } else { - const auto data = std::get(b.value); - auto data_reg = reg_pack(data); - do_mem_store(b, data, data_reg.svalue, data_reg.uvalue, data_reg); + auto data_reg = reg_pack(*preg); + do_mem_store(b, *preg, data_reg.svalue, data_reg.uvalue, data_reg); } } else { do_mem_store(b, number_t{T_NUM}, number_t{static_cast(std::get(b.value).v)}, @@ -2498,15 +2491,15 @@ void ebpf_domain_t::operator()(const Bin& bin) { auto dst = reg_pack(bin.dst); int finite_width = bin.is64 ? 64 : 32; - if (std::holds_alternative(bin.v)) { + if (auto pimm = std::get_if(&bin.v)) { // dst += K int64_t imm; if (bin.is64) { // Use the full signed value. - imm = static_cast(std::get(bin.v).v); + imm = static_cast(pimm->v); } else { // Use only the low 32 bits of the value. - imm = static_cast(std::get(bin.v).v); + imm = static_cast(pimm->v); bitwise_and(dst.svalue, dst.uvalue, std::numeric_limits::max()); } switch (bin.op) { diff --git a/src/crab/fwd_analyzer.cpp b/src/crab/fwd_analyzer.cpp index ffebe5105..e0ccc0162 100644 --- a/src/crab/fwd_analyzer.cpp +++ b/src/crab/fwd_analyzer.cpp @@ -19,7 +19,7 @@ class member_component_visitor final { bool _found; public: - explicit member_component_visitor(label_t node) : _node(node), _found(false) {} + explicit member_component_visitor(const label_t& node) : _node(node), _found(false) {} void operator()(const label_t& vertex) { if (!_found) { @@ -48,7 +48,7 @@ class member_component_visitor final { }; class interleaved_fwd_fixpoint_iterator_t final { - using iterator = typename invariant_table_t::iterator; + using iterator = invariant_table_t::iterator; cfg_t& _cfg; wto_t _wto; @@ -74,7 +74,7 @@ class interleaved_fwd_fixpoint_iterator_t final { } [[nodiscard]] - static ebpf_domain_t extrapolate(ebpf_domain_t before, const ebpf_domain_t& after, unsigned int iteration) { + static ebpf_domain_t extrapolate(const ebpf_domain_t& before, const ebpf_domain_t& after, unsigned int iteration) { /// number of iterations until triggering widening constexpr auto _widening_delay = 2; @@ -84,7 +84,7 @@ class interleaved_fwd_fixpoint_iterator_t final { return before.widen(after, iteration == _widening_delay); } - static ebpf_domain_t refine(ebpf_domain_t before, const ebpf_domain_t& after, unsigned int iteration) { + static ebpf_domain_t refine(const ebpf_domain_t& before, const ebpf_domain_t& after, unsigned int iteration) { if (iteration == 1) { return before & after; } else { @@ -125,8 +125,8 @@ std::pair run_forward_analyzer(cfg_t& cfg, if (thread_local_options.check_termination) { std::vector cycle_heads; for (auto& component : analyzer._wto) { - if (std::holds_alternative>(component)) { - cycle_heads.push_back(std::get>(component)->head()); + if (const auto pc = std::get_if>(&component)) { + cycle_heads.push_back((*pc)->head()); } } for (const label_t& label : cycle_heads) { @@ -190,7 +190,8 @@ void interleaved_fwd_fixpoint_iterator_t::operator()(const std::shared_ptr(component) || (std::get(component) != head)) { + const auto plabel = std::get_if(&component); + if (!plabel || *plabel != head) { std::visit(*this, component); } } @@ -210,7 +211,8 @@ void interleaved_fwd_fixpoint_iterator_t::operator()(const std::shared_ptr(component) || (std::get(component) != head)) { + const auto plabel = std::get_if(&component); + if (!plabel || *plabel != head) { std::visit(*this, component); } } diff --git a/src/ebpf_yaml.cpp b/src/ebpf_yaml.cpp index cc7752f3a..6a523d24e 100644 --- a/src/ebpf_yaml.cpp +++ b/src/ebpf_yaml.cpp @@ -338,8 +338,8 @@ ConformanceTestResult run_conformance_test_case(const std::vector& memo // Convert the raw program section to a set of instructions. std::variant prog_or_error = unmarshal(raw_prog); - if (std::holds_alternative(prog_or_error)) { - std::cerr << "unmarshaling error at " << std::get(prog_or_error) << "\n"; + if (auto prog = std::get_if(&prog_or_error)) { + std::cerr << "unmarshaling error at " << *prog << "\n"; return {}; } diff --git a/src/main/check.cpp b/src/main/check.cpp index 5c43972ec..a7cd49955 100644 --- a/src/main/check.cpp +++ b/src/main/check.cpp @@ -220,8 +220,8 @@ int main(int argc, char** argv) { // Convert the raw program section to a set of instructions. std::variant prog_or_error = unmarshal(raw_prog); - if (std::holds_alternative(prog_or_error)) { - std::cout << "unmarshaling error at " << std::get(prog_or_error) << "\n"; + if (auto prog = std::get_if(&prog_or_error)) { + std::cout << "unmarshaling error at " << *prog << "\n"; return 1; } diff --git a/src/test/test_marshal.cpp b/src/test/test_marshal.cpp index 2d84c76c2..ec54569d0 100644 --- a/src/test/test_marshal.cpp +++ b/src/test/test_marshal.cpp @@ -299,33 +299,33 @@ static void check_marshal_unmarshal_fail(const Instruction& ins, std::string exp static void check_unmarshal_fail(ebpf_inst inst, std::string expected_error_message, const ebpf_platform_t& platform = g_ebpf_platform_linux) { program_info info{.platform = &platform, .type = platform.get_program_type("unspec", "unspec")}; - std::vector insns = {inst}; + std::vector insns = {inst}; auto result = unmarshal(raw_program{"", "", 0, "", insns, info}); - REQUIRE(std::holds_alternative(result)); - std::string error_message = std::get(result); - REQUIRE(error_message == expected_error_message); + std::string* error_message = std::get_if(&result); + REQUIRE(error_message != nullptr); + REQUIRE(*error_message == expected_error_message); } static void check_unmarshal_fail_goto(ebpf_inst inst, const std::string& expected_error_message, const ebpf_platform_t& platform = g_ebpf_platform_linux) { program_info info{.platform = &platform, .type = platform.get_program_type("unspec", "unspec")}; - const ebpf_inst exit{.opcode = INST_OP_EXIT}; - std::vector insns{inst, exit, exit}; + constexpr ebpf_inst exit{.opcode = INST_OP_EXIT}; + std::vector insns{inst, exit, exit}; auto result = unmarshal(raw_program{"", "", 0, "", insns, info}); - REQUIRE(std::holds_alternative(result)); - std::string error_message = std::get(result); - REQUIRE(error_message == expected_error_message); + std::string* error_message = std::get_if(&result); + REQUIRE(error_message != nullptr); + REQUIRE(*error_message == expected_error_message); } // Check that unmarshaling a 64-bit immediate instruction fails. static void check_unmarshal_fail(ebpf_inst inst1, ebpf_inst inst2, std::string expected_error_message, const ebpf_platform_t& platform = g_ebpf_platform_linux) { program_info info{.platform = &platform, .type = platform.get_program_type("unspec", "unspec")}; - std::vector insns = {inst1, inst2}; + std::vector insns{inst1, inst2}; auto result = unmarshal(raw_program{"", "", 0, "", insns, info}); - REQUIRE(std::holds_alternative(result)); - std::string error_message = std::get(result); - REQUIRE(error_message == expected_error_message); + std::string* error_message = std::get_if(&result); + REQUIRE(error_message != nullptr); + REQUIRE(*error_message == expected_error_message); } static const auto ws = {1, 2, 4, 8}; diff --git a/src/test/test_print.cpp b/src/test/test_print.cpp index e567f15a6..b7e8eaf94 100644 --- a/src/test/test_print.cpp +++ b/src/test/test_print.cpp @@ -26,9 +26,9 @@ void verify_printed_string(const std::string& file) { read_elf(std::string(TEST_OBJECT_FILE_DIRECTORY) + file + ".o", "", nullptr, &g_ebpf_platform_linux); const raw_program& raw_prog = raw_progs.back(); std::variant prog_or_error = unmarshal(raw_prog); - REQUIRE(std::holds_alternative(prog_or_error)); - auto& program = std::get(prog_or_error); - print(program, generated_output, {}); + auto program = std::get_if(&prog_or_error); + REQUIRE(program != nullptr); + print(*program, generated_output, {}); print_map_descriptors(raw_prog.info.map_descriptors, generated_output); std::ifstream expected_stream(std::string(TEST_ASM_FILE_DIRECTORY) + file + std::string(".asm")); REQUIRE(expected_stream); diff --git a/src/test/test_verify.cpp b/src/test/test_verify.cpp index be7fa62b2..7c3160b3c 100644 --- a/src/test/test_verify.cpp +++ b/src/test/test_verify.cpp @@ -38,10 +38,10 @@ FAIL_UNMARSHAL("invalid", "invalid-lddw.o", ".text") auto raw_progs = read_elf("ebpf-samples/" dirname "/" filename, sectionname, nullptr, platform); \ REQUIRE(raw_progs.size() == 1); \ raw_program raw_prog = raw_progs.back(); \ - std::variant prog_or_error = unmarshal(raw_prog); \ - REQUIRE(std::holds_alternative(prog_or_error)); \ - auto& prog = std::get(prog_or_error); \ - bool res = ebpf_verify_program(std::cout, prog, raw_prog.info, options, nullptr); \ + auto prog_or_error = unmarshal(raw_prog); \ + auto prog = std::get_if(&prog_or_error); \ + REQUIRE(prog != nullptr); \ + bool res = ebpf_verify_program(std::cout, *prog, raw_prog.info, options, nullptr); \ if (pass) \ REQUIRE(res); \ else \ @@ -54,10 +54,10 @@ FAIL_UNMARSHAL("invalid", "invalid-lddw.o", ".text") auto raw_progs = read_elf("ebpf-samples/" dirname "/" filename, section_name, nullptr, platform); \ for (auto& raw_prog : raw_progs) { \ if (raw_prog.function_name == program_name) { \ - std::variant prog_or_error = unmarshal(raw_prog); \ - REQUIRE(std::holds_alternative(prog_or_error)); \ - auto& prog = std::get(prog_or_error); \ - bool res = ebpf_verify_program(std::cout, prog, raw_prog.info, options, nullptr); \ + auto prog_or_error = unmarshal(raw_prog); \ + auto prog = std::get_if(&prog_or_error); \ + REQUIRE(prog != nullptr); \ + bool res = ebpf_verify_program(std::cout, *prog, raw_prog.info, options, nullptr); \ if (pass) \ REQUIRE(res); \ else \ @@ -598,18 +598,18 @@ TEST_CASE("multithreading", "[verify][multithreading]") { auto raw_progs1 = read_elf("ebpf-samples/bpf_cilium_test/bpf_netdev.o", "2/1", nullptr, &g_ebpf_platform_linux); REQUIRE(raw_progs1.size() == 1); raw_program raw_prog1 = raw_progs1.back(); - std::variant prog_or_error1 = unmarshal(raw_prog1); - REQUIRE(std::holds_alternative(prog_or_error1)); - auto& prog1 = std::get(prog_or_error1); - cfg_t cfg1 = prepare_cfg(prog1, raw_prog1.info, true); + auto prog_or_error1 = unmarshal(raw_prog1); + auto prog1 = std::get_if(&prog_or_error1); + REQUIRE(prog1 != nullptr); + cfg_t cfg1 = prepare_cfg(*prog1, raw_prog1.info, true); auto raw_progs2 = read_elf("ebpf-samples/bpf_cilium_test/bpf_netdev.o", "2/2", nullptr, &g_ebpf_platform_linux); REQUIRE(raw_progs2.size() == 1); raw_program raw_prog2 = raw_progs2.back(); - std::variant prog_or_error2 = unmarshal(raw_prog2); - REQUIRE(std::holds_alternative(prog_or_error2)); - auto& prog2 = std::get(prog_or_error2); - cfg_t cfg2 = prepare_cfg(prog2, raw_prog2.info, true); + auto prog_or_error2 = unmarshal(raw_prog2); + auto prog2 = std::get_if(&prog_or_error2); + REQUIRE(prog2 != nullptr); + cfg_t cfg2 = prepare_cfg(*prog2, raw_prog2.info, true); bool res1, res2; std::thread a(test_analyze_thread, &cfg1, &raw_prog1.info, &res1);