diff --git a/deepmd/entrypoints/freeze.py b/deepmd/entrypoints/freeze.py index 172c765646..1f816d083e 100755 --- a/deepmd/entrypoints/freeze.py +++ b/deepmd/entrypoints/freeze.py @@ -9,6 +9,7 @@ import logging import google.protobuf.message from deepmd.env import tf, FITTING_NET_PATTERN +from deepmd.utils.errors import GraphTooLargeError from deepmd.utils.sess import run_sess from deepmd.utils.graph import get_pattern_nodes_from_graph_def from os.path import abspath diff --git a/deepmd/env.py b/deepmd/env.py index f8088c69d4..5942b4c062 100644 --- a/deepmd/env.py +++ b/deepmd/env.py @@ -21,6 +21,10 @@ tf.disable_v2_behavior() except ImportError: import tensorflow as tf +try: + import tensorflow.compat.v2 as tfv2 +except ImportError: + tfv2 = None __all__ = [ "GLOBAL_CONFIG", diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index ea75b30bd1..927074c1c6 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -9,7 +9,7 @@ import numpy as np from packaging.version import Version -from deepmd.env import tf +from deepmd.env import tf, tfv2 from deepmd.env import get_tf_session_config from deepmd.env import GLOBAL_TF_FLOAT_PRECISION from deepmd.env import GLOBAL_ENER_FLOAT_PRECISION @@ -226,6 +226,7 @@ def _init_param(self, jdata): self.timing_in_training = tr_data.get('time_training', True) self.profiling = self.run_opt.is_chief and tr_data.get('profiling', False) self.profiling_file = tr_data.get('profiling_file', 'timeline.json') + self.enable_profiler = tr_data.get('enable_profiler', False) self.tensorboard = self.run_opt.is_chief and tr_data.get('tensorboard', False) self.tensorboard_log_dir = tr_data.get('tensorboard_log_dir', 'log') self.tensorboard_freq = tr_data.get('tensorboard_freq', 1) @@ -480,6 +481,9 @@ def train (self, train_data = None, valid_data=None) : else: tb_train_writer = None tb_valid_writer = None + if self.enable_profiler: + # https://www.tensorflow.org/guide/profiler + tfv2.profiler.experimental.start(self.tensorboard_log_dir) train_time = 0 @@ -550,6 +554,8 @@ def train (self, train_data = None, valid_data=None) : chrome_trace = fetched_timeline.generate_chrome_trace_format() with open(self.profiling_file, 'w') as f: f.write(chrome_trace) + if self.enable_profiler and self.run_opt.is_chief: + tfv2.profiler.experimental.stop() def get_feed_dict(self, batch, is_training): feed_dict = {} diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 847eccc52e..3c99b58196 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -629,6 +629,7 @@ def training_args(): # ! modified by Ziyao: data configuration isolated. doc_time_training = 'Timing durining training.' doc_profiling = 'Profiling during training.' doc_profiling_file = 'Output file for profiling.' + doc_enable_profiler = 'Enable TensorFlow Profiler (available in TensorFlow 2.3) to analyze performance. The log will be saved to `tensorboard_log_dir`.' doc_tensorboard = 'Enable tensorboard' doc_tensorboard_log_dir = 'The log directory of tensorboard outputs' doc_tensorboard_freq = 'The frequency of writing tensorboard events.' @@ -651,6 +652,7 @@ def training_args(): # ! modified by Ziyao: data configuration isolated. Argument("time_training", bool, optional=True, default=True, doc=doc_time_training), Argument("profiling", bool, optional=True, default=False, doc=doc_profiling), Argument("profiling_file", str, optional=True, default='timeline.json', doc=doc_profiling_file), + Argument("enable_profiler", bool, optional=True, default=False, doc=doc_enable_profiler), Argument("tensorboard", bool, optional=True, default=False, doc=doc_tensorboard), Argument("tensorboard_log_dir", str, optional=True, default='log', doc=doc_tensorboard_log_dir), Argument("tensorboard_freq", int, optional=True, default=1, doc=doc_tensorboard_freq),