Skip to content

Commit

Permalink
#6633: Update data gen func that fill the constant val
Browse files Browse the repository at this point in the history
  • Loading branch information
umadevimcw committed Mar 22, 2024
1 parent 8623395 commit 8b4d575
Showing 1 changed file with 8 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ def data_gen_pt_tt(input_shapes, low, high, device, required_grad=False):
return pt_tensor, tt_tensor


def data_gen_with_val(input_shapes, device, required_grad=False, val=1):
pt_tensor = (torch.ones(input_shapes, requires_grad=required_grad) * val).bfloat16()
tt_tensor = (
tt_lib.tensor.Tensor(pt_tensor, tt_lib.tensor.DataType.BFLOAT16).to(tt_lib.tensor.Layout.TILE).to(device)
)
return pt_tensor, tt_tensor


def compare_pcc(tt_tensor, golden_tensor, pcc=0.99):
status = True
for i in range(len(tt_tensor)):
Expand Down

0 comments on commit 8b4d575

Please sign in to comment.