Skip to content

Commit

Permalink
fixing pylint version skew. rewrote some dicts as literals before dis…
Browse files Browse the repository at this point in the history
…abling that check.
  • Loading branch information
znado committed Feb 3, 2023
1 parent ab798f7 commit c7691e2
Show file tree
Hide file tree
Showing 15 changed files with 49 additions and 44 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/linting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
- name: Install pylint
run: |
python -m pip install --upgrade pip
pip install pylint
pip install pylint==2.16.1
- name: Run pylint
run: |
pylint algorithmic_efficiency
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ repos:
hooks:
- id: isort
- repo: https://github.com/pycqa/pylint
rev: v2.14.5
rev: v2.16.1
hooks:
- id: pylint
36 changes: 19 additions & 17 deletions algorithmic_efficiency/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,15 @@ def maybe_restore_checkpoint(framework: str,

uninitialized_global_step = -1
uninitialized_preemption_count = -1
checkpoint_state = dict(
model_params=model_params,
optimizer_state=opt_state,
model_state=model_state,
train_state=train_state,
eval_results=None,
global_step=uninitialized_global_step,
preemption_count=uninitialized_preemption_count)
checkpoint_state = {
'model_params': model_params,
'optimizer_state': opt_state,
'model_state': model_state,
'train_state': train_state,
'eval_results': None,
'global_step': uninitialized_global_step,
'preemption_count': uninitialized_preemption_count,
}

if framework == 'jax':
latest_ckpt = flax_checkpoints.restore_checkpoint(
Expand All @@ -90,7 +91,7 @@ def maybe_restore_checkpoint(framework: str,

# Load_latest_checkpoint() will return checkpoint_state if
# checkpoint_dir does not exist or if it exists and contains no checkpoints.
found_checkpoint = (latest_ckpt['global_step'] != uninitialized_global_step)
found_checkpoint = latest_ckpt['global_step'] != uninitialized_global_step

if not found_checkpoint:
return (optimizer_state,
Expand Down Expand Up @@ -209,14 +210,15 @@ def save_checkpoint(framework: str,
'method.')
opt_state = optimizer_state_dict

checkpoint_state = dict(
model_params=model_params,
optimizer_state=opt_state,
model_state=model_state,
train_state=train_state,
eval_results=tuple(eval_results),
global_step=global_step,
preemption_count=preemption_count)
checkpoint_state = {
'model_params': model_params,
'optimizer_state': opt_state,
'model_state': model_state,
'train_state': train_state,
'eval_results': tuple(eval_results),
'global_step': global_step,
'preemption_count': preemption_count,
}

save_path = os.path.join(checkpoint_dir, f'checkpoint_{global_step}')
if framework == 'jax':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ class MlpBlock(nn.Module):
@nn.compact
def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
"""Applies Transformer MlpBlock module."""
inits = dict(
kernel_init=nn.initializers.xavier_uniform(),
bias_init=nn.initializers.normal(stddev=1e-6),
)
inits = {
'kernel_init': nn.initializers.xavier_uniform(),
'bias_init': nn.initializers.normal(stddev=1e-6),
}

d = x.shape[2]
x = nn.Dense(self.mlp_dim or 4 * d, **inits)(x)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def greedy_decode(
fin_result.shape[1], device=result.device).view(
1, -1) < result.count_nonzero(dim=1).view(-1, 1)
fin_result.view(-1)[idxs[mask != 0]] = result[result != blank_id]
padding = (fin_result == 0)
padding = fin_result == 0
return fin_result, padding

def sync_sd(self, params: spec.ParameterContainer) -> None:
Expand Down
2 changes: 1 addition & 1 deletion algorithmic_efficiency/workloads/ogbg/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def predictions_match_labels(*,
**kwargs) -> jnp.ndarray:
"""Returns a binary array indicating where predictions match the labels."""
del kwargs # Unused.
preds = (logits > 0)
preds = logits > 0
return (preds == labels).astype(jnp.float32)


Expand Down
2 changes: 1 addition & 1 deletion algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _binary_cross_entropy_with_mask(

# Numerically stable implementation of BCE loss.
# This mimics TensorFlow's tf.nn.sigmoid_cross_entropy_with_logits().
positive_logits = (logits >= 0)
positive_logits = logits >= 0
relu_logits = jnp.where(positive_logits, logits, 0)
abs_logits = jnp.where(positive_logits, logits, -logits)
losses = relu_logits - (logits * smoothed_labels) + (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def _binary_cross_entropy_with_mask(

# Numerically stable implementation of BCE loss.
# This mimics TensorFlow's tf.nn.sigmoid_cross_entropy_with_logits().
positive_logits = (logits >= 0)
positive_logits = logits >= 0
relu_logits = torch.where(positive_logits, logits, 0)
abs_logits = torch.where(positive_logits, logits, -logits)
losses = relu_logits - (logits * smoothed_labels) + (
Expand Down
2 changes: 1 addition & 1 deletion algorithmic_efficiency/workloads/wmt/wmt_jax/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def beam_search(inputs,
def beam_search_loop_cond_fn(state):
"""Beam search loop termination condition."""
# Have we reached max decoding length?
not_at_end = (state.cur_index < max_decode_len - 1)
not_at_end = state.cur_index < max_decode_len - 1

# Is no further progress in the beam search possible?
# Get the best possible scores from alive sequences.
Expand Down
2 changes: 1 addition & 1 deletion algorithmic_efficiency/workloads/wmt/wmt_pytorch/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def beam_search(
def beam_search_loop_cond_fn(state: BeamState) -> bool:
"""Beam search loop termination condition."""
# Have we reached max decoding length?
not_at_end = (state.cur_index < max_decode_len - 1)
not_at_end = state.cur_index < max_decode_len - 1

# Is no further progress in the beam search possible?
# Get the best possible scores from alive sequences.
Expand Down
4 changes: 3 additions & 1 deletion baselines/nadamw/pytorch/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def __init__(self,
raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}')
if not 0.0 <= weight_decay:
raise ValueError(f'Invalid weight_decay value: {weight_decay}')
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
defaults = {
'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay
}
super().__init__(params, defaults)

def __setstate__(self, state):
Expand Down
6 changes: 2 additions & 4 deletions datasets/dataset_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,9 @@ def _maybe_prompt_for_deletion(paths, interactive_deletion):


def _download_url(url, data_dir):
url = url
data_dir = os.path.expanduser(data_dir)
file_path = os.path.join(data_dir, url.split('/')[-1])
response = requests.get(url, stream=True)
response = requests.get(url, stream=True, timeout=600)
total_size_in_bytes = int(response.headers.get('Content-length', 0))
total_size_in_mib = total_size_in_bytes / (2**20)
progress_bar = tqdm.tqdm(total=total_size_in_mib, unit='MiB', unit_scale=True)
Expand All @@ -209,7 +208,7 @@ def _download_url(url, data_dir):
f.write(chunk)
progress_bar.close()
if (progress_bar.total != 0 and progress_bar.n != progress_bar.total):
raise Exception(
raise RuntimeError(
('Download corrupted, size {n} MiB from {url} does not match '
'expected size {size} MiB').format(
url=url, n=progress_bar.n, size=progress_bar.total))
Expand Down Expand Up @@ -402,7 +401,6 @@ def setup_imagenet_pytorch(data_dir):
os.path.join(imagenet_pytorch_data_dir, IMAGENET_VAL_TAR_FILENAME),
os.path.join(imagenet_pytorch_data_dir, 'val'))

cwd = os.path.join(imagenet_pytorch_data_dir, 'train')
valprep_command = [
'wget',
'-qO-',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def __init__(self,
raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}')
if not 0.0 <= weight_decay:
raise ValueError(f'Invalid weight_decay value: {weight_decay}')
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
defaults = {
'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay
}
super().__init__(params, defaults)

def __setstate__(self, state):
Expand Down
6 changes: 3 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ full_dev =
# Dependencies for developing the package
dev =
isort==5.10.1
pylint==2.14.5
pylint==2.16.1
pytest==7.1.2
yapf==0.32.0
pre-commit==2.20.0
Expand Down Expand Up @@ -252,13 +252,13 @@ defining-attr-methods=__init__,__new__,setUp
# "class_" is also a valid for the first argument to a class method.
valid-classmethod-first-arg=cls,class_
[pylint.EXCEPTIONS]
overgeneral-exceptions=StandardError,Exception,BaseException
overgeneral-exceptions=builtins.StandardError,builtins.Exception,builtins.BaseException
[pylint.IMPORTS]
# Deprecated modules which should not be used, separated by a comma
deprecated-modules=regsub,TERMIOS,Bastion,rexec,sets
[pylint.FORMAT]
# List of checkers and warnings to disable.
disable=abstract-method,access-member-before-definition,arguments-differ,assignment-from-no-return,attribute-defined-outside-init,bad-mcs-classmethod-argument,bad-option-value,c-extension-no-member,consider-merging-isinstance,consider-using-dict-comprehension,consider-using-enumerate,consider-using-in,consider-using-set-comprehension,consider-using-ternary,deprecated-method,design,file-ignored,fixme,global-statement,import-error,inconsistent-return-statements,invalid-unary-operand-type,len-as-condition,locally-disabled,locally-enabled,misplaced-comparison-constant,missing-docstring,multiple-imports,no-else-return,no-member,no-name-in-module,no-self-use,no-value-for-parameter,not-an-iterable,not-context-manager,pointless-except,protected-access,redefined-argument-from-local,signature-differs,similarities,simplifiable-if-expression,star-args,super-init-not-called,suppressed-message,too-many-function-args,trailing-comma-tuple,trailing-newlines,ungrouped-imports,unnecessary-pass,unsubscriptable-object,unused-argument,useless-object-inheritance,useless-return,useless-suppression,wrong-import-order,wrong-import-position,unneeded-not,unexpected-keyword-arg,redundant-keyword-arg,unspecified-encoding,logging-fstring-interpolation,consider-using-f-string
disable=abstract-method,access-member-before-definition,arguments-differ,assignment-from-no-return,attribute-defined-outside-init,bad-mcs-classmethod-argument,bad-option-value,c-extension-no-member,consider-merging-isinstance,consider-using-dict-comprehension,consider-using-enumerate,consider-using-in,consider-using-set-comprehension,consider-using-ternary,deprecated-method,design,file-ignored,fixme,global-statement,import-error,inconsistent-return-statements,invalid-unary-operand-type,len-as-condition,locally-disabled,locally-enabled,misplaced-comparison-constant,missing-docstring,multiple-imports,no-else-return,no-member,no-name-in-module,no-self-use,no-value-for-parameter,not-an-iterable,not-context-manager,pointless-except,protected-access,redefined-argument-from-local,signature-differs,similarities,simplifiable-if-expression,star-args,super-init-not-called,suppressed-message,too-many-function-args,trailing-comma-tuple,trailing-newlines,ungrouped-imports,unnecessary-pass,unsubscriptable-object,unused-argument,useless-object-inheritance,useless-return,useless-suppression,wrong-import-order,wrong-import-position,unneeded-not,unexpected-keyword-arg,redundant-keyword-arg,unspecified-encoding,logging-fstring-interpolation,consider-using-f-string,use-dict-literal

# Maximum number of characters on a single line.
max-line-length=80
Expand Down
13 changes: 7 additions & 6 deletions tests/reference_algorithm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,12 +266,13 @@ def _build_input_queue(self, *args, **kwargs):

def _fake_iter():
while True:
fake_batch = dict(
num_nodes=tf.ones((1,), dtype=tf.int64),
edge_index=tf.ones((1, 2), dtype=tf.int64),
node_feat=tf.random.normal((1, 9)),
edge_feat=tf.random.normal((1, 3)),
labels=tf.ones((self._num_outputs,)))
fake_batch = {
'num_nodes': tf.ones((1,), dtype=tf.int64),
'edge_index': tf.ones((1, 2), dtype=tf.int64),
'node_feat': tf.random.normal((1, 9)),
'edge_feat': tf.random.normal((1, 3)),
'labels': tf.ones((self._num_outputs,)),
}
yield fake_batch

fake_batch_iter = ogbg_input_pipeline._get_batch_iterator(
Expand Down

0 comments on commit c7691e2

Please sign in to comment.