Skip to content

Commit

Permalink
#11084: Update ternary bw docs
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Aug 7, 2024
1 parent 78f18d1 commit 297d892
Showing 1 changed file with 53 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace ternary_backward {
namespace detail {

template <typename ternary_backward_operation_t>
void bind_ternary_backward(py::module& module, const ternary_backward_operation_t& operation, const std::string& description) {
void bind_ternary_backward(py::module& module, const ternary_backward_operation_t& operation, std::string_view description, std::string_view supported_dtype) {
auto doc = fmt::format(
R"doc({0}(grad_tensor: ttnn.Tensor, input_tensor_a: ttnn.Tensor, input_tensor_b: ttnn.Tensor, input_tensor_c: ttnn.Tensor, alpha: float, *, memory_config: ttnn.MemoryConfig) -> std::vector<Tensor>
Expand All @@ -36,6 +36,8 @@ void bind_ternary_backward(py::module& module, const ternary_backward_operation_
Keyword args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): memory config for the output tensor
{3}
Example:
>>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device)
Expand All @@ -46,6 +48,7 @@ void bind_ternary_backward(py::module& module, const ternary_backward_operation_
)doc",
operation.base_name(),
operation.python_fully_qualified_name(),
supported_dtype,
description);

bind_registered_operation(
Expand Down Expand Up @@ -74,7 +77,7 @@ void bind_ternary_backward(py::module& module, const ternary_backward_operation_
}

template <typename ternary_backward_operation_t>
void bind_ternary_backward_op(py::module& module, const ternary_backward_operation_t& operation, const std::string& description) {
void bind_ternary_backward_op(py::module& module, const ternary_backward_operation_t& operation, std::string_view description, std::string_view supported_dtype) {
auto doc = fmt::format(
R"doc({0}(grad_tensor: ttnn.Tensor, input_tensor_a: ttnn.Tensor, input_tensor_b: ttnn.Tensor, input_tensor_c: Union[ttnn.Tensor, float], *, memory_config: ttnn.MemoryConfig) -> std::vector<Tensor>
Expand All @@ -84,11 +87,13 @@ void bind_ternary_backward_op(py::module& module, const ternary_backward_operati
* :attr:`grad_tensor`
* :attr:`input_tensor_a`
* :attr:`input_tensor_b`
* :attr:`input_tensor_c`
* :attr:`input_tensor_c` (ttnn.Tensor or float)
Keyword args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): memory config for the output tensor
{3}
Example:
>>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device)
Expand All @@ -99,6 +104,7 @@ void bind_ternary_backward_op(py::module& module, const ternary_backward_operati
)doc",
operation.base_name(),
operation.python_fully_qualified_name(),
supported_dtype,
description);

bind_registered_operation(
Expand Down Expand Up @@ -140,7 +146,7 @@ void bind_ternary_backward_op(py::module& module, const ternary_backward_operati
}

template <typename ternary_backward_operation_t>
void bind_ternary_backward_optional_output(py::module& module, const ternary_backward_operation_t& operation, const std::string& description) {
void bind_ternary_backward_optional_output(py::module& module, const ternary_backward_operation_t& operation, std::string_view description, std::string_view supported_dtype) {
auto doc = fmt::format(
R"doc({0}(grad_tensor: ttnn.Tensor, input_tensor_a: ttnn.Tensor, input_tensor_b: ttnn.Tensor, input_tensor_c: ttnn.Tensor, *, memory_config: ttnn.MemoryConfig) -> std::vector<std::optional<Tensor>>
Expand All @@ -157,6 +163,8 @@ void bind_ternary_backward_optional_output(py::module& module, const ternary_bac
* :attr:`output_tensor` (Optional[ttnn.Tensor]): preallocated output tensor
* :attr:`queue_id` (Optional[uint8]): command queue id
{3}
Example:
>>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device)
Expand All @@ -167,6 +175,7 @@ void bind_ternary_backward_optional_output(py::module& module, const ternary_bac
)doc",
operation.base_name(),
operation.python_fully_qualified_name(),
supported_dtype,
description);

bind_registered_operation(
Expand Down Expand Up @@ -205,21 +214,61 @@ void py_module(py::module& module) {
detail::bind_ternary_backward(
module,
ttnn::addcmul_bw,
R"doc(Supported dtypes, layouts, and ranks:
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16, BFLOAT8_B | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
Note : bfloat8_b/bfloat4_b supports only on TILE_LAYOUT)doc",
R"doc(Performs backward operations for addcmul of :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`input_tensor_c` with given :attr:`grad_tensor`.)doc");

detail::bind_ternary_backward(
module,
ttnn::addcdiv_bw,
R"doc(Supported dtypes, layouts, and ranks:
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16 | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+)doc",
R"doc(Performs backward operations for addcdiv of :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`input_tensor_c` with given :attr:`grad_tensor`.)doc");

detail::bind_ternary_backward_optional_output(
module,
ttnn::where_bw,
R"doc(Supported dtypes, layouts, and ranks:
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16 | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+)doc",
R"doc(Performs backward operations for where of :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`input_tensor_c` with given :attr:`grad_tensor`.)doc");

detail::bind_ternary_backward_op(
module,
ttnn::lerp_bw,
R"doc(Supported dtypes, layouts, and ranks: For Inputs : :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`input_tensor_c`
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16 | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
Supported dtypes, layouts, and ranks: For Inputs : :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`scalar`
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16, BFLOAT8_B | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
Note : bfloat8_b/bfloat4_b supports only on TILE_LAYOUT)doc",
R"doc(Performs backward operations for lerp of :attr:`input_tensor_a` , :attr:`input_tensor_b` and :attr:`input_tensor_c` or :attr:`scalar` with given :attr:`grad_tensor`.)doc");

}
Expand Down

0 comments on commit 297d892

Please sign in to comment.