Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AOT tools #10

Merged
merged 51 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
4f63c14
added tools to share with MR
Dec 1, 2022
980cb65
added operations checker.py
Dec 1, 2022
3879736
remove unnecessary file
Dec 1, 2022
06dc83b
clean up of methods that should be in another class
Dec 1, 2022
97b81cb
add test saved_model
Dec 2, 2022
2fe21d6
added tensorboard function
Dec 5, 2022
c60cc20
added scripts dummies
Dec 5, 2022
8b321a1
clean up repo
Dec 5, 2022
2686e89
more cleaning
Dec 5, 2022
3c6bb08
Added tests and markdown tables for TF2.6.4 and TF1.X
Dec 16, 2022
ea183ef
create tools and tests
Apr 11, 2023
476214a
delete refactor directory structure
Apr 11, 2023
9485cb7
Removed many function. This was possible by found a way using signatu…
May 12, 2023
885beaf
fixed path dst path bug
May 12, 2023
6a5f04f
finished unit tests for convert_model.py
May 12, 2023
3070307
Add more describtion to the functions, also refactor a bit
May 16, 2023
6c128d6
changed graph to concrete function, to be more clear. removed debugger
May 16, 2023
b2cbd3d
made out of test a proper module
May 16, 2023
56f23ba
Finished writing tests
Jun 1, 2023
45b90b8
added test for aot_compatility
Jun 1, 2023
dfe48a7
Start adjustments.
riga Sep 1, 2023
833d6d7
Start refactoring.
riga Sep 4, 2023
e2f9859
Merge pull request #1 from riga/AOT_tools_refactor
Bogdan-Wiederspan Sep 25, 2023
cb4b5b8
Add docstrings on all functions
riga Sep 26, 2023
f95959b
Rename save and load grpah functions.
riga Oct 27, 2023
20f3f6e
Linting.
riga Oct 27, 2023
a62238c
moved functions to handle the loading of Saved_models, and extraction…
riga Nov 2, 2023
51695c3
added tests for 'load_model' and 'load_graph_def'
riga Nov 2, 2023
98b6b56
Added tests for functions regarding aot-utility and aot-compilation s…
riga Nov 2, 2023
1ac967a
moved functions to handle the loading of Saved_models, and extraction…
Bogdan-Wiederspan Nov 2, 2023
001382b
added tests for 'load_model' and 'load_graph_def'
Bogdan-Wiederspan Nov 2, 2023
281cdf2
Added tests for functions regarding aot-utility and aot-compilation s…
Bogdan-Wiederspan Nov 2, 2023
253ce63
Merge branch 'AOT_tools' of github.com:Bogdan-Wiederspan/cmsml into A…
Bogdan-Wiederspan Nov 3, 2023
fa52e71
Merged cmsml/aot_helper into forked branch of AOT helper tools.
Bogdan-Wiederspan Nov 9, 2023
4572709
fixed usage of contextmanager in test_load_graph_def
Bogdan-Wiederspan Nov 9, 2023
730d407
fix bug when extracting nodes from graphdef of a concretefunction. Al…
Bogdan-Wiederspan Nov 10, 2023
6153ce1
Fixed unit tests for general tensorflow tools
Bogdan-Wiederspan Nov 10, 2023
aa2bc22
fixed test for aot compiltation and fixed most linting issues
Bogdan-Wiederspan Nov 13, 2023
1ae9684
Merge pull request #9 from Bogdan-Wiederspan/AOT_tools
riga Nov 15, 2023
795abab
Skip some tests if skip_if_no_tf2xla_supported_ops not available.
riga Nov 15, 2023
e1aa4e0
Adjust tests.
riga Nov 15, 2023
cc3cdf8
Fix lazy loader tests.
riga Nov 15, 2023
63ebe84
Polish tests.
riga Nov 15, 2023
6652e1f
Move to pytest.
riga Nov 15, 2023
a8e1be5
Update intallation in images.
riga Nov 15, 2023
e643c1b
Update deps.
riga Nov 15, 2023
4ce04e1
Update docker files.
riga Nov 15, 2023
d5fd175
Merge branch 'master' into aot_helpers.
riga Nov 15, 2023
8156ae4
Update type hints.
riga Nov 15, 2023
f77652c
Merge branch 'master' into aot_helpers.
riga Nov 15, 2023
0b0d42e
Polish code.
riga Nov 15, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[flake8]

max-line-length = 101
max-line-length = 120

# codes of errors to ignore
ignore = E128, E306, E402, E722, E731, E741, W504, Q003
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/lint_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout 🛎️
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
persist-credentials: false

Expand Down Expand Up @@ -40,12 +40,12 @@ jobs:
name: test (image=${{ matrix.versions.tag }}, tf=${{ matrix.versions.tf }})
steps:
- name: Checkout 🛎️
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
persist-credentials: false

- name: Pull docker image 🐳
run: docker pull cmsml/cmsml:${{ matrix.versions.tag }}

- name: Test 🎰
run: bash tests/docker.sh cmsml/cmsml:${{ matrix.versions.tag }} "[ '${{ matrix.versions.tf }}' = 'default' ] || pip install -U tensorflow=='${{ matrix.versions.tf }}'; python -m unittest tests"
run: bash tests/docker.sh cmsml/cmsml:${{ matrix.versions.tag }} "[ '${{ matrix.versions.tf }}' = 'default' ] || pip install -U tensorflow=='${{ matrix.versions.tf }}'; pytest -n 2 tests"
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ To use the cmsml package via docker, checkout our [DockerHub](https://hub.docker
The tests can be triggered with

```shell
python -m unittest tests
pytest -n auto tests
```

and in general, they should be run for Python 3.7 to 3.11.
Expand Down
5 changes: 5 additions & 0 deletions cmsml/scripts/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
# coding: utf-8

__all__ = ["compile_tf_graph", "aot_compile"]

# provisioning imports
from cmsml.scripts.compile_tf_graph import compile_tf_graph, aot_compile
163 changes: 163 additions & 0 deletions cmsml/scripts/check_aot_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# coding: utf-8

"""
Script that provides insight on which TensorFlow operations are XLA / AOT compatible and whether a specified graph would
be supported.
"""

from __future__ import annotations

import tabulate

from cmsml.util import colored
from cmsml.tensorflow.aot import OpsData, load_graph_def, get_graph_ops


def check_aot_compatibility(
model_path: str,
serving_key: str = "serving_default",
devices: tuple[str] = ("cpu",),
table_format: str = "grid",
) -> None:
"""
Loads model stored in *model_path* and extracts the GraphDef saved under the specified *serving_key*. From this
GraphDef, all ops for specific *devices* are read and compared to all ops with XLA implementation. The matching
result is printed given the chosen *table_format* style.
"""
# open the graph
graph_def = load_graph_def(model_path, serving_key=serving_key)

# extract operation names
op_names = get_graph_ops(graph_def)

# remove trivial ops
op_names = [op_name for op_name in op_names if op_name not in ["Placeholder", "NoOp"]]

# print the op table
devices, ops = print_op_table(devices, filter_ops=op_names, table_format=table_format)

# print a final summary per device
for device in devices:
failed_ops = [
op_name
for op_name in op_names
if not ops.get(op_name, {}).get(device)
]

msg = f"\n{colored(device, 'magenta')}: "
if failed_ops:
msg += colored("not compatible", "red")
msg += f", {len(failed_ops)} incompatible ops: {', '.join(failed_ops)}"
else:
msg += colored("all ops compatible", "green")
print(msg)


def print_op_table(
devices: tuple[str],
filter_ops: list[str] | None = None,
table_format: str = "grid",
) -> tuple[list[str], OpsData]:
"""
Reads all ops for specific *devices* and prints a table given *table_format* style. Specific ops can be filtered
using *filter_ops*.
"""
# read ops
ops = OpsData(devices)

# get parsed devices
devices = [
device
for device in ops.device_ids
if any(
op_data.get(device)
for op_name, op_data in ops.items()
if not filter_ops or op_name in filter_ops
)
]
devices = sorted(set(devices), key=devices.index)

# prepare the table
headers = ["Operation"] + devices
content = []
str_flag = lambda b: "yes" if b else "NO"
for op_name, op_data in ops.items():
if filter_ops and op_name not in filter_ops:
continue

content.append([
op_name,
*(str_flag(bool(op_data.get(device))) for device in devices),
])

# print it
print(tabulate.tabulate(content, headers=headers, tablefmt=table_format))

return devices, ops


def main() -> None:
import os
import sys
from argparse import ArgumentParser

parser = ArgumentParser(
prog=f"cmsml_{os.path.splitext(os.path.basename(__file__))[0]}",
description="performs XLA / AOT compatiblity checks on a TensorFlow graph",
)

parser.add_argument(
"model_path",
nargs="?",
help="the path of the model to open",
)
parser.add_argument(
"--serving-key",
"-k",
default="serving_default",
help="serving key of the graph in --model-path; default: serving_default",
)
parser.add_argument(
"--table",
"-t",
action="store_true",
help="just print a table showing which operations are XLA / AOT supported for --devices",
)
parser.add_argument(
"--table-format",
"-f",
default="grid",
help="the tabulate format for printed tables; default: grid",
)
parser.add_argument(
"--devices",
"-d",
type=(lambda s: tuple(s.strip().split(","))),
help="comma separated list of devices to check; choices: cpu,gpu,tpu, default: cpu",
)

args = parser.parse_args()

if args.table:
# print the op table
print_op_table(
devices=args.devices,
table_format=args.table_format,
)

elif args.model_path:
# run the compatibility check
check_aot_compatibility(
model_path=args.model_path,
serving_key=args.serving_key,
devices=args.devices,
table_format=args.table_format,
)

else:
print("either '--model-path PATH' or '--table' must be set", file=sys.stderr)
sys.exit(1)


if __name__ == "__main__":
main()
Loading