Skip to content

Commit

Permalink
copy config in optuna (#844)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Jan 9, 2024
1 parent 3113a20 commit 714c036
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
5 changes: 3 additions & 2 deletions nbs/common.base_auto.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -342,15 +342,16 @@
" import optuna\n",
"\n",
" def objective(trial):\n",
" cfg = config(trial)\n",
" user_cfg = config(trial)\n",
" cfg = deepcopy(user_cfg)\n",
" fitted_model = self._fit_model(\n",
" cls_model=cls_model,\n",
" config=cfg,\n",
" dataset=dataset,\n",
" val_size=val_size,\n",
" test_size=test_size,\n",
" )\n",
" trial.set_user_attr('ALL_PARAMS', cfg)\n",
" trial.set_user_attr('ALL_PARAMS', user_cfg)\n",
" metrics = fitted_model.trainer.callback_metrics\n",
" trial.set_user_attr('METRICS', {\n",
" \"loss\": metrics[\"ptl/val_loss\"],\n",
Expand Down
5 changes: 3 additions & 2 deletions neuralforecast/common/_base_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,15 +307,16 @@ def _optuna_tune_model(
import optuna

def objective(trial):
cfg = config(trial)
user_cfg = config(trial)
cfg = deepcopy(user_cfg)
fitted_model = self._fit_model(
cls_model=cls_model,
config=cfg,
dataset=dataset,
val_size=val_size,
test_size=test_size,
)
trial.set_user_attr("ALL_PARAMS", cfg)
trial.set_user_attr("ALL_PARAMS", user_cfg)
metrics = fitted_model.trainer.callback_metrics
trial.set_user_attr(
"METRICS",
Expand Down

0 comments on commit 714c036

Please sign in to comment.