diff --git a/uniport/function.py b/uniport/function.py index c23f788..bf84733 100644 --- a/uniport/function.py +++ b/uniport/function.py @@ -16,12 +16,11 @@ from .model.vae import VAE from .model.utils import EarlyStopping -from .metrics import batch_entropy_mixing_score, silhouette from .logger import create_logger from .data_loader import load_data from anndata import AnnData -from sklearn.preprocessing import maxabs_scale, MaxAbsScaler +from sklearn.preprocessing import MaxAbsScaler from glob import glob @@ -277,14 +276,11 @@ def Run( pred_id=1, seed=124, num_workers=4, + patience=30, batch_key='domain_id', source_name='source', - rep_celltype='cell_type', model_info=False, - umap=False, verbose=False, - assess=False, - show=False, ): """ @@ -354,6 +350,8 @@ def Run( Only used when out=='predict' to choose a decoder to predict data. Default: 1 seed Random seed for torch and numpy. Default: 124 + patience + early stopping patience. Default: 10 batch_key Name of batch in AnnData. Default: domain_id source_name @@ -473,14 +471,14 @@ def Run( use_specific=use_specific, domain_name=batch_key, batch_size=batch_size, - num_workers=4 + num_workers=num_workers ) - early_stopping = EarlyStopping(patience=10, checkpoint_file=outdir+'/checkpoint/model.pt') + early_stopping = EarlyStopping(patience=patience, checkpoint_file=outdir+'/checkpoint/model.pt') # encoder structure if enc is None: - enc = [['fc', 1024, 1, 'relu'],['fc', 16, '', '']] + enc = [['fc', 1024, 1, 'relu'], ['fc', 16, '', '']] # decoder structure dec = {} @@ -568,47 +566,6 @@ def Run( elif out == 'predict': adata_cm.obsm[out] = model.encodeBatch(testloader, num_gene, pred_id=pred_id, device=device, mode=mode, out=out) - - if umap: #and adata.shape[0]<1e6: - log.info('Run UMAP') - sc.settings.figdir = outdir - # sc.set_figure_params(dpi=200, fontsize=10) - - if mode == 'h': - sc.pp.neighbors(adata_cm, use_rep=out) - sc.tl.umap(adata_cm, min_dist=0.1) - sc.tl.leiden(adata_cm) - - cols = [source_name, rep_celltype] - color = [c for c in cols if c in adata_cm.obs] - - if len(color) > 0: - sc.pl.umap(adata_cm, color=color, save='result.pdf', show=show) - - if assess: - if len(adata_cm.obs[batch_key].cat.categories) > 1: - entropy_score = batch_entropy_mixing_score(adata_cm.obsm['X_umap'], adata_cm.obs[batch_key]) - log.info('batch_entropy_mixing_score: {:.3f}'.format(entropy_score)) - - if rep_celltype in adata_cm.obs: - sil_score = silhouette(adata_cm.obsm['X_umap'], adata_cm.obs[rep_celltype].cat.codes) - log.info("silhouette_score: {:.3f}".format(sil_score)) - - else: - log.info('Plot umap') - for i in range(n_domain-1): - adata_concat = adatas[i].concatenate(adatas[i+1]) - sc.pp.neighbors(adata_concat, n_neighbors=30, use_rep=out) - sc.tl.umap(adata_concat, min_dist=0.1) - sc.pl.umap(adata_concat, color=[source_name, rep_celltype], save='specific.pdf', wspace=0.3, legend_fontsize=14, \ - title=['',''], s=5, show=show) - - entropy_score = batch_entropy_mixing_score(adata_concat.obsm['X_umap'], adata_concat.obs[source_name]) - log.info('batch_entropy_mixing_score: {:.3f}'.format(entropy_score)) - sil_score = silhouette(adata_concat.obsm['X_umap'], adata_concat.obs[rep_celltype].cat.codes) - log.info("silhouette_score: {:.3f}".format(sil_score)) - - if mode == 'h': if save_OT: return adata_cm, tran