Skip to content

Commit

Permalink
Use default equality for asm_syntax structs (#592)
Browse files Browse the repository at this point in the history
* Use default equality for asm_syntax structs
* Add test for NEG where .id64=true

Signed-off-by: Elazar Gershuni <[email protected]>
Co-authored-by: Dave Thaler <[email protected]>
  • Loading branch information
elazarg and dthaler authored Feb 21, 2024
1 parent 1056586 commit 3db1b62
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 102 deletions.
17 changes: 8 additions & 9 deletions src/asm_marshal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ static int16_t offset(Bin::Op op) {
return 0;
}

static uint8_t imm(Un::Op op) {
static uint8_t imm_endian(Un::Op op) {
using Op = Un::Op;
switch (op) {
case Op::NEG: return 0;
case Op::NEG: assert(false); return 0;
case Op::BE16:
case Op::LE16:
case Op::SWAP16: return 16;
Expand Down Expand Up @@ -132,22 +132,21 @@ struct MarshalVisitor {
switch (b.op) {
case Un::Op::NEG:
return {ebpf_inst{
// FIX: should be INST_CLS_ALU / INST_CLS_ALU64
.opcode = static_cast<uint8_t>(INST_CLS_ALU | 0x3 | INST_ALU_OP_NEG),
.opcode = static_cast<uint8_t>((b.is64 ? INST_CLS_ALU64 : INST_CLS_ALU) | INST_ALU_OP_NEG),
.dst = b.dst.v,
.src = 0,
.offset = 0,
.imm = imm(b.op),
.imm = 0,
}};
case Un::Op::LE16:
case Un::Op::LE32:
case Un::Op::LE64:
return {ebpf_inst{
.opcode = static_cast<uint8_t>(INST_CLS_ALU | INST_ALU_OP_END),
.opcode = static_cast<uint8_t>(INST_CLS_ALU | INST_END_LE | INST_ALU_OP_END),
.dst = b.dst.v,
.src = 0,
.offset = 0,
.imm = imm(b.op),
.imm = imm_endian(b.op),
}};
case Un::Op::BE16:
case Un::Op::BE32:
Expand All @@ -157,7 +156,7 @@ struct MarshalVisitor {
.dst = b.dst.v,
.src = 0,
.offset = 0,
.imm = imm(b.op),
.imm = imm_endian(b.op),
}};
case Un::Op::SWAP16:
case Un::Op::SWAP32:
Expand All @@ -167,7 +166,7 @@ struct MarshalVisitor {
.dst = b.dst.v,
.src = 0,
.offset = 0,
.imm = imm(b.op),
.imm = imm_endian(b.op),
}};
}
assert(false);
Expand Down
109 changes: 36 additions & 73 deletions src/asm_syntax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ struct label_t {
return label_t{src_label.from, target_label.from};
}

constexpr bool operator==(const label_t& other) const { return from == other.from && to == other.to; }
constexpr bool operator!=(const label_t& other) const { return !(*this == other); }
constexpr bool operator==(const label_t&) const = default;

constexpr bool operator<(const label_t& other) const {
if (this == &other) return false;
if (*this == label_t::exit) return false;
Expand Down Expand Up @@ -61,11 +61,13 @@ namespace asm_syntax {
/// Immediate argument.
struct Imm {
uint64_t v{};
constexpr bool operator==(const Imm&) const = default;
};

/// Register argument.
struct Reg {
uint8_t v{};
constexpr bool operator==(const Reg&) const = default;
};

using Value = std::variant<Imm, Reg>;
Expand Down Expand Up @@ -97,6 +99,7 @@ struct Bin {
Value v;
bool is64{};
bool lddw{};
constexpr bool operator==(const Bin&) const = default;
};

/// Unary operation.
Expand All @@ -117,13 +120,15 @@ struct Un {
Op op;
Reg dst;
bool is64{};
constexpr bool operator==(const Un&) const = default;
};

/// This instruction is encoded similarly to LDDW.
/// See comment in makeLddw() at asm_unmarshal.cpp
struct LoadMapFd {
Reg dst;
int32_t mapfd{};
constexpr bool operator==(const LoadMapFd&) const = default;
};

struct Condition {
Expand All @@ -146,11 +151,13 @@ struct Condition {
Reg left;
Value right;
bool is64{};
constexpr bool operator==(const Condition&) const = default;
};

struct Jmp {
std::optional<Condition> cond;
label_t target;
constexpr bool operator==(const Jmp&) const = default;
};

struct ArgSingle {
Expand All @@ -164,6 +171,7 @@ struct ArgSingle {
ANYTHING,
} kind{};
Reg reg;
constexpr bool operator==(const ArgSingle&) const = default;
};

/// Pair of arguments to a function for pointer and size.
Expand All @@ -176,30 +184,40 @@ struct ArgPair {
Reg mem; ///< Pointer.
Reg size; ///< Size of space pointed to.
bool can_be_zero{};
constexpr bool operator==(const ArgPair&) const = default;
};

struct Call {
int32_t func{};
constexpr bool operator==(const Call& other) const {
return func == other.func;
}

// TODO: move name and signature information somewhere else
std::string name;
bool is_map_lookup{};
bool reallocate_packet{};
std::vector<ArgSingle> singles;
std::vector<ArgPair> pairs;
};

struct Exit {};
struct Exit {
constexpr bool operator==(const Exit&) const = default;
};

struct Deref {
int32_t width{};
Reg basereg;
int32_t offset{};
constexpr bool operator==(const Deref&) const = default;
};

/// Load/store instruction.
struct Mem {
Deref access;
Value value;
bool is_load{};
constexpr bool operator==(const Mem&) const = default;
};

/// A deprecated instruction for checked access to packets; it is actually a
Expand All @@ -209,24 +227,28 @@ struct Packet {
int32_t width{};
int32_t offset{};
std::optional<Reg> regoffset;
constexpr bool operator==(const Packet&) const = default;
};

/// Special instruction for incrementing values inside shared memory.
struct LockAdd {
Deref access;
Reg valreg;
constexpr bool operator==(const LockAdd&) const = default;
};

/// Not an instruction, just used for failure cases.
struct Undefined {
int opcode{};
constexpr bool operator==(const Undefined&) const = default;
};

/// When a CFG is translated to its nondeterministic form, Conditional Jump
/// instructions are replaced by two Assume instructions, immediately after
/// the branch and before each jump target.
struct Assume {
Condition cond;
constexpr bool operator==(const Assume&) const = default;
};

enum class TypeGroup {
Expand All @@ -250,6 +272,7 @@ enum class TypeGroup {
struct ValidSize {
Reg reg;
bool can_be_zero{};
constexpr bool operator==(const ValidSize&) const = default;
};

/// Condition check whether two registers can be compared with each other.
Expand All @@ -259,18 +282,21 @@ struct Comparable {
Reg r1;
Reg r2;
bool or_r2_is_number{}; ///< true for subtraction, false for comparison
constexpr bool operator==(const Comparable&) const = default;
};

// ptr: ptr -> num : num
struct Addable {
Reg ptr;
Reg num;
constexpr bool operator==(const Addable&) const = default;
};

// Condition check whether a register contains a non-zero number.
struct ValidDivisor {
Reg reg;
bool is_signed{};
constexpr bool operator==(const ValidDivisor&) const = default;
};

enum class AccessType {
Expand All @@ -285,29 +311,34 @@ struct ValidAccess {
Value width{Imm{0}};
bool or_null{};
AccessType access_type{};
constexpr bool operator==(const ValidAccess&) const = default;
};

/// Condition check whether something is a valid key value.
struct ValidMapKeyValue {
Reg access_reg;
Reg map_fd_reg;
bool key{};
constexpr bool operator==(const ValidMapKeyValue&) const = default;
};

// "if mem is not stack, val is num"
struct ValidStore {
Reg mem;
Reg val;
constexpr bool operator==(const ValidStore&) const = default;
};

struct TypeConstraint {
Reg reg;
TypeGroup types;
constexpr bool operator==(const TypeConstraint&) const = default;
};

/// Condition check whether something is a valid size.
struct ZeroCtxOffset {
Reg reg;
constexpr bool operator==(const ZeroCtxOffset&) const = default;
};

using AssertionConstraint =
Expand All @@ -316,88 +347,22 @@ using AssertionConstraint =
struct Assert {
AssertionConstraint cst;
Assert(AssertionConstraint cst): cst(cst) { }
constexpr bool operator==(const Assert&) const = default;
};

struct IncrementLoopCounter {
label_t name;
constexpr bool operator==(const IncrementLoopCounter&) const = default;
};

using Instruction = std::variant<Undefined, Bin, Un, LoadMapFd, Call, Exit, Jmp, Mem, Packet, LockAdd, Assume, Assert, IncrementLoopCounter>;

using LabeledInstruction = std::tuple<label_t, Instruction, std::optional<btf_line_info_t>>;
using InstructionSeq = std::vector<LabeledInstruction>;


#define DECLARE_EQ5(T, f1, f2, f3, f4, f5) \
inline bool operator==(T const& a, T const& b) { \
return a.f1 == b.f1 && a.f2 == b.f2 && a.f3 == b.f3 && a.f4 == b.f4 && a.f5 == b.f5; \
}
#define DECLARE_EQ3(T, f1, f2, f3) \
inline bool operator==(T const& a, T const& b) { return a.f1 == b.f1 && a.f2 == b.f2 && a.f3 == b.f3; }
#define DECLARE_EQ2(T, f1, f2) \
inline bool operator==(T const& a, T const& b) { return a.f1 == b.f1 && a.f2 == b.f2; }
#define DECLARE_EQ1(T, f1) \
inline bool operator==(T const& a, T const& b) { return a.f1 == b.f1; }

// cpu=v4 supports 32-bit PC offsets so we need a large enough type.
using pc_t = size_t;

// Helpers:

struct InstructionVisitorPrototype {
void operator()(Undefined const& a);
void operator()(LoadMapFd const& a);
void operator()(Bin const& a);
void operator()(Un const& a);
void operator()(Call const& a);
void operator()(Exit const& a);
void operator()(Jmp const& a);
void operator()(Assume const& a);
void operator()(Assert const& a);
void operator()(Packet const& a);
void operator()(Mem const& a);
void operator()(LockAdd const& a);
};

inline bool operator==(Imm const& a, Imm const& b) { return a.v == b.v; }
inline bool operator==(Reg const& a, Reg const& b) { return a.v == b.v; }
inline bool operator==(Deref const& a, Deref const& b) {
return a.basereg == b.basereg && a.offset == b.offset && a.width == b.width;
}
inline bool operator==(Condition const& a, Condition const& b) {
return a.left == b.left && a.op == b.op && a.right == b.right;
}
inline bool operator==(Undefined const& a, Undefined const& b) { return a.opcode == b.opcode; }
inline bool operator==(LoadMapFd const& a, LoadMapFd const& b) { return a.dst == b.dst && a.mapfd == b.mapfd; }
inline bool operator==(Bin const& a, Bin const& b) {
return a.op == b.op && a.dst == b.dst && a.is64 == b.is64 && a.v == b.v && a.lddw == b.lddw;
}
inline bool operator==(Un const& a, Un const& b) { return a.op == b.op && a.dst == b.dst; }
inline bool operator==(Call const& a, Call const& b) { return a.func == b.func; }
inline bool operator==(Exit const& a, Exit const& b) { return true; }
inline bool operator==(Jmp const& a, Jmp const& b) { return a.cond == b.cond && a.target == b.target; }
inline bool operator==(Packet const& a, Packet const& b) {
return a.offset == b.offset && a.regoffset == b.regoffset && a.width == b.width;
}
inline bool operator==(Mem const& a, Mem const& b) {
return a.access == b.access && a.value == b.value && a.is_load == b.is_load;
}
inline bool operator==(LockAdd const& a, LockAdd const& b) { return a.access == b.access && a.valreg == b.valreg; }
inline bool operator==(Assume const& a, Assume const& b) { return a.cond == b.cond; }
bool operator==(Assert const& a, Assert const& b);

DECLARE_EQ2(TypeConstraint, reg, types)
DECLARE_EQ2(ValidSize, reg, can_be_zero)
DECLARE_EQ2(Comparable, r1, r2)
DECLARE_EQ2(Addable, ptr, num)
DECLARE_EQ2(ValidDivisor, reg, is_signed)
DECLARE_EQ2(ValidStore, mem, val)
DECLARE_EQ5(ValidAccess, reg, offset, width, or_null, access_type)
DECLARE_EQ3(ValidMapKeyValue, access_reg, map_fd_reg, key)
DECLARE_EQ1(ZeroCtxOffset, reg)
DECLARE_EQ1(Assert, cst)
DECLARE_EQ1(IncrementLoopCounter, name)

}

using namespace asm_syntax;
Expand All @@ -406,5 +371,3 @@ template <class... Ts>
struct overloaded : Ts... {
using Ts::operator()...;
};
template <class... Ts>
overloaded(Ts...)->overloaded<Ts...>;
5 changes: 3 additions & 2 deletions src/asm_unmarshal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,13 +293,14 @@ struct Unmarshaller {
throw InvalidInstruction(pc, "Invalid target r10");
if (inst.dst > R10_STACK_POINTER || inst.src > R10_STACK_POINTER)
throw InvalidInstruction(pc, "Bad register");
return std::visit(overloaded{[&](Un::Op op) -> Instruction { return Un{.op = op, .dst = Reg{inst.dst}, .is64 = (inst.opcode & INST_CLS_MASK) == INST_CLS_ALU64}; },
bool is64 = (inst.opcode & INST_CLS_MASK) == INST_CLS_ALU64;
return std::visit(overloaded{[&](Un::Op op) -> Instruction { return Un{.op = op, .dst = Reg{inst.dst}, .is64 = is64}; },
[&](Bin::Op op) -> Instruction {
Bin res{
.op = op,
.dst = Reg{inst.dst},
.v = getBinValue(pc, inst),
.is64 = (inst.opcode & INST_CLS_MASK) == INST_CLS_ALU64,
.is64 = is64,
};
if (!thread_local_options.allow_division_by_zero && (op == Bin::Op::UDIV || op == Bin::Op::UMOD))
if (std::holds_alternative<Imm>(res.v) && std::get<Imm>(res.v).v == 0)
Expand Down
Loading

0 comments on commit 3db1b62

Please sign in to comment.