diff --git a/autoPyTorch/evaluation/abstract_evaluator.py b/autoPyTorch/evaluation/abstract_evaluator.py index 2f792b7a8..2af333d11 100644 --- a/autoPyTorch/evaluation/abstract_evaluator.py +++ b/autoPyTorch/evaluation/abstract_evaluator.py @@ -433,34 +433,16 @@ def __init__(self, backend: Backend, self.backend: Backend = backend self.queue = queue - self.datamanager: BaseDataset = self.backend.load_datamanager() - - assert self.datamanager.task_type is not None, \ - "Expected dataset {} to have task_type got None".format(self.datamanager.__class__.__name__) - self.task_type = STRING_TO_TASK_TYPES[self.datamanager.task_type] - self.output_type = STRING_TO_OUTPUT_TYPES[self.datamanager.output_type] - self.issparse = self.datamanager.issparse - self.include = include self.exclude = exclude self.search_space_updates = search_space_updates - self.X_train, self.y_train = self.datamanager.train_tensors - - if self.datamanager.val_tensors is not None: - self.X_valid, self.y_valid = self.datamanager.val_tensors - else: - self.X_valid, self.y_valid = None, None - - if self.datamanager.test_tensors is not None: - self.X_test, self.y_test = self.datamanager.test_tensors - else: - self.X_test, self.y_test = None, None - self.metric = metric self.seed = seed + self._init_datamanager_info() + # Flag to save target for ensemble self.output_y_hat_optimization = output_y_hat_optimization @@ -497,12 +479,6 @@ def __init__(self, backend: Backend, else: raise ValueError('task {} not available'.format(self.task_type)) self.predict_function = self._predict_proba - self.dataset_properties = self.datamanager.get_dataset_properties( - get_dataset_requirements(info=self.datamanager.get_required_dataset_info(), - include=self.include, - exclude=self.exclude, - search_space_updates=self.search_space_updates - )) self.additional_metrics: Optional[List[autoPyTorchMetric]] = None metrics_dict: Optional[Dict[str, List[str]]] = None @@ -542,6 +518,53 @@ def __init__(self, backend: Backend, self.logger.debug("Fit dictionary in Abstract evaluator: {}".format(dict_repr(self.fit_dictionary))) self.logger.debug("Search space updates :{}".format(self.search_space_updates)) + def _init_datamanager_info( + self, + ) -> None: + """ + Initialises instance attributes that come from the datamanager. + For example, + X_train, y_train, etc. + """ + + datamanager: BaseDataset = self.backend.load_datamanager() + + assert datamanager.task_type is not None, \ + "Expected dataset {} to have task_type got None".format(datamanager.__class__.__name__) + self.task_type = STRING_TO_TASK_TYPES[datamanager.task_type] + self.output_type = STRING_TO_OUTPUT_TYPES[datamanager.output_type] + self.issparse = datamanager.issparse + + self.X_train, self.y_train = datamanager.train_tensors + + if datamanager.val_tensors is not None: + self.X_valid, self.y_valid = datamanager.val_tensors + else: + self.X_valid, self.y_valid = None, None + + if datamanager.test_tensors is not None: + self.X_test, self.y_test = datamanager.test_tensors + else: + self.X_test, self.y_test = None, None + + self.resampling_strategy = datamanager.resampling_strategy + + self.num_classes: Optional[int] = getattr(datamanager, "num_classes", None) + + self.dataset_properties = datamanager.get_dataset_properties( + get_dataset_requirements(info=datamanager.get_required_dataset_info(), + include=self.include, + exclude=self.exclude, + search_space_updates=self.search_space_updates + )) + self.splits = datamanager.splits + if self.splits is None: + raise AttributeError(f"create_splits on {datamanager.__class__.__name__} must be called " + f"before the instantiation of {self.__class__.__name__}") + + # delete datamanager from memory + del datamanager + def _init_fit_dictionary( self, logger_port: int, @@ -988,21 +1011,20 @@ def _ensure_prediction_array_sizes(self, prediction: np.ndarray, (np.ndarray): The formatted prediction """ - assert self.datamanager.num_classes is not None, "Called function on wrong task" - num_classes: int = self.datamanager.num_classes + assert self.num_classes is not None, "Called function on wrong task" if self.output_type == MULTICLASS and \ - prediction.shape[1] < num_classes: + prediction.shape[1] < self.num_classes: if Y_train is None: raise ValueError('Y_train must not be None!') classes = list(np.unique(Y_train)) mapping = dict() - for class_number in range(num_classes): + for class_number in range(self.num_classes): if class_number in classes: index = classes.index(class_number) mapping[index] = class_number - new_predictions = np.zeros((prediction.shape[0], num_classes), + new_predictions = np.zeros((prediction.shape[0], self.num_classes), dtype=np.float32) for index in mapping: diff --git a/autoPyTorch/evaluation/test_evaluator.py b/autoPyTorch/evaluation/test_evaluator.py index 0c6da71a9..4d5b0ae91 100644 --- a/autoPyTorch/evaluation/test_evaluator.py +++ b/autoPyTorch/evaluation/test_evaluator.py @@ -145,17 +145,12 @@ def __init__( search_space_updates=search_space_updates ) - if not isinstance(self.datamanager.resampling_strategy, (NoResamplingStrategyTypes)): - resampling_strategy = self.datamanager.resampling_strategy + if not isinstance(self.resampling_strategy, (NoResamplingStrategyTypes)): raise ValueError( f'resampling_strategy for TestEvaluator must be in ' - f'NoResamplingStrategyTypes, but got {resampling_strategy}' + f'NoResamplingStrategyTypes, but got {self.resampling_strategy}' ) - self.splits = self.datamanager.splits - if self.splits is None: - raise AttributeError("create_splits must be called in {}".format(self.datamanager.__class__.__name__)) - def fit_predict_and_loss(self) -> None: split_id = 0 diff --git a/autoPyTorch/evaluation/train_evaluator.py b/autoPyTorch/evaluation/train_evaluator.py index a9313ee9e..9f5150889 100644 --- a/autoPyTorch/evaluation/train_evaluator.py +++ b/autoPyTorch/evaluation/train_evaluator.py @@ -152,16 +152,12 @@ def __init__(self, backend: Backend, queue: Queue, search_space_updates=search_space_updates ) - if not isinstance(self.datamanager.resampling_strategy, (CrossValTypes, HoldoutValTypes)): - resampling_strategy = self.datamanager.resampling_strategy + if not isinstance(self.resampling_strategy, (CrossValTypes, HoldoutValTypes)): raise ValueError( f'resampling_strategy for TrainEvaluator must be in ' - f'(CrossValTypes, HoldoutValTypes), but got {resampling_strategy}' + f'(CrossValTypes, HoldoutValTypes), but got {self.resampling_strategy}' ) - self.splits = self.datamanager.splits - if self.splits is None: - raise AttributeError("Must have called create_splits on {}".format(self.datamanager.__class__.__name__)) self.num_folds: int = len(self.splits) self.Y_targets: List[Optional[np.ndarray]] = [None] * self.num_folds self.Y_train_targets: np.ndarray = np.ones(self.y_train.shape) * np.NaN diff --git a/autoPyTorch/optimizer/smbo.py b/autoPyTorch/optimizer/smbo.py index 7407f6ba5..898afd7f5 100644 --- a/autoPyTorch/optimizer/smbo.py +++ b/autoPyTorch/optimizer/smbo.py @@ -18,7 +18,6 @@ from smac.utils.io.traj_logging import TrajEntry from autoPyTorch.automl_common.common.utils.backend import Backend -from autoPyTorch.datasets.base_dataset import BaseDataset from autoPyTorch.datasets.resampling_strategy import ( CrossValTypes, DEFAULT_RESAMPLING_PARAMETERS, @@ -194,9 +193,8 @@ def __init__(self, super(AutoMLSMBO, self).__init__() # data related self.dataset_name = dataset_name - self.datamanager: Optional[BaseDataset] = None self.metric = metric - self.task: Optional[str] = None + self.backend = backend self.all_supported_metrics = all_supported_metrics @@ -252,21 +250,11 @@ def __init__(self, self.initial_configurations = initial_configurations \ if len(initial_configurations) > 0 else None - def reset_data_manager(self) -> None: - if self.datamanager is not None: - del self.datamanager - self.datamanager = self.backend.load_datamanager() - - if self.datamanager is not None and self.datamanager.task_type is not None: - self.task = self.datamanager.task_type - def run_smbo(self, func: Optional[Callable] = None ) -> Tuple[RunHistory, List[TrajEntry], str]: self.watcher.start_task('SMBO') self.logger.info("Started run of SMBO") - # == first things first: load the datamanager - self.reset_data_manager() # == Initialize non-SMBO stuff # first create a scenario