Skip to content

Commit

Permalink
ttrt add support for initialization function + ident check (#296)
Browse files Browse the repository at this point in the history
- New tensor initialization functions, for certain kinds of tests it's
useful to initialize with a range of values vs random.  In the future we
could add more like uniform distribution.
- Some tests, like for data movement, can be implemented to perform a
trivial golden check where the input exactly matches the output. Add an
identity flag which does this check.
  • Loading branch information
nsmithtt authored Aug 8, 2024
1 parent e61ea78 commit 52ecaf0
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 2 deletions.
1 change: 1 addition & 0 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#pragma clang diagnostic ignored "-Wunused-local-typedef"
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wpessimizing-move"
#pragma clang diagnostic ignored "-Wparentheses"
#define FMT_HEADER_ONLY
#include "ttnn/device.hpp"
#include "ttnn/operations/binary.hpp"
Expand Down
13 changes: 12 additions & 1 deletion runtime/tools/python/ttrt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import shutil

import ttrt.binary
from ttrt.common.api import read, run, query, perf
from ttrt.common.api import read, run, query, perf, init_fns
from ttrt.common.util import read_actions

#######################################################################################
Expand Down Expand Up @@ -80,6 +80,17 @@ def main():
action="store_true",
help="save all artifacts during run",
)
run_parser.add_argument(
"--init",
default="randn",
choices=init_fns,
help="Function to initialize tensors with",
)
run_parser.add_argument(
"--identity",
action="store_true",
help="Do a golden identity test on the output tensors",
)
run_parser.add_argument(
"--seed",
default=0,
Expand Down
37 changes: 36 additions & 1 deletion runtime/tools/python/ttrt/common/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,31 @@
import ttrt.binary
from ttrt.common.util import *


def randn(shape, dtype):
import torch

return torch.randn(shape, dtype=dtype)


def arange(shape, dtype):
import torch

def volume(shape):
v = 1
for i in shape:
v *= i
return v

return torch.arange(volume(shape), dtype=dtype).reshape(shape)


init_fns_map = {
"randn": randn,
"arange": arange,
}
init_fns = sorted(list(init_fns_map.keys()))

#######################################################################################
########################################**API**########################################
#######################################################################################
Expand Down Expand Up @@ -173,7 +198,7 @@ def run(args):
)

for i in program["inputs"]:
torch_tensor = torch.randn(
torch_tensor = init_fns_map[args.init](
i["desc"]["shape"],
dtype=fromDataType(i["desc"]["layout"]["memory_desc"]["data_type"]),
)
Expand Down Expand Up @@ -225,6 +250,16 @@ def run(args):
print(f"finished loop={loop}")
ttrt.runtime.wait(event)
print("outputs:\n", torch_outputs)
if args.identity:
for i, o in zip(torch_inputs[binary_name], torch_outputs[binary_name]):
if not torch.allclose(i, o):
print(
"Failed: inputs and outputs do not match in binary",
binary_name,
)
print(i - o)
else:
print("Passed:", binary_name)

# save artifacts
if arg_save_artifacts:
Expand Down

0 comments on commit 52ecaf0

Please sign in to comment.