Skip to content

Commit

Permalink
Initial support for expression templates in array and array_ref class
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak authored and vgvassilev committed Sep 26, 2023
1 parent 1bdc579 commit b54fcd1
Show file tree
Hide file tree
Showing 3 changed files with 288 additions and 106 deletions.
150 changes: 96 additions & 54 deletions include/clad/Differentiator/Array.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef CLAD_ARRAY_H
#define CLAD_ARRAY_H

#include "clad/Differentiator/ArrayExpression.h"
#include "clad/Differentiator/CladConfig.h"

#include <assert.h>
Expand Down Expand Up @@ -36,23 +37,31 @@ template <typename T> class array {

template <typename U>
CUDA_HOST_DEVICE array(U* a, std::size_t size)
: m_arr(new T[size]{static_cast<T>(T())}), m_size(size) {
: m_arr(new T[size]), m_size(size) {
for (std::size_t i = 0; i < size; ++i)
m_arr[i] = static_cast<T>(a[i]);
}

CUDA_HOST_DEVICE array(const array<T>& arr) : array(arr.m_arr, arr.m_size) {}

CUDA_HOST_DEVICE array(std::size_t size, const clad::array<T>& arr)
: m_arr(new T[size]{static_cast<T>(T())}), m_size(size) {
: m_arr(new T[size]), m_size(size) {
for (std::size_t i = 0; i < size; ++i)
m_arr[i] = arr[i];
}

template <typename L, typename BinaryOp, typename R>
CUDA_HOST_DEVICE array(std::size_t size,
const array_expression<L, BinaryOp, R>& expression)
: m_arr(new T[size]), m_size(size) {
for (std::size_t i = 0; i < size; ++i)
m_arr[i] = expression[i];
}

// initializing all entries using the same value
template <typename U>
CUDA_HOST_DEVICE array(std::size_t size, U val)
: m_arr(new T[size]{static_cast<T>(T())}), m_size(size) {
: m_arr(new T[size]), m_size(size) {
for (std::size_t i = 0; i < size; ++i)
m_arr[i] = static_cast<T>(val);
}
Expand Down Expand Up @@ -229,6 +238,15 @@ template <typename T> class array {
m_arr[i] *= static_cast<T>(arr[i]);
return *this;
}
/// Initializes the clad::array from the given clad::array_expression
template <typename L, typename BinaryOp, typename R>
CUDA_HOST_DEVICE array<T>&
operator=(const array_expression<L, BinaryOp, R>& arr_exp) {
assert(arr_exp.size() == m_size);
for (std::size_t i = 0; i < m_size; i++)
m_arr[i] = arr_exp[i];
return *this;
}
/// Performs element wise division
template <typename U>
CUDA_HOST_DEVICE array<T>& operator/=(const array<U>& arr) {
Expand All @@ -237,25 +255,55 @@ template <typename T> class array {
m_arr[i] /= static_cast<T>(arr[i]);
return *this;
}
/// Performs element wise addition with array_expression
template <typename L, typename BinaryOp, typename R>
CUDA_HOST_DEVICE array<T>&
operator+=(const array_expression<L, BinaryOp, R>& arr_exp) {
assert(arr_exp.size() == m_size);
for (std::size_t i = 0; i < m_size; i++)
m_arr[i] += arr_exp[i];
return *this;
}
/// Performs element wise subtraction with array_expression
template <typename L, typename BinaryOp, typename R>
CUDA_HOST_DEVICE array<T>&
operator-=(const array_expression<L, BinaryOp, R>& arr_exp) {
assert(arr_exp.size() == m_size);
for (std::size_t i = 0; i < m_size; i++)
m_arr[i] -= arr_exp[i];
return *this;
}
/// Performs element wise multiplication with array_expression
template <typename L, typename BinaryOp, typename R>
CUDA_HOST_DEVICE array<T>&
operator*=(const array_expression<L, BinaryOp, R>& arr_exp) {
assert(arr_exp.size() == m_size);
for (std::size_t i = 0; i < m_size; i++)
m_arr[i] *= arr_exp[i];
return *this;
}
/// Performs element wise division with array_expression
template <typename L, typename BinaryOp, typename R>
CUDA_HOST_DEVICE array<T>&
operator/=(const array_expression<L, BinaryOp, R>& arr_exp) {
assert(arr_exp.size() == m_size);
for (std::size_t i = 0; i < m_size; i++)
m_arr[i] /= arr_exp[i];
return *this;
}

/// Negate the array and return a new array.
CUDA_HOST_DEVICE array<T> operator-() const {
array<T> arr2(m_size);
for (std::size_t i = 0; i < m_size; i++)
arr2[i] = -m_arr[i];
return arr2;
CUDA_HOST_DEVICE array_expression<T, BinarySub, array<T>> operator-() const {
return array_expression<T, BinarySub, array<T>>(static_cast<T>(0), *this);
}

/// Subtracts the number from every element in the array and returns a new
/// array, when the number is on the left side.
template <typename U, typename std::enable_if<std::is_arithmetic<U>::value,
int>::type = 0>
CUDA_HOST_DEVICE friend array<T> operator-(U n, const array<T>& arr) {
size_t size = arr.size();
array<T> arr2(size);
for (std::size_t i = 0; i < size; i++)
arr2[i] = n - arr[i];
return arr2;
CUDA_HOST_DEVICE friend array_expression<U, BinarySub, array<T>>
operator-(U n, const array<T>& arr) {
return array_expression<U, BinarySub, array<T>>(n, arr);
}

/// Implicitly converts from clad::array to pointer to an array of type T
Expand All @@ -281,79 +329,73 @@ template <typename T> CUDA_HOST_DEVICE array<T> zero_vector(std::size_t n) {

/// Overloaded operators for clad::array which return a new array.

/// Multiplies the number to every element in the array and returns a new
/// array.
/// Multiplies the number to every element in the array and returns an array
/// expression.
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array<T> operator*(const array<T>& arr, U n) {
array<T> arr2(arr);
arr2 *= n;
return arr2;
CUDA_HOST_DEVICE array_expression<array<T>, BinaryMul, U>
operator*(const array<T>& arr, U n) {
return array_expression<array<T>, BinaryMul, U>(arr, n);
}

/// Multiplies the number to every element in the array and returns a new
/// array, when the number is on the left side.
/// Multiplies the number to every element in the array and returns an array
/// expression, when the number is on the left side.
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array<T> operator*(U n, const array<T>& arr) {
return arr * n;
CUDA_HOST_DEVICE array_expression<array<T>, BinaryMul, U>
operator*(U n, const array<T>& arr) {
return array_expression<array<T>, BinaryMul, U>(arr, n);
}

/// Divides the number from every element in the array and returns a new
/// array
/// Divides the number from every element in the array and returns an array
/// expression.
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array<T> operator/(const array<T>& arr, U n) {
array<T> arr2(arr);
arr2 /= n;
return arr2;
CUDA_HOST_DEVICE array_expression<array<T>, BinaryDiv, U>
operator/(const array<T>& arr, U n) {
return array_expression<array<T>, BinaryDiv, U>(arr, n);
}

/// Adds the number to every element in the array and returns a new array
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array<T> operator+(const array<T>& arr, U n) {
array<T> arr2(arr);
arr2 += n;
return arr2;
CUDA_HOST_DEVICE array_expression<array<T>, BinaryAdd, U>
operator+(const array<T>& arr, U n) {
return array_expression<array<T>, BinaryAdd, U>(arr, n);
}

/// Adds the number to every element in the array and returns a new array,
/// when the number is on the left side.
/// Adds the number to every element in the array and returns an array
/// expression, when the number is on the left side.
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array<T> operator+(U n, const array<T>& arr) {
return arr + n;
CUDA_HOST_DEVICE array_expression<array<T>, BinaryAdd, U>
operator+(U n, const array<T>& arr) {
return array_expression<array<T>, BinaryAdd, U>(arr, n);
}

/// Subtracts the number from every element in the array and returns a new
/// array
/// Subtracts the number from every element in the array and returns an array
/// expression.
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array<T> operator-(const array<T>& arr, U n) {
array<T> arr2(arr);
arr2 -= n;
return arr2;
CUDA_HOST_DEVICE array_expression<array<T>, BinarySub, U>
operator-(const array<T>& arr, U n) {
return array_expression<array<T>, BinarySub, U>(arr, n);
}

/// Function to define element wise adding of two arrays.
template <typename T, typename U>
CUDA_HOST_DEVICE array<T> operator+(const array<T>& arr1,
const array<U>& arr2) {
CUDA_HOST_DEVICE array_expression<array<T>, BinaryAdd, array<U>>
operator+(const array<T>& arr1, const array<U>& arr2) {
assert(arr1.size() == arr2.size());
array<T> arr(arr1);
arr += arr2;
return arr;
return array_expression<array<T>, BinaryAdd, array<U>>(arr1, arr2);
}

/// Function to define element wise subtraction of two arrays.
template <typename T, typename U>
CUDA_HOST_DEVICE array<T> operator-(const array<T>& arr1,
const array<U>& arr2) {
CUDA_HOST_DEVICE array_expression<array<T>, BinarySub, array<U>>
operator-(const array<T>& arr1, const array<U>& arr2) {
assert(arr1.size() == arr2.size());
array<T> arr(arr1);
arr -= arr2;
return arr;
return array_expression<array<T>, BinarySub, array<U>>(arr1, arr2);
}

} // namespace clad
Expand Down
150 changes: 150 additions & 0 deletions include/clad/Differentiator/ArrayExpression.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#ifndef CLAD_DIFFERENTIATOR_ARRAYEXPRESSION_H
#define CLAD_DIFFERENTIATOR_ARRAYEXPRESSION_H

#include <algorithm>
#include <type_traits>

// This is a helper class to implement expression templates for clad::array.

// NOLINTBEGIN(*-pointer-arithmetic)
namespace clad {

// Operator to add two elements.
struct BinaryAdd {
template <typename T, typename U>
static auto apply(T const& t, U const& u) -> decltype(t + u) {
return t + u;
}
};

// Operator to add two elements.
struct BinaryMul {
template <typename T, typename U>
static auto apply(T const& t, U const& u) -> decltype(t * u) {
return t * u;
}
};

// Operator to divide two elements.
struct BinaryDiv {
template <typename T, typename U>
static auto apply(T const& t, U const& u) -> decltype(t / u) {
return t / u;
}
};

// Operator to subtract two elements.
struct BinarySub {
template <typename T, typename U>
static auto apply(T const& t, U const& u) -> decltype(t - u) {
return t - u;
}
};

// Class to represent an array expression using templates.
template <typename LeftExp, typename BinaryOp, typename RightExp>
class array_expression {
LeftExp l;
RightExp r;

public:
array_expression(LeftExp const& l, RightExp const& r) : l(l), r(r) {}

// for scalars
template <typename T, typename std::enable_if<std::is_arithmetic<T>::value,
int>::type = 0>
std::size_t get_size(T const& t) const {
return 1;
}
template <typename T, typename std::enable_if<std::is_arithmetic<T>::value,
int>::type = 0>
T get(T const& t, std::size_t i) const {
return t;
}

// for vectors
template <typename T, typename std::enable_if<!std::is_arithmetic<T>::value,
int>::type = 0>
std::size_t get_size(T const& t) const {
return t.size();
}
template <typename T, typename std::enable_if<!std::is_arithmetic<T>::value,
int>::type = 0>
auto get(T const& t, std::size_t i) const -> decltype(t[i]) {
return t[i];
}

// We also need to handle the case when any of the operands is a scalar.
auto operator[](std::size_t i) const
-> decltype(BinaryOp::apply(get(l, i), get(r, i))) {
return BinaryOp::apply(get(l, i), get(r, i));
}

std::size_t size() const { return std::max(get_size(l), get_size(r)); }

// Operator overload for addition.
template <typename RE>
array_expression<array_expression<LeftExp, BinaryOp, RightExp>, BinaryAdd, RE>
operator+(RE const& r) const {
return array_expression<array_expression<LeftExp, BinaryOp, RightExp>,
BinaryAdd, RE>(*this, r);
}

// Operator overload for multiplication.
template <typename RE>
array_expression<array_expression<LeftExp, BinaryOp, RightExp>, BinaryMul, RE>
operator*(RE const& r) const {
return array_expression<array_expression<LeftExp, BinaryOp, RightExp>,
BinaryMul, RE>(*this, r);
}

// Operator overload for subtraction.
template <typename RE>
array_expression<array_expression<LeftExp, BinaryOp, RightExp>, BinarySub, RE>
operator-(RE const& r) const {
return array_expression<array_expression<LeftExp, BinaryOp, RightExp>,
BinarySub, RE>(*this, r);
}

// Operator overload for division.
template <typename RE>
array_expression<array_expression<LeftExp, BinaryOp, RightExp>, BinaryDiv, RE>
operator/(RE const& r) const {
return array_expression<array_expression<LeftExp, BinaryOp, RightExp>,
BinaryDiv, RE>(*this, r);
}
};

// Operator overload for addition, when the right operand is an array_expression
// and the left operand is a scalar.
template <typename T, typename LeftExp, typename BinaryOp, typename RightExp,
typename std::enable_if<std::is_arithmetic<T>::value, int>::type = 0>
array_expression<T, BinaryAdd, array_expression<LeftExp, BinaryOp, RightExp>>
operator+(T const& l, array_expression<LeftExp, BinaryOp, RightExp> const& r) {
return array_expression<T, BinaryAdd,
array_expression<LeftExp, BinaryOp, RightExp>>(l, r);
}

// Operator overload for multiplication, when the right operand is an
// array_expression and the left operand is a scalar.
template <typename T, typename LeftExp, typename BinaryOp, typename RightExp,
typename std::enable_if<std::is_arithmetic<T>::value, int>::type = 0>
array_expression<T, BinaryMul, array_expression<LeftExp, BinaryOp, RightExp>>
operator*(T const& l, array_expression<LeftExp, BinaryOp, RightExp> const& r) {
return array_expression<T, BinaryMul,
array_expression<LeftExp, BinaryOp, RightExp>>(l, r);
}

// Operator overload for subtraction, when the right operand is an
// array_expression and the left operand is a scalar.
template <typename T, typename LeftExp, typename BinaryOp, typename RightExp,
typename std::enable_if<std::is_arithmetic<T>::value, int>::type = 0>
array_expression<T, BinarySub, array_expression<LeftExp, BinaryOp, RightExp>>
operator-(T const& l, array_expression<LeftExp, BinaryOp, RightExp> const& r) {
return array_expression<T, BinarySub,
array_expression<LeftExp, BinaryOp, RightExp>>(l, r);
}
} // namespace clad
// NOLINTEND(*-pointer-arithmetic)

#endif // CLAD_DIFFERENTIATOR_ARRAYEXPRESSION_H
Loading

0 comments on commit b54fcd1

Please sign in to comment.