Skip to content

Commit

Permalink
Implement the lowering for HardSigmoid and backward
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed Apr 20, 2020
1 parent 16ed467 commit a2335ec
Show file tree
Hide file tree
Showing 11 changed files with 139 additions and 10 deletions.
1 change: 1 addition & 0 deletions scripts/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class ArgTemplate(string.Template):
'div_out': FuncOpts(),
'gather_out': FuncOpts(),
'ger_out': FuncOpts(),
'hardsigmoid_out': FuncOpts(),
'kthvalue_out': FuncOpts(),
'index_select_out': FuncOpts(),
'inverse_out': FuncOpts(),
Expand Down
52 changes: 44 additions & 8 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4888,8 +4888,7 @@ TEST_F(AtenXlaTensorTest, TestReluInPlace) {
}

TEST_F(AtenXlaTensorTest, TestHardshrink) {
torch::Tensor input =
torch::randn({100}, torch::TensorOptions(torch::kFloat));
torch::Tensor input = torch::randn({10}, torch::TensorOptions(torch::kFloat));
torch::Tensor output = torch::hardshrink(input);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
Expand All @@ -4898,9 +4897,48 @@ TEST_F(AtenXlaTensorTest, TestHardshrink) {
});
}

TEST_F(AtenXlaTensorTest, TestHardSigmoid) {
torch::Tensor input = torch::randn({10}, torch::TensorOptions(torch::kFloat));
torch::Tensor output = torch::hardsigmoid(input);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor xla_output = torch::hardsigmoid(xla_input);
AllClose(output, xla_output);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::hardsigmoid", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestHardSigmoidInPlace) {
ForEachDevice([&](const torch::Device& device) {
torch::Tensor input =
torch::randn({10}, torch::TensorOptions(torch::kFloat));
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor output = torch::hardsigmoid_(input);
torch::Tensor xla_output = torch::hardsigmoid_(xla_input);
AllClose(input, xla_input);
AllClose(output, xla_output);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::hardsigmoid_", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestHardSigmoidBackward) {
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::hardsigmoid(inputs[0]);
};
ForEachDevice([&](const torch::Device& device) {
TestBackward(
{torch::randn({10},
torch::TensorOptions(torch::kFloat).requires_grad(true))},
device, testfn);
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestSoftshrink) {
torch::Tensor input =
torch::randn({100}, torch::TensorOptions(torch::kFloat));
torch::Tensor input = torch::randn({10}, torch::TensorOptions(torch::kFloat));
torch::Tensor output = torch::softshrink(input);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
Expand All @@ -4910,8 +4948,7 @@ TEST_F(AtenXlaTensorTest, TestSoftshrink) {
}

TEST_F(AtenXlaTensorTest, TestHardtanh) {
torch::Tensor input =
torch::randn({100}, torch::TensorOptions(torch::kFloat));
torch::Tensor input = torch::randn({10}, torch::TensorOptions(torch::kFloat));
torch::Tensor output = torch::hardtanh(input);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
Expand All @@ -4921,8 +4958,7 @@ TEST_F(AtenXlaTensorTest, TestHardtanh) {
}

TEST_F(AtenXlaTensorTest, TestHardtanhInPlace) {
torch::Tensor input =
torch::randn({100}, torch::TensorOptions(torch::kFloat));
torch::Tensor input = torch::randn({10}, torch::TensorOptions(torch::kFloat));
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor output = torch::hardtanh_(input);
Expand Down
2 changes: 1 addition & 1 deletion test/pytorch_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@
'test_nll_loss_empty_tensor_reduction_mean', # floating point division 0 by 0, expecting nan but get 0
'test_fold', # The gradient check code errors out on type() call, and code is slow on XLA
'test_unfold', # The gradient check code errors out on type() call, and code is slow on XLA
'test_hardsigmoid_grad_xla', # gradient check is slow

# test_type_promotion.py
# TestTypePromotion
Expand All @@ -212,7 +213,6 @@
'test_half', # half support
'test_complex_promotion', # complex support
'test_complex_scalar_mult_tensor_promotion', # complex support
'test_hardsigmoid_grad_xla', # FIXEME: accessing storage
}

DISABLED_TORCH_TESTS_TPU = DISABLED_TORCH_TESTS_ANY | {
Expand Down
20 changes: 20 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1399,6 +1399,26 @@ at::Tensor AtenXlaType::hardshrink(const at::Tensor& self, at::Scalar lambda) {
XLATensor::hardshrink(bridge::GetXlaTensor(self), lambda));
}

at::Tensor AtenXlaType::hardsigmoid(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(
XLATensor::hardsigmoid(bridge::GetXlaTensor(self)));
}

at::Tensor& AtenXlaType::hardsigmoid_(at::Tensor& self) {
XLA_FN_COUNTER("xla::");
XLATensor self_tensor = bridge::GetXlaTensor(self);
XLATensor::hardsigmoid_(self_tensor);
return self;
}

at::Tensor AtenXlaType::hardsigmoid_backward(const at::Tensor& grad_output,
const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::hardsigmoid_backward(
bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self)));
}

at::Tensor AtenXlaType::hardshrink_backward(const at::Tensor& grad_out,
const at::Tensor& self,
at::Scalar lambda) {
Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/aten_xla_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,13 @@ class AtenXlaType {
const at::Tensor& self,
at::Scalar lambda);

static at::Tensor hardsigmoid(const at::Tensor& self);

static at::Tensor& hardsigmoid_(at::Tensor& self);

static at::Tensor hardsigmoid_backward(const at::Tensor& grad_output,
const at::Tensor& self);

static at::Tensor hardtanh(const at::Tensor& self, at::Scalar min_val,
at::Scalar max_val);

Expand Down
20 changes: 19 additions & 1 deletion torch_xla/csrc/elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,28 @@ xla::XlaOp BuildHardshrink(xla::XlaOp input, at::Scalar lambda) {
input);
}

xla::XlaOp BuildHardSigmoid(xla::XlaOp input) {
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input);
xla::XlaOp zero = xla::Zero(input.builder(), shape.element_type());
xla::XlaOp three = XlaHelpers::ScalarValue<float>(3.0, shape.element_type(),
input.builder());
xla::XlaOp six = XlaHelpers::ScalarValue<float>(6.0, shape.element_type(),
input.builder());
return xla::Min(xla::Max(input + three, zero), six) / six;
}

xla::XlaOp BuildHardSigmoidBackward(xla::XlaOp grad_output, xla::XlaOp input) {
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input);
xla::XlaOp six = XlaHelpers::ScalarValue<float>(6.0, shape.element_type(),
input.builder());
return xla::Select(Between(input, -3.0, 3.0), grad_output / six,
XlaHelpers::ScalarBroadcast(0, shape, input.builder()));
}

xla::XlaOp BuildSoftshrink(xla::XlaOp input, at::Scalar lambda) {
xla::XlaBuilder* builder = input.builder();
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input);
xla::XlaOp zero = XlaHelpers::ScalarBroadcast(0, shape, builder);
xla::XlaOp zero = xla::Zero(input.builder(), shape.element_type());
xla::XlaOp xla_lambd =
XlaHelpers::ScalarBroadcast(lambda.to<double>(), shape, builder);
xla::XlaOp le_lambda_branch =
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ xla::XlaOp BuildRreluBackward(xla::XlaOp grad_output, xla::XlaOp input,
at::Scalar upper, bool training);

xla::XlaOp BuildHardshrink(xla::XlaOp input, at::Scalar lambda);
xla::XlaOp BuildHardSigmoid(xla::XlaOp input);
xla::XlaOp BuildHardSigmoidBackward(xla::XlaOp grad_output, xla::XlaOp input);
xla::XlaOp BuildSoftshrink(xla::XlaOp input, at::Scalar lambda);
xla::XlaOp BuildShrinkBackward(xla::XlaOp grad_output, xla::XlaOp input,
at::Scalar lambda);
Expand Down
20 changes: 20 additions & 0 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,26 @@ NodePtr TransposeOp(const Value& input, xla::int64 dim0, xla::int64 dim1) {
/*rank=*/input.shape().rank()));
}

NodePtr HardSigmoid(const Value& input) {
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
return node.ReturnOp(BuildHardSigmoid(xla_input), loctx);
};
return GenericOp(OpKind(at::aten::hardsigmoid), {input}, input.shape(),
std::move(lower_fn));
}

NodePtr HardSigmoidBackward(const Value& grad_output, const Value& input) {
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_grad_output = loctx->GetOutputOp(node.operand(0));
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(1));
return node.ReturnOp(BuildHardSigmoidBackward(xla_grad_output, xla_input),
loctx);
};
return GenericOp(OpKind(at::aten::hardsigmoid_backward), {grad_output, input},
input.shape(), std::move(lower_fn));
}

std::tuple<NodePtr, NodePtr> LogSigmoid(const Value& input) {
ScopePusher ir_scope(at::aten::log_sigmoid.toQualString());
// Use log-sum-exp trick to avoid overflow.
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ NodePtr Not(const Value& input);

NodePtr TransposeOp(const Value& input, xla::int64 dim0, xla::int64 dim1);

NodePtr HardSigmoid(const Value& input);

NodePtr HardSigmoidBackward(const Value& grad_output, const Value& input);

std::tuple<NodePtr, NodePtr> LogSigmoid(const Value& input);

NodePtr LogSigmoidBackward(const Value& grad_output, const Value& input,
Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,13 @@ class XLATensor {
const XLATensor& input,
at::Scalar lambda);

static XLATensor hardsigmoid(const XLATensor& input);

static void hardsigmoid_(XLATensor& input);

static XLATensor hardsigmoid_backward(const XLATensor& grad_output,
const XLATensor& input);

static XLATensor hardtanh_backward(const XLATensor& grad_output,
const XLATensor& input, at::Scalar min_val,
at::Scalar max_val);
Expand Down
14 changes: 14 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1366,6 +1366,20 @@ XLATensor XLATensor::hardshrink_backward(const XLATensor& grad_out,
input.GetIrValue(), lambda));
}

XLATensor XLATensor::hardsigmoid(const XLATensor& input) {
return input.CreateFrom(ir::ops::HardSigmoid(input.GetIrValue()));
}

void XLATensor::hardsigmoid_(XLATensor& input) {
input.SetIrValue(ir::ops::HardSigmoid(input.GetIrValue()));
}

XLATensor XLATensor::hardsigmoid_backward(const XLATensor& grad_output,
const XLATensor& input) {
return input.CreateFrom(ir::ops::HardSigmoidBackward(grad_output.GetIrValue(),
input.GetIrValue()));
}

XLATensor XLATensor::hardtanh_backward(const XLATensor& grad_output,
const XLATensor& input,
at::Scalar min_val, at::Scalar max_val) {
Expand Down

0 comments on commit a2335ec

Please sign in to comment.