Skip to content

Commit

Permalink
Address #40
Browse files Browse the repository at this point in the history
Use "boolean" dtype instead of bool to deal with nullable bool arrays
  • Loading branch information
gtca committed May 25, 2023
1 parent 6ab1ade commit bc7a066
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 8 deletions.
23 changes: 15 additions & 8 deletions mudata/_core/mudata.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from anndata._core.views import DataFrameView

from .file_backing import MuDataFileManager
from .utils import _make_index_unique, _restore_index
from .utils import _make_index_unique, _restore_index, _maybe_coerce_to_boolean

from .repr import *
from .config import OPTIONS
Expand Down Expand Up @@ -541,7 +541,10 @@ def _update_attr(self, attr: str, axis: int, join_common: bool = False):
sort=False,
)
data_common = pd.concat(
[getattr(a, attr)[columns_common] for m, a in self.mod.items()],
[
_maybe_coerce_to_boolean(getattr(a, attr)[columns_common])
for m, a in self.mod.items()
],
join="outer",
axis=0,
sort=False,
Expand Down Expand Up @@ -587,11 +590,13 @@ def _update_attr(self, attr: str, axis: int, join_common: bool = False):
else:
if join_common:
dfs = [
_make_index_unique(
getattr(a, attr)
.drop(columns_common, axis=1)
.assign(**{rowcol: np.arange(getattr(a, attr).shape[0])})
.add_prefix(m + ":")
_maybe_coerce_to_boolean(
_make_index_unique(
getattr(a, attr)
.drop(columns_common, axis=1)
.assign(**{rowcol: np.arange(getattr(a, attr).shape[0])})
.add_prefix(m + ":")
)
)
for m, a in self.mod.items()
]
Expand All @@ -606,7 +611,9 @@ def _update_attr(self, attr: str, axis: int, join_common: bool = False):

data_common = pd.concat(
[
_make_index_unique(getattr(a, attr)[columns_common])
_maybe_coerce_to_boolean(
_make_index_unique(getattr(a, attr)[columns_common])
)
for m, a in self.mod.items()
],
join="outer",
Expand Down
16 changes: 16 additions & 0 deletions mudata/_core/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from collections import Counter
from typing import TypeVar
import pandas as pd
import numpy as np
from anndata.utils import make_index_unique

T = TypeVar("T", pd.Series, pd.DataFrame)


def _make_index_unique(df: pd.DataFrame) -> pd.DataFrame:
dup_idx = np.zeros((df.shape[0],), dtype=np.uint8)
Expand All @@ -19,3 +22,16 @@ def _make_index_unique(df: pd.DataFrame) -> pd.DataFrame:

def _restore_index(df: pd.DataFrame) -> pd.DataFrame:
return df.reset_index(level=-1, drop=True)


def _maybe_coerce_to_boolean(df: T) -> T:
if isinstance(df, pd.Series):
if df.dtype == bool:
return df.astype("boolean")
return df

for col in df.columns:
if df[col].dtype == bool:
df = df.assign(**{col: df[col].astype("boolean")})

return df

0 comments on commit bc7a066

Please sign in to comment.