Skip to content

Commit

Permalink
[SYCL] Add support for half type
Browse files Browse the repository at this point in the history
Signed-off-by: Mariya Podchishchaeva <[email protected]>
  • Loading branch information
Fznamznon authored and romanovvlad committed Apr 16, 2019
1 parent e87838c commit 83403c9
Show file tree
Hide file tree
Showing 10 changed files with 718 additions and 83 deletions.
1 change: 1 addition & 0 deletions sycl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ add_library("${SYCLLibrary}" SHARED
"${sourceRootPath}/device_selector.cpp"
"${sourceRootPath}/event.cpp"
"${sourceRootPath}/exception.cpp"
"${sourceRootPath}/half_type.cpp"
"${sourceRootPath}/kernel.cpp"
"${sourceRootPath}/platform.cpp"
"${sourceRootPath}/queue.cpp"
Expand Down
84 changes: 84 additions & 0 deletions sycl/include/CL/sycl/half_type.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
//==-------------- half_type.hpp --- SYCL half type ------------------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#pragma once

#include <cstdint>
#include <functional>

namespace cl {
namespace sycl {
namespace detail {
namespace half_impl {

class half {
public:
half() = default;
half(const half &) = default;
half(half &&) = default;

half(const float &rhs);

half &operator=(const half &rhs) = default;

// Operator +=, -=, *=, /=
half &operator+=(const half &rhs);

half &operator-=(const half &rhs);

half &operator*=(const half &rhs);

half &operator/=(const half &rhs);

// Operator ++, --
half &operator++() {
*this += 1;
return *this;
}

half operator++(int) {
half ret(*this);
operator++();
return ret;
}

half &operator--() {
*this -= 1;
return *this;
}

half operator--(int) {
half ret(*this);
operator--();
return ret;
}

// Operator float
operator float() const;

template <typename Key> friend struct std::hash;

private:
uint16_t Buf;
};
} // namespace half_impl
} // namespace detail

} // namespace sycl
} // namespace cl

namespace std {

template <> struct hash<cl::sycl::detail::half_impl::half> {
size_t operator()(cl::sycl::detail::half_impl::half const &key) const
noexcept {
return hash<uint16_t>()(key.Buf);
}
};

} // namespace std
21 changes: 14 additions & 7 deletions sycl/include/CL/sycl/intel/sub_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,18 @@ struct sub_group {
return BinaryOperation::template calc<T, cl::__spirv::InclusiveScan>(x);
}

template <typename T>
using EnableIfIsArithmeticOrHalf = typename std::enable_if<
(std::is_arithmetic<T>::value ||
std::is_same<typename std::remove_const<T>::type, half>::value),
T>::type;


/* --- one - input shuffles --- */
/* indices in [0 , sub - group size ) */

template <typename T>
typename std::enable_if<std::is_arithmetic<T>::value, T>::type
EnableIfIsArithmeticOrHalf<T>
shuffle(T x, id<1> local_id) {
return cl::__spirv::OpSubgroupShuffleINTEL(x, local_id.get(0));
}
Expand All @@ -156,7 +163,7 @@ struct sub_group {
}

template <typename T>
typename std::enable_if<std::is_arithmetic<T>::value, T>::type
EnableIfIsArithmeticOrHalf<T>
shuffle_down(T x, uint32_t delta) {
return shuffle_down(x, x, delta);
}
Expand All @@ -168,7 +175,7 @@ struct sub_group {
}

template <typename T>
typename std::enable_if<std::is_arithmetic<T>::value, T>::type
EnableIfIsArithmeticOrHalf<T>
shuffle_up(T x, uint32_t delta) {
return shuffle_up(x, x, delta);
}
Expand All @@ -180,7 +187,7 @@ struct sub_group {
}

template <typename T>
typename std::enable_if<std::is_arithmetic<T>::value, T>::type
EnableIfIsArithmeticOrHalf<T>
shuffle_xor(T x, id<1> value) {
return cl::__spirv::OpSubgroupShuffleXorINTEL(x, (uint32_t)value.get(0));
}
Expand All @@ -195,7 +202,7 @@ struct sub_group {
/* --- two - input shuffles --- */
/* indices in [0 , 2* sub - group size ) */
template <typename T>
typename std::enable_if<std::is_arithmetic<T>::value, T>::type
EnableIfIsArithmeticOrHalf<T>
shuffle(T x, T y, id<1> local_id) {
return cl::__spirv::OpSubgroupShuffleDownINTEL(
x, y, local_id.get(0) - get_local_id().get(0));
Expand All @@ -210,7 +217,7 @@ struct sub_group {
}

template <typename T>
typename std::enable_if<std::is_arithmetic<T>::value, T>::type
EnableIfIsArithmeticOrHalf<T>
shuffle_down(T current, T next, uint32_t delta) {
return cl::__spirv::OpSubgroupShuffleDownINTEL(current, next, delta);
}
Expand All @@ -223,7 +230,7 @@ struct sub_group {
}

template <typename T>
typename std::enable_if<std::is_arithmetic<T>::value, T>::type
EnableIfIsArithmeticOrHalf<T>
shuffle_up(T previous, T current, uint32_t delta) {
return cl::__spirv::OpSubgroupShuffleUpINTEL(previous, current, delta);
}
Expand Down
Loading

0 comments on commit 83403c9

Please sign in to comment.