Skip to content

Commit

Permalink
More PM fixes (#444)
Browse files Browse the repository at this point in the history
* use original batch subsetting method

* add print statement

* use heuristic to determine modality

* make feature_name optional for now

* fix division
  • Loading branch information
rcannood authored Apr 26, 2024
1 parent f0ef558 commit f8c18d7
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 16 deletions.
1 change: 1 addition & 0 deletions src/common/check_dataset_schema/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def check_structure(slot, slot_info, adata_slot):
print("Checking slot", slot, flush=True)
missing = check_structure(slot, def_slots[slot], getattr(adata, slot))
if missing:
print(f"Dataset is missing {slot} {missing}", flush=True)
out['exit_code'] = 1
out['data_schema'] = 'not ok'
out['error'][slot] = missing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ info:
name: feature_name
description: A human-readable name for the feature, usually a gene symbol.
# TODO: make this required once the dataloader supports it
required: true
required: false

- type: boolean
name: hvg
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ info:
name: feature_name
description: A human-readable name for the feature, usually a gene symbol.
# TODO: make this required once the dataloader supports it
required: true
required: false

- type: boolean
name: hvg
Expand Down
33 changes: 22 additions & 11 deletions src/tasks/predict_modality/methods/guanlab_dengkw_pm/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,25 +86,36 @@
del test_matrix

print('Running KRR model ...', flush=True)
y_pred = np.zeros((input_test_mod1.n_obs, input_train_mod2.n_vars), dtype=np.float32)

for _ in range(5):
np.random.shuffle(batches)
for batch in [batches[:batch_len//2], batches[batch_len//2:]]:
# for passing the test
if not batch:
batch = [batches[0]]
if batch_len == 1:
# just in case there is only one batch
batch_subsets = [batches]
elif mod1_type == "ADT" or mod2_type == "ADT":
# two fold consensus predictions
batch_subsets = [
batches[:batch_len//2],
batches[batch_len//2:]
]
else:
# leave-one-batch-out consensus predictions
batch_subsets = [
batches[:i] + batches[i+1:]
for i in range(batch_len)
]

y_pred = np.zeros((input_test_mod1.n_obs, input_train_mod2.n_vars), dtype=np.float32)
for batch in batch_subsets:
print(batch, flush=True)
kernel = RBF(length_scale = scale)
krr = KernelRidge(alpha=alpha, kernel=kernel)
print('Fitting KRR ... ', flush=True)
krr.fit(train_norm[input_train_mod1.obs.batch.isin(batch)], train_gs[input_train_mod2.obs.batch.isin(batch)])
krr.fit(
train_norm[input_train_mod1.obs.batch.isin(batch)],
train_gs[input_train_mod2.obs.batch.isin(batch)]
)
y_pred += (krr.predict(test_norm) @ embedder_mod2.components_)

np.clip(y_pred, a_min=0, a_max=None, out=y_pred)

y_pred /= 10
y_pred /= len(batch_subsets)

# Store as sparse matrix to be efficient.
# Note that this might require different classifiers/embedders before-hand.
Expand Down
22 changes: 19 additions & 3 deletions src/tasks/predict_modality/process_dataset/script.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,25 @@ cat("Reading input data\n")
ad1 <- anndata::read_h5ad(if (!par$swap) par$input_mod1 else par$input_mod2)
ad2 <- anndata::read_h5ad(if (!par$swap) par$input_mod2 else par$input_mod1)

# figure out modality types
ad1_mod <- unique(ad1$var[["feature_types"]])
ad2_mod <- unique(ad2$var[["feature_types"]])
# use heuristic to determine modality
# TODO: should be removed once modality is stored in the uns
determine_modality <- function(ad, mod1 = TRUE) {
if ("modality" %in% names(ad$uns)) {
ad$uns[["modality"]]
} else if ("feature_types" %in% colnames(ad$var)) {
unique(ad$var[["feature_types"]])
} else if (mod1) {
"RNA"
} else if (grepl("cite", ad$uns[["dataset_id"]])) {
"ADT"
} else if (grepl("multiome", ad$uns[["dataset_id"]])) {
"ATAC"
} else {
stop("Could not determine modality")
}
}
ad1_mod <- determine_modality(ad1, !par$swap)
ad2_mod <- determine_modality(ad2, par$swap)

# determine new uns
uns_vars <- c("dataset_id", "dataset_name", "dataset_url", "dataset_reference", "dataset_summary", "dataset_description", "dataset_organism", "normalization_id")
Expand Down

0 comments on commit f8c18d7

Please sign in to comment.