Skip to content

Commit

Permalink
Fix linting (due to shorter package name in imports)
Browse files Browse the repository at this point in the history
  • Loading branch information
fsschneider committed Jan 15, 2025
1 parent 37f556d commit bc666a7
Show file tree
Hide file tree
Showing 21 changed files with 31 additions and 78 deletions.
3 changes: 1 addition & 2 deletions algoperf/workloads/cifar/cifar_jax/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
import jax.numpy as jnp

from algoperf import spec
from algoperf.workloads.imagenet_resnet.imagenet_jax.models import \
ResNetBlock
from algoperf.workloads.imagenet_resnet.imagenet_jax.models import ResNetBlock

ModuleDef = nn.Module

Expand Down
3 changes: 1 addition & 2 deletions algoperf/workloads/cifar/cifar_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from algoperf import param_utils
from algoperf import spec
from algoperf.workloads.cifar.cifar_jax import models
from algoperf.workloads.cifar.cifar_jax.input_pipeline import \
create_input_iter
from algoperf.workloads.cifar.cifar_jax.input_pipeline import create_input_iter
from algoperf.workloads.cifar.workload import BaseCifarWorkload


Expand Down
3 changes: 1 addition & 2 deletions algoperf/workloads/cifar/cifar_pytorch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
BasicBlock
from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \
Bottleneck
from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \
conv1x1
from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import conv1x1


class ResNet(nn.Module):
Expand Down
3 changes: 1 addition & 2 deletions algoperf/workloads/cifar/cifar_pytorch/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
from algoperf import param_utils
from algoperf import pytorch_utils
from algoperf import spec
from algoperf.workloads.cifar.cifar_pytorch.models import \
resnet18
from algoperf.workloads.cifar.cifar_pytorch.models import resnet18
from algoperf.workloads.cifar.workload import BaseCifarWorkload

USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup()
Expand Down
3 changes: 1 addition & 2 deletions algoperf/workloads/fastmri/fastmri_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
import algoperf.random_utils as prng
from algoperf.workloads.fastmri.fastmri_jax.models import UNet
from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim
from algoperf.workloads.fastmri.workload import \
BaseFastMRIWorkload
from algoperf.workloads.fastmri.workload import BaseFastMRIWorkload


class FastMRIWorkload(BaseFastMRIWorkload):
Expand Down
6 changes: 2 additions & 4 deletions algoperf/workloads/fastmri/fastmri_pytorch/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@
from algoperf import pytorch_utils
from algoperf import spec
import algoperf.random_utils as prng
from algoperf.workloads.fastmri.fastmri_pytorch.models import \
UNet
from algoperf.workloads.fastmri.fastmri_pytorch.models import UNet
from algoperf.workloads.fastmri.fastmri_pytorch.ssim import ssim
from algoperf.workloads.fastmri.workload import \
BaseFastMRIWorkload
from algoperf.workloads.fastmri.workload import BaseFastMRIWorkload

USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@

from algoperf import data_utils
from algoperf import spec
from algoperf.workloads.imagenet_resnet.imagenet_jax import \
randaugment
from algoperf.workloads.imagenet_resnet.imagenet_jax import randaugment

TFDS_SPLIT_NAME = {
'train': 'train', 'eval_train': 'train', 'validation': 'validation'
Expand Down
6 changes: 2 additions & 4 deletions algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@
from algoperf import random_utils as prng
from algoperf import spec
from algoperf.workloads.imagenet_resnet import imagenet_v2
from algoperf.workloads.imagenet_resnet.imagenet_jax import \
input_pipeline
from algoperf.workloads.imagenet_resnet.imagenet_jax import \
models
from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline
from algoperf.workloads.imagenet_resnet.imagenet_jax import models
from algoperf.workloads.imagenet_resnet.workload import \
BaseImagenetResNetWorkload

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@
from algoperf import spec
import algoperf.random_utils as prng
from algoperf.workloads.imagenet_resnet import imagenet_v2
from algoperf.workloads.imagenet_resnet.imagenet_pytorch import \
randaugment
from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \
resnet50
from algoperf.workloads.imagenet_resnet.imagenet_pytorch import randaugment
from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import resnet50
from algoperf.workloads.imagenet_resnet.workload import \
BaseImagenetResNetWorkload

Expand Down
3 changes: 1 addition & 2 deletions algoperf/workloads/imagenet_resnet/imagenet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@

from algoperf import data_utils
from algoperf import spec
from algoperf.workloads.imagenet_resnet.imagenet_jax import \
input_pipeline
from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline


def get_imagenet_v2_iter(data_dir: str,
Expand Down
6 changes: 2 additions & 4 deletions algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \
ImagenetResNetWorkload
from algoperf.workloads.imagenet_vit.imagenet_jax import models
from algoperf.workloads.imagenet_vit.workload import \
BaseImagenetVitWorkload
from algoperf.workloads.imagenet_vit.workload import \
decode_variant
from algoperf.workloads.imagenet_vit.workload import BaseImagenetVitWorkload
from algoperf.workloads.imagenet_vit.workload import decode_variant


# Make sure we inherit from the ViT base workload first.
Expand Down
3 changes: 1 addition & 2 deletions algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@

from algoperf import init_utils
from algoperf import spec
from algoperf.workloads.wmt.wmt_pytorch.models import \
MultiheadAttention
from algoperf.workloads.wmt.wmt_pytorch.models import MultiheadAttention


def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.) -> spec.Tensor:
Expand Down
9 changes: 3 additions & 6 deletions algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,9 @@
from algoperf import spec
from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \
ImagenetResNetWorkload
from algoperf.workloads.imagenet_vit.imagenet_pytorch import \
models
from algoperf.workloads.imagenet_vit.workload import \
BaseImagenetVitWorkload
from algoperf.workloads.imagenet_vit.workload import \
decode_variant
from algoperf.workloads.imagenet_vit.imagenet_pytorch import models
from algoperf.workloads.imagenet_vit.workload import BaseImagenetVitWorkload
from algoperf.workloads.imagenet_vit.workload import decode_variant

USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
from algoperf.workloads.librispeech_conformer import workload
from algoperf.workloads.librispeech_conformer.input_pipeline import \
LibriSpeechDataset
from algoperf.workloads.librispeech_conformer.librispeech_jax import \
models
from algoperf.workloads.librispeech_conformer.librispeech_jax import models


class LibriSpeechConformerWorkload(workload.BaseLibrispeechWorkload):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
from algoperf.workloads.librispeech_conformer import workload
from algoperf.workloads.librispeech_conformer.input_pipeline import \
LibriSpeechDataset
from algoperf.workloads.librispeech_conformer.librispeech_pytorch import \
models
from algoperf.workloads.librispeech_conformer.librispeech_pytorch import models

USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from algoperf import spec
from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \
LibriSpeechConformerWorkload
from algoperf.workloads.librispeech_deepspeech.librispeech_jax import \
models
from algoperf.workloads.librispeech_deepspeech.librispeech_jax import models


class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload):
Expand Down
3 changes: 1 addition & 2 deletions tests/modeldiffs/wmt/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import torch

from algoperf import spec
from algoperf.workloads.wmt.wmt_jax.workload import \
WmtWorkload as JaxWorkload
from algoperf.workloads.wmt.wmt_jax.workload import WmtWorkload as JaxWorkload
from algoperf.workloads.wmt.wmt_pytorch.workload import \
WmtWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff
Expand Down
6 changes: 2 additions & 4 deletions tests/reference_algorithm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,8 @@
from algoperf import random_utils as prng
from algoperf.profiler import PassThroughProfiler
from algoperf.workloads import workloads
from algoperf.workloads.ogbg import \
input_pipeline as ogbg_input_pipeline
from algoperf.workloads.ogbg.ogbg_pytorch.workload import \
_graph_map
from algoperf.workloads.ogbg import input_pipeline as ogbg_input_pipeline
from algoperf.workloads.ogbg.ogbg_pytorch.workload import _graph_map
import submission_runner
from tests.modeldiffs import diff as diff_utils

Expand Down
15 changes: 5 additions & 10 deletions tests/test_num_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
resnet18 as PyTorchResNet_c10
from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \
resnet50 as PyTorchResNet
from algoperf.workloads.imagenet_vit.imagenet_jax.models import \
ViT as JaxViT
from algoperf.workloads.imagenet_vit.imagenet_jax.models import ViT as JaxViT
from algoperf.workloads.imagenet_vit.imagenet_pytorch.models import \
ViT as PyTorchViT
from algoperf.workloads.librispeech_conformer.librispeech_jax.models import \
Expand All @@ -29,17 +28,13 @@
ConformerConfig as PytorchConformerConfig
from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \
ConformerEncoderDecoder as PytorchConformer
from algoperf.workloads.mnist.mnist_jax.workload import \
_Model as JaxMLP
from algoperf.workloads.mnist.mnist_jax.workload import _Model as JaxMLP
from algoperf.workloads.mnist.mnist_pytorch.workload import \
_Model as PyTorchMLP
from algoperf.workloads.ogbg.ogbg_jax.models import GNN as JaxGNN
from algoperf.workloads.ogbg.ogbg_pytorch.models import \
GNN as PyTorchGNN
from algoperf.workloads.wmt.wmt_jax.models import \
Transformer as JaxTransformer
from algoperf.workloads.wmt.wmt_jax.models import \
TransformerConfig
from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN as PyTorchGNN
from algoperf.workloads.wmt.wmt_jax.models import Transformer as JaxTransformer
from algoperf.workloads.wmt.wmt_jax.models import TransformerConfig
from algoperf.workloads.wmt.wmt_pytorch.models import \
Transformer as PyTorchTransformer

Expand Down
3 changes: 1 addition & 2 deletions tests/test_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from algoperf.pytorch_utils import pytorch_setup
from algoperf.workloads.fastmri.fastmri_jax.ssim import \
_uniform_filter as _jax_uniform_filter
from algoperf.workloads.fastmri.fastmri_jax.ssim import \
ssim as jax_ssim
from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim as jax_ssim
from algoperf.workloads.fastmri.fastmri_pytorch.ssim import \
_uniform_filter as _pytorch_uniform_filter
from algoperf.workloads.fastmri.fastmri_pytorch.ssim import \
Expand Down
16 changes: 0 additions & 16 deletions tests/version_test.py

This file was deleted.

0 comments on commit bc666a7

Please sign in to comment.