Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tensor compare ops #80

Merged
merged 3 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions src/aten/TensorCompare.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#include <ATen/ScalarOps.h>
#include <ATen/TensorIndexing.h>
#include <ATen/XPUNativeFunctions.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/TensorCompare.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/TypeProperties.h>
#include <aten/sycl/TensorCompare.h>

namespace at {

template <typename... Args>
Device out_device(Args&... inps) {
for (const auto& i : {inps...}) {
if (i.device() != at::kCPU) {
return i.device();
}
}
return at::kCPU;
}

Tensor& where_self_out(
const Tensor& condition,
const Tensor& self,
const Tensor& other,
Tensor& out) {
const auto result_type = at::native::result_type(self, other);
TORCH_CHECK(
out.scalar_type() == result_type,
"Expected out type to be ",
result_type,
" but got ",
out.scalar_type());

auto self_ = self.scalar_type() != result_type ? self.to(result_type) : self;
auto other_ =
other.scalar_type() != result_type ? other.to(result_type) : other;
auto condition_ = condition;
auto device = out_device(condition, self_, other_);
if (device != at::kCPU) { // allow CPU scalars on non-cpu device
if (condition.device() != device && condition.ndimension() == 0) {
condition_ = condition.to(device);
}
if (self_.device() != device && self_.ndimension() == 0) {
self_ = self_.to(device);
}
if (other_.device() != device && other_.ndimension() == 0) {
other_ = other_.to(device);
}
}
if (condition_.scalar_type() == ScalarType::Byte) {
TORCH_WARN_ONCE(
"where received a uint8 condition tensor. This behavior is deprecated and will be removed in a future version of PyTorch. Use a boolean condition instead.");
condition_ = condition_.to(kBool);
}
TORCH_CHECK(
condition_.scalar_type() == kBool,
"where expected condition to be a boolean tensor, but got a tensor with dtype ",
condition_.scalar_type());
// if there's still a device mismatch, let tensoriterator error out with it
auto iter = at::TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(out)
.add_const_input(condition_)
.add_const_input(self_)
.add_const_input(other_)
.build();
native::xpu::where_kernel(iter);
return out;
}

Tensor& XPUNativeFunctions::where_out(
const Tensor& condition,
const Tensor& self,
const Tensor& other,
Tensor& out) {
return where_self_out(condition, self, other, out);
}

Tensor XPUNativeFunctions::where(
const Tensor& condition,
const Tensor& self,
const Tensor& other) {
auto device = out_device(condition, self, other);
auto result_type = at::native::result_type(self, other);
Tensor ret = at::empty({0}, self.options().dtype(result_type).device(device));
where_self_out(condition, self, other, ret);
return ret;
}

Tensor& XPUNativeFunctions::clamp_out(
const Tensor& self,
const c10::optional<Scalar>& min,
const c10::optional<Scalar>& max,
Tensor& result) {
using at::native::detail::ClampLimits;
if (min && max) {
if ((*min).toDouble() != (*min).toDouble() ||
(*max).toDouble() != (*max).toDouble()) {
at::fill_(
const_cast<Tensor&>(result),
std::numeric_limits<double>::quiet_NaN());
} else {
auto iter = TensorIterator::unary_op(result, self);
native::xpu::clamp_scalar_kernel(iter, *min, *max);
}
} else if (max) {
auto iter = TensorIterator::unary_op(result, self);
native::xpu::clamp_max_scalar_kernel(iter, *max);
} else if (min) {
auto iter = TensorIterator::unary_op(result, self);
native::xpu::clamp_min_scalar_kernel(iter, *min);
}
return result;
}

Tensor& XPUNativeFunctions::clamp_min_out(
const Tensor& self,
const Scalar& min,
Tensor& result) {
if (min.toDouble() != min.toDouble()) {
at::fill_(const_cast<Tensor&>(result), min);
} else {
auto iter = TensorIterator::unary_op(result, self);
native::xpu::clamp_min_scalar_kernel(iter, min);
}
return result;
}

Tensor& XPUNativeFunctions::clamp_max_out(
const Tensor& self,
const Scalar& max,
Tensor& result) {
if (max.toDouble() != max.toDouble()) {
// TODO this is not great, building TI again is expensive, but I can't use
// fill_stub because fill is not structured
// this is a corner case anyway
at::fill_(const_cast<Tensor&>(result), native::wrapped_scalar_tensor(max));
} else {
auto iter = TensorIterator::unary_op(result, self);
native::xpu::clamp_max_scalar_kernel(iter, max);
}
return result;
}

} // namespace at
110 changes: 110 additions & 0 deletions src/aten/sycl/TensorCompare.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/NumericUtils.h>
#include <ATen/native/TensorCompare.h>
#include <ATen/native/TensorIterator.h>

#include <aten/sycl/Loops.h>

namespace at {
namespace native {
namespace xpu {

template <typename scalar_t>
struct WhereFunctor {
scalar_t operator()(bool cond_val, scalar_t self_val, scalar_t other_val)
const {
return cond_val ? self_val : other_val;
}
};

template <typename scalar_t>
struct ClampFunctor {
scalar_t operator()(scalar_t v, scalar_t lower, scalar_t upper) const {
if (at::_isnan(v)) {
return v;
}
if (at::_isnan(lower)) {
return lower;
}
if (at::_isnan(upper)) {
return upper;
} else {
return std::min(std::max(v, lower), upper);
}
}
};

template <typename scalar_t>
struct ClampScalarFunctor {
using opmath_t = at::opmath_type<scalar_t>;
scalar_t operator()(scalar_t v) const {
if (_isnan(static_cast<opmath_t>(v))) {
return v;
} else if (minmax_ == at::native::detail::ClampLimits::Min) {
return std::max(static_cast<opmath_t>(v), lim0_val_);
} else if (minmax_ == at::native::detail::ClampLimits::Max) {
return std::min(static_cast<opmath_t>(v), lim0_val_);
} else {
return std::min(std::max(static_cast<opmath_t>(v), lim0_val_), lim1_val_);
}
}
ClampScalarFunctor(
opmath_t lim0_val,
opmath_t lim1_val,
at::native::detail::ClampLimits minmax)
: lim0_val_(lim0_val), lim1_val_(lim1_val), minmax_(minmax) {}

private:
opmath_t lim0_val_;
opmath_t lim1_val_;
at::native::detail::ClampLimits minmax_;
};

void where_kernel(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf, kHalf, kBFloat16, kBool, iter.dtype(), "where_xpu", [&] {
gpu_kernel(iter, WhereFunctor<scalar_t>());
});
}

void clamp_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_ALL_TYPES_AND2(
kHalf, kBFloat16, iter.common_dtype(), "clamp_xpu", [&] {
gpu_kernel(iter, ClampFunctor<scalar_t>());
});
}

void inline launch_clamp_scalar(
TensorIteratorBase& iter,
Scalar lim0,
Scalar lim1,
at::native::detail::ClampLimits minmax) {
AT_DISPATCH_ALL_TYPES_AND2(
kHalf, kBFloat16, iter.common_dtype(), "clamp_scalar_xpu", [&] {
using opmath_t = at::opmath_type<scalar_t>;
auto lim0_val = lim0.to<opmath_t>();
auto lim1_val = lim1.to<opmath_t>();
gpu_kernel(
iter, ClampScalarFunctor<scalar_t>(lim0_val, lim1_val, minmax));
});
}

void clamp_scalar_kernel(
TensorIteratorBase& iter,
const Scalar& min,
const Scalar& max) {
launch_clamp_scalar(iter, min, max, at::native::detail::ClampLimits::MinMax);
}

void clamp_min_scalar_kernel(TensorIteratorBase& iter, Scalar min) {
launch_clamp_scalar(iter, min, min, at::native::detail::ClampLimits::Min);
}

void clamp_max_scalar_kernel(TensorIteratorBase& iter, Scalar max) {
launch_clamp_scalar(iter, max, max, at::native::detail::ClampLimits::Max);
}

} // namespace xpu
} // namespace native
} // namespace at
24 changes: 24 additions & 0 deletions src/aten/sycl/TensorCompare.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once

#include <ATen/native/TensorIterator.h>

namespace at {
namespace native {
namespace xpu {

void where_kernel(TensorIterator& iter);

void clamp_kernel(TensorIteratorBase& iter);

void clamp_scalar_kernel(
TensorIteratorBase& iter,
const Scalar& min,
const Scalar& max);

void clamp_min_scalar_kernel(TensorIteratorBase& iter, Scalar min);

void clamp_max_scalar_kernel(TensorIteratorBase& iter, Scalar max);

} // namespace xpu
} // namespace native
} // namespace at
3 changes: 3 additions & 0 deletions test/xpu/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@
"bitwise_or",
"bitwise_xor",
"bitwise_not",
"where",
"clamp_min",
"clamp_max",
"clamp",
]
_xpu_tensor_factory_op_list = [
Expand Down
5 changes: 5 additions & 0 deletions yaml/xpu_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,8 @@ supported:
- bitwise_or.Tensor_out
- bitwise_xor.Tensor_out
- bitwise_not.out
- where.self_out
- where.self
- clamp.out
- clamp_min.out
- clamp_max.out
Loading