Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(WIP) Vectorised compact mode #248

Draft
wants to merge 18 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/heyoka/detail/llvm_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ HEYOKA_DLL_PUBLIC std::uint64_t get_alignment(llvm::Module &, llvm::Type *);

HEYOKA_DLL_PUBLIC llvm::Value *load_vector_from_memory(ir_builder &, llvm::Value *, std::uint32_t);
HEYOKA_DLL_PUBLIC void store_vector_to_memory(ir_builder &, llvm::Value *, llvm::Value *);
llvm::Value *gather_vector_from_memory(ir_builder &, llvm::Type *, llvm::Value *);

HEYOKA_DLL_PUBLIC llvm::Value *gather_vector_from_memory(ir_builder &, llvm::Type *, llvm::Value *);
HEYOKA_DLL_PUBLIC void scatter_vector_to_memory(ir_builder &, llvm::Value *, llvm::Value *);

HEYOKA_DLL_PUBLIC llvm::Value *vector_splat(ir_builder &, llvm::Value *, std::uint32_t);

Expand Down
30 changes: 2 additions & 28 deletions include/heyoka/expression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,35 +381,9 @@ inline llvm::Value *taylor_diff(llvm_state &s, const expression &ex, const std::
}
}

HEYOKA_DLL_PUBLIC llvm::Function *taylor_c_diff_func_dbl(llvm_state &, const expression &, std::uint32_t, std::uint32_t,
bool);

HEYOKA_DLL_PUBLIC llvm::Function *taylor_c_diff_func_ldbl(llvm_state &, const expression &, std::uint32_t,
std::uint32_t, bool);

#if defined(HEYOKA_HAVE_REAL128)

HEYOKA_DLL_PUBLIC llvm::Function *taylor_c_diff_func_f128(llvm_state &, const expression &, std::uint32_t,
std::uint32_t, bool);

#endif

template <typename T>
inline llvm::Function *taylor_c_diff_func(llvm_state &s, const expression &ex, std::uint32_t n_uvars,
std::uint32_t batch_size, bool high_accuracy)
{
if constexpr (std::is_same_v<T, double>) {
return taylor_c_diff_func_dbl(s, ex, n_uvars, batch_size, high_accuracy);
} else if constexpr (std::is_same_v<T, long double>) {
return taylor_c_diff_func_ldbl(s, ex, n_uvars, batch_size, high_accuracy);
#if defined(HEYOKA_HAVE_REAL128)
} else if constexpr (std::is_same_v<T, mppp::real128>) {
return taylor_c_diff_func_f128(s, ex, n_uvars, batch_size, high_accuracy);
#endif
} else {
static_assert(detail::always_false_v<T>, "Unhandled type.");
}
}
HEYOKA_DLL_PUBLIC llvm::Function *taylor_c_diff_func(llvm_state &, const expression &, std::uint32_t, std::uint32_t,
bool, std::uint32_t);

HEYOKA_DLL_PUBLIC std::uint32_t get_param_size(const expression &);

Expand Down
33 changes: 18 additions & 15 deletions include/heyoka/func.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,13 @@ struct HEYOKA_DLL_PUBLIC func_inner_base {
const std::vector<llvm::Value *> &, llvm::Value *, llvm::Value *,
std::uint32_t, std::uint32_t, std::uint32_t, std::uint32_t, bool) const = 0;
#endif
virtual llvm::Function *taylor_c_diff_func_dbl(llvm_state &, std::uint32_t, std::uint32_t, bool) const = 0;
virtual llvm::Function *taylor_c_diff_func_ldbl(llvm_state &, std::uint32_t, std::uint32_t, bool) const = 0;
virtual llvm::Function *taylor_c_diff_func_dbl(llvm_state &, std::uint32_t, std::uint32_t, bool,
std::uint32_t) const = 0;
virtual llvm::Function *taylor_c_diff_func_ldbl(llvm_state &, std::uint32_t, std::uint32_t, bool,
std::uint32_t) const = 0;
#if defined(HEYOKA_HAVE_REAL128)
virtual llvm::Function *taylor_c_diff_func_f128(llvm_state &, std::uint32_t, std::uint32_t, bool) const = 0;
virtual llvm::Function *taylor_c_diff_func_f128(llvm_state &, std::uint32_t, std::uint32_t, bool,
std::uint32_t) const = 0;
#endif

private:
Expand Down Expand Up @@ -310,7 +313,7 @@ template <typename T>
using func_taylor_c_diff_func_dbl_t
= decltype(std::declval<std::add_lvalue_reference_t<const T>>().taylor_c_diff_func_dbl(
std::declval<llvm_state &>(), std::declval<std::uint32_t>(), std::declval<std::uint32_t>(),
std::declval<bool>()));
std::declval<bool>(), std::declval<std::uint32_t>()));

template <typename T>
inline constexpr bool func_has_taylor_c_diff_func_dbl_v
Expand All @@ -320,7 +323,7 @@ template <typename T>
using func_taylor_c_diff_func_ldbl_t
= decltype(std::declval<std::add_lvalue_reference_t<const T>>().taylor_c_diff_func_ldbl(
std::declval<llvm_state &>(), std::declval<std::uint32_t>(), std::declval<std::uint32_t>(),
std::declval<bool>()));
std::declval<bool>(), std::declval<std::uint32_t>()));

template <typename T>
inline constexpr bool func_has_taylor_c_diff_func_ldbl_v
Expand All @@ -332,7 +335,7 @@ template <typename T>
using func_taylor_c_diff_func_f128_t
= decltype(std::declval<std::add_lvalue_reference_t<const T>>().taylor_c_diff_func_f128(
std::declval<llvm_state &>(), std::declval<std::uint32_t>(), std::declval<std::uint32_t>(),
std::declval<bool>()));
std::declval<bool>(), std::declval<std::uint32_t>()));

template <typename T>
inline constexpr bool func_has_taylor_c_diff_func_f128_v
Expand Down Expand Up @@ -593,31 +596,31 @@ struct HEYOKA_DLL_PUBLIC_INLINE_CLASS func_inner final : func_inner_base {
}
#endif
llvm::Function *taylor_c_diff_func_dbl(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
bool high_accuracy) const final
bool high_accuracy, std::uint32_t vector_size) const final
{
if constexpr (func_has_taylor_c_diff_func_dbl_v<T>) {
return m_value.taylor_c_diff_func_dbl(s, n_uvars, batch_size, high_accuracy);
return m_value.taylor_c_diff_func_dbl(s, n_uvars, batch_size, high_accuracy, vector_size);
} else {
throw not_implemented_error("double Taylor diff in compact mode is not implemented for the function '"
+ get_name() + "'");
}
}
llvm::Function *taylor_c_diff_func_ldbl(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
bool high_accuracy) const final
bool high_accuracy, std::uint32_t vector_size) const final
{
if constexpr (func_has_taylor_c_diff_func_ldbl_v<T>) {
return m_value.taylor_c_diff_func_ldbl(s, n_uvars, batch_size, high_accuracy);
return m_value.taylor_c_diff_func_ldbl(s, n_uvars, batch_size, high_accuracy, vector_size);
} else {
throw not_implemented_error("long double Taylor diff in compact mode is not implemented for the function '"
+ get_name() + "'");
}
}
#if defined(HEYOKA_HAVE_REAL128)
llvm::Function *taylor_c_diff_func_f128(llvm_state &s, std::uint32_t n_uvars, std::uint32_t batch_size,
bool high_accuracy) const final
bool high_accuracy, std::uint32_t vector_size) const final
{
if constexpr (func_has_taylor_c_diff_func_f128_v<T>) {
return m_value.taylor_c_diff_func_f128(s, n_uvars, batch_size, high_accuracy);
return m_value.taylor_c_diff_func_f128(s, n_uvars, batch_size, high_accuracy, vector_size);
} else {
throw not_implemented_error("float128 Taylor diff in compact mode is not implemented for the function '"
+ get_name() + "'");
Expand Down Expand Up @@ -784,10 +787,10 @@ class HEYOKA_DLL_PUBLIC func
llvm::Value *, llvm::Value *, std::uint32_t, std::uint32_t, std::uint32_t,
std::uint32_t, bool) const;
#endif
llvm::Function *taylor_c_diff_func_dbl(llvm_state &, std::uint32_t, std::uint32_t, bool) const;
llvm::Function *taylor_c_diff_func_ldbl(llvm_state &, std::uint32_t, std::uint32_t, bool) const;
llvm::Function *taylor_c_diff_func_dbl(llvm_state &, std::uint32_t, std::uint32_t, bool, std::uint32_t) const;
llvm::Function *taylor_c_diff_func_ldbl(llvm_state &, std::uint32_t, std::uint32_t, bool, std::uint32_t) const;
#if defined(HEYOKA_HAVE_REAL128)
llvm::Function *taylor_c_diff_func_f128(llvm_state &, std::uint32_t, std::uint32_t, bool) const;
llvm::Function *taylor_c_diff_func_f128(llvm_state &, std::uint32_t, std::uint32_t, bool, std::uint32_t) const;
#endif
};

Expand Down
6 changes: 3 additions & 3 deletions include/heyoka/math/binary_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ class HEYOKA_DLL_PUBLIC binary_op : public func_base
std::uint32_t, bool) const;
#endif

llvm::Function *taylor_c_diff_func_dbl(llvm_state &, std::uint32_t, std::uint32_t, bool) const;
llvm::Function *taylor_c_diff_func_ldbl(llvm_state &, std::uint32_t, std::uint32_t, bool) const;
llvm::Function *taylor_c_diff_func_dbl(llvm_state &, std::uint32_t, std::uint32_t, bool, std::uint32_t) const;
llvm::Function *taylor_c_diff_func_ldbl(llvm_state &, std::uint32_t, std::uint32_t, bool, std::uint32_t) const;
#if defined(HEYOKA_HAVE_REAL128)
llvm::Function *taylor_c_diff_func_f128(llvm_state &, std::uint32_t, std::uint32_t, bool) const;
llvm::Function *taylor_c_diff_func_f128(llvm_state &, std::uint32_t, std::uint32_t, bool, std::uint32_t) const;
#endif
};

Expand Down
6 changes: 3 additions & 3 deletions include/heyoka/math/pow.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ class HEYOKA_DLL_PUBLIC pow_impl : public func_base
std::uint32_t, bool) const;
#endif

llvm::Function *taylor_c_diff_func_dbl(llvm_state &, std::uint32_t, std::uint32_t, bool) const;
llvm::Function *taylor_c_diff_func_ldbl(llvm_state &, std::uint32_t, std::uint32_t, bool) const;
llvm::Function *taylor_c_diff_func_dbl(llvm_state &, std::uint32_t, std::uint32_t, bool, std::uint32_t) const;
llvm::Function *taylor_c_diff_func_ldbl(llvm_state &, std::uint32_t, std::uint32_t, bool, std::uint32_t) const;
#if defined(HEYOKA_HAVE_REAL128)
llvm::Function *taylor_c_diff_func_f128(llvm_state &, std::uint32_t, std::uint32_t, bool) const;
llvm::Function *taylor_c_diff_func_f128(llvm_state &, std::uint32_t, std::uint32_t, bool, std::uint32_t) const;
#endif
};

Expand Down
6 changes: 3 additions & 3 deletions include/heyoka/math/sum.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ class HEYOKA_DLL_PUBLIC sum_impl : public func_base
llvm::Value *, llvm::Value *, std::uint32_t, std::uint32_t, std::uint32_t,
std::uint32_t, bool) const;
#endif
llvm::Function *taylor_c_diff_func_dbl(llvm_state &, std::uint32_t, std::uint32_t, bool) const;
llvm::Function *taylor_c_diff_func_ldbl(llvm_state &, std::uint32_t, std::uint32_t, bool) const;
llvm::Function *taylor_c_diff_func_dbl(llvm_state &, std::uint32_t, std::uint32_t, bool, std::uint32_t) const;
llvm::Function *taylor_c_diff_func_ldbl(llvm_state &, std::uint32_t, std::uint32_t, bool, std::uint32_t) const;
#if defined(HEYOKA_HAVE_REAL128)
llvm::Function *taylor_c_diff_func_f128(llvm_state &, std::uint32_t, std::uint32_t, bool) const;
llvm::Function *taylor_c_diff_func_f128(llvm_state &, std::uint32_t, std::uint32_t, bool, std::uint32_t) const;
#endif
};

Expand Down
6 changes: 3 additions & 3 deletions include/heyoka/math/sum_sq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ class HEYOKA_DLL_PUBLIC sum_sq_impl : public func_base
std::uint32_t, bool) const;
#endif

llvm::Function *taylor_c_diff_func_dbl(llvm_state &, std::uint32_t, std::uint32_t, bool) const;
llvm::Function *taylor_c_diff_func_ldbl(llvm_state &, std::uint32_t, std::uint32_t, bool) const;
llvm::Function *taylor_c_diff_func_dbl(llvm_state &, std::uint32_t, std::uint32_t, bool, std::uint32_t) const;
llvm::Function *taylor_c_diff_func_ldbl(llvm_state &, std::uint32_t, std::uint32_t, bool, std::uint32_t) const;
#if defined(HEYOKA_HAVE_REAL128)
llvm::Function *taylor_c_diff_func_f128(llvm_state &, std::uint32_t, std::uint32_t, bool) const;
llvm::Function *taylor_c_diff_func_f128(llvm_state &, std::uint32_t, std::uint32_t, bool, std::uint32_t) const;
#endif
};

Expand Down
11 changes: 11 additions & 0 deletions include/heyoka/taylor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,15 @@ llvm::Value *taylor_codegen_numparam(llvm_state &s, const U &n, llvm::Value *par
}
}

// TODO remove old overloads.
HEYOKA_DLL_PUBLIC llvm::Value *taylor_c_diff_numparam_codegen(llvm_state &, const number &, llvm::Value *,
llvm::Value *, std::uint32_t);
HEYOKA_DLL_PUBLIC llvm::Value *taylor_c_diff_numparam_codegen(llvm_state &, const number &, llvm::Value *,
llvm::Value *, std::uint32_t, std::uint32_t);
HEYOKA_DLL_PUBLIC llvm::Value *taylor_c_diff_numparam_codegen(llvm_state &, const param &, llvm::Value *, llvm::Value *,
std::uint32_t);
HEYOKA_DLL_PUBLIC llvm::Value *taylor_c_diff_numparam_codegen(llvm_state &, const param &, llvm::Value *, llvm::Value *,
std::uint32_t, std::uint32_t);

HEYOKA_DLL_PUBLIC llvm::Value *taylor_fetch_diff(const std::vector<llvm::Value *> &, std::uint32_t, std::uint32_t,
std::uint32_t);
Expand Down Expand Up @@ -172,6 +177,12 @@ taylor_c_diff_func_name_args(llvm::LLVMContext &c, const std::string &name, std:
return taylor_c_diff_func_name_args_impl(c, name, val_t, n_uvars, args, n_hidden_deps);
}

// TODO remove the other version, then rename?
template <typename T>
HEYOKA_DLL_PUBLIC std::pair<std::string, std::vector<llvm::Type *>>
taylor_c_diff_vfunc_name_args(llvm::LLVMContext &, const std::string &, std::uint32_t, std::uint32_t, std::uint32_t,
const std::vector<std::variant<variable, number, param>> &, std::uint32_t = 0);

// Add a function for computing the dense output
// via polynomial evaluation.
template <typename T>
Expand Down
51 changes: 47 additions & 4 deletions src/detail/llvm_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,12 +276,16 @@ void store_vector_to_memory(ir_builder &builder, llvm::Value *ptr, llvm::Value *
}
}

// Gather a vector of type vec_tp from the vector of pointers ptrs.
// Gather a vector of type vec_tp from ptrs. If vec_tp is a vector type, then ptrs
// must be a vector of pointers of the same size and the returned value is also a vector
// of that size. Otherwise, ptrs must be a single scalar pointer and the returned value is a scalar.
llvm::Value *gather_vector_from_memory(ir_builder &builder, llvm::Type *vec_tp, llvm::Value *ptrs)
{
if (llvm::isa<llvm_vector_type>(vec_tp)) {
// LCOV_EXCL_START
assert(llvm::isa<llvm_vector_type>(ptrs->getType()));
assert(llvm::cast<llvm_vector_type>(vec_tp)->getNumElements()
== llvm::cast<llvm_vector_type>(ptrs->getType())->getNumElements());
assert(ptrs->getType()->getScalarType()->getPointerElementType() == vec_tp->getScalarType());
// LCOV_EXCL_STOP

Expand Down Expand Up @@ -311,11 +315,48 @@ llvm::Value *gather_vector_from_memory(ir_builder &builder, llvm::Type *vec_tp,
}
}

// Scatter val to ptrs. If val is a vector, then ptrs must be a vector of pointers
// and a vector scatter takes place. Otherwise, ptrs must be a single scalar pointer
// and a scalar store takes place.
void scatter_vector_to_memory(ir_builder &builder, llvm::Value *val, llvm::Value *ptrs)
{
if (llvm::isa<llvm_vector_type>(ptrs->getType())) {
// LCOV_EXCL_START
assert(llvm::isa<llvm_vector_type>(val->getType()));
assert(llvm::cast<llvm_vector_type>(val->getType())->getNumElements()
== llvm::cast<llvm_vector_type>(ptrs->getType())->getNumElements());
assert(val->getType()->getScalarType() == ptrs->getType()->getScalarType()->getPointerElementType());
// LCOV_EXCL_STOP

// Fetch the alignment of the scalar type.
const auto align = get_alignment(*builder.GetInsertBlock()->getModule(), val->getType()->getScalarType());

builder.CreateMaskedScatter(val, ptrs,
#if LLVM_VERSION_MAJOR == 10
boost::numeric_cast<unsigned>(align)
#else
llvm::Align(align)
#endif
);
} else {
// LCOV_EXCL_START
assert(!llvm::isa<llvm_vector_type>(val->getType()));
assert(ptrs->getType()->getPointerElementType() == val->getType());
// LCOV_EXCL_STOP

// Not a vector, store val directly.
builder.CreateStore(val, ptrs);
}
}

// Create a SIMD vector of size vector_size filled with the value c. If vector_size is 1,
// c will be returned.
llvm::Value *vector_splat(ir_builder &builder, llvm::Value *c, std::uint32_t vector_size)
{
// LCOV_EXCL_START
assert(vector_size > 0u);
assert(!llvm::isa<llvm_vector_type>(c->getType()));
// LCOV_EXCL_STOP

if (vector_size == 1u) {
return c;
Expand All @@ -326,15 +367,18 @@ llvm::Value *vector_splat(ir_builder &builder, llvm::Value *c, std::uint32_t vec

llvm::Type *make_vector_type(llvm::Type *t, std::uint32_t vector_size)
{
// LCOV_EXCL_START
assert(t != nullptr);
assert(vector_size > 0u);
assert(!llvm::isa<llvm_vector_type>(t));
// LCOV_EXCL_STOP

if (vector_size == 1u) {
return t;
} else {
auto retval = llvm_vector_type::get(t, boost::numeric_cast<unsigned>(vector_size));

assert(retval != nullptr);
assert(retval != nullptr); // LCOV_EXCL_LINE

return retval;
}
Expand Down Expand Up @@ -1380,8 +1424,7 @@ llvm::Function *llvm_add_csc_impl(llvm_state &s, llvm::Type *scal_t, std::uint32
vector_splat(builder, builder.getInt32(batch_size), batch_size)));
assert(llvm_depr_GEP_type_check(cf_ptr_v, scal_t)); // LCOV_EXCL_LINE
auto last_nz_ptr = builder.CreateInBoundsGEP(scal_t, cf_ptr_v, last_nz_ptr_idx);
auto last_nz_cf = batch_size > 1u ? gather_vector_from_memory(builder, cur_cf->getType(), last_nz_ptr)
: static_cast<llvm::Value *>(builder.CreateLoad(scal_t, last_nz_ptr));
auto last_nz_cf = gather_vector_from_memory(builder, cur_cf->getType(), last_nz_ptr);

// Compute the sign of the current coefficient(s).
auto cur_sgn = llvm_sgn(s, cur_cf);
Expand Down
Loading