diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index f9e2f15a8..607a3dd79 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -14,7 +14,7 @@ jobs: - name: Install pylint run: | python -m pip install --upgrade pip - pip install pylint + pip install pylint==2.16.1 - name: Run pylint run: | pylint algorithmic_efficiency diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 11c31a4d6..cc8f13d25 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,6 +9,6 @@ repos: hooks: - id: isort - repo: https://github.com/pycqa/pylint - rev: v2.14.5 + rev: v2.16.1 hooks: - id: pylint diff --git a/algorithmic_efficiency/checkpoint_utils.py b/algorithmic_efficiency/checkpoint_utils.py index f35bceb3e..fb7449b99 100644 --- a/algorithmic_efficiency/checkpoint_utils.py +++ b/algorithmic_efficiency/checkpoint_utils.py @@ -68,14 +68,15 @@ def maybe_restore_checkpoint(framework: str, uninitialized_global_step = -1 uninitialized_preemption_count = -1 - checkpoint_state = dict( - model_params=model_params, - optimizer_state=opt_state, - model_state=model_state, - train_state=train_state, - eval_results=None, - global_step=uninitialized_global_step, - preemption_count=uninitialized_preemption_count) + checkpoint_state = { + 'model_params': model_params, + 'optimizer_state': opt_state, + 'model_state': model_state, + 'train_state': train_state, + 'eval_results': None, + 'global_step': uninitialized_global_step, + 'preemption_count': uninitialized_preemption_count, + } if framework == 'jax': latest_ckpt = flax_checkpoints.restore_checkpoint( @@ -90,7 +91,7 @@ def maybe_restore_checkpoint(framework: str, # Load_latest_checkpoint() will return checkpoint_state if # checkpoint_dir does not exist or if it exists and contains no checkpoints. - found_checkpoint = (latest_ckpt['global_step'] != uninitialized_global_step) + found_checkpoint = latest_ckpt['global_step'] != uninitialized_global_step if not found_checkpoint: return (optimizer_state, @@ -209,14 +210,15 @@ def save_checkpoint(framework: str, 'method.') opt_state = optimizer_state_dict - checkpoint_state = dict( - model_params=model_params, - optimizer_state=opt_state, - model_state=model_state, - train_state=train_state, - eval_results=tuple(eval_results), - global_step=global_step, - preemption_count=preemption_count) + checkpoint_state = { + 'model_params': model_params, + 'optimizer_state': opt_state, + 'model_state': model_state, + 'train_state': train_state, + 'eval_results': tuple(eval_results), + 'global_step': global_step, + 'preemption_count': preemption_count, + } save_path = os.path.join(checkpoint_dir, f'checkpoint_{global_step}') if framework == 'jax': diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py index 007891dfe..510c4c2b7 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py @@ -39,10 +39,10 @@ class MlpBlock(nn.Module): @nn.compact def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: """Applies Transformer MlpBlock module.""" - inits = dict( - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.normal(stddev=1e-6), - ) + inits = { + 'kernel_init': nn.initializers.xavier_uniform(), + 'bias_init': nn.initializers.normal(stddev=1e-6), + } d = x.shape[2] x = nn.Dense(self.mlp_dim or 4 * d, **inits)(x) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index c910e4425..95f9b8542 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -239,7 +239,7 @@ def greedy_decode( fin_result.shape[1], device=result.device).view( 1, -1) < result.count_nonzero(dim=1).view(-1, 1) fin_result.view(-1)[idxs[mask != 0]] = result[result != blank_id] - padding = (fin_result == 0) + padding = fin_result == 0 return fin_result, padding def sync_sd(self, params: spec.ParameterContainer) -> None: diff --git a/algorithmic_efficiency/workloads/ogbg/metrics.py b/algorithmic_efficiency/workloads/ogbg/metrics.py index f971f2142..a654eb2ae 100644 --- a/algorithmic_efficiency/workloads/ogbg/metrics.py +++ b/algorithmic_efficiency/workloads/ogbg/metrics.py @@ -22,7 +22,7 @@ def predictions_match_labels(*, **kwargs) -> jnp.ndarray: """Returns a binary array indicating where predictions match the labels.""" del kwargs # Unused. - preds = (logits > 0) + preds = logits > 0 return (preds == labels).astype(jnp.float32) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py index 140d3defe..43084d47f 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py @@ -88,7 +88,7 @@ def _binary_cross_entropy_with_mask( # Numerically stable implementation of BCE loss. # This mimics TensorFlow's tf.nn.sigmoid_cross_entropy_with_logits(). - positive_logits = (logits >= 0) + positive_logits = logits >= 0 relu_logits = jnp.where(positive_logits, logits, 0) abs_logits = jnp.where(positive_logits, logits, -logits) losses = relu_logits - (logits * smoothed_labels) + ( diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index b1427f7a4..ea17ed02a 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -192,7 +192,7 @@ def _binary_cross_entropy_with_mask( # Numerically stable implementation of BCE loss. # This mimics TensorFlow's tf.nn.sigmoid_cross_entropy_with_logits(). - positive_logits = (logits >= 0) + positive_logits = logits >= 0 relu_logits = torch.where(positive_logits, logits, 0) abs_logits = torch.where(positive_logits, logits, -logits) losses = relu_logits - (logits * smoothed_labels) + ( diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/decode.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/decode.py index 4636e2757..93b2eeca7 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/decode.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/decode.py @@ -193,7 +193,7 @@ def beam_search(inputs, def beam_search_loop_cond_fn(state): """Beam search loop termination condition.""" # Have we reached max decoding length? - not_at_end = (state.cur_index < max_decode_len - 1) + not_at_end = state.cur_index < max_decode_len - 1 # Is no further progress in the beam search possible? # Get the best possible scores from alive sequences. diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/decode.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/decode.py index f96153432..0488a144f 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/decode.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/decode.py @@ -219,7 +219,7 @@ def beam_search( def beam_search_loop_cond_fn(state: BeamState) -> bool: """Beam search loop termination condition.""" # Have we reached max decoding length? - not_at_end = (state.cur_index < max_decode_len - 1) + not_at_end = state.cur_index < max_decode_len - 1 # Is no further progress in the beam search possible? # Get the best possible scores from alive sequences. diff --git a/baselines/nadamw/pytorch/submission.py b/baselines/nadamw/pytorch/submission.py index 3c5d9b990..f16bbd4a8 100644 --- a/baselines/nadamw/pytorch/submission.py +++ b/baselines/nadamw/pytorch/submission.py @@ -52,7 +52,9 @@ def __init__(self, raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') if not 0.0 <= weight_decay: raise ValueError(f'Invalid weight_decay value: {weight_decay}') - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + defaults = { + 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + } super().__init__(params, defaults) def __setstate__(self, state): diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 676650377..e8cd66c7a 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -181,10 +181,9 @@ def _maybe_prompt_for_deletion(paths, interactive_deletion): def _download_url(url, data_dir): - url = url data_dir = os.path.expanduser(data_dir) file_path = os.path.join(data_dir, url.split('/')[-1]) - response = requests.get(url, stream=True) + response = requests.get(url, stream=True, timeout=600) total_size_in_bytes = int(response.headers.get('Content-length', 0)) total_size_in_mib = total_size_in_bytes / (2**20) progress_bar = tqdm.tqdm(total=total_size_in_mib, unit='MiB', unit_scale=True) @@ -209,7 +208,7 @@ def _download_url(url, data_dir): f.write(chunk) progress_bar.close() if (progress_bar.total != 0 and progress_bar.n != progress_bar.total): - raise Exception( + raise RuntimeError( ('Download corrupted, size {n} MiB from {url} does not match ' 'expected size {size} MiB').format( url=url, n=progress_bar.n, size=progress_bar.total)) @@ -402,7 +401,6 @@ def setup_imagenet_pytorch(data_dir): os.path.join(imagenet_pytorch_data_dir, IMAGENET_VAL_TAR_FILENAME), os.path.join(imagenet_pytorch_data_dir, 'val')) - cwd = os.path.join(imagenet_pytorch_data_dir, 'train') valprep_command = [ 'wget', '-qO-', diff --git a/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py b/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py index beed93827..59a33c55c 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py @@ -55,7 +55,9 @@ def __init__(self, raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') if not 0.0 <= weight_decay: raise ValueError(f'Invalid weight_decay value: {weight_decay}') - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + defaults = { + 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + } super().__init__(params, defaults) def __setstate__(self, state): diff --git a/setup.cfg b/setup.cfg index e1be7c37f..600642c55 100644 --- a/setup.cfg +++ b/setup.cfg @@ -79,7 +79,7 @@ full_dev = # Dependencies for developing the package dev = isort==5.10.1 - pylint==2.14.5 + pylint==2.16.1 pytest==7.1.2 yapf==0.32.0 pre-commit==2.20.0 @@ -252,13 +252,13 @@ defining-attr-methods=__init__,__new__,setUp # "class_" is also a valid for the first argument to a class method. valid-classmethod-first-arg=cls,class_ [pylint.EXCEPTIONS] -overgeneral-exceptions=StandardError,Exception,BaseException +overgeneral-exceptions=builtins.StandardError,builtins.Exception,builtins.BaseException [pylint.IMPORTS] # Deprecated modules which should not be used, separated by a comma deprecated-modules=regsub,TERMIOS,Bastion,rexec,sets [pylint.FORMAT] # List of checkers and warnings to disable. -disable=abstract-method,access-member-before-definition,arguments-differ,assignment-from-no-return,attribute-defined-outside-init,bad-mcs-classmethod-argument,bad-option-value,c-extension-no-member,consider-merging-isinstance,consider-using-dict-comprehension,consider-using-enumerate,consider-using-in,consider-using-set-comprehension,consider-using-ternary,deprecated-method,design,file-ignored,fixme,global-statement,import-error,inconsistent-return-statements,invalid-unary-operand-type,len-as-condition,locally-disabled,locally-enabled,misplaced-comparison-constant,missing-docstring,multiple-imports,no-else-return,no-member,no-name-in-module,no-self-use,no-value-for-parameter,not-an-iterable,not-context-manager,pointless-except,protected-access,redefined-argument-from-local,signature-differs,similarities,simplifiable-if-expression,star-args,super-init-not-called,suppressed-message,too-many-function-args,trailing-comma-tuple,trailing-newlines,ungrouped-imports,unnecessary-pass,unsubscriptable-object,unused-argument,useless-object-inheritance,useless-return,useless-suppression,wrong-import-order,wrong-import-position,unneeded-not,unexpected-keyword-arg,redundant-keyword-arg,unspecified-encoding,logging-fstring-interpolation,consider-using-f-string +disable=abstract-method,access-member-before-definition,arguments-differ,assignment-from-no-return,attribute-defined-outside-init,bad-mcs-classmethod-argument,bad-option-value,c-extension-no-member,consider-merging-isinstance,consider-using-dict-comprehension,consider-using-enumerate,consider-using-in,consider-using-set-comprehension,consider-using-ternary,deprecated-method,design,file-ignored,fixme,global-statement,import-error,inconsistent-return-statements,invalid-unary-operand-type,len-as-condition,locally-disabled,locally-enabled,misplaced-comparison-constant,missing-docstring,multiple-imports,no-else-return,no-member,no-name-in-module,no-self-use,no-value-for-parameter,not-an-iterable,not-context-manager,pointless-except,protected-access,redefined-argument-from-local,signature-differs,similarities,simplifiable-if-expression,star-args,super-init-not-called,suppressed-message,too-many-function-args,trailing-comma-tuple,trailing-newlines,ungrouped-imports,unnecessary-pass,unsubscriptable-object,unused-argument,useless-object-inheritance,useless-return,useless-suppression,wrong-import-order,wrong-import-position,unneeded-not,unexpected-keyword-arg,redundant-keyword-arg,unspecified-encoding,logging-fstring-interpolation,consider-using-f-string,use-dict-literal # Maximum number of characters on a single line. max-line-length=80 diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index 5d4fb774d..2d04a1758 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -266,12 +266,13 @@ def _build_input_queue(self, *args, **kwargs): def _fake_iter(): while True: - fake_batch = dict( - num_nodes=tf.ones((1,), dtype=tf.int64), - edge_index=tf.ones((1, 2), dtype=tf.int64), - node_feat=tf.random.normal((1, 9)), - edge_feat=tf.random.normal((1, 3)), - labels=tf.ones((self._num_outputs,))) + fake_batch = { + 'num_nodes': tf.ones((1,), dtype=tf.int64), + 'edge_index': tf.ones((1, 2), dtype=tf.int64), + 'node_feat': tf.random.normal((1, 9)), + 'edge_feat': tf.random.normal((1, 3)), + 'labels': tf.ones((self._num_outputs,)), + } yield fake_batch fake_batch_iter = ogbg_input_pipeline._get_batch_iterator(