From bd043b20ca5a1c48dbea68e0f0af199b55413946 Mon Sep 17 00:00:00 2001 From: Kai Waldrant Date: Fri, 20 Sep 2024 12:38:06 +0200 Subject: [PATCH] Add fixes for benchmark results (#14) * update submodule * add sharedmem to alra * Add subsample to cellxgene datasets * add shared mem to knn_smoothing * update changelog * use hard cutoff for batch selection * update process_dataset to include random subsample * script * [WIP] remove shared mem * Update changelog * update changelog --- CHANGELOG.md | 2 ++ src/api/file_common_dataset.yaml | 6 ++++ .../process_dataset/config.vsh.yaml | 4 +++ src/data_processors/process_dataset/script.py | 36 ++++++++++++++----- 4 files changed, 39 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b220bab..afdf427 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ * Add `CHANGELOG.md` (PR #7). +* Update `process_dataset` component to subsample large datasets (PR #14). + ## MAJOR CHANGES * Revamp `scripts` directory (PR #13). diff --git a/src/api/file_common_dataset.yaml b/src/api/file_common_dataset.yaml index 4e5db66..8ad021f 100644 --- a/src/api/file_common_dataset.yaml +++ b/src/api/file_common_dataset.yaml @@ -10,6 +10,12 @@ info: name: counts description: Raw counts required: true + obs: + - type: string + name: batch + description: Batch information + required: false + uns: - type: string name: dataset_id diff --git a/src/data_processors/process_dataset/config.vsh.yaml b/src/data_processors/process_dataset/config.vsh.yaml index 3914251..4587c23 100644 --- a/src/data_processors/process_dataset/config.vsh.yaml +++ b/src/data_processors/process_dataset/config.vsh.yaml @@ -19,6 +19,10 @@ arguments: type: "integer" description: "A seed for the subsampling." example: 123 + - name: "--n_obs_limit" + type: "integer" + description: "The maximum number of cells the dataset may have before subsampling according to `obs.batch`." + default: 20000 resources: - type: python_script path: script.py diff --git a/src/data_processors/process_dataset/script.py b/src/data_processors/process_dataset/script.py index 9d34051..0c0c7e1 100644 --- a/src/data_processors/process_dataset/script.py +++ b/src/data_processors/process_dataset/script.py @@ -1,4 +1,5 @@ import sys +import random import anndata as ad import numpy as np @@ -26,13 +27,30 @@ print(">> Load Data", flush=True) adata = ad.read_h5ad(par["input"]) +# limit to max number of observations +adata_output = adata.copy() +if adata.n_obs > par["n_obs_limit"]: + print(">> Subsampling the observations", flush=True) + print(f">> Setting seed to {par['seed']}") + random.seed(par["seed"]) + if "batch" not in adata.obs: + obs_filt = np.ones(dtype=np.bool_, shape=adata.n_obs) + obs_index = np.random.choice(np.where(obs_filt)[0], par["n_obs_limit"], replace=False) + adata_output = adata[obs_index].copy() + else: + batch_counts = adata.obs.groupby('batch').size() + filtered_batches = batch_counts[batch_counts <= par["n_obs_limit"]] + sorted_filtered_batches = filtered_batches.sort_values(ascending=False) + selected_batch = sorted_filtered_batches.index[0] + adata_output = adata[adata.obs["batch"]==selected_batch,:].copy() + # remove all layers except for counts -for key in list(adata.layers.keys()): +for key in list(adata_output.layers.keys()): if key != "counts": - del adata.layers[key] + del adata_output.layers[key] # round counts and convert to int -counts = np.array(adata.layers["counts"]).round().astype(int) +counts = np.array(adata_output.layers["counts"]).round().astype(int) print(">> process and split data", flush=True) train_data, test_data = split_molecules( @@ -49,16 +67,16 @@ # copy adata to train_set, test_set output_train = ad.AnnData( layers={"counts": X_train}, - obs=adata.obs[[]], - var=adata.var[[]], - uns={"dataset_id": adata.uns["dataset_id"]} + obs=adata_output.obs[[]], + var=adata_output.var[[]], + uns={"dataset_id": adata_output.uns["dataset_id"]} ) test_uns_keys = ["dataset_id", "dataset_name", "dataset_url", "dataset_reference", "dataset_summary", "dataset_description", "dataset_organism"] output_test = ad.AnnData( layers={"counts": X_test}, - obs=adata.obs[[]], - var=adata.var[[]], - uns={key: adata.uns[key] for key in test_uns_keys} + obs=adata_output.obs[[]], + var=adata_output.var[[]], + uns={key: adata_output.uns[key] for key in test_uns_keys} ) # add additional information for the train set