Skip to content
This repository has been archived by the owner on Mar 10, 2024. It is now read-only.

Commit

Permalink
[Fork from automl#105] Made CrossValFuncs and HoldOutFuncs class to g…
Browse files Browse the repository at this point in the history
…roup the functions
  • Loading branch information
nabenabe0928 committed Mar 15, 2021
1 parent 5c6ce0b commit b8738b7
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 82 deletions.
22 changes: 10 additions & 12 deletions autoPyTorch/datasets/base_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABCMeta
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast, Callable

import numpy as np

Expand All @@ -13,18 +13,16 @@

from autoPyTorch.constants import CLASSIFICATION_OUTPUTS, STRING_TO_OUTPUT_TYPES
from autoPyTorch.datasets.resampling_strategy import (
CROSS_VAL_FN,
CrossValFuncs,
CrossValTypes,
DEFAULT_RESAMPLING_PARAMETERS,
HOLDOUT_FN,
HoldoutValTypes,
get_cross_validators,
get_holdout_validators,
is_stratified,
HoldOutFuncs,
)
from autoPyTorch.utils.common import FitRequirement, hash_array_or_matrix

BaseDatasetType = Union[Tuple[np.ndarray, np.ndarray], Dataset]
SplitFunc = Callable[[Union[int, float], np.ndarray, Any], List[Tuple[np.ndarray, np.ndarray]]]


def check_valid_data(data: Any) -> None:
Expand Down Expand Up @@ -112,8 +110,8 @@ def __init__(
if not hasattr(train_tensors[0], 'shape'):
type_check(train_tensors, val_tensors)
self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors
self.cross_validators: Dict[str, CROSS_VAL_FN] = {}
self.holdout_validators: Dict[str, HOLDOUT_FN] = {}
self.cross_validators: Dict[str, SplitFunc] = {}
self.holdout_validators: Dict[str, SplitFunc] = {}
self.rng = np.random.RandomState(seed=seed)
self.shuffle = shuffle
self.resampling_strategy = resampling_strategy
Expand All @@ -134,8 +132,8 @@ def __init__(
self.is_small_preprocess = True

# Make sure cross validation splits are created once
self.cross_validators = get_cross_validators(*CrossValTypes)
self.holdout_validators = get_holdout_validators(*HoldoutValTypes)
self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes)
self.holdout_validators = HoldOutFuncs.get_holdout_validators(*HoldoutValTypes)
self.splits = self.get_splits_from_resampling_strategy()

# We also need to be able to transform the data, be it for pre-processing
Expand Down Expand Up @@ -263,7 +261,7 @@ def create_cross_val_splits(
if not isinstance(cross_val_type, CrossValTypes):
raise NotImplementedError(f'The selected `cross_val_type` "{cross_val_type}" is not implemented.')
kwargs = {}
if is_stratified(cross_val_type):
if cross_val_type.is_stratified():
# we need additional information about the data for stratification
kwargs["stratify"] = self.train_tensors[-1]
splits = self.cross_validators[cross_val_type.name](
Expand Down Expand Up @@ -298,7 +296,7 @@ def create_holdout_val_split(
if not isinstance(holdout_val_type, HoldoutValTypes):
raise NotImplementedError(f'The specified `holdout_val_type` "{holdout_val_type}" is not supported.')
kwargs = {}
if is_stratified(holdout_val_type):
if holdout_val_type.is_stratified():
# we need additional information about the data for stratification
kwargs["stratify"] = self.train_tensors[-1]
train, val = self.holdout_validators[holdout_val_type.name](val_share, self._get_indices(), **kwargs)
Expand Down
193 changes: 123 additions & 70 deletions autoPyTorch/datasets/resampling_strategy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import IntEnum
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union, Callable

import numpy as np

Expand All @@ -15,8 +15,12 @@
from typing_extensions import Protocol


SplitFunc = Callable[[Union[int, float], np.ndarray, Any], List[Tuple[np.ndarray, np.ndarray]]]


# Use callback protocol as workaround, since callable with function fields count 'self' as argument
class CROSS_VAL_FN(Protocol):
"""TODO: deprecate soon"""
def __call__(self,
num_splits: int,
indices: np.ndarray,
Expand All @@ -25,26 +29,59 @@ def __call__(self,


class HOLDOUT_FN(Protocol):
"""TODO: deprecate soon"""
def __call__(self, val_share: float, indices: np.ndarray, stratify: Optional[Any]
) -> Tuple[np.ndarray, np.ndarray]:
...


class CrossValTypes(IntEnum):
"""The type of cross validation
This class is used to specify the cross validation function
and is not supposed to be instantiated.
Examples: This class is supposed to be used as follows
>>> cv_type = CrossValTypes.k_fold_cross_validation
>>> print(cv_type.name)
k_fold_cross_validation
>>> for cross_val_type in CrossValTypes:
print(cross_val_type.name, cross_val_type.value)
stratified_k_fold_cross_validation 1
k_fold_cross_validation 2
stratified_shuffle_split_cross_validation 3
shuffle_split_cross_validation 4
time_series_cross_validation 5
"""
stratified_k_fold_cross_validation = 1
k_fold_cross_validation = 2
stratified_shuffle_split_cross_validation = 3
shuffle_split_cross_validation = 4
time_series_cross_validation = 5

def is_stratified(self) -> bool:
stratified = [self.stratified_k_fold_cross_validation,
self.stratified_shuffle_split_cross_validation]
return getattr(self, self.name) in stratified


class HoldoutValTypes(IntEnum):
"""The type of hold out validation (refer to CrossValTypes' doc-string)"""
holdout_validation = 6
stratified_holdout_validation = 7

def is_stratified(self) -> bool:
stratified = [self.stratified_holdout_validation]
return getattr(self, self.name) in stratified


"""TODO: deprecate soon"""
RESAMPLING_STRATEGIES = [CrossValTypes, HoldoutValTypes]

"""TODO: deprecate soon"""
DEFAULT_RESAMPLING_PARAMETERS = {
HoldoutValTypes.holdout_validation: {
'val_share': 0.33,
Expand All @@ -67,15 +104,8 @@ class HoldoutValTypes(IntEnum):
} # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]]


def get_cross_validators(*cross_val_types: CrossValTypes) -> Dict[str, CROSS_VAL_FN]:
cross_validators = {} # type: Dict[str, CROSS_VAL_FN]
for cross_val_type in cross_val_types:
cross_val_fn = globals()[cross_val_type.name]
cross_validators[cross_val_type.name] = cross_val_fn
return cross_validators


def get_holdout_validators(*holdout_val_types: HoldoutValTypes) -> Dict[str, HOLDOUT_FN]:
"""TODO: deprecate soon"""
holdout_validators = {} # type: Dict[str, HOLDOUT_FN]
for holdout_val_type in holdout_val_types:
holdout_val_fn = globals()[holdout_val_type.name]
Expand All @@ -84,70 +114,93 @@ def get_holdout_validators(*holdout_val_types: HoldoutValTypes) -> Dict[str, HOL


def is_stratified(val_type: Union[str, CrossValTypes, HoldoutValTypes]) -> bool:
"""TODO: deprecate soon"""
if isinstance(val_type, str):
return val_type.lower().startswith("stratified")
else:
return val_type.name.lower().startswith("stratified")


def holdout_validation(val_share: float, indices: np.ndarray, **kwargs: Any) -> Tuple[np.ndarray, np.ndarray]:
train, val = train_test_split(indices, test_size=val_share, shuffle=False)
return train, val


def stratified_holdout_validation(val_share: float, indices: np.ndarray, **kwargs: Any) \
-> Tuple[np.ndarray, np.ndarray]:
train, val = train_test_split(indices, test_size=val_share, shuffle=False, stratify=kwargs["stratify"])
return train, val


def shuffle_split_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \
-> List[Tuple[np.ndarray, np.ndarray]]:
cv = ShuffleSplit(n_splits=num_splits)
splits = list(cv.split(indices))
return splits


def stratified_shuffle_split_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \
-> List[Tuple[np.ndarray, np.ndarray]]:
cv = StratifiedShuffleSplit(n_splits=num_splits)
splits = list(cv.split(indices, kwargs["stratify"]))
return splits


def stratified_k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \
-> List[Tuple[np.ndarray, np.ndarray]]:
cv = StratifiedKFold(n_splits=num_splits)
splits = list(cv.split(indices, kwargs["stratify"]))
return splits


def k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) -> List[Tuple[np.ndarray, np.ndarray]]:
"""
Standard k fold cross validation.
:param indices: array of indices to be split
:param num_splits: number of cross validation splits
:return: list of tuples of training and validation indices
"""
cv = KFold(n_splits=num_splits)
splits = list(cv.split(indices))
return splits


def time_series_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \
-> List[Tuple[np.ndarray, np.ndarray]]:
"""
Returns train and validation indices respecting the temporal ordering of the data.
Dummy example: [0, 1, 2, 3] with 3 folds yields
[0] [1]
[0, 1] [2]
[0, 1, 2] [3]
:param indices: array of indices to be split
:param num_splits: number of cross validation splits
:return: list of tuples of training and validation indices
"""
cv = TimeSeriesSplit(n_splits=num_splits)
splits = list(cv.split(indices))
return splits
class HoldOutFuncs():
@staticmethod
def holdout_validation(val_share: float, indices: np.ndarray, **kwargs: Any) -> Tuple[np.ndarray, np.ndarray]:
train, val = train_test_split(indices, test_size=val_share, shuffle=False)
return train, val

@staticmethod
def stratified_holdout_validation(val_share: float, indices: np.ndarray, **kwargs: Any) \
-> Tuple[np.ndarray, np.ndarray]:
train, val = train_test_split(indices, test_size=val_share, shuffle=False, stratify=kwargs["stratify"])
return train, val

@classmethod
def get_holdout_validators(cls, *holdout_val_types: Tuple[HoldoutValTypes]) -> Dict[str, SplitFunc]:

holdout_validators = {
holdout_val_type.name: getattr(cls, holdout_val_type.name)
for holdout_val_type in holdout_val_types
}
return holdout_validators


class CrossValFuncs():
@staticmethod
def shuffle_split_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \
-> List[Tuple[np.ndarray, np.ndarray]]:
cv = ShuffleSplit(n_splits=num_splits)
splits = list(cv.split(indices))
return splits

@staticmethod
def stratified_shuffle_split_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \
-> List[Tuple[np.ndarray, np.ndarray]]:
cv = StratifiedShuffleSplit(n_splits=num_splits)
splits = list(cv.split(indices, kwargs["stratify"]))
return splits

@staticmethod
def stratified_k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \
-> List[Tuple[np.ndarray, np.ndarray]]:
cv = StratifiedKFold(n_splits=num_splits)
splits = list(cv.split(indices, kwargs["stratify"]))
return splits

@staticmethod
def k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \
-> List[Tuple[np.ndarray, np.ndarray]]:
"""
Standard k fold cross validation.
:param indices: array of indices to be split
:param num_splits: number of cross validation splits
:return: list of tuples of training and validation indices
"""
cv = KFold(n_splits=num_splits)
splits = list(cv.split(indices))
return splits

@staticmethod
def time_series_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \
-> List[Tuple[np.ndarray, np.ndarray]]:
"""
Returns train and validation indices respecting the temporal ordering of the data.
Dummy example: [0, 1, 2, 3] with 3 folds yields
[0] [1]
[0, 1] [2]
[0, 1, 2] [3]
:param indices: array of indices to be split
:param num_splits: number of cross validation splits
:return: list of tuples of training and validation indices
"""
cv = TimeSeriesSplit(n_splits=num_splits)
splits = list(cv.split(indices))
return splits

@classmethod
def get_cross_validators(cls, *cross_val_types: CrossValTypes) -> Dict[str, SplitFunc]:
cross_validators = {
cross_val_type.name: getattr(cls, cross_val_type.name)
for cross_val_type in cross_val_types
}
return cross_validators

0 comments on commit b8738b7

Please sign in to comment.