Skip to content

Commit

Permalink
使用 Colab 创建而成
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyji committed Aug 8, 2024
1 parent dea1624 commit 79424d3
Showing 1 changed file with 30 additions and 67 deletions.
97 changes: 30 additions & 67 deletions examples/colab/train_xpinns_iso.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"colab": {
"provenance": [],
"gpuType": "T4",
"authorship_tag": "ABX9TyNMRppy0Ag7SZRjdAit9OS6",
"authorship_tag": "ABX9TyMWPNdkd3OZ5DzMPjt/tuVq",
"include_colab_link": true
},
"kernelspec": {
Expand Down Expand Up @@ -63,22 +63,6 @@
"id": "jlnGqpgzgf_Z"
}
},
{
"cell_type": "code",
"source": [
"# Install specific version (0.4.23) of JAX and Jaxlib\n",
"!pip install --upgrade jax==0.4.23 jaxlib==0.4.23+cuda12.cudnn89 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"\n",
"# Verify the installed version\n",
"import jax\n",
"print(jax.__version__)"
],
"metadata": {
"id": "ei2r5MmucPiG"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
Expand All @@ -91,36 +75,17 @@
{
"cell_type": "code",
"source": [
"# load the GitHub respository\n",
"!git clone https://github.com/YaoGroup/DIFFICE_jax\n",
"# install the DIFFICE_jax\n",
"!pip install DIFFICE_jax\n",
"\n",
"# add the path of the github folder that contains the data\n",
"import os\n",
"os.chdir('DIFFICE_jax/data')\n"
"# download the data from GitHub respository\n",
"!wget https://github.com/YaoGroup/DIFFICE_jax/raw/main/examples/real_data/data_xpinns_RnFlch.mat\n"
],
"metadata": {
"id": "eHc_W9sTUBF6",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "67414092-d6cb-4bf4-f8fd-826a86ff683e"
"id": "eHc_W9sTUBF6"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cloning into 'DIFFICE_jax'...\n",
"remote: Enumerating objects: 991, done.\u001b[K\n",
"remote: Counting objects: 100% (407/407), done.\u001b[K\n",
"remote: Compressing objects: 100% (227/227), done.\u001b[K\n",
"remote: Total 991 (delta 325), reused 239 (delta 180), pack-reused 584\u001b[K\n",
"Receiving objects: 100% (991/991), 91.60 MiB | 16.27 MiB/s, done.\n",
"Resolving deltas: 100% (583/583), done.\n"
]
}
]
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -133,7 +98,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {
"id": "ZSe-qjr3Ju2N"
},
Expand All @@ -146,14 +111,12 @@
"from scipy.io import loadmat\n",
"import time\n",
"\n",
"from DIFFICE_jax.data.xpinns.preprocessing import normalize_data\n",
"from DIFFICE_jax.data.xpinns.sampling import data_sample_create\n",
"from DIFFICE_jax.equation.ssa_eqn_iso import vectgrad, gov_eqn, front_eqn\n",
"from DIFFICE_jax.model.xpinns.initialization import init_xpinns\n",
"from DIFFICE_jax.model.xpinns.networks import solu_create\n",
"from DIFFICE_jax.model.xpinns.loss import loss_iso_create\n",
"from DIFFICE_jax.model.xpinns.prediction import predict\n",
"from DIFFICE_jax.optimizer.optimization import adam_optimizer, lbfgs_optimizer\n"
"from diffice_jax import normdata_xpinn, dsample_xpinn\n",
"from diffice_jax import vectgrad, ssa_iso, dbc_iso\n",
"from diffice_jax import init_xpinn, solu_xpinn\n",
"from diffice_jax import loss_iso_xpinn\n",
"from diffice_jax import predict_xpinn\n",
"from diffice_jax import adam_opt, lbfgs_opt\n"
]
},
{
Expand Down Expand Up @@ -198,7 +161,7 @@
"metadata": {
"id": "wFBWogcjXcEe"
},
"execution_count": null,
"execution_count": 4,
"outputs": []
},
{
Expand All @@ -218,14 +181,14 @@
"rawdata = loadmat('data_xpinns_RnFlch.mat')\n",
"\n",
"# normalize the remote-sensing data for the XPINNs training\n",
"data_all, idxgall, posi_all, idxcrop_all = normalize_data(rawdata)\n",
"data_all, idxgall, posi_all, idxcrop_all = normdata_xpinn(rawdata)\n",
"# extract the scale information for each variable\n",
"scale = tree_map(lambda x: data_all[x][4][0:2], idxgall)\n"
],
"metadata": {
"id": "S0S1O_qHmc3d"
},
"execution_count": null,
"execution_count": 5,
"outputs": []
},
{
Expand All @@ -243,33 +206,33 @@
"cell_type": "code",
"source": [
"# initialize the weights and biases of the network\n",
"trained_params = init_xpinns(keys[0], n_hl, n_unit,\n",
"trained_params = init_xpinn(keys[0], n_hl, n_unit,\n",
" n_sub=len(idxgall))\n",
"\n",
"# create the solution function [tuple(callable, callable)]\n",
"solNN = solu_create(scale)\n",
"solNN = solu_xpinn(scale)\n",
"\n",
"# create the data function for Adam\n",
"dataf = data_sample_create(data_all, idxgall, n_pt)\n",
"dataf = dsample_xpinn(data_all, idxgall, n_pt)\n",
"keys_adam = random.split(keys[1], 5)\n",
"# generate the data\n",
"data = dataf(keys_adam[0])\n",
"\n",
"# create the data function for L-BFGS\n",
"dataf_l = data_sample_create(data_all, idxgall, n_pt2)\n",
"dataf_l = dsample_xpinn(data_all, idxgall, n_pt2)\n",
"key_lbfgs = keys[2]\n",
"\n",
"# group the gov. eqn and bd cond.\n",
"eqn_all = (gov_eqn, front_eqn)\n",
"eqn_all = (ssa_iso, dbc_iso)\n",
"# calculate the loss function\n",
"NN_loss = loss_iso_create(solNN, eqn_all, scale, idxgall, lw)\n",
"NN_loss = loss_iso_xpinn(solNN, eqn_all, scale, idxgall, lw)\n",
"# calculate the initial loss and set it as the loss reference value\n",
"NN_loss.lref = NN_loss(trained_params, data)[0]\n"
],
"metadata": {
"id": "fzpbvcDDXunl"
},
"execution_count": null,
"execution_count": 6,
"outputs": []
},
{
Expand All @@ -294,7 +257,7 @@
"epoch1 = 10000\n",
"\n",
"# training with Adam with reducing w\n",
"trained_params, loss1 = adam_optimizer(\n",
"trained_params, loss1 = adam_opt(\n",
" keys_adam[1], NN_loss, trained_params, dataf, epoch1, lr=lr)\n"
],
"metadata": {
Expand Down Expand Up @@ -335,7 +298,7 @@
"data_l = dataf_l(key_lbfgs[1])\n",
"\n",
"# training with L-bfgs\n",
"trained_params2, loss2 = lbfgs_optimizer(NN_loss, trained_params, data_l, epoch2)\n"
"trained_params2, loss2 = lbfgs_opt(NN_loss, trained_params, data_l, epoch2)\n"
],
"metadata": {
"id": "7gMWY2C8oAIw"
Expand All @@ -358,18 +321,18 @@
"cell_type": "code",
"source": [
"# create the function for trained solution and equation residues\n",
"f_u = lambda x, idx: solNN[0](trained_params, x, idx)\n",
"f_u = lambda x, idx: solNN[0](trained_params2, x, idx)\n",
"\n",
"# group all the function\n",
"func_all = (f_u, gov_eqn)\n",
"func_all = (f_u, ssa_iso)\n",
"\n",
"# calculate the solution and equation residue at given grids for visualization\n",
"results = predict(func_all, data_all, posi_all, idxcrop_all, idxgall)\n"
"results = predict_xpinn(func_all, data_all, posi_all, idxcrop_all, idxgall)\n"
],
"metadata": {
"id": "KnDbH7sZoRHu"
},
"execution_count": null,
"execution_count": 7,
"outputs": []
},
{
Expand Down

0 comments on commit 79424d3

Please sign in to comment.