Skip to content

Commit

Permalink
Update aot tool tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
riga committed Mar 27, 2024
1 parent e07f8bd commit 7c477f1
Showing 1 changed file with 23 additions and 53 deletions.
76 changes: 23 additions & 53 deletions PhysicsTools/TensorFlowAOT/test/testAOTTools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

import os
import sys
import re
import shlex
import subprocess
import tempfile
Expand Down Expand Up @@ -33,66 +33,36 @@ def wrapper(self):
class TFAOTTests(unittest.TestCase):

@run_in_tmp
def test_compilation(self, tmp_dir):
# create the test model
cmd = [
sys.executable,
"-W", "ignore",
os.path.join(this_dir, "create_model.py"),
"-d", os.path.join(tmp_dir, "testmodel"),
]
run_cmd(cmd)

# compile it
cmd = [
"PYTHONWARNINGS=ignore",
"cmsml_compile_tf_graph",
os.path.join(tmp_dir, "testmodel"),
os.path.join(tmp_dir, "testmodel_compiled"),
"-c", r"testmodel_bs{}", r"testmodel_bs{}",
"-b", "1,2",
]
run_cmd(cmd)
def test_dev_workflow(self, tmp_dir):
import cms_tfaot

# check files
exists = lambda *p: os.path.exists(os.path.join(tmp_dir, "testmodel_compiled", "aot", *p))
self.assertTrue(exists("testmodel_bs1.h"))
self.assertTrue(exists("testmodel_bs1.o"))
self.assertTrue(exists("testmodel_bs2.h"))
self.assertTrue(exists("testmodel_bs2.o"))
# find the cms_tfaot install dir to locate the test model
m = re.match(r"(.+/\d+\.\d+\.\d+\-[^/]+)/lib/.+$", cms_tfaot.__file__)
self.assertIsNotNone(m)
config_file = os.path.join(m.group(1), "share", "test_models", "simple", "aot_config.yaml")
self.assertTrue(os.path.exists(config_file))

@run_in_tmp
def test_dev_workflow(self, tmp_dir):
# run the dev workflow
# create the test model
cmd = [
sys.executable,
"-W", "ignore",
os.path.join(this_dir, "create_model.py"),
"-d", os.path.join(tmp_dir, "testmodel"),
]
run_cmd(cmd)

# compile it
cmd = [
sys.executable,
"-W", "ignore",
os.path.normpath(os.path.join(this_dir, "..", "scripts", "compile_model.py")),
"-m", os.path.join(tmp_dir, "testmodel"),
"-s", "PhysicsTools",
"-p", "TensorFlowAOT",
"-b", "1,2",
"-o", os.path.join(tmp_dir, "testmodel_compiled"),
"cms_tfaot_compile",
"-c", config_file,
"-o", tmp_dir,
"--tool-name", "tfaot-model-test",
"--dev",
]
run_cmd(cmd)

# check files
exists = lambda *p: os.path.exists(os.path.join(tmp_dir, "testmodel_compiled", *p))
self.assertTrue(exists("tfaot-dev-physicstools-tensorflowaot-testmodel.xml"))
self.assertTrue(exists("include", "testmodel.h"))
self.assertTrue(exists("include", "testmodel_bs1.h"))
self.assertTrue(exists("include", "testmodel_bs2.h"))
self.assertTrue(exists("lib", "testmodel_bs1.o"))
self.assertTrue(exists("lib", "testmodel_bs2.o"))
exists = lambda *p: os.path.exists(os.path.join(tmp_dir, *p))
self.assertTrue(exists("tfaot-model-test.xml"))
self.assertTrue(exists("include", "tfaot-model-test"))
self.assertTrue(exists("include", "tfaot-model-test", "test_simple_bs1.h"))
self.assertTrue(exists("include", "tfaot-model-test", "test_simple_bs2.h"))
self.assertTrue(exists("include", "tfaot-model-test", "test_simple.h"))
self.assertTrue(exists("include", "tfaot-model-test", "model.h"))
self.assertTrue(exists("lib", "test_simple_bs1.o"))
self.assertTrue(exists("lib", "test_simple_bs2.o"))


if __name__ == "__main__":
Expand Down

0 comments on commit 7c477f1

Please sign in to comment.