-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
e.g. where, clamp, clamp_min, clamp_max Co-authored-by: Feng Yuan <[email protected]>
- Loading branch information
1 parent
40cff1f
commit 2a992c5
Showing
5 changed files
with
288 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
#include <ATen/ScalarOps.h> | ||
#include <ATen/TensorIndexing.h> | ||
#include <ATen/XPUNativeFunctions.h> | ||
#include <ATen/core/Tensor.h> | ||
#include <ATen/native/TensorCompare.h> | ||
#include <ATen/native/TensorIterator.h> | ||
#include <ATen/native/TypeProperties.h> | ||
#include <aten/sycl/TensorCompare.h> | ||
|
||
namespace at { | ||
|
||
template <typename... Args> | ||
Device out_device(Args&... inps) { | ||
for (const auto& i : {inps...}) { | ||
if (i.device() != at::kCPU) { | ||
return i.device(); | ||
} | ||
} | ||
return at::kCPU; | ||
} | ||
|
||
Tensor& where_self_out( | ||
const Tensor& condition, | ||
const Tensor& self, | ||
const Tensor& other, | ||
Tensor& out) { | ||
const auto result_type = at::native::result_type(self, other); | ||
TORCH_CHECK( | ||
out.scalar_type() == result_type, | ||
"Expected out type to be ", | ||
result_type, | ||
" but got ", | ||
out.scalar_type()); | ||
|
||
auto self_ = self.scalar_type() != result_type ? self.to(result_type) : self; | ||
auto other_ = | ||
other.scalar_type() != result_type ? other.to(result_type) : other; | ||
auto condition_ = condition; | ||
auto device = out_device(condition, self_, other_); | ||
if (device != at::kCPU) { // allow CPU scalars on non-cpu device | ||
if (condition.device() != device && condition.ndimension() == 0) { | ||
condition_ = condition.to(device); | ||
} | ||
if (self_.device() != device && self_.ndimension() == 0) { | ||
self_ = self_.to(device); | ||
} | ||
if (other_.device() != device && other_.ndimension() == 0) { | ||
other_ = other_.to(device); | ||
} | ||
} | ||
if (condition_.scalar_type() == ScalarType::Byte) { | ||
TORCH_WARN_ONCE( | ||
"where received a uint8 condition tensor. This behavior is deprecated and will be removed in a future version of PyTorch. Use a boolean condition instead."); | ||
condition_ = condition_.to(kBool); | ||
} | ||
TORCH_CHECK( | ||
condition_.scalar_type() == kBool, | ||
"where expected condition to be a boolean tensor, but got a tensor with dtype ", | ||
condition_.scalar_type()); | ||
// if there's still a device mismatch, let tensoriterator error out with it | ||
auto iter = at::TensorIteratorConfig() | ||
.check_all_same_dtype(false) | ||
.add_output(out) | ||
.add_const_input(condition_) | ||
.add_const_input(self_) | ||
.add_const_input(other_) | ||
.build(); | ||
native::xpu::where_kernel(iter); | ||
return out; | ||
} | ||
|
||
Tensor& XPUNativeFunctions::where_out( | ||
const Tensor& condition, | ||
const Tensor& self, | ||
const Tensor& other, | ||
Tensor& out) { | ||
return where_self_out(condition, self, other, out); | ||
} | ||
|
||
Tensor XPUNativeFunctions::where( | ||
const Tensor& condition, | ||
const Tensor& self, | ||
const Tensor& other) { | ||
auto device = out_device(condition, self, other); | ||
auto result_type = at::native::result_type(self, other); | ||
Tensor ret = at::empty({0}, self.options().dtype(result_type).device(device)); | ||
where_self_out(condition, self, other, ret); | ||
return ret; | ||
} | ||
|
||
Tensor& XPUNativeFunctions::clamp_out( | ||
const Tensor& self, | ||
const c10::optional<Scalar>& min, | ||
const c10::optional<Scalar>& max, | ||
Tensor& result) { | ||
using at::native::detail::ClampLimits; | ||
if (min && max) { | ||
if ((*min).toDouble() != (*min).toDouble() || | ||
(*max).toDouble() != (*max).toDouble()) { | ||
at::fill_( | ||
const_cast<Tensor&>(result), | ||
std::numeric_limits<double>::quiet_NaN()); | ||
} else { | ||
auto iter = TensorIterator::unary_op(result, self); | ||
native::xpu::clamp_scalar_kernel(iter, *min, *max); | ||
} | ||
} else if (max) { | ||
auto iter = TensorIterator::unary_op(result, self); | ||
native::xpu::clamp_max_scalar_kernel(iter, *max); | ||
} else if (min) { | ||
auto iter = TensorIterator::unary_op(result, self); | ||
native::xpu::clamp_min_scalar_kernel(iter, *min); | ||
} | ||
return result; | ||
} | ||
|
||
Tensor& XPUNativeFunctions::clamp_min_out( | ||
const Tensor& self, | ||
const Scalar& min, | ||
Tensor& result) { | ||
if (min.toDouble() != min.toDouble()) { | ||
at::fill_(const_cast<Tensor&>(result), min); | ||
} else { | ||
auto iter = TensorIterator::unary_op(result, self); | ||
native::xpu::clamp_min_scalar_kernel(iter, min); | ||
} | ||
return result; | ||
} | ||
|
||
Tensor& XPUNativeFunctions::clamp_max_out( | ||
const Tensor& self, | ||
const Scalar& max, | ||
Tensor& result) { | ||
if (max.toDouble() != max.toDouble()) { | ||
// TODO this is not great, building TI again is expensive, but I can't use | ||
// fill_stub because fill is not structured | ||
// this is a corner case anyway | ||
at::fill_(const_cast<Tensor&>(result), native::wrapped_scalar_tensor(max)); | ||
} else { | ||
auto iter = TensorIterator::unary_op(result, self); | ||
native::xpu::clamp_max_scalar_kernel(iter, max); | ||
} | ||
return result; | ||
} | ||
|
||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
#include <ATen/ATen.h> | ||
#include <ATen/Dispatch.h> | ||
#include <ATen/NumericUtils.h> | ||
#include <ATen/native/TensorCompare.h> | ||
#include <ATen/native/TensorIterator.h> | ||
|
||
#include <aten/sycl/Loops.h> | ||
|
||
namespace at { | ||
namespace native { | ||
namespace xpu { | ||
|
||
template <typename scalar_t> | ||
struct WhereFunctor { | ||
scalar_t operator()(bool cond_val, scalar_t self_val, scalar_t other_val) | ||
const { | ||
return cond_val ? self_val : other_val; | ||
} | ||
}; | ||
|
||
template <typename scalar_t> | ||
struct ClampFunctor { | ||
scalar_t operator()(scalar_t v, scalar_t lower, scalar_t upper) const { | ||
if (at::_isnan(v)) { | ||
return v; | ||
} | ||
if (at::_isnan(lower)) { | ||
return lower; | ||
} | ||
if (at::_isnan(upper)) { | ||
return upper; | ||
} else { | ||
return std::min(std::max(v, lower), upper); | ||
} | ||
} | ||
}; | ||
|
||
template <typename scalar_t> | ||
struct ClampScalarFunctor { | ||
using opmath_t = at::opmath_type<scalar_t>; | ||
scalar_t operator()(scalar_t v) const { | ||
if (_isnan(static_cast<opmath_t>(v))) { | ||
return v; | ||
} else if (minmax_ == at::native::detail::ClampLimits::Min) { | ||
return std::max(static_cast<opmath_t>(v), lim0_val_); | ||
} else if (minmax_ == at::native::detail::ClampLimits::Max) { | ||
return std::min(static_cast<opmath_t>(v), lim0_val_); | ||
} else { | ||
return std::min(std::max(static_cast<opmath_t>(v), lim0_val_), lim1_val_); | ||
} | ||
} | ||
ClampScalarFunctor( | ||
opmath_t lim0_val, | ||
opmath_t lim1_val, | ||
at::native::detail::ClampLimits minmax) | ||
: lim0_val_(lim0_val), lim1_val_(lim1_val), minmax_(minmax) {} | ||
|
||
private: | ||
opmath_t lim0_val_; | ||
opmath_t lim1_val_; | ||
at::native::detail::ClampLimits minmax_; | ||
}; | ||
|
||
void where_kernel(TensorIterator& iter) { | ||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( | ||
kComplexHalf, kHalf, kBFloat16, kBool, iter.dtype(), "where_xpu", [&] { | ||
gpu_kernel(iter, WhereFunctor<scalar_t>()); | ||
}); | ||
} | ||
|
||
void clamp_kernel(TensorIteratorBase& iter) { | ||
AT_DISPATCH_ALL_TYPES_AND2( | ||
kHalf, kBFloat16, iter.common_dtype(), "clamp_xpu", [&] { | ||
gpu_kernel(iter, ClampFunctor<scalar_t>()); | ||
}); | ||
} | ||
|
||
void inline launch_clamp_scalar( | ||
TensorIteratorBase& iter, | ||
Scalar lim0, | ||
Scalar lim1, | ||
at::native::detail::ClampLimits minmax) { | ||
AT_DISPATCH_ALL_TYPES_AND2( | ||
kHalf, kBFloat16, iter.common_dtype(), "clamp_scalar_xpu", [&] { | ||
using opmath_t = at::opmath_type<scalar_t>; | ||
auto lim0_val = lim0.to<opmath_t>(); | ||
auto lim1_val = lim1.to<opmath_t>(); | ||
gpu_kernel( | ||
iter, ClampScalarFunctor<scalar_t>(lim0_val, lim1_val, minmax)); | ||
}); | ||
} | ||
|
||
void clamp_scalar_kernel( | ||
TensorIteratorBase& iter, | ||
const Scalar& min, | ||
const Scalar& max) { | ||
launch_clamp_scalar(iter, min, max, at::native::detail::ClampLimits::MinMax); | ||
} | ||
|
||
void clamp_min_scalar_kernel(TensorIteratorBase& iter, Scalar min) { | ||
launch_clamp_scalar(iter, min, min, at::native::detail::ClampLimits::Min); | ||
} | ||
|
||
void clamp_max_scalar_kernel(TensorIteratorBase& iter, Scalar max) { | ||
launch_clamp_scalar(iter, max, max, at::native::detail::ClampLimits::Max); | ||
} | ||
|
||
} // namespace xpu | ||
} // namespace native | ||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#pragma once | ||
|
||
#include <ATen/native/TensorIterator.h> | ||
|
||
namespace at { | ||
namespace native { | ||
namespace xpu { | ||
|
||
void where_kernel(TensorIterator& iter); | ||
|
||
void clamp_kernel(TensorIteratorBase& iter); | ||
|
||
void clamp_scalar_kernel( | ||
TensorIteratorBase& iter, | ||
const Scalar& min, | ||
const Scalar& max); | ||
|
||
void clamp_min_scalar_kernel(TensorIteratorBase& iter, Scalar min); | ||
|
||
void clamp_max_scalar_kernel(TensorIteratorBase& iter, Scalar max); | ||
|
||
} // namespace xpu | ||
} // namespace native | ||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters