Skip to content

Commit

Permalink
#6443: Update backward ops
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Apr 3, 2024
1 parent b688dd8 commit 13edc65
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_results, data_gen_pt_tt
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_pcc, data_gen_with_range


@pytest.mark.parametrize(
Expand All @@ -17,12 +17,8 @@
),
)
def test_bw_digamma(input_shapes, device):
grad_data, grad_tensor = data_gen_pt_tt(input_shapes, device)
in_data = torch.Tensor(size=input_shapes).uniform_()
in_data.requires_grad = True
input_tensor = (
tt_lib.tensor.Tensor(in_data, tt_lib.tensor.DataType.BFLOAT16).to(tt_lib.tensor.Layout.TILE).to(device)
)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 199, device)
in_data, input_tensor = data_gen_with_range(input_shapes, 1, 10, device, required_grad=True)

pyt_y = torch.digamma(in_data)

Expand All @@ -33,5 +29,5 @@ def test_bw_digamma(input_shapes, device):
pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad]
comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_results, data_gen_pt_tt
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import compare_pcc, data_gen_with_range


@pytest.mark.parametrize(
Expand All @@ -17,12 +17,9 @@
),
)
def test_bw_i0(input_shapes, device):
in_data = torch.Tensor(size=input_shapes).uniform_()
in_data.requires_grad = True
input_tensor = (
tt_lib.tensor.Tensor(in_data, tt_lib.tensor.DataType.BFLOAT16).to(tt_lib.tensor.Layout.TILE).to(device)
)
grad_data, grad_tensor = data_gen_pt_tt(input_shapes, device)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 199, device)
in_data, input_tensor = data_gen_with_range(input_shapes, -10, 10, device, required_grad=True)
in_data = in_data.float()

pyt_y = torch.i0(in_data)

Expand All @@ -33,5 +30,5 @@ def test_bw_i0(input_shapes, device):
pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad]
comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
24 changes: 11 additions & 13 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -891,19 +891,12 @@ float factorial(int n) {

std::vector<Tensor> _i0_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;

Tensor result=zeros_like(input);
Tensor term=zeros_like(input);
Tensor final_res=zeros_like(input);

float fact;
for (int i=0; i<100; i++){
fact=factorial(i);
term = mul_unary(power(div_unary(input, 2.0, output_mem_config), 2*i-1, output_mem_config), i / (fact*fact), output_mem_config);
result = add(result,term);
}
final_res= mul(result, grad, std::nullopt, output_mem_config);
grad_tensor.emplace_back(final_res);
float t_inf = std::numeric_limits<float>::infinity();
Tensor value = mul_unary(0.5, mul(i0(input, output_mem_config), recip(input, output_mem_config), std::nullopt, output_mem_config), output_mem_config);
Tensor result = where(ltz(input, output_mem_config), mul(grad, sub(neg(i0(input, output_mem_config), output_mem_config), value, std::nullopt, output_mem_config), std::nullopt, output_mem_config), mul(grad, sub(i0(input, output_mem_config), value, std::nullopt, output_mem_config), std::nullopt, output_mem_config), output_mem_config);
result = where(gte_unary(abs(i0(input, output_mem_config), output_mem_config), 3.4e+38, output_mem_config), t_inf, result, output_mem_config);
result = where(gte_unary(abs(result, output_mem_config), 3.4e+38, output_mem_config), t_inf, result, output_mem_config);
grad_tensor.emplace_back(result);
return grad_tensor;
}
std::vector<Tensor> i0_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config)
Expand Down Expand Up @@ -1335,7 +1328,12 @@ std::vector<Tensor> erfc_bw(const Tensor& grad, const Tensor& input, const Memor

std::vector<Tensor> _digamma_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
float t_inf = std::numeric_limits<float>::infinity();
float t_nan = std::nanf("");
Tensor grad_a = mul(grad, polygamma(input, 1, output_mem_config), std::nullopt, output_mem_config);
grad_a = where(logical_and(eqz(input, output_mem_config), eqz(grad, output_mem_config), std::nullopt, output_mem_config), t_nan, grad_a, output_mem_config);
grad_a = where(logical_and(eqz(input, output_mem_config), ltz(grad, output_mem_config), std::nullopt, output_mem_config), -t_inf, grad_a, output_mem_config);
grad_a = where(logical_and(eqz(input, output_mem_config), gtz(grad, output_mem_config), std::nullopt, output_mem_config), t_inf, grad_a, output_mem_config);
grad_tensor.emplace_back(grad_a);
return grad_tensor;
}
Expand Down

0 comments on commit 13edc65

Please sign in to comment.