Skip to content

Commit

Permalink
⚡️Optimized TrainMonitor
Browse files Browse the repository at this point in the history
1) `_score_before_overfit` should be current running mean, not new score
2) `res` should be calculated with `_score_before_overfit`
  • Loading branch information
carefree0910 committed Dec 20, 2020
1 parent bd051a9 commit 91dfc43
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions cflearn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,16 @@ def _log_descend_counter(self, new_score: float, res: float, std: float) -> None
msg_level=logging.DEBUG,
)

def _handle_overfitting(self, new_score: float, res: float, std: float) -> None:
def _handle_overfitting(
self,
new_score: float,
res: float,
mean: float,
std: float,
) -> None:
if self._descend_counter == 0.0:
self.info["save_best"] = True
self._score_before_overfit = new_score
self._score_before_overfit = mean
self._descend_counter += min(self.tolerance_ratio, max(0.0, -res / std - 1.0))
self._log_descend_counter(new_score, res, std)
self.over_fitting_flag = 1
Expand Down Expand Up @@ -311,9 +317,12 @@ def check_terminate(self, new_score: float) -> bool:
if self._plateau_counter > 0:
self._plateau_counter = max(self._plateau_counter - 1, 0)
plateau_updated = True
res = new_score - mean
if res < -std and new_score < self._score_before_overfit - std:
self._handle_overfitting(new_score, res, std)
if math.isinf(self._score_before_overfit):
res = new_score - mean
else:
res = new_score - self._score_before_overfit
if res < -std:
self._handle_overfitting(new_score, res, mean, std)
elif res > std:
self._handle_recovering(improvement, new_score, res, std)
if plateau_updated:
Expand Down

0 comments on commit 91dfc43

Please sign in to comment.