diff --git a/.flake8 b/.flake8 index 7530bcb..12f7e57 100644 --- a/.flake8 +++ b/.flake8 @@ -1,6 +1,6 @@ [flake8] -max-line-length = 101 +max-line-length = 120 # codes of errors to ignore ignore = E128, E306, E402, E722, E731, E741, W504, Q003 diff --git a/.github/workflows/lint_and_test.yml b/.github/workflows/lint_and_test.yml index d86608d..5ec6aaa 100644 --- a/.github/workflows/lint_and_test.yml +++ b/.github/workflows/lint_and_test.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout 🛎️ - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: persist-credentials: false @@ -40,7 +40,7 @@ jobs: name: test (image=${{ matrix.versions.tag }}, tf=${{ matrix.versions.tf }}) steps: - name: Checkout 🛎️ - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: persist-credentials: false @@ -48,4 +48,4 @@ jobs: run: docker pull cmsml/cmsml:${{ matrix.versions.tag }} - name: Test 🎰 - run: bash tests/docker.sh cmsml/cmsml:${{ matrix.versions.tag }} "[ '${{ matrix.versions.tf }}' = 'default' ] || pip install -U tensorflow=='${{ matrix.versions.tf }}'; python -m unittest tests" + run: bash tests/docker.sh cmsml/cmsml:${{ matrix.versions.tag }} "[ '${{ matrix.versions.tf }}' = 'default' ] || pip install -U tensorflow=='${{ matrix.versions.tf }}'; pytest -n 2 tests" diff --git a/README.md b/README.md index 62456f0..266a272 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,7 @@ To use the cmsml package via docker, checkout our [DockerHub](https://hub.docker The tests can be triggered with ```shell -python -m unittest tests +pytest -n auto tests ``` and in general, they should be run for Python 3.7 to 3.11. diff --git a/cmsml/scripts/__init__.py b/cmsml/scripts/__init__.py index 57d631c..eb7757e 100644 --- a/cmsml/scripts/__init__.py +++ b/cmsml/scripts/__init__.py @@ -1 +1,6 @@ # coding: utf-8 + +__all__ = ["compile_tf_graph", "aot_compile"] + +# provisioning imports +from cmsml.scripts.compile_tf_graph import compile_tf_graph, aot_compile diff --git a/cmsml/scripts/check_aot_compatibility.py b/cmsml/scripts/check_aot_compatibility.py new file mode 100644 index 0000000..2783443 --- /dev/null +++ b/cmsml/scripts/check_aot_compatibility.py @@ -0,0 +1,163 @@ +# coding: utf-8 + +""" +Script that provides insight on which TensorFlow operations are XLA / AOT compatible and whether a specified graph would +be supported. +""" + +from __future__ import annotations + +import tabulate + +from cmsml.util import colored +from cmsml.tensorflow.aot import OpsData, load_graph_def, get_graph_ops + + +def check_aot_compatibility( + model_path: str, + serving_key: str = "serving_default", + devices: tuple[str] = ("cpu",), + table_format: str = "grid", +) -> None: + """ + Loads model stored in *model_path* and extracts the GraphDef saved under the specified *serving_key*. From this + GraphDef, all ops for specific *devices* are read and compared to all ops with XLA implementation. The matching + result is printed given the chosen *table_format* style. + """ + # open the graph + graph_def = load_graph_def(model_path, serving_key=serving_key) + + # extract operation names + op_names = get_graph_ops(graph_def) + + # remove trivial ops + op_names = [op_name for op_name in op_names if op_name not in ["Placeholder", "NoOp"]] + + # print the op table + devices, ops = print_op_table(devices, filter_ops=op_names, table_format=table_format) + + # print a final summary per device + for device in devices: + failed_ops = [ + op_name + for op_name in op_names + if not ops.get(op_name, {}).get(device) + ] + + msg = f"\n{colored(device, 'magenta')}: " + if failed_ops: + msg += colored("not compatible", "red") + msg += f", {len(failed_ops)} incompatible ops: {', '.join(failed_ops)}" + else: + msg += colored("all ops compatible", "green") + print(msg) + + +def print_op_table( + devices: tuple[str], + filter_ops: list[str] | None = None, + table_format: str = "grid", +) -> tuple[list[str], OpsData]: + """ + Reads all ops for specific *devices* and prints a table given *table_format* style. Specific ops can be filtered + using *filter_ops*. + """ + # read ops + ops = OpsData(devices) + + # get parsed devices + devices = [ + device + for device in ops.device_ids + if any( + op_data.get(device) + for op_name, op_data in ops.items() + if not filter_ops or op_name in filter_ops + ) + ] + devices = sorted(set(devices), key=devices.index) + + # prepare the table + headers = ["Operation"] + devices + content = [] + str_flag = lambda b: "yes" if b else "NO" + for op_name, op_data in ops.items(): + if filter_ops and op_name not in filter_ops: + continue + + content.append([ + op_name, + *(str_flag(bool(op_data.get(device))) for device in devices), + ]) + + # print it + print(tabulate.tabulate(content, headers=headers, tablefmt=table_format)) + + return devices, ops + + +def main() -> None: + import os + import sys + from argparse import ArgumentParser + + parser = ArgumentParser( + prog=f"cmsml_{os.path.splitext(os.path.basename(__file__))[0]}", + description="performs XLA / AOT compatiblity checks on a TensorFlow graph", + ) + + parser.add_argument( + "model_path", + nargs="?", + help="the path of the model to open", + ) + parser.add_argument( + "--serving-key", + "-k", + default="serving_default", + help="serving key of the graph in --model-path; default: serving_default", + ) + parser.add_argument( + "--table", + "-t", + action="store_true", + help="just print a table showing which operations are XLA / AOT supported for --devices", + ) + parser.add_argument( + "--table-format", + "-f", + default="grid", + help="the tabulate format for printed tables; default: grid", + ) + parser.add_argument( + "--devices", + "-d", + type=(lambda s: tuple(s.strip().split(","))), + help="comma separated list of devices to check; choices: cpu,gpu,tpu, default: cpu", + ) + + args = parser.parse_args() + + if args.table: + # print the op table + print_op_table( + devices=args.devices, + table_format=args.table_format, + ) + + elif args.model_path: + # run the compatibility check + check_aot_compatibility( + model_path=args.model_path, + serving_key=args.serving_key, + devices=args.devices, + table_format=args.table_format, + ) + + else: + print("either '--model-path PATH' or '--table' must be set", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/cmsml/scripts/compile_tf_graph.py b/cmsml/scripts/compile_tf_graph.py new file mode 100644 index 0000000..6d990fb --- /dev/null +++ b/cmsml/scripts/compile_tf_graph.py @@ -0,0 +1,227 @@ +# coding: utf-8 + +""" +Script that reads a tensorflow graph from a model file and ahead-of-time compiles it for selected batch-sizes using XLA. +""" + +from __future__ import annotations + +import os + +from cmsml.util import colored, interruptable_popen +from cmsml.tensorflow.tools import import_tf, load_model + + +def compile_tf_graph( + model_path: str, + output_path: str, + batch_sizes: tuple[int] = (1,), + input_serving_key: str = "serving_default", + output_serving_key: str | None = None, + compile_prefix: str | None = None, + compile_class: str | None = None, +) -> None: + """ + For AOT compilation a static memory layout at runtime is required. This function prepares the given input SavedModel + to make it ready for AOT compilation + + This function takes the subgraph saved under the *input_serving_key* signature within a given SavedModel, stored in + *model_path*, and creates a 'ConcreteFunction' with a static shape for given *batch_sizes*. If not + *input_serving_key* is given the TensorFlow default 'serving_default' is used. + + The resulting static 'ConcreteFunction' is saved as subgraph under a new *output_serving_key* signature in a + SavedModel stored at *output_path*. If no *output_serving_key* is given the 'ConcreteFunction' are saved with the + signature "{*input_serving_key*}_bs{*batch_size*}". + + An optional AOT compilation is initiated if *compile_class* and *compile_prefix* are given. In this case + *compile_prefix* is the file prefix, while *compile_class* is the name of the AOT class within the generated files. + """ + tf = import_tf()[0] + + # default output_serving key + if not output_serving_key: + output_serving_key = input_serving_key + "_bs{}" + + # check compile values + if compile_prefix and not compile_class: + raise ValueError("when compile_prefix is set, compile_class must not be empty") + if compile_class and not compile_prefix: + raise ValueError("when compile_class is set, compile_prefix must not be empty") + + # get the model object + model = load_model(model_path) + + # get the tf function + func = model.signatures[input_serving_key] + + # prepare the output directory + output_path = os.path.expandvars(os.path.expanduser(str(output_path))) + if os.path.isfile(output_path): + raise OSError(f"output_path existing and points to file: {output_path}") + if not os.path.exists(output_path): + os.makedirs(output_path) + + # create concrete functions per batch size + c_funcs = {} + for bs in sorted(set(map(int, batch_sizes))): + # create a fully defined signature, filling leading None's in shapes with the batch size + specs = {} + for key, spec in model.signatures["serving_default"].structured_input_signature[1].items(): + # ignore inputs without undefined axes + if None not in spec.shape: + continue + # create new shape and name + shape = [ + (bs if n is None else n) + for n in spec.shape + ] + # : is the delimiter of ops numering scheme + name = f"{spec.name.replace(':', '_')}_bs{bs}" + # store the new spec + specs[key] = type(spec)(type(spec.shape)(shape), dtype=spec.dtype, name=name) + + # concrete function + c_funcs[output_serving_key.format(bs)] = tf.function(func).get_concrete_function(**specs) + + # save concrete functions as signatures of the model + tf.saved_model.save(model, output_path, signatures=c_funcs) + print(f"saved model at '{colored(output_path, 'magenta')}'") + + # optionally compile + if compile_prefix and compile_class: + aot_compile( + output_path, + os.path.join(output_path, "aot"), + compile_prefix, + compile_class, + batch_sizes=batch_sizes, + serving_key=output_serving_key, + ) + + +def aot_compile( + model_path: str, + output_path: str, + prefix: str, + class_name: str, + batch_sizes: tuple[int] = (1,), + serving_key: str = r"serving_default_bs{}", +) -> None: + """ + Loads the graph from the SavedModel located at *model_path*, extracts the static graph specified by *serving_key* + from it, AOT compiles it. + + This process generates header and object files at *output_path*. The *class_name* is used as class name within the + header access the AOT-compiled network. + """ + # prepare model path + model_path = os.path.abspath(os.path.expandvars(os.path.expanduser(str(model_path)))) + + # merge output path and prefix, and split them again with prefix being the basename + output_path = os.path.expandvars(os.path.expanduser(str(output_path))) + prefix = os.path.expandvars(os.path.expanduser(str(prefix))) + output_path, prefix = os.path.split(os.path.join(output_path, prefix)) + + # prepare the output directory + if os.path.isfile(output_path): + raise OSError(f"output_path existing and points to file: {output_path}") + if not os.path.exists(output_path): + os.makedirs(output_path) + + # get the compilation executable + exe = _which_saved_model_cli() + + # compile for each batch size + for bs in sorted(set(map(int, batch_sizes))): + cmd = ( + f"{exe} aot_compile_cpu" + f" --dir {model_path}" + f" --signature_def_key {serving_key.format(bs)}" + f" --output_prefix {prefix.format(bs)}" + f" --cpp_class {class_name.format(bs)}" + " --tag_set serve" + ) + + print(f"compiling for batch size {colored(bs, 'magenta')}") + code = interruptable_popen(cmd, executable="/bin/bash", shell=True, cwd=output_path)[0] + if code != 0: + raise Exception(f"aot compilation using {exe} failed with exit code {code}") + + +def _which_saved_model_cli() -> str: + """ + Determines the ``saved_model_cli`` executable that is used for the AOT compilation. + """ + # prefer executable set by CMSML_SAVED_MODEL_CLI + exe = os.getenv("CMSML_SAVED_MODEL_CLI") + if exe: + return exe + + # try usual candidates + for exe in ["saved_model_cli", "saved_model_cli3"]: + cmd = f"type {exe} &> /dev/null" + code = interruptable_popen(cmd, executable="/bin/bash", shell=True)[0] + if code == 0: + return exe + + # return default and let subsequent tools potentially fail + return "saved_model_cli" + + +def main() -> None: + from argparse import ArgumentParser + + parser = ArgumentParser( + prog=f"cmsml_{os.path.splitext(os.path.basename(__file__))[0]}", + description="ahead-of-time (AOT) compiles TensorFlow graphs for fixed batch sizes with XLA", + ) + + parser.add_argument( + "model_path", + help="the path of the model to open", + ) + parser.add_argument( + "output_path", + help="the path where compiled models should be stored", + ) + parser.add_argument( + "--batch-sizes", + "-b", + default=(1,), + type=(lambda s: tuple(map(int, s.strip().split(",")))), + help="comma-separated list of batch sizes to convert the model for; default: 1", + ) + parser.add_argument( + "--input-serving-key", + default="serving_default", + help="serving key of the model in --src; default: serving_default", + ) + parser.add_argument( + "--output-serving-key", + help=r"serving key pattern for concrete models in --output-path, with {} being replaced by " + r"the batch size; default: __bs{}", + ) + parser.add_argument( + "--compile", + "-c", + nargs=2, + help=r"file name prefix and class name of the AOT compiled objects; in both values, {} is " + "replaced by the batch size; no AOT compilation is triggered when empty; files will be " + "saved at /aot/{.h,.o,_metadata.o,_makefile.inc}", + ) + + args = parser.parse_args() + + compile_tf_graph( + model_path=args.model_path, + output_path=args.output_path, + batch_sizes=args.batch_sizes, + input_serving_key=args.input_serving_key, + output_serving_key=args.output_serving_key, + compile_prefix=args.compile and args.compile[0], + compile_class=args.compile and args.compile[1], + ) + + +if __name__ == "__main__": + main() diff --git a/cmsml/scripts/open_tf_graph.py b/cmsml/scripts/open_tf_graph.py index a5dedcc..5b4febf 100644 --- a/cmsml/scripts/open_tf_graph.py +++ b/cmsml/scripts/open_tf_graph.py @@ -24,7 +24,7 @@ def main(): parser.add_argument( "graph_path", - help="the path to the graph to open", + help="the path of the graph to open", ) parser.add_argument( "--log-dir", diff --git a/cmsml/tensorflow/__init__.py b/cmsml/tensorflow/__init__.py index 9b0dbfb..9042869 100644 --- a/cmsml/tensorflow/__init__.py +++ b/cmsml/tensorflow/__init__.py @@ -5,8 +5,15 @@ Classes, functions and tools for efficiently working with TensorFlow. """ -__all__ = ["import_tf", "save_graph", "load_graph", "write_graph_summary"] +__all__ = ["import_tf", "save_frozen_graph", "load_frozen_graph", "write_graph_summary", +"load_model", "load_graph_def","OpsData", "get_graph_ops"] # provisioning imports -from cmsml.tensorflow.tools import import_tf, save_graph, load_graph, write_graph_summary +from cmsml.tensorflow.tools import ( + import_tf, save_frozen_graph, load_frozen_graph, write_graph_summary, load_model, load_graph_def +) + +from cmsml.tensorflow.aot import ( + OpsData, get_graph_ops +) diff --git a/cmsml/tensorflow/aot.py b/cmsml/tensorflow/aot.py new file mode 100644 index 0000000..5f97e7c --- /dev/null +++ b/cmsml/tensorflow/aot.py @@ -0,0 +1,236 @@ +# coding: utf-8 + +""" +Tools and objects for working with AOT / XLA. +""" + +from __future__ import annotations + +import sys +import re +from subprocess import PIPE + +from cmsml.util import interruptable_popen +from cmsml.tensorflow.tools import import_tf + +tf = import_tf()[0] + +from tensorflow.core.framework.graph_pb2 import GraphDef + + +class OpsData(object): + """ + AOT needs two requirements to work: + 1) the outcome of an ops-kernel needs to be deterministic + 2) the ops-kernel needs to have an XLA implementation. + + Tensorflow can return a markdown table containing all XLA compatible ops. + This class is a wrapper to create this table and consequently read it. + """ + + device_ids = { + "cpu": "XLA_CPU_JIT", + "gpu": "XLA_GPU_JIT", + } + + def __init__(self: OpsData, devices: tuple[str] | None = None) -> None: + """ + Sets an iterable of *devices* for which the XLA operations table should be generate. + """ + super().__init__() + + # store operation data in a nested dict + self._ops = {} + + # determine ops + if not devices: + devices = () + elif not isinstance(devices, (list, tuple, set)): + devices = (devices,) + self._determine_ops(devices) + + @classmethod + def _assert_device_supported(cls, device: str) -> None: + if device not in cls.device_ids: + raise ValueError( + f"{device} not in supported devices {list(cls.device_ids.keys())}", + ) + + @classmethod + def read_ops_table( + cls, + device: str = "cpu", + ) -> str: + """ + Generate a markdown table for *device* and returns it. + """ + cls._assert_device_supported(device) + + # tf2xla_supported_ops prints the table + # catch the stdout put stream and decode into str + cmd = f"tf2xla_supported_ops --device={cls.device_ids[device]}" + code, out, _ = interruptable_popen(cmd, stdout=PIPE, executable="/bin/bash", shell=True) + if code != 0: + raise Exception(f"tf2xla_supported_ops command failed with exit code {code}") + + return out + + @classmethod + def parse_ops_table( + cls, + table: str | None = None, + *, + device: str = "cpu", + ) -> dict[str, dict]: + """ + Read a given markdown-*table* generated with 'tf2xla_supported_ops' and returns a dictionary contaning all ops + with XLA implementation. For a given table the *device* information is ignored and extracted from the table. If + no table is given one will be generate for given *device*. + """ + cls._assert_device_supported(device) + + # create the table if empty + if not table: + table = cls.read_ops_table(device) + else: + with open(table, "r") as txt_file: + table = txt_file.read() + + # split into lines + lines = table.splitlines() + + # first line contains device information + for device, device_id in cls.device_ids.items(): + if device_id in lines[0]: + break + else: + raise ValueError(f"no device string found in table header '{lines[0]}'") + + # read op infos from table lines + ops = {} + content_started = False + cre = re.compile(r"^\`([^\`]+)\`\s+\|\s*(.*)$") + for line in lines[1:]: + line = line.strip() + + # find the beginning of the table + if not content_started: + if line.startswith("---"): + content_started = True + continue + + # check if the end is reached + if not line: + break + + # parse the line + m = cre.match(line) + if not m: + print(f"error parsing table line: {line}", file=sys.stderr) + continue + + op_name, allowed_types = m.groups() + allowed_types = allowed_types.replace("`", "").replace("
", "") + + # save op data + ops[op_name] = { + "name": op_name, + "device": device, + "allowed_types": allowed_types, + } + + return ops + + def _determine_ops(self: OpsData, devices: tuple[str] | None = None) -> None: + """ + Merges multiple tables of different devices into 1 dictionary. + + WARNING: Since its not possible to see from which version the markdown table is generated, try to not mix tables + from different tensorflow versions. + """ + if not devices: + devices = tuple(self.device_ids.keys()) + + # read op dictionaries + all_op_dicts = [ + self.parse_ops_table(device=device) + for device in devices + ] + + # merge + ops = {} + for op_dicts in all_op_dicts: + for op_data in op_dicts.values(): + op_name = op_data["name"] + if op_name not in ops: + ops[op_name] = {} + ops[op_name][op_data["device"]] = op_data["allowed_types"] + + self._ops = ops + + def _get_unique_ops(self: OpsData, device: str | None = None) -> set[str]: + self._assert_device_supported(device) + + return { + op_name + for op_name, op_data in self._ops.items() + if device is None or op_data.get(device) + } + + @property + def cpu_ops(self: OpsData) -> set[str]: + # get unique XLA compatible results for CPU only + return self._get_unique_ops("cpu") + + @property + def gpu_ops(self: OpsData) -> set[str]: + # get unique XLA compatible results for GPU only + return self._get_unique_ops("gpu") + + @property + def ops(self: OpsData) -> set[str]: + # get unique ops that have CPU or GPU implementation + return self._ops + + def __len__(self: OpsData) -> int: + # number of ops + return len(self._ops) + + def __getitem__(self: OpsData, key: str) -> dict: + return self._ops[key] + + def keys(self: OpsData) -> list[str]: + return list(self._ops.keys()) + + def values(self: OpsData) -> list[dict]: + return list(self._ops.values()) + + def items(self: OpsData) -> list[tuple[str, dict]]: + return list(self._ops.items()) + + def get(self: OpsData, *args, **kwargs) -> tuple[str, dict]: + return self._ops.get(*args, **kwargs) + + +def get_graph_ops(graph_def: GraphDef, node_def_number: int = 0) -> list[str]: + """ + Extracts all ops from a *graph_def* and returns them as a list. + If there are multiple ``FunctionDef`` instances in the graph, set *node_def_number* to specify from which GraphDef + the ops should be extracted. + """ + # extract node definition from the graph "library for savedmodels" + num_funcs = len(graph_def.library.function) + # library is empty for graph.pb, but not for SavedModels + if num_funcs == 0: + node_def = graph_def.node + else: + if node_def_number + 1 > num_funcs: + raise AttributeError( + f"node_def_number {node_def_number} does not match amount of {num_funcs} " + "FunctionDef objects in graph", + ) + node_def = graph_def.library.function[node_def_number].node_def + + op_names = [node.op for node in node_def] + + return sorted(set(op_names), key=op_names.index) diff --git a/cmsml/tensorflow/tools.py b/cmsml/tensorflow/tools.py index 8cfd331..08b08a3 100644 --- a/cmsml/tensorflow/tools.py +++ b/cmsml/tensorflow/tools.py @@ -9,8 +9,10 @@ __all__ = [] import os +import warnings from types import ModuleType from typing import Any +from tensorflow.core.framework.graph_pb2 import GraphDef from cmsml.util import MockModule @@ -76,6 +78,31 @@ def save_graph( output_names: list[str] | None = None, *args, **kwargs, +) -> None: + """ + Deprecated. Please use :py:func:`save_frozen_graph`. + """ + warnings.warn( + "save_graph() is deprecated, please use save_frozen_graph() instead", + DeprecationWarning, + ) + return save_frozen_graph( + path, + obj, + variables_to_constants=variables_to_constants, + output_names=output_names, + *args, + **kwargs, + ) + + +def save_frozen_graph( + path: str, + obj: Any, + variables_to_constants: bool = False, + output_names: list[str] | None = None, + *args, + **kwargs, ) -> None: """ Extracts a TensorFlow graph from an object *obj* and saves it at *path*. The graph is optionally @@ -179,6 +206,27 @@ def load_graph( create_session: bool | None = None, session_kwargs: dict | None = None, as_text: bool | None = None, +) -> tf.Graph | tuple[tf.Graph, tf.Session]: + """ + Deprecated. Please use :py:func:`load_frozen_graph`. + """ + warnings.warn( + "load_graph() is deprecated, please use load_frozen_graph() instead", + DeprecationWarning, + ) + return load_frozen_graph( + path=path, + create_session=create_session, + session_kwargs=session_kwargs, + as_text=as_text, + ) + + +def load_frozen_graph( + path: str, + create_session: bool | None = None, + session_kwargs: dict | None = None, + as_text: bool | None = None, ) -> tf.Graph | tuple[tf.Graph, tf.Session]: """ Reads a saved TensorFlow graph from *path* and returns it. When *create_session* is *True*, @@ -192,9 +240,9 @@ def load_graph( .. code-block:: python - graph = load_graph("path/to/model.pb", create_session=False) + graph = load_frozen_graph("path/to/model.pb", create_session=False) - graph, session = load_graph("path/to/model.pb", create_session=True) + graph, session = load_frozen_graph("path/to/model.pb", create_session=True) """ tf, tf1, tf_version = import_tf() path = os.path.expandvars(os.path.expanduser(str(path))) @@ -242,6 +290,61 @@ def load_graph( return graph +def load_graph_def( + model_path: str, + serving_key: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, +) -> GraphDef: + """ + Loads the model saved at *model_path* and returns the GraphDef of it. Supported input types are tensorflow and keras + SavedModels, as well as frozen graphs. + """ + tf, tf1, tf_version = import_tf() + + model_path = os.path.expandvars(os.path.expanduser(str(model_path))) + + # if model_path is directory try load as saved model + if os.path.isdir(model_path) and tf.saved_model.contains_saved_model(model_path): + # if keras model try to load as keras model + # else load as tensorflow saved model + loaded_saved_model = load_model(model_path) + + # extract graph + if serving_key not in loaded_saved_model.signatures: + raise KeyError( + f"no graph with serving key '{serving_key}' in model, " + f"existing keys: {', '.join(list(loaded_saved_model.signatures))}", + ) + # loaded_saved_model.signatures[serving_key].function_def.node_def + return loaded_saved_model.signatures[serving_key].graph.as_graph_def() + + # load as frozen graph + if os.path.splitext(model_path)[1] == ".pb": # pb.txt pbtxt?? TODO + with tf.io.gfile.GFile(str(model_path), "rb") as f: + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(f.read()) + + return graph_def + + raise FileNotFoundError(f"{model_path} contains neither frozen graph nor SavedModel") + + +def load_model(model_path: str) -> tf.Model: + """ + Load and return the SavedModel stored at *model_path*. If the model was saved using keras it will be loaded using + keras SavedModel API, otherwise tensorflow's SavedModel API is used. + """ + tf, tf1, tf_version = import_tf() + + model_path = os.path.expandvars(os.path.expanduser(str(model_path))) + + if os.path.isdir(model_path) and os.path.exists(os.path.join(model_path, "keras_metadata.pb")): + model = tf.keras.models.load_model(model_path) + else: + model = tf.saved_model.load(model_path) + + return model + + def write_graph_summary( graph: tf.Graph, summary_dir: str, @@ -251,7 +354,7 @@ def write_graph_summary( Writes the summary of a *graph* to a directory *summary_dir* using a ``tf.summary.FileWriter`` (v1) or ``tf.summary.create_file_writer`` (v2). This summary can be used later on to visualize the graph via tensorboard. *graph* can be either a graph object or a path to a protobuf file. In - the latter case, :py:func:`load_graph` is used and all *kwargs* are forwarded. + the latter case, :py:func:`load_frozen_graph` is used and all *kwargs* are forwarded. .. note:: When used with TensorFlow v1, eager mode must be disabled. @@ -263,7 +366,7 @@ def write_graph_summary( # read the graph when a string is passed if isinstance(graph, str): - graph = load_graph(graph, create_session=False, **kwargs) + graph = load_frozen_graph(graph, create_session=False, **kwargs) # further handling is version dependent tf, tf1, tf_version = import_tf() diff --git a/cmsml/util.py b/cmsml/util.py index a430892..e152834 100644 --- a/cmsml/util.py +++ b/cmsml/util.py @@ -11,10 +11,15 @@ ] import os +import time import shutil import tempfile import contextlib +import subprocess +import signal import importlib +import six + from collections.abc import MappingView from types import GeneratorType, ModuleType from typing import Any @@ -110,6 +115,102 @@ def tmp_dir(create=True, delete=True, **kwargs): shutil.rmtree(path) +_shell_colors = { + "default": 39, + "black": 30, + "red": 31, + "green": 32, + "yellow": 33, + "blue": 34, + "magenta": 35, + "cyan": 36, +} + + +def colored(s: str, color: str = "white") -> str: + """ + Returns a string *s* in a shell-colored representation. + """ + color_id = _shell_colors.get(color, 39) + return "\033[{}m{}\033[0m".format(color_id, s) + + +def interruptable_popen(*args, **kwargs): + """ interruptable_popen(*args, stdin_callback=None, stdin_delay=0, interrupt_callback=None, kill_timeout=None, **kwargs) # noqa + Shorthand to :py:class:`Popen` followed by :py:meth:`Popen.communicate` which can be interrupted + by *KeyboardInterrupt*. The return code, standard output and standard error are returned in a + 3-tuple. + + *stdin_callback* can be a function accepting no arguments and whose return value is passed to + ``communicate`` after a delay of *stdin_delay* to feed data input to the subprocess. + + *interrupt_callback* can be a function, accepting the process instance as an argument, that is + called immediately after a *KeyboardInterrupt* occurs. After that, a SIGTERM signal is send to + the subprocess to allow it to gracefully shutdown. + + When *kill_timeout* is set, and the process is still alive after that period (in seconds), a + SIGKILL signal is sent to force the process termination. + + All other *args* and *kwargs* are forwarded to the :py:class:`Popen` constructor. + """ + # get kwargs not being passed to Popen + stdin_callback = kwargs.pop("stdin_callback", None) + stdin_delay = kwargs.pop("stdin_delay", 0) + interrupt_callback = kwargs.pop("interrupt_callback", None) + kill_timeout = kwargs.pop("kill_timeout", None) + + # start the subprocess in a new process group + kwargs["preexec_fn"] = os.setsid + p = subprocess.Popen(*args, **kwargs) + + # get stdin + stdin_data = None + if callable(stdin_callback): + if stdin_delay > 0: + time.sleep(stdin_delay) + stdin_data = stdin_callback() + if isinstance(stdin_data, six.string_types): + stdin_data = (stdin_data + "\n").encode("utf-8") + + # handle interrupts + try: + out, err = p.communicate(stdin_data) + except KeyboardInterrupt: + # allow the interrupt_callback to perform a custom process termination + if callable(interrupt_callback): + interrupt_callback(p) + + # when the process is still alive, send SIGTERM to gracefully terminate it + pgid = os.getpgid(p.pid) + if p.poll() is None: + os.killpg(pgid, signal.SIGTERM) + + # when a kill_timeout is set, and the process is still running after that period, + # send SIGKILL to force its termination + if kill_timeout is not None: + target_time = time.perf_counter() + kill_timeout + while target_time > time.perf_counter(): + time.sleep(0.05) + if p.poll() is not None: + # the process terminated, exit the loop + break + else: + # check the status again to avoid race conditions + if p.poll() is None: + os.killpg(pgid, signal.SIGKILL) + + # transparently reraise + raise + + # decode outputs + if out is not None: + out = out.decode("utf-8") + if err is not None: + err = err.decode("utf-8") + + return p.returncode, out, err + + class MockModule(object): """ Mockup object that resembles a module with arbitrarily deep structure such that, e.g., @@ -120,12 +221,13 @@ class MockModule(object): print(tf.Graph) # -> "" - will always succeed at declaration. + will always succeed at declaration time. .. py:attribute:: _name - type: str - The name of the mock module. + type: str + + The name of the mock module. """ def __init__(self, name: str): diff --git a/requirements_dev.txt b/requirements_dev.txt index f16e85e..61c3cc4 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -1,5 +1,5 @@ -flake8>=4.0,<6.0 -flake8-commas>=2.1 -flake8-quotes>=3.3 +flake8~=5.0 +flake8-commas~=2.1 +flake8-quotes~=3.3 pytest-cov>=3.0 pytest-xdist~=3.4.0 diff --git a/tests/__init__.py b/tests/__init__.py index 24bf311..0d80fad 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -52,3 +52,5 @@ def require_nvml(*args, **kwargs): from .test_util import * from .test_tensorflow import * from .test_keras_callbacks import * +from .test_aot import * +from .test_compile_tf_graph import * diff --git a/tests/docker.sh b/tests/docker.sh index 2020896..c391ff2 100755 --- a/tests/docker.sh +++ b/tests/docker.sh @@ -4,7 +4,7 @@ # Arguments: # 1. The docker image, defaults to "cmsml/cmsml". # 2. The test command. When just "i", an interactive bash is started instead of running the tests -# and exiting. Defaults to "python -m unittest tests". +# and exiting. Defaults to "pytest -n auto tests". action() { local this_file="$( [ ! -z "${ZSH_VERSION}" ] && echo "${(%):-%x}" || echo "${BASH_SOURCE[0]}" )" @@ -13,7 +13,7 @@ action() { local image="${1:-cmsml/cmsml}" local cmd="${@:2}" - cmd="${cmd:-python -m unittest tests}" + cmd="${cmd:-pytest -n auto tests}" # tty options local tty_opts="$( [ -t 0 ] && echo "-ti" || echo "-t" )" diff --git a/tests/test_aot.py b/tests/test_aot.py new file mode 100644 index 0000000..5e9abfd --- /dev/null +++ b/tests/test_aot.py @@ -0,0 +1,183 @@ +# coding: utf-8 + +from __future__ import annotations + +import os +import functools +import subprocess +os.environ["CUDA_VISIBLE_DEVICES"] = "" + +import cmsml +from cmsml.util import tmp_dir, tmp_file +# from cmsml.tensorflow.aot import get_graph_ops, OpsData + +from . import CMSMLTestCase + + +# check if the tf2xla_supported_ops command exists +p = subprocess.run("type tf2xla_supported_ops", shell=True) +HAS_TF2XLA_SUPPORTED_OPS = p.returncode == 0 + + +def skip_if_no_tf2xla_supported_ops(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if not HAS_TF2XLA_SUPPORTED_OPS: + print(f"skipping {func.__name__} because tf2xla_supported_ops is not available") + return + return func(*args, **kwargs) + return wrapper + + +class AOTTestCase(CMSMLTestCase): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._tf = None + self._tf1 = None + self._tf_version = None + + @property + def tf(self): + if self._tf is None: + self._tf, self._tf1, self._tf_version = cmsml.tensorflow.import_tf() + return self._tf + + @property + def tf1(self): + if self._tf1 is None: + self._tf, self._tf1, self._tf_version = cmsml.tensorflow.import_tf() + return self._tf1 + + @property + def tf_version(self): + if self._tf_version is None: + self._tf, self._tf1, self._tf_version = cmsml.tensorflow.import_tf() + return self._tf_version + + def create_graph_def(self, create="saved_model", **kwargs): + import cmsml.tensorflow.tools as cmsml_tools + + # helper function to create GraphDef from SavedModel or Graph + tf = self.tf + + model = tf.keras.Sequential() + model.add(tf.keras.layers.InputLayer(input_shape=(10,), dtype=tf.float32, name="input")) + model.add(tf.keras.layers.BatchNormalization(axis=1, renorm=True)) + model.add(tf.keras.layers.Dense(100, activation="tanh")) + model.add(tf.keras.layers.BatchNormalization(axis=1, renorm=True)) + model.add(tf.keras.layers.Dense(3, activation="softmax", name="output")) + + if create == "saved_model": + with tmp_dir(create=False) as keras_path, tmp_dir(create=False) as tf_path: + + tf.saved_model.save(model, tf_path) + model.save(keras_path, overwrite=True, include_optimizer=False) + + default_signature = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY + tf_graph_def = cmsml_tools.load_graph_def(tf_path, default_signature) + keras_graph_def = cmsml_tools.load_graph_def(keras_path, default_signature) + return tf_graph_def, keras_graph_def + + if create == "graph": + concrete_func = tf.function(model).get_concrete_function(tf.ones((2, 10))) + + with tmp_file(suffix=".pb") as pb_path: + cmsml_tools.save_graph(pb_path, concrete_func, variables_to_constants=False) + graph_graph_def = cmsml.tensorflow.load_graph_def( + pb_path, + tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, + ) + return graph_graph_def + + self.assertTrue(False) + + @skip_if_no_tf2xla_supported_ops + def test_get_graph_ops_saved_model(self): + from cmsml.tensorflow.aot import get_graph_ops + + tf_graph_def, keras_graph_def = self.create_graph_def(create="saved_model") + + graph_ops = set(get_graph_ops(tf_graph_def, node_def_number=0)) + expected_ops = { + "AddV2", "BiasAdd", "Const", "Identity", "MatMul", "Mul", "NoOp", "Rsqrt", "Softmax", + "Sub", "Tanh", + } + io_ops = {"ReadVariableOp", "Placeholder"} + + ops_without_io = graph_ops - io_ops + self.assertSetEqual(ops_without_io, expected_ops) + + @skip_if_no_tf2xla_supported_ops + def test_get_graph_ops_graph(self): + from cmsml.tensorflow.aot import get_graph_ops + + concrete_function_graph_def = self.create_graph_def(create="graph") + graph_ops = set(get_graph_ops(concrete_function_graph_def, node_def_number=0)) + + expected_ops = { + "AddV2", "BiasAdd", "Const", "Identity", "MatMul", "Mul", "NoOp", "Rsqrt", "Softmax", + "Sub", "Tanh", + } + io_ops = {"ReadVariableOp", "Placeholder"} + + ops_without_io = graph_ops - io_ops + self.assertSetEqual(ops_without_io, expected_ops) + + +class OpsTestCase(CMSMLTestCase): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._tf = None + self._tf1 = None + self._tf_version = None + + @property + def tf(self): + if self._tf is None: + self._tf, self._tf1, self._tf_version = cmsml.tensorflow.import_tf() + return self._tf + + @property + def tf1(self): + if self._tf1 is None: + self._tf, self._tf1, self._tf_version = cmsml.tensorflow.import_tf() + return self._tf1 + + @property + def tf_version(self): + if self._tf_version is None: + self._tf, self._tf1, self._tf_version = cmsml.tensorflow.import_tf() + return self._tf_version + + @skip_if_no_tf2xla_supported_ops + def test_parse_ops_table(self): + from cmsml.tensorflow.aot import OpsData + + ops_dict = OpsData.parse_ops_table(device="cpu") + expected_ops = ("Abs", "Acosh", "Add", "Atan", "BatchMatMul", "Conv2D") + + # check if ops name and content exist + # since content changes with every version only naiv test is done + for op in expected_ops: + self.assertTrue(bool(ops_dict[op]["allowed_types"])) + + @skip_if_no_tf2xla_supported_ops + def test_determine_ops(self): + from cmsml.tensorflow.aot import OpsData + + # function to merge multiple tables + devices = ("cpu", "gpu") + + ops_data = OpsData(devices) + ops_data_ops = ops_data.ops + # for these ops cpu and gpu implentation are guaranteed + expected_ops = ("Abs", "Acosh", "Add", "Atan", "BatchMatMul", "Conv2D") + + # content for cpu and gpu should not be empty + for op in expected_ops: + for device in devices: + self.assertTrue(bool(ops_data_ops[op][device])) diff --git a/tests/test_compile_tf_graph.py b/tests/test_compile_tf_graph.py new file mode 100644 index 0000000..f912921 --- /dev/null +++ b/tests/test_compile_tf_graph.py @@ -0,0 +1,142 @@ +# coding: utf-8 + +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + +import cmsml +from cmsml.util import tmp_dir + +from . import CMSMLTestCase + + +class TfCompileTestCase(CMSMLTestCase): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._tf = None + self._tf1 = None + self._tf_version = None + + @property + def tf(self): + if self._tf is None: + self._tf, self._tf1, self._tf_version = cmsml.tensorflow.import_tf() + return self._tf + + @property + def tf1(self): + if self._tf1 is None: + self._tf, self._tf1, self._tf_version = cmsml.tensorflow.import_tf() + return self._tf1 + + @property + def tf_version(self): + if self._tf_version is None: + self._tf, self._tf1, self._tf_version = cmsml.tensorflow.import_tf() + return self._tf_version + + def create_test_model(self, tf): + # model with 2 input nodes and 1 output node, with non-static batchsize + x1 = tf.keras.Input(shape=(2,), name="first") + x2 = tf.keras.Input(shape=(3,), name="second") + x3 = tf.keras.Input(shape=(10,), name="third") + + x = tf.concat([x1, x2], axis=1) + a1 = tf.keras.layers.Dense(10, activation="elu")(x) + y = tf.keras.layers.Dense(5, activation="softmax")(a1) + + model = tf.keras.Model(inputs=(x1, x2, x3), outputs=y) + return model + + def test_compile_tf_graph_static_preparation(self): + from cmsml.scripts.compile_tf_graph import compile_tf_graph + + # check only preparation process for aot, but do not aot compile + tf = self.tf + + model = self.create_test_model(tf) + + with tmp_dir(create=False) as model_path, tmp_dir(create=False) as static_saved_model_path: + tf.saved_model.save(model, model_path) + + # throw error if compilation happens with illegal batch size + with self.assertRaises(ValueError): + compile_tf_graph( + model_path=model_path, + output_path=static_saved_model_path, + batch_sizes=[-1], + input_serving_key=tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, + output_serving_key=None, + compile_prefix=None, + compile_class=None, + ) + + batch_sizes = [1, 2] + compile_tf_graph( + model_path=model_path, + output_path=static_saved_model_path, + batch_sizes=batch_sizes, + input_serving_key=tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, + output_serving_key=None, + compile_prefix=None, + compile_class=None, + ) + + # load model + loaded_static_model = cmsml.tensorflow.load_model(static_saved_model_path) + + # check input shape + for batch_size in batch_sizes: + # first entry is empty, second contains inputs tuple(tensorspecs) + key = f"serving_default_bs{batch_size}" + model_static_inputs = loaded_static_model.signatures[key].structured_input_signature[1] + + expected_model_static_inputs = { + f"first_bs{batch_size}": tf.TensorSpec( + shape=(batch_size, 2), + dtype=tf.float32, + name=f"first_bs{batch_size}", + ), + f"second_bs{batch_size}": tf.TensorSpec( + shape=(batch_size, 3), + dtype=tf.float32, + name=f"second_bs{batch_size}", + ), + f"third_bs{batch_size}": tf.TensorSpec( + shape=(batch_size, 10), + dtype=tf.float32, + name=f"third_bs{batch_size}", + ), + } + + self.assertDictEqual(model_static_inputs, expected_model_static_inputs) + + def test_compile_tf_graph_static_aot_compilation(self): + from cmsml.scripts.compile_tf_graph import compile_tf_graph + + # check aot compilation + tf = self.tf + model = self.create_test_model(tf) + + with tmp_dir(create=False) as model_path, tmp_dir(create=False) as static_saved_model_path: + tf.saved_model.save(model, model_path) + + batch_sizes = [1, 2] + compile_tf_graph( + model_path=model_path, + output_path=static_saved_model_path, + batch_sizes=batch_sizes, + input_serving_key=tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, + output_serving_key=None, + compile_prefix="aot_model_bs{}", + compile_class="bs_{}", + ) + + aot_dir = os.path.join(static_saved_model_path, "aot") + for batch_size in batch_sizes: + aot_model_header = os.path.join(aot_dir, "aot_model_bs{}.h".format(batch_size)) + aot_model_object = os.path.join(aot_dir, "aot_model_bs{}.o".format(batch_size)) + + self.assertTrue(os.path.exists(aot_model_object)) + self.assertTrue(os.path.exists(aot_model_header)) diff --git a/tests/test_lazy_loader.py b/tests/test_lazy_loader.py index c445e23..0e0b3cb 100644 --- a/tests/test_lazy_loader.py +++ b/tests/test_lazy_loader.py @@ -13,18 +13,3 @@ class LazyLoaderTestCase(CMSMLTestCase): def test_started(self): self.assertTrue(cmsml.lazy_loader.started()) - - def test_modules(self): - # check if placeholders are in place - for module_name in cmsml.lazy_loader._lazy_modules: - self.assertIsInstance(getattr(cmsml, module_name), cmsml.lazy_loader.ModulePlaceholder) - - # access modules - for module_name in cmsml.lazy_loader._lazy_modules: - module = getattr(cmsml, module_name) - self.assertIsInstance(module.__all__, list) - - # after the first access, the placeholders should have been replaced - for module_name in cmsml.lazy_loader._lazy_modules: - module = getattr(cmsml, module_name) - self.assertNotIsInstance(module, cmsml.lazy_loader.ModulePlaceholder) diff --git a/tests/test_tensorflow.py b/tests/test_tensorflow.py index 9955a39..b2b3f00 100644 --- a/tests/test_tensorflow.py +++ b/tests/test_tensorflow.py @@ -5,6 +5,8 @@ """ import os +import contextlib +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" import cmsml from cmsml.util import tmp_file, tmp_dir @@ -15,9 +17,7 @@ class TensorFlowTestCase(CMSMLTestCase): def __init__(self, *args, **kwargs): - super(TensorFlowTestCase, self).__init__(*args, **kwargs) - - os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + super().__init__(*args, **kwargs) self._tf = None self._tf1 = None @@ -145,35 +145,41 @@ def test_import_tf(self): if tf_version[0] == "1": self.assertEqual(tf, tf1) - def test_save_graph(self): + def test_save_frozen_graph(self): graph, session = self.create_tf1_graph() if graph is None or session is None: return with tmp_file(suffix=".pb") as path: - cmsml.tensorflow.save_graph(path, graph, variables_to_constants=False) + cmsml.tensorflow.save_frozen_graph(path, graph, variables_to_constants=False) self.assertTrue(os.path.exists(path)) with tmp_file(suffix=".pb.txt") as path: - cmsml.tensorflow.save_graph(path, graph, variables_to_constants=False) + cmsml.tensorflow.save_frozen_graph(path, graph, variables_to_constants=False) self.assertTrue(os.path.exists(path)) with tmp_file(suffix=".pb") as path: - cmsml.tensorflow.save_graph(path, graph.as_graph_def(), variables_to_constants=False) + cmsml.tensorflow.save_frozen_graph( + path, graph.as_graph_def(), variables_to_constants=False, + ) self.assertTrue(os.path.exists(path)) with tmp_file(suffix=".pb") as path: - cmsml.tensorflow.save_graph(path, session, variables_to_constants=False) + cmsml.tensorflow.save_frozen_graph(path, session, variables_to_constants=False) self.assertTrue(os.path.exists(path)) with tmp_file(suffix=".pb") as path: - cmsml.tensorflow.save_graph(path, session, variables_to_constants=True, - output_names=["output"]) + cmsml.tensorflow.save_frozen_graph( + path, + session, + variables_to_constants=True, + output_names=["output"], + ) self.assertTrue(os.path.exists(path)) with tmp_file(suffix=".pb") as path: with self.assertRaises(ValueError): - cmsml.tensorflow.save_graph(path, session, variables_to_constants=True) + cmsml.tensorflow.save_frozen_graph(path, session, variables_to_constants=True) self.assertFalse(os.path.exists(path)) def test_save_polymorphic_function_error(self): @@ -181,70 +187,70 @@ def test_save_polymorphic_function_error(self): with self.assertRaises(ValueError): with tmp_file(suffix=".pb") as path: - cmsml.tensorflow.save_graph(path, poly_func, variables_to_constants=False) + cmsml.tensorflow.save_frozen_graph(path, poly_func, variables_to_constants=False) with self.assertRaises(ValueError): with tmp_file(suffix=".pb") as path: - cmsml.tensorflow.save_graph(path, poly_func, variables_to_constants=True) + cmsml.tensorflow.save_frozen_graph(path, poly_func, variables_to_constants=True) def test_save_empty_polymorphic_function(self): empty_poly_func = self.create_tf_function(no_input=True) with tmp_file(suffix=".pb") as path: - cmsml.tensorflow.save_graph(path, empty_poly_func, variables_to_constants=False) + cmsml.tensorflow.save_frozen_graph(path, empty_poly_func, variables_to_constants=False) self.assertTrue(os.path.exists(path)) with tmp_file(suffix=".pb") as path: - cmsml.tensorflow.save_graph(path, empty_poly_func, variables_to_constants=True) + cmsml.tensorflow.save_frozen_graph(path, empty_poly_func, variables_to_constants=True) self.assertTrue(os.path.exists(path)) def test_save_frozen_polymorphic_function(self): frozen_poly_func = self.create_tf_function(frozen=True) with tmp_file(suffix=".pb") as path: - cmsml.tensorflow.save_graph(path, frozen_poly_func, variables_to_constants=False) + cmsml.tensorflow.save_frozen_graph(path, frozen_poly_func, variables_to_constants=False) self.assertTrue(os.path.exists(path)) with tmp_file(suffix=".pb.txt") as path: - cmsml.tensorflow.save_graph(path, frozen_poly_func, variables_to_constants=False) + cmsml.tensorflow.save_frozen_graph(path, frozen_poly_func, variables_to_constants=False) self.assertTrue(os.path.exists(path)) with tmp_file(suffix=".pb") as path: - cmsml.tensorflow.save_graph(path, frozen_poly_func, variables_to_constants=True) + cmsml.tensorflow.save_frozen_graph(path, frozen_poly_func, variables_to_constants=True) self.assertTrue(os.path.exists(path)) def test_save_concrete_function(self): concrete_func = self.create_tf_function(concrete=True) with tmp_file(suffix=".pb") as path: - cmsml.tensorflow.save_graph(path, concrete_func, variables_to_constants=False) + cmsml.tensorflow.save_frozen_graph(path, concrete_func, variables_to_constants=False) self.assertTrue(os.path.exists(path)) with tmp_file(suffix=".pb.txt") as path: - cmsml.tensorflow.save_graph(path, concrete_func, variables_to_constants=False) + cmsml.tensorflow.save_frozen_graph(path, concrete_func, variables_to_constants=False) self.assertTrue(os.path.exists(path)) with tmp_file(suffix=".pb") as path: - cmsml.tensorflow.save_graph(path, concrete_func, variables_to_constants=True) + cmsml.tensorflow.save_frozen_graph(path, concrete_func, variables_to_constants=True) self.assertTrue(os.path.exists(path)) def test_save_keras_model_v1(self): model = self.create_keras_model(self.tf1) with tmp_file(suffix=".pb") as path: - cmsml.tensorflow.save_graph(path, model, variables_to_constants=False) + cmsml.tensorflow.save_frozen_graph(path, model, variables_to_constants=False) self.assertTrue(os.path.exists(path)) with tmp_file(suffix=".pb.txt") as path: - cmsml.tensorflow.save_graph(path, model, variables_to_constants=False) + cmsml.tensorflow.save_frozen_graph(path, model, variables_to_constants=False) self.assertTrue(os.path.exists(path)) with tmp_file(suffix=".pb") as path: - cmsml.tensorflow.save_graph(path, model, variables_to_constants=True) + cmsml.tensorflow.save_frozen_graph(path, model, variables_to_constants=True) self.assertTrue(os.path.exists(path)) with tmp_file(suffix=".pb") as path: - cmsml.tensorflow.save_graph( + cmsml.tensorflow.save_frozen_graph( path, self.tf1.keras.backend.get_session(), variables_to_constants=False, @@ -255,37 +261,37 @@ def test_save_keras_model_v2(self): model = self.create_keras_model(self.tf) with tmp_file(suffix=".pb") as path: - cmsml.tensorflow.save_graph(path, model, variables_to_constants=False) + cmsml.tensorflow.save_frozen_graph(path, model, variables_to_constants=False) self.assertTrue(os.path.exists(path)) with tmp_file(suffix=".pb") as path: - cmsml.tensorflow.save_graph(path, model, variables_to_constants=True) + cmsml.tensorflow.save_frozen_graph(path, model, variables_to_constants=True) self.assertTrue(os.path.exists(path)) - def test_load_graph(self): + def test_load_frozen_graph(self): import google.protobuf as pb concrete_func = self.create_tf_function(concrete=True) with tmp_file(suffix=".pb") as path_pb, tmp_file(suffix=".pb.txt") as path_txt: - cmsml.tensorflow.save_graph(path_txt, concrete_func, variables_to_constants=True) - cmsml.tensorflow.save_graph(path_pb, concrete_func, variables_to_constants=False) + cmsml.tensorflow.save_frozen_graph(path_txt, concrete_func, variables_to_constants=True) + cmsml.tensorflow.save_frozen_graph(path_pb, concrete_func, variables_to_constants=False) self.assertTrue(os.path.exists(path_pb)) self.assertTrue(os.path.exists(path_txt)) - graph = cmsml.tensorflow.load_graph(path_txt) + graph = cmsml.tensorflow.load_frozen_graph(path_txt) self.assertIsInstance(graph, self.tf.Graph) - graph = cmsml.tensorflow.load_graph(path_pb) + graph = cmsml.tensorflow.load_frozen_graph(path_pb) self.assertIsInstance(graph, self.tf.Graph) with self.assertRaises(pb.text_format.ParseError): - cmsml.tensorflow.load_graph(path_pb, as_text=True) + cmsml.tensorflow.load_frozen_graph(path_pb, as_text=True) with self.assertRaises(pb.message.DecodeError): - cmsml.tensorflow.load_graph(path_txt, as_text=False) + cmsml.tensorflow.load_frozen_graph(path_txt, as_text=False) - def test_load_graph_and_run(self): + def test_load_frozen_graph_and_run(self): import numpy as np tf = self.tf1 @@ -294,13 +300,13 @@ def test_load_graph_and_run(self): _, session = self.create_tf1_graph() with tmp_file(suffix=".pb.txt") as path: - cmsml.tensorflow.save_graph( + cmsml.tensorflow.save_frozen_graph( path, session, variables_to_constants=True, output_names=["output"], ) - graph = cmsml.tensorflow.load_graph(path) + graph = cmsml.tensorflow.load_frozen_graph(path) session = self.create_tf1_session(graph) with graph.as_default(): @@ -320,9 +326,45 @@ def test_write_summary(self): self.assertGreater(len(os.listdir(path)), 0) with tmp_file(suffix=".pb") as graph_path: - cmsml.tensorflow.save_graph(graph_path, concrete_func) + cmsml.tensorflow.save_frozen_graph(graph_path, concrete_func) with tmp_dir(create=False) as path: cmsml.tensorflow.write_graph_summary(graph_path, path) self.assertTrue(os.path.exists(path)) self.assertGreater(len(os.listdir(path)), 0) self.assertTrue(os.path.exists(path)) + + @contextlib.contextmanager + def create_saved_model(self, **kwargs): + # helper function to create, saved_model + + model = self.create_keras_model(self.tf) + + with tmp_dir(create=False) as keras_path, tmp_dir(create=False) as tf_path: + self.tf.saved_model.save(model, tf_path) + model.save(keras_path, overwrite=True, include_optimizer=False) + + yield keras_path, tf_path + + def test_load_model(self): + with self.create_saved_model() as paths: + keras_path, tf_path = paths + keras_model = cmsml.tensorflow.load_model(keras_path) + tf_model = cmsml.tensorflow.load_model(tf_path) + + inp = self.tf.ones(shape=(2, 10)) + keras_out, tf_out = keras_model(inp), tf_model(inp) + + expected_shape = self.tf.TensorShape([2, 3]) + + self.assertEqual(keras_out.shape, expected_shape) + self.assertEqual(tf_out.shape, expected_shape) + + def test_load_graph_def(self): + with self.create_saved_model() as paths: + keras_path, tf_path = paths + default_serving_key = self.tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY + tf_graph_def = cmsml.tensorflow.load_graph_def(tf_path, default_serving_key) + keras_graph_def = cmsml.tensorflow.load_graph_def(keras_path, default_serving_key) + + self.assertTrue(isinstance(tf_graph_def, self.tf.compat.v1.GraphDef)) + self.assertTrue(isinstance(keras_graph_def, self.tf.compat.v1.GraphDef))