-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GroupedPredictor
refactoring
#618
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,9 +16,10 @@ def constant_shrinkage(group_sizes: list, alpha: float) -> np.ndarray: | |
Let $\hat{y}_i$ be the prediction at level $i$, with $i=0$ being the root, than the augmented prediction | ||
$\hat{y}_i^* = \alpha \hat{y}_i + (1 - \alpha) \hat{y}_{i-1}^*$, with $\hat{y}_0^* = \hat{y}_0$. | ||
""" | ||
n_groups = len(group_sizes) | ||
return np.array( | ||
[alpha ** (len(group_sizes) - 1)] | ||
+ [alpha ** (len(group_sizes) - 1 - i) * (1 - alpha) for i in range(1, len(group_sizes) - 1)] | ||
[alpha ** (n_groups - 1)] | ||
+ [alpha ** (n_groups - 1 - i) * (1 - alpha) for i in range(1, n_groups - 1)] | ||
+ [(1 - alpha)] | ||
) | ||
|
||
|
@@ -45,13 +46,15 @@ def _split_groups_and_values( | |
_shape_check(X, min_value_cols) | ||
|
||
try: | ||
lgroups = as_list(groups) | ||
|
||
if isinstance(X, pd.DataFrame): | ||
X_group = X.loc[:, as_list(groups)] | ||
X_value = X.drop(columns=groups).values | ||
X_group = X.loc[:, lgroups] | ||
X_value = X.drop(columns=lgroups).values | ||
else: | ||
X_group = pd.DataFrame(X[:, as_list(groups)]) | ||
pos_indexes = range(X.shape[1]) | ||
X_value = np.delete(X, [pos_indexes[g] for g in as_list(groups)], axis=1) | ||
X_group = pd.DataFrame(X[:, lgroups]) | ||
X_value = np.delete(X, lgroups, axis=1) | ||
|
||
except (KeyError, IndexError): | ||
raise ValueError(f"Could not drop groups {groups} from columns of X") | ||
|
||
|
@@ -88,7 +91,46 @@ def _check_grouping_columns(X_group, **kwargs) -> pd.DataFrame: | |
|
||
# Only check missingness in object columns | ||
if X_group.select_dtypes(exclude="number").isnull().any(axis=None): | ||
raise ValueError("X has NaN values") | ||
raise ValueError("Group columns contain NaN values") | ||
|
||
# The grouping part we always want as a DataFrame with range index | ||
return X_group.reset_index(drop=True) | ||
|
||
|
||
def _get_estimator(estimators, grp_values, grp_names, return_level, fallback_method): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The point of this function is to determine which estimator to use to predict.
The point of returning a |
||
"""Recursive function to get the estimator for the given group values. | ||
|
||
Parameters | ||
---------- | ||
estimators : dict[tuple, scikit-learn compatible estimator/pipeline] | ||
Dictionary with group values as keys and estimators as values. | ||
grp_values : tuple | ||
List of group values - keys to the estimators dictionary. | ||
grp_names : list | ||
List of group names | ||
return_level : int | ||
The level of the group values to return the estimator for. | ||
fallback_method : Literal["global", "next", "raise"] | ||
Defines which fallback strategy to use if a group is not found at prediction time. | ||
""" | ||
if fallback_method == "raise": | ||
try: | ||
return estimators[grp_values], return_level | ||
except KeyError: | ||
raise KeyError(f"No fallback/parent estimator found for the given group values: {grp_names}={grp_values}") | ||
|
||
elif fallback_method == "next": | ||
try: | ||
return estimators[grp_values], return_level | ||
except KeyError: | ||
if len(grp_values) == 1: | ||
raise KeyError( | ||
f"No fallback/parent estimator found for the given group values: {grp_names}={grp_values}" | ||
) | ||
return _get_estimator(estimators, grp_values[:-1], grp_names[:-1], return_level - 1, fallback_method) | ||
|
||
else: # fallback_method == "global" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: technically the Just noticed we check this elsewhere, so it's probably fine to not check here. |
||
try: | ||
return estimators[grp_values], return_level | ||
except KeyError: | ||
return estimators[(1,)], 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we should set
group_sizes
to also be a non-list type in the function definition? Bit of a nit this one.Also: maybe
groups_list
instead oflgroups
.