Skip to content

Commit

Permalink
#8158: Update heaviside sweep config and doc (#13240)
Browse files Browse the repository at this point in the history
* #8158: Update heaviside sweep config and doc

* #8158: Move sweep test
  • Loading branch information
mcw-anasuya authored Oct 16, 2024
1 parent b0d1a37 commit ccb884a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
17 changes: 14 additions & 3 deletions tests/sweep_framework/sweeps/eltwise/unary/heaviside/heaviside.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,23 @@
"input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 256, 256], [1, 1, 32, 32], 16)
+ gen_shapes([1, 32, 32], [12, 256, 256], [1, 32, 32], 16)
+ gen_shapes([32, 32], [256, 256], [32, 32], 16),
"input_a_dtype": [ttnn.bfloat16],
"input_a_layout": [ttnn.TILE_LAYOUT],
"input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b],
"input_a_layout": [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT],
"input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG],
"output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG],
},
}


# Invalidate vector is called during the generation phase where each vector will be passed in.
# If invalidated, the vector will still be stored but will be skipped.
# Returns False, None if the vector is valid, and True, str with a reason for invalidation if it is invalid.
def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]:
if test_vector["input_a_layout"] == ttnn.ROW_MAJOR_LAYOUT:
return True, "Row Major layout is not supported"
return False, None


# This is the run instructions for the test, defined by the developer.
# The run function must take the above-defined parameters as inputs.
# The runner will call this run function with each test vector, and the returned results from this function will be stored.
Expand All @@ -58,7 +67,9 @@ def run(
)(input_shape)

scalar = torch.tensor(1, dtype=torch.bfloat16).uniform_(-100, 100)
torch_output_tensor = torch.heaviside(torch_input_tensor_a, scalar)

golden_function = ttnn.get_golden_function(ttnn.heaviside)
torch_output_tensor = golden_function(torch_input_tensor_a, value=scalar)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
Expand Down
11 changes: 10 additions & 1 deletion ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1565,7 +1565,16 @@ void py_module(py::module& module) {
)doc");

detail::bind_unary_operation_with_float_parameter(module, ttnn::heaviside, "value", "The value parameter for the Heaviside function", "");
detail::bind_unary_operation_with_float_parameter(module, ttnn::heaviside, "value", "The value parameter for the Heaviside function", "",
R"doc(Supported dtypes, layouts, and ranks:
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16, BFLOAT8_B | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
)doc");

detail::bind_unary_operation_with_float_parameter(module, ttnn::leaky_relu, "negative_slope", "The slope parameter for the Leaky ReLU function", "");
detail::bind_unary_operation_with_float_parameter(module, ttnn::relu_max, "upper_limit", "The max value for ReLU function", "This function caps off the input to a max value and a min value of 0");
detail::bind_unary_operation_with_float_parameter(module, ttnn::relu_min, "lower_limit", "The min value for ReLU function", "This will carry out ReLU operation at min value instead of the standard 0");
Expand Down

0 comments on commit ccb884a

Please sign in to comment.