Skip to content

Commit

Permalink
Benchmark (#1617)
Browse files Browse the repository at this point in the history
* Benchmark wip

* fix

* prune prints
  • Loading branch information
wsmoses authored Jan 19, 2024
1 parent 14f4526 commit 11cc0f1
Show file tree
Hide file tree
Showing 9 changed files with 432 additions and 564 deletions.
104 changes: 81 additions & 23 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7063,6 +7063,7 @@ Constraints::allSolutions(SCEVExpander &Exp, llvm::Type *T, Instruction *IP,
return {};
}

constexpr bool SparseDebug = false;
std::shared_ptr<const Constraints>
getSparseConditions(bool &legal, Value *val,
std::shared_ptr<const Constraints> defaultFloat,
Expand All @@ -7077,11 +7078,13 @@ getSparseConditions(bool &legal, Value *val,
auto res = lhs->andB(rhs, ctx);
assert(res);
assert(ctx.seen.size() == 0);
llvm::errs() << " getSparse(and, " << *I << "), lhs(" << *I->getOperand(0)
<< ") = " << *lhs << "\n";
llvm::errs() << " getSparse(and, " << *I << "), rhs(" << *I->getOperand(1)
<< ") = " << *rhs << "\n";
llvm::errs() << " getSparse(and, " << *I << ") = " << *res << "\n";
if (SparseDebug) {
llvm::errs() << " getSparse(and, " << *I << "), lhs("
<< *I->getOperand(0) << ") = " << *lhs << "\n";
llvm::errs() << " getSparse(and, " << *I << "), rhs("
<< *I->getOperand(1) << ") = " << *rhs << "\n";
llvm::errs() << " getSparse(and, " << *I << ") = " << *res << "\n";
}
return res;
}

Expand All @@ -7092,11 +7095,13 @@ getSparseConditions(bool &legal, Value *val,
auto rhs = getSparseConditions(legal, I->getOperand(1),
Constraints::none(), I, ctx);
auto res = lhs->orB(rhs, ctx);
llvm::errs() << " getSparse(or, " << *I << "), lhs(" << *I->getOperand(0)
<< ") = " << *lhs << "\n";
llvm::errs() << " getSparse(or, " << *I << "), rhs(" << *I->getOperand(1)
<< ") = " << *rhs << "\n";
llvm::errs() << " getSparse(or, " << *I << ") = " << *res << "\n";
if (SparseDebug) {
llvm::errs() << " getSparse(or, " << *I << "), lhs("
<< *I->getOperand(0) << ") = " << *lhs << "\n";
llvm::errs() << " getSparse(or, " << *I << "), rhs("
<< *I->getOperand(1) << ") = " << *rhs << "\n";
llvm::errs() << " getSparse(or, " << *I << ") = " << *res << "\n";
}
return res;
}

Expand All @@ -7108,9 +7113,12 @@ getSparseConditions(bool &legal, Value *val,
getSparseConditions(legal, I->getOperand(1 - i),
defaultFloat->notB(ctx), scope, ctx);
auto res = pres->notB(ctx);
llvm::errs() << " getSparse(not, " << *I << "), prev ("
<< *I->getOperand(0) << ") = " << *pres << "\n";
llvm::errs() << " getSparse(not, " << *I << ") = " << *res << "\n";
if (SparseDebug) {
llvm::errs() << " getSparse(not, " << *I << "), prev ("
<< *I->getOperand(0) << ") = " << *pres << "\n";
llvm::errs() << " getSparse(not, " << *I << ") = " << *res
<< "\n";
}
return res;
}
}
Expand All @@ -7120,8 +7128,10 @@ getSparseConditions(bool &legal, Value *val,
auto L = ctx.loopToSolve;
auto lhs = ctx.SE.getSCEVAtScope(icmp->getOperand(0), L);
auto rhs = ctx.SE.getSCEVAtScope(icmp->getOperand(1), L);
llvm::errs() << " lhs: " << *lhs << "\n";
llvm::errs() << " rhs: " << *rhs << "\n";
if (SparseDebug) {
llvm::errs() << " lhs: " << *lhs << "\n";
llvm::errs() << " rhs: " << *rhs << "\n";
}

auto sub1 = ctx.SE.getMinusSCEV(lhs, rhs);

Expand All @@ -7145,8 +7155,10 @@ getSparseConditions(bool &legal, Value *val,
auto res = Constraints::make_compare(
div, icmp->getPredicate() == ICmpInst::ICMP_EQ,
add->getLoop(), ctx);
llvm::errs()
<< " getSparse(icmp, " << *I << ") = " << *res << "\n";
if (SparseDebug) {
llvm::errs()
<< " getSparse(icmp, " << *I << ") = " << *res << "\n";
}
return res;
}
}
Expand All @@ -7172,7 +7184,9 @@ getSparseConditions(bool &legal, Value *val,
// cmp x, 1.0 -> false/true
if (auto fcmp = dyn_cast<FCmpInst>(I)) {
auto res = defaultFloat;
llvm::errs() << " getSparse(fcmp, " << *I << ") = " << *res << "\n";
if (SparseDebug) {
llvm::errs() << " getSparse(fcmp, " << *I << ") = " << *res << "\n";
}
return res;

if (fcmp->getPredicate() == CmpInst::FCMP_OEQ ||
Expand Down Expand Up @@ -7263,13 +7277,16 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM,
// Full simplification
while (!Q.empty()) {
auto cur = Q.pop_back_val();
/*
std::set<Instruction *> prev;
for (auto v : Q)
prev.insert(v);
// llvm::errs() << "\n\n\n\n" << F << "\n";
llvm::errs() << "cur: " << *cur << "\n";
*/
auto changed = fixSparse_inner(cur, F, Q, DT, SE, LI, DL);
(void)changed;
/*
if (changed) {
llvm::errs() << "changed: " << *changed << "\n";
Expand All @@ -7278,6 +7295,7 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM,
llvm::errs() << " + " << *I << "\n";
// llvm::errs() << F << "\n\n";
}
*/
}

// llvm::errs() << " post fix inner " << F << "\n";
Expand Down Expand Up @@ -7872,6 +7890,7 @@ void replaceToDense(llvm::CallBase *CI, bool replaceAll, llvm::Function *F,
args.push_back(diff);
for (size_t i = argstart; i < num_args; i++)
args.push_back(CI->getArgOperand(i));

if (load_fn->getFunctionType()->getNumParams() != args.size()) {
auto fnName = load_fn->getName();
auto found_numargs = load_fn->getFunctionType()->getNumParams();
Expand All @@ -7893,7 +7912,7 @@ void replaceToDense(llvm::CallBase *CI, bool replaceAll, llvm::Function *F,
*args[i]->getType(), " found ",
load_fn->getFunctionType()->params()[i]);
tocontinue = true;
break;
args[i] = UndefValue::get(args[i]->getType());
}
}
if (tocontinue)
Expand All @@ -7902,8 +7921,18 @@ void replaceToDense(llvm::CallBase *CI, bool replaceAll, llvm::Function *F,
CallInst *call = B.CreateCall(load_fn, args);
call->setDebugLoc(LI->getDebugLoc());
Value *tmp = call;
if (tmp->getType() != LI->getType())
tmp = B.CreateBitCast(tmp, LI->getType());
if (tmp->getType() != LI->getType()) {
if (CastInst::castIsValid(Instruction::BitCast, tmp, LI->getType()))
tmp = B.CreateBitCast(tmp, LI->getType());
else {
auto fnName = load_fn->getName();
EmitFailure("IllegalSparse", CI->getDebugLoc(), CI,
" incorrect return type of loader function ", fnName,
" expected ", *LI->getType(), " found ",
*call->getType());
tmp = UndefValue::get(LI->getType());
}
}
LI->replaceAllUsesWith(tmp);

if (load_fn->hasFnAttribute(Attribute::AlwaysInline)) {
Expand All @@ -7927,15 +7956,44 @@ void replaceToDense(llvm::CallBase *CI, bool replaceAll, llvm::Function *F,
EmitFailure("IllegalSparse", CI->getDebugLoc(), CI,
" first argument of store function must be the type of "
"the store found fn arg type ",
sty, " expected ", args0ty);
*sty, " expected ", *args0ty);
args[0] = UndefValue::get(sty);
}
}
args.push_back(diff);
for (size_t i = argstart; i < num_args; i++)
args.push_back(CI->getArgOperand(i));

if (store_fn->getFunctionType()->getNumParams() != args.size()) {
auto fnName = store_fn->getName();
auto found_numargs = store_fn->getFunctionType()->getNumParams();
auto expected_numargs = args.size();
EmitFailure("IllegalSparse", CI->getDebugLoc(), CI,
" incorrect number of arguments to store function ", fnName,
" expected ", expected_numargs, " found ", found_numargs,
" - ", *store_fn->getFunctionType());
continue;
} else {
bool tocontinue = false;
for (size_t i = 0; i < args.size(); i++) {
if (store_fn->getFunctionType()->getParamType(i) !=
args[i]->getType()) {
auto fnName = store_fn->getName();
EmitFailure("IllegalSparse", CI->getDebugLoc(), CI,
" incorrect type of argument ", i,
" to storeer function ", fnName, " expected ",
*args[i]->getType(), " found ",
store_fn->getFunctionType()->params()[i]);
tocontinue = true;
args[i] = UndefValue::get(args[i]->getType());
}
}
if (tocontinue)
continue;
}
auto call = B.CreateCall(store_fn, args);
call->setDebugLoc(SI->getDebugLoc());
if (load_fn->hasFnAttribute(Attribute::AlwaysInline)) {
if (store_fn->hasFnAttribute(Attribute::AlwaysInline)) {
InlineFunctionInfo IFI;
InlineFunction(*call, IFI);
}
Expand Down
70 changes: 27 additions & 43 deletions enzyme/test/Integration/Sparse/eigen_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,41 +150,6 @@ static void gradient_ip(const T *__restrict__ pos0, const size_t num_faces, cons
enzyme_dup, x, out);
}


template<typename T>
__attribute__((always_inline))
static T ident_load(unsigned long long offset, size_t i) {
return (offset / sizeof(T) == i) ? T(1) : T(0);
}


template<typename T>
__attribute__((always_inline))
static void err_store(T val, unsigned long long offset, size_t i) {
assert(0 && "store is not legal");
}


template<typename T>
__attribute__((always_inline))
static T zero_load(unsigned long long offset, size_t i, std::vector<Triple<T>> &hess) {
return T(0);
}


__attribute__((enzyme_sparse_accumulate))
void inner_store(size_t offset, size_t i, float val, std::vector<Triple<float>> &hess) {
hess.push_back(Triple<float>(offset, i, val));
}

template<typename T>
__attribute__((always_inline))
static void csr_store(T val, unsigned long long offset, size_t i, std::vector<Triple<T>> &hess) {
if (val == 0.0) return;
offset /= sizeof(T);
inner_store(offset, i, val, hess);
}

template<typename T>
__attribute__((noinline))
std::vector<Triple<T>> hessian(const T*__restrict__ pos0, size_t num_faces, const int* faces, const T*__restrict__ x, size_t x_pts)
Expand Down Expand Up @@ -217,13 +182,20 @@ std::vector<Triple<T>> hessian(const T*__restrict__ pos0, size_t num_faces, cons
enzyme_const, pos02,
enzyme_const, num_faces,
enzyme_const, faces,
enzyme_dup, x2, __enzyme_todense<T*>(ident_load<T>, err_store<T>, i),
enzyme_dupnoneed, nullptr, __enzyme_todense<T*>(zero_load<T>, csr_store<T>, i, &hess));
enzyme_dup, x2, __enzyme_todense<T*>(ident_load<T>, ident_store<T>, i),
enzyme_dupnoneed, nullptr, __enzyme_todense<T*>(sparse_load<T>, sparse_store<T>, i, &hess));
return hess;
}

int main() {
const size_t x_pts = 1;
int main(int argc, char** argv) {
size_t x_pts = 8;

if (argc >= 2) {
x_pts = atoi(argv[1]);
}

// TODO generate data for more inputs
assert(x_pts == 8);
const float x[] = {0.0, 1.0, 0.0};


Expand All @@ -233,25 +205,37 @@ int main() {
const float pos0[] = {1.0, 2.0, 3.0, 4.0, 3.0, 2.0, 3.0, 1.0, 3.0};

// Call eigenstuffM_simple
struct timeval start, end;
gettimeofday(&start, NULL);
const float resultM = eigenstuffM(pos0, num_faces, faces, x);
printf("Result for eigenstuffM_simple: %f\n", resultM);
gettimeofday(&end, NULL);
printf("Result for eigenstuffM_simple: %f, runtime:%f\n", resultM, tdiff(&start, &end));

// Call eigenstuffL_simple
gettimeofday(&start, NULL);
const float resultL = eigenstuffL(pos0, num_faces, faces, x);
printf("Result for eigenstuffL_simple: %f\n", resultL);
gettimeofday(&end, NULL);
printf("Result for eigenstuffL_simple: %f, runtime:%f\n", resultL, tdiff(&start, &end));

float dx[sizeof(x)/sizeof(x[0])];
for (size_t i=0; i<sizeof(dx)/sizeof(x[0]); i++)
dx[i] = 0;
gradient_ip(pos0, num_faces, faces, x, dx);

if (x_pts < 30) {
for (size_t i=0; i<sizeof(dx)/sizeof(dx[0]); i++)
printf("eigenstuffM grad_vert[%zu]=%f\n", i, dx[i]);

size_t num_elts = sizeof(x)/sizeof(x[0]) * sizeof(x)/sizeof(x[0]);
}

gettimeofday(&start, NULL);
auto hess_x = hessian(pos0, num_faces, faces, x, x_pts);
gettimeofday(&end, NULL);

printf("Number of elements %ld\n", hess_x.size());

printf("Runtime %0.6f\n", tdiff(&start, &end));

if (x_pts <= 8)
for (auto &hess : hess_x) {
printf("i=%lu, j=%lu, val=%f\n", hess.row, hess.col, hess.val);
}
Expand Down
Loading

0 comments on commit 11cc0f1

Please sign in to comment.