Skip to content

Commit

Permalink
fix base validation K without cutoffs
Browse files Browse the repository at this point in the history
  • Loading branch information
scne committed Mar 11, 2021
1 parent 98ec4d3 commit dc31120
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions elliot/recommender/base_recommender_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,17 @@ def __init__(self, data, config, params, *args, **kwargs):

self._restore = getattr(self._params.meta, "restore", False)

_cutoff_k = getattr(data.config.evaluation, "cutoffs", [data.config.top_k])
_cutoff_k = _cutoff_k if isinstance(_cutoff_k, list) else [_cutoff_k]
_first_metric = data.config.evaluation.simple_metrics[0] if data.config.evaluation.simple_metrics else ""
self._validation_metric = getattr(self._params.meta, "validation_metric", _first_metric + "@10").split("@")
_default_validation_k = getattr(data.config.evaluation, "cutoffs", [data.config.top_k])[0]
self._validation_metric = getattr(self._params.meta, "validation_metric",
_first_metric + "@" + str(_default_validation_k)).split("@")
if self._validation_metric[0].lower() not in [m.lower()
for m in data.config.evaluation.simple_metrics]:
for m in data.config.evaluation.simple_metrics]:
raise Exception("Validation metric must be in the list of simple metrics")

_cutoff_k = getattr(data.config.evaluation, "cutoffs", [data.config.top_k])
_cutoff_k = _cutoff_k if isinstance(_cutoff_k, list) else [_cutoff_k]
self._validation_k = int(self._validation_metric[1]) if len(self._validation_metric) > 1 else _cutoff_k[0]

if self._validation_k not in _cutoff_k:
raise Exception("Validation cutoff must be in general cutoff values")

Expand Down

0 comments on commit dc31120

Please sign in to comment.