Skip to content

Commit

Permalink
[config] turn exponential notation back on for config dump (#955)
Browse files Browse the repository at this point in the history
* e-notation for large floats

* handle ints too

* readability

* handle bool

Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
stas00 and tjruwase authored Apr 14, 2021
1 parent adac058 commit c87118b
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 3 deletions.
3 changes: 2 additions & 1 deletion deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .constants import *
from .fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE
from .config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys
from .config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys, ScientificNotationEncoder
from .zero.config import DeepSpeedZeroConfig
from .zero.constants import *
from .activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig
Expand Down Expand Up @@ -744,6 +744,7 @@ def print(self, name):
json.dumps(self._param_dict,
sort_keys=True,
indent=4,
cls=ScientificNotationEncoder,
separators=(',',
':'))))

Expand Down
42 changes: 40 additions & 2 deletions deepspeed/runtime/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,40 @@
Collection of DeepSpeed configuration utilities
"""
import json
from collections import Counter
from collections import Counter, Mapping, Sequence


# adapted from https://stackoverflow.com/a/50701137/9201239
class ScientificNotationEncoder(json.JSONEncoder):
"""
This class overrides ``json.dumps`` default formatter.
This version keeps everything as normal except formats numbers bigger than 1e3 using scientific notation.
Just pass ``cls=ScientificNotationEncoder`` to ``json.dumps`` to activate it
"""
def iterencode(self, o, _one_shot=False, level=0):
indent = self.indent if self.indent is not None else 4
prefix_close = " " * level * indent
level += 1
prefix = " " * level * indent
if isinstance(o, bool):
return "true" if o else "false"
elif isinstance(o, float) or isinstance(o, int):
if o > 1e3:
return f"{o:e}"
else:
return f"{o}"
elif isinstance(o, Mapping):
x = [
f'\n{prefix}"{k}": {self.iterencode(v, level=level)}' for k,
v in o.items()
]
return "{" + ', '.join(x) + f"\n{prefix_close}" + "}"
elif isinstance(o, Sequence) and not isinstance(o, str):
return f"[{ f', '.join(map(self.iterencode, o)) }]"
return "\n, ".join(super().iterencode(o, _one_shot))


class DeepSpeedConfigObject(object):
Expand All @@ -17,7 +50,12 @@ def repr(self):
return self.__dict__

def __repr__(self):
return json.dumps(self.__dict__, sort_keys=True, indent=4)
return json.dumps(
self.__dict__,
sort_keys=True,
indent=4,
cls=ScientificNotationEncoder,
)


def get_scalar_param(param_dict, param_name, param_default_value):
Expand Down

0 comments on commit c87118b

Please sign in to comment.