Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moving from Python3.8 to Python 3.11 #811

Open
wants to merge 29 commits into
base: python_upgrades
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
ce99901
feat: package updates with python311
init-22 Nov 14, 2024
21fb3f9
fix: absl package version change
init-22 Nov 14, 2024
67b9f15
fix: pytorch version change
init-22 Nov 14, 2024
78df36f
fix: tf version to use numpy < 2
init-22 Nov 14, 2024
2584416
fix: librispeech requirement of tf-text rolled back to v2.17
init-22 Nov 15, 2024
d603ce9
fix: using the main repo and branch for testing
init-22 Nov 15, 2024
be68f8c
fix: overflow error resolved and PRNGKey to key
init-22 Nov 16, 2024
e890c89
fix: minor changes in docs
init-22 Nov 20, 2024
1bc2a7b
fix: changing the python versions in workflow to pass the tests
init-22 Nov 30, 2024
7a0fee3
fix: changing numpy compatible version
init-22 Nov 30, 2024
7cdea16
adding key_data to check the CI tests
init-22 Nov 30, 2024
7264c3f
fix: updated packge of sacrebleu changed the way it used to work, hen…
init-22 Dec 1, 2024
abbdc82
fix: temporarily commenting tfa
init-22 Dec 1, 2024
86029a7
fix: explicitly using mask kwarg to use MultiHeadDotProductAttention …
init-22 Dec 2, 2024
aca45a2
fix: using flax.core.pop instead of variables.pop, better way to upda…
init-22 Dec 2, 2024
2618c5e
fix: changing the traindiffs_tests branch to main again
init-22 Dec 2, 2024
8c90625
fix: unfreeze() in test_param_shapes expect FrozenDict also added fla…
init-22 Dec 2, 2024
1b587b7
fix: formatting changes with yapf
init-22 Dec 3, 2024
c65d93e
fix: running yapf again with 0.32, earlier using 0.43
init-22 Dec 3, 2024
3afd1df
fix: running yapf again with 0.32, earlier using 0.43
init-22 Dec 3, 2024
6ff2010
fix: latest versions of typing dont support Text instead str is recom…
init-22 Dec 3, 2024
55bacbd
fix: minor yapf
init-22 Dec 3, 2024
cfd5a00
Merge branch 'python_upgrades' into python311
priyakasimbeg Dec 6, 2024
5eac985
fix: going back to sacrebleu v1.3.1
init-22 Dec 7, 2024
7867711
feat: custom tf_addons support in TF2.18
init-22 Dec 17, 2024
d6dd2e8
fix: resolving pylint issues in custom_tf_addons
init-22 Dec 17, 2024
a0b587a
resolved pyline and changed the pylint version to current version of …
init-22 Dec 17, 2024
9393145
fix: removing tensorflow addons from setup cfg
init-22 Dec 18, 2024
53eff1d
fix: adding absolute paths for custom_tf_addons in randaugment
init-22 Dec 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 24 additions & 24 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -25,10 +25,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -42,10 +42,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -59,10 +59,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -77,10 +77,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -96,10 +96,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -113,10 +113,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -130,10 +130,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -148,10 +148,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -166,10 +166,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install Modules and Run
Expand All @@ -184,10 +184,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install pytest
Expand All @@ -208,10 +208,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.11.10
cache: 'pip' # Cache pip dependencies\.
cache-dependency-path: '**/setup.py'
- name: Install pytest
Expand Down
12 changes: 6 additions & 6 deletions .github/workflows/linting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v2
with:
python-version: 3.9
python-version: 3.11.10
- name: Install pylint
run: |
python -m pip install --upgrade pip
Expand All @@ -27,10 +27,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v2
with:
python-version: 3.9
python-version: 3.11.10
- name: Install isort
run: |
python -m pip install --upgrade pip
Expand All @@ -43,10 +43,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.9
- name: Set up Python 3.11.10
uses: actions/setup-python@v2
with:
python-version: 3.9
python-version: 3.11.10
- name: Install yapf
run: |
python -m pip install --upgrade pip
Expand Down
2 changes: 1 addition & 1 deletion GETTING_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ The specs on the benchmarking machines are:

> **Prerequisites:**
>
> - Python minimum requirement >= 3.8
> - Python minimum requirement >= 3.11
> - CUDA 12.1
> - NVIDIA Driver version 535.104.05
Expand Down
2 changes: 1 addition & 1 deletion algorithmic_efficiency/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def save_checkpoint(framework: str,
target=checkpoint_state,
step=global_step,
overwrite=True,
keep=np.Inf if save_intermediate_checkpoints else 1)
keep=np.inf if save_intermediate_checkpoints else 1)
else:
if not save_intermediate_checkpoints:
checkpoint_files = gfile.glob(
Expand Down
16 changes: 8 additions & 8 deletions algorithmic_efficiency/halton.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import functools
import itertools
import math
from typing import Any, Callable, Dict, List, Sequence, Text, Tuple, Union
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union

from absl import logging
from numpy import random

_SweepSequence = List[Dict[Text, Any]]
_GeneratorFn = Callable[[float], Tuple[Text, float]]
_SweepSequence = List[Dict[str, Any]]
_GeneratorFn = Callable[[float], Tuple[str, float]]


def generate_primes(n: int) -> List[int]:
Expand Down Expand Up @@ -195,10 +195,10 @@ def generate_sequence(num_samples: int,
return halton_sequence


def _generate_double_point(name: Text,
def _generate_double_point(name: str,
min_val: float,
max_val: float,
scaling: Text,
scaling: str,
halton_point: float) -> Tuple[str, float]:
"""Generate a float hyperparameter value from a Halton sequence point."""
if scaling not in ['linear', 'log']:
Expand Down Expand Up @@ -234,7 +234,7 @@ def interval(start: int, end: int) -> Tuple[int, int]:
return start, end


def loguniform(name: Text, range_endpoints: Tuple[int, int]) -> _GeneratorFn:
def loguniform(name: str, range_endpoints: Tuple[int, int]) -> _GeneratorFn:
min_val, max_val = range_endpoints
return functools.partial(_generate_double_point,
name,
Expand All @@ -244,8 +244,8 @@ def loguniform(name: Text, range_endpoints: Tuple[int, int]) -> _GeneratorFn:


def uniform(
name: Text, search_points: Union[_DiscretePoints,
Tuple[int, int]]) -> _GeneratorFn:
name: str, search_points: Union[_DiscretePoints,
Tuple[int, int]]) -> _GeneratorFn:
if isinstance(search_points, _DiscretePoints):
return functools.partial(_generate_discrete_point,
name,
Expand Down
2 changes: 1 addition & 1 deletion algorithmic_efficiency/logger_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def _get_system_software_info() -> Dict:
system_software_info['os_platform'] = \
platform.platform() # Ex. 'Linux-5.4.48-x86_64-with-glibc2.29'
system_software_info['python_version'] = platform.python_version(
) # Ex. '3.8.10'
) # Ex. '3.11.10'
system_software_info['python_compiler'] = platform.python_compiler(
) # Ex. 'GCC 9.3.0'
# Note: do not store hostname as that may be sensitive
Expand Down
16 changes: 8 additions & 8 deletions algorithmic_efficiency/random_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,30 @@

# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an
# unsigned int), while RandomState.randint only accepts and returns signed ints.
MAX_INT32 = 2**31
MIN_INT32 = -MAX_INT32
MAX_UINT32 = 2**32 - 1
runame marked this conversation as resolved.
Show resolved Hide resolved
MIN_UINT32 = 0

SeedType = Union[int, list, np.ndarray]


def _signed_to_unsigned(seed: SeedType) -> SeedType:
if isinstance(seed, int):
return seed % 2**32
return seed % MAX_UINT32
if isinstance(seed, list):
return [s % 2**32 for s in seed]
return [s % MAX_UINT32 for s in seed]
if isinstance(seed, np.ndarray):
return np.array([s % 2**32 for s in seed.tolist()])
return np.array([s % MAX_UINT32 for s in seed.tolist()])


def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32)
new_seed = rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32)
return [new_seed, data]


def _split(seed: SeedType, num: int = 2) -> SeedType:
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2])
return rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32, size=[num, 2])


def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name
Expand Down Expand Up @@ -75,5 +75,5 @@ def split(seed: SeedType, num: int = 2) -> SeedType:
def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name
if FLAGS.framework == 'jax':
_check_jax_install()
return jax_rng.PRNGKey(seed)
return jax_rng.key(seed)
return _PRNGKey(seed)
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from flax import jax_utils
from flax import linen as nn
from flax.core import pop
import jax
from jax import lax
import jax.numpy as jnp
Expand Down Expand Up @@ -75,8 +76,8 @@ def sync_batch_stats(
# In this case each device has its own version of the batch statistics
# and we average them.
avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x')
new_model_state = model_state.copy(
{'batch_stats': avg_fn(model_state['batch_stats'])})
new_model_state = model_state.copy()
new_model_state['batch_stats'] = avg_fn(model_state['batch_stats'])
return new_model_state

def init_model_fn(
Expand All @@ -93,7 +94,7 @@ def init_model_fn(
input_shape = (1, 32, 32, 3)
variables = jax.jit(model.init)({'params': rng},
jnp.ones(input_shape, model.dtype))
model_state, params = variables.pop('params')
model_state, params = pop(variables, 'params')
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
model_state = jax_utils.replicate(model_state)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import math

import tensorflow as tf
from tensorflow_addons import image as contrib_image

#from tensorflow_addons import image as contrib_image
runame marked this conversation as resolved.
Show resolved Hide resolved

# This signifies the max integer that the controller RNN could predict for the
# augmentation scheme.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from flax import jax_utils
from flax import linen as nn
from flax.core import pop
import jax
from jax import lax
import jax.numpy as jnp
Expand Down Expand Up @@ -79,8 +80,8 @@ def sync_batch_stats(
# In this case each device has its own version of the batch statistics and
# we average them.
avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x')
new_model_state = model_state.copy(
{'batch_stats': avg_fn(model_state['batch_stats'])})
new_model_state = model_state.copy() # Create a shallow copy
new_model_state['batch_stats'] = avg_fn(model_state['batch_stats'])
return new_model_state

def init_model_fn(
Expand Down Expand Up @@ -111,7 +112,7 @@ def init_model_fn(
input_shape = (1, 224, 224, 3)
variables = jax.jit(model.init)({'params': rng},
jnp.ones(input_shape, model.dtype))
model_state, params = variables.pop('params')
model_state, params = pop(variables, "params")
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
model_state = jax_utils.replicate(model_state)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from flax import jax_utils
from flax import linen as nn
from flax.core import pop
import jax
import jax.numpy as jnp

Expand All @@ -28,7 +29,7 @@ def initialized(self, key: spec.RandomState,
variables = jax.jit(
model.init)({'params': params_rng, 'dropout': dropout_rng},
jnp.ones(input_shape))
model_state, params = variables.pop('params')
model_state, params = pop(variables, "params")
return params, model_state

def init_model_fn(
Expand Down
Loading
Loading