diff --git a/diffice_jax/__init__.py b/diffice_jax/__init__.py index d12519f..1221c6f 100644 --- a/diffice_jax/__init__.py +++ b/diffice_jax/__init__.py @@ -10,6 +10,9 @@ from .equation.eqn_aniso_zz import gov_eqn as ssa_aniso from .equation.eqn_aniso_zz import front_eqn as dbc_aniso +from .model.pinns.initialization import init_nets as init_pinn +from .model.xpinns.initialization import init_nets as init_xpinn + from .model.pinns.networks import solu_create as solu_pinn from .model.xpinns.networks import solu_create as solu_xpinn @@ -24,7 +27,8 @@ from .optimizer.optimization import adam_optimizer as adam_opt from .optimizer.optimization import lbfgs_optimizer as lbfgs_opt -__all__ = ["normdata_pinn", "normdata_xpinn", "dsample_pinn", "dsample_xpinn", "vectgrad", - "ssa_iso", "dbc_iso", "ssa_aniso", "dbc_aniso", "solu_pinn", "solu_xpinn", +__all__ = ["normdata_pinn", "normdata_xpinn", "dsample_pinn", "dsample_xpinn", + "vectgrad", "ssa_iso", "dbc_iso", "ssa_aniso", "dbc_aniso", + "init_pinn", "init_xpinn", "solu_pinn", "solu_xpinn", "loss_iso_pinn", "loss_aniso_pinn", "loss_iso_xpinn", "loss_aniso_xpinn", "predict_pinn", "predict_xpinn", "adam_opt", "lbfgs_opt"]