From 4ffa9b89e711b1c542deaf61ad2ab956bcead0d7 Mon Sep 17 00:00:00 2001 From: Heeringa Date: Tue, 13 Aug 2024 10:05:46 +0200 Subject: [PATCH] Applied linting --- .flake8 | 4 +- .github/workflows/black.yml | 12 -- .github/workflows/lint.yml | 36 +++++ notebooks/advection.ipynb | 63 ++++---- notebooks/diffusion.ipynb | 80 +++++----- notebooks/reaction_diffusion.ipynb | 147 +++++++++--------- pyproject.toml | 6 + src/data/databases/advection_1d_database.py | 24 +-- src/data/databases/diffusion_1d_database.py | 13 +- src/data/databases/pde_database.py | 20 +-- .../reaction_diffusion_2d_database.py | 50 +++--- src/data/generate_datasets.py | 24 +-- src/data/initial_conditions/gaussian.py | 4 +- src/models.py | 31 ++-- src/training/advection/train.py | 10 +- src/training/diffusion/train.py | 11 +- src/training/reaction_diffusion/train.py | 12 +- src/training/train.py | 16 +- src/utils/L12_nuclear.py | 10 +- src/utils/__init__.py | 2 +- src/utils/get_bias.py | 6 +- 21 files changed, 265 insertions(+), 316 deletions(-) delete mode 100644 .github/workflows/black.yml create mode 100644 .github/workflows/lint.yml create mode 100644 pyproject.toml diff --git a/.flake8 b/.flake8 index 8b6b4c7..2a18a1b 100644 --- a/.flake8 +++ b/.flake8 @@ -1,3 +1,3 @@ [flake8] -max-line-length=88 -extend-ignore=E230,W503,W504 +max-line-length=130 +extend-ignore=F401 diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml deleted file mode 100644 index bfc6712..0000000 --- a/.github/workflows/black.yml +++ /dev/null @@ -1,12 +0,0 @@ -name: Lint - -on: [push, pull_request] - -jobs: - lint: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: psf/black@stable - with: - jupyter: true diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..815408f --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,36 @@ +name: lint +on: + - push + - pull_request + +jobs: + isort: + runs-on: ubuntu-latest + steps: + - name: Check out source repository + uses: actions/checkout@v3 + - name: Run isort + uses: isort/isort-action@v1 + black: + runs-on: ubuntu-latest + steps: + - name: Check out source repository + uses: actions/checkout@v4 + - name: Run black + uses: psf/black@stable + with: + jupyter: true + flake8: + runs-on: ubuntu-latest + name: Lint + steps: + - name: Check out source repository + uses: actions/checkout@v3 + - name: Set up Python environment + uses: actions/setup-python@v4 + with: + python-version: "3.10" + - name: flake8 Lint + uses: py-actions/flake8@v2 + with: + plugins: "flake8-nb flake8-black" diff --git a/notebooks/advection.ipynb b/notebooks/advection.ipynb index 843ac8b..44db80a 100644 --- a/notebooks/advection.ipynb +++ b/notebooks/advection.ipynb @@ -11,7 +11,8 @@ }, "source": [ "import sys\n", - "sys.path.insert(0, '../')" + "\n", + "sys.path.insert(0, \"../\")" ], "outputs": [], "execution_count": 22 @@ -53,7 +54,7 @@ "from tqdm import tqdm\n", "\n", "from src.data.generate_datasets import BASE_ADVECTION_CONFIG as CONFIG, generate_advection_dataset\n", - "from src.training.advection.train import TRAIN_PARAMETERS, TEST_PARAMETERS\n" + "from src.training.advection.train import TRAIN_PARAMETERS, TEST_PARAMETERS" ], "metadata": { "collapsed": false, @@ -77,7 +78,7 @@ "cell_type": "code", "source": [ "base_storage_path = Path(\"../data/advection\")\n", - "results_path = Path(\"../results/advection\")\n" + "results_path = Path(\"../results/advection\")" ], "metadata": { "collapsed": false, @@ -124,7 +125,7 @@ "stacked_test_datasets = generate_advection_dataset(TEST_PARAMETERS, base_storage_path=base_storage_path, return_stacked=True)\n", "\n", "train_dataset = torch.utils.data.TensorDataset(stacked_train_datasets)\n", - "test_dataset = torch.utils.data.TensorDataset(stacked_test_datasets)\n" + "test_dataset = torch.utils.data.TensorDataset(stacked_test_datasets)" ], "metadata": { "collapsed": false, @@ -269,7 +270,7 @@ "TTx, XX = torch.meshgrid(\n", " [time, domain],\n", " indexing=\"ij\",\n", - ")\n" + ")" ], "metadata": { "collapsed": false, @@ -299,10 +300,9 @@ }, "cell_type": "code", "source": [ - "\n", "fig, axs = plt.subplots(1, 3)\n", "for i in range(len(TRAIN_PARAMETERS)):\n", - " axs[i].pcolormesh(TTx, XX, stacked_train_datasets[i*CONFIG.Nt:(i+1)*CONFIG.Nt].data.detach())\n" + " axs[i].pcolormesh(TTx, XX, stacked_train_datasets[i * CONFIG.Nt : (i + 1) * CONFIG.Nt].data.detach())" ], "outputs": [ { @@ -330,14 +330,14 @@ { "cell_type": "code", "source": [ - "snapshot_matrix = stacked_train_datasets[:CONFIG.Nt].data\n", + "snapshot_matrix = stacked_train_datasets[: CONFIG.Nt].data\n", "\n", "plt.figure()\n", "# plt.pcolormesh(TTx, XX, snapshot_matrix.detach(), cmap=mpl.colormaps[\"magma\"])\n", "plt.pcolormesh(TTx, XX, snapshot_matrix.detach())\n", "plt.colorbar()\n", - "plt.rc('xtick', labelsize=12) # fontsize of the tick labels\n", - "plt.rc('ytick', labelsize=12) # fontsize of the tick labels\n", + "plt.rc(\"xtick\", labelsize=12) # fontsize of the tick labels\n", + "plt.rc(\"ytick\", labelsize=12) # fontsize of the tick labels\n", "plt.ylabel(\"Space\", fontsize=16)\n", "plt.xlabel(\"Time\", fontsize=16)\n", "# plt.savefig(results_path / \"advection_snapshot.png\", bbox_inches='tight')\n", @@ -345,8 +345,8 @@ "plt.figure()\n", "plt.semilogy(torch.linalg.svdvals(snapshot_matrix), lw=5)\n", "# plt.semilogy(torch.linalg.svdvals(snapshot_matrix), lw=5, c=\"magenta\")\n", - "plt.rc('xtick', labelsize=14) # fontsize of the tick labels\n", - "plt.rc('ytick', labelsize=14) # fontsize of the tick labels\n", + "plt.rc(\"xtick\", labelsize=14) # fontsize of the tick labels\n", + "plt.rc(\"ytick\", labelsize=14) # fontsize of the tick labels\n", "plt.ylabel(\"Singular value\", fontsize=16)\n", "plt.xlabel(\"Index\", fontsize=16)\n", "# plt.savefig(results_path / \"advection_singular.png\", bbox_inches='tight')\n", @@ -354,7 +354,7 @@ "\n", "plt.figure()\n", "for k in [0, 10, 100, -1]:\n", - " plt.plot(domain, snapshot_matrix[k])\n" + " plt.plot(domain, snapshot_matrix[k])" ], "metadata": { "collapsed": false, @@ -416,13 +416,15 @@ "cell_type": "code", "source": [ "fig = plt.figure()\n", - "line, = plt.plot(domain, snapshot_matrix.data[0,:].detach(), lw=4)\n", + "(line,) = plt.plot(domain, snapshot_matrix.data[0, :].detach(), lw=4)\n", + "\n", "\n", "def animate(k):\n", - " line.set_ydata(snapshot_matrix.data[k,:].detach())\n", - " return line,\n", + " line.set_ydata(snapshot_matrix.data[k, :].detach())\n", + " return (line,)\n", + "\n", "\n", - "# ani = FuncAnimation(fig, animate, interval=40, blit=True, repeat=True, frames=CONFIG.Nt) \n", + "# ani = FuncAnimation(fig, animate, interval=40, blit=True, repeat=True, frames=CONFIG.Nt)\n", "# ani.save(results_path / \"advection.gif\", dpi=300, writer=PillowWriter(fps=25))" ], "outputs": [ @@ -479,7 +481,7 @@ "plt.figure()\n", "plt.semilogy(relative_error, lw=7)\n", "plt.ylabel(\"Singular values\")\n", - "plt.xlabel(\"Index\")\n" + "plt.xlabel(\"Index\")" ], "metadata": { "collapsed": false, @@ -539,9 +541,9 @@ "\n", "plt.figure()\n", "plt.semilogy(relative_error, lw=3)\n", - "plt.xlim((0,50))\n", + "plt.xlim((0, 50))\n", "plt.ylabel(\"Relative error\")\n", - "plt.xlabel(\"index\")\n" + "plt.xlabel(\"index\")" ], "outputs": [ { @@ -616,16 +618,18 @@ "cell_type": "code", "outputs": [], "execution_count": 37, - "source": "Vhr = Vh[:ldim,:]" + "source": [ + "Vhr = Vh[:ldim, :]" + ] }, { "cell_type": "code", "source": [ "plt.figure()\n", - "plt.plot(domain, Vh[1,:], label=\"mode 1\")\n", - "plt.plot(domain, Vh[2,:], label=\"mode 2\")\n", - "plt.plot(domain, Vh[3,:], label=\"mode 3\")\n", - "plt.plot(domain, Vh[4,:], label=\"mode 4\")\n" + "plt.plot(domain, Vh[1, :], label=\"mode 1\")\n", + "plt.plot(domain, Vh[2, :], label=\"mode 2\")\n", + "plt.plot(domain, Vh[3, :], label=\"mode 3\")\n", + "plt.plot(domain, Vh[4, :], label=\"mode 4\")" ], "metadata": { "collapsed": false, @@ -671,7 +675,12 @@ } }, "cell_type": "code", - "source": "print(torch.norm(stacked_train_datasets - torch.einsum(\"nm,bn->bm\", Vhr.T @ Vhr, stacked_train_datasets))**2 / torch.norm(stacked_train_datasets)**2)", + "source": [ + "print(\n", + " torch.norm(stacked_train_datasets - torch.einsum(\"nm,bn->bm\", Vhr.T @ Vhr, stacked_train_datasets)) ** 2\n", + " / torch.norm(stacked_train_datasets) ** 2\n", + ")" + ], "outputs": [ { "name": "stdout", @@ -737,7 +746,7 @@ "source": [ "plt.figure()\n", "plt.plot(snapshot_matrix[-1].detach(), lw=5)\n", - "plt.plot((Vhr.T @ Vhr @ snapshot_matrix[-1]).detach())\n" + "plt.plot((Vhr.T @ Vhr @ snapshot_matrix[-1]).detach())" ], "metadata": { "collapsed": false, diff --git a/notebooks/diffusion.ipynb b/notebooks/diffusion.ipynb index b0c8d7d..318798e 100644 --- a/notebooks/diffusion.ipynb +++ b/notebooks/diffusion.ipynb @@ -29,7 +29,8 @@ "source": [ "import copy\n", "import sys\n", - "sys.path.insert(0, '../')" + "\n", + "sys.path.insert(0, \"../\")" ], "metadata": { "collapsed": false, @@ -256,21 +257,21 @@ "# plt.pcolormesh(TTx, XX, snapshot_matrix.detach(), cmap=mpl.colormaps[\"plasma_r\"])\n", "plt.pcolormesh(TTx, XX, snapshot_matrix.detach())\n", "plt.colorbar()\n", - "plt.rc('xtick', labelsize=12) # fontsize of the tick labels\n", - "plt.rc('ytick', labelsize=12) # fontsize of the tick labels\n", + "plt.rc(\"xtick\", labelsize=12) # fontsize of the tick labels\n", + "plt.rc(\"ytick\", labelsize=12) # fontsize of the tick labels\n", "plt.ylabel(\"Space\", fontsize=16)\n", "plt.xlabel(\"Time\", fontsize=16)\n", - "plt.savefig(results_path / \"sol.png\", bbox_inches='tight')\n", + "plt.savefig(results_path / \"sol.png\", bbox_inches=\"tight\")\n", "\n", "plt.figure()\n", "plt.semilogy(torch.linalg.svdvals(snapshot_matrix), lw=5)\n", "# plt.semilogy(torch.linalg.svdvals(snapshot_matrix), lw=5, c=\"magenta\")\n", "# plt.xlim((0,25))\n", - "plt.rc('xtick', labelsize=14) # fontsize of the tick labels\n", - "plt.rc('ytick', labelsize=14) # fontsize of the tick labels\n", + "plt.rc(\"xtick\", labelsize=14) # fontsize of the tick labels\n", + "plt.rc(\"ytick\", labelsize=14) # fontsize of the tick labels\n", "plt.ylabel(\"Singular value\", fontsize=16)\n", "plt.xlabel(\"Index\", fontsize=16)\n", - "plt.savefig(results_path / \"singular.png\", bbox_inches='tight')\n" + "plt.savefig(results_path / \"singular.png\", bbox_inches=\"tight\")" ], "metadata": { "collapsed": false, @@ -306,15 +307,17 @@ ], "source": [ "fig = plt.figure()\n", - "line, = plt.plot(domain, snapshot_matrix[0,:].detach(), lw=4)\n", - "plt.xlim((-1,1))\n", + "(line,) = plt.plot(domain, snapshot_matrix[0, :].detach(), lw=4)\n", + "plt.xlim((-1, 1))\n", "plt.axhline(0, color=\"black\", linestyle=\"--\")\n", "\n", + "\n", "def animate(k):\n", - " line.set_ydata(snapshot_matrix[5*k,:].detach())\n", - " return line,\n", + " line.set_ydata(snapshot_matrix[5 * k, :].detach())\n", + " return (line,)\n", + "\n", "\n", - "# ani = FuncAnimation(fig, animate, interval=40, blit=True, frames=Nt // 5) \n", + "# ani = FuncAnimation(fig, animate, interval=40, blit=True, frames=Nt // 5)\n", "# ani.save(results_path / \"diffusion.gif\", dpi=300, writer=PillowWriter(fps=25))" ], "metadata": { @@ -371,7 +374,7 @@ "execution_count": 137, "outputs": [], "source": [ - "relative_error = ( torch.sum(Sigma, dim=0) - torch.cumsum(Sigma, dim=0) ) / torch.sum(Sigma, dim=0)" + "relative_error = (torch.sum(Sigma, dim=0) - torch.cumsum(Sigma, dim=0)) / torch.sum(Sigma, dim=0)" ], "metadata": { "collapsed": false, @@ -418,9 +421,9 @@ "\n", "plt.figure()\n", "plt.semilogy(S, lw=3)\n", - "plt.xlim((0,20))\n", + "plt.xlim((0, 20))\n", "plt.ylabel(\"Singular value\")\n", - "plt.xlabel(\"index\")\n" + "plt.xlabel(\"index\")" ], "metadata": { "collapsed": false, @@ -468,9 +471,9 @@ "\n", "plt.figure()\n", "plt.semilogy(relative_error, lw=3)\n", - "plt.xlim((0,10))\n", + "plt.xlim((0, 10))\n", "plt.ylabel(\"Relative error\")\n", - "plt.xlabel(\"index\")\n" + "plt.xlabel(\"index\")" ], "metadata": { "collapsed": false, @@ -527,7 +530,7 @@ "execution_count": 141, "outputs": [], "source": [ - "Vhr = Vh[:ldim,:]" + "Vhr = Vh[:ldim, :]" ], "metadata": { "collapsed": false, @@ -559,10 +562,10 @@ } ], "source": [ - "plt.plot(Vh[2,:])\n", - "plt.plot(Vh[3,:])\n", - "plt.plot(Vh[4,:])\n", - "plt.plot(Vh[5,:])\n" + "plt.plot(Vh[2, :])\n", + "plt.plot(Vh[3, :])\n", + "plt.plot(Vh[4, :])\n", + "plt.plot(Vh[5, :])" ], "metadata": { "collapsed": false, @@ -594,7 +597,10 @@ } ], "source": [ - "print(torch.norm(stacked_train_databases - torch.einsum(\"nm,bn->bm\", Vhr.T @ Vhr, stacked_train_databases))**2 / torch.norm(stacked_train_databases)**2)" + "print(\n", + " torch.norm(stacked_train_databases - torch.einsum(\"nm,bn->bm\", Vhr.T @ Vhr, stacked_train_databases)) ** 2\n", + " / torch.norm(stacked_train_databases) ** 2\n", + ")" ], "metadata": { "collapsed": false, @@ -661,7 +667,7 @@ "\n", "for (batch,) in tqdm(train_loader):\n", " loss = loss_functional(batch, torch.zeros_like(batch))\n", - " total_batch_test_loss += loss.item()\n" + " total_batch_test_loss += loss.item()" ], "metadata": { "collapsed": false, @@ -805,7 +811,7 @@ "plt.figure()\n", "plt.plot(domain, stacked_train_databases[index].detach(), label=\"org\", lw=7, c=\"lightsteelblue\")\n", "plt.plot(domain, (Vhr.T @ Vhr @ stacked_train_databases[index]).detach(), label=\"POD\", c=\"darkorange\")\n", - "plt.legend()\n" + "plt.legend()" ], "metadata": { "collapsed": false, @@ -838,7 +844,7 @@ " decoder_layers=[5, 25, 50, 101],\n", ")\n", "model.apply(init_linear)\n", - "bregman.sparsify(model, 0.20)\n" + "bregman.sparsify(model, 0.20)" ], "metadata": { "collapsed": false, @@ -923,7 +929,7 @@ ], "source": [ "U, S, Vh = torch.linalg.svd(model.encoder[-1].weight, full_matrices=False)\n", - "Sigma = S ** 2\n", + "Sigma = S**2\n", "relative_error = (torch.sum(Sigma, dim=0) - torch.cumsum(Sigma, dim=0)) / torch.sum(Sigma, dim=0)\n", "ldim: int = torch.count_nonzero(relative_error >= 1e-6).item() + 1\n", "\n", @@ -938,7 +944,7 @@ "print(\"l 1e-6\", model.latent_size(\"latent POD\", 1e-6))\n", "print(\"l 1e-8\", model.latent_size(\"latent POD\", 1e-8))\n", "print(\"es\", model.latent_size(\"encoder spectral\"))\n", - "print(\"ds\", model.latent_size(\"decoder spectral\"))\n" + "print(\"ds\", model.latent_size(\"decoder spectral\"))" ], "metadata": { "collapsed": false, @@ -1995,7 +2001,7 @@ } ], "source": [ - "epochs=1000\n", + "epochs = 1000\n", "tic = timer()\n", "\n", "for epoch in range(1, epochs):\n", @@ -2016,7 +2022,9 @@ " ldims.append(torch.linalg.svdvals(model.encoder[-1].weight.data))\n", "\n", "toc = timer()\n", - "print(f\"{epochs} epochs took {toc-tic:.04f} seconds, meaning approx {(toc-tic)/epochs:.04f} sec/epoch or {epochs/(toc-tic):.04f} epochs/sec.\")\n" + "print(\n", + " f\"{epochs} epochs took {toc-tic:.04f} seconds, meaning approx {(toc-tic)/epochs:.04f} sec/epoch or {epochs/(toc-tic):.04f} epochs/sec.\"\n", + ")" ], "metadata": { "collapsed": false, @@ -2076,7 +2084,7 @@ "plt.figure()\n", "plt.plot(ldims)\n", "plt.xlabel(\"Epochs\")\n", - "plt.ylabel(\"Singular vectors\")\n" + "plt.ylabel(\"Singular vectors\")" ], "metadata": { "collapsed": false, @@ -2123,7 +2131,7 @@ ], "source": [ "U, S, Vh = torch.linalg.svd(best_model.encoder[-1].weight, full_matrices=False)\n", - "Sigma = S ** 2\n", + "Sigma = S**2\n", "relative_error = (torch.sum(Sigma, dim=0) - torch.cumsum(Sigma, dim=0)) / torch.sum(Sigma, dim=0)\n", "ldim: int = torch.count_nonzero(relative_error >= 1e-6).item() + 1\n", "\n", @@ -2138,7 +2146,7 @@ "print(\"l 1e-6\", best_model.latent_size(\"latent POD\", 1e-6))\n", "print(\"l 1e-8\", best_model.latent_size(\"latent POD\", 1e-8))\n", "print(\"es\", best_model.latent_size(\"encoder spectral\"))\n", - "print(\"ds\", best_model.latent_size(\"decoder spectral\"))\n" + "print(\"ds\", best_model.latent_size(\"decoder spectral\"))" ], "metadata": { "collapsed": false, @@ -2200,7 +2208,7 @@ ], "source": [ "U, S, Vh = torch.linalg.svd(best_model.encoder[-1].weight, full_matrices=False)\n", - "Sigma = S ** 2\n", + "Sigma = S**2\n", "relative_error = (torch.sum(Sigma, dim=0) - torch.cumsum(Sigma, dim=0)) / torch.sum(Sigma, dim=0)\n", "ldim: int = torch.count_nonzero(relative_error >= 1e-6).item() + 1\n", "\n", @@ -2271,7 +2279,7 @@ ], "source": [ "U, S, Vh = torch.linalg.svd(pruned_model.encoder[-1].weight, full_matrices=False)\n", - "Sigma = S ** 2\n", + "Sigma = S**2\n", "relative_error = (torch.sum(Sigma, dim=0) - torch.cumsum(Sigma, dim=0)) / torch.sum(Sigma, dim=0)\n", "ldim: int = torch.count_nonzero(relative_error >= 1e-6).item() + 1\n", "\n", @@ -2416,7 +2424,7 @@ " test_loss_model += loss.item()\n", "\n", "print(\"model: \", test_loss_model)\n", - "print(\"POD: \", test_loss_pod)\n" + "print(\"POD: \", test_loss_pod)" ], "metadata": { "collapsed": false, diff --git a/notebooks/reaction_diffusion.ipynb b/notebooks/reaction_diffusion.ipynb index 94fd12a..cff3073 100644 --- a/notebooks/reaction_diffusion.ipynb +++ b/notebooks/reaction_diffusion.ipynb @@ -27,7 +27,8 @@ "cell_type": "code", "source": [ "import sys\n", - "sys.path.insert(0, '../')" + "\n", + "sys.path.insert(0, \"../\")" ], "outputs": [], "execution_count": 2 @@ -105,7 +106,7 @@ { "cell_type": "code", "source": [ - "dataset = generate_reaction_diffusion_dataset(base_storage_path=storage_path, return_database=True)\n" + "dataset = generate_reaction_diffusion_dataset(base_storage_path=storage_path, return_database=True)" ], "metadata": { "collapsed": false, @@ -141,11 +142,11 @@ "v_max = []\n", "u_min = []\n", "v_min = []\n", - "for n, (u,v) in enumerate(dataset.data):\n", + "for n, (u, v) in enumerate(dataset.data):\n", " u_max.append(torch.max(u))\n", " v_max.append(torch.max(v))\n", " u_min.append(torch.min(u))\n", - " v_min.append(torch.min(v))\n" + " v_min.append(torch.min(v))" ], "metadata": { "collapsed": false, @@ -163,25 +164,25 @@ "plt.figure()\n", "vline = 4_000\n", "\n", - "plt.subplot(2,2,1)\n", + "plt.subplot(2, 2, 1)\n", "plt.title(r\"$u_{max}$\")\n", "plt.plot(u_max)\n", "plt.axvline(vline, c=\"red\")\n", "\n", - "plt.subplot(2,2,2)\n", + "plt.subplot(2, 2, 2)\n", "plt.title(r\"$v_{max}$\")\n", "plt.plot(v_max)\n", "plt.axvline(vline, c=\"red\")\n", "\n", - "plt.subplot(2,2,3)\n", + "plt.subplot(2, 2, 3)\n", "plt.title(r\"$u_{min}$\")\n", "plt.plot(u_min)\n", "plt.axvline(vline, c=\"red\")\n", "\n", - "plt.subplot(2,2,4)\n", + "plt.subplot(2, 2, 4)\n", "plt.title(r\"$v_{min}$\")\n", "plt.plot(v_min)\n", - "plt.axvline(vline, c=\"red\")\n" + "plt.axvline(vline, c=\"red\")" ], "metadata": { "collapsed": false, @@ -223,15 +224,15 @@ "cell_type": "code", "source": [ "max_indices2 = set()\n", - "for i in range(4000, dataset.config.Nt-2500, 1000):\n", - " max_indices2.add(torch.argmax(torch.tensor(u_max[i:i+2_500])).item()+i)\n", + "for i in range(4000, dataset.config.Nt - 2500, 1000):\n", + " max_indices2.add(torch.argmax(torch.tensor(u_max[i : i + 2_500])).item() + i)\n", "max_indices2 = list(max_indices2)\n", "max_indices2.sort()\n", "\n", "plt.figure()\n", - "plt.plot(np.diff(max_indices2), marker='o')\n", + "plt.plot(np.diff(max_indices2), marker=\"o\")\n", "\n", - "print(np.diff(max_indices2))\n" + "print(np.diff(max_indices2))" ], "metadata": { "collapsed": false, @@ -285,13 +286,13 @@ "source": [ "max_indices = set()\n", "for i in range(4000, 12000, 1000):\n", - " max_indices.add(torch.argmax(torch.tensor(u_max[i:i+2_500])).item()+i)\n", + " max_indices.add(torch.argmax(torch.tensor(u_max[i : i + 2_500])).item() + i)\n", "max_indices = list(max_indices)\n", "max_indices.sort()\n", "print(max_indices)\n", "print(np.diff(max_indices))\n", "post_transience = max_indices[0]\n", - "period = max_indices[-1]-max_indices[0]\n", + "period = max_indices[-1] - max_indices[0]\n", "print(f\"Period is {period} times steps\")" ] }, @@ -300,59 +301,59 @@ "source": [ "plt.figure()\n", "\n", - "plt.subplot(2,2,1)\n", + "plt.subplot(2, 2, 1)\n", "plt.plot(u_max)\n", "for index in max_indices:\n", " plt.axvline(index, c=\"red\")\n", "plt.title(r\"$u_{max}$\")\n", - "plt.xlim((0,15_000))\n", + "plt.xlim((0, 15_000))\n", "\n", - "plt.subplot(2,2,2)\n", + "plt.subplot(2, 2, 2)\n", "plt.plot(v_max)\n", "for index in max_indices:\n", " plt.axvline(index, c=\"red\")\n", "plt.title(r\"$v_{max}$\")\n", - "plt.xlim((0,15_000))\n", + "plt.xlim((0, 15_000))\n", "\n", - "plt.subplot(2,2,3)\n", + "plt.subplot(2, 2, 3)\n", "plt.plot(u_min)\n", "for index in max_indices:\n", " plt.axvline(index, c=\"red\")\n", "plt.title(r\"$u_{min}$\")\n", - "plt.xlim((0,15_000))\n", + "plt.xlim((0, 15_000))\n", "\n", - "plt.subplot(2,2,4)\n", + "plt.subplot(2, 2, 4)\n", "plt.plot(v_min)\n", "for index in max_indices:\n", " plt.axvline(index, c=\"red\")\n", "plt.title(r\"$v_{min}$\")\n", - "plt.xlim((0,15_000))\n", + "plt.xlim((0, 15_000))\n", "\n", "plt.figure()\n", "\n", - "plt.subplot(2,2,1)\n", + "plt.subplot(2, 2, 1)\n", "plt.plot(u_max)\n", "for index in max_indices:\n", " plt.axvline(index, c=\"red\")\n", "plt.title(r\"$u_{max}$\")\n", "\n", - "plt.subplot(2,2,2)\n", + "plt.subplot(2, 2, 2)\n", "plt.plot(v_max)\n", "for index in max_indices:\n", " plt.axvline(index, c=\"red\")\n", "plt.title(r\"$v_{max}$\")\n", "\n", - "plt.subplot(2,2,3)\n", + "plt.subplot(2, 2, 3)\n", "plt.plot(u_min)\n", "for index in max_indices:\n", " plt.axvline(index, c=\"red\")\n", "plt.title(r\"$u_{min}$\")\n", "\n", - "plt.subplot(2,2,4)\n", + "plt.subplot(2, 2, 4)\n", "plt.plot(v_min)\n", "for index in max_indices:\n", " plt.axvline(index, c=\"red\")\n", - "plt.title(r\"$v_{min}$\")\n" + "plt.title(r\"$v_{min}$\")" ], "metadata": { "collapsed": false, @@ -469,19 +470,19 @@ "\n", " plt.subplot(1, 2, 1)\n", " plt.pcolormesh(dataset.meshgrid_x, dataset.meshgrid_y, dataset.data[index][0].detach(), cmap=\"PuOr\")\n", - " plt.axis('square')\n", + " plt.axis(\"square\")\n", " plt.title(f\"$u$\")\n", " plt.colorbar()\n", " plt.xticks([])\n", " plt.yticks([])\n", - " \n", + "\n", " plt.subplot(1, 2, 2)\n", " plt.pcolormesh(dataset.meshgrid_x, dataset.meshgrid_y, dataset.data[index][1].detach(), cmap=\"PiYG\")\n", - " plt.axis('square')\n", + " plt.axis(\"square\")\n", " plt.title(f\"$v$\")\n", " plt.colorbar()\n", " plt.xticks([])\n", - " plt.yticks([])\n" + " plt.yticks([])" ], "metadata": { "collapsed": false, @@ -547,24 +548,24 @@ { "cell_type": "code", "source": [ - "for index in range(post_transience, post_transience+6*period+1, period):\n", + "for index in range(post_transience, post_transience + 6 * period + 1, period):\n", " plt.figure()\n", - " \n", + "\n", " plt.subplot(1, 2, 1)\n", " plt.pcolormesh(dataset.meshgrid_x, dataset.meshgrid_y, dataset.data[index][0].detach(), cmap=\"PuOr\")\n", - " plt.axis('square')\n", + " plt.axis(\"square\")\n", " plt.title(f\"$u$\")\n", " plt.colorbar()\n", " plt.xticks([])\n", " plt.yticks([])\n", - " \n", + "\n", " plt.subplot(1, 2, 2)\n", " plt.pcolormesh(dataset.meshgrid_x, dataset.meshgrid_y, dataset.data[index][1].detach(), cmap=\"PiYG\")\n", - " plt.axis('square')\n", + " plt.axis(\"square\")\n", " plt.title(f\"$v$\")\n", " plt.colorbar()\n", " plt.xticks([])\n", - " plt.yticks([])\n" + " plt.yticks([])" ], "metadata": { "collapsed": false, @@ -682,11 +683,11 @@ ], "execution_count": 32, "source": [ - "plt.figure(figsize=(19,12))\n", + "plt.figure(figsize=(19, 12))\n", "\n", "plt.subplot(1, 3, 1)\n", "plt.pcolormesh(dataset.meshgrid_x, dataset.meshgrid_y, dataset.data[-1].view(2, 100, 100)[0].detach(), cmap=\"PuOr\")\n", - "plt.axis('square')\n", + "plt.axis(\"square\")\n", "plt.title(f\"$u$\")\n", "plt.colorbar()\n", "plt.xticks([])\n", @@ -694,19 +695,23 @@ "\n", "plt.subplot(1, 3, 2)\n", "plt.pcolormesh(dataset.meshgrid_x, dataset.meshgrid_y, dataset.data[-1].view(2, 100, 100)[1].detach(), cmap=\"PiYG\")\n", - "plt.axis('square')\n", + "plt.axis(\"square\")\n", "plt.title(f\"$v$\")\n", "plt.colorbar()\n", "plt.xticks([])\n", "plt.yticks([])\n", "\n", "plt.subplot(1, 3, 3)\n", - "plt.pcolormesh(dataset.meshgrid_x, dataset.meshgrid_y, dataset.data[-1].view(2, 100, 100)[0].detach()-np.flip(dataset.data[-1].view(2, 100, 100)[1].T.detach().numpy(), axis=1))\n", - "plt.axis('square')\n", + "plt.pcolormesh(\n", + " dataset.meshgrid_x,\n", + " dataset.meshgrid_y,\n", + " dataset.data[-1].view(2, 100, 100)[0].detach() - np.flip(dataset.data[-1].view(2, 100, 100)[1].T.detach().numpy(), axis=1),\n", + ")\n", + "plt.axis(\"square\")\n", "plt.title(f\"$u-flip_y(v^\\intercal)$\")\n", "plt.colorbar()\n", "plt.xticks([])\n", - "plt.yticks([])\n" + "plt.yticks([])" ] }, { @@ -727,11 +732,15 @@ "plt.savefig(results_path / \"solution_v.png\", bbox_inches=\"tight\")\n", "\n", "plt.figure()\n", - "plt.pcolormesh(dataset.meshgrid_x, dataset.meshgrid_y, dataset.data[-1].view(2, 100, 100)[0].detach()-np.flip(dataset.data[-1].view(2, 100, 100)[1].T.detach().numpy(), axis=1))\n", + "plt.pcolormesh(\n", + " dataset.meshgrid_x,\n", + " dataset.meshgrid_y,\n", + " dataset.data[-1].view(2, 100, 100)[0].detach() - np.flip(dataset.data[-1].view(2, 100, 100)[1].T.detach().numpy(), axis=1),\n", + ")\n", "# plt.axis('square')\n", "plt.title(f\"$u-flip_y(v^\\intercal)$\")\n", "plt.colorbar()\n", - "plt.savefig(results_path / \"solution_diff.png\", bbox_inches=\"tight\")\n" + "plt.savefig(results_path / \"solution_diff.png\", bbox_inches=\"tight\")" ], "metadata": { "collapsed": false, @@ -805,11 +814,13 @@ "\n", "num_frames = 500\n", "\n", + "\n", "def animate(k):\n", - " quad.set_array(dataset.data[:, 0, :, :][5000+36*k].detach())\n", - " return quad,\n", + " quad.set_array(dataset.data[:, 0, :, :][5000 + 36 * k].detach())\n", + " return (quad,)\n", "\n", - "# ani = FuncAnimation(fig, animate, blit=True, frames=num_frames, repeat=False) \n", + "\n", + "# ani = FuncAnimation(fig, animate, blit=True, frames=num_frames, repeat=False)\n", "# ani.save(results_path / \"reaction_diffusion.gif\", dpi=300, writer=PillowWriter(fps=25))" ] }, @@ -829,18 +840,16 @@ "execution_count": 18, "source": [ "subsampled_dataset = torch.utils.data.Subset(\n", - " generate_reaction_diffusion_dataset(base_storage_path=storage_path, return_database=False, u_only=True), \n", - " torch.arange(start=5000, end=50_000, step=36)\n", - ")\n" + " generate_reaction_diffusion_dataset(base_storage_path=storage_path, return_database=False, u_only=True),\n", + " torch.arange(start=5000, end=50_000, step=36),\n", + ")" ] }, { "cell_type": "code", "source": [ "train_dataset, test_dataset, validation_dataset = torch.utils.data.random_split(\n", - " subsampled_dataset,\n", - " [750, 250, 250],\n", - " generator=torch.Generator().manual_seed(42)\n", + " subsampled_dataset, [750, 250, 250], generator=torch.Generator().manual_seed(42)\n", ")" ], "metadata": { @@ -877,7 +886,7 @@ "\n", "for (batch,) in tqdm(train_loader):\n", " loss = loss_functional(batch, torch.zeros_like(batch))\n", - " total_batch_test_loss += loss.item()\n" + " total_batch_test_loss += loss.item()" ], "metadata": { "collapsed": false, @@ -902,9 +911,7 @@ "cell_type": "code", "source": [ "# Concatenate the tensors into one tensor\n", - "full_train_dataset = torch.vstack([\n", - " train_dataset[i][0] for i in range(len(train_dataset))\n", - "])" + "full_train_dataset = torch.vstack([train_dataset[i][0] for i in range(len(train_dataset))])" ], "metadata": { "collapsed": false, @@ -922,7 +929,7 @@ "plt.figure()\n", "plt.pcolormesh(dataset.meshgrid_x, dataset.meshgrid_y, train_dataset[0][0].view(100, 100).detach(), cmap=\"PuOr\")\n", "plt.colorbar()\n", - "plt.savefig(results_path / \"solution.png\", bbox_inches=\"tight\")\n" + "plt.savefig(results_path / \"solution.png\", bbox_inches=\"tight\")" ], "metadata": { "collapsed": false, @@ -948,9 +955,9 @@ { "cell_type": "code", "source": [ - "U, S, Vh=torch.linalg.svd(full_train_dataset, full_matrices=False)\n", - "Sigma = S ** 2\n", - "relative_error = ( torch.sum(Sigma, dim=0) - torch.cumsum(Sigma, dim=0) ) / torch.sum(Sigma, dim=0)" + "U, S, Vh = torch.linalg.svd(full_train_dataset, full_matrices=False)\n", + "Sigma = S**2\n", + "relative_error = (torch.sum(Sigma, dim=0) - torch.cumsum(Sigma, dim=0)) / torch.sum(Sigma, dim=0)" ], "metadata": { "collapsed": false, @@ -967,10 +974,10 @@ "source": [ "plt.figure()\n", "plt.semilogy(S, lw=3)\n", - "plt.xlim((0,60))\n", + "plt.xlim((0, 60))\n", "plt.ylabel(\"Singular values\")\n", "plt.xlabel(\"Index\")\n", - "plt.savefig(\"../results/reaction_diffusion/singular_values.png\", bbox_inches=\"tight\")\n" + "plt.savefig(\"../results/reaction_diffusion/singular_values.png\", bbox_inches=\"tight\")" ], "metadata": { "collapsed": false, @@ -1004,7 +1011,7 @@ "\n", "plt.subplot(1, 2, 2)\n", "plt.semilogy(S, lw=3)\n", - "plt.xlim((0,60))\n", + "plt.xlim((0, 60))\n", "\n", "# ------\n", "\n", @@ -1016,7 +1023,7 @@ "\n", "plt.subplot(1, 2, 2)\n", "plt.semilogy(relative_error, lw=3)\n", - "plt.xlim((0,60))\n" + "plt.xlim((0, 60))" ], "metadata": { "collapsed": false, @@ -1066,7 +1073,7 @@ "print(\"finfo: \", torch.count_nonzero(relative_error >= torch.finfo().eps).item() + 1)\n", "print(\"ldim: \", ldim)\n", "\n", - "Vhr = Vh[:ldim,:]" + "Vhr = Vh[:ldim, :]" ], "metadata": { "collapsed": false, @@ -1096,14 +1103,14 @@ "for i, ax in enumerate(axes.flatten()):\n", " pcm = ax.pcolormesh(dataset.meshgrid_x, dataset.meshgrid_y, Vh[i, :].reshape(100, 100), cmap=\"PuOr\")\n", " ax.set_title(f\"$\\phi_{{{i}}}$\")\n", - " ax.axis('square')\n", + " ax.axis(\"square\")\n", " ax.set_xticks([]) # Hide x-axis ticks\n", " ax.set_yticks([]) # Hide y-axis ticks\n", "\n", "# # Add a big colorbar\n", "fig.subplots_adjust(right=0.8)\n", "cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])\n", - "cbar = fig.colorbar(pcm, cax=cbar_ax)\n" + "cbar = fig.colorbar(pcm, cax=cbar_ax)" ], "metadata": { "collapsed": false, diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..0a52268 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,6 @@ +[tool.black] +line-length = 130 +target-version = ['py310'] + +[tool.isort] +profile = "black" diff --git a/src/data/databases/advection_1d_database.py b/src/data/databases/advection_1d_database.py index a0ebd2e..9f10213 100644 --- a/src/data/databases/advection_1d_database.py +++ b/src/data/databases/advection_1d_database.py @@ -19,11 +19,7 @@ def __init__(self, *args, **kwargs): self.config.xmax - self.config.dx, self.config.Nx, ) - self.cfl = ( - torch.abs(self.config.pde_params.advection) - * self.config.dt - / self.config.dx - ) + self.cfl = torch.abs(self.config.pde_params.advection) * self.config.dt / self.config.dx self.direction = torch.sgn(self.config.pde_params.advection) self.a = 1 - self.direction * self.cfl**2 self.b = self.direction * self.cfl / 2 + self.cfl**2 / 2 @@ -31,9 +27,7 @@ def __init__(self, *args, **kwargs): def initialize(self): self.data = torch.zeros((self.config.Nt, self.config.Nx)) - self.data[0, :] = self.initial_condition(self.config.initial_condition_params)( - self.domain - ) + self.data[0, :] = self.initial_condition(self.config.initial_condition_params)(self.domain) def forward(self, n: int, t: float): """ @@ -57,17 +51,11 @@ def __init__(self, *args, **kwargs): self.config.xmax - self.config.dx, self.config.Nx, ) - self.shift_per_time = ( - self.config.pde_params.advection - * self.config.dt - / self.config.dx - ) + self.shift_per_time = self.config.pde_params.advection * self.config.dt / self.config.dx def initialize(self): self.data = torch.zeros((self.config.Nt, self.config.Nx)) - self.data[0, :] = self.initial_condition(self.config.initial_condition_params)( - self.domain - ) + self.data[0, :] = self.initial_condition(self.config.initial_condition_params)(self.domain) def forward(self, n: int, t: float): """ @@ -83,6 +71,4 @@ def forward(self, n: int, t: float): wave_params.mu += n * self.config.pde_params.advection * self.config.dt if n <= 10: print(wave_params) - self.data[n, :] = self.initial_condition(wave_params)( - self.domain - ) + self.data[n, :] = self.initial_condition(wave_params)(self.domain) diff --git a/src/data/databases/diffusion_1d_database.py b/src/data/databases/diffusion_1d_database.py index f4eeb1d..30e596e 100644 --- a/src/data/databases/diffusion_1d_database.py +++ b/src/data/databases/diffusion_1d_database.py @@ -19,10 +19,7 @@ def __init__(self, *args, **kwargs): self.c = self.config.pde_params.diffusion * self.dt / self.config.dx**2 self.c_tilde = 1 - 2 * self.c if self.c >= 0.5: - print( - f"Warning: c value is {self.c}, which is bigger than 0.5. " - f"This means the solution will be unstable." - ) + print(f"Warning: c value is {self.c}, which is bigger than 0.5. " f"This means the solution will be unstable.") def parameter_set(self): return Diffusion1DParamSet(diffusion=self.diffusion) @@ -33,9 +30,7 @@ def parameter_set_class(): def initialize(self): self.data = torch.zeros((self.config.Nt, self.config.Nx)) - self.data[0, :] = self.initial_condition(self.config.initial_condition_params)( - self.domain - ) + self.data[0, :] = self.initial_condition(self.config.initial_condition_params)(self.domain) def forward(self, n: int, t: float): """ @@ -45,6 +40,4 @@ def forward(self, n: int, t: float): Returns: """ - self.data[n, 1:-1] = self.c_tilde * self.data[n - 1, 1:-1] + self.c * ( - self.data[n - 1, 0:-2] + self.data[n - 1, 2:] - ) + self.data[n, 1:-1] = self.c_tilde * self.data[n - 1, 1:-1] + self.c * (self.data[n - 1, 0:-2] + self.data[n - 1, 2:]) diff --git a/src/data/databases/pde_database.py b/src/data/databases/pde_database.py index 2d496d4..62a0e78 100644 --- a/src/data/databases/pde_database.py +++ b/src/data/databases/pde_database.py @@ -42,9 +42,7 @@ def __post_init__(self): self.Nd = 1 if self.ymin is not None or self.ymax is not None or self.Ny is not None: - assert ( - self.ymin is not None and self.ymax is not None and self.Ny is not None - ) + assert self.ymin is not None and self.ymax is not None and self.Ny is not None assert self.ymin <= self.ymax assert self.Ny > 1 self.dy = (self.ymax - self.ymin) / (self.Ny - 1) @@ -52,28 +50,18 @@ def __post_init__(self): def to_json(self) -> str: return json.dumps( - asdict( - self, dict_factory=lambda x: {k: v for (k, v) in x if v is not None} - ), + asdict(self, dict_factory=lambda x: {k: v for (k, v) in x if v is not None}), cls=TorchEncoder, indent=4, ) @classmethod def from_json(cls, json_): - return cls( - **{ - k: v - for (k, v) in json.loads(json_).items() - if k not in ["Nd", "dt", "dx", "dy"] - } - ) + return cls(**{k: v for (k, v) in json.loads(json_).items() if k not in ["Nd", "dt", "dx", "dy"]}) class PdeDatabase: - def __init__( - self, initial_condition, config: Config, storage_path: str | PathLike[str] - ): + def __init__(self, initial_condition, config: Config, storage_path: str | PathLike[str]): self.initial_condition = initial_condition self._config = config self.storage_path = storage_path diff --git a/src/data/databases/reaction_diffusion_2d_database.py b/src/data/databases/reaction_diffusion_2d_database.py index d5ac4f2..b1a3031 100644 --- a/src/data/databases/reaction_diffusion_2d_database.py +++ b/src/data/databases/reaction_diffusion_2d_database.py @@ -28,16 +28,12 @@ def __init__(self, *args, **kwargs): def initialize(self): self.data = torch.zeros((self.config.Nt, 2, self.config.Nx, self.config.Ny)) - self.data[0, :] = self.initial_condition(self.config.initial_condition_params)( - self.meshgrid_x, self.meshgrid_y - ) + self.data[0, :] = self.initial_condition(self.config.initial_condition_params)(self.meshgrid_x, self.meshgrid_y) def _laplacian2d_b(self, f): lap = torch.zeros_like(f) - lap[1:-1, 1:-1] = ( - f[0:-2, 1:-1] - 2 * f[1:-1, 1:-1] + f[2:, 1:-1] - ) / self.config.dx**2 + ( + lap[1:-1, 1:-1] = (f[0:-2, 1:-1] - 2 * f[1:-1, 1:-1] + f[2:, 1:-1]) / self.config.dx**2 + ( f[1:-1, 0:-2] - 2 * f[1:-1, 1:-1] + f[1:-1, 2:] ) / self.config.dy**2 @@ -50,30 +46,24 @@ def _laplacian2d_b(self, f): return lap def _laplacian2d(self, f): - laplacian_f_y = ( - torch.diff( - f, - n=2, - dim=0, - # prepend=torch.zeros_like(f[0, :]).reshape(1, -1), - # append=torch.zeros_like(f[-1, :]).reshape(1, -1) - prepend=f[0, :].reshape(1, -1), - append=f[-1, :].reshape(1, -1), - ) - / (self.config.dy**2) - ) - laplacian_f_x = ( - torch.diff( - f, - n=2, - dim=1, - # prepend=torch.zeros_like(f[:, 0]).reshape(-1, 1), - # append=torch.zeros_like(f[:, -1]).reshape(-1, 1) - prepend=f[:, 0].reshape(-1, 1), - append=f[:, -1].reshape(-1, 1), - ) - / (self.config.dx**2) - ) + laplacian_f_y = torch.diff( + f, + n=2, + dim=0, + # prepend=torch.zeros_like(f[0, :]).reshape(1, -1), + # append=torch.zeros_like(f[-1, :]).reshape(1, -1) + prepend=f[0, :].reshape(1, -1), + append=f[-1, :].reshape(1, -1), + ) / (self.config.dy**2) + laplacian_f_x = torch.diff( + f, + n=2, + dim=1, + # prepend=torch.zeros_like(f[:, 0]).reshape(-1, 1), + # append=torch.zeros_like(f[:, -1]).reshape(-1, 1) + prepend=f[:, 0].reshape(-1, 1), + append=f[:, -1].reshape(-1, 1), + ) / (self.config.dx**2) return laplacian_f_x + laplacian_f_y def forward(self, n: int, t: float): diff --git a/src/data/generate_datasets.py b/src/data/generate_datasets.py index b259e72..2c36766 100644 --- a/src/data/generate_datasets.py +++ b/src/data/generate_datasets.py @@ -26,9 +26,7 @@ Nt=5001, T=1, pde_params=None, - initial_condition_params=GaussianParamSet( - mu=torch.tensor([0]), sigma=torch.tensor([0.02]) - ), + initial_condition_params=GaussianParamSet(mu=torch.tensor([0]), sigma=torch.tensor([0.02])), xmin=-1, xmax=1, Nx=101, @@ -38,9 +36,7 @@ Nt=200, T=1, pde_params=None, - initial_condition_params=GaussianParamSet( - mu=torch.tensor([0.2]), sigma=torch.tensor([1e-3]) - ), + initial_condition_params=GaussianParamSet(mu=torch.tensor([0.2]), sigma=torch.tensor([1e-3])), xmin=0, xmax=2, Nx=256, @@ -60,9 +56,7 @@ ) -def generate_diffusion_dataset( - parameters, base_storage_path=Path("data/diffusion"), return_stacked=False -): +def generate_diffusion_dataset(parameters, base_storage_path=Path("data/diffusion"), return_stacked=False): databases = [] for diff in parameters: config_ = BASE_DIFFUSION_CONFIG @@ -78,18 +72,14 @@ def generate_diffusion_dataset( databases.append(database) # The [0:-1:20] is there to subsample every 20th point - stacked_datasets = torch.vstack( - [database.data[0:-1:20, :] for database in databases] - ) + stacked_datasets = torch.vstack([database.data[0:-1:20, :] for database in databases]) if return_stacked: return stacked_datasets return torch.utils.data.TensorDataset(stacked_datasets) -def generate_advection_dataset( - parameters, base_storage_path=Path("data/advection"), return_stacked=False -): +def generate_advection_dataset(parameters, base_storage_path=Path("data/advection"), return_stacked=False): databases = [] for adv in parameters: config_ = BASE_ADVECTION_CONFIG @@ -134,8 +124,6 @@ def generate_reaction_diffusion_dataset( return database if u_only: - return torch.utils.data.TensorDataset( - database.data[:, 0, :, :].view(database.config.Nt, -1) - ) + return torch.utils.data.TensorDataset(database.data[:, 0, :, :].view(database.config.Nt, -1)) return torch.utils.data.TensorDataset(database.data.view(database.config.Nt, -1)) diff --git a/src/data/initial_conditions/gaussian.py b/src/data/initial_conditions/gaussian.py index 3157ecd..85ec8c2 100644 --- a/src/data/initial_conditions/gaussian.py +++ b/src/data/initial_conditions/gaussian.py @@ -46,6 +46,4 @@ def __call__(self, x): mahalanobis_sqrd = torch.dot(delta, torch.matmul(delta, self.inv_sigma)) pi_constant = (2 * torch.tensor(torch.pi)) ** self.dim - return ( - 1 / torch.sqrt(pi_constant * self.det) * torch.exp(-0.5 * mahalanobis_sqrd) - ) + return 1 / torch.sqrt(pi_constant * self.det) * torch.exp(-0.5 * mahalanobis_sqrd) diff --git a/src/models.py b/src/models.py index 32f04e1..f740c73 100644 --- a/src/models.py +++ b/src/models.py @@ -1,15 +1,11 @@ import igraph import torch -from src.utils.get_weights_linear import get_weights_linear +from src.utils import get_weights_linear class DeepAutoEncoder(torch.nn.Module): - def __init__( - self, - encoder_layers: list[int], - decoder_layers: list[int] - ): + def __init__(self, encoder_layers: list[int], decoder_layers: list[int]): super(DeepAutoEncoder, self).__init__() # latent dimension must match assert encoder_layers[-1] == decoder_layers[0] @@ -23,13 +19,9 @@ def __init__( num_of_relu_needed = len(encoder_layers) - 2 shifted_encoder_layers = encoder_layers[1:] + encoder_layers[:1] layers = [] - for k, (current_width, next_width) in enumerate( - zip(encoder_layers, shifted_encoder_layers) - ): + for k, (current_width, next_width) in enumerate(zip(encoder_layers, shifted_encoder_layers)): if k <= num_of_relu_needed - 1: - layers.extend( - [torch.nn.Linear(current_width, next_width), torch.nn.ReLU()] - ) + layers.extend([torch.nn.Linear(current_width, next_width), torch.nn.ReLU()]) elif k <= num_of_relu_needed: layers.append(torch.nn.Linear(current_width, next_width)) self.encoder = torch.nn.Sequential(*layers) @@ -38,13 +30,9 @@ def __init__( num_of_relu_needed = len(decoder_layers) - 2 shifted_decoder_layers = decoder_layers[1:] + decoder_layers[:1] layers = [] - for k, (current_width, next_width) in enumerate( - zip(decoder_layers, shifted_decoder_layers) - ): + for k, (current_width, next_width) in enumerate(zip(decoder_layers, shifted_decoder_layers)): if k <= num_of_relu_needed - 1: - layers.extend( - [torch.nn.Linear(current_width, next_width), torch.nn.ReLU()] - ) + layers.extend([torch.nn.Linear(current_width, next_width), torch.nn.ReLU()]) elif k <= num_of_relu_needed: layers.append(torch.nn.Linear(current_width, next_width)) self.decoder = torch.nn.Sequential(*layers) @@ -56,9 +44,7 @@ def fom_size(self): def latent_size(self, direction="encoder spectral"): match direction: case "encoder spectral": - return torch.count_nonzero( - torch.linalg.svdvals(self.encoder[-1].weight) - ) + return torch.count_nonzero(torch.linalg.svdvals(self.encoder[-1].weight)) case "decoder spectral": return torch.count_nonzero(torch.linalg.svdvals(self.decoder[0].weight)) case "decoder row": @@ -88,6 +74,7 @@ def _color_vertices(self, vertices: igraph.VertexSeq) -> list[str]: Returns: List with colour, where the first element corresponds to the first vertex etc. """ + def color_vertex(index: int, vertex_: igraph.Vertex): if index < self.encoder_layers[0]: return "blue" @@ -147,6 +134,6 @@ def plot(self, save_to: str | None = None) -> None: "vertex_size": 10, "edge_arrow_size": 0.1, "vertex_color": self._color_vertices(graph.vs), - "edge_color": "#778899" + "edge_color": "#778899", } return igraph.plot(graph, save_to, layout=layout, **visual_style) diff --git a/src/training/advection/train.py b/src/training/advection/train.py index 0e425f4..f01f2b4 100644 --- a/src/training/advection/train.py +++ b/src/training/advection/train.py @@ -10,7 +10,7 @@ from src.data.generate_datasets import generate_advection_dataset # noqa: E402 -TRAIN_PARAMETERS = torch.tensor([0.6, 0.9, 1.2]) +TRAIN_PARAMETERS = torch.tensor([0.6, 0.9, 1.2]) VALIDATION_PARAMETERS = torch.tensor([0.75]) TEST_PARAMETERS = torch.tensor([1.05]) @@ -28,10 +28,4 @@ test_dataset = generate_advection_dataset(TEST_PARAMETERS) validation_dataset = generate_advection_dataset(VALIDATION_PARAMETERS) - train( - train_dataset, - test_dataset, - validation_dataset, - Path("models/advection"), - config - ) + train(train_dataset, test_dataset, validation_dataset, Path("models/advection"), config) diff --git a/src/training/diffusion/train.py b/src/training/diffusion/train.py index 9268b9e..8876c4c 100644 --- a/src/training/diffusion/train.py +++ b/src/training/diffusion/train.py @@ -1,6 +1,5 @@ from pathlib import Path import sys -import torch import wandb @@ -18,8 +17,6 @@ if __name__ == "__main__": # If called by wandb.agent, as below, # this config will be set by Sweep Controller - #torch.set_default_dtype(torch.float64) - config = wandb.config print(config) @@ -28,10 +25,4 @@ test_dataset = generate_diffusion_dataset(TEST_PARAMETERS) validation_dataset = generate_diffusion_dataset(VALIDATION_PARAMETERS) - train( - train_dataset, - test_dataset, - validation_dataset, - Path("models/diffusion"), - config - ) + train(train_dataset, test_dataset, validation_dataset, Path("models/diffusion"), config) diff --git a/src/training/reaction_diffusion/train.py b/src/training/reaction_diffusion/train.py index d721403..63f1e09 100644 --- a/src/training/reaction_diffusion/train.py +++ b/src/training/reaction_diffusion/train.py @@ -24,15 +24,7 @@ subsampled_dataset = torch.utils.data.Subset(dataset, torch.arange(start=5000, end=50_000, step=36)) train_dataset, test_dataset, validation_dataset = torch.utils.data.random_split( - subsampled_dataset, - [750, 250, 250], - generator=torch.Generator().manual_seed(42) + subsampled_dataset, [750, 250, 250], generator=torch.Generator().manual_seed(42) ) - train( - train_dataset, - test_dataset, - validation_dataset, - Path("models/reaction_diffusion"), - config - ) + train(train_dataset, test_dataset, validation_dataset, Path("models/reaction_diffusion"), config) diff --git a/src/training/train.py b/src/training/train.py index a855d0d..e2da900 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -37,10 +37,7 @@ def train(train_dataset, test_dataset, validation_dataset, model_storage_folder, ) # define network to use - model = bregman.AutoEncoder( - encoder_layers=config.encoder_layers, - decoder_layers=config.decoder_layers - ) + model = bregman.AutoEncoder(encoder_layers=config.encoder_layers, decoder_layers=config.decoder_layers) # initialize weights and biases model.apply(init_linear) @@ -99,10 +96,7 @@ def train(train_dataset, test_dataset, validation_dataset, model_storage_folder, for epoch in range(1, config.epochs + 1): if epoch % 20 == 1: print(50 * "-") - print( - f"epochs {epoch} to {min(epoch + 19, config.epochs)} " - f"of the {config.epochs} total epochs" - ) + print(f"epochs {epoch} to {min(epoch + 19, config.epochs)} " f"of the {config.epochs} total epochs") print(50 * "-") best_model_updated = False @@ -163,8 +157,8 @@ def train(train_dataset, test_dataset, validation_dataset, model_storage_folder, "Latent dimension (post)": model.latent_size("latent POD", 1e-8), "Latent dimension (best)": best_model.latent_size("minimal"), "Latent dimension (best post)": best_model.latent_size("latent POD", 1e-8), - "AIC/2": model.latent_size("minimal")+math.log(train_loss, 10), - "AIC/2 (best)": best_model.latent_size("minimal")+math.log(best_training_loss, 10), + "AIC/2": model.latent_size("minimal") + math.log(train_loss, 10), + "AIC/2 (best)": best_model.latent_size("minimal") + math.log(best_training_loss, 10), } if epoch == config.epochs: pruned_model = bregman.simplify(bregman.latent_pod(best_model, 1e-8)) @@ -185,8 +179,6 @@ def train(train_dataset, test_dataset, validation_dataset, model_storage_folder, wandb.log(log_dict, step=epoch) - - # create folder for models to be stored, if it doesn't already exist model_storage_folder.mkdir(parents=True, exist_ok=True) diff --git a/src/utils/L12_nuclear.py b/src/utils/L12_nuclear.py index dbd79a9..23e1126 100644 --- a/src/utils/L12_nuclear.py +++ b/src/utils/L12_nuclear.py @@ -13,14 +13,14 @@ def L12_nuclear(model: "AutoEncoder", rc1, rc2=None): rc2 = rc1 preset = [ - {'params': get_bias(model), 'reg': Null()}, + {"params": get_bias(model), "reg": Null()}, {"params": model.encoder[-1].weight, "reg": Nuclear(rc=rc2)}, ] - for i in range(len(model.encoder_layers)-2): - preset.append({"params": model.encoder[2*i].weight, "reg": L12(rc=rc1)}) + for i in range(len(model.encoder_layers) - 2): + preset.append({"params": model.encoder[2 * i].weight, "reg": L12(rc=rc1)}) - for i in range(len(model.decoder_layers)-1): - preset.append({"params": model.decoder[2*i].weight, "reg": L12(rc=rc1)}) + for i in range(len(model.decoder_layers) - 1): + preset.append({"params": model.decoder[2 * i].weight, "reg": L12(rc=rc1)}) return preset diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 80b7a07..e8ee40e 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -1,4 +1,4 @@ from .get_bias import get_bias # noqa: F401 -from .get_weights_linear import get_weights_linear # noqa: F401 +from .get_weights_linear import get_weights_linear from .L12_nuclear import L12_nuclear from .init_linear import init_linear diff --git a/src/utils/get_bias.py b/src/utils/get_bias.py index b6ca71c..8a70596 100644 --- a/src/utils/get_bias.py +++ b/src/utils/get_bias.py @@ -3,11 +3,7 @@ def get_bias(model): for m in model.modules(): - if ( - isinstance(m, torch.nn.Linear) - or isinstance(m, torch.nn.Conv2d) - or isinstance(m, torch.nn.BatchNorm2d) - ): + if isinstance(m, torch.nn.Linear) or isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.BatchNorm2d): if not (m.bias is None): yield m.bias else: