Skip to content

Commit

Permalink
Allow ignoring cache per-estimator, closes #16
Browse files Browse the repository at this point in the history
  • Loading branch information
dunnkers committed Jun 12, 2021
1 parent c1aa265 commit 07e631d
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 5 deletions.
26 changes: 22 additions & 4 deletions fseval/pipeline/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fseval.types import (
AbstractEstimator,
AbstractStorageProvider,
CacheUsage,
IncompatibilityError,
Task,
)
Expand All @@ -17,7 +18,8 @@
@dataclass
class EstimatorConfig:
estimator: Any = None # must have _target_ of type BaseEstimator.
use_cache_if_available: bool = True
load_cache: CacheUsage = CacheUsage.allow
save_cache: CacheUsage = CacheUsage.allow
# tags
multioutput: Optional[bool] = None
multioutput_only: Optional[bool] = None
Expand All @@ -36,7 +38,8 @@ class TaskedEstimatorConfig(EstimatorConfig):
name: str = MISSING
classifier: Optional[EstimatorConfig] = None
regressor: Optional[EstimatorConfig] = None
use_cache_if_available: bool = True
load_cache: CacheUsage = CacheUsage.allow
save_cache: CacheUsage = CacheUsage.allow
# tags
multioutput: Optional[bool] = False
multioutput_only: Optional[bool] = False
Expand Down Expand Up @@ -75,16 +78,31 @@ def _get_class_repr(cls, estimator):
return f"{module_name}.{class_name}"

def _load_cache(self, filename: str, storage_provider: AbstractStorageProvider):
if self.load_cache == CacheUsage.never:
self.logger.info("ignoring cache load completely.")
return

restored = storage_provider.restore_pickle(filename)
self.estimator = restored or self.estimator
self._is_fitted = bool(restored)

if self.load_cache == CacheUsage.must:
assert self._is_fitted, (
"Cache usage was set to 'must' but loading cached estimator failed."
+ " Pickle file might be corrupt or could not be found."
)

def _save_cache(self, filename: str, storage_provider: AbstractStorageProvider):
storage_provider.save_pickle(filename, self.estimator)
if self.save_cache == CacheUsage.never:
self.logger.info("ignoring cache save completely.")
return
else:
storage_provider.save_pickle(filename, self.estimator)
# TODO check whether file was successfully saved.

def fit(self, X, y):
# don't refit if cache available and `use_cache_if_available` is enabled
if self._is_fitted and self.use_cache_if_available:
if self._is_fitted:
self.logger.debug("using estimator from cache.")
return self

Expand Down
2 changes: 1 addition & 1 deletion fseval/storage_providers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class WandbStorageProvider(AbstractStorageProvider):
run_id: Optional[str] - recover from a specific run id."""

resume: Optional[str] = None
local_dir: Optional[str] = None
resume: Optional[str] = None
entity: Optional[str] = None
project: Optional[str] = None
run_id: Optional[str] = None
Expand Down
19 changes: 19 additions & 0 deletions fseval/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,25 @@ class Task(Enum):
classification = 2


class CacheUsage(Enum):
"""
Determines how cache usage is handled. In the case of **loading** caches:
- `allow`: program might use cache; if found and could be restored
- `must`: program should fail if no cache found
- `never`: program should not load cache even if found
When **saving** caches:
- `allow`: program might save cache; no fatal error thrown when fails
- `must`: program must save cache; throws error if fails (e.g. due to out of memory)
- `never`: program does not try to save a cached version
"""

allow = 1
must = 2
never = 3


class IncompatibilityError(Exception):
...

Expand Down
3 changes: 3 additions & 0 deletions tests/integration/pipelines/test_rank_and_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ def cfg(dataset, cv, resample, classifier, ranker, validator):
n_bootstraps=2,
n_jobs=None,
all_features_to_select="range(1, min(50, p) + 1)",
upload_ranking_scores=True,
upload_validation_scores=True,
upload_best_scores=True,
)

cfg = OmegaConf.create(config.__dict__)
Expand Down

0 comments on commit 07e631d

Please sign in to comment.