Skip to content

Commit

Permalink
Implement sycl::vec (without multi-component swizzles)
Browse files Browse the repository at this point in the history
  • Loading branch information
fknorr committed Dec 13, 2023
1 parent af78034 commit cf9cedd
Show file tree
Hide file tree
Showing 7 changed files with 523 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ add_library(simsycl
include/simsycl/sycl/reduction.hh
include/simsycl/sycl/sub_group.hh
include/simsycl/sycl/type_traits.hh
include/simsycl/sycl/usm.hh
include/simsycl/sycl/vec.hh
"${CMAKE_CURRENT_BINARY_DIR}/include/simsycl/config.hh"
src/simsycl/dummy.cc
)
Expand Down
26 changes: 25 additions & 1 deletion include/simsycl/detail/utils.hh
Original file line number Diff line number Diff line change
@@ -1,10 +1,34 @@
#pragma once

#include <utility>

namespace simsycl::detail {

template<typename T, typename T2>
auto div_ceil(T a, T2 b) {
constexpr auto div_ceil(T a, T2 b) {
return (a + b - 1) / b;
}

template<typename T>
constexpr T &&max(T &&x) {
return std::forward<T>(x);
}

template<typename T, typename... Ts>
constexpr decltype(auto) max(T &&x, Ts &&...ts) {
decltype(auto) rhs = max(std::forward<Ts>(ts)...);
return x < rhs ? std::forward<decltype(rhs)>(rhs) : std::forward<T>(x);
}

template<typename T>
constexpr T &&min(T &&x) {
return std::forward<T>(x);
}

template<typename T, typename... Ts>
constexpr decltype(auto) min(T &&x, Ts &&...ts) {
decltype(auto) rhs = min(std::forward<Ts>(ts)...);
return x < rhs ? std::forward<T>(x) : std::forward<decltype(rhs)>(rhs);
}

} // namespace simsycl::detail
1 change: 1 addition & 0 deletions include/simsycl/sycl.hh
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@
#include "sycl/sub_group.hh"
#include "sycl/type_traits.hh"
#include "sycl/usm.hh"
#include "sycl/vec.hh"
20 changes: 15 additions & 5 deletions include/simsycl/sycl/atomic_ref.hh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

#include "enums.hh"

#include <algorithm>
#include "../detail/utils.hh"

#include <cstdlib>
#include <utility>

Expand Down Expand Up @@ -155,6 +156,9 @@ class atomic_ref<Integral, DefaultOrder, DefaultScope, AddressSpace>
using base::default_read_modify_write_order;
using base::default_scope;

using base::base;
using base::operator=;

Integral fetch_add(Integral operand, memory_order order = default_read_modify_write_order,
memory_scope scope = default_scope) const noexcept {
(void)order;
Expand Down Expand Up @@ -205,7 +209,7 @@ class atomic_ref<Integral, DefaultOrder, DefaultScope, AddressSpace>
(void)order;
(void)scope;
const auto original = m_ref;
m_ref = std::min(m_ref, operand);
m_ref = detail::min(m_ref, operand);
return original;
}

Expand All @@ -214,7 +218,7 @@ class atomic_ref<Integral, DefaultOrder, DefaultScope, AddressSpace>
(void)order;
(void)scope;
const auto original = m_ref;
m_ref = std::max(m_ref, operand);
m_ref = detail::max(m_ref, operand);
return original;
}

Expand Down Expand Up @@ -247,6 +251,9 @@ class atomic_ref<Floating, DefaultOrder, DefaultScope, AddressSpace>
using base::default_read_modify_write_order;
using base::default_scope;

using base::base;
using base::operator=;

Floating fetch_add(Floating operand, memory_order order = default_read_modify_write_order,
memory_scope scope = default_scope) const noexcept {
(void)order;
Expand All @@ -270,7 +277,7 @@ class atomic_ref<Floating, DefaultOrder, DefaultScope, AddressSpace>
(void)order;
(void)scope;
const auto original = m_ref;
m_ref = std::min(m_ref, operand);
m_ref = detail::min(m_ref, operand);
return original;
}

Expand All @@ -279,7 +286,7 @@ class atomic_ref<Floating, DefaultOrder, DefaultScope, AddressSpace>
(void)order;
(void)scope;
const auto original = m_ref;
m_ref = std::max(m_ref, operand);
m_ref = detail::max(m_ref, operand);
return original;
}

Expand All @@ -304,6 +311,9 @@ class atomic_ref<T *, DefaultOrder, DefaultScope, AddressSpace>
using base::default_read_modify_write_order;
using base::default_scope;

using base::base;
using base::operator=;

T *fetch_add(difference_type operand, memory_order order = default_read_modify_write_order,
memory_scope scope = default_scope) const noexcept {
(void)order;
Expand Down
3 changes: 3 additions & 0 deletions include/simsycl/sycl/forward.hh
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ class unsampled_image_accessor;
template<int Dimensions = 1, typename AllocatorT = image_allocator>
class unsampled_image;

template <typename DataT, int NumElements>
class vec;

} // namespace simsycl::sycl

namespace simsycl::detail {
Expand Down
9 changes: 7 additions & 2 deletions include/simsycl/sycl/type_traits.hh
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,17 @@ struct is_function_object : std::false_type {};
template<class Fn>
inline constexpr bool is_function_object_v = is_function_object<Fn>::value;

// TODO: Add sycl::half when we have that type
template<typename T>
struct is_arithmetic : std::bool_constant<std::is_arithmetic_v<T>> {};
struct is_arithmetic : std::bool_constant<std::is_arithmetic_v<T> || std::is_same_v<T, sycl::half>> {};
template<class T>
inline constexpr bool is_arithmetic_v = is_arithmetic<T>::value;

template<typename T>
struct is_floating_point
: std::bool_constant<std::is_same_v<T, sycl::half> || std::is_same_v<T, float> || std::is_same_v<T, double>> {};
template<class T>
inline constexpr bool is_floating_point_v = is_floating_point<T>::value;

template<typename...>
constexpr bool always_false = false;

Expand Down
Loading

0 comments on commit cf9cedd

Please sign in to comment.