Skip to content

Commit

Permalink
greatly improve the efficiency of the bottleneck - adding awkward arr…
Browse files Browse the repository at this point in the history
…ay to adata.obsm using dataframe
  • Loading branch information
xinyuejohn committed Feb 21, 2024
1 parent 4a9e960 commit a4b96b2
Showing 1 changed file with 22 additions and 21 deletions.
43 changes: 22 additions & 21 deletions ehrdata/io/_omop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import awkward as ak
import ehrapy as ep
import pandas as pd
import pyarrow as pa
from rich import print as rprint

from ehrdata.utils._omop_utils import (
Expand Down Expand Up @@ -407,29 +408,29 @@ def extract_note(


def from_dataframe(adata, feature: str, df):
grouped = df.groupby("visit_occurrence_id")
unique_visit_occurrence_ids = set(adata.obs.index)
# Add new rows for those visit_occurrence_id that don't have any data
new_row_dict = {col: [] for col in df.columns}
for key in new_row_dict.keys():
if key == "visit_occurrence_id":
new_row_dict[key] = list(set(adata.obs.index) - set(df.visit_occurrence_id))
else:
new_row_dict[key] = [None] * len(new_row_dict["visit_occurrence_id"])
new_rows = pd.DataFrame(new_row_dict)
df = pd.concat([df, new_rows], ignore_index=True)

# Use set difference and intersection more efficiently
feature_ids = unique_visit_occurrence_ids.intersection(grouped.groups.keys())
empty_entry = {
source_table_column: []
for source_table_column in set(df.columns)
if source_table_column not in ["visit_occurrence_id"]
}
ak_array = ak.from_arrow(pa.Table.from_pandas(df), highlevel=True)
ak_array = ak.unflatten(ak_array, df["visit_occurrence_id"].value_counts(sort=False).values)

# Need to sort the visit_occurrence_id in awkward array accoring to the sequence in the indices in the adata
id_in_df = list(df["visit_occurrence_id"].unique())
id_in_adata = list(adata.obs.index)
index_dict = {value: index for index, value in enumerate(id_in_df)}
index = [index_dict[x] for x in id_in_adata]

# Sort the ak_array to align with the adata
ak_array = ak_array[index]
columns_in_ak_array = list(set(df.columns) - {"visit_occurrence_id"})
# Creating the array more efficiently
ak_array = ak.Array(
[
(
grouped.get_group(visit_occurrence_id)[columns_in_ak_array].to_dict(orient="list")
if visit_occurrence_id in feature_ids
else empty_entry
)
for visit_occurrence_id in unique_visit_occurrence_ids
]
)
adata.obsm[feature] = ak_array
adata.obsm[feature] = ak_array[columns_in_ak_array]

return adata

Expand Down

0 comments on commit a4b96b2

Please sign in to comment.