Skip to content

Commit

Permalink
pprint head and tail
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli committed Feb 9, 2023
1 parent 94feae5 commit e8c807e
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 3 deletions.
6 changes: 3 additions & 3 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import ast
import json
import os
import pprint
import re
import time
import warnings
Expand All @@ -37,7 +36,7 @@
from monai.data import load_net_with_metadata, save_net_with_metadata
from monai.networks import convert_to_torchscript, copy_model_state, get_state_dict, save_state
from monai.utils import check_parent_dir, get_equivalent_dtype, min_version, optional_import
from monai.utils.misc import ensure_tuple
from monai.utils.misc import ensure_tuple, pprint_edges

validate, _ = optional_import("jsonschema", name="validate")
ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError")
Expand All @@ -48,6 +47,7 @@

# set BUNDLE_DOWNLOAD_SRC="ngc" to use NGC source in default for bundle download
download_source = os.environ.get("BUNDLE_DOWNLOAD_SRC", "github")
PPRINT_CONFIG_N = 5


def _update_args(args: str | dict | None = None, ignore_none: bool = True, **kwargs: Any) -> dict:
Expand Down Expand Up @@ -88,7 +88,7 @@ def _pop_args(src: dict, *args: Any, **kwargs: Any) -> tuple:
def _log_input_summary(tag: str, args: dict) -> None:
logger.info(f"--- input summary of monai.bundle.scripts.{tag} ---")
for name, val in args.items():
logger.info(f"> {name}: {pprint.pformat(val)}")
logger.info(f"> {name}: {pprint_edges(val, PPRINT_CONFIG_N)}")
logger.info("---\n\n")


Expand Down
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
issequenceiterable,
list_to_dict,
path_to_uri,
pprint_edges,
progress_bar,
sample_slices,
save_obj,
Expand Down
16 changes: 16 additions & 0 deletions monai/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import inspect
import itertools
import os
import pprint
import random
import shutil
import tempfile
Expand Down Expand Up @@ -60,6 +61,7 @@
"save_obj",
"label_union",
"path_to_uri",
"pprint_edges",
]

_seed = None
Expand Down Expand Up @@ -626,3 +628,17 @@ def path_to_uri(path: PathLike) -> str:
"""
return Path(path).absolute().as_uri()


def pprint_edges(val, n_lines=20) -> str:
"""
Pretty print the head and tail ``n_lines`` of ``val``, and omit the middle part if the part has more than 3 lines.
Returns: the formatted string.
"""
val_str = pprint.pformat(val).splitlines(True)
n_lines = max(n_lines, 1)
if len(val_str) > n_lines * 2 + 3:
hidden_n = len(val_str) - n_lines * 2
val_str = val_str[:n_lines] + [f"\n ... omitted {hidden_n} line(s)\n\n"] + val_str[-n_lines:]
return "".join(val_str)
12 changes: 12 additions & 0 deletions tests/test_bundle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from monai.bundle.utils import load_bundle_config
from monai.networks.nets import UNet
from monai.utils import pprint_edges
from tests.utils import command_line_tests, skip_if_windows

metadata = """
Expand Down Expand Up @@ -117,5 +118,16 @@ def test_load_config_ts(self):
self.assertEqual(p["test_dict"]["b"], "c")


class TestPPrintEdges(unittest.TestCase):
def test_str(self):
self.assertEqual(pprint_edges("", 0), "''")
self.assertEqual(pprint_edges({"a": 1, "b": 2}, 0), "{'a': 1, 'b': 2}")
self.assertEqual(
pprint_edges([{"a": 1, "b": 2}] * 20, 1),
"[{'a': 1, 'b': 2},\n\n ... omitted 18 line(s)\n\n {'a': 1, 'b': 2}]",
)
self.assertEqual(pprint_edges([{"a": 1, "b": 2}] * 8, 4), pprint_edges([{"a": 1, "b": 2}] * 8, 3))


if __name__ == "__main__":
unittest.main()

0 comments on commit e8c807e

Please sign in to comment.