Setup amortized network for Cell2location #264
Replies: 9 comments
-
The main parameter that needs to be changed is the number of hidden nodes per variable. This can be fixed in the dictionary returned by Any thoughts? |
Beta Was this translation helpful? Give feedback.
-
We also use ELU and LayerNorm rather than BatchNorm. We don't know the optimal number of layers - but 1-2 give good results. Encoder NN is also conditioned on batch but not other covariates. Maybe you can help with a more systematic benchmark of these choices? |
Beta Was this translation helpful? Give feedback.
-
My idea would be not manual interaction but setting variables that are in the order of n_genes to another n_hidden than other variables (something like 128 vs 16 neurons). Relying on user input feels like a bad choice and there should be some reasonable default. |
Beta Was this translation helpful? Give feedback.
-
1-2 layers sounds reasonable to me. I wouldn't use parameter specific dropout and I generally see low difference between different activation functions in scVI models (ELU vs ReLU vs mish). I wouldn't use parameter specific dropout, though most likely some configurations will give slightly better results. |
Beta Was this translation helpful? Give feedback.
-
The amortising NN here is actually modified such that dropout is applied on input to each layer (not output). I found that it is important to use dropout of input data to improve accuracy - probably due to "selection for redundancy" when the first layer learns informative genes. |
Beta Was this translation helpful? Give feedback.
-
Re aggressive training. I also recommend changing the number of aggressive steps (10 seems to be generally more useful than 5 or 20): model.train_aggressive(
plan_kwargs={
'optim': pyro.optim.Adam(optim_args={'lr': 0.001}),
'n_aggressive_epochs': 900,
'n_aggressive_steps': 10,
},
) |
Beta Was this translation helpful? Give feedback.
-
I generally use the following encoder NN settings: mod = cell2location.models.Cell2location(
... ,
encoder_kwargs={'dropout_rate': 0.1,
'n_hidden': {
"single": 256,
"n_s_cells_per_location": 10,
"b_s_groups_per_location": 10,
"z_sr_groups_factors": 64,
"w_sf": 256,
"detection_y_s": 20,
"w_sf_residual_factors": 64,
"b_s_residual_factors_per_location": 10,
},
'use_batch_norm': False, 'use_layer_norm': True,
'n_layers': 1, 'activation_fn': nn.ELU,
},
) |
Beta Was this translation helpful? Give feedback.
-
Do you have any insights into RAM usage? For 6000 genes and ~40k cells 20 GB of free GPU usage is not enough. Would you recommend training on each slide separately or use mini-batching?
|
Beta Was this translation helpful? Give feedback.
-
Hi @cane11 For standard inference (amortised=False), we recommend 1) using a GPU with more memory, 2) splitting the dataset into several parts (not each batch separately but more meaningful divisions). Mini-batching needs to go through the same number of epoch as standard inference (e.g. 20k) which can take several days. For amortised inference, you need to use mini batching and expect 2GB-10GB depending on batch size and the number of genes.
/home/eecs/cergen/anaconda3/envs/reticulate/lib/python3.8/site-packages/pyro/primitives.py:443: UserWarning: Layer 0.1.weight was not registered in the param store because requires_grad=False. You can silence this warning by calling my_module.train(False)
warnings.warn(
/home/eecs/cergen/anaconda3/envs/reticulate/lib/python3.8/site-packages/pyro/primitives.py:443: UserWarning: Layer 0.1.bias was not registered in the param store because requires_grad=False. You can silence this warning by calling my_module.train(False) This is an expected consequence/side effect of aggressive training periodically hiding and exposing amortised and non-amortised variables. You can silence the warnings: with warnings.catch_warnings():
warnings.simplefilter("ignore")
mod.train_aggressive() |
Beta Was this translation helpful? Give feedback.
-
As discussed this morning, can you provide better setup parameter heuristic for amortized network setup? This is meant to be in line of:
mod = cell2location.models.Cell2location(
adata_vis,
cell_state_df=inf_aver,
# the expected average cell abundance: tissue-dependent
# hyper-prior which can be estimated from paired histology:
N_cells_per_location=30,
# hyperparameter controlling normalisation of
# within-experiment variation in RNA detection:
detection_alpha=10,
detection_mean_per_sample=True,
amortised=True,
encoder_mode="multiple",
)
mod.view_anndata_setup()
Beta Was this translation helpful? Give feedback.
All reactions