Skip to content

Commit

Permalink
Applied Isort
Browse files Browse the repository at this point in the history
  • Loading branch information
Heeringa committed Aug 13, 2024
1 parent 4ffa9b8 commit ab97367
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 30 deletions.
7 changes: 3 additions & 4 deletions src/data/databases/pde_database.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Any
from dataclasses import dataclass, asdict, field
from pathlib import Path
import json

from dataclasses import asdict, dataclass, field
from os import PathLike
from pathlib import Path
from typing import Any

import torch

Expand Down
18 changes: 7 additions & 11 deletions src/data/generate_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,21 @@

import torch

from src.data.databases.diffusion_1d_database import (
Diffusion1DDatabase,
Diffusion1DParamSet,
)
from src.data.databases.advection_1d_database import (
Advection1DAnalyticDatabase,
Advection1DParamSet,
)
from src.data.databases.diffusion_1d_database import (
Diffusion1DDatabase,
Diffusion1DParamSet,
)
from src.data.databases.pde_database import Config
from src.data.databases.reaction_diffusion_2d_database import (
ReactionDiffusion2DParamSet,
ReactionDiffusion2DDatabase,
ReactionDiffusion2DParamSet,
)
from src.data.initial_conditions.gaussian import (
Gaussian,
GaussianParamSet,
)
from src.data.initial_conditions.spiral import SpiralParamSet, Spiral

from src.data.initial_conditions.gaussian import Gaussian, GaussianParamSet
from src.data.initial_conditions.spiral import Spiral, SpiralParamSet

BASE_DIFFUSION_CONFIG = Config(
Nt=5001,
Expand Down
5 changes: 2 additions & 3 deletions src/training/advection/train.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from pathlib import Path
import sys
from pathlib import Path

import torch
import wandb

sys.path.insert(0, str(Path.cwd()))

from src.training.train import train # noqa: E402
from src.data.generate_datasets import generate_advection_dataset # noqa: E402

from src.training.train import train # noqa: E402

TRAIN_PARAMETERS = torch.tensor([0.6, 0.9, 1.2])
VALIDATION_PARAMETERS = torch.tensor([0.75])
Expand Down
5 changes: 2 additions & 3 deletions src/training/diffusion/train.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from pathlib import Path
import sys
from pathlib import Path

import wandb

sys.path.insert(0, str(Path.cwd()))

from src.training.train import train # noqa: E402
from src.data.generate_datasets import generate_diffusion_dataset # noqa: E402

from src.training.train import train # noqa: E402

TRAIN_PARAMETERS = [0.1, 0.5, 1]
VALIDATION_PARAMETERS = [0.8]
Expand Down
5 changes: 2 additions & 3 deletions src/training/reaction_diffusion/train.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from pathlib import Path
import sys
from pathlib import Path

import torch.utils.data
import wandb

sys.path.insert(0, str(Path.cwd()))

from src.training.train import train # noqa: E402
from src.data.generate_datasets import generate_reaction_diffusion_dataset # noqa: E402

from src.training.train import train # noqa: E402

if __name__ == "__main__":
# If called by wandb.agent, as below,
Expand Down
7 changes: 3 additions & 4 deletions src/training/train.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import copy
import math
from pathlib import Path
import sys
from pathlib import Path

import wandb
import bregman
import torch
import wandb
from tqdm import tqdm

import bregman

sys.path.insert(0, str(Path.cwd()))

from src.utils import L12_nuclear, init_linear # noqa: E402
Expand Down
2 changes: 1 addition & 1 deletion src/utils/L12_nuclear.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import typing

from bregman import Null, L12, Nuclear
from bregman import L12, Nuclear, Null

from .get_bias import get_bias

Expand Down
2 changes: 1 addition & 1 deletion src/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .get_bias import get_bias # noqa: F401
from .get_weights_linear import get_weights_linear
from .L12_nuclear import L12_nuclear
from .init_linear import init_linear
from .L12_nuclear import L12_nuclear

0 comments on commit ab97367

Please sign in to comment.