Skip to content

Commit

Permalink
Using the Forward-Forward Algorithm for Image Classification to keras…
Browse files Browse the repository at this point in the history
… 3.0 (Tensorflow backend only) (#1932)

* migration to keras3

* add md and ipynb files
  • Loading branch information
chunduriv authored Nov 16, 2024
1 parent 67f981b commit 98359d8
Show file tree
Hide file tree
Showing 5 changed files with 678 additions and 381 deletions.
90 changes: 51 additions & 39 deletions examples/vision/forwardforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Title: Using the Forward-Forward Algorithm for Image Classification
Author: [Suvaditya Mukherjee](https://twitter.com/halcyonrayes)
Date created: 2023/01/08
Last modified: 2023/01/08
Last modified: 2024/09/17
Description: Training a Dense-layer model using the Forward-Forward algorithm.
Accelerator: GPU
"""
Expand Down Expand Up @@ -59,9 +59,13 @@
"""
## Setup imports
"""
import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import tensorflow as tf
from tensorflow import keras
import keras
from keras import ops
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
Expand Down Expand Up @@ -143,7 +147,7 @@ class FFDense(keras.layers.Layer):
def __init__(
self,
units,
optimizer,
init_optimizer,
loss_metric,
num_epochs=50,
use_bias=True,
Expand All @@ -163,7 +167,7 @@ def __init__(
bias_regularizer=bias_regularizer,
)
self.relu = keras.layers.ReLU()
self.optimizer = optimizer
self.optimizer = init_optimizer()
self.loss_metric = loss_metric
self.threshold = 1.5
self.num_epochs = num_epochs
Expand All @@ -172,7 +176,7 @@ def __init__(
# layer.

def call(self, x):
x_norm = tf.norm(x, ord=2, axis=1, keepdims=True)
x_norm = ops.norm(x, ord=2, axis=1, keepdims=True)
x_norm = x_norm + 1e-4
x_dir = x / x_norm
res = self.dense(x_dir)
Expand All @@ -192,22 +196,24 @@ def call(self, x):
def forward_forward(self, x_pos, x_neg):
for i in range(self.num_epochs):
with tf.GradientTape() as tape:
g_pos = tf.math.reduce_mean(tf.math.pow(self.call(x_pos), 2), 1)
g_neg = tf.math.reduce_mean(tf.math.pow(self.call(x_neg), 2), 1)
g_pos = ops.mean(ops.power(self.call(x_pos), 2), 1)
g_neg = ops.mean(ops.power(self.call(x_neg), 2), 1)

loss = tf.math.log(
loss = ops.log(
1
+ tf.math.exp(
tf.concat([-g_pos + self.threshold, g_neg - self.threshold], 0)
+ ops.exp(
ops.concatenate(
[-g_pos + self.threshold, g_neg - self.threshold], 0
)
)
)
mean_loss = tf.cast(tf.math.reduce_mean(loss), tf.float32)
mean_loss = ops.cast(ops.mean(loss), dtype="float32")
self.loss_metric.update_state([mean_loss])
gradients = tape.gradient(mean_loss, self.dense.trainable_weights)
self.optimizer.apply_gradients(zip(gradients, self.dense.trainable_weights))
return (
tf.stop_gradient(self.call(x_pos)),
tf.stop_gradient(self.call(x_neg)),
ops.stop_gradient(self.call(x_pos)),
ops.stop_gradient(self.call(x_neg)),
self.loss_metric.result(),
)

Expand Down Expand Up @@ -248,25 +254,24 @@ class FFNetwork(keras.Model):
# the `Adam` optimizer with a default learning rate of 0.03 as that was
# found to be the best rate after experimentation.
# Loss is tracked using `loss_var` and `loss_count` variables.
# Use legacy optimizer for Layer Optimizer to fix issue
# https://github.com/keras-team/keras-io/issues/1241

def __init__(
self,
dims,
layer_optimizer=keras.optimizers.legacy.Adam(learning_rate=0.03),
init_layer_optimizer=lambda: keras.optimizers.Adam(learning_rate=0.03),
**kwargs,
):
super().__init__(**kwargs)
self.layer_optimizer = layer_optimizer
self.loss_var = tf.Variable(0.0, trainable=False, dtype=tf.float32)
self.loss_count = tf.Variable(0.0, trainable=False, dtype=tf.float32)
self.init_layer_optimizer = init_layer_optimizer
self.loss_var = keras.Variable(0.0, trainable=False, dtype="float32")
self.loss_count = keras.Variable(0.0, trainable=False, dtype="float32")
self.layer_list = [keras.Input(shape=(dims[0],))]
self.metrics_built = False
for d in range(len(dims) - 1):
self.layer_list += [
FFDense(
dims[d + 1],
optimizer=self.layer_optimizer,
init_optimizer=self.init_layer_optimizer,
loss_metric=keras.metrics.Mean(),
)
]
Expand All @@ -280,9 +285,9 @@ def __init__(
@tf.function(reduce_retracing=True)
def overlay_y_on_x(self, data):
X_sample, y_sample = data
max_sample = tf.reduce_max(X_sample, axis=0, keepdims=True)
max_sample = tf.cast(max_sample, dtype=tf.float64)
X_zeros = tf.zeros([10], dtype=tf.float64)
max_sample = ops.amax(X_sample, axis=0, keepdims=True)
max_sample = ops.cast(max_sample, dtype="float64")
X_zeros = ops.zeros([10], dtype="float64")
X_update = xla.dynamic_update_slice(X_zeros, max_sample, [y_sample])
X_sample = xla.dynamic_update_slice(X_sample, X_update, [0])
return X_sample, y_sample
Expand All @@ -297,25 +302,23 @@ def overlay_y_on_x(self, data):
@tf.function(reduce_retracing=True)
def predict_one_sample(self, x):
goodness_per_label = []
x = tf.reshape(x, [tf.shape(x)[0] * tf.shape(x)[1]])
x = ops.reshape(x, [ops.shape(x)[0] * ops.shape(x)[1]])
for label in range(10):
h, label = self.overlay_y_on_x(data=(x, label))
h = tf.reshape(h, [-1, tf.shape(h)[0]])
h = ops.reshape(h, [-1, ops.shape(h)[0]])
goodness = []
for layer_idx in range(1, len(self.layer_list)):
layer = self.layer_list[layer_idx]
h = layer(h)
goodness += [tf.math.reduce_mean(tf.math.pow(h, 2), 1)]
goodness_per_label += [
tf.expand_dims(tf.reduce_sum(goodness, keepdims=True), 1)
]
goodness += [ops.mean(ops.power(h, 2), 1)]
goodness_per_label += [ops.expand_dims(ops.sum(goodness, keepdims=True), 1)]
goodness_per_label = tf.concat(goodness_per_label, 1)
return tf.cast(tf.argmax(goodness_per_label, 1), tf.float64)
return ops.cast(ops.argmax(goodness_per_label, 1), dtype="float64")

def predict(self, data):
x = data
preds = list()
preds = tf.map_fn(fn=self.predict_one_sample, elems=x)
preds = ops.vectorized_map(self.predict_one_sample, x)
return np.asarray(preds, dtype=int)

# This custom `train_step` function overrides the internal `train_step`
Expand All @@ -328,17 +331,26 @@ def predict(self, data):
# the Forward-Forward computation on it. The returned loss is the final
# loss value over all the layers.

@tf.function(jit_compile=True)
@tf.function(jit_compile=False)
def train_step(self, data):
x, y = data

if not self.metrics_built:
# build metrics to ensure they can be queried without erroring out.
# We can't update the metrics' state, as we would usually do, since
# we do not perform predictions within the train step
for metric in self.metrics:
if hasattr(metric, "build"):
metric.build(y, y)
self.metrics_built = True

# Flatten op
x = tf.reshape(x, [-1, tf.shape(x)[1] * tf.shape(x)[2]])
x = ops.reshape(x, [-1, ops.shape(x)[1] * ops.shape(x)[2]])

x_pos, y = tf.map_fn(fn=self.overlay_y_on_x, elems=(x, y))
x_pos, y = ops.vectorized_map(self.overlay_y_on_x, (x, y))

random_y = tf.random.shuffle(y)
x_neg, y = tf.map_fn(fn=self.overlay_y_on_x, elems=(x, random_y))
x_neg, y = tf.map_fn(self.overlay_y_on_x, (x, random_y))

h_pos, h_neg = x_pos, x_neg

Expand All @@ -351,7 +363,7 @@ def train_step(self, data):
else:
print(f"Passing layer {idx+1} now : ")
x = layer(x)
mean_res = tf.math.divide(self.loss_var, self.loss_count)
mean_res = ops.divide(self.loss_var, self.loss_count)
return {"FinalLoss": mean_res}


Expand Down Expand Up @@ -386,8 +398,8 @@ def train_step(self, data):
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=0.03),
loss="mse",
jit_compile=True,
metrics=[keras.metrics.Mean()],
jit_compile=False,
metrics=[],
)

epochs = 250
Expand All @@ -400,7 +412,7 @@ def train_step(self, data):
test set. We calculate the Accuracy Score to understand the results closely.
"""

preds = model.predict(tf.convert_to_tensor(x_test))
preds = model.predict(ops.convert_to_tensor(x_test))

preds = preds.reshape((preds.shape[0], preds.shape[1]))

Expand Down
Binary file modified examples/vision/img/forwardforward/forwardforward_15_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/vision/img/forwardforward/forwardforward_5_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 98359d8

Please sign in to comment.