Skip to content

Commit

Permalink
Lower Lerp
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed Jun 2, 2021
1 parent f34281f commit cf3e233
Show file tree
Hide file tree
Showing 7 changed files with 227 additions and 0 deletions.
109 changes: 109 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10272,5 +10272,114 @@ TEST_F(AtenXlaTensorTest, TestEarlySyncLiveTensors) {
cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestLerp) {
torch::Tensor start =
torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor weight =
torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor res = torch::lerp(start, end, weight);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_start = CopyToDevice(start, device);
torch::Tensor xla_end = CopyToDevice(end, device);
torch::Tensor xla_weight = CopyToDevice(weight, device);
torch::Tensor xla_res = torch::lerp(xla_start, xla_end, xla_weight);
AllClose(res, xla_res);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::lerp", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestLerpScalar) {
torch::Tensor start =
torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Scalar weight = torch::Scalar(3.0);
torch::Tensor res = torch::lerp(start, end, weight);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_start = CopyToDevice(start, device);
torch::Tensor xla_end = CopyToDevice(end, device);
torch::Tensor xla_res = torch::lerp(xla_start, xla_end, weight);
AllClose(res, xla_res);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::lerp", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestLerpInplace) {
torch::Tensor input =
torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor weight =
torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor input_copy = input.clone();
input.lerp_(end, weight);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input_copy, device);
torch::Tensor xla_end = CopyToDevice(end, device);
torch::Tensor xla_weight = CopyToDevice(weight, device);
xla_input.lerp_(xla_end, xla_weight);
AllClose(xla_input, input);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::lerp_", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestLerpScalarInplace) {
torch::Tensor input =
torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Scalar weight = torch::Scalar(3.0);
torch::Tensor input_copy = input.clone();
input.lerp_(end, weight);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input_copy, device);
torch::Tensor xla_end = CopyToDevice(end, device);
xla_input.lerp_(xla_end, weight);
AllClose(xla_input, input);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::lerp_", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestLerpOut) {
torch::Tensor start =
torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor weight =
torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor res = torch::empty({3, 4}, torch::TensorOptions(torch::kFloat));
;
torch::lerp_out(res, start, end, weight);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_start = CopyToDevice(start, device);
torch::Tensor xla_end = CopyToDevice(end, device);
torch::Tensor xla_weight = CopyToDevice(weight, device);
torch::Tensor xla_res = torch::empty({3, 4}, xla_start.options());
torch::lerp_out(xla_res, xla_start, xla_end, xla_weight);
AllClose(res, xla_res);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::lerp_out", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestLerpScalarOut) {
torch::Tensor start =
torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Scalar weight = torch::Scalar(3.0);
torch::Tensor res = torch::empty({3, 4}, torch::TensorOptions(torch::kFloat));
torch::lerp_out(res, start, end, weight);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_start = CopyToDevice(start, device);
torch::Tensor xla_end = CopyToDevice(end, device);
torch::Tensor xla_res = torch::empty({3, 4}, xla_start.options());
torch::lerp_out(xla_res, xla_start, xla_end, weight);
AllClose(res, xla_res);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::lerp_out", cpp_test::GetIgnoredCounters());
}

} // namespace cpp_test
} // namespace torch_xla
50 changes: 50 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1513,6 +1513,56 @@ at::Tensor leaky_relu_backward(const at::Tensor& grad_output,
negative_slope.to<double>()));
}

at::Tensor lerp(const at::Tensor& self, const at::Tensor& end,
const at::Tensor& weight) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(
XLATensor::lerp(bridge::GetXlaTensor(self), bridge::GetXlaTensor(end),
bridge::GetXlaTensor(weight)));
}

at::Tensor lerp(const at::Tensor& self, const at::Tensor& end,
const at::Scalar& weight) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::lerp(
bridge::GetXlaTensor(self), bridge::GetXlaTensor(end), weight));
}

at::Tensor& lerp_(at::Tensor& self, const at::Tensor& end,
const at::Tensor& weight) {
XLA_FN_COUNTER("xla::");
XLATensor self_tensor = bridge::GetXlaTensor(self);
XLATensor::lerp_(self_tensor, bridge::GetXlaTensor(end),
bridge::GetXlaTensor(weight));
return self;
}

at::Tensor& lerp_(at::Tensor& self, const at::Tensor& end,
const at::Scalar& weight) {
XLA_FN_COUNTER("xla::");
XLATensor self_tensor = bridge::GetXlaTensor(self);
XLATensor::lerp_(self_tensor, bridge::GetXlaTensor(end), weight);
return self;
}

at::Tensor& lerp_out(const at::Tensor& self, const at::Tensor& end,
const at::Tensor& weight, at::Tensor& out) {
XLA_FN_COUNTER("xla::");
XLATensor out_tensor = bridge::GetXlaTensor(out);
XLATensor::lerp_out(out_tensor, bridge::GetXlaTensor(self),
bridge::GetXlaTensor(end), bridge::GetXlaTensor(weight));
return out;
}

at::Tensor& lerp_out(const at::Tensor& self, const at::Tensor& end,
const at::Scalar& weight, at::Tensor& out) {
XLA_FN_COUNTER("xla::");
XLATensor out_tensor = bridge::GetXlaTensor(out);
XLATensor::lerp_out(out_tensor, bridge::GetXlaTensor(self),
bridge::GetXlaTensor(end), weight);
return out;
}

at::Tensor log(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::log(bridge::GetXlaTensor(self)));
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,11 @@ NodePtr BaddBmm(const Value& lhs, const Value& rhs, const Value& bias,
std::move(lower_fn));
}

NodePtr Lerp(const Value& start, const Value& end, const Value& weight) {
ScopePusher ir_scope(at::aten::lerp.toQualString());
return start + weight * (end - start);
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
2 changes: 2 additions & 0 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ NodePtr Inverse(const Value& input);
NodePtr BaddBmm(const Value& lhs, const Value& rhs, const Value& bias,
const Value& product_multiplier, const Value& bias_multiplier);

NodePtr Lerp(const Value& start, const Value& end, const Value& weight);

} // namespace ops
} // namespace ir
} // namespace torch_xla
13 changes: 13 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,19 @@ class XLATensor {
const XLATensor& input,
double negative_slope);

static XLATensor lerp(const XLATensor& input, const XLATensor& end,
const XLATensor& weight);
static XLATensor lerp(const XLATensor& input, const XLATensor& end,
const at::Scalar& weight);
static void lerp_(XLATensor& input, const XLATensor& end,
const XLATensor& weight);
static void lerp_(XLATensor& input, const XLATensor& end,
const at::Scalar& weight);
static void lerp_out(XLATensor& out, const XLATensor& input,
const XLATensor& end, const XLATensor& weight);
static void lerp_out(XLATensor& out, const XLATensor& input,
const XLATensor& end, const at::Scalar& weight);

static XLATensor log(const XLATensor& input);

static XLATensor log_base(const XLATensor& input, ir::OpKind op, double base);
Expand Down
42 changes: 42 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1453,6 +1453,48 @@ XLATensor XLATensor::leaky_relu_backward(const XLATensor& grad_output,
grad_output.GetIrValue(), input.GetIrValue(), negative_slope));
}

XLATensor XLATensor::lerp(const XLATensor& input, const XLATensor& end,
const XLATensor& weight) {
return input.CreateFrom(
ir::ops::Lerp(input.GetIrValue(), end.GetIrValue(), weight.GetIrValue()));
}

XLATensor XLATensor::lerp(const XLATensor& input, const XLATensor& end,
const at::Scalar& weight) {
ir::Value weight_val = GetIrValueForScalar(
weight, input.shape().get().element_type(), input.GetDevice());
return input.CreateFrom(
ir::ops::Lerp(input.GetIrValue(), end.GetIrValue(), weight_val));
}

void XLATensor::lerp_(XLATensor& input, const XLATensor& end,
const XLATensor& weight) {
input.SetInPlaceIrValue(
ir::ops::Lerp(input.GetIrValue(), end.GetIrValue(), weight.GetIrValue()));
}

void XLATensor::lerp_(XLATensor& input, const XLATensor& end,
const at::Scalar& weight) {
ir::Value weight_val = GetIrValueForScalar(
weight, input.shape().get().element_type(), input.GetDevice());
input.SetInPlaceIrValue(
ir::ops::Lerp(input.GetIrValue(), end.GetIrValue(), weight_val));
}

void XLATensor::lerp_out(XLATensor& out, const XLATensor& input,
const XLATensor& end, const XLATensor& weight) {
out.SetInPlaceIrValue(
ir::ops::Lerp(input.GetIrValue(), end.GetIrValue(), weight.GetIrValue()));
}

void XLATensor::lerp_out(XLATensor& out, const XLATensor& input,
const XLATensor& end, const at::Scalar& weight) {
ir::Value weight_val = GetIrValueForScalar(
weight, input.shape().get().element_type(), input.GetDevice());
out.SetInPlaceIrValue(
ir::ops::Lerp(input.GetIrValue(), end.GetIrValue(), weight_val));
}

XLATensor XLATensor::log(const XLATensor& input) {
return input.CreateFrom(ir::ops::Log(input.GetIrValue()));
}
Expand Down
6 changes: 6 additions & 0 deletions xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,12 @@ supported:
- sigmoid_backward
- tanh_backward
- ger
- lerp_.Scalar
- lerp_.Tensor
- lerp.Scalar_out
- lerp.Tensor_out
- lerp.Scalar
- lerp.Tensor
autograd:
- max_pool2d
- max_pool3d

0 comments on commit cf3e233

Please sign in to comment.