Skip to content

Commit

Permalink
#6536: Fix asin backward op
Browse files Browse the repository at this point in the history
  • Loading branch information
umadevimcw committed Mar 22, 2024
1 parent 3528483 commit 88ab23f
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@ def test_bw_ldexp(input_shapes, device):
pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad, other_data.grad]
comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor, pcc=0.97)
comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
4 changes: 2 additions & 2 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ std::vector<Tensor> _ldexp_bw(const Tensor& grad, const Tensor& input, const Ten
std::vector<Tensor> grad_tensor;
Tensor tpow_o = mul(grad, rpow(other, 2.0, output_mem_config), std::nullopt, output_mem_config);
grad_tensor.emplace_back(tpow_o);
Tensor result = mul(grad, mul(input, mul_unary(tpow_o, M_LN2, output_mem_config), std::nullopt, output_mem_config), std::nullopt, output_mem_config);
Tensor result = mul(input, mul_unary(tpow_o, M_LN2, output_mem_config), std::nullopt, output_mem_config);
grad_tensor.emplace_back(result);
return grad_tensor;
}
Expand Down Expand Up @@ -980,7 +980,7 @@ std::vector<Tensor> atanh_bw(const Tensor& grad, const Tensor& input, const Memo
// result: grad * (-self * self + 1).rsqrt()
std::vector<Tensor> _asin_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor grad_result = mul(grad, rsqrt(add1(neg(square(grad, output_mem_config), output_mem_config), output_mem_config), true, output_mem_config), std::nullopt, output_mem_config);
Tensor grad_result = mul(grad, rsqrt(add1(neg(square(input, output_mem_config), output_mem_config), output_mem_config), true, output_mem_config), std::nullopt, output_mem_config);
grad_tensor.emplace_back(grad_result);
return grad_tensor;
}
Expand Down

0 comments on commit 88ab23f

Please sign in to comment.