diff --git a/sklearn/compose/_column_transformer.py b/sklearn/compose/_column_transformer.py index 2e5c070062ce5..5da308ba05ac1 100644 --- a/sklearn/compose/_column_transformer.py +++ b/sklearn/compose/_column_transformer.py @@ -443,7 +443,7 @@ def fit_transform(self, X, y=None): self._update_fitted_transformers(transformers) self._validate_output(Xs) - return _hstack(list(Xs), self.sparse_output_) + return self._hstack(list(Xs), self.sparse_output_) def transform(self, X): """Transform X separately by each transformer, concatenate results. @@ -471,7 +471,25 @@ def transform(self, X): # All transformers are None return np.zeros((X.shape[0], 0)) - return _hstack(list(Xs), self.sparse_output_) + return self._hstack(list(Xs), self.sparse_output_) + + @staticmethod + def _hstack(X, sparse_): + """ + Stacks X horizontally. + + Supports input types (X): list of + numpy arrays, sparse arrays and DataFrames + + This is implemented as a staticmethod to enable subclasses to control + the stacking behavior, while reusing everything else from + ColumnTransformer. + """ + if sparse_: + return sparse.hstack(X).tocsr() + else: + X = [f.toarray() if sparse.issparse(f) else f for f in X] + return np.hstack(X) def _check_key_type(key, superclass): @@ -505,20 +523,6 @@ def _check_key_type(key, superclass): return False -def _hstack(X, sparse_): - """ - Stacks X horizontally. - - Supports input types (X): list of - numpy arrays, sparse arrays and DataFrames - """ - if sparse_: - return sparse.hstack(X).tocsr() - else: - X = [f.toarray() if sparse.issparse(f) else f for f in X] - return np.hstack(X) - - def _get_column(X, key): """ Get feature column(s) from input data X.