Skip to content

Commit

Permalink
intrinsics: add support for X86 vector intrinsics
Browse files Browse the repository at this point in the history
This work was originally authored by Zhengyang Liu <[email protected]>,
and was motivated by the development of the Minotaur project
[https://arxiv.org/abs/2306.00229].
  • Loading branch information
artagnon committed Nov 27, 2024
1 parent 708948d commit 7c83975
Show file tree
Hide file tree
Showing 89 changed files with 2,034 additions and 27 deletions.
835 changes: 811 additions & 24 deletions ir/instr.cpp

Large diffs are not rendered by default.

139 changes: 137 additions & 2 deletions ir/instr.h
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,8 @@ class Phi final : public Instr {
void print(std::ostream &os) const override;
StateValue toSMT(State &s) const override;
smt::expr getTypeConstraints(const Function &f) const override;
std::unique_ptr<Instr>
dup(Function &f, const std::string &suffix) const override;
std::unique_ptr<Instr> dup(Function &f,
const std::string &suffix) const override;
};


Expand Down Expand Up @@ -1286,6 +1286,141 @@ class ShuffleVector final : public Instr {
dup(Function &f, const std::string &suffix) const override;
};

class FakeShuffle final : public Instr {
Value *v1, *v2, *mask;

public:
FakeShuffle(Type &type, std::string &&name, Value &v1, Value &v2, Value &mask)
: Instr(type, std::move(name)), v1(&v1), v2(&v2), mask(&mask) {}
std::vector<Value *> operands() const override;
bool propagatesPoison() const override;
bool hasSideEffects() const override;
void rauw(const Value &what, Value &with) override;
void print(std::ostream &os) const override;
StateValue toSMT(State &s) const override;
smt::expr getTypeConstraints(const Function &f) const override;
std::unique_ptr<Instr> dup(Function &f,
const std::string &suffix) const override;
};

class X86IntrinBinOp final : public Instr {
public:
static constexpr unsigned numOfX86Intrinsics = 135;
enum Op {
#define PROCESS(NAME, A, B, C, D, E, F) NAME,
#include "intrinsics_binop.h"
#undef PROCESS
};

// the shape of a vector is stored as <# of lanes, element bits>
static constexpr std::array<std::pair<unsigned, unsigned>, numOfX86Intrinsics>
shape_op0 = {
#define PROCESS(NAME, A, B, C, D, E, F) std::make_pair(C, D),
#include "intrinsics_binop.h"
#undef PROCESS
};
static constexpr std::array<std::pair<unsigned, unsigned>, numOfX86Intrinsics>
shape_op1 = {
#define PROCESS(NAME, A, B, C, D, E, F) std::make_pair(E, F),
#include "intrinsics_binop.h"
#undef PROCESS
};
static constexpr std::array<std::pair<unsigned, unsigned>, numOfX86Intrinsics>
shape_ret = {
#define PROCESS(NAME, A, B, C, D, E, F) std::make_pair(A, B),
#include "intrinsics_binop.h"
#undef PROCESS
};
static constexpr std::array<unsigned, numOfX86Intrinsics> ret_width = {
#define PROCESS(NAME, A, B, C, D, E, F) A *B,
#include "intrinsics_binop.h"
#undef PROCESS
};

private:
Value *a, *b;
Op op;

public:
static unsigned getRetWidth(Op op) {
return ret_width[op];
}
X86IntrinBinOp(Type &type, std::string &&name, Value &a, Value &b, Op op)
: Instr(type, std::move(name)), a(&a), b(&b), op(op) {}
std::vector<Value *> operands() const override;
bool propagatesPoison() const override;
bool hasSideEffects() const override;
void rauw(const Value &what, Value &with) override;
static std::string getOpName(Op op);
void print(std::ostream &os) const override;
StateValue toSMT(State &s) const override;
smt::expr getTypeConstraints(const Function &f) const override;
std::unique_ptr<Instr> dup(Function &f,
const std::string &suffix) const override;
};

class X86IntrinTerOp final : public Instr {
public:
static constexpr unsigned numOfX86Intrinsics = 1;
enum Op {
#define PROCESS(NAME, A, B, C, D, E, F, G, H) NAME,
#include "intrinsics_terop.h"
#undef PROCESS
};

// the shape of a vector is stored as <# of lanes, element bits>
static constexpr std::array<std::pair<unsigned, unsigned>, numOfX86Intrinsics>
shape_op0 = {
#define PROCESS(NAME, A, B, C, D, E, F, G, H) std::make_pair(C, D),
#include "intrinsics_terop.h"
#undef PROCESS
};
static constexpr std::array<std::pair<unsigned, unsigned>, numOfX86Intrinsics>
shape_op1 = {
#define PROCESS(NAME, A, B, C, D, E, F, G, H) std::make_pair(E, F),
#include "intrinsics_terop.h"
#undef PROCESS
};
static constexpr std::array<std::pair<unsigned, unsigned>, numOfX86Intrinsics>
shape_op2 = {
#define PROCESS(NAME, A, B, C, D, E, F, G, H) std::make_pair(G, H),
#include "intrinsics_terop.h"
#undef PROCESS
};
static constexpr std::array<std::pair<unsigned, unsigned>, numOfX86Intrinsics>
shape_ret = {
#define PROCESS(NAME, A, B, C, D, E, F, G, H) std::make_pair(A, B),
#include "intrinsics_terop.h"
#undef PROCESS
};
static constexpr std::array<unsigned, numOfX86Intrinsics> ret_width = {
#define PROCESS(NAME, A, B, C, D, E, F, G, H) A *B,
#include "intrinsics_terop.h"
#undef PROCESS
};

private:
Value *a, *b, *c;
Op op;

public:
static unsigned getRetWidth(Op op) {
return ret_width[op];
}
X86IntrinTerOp(Type &type, std::string &&name, Value &a, Value &b, Value &c,
Op op)
: Instr(type, std::move(name)), a(&a), b(&b), c(&c), op(op) {}
std::vector<Value *> operands() const override;
bool propagatesPoison() const override;
bool hasSideEffects() const override;
void rauw(const Value &what, Value &with) override;
static std::string getOpName(Op op);
void print(std::ostream &os) const override;
StateValue toSMT(State &s) const override;
smt::expr getTypeConstraints(const Function &f) const override;
std::unique_ptr<Instr> dup(Function &f,
const std::string &suffix) const override;
};

const ConversionOp *isCast(ConversionOp::Op op, const Value &v);
Value *isNoOp(const Value &v);
Expand Down
135 changes: 135 additions & 0 deletions ir/intrinsics_binop.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
PROCESS(x86_sse2_pavg_w, 8, 16, 8, 16, 8, 16)
PROCESS(x86_sse2_pavg_b, 16, 8, 16, 8, 16, 8)
PROCESS(x86_avx2_pavg_w, 16, 16, 16, 16, 16, 16)
PROCESS(x86_avx2_pavg_b, 32, 8, 32, 8, 32, 8)
PROCESS(x86_avx512_pavg_w_512, 32, 16, 32, 16, 32, 16)
PROCESS(x86_avx512_pavg_b_512, 64, 8, 64, 8, 64, 8)
PROCESS(x86_avx2_pshuf_b, 32, 8, 32, 8, 32, 8)
PROCESS(x86_ssse3_pshuf_b_128, 16, 8, 16, 8, 16, 8)
PROCESS(x86_avx512_pshuf_b_512, 64, 8, 64, 8, 64, 8)
PROCESS(x86_sse2_psrl_w, 8, 16, 8, 16, 8, 16)
PROCESS(x86_sse2_psrl_d, 4, 32, 4, 32, 4, 32)
PROCESS(x86_sse2_psrl_q, 2, 64, 2, 64, 2, 64)
PROCESS(x86_avx2_psrl_w, 16, 16, 16, 16, 8, 16)
PROCESS(x86_avx2_psrl_d, 8, 32, 8, 32, 4, 32)
PROCESS(x86_avx2_psrl_q, 4, 64, 4, 64, 2, 64)
PROCESS(x86_avx512_psrl_w_512, 32, 16, 32, 16, 8, 16)
PROCESS(x86_avx512_psrl_d_512, 16, 32, 16, 32, 4, 32)
PROCESS(x86_avx512_psrl_q_512, 8, 64, 8, 64, 2, 64)
PROCESS(x86_sse2_psrli_w, 8, 16, 8, 16, 1, 32)
PROCESS(x86_sse2_psrli_d, 4, 32, 4, 32, 1, 32)
PROCESS(x86_sse2_psrli_q, 2, 64, 2, 64, 1, 32)
PROCESS(x86_avx2_psrli_w, 16, 16, 16, 16, 1, 32)
PROCESS(x86_avx2_psrli_d, 8, 32, 8, 32, 1, 32)
PROCESS(x86_avx2_psrli_q, 4, 64, 4, 64, 1, 32)
PROCESS(x86_avx512_psrli_w_512, 32, 16, 32, 16, 1, 32)
PROCESS(x86_avx512_psrli_d_512, 16, 32, 16, 32, 1, 32)
PROCESS(x86_avx512_psrli_q_512, 8, 64, 8, 64, 1, 32)
PROCESS(x86_avx2_psrlv_d, 4, 32, 4, 32, 4, 32)
PROCESS(x86_avx2_psrlv_d_256, 8, 32, 8, 32, 8, 32)
PROCESS(x86_avx2_psrlv_q, 2, 64, 2, 64, 2, 64)
PROCESS(x86_avx2_psrlv_q_256, 4, 64, 4, 64, 4, 64)
PROCESS(x86_avx512_psrlv_d_512, 16, 32, 16, 32, 16, 32)
PROCESS(x86_avx512_psrlv_q_512, 8, 64, 8, 64, 8, 64)
PROCESS(x86_avx512_psrlv_w_128, 8, 16, 8, 16, 8, 16)
PROCESS(x86_avx512_psrlv_w_256, 16, 16, 16, 16, 16, 16)
PROCESS(x86_avx512_psrlv_w_512, 32, 16, 32, 16, 32, 16)
PROCESS(x86_sse2_psra_w, 8, 16, 8, 16, 8, 16)
PROCESS(x86_sse2_psra_d, 4, 32, 4, 32, 4, 32)
PROCESS(x86_avx2_psra_w, 16, 16, 16, 16, 8, 16)
PROCESS(x86_avx2_psra_d, 8, 32, 8, 32, 4, 32)
PROCESS(x86_avx512_psra_q_128, 2, 64, 2, 64, 2, 64)
PROCESS(x86_avx512_psra_q_256, 4, 64, 4, 64, 2, 64)
PROCESS(x86_avx512_psra_w_512, 32, 16, 32, 16, 8, 16)
PROCESS(x86_avx512_psra_d_512, 16, 32, 16, 32, 4, 32)
PROCESS(x86_avx512_psra_q_512, 8, 64, 8, 64, 2, 64)
PROCESS(x86_sse2_psrai_w, 8, 16, 8, 16, 1, 32)
PROCESS(x86_sse2_psrai_d, 4, 32, 4, 32, 1, 32)
PROCESS(x86_avx2_psrai_w, 16, 16, 16, 16, 1, 32)
PROCESS(x86_avx2_psrai_d, 8, 32, 8, 32, 1, 32)
PROCESS(x86_avx512_psrai_w_512, 32, 16, 32, 16, 1, 32)
PROCESS(x86_avx512_psrai_d_512, 16, 32, 16, 32, 1, 32)
PROCESS(x86_avx512_psrai_q_128, 2, 64, 2, 64, 1, 32)
PROCESS(x86_avx512_psrai_q_256, 4, 64, 4, 64, 1, 32)
PROCESS(x86_avx512_psrai_q_512, 8, 64, 8, 64, 1, 32)
PROCESS(x86_avx2_psrav_d, 4, 32, 4, 32, 4, 32)
PROCESS(x86_avx2_psrav_d_256, 8, 32, 8, 32, 8, 32)
PROCESS(x86_avx512_psrav_d_512, 16, 32, 16, 32, 16, 32)
PROCESS(x86_avx512_psrav_q_128, 2, 64, 2, 64, 2, 64)
PROCESS(x86_avx512_psrav_q_256, 4, 64, 4, 64, 4, 64)
PROCESS(x86_avx512_psrav_q_512, 8, 64, 8, 64, 8, 64)
PROCESS(x86_avx512_psrav_w_128, 8, 16, 8, 16, 8, 16)
PROCESS(x86_avx512_psrav_w_256, 16, 16, 16, 16, 16, 16)
PROCESS(x86_avx512_psrav_w_512, 32, 16, 32, 16, 32, 16)
PROCESS(x86_sse2_psll_w, 8, 16, 8, 16, 8, 16)
PROCESS(x86_sse2_psll_d, 4, 32, 4, 32, 4, 32)
PROCESS(x86_sse2_psll_q, 2, 64, 2, 64, 2, 64)
PROCESS(x86_avx2_psll_w, 16, 16, 16, 16, 8, 16)
PROCESS(x86_avx2_psll_d, 8, 32, 8, 32, 4, 32)
PROCESS(x86_avx2_psll_q, 4, 64, 4, 64, 2, 64)
PROCESS(x86_avx512_psll_w_512, 32, 16, 32, 16, 8, 16)
PROCESS(x86_avx512_psll_d_512, 16, 32, 16, 32, 4, 32)
PROCESS(x86_avx512_psll_q_512, 8, 64, 8, 64, 2, 64)
PROCESS(x86_sse2_pslli_w, 8, 16, 8, 16, 1, 32)
PROCESS(x86_sse2_pslli_d, 4, 32, 4, 32, 1, 32)
PROCESS(x86_sse2_pslli_q, 2, 64, 2, 64, 1, 32)
PROCESS(x86_avx2_pslli_w, 16, 16, 16, 16, 1, 32)
PROCESS(x86_avx2_pslli_d, 8, 32, 8, 32, 1, 32)
PROCESS(x86_avx2_pslli_q, 4, 64, 4, 64, 1, 32)
PROCESS(x86_avx512_pslli_w_512, 32, 16, 32, 16, 1, 32)
PROCESS(x86_avx512_pslli_d_512, 16, 32, 16, 32, 1, 32)
PROCESS(x86_avx512_pslli_q_512, 8, 64, 8, 64, 1, 32)
PROCESS(x86_avx2_psllv_d, 4, 32, 4, 32, 4, 32)
PROCESS(x86_avx2_psllv_d_256, 8, 32, 8, 32, 8, 32)
PROCESS(x86_avx2_psllv_q, 2, 64, 2, 64, 2, 64)
PROCESS(x86_avx2_psllv_q_256, 4, 64, 4, 64, 4, 64)
PROCESS(x86_avx512_psllv_d_512, 16, 32, 16, 32, 16, 32)
PROCESS(x86_avx512_psllv_q_512, 8, 64, 8, 64, 8, 64)
PROCESS(x86_avx512_psllv_w_128, 8, 16, 8, 16, 8, 16)
PROCESS(x86_avx512_psllv_w_256, 16, 16, 16, 16, 16, 16)
PROCESS(x86_avx512_psllv_w_512, 32, 16, 32, 16, 32, 16)
PROCESS(x86_ssse3_psign_b_128, 16, 8, 16, 8, 16, 8)
PROCESS(x86_ssse3_psign_w_128, 8, 16, 8, 16, 8, 16)
PROCESS(x86_ssse3_psign_d_128, 4, 32, 4, 32, 4, 32)
PROCESS(x86_avx2_psign_b, 32, 8, 32, 8, 32, 8)
PROCESS(x86_avx2_psign_w, 16, 16, 16, 16, 16, 16)
PROCESS(x86_avx2_psign_d, 8, 32, 8, 32, 8, 32)
PROCESS(x86_ssse3_phadd_w_128, 8, 16, 8, 16, 8, 16)
PROCESS(x86_ssse3_phadd_d_128, 4, 32, 4, 32, 4, 32)
PROCESS(x86_ssse3_phadd_sw_128, 8, 16, 8, 16, 8, 16)
PROCESS(x86_avx2_phadd_w, 16, 16, 16, 16, 16, 16)
PROCESS(x86_avx2_phadd_d, 8, 32, 8, 32, 8, 32)
PROCESS(x86_avx2_phadd_sw, 16, 16, 16, 16, 16, 16)
PROCESS(x86_ssse3_phsub_w_128, 8, 16, 8, 16, 8, 16)
PROCESS(x86_ssse3_phsub_d_128, 4, 32, 4, 32, 4, 32)
PROCESS(x86_ssse3_phsub_sw_128, 8, 16, 8, 16, 8, 16)
PROCESS(x86_avx2_phsub_w, 16, 16, 16, 16, 16, 16)
PROCESS(x86_avx2_phsub_d, 8, 32, 8, 32, 8, 32)
PROCESS(x86_avx2_phsub_sw, 16, 16, 16, 16, 16, 16)
PROCESS(x86_sse2_pmulh_w, 8, 16, 8, 16, 8, 16)
PROCESS(x86_avx2_pmulh_w, 16, 16, 16, 16, 16, 16)
PROCESS(x86_avx512_pmulh_w_512, 32, 16, 32, 16, 32, 16)
PROCESS(x86_sse2_pmulhu_w, 8, 16, 8, 16, 8, 16)
PROCESS(x86_avx2_pmulhu_w, 16, 16, 16, 16, 16, 16)
PROCESS(x86_avx512_pmulhu_w_512, 32, 16, 32, 16, 32, 16)
PROCESS(x86_sse2_pmadd_wd, 4, 32, 8, 16, 8, 16)
PROCESS(x86_avx2_pmadd_wd, 8, 32, 16, 16, 16, 16)
PROCESS(x86_avx512_pmaddw_d_512, 16, 32, 32, 16, 32, 16)
PROCESS(x86_ssse3_pmadd_ub_sw_128, 8, 16, 16, 8, 16, 8)
PROCESS(x86_avx2_pmadd_ub_sw, 16, 16, 32, 8, 32, 8)
PROCESS(x86_avx512_pmaddubs_w_512, 32, 16, 64, 8, 64, 8)
PROCESS(x86_sse2_packsswb_128, 16, 8, 8, 16, 8, 16)
PROCESS(x86_avx2_packsswb, 32, 8, 16, 16, 16, 16)
PROCESS(x86_avx512_packsswb_512, 64, 8, 32, 16, 32, 16)
PROCESS(x86_sse2_packuswb_128, 16, 8, 8, 16, 8, 16)
PROCESS(x86_avx2_packuswb, 32, 8, 16, 16, 16, 16)
PROCESS(x86_avx512_packuswb_512, 64, 8, 32, 16, 32, 16)
PROCESS(x86_sse2_packssdw_128, 8, 16, 4, 32, 4, 32)
PROCESS(x86_avx2_packssdw, 16, 16, 8, 32, 8, 32)
PROCESS(x86_avx512_packssdw_512, 32, 16, 16, 32, 16, 32)
PROCESS(x86_sse41_packusdw, 8, 16, 4, 32, 4, 32)
PROCESS(x86_avx2_packusdw, 16, 16, 8, 32, 8, 32)
PROCESS(x86_avx512_packusdw_512, 32, 16, 16, 32, 16, 32)
PROCESS(x86_sse2_psad_bw, 2, 64, 16, 8, 16, 8)
PROCESS(x86_avx2_psad_bw, 4, 64, 32, 8, 32, 8)
PROCESS(x86_avx512_psad_bw_512, 8, 64, 64, 8, 64, 8)
1 change: 1 addition & 0 deletions ir/intrinsics_terop.h
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
PROCESS(x86_avx2_pblendvb, 32, 8, 32, 8, 32, 8, 32, 8)
4 changes: 4 additions & 0 deletions ir/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "ir/state.h"
#include "smt/solver.h"
#include "util/compiler.h"
#include "util/config.h"
#include <array>
#include <cassert>
#include <numeric>
Expand Down Expand Up @@ -447,6 +448,9 @@ expr FloatType::getFloat(const expr &v) const {
expr FloatType::fromFloat(State &s, const expr &fp, const Type &from_type0,
unsigned nary, const expr &a, const expr &b,
const expr &c) const {
if (config::use_exact_fp)
return fp.float2BV();

expr isnan = fp.isNaN();
expr val = fp.float2BV();

Expand Down
6 changes: 6 additions & 0 deletions llvm_util/known_fns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,12 @@ known_call(llvm::CallInst &i, const llvm::TargetLibraryInfo &TLI,
RETURN_EXACT();

auto decl = i.getCalledFunction();

if (decl && decl->hasName() && decl->getName().starts_with("__fksv")) {
RETURN_VAL(make_unique<FakeShuffle>(*ty, value_name(i), *args[0], *args[1],
*args[2]));
}

llvm::LibFunc libfn;
if (!decl || !TLI.getLibFunc(*decl, libfn))
RETURN_EXACT();
Expand Down
40 changes: 40 additions & 0 deletions llvm_util/llvm2alive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "llvm/IR/InstVisitor.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicsX86.h"
#include "llvm/IR/Operator.h"
#include "llvm/Support/ModRef.h"
#include <sstream>
Expand Down Expand Up @@ -1222,6 +1223,45 @@ class llvm2alive_ : public llvm::InstVisitor<llvm2alive_, unique_ptr<Instr>> {
case llvm::Intrinsic::prefetch:
return NOP(i);

// intel x86 intrinsics
#define PROCESS(NAME, A, B, C, D, E, F) case llvm::Intrinsic::NAME:
#include "ir/intrinsics_binop.h"
#undef PROCESS
{
PARSE_BINOP();
X86IntrinBinOp::Op op;
switch (i.getIntrinsicID()) {
#define PROCESS(NAME, A, B, C, D, E, F) \
case llvm::Intrinsic::NAME: \
op = X86IntrinBinOp::NAME; \
break;
#include "ir/intrinsics_binop.h"
#undef PROCESS
default:
UNREACHABLE();
}
return make_unique<X86IntrinBinOp>(*ty, value_name(i), *a, *b, op);
}

#define PROCESS(NAME, A, B, C, D, E, F, G, H) case llvm::Intrinsic::NAME:
#include "ir/intrinsics_terop.h"
#undef PROCESS
{
PARSE_TRIOP();
X86IntrinTerOp::Op op;
switch (i.getIntrinsicID()) {
#define PROCESS(NAME, A, B, C, D, E, F, G, H) \
case llvm::Intrinsic::NAME: \
op = X86IntrinTerOp::NAME; \
break;
#include "ir/intrinsics_terop.h"
#undef PROCESS
default:
UNREACHABLE();
}
return make_unique<X86IntrinTerOp>(*ty, value_name(i), *a, *b, *c, op);
}

default:
break;
}
Expand Down
2 changes: 1 addition & 1 deletion llvm_util/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ Type* llvm_type2alive(const llvm::Type *ty) {
auto vty = cast<llvm::VectorType>(ty);
auto elems = vty->getElementCount().getKnownMinValue();
auto ety = llvm_type2alive(vty->getElementType());
if (!ety || elems > 1024)
if (!ety || elems > 2048)
return nullptr;
cache = make_unique<VectorType>("ty_" + to_string(type_id_counter++),
elems, *ety);
Expand Down
10 changes: 10 additions & 0 deletions tests/alive-tv/vector/x86/avx2_psign_w-0.srctgt.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
define <8 x i32> @src(<8 x i32> %v) {
%1 = call <8 x i32> @llvm.x86.avx2.psign.d(<8 x i32> %v, <8 x i32> zeroinitializer)
ret <8 x i32> %1
}

define <8 x i32> @tgt(<8 x i32> %v) {
ret <8 x i32> zeroinitializer
}

declare <8 x i32> @llvm.x86.avx2.psign.d(<8 x i32>, <8 x i32>)
Loading

0 comments on commit 7c83975

Please sign in to comment.