Skip to content

Commit

Permalink
#13928: update binary bw doc
Browse files Browse the repository at this point in the history
  • Loading branch information
KalaivaniMCW committed Oct 24, 2024
1 parent 002ca3a commit ecaa212
Show file tree
Hide file tree
Showing 12 changed files with 280 additions and 35 deletions.
8 changes: 4 additions & 4 deletions docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,10 @@ Pointwise Unary
ttnn.tanhshrink
ttnn.threshold
ttnn.trunc
ttnn.mul_bw
ttnn.clamp_bw
ttnn.hardtanh_bw
ttnn.threshold_bw
ttnn.softplus_bw
ttnn.div_bw
ttnn.rdiv_bw
ttnn.bias_gelu_bw
ttnn.pow_bw
Expand All @@ -207,7 +205,6 @@ Pointwise Unary
ttnn.sqrt_bw
ttnn.assign_bw
ttnn.multigammaln_bw
ttnn.add_bw
ttnn.lgamma_bw
ttnn.fill_bw
ttnn.hardsigmoid_bw
Expand All @@ -216,7 +213,6 @@ Pointwise Unary
ttnn.acos_bw
ttnn.atan_bw
ttnn.rad2deg_bw
ttnn.sub_bw
ttnn.frac_bw
ttnn.trunc_bw
ttnn.log_sigmoid_bw
Expand Down Expand Up @@ -344,7 +340,9 @@ Pointwise Binary
ttnn.polyval
ttnn.scatter
ttnn.atan2
ttnn.add_bw
ttnn.atan2_bw
ttnn.div_bw
ttnn.embedding_bw
ttnn.addalpha_bw
ttnn.subalpha_bw
Expand All @@ -353,6 +351,8 @@ Pointwise Binary
ttnn.ldexp_bw
ttnn.logaddexp_bw
ttnn.logaddexp2_bw
ttnn.mul_bw
ttnn.sub_bw
ttnn.squared_difference_bw
ttnn.concat_bw
ttnn.rsub_bw
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import torch
import pytest
import ttnn
from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc
from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import (
data_gen_with_range,
compare_pcc,
data_gen_with_range_dtype,
)


@pytest.mark.parametrize(
Expand All @@ -30,6 +34,28 @@ def test_bw_add(input_shapes, device):
assert status


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_add_bf8b(input_shapes, device):
in_data, input_tensor = data_gen_with_range_dtype(input_shapes, -100, 100, device, True, False, ttnn.bfloat8_b)
other_data, other_tensor = data_gen_with_range_dtype(input_shapes, -100, 100, device, True, False, ttnn.bfloat8_b)
grad_data, grad_tensor = data_gen_with_range_dtype(input_shapes, -100, 100, device, False, False, ttnn.bfloat8_b)

tt_output_tensor_on_device = ttnn.add_bw(grad_tensor, input_tensor, other_tensor)

golden_function = ttnn.get_golden_function(ttnn.add_bw)
golden_tensor = golden_function(grad_data, in_data, other_data)

status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status


@pytest.mark.parametrize(
"input_shapes",
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import torch
import pytest
import ttnn
from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc
from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import (
data_gen_with_range,
compare_pcc,
data_gen_with_range_dtype,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -51,6 +55,27 @@ def test_bw_binary_assign(input_shapes, device):
assert status


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_binary_assign_bf8b(input_shapes, device):
in_data, input_tensor = data_gen_with_range_dtype(input_shapes, -100, 100, device, True, False, ttnn.bfloat8_b)
other_data, other_tensor = data_gen_with_range_dtype(input_shapes, -100, 100, device, True, False, ttnn.bfloat8_b)
grad_data, grad_tensor = data_gen_with_range_dtype(input_shapes, -100, 100, device, False, False, ttnn.bfloat8_b)

tt_output_tensor_on_device = ttnn.assign_bw(grad_tensor, input_tensor, other_tensor)

golden_function = ttnn.get_golden_function(ttnn.assign_bw)
golden_tensor = golden_function(grad_data, in_data, other_data)
status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status


@pytest.mark.parametrize(
"input_shapes",
(
Expand Down Expand Up @@ -78,6 +103,33 @@ def test_bw_unary_assign_opt_output(input_shapes, device):
assert status


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_unary_assign_opt_output_rm(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device, False, True)
opt_tensor = torch.zeros(input_shapes, dtype=torch.bfloat16)
input_grad = ttnn.from_torch(
opt_tensor, ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG
)
pages_before = ttnn._ttnn.reports.get_buffer_pages()
ttnn.assign_bw(grad_tensor, input_tensor, input_a_grad=input_grad, queue_id=0)
assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages())

tt_output_tensor_on_device = [input_grad]
golden_function = ttnn.get_golden_function(ttnn.assign_bw)
golden_tensor = golden_function(grad_data, in_data)

status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status


@pytest.mark.parametrize(
"input_shapes",
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import (
data_gen_with_range,
data_gen_with_val,
data_gen_with_range_dtype,
compare_pcc,
)
from models.utility_functions import (
Expand Down Expand Up @@ -182,6 +183,28 @@ def test_bw_unary_div_default(input_shapes, scalar, device):
assert status


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize("scalar", [0.05, 1.0, 0.5, 0.12, 0.0, -0.05, -1.0, -0.5, -0.12])
def test_bw_unary_div_bf8b(input_shapes, scalar, device):
in_data, input_tensor = data_gen_with_range_dtype(input_shapes, -100, 100, device, True, False, ttnn.bfloat8_b)
grad_data, grad_tensor = data_gen_with_range_dtype(input_shapes, -1, 1, device, False, False, ttnn.bfloat8_b)

tt_output_tensor_on_device = ttnn.div_bw(grad_tensor, input_tensor, scalar)

golden_function = ttnn.get_golden_function(ttnn.div_bw)
golden_tensor = golden_function(grad_data, in_data, scalar)

status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status


@pytest.mark.parametrize(
"input_shapes",
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import torch
import pytest
import ttnn
from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc
from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import (
data_gen_with_range,
compare_pcc,
data_gen_with_range_dtype,
)


@pytest.mark.parametrize(
Expand All @@ -28,3 +32,25 @@ def test_bw_ldexp(input_shapes, device):
golden_tensor = golden_function(grad_data, in_data, other_data)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_ldexp_bf8b(input_shapes, device):
in_data, input_tensor = data_gen_with_range_dtype(input_shapes, -10, 10, device, True, False, ttnn.bfloat8_b)
other_data, other_tensor = data_gen_with_range_dtype(input_shapes, -20, 20, device, True, False, ttnn.bfloat8_b)

grad_data, grad_tensor = data_gen_with_range_dtype(input_shapes, -5, 5, device, False, False, ttnn.bfloat8_b)

tt_output_tensor_on_device = ttnn.ldexp_bw(grad_tensor, input_tensor, other_tensor)

golden_function = ttnn.get_golden_function(ttnn.ldexp_bw)
golden_tensor = golden_function(grad_data, in_data, other_data)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import torch
import pytest
import ttnn
from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc
from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import (
data_gen_with_range,
compare_pcc,
data_gen_with_range_dtype,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -100,6 +104,28 @@ def test_bw_mul_scalar(input_shapes, scalar, device):
assert status


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize("scalar", [0.05, 1.0, 0.5, 0.12, 0.0, -0.05, -1.0, -0.5, -0.12])
def test_bw_mul_scalar_bf8b(input_shapes, scalar, device):
in_data, input_tensor = data_gen_with_range_dtype(input_shapes, -100, 100, device, True, False, ttnn.bfloat8_b)
grad_data, grad_tensor = data_gen_with_range_dtype(input_shapes, -5, 5, device, False, False, ttnn.bfloat8_b)

tt_output_tensor_on_device = ttnn.mul_bw(grad_tensor, input_tensor, scalar)

golden_function = ttnn.get_golden_function(ttnn.mul_bw)
golden_tensor = golden_function(grad_data, in_data, scalar)

status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status


@pytest.mark.parametrize(
"input_shapes",
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import torch
import pytest
import ttnn
from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc
from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import (
data_gen_with_range,
compare_pcc,
data_gen_with_range_dtype,
)


@pytest.mark.parametrize(
Expand All @@ -30,6 +34,28 @@ def test_bw_sub(input_shapes, device):
assert status


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_sub_bf8b(input_shapes, device):
in_data, input_tensor = data_gen_with_range_dtype(input_shapes, -100, 100, device, True, False, ttnn.bfloat8_b)
other_data, other_tensor = data_gen_with_range_dtype(input_shapes, -100, 100, device, True, False, ttnn.bfloat8_b)
grad_data, grad_tensor = data_gen_with_range_dtype(input_shapes, -100, 100, device, False, False, ttnn.bfloat8_b)

tt_output_tensor_on_device = ttnn.sub_bw(grad_tensor, input_tensor, other_tensor)

golden_function = ttnn.get_golden_function(ttnn.sub_bw)
golden_tensor = golden_function(grad_data, in_data, other_data)

status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status


@pytest.mark.parametrize(
"input_shapes",
(
Expand Down
13 changes: 13 additions & 0 deletions tests/ttnn/unit_tests/operations/eltwise/backward/utility_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@ def data_gen_with_range(input_shapes, low, high, device, required_grad=False, is
return pt_tensor, tt_tensor


def data_gen_with_range_dtype(
input_shapes, low, high, device, required_grad=False, is_row_major=False, ttnn_dtype=ttnn.bfloat16
):
assert high > low, "Incorrect range provided"
torch.manual_seed(213919)
pt_tensor = torch.rand(input_shapes, requires_grad=required_grad).bfloat16() * (high - low) + low
if is_row_major:
tt_tensor = ttnn.Tensor(pt_tensor, ttnn_dtype).to(ttnn.ROW_MAJOR_LAYOUT).to(device)
else:
tt_tensor = ttnn.Tensor(pt_tensor, ttnn_dtype).to(ttnn.TILE_LAYOUT).to(device)
return pt_tensor, tt_tensor


def data_gen_with_range_int(input_shapes, low, high, device, required_grad=False, is_row_major=False):
assert high > low, "Incorrect range provided"
torch.manual_seed(213919)
Expand Down
Loading

0 comments on commit ecaa212

Please sign in to comment.