Skip to content

Stochastic Weight Averaging (SWA) transforms for Optax with JAX

License

Notifications You must be signed in to change notification settings

activatedgeek/optax-swag

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SWAG in Optax

PyPI version

This package implements SWAG as an Optax transform to allow usage with JAX.

Installation

Install from pip as:

pip install optax-swag

To install the latest directly from source, run

pip install git+https://github.com/activatedgeek/optax-swag.git

Usage

To start updating the iterate statistics, use chaining as

import optax
from optax_swag import swag

optimizer = optax.chain(
    ...  ## Other optimizer and transform config.
    swag(freq, rank)  ## Always add as the last transform.
)

The SWAGState object can be accessed from the optimizer state list for downstream usage.

Sampling

A reference code to generate samples from the collected statistics is provided below.

import jax
import jax.numpy as jnp

from optax_swag import sample_swag

swa_opt_state = # Reference to a SWAGState object from the optimizer.
n_samples = 10

rng = jax.random.PRNGKey(42)
rng, *samples_rng = jax.random.split(rng, 1 + n_samples)

swag_sample_params = jax.vmap(sample_swag, in_axes=(0, None))(
    jnp.array(samples_rng), swa_opt_state)

The resulting swag_sample_params can now be used for downstream evaluation.

NOTE: Make sure to update non-parameter variables (e.g. BatchNorm running statistics) for each generated sample.

License

Apache 2.0