Skip to content

Commit

Permalink
Applied linting
Browse files Browse the repository at this point in the history
  • Loading branch information
Heeringa committed Aug 13, 2024
1 parent 89de684 commit 4ffa9b8
Show file tree
Hide file tree
Showing 21 changed files with 265 additions and 316 deletions.
4 changes: 2 additions & 2 deletions .flake8
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[flake8]
max-line-length=88
extend-ignore=E230,W503,W504
max-line-length=130
extend-ignore=F401
12 changes: 0 additions & 12 deletions .github/workflows/black.yml

This file was deleted.

36 changes: 36 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
name: lint
on:
- push
- pull_request

jobs:
isort:
runs-on: ubuntu-latest
steps:
- name: Check out source repository
uses: actions/checkout@v3
- name: Run isort
uses: isort/isort-action@v1
black:
runs-on: ubuntu-latest
steps:
- name: Check out source repository
uses: actions/checkout@v4
- name: Run black
uses: psf/black@stable
with:
jupyter: true
flake8:
runs-on: ubuntu-latest
name: Lint
steps:
- name: Check out source repository
uses: actions/checkout@v3
- name: Set up Python environment
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: flake8 Lint
uses: py-actions/flake8@v2
with:
plugins: "flake8-nb flake8-black"
63 changes: 36 additions & 27 deletions notebooks/advection.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
},
"source": [
"import sys\n",
"sys.path.insert(0, '../')"
"\n",
"sys.path.insert(0, \"../\")"
],
"outputs": [],
"execution_count": 22
Expand Down Expand Up @@ -53,7 +54,7 @@
"from tqdm import tqdm\n",
"\n",
"from src.data.generate_datasets import BASE_ADVECTION_CONFIG as CONFIG, generate_advection_dataset\n",
"from src.training.advection.train import TRAIN_PARAMETERS, TEST_PARAMETERS\n"
"from src.training.advection.train import TRAIN_PARAMETERS, TEST_PARAMETERS"
],
"metadata": {
"collapsed": false,
Expand All @@ -77,7 +78,7 @@
"cell_type": "code",
"source": [
"base_storage_path = Path(\"../data/advection\")\n",
"results_path = Path(\"../results/advection\")\n"
"results_path = Path(\"../results/advection\")"
],
"metadata": {
"collapsed": false,
Expand Down Expand Up @@ -124,7 +125,7 @@
"stacked_test_datasets = generate_advection_dataset(TEST_PARAMETERS, base_storage_path=base_storage_path, return_stacked=True)\n",
"\n",
"train_dataset = torch.utils.data.TensorDataset(stacked_train_datasets)\n",
"test_dataset = torch.utils.data.TensorDataset(stacked_test_datasets)\n"
"test_dataset = torch.utils.data.TensorDataset(stacked_test_datasets)"
],
"metadata": {
"collapsed": false,
Expand Down Expand Up @@ -269,7 +270,7 @@
"TTx, XX = torch.meshgrid(\n",
" [time, domain],\n",
" indexing=\"ij\",\n",
")\n"
")"
],
"metadata": {
"collapsed": false,
Expand Down Expand Up @@ -299,10 +300,9 @@
},
"cell_type": "code",
"source": [
"\n",
"fig, axs = plt.subplots(1, 3)\n",
"for i in range(len(TRAIN_PARAMETERS)):\n",
" axs[i].pcolormesh(TTx, XX, stacked_train_datasets[i*CONFIG.Nt:(i+1)*CONFIG.Nt].data.detach())\n"
" axs[i].pcolormesh(TTx, XX, stacked_train_datasets[i * CONFIG.Nt : (i + 1) * CONFIG.Nt].data.detach())"
],
"outputs": [
{
Expand Down Expand Up @@ -330,31 +330,31 @@
{
"cell_type": "code",
"source": [
"snapshot_matrix = stacked_train_datasets[:CONFIG.Nt].data\n",
"snapshot_matrix = stacked_train_datasets[: CONFIG.Nt].data\n",
"\n",
"plt.figure()\n",
"# plt.pcolormesh(TTx, XX, snapshot_matrix.detach(), cmap=mpl.colormaps[\"magma\"])\n",
"plt.pcolormesh(TTx, XX, snapshot_matrix.detach())\n",
"plt.colorbar()\n",
"plt.rc('xtick', labelsize=12) # fontsize of the tick labels\n",
"plt.rc('ytick', labelsize=12) # fontsize of the tick labels\n",
"plt.rc(\"xtick\", labelsize=12) # fontsize of the tick labels\n",
"plt.rc(\"ytick\", labelsize=12) # fontsize of the tick labels\n",
"plt.ylabel(\"Space\", fontsize=16)\n",
"plt.xlabel(\"Time\", fontsize=16)\n",
"# plt.savefig(results_path / \"advection_snapshot.png\", bbox_inches='tight')\n",
"\n",
"plt.figure()\n",
"plt.semilogy(torch.linalg.svdvals(snapshot_matrix), lw=5)\n",
"# plt.semilogy(torch.linalg.svdvals(snapshot_matrix), lw=5, c=\"magenta\")\n",
"plt.rc('xtick', labelsize=14) # fontsize of the tick labels\n",
"plt.rc('ytick', labelsize=14) # fontsize of the tick labels\n",
"plt.rc(\"xtick\", labelsize=14) # fontsize of the tick labels\n",
"plt.rc(\"ytick\", labelsize=14) # fontsize of the tick labels\n",
"plt.ylabel(\"Singular value\", fontsize=16)\n",
"plt.xlabel(\"Index\", fontsize=16)\n",
"# plt.savefig(results_path / \"advection_singular.png\", bbox_inches='tight')\n",
"\n",
"\n",
"plt.figure()\n",
"for k in [0, 10, 100, -1]:\n",
" plt.plot(domain, snapshot_matrix[k])\n"
" plt.plot(domain, snapshot_matrix[k])"
],
"metadata": {
"collapsed": false,
Expand Down Expand Up @@ -416,13 +416,15 @@
"cell_type": "code",
"source": [
"fig = plt.figure()\n",
"line, = plt.plot(domain, snapshot_matrix.data[0,:].detach(), lw=4)\n",
"(line,) = plt.plot(domain, snapshot_matrix.data[0, :].detach(), lw=4)\n",
"\n",
"\n",
"def animate(k):\n",
" line.set_ydata(snapshot_matrix.data[k,:].detach())\n",
" return line,\n",
" line.set_ydata(snapshot_matrix.data[k, :].detach())\n",
" return (line,)\n",
"\n",
"\n",
"# ani = FuncAnimation(fig, animate, interval=40, blit=True, repeat=True, frames=CONFIG.Nt) \n",
"# ani = FuncAnimation(fig, animate, interval=40, blit=True, repeat=True, frames=CONFIG.Nt)\n",
"# ani.save(results_path / \"advection.gif\", dpi=300, writer=PillowWriter(fps=25))"
],
"outputs": [
Expand Down Expand Up @@ -479,7 +481,7 @@
"plt.figure()\n",
"plt.semilogy(relative_error, lw=7)\n",
"plt.ylabel(\"Singular values\")\n",
"plt.xlabel(\"Index\")\n"
"plt.xlabel(\"Index\")"
],
"metadata": {
"collapsed": false,
Expand Down Expand Up @@ -539,9 +541,9 @@
"\n",
"plt.figure()\n",
"plt.semilogy(relative_error, lw=3)\n",
"plt.xlim((0,50))\n",
"plt.xlim((0, 50))\n",
"plt.ylabel(\"Relative error\")\n",
"plt.xlabel(\"index\")\n"
"plt.xlabel(\"index\")"
],
"outputs": [
{
Expand Down Expand Up @@ -616,16 +618,18 @@
"cell_type": "code",
"outputs": [],
"execution_count": 37,
"source": "Vhr = Vh[:ldim,:]"
"source": [
"Vhr = Vh[:ldim, :]"
]
},
{
"cell_type": "code",
"source": [
"plt.figure()\n",
"plt.plot(domain, Vh[1,:], label=\"mode 1\")\n",
"plt.plot(domain, Vh[2,:], label=\"mode 2\")\n",
"plt.plot(domain, Vh[3,:], label=\"mode 3\")\n",
"plt.plot(domain, Vh[4,:], label=\"mode 4\")\n"
"plt.plot(domain, Vh[1, :], label=\"mode 1\")\n",
"plt.plot(domain, Vh[2, :], label=\"mode 2\")\n",
"plt.plot(domain, Vh[3, :], label=\"mode 3\")\n",
"plt.plot(domain, Vh[4, :], label=\"mode 4\")"
],
"metadata": {
"collapsed": false,
Expand Down Expand Up @@ -671,7 +675,12 @@
}
},
"cell_type": "code",
"source": "print(torch.norm(stacked_train_datasets - torch.einsum(\"nm,bn->bm\", Vhr.T @ Vhr, stacked_train_datasets))**2 / torch.norm(stacked_train_datasets)**2)",
"source": [
"print(\n",
" torch.norm(stacked_train_datasets - torch.einsum(\"nm,bn->bm\", Vhr.T @ Vhr, stacked_train_datasets)) ** 2\n",
" / torch.norm(stacked_train_datasets) ** 2\n",
")"
],
"outputs": [
{
"name": "stdout",
Expand Down Expand Up @@ -737,7 +746,7 @@
"source": [
"plt.figure()\n",
"plt.plot(snapshot_matrix[-1].detach(), lw=5)\n",
"plt.plot((Vhr.T @ Vhr @ snapshot_matrix[-1]).detach())\n"
"plt.plot((Vhr.T @ Vhr @ snapshot_matrix[-1]).detach())"
],
"metadata": {
"collapsed": false,
Expand Down
Loading

0 comments on commit 4ffa9b8

Please sign in to comment.