From c0c53ca08ca702ff49467a4916ff8dd67ab5e2d8 Mon Sep 17 00:00:00 2001 From: Parth Raut Date: Sun, 22 Dec 2024 14:07:32 -0500 Subject: [PATCH] jax fixes --- examples/jax/train_single.py | 125 ++++++++++++++++++++++++ examples/jax/train_single_NO_MONITOR.py | 114 +++++++++++++++++++++ 2 files changed, 239 insertions(+) create mode 100644 examples/jax/train_single.py create mode 100644 examples/jax/train_single_NO_MONITOR.py diff --git a/examples/jax/train_single.py b/examples/jax/train_single.py new file mode 100644 index 00000000..27892700 --- /dev/null +++ b/examples/jax/train_single.py @@ -0,0 +1,125 @@ +# Adapted from Training a simple neural network, with tensorflow/datasets data loading (https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html) + +import jax.numpy as jnp +from jax import grad, jit, vmap +from jax import random +from jax.scipy.special import logsumexp +import tensorflow as tf +import tensorflow_datasets as tfds +import time +from zeus.monitor import ZeusMonitor +from zeus.optimizer.power_limit import GlobalPowerLimitOptimizer + +# A helper function to randomly initialize weights and biases +# for a dense neural network layer +def random_layer_params(m, n, key, scale=1e-2): + w_key, b_key = random.split(key) + return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,)) + +# Initialize all layers for a fully-connected neural network with sizes "sizes" +def init_network_params(sizes, key): + keys = random.split(key, len(sizes)) + return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)] + +layer_sizes = [784, 512, 512, 10] +step_size = 0.01 +num_epochs = 10 +batch_size = 128 +n_targets = 10 +params = init_network_params(layer_sizes, random.key(0)) + +def relu(x): + return jnp.maximum(0, x) + +def predict(params, image): + # per-example predictions + activations = image + for w, b in params[:-1]: + outputs = jnp.dot(w, activations) + b + activations = relu(outputs) + + final_w, final_b = params[-1] + logits = jnp.dot(final_w, activations) + final_b + return logits - logsumexp(logits) + +def one_hot(x, k, dtype=jnp.float32): + """Create a one-hot encoding of x of size k.""" + return jnp.array(x[:, None] == jnp.arange(k), dtype) + +def accuracy(params, images, targets): + target_class = jnp.argmax(targets, axis=1) + predicted_class = jnp.argmax(batched_predict(params, images), axis=1) + return jnp.mean(predicted_class == target_class) + +# Make a batched version of the `predict` function +batched_predict = vmap(predict, in_axes=(None, 0)) + +def loss(params, images, targets): + preds = batched_predict(params, images) + return -jnp.mean(preds * targets) + +@jit +def update(params, x, y): + grads = grad(loss)(params, x, y) + return [(w - step_size * dw, b - step_size * db) + for (w, b), (dw, db) in zip(params, grads)] + + +# Ensure TF does not see GPU and grab all GPU memory. +tf.config.set_visible_devices([], device_type='GPU') + +data_dir = '/tmp/tfds' + +# Fetch full datasets for evaluation +# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1) +# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy +mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True) +mnist_data = tfds.as_numpy(mnist_data) +train_data, test_data = mnist_data['train'], mnist_data['test'] +num_labels = info.features['label'].num_classes +h, w, c = info.features['image'].shape +num_pixels = h * w * c + +# Full train set +train_images, train_labels = train_data['image'], train_data['label'] +train_images = jnp.reshape(train_images, (len(train_images), num_pixels)) +train_labels = one_hot(train_labels, num_labels) + +# Full test set +test_images, test_labels = test_data['image'], test_data['label'] +test_images = jnp.reshape(test_images, (len(test_images), num_pixels)) +test_labels = one_hot(test_labels, num_labels) + +print('Train:', train_images.shape, train_labels.shape) +print('Test:', test_images.shape, test_labels.shape) + +def get_train_batches(): + # as_supervised=True gives us the (image, label) as a tuple instead of a dict + ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir) + # You can build up an arbitrary tf.data input pipeline + ds = ds.batch(batch_size).prefetch(1) + # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays + return tfds.as_numpy(ds) + +monitor = ZeusMonitor() +plo = GlobalPowerLimitOptimizer(monitor) + +for epoch in range(num_epochs): + start_time = time.time() + + plo.on_epoch_begin() + for x, y in get_train_batches(): + plo.on_step_begin() + x = jnp.reshape(x, (len(x), num_pixels)) + y = one_hot(y, num_labels) + params = update(params, x, y) + plo.on_step_end() + plo.on_epoch_end() + + epoch_time = time.time() - start_time + + train_acc = accuracy(params, train_images, train_labels) + test_acc = accuracy(params, test_images, test_labels) + print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time)) + print("Training set accuracy {}".format(train_acc)) + print("Test set accuracy {}".format(test_acc)) \ No newline at end of file diff --git a/examples/jax/train_single_NO_MONITOR.py b/examples/jax/train_single_NO_MONITOR.py new file mode 100644 index 00000000..65f1a6bd --- /dev/null +++ b/examples/jax/train_single_NO_MONITOR.py @@ -0,0 +1,114 @@ +# Adapted from Training a simple neural network, with tensorflow/datasets data loading (https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html) + +import jax.numpy as jnp +from jax import grad, jit, vmap +from jax import random +from jax.scipy.special import logsumexp +import tensorflow as tf +import tensorflow_datasets as tfds +import time + +# A helper function to randomly initialize weights and biases +# for a dense neural network layer +def random_layer_params(m, n, key, scale=1e-2): + w_key, b_key = random.split(key) + return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,)) + +# Initialize all layers for a fully-connected neural network with sizes "sizes" +def init_network_params(sizes, key): + keys = random.split(key, len(sizes)) + return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)] + +layer_sizes = [784, 512, 512, 10] +step_size = 0.01 +num_epochs = 10 +batch_size = 128 +n_targets = 10 +params = init_network_params(layer_sizes, random.key(0)) + +def relu(x): + return jnp.maximum(0, x) + +def predict(params, image): + # per-example predictions + activations = image + for w, b in params[:-1]: + outputs = jnp.dot(w, activations) + b + activations = relu(outputs) + + final_w, final_b = params[-1] + logits = jnp.dot(final_w, activations) + final_b + return logits - logsumexp(logits) + +def one_hot(x, k, dtype=jnp.float32): + """Create a one-hot encoding of x of size k.""" + return jnp.array(x[:, None] == jnp.arange(k), dtype) + +def accuracy(params, images, targets): + target_class = jnp.argmax(targets, axis=1) + predicted_class = jnp.argmax(batched_predict(params, images), axis=1) + return jnp.mean(predicted_class == target_class) + +# Make a batched version of the `predict` function +batched_predict = vmap(predict, in_axes=(None, 0)) + +def loss(params, images, targets): + preds = batched_predict(params, images) + return -jnp.mean(preds * targets) + +@jit +def update(params, x, y): + grads = grad(loss)(params, x, y) + return [(w - step_size * dw, b - step_size * db) + for (w, b), (dw, db) in zip(params, grads)] + + +# Ensure TF does not see GPU and grab all GPU memory. +tf.config.set_visible_devices([], device_type='GPU') + +data_dir = '/tmp/tfds' + +# Fetch full datasets for evaluation +# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1) +# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy +mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True) +mnist_data = tfds.as_numpy(mnist_data) +train_data, test_data = mnist_data['train'], mnist_data['test'] +num_labels = info.features['label'].num_classes +h, w, c = info.features['image'].shape +num_pixels = h * w * c + +# Full train set +train_images, train_labels = train_data['image'], train_data['label'] +train_images = jnp.reshape(train_images, (len(train_images), num_pixels)) +train_labels = one_hot(train_labels, num_labels) + +# Full test set +test_images, test_labels = test_data['image'], test_data['label'] +test_images = jnp.reshape(test_images, (len(test_images), num_pixels)) +test_labels = one_hot(test_labels, num_labels) + +print('Train:', train_images.shape, train_labels.shape) +print('Test:', test_images.shape, test_labels.shape) + +def get_train_batches(): + # as_supervised=True gives us the (image, label) as a tuple instead of a dict + ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir) + # You can build up an arbitrary tf.data input pipeline + ds = ds.batch(batch_size).prefetch(1) + # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays + return tfds.as_numpy(ds) + +for epoch in range(num_epochs): + start_time = time.time() + for x, y in get_train_batches(): + x = jnp.reshape(x, (len(x), num_pixels)) + y = one_hot(y, num_labels) + params = update(params, x, y) + epoch_time = time.time() - start_time + + train_acc = accuracy(params, train_images, train_labels) + test_acc = accuracy(params, test_images, test_labels) + print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time)) + print("Training set accuracy {}".format(train_acc)) + print("Test set accuracy {}".format(test_acc)) \ No newline at end of file