-
Notifications
You must be signed in to change notification settings - Fork 3
/
MarrowTM.py
47 lines (41 loc) · 2.11 KB
/
MarrowTM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
use_cuda = True
from scvi.dataset.dataset import GeneExpressionDataset
from scvi.harmonization.utils_chenling import CompareModels
import numpy as np
import sys
sys.path.append("/data/yosef2/users/chenling/HarmonizationSCANVI")
models = str(sys.argv[1])
plotname = 'MarrowTM'
from scvi.dataset.muris_tabula import TabulaMuris
dataset1 = TabulaMuris('facs', save_path='/data/yosef2/scratch/chenling/scanvi_data/')
dataset2 = TabulaMuris('droplet', save_path='/data/yosef2/scratch/chenling/scanvi_data/')
dataset1.subsample_genes(dataset1.nb_genes)
dataset2.subsample_genes(dataset2.nb_genes)
gene_dataset = GeneExpressionDataset.concat_datasets(dataset1, dataset2)
f = open("../%s/celltypeprop.txt"%plotname, "w+")
f.write("%s\t"*len(gene_dataset.cell_types)%tuple(gene_dataset.cell_types)+"\n")
freq = [np.mean(gene_dataset.labels.ravel()==i) for i in np.unique(gene_dataset.labels.ravel())]
f.write("%f\t"*len(gene_dataset.cell_types)%tuple(freq)+"\n")
freq1 = [np.mean(gene_dataset.labels.ravel()[gene_dataset.batch_indices.ravel()==0]==i) for i in np.unique(gene_dataset.labels.ravel())]
f.write("%f\t"*len(gene_dataset.cell_types)%tuple(freq1)+"\n")
freq2 = [np.mean(gene_dataset.labels.ravel()[gene_dataset.batch_indices.ravel()==1]==i) for i in np.unique(gene_dataset.labels.ravel())]
f.write("%f\t"*len(gene_dataset.cell_types)%tuple(freq2)+"\n")
f.close()
CompareModels(gene_dataset, dataset1, dataset2, plotname, models)
# 438.05354394990854
#
#
# from scvi.harmonization.utils_chenling import run_model
# from scvi.harmonization.utils_chenling import scmap_eval
# latent, batch_indices, labels, keys, stats = run_model('scmap', gene_dataset, dataset1, dataset2, filename=plotname)
# pred1 = latent
# pred2 = stats
# res1 = scmap_eval(pred1, labels[batch_indices == 1],labels)
# res2 = scmap_eval(pred2, labels[batch_indices == 0],labels)
# print(res1)
# print(res2)
# temp1, batch_indices, labels, keys, temp2 = run_model('scmap', gene_dataset, dataset1, dataset2, filename=plotname)
# pred1 = temp1
# pred2 = temp2
# res1 = scmap_eval(pred1, labels[batch_indices == 1],labels)
# res2 = scmap_eval(pred2, labels[batch_indices == 0],labels)