Skip to content

Commit

Permalink
#13373: Update golden function
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Oct 16, 2024
1 parent 2f6a632 commit bec8221
Show file tree
Hide file tree
Showing 12 changed files with 25 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def run(
partial(torch_random, low=-100, high=100, dtype=torch.float32), input_b_dtype
)(other)

torch_output_tensor = torch.div(torch_input_tensor_a, torch_other_tensor)
golden_function = ttnn.get_golden_function(ttnn.div)
torch_output_tensor = golden_function(torch_input_tensor_a, torch_other_tensor)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def run(
partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype
)(input_specs["shape"])

torch_output_tensor = torch.lt(torch_input_tensor_a, input_specs["other"])
golden_function = ttnn.get_golden_function(ttnn.lt)
torch_output_tensor = golden_function(torch_input_tensor_a, input_specs["other"])

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def run(
torch_input_tensor_b = gen_func_with_cast_tt(
partial(torch_random, low=-100, high=100, dtype=torch.float32), input_b_dtype
)(input_shape_b)
torch_output_tensor = torch.lt(torch_input_tensor_a, torch_input_tensor_b)

golden_function = ttnn.get_golden_function(ttnn.lt)
torch_output_tensor = golden_function(torch_input_tensor_a, torch_input_tensor_b)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def run(
partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype
)(input_specs["shape"])

torch_output_tensor = torch.mul(torch_input_tensor_a, input_specs["other"])
golden_function = ttnn.get_golden_function(ttnn.mul)
torch_output_tensor = golden_function(torch_input_tensor_a, input_specs["other"])

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def run(
partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype
)(input_specs["shape"])

torch_output_tensor = torch.ne(torch_input_tensor_a, input_specs["other"])
golden_function = ttnn.get_golden_function(ttnn.ne)
torch_output_tensor = golden_function(torch_input_tensor_a, input_specs["other"])

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def run(
partial(torch_random, low=-100, high=100, dtype=torch.float32), input_b_dtype
)(other)

torch_output_tensor = torch.sub(torch_input_tensor_a, torch_other_tensor)
golden_function = ttnn.get_golden_function(ttnn.sub)
torch_output_tensor = golden_function(torch_input_tensor_a, torch_other_tensor)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def run(
min_val = input_specs.get("min", None)
max_val = input_specs.get("max", None)

torch_output_tensor = torch.clamp(torch_input_tensor_a, min_val, max_val)
golden_function = ttnn.get_golden_function(ttnn.clamp)
torch_output_tensor = golden_function(torch_input_tensor_a, min_val, max_val)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def run(
partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype
)(input_specs["shape"])

torch_output_tensor = torch.clamp(torch_input_tensor_a, input_specs["min"])
golden_function = ttnn.get_golden_function(ttnn.clamp)
torch_output_tensor = golden_function(torch_input_tensor_a, input_specs["min"])

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def run(
partial(torch_random, low=-100, high=100, dtype=torch.float32), input_dtype
)(input_shape)

torch_output_tensor = torch.nn.functional.hardswish(torch_input_tensor_a)
golden_function = ttnn.get_golden_function(ttnn.hardswish)
torch_output_tensor = golden_function(torch_input_tensor_a)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ def run(
min_val = input_specs.get("min_val")
max_val = input_specs.get("max_val")

torch_output_tensor = torch.nn.functional.hardtanh(torch_input_tensor_a, min_val, max_val)
golden_function = ttnn.get_golden_function(ttnn.hardtanh)
torch_output_tensor = golden_function(torch_input_tensor_a, min_val, max_val)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def run(
partial(torch_random, low=-100, high=100, dtype=torch.float32), input_dtype
)(input_shape)

torch_output_tensor = torch.logical_not(torch_input_tensor_a)
golden_function = ttnn.get_golden_function(ttnn.logical_not)
torch_output_tensor = golden_function(torch_input_tensor_a)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def run(
partial(torch_random, low=-100, high=100, dtype=torch.float32), input_dtype
)(input_shape)

torch_output_tensor = torch.neg(torch_input_tensor_a)
golden_function = ttnn.get_golden_function(ttnn.neg)
torch_output_tensor = golden_function(torch_input_tensor_a)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
Expand Down

0 comments on commit bec8221

Please sign in to comment.