diff --git a/algoperf/workloads/cifar/cifar_jax/models.py b/algoperf/workloads/cifar/cifar_jax/models.py index 4d5df766e..957079272 100644 --- a/algoperf/workloads/cifar/cifar_jax/models.py +++ b/algoperf/workloads/cifar/cifar_jax/models.py @@ -11,8 +11,7 @@ import jax.numpy as jnp from algoperf import spec -from algoperf.workloads.imagenet_resnet.imagenet_jax.models import \ - ResNetBlock +from algoperf.workloads.imagenet_resnet.imagenet_jax.models import ResNetBlock ModuleDef = nn.Module diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index f4bcffbc3..952bb977d 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -14,8 +14,7 @@ from algoperf import param_utils from algoperf import spec from algoperf.workloads.cifar.cifar_jax import models -from algoperf.workloads.cifar.cifar_jax.input_pipeline import \ - create_input_iter +from algoperf.workloads.cifar.cifar_jax.input_pipeline import create_input_iter from algoperf.workloads.cifar.workload import BaseCifarWorkload diff --git a/algoperf/workloads/cifar/cifar_pytorch/models.py b/algoperf/workloads/cifar/cifar_pytorch/models.py index 393d568b9..e6a7a8a81 100644 --- a/algoperf/workloads/cifar/cifar_pytorch/models.py +++ b/algoperf/workloads/cifar/cifar_pytorch/models.py @@ -16,8 +16,7 @@ BasicBlock from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \ Bottleneck -from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \ - conv1x1 +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import conv1x1 class ResNet(nn.Module): diff --git a/algoperf/workloads/cifar/cifar_pytorch/workload.py b/algoperf/workloads/cifar/cifar_pytorch/workload.py index 2ba92f0b9..b16d62204 100644 --- a/algoperf/workloads/cifar/cifar_pytorch/workload.py +++ b/algoperf/workloads/cifar/cifar_pytorch/workload.py @@ -16,8 +16,7 @@ from algoperf import param_utils from algoperf import pytorch_utils from algoperf import spec -from algoperf.workloads.cifar.cifar_pytorch.models import \ - resnet18 +from algoperf.workloads.cifar.cifar_pytorch.models import resnet18 from algoperf.workloads.cifar.workload import BaseCifarWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index 393aa19d7..1156cf30a 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -13,8 +13,7 @@ import algoperf.random_utils as prng from algoperf.workloads.fastmri.fastmri_jax.models import UNet from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim -from algoperf.workloads.fastmri.workload import \ - BaseFastMRIWorkload +from algoperf.workloads.fastmri.workload import BaseFastMRIWorkload class FastMRIWorkload(BaseFastMRIWorkload): diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py index f40654678..58943de2f 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py @@ -13,11 +13,9 @@ from algoperf import pytorch_utils from algoperf import spec import algoperf.random_utils as prng -from algoperf.workloads.fastmri.fastmri_pytorch.models import \ - UNet +from algoperf.workloads.fastmri.fastmri_pytorch.models import UNet from algoperf.workloads.fastmri.fastmri_pytorch.ssim import ssim -from algoperf.workloads.fastmri.workload import \ - BaseFastMRIWorkload +from algoperf.workloads.fastmri.workload import BaseFastMRIWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py index 709a318c2..66105335b 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py @@ -14,8 +14,7 @@ from algoperf import data_utils from algoperf import spec -from algoperf.workloads.imagenet_resnet.imagenet_jax import \ - randaugment +from algoperf.workloads.imagenet_resnet.imagenet_jax import randaugment TFDS_SPLIT_NAME = { 'train': 'train', 'eval_train': 'train', 'validation': 'validation' diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index b445e9f00..9494fd63c 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -21,10 +21,8 @@ from algoperf import random_utils as prng from algoperf import spec from algoperf.workloads.imagenet_resnet import imagenet_v2 -from algoperf.workloads.imagenet_resnet.imagenet_jax import \ - input_pipeline -from algoperf.workloads.imagenet_resnet.imagenet_jax import \ - models +from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline +from algoperf.workloads.imagenet_resnet.imagenet_jax import models from algoperf.workloads.imagenet_resnet.workload import \ BaseImagenetResNetWorkload diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 7a08f325e..92b651ba2 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -22,10 +22,8 @@ from algoperf import spec import algoperf.random_utils as prng from algoperf.workloads.imagenet_resnet import imagenet_v2 -from algoperf.workloads.imagenet_resnet.imagenet_pytorch import \ - randaugment -from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \ - resnet50 +from algoperf.workloads.imagenet_resnet.imagenet_pytorch import randaugment +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import resnet50 from algoperf.workloads.imagenet_resnet.workload import \ BaseImagenetResNetWorkload diff --git a/algoperf/workloads/imagenet_resnet/imagenet_v2.py b/algoperf/workloads/imagenet_resnet/imagenet_v2.py index f63ddbc34..84d364586 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_v2.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_v2.py @@ -10,8 +10,7 @@ from algoperf import data_utils from algoperf import spec -from algoperf.workloads.imagenet_resnet.imagenet_jax import \ - input_pipeline +from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline def get_imagenet_v2_iter(data_dir: str, diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 2261aac6d..9a6190f5e 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -12,10 +12,8 @@ from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetWorkload from algoperf.workloads.imagenet_vit.imagenet_jax import models -from algoperf.workloads.imagenet_vit.workload import \ - BaseImagenetVitWorkload -from algoperf.workloads.imagenet_vit.workload import \ - decode_variant +from algoperf.workloads.imagenet_vit.workload import BaseImagenetVitWorkload +from algoperf.workloads.imagenet_vit.workload import decode_variant # Make sure we inherit from the ViT base workload first. diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py index 4fac8bd35..fcf0992d3 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -14,8 +14,7 @@ from algoperf import init_utils from algoperf import spec -from algoperf.workloads.wmt.wmt_pytorch.models import \ - MultiheadAttention +from algoperf.workloads.wmt.wmt_pytorch.models import MultiheadAttention def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.) -> spec.Tensor: diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py index 20b294b47..97bb38515 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -11,12 +11,9 @@ from algoperf import spec from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \ ImagenetResNetWorkload -from algoperf.workloads.imagenet_vit.imagenet_pytorch import \ - models -from algoperf.workloads.imagenet_vit.workload import \ - BaseImagenetVitWorkload -from algoperf.workloads.imagenet_vit.workload import \ - decode_variant +from algoperf.workloads.imagenet_vit.imagenet_pytorch import models +from algoperf.workloads.imagenet_vit.workload import BaseImagenetVitWorkload +from algoperf.workloads.imagenet_vit.workload import decode_variant USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index b4fdb0811..8d9872461 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -18,8 +18,7 @@ from algoperf.workloads.librispeech_conformer import workload from algoperf.workloads.librispeech_conformer.input_pipeline import \ LibriSpeechDataset -from algoperf.workloads.librispeech_conformer.librispeech_jax import \ - models +from algoperf.workloads.librispeech_conformer.librispeech_jax import models class LibriSpeechConformerWorkload(workload.BaseLibrispeechWorkload): diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 592e63989..974b3bb19 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -19,8 +19,7 @@ from algoperf.workloads.librispeech_conformer import workload from algoperf.workloads.librispeech_conformer.input_pipeline import \ LibriSpeechDataset -from algoperf.workloads.librispeech_conformer.librispeech_pytorch import \ - models +from algoperf.workloads.librispeech_conformer.librispeech_pytorch import models USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 3e0781deb..9fd0898b4 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -10,8 +10,7 @@ from algoperf import spec from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerWorkload -from algoperf.workloads.librispeech_deepspeech.librispeech_jax import \ - models +from algoperf.workloads.librispeech_deepspeech.librispeech_jax import models class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload): diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 73bc03f78..64401ef7f 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -7,8 +7,7 @@ import torch from algoperf import spec -from algoperf.workloads.wmt.wmt_jax.workload import \ - WmtWorkload as JaxWorkload +from algoperf.workloads.wmt.wmt_jax.workload import WmtWorkload as JaxWorkload from algoperf.workloads.wmt.wmt_pytorch.workload import \ WmtWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index 3f279a605..0a17e470c 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -45,10 +45,8 @@ from algoperf import random_utils as prng from algoperf.profiler import PassThroughProfiler from algoperf.workloads import workloads -from algoperf.workloads.ogbg import \ - input_pipeline as ogbg_input_pipeline -from algoperf.workloads.ogbg.ogbg_pytorch.workload import \ - _graph_map +from algoperf.workloads.ogbg import input_pipeline as ogbg_input_pipeline +from algoperf.workloads.ogbg.ogbg_pytorch.workload import _graph_map import submission_runner from tests.modeldiffs import diff as diff_utils diff --git a/tests/test_num_params.py b/tests/test_num_params.py index 83a23c9a4..b0633025e 100644 --- a/tests/test_num_params.py +++ b/tests/test_num_params.py @@ -17,8 +17,7 @@ resnet18 as PyTorchResNet_c10 from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \ resnet50 as PyTorchResNet -from algoperf.workloads.imagenet_vit.imagenet_jax.models import \ - ViT as JaxViT +from algoperf.workloads.imagenet_vit.imagenet_jax.models import ViT as JaxViT from algoperf.workloads.imagenet_vit.imagenet_pytorch.models import \ ViT as PyTorchViT from algoperf.workloads.librispeech_conformer.librispeech_jax.models import \ @@ -29,17 +28,13 @@ ConformerConfig as PytorchConformerConfig from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ ConformerEncoderDecoder as PytorchConformer -from algoperf.workloads.mnist.mnist_jax.workload import \ - _Model as JaxMLP +from algoperf.workloads.mnist.mnist_jax.workload import _Model as JaxMLP from algoperf.workloads.mnist.mnist_pytorch.workload import \ _Model as PyTorchMLP from algoperf.workloads.ogbg.ogbg_jax.models import GNN as JaxGNN -from algoperf.workloads.ogbg.ogbg_pytorch.models import \ - GNN as PyTorchGNN -from algoperf.workloads.wmt.wmt_jax.models import \ - Transformer as JaxTransformer -from algoperf.workloads.wmt.wmt_jax.models import \ - TransformerConfig +from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN as PyTorchGNN +from algoperf.workloads.wmt.wmt_jax.models import Transformer as JaxTransformer +from algoperf.workloads.wmt.wmt_jax.models import TransformerConfig from algoperf.workloads.wmt.wmt_pytorch.models import \ Transformer as PyTorchTransformer diff --git a/tests/test_ssim.py b/tests/test_ssim.py index ba0b2ca7f..920556964 100644 --- a/tests/test_ssim.py +++ b/tests/test_ssim.py @@ -12,8 +12,7 @@ from algoperf.pytorch_utils import pytorch_setup from algoperf.workloads.fastmri.fastmri_jax.ssim import \ _uniform_filter as _jax_uniform_filter -from algoperf.workloads.fastmri.fastmri_jax.ssim import \ - ssim as jax_ssim +from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim as jax_ssim from algoperf.workloads.fastmri.fastmri_pytorch.ssim import \ _uniform_filter as _pytorch_uniform_filter from algoperf.workloads.fastmri.fastmri_pytorch.ssim import \ diff --git a/tests/version_test.py b/tests/version_test.py deleted file mode 100644 index 2205b305f..000000000 --- a/tests/version_test.py +++ /dev/null @@ -1,16 +0,0 @@ -"""Check whether the __version__ attribute is set correctly.""" - -import algoperf - - -def test_version_attribute(): - """Check whether __version__ exists and is a valid string.""" - - assert hasattr(algoperf, "__version__") - version = algoperf.__version__ - assert isinstance(version, str) - version_elements = version.split(".") - print(version_elements) - # Only check the first three elements, i.e. major, minor, patch. - # The remaining elements contain commit hash and dirty status. - assert all(el.isnumeric() for el in version_elements[0:3])