Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
caokai1073 authored Mar 7, 2023
1 parent 43402f9 commit 33d37e6
Showing 1 changed file with 7 additions and 50 deletions.
57 changes: 7 additions & 50 deletions uniport/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@

from .model.vae import VAE
from .model.utils import EarlyStopping
from .metrics import batch_entropy_mixing_score, silhouette
from .logger import create_logger
from .data_loader import load_data

from anndata import AnnData
from sklearn.preprocessing import maxabs_scale, MaxAbsScaler
from sklearn.preprocessing import MaxAbsScaler

from glob import glob

Expand Down Expand Up @@ -277,14 +276,11 @@ def Run(
pred_id=1,
seed=124,
num_workers=4,
patience=30,
batch_key='domain_id',
source_name='source',
rep_celltype='cell_type',
model_info=False,
umap=False,
verbose=False,
assess=False,
show=False,
):

"""
Expand Down Expand Up @@ -354,6 +350,8 @@ def Run(
Only used when out=='predict' to choose a decoder to predict data. Default: 1
seed
Random seed for torch and numpy. Default: 124
patience
early stopping patience. Default: 10
batch_key
Name of batch in AnnData. Default: domain_id
source_name
Expand Down Expand Up @@ -473,14 +471,14 @@ def Run(
use_specific=use_specific,
domain_name=batch_key,
batch_size=batch_size,
num_workers=4
num_workers=num_workers
)

early_stopping = EarlyStopping(patience=10, checkpoint_file=outdir+'/checkpoint/model.pt')
early_stopping = EarlyStopping(patience=patience, checkpoint_file=outdir+'/checkpoint/model.pt')

# encoder structure
if enc is None:
enc = [['fc', 1024, 1, 'relu'],['fc', 16, '', '']]
enc = [['fc', 1024, 1, 'relu'], ['fc', 16, '', '']]

# decoder structure
dec = {}
Expand Down Expand Up @@ -568,47 +566,6 @@ def Run(
elif out == 'predict':
adata_cm.obsm[out] = model.encodeBatch(testloader, num_gene, pred_id=pred_id, device=device, mode=mode, out=out)


if umap: #and adata.shape[0]<1e6:
log.info('Run UMAP')
sc.settings.figdir = outdir
# sc.set_figure_params(dpi=200, fontsize=10)

if mode == 'h':
sc.pp.neighbors(adata_cm, use_rep=out)
sc.tl.umap(adata_cm, min_dist=0.1)
sc.tl.leiden(adata_cm)

cols = [source_name, rep_celltype]
color = [c for c in cols if c in adata_cm.obs]

if len(color) > 0:
sc.pl.umap(adata_cm, color=color, save='result.pdf', show=show)

if assess:
if len(adata_cm.obs[batch_key].cat.categories) > 1:
entropy_score = batch_entropy_mixing_score(adata_cm.obsm['X_umap'], adata_cm.obs[batch_key])
log.info('batch_entropy_mixing_score: {:.3f}'.format(entropy_score))

if rep_celltype in adata_cm.obs:
sil_score = silhouette(adata_cm.obsm['X_umap'], adata_cm.obs[rep_celltype].cat.codes)
log.info("silhouette_score: {:.3f}".format(sil_score))

else:
log.info('Plot umap')
for i in range(n_domain-1):
adata_concat = adatas[i].concatenate(adatas[i+1])
sc.pp.neighbors(adata_concat, n_neighbors=30, use_rep=out)
sc.tl.umap(adata_concat, min_dist=0.1)
sc.pl.umap(adata_concat, color=[source_name, rep_celltype], save='specific.pdf', wspace=0.3, legend_fontsize=14, \
title=['',''], s=5, show=show)

entropy_score = batch_entropy_mixing_score(adata_concat.obsm['X_umap'], adata_concat.obs[source_name])
log.info('batch_entropy_mixing_score: {:.3f}'.format(entropy_score))
sil_score = silhouette(adata_concat.obsm['X_umap'], adata_concat.obs[rep_celltype].cat.codes)
log.info("silhouette_score: {:.3f}".format(sil_score))


if mode == 'h':
if save_OT:
return adata_cm, tran
Expand Down

0 comments on commit 33d37e6

Please sign in to comment.