-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for polars dataframes and series #7463
base: main
Are you sure you want to change the base?
Conversation
polars should be an optional dependency. For the dispatch it can be done with a try except import |
if pl is not None: | ||
@_as_tensor_variable.register(pd.Series) | ||
@_as_tensor_variable.register(pd.DataFrame) | ||
@_as_tensor_variable.register(pl.DataFrame) | ||
@_as_tensor_variable.register(pl.Series) | ||
def dataframe_to_tensor_variable(df: pd.DataFrame | pl.DataFrame, *args, **kwargs) -> TensorVariable: | ||
return pt.as_tensor_variable(df.to_numpy(), *args, **kwargs) | ||
else: | ||
@_as_tensor_variable.register(pd.Series) | ||
@_as_tensor_variable.register(pd.DataFrame) | ||
def dataframe_to_tensor_variable(df: pd.DataFrame, *args, **kwargs) -> TensorVariable: | ||
return pt.as_tensor_variable(df.to_numpy(), *args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is more succinct. Also type hint of df was wrong, so I just removed it.
if pl is not None: | |
@_as_tensor_variable.register(pd.Series) | |
@_as_tensor_variable.register(pd.DataFrame) | |
@_as_tensor_variable.register(pl.DataFrame) | |
@_as_tensor_variable.register(pl.Series) | |
def dataframe_to_tensor_variable(df: pd.DataFrame | pl.DataFrame, *args, **kwargs) -> TensorVariable: | |
return pt.as_tensor_variable(df.to_numpy(), *args, **kwargs) | |
else: | |
@_as_tensor_variable.register(pd.Series) | |
@_as_tensor_variable.register(pd.DataFrame) | |
def dataframe_to_tensor_variable(df: pd.DataFrame, *args, **kwargs) -> TensorVariable: | |
return pt.as_tensor_variable(df.to_numpy(), *args, **kwargs) | |
@_as_tensor_variable.register(pd.Series) | |
@_as_tensor_variable.register(pd.DataFrame) | |
def dataframe_to_tensor_variable(df, *args, **kwargs) -> TensorVariable: | |
return pt.as_tensor_variable(df.to_numpy(), *args, **kwargs) | |
if pl is not None: | |
@_as_tensor_variable.register(pl.DataFrame) | |
@_as_tensor_variable.register(pl.Series) | |
def polars_dataframe_to_tensor_variable(df, *args, **kwargs) -> TensorVariable: | |
return pt.as_tensor_variable(df.to_numpy(), *args, **kwargs) | |
@@ -111,6 +115,18 @@ def convert_data(data) -> np.ndarray | Variable: | |||
ret = np.ma.MaskedArray(vals, mask) | |||
else: | |||
ret = vals | |||
elif hasattr(data, "to_numpy") and hasattr(data, "is_null"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elif hasattr(data, "to_numpy") and hasattr(data, "is_null"): | |
elif hasattr(data, "to_numpy") and hasattr(data, "is_null"): | |
# Probably polars object |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not a bit more explicit:
elif hasattr(data, "to_numpy") and hasattr(data, "is_null"): | |
elif pl is not None and isinstance(data, (pl.DataFrame, pl.Series)): |
The polars namespace is used anyway (in the except clause).
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #7463 +/- ##
==========================================
- Coverage 92.17% 92.10% -0.08%
==========================================
Files 103 103
Lines 17258 17279 +21
==========================================
+ Hits 15908 15914 +6
- Misses 1350 1365 +15
|
Description
Mostly superficial changes to recognize polars data structures.
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7463.org.readthedocs.build/en/7463/