diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index a07ebf61d..5c7c348d9 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -32,6 +32,7 @@ #pragma clang diagnostic ignored "-Wunused-local-typedef" #pragma clang diagnostic ignored "-Wunused-function" #pragma clang diagnostic ignored "-Wpessimizing-move" +#pragma clang diagnostic ignored "-Wparentheses" #define FMT_HEADER_ONLY #include "ttnn/device.hpp" #include "ttnn/operations/binary.hpp" diff --git a/runtime/tools/python/ttrt/__init__.py b/runtime/tools/python/ttrt/__init__.py index f7883c980..34d307a81 100644 --- a/runtime/tools/python/ttrt/__init__.py +++ b/runtime/tools/python/ttrt/__init__.py @@ -17,7 +17,7 @@ import shutil import ttrt.binary -from ttrt.common.api import read, run, query, perf +from ttrt.common.api import read, run, query, perf, init_fns from ttrt.common.util import read_actions ####################################################################################### @@ -80,6 +80,17 @@ def main(): action="store_true", help="save all artifacts during run", ) + run_parser.add_argument( + "--init", + default="randn", + choices=init_fns, + help="Function to initialize tensors with", + ) + run_parser.add_argument( + "--identity", + action="store_true", + help="Do a golden identity test on the output tensors", + ) run_parser.add_argument( "--seed", default=0, diff --git a/runtime/tools/python/ttrt/common/api.py b/runtime/tools/python/ttrt/common/api.py index a8e7f7ee4..1356850c7 100644 --- a/runtime/tools/python/ttrt/common/api.py +++ b/runtime/tools/python/ttrt/common/api.py @@ -20,6 +20,31 @@ import ttrt.binary from ttrt.common.util import * + +def randn(shape, dtype): + import torch + + return torch.randn(shape, dtype=dtype) + + +def arange(shape, dtype): + import torch + + def volume(shape): + v = 1 + for i in shape: + v *= i + return v + + return torch.arange(volume(shape), dtype=dtype).reshape(shape) + + +init_fns_map = { + "randn": randn, + "arange": arange, +} +init_fns = sorted(list(init_fns_map.keys())) + ####################################################################################### ########################################**API**######################################## ####################################################################################### @@ -173,7 +198,7 @@ def run(args): ) for i in program["inputs"]: - torch_tensor = torch.randn( + torch_tensor = init_fns_map[args.init]( i["desc"]["shape"], dtype=fromDataType(i["desc"]["layout"]["memory_desc"]["data_type"]), ) @@ -225,6 +250,16 @@ def run(args): print(f"finished loop={loop}") ttrt.runtime.wait(event) print("outputs:\n", torch_outputs) + if args.identity: + for i, o in zip(torch_inputs[binary_name], torch_outputs[binary_name]): + if not torch.allclose(i, o): + print( + "Failed: inputs and outputs do not match in binary", + binary_name, + ) + print(i - o) + else: + print("Passed:", binary_name) # save artifacts if arg_save_artifacts: