diff --git a/py/trtorch/csrc/trtorch_py.cpp b/py/trtorch/csrc/trtorch_py.cpp index 84e97a0ecb..39fdad5cde 100644 --- a/py/trtorch/csrc/trtorch_py.cpp +++ b/py/trtorch/csrc/trtorch_py.cpp @@ -309,6 +309,7 @@ PYBIND11_MODULE(_C, m) { .value("WARNING", core::util::logging::LogLevel::kWARNING) .value("INFO", core::util::logging::LogLevel::kINFO) .value("DEBUG", core::util::logging::LogLevel::kDEBUG) + .value("GRAPH", core::util::logging::LogLevel::kGRAPH) .export_values(); } diff --git a/py/trtorch/logging.py b/py/trtorch/logging.py index c14f7b402b..36a9ec3a2c 100644 --- a/py/trtorch/logging.py +++ b/py/trtorch/logging.py @@ -13,6 +13,7 @@ class Level(Enum): Warning = LogLevel.WARNING Info = LogLevel.INFO Debug = LogLevel.DEBUG + Graph = LogLevel.GRAPH @staticmethod def _to_internal_level(external) -> LogLevel: @@ -26,6 +27,8 @@ def _to_internal_level(external) -> LogLevel: return LogLevel.INFO if external == Level.Debug: return LogLevel.DEBUG + if external == Level.Graph: + return LogLevel.GRAPH def get_logging_prefix() -> str: