diff --git a/examples/colab/train_xpinns_iso.ipynb b/examples/colab/train_xpinns_iso.ipynb index a204f7a..182b707 100644 --- a/examples/colab/train_xpinns_iso.ipynb +++ b/examples/colab/train_xpinns_iso.ipynb @@ -5,7 +5,7 @@ "colab": { "provenance": [], "gpuType": "T4", - "authorship_tag": "ABX9TyNMRppy0Ag7SZRjdAit9OS6", + "authorship_tag": "ABX9TyMWPNdkd3OZ5DzMPjt/tuVq", "include_colab_link": true }, "kernelspec": { @@ -63,22 +63,6 @@ "id": "jlnGqpgzgf_Z" } }, - { - "cell_type": "code", - "source": [ - "# Install specific version (0.4.23) of JAX and Jaxlib\n", - "!pip install --upgrade jax==0.4.23 jaxlib==0.4.23+cuda12.cudnn89 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n", - "\n", - "# Verify the installed version\n", - "import jax\n", - "print(jax.__version__)" - ], - "metadata": { - "id": "ei2r5MmucPiG" - }, - "execution_count": null, - "outputs": [] - }, { "cell_type": "markdown", "source": [ @@ -91,36 +75,17 @@ { "cell_type": "code", "source": [ - "# load the GitHub respository\n", - "!git clone https://github.com/YaoGroup/DIFFICE_jax\n", + "# install the DIFFICE_jax\n", + "!pip install DIFFICE_jax\n", "\n", - "# add the path of the github folder that contains the data\n", - "import os\n", - "os.chdir('DIFFICE_jax/data')\n" + "# download the data from GitHub respository\n", + "!wget https://github.com/YaoGroup/DIFFICE_jax/raw/main/examples/real_data/data_xpinns_RnFlch.mat\n" ], "metadata": { - "id": "eHc_W9sTUBF6", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "67414092-d6cb-4bf4-f8fd-826a86ff683e" + "id": "eHc_W9sTUBF6" }, "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Cloning into 'DIFFICE_jax'...\n", - "remote: Enumerating objects: 991, done.\u001b[K\n", - "remote: Counting objects: 100% (407/407), done.\u001b[K\n", - "remote: Compressing objects: 100% (227/227), done.\u001b[K\n", - "remote: Total 991 (delta 325), reused 239 (delta 180), pack-reused 584\u001b[K\n", - "Receiving objects: 100% (991/991), 91.60 MiB | 16.27 MiB/s, done.\n", - "Resolving deltas: 100% (583/583), done.\n" - ] - } - ] + "outputs": [] }, { "cell_type": "markdown", @@ -133,7 +98,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": { "id": "ZSe-qjr3Ju2N" }, @@ -146,14 +111,12 @@ "from scipy.io import loadmat\n", "import time\n", "\n", - "from DIFFICE_jax.data.xpinns.preprocessing import normalize_data\n", - "from DIFFICE_jax.data.xpinns.sampling import data_sample_create\n", - "from DIFFICE_jax.equation.ssa_eqn_iso import vectgrad, gov_eqn, front_eqn\n", - "from DIFFICE_jax.model.xpinns.initialization import init_xpinns\n", - "from DIFFICE_jax.model.xpinns.networks import solu_create\n", - "from DIFFICE_jax.model.xpinns.loss import loss_iso_create\n", - "from DIFFICE_jax.model.xpinns.prediction import predict\n", - "from DIFFICE_jax.optimizer.optimization import adam_optimizer, lbfgs_optimizer\n" + "from diffice_jax import normdata_xpinn, dsample_xpinn\n", + "from diffice_jax import vectgrad, ssa_iso, dbc_iso\n", + "from diffice_jax import init_xpinn, solu_xpinn\n", + "from diffice_jax import loss_iso_xpinn\n", + "from diffice_jax import predict_xpinn\n", + "from diffice_jax import adam_opt, lbfgs_opt\n" ] }, { @@ -198,7 +161,7 @@ "metadata": { "id": "wFBWogcjXcEe" }, - "execution_count": null, + "execution_count": 4, "outputs": [] }, { @@ -218,14 +181,14 @@ "rawdata = loadmat('data_xpinns_RnFlch.mat')\n", "\n", "# normalize the remote-sensing data for the XPINNs training\n", - "data_all, idxgall, posi_all, idxcrop_all = normalize_data(rawdata)\n", + "data_all, idxgall, posi_all, idxcrop_all = normdata_xpinn(rawdata)\n", "# extract the scale information for each variable\n", "scale = tree_map(lambda x: data_all[x][4][0:2], idxgall)\n" ], "metadata": { "id": "S0S1O_qHmc3d" }, - "execution_count": null, + "execution_count": 5, "outputs": [] }, { @@ -243,33 +206,33 @@ "cell_type": "code", "source": [ "# initialize the weights and biases of the network\n", - "trained_params = init_xpinns(keys[0], n_hl, n_unit,\n", + "trained_params = init_xpinn(keys[0], n_hl, n_unit,\n", " n_sub=len(idxgall))\n", "\n", "# create the solution function [tuple(callable, callable)]\n", - "solNN = solu_create(scale)\n", + "solNN = solu_xpinn(scale)\n", "\n", "# create the data function for Adam\n", - "dataf = data_sample_create(data_all, idxgall, n_pt)\n", + "dataf = dsample_xpinn(data_all, idxgall, n_pt)\n", "keys_adam = random.split(keys[1], 5)\n", "# generate the data\n", "data = dataf(keys_adam[0])\n", "\n", "# create the data function for L-BFGS\n", - "dataf_l = data_sample_create(data_all, idxgall, n_pt2)\n", + "dataf_l = dsample_xpinn(data_all, idxgall, n_pt2)\n", "key_lbfgs = keys[2]\n", "\n", "# group the gov. eqn and bd cond.\n", - "eqn_all = (gov_eqn, front_eqn)\n", + "eqn_all = (ssa_iso, dbc_iso)\n", "# calculate the loss function\n", - "NN_loss = loss_iso_create(solNN, eqn_all, scale, idxgall, lw)\n", + "NN_loss = loss_iso_xpinn(solNN, eqn_all, scale, idxgall, lw)\n", "# calculate the initial loss and set it as the loss reference value\n", "NN_loss.lref = NN_loss(trained_params, data)[0]\n" ], "metadata": { "id": "fzpbvcDDXunl" }, - "execution_count": null, + "execution_count": 6, "outputs": [] }, { @@ -294,7 +257,7 @@ "epoch1 = 10000\n", "\n", "# training with Adam with reducing w\n", - "trained_params, loss1 = adam_optimizer(\n", + "trained_params, loss1 = adam_opt(\n", " keys_adam[1], NN_loss, trained_params, dataf, epoch1, lr=lr)\n" ], "metadata": { @@ -335,7 +298,7 @@ "data_l = dataf_l(key_lbfgs[1])\n", "\n", "# training with L-bfgs\n", - "trained_params2, loss2 = lbfgs_optimizer(NN_loss, trained_params, data_l, epoch2)\n" + "trained_params2, loss2 = lbfgs_opt(NN_loss, trained_params, data_l, epoch2)\n" ], "metadata": { "id": "7gMWY2C8oAIw" @@ -358,18 +321,18 @@ "cell_type": "code", "source": [ "# create the function for trained solution and equation residues\n", - "f_u = lambda x, idx: solNN[0](trained_params, x, idx)\n", + "f_u = lambda x, idx: solNN[0](trained_params2, x, idx)\n", "\n", "# group all the function\n", - "func_all = (f_u, gov_eqn)\n", + "func_all = (f_u, ssa_iso)\n", "\n", "# calculate the solution and equation residue at given grids for visualization\n", - "results = predict(func_all, data_all, posi_all, idxcrop_all, idxgall)\n" + "results = predict_xpinn(func_all, data_all, posi_all, idxcrop_all, idxgall)\n" ], "metadata": { "id": "KnDbH7sZoRHu" }, - "execution_count": null, + "execution_count": 7, "outputs": [] }, {