Skip to content

Commit

Permalink
#6991: Remove cskip test
Browse files Browse the repository at this point in the history
  • Loading branch information
umadevimcw committed Sep 5, 2024
1 parent b2bcee9 commit 9540dee
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 56 deletions.
4 changes: 2 additions & 2 deletions tests/ttnn/unit_tests/operations/test_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_unary_composite_acosh_ttnn(input_shapes, device):

output_tensor = ttnn.acosh(input_tensor1)
golden_function = ttnn.get_golden_function(ttnn.acosh)
golden_tensor = golden_function(in_data1)
golden_tensor = golden_function(in_data1, device=device)

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass
Expand Down Expand Up @@ -738,7 +738,7 @@ def test_unary_logit(input_shapes, param, device):
in_data, input_tensor = data_gen_with_range(input_shapes, 0, 1, device)
output_tensor = ttnn.logit(input_tensor, eps=param)
golden_function = ttnn.get_golden_function(ttnn.logit)
golden_tensor = golden_function(in_data, eps=param)
golden_tensor = golden_function(in_data, eps=param, device=device)

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass
Expand Down
4 changes: 2 additions & 2 deletions tests/ttnn/unit_tests/operations/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def run_math_unary_test(device, h, w, ttnn_function, pcc=0.9999):
if "digamma" in str(ttnn_function):
torch_input_tensor += 100.0
golden_function = ttnn.get_golden_function(ttnn_function)
torch_output_tensor = golden_function(torch_input_tensor, device=device)
torch_output_tensor = golden_function(torch_input_tensor)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = ttnn_function(input_tensor)
Expand Down Expand Up @@ -277,7 +277,7 @@ def run_math_unary_test_range(device, h, w, ttnn_function, pcc=0.9999):

torch_input_tensor = torch_random((h, w), low, high, dtype=torch.bfloat16)
golden_function = ttnn.get_golden_function(ttnn_function)
torch_output_tensor = golden_function(torch_input_tensor, device=device)
torch_output_tensor = golden_function(torch_input_tensor)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = ttnn_function(input_tensor)
Expand Down
12 changes: 6 additions & 6 deletions tests/ttnn/unit_tests/operations/test_unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def run_identity_test(device, h, w, data_type, pcc=0.9999):
# run torch
torch_input_tensor = torch_input_tensor + bias
golden_function = ttnn.get_golden_function(ttnn_function)
torch_output_tensor = golden_function(torch_input_tensor, device=device)
torch_output_tensor = golden_function(torch_input_tensor)

# run tt
input_tensor = ttnn.from_torch(torch_input_tensor, data_type, layout=ttnn.TILE_LAYOUT, device=device)
Expand All @@ -74,7 +74,7 @@ def run_identity_test(device, h, w, data_type, pcc=0.9999):
# run torch
torch_input_tensor = torch_input_tensor + bias
golden_function = ttnn.get_golden_function(ttnn_function)
torch_output_tensor = golden_function(torch_input_tensor, device=device)
torch_output_tensor = golden_function(torch_input_tensor)

# run tt
input_tensor = ttnn.from_torch(torch_input_tensor, data_type, layout=ttnn.TILE_LAYOUT, device=device)
Expand All @@ -92,7 +92,7 @@ def run_identity_test(device, h, w, data_type, pcc=0.9999):
# run torch
torch_input_tensor = torch_input_tensor + bias
golden_function = ttnn.get_golden_function(ttnn_function)
torch_output_tensor = golden_function(torch_input_tensor, device=device)
torch_output_tensor = golden_function(torch_input_tensor)

# run tt
input_tensor = ttnn.from_torch(torch_input_tensor, data_type, layout=ttnn.TILE_LAYOUT, device=device)
Expand All @@ -110,7 +110,7 @@ def run_identity_test(device, h, w, data_type, pcc=0.9999):
# run torch
torch_input_tensor = torch_input_tensor + bias
golden_function = ttnn.get_golden_function(ttnn_function)
torch_output_tensor = golden_function(torch_input_tensor, device=device)
torch_output_tensor = golden_function(torch_input_tensor)

# run tt
input_tensor = ttnn.from_torch(torch_input_tensor, data_type, layout=ttnn.TILE_LAYOUT, device=device)
Expand All @@ -127,7 +127,7 @@ def run_identity_test(device, h, w, data_type, pcc=0.9999):
# run torch
torch_input_tensor = torch_input_tensor
golden_function = ttnn.get_golden_function(ttnn_function)
torch_output_tensor = golden_function(torch_input_tensor, device=device)
torch_output_tensor = golden_function(torch_input_tensor)

# run tt
input_tensor = ttnn.from_torch(torch_input_tensor, data_type, layout=ttnn.TILE_LAYOUT, device=device)
Expand Down Expand Up @@ -268,7 +268,7 @@ def run_unary_test_range(device, h, w, ttnn_function, pcc=0.9999):
torch_input_tensor = torch_random((h, w), low, high, dtype=torch.bfloat16)

golden_function = ttnn.get_golden_function(ttnn_function)
torch_output_tensor = golden_function(torch_input_tensor, device=device)
torch_output_tensor = golden_function(torch_input_tensor)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = ttnn_function(input_tensor)
Expand Down
7 changes: 3 additions & 4 deletions tt_metal/impl/device/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,11 @@ static constexpr float EPS_WHB0 = 1.19209e-7f;
static constexpr float EPS_BH = EPS_WHB0;

static constexpr float NAN_GS = 6.9752e19;
static constexpr float INF_GS = 1.6948e38;

static constexpr float NAN_WHB0 = 7.0040e+19;
static constexpr float INF_WHB0 = 1.7014e+38;

static constexpr float NAN_BH = NAN_WHB0;

static constexpr float INF_GS = 1.6948e38;
static constexpr float INF_WHB0 = 1.7014e+38;
static constexpr float INF_BH = INF_WHB0;

// A physical PCIexpress Tenstorrent device
Expand Down
36 changes: 9 additions & 27 deletions ttnn/cpp/pybind11/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,36 +116,18 @@ void device_module(py::module &m_device) {

m_device.attr("INF_GS") = INF_GS;
m_device.attr("INF_WHB0") = INF_WHB0;
m_device.attr("INF_BH") = INF_BH
m_device.attr("INF_BH") = INF_BH;

pyDevice.def("sfpu_eps", &Device::sfpu_eps, R"doc(
Machine epsilon value for current device.
+------------------+------------------------+-----------------------+-------------+----------+
| Argument | Description | Data type | Valid range | Required |
+==================+========================+=======================+=============+==========+
| device | return machine epsilon | tt_lib.device.Device | NA | Yes |
+------------------+------------------------+-----------------------+-------------+----------+
Returns machine epsilon value for current device.
)doc");

pyDevice.def("sfpu_nan", &Device::sfpu_nan, R"doc(
NaN value for current device.
+------------------+------------------------+-----------------------+-------------+----------+
| Argument | Description | Data type | Valid range | Required |
+==================+========================+=======================+=============+==========+
| device | return machine NaN | ttnn.device.Device | NA | Yes |
+------------------+------------------------+-----------------------+-------------+----------+
Returns NaN value for current device.
)doc");

pyDevice.def("sfpu_inf", &Device::sfpu_inf, R"doc(
Infinity value for current device.
+------------------+------------------------+-----------------------+-------------+----------+
| Argument | Description | Data type | Valid range | Required |
+==================+========================+=======================+=============+==========+
| device | return machine Inf | ttnn.device.Device | NA | Yes |
+------------------+------------------------+-----------------------+-------------+----------+
Returns Infinity value for current device.
)doc");

m_device.def(
Expand Down Expand Up @@ -190,7 +172,7 @@ void device_module(py::module &m_device) {
+------------------+------------------------+-----------------------+-------------+----------+
| Argument | Description | Data type | Valid range | Required |
+==================+========================+=======================+=============+==========+
| device | TT Device to close | tt_lib.device.Device | | Yes |
| device | TT Device to close | ttnn.Device | | Yes |
+------------------+------------------------+-----------------------+-------------+----------+
)doc");
m_device.def("CloseDevices", &tt::tt_metal::detail::CloseDevices, R"doc(
Expand All @@ -199,7 +181,7 @@ void device_module(py::module &m_device) {
+------------------+------------------------+-----------------------+-------------+----------+
| Argument | Description | Data type | Valid range | Required |
+==================+========================+=======================+=============+==========+
| device | TT Device to close | tt_lib.device.Device | | Yes |
| device | TT Device to close | ttnn.Device | | Yes |
+------------------+------------------------+-----------------------+-------------+----------+
)doc");

Expand All @@ -221,7 +203,7 @@ void device_module(py::module &m_device) {
+------------------+------------------------+-----------------------+-------------+----------+
| Argument | Description | Data type | Valid range | Required |
+==================+========================+=======================+=============+==========+
| device | TT Device to use | tt_lib.device.Device | | Yes |
| device | TT Device to use | ttnn.Device | | Yes |
+------------------+------------------------+-----------------------+-------------+----------+
)doc");

Expand Down Expand Up @@ -287,7 +269,7 @@ void device_module(py::module &m_device) {
+------------------+----------------------------------+-----------------------+-------------+----------+
| Argument | Description | Data type | Valid range | Required |
+==================+==================================+=======================+=============+==========+
| device | Device to dump memory state for | tt_lib.device.Device | | Yes |
| device | Device to dump memory state for | ttnn.Device | | Yes |
| prefix | Dumped report filename prefix | str | | No |
+------------------+----------------------------------+-----------------------+-------------+----------+
)doc");
Expand Down Expand Up @@ -319,7 +301,7 @@ void device_module(py::module &m_device) {
+------------------+----------------------------------+-----------------------+-------------+----------+
| Argument | Description | Data type | Valid range | Required |
+==================+==================================+=======================+=============+==========+
| device | Device to dump profiling data of | tt_lib.device.Device | | Yes |
| device | Device to dump profiling data of | ttnn.Device | | Yes |
| last_dump | Last dump before process dies | bool | | No |
+------------------+----------------------------------+-----------------------+-------------+----------+
)doc");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,7 @@ Tensor _logit(const Tensor& input_a, float eps, const std::optional<MemoryConfig
Tensor linput_m1 = ttnn::rsub(logit_input, 1.0, output_mem_config);
Tensor log_input = ttnn::multiply(logit_input, ttnn::reciprocal(linput_m1, output_mem_config), std::nullopt, output_mem_config);
linput_m1.deallocate();
Tensor t_inf = ttnn::multiply(ttnn::sign(input_a, output_mem_config), std::numeric_limits<float>::infinity(), std::nullopt, output_mem_config);
Tensor t_inf = ttnn::multiply(ttnn::sign(input_a, output_mem_config), input_a.device()->sfpu_inf(), std::nullopt, output_mem_config);
Tensor logit_result = ttnn::where(
ttnn::eq(logit_input, 1.0, std::nullopt, output_mem_config),
t_inf,
Expand Down
59 changes: 45 additions & 14 deletions ttnn/ttnn/operations/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@ def torch_multigammaln(x, *args, **kwargs):
result += 3.434189657547
return result

def _golden_function(input_tensor: ttnn.Tensor, *args, device, **kwargs):
def _golden_function(input_tensor: ttnn.Tensor, **_):
name_to_golden_function = {
"abs": torch.abs,
"acos": torch.acos,
"asin": torch.asin,
"atan": torch.atan,
"cos": torch.cos,
"erfinv": torch.erfinv,
Expand All @@ -50,7 +48,6 @@ def _golden_function(input_tensor: ttnn.Tensor, *args, device, **kwargs):
"ltz": lambda x: torch.lt(x, 0),
"neg": torch.neg,
"nez": lambda x: torch.ne(x, 0),
"reciprocal": torch.reciprocal,
"relu": torch.relu,
"relu6": torch.nn.functional.relu6,
"sigmoid": torch.sigmoid,
Expand All @@ -75,7 +72,6 @@ def _golden_function(input_tensor: ttnn.Tensor, *args, device, **kwargs):
# Other unaries (composite operations)
"softplus": torch.nn.functional.softplus,
"sigmoid_accurate": torch.sigmoid,
"acosh": torch.acosh,
"asinh": torch.asinh,
"atanh": torch.atanh,
"cbrt": torch_cbrt,
Expand Down Expand Up @@ -105,20 +101,13 @@ def _golden_function(input_tensor: ttnn.Tensor, *args, device, **kwargs):
)

torch_function = name_to_golden_function[unary_function.__name__.split(".")[-1]]
op_name = unary_function.__name__.split(".")[-1]
if op_name in ["reciprocal", "asin", "acos", "acosh"]:
return torch.nan_to_num(
torch_function(input_tensor), nan=device.sfpu_nan(), posinf=device.sfpu_inf(), neginf=-device.sfpu_inf()
)
return torch_function(input_tensor)

ttnn.attach_golden_function(unary_function, golden_function=_golden_function)


TTNN_ELTWISE_UNARY_CPP_FUNCTIONS = [
ttnn.abs,
ttnn.acos,
ttnn.asin,
ttnn.atan,
ttnn.cos,
ttnn.erfinv,
Expand All @@ -144,7 +133,6 @@ def _golden_function(input_tensor: ttnn.Tensor, *args, device, **kwargs):
ttnn.ltz,
ttnn.neg,
ttnn.nez,
ttnn.reciprocal,
ttnn.relu,
ttnn.relu6,
ttnn.sigmoid,
Expand All @@ -171,7 +159,6 @@ def _golden_function(input_tensor: ttnn.Tensor, *args, device, **kwargs):
ttnn.softplus,
ttnn.sigmoid_accurate,
# Other unaries (composite operations - tt_eager dependency)
ttnn.acosh,
ttnn.asinh,
ttnn.atanh,
ttnn.cbrt,
Expand All @@ -196,6 +183,50 @@ def _golden_function(input_tensor: ttnn.Tensor, *args, device, **kwargs):
register_ttnn_cpp_unary_function(unary_function)


def _golden_function_asin(input_tensor_a, *args, device, **kwargs):
import torch

return torch.nan_to_num(
torch.asin(input_tensor_a), nan=device.sfpu_nan(), posinf=device.sfpu_inf(), neginf=-device.sfpu_inf()
)


ttnn.attach_golden_function(ttnn.asin, golden_function=_golden_function_asin)


def _golden_function_acos(input_tensor_a, *args, device, **kwargs):
import torch

return torch.nan_to_num(
torch.acos(input_tensor_a), nan=device.sfpu_nan(), posinf=device.sfpu_inf(), neginf=-device.sfpu_inf()
)


ttnn.attach_golden_function(ttnn.acos, golden_function=_golden_function_acos)


def _golden_function_acosh(input_tensor_a, *args, device, **kwargs):
import torch

return torch.nan_to_num(
torch.acosh(input_tensor_a), nan=device.sfpu_nan(), posinf=device.sfpu_inf(), neginf=-device.sfpu_inf()
)


ttnn.attach_golden_function(ttnn.acosh, golden_function=_golden_function_acosh)


def _golden_function_reciprocal(input_tensor_a, *args, device, **kwargs):
import torch

return torch.nan_to_num(
torch.reciprocal(input_tensor_a), nan=device.sfpu_nan(), posinf=device.sfpu_inf(), neginf=-device.sfpu_inf()
)


ttnn.attach_golden_function(ttnn.reciprocal, golden_function=_golden_function_reciprocal)


def _golden_function_pow(input_tensor_a, exponent, *args, **kwargs):
import torch

Expand Down

0 comments on commit 9540dee

Please sign in to comment.