Skip to content

Commit

Permalink
convert decay_rate to stop_lr from old inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
njzjz committed Aug 11, 2021
1 parent e8a9101 commit f9afdc2
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions deepmd/utils/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
from typing import Any, Dict, Optional, Sequence, Union

import numpy as np
from deepmd.common import j_must_have


Expand Down Expand Up @@ -237,6 +238,25 @@ def _jcopy(src: Dict[str, Any], dst: Dict[str, Any], keys: Sequence[str]):
dst[k] = src[k]


def remove_decay_rate(jdata: Dict[str, Any]):
"""convert decay_rate to stop_lr.
Parameters
----------
jdata: Dict[str, Any]
input data
"""
lr = jdata["learning_rate"]
if "decay_rate" in lr:
decay_rate = lr["decay_rate"]
start_lr = lr["start_lr"]
stop_step = jdata["training"]["stop_batch"]
decay_steps = lr["decay_steps"]
stop_lr = np.exp(np.log(decay_rate) * (stop_step / decay_steps)) * start_lr
lr["stop_lr"] = stop_lr
lr.pop("decay_rate")


def convert_input_v1_v2(jdata: Dict[str, Any],
warning: bool = True,
dump: Optional[Union[str, Path]] = None) -> Dict[str, Any]:
Expand All @@ -259,6 +279,9 @@ def convert_input_v1_v2(jdata: Dict[str, Any],

jdata["training"] = new_tr_cfg

# remove deprecated arguments
remove_decay_rate(jdata)

if warning:
_warning_input_v1_v2(dump)
if dump is not None:
Expand Down

0 comments on commit f9afdc2

Please sign in to comment.