Skip to content

Commit

Permalink
tl.mofa: add train_kwargs argument
Browse files Browse the repository at this point in the history
  • Loading branch information
ilia-kats committed Nov 8, 2024
1 parent 0135489 commit a706e35
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion muon/_core/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ def mofa(
use_float32: bool = False,
gpu_mode: bool = False,
gpu_device: Optional[bool] = None,
train_kwargs: Optional[Mapping[str, Any]] = None,
svi_mode: bool = False,
svi_batch_size: float = 0.5,
svi_learning_rate: float = 1.0,
Expand Down Expand Up @@ -370,8 +371,11 @@ def mofa(
use reduced precision (float32)
gpu_mode : optional
if to use GPU mode
gpu_mode : optional
gpu_device : optional
which GPU device to use
train_kwargs: optional
additional parameters for MOFA (startELBO, freqELBO, startSparsity, tolerance, startDrop, freqDrop,
dropR2, nostop, schedule, weight_views)
svi_mode : optional
if to use Stochastic Variational Inference (SVI)
svi_batch_size : optional
Expand Down Expand Up @@ -489,6 +493,8 @@ def mofa(
)
logging.info(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Setting training options...")

if train_kwargs is None:
train_kwargs = {}
try:
ent.set_train_options(
iter=n_iterations,
Expand All @@ -500,6 +506,7 @@ def mofa(
quiet=quiet,
outfile=outfile,
save_interrupted=save_interrupted,
**train_kwargs
)
except TypeError:
# mofapy2 <0.7 does not have a gpu_device argument
Expand All @@ -516,6 +523,7 @@ def mofa(
quiet=quiet,
outfile=outfile,
save_interrupted=save_interrupted,
**train_kwargs
)

if svi_mode:
Expand Down

0 comments on commit a706e35

Please sign in to comment.