Skip to content

Commit

Permalink
ttrt add support for initialization function + ident check
Browse files Browse the repository at this point in the history
- New tensor initialization functions, for certain kinds of tests it's
useful to initialize with a range of values vs random.  In the future we
could add more like uniform distribution.
- Some tests, like for data movement, can be implemented to perform a
trivial golden check where the input exactly matches the output. Add an
identity flag which does this check.
  • Loading branch information
nsmithtt committed Aug 5, 2024
1 parent 5a99c93 commit 67a4764
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 4 deletions.
1 change: 1 addition & 0 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#pragma clang diagnostic ignored "-Wunused-local-typedef"
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wpessimizing-move"
#pragma clang diagnostic ignored "-Wparentheses"
#define FMT_HEADER_ONLY
#include "ttnn/device.hpp"
#include "ttnn/operations/binary.hpp"
Expand Down
13 changes: 12 additions & 1 deletion runtime/tools/python/ttrt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import shutil

import ttrt.binary
from ttrt.common.api import read, run, query, perf
from ttrt.common.api import read, run, query, perf, init_fns
from ttrt.common.util import read_actions

#######################################################################################
Expand Down Expand Up @@ -80,6 +80,17 @@ def main():
action="store_true",
help="save all artifacts during run",
)
run_parser.add_argument(
"--init",
default="randn",
choices=init_fns,
help="Function to initialize tensors with",
)
run_parser.add_argument(
"--identity",
action="store_true",
help="Do a golden identity test on the output tensors",
)
run_parser.add_argument(
"--seed",
default=0,
Expand Down
39 changes: 36 additions & 3 deletions runtime/tools/python/ttrt/common/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,31 @@
import ttrt.binary
from ttrt.common.util import *


def randn(shape, dtype):
import torch

return torch.randn(shape, dtype=dtype)


def arange(shape, dtype):
import torch

def volume(shape):
v = 1
for i in shape:
v *= i
return v

return torch.arange(volume(shape), dtype=dtype).reshape(shape)


init_fns_map = {
"randn": randn,
"arange": arange,
}
init_fns = sorted(list(init_fns_map.keys()))

#######################################################################################
########################################**API**########################################
#######################################################################################
Expand Down Expand Up @@ -134,7 +159,6 @@ def run(args):
atexit.register(lambda: ttrt.runtime.close_device(device))

torch.manual_seed(args.seed)

for (binary_name, fbb, fbb_dict) in fbb_list:
torch_inputs[binary_name] = []
torch_outputs[binary_name] = []
Expand All @@ -145,9 +169,9 @@ def run(args):
)

for i in program["inputs"]:
torch_tensor = torch.randn(
torch_tensor = init_fns_map[args.init](
i["desc"]["shape"],
dtype=fromDataType(i["desc"]["layout"]["memory_desc"]["data_type"]),
fromDataType(i["desc"]["layout"]["memory_desc"]["data_type"]),
)
torch_inputs[binary_name].append(torch_tensor)
for i in program["outputs"]:
Expand Down Expand Up @@ -193,6 +217,15 @@ def run(args):
ttrt.runtime.submit(device, fbb, 0, total_inputs[loop], total_outputs[loop])
print(f"finished loop={loop}")
print("outputs:\n", torch_outputs)
if args.identity:
for i, o in zip(torch_inputs[binary_name], torch_outputs[binary_name]):
if not torch.allclose(i, o):
print(
"Failed: inputs and outputs do not match in binary", binary_name
)
print(i - o)
else:
print("Passed:", binary_name)

# save artifacts
if arg_save_artifacts:
Expand Down

0 comments on commit 67a4764

Please sign in to comment.