diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index d7a46994a1..0a4c3139b1 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -14,7 +14,6 @@ import ast import json import os -import pprint import re import time import warnings @@ -37,7 +36,7 @@ from monai.data import load_net_with_metadata, save_net_with_metadata from monai.networks import convert_to_torchscript, copy_model_state, get_state_dict, save_state from monai.utils import check_parent_dir, get_equivalent_dtype, min_version, optional_import -from monai.utils.misc import ensure_tuple +from monai.utils.misc import ensure_tuple, pprint_edges validate, _ = optional_import("jsonschema", name="validate") ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError") @@ -48,6 +47,7 @@ # set BUNDLE_DOWNLOAD_SRC="ngc" to use NGC source in default for bundle download download_source = os.environ.get("BUNDLE_DOWNLOAD_SRC", "github") +PPRINT_CONFIG_N = 5 def _update_args(args: str | dict | None = None, ignore_none: bool = True, **kwargs: Any) -> dict: @@ -88,7 +88,7 @@ def _pop_args(src: dict, *args: Any, **kwargs: Any) -> tuple: def _log_input_summary(tag: str, args: dict) -> None: logger.info(f"--- input summary of monai.bundle.scripts.{tag} ---") for name, val in args.items(): - logger.info(f"> {name}: {pprint.pformat(val)}") + logger.info(f"> {name}: {pprint_edges(val, PPRINT_CONFIG_N)}") logger.info("---\n\n") diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 601a5f10ae..318b16c47c 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -77,6 +77,7 @@ issequenceiterable, list_to_dict, path_to_uri, + pprint_edges, progress_bar, sample_slices, save_obj, diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 583acd3a54..05ef1cb4c7 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -14,6 +14,7 @@ import inspect import itertools import os +import pprint import random import shutil import tempfile @@ -60,6 +61,7 @@ "save_obj", "label_union", "path_to_uri", + "pprint_edges", ] _seed = None @@ -626,3 +628,17 @@ def path_to_uri(path: PathLike) -> str: """ return Path(path).absolute().as_uri() + + +def pprint_edges(val: Any, n_lines: int = 20) -> str: + """ + Pretty print the head and tail ``n_lines`` of ``val``, and omit the middle part if the part has more than 3 lines. + + Returns: the formatted string. + """ + val_str = pprint.pformat(val).splitlines(True) + n_lines = max(n_lines, 1) + if len(val_str) > n_lines * 2 + 3: + hidden_n = len(val_str) - n_lines * 2 + val_str = val_str[:n_lines] + [f"\n ... omitted {hidden_n} line(s)\n\n"] + val_str[-n_lines:] + return "".join(val_str) diff --git a/tests/test_bundle_utils.py b/tests/test_bundle_utils.py index 9d28903f2f..d92f6e517f 100644 --- a/tests/test_bundle_utils.py +++ b/tests/test_bundle_utils.py @@ -20,6 +20,7 @@ from monai.bundle.utils import load_bundle_config from monai.networks.nets import UNet +from monai.utils import pprint_edges from tests.utils import command_line_tests, skip_if_windows metadata = """ @@ -117,5 +118,16 @@ def test_load_config_ts(self): self.assertEqual(p["test_dict"]["b"], "c") +class TestPPrintEdges(unittest.TestCase): + def test_str(self): + self.assertEqual(pprint_edges("", 0), "''") + self.assertEqual(pprint_edges({"a": 1, "b": 2}, 0), "{'a': 1, 'b': 2}") + self.assertEqual( + pprint_edges([{"a": 1, "b": 2}] * 20, 1), + "[{'a': 1, 'b': 2},\n\n ... omitted 18 line(s)\n\n {'a': 1, 'b': 2}]", + ) + self.assertEqual(pprint_edges([{"a": 1, "b": 2}] * 8, 4), pprint_edges([{"a": 1, "b": 2}] * 8, 3)) + + if __name__ == "__main__": unittest.main()