-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
8 changed files
with
203 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import subprocess | ||
import os | ||
import tempfile | ||
|
||
# TODO(odjuricic) Cleaner to implement ttrt --quiet flag. | ||
os.environ["TTRT_LOGGER_LEVEL"] = "ERROR" | ||
from ttrt import API as ttrt | ||
import ttmlir.passes | ||
from . import utils | ||
import pandas as pd | ||
|
||
|
||
class ModelRunner: | ||
""" | ||
ModelRunner is a singleton class used for compilation and running of models. Ensuring only one can be run at a time. | ||
This is necessary because the adaptor class is reinitialized on every request from the frontend, so it cannot keep state. | ||
""" | ||
|
||
_instance = None | ||
_explorer_artifacts_dir = None | ||
|
||
def __new__(cls, *args, **kwargs): | ||
if not cls._instance: | ||
print("Creating a new ModelRunner instance.") | ||
cls._instance = super(ModelRunner, cls).__new__(cls, *args, **kwargs) | ||
cls._instance.initialize() | ||
return cls._instance | ||
|
||
def initialize(self): | ||
# Initialize machine to generate SystemDesc and load up functionality to begin | ||
print("Running ttrt initialization.") | ||
ttrt.initialize_apis() | ||
|
||
if "TT_MLIR_HOME" not in os.environ: | ||
raise RuntimeError("TT_MLIR_HOME not set. Did you run source env/activate?") | ||
|
||
# TODO(odjuricic, #1200) ttrt perf breaks if artifacts dir is changed from default. | ||
# self._explorer_artifacts_dir = os.environ['TT_MLIR_HOME'] + '/explorer-artifacts' | ||
self._explorer_artifacts_dir = os.environ["TT_MLIR_HOME"] + "/ttrt-artifacts" | ||
os.makedirs(self._explorer_artifacts_dir, exist_ok=True) | ||
|
||
# Save the system descriptor. | ||
ttrt.Query( | ||
args={ | ||
"--save-artifacts": True, | ||
"--artifact-dir": self._explorer_artifacts_dir, | ||
} | ||
)() | ||
|
||
def run(self, model_path): | ||
# TODO(odjuricic, #1174) This should be in a separete thread later. | ||
model_name = os.path.basename(model_path).split(".")[0] | ||
|
||
ttir_to_ttnn_options = " ".join( | ||
[ | ||
f'system-desc-path={f"{self._explorer_artifacts_dir}/system_desc.ttsys"}', | ||
"enable-optimizer=true", | ||
"memory-layout-analysis-enabled=true", | ||
] | ||
) | ||
|
||
module = utils.parse_mlir_file(model_path) | ||
|
||
try: | ||
print("Running MLIR compile: TTIR to TTNN Backend Pipeline") | ||
print("With options: ", ttir_to_ttnn_options) | ||
# TODO(odjuricic) When we hit compiler assert it terminates the process. We should catch this and return an error to the frontend. | ||
ttmlir.passes.ttir_to_ttnn_backend_pipeline(module, ttir_to_ttnn_options) | ||
except Exception as e: | ||
print("Error running MLIR compile: TTIR to TTNN Backend Pipeline") | ||
raise e | ||
|
||
# TODO(odjuricic) Move this file somewhere else, but keep the name. | ||
flatbuffer_file = model_name + ".ttnn" | ||
try: | ||
print("Running TTNN to Flatbuffer File") | ||
ttmlir.passes.ttnn_to_flatbuffer_file(module, flatbuffer_file, {}) | ||
except Exception as e: | ||
print("Error running TTNN to Flatbuffer File") | ||
raise e | ||
|
||
# TODO(odjuricic) validate that the module was converted to TTNN without fail | ||
|
||
if os.path.exists(f"{self._explorer_artifacts_dir}/{flatbuffer_file}"): | ||
print("Removing artifacts of previous run.") | ||
os.system(f"rm -rf {self._explorer_artifacts_dir}/{flatbuffer_file}") | ||
|
||
ttrt_perf_command = " ".join( | ||
[ | ||
"ttrt", | ||
"perf", | ||
flatbuffer_file, | ||
f"--artifact-dir={self._explorer_artifacts_dir}", | ||
"--save-artifacts", | ||
] | ||
) | ||
|
||
print("Running", ttrt_perf_command) | ||
process = subprocess.Popen( | ||
ttrt_perf_command, | ||
shell=True, | ||
stdout=subprocess.PIPE, | ||
stderr=subprocess.STDOUT, | ||
text=True, | ||
) | ||
|
||
for line in process.stdout: | ||
print(line, end="") | ||
|
||
process.stdout.close() | ||
process.wait() | ||
|
||
if process.returncode != 0: | ||
print(f"Error: TTRT process exited with code {process.returncode}") | ||
raise RuntimeError("Error running TTRT") | ||
|
||
op_perf_file = f"{self._explorer_artifacts_dir}/{flatbuffer_file}/perf/ops_perf_results.csv" | ||
if not os.path.exists(op_perf_file): | ||
raise FileNotFoundError(f"Performance file {op_perf_file} not found.") | ||
perf = pd.read_csv(op_perf_file) | ||
columns = [ | ||
"GLOBAL CALL COUNT", | ||
"OP CODE", | ||
"DEVICE FW DURATION [ns]", | ||
"CORE COUNT", | ||
"OUTPUT_0_MEMORY", | ||
] | ||
perf = perf[columns] | ||
print(perf) | ||
|
||
print("Total device duration: ", perf["DEVICE FW DURATION [ns]"].sum(), "ns") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import ttmlir | ||
|
||
|
||
def parse_mlir_file(model_path): | ||
with ttmlir.ir.Context() as ctx, open(model_path, "r") as model_file: | ||
ttmlir.dialects.ttkernel.register_dialect(ctx) | ||
ttmlir.dialects.ttir.register_dialect(ctx) | ||
ttmlir.dialects.tt.register_dialect(ctx) | ||
module = ttmlir.ir.Module.parse("".join(model_file.readlines()), ctx) | ||
return module |