Skip to content

Commit

Permalink
Add tensor compare ops (#80)
Browse files Browse the repository at this point in the history
e.g. where, clamp, clamp_min, clamp_max

Co-authored-by: Feng Yuan <[email protected]>
  • Loading branch information
xytintel and fengyuan14 authored Mar 27, 2024
1 parent 40cff1f commit 2a992c5
Show file tree
Hide file tree
Showing 5 changed files with 288 additions and 0 deletions.
146 changes: 146 additions & 0 deletions src/aten/TensorCompare.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#include <ATen/ScalarOps.h>
#include <ATen/TensorIndexing.h>
#include <ATen/XPUNativeFunctions.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/TensorCompare.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/TypeProperties.h>
#include <aten/sycl/TensorCompare.h>

namespace at {

template <typename... Args>
Device out_device(Args&... inps) {
for (const auto& i : {inps...}) {
if (i.device() != at::kCPU) {
return i.device();
}
}
return at::kCPU;
}

Tensor& where_self_out(
const Tensor& condition,
const Tensor& self,
const Tensor& other,
Tensor& out) {
const auto result_type = at::native::result_type(self, other);
TORCH_CHECK(
out.scalar_type() == result_type,
"Expected out type to be ",
result_type,
" but got ",
out.scalar_type());

auto self_ = self.scalar_type() != result_type ? self.to(result_type) : self;
auto other_ =
other.scalar_type() != result_type ? other.to(result_type) : other;
auto condition_ = condition;
auto device = out_device(condition, self_, other_);
if (device != at::kCPU) { // allow CPU scalars on non-cpu device
if (condition.device() != device && condition.ndimension() == 0) {
condition_ = condition.to(device);
}
if (self_.device() != device && self_.ndimension() == 0) {
self_ = self_.to(device);
}
if (other_.device() != device && other_.ndimension() == 0) {
other_ = other_.to(device);
}
}
if (condition_.scalar_type() == ScalarType::Byte) {
TORCH_WARN_ONCE(
"where received a uint8 condition tensor. This behavior is deprecated and will be removed in a future version of PyTorch. Use a boolean condition instead.");
condition_ = condition_.to(kBool);
}
TORCH_CHECK(
condition_.scalar_type() == kBool,
"where expected condition to be a boolean tensor, but got a tensor with dtype ",
condition_.scalar_type());
// if there's still a device mismatch, let tensoriterator error out with it
auto iter = at::TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(out)
.add_const_input(condition_)
.add_const_input(self_)
.add_const_input(other_)
.build();
native::xpu::where_kernel(iter);
return out;
}

Tensor& XPUNativeFunctions::where_out(
const Tensor& condition,
const Tensor& self,
const Tensor& other,
Tensor& out) {
return where_self_out(condition, self, other, out);
}

Tensor XPUNativeFunctions::where(
const Tensor& condition,
const Tensor& self,
const Tensor& other) {
auto device = out_device(condition, self, other);
auto result_type = at::native::result_type(self, other);
Tensor ret = at::empty({0}, self.options().dtype(result_type).device(device));
where_self_out(condition, self, other, ret);
return ret;
}

Tensor& XPUNativeFunctions::clamp_out(
const Tensor& self,
const c10::optional<Scalar>& min,
const c10::optional<Scalar>& max,
Tensor& result) {
using at::native::detail::ClampLimits;
if (min && max) {
if ((*min).toDouble() != (*min).toDouble() ||
(*max).toDouble() != (*max).toDouble()) {
at::fill_(
const_cast<Tensor&>(result),
std::numeric_limits<double>::quiet_NaN());
} else {
auto iter = TensorIterator::unary_op(result, self);
native::xpu::clamp_scalar_kernel(iter, *min, *max);
}
} else if (max) {
auto iter = TensorIterator::unary_op(result, self);
native::xpu::clamp_max_scalar_kernel(iter, *max);
} else if (min) {
auto iter = TensorIterator::unary_op(result, self);
native::xpu::clamp_min_scalar_kernel(iter, *min);
}
return result;
}

Tensor& XPUNativeFunctions::clamp_min_out(
const Tensor& self,
const Scalar& min,
Tensor& result) {
if (min.toDouble() != min.toDouble()) {
at::fill_(const_cast<Tensor&>(result), min);
} else {
auto iter = TensorIterator::unary_op(result, self);
native::xpu::clamp_min_scalar_kernel(iter, min);
}
return result;
}

Tensor& XPUNativeFunctions::clamp_max_out(
const Tensor& self,
const Scalar& max,
Tensor& result) {
if (max.toDouble() != max.toDouble()) {
// TODO this is not great, building TI again is expensive, but I can't use
// fill_stub because fill is not structured
// this is a corner case anyway
at::fill_(const_cast<Tensor&>(result), native::wrapped_scalar_tensor(max));
} else {
auto iter = TensorIterator::unary_op(result, self);
native::xpu::clamp_max_scalar_kernel(iter, max);
}
return result;
}

} // namespace at
110 changes: 110 additions & 0 deletions src/aten/sycl/TensorCompare.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/NumericUtils.h>
#include <ATen/native/TensorCompare.h>
#include <ATen/native/TensorIterator.h>

#include <aten/sycl/Loops.h>

namespace at {
namespace native {
namespace xpu {

template <typename scalar_t>
struct WhereFunctor {
scalar_t operator()(bool cond_val, scalar_t self_val, scalar_t other_val)
const {
return cond_val ? self_val : other_val;
}
};

template <typename scalar_t>
struct ClampFunctor {
scalar_t operator()(scalar_t v, scalar_t lower, scalar_t upper) const {
if (at::_isnan(v)) {
return v;
}
if (at::_isnan(lower)) {
return lower;
}
if (at::_isnan(upper)) {
return upper;
} else {
return std::min(std::max(v, lower), upper);
}
}
};

template <typename scalar_t>
struct ClampScalarFunctor {
using opmath_t = at::opmath_type<scalar_t>;
scalar_t operator()(scalar_t v) const {
if (_isnan(static_cast<opmath_t>(v))) {
return v;
} else if (minmax_ == at::native::detail::ClampLimits::Min) {
return std::max(static_cast<opmath_t>(v), lim0_val_);
} else if (minmax_ == at::native::detail::ClampLimits::Max) {
return std::min(static_cast<opmath_t>(v), lim0_val_);
} else {
return std::min(std::max(static_cast<opmath_t>(v), lim0_val_), lim1_val_);
}
}
ClampScalarFunctor(
opmath_t lim0_val,
opmath_t lim1_val,
at::native::detail::ClampLimits minmax)
: lim0_val_(lim0_val), lim1_val_(lim1_val), minmax_(minmax) {}

private:
opmath_t lim0_val_;
opmath_t lim1_val_;
at::native::detail::ClampLimits minmax_;
};

void where_kernel(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf, kHalf, kBFloat16, kBool, iter.dtype(), "where_xpu", [&] {
gpu_kernel(iter, WhereFunctor<scalar_t>());
});
}

void clamp_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_ALL_TYPES_AND2(
kHalf, kBFloat16, iter.common_dtype(), "clamp_xpu", [&] {
gpu_kernel(iter, ClampFunctor<scalar_t>());
});
}

void inline launch_clamp_scalar(
TensorIteratorBase& iter,
Scalar lim0,
Scalar lim1,
at::native::detail::ClampLimits minmax) {
AT_DISPATCH_ALL_TYPES_AND2(
kHalf, kBFloat16, iter.common_dtype(), "clamp_scalar_xpu", [&] {
using opmath_t = at::opmath_type<scalar_t>;
auto lim0_val = lim0.to<opmath_t>();
auto lim1_val = lim1.to<opmath_t>();
gpu_kernel(
iter, ClampScalarFunctor<scalar_t>(lim0_val, lim1_val, minmax));
});
}

void clamp_scalar_kernel(
TensorIteratorBase& iter,
const Scalar& min,
const Scalar& max) {
launch_clamp_scalar(iter, min, max, at::native::detail::ClampLimits::MinMax);
}

void clamp_min_scalar_kernel(TensorIteratorBase& iter, Scalar min) {
launch_clamp_scalar(iter, min, min, at::native::detail::ClampLimits::Min);
}

void clamp_max_scalar_kernel(TensorIteratorBase& iter, Scalar max) {
launch_clamp_scalar(iter, max, max, at::native::detail::ClampLimits::Max);
}

} // namespace xpu
} // namespace native
} // namespace at
24 changes: 24 additions & 0 deletions src/aten/sycl/TensorCompare.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once

#include <ATen/native/TensorIterator.h>

namespace at {
namespace native {
namespace xpu {

void where_kernel(TensorIterator& iter);

void clamp_kernel(TensorIteratorBase& iter);

void clamp_scalar_kernel(
TensorIteratorBase& iter,
const Scalar& min,
const Scalar& max);

void clamp_min_scalar_kernel(TensorIteratorBase& iter, Scalar min);

void clamp_max_scalar_kernel(TensorIteratorBase& iter, Scalar max);

} // namespace xpu
} // namespace native
} // namespace at
3 changes: 3 additions & 0 deletions test/xpu/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@
"bitwise_or",
"bitwise_xor",
"bitwise_not",
"where",
"clamp_min",
"clamp_max",
"clamp",
]
_xpu_tensor_factory_op_list = [
Expand Down
5 changes: 5 additions & 0 deletions yaml/xpu_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,8 @@ supported:
- bitwise_or.Tensor_out
- bitwise_xor.Tensor_out
- bitwise_not.out
- where.self_out
- where.self
- clamp.out
- clamp_min.out
- clamp_max.out

0 comments on commit 2a992c5

Please sign in to comment.