Skip to content

Commit

Permalink
list -> numpy constant creation
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Aug 7, 2024
1 parent 0ba28c7 commit 9437a0a
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion sklego/meta/grouped_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __add_shrinkage_column(self, frame, groups=None):

frame = frame.select(
nw.from_dict(
data={self._global_col_name: [self._global_col_value] * n_samples},
data={self._global_col_name: np.full(shape=n_samples, fill_value=self._global_col_value)},
native_namespace=nw.get_native_namespace(frame),
)[self._global_col_name],
nw.all(),
Expand Down
4 changes: 2 additions & 2 deletions sklego/meta/hierarchical_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def fit(self, X, y=None):

native_namespace = nw.get_native_namespace(X)
target_series = nw.from_dict({self._TARGET_NAME: y}, native_namespace=native_namespace)[self._TARGET_NAME]
global_series = nw.from_dict({self._GLOBAL_NAME: [1] * n_samples}, native_namespace=native_namespace)[
global_series = nw.from_dict({self._GLOBAL_NAME: np.ones(n_samples)}, native_namespace=native_namespace)[
self._GLOBAL_NAME
]
frame = X.with_columns(
Expand Down Expand Up @@ -322,7 +322,7 @@ def _predict_estimators(self, X, method_name):

n_samples = X.shape[0]
native_namespace = nw.get_native_namespace(X)
global_series = nw.from_dict({self._GLOBAL_NAME: [1] * n_samples}, native_namespace=native_namespace)[
global_series = nw.from_dict({self._GLOBAL_NAME: np.ones(n_samples)}, native_namespace=native_namespace)[
self._GLOBAL_NAME
]

Expand Down

0 comments on commit 9437a0a

Please sign in to comment.