diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index a5c2315f7..000000000 Binary files a/.DS_Store and /dev/null differ diff --git a/.gitignore b/.gitignore index b4b0db9c6..708d90640 100644 --- a/.gitignore +++ b/.gitignore @@ -108,3 +108,6 @@ ENV/ # Poetry lock file poetry.lock + +# MacOS files +.DS_Store \ No newline at end of file diff --git a/README.md b/README.md index 7abacab1b..d6cfcc289 100644 --- a/README.md +++ b/README.md @@ -66,13 +66,13 @@ The signature of the model building function will be used to dynamically determi from scikeras.wrappers import KerasRegressor -def model_building_function(X, n_outputs_, hidden_layer_sizes): +def model_building_function(meta, hidden_layer_sizes): """Dynamically build regressor.""" model = Sequential() - model.add(Dense(X.shape[1], activation="relu", input_shape=X.shape[1:])) + model.add(Dense(meta["X_shape_"][1], activation="relu", input_shape=meta["X_shape_"][1:])) for size in hidden_layer_sizes: model.add(Dense(size, activation="relu")) - model.add(Dense(n_outputs_)) + model.add(Dense(meta["n_outputs_"])) model.compile("adam", loss="mean_squared_error") return model @@ -103,15 +103,15 @@ class MLPRegressor(KerasRegressor): def __init__(self, hidden_layer_sizes=None): self.hidden_layer_sizes = hidden_layer_sizes - def _keras_build_fn(self, X, n_outputs_, hidden_layer_sizes): + def _keras_build_fn(self, meta, hidden_layer_sizes): """Dynamically build regressor.""" if hidden_layer_sizes is None: hidden_layer_sizes = (100, ) model = Sequential() - model.add(Dense(X.shape[1], activation="relu", input_shape=X.shape[1:])) + model.add(Dense(meta["X_shape_"][1], activation="relu", input_shape=meta["X_shape_"][1:])) for size in hidden_layer_sizes: model.add(Dense(size, activation="relu")) - model.add(Dense(n_outputs_)) + model.add(Dense(meta["n_outputs_"])) model.compile("adam", loss=KerasRegressor.r_squared) return model ``` @@ -132,7 +132,7 @@ class MLPRegressor(KerasRegressor): self.hidden_layer_sizes = hidden_layer_sizes super().__init__(**kwargs) # this is very important! - def _keras_build_fn(self, X, n_outputs_, hidden_layer_sizes): + def _keras_build_fn(self, meta, hidden_layer_sizes): ... estimator = MLPRegressor(hidden_layer_sizes=[200], a_kwarg="saveme") @@ -152,7 +152,7 @@ class ChildMLPRegressor(MLPRegressor): self.child_argument = child_argument super().__init__(**kwargs) # this is very important! - def _keras_build_fn(self, X, n_outputs_, hidden_layer_sizes): + def _keras_build_fn(self, meta, hidden_layer_sizes): ... estimator = ChildMLPRegressor(child_argument="hello", a_kwarg="saveme") @@ -218,13 +218,13 @@ class FunctionalAPIMultiOutputClassifier(KerasClassifier): """Functional API Classifier with 2 outputs of different type. """ - def _keras_build_fn(self, X, n_classes_): + def _keras_build_fn(self, meta): inp = Input((4,)) x1 = Dense(100)(inp) binary_out = Dense(1, activation="sigmoid")(x1) - cat_out = Dense(n_classes_[1], activation="softmax")(x1) + cat_out = Dense(meta["n_classes_"][1], activation="softmax")(x1) model = Model([inp], [binary_out, cat_out]) losses = ["binary_crossentropy", "categorical_crossentropy"] @@ -272,7 +272,7 @@ class FunctionalAPIMultiInputClassifier(KerasClassifier): """Functional API Classifier with 2 inputs. """ - def _keras_build_fn(self, n_classes_): + def _keras_build_fn(self, meta): inp1 = Input((1,)) inp2 = Input((3,)) @@ -281,7 +281,7 @@ class FunctionalAPIMultiInputClassifier(KerasClassifier): x3 = Concatenate(axis=-1)([x1, x2]) - cat_out = Dense(n_classes_, activation="softmax")(x3) + cat_out = Dense(meta["n_classes_"], activation="softmax")(x3) model = Model([inp1, inp2], [cat_out]) losses = ["categorical_crossentropy"] @@ -331,9 +331,10 @@ class ClassifierWithCallback(KerasClassifier): """ def __init__(self, tolerance, hidden_dim=None): + super().__init__() self.callbacks = [SentinalCallback(tolerance)] self.hidden_dim = hidden_dim - super().__init__() + def _keras_build_fn(self, hidden_dim): return build_fn_clf(hidden_dim) diff --git a/pyproject.toml b/pyproject.toml index 230efe854..5109924a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,6 @@ docs = ["sphinx", "sphinx_rtd_theme"] [tool.isort] line_length = 79 -force_single_line = true atomic = true include_trailing_comma = true lines_after_imports = 2 diff --git a/scikeras/_utils.py b/scikeras/_utils.py index dcf48e6d4..d4f641998 100644 --- a/scikeras/_utils.py +++ b/scikeras/_utils.py @@ -3,11 +3,12 @@ import random import warnings +from typing import Any, Callable, Dict, Iterable, List, Union + import numpy as np import tensorflow as tf -from sklearn.base import BaseEstimator -from sklearn.base import TransformerMixin +from sklearn.base import BaseEstimator, TransformerMixin from tensorflow.keras.layers import deserialize as deserialize_layer from tensorflow.keras.layers import serialize as serialize_layer from tensorflow.keras.metrics import deserialize as deserialize_metric @@ -155,19 +156,83 @@ def get_metric_full_name(name: str) -> str: # deserialize returns the actual function, then get it's name # to keep a single consistent name for the metric if name == "loss": - # may be passed "loss" from thre training history + # may be passed "loss" from training history return name return getattr(deserialize_metric(name), "__name__") -def _get_default_args(func): - signature = inspect.signature(func) - return { - k: v.default - for k, v in signature.parameters.items() - if v.default is not inspect.Parameter.empty - } +def _windows_upcast_ints( + arr: Union[List[np.ndarray], np.ndarray] +) -> Union[List[np.ndarray], np.ndarray]: + # see tensorflow/probability#886 + def _upcast(x): + return x.astype("int64") if x.dtype == np.int32 else x + + if isinstance(arr, np.ndarray): + return _upcast(arr) + else: + return [_upcast(x_) for x_ in arr] + +def route_params( + params: Dict[str, Any], destination: str, pass_filter: Iterable[str], +) -> Dict[str, Any]: + """Route and trim parameter names. + + Parameters + ---------- + params : Dict[str, Any] + Parameters to route/filter. + destination : str + Destination to route to, ex: `build` or `compile`. + pass_filter: Iterable[str] + Only keys from `params` that are in the iterable are passed. + This does not affect routed parameters. -def _windows_upcast_ints(x: np.ndarray) -> np.ndarray: - return x.astype("int64") if x.dtype == np.int32 else x + Returns + ------- + Dict[str, Any] + Filtered parameters, with any routing prefixes removed. + """ + res = dict() + for key, val in params.items(): + if "__" in key: + # routed param + if key.startswith(destination): + new_key = key[len(destination + "__") :] + res[new_key] = val + else: + # non routed + if pass_filter is None or key in pass_filter: + res[key] = val + return res + + +def has_param(func: Callable, param: str) -> bool: + """[summary] + + Parameters + ---------- + func : Callable + [description] + param : str + [description] + + Returns + ------- + bool + [description] + """ + return any( + p.name == param + for p in inspect.signature(func).parameters.values() + if p.kind in (p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY) + ) + + +def accepts_kwargs(func: Callable) -> bool: + return any( + True + for param in inspect.signature(func).parameters.values() + if param.kind == param.VAR_KEYWORD + ) diff --git a/scikeras/wrappers.py b/scikeras/wrappers.py index f6aaf0ac1..c00c3ef94 100644 --- a/scikeras/wrappers.py +++ b/scikeras/wrappers.py @@ -5,6 +5,7 @@ import warnings from collections import defaultdict +from typing import Any, Dict import numpy as np import tensorflow as tf @@ -14,28 +15,30 @@ from sklearn.metrics import accuracy_score as sklearn_accuracy_score from sklearn.metrics import r2_score as sklearn_r2_score from sklearn.pipeline import make_pipeline -from sklearn.preprocessing import LabelEncoder -from sklearn.preprocessing import OneHotEncoder +from sklearn.preprocessing import LabelEncoder, OneHotEncoder from sklearn.utils.multiclass import type_of_target -from sklearn.utils.validation import _check_sample_weight -from sklearn.utils.validation import check_array -from sklearn.utils.validation import check_X_y +from sklearn.utils.validation import ( + _check_sample_weight, + check_array, + check_X_y, +) from tensorflow.keras.models import Model from tensorflow.python.keras.losses import is_categorical_crossentropy -from tensorflow.python.keras.utils.generic_utils import has_arg from tensorflow.python.keras.utils.generic_utils import ( + has_arg, register_keras_serializable, ) -from ._utils import LabelDimensionTransformer -from ._utils import TFRandomState -from ._utils import _get_default_args -from ._utils import _windows_upcast_ints -from ._utils import get_metric_full_name -from ._utils import make_model_picklable - - -OS_IS_WINDOWS = os.name == "nt" # see tensorflow/probability#886 +from ._utils import ( + LabelDimensionTransformer, + TFRandomState, + _windows_upcast_ints, + accepts_kwargs, + get_metric_full_name, + has_param, + make_model_picklable, + route_params, +) class BaseWrapper(BaseEstimator): @@ -64,10 +67,71 @@ class BaseWrapper(BaseEstimator): "multioutput": True, } + _fit_kwargs = { + # parameters destined to keras.Model.fit + "callbacks", + "batch_size", + "epochs", + "verbose", + "callbacks", + "validation_split", + "shuffle", + "class_weight", + "sample_weight", + "initial_epoch", + "validation_steps", + "validation_batch_size", + "validation_freq", + } + + _predict_kwargs = { + # parameters destined to keras.Model.predict + "batch_size", + "verbose", + "callbacks", + "steps", + } + + _compile_kwargs = { + # parameters destined to keras.Model.compile + "optimizer", + "loss", + "metrics", + "loss_weights", + "weighted_metrics", + "run_eagerly", + } + + _wrapper_params = { + # parameters consumed by the wrappers themselves + "warm_start", + "random_state", + } + + _meta = { + # parameters created by wrappers within `fit` + "_random_state", + "n_features_in_", + "X_dtype_", + "y_dtype_", + "X_shape_", + "y_shape_", + "model_", + "history_", + "is_fitted_", + "n_outputs_", + "model_n_outputs_", + "_user_params", + } + + _routing_prefixes = {"model", "fit", "compile", "predict"} + def __init__( self, - build_fn=None, + model=None, *, + build_fn=None, # for backwards compatibility + warm_start=False, random_state=None, optimizer="rmsprop", loss=None, @@ -78,18 +142,18 @@ def __init__( validation_split=0.0, shuffle=True, run_eagerly=False, + epochs=1, **kwargs, ): - # Get defaults from `build_fn` - if inspect.isfunction(build_fn): - vars(self).update(_get_default_args(build_fn)) if isinstance(build_fn, Model): # ensure prebuilt model can be serialized make_model_picklable(build_fn) # Parse hardcoded params + self.model = model self.build_fn = build_fn + self.warm_start = warm_start self.random_state = random_state self.optimizer = optimizer self.loss = loss @@ -100,119 +164,124 @@ def __init__( self.validation_split = validation_split self.shuffle = shuffle self.run_eagerly = run_eagerly + self.epochs = epochs # Unpack kwargs vars(self).update(**kwargs) + # Save names of kwargs into set + if kwargs: + self._user_params = set(kwargs) + @property def __name__(self): return self.__class__.__name__ - def _check_build_fn(self, build_fn): - """Checks `build_fn`. + @property + def _model_params(self): + return { + k[len("model__") :] + for k in self.get_params() + if "model__" == k[: len("model__")] + or k in getattr(self, "_user_params", set()) + } - Arguments: - build_fn : method or callable class as defined in __init__ + def _check_model_param(self): + """Checks `model` and returns model building + function to use. Raises: - ValueError: if `build_fn` is not valid. + ValueError: if `self.model` is not valid. """ - if build_fn is None: - # no build_fn, use this class' __call__method + model = self.model + build_fn = self.build_fn + if model is None and build_fn is not None: + model = build_fn + warnings.warn( + "`build_fn` will be renamed to `model` in a future release," + " at which point use of `build_fn` will raise an Error instead." + ) + if model is None: + # no model, use this class' _keras_build_fn if not hasattr(self, "_keras_build_fn"): raise ValueError( "If not using the `build_fn` param, " "you must implement `_keras_build_fn`" ) - final_build_fn = getattr(self, "_keras_build_fn") - elif isinstance(build_fn, Model): + final_build_fn = self._keras_build_fn + elif isinstance(model, Model): # pre-built Keras Model def final_build_fn(): - return build_fn + return model - elif inspect.isfunction(build_fn): + elif inspect.isfunction(model): if hasattr(self, "_keras_build_fn"): raise ValueError( "This class cannot implement `_keras_build_fn` if" - " using the `build_fn` parameter" + " using the `model` parameter" ) # a callable method/function - final_build_fn = build_fn - elif ( - callable(build_fn) - and hasattr(build_fn, "__class__") - and hasattr(build_fn.__class__, "__call__") - ): - if hasattr(self, "_keras_build_fn"): - raise ValueError( - "This class cannot implement `_keras_build_fn` if" - " using the `build_fn` parameter" - ) - # an instance of a class implementing __call__ - final_build_fn = build_fn.__call__ + final_build_fn = model else: - raise TypeError("`build_fn` must be a callable or None") + raise TypeError("`model` must be a callable or None") return final_build_fn - def _fit_build_keras_model(self, X, y, **kwargs): + def _build_keras_model(self): """Build the Keras model. This method will process all arguments and call the model building function with appropriate arguments. - Arguments: - X : array-like, shape `(n_samples, n_features)` - Training samples where `n_samples` is the number of samples - and `n_features` is the number of features. - y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)` - True labels for `X`. - **kwargs: dictionary arguments - Legal arguments are the arguments `build_fn`. Returns: - self : object - a reference to the instance that can be chain called - (ex: instance.fit(X,y).transform(X) ) - Raises: - ValueError : In case sample_weight != None and the Keras model's - `fit` method does not support that parameter. + model : tensorflow.keras.Model + Instantiated and compiled keras Model. """ # dynamically build model, i.e. final_build_fn builds a Keras model # determine what type of build_fn to use - final_build_fn = self._check_build_fn(getattr(self, "build_fn", None)) + final_build_fn = self._check_model_param() - # get model arguments - model_args = self._filter_params(final_build_fn) - - # check if the model building function requires X and/or y to be passed - X_y_args = self._filter_params( - final_build_fn, params_to_check={"X": X, "y": y} + # collect parameters + params = self.get_params() + build_params = route_params( + params, + destination="model", + pass_filter=getattr(self, "_user_params", set()), ) - - # filter kwargs - kwargs = self._filter_params(final_build_fn, params_to_check=kwargs) - - # combine all arguments - build_args = { - **model_args, # noqa: E999 - **X_y_args, # noqa: E999 - **kwargs, # noqa: E999 - } + if has_param(final_build_fn, "meta") or accepts_kwargs(final_build_fn): + # build_fn accepts `meta`, add it + meta = route_params( + self.get_meta(), destination=None, pass_filter=self._meta, + ) + build_params["meta"] = meta + if has_param(final_build_fn, "compile_kwargs") or accepts_kwargs( + final_build_fn + ): + # build_fn accepts `compile_kwargs`, add it + compile_kwargs = route_params( + params, destination="compile", pass_filter=self._compile_kwargs + ) + build_params["compile_kwargs"] = compile_kwargs + if has_param(final_build_fn, "params") or accepts_kwargs( + final_build_fn + ): + # build_fn accepts `params`, i.e. all of get_params() + build_params["params"] = self.get_params() # build model if self._random_state is not None: with TFRandomState(self._random_state): - model = final_build_fn(**build_args) + model = final_build_fn(**build_params) else: - model = final_build_fn(**build_args) + model = final_build_fn(**build_params) # make serializable make_model_picklable(model) return model - def _fit_keras_model(self, X, y, sample_weight, warm_start, **kwargs): + def _fit_keras_model(self, X, y, sample_weight, warm_start): """Fits the Keras model. This method will process all arguments and call the Keras @@ -229,9 +298,6 @@ def _fit_keras_model(self, X, y, sample_weight, warm_start, **kwargs): warm_start : bool If ``warm_start`` is True, don't don't overwrite the ``history_`` attribute and append to it instead. - **kwargs: dictionary arguments - Legal arguments are the arguments of the keras model's - `fit` method. Returns: self : object a reference to the instance that can be chain called @@ -240,34 +306,17 @@ def _fit_keras_model(self, X, y, sample_weight, warm_start, **kwargs): ValueError : In case sample_weight != None and the Keras model's `fit` method does not support that parameter. """ - # add `sample_weight` param, required to be explicit by some sklearn - # functions that use inspect.signature on the `score` method - if sample_weight is not None: - # avoid pesky Keras warnings if sample_weight is not used - kwargs.update({"sample_weight": sample_weight}) - - # filter kwargs down to those accepted by self.model_.fit - kwargs = self._filter_params(self.model_.fit, params_to_check=kwargs) - - # get model.fit's arguments (allows arbitrary model use) - fit_args = self._filter_params(self.model_.fit) - - # fit model and save history - # order implies kwargs overwrites fit_args - fit_args = {**fit_args, **kwargs} - - if OS_IS_WINDOWS: + if os.name == "nt": # see tensorflow/probability#886 - X = ( - _windows_upcast_ints(X) - if isinstance(X, np.ndarray) - else [_windows_upcast_ints(x) for x in X] - ) - y = ( - _windows_upcast_ints(y) - if isinstance(y, np.ndarray) - else [_windows_upcast_ints(yi) for yi in y] - ) + X = _windows_upcast_ints(X) + y = _windows_upcast_ints(y) + + # collect parameters + params = self.get_params() + fit_args = route_params( + params, destination="fit", pass_filter=self._fit_kwargs + ) + fit_args["sample_weight"] = sample_weight if self._random_state is not None: with TFRandomState(self._random_state): @@ -296,7 +345,7 @@ def _check_output_model_compatibility(self, y): This is mainly in place to avoid cryptic TF errors. """ # check if this is a multi-output model - if self.keras_expected_n_ouputs_ != len(self.model_.outputs): + if self.model_n_outputs_ != len(self.model_.outputs): raise RuntimeError( "Detected an input of size " "{}, but {} has {} outputs".format( @@ -392,7 +441,10 @@ def preprocess_y(y): extra_args : dictionary of output attributes, ex: n_outputs_ """ - extra_args = dict() + extra_args = { + "y_dtype_": y.dtype, + "y_shape_": y.shape, + } return y, extra_args @@ -432,10 +484,35 @@ def preprocess_X(X): X : unchanged 2D numpy array extra_args : attributes of output `y`. """ - extra_args = dict() + extra_args = { + "X_dtype_": X.dtype, + "X_shape_": X.shape, + } return X, extra_args - def fit(self, X, y, sample_weight=None, warm_start=False, **kwargs): + def fit(self, X, y, sample_weight=None): + """Constructs a new model with `build_fn` & fit the model to `(X, y)`. + + Arguments: + X : array-like, shape `(n_samples, n_features)` + Training samples where `n_samples` is the number of samples + and `n_features` is the number of features. + y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)` + True labels for `X`. + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. The Keras Model must support this. + Returns: + self : object + a reference to the instance that can be chain called + (ex: instance.fit(X,y).transform(X) ) + Raises: + ValueError : In case of invalid shape for `y` argument. + """ + return self._fit( + X=X, y=y, sample_weight=sample_weight, warm_start=self.warm_start + ) + + def _fit(self, X, y, sample_weight=None, warm_start=False): """Constructs a new model with `build_fn` & fit the model to `(X, y)`. Arguments: @@ -448,17 +525,12 @@ def fit(self, X, y, sample_weight=None, warm_start=False, **kwargs): Sample weights. The Keras Model must support this. warm_start : bool, default False If ``warm_start`` is True, don't rebuild the model. - **kwargs: dictionary arguments - Legal arguments are the arguments of the keras model's `fit` - method. Returns: self : object a reference to the instance that can be chain called (ex: instance.fit(X,y).transform(X) ) Raises: ValueError : In case of invalid shape for `y` argument. - ValueError : In case sample_weight != None and the Keras model's - `fit` method does not support that parameter. """ # Handle random state if hasattr(self, "random_state"): @@ -489,7 +561,7 @@ def fit(self, X, y, sample_weight=None, warm_start=False, **kwargs): X, y = self._validate_data(X=X, y=y, reset=reset) # Save input dtype - self.input_dtype_ = y.dtype + self.y_dtype_ = y.dtype if sample_weight is not None: sample_weight = _check_sample_weight( @@ -498,7 +570,7 @@ def fit(self, X, y, sample_weight=None, warm_start=False, **kwargs): # Scikit-Learn expects a 0 in sample_weight to mean # "ignore the sample", but because of how Keras applies # sample_weight to the loss function, this doesn't - # exacly work out (as in, sklearn estimator checks fail + # exactly work out (as in, sklearn estimator checks fail # because the predictions differ by a small margin). # To get around this, we manually delete these samples here zeros = sample_weight == 0 @@ -515,25 +587,28 @@ def fit(self, X, y, sample_weight=None, warm_start=False, **kwargs): ) # pre process X, y - X, _ = self.preprocess_X(X) + X, extra_args = self.preprocess_X(X) + # update self.X_dtype_, self.X_shape_ + for attr_name, attr_val in extra_args.items(): + setattr(self, attr_name, attr_val) y, extra_args = self.preprocess_y(y) # update self.classes_, self.n_outputs_, self.n_classes_ and - # self.cls_type_ + # self.target_type_ for attr_name, attr_val in extra_args.items(): setattr(self, attr_name, attr_val) # build model if (not warm_start) or (not hasattr(self, "model_")): - self.model_ = self._fit_build_keras_model(X, y, **kwargs) + self.model_ = self._build_keras_model() y = self._check_output_model_compatibility(y) # fit model return self._fit_keras_model( - X, y, sample_weight=sample_weight, warm_start=warm_start, **kwargs + X, y, sample_weight=sample_weight, warm_start=warm_start ) - def partial_fit(self, X, y, sample_weight=None, **kwargs): + def partial_fit(self, X, y, sample_weight=None): """ Partially fit a model. @@ -545,9 +620,6 @@ def partial_fit(self, X, y, sample_weight=None, **kwargs): True labels for `X`. sample_weight : array-like of shape (n_samples,), default=None Sample weights. The Keras Model must support this. - **kwargs: dictionary arguments - Legal arguments are the arguments of the keras model's `fit` - method. Returns: self : object @@ -555,22 +627,16 @@ def partial_fit(self, X, y, sample_weight=None, **kwargs): (ex: instance.partial_fit(X, y).transform(X) ) Raises: ValueError : In case of invalid shape for `y` argument. - ValuError : In case sample_weight != None and the Keras model's - `fit` method does not support that parameter. """ - return self.fit( - X, y, sample_weight=sample_weight, warm_start=True, **kwargs - ) + return self._fit(X, y, sample_weight=sample_weight, warm_start=True) - def predict(self, X, **kwargs): + def predict(self, X): """Returns predictions for the given test data. Arguments: X: array-like, shape `(n_samples, n_features)` Test samples where `n_samples` is the number of samples and `n_features` is the number of features. - **kwargs: dictionary arguments - Legal arguments are the arguments of `self.model_.predict`. Returns: preds: array-like, shape `(n_samples,)` @@ -589,20 +655,19 @@ def predict(self, X, **kwargs): X, _ = self.preprocess_X(X) # filter kwargs and get attributes for predict - kwargs = self._filter_params( - self.model_.predict, params_to_check=kwargs + params = self.get_params() + pred_args = route_params( + params, destination="predict", pass_filter=self._predict_kwargs ) - predict_args = self._filter_params(self.model_.predict) # predict with Keras model - pred_args = {**predict_args, **kwargs} y_pred = self.model_.predict(X, **pred_args) # post process y y, _ = self.postprocess_y(y_pred) return y - def score(self, X, y, sample_weight=None, **kwargs): + def score(self, X, y, sample_weight=None): """Returns the mean accuracy on the given test data and labels. Arguments: @@ -613,8 +678,6 @@ def score(self, X, y, sample_weight=None, **kwargs): True labels for `X`. sample_weight : array-like of shape (n_samples,), default=None Sample weights. The Keras Model must support this. - **kwargs: dictionary arguments - Legal arguments are those of `self.model_.evaluate`. Returns: score: float @@ -633,28 +696,51 @@ def score(self, X, y, sample_weight=None, **kwargs): y = check_array(y, ensure_2d=False) # compute Keras model score - y_pred = self.predict(X, **kwargs) + y_pred = self.predict(X) - return self._scorer(y, y_pred, sample_weight=sample_weight) + # filter kwargs and get attributes for score + params = self.get_params() + score_args = route_params( + params, destination="score", pass_filter=set() + ) - def _filter_params(self, fn, params_to_check=None): - """Filters all instance attributes (parameters) and - returns those in `fn`'s arguments. + return self.scorer( + y, y_pred, sample_weight=sample_weight, **score_args + ) - Arguments: - fn : arbitrary function - params_to_check : dictionary, parameters to check. - Defaults to checking all attributes of this estimator. + def get_meta(self) -> Dict[str, Any]: + """Get meta parameters (parameters created by fit, like + n_features_in_ or target_type_). - Returns: - res : dictionary containing variables - in both self and `fn`'s arguments. + Returns + ------- + Dict[str, Any] + Dictionary of meta parameters """ - res = {} - for name, value in (params_to_check or self.__dict__).items(): - if has_arg(fn, name): - res.update({name: value}) - return res + return { + k: v + for k, v in self.__dict__.items() + if ( + k not in type(self)().__dict__ + and k not in self.get_params() + and (k.startswith("_") or k.endswith("_")) + ) + } + + def set_params(self, **params) -> "BaseWrapper": + """Override BaseEstimator.set_params to allow setting of routed params. + """ + passthrough = dict() + for param, value in params.items(): + if any( + param.startswith(prefix + "__") + for prefix in self._routing_prefixes + ): + # routed param + setattr(self, param, value) + else: + passthrough[param] = value + return super().set_params(**passthrough) def _get_param_names(self): """Get parameter names for the estimator""" @@ -687,21 +773,46 @@ class KerasClassifier(BaseWrapper): """ _estimator_type = "classifier" - _scorer = staticmethod(sklearn_accuracy_score) - _tags = BaseWrapper._tags.copy() - _tags.update( - { - "multilabel": True, - "_xfail_checks": { - "check_classifiers_classes": "can't meet \ - performance target", - "check_fit_idempotent": "tf does not use \ - sparse tensors", - "check_no_attributes_set_in_init": "can only \ - pass if all params are hardcoded in __init__", - }, - } - ) + _tags = { + "multilabel": True, + "_xfail_checks": { + "check_classifiers_classes": "can't meet \ + performance target", + "check_fit_idempotent": "tf does not use \ + sparse tensors", + "check_no_attributes_set_in_init": "can only \ + pass if all params are hardcoded in __init__", + }, + **BaseWrapper._tags, + } + + _meta = { + "n_classes_", + "target_type_", + "classes_", + "encoders_", + "n_outputs_", + "model_n_outputs_", + *BaseWrapper._meta, + } + + @staticmethod + def scorer(y_true, y_pred, **kwargs) -> float: + """Accuracy score based on true and predicted target values. + + Parameters + ---------- + y_true : array-like + True labels. + y_pred : array-like + Predicted labels. + + Returns + ------- + score + float + """ + return sklearn_accuracy_score(y_true, y_pred, **kwargs) @staticmethod def preprocess_y(y): @@ -716,21 +827,19 @@ def preprocess_y(y): y : modified 2D numpy array with 0 indexed integer class labels. extra_args : dictionary of output attributes, ex `n_outputs_` """ - y, _ = super(KerasClassifier, KerasClassifier).preprocess_y(y) - - cls_type_ = type_of_target(y) + y, extra_args = super(KerasClassifier, KerasClassifier).preprocess_y(y) - input_dtype_ = y.dtype + target_type_ = type_of_target(y) if len(y.shape) == 1: n_outputs_ = 1 else: n_outputs_ = y.shape[1] - if cls_type_ == "binary": + if target_type_ == "binary": # y = array([1, 0, 1, 0]) # single task, single label, binary classification - keras_expected_n_ouputs_ = 1 # single sigmoid output expected + model_n_outputs_ = 1 # single sigmoid output expected # encode encoder = LabelEncoder() # No need to reshape to 1D here, @@ -741,9 +850,9 @@ def preprocess_y(y): encoders_ = [encoder] classes_ = [classes_] y = [y] - elif cls_type_ == "multiclass": + elif target_type_ == "multiclass": # y = array([1, 5, 2]) - keras_expected_n_ouputs_ = 1 # single softmax output expected + model_n_outputs_ = 1 # single softmax output expected # encode encoder = LabelEncoder() if len(y.shape) > 1 and y.shape[1] == 1: @@ -755,13 +864,13 @@ def preprocess_y(y): encoders_ = [encoder] classes_ = [classes_] y = [y] - elif cls_type_ == "multilabel-indicator": + elif target_type_ == "multilabel-indicator": # y = array([1, 1, 1, 0], [0, 0, 1, 1]) # split into lists for multi-output Keras # will be processed as multiple binary classifications classes_ = [np.array([0, 1])] * y.shape[1] y = np.split(y, y.shape[1], axis=1) - keras_expected_n_ouputs_ = len(y) + model_n_outputs_ = len(y) # encode encoders_ = [LabelEncoder() for _ in range(len(y))] y = [ @@ -771,12 +880,12 @@ def preprocess_y(y): for encoder, y_ in zip(encoders_, y) ] classes_ = [encoder.classes_ for encoder in encoders_] - elif cls_type_ == "multiclass-multioutput": + elif target_type_ == "multiclass-multioutput": # y = array([1, 0, 5], [2, 1, 3]) # split into lists for multi-output Keras # each will be processesed as a seperate multiclass problem y = np.split(y, y.shape[1], axis=1) - keras_expected_n_ouputs_ = len(y) + model_n_outputs_ = len(y) # encode encoders_ = [LabelEncoder() for _ in range(len(y))] y = [ @@ -787,7 +896,7 @@ def preprocess_y(y): ] classes_ = [encoder.classes_ for encoder in encoders_] else: - raise ValueError("Unknown label type: {}".format(cls_type_)) + raise ValueError("Unknown label type: {}".format(target_type_)) # self.classes_ is kept as an array when n_outputs>1 for compatibility # with ensembles and other meta estimators @@ -800,15 +909,16 @@ def preprocess_y(y): n_classes_ = [class_.shape[0] for class_ in classes_] n_outputs_ = len(n_classes_) - extra_args = { - "classes_": classes_, - "encoders_": encoders_, - "n_outputs_": n_outputs_, - "keras_expected_n_ouputs_": keras_expected_n_ouputs_, - "n_classes_": n_classes_, - "cls_type_": cls_type_, - "input_dtype_": input_dtype_, - } + extra_args.update( + { + "classes_": classes_, + "encoders_": encoders_, + "n_outputs_": n_outputs_, + "model_n_outputs_": model_n_outputs_, + "n_classes_": n_classes_, + "target_type_": target_type_, + } + ) return y, extra_args @@ -821,13 +931,13 @@ def postprocess_y(self, y): # convert single-target y to a list for easier processing y = [y] - cls_type_ = self.cls_type_ + target_type_ = self.target_type_ class_predictions = [] for i in range(self.n_outputs_): - if cls_type_ == "binary": + if target_type_ == "binary": # array([0.9, 0.1], [.2, .8]) -> array(['yes', 'no']) if ( isinstance(self.encoders_[i], LabelEncoder) @@ -859,7 +969,7 @@ def postprocess_y(self, y): # result from a single sigmoid output # reformat so that we have 2 columns y[i] = np.column_stack([1 - y[i], y[i]]) - elif cls_type_ in ("multiclass", "multiclass-multioutput"): + elif target_type_ in ("multiclass", "multiclass-multioutput"): # array([0.8, 0.1, 0.1], [.1, .8, .1]) -> # array(['apple', 'orange']) idx = np.argmax(y[i], axis=-1) @@ -871,7 +981,7 @@ def postprocess_y(self, y): class_predictions.append( self.encoders_[i].inverse_transform(y_) ) - elif cls_type_ == "multilabel-indicator": + elif target_type_ == "multilabel-indicator": class_predictions.append( self.encoders_[i].inverse_transform( np.argmax(y[i], axis=1) @@ -883,7 +993,7 @@ def postprocess_y(self, y): y = np.squeeze(np.column_stack(class_predictions)) # type cast back to input dtype - y = y.astype(self.input_dtype_, copy=False) + y = y.astype(self.y_dtype_, copy=False) extra_args = {"class_probabilities": class_probabilities} @@ -912,16 +1022,42 @@ def _check_output_model_compatibility(self, y): return super()._check_output_model_compatibility(y) - def predict_proba(self, X, **kwargs): + def partial_fit(self, X, y, classes=None, sample_weight=None): + """ + Partially fit a model. + + Arguments: + X : array-like, shape `(n_samples, n_features)` + Training samples where `n_samples` is the number of samples + and `n_features` is the number of features. + y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)` + True labels for `X`. + classes: ndarray of shape (n_classes,), default=None + Classes across all calls to partial_fit. Can be obtained by via + np.unique(y_all), where y_all is the target vector of the entire dataset. + This argument is required for the first call to partial_fit and can be + omitted in the subsequent calls. Note that y doesn’t need to contain + all labels in classes. + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. The Keras Model must support this. + + Returns: + self : object + a reference to the instance that can be chain called + (ex: instance.partial_fit(X, y).transform(X) ) + Raises: + ValueError : In case of invalid shape for `y` argument. + """ + self.classes_ = classes # TODO: don't swallow this param + return super().partial_fit(X, y, sample_weight=sample_weight) + + def predict_proba(self, X): """Returns class probability estimates for the given test data. Arguments: X: array-like, shape `(n_samples, n_features)` Test samples where `n_samples` is the number of samples and `n_features` is the number of features. - **kwargs: dictionary arguments - Legal arguments are the arguments - of `Sequential.predict_classes`. Returns: proba: array-like, shape `(n_samples, n_outputs)` @@ -943,19 +1079,20 @@ def predict_proba(self, X, **kwargs): # pre process X X, _ = self.preprocess_X(X) - # filter kwargs and get attributes that are inputs to model.predict - kwargs = self._filter_params( - self.model_.predict, params_to_check=kwargs + # collect arguments + predict_args = route_params( + self.get_params(), + destination="predict", + pass_filter=self._predict_kwargs, ) - predict_args = self._filter_params(self.model_.predict) - # call the Keras model - predict_args = {**predict_args, **kwargs} + # call the Keras model's predict outputs = self.model_.predict(X, **predict_args) # join list of outputs into single output array _, extra_args = self.postprocess_y(outputs) + # get class probabilities from postprocess_y's output class_probabilities = extra_args["class_probabilities"] return class_probabilities @@ -966,23 +1103,38 @@ class KerasRegressor(BaseWrapper): """ _estimator_type = "regressor" - _scorer = staticmethod(sklearn_r2_score) - _tags = BaseWrapper._tags.copy() - _tags.update( - { - "multilabel": True, - "_xfail_checks": { - "check_fit_idempotent": "tf does not use sparse tensors", - "check_methods_subset_invariance": "can't meet tol", - "check_no_attributes_set_in_init": "can only pass if all \ - params are hardcoded in __init__", - }, - } - ) + _tags = { + "multilabel": True, + "_xfail_checks": { + "check_fit_idempotent": "tf does not use sparse tensors", + "check_methods_subset_invariance": "can't meet tol", + "check_no_attributes_set_in_init": "can only pass if all \ + params are hardcoded in __init__", + }, + **BaseWrapper._tags, + } + + @staticmethod + def scorer(y_true, y_pred, **kwargs) -> float: + """R^2 score based on true and predicted target values. + + Parameters + ---------- + y_true : array-like + True labels. + y_pred : array-like + Predicted labels. + + Returns + ------- + score + float + """ + return sklearn_r2_score(y_true, y_pred, **kwargs) def postprocess_y(self, y): """Ensures output is floatx and squeeze.""" - if np.can_cast(self.input_dtype_, np.float32): + if np.can_cast(self.y_dtype_, np.float32): return np.squeeze(y.astype(np.float32, copy=False)), dict() else: return np.squeeze(y.astype(np.float64, copy=False)), dict() @@ -990,7 +1142,7 @@ def postprocess_y(self, y): def preprocess_y(self, y): """Split y for multi-output tasks. """ - y, _ = super().preprocess_y(y) + y, extra_args = super().preprocess_y(y) if len(y.shape) == 1: n_outputs_ = 1 @@ -998,18 +1150,17 @@ def preprocess_y(self, y): n_outputs_ = y.shape[1] # for regression, multi-output is handled by single Keras output - keras_expected_n_ouputs_ = 1 + model_n_outputs_ = 1 - extra_args = { - "n_outputs_": n_outputs_, - "keras_expected_n_ouputs_": keras_expected_n_ouputs_, - } + extra_args.update( + {"n_outputs_": n_outputs_, "model_n_outputs_": model_n_outputs_,} + ) y = [y] # pack into single output list return y, extra_args - def score(self, X, y, sample_weight=None, **kwargs): + def score(self, X, y, sample_weight=None): """Returns the mean loss on the given test data and labels. Arguments: @@ -1018,15 +1169,11 @@ def score(self, X, y, sample_weight=None, **kwargs): and `n_features` is the number of features. y: array-like, shape `(n_samples,)` True labels for `X`. - **kwargs: dictionary arguments - Legal arguments are the arguments of `Sequential.evaluate`. Returns: score: float Mean accuracy of predictions on `X` wrt. `y`. """ - res = super(KerasRegressor, self).score(X, y, sample_weight, **kwargs) - # check loss function and warn if it is not the same as score function if self.model_.loss not in ("mean_squared_error", self.r_squared,): warnings.warn( @@ -1036,7 +1183,7 @@ def score(self, X, y, sample_weight=None, **kwargs): "`KerasRegressor.r_squared`." ) - return res + return super().score(X, y, sample_weight=sample_weight) @staticmethod @register_keras_serializable() diff --git a/tests/mlp_models.py b/tests/mlp_models.py index 75a152362..cbc655b54 100644 --- a/tests/mlp_models.py +++ b/tests/mlp_models.py @@ -1,23 +1,22 @@ -from tensorflow.python.keras.layers import Dense -from tensorflow.python.keras.layers import Input -from tensorflow.python.keras.models import Model +from typing import Any, Dict + +from tensorflow.keras.layers import Dense, Input +from tensorflow.keras.models import Model from scikeras.wrappers import KerasRegressor def dynamic_classifier( - n_features_in_, - cls_type_, - n_classes_, - metrics=None, - keras_expected_n_ouputs_=1, - loss=None, - optimizer="sgd", - hidden_layer_sizes=(100,), -): + hidden_layer_sizes, meta: Dict[str, Any], compile_kwargs: Dict[str, Any], +) -> Model: """Creates a basic MLP classifier dynamically choosing binary/multiclass classification loss and ouput activations. """ + # get parameters + n_features_in_ = meta["n_features_in_"] + target_type_ = meta["target_type_"] + n_classes_ = meta["n_classes_"] + model_n_outputs_ = meta["model_n_outputs_"] inp = Input(shape=(n_features_in_,)) @@ -25,45 +24,48 @@ def dynamic_classifier( for layer_size in hidden_layer_sizes: hidden = Dense(layer_size, activation="relu")(hidden) - if cls_type_ == "binary": - loss = loss or "binary_crossentropy" + if target_type_ == "binary": + compile_kwargs["loss"] = ( + compile_kwargs["loss"] or "binary_crossentropy" + ) out = [Dense(1, activation="sigmoid")(hidden)] - elif cls_type_ == "multilabel-indicator": - loss = loss or "binary_crossentropy" + elif target_type_ == "multilabel-indicator": + compile_kwargs["loss"] = ( + compile_kwargs["loss"] or "binary_crossentropy" + ) out = [ Dense(1, activation="sigmoid")(hidden) - for _ in range(keras_expected_n_ouputs_) + for _ in range(model_n_outputs_) ] - elif cls_type_ == "multiclass-multioutput": - loss = loss or "binary_crossentropy" + elif target_type_ == "multiclass-multioutput": + compile_kwargs["loss"] = ( + compile_kwargs["loss"] or "binary_crossentropy" + ) out = [Dense(n, activation="softmax")(hidden) for n in n_classes_] else: # multiclass - loss = loss or "categorical_crossentropy" + compile_kwargs["loss"] = ( + compile_kwargs["loss"] or "categorical_crossentropy" + ) out = [Dense(n_classes_, activation="softmax")(hidden)] model = Model(inp, out) - model.compile( - loss=loss, optimizer=optimizer, metrics=metrics, - ) + model.compile(**compile_kwargs) return model def dynamic_regressor( - n_features_in_, - n_outputs_, - loss=KerasRegressor.r_squared, - optimizer="adam", - metrics=None, - hidden_layer_sizes=(100,), -): + hidden_layer_sizes, meta: Dict[str, Any], compile_kwargs: Dict[str, Any], +) -> Model: """Creates a basic MLP regressor dynamically. """ - if loss is None: - # Default Model loss, not appropriate for a classifier - loss = KerasRegressor.r_squared + # get parameters + n_features_in_ = meta["n_features_in_"] + n_outputs_ = meta["n_outputs_"] + + compile_kwargs["loss"] = compile_kwargs["loss"] or KerasRegressor.r_squared inp = Input(shape=(n_features_in_,)) @@ -75,9 +77,5 @@ def dynamic_regressor( model = Model(inp, out) - model.compile( - optimizer=optimizer, - loss=loss, # KerasRegressor.r_squared - metrics=metrics, - ) + model.compile(**compile_kwargs) return model diff --git a/tests/test_api.py b/tests/test_api.py index 29c8b5a63..b944fdad7 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,38 +1,34 @@ """Tests for Scikit-learn API wrapper.""" import pickle +from typing import Any, Dict + import numpy as np import pytest from sklearn.calibration import CalibratedClassifierCV -from sklearn.datasets import load_boston -from sklearn.datasets import load_digits -from sklearn.datasets import load_iris -from sklearn.ensemble import AdaBoostClassifier -from sklearn.ensemble import AdaBoostRegressor -from sklearn.ensemble import BaggingClassifier -from sklearn.ensemble import BaggingRegressor +from sklearn.datasets import load_boston, load_digits, load_iris +from sklearn.ensemble import ( + AdaBoostClassifier, + AdaBoostRegressor, + BaggingClassifier, + BaggingRegressor, +) from sklearn.exceptions import DataConversionWarning # noqa -from sklearn.model_selection import GridSearchCV -from sklearn.model_selection import RandomizedSearchCV +from sklearn.model_selection import GridSearchCV, RandomizedSearchCV from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler +from tensorflow.keras.layers import Conv2D, Dense, Flatten, Input +from tensorflow.keras.models import Model, Sequential +from tensorflow.keras.optimizers import Adam from tensorflow.python import keras from tensorflow.python.keras import backend as K -from tensorflow.python.keras.layers import Conv2D -from tensorflow.python.keras.layers import Dense -from tensorflow.python.keras.layers import Flatten -from tensorflow.python.keras.layers import Input -from tensorflow.python.keras.models import Model -from tensorflow.python.keras.models import Sequential from tensorflow.python.keras.utils.np_utils import to_categorical from scikeras import wrappers -from scikeras.wrappers import KerasClassifier -from scikeras.wrappers import KerasRegressor +from scikeras.wrappers import KerasClassifier, KerasRegressor -from .mlp_models import dynamic_classifier -from .mlp_models import dynamic_regressor +from .mlp_models import dynamic_classifier, dynamic_regressor from .testing_utils import basic_checks @@ -43,13 +39,16 @@ def build_fn_clf( - n_features_in_, n_classes_, hidden_dim, -): + hidden_dim, meta: Dict[str, Any], compile_kwargs: Dict[str, Any], +) -> Model: """Builds a Sequential based classifier.""" + # extract parameters + n_features_in_ = meta["n_features_in_"] + X_shape_ = meta["X_shape_"] + n_classes_ = meta["n_classes_"] + model = keras.models.Sequential() - model.add( - keras.layers.Dense(n_features_in_, input_shape=(n_features_in_,)) - ) + model.add(keras.layers.Dense(n_features_in_, input_shape=X_shape_[1:])) model.add(keras.layers.Activation("relu")) model.add(keras.layers.Dense(hidden_dim)) model.add(keras.layers.Activation("relu")) @@ -61,8 +60,13 @@ def build_fn_clf( return model -def build_fn_reg(n_features_in_, hidden_dim): +def build_fn_reg( + hidden_dim, meta: Dict[str, Any], compile_kwargs: Dict[str, Any], +) -> Model: """Builds a Sequential based regressor.""" + # extract parameters + n_features_in_ = meta["n_features_in_"] + model = keras.models.Sequential() model.add( keras.layers.Dense(n_features_in_, input_shape=(n_features_in_,)) @@ -78,35 +82,21 @@ def build_fn_reg(n_features_in_, hidden_dim): return model -class ClassBuildFnClf: - def __call__(self, hidden_dim, n_features_in_, n_classes_): - return build_fn_clf( - hidden_dim=hidden_dim, - n_features_in_=n_features_in_, - n_classes_=n_classes_, - ) - - class InheritClassBuildFnClf(wrappers.KerasClassifier): - def _keras_build_fn(self, hidden_dim, n_features_in_, n_classes_): + def _keras_build_fn( + self, hidden_dim, meta: Dict[str, Any], compile_kwargs: Dict[str, Any], + ) -> Model: return build_fn_clf( - hidden_dim=hidden_dim, - n_features_in_=n_features_in_, - n_classes_=n_classes_, - ) - - -class ClassBuildFnReg: - def __call__(self, hidden_dim, n_features_in_): - return build_fn_reg( - hidden_dim=hidden_dim, n_features_in_=n_features_in_ + hidden_dim=hidden_dim, meta=meta, compile_kwargs=compile_kwargs, ) class InheritClassBuildFnReg(wrappers.KerasRegressor): - def _keras_build_fn(self, hidden_dim, n_features_in_): + def _keras_build_fn( + self, hidden_dim, meta: Dict[str, Any], compile_kwargs: Dict[str, Any], + ) -> Model: return build_fn_reg( - hidden_dim=hidden_dim, n_features_in_=n_features_in_ + hidden_dim=hidden_dim, meta=meta, compile_kwargs=compile_kwargs, ) @@ -115,32 +105,18 @@ class TestBasicAPI: def test_classify_build_fn(self): """Tests a classification task for errors.""" - clf = wrappers.KerasClassifier(build_fn=build_fn_clf, hidden_dim=5,) - basic_checks(clf, load_iris) - - def test_classify_class_build_fn(self): - """Tests for errors using a class implementing __call__.""" - - clf = wrappers.KerasClassifier( - build_fn=ClassBuildFnClf(), hidden_dim=5, - ) + clf = wrappers.KerasClassifier(build_fn=build_fn_clf, hidden_dim=5) basic_checks(clf, load_iris) def test_classify_inherit_class_build_fn(self): """Tests for errors using an inherited class.""" - clf = InheritClassBuildFnClf(build_fn=None, hidden_dim=5,) + clf = InheritClassBuildFnClf(build_fn=None, hidden_dim=5) basic_checks(clf, load_iris) def test_regression_build_fn(self): """Tests for errors using KerasRegressor.""" - reg = wrappers.KerasRegressor(build_fn=build_fn_reg, hidden_dim=5,) - basic_checks(reg, load_boston) - - def test_regression_class_build_fn(self): - """Tests for errors using KerasRegressor implementing __call__.""" - - reg = KerasRegressor(build_fn=ClassBuildFnReg(), hidden_dim=5,) + reg = wrappers.KerasRegressor(build_fn=build_fn_reg, hidden_dim=5) basic_checks(reg, load_boston) def test_regression_inherit_class_build_fn(self): @@ -161,12 +137,16 @@ def load_digits8x8(): return data -def build_fn_regs(X, n_outputs_, hidden_layer_sizes=None, n_classes_=None): +def build_fn_regs( + hidden_layer_sizes, meta: Dict[str, Any], compile_kwargs: Dict[str, Any], +) -> Model: """Dynamically build regressor.""" - if hidden_layer_sizes is None: - hidden_layer_sizes = [] + # get params + X_shape_ = meta["X_shape_"] + n_outputs_ = meta["n_outputs_"] + model = Sequential() - model.add(Dense(X.shape[1], activation="relu", input_shape=X.shape[1:])) + model.add(Dense(X_shape_[1], activation="relu", input_shape=X_shape_[1:])) for size in hidden_layer_sizes: model.add(Dense(size, activation="relu")) model.add(Dense(n_outputs_)) @@ -174,12 +154,15 @@ def build_fn_regs(X, n_outputs_, hidden_layer_sizes=None, n_classes_=None): return model -def build_fn_clss(X, n_outputs_, hidden_layer_sizes=None, n_classes_=None): +def build_fn_clss( + hidden_layer_sizes, meta: Dict[str, Any], compile_kwargs: Dict[str, Any], +) -> Model: """Dynamically build classifier.""" - if hidden_layer_sizes is None: - hidden_layer_sizes = [] + # get params + X_shape_ = meta["X_shape_"] + model = Sequential() - model.add(Dense(X.shape[1], activation="relu", input_shape=X.shape[1:])) + model.add(Dense(X_shape_[1], activation="relu", input_shape=X_shape_[1:])) for size in hidden_layer_sizes: model.add(Dense(size, activation="relu")) model.add(Dense(1, activation="softmax")) @@ -187,12 +170,16 @@ def build_fn_clss(X, n_outputs_, hidden_layer_sizes=None, n_classes_=None): return model -def build_fn_clscs(X, n_outputs_, hidden_layer_sizes=None, n_classes_=None): +def build_fn_clscs( + hidden_layer_sizes, meta: Dict[str, Any], compile_kwargs: Dict[str, Any], +) -> Model: """Dynamically build functional API regressor.""" - if hidden_layer_sizes is None: - hidden_layer_sizes = [] + # get params + X_shape_ = meta["X_shape_"] + n_classes_ = meta["n_classes_"] + model = Sequential() - model.add(Conv2D(3, (3, 3), input_shape=X.shape[1:])) + model.add(Conv2D(3, (3, 3), input_shape=X_shape_[1:])) model.add(Flatten()) for size in hidden_layer_sizes: model.add(Dense(size, activation="relu")) @@ -203,11 +190,15 @@ def build_fn_clscs(X, n_outputs_, hidden_layer_sizes=None, n_classes_=None): return model -def build_fn_clscf(X, n_outputs_, hidden_layer_sizes=None, n_classes_=None): +def build_fn_clscf( + hidden_layer_sizes, meta: Dict[str, Any], compile_kwargs: Dict[str, Any], +) -> Model: """Dynamically build functional API classifier.""" - if hidden_layer_sizes is None: - hidden_layer_sizes = [] - x = Input(shape=X.shape[1:]) + # get params + X_shape_ = meta["X_shape_"] + n_classes_ = meta["n_classes_"] + + x = Input(shape=X_shape_[1:]) z = Conv2D(3, (3, 3))(x) z = Flatten()(z) for size in hidden_layer_sizes: @@ -258,14 +249,14 @@ class TestAdvancedAPIFuncs: def test_standalone(self, config): """Tests standalone estimator.""" loader, model, build_fn, _ = CONFIG[config] - estimator = model(build_fn, epochs=1) + estimator = model(build_fn, epochs=1, model__hidden_layer_sizes=[]) basic_checks(estimator, loader) @pytest.mark.parametrize("config", ["MLPRegressor", "MLPClassifier"]) def test_pipeline(self, config): """Tests compatibility with Scikit-learn's pipeline.""" loader, model, build_fn, _ = CONFIG[config] - estimator = model(build_fn, epochs=1) + estimator = model(build_fn, epochs=1, model__hidden_layer_sizes=[]) estimator = Pipeline([("s", StandardScaler()), ("e", estimator)]) basic_checks(estimator, loader) @@ -277,23 +268,49 @@ def test_searchcv_init_params(self, config): """Tests compatibility with Scikit-learn's hyperparameter search CV.""" loader, model, build_fn, _ = CONFIG[config] estimator = model( - build_fn, epochs=1, validation_split=0.1, hidden_layer_sizes=[] + build_fn, + epochs=1, + validation_split=0.1, + model__hidden_layer_sizes=[], ) basic_checks( - GridSearchCV(estimator, {"hidden_layer_sizes": [[], [5]]}), loader, + GridSearchCV(estimator, {"model__hidden_layer_sizes": [[], [5]]}), + loader, ) basic_checks( RandomizedSearchCV( - estimator, {"epochs": np.random.randint(1, 5, 2)}, n_iter=2, + estimator, + {"epochs": [1, 2, 3], "optimizer": ["rmsprop", "sgd"]}, + n_iter=2, ), loader, ) + @pytest.mark.parametrize( + "config", ["MLPClassifier"], + ) + def test_searchcv_routed_params(self, config): + """Tests compatibility with Scikit-learn's hyperparameter search CV.""" + loader, model, build_fn, _ = CONFIG[config] + estimator = model(build_fn, epochs=1, model__hidden_layer_sizes=[]) + params = { + "model__hidden_layer_sizes": [[], [5]], + "compile__optimizer": ["sgd", "adam"], + } + search = GridSearchCV(estimator, params) + basic_checks(search, loader) + assert search.best_estimator_.model_.optimizer._name.lower() in ( + "sgd", + "adam", + ) + @pytest.mark.parametrize("config", ["MLPRegressor", "MLPClassifier"]) def test_ensemble(self, config): """Tests compatibility with Scikit-learn's ensembles.""" loader, model, build_fn, ensembles = CONFIG[config] - base_estimator = model(build_fn, epochs=1) + base_estimator = model( + build_fn, epochs=1, model__hidden_layer_sizes=[] + ) for ensemble in ensembles: estimator = ensemble(base_estimator=base_estimator, n_estimators=2) basic_checks(estimator, loader) @@ -302,7 +319,9 @@ def test_ensemble(self, config): def test_calibratedclassifiercv(self, config): """Tests compatibility with Scikit-learn's calibrated classifier CV.""" loader, _, build_fn, _ = CONFIG[config] - base_estimator = KerasClassifier(build_fn, epochs=1) + base_estimator = KerasClassifier( + build_fn, epochs=1, model__hidden_layer_sizes=[] + ) estimator = CalibratedClassifierCV(base_estimator=base_estimator, cv=5) basic_checks(estimator, loader) @@ -323,14 +342,34 @@ def test_basic(self, config): # make y the same shape as will be used by .fit if config != "MLPRegressor": y_train = to_categorical(y_train) + meta = { + "n_classes_": n_classes_, + "target_type_": "multiclass", + "n_features_in_": x_train.shape[1], + "model_n_outputs_": 1, + } keras_model = build_fn( - n_classes_=n_classes_, - cls_type_="multiclass", - n_features_in_=x_train.shape[1], + meta=meta, + hidden_layer_sizes=(100,), + compile_kwargs={ + "optimizer": "adam", + "loss": None, + "metrics": None, + }, ) else: + meta = { + "n_outputs_": 1, + "n_features_in_": x_train.shape[1], + } keras_model = build_fn( - n_features_in_=x_train.shape[1], n_outputs_=1 + meta=meta, + hidden_layer_sizes=(100,), + compile_kwargs={ + "optimizer": "adam", + "loss": None, + "metrics": None, + }, ) estimator = model(build_fn=keras_model) @@ -347,14 +386,34 @@ def test_ensemble(self, config): # make y the same shape as will be used by .fit if config != "MLPRegressor": y_train = to_categorical(y_train) + meta = { + "n_classes_": n_classes_, + "target_type_": "multiclass", + "n_features_in_": x_train.shape[1], + "model_n_outputs_": 1, + } keras_model = build_fn( - n_classes_=n_classes_, - cls_type_="multiclass", - n_features_in_=x_train.shape[1], + meta=meta, + hidden_layer_sizes=(100,), + compile_kwargs={ + "optimizer": "adam", + "loss": None, + "metrics": None, + }, ) else: + meta = { + "n_outputs_": 1, + "n_features_in_": x_train.shape[1], + } keras_model = build_fn( - n_features_in_=x_train.shape[1], n_outputs_=1 + meta=meta, + hidden_layer_sizes=(100,), + compile_kwargs={ + "optimizer": "adam", + "loss": None, + "metrics": None, + }, ) base_estimator = model(build_fn=keras_model) @@ -370,25 +429,26 @@ def test_warm_start(): X, y = data.data[:100], data.target[:100] # Initial fit estimator = KerasRegressor( - build_fn=dynamic_regressor, loss=KerasRegressor.r_squared + build_fn=dynamic_regressor, + loss=KerasRegressor.r_squared, + model__hidden_layer_sizes=(100,), ) estimator.fit(X, y) model = estimator.model_ # With warm start, successive calls to fit # should NOT create a new model - estimator.fit(X, y, warm_start=True) + estimator.set_params(warm_start=True) + estimator.fit(X, y) assert model is estimator.model_ # Without warm start, each call to fit # should create a new model instance - estimator.fit(X, y, warm_start=False) - assert model is not estimator.model_ - model = estimator.model_ # for successive tests - - # The default should be warm_start=False - estimator.fit(X, y) - assert model is not estimator.model_ + estimator.set_params(warm_start=False) + for _ in range(3): + estimator.fit(X, y) + assert model is not estimator.model_ + model = estimator.model_ class TestPartialFit: @@ -396,7 +456,9 @@ def test_partial_fit(self): data = load_boston() X, y = data.data[:100], data.target[:100] estimator = KerasRegressor( - build_fn=dynamic_regressor, loss=KerasRegressor.r_squared, + build_fn=dynamic_regressor, + loss=KerasRegressor.r_squared, + model__hidden_layer_sizes=[100,], ) estimator.partial_fit(X, y) @@ -423,6 +485,7 @@ def test_partial_fit_history_len(self): build_fn=dynamic_regressor, loss=KerasRegressor.r_squared, metrics="mean_squared_error", + model__hidden_layer_sizes=[100,], ) for k in range(10): @@ -438,7 +501,7 @@ def test_partial_fit_history_len(self): ) def test_pf_pickle_pf(self, config): loader, model, build_fn, _ = CONFIG[config] - clf = model(build_fn, epochs=1) + clf = model(build_fn, epochs=1, model__hidden_layer_sizes=[]) data = loader() X, y = data.data[:100], data.target[:100] @@ -466,6 +529,7 @@ def test_pf_pickle_pf(self, config): # Make sure there's a decent number of weights # Also make sure that this network is "over-parameterized" (more # weights than examples) + # (these numbers are empirical and depend on model__hidden_layer_sizes=[]) assert 1000 <= sum(n_weights) <= 2000 assert 200 <= np.mean(n_weights) <= 300 assert max(n_weights) >= 1000 @@ -482,7 +546,7 @@ def test_pf_pickle_pf(self, config): # and rel_error > 0.9 to be completely different. assert all(0.01 < x for x in rel_errors) assert any(x > 0.5 for x in rel_errors) - # the rel_error is often higher than 0.5 but the tests are randomn + # the rel_error is often higher than 0.5 but the tests are random def test_history(): @@ -490,7 +554,9 @@ def test_history(): """ data = load_boston() X, y = data.data[:100], data.target[:100] - estimator = KerasRegressor(build_fn=dynamic_regressor,) + estimator = KerasRegressor( + build_fn=dynamic_regressor, model__hidden_layer_sizes=[] + ) estimator.partial_fit(X, y) diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index a0538b198..c9d30fc99 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -30,6 +30,7 @@ def test_callbacks(): build_fn=dynamic_classifier, callbacks=(SentinalCallback(),), optimizer="adam", + model__hidden_layer_sizes=(100,), ) # Check for picklign and partial fit check_estimators_pickle("KerasClassifier", estimator) diff --git a/tests/test_errors.py b/tests/test_errors.py index fc184618d..8d1970dcd 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -5,11 +5,9 @@ from sklearn.exceptions import NotFittedError -from scikeras.wrappers import KerasClassifier -from scikeras.wrappers import KerasRegressor +from scikeras.wrappers import BaseWrapper, KerasClassifier, KerasRegressor -from .mlp_models import dynamic_classifier -from .mlp_models import dynamic_regressor +from .mlp_models import dynamic_classifier, dynamic_regressor def test_validate_data(): @@ -53,8 +51,10 @@ class TestInvalidBuildFn: """ def test_invalid_build_fn(self): - clf = KerasClassifier(build_fn="invalid") - with pytest.raises(TypeError, match="build_fn"): + clf = KerasClassifier(model="invalid") + with pytest.raises( + TypeError, match="`model` must be a callable or None" + ): clf.fit(np.array([[0]]), np.array([0])) def test_no_build_fn(self): @@ -85,26 +85,6 @@ def dummy_func(): ): clf.fit(np.array([[0]]), np.array([0])) - def test_call_and_invalid_build_fn_class(self): - class Clf(KerasClassifier): - def _keras_build_fn(self, hidden_layer_sizes): - return dynamic_classifier( - hidden_layer_sizes=hidden_layer_sizes - ) - - class DummyBuildClass: - def __call__(self, hidden_layer_sizes): - return dynamic_classifier( - hidden_layer_sizes=hidden_layer_sizes - ) - - clf = Clf(build_fn=DummyBuildClass(),) - - with pytest.raises( - ValueError, match="cannot implement `_keras_build_fn`" - ): - clf.fit(np.array([[0]]), np.array([0])) - def test_sample_weights_all_zero(): """Checks for a user-friendly error when sample_weights @@ -113,7 +93,7 @@ def test_sample_weights_all_zero(): # build estimator estimator = KerasClassifier( build_fn=dynamic_classifier, - hidden_layer_sizes=(100,), + model__hidden_layer_sizes=(100,), epochs=10, random_state=0, ) @@ -126,3 +106,40 @@ def test_sample_weights_all_zero(): with pytest.raises(RuntimeError, match="no samples left"): estimator.fit(X, y, sample_weight=sample_weight) + + +def test_build_fn_deprecation(): + """An appropriate warning is raised when using the `build_fn` + parameter instead of `model`. + """ + clf = KerasClassifier( + build_fn=dynamic_regressor, model__hidden_layer_sizes=(100,) + ) + with pytest.warns( + UserWarning, match="`build_fn` will be renamed to `model`" + ): + clf.fit([[1]], [1]) + + +@pytest.mark.parametrize("wrapper", [KerasClassifier, KerasRegressor]) +def test_build_fn_and_init_signature_do_not_agree(wrapper): + """Test that passing a kwarg not present in the model + building function's signature raises a TypeError. + """ + + def no_bar(foo=42): + pass + + # all attempts to pass `bar` should fail + est = wrapper(model=no_bar, model__bar=42) + with pytest.raises(TypeError, match="got an unexpected keyword argument"): + est.fit([[1]], [1]) + est = wrapper(model=no_bar, bar=42) + with pytest.raises(TypeError, match="got an unexpected keyword argument"): + est.fit([[1]], [1]) + est = wrapper(model=no_bar, model__bar=42, foo=43) + with pytest.raises(TypeError, match="got an unexpected keyword argument"): + est.fit([[1]], [1]) + est = wrapper(model=no_bar, bar=42, foo=43) + with pytest.raises(TypeError, match="got an unexpected keyword argument"): + est.fit([[1]], [1]) diff --git a/tests/test_input_outputs.py b/tests/test_input_outputs.py index 635acb6bb..2369f2a22 100644 --- a/tests/test_input_outputs.py +++ b/tests/test_input_outputs.py @@ -1,23 +1,18 @@ +from typing import Any, Dict + import numpy as np import pytest import tensorflow as tf -from sklearn.ensemble import RandomForestClassifier -from sklearn.ensemble import RandomForestRegressor +from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor from sklearn.preprocessing import MultiLabelBinarizer -from tensorflow.python.keras.layers import Concatenate -from tensorflow.python.keras.layers import Dense -from tensorflow.python.keras.layers import Input -from tensorflow.python.keras.models import Model -from tensorflow.python.keras.models import Sequential +from tensorflow.python.keras.layers import Concatenate, Dense, Input +from tensorflow.python.keras.models import Model, Sequential from tensorflow.python.keras.testing_utils import get_test_data -from scikeras.wrappers import BaseWrapper -from scikeras.wrappers import KerasClassifier -from scikeras.wrappers import KerasRegressor +from scikeras.wrappers import BaseWrapper, KerasClassifier, KerasRegressor -from .mlp_models import dynamic_classifier -from .mlp_models import dynamic_regressor +from .mlp_models import dynamic_classifier, dynamic_regressor # Defaults @@ -31,7 +26,12 @@ class FunctionalAPIMultiInputClassifier(KerasClassifier): """Tests Functional API Classifier with 2 inputs. """ - def _keras_build_fn(self, X, n_classes_): + def _keras_build_fn( + self, meta: Dict[str, Any], compile_kwargs: Dict[str, Any], + ) -> Model: + # get params + n_classes_ = meta["n_classes_"] + inp1 = Input((1,)) inp2 = Input((3,)) @@ -59,8 +59,14 @@ class FunctionalAPIMultiOutputClassifier(KerasClassifier): """Tests Functional API Classifier with 2 outputs of different type. """ - def _keras_build_fn(self, X, n_classes_): - inp = Input((4,)) + def _keras_build_fn( + self, meta: Dict[str, Any], compile_kwargs: Dict[str, Any], + ) -> Model: + # get params + n_features_in_ = meta["n_features_in_"] + n_classes_ = meta["n_classes_"] + + inp = Input((n_features_in_,)) x1 = Dense(100)(inp) @@ -84,7 +90,12 @@ class FunctionAPIMultiLabelClassifier(KerasClassifier): """Tests Functional API Classifier with multiple binary outputs. """ - def _keras_build_fn(self, X, n_outputs_): + def _keras_build_fn( + self, meta: Dict[str, Any], compile_kwargs: Dict[str, Any], + ) -> Model: + # get params + n_outputs_ = meta["n_outputs_"] + inp = Input((4,)) x1 = Dense(100)(inp) @@ -106,7 +117,12 @@ class FunctionAPIMultiOutputRegressor(KerasRegressor): """Tests Functional API Regressor with multiple outputs. """ - def _keras_build_fn(self, X, n_outputs_): + def _keras_build_fn( + self, meta: Dict[str, Any], compile_kwargs: Dict[str, Any], + ) -> Model: + # get params + n_outputs_ = meta["n_outputs_"] + inp = Input((INPUT_DIM,)) x1 = Dense(100)(inp) @@ -218,7 +234,9 @@ def test_incompatible_output_dimensions(): y = np.random.randint(low=0, high=3, size=(10, 4)) # create a model with 2 outputs - def build_fn_clf(): + def build_fn_clf( + meta: Dict[str, Any], compile_kwargs: Dict[str, Any], + ) -> Model: """Builds a Sequential based classifier.""" model = Sequential() model.add(Dense(20, input_shape=(20,), activation="relu")) @@ -269,7 +287,7 @@ def test_classifier_handles_dtypes(dtype): sample_weight = np.ones(y.shape).astype(dtype) class StrictClassifier(KerasClassifier): - def _fit_keras_model(self, X, y, sample_weight, warm_start, **kwargs): + def _fit_keras_model(self, X, y, sample_weight, warm_start): if dtype == "object": assert X.dtype == np.dtype(tf.keras.backend.floatx()) else: @@ -277,11 +295,11 @@ def _fit_keras_model(self, X, y, sample_weight, warm_start, **kwargs): # y is passed through encoders, it is likely not the original dtype # sample_weight should always be floatx assert sample_weight.dtype == np.dtype(tf.keras.backend.floatx()) - return super()._fit_keras_model( - X, y, sample_weight, warm_start, **kwargs - ) + return super()._fit_keras_model(X, y, sample_weight, warm_start) - clf = StrictClassifier(build_fn=dynamic_classifier) + clf = StrictClassifier( + build_fn=dynamic_classifier, model__hidden_layer_sizes=(100,) + ) clf.fit(X, y, sample_weight=sample_weight) assert clf.score(X, y) >= 0 if y.dtype.kind != "O": @@ -304,7 +322,7 @@ def test_regressor_handles_dtypes(dtype): sample_weight = np.ones(y.shape).astype(dtype) class StrictRegressor(KerasRegressor): - def _fit_keras_model(self, X, y, sample_weight, warm_start, **kwargs): + def _fit_keras_model(self, X, y, sample_weight, warm_start): if dtype == "object": assert X.dtype == np.dtype(tf.keras.backend.floatx()) assert y.dtype == np.dtype(tf.keras.backend.floatx()) @@ -313,11 +331,11 @@ def _fit_keras_model(self, X, y, sample_weight, warm_start, **kwargs): assert y.dtype == np.dtype(dtype) # sample_weight should always be floatx assert sample_weight.dtype == np.dtype(tf.keras.backend.floatx()) - return super()._fit_keras_model( - X, y, sample_weight, warm_start, **kwargs - ) + return super()._fit_keras_model(X, y, sample_weight, warm_start) - reg = StrictRegressor(build_fn=dynamic_regressor) + reg = StrictRegressor( + build_fn=dynamic_regressor, model__hidden_layer_sizes=(100,) + ) reg.fit(X, y, sample_weight=sample_weight) y_hat = reg.predict(X) if y.dtype.kind == "f": @@ -338,7 +356,7 @@ def test_mixed_dtypes(y_dtype, X_dtype, run_eagerly): y = np.random.choice(n_classes, size=n).astype(y_dtype) class StrictRegressor(KerasRegressor): - def _fit_keras_model(self, X, y, sample_weight, warm_start, **kwargs): + def _fit_keras_model(self, X, y, sample_weight, warm_start): if X_dtype == "object": assert X.dtype == np.dtype(tf.keras.backend.floatx()) else: @@ -347,11 +365,13 @@ def _fit_keras_model(self, X, y, sample_weight, warm_start, **kwargs): assert y.dtype == np.dtype(tf.keras.backend.floatx()) else: assert y.dtype == np.dtype(y_dtype) - return super()._fit_keras_model( - X, y, sample_weight, warm_start, **kwargs - ) + return super()._fit_keras_model(X, y, sample_weight, warm_start) - reg = StrictRegressor(build_fn=dynamic_regressor, run_eagerly=run_eagerly) + reg = StrictRegressor( + build_fn=dynamic_regressor, + run_eagerly=run_eagerly, + model__hidden_layer_sizes=(100,), + ) reg.fit(X, y) y_hat = reg.predict(X) if y.dtype.kind == "f": diff --git a/tests/test_param_routing.py b/tests/test_param_routing.py new file mode 100644 index 000000000..cf95c54c6 --- /dev/null +++ b/tests/test_param_routing.py @@ -0,0 +1,161 @@ +import inspect + +from distutils.version import LooseVersion +from typing import Any, Dict + +import numpy as np +import pytest +import tensorflow as tf + +from tensorflow.keras import Model + +from scikeras.wrappers import BaseWrapper, KerasClassifier, KerasRegressor + +from .mlp_models import dynamic_classifier, dynamic_regressor + + +@pytest.mark.parametrize( + "wrapper, builder", + [ + (KerasClassifier, dynamic_classifier), + (KerasRegressor, dynamic_regressor), + ], +) +def test_routing_basic(wrapper, builder): + n, d = 20, 3 + n_classes = 3 + X = np.random.uniform(size=(n, d)).astype(float) + y = np.random.choice(n_classes, size=n).astype(int) + + est = wrapper(build_fn=builder, model__hidden_layer_sizes=(100,)) + + def build_fn(hidden_layer_sizes, compile_kwargs, params, meta): + assert set(params.keys()) == set(est.get_params().keys()) + expected_meta = wrapper._meta - { + "model_", + "history_", + "is_fitted_", + } + assert set(meta.keys()) == expected_meta + assert set(compile_kwargs.keys()).issubset(wrapper._compile_kwargs) + return builder( + hidden_layer_sizes=hidden_layer_sizes, + compile_kwargs=compile_kwargs, + meta=meta, + ) + + est = wrapper(build_fn=build_fn, model__hidden_layer_sizes=(100,)) + est.fit(X, y) + + +@pytest.mark.parametrize( + "wrapper, builder", + [ + (KerasClassifier, dynamic_classifier), + (KerasRegressor, dynamic_regressor), + ], +) +def test_routing_kwargs(wrapper, builder): + """Tests that special parameters are passed if + build_fn accepts kwargs. + """ + n, d = 20, 3 + n_classes = 3 + X = np.random.uniform(size=(n, d)).astype(float) + y = np.random.choice(n_classes, size=n).astype(int) + + def build_fn(*args, **kwargs): + assert len(args) == 0, "No *args should be passed to `build_fn`" + assert tuple(kwargs.keys()) == ( + "hidden_layer_sizes", + "meta", + "compile_kwargs", + "params", + ), "The number and order of **kwargs passed to `build_fn` should be fixed" + meta = set(kwargs["meta"].keys()) + expected_meta = wrapper._meta - { + "model_", + "history_", + "is_fitted_", + } + assert meta == expected_meta + assert set(kwargs["compile_kwargs"].keys()).issubset( + wrapper._compile_kwargs + ) + kwargs.pop("params") # dynamic_classifier/regressor don't accept it + return builder(*args, **kwargs) + + est = wrapper(build_fn=build_fn, model__hidden_layer_sizes=(100,)) + est.fit(X, y) + + +@pytest.mark.parametrize( + "wrapper_class,build_fn", + [ + (KerasClassifier, dynamic_classifier), + (KerasRegressor, dynamic_regressor), + ], +) +def test_no_extra_meta(wrapper_class, build_fn): + """Check that wrappers do not create any unexpected meta parameters. + """ + n, d = 20, 3 + n_classes = 3 + X = np.random.uniform(size=(n, d)).astype(float) + y = np.random.choice(n_classes, size=n).astype(int) + + # with user kwargs + clf = wrapper_class(build_fn=build_fn, model__hidden_layer_sizes=(100,)) + clf.fit(X, y) + assert set(clf.get_meta().keys()) == wrapper_class._meta + # without user kwargs + def build_fn_no_args(meta, compile_kwargs): + return build_fn( + hidden_layer_sizes=(100,), + meta=meta, + compile_kwargs=compile_kwargs, + ) + + clf = wrapper_class(build_fn=build_fn_no_args) + clf.fit(X, y) + assert set(clf.get_meta().keys()) == wrapper_class._meta - {"_user_params"} + + +def test_model_params_property(): + """Check that the `_model_params` property works as expected. + """ + clf = KerasRegressor(model="test", model__hidden_layer_sizes=(100,)) + assert clf._model_params == {"hidden_layer_sizes"} + + +@pytest.mark.parametrize("dest", ["fit", "compile", "predict"]) +def test_routing_sets(dest): + accepted_params = set( + inspect.signature(getattr(Model, dest)).parameters.keys() + ) - {"self", "kwargs"} + known_params = getattr(BaseWrapper, f"_{dest}_kwargs") + if LooseVersion(tf.__version__) <= "2.2.0": + # this parameter is a kwarg in TF 2.2.0 + # it will still work in practice, but breaks this test + known_params = known_params - {"run_eagerly"} + assert known_params.issubset(accepted_params) + + +def test_routed_unrouted_equivalence(): + """Test that `hidden_layer_sizes` and `model__hidden_layer_sizes` + both work. + """ + n, d = 20, 3 + n_classes = 3 + X = np.random.uniform(size=(n, d)).astype(float) + y = np.random.choice(n_classes, size=n).astype(int) + + clf = KerasClassifier( + build_fn=dynamic_classifier, model__hidden_layer_sizes=(100,) + ) + clf.fit(X, y) + + clf = KerasClassifier( + build_fn=dynamic_classifier, hidden_layer_sizes=(100,) + ) + clf.fit(X, y) diff --git a/tests/test_parameters.py b/tests/test_parameters.py index 6df38ed5a..531b92b1c 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -6,11 +6,9 @@ from sklearn.base import clone from tensorflow.python.keras.testing_utils import get_test_data -from scikeras.wrappers import KerasClassifier -from scikeras.wrappers import KerasRegressor +from scikeras.wrappers import KerasClassifier, KerasRegressor -from .mlp_models import dynamic_classifier -from .mlp_models import dynamic_regressor +from .mlp_models import dynamic_classifier, dynamic_regressor # Defaults @@ -31,9 +29,13 @@ class TestRandomState: "estimator", [ KerasRegressor( - build_fn=dynamic_regressor, loss=KerasRegressor.r_squared, + build_fn=dynamic_regressor, + loss=KerasRegressor.r_squared, + model__hidden_layer_sizes=(100,), + ), + KerasClassifier( + build_fn=dynamic_classifier, model__hidden_layer_sizes=(100,) ), - KerasClassifier(build_fn=dynamic_classifier), ], ) def test_random_states(self, random_state, estimator): @@ -70,9 +72,13 @@ def test_random_states(self, random_state, estimator): "estimator", [ KerasRegressor( - build_fn=dynamic_regressor, loss=KerasRegressor.r_squared, + build_fn=dynamic_regressor, + loss=KerasRegressor.r_squared, + model__hidden_layer_sizes=(100,), + ), + KerasClassifier( + build_fn=dynamic_classifier, model__hidden_layer_sizes=(100,) ), - KerasClassifier(build_fn=dynamic_classifier), ], ) @pytest.mark.parametrize("pyhash", [None, "0", "1"]) @@ -136,7 +142,7 @@ def test_sample_weights_fit(): # build estimator estimator = KerasClassifier( build_fn=dynamic_classifier, - hidden_layer_sizes=(100,), + model__hidden_layer_sizes=(100,), epochs=10, random_state=0, ) @@ -182,7 +188,7 @@ def test_sample_weights_score(): # build estimator estimator = KerasRegressor( build_fn=dynamic_regressor, - hidden_layer_sizes=(100,), + model__hidden_layer_sizes=(100,), epochs=10, random_state=0, ) @@ -211,13 +217,15 @@ def test_build_fn_default_params(): """Tests that default arguments arguments of `build_fn` are registered as hyperparameters. """ - est = KerasClassifier(build_fn=dynamic_classifier) + est = KerasClassifier( + build_fn=dynamic_classifier, model__hidden_layer_sizes=(100,) + ) params = est.get_params() # (100, ) is the default for dynamic_classifier - assert params["hidden_layer_sizes"] == (100,) + assert params["model__hidden_layer_sizes"] == (100,) est = KerasClassifier( - build_fn=dynamic_classifier, hidden_layer_sizes=(200,) + build_fn=dynamic_classifier, model__hidden_layer_sizes=(200,) ) params = est.get_params() - assert params["hidden_layer_sizes"] == (200,) + assert params["model__hidden_layer_sizes"] == (200,) diff --git a/tests/test_scikit_learn_checks.py b/tests/test_scikit_learn_checks.py index 3ab2f3ca8..697dfedb6 100644 --- a/tests/test_scikit_learn_checks.py +++ b/tests/test_scikit_learn_checks.py @@ -1,20 +1,20 @@ """Tests using Scikit-Learn's bundled estimator_checks.""" from distutils.version import LooseVersion +from typing import Any, Dict +import numpy as np import pytest from sklearn import __version__ as sklearn_version from sklearn.datasets import load_iris from sklearn.utils.estimator_checks import check_no_attributes_set_in_init +from tensorflow.keras import Model, Sequential, layers -from scikeras.wrappers import KerasClassifier -from scikeras.wrappers import KerasRegressor +from scikeras.wrappers import KerasClassifier, KerasRegressor -from .mlp_models import dynamic_classifier -from .mlp_models import dynamic_regressor -from .testing_utils import basic_checks -from .testing_utils import parametrize_with_checks +from .mlp_models import dynamic_classifier, dynamic_regressor +from .testing_utils import basic_checks, parametrize_with_checks @parametrize_with_checks( @@ -29,6 +29,7 @@ # applicable to real world datasets batch_size=1000, optimizer="adam", + model__hidden_layer_sizes=(100,), ), KerasRegressor( build_fn=dynamic_regressor, @@ -41,6 +42,7 @@ batch_size=1000, optimizer="adam", loss=KerasRegressor.r_squared, + model__hidden_layer_sizes=(100,), ), ], ids=["KerasClassifier", "KerasRegressor"], @@ -59,23 +61,52 @@ def test_fully_compliant_estimators(estimator, check): class SubclassedClassifier(KerasClassifier): def __init__( - self, hidden_layer_sizes=(100,), + self, + model__hidden_layer_sizes=(100,), + metrics=None, + loss=None, + **kwargs, ): - self.hidden_layer_sizes = hidden_layer_sizes + super().__init__(**kwargs) + self.model__hidden_layer_sizes = model__hidden_layer_sizes + self.metrics = metrics + self.loss = loss + self.optimizer = "sgd" - def _keras_build_fn(self, hidden_layer_sizes): + def _keras_build_fn( + self, + hidden_layer_sizes, + meta: Dict[str, Any], + compile_kwargs: Dict[str, Any], + ) -> Model: return dynamic_classifier( - n_features_in_=self.n_features_in_, - cls_type_=self.cls_type_, - n_classes_=self.n_classes_, hidden_layer_sizes=hidden_layer_sizes, + meta=meta, + compile_kwargs=compile_kwargs, ) -def test_no_attributes_set_init(): +def test_no_attributes_set_init_sublcassed(): """Tests that subclassed models can be made that set all parameters in a single __init__ """ estimator = SubclassedClassifier() check_no_attributes_set_in_init(estimator.__name__, estimator) basic_checks(estimator, load_iris) + + +def test_no_attributes_set_init_no_args(): + """Tests that models with no build arguments + set all parameters in a single __init__ + """ + + def build_fn(): + model = Sequential() + model.add(layers.Dense(1, input_dim=1, activation="relu")) + model.add(layers.Dense(1)) + model.compile(loss="mse") + return model + + estimator = KerasRegressor(model=build_fn) + check_no_attributes_set_in_init(estimator.__name__, estimator) + estimator.fit([[1]], [1]) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index a2d9f093c..53c4c879d 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -1,12 +1,13 @@ import pickle +from typing import Any, Dict + import numpy as np import pytest from sklearn.datasets import load_boston from tensorflow.python import keras -from tensorflow.python.keras.layers import Dense -from tensorflow.python.keras.layers import Input +from tensorflow.python.keras.layers import Dense, Input from tensorflow.python.keras.models import Model from scikeras.wrappers import KerasRegressor @@ -26,7 +27,7 @@ def check_pickle(estimator, loader): deserialized_estimator = pickle.loads(serialized_estimator) deserialized_estimator.predict(X) score_new = deserialized_estimator.score(X, y) - np.testing.assert_almost_equal(score, score_new) + np.testing.assert_almost_equal(score, score_new, decimal=2) # ---------------------- Custom Loss Test ---------------------- @@ -42,14 +43,20 @@ class CustomLoss(keras.losses.MeanSquaredError): def test_custom_loss_function(): """Test that a custom loss function can be serialized. """ - estimator = KerasRegressor(build_fn=dynamic_regressor, loss=CustomLoss(),) + estimator = KerasRegressor( + build_fn=dynamic_regressor, + loss=CustomLoss(), + model__hidden_layer_sizes=(100,), + ) check_pickle(estimator, load_boston) # ---------------------- Subclassed Model Tests ------------------ -def build_fn_custom_model_registered(n_features_in_, n_outputs_): +def build_fn_custom_model_registered( + meta: Dict[str, Any], compile_kwargs: Dict[str, Any], +) -> Model: """Dummy custom Model subclass that is registered to be serializable. """ @@ -57,6 +64,10 @@ def build_fn_custom_model_registered(n_features_in_, n_outputs_): class CustomModelRegistered(Model): pass + # get parameters + n_features_in_ = meta["n_features_in_"] + n_outputs_ = meta["n_outputs_"] + inp = Input(shape=n_features_in_) x1 = Dense(n_features_in_, activation="relu")(inp) out = Dense(n_outputs_, activation="linear")(x1) @@ -72,13 +83,19 @@ def test_custom_model_registered(): check_pickle(estimator, load_boston) -def build_fn_custom_model_unregistered(n_features_in_, n_outputs_): +def build_fn_custom_model_unregistered( + meta: Dict[str, Any], compile_kwargs: Dict[str, Any], +) -> Model: """Dummy custom Model subclass that is not registed to be serializable. """ class CustomModelUnregistered(Model): pass + # get parameters + n_features_in_ = meta["n_features_in_"] + n_outputs_ = meta["n_outputs_"] + inp = Input(shape=n_features_in_) x1 = Dense(n_features_in_, activation="relu")(inp) out = Dense(n_outputs_, activation="linear")(x1) @@ -105,5 +122,6 @@ def test_run_eagerly(): build_fn=dynamic_regressor, run_eagerly=True, loss=KerasRegressor.r_squared, + model__hidden_layer_sizes=(100,), ) check_pickle(estimator, load_boston) diff --git a/tests/test_utils.py b/tests/test_utils.py index 457fdbbff..e377e79d8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,6 @@ import pytest -from scikeras._utils import pack_keras_model -from scikeras._utils import unpack_keras_model +from scikeras._utils import pack_keras_model, route_params, unpack_keras_model @pytest.mark.parametrize("obj", [None, "notamodel"]) @@ -10,3 +9,13 @@ def test_pack_unpack_not_model(obj): pack_keras_model(obj, 0) with pytest.raises(TypeError): unpack_keras_model(obj, 0) + + +def test_route_params(): + """Test the `route_params` function. + """ + params = {"model__foo": object()} + destination = "model" + pass_filter = set() + out = route_params(params, destination, pass_filter) + assert out["foo"] is params["model__foo"]