Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fixes for benchmark results #14

Merged
merged 14 commits into from
Sep 20, 2024
Merged
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

* Add `CHANGELOG.md` (PR #7).

* Update `process_dataset` component to subsample large datasets (PR #14).

## MAJOR CHANGES

* Revamp `scripts` directory (PR #13).
Expand Down
6 changes: 6 additions & 0 deletions src/api/file_common_dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ info:
name: counts
description: Raw counts
required: true
obs:
- type: string
name: batch
description: Batch information
required: false

uns:
- type: string
name: dataset_id
Expand Down
4 changes: 4 additions & 0 deletions src/data_processors/process_dataset/config.vsh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ arguments:
type: "integer"
description: "A seed for the subsampling."
example: 123
- name: "--n_obs_limit"
type: "integer"
description: "The maximum number of cells the dataset may have before subsampling according to `obs.batch`."
default: 20000
resources:
- type: python_script
path: script.py
Expand Down
36 changes: 27 additions & 9 deletions src/data_processors/process_dataset/script.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
import random
import anndata as ad
import numpy as np

Expand Down Expand Up @@ -26,13 +27,30 @@
print(">> Load Data", flush=True)
adata = ad.read_h5ad(par["input"])

# limit to max number of observations
adata_output = adata.copy()
if adata.n_obs > par["n_obs_limit"]:
print(">> Subsampling the observations", flush=True)
print(f">> Setting seed to {par['seed']}")
random.seed(par["seed"])
if "batch" not in adata.obs:
obs_filt = np.ones(dtype=np.bool_, shape=adata.n_obs)
obs_index = np.random.choice(np.where(obs_filt)[0], par["n_obs_limit"], replace=False)
adata_output = adata[obs_index].copy()
else:
batch_counts = adata.obs.groupby('batch').size()
filtered_batches = batch_counts[batch_counts <= par["n_obs_limit"]]
sorted_filtered_batches = filtered_batches.sort_values(ascending=False)
selected_batch = sorted_filtered_batches.index[0]
adata_output = adata[adata.obs["batch"]==selected_batch,:].copy()

# remove all layers except for counts
for key in list(adata.layers.keys()):
for key in list(adata_output.layers.keys()):
if key != "counts":
del adata.layers[key]
del adata_output.layers[key]

# round counts and convert to int
counts = np.array(adata.layers["counts"]).round().astype(int)
counts = np.array(adata_output.layers["counts"]).round().astype(int)

print(">> process and split data", flush=True)
train_data, test_data = split_molecules(
Expand All @@ -49,16 +67,16 @@
# copy adata to train_set, test_set
output_train = ad.AnnData(
layers={"counts": X_train},
obs=adata.obs[[]],
var=adata.var[[]],
uns={"dataset_id": adata.uns["dataset_id"]}
obs=adata_output.obs[[]],
var=adata_output.var[[]],
uns={"dataset_id": adata_output.uns["dataset_id"]}
)
test_uns_keys = ["dataset_id", "dataset_name", "dataset_url", "dataset_reference", "dataset_summary", "dataset_description", "dataset_organism"]
output_test = ad.AnnData(
layers={"counts": X_test},
obs=adata.obs[[]],
var=adata.var[[]],
uns={key: adata.uns[key] for key in test_uns_keys}
obs=adata_output.obs[[]],
var=adata_output.var[[]],
uns={key: adata_output.uns[key] for key in test_uns_keys}
)

# add additional information for the train set
Expand Down