Skip to content

Commit

Permalink
conditional_classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
yallup committed Jun 10, 2024
1 parent 14ab48d commit 3be37bf
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 23 deletions.
2 changes: 1 addition & 1 deletion clax/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from clax.clax import Classifier, Regressor
from clax.clax import Classifier, ConditionalClassifier, Regressor
92 changes: 73 additions & 19 deletions clax/clax.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, n=1, **kwargs):
self.network = Network(n_out=n)
self.state = None

def loss(self, params, batch_stats, batch, labels):
def loss(self, params, batch_stats, batch, labels, rng):
"""Loss function for training the classifier."""
output, updates = self.state.apply_fn(
{"params": params, "batch_stats": batch_stats},
Expand All @@ -57,16 +57,14 @@ def loss(self, params, batch_stats, batch, labels):
def _train(self, samples, labels, batches_per_epoch, **kwargs):
"""Internal wrapping of training loop."""
self.trace = Trace()
labels = jnp.array(labels, dtype=int)
samples = jnp.array(samples, dtype=jnp.float32)
batch_size = kwargs.get("batch_size", 1024)
epochs = kwargs.get("epochs", 10)
epochs *= batches_per_epoch

@jit
def update_step(state, samples, labels):
def update_step(state, samples, labels, rng):
(val, updates), grads = jax.value_and_grad(self.loss, has_aux=True)(
state.params, state.batch_stats, samples, labels
state.params, state.batch_stats, samples, labels, rng
)
state = state.apply_gradients(grads=grads) # , scale_value=val)
state = state.replace(batch_stats=updates["batch_stats"])
Expand All @@ -80,15 +78,15 @@ def update_step(state, samples, labels):
for k in tepochs:
self.rng, step_rng = random.split(self.rng)
perm, _ = map.sample(batch_size)
batch = samples[perm, :]
batch = samples[perm]
batch_label = labels[perm]
loss, self.state = update_step(self.state, batch, batch_label)
loss, self.state = update_step(self.state, batch, batch_label, step_rng)
losses.append(loss)
# self.state.losses.append(loss)
if (k + 1) % 50 == 0:
ma = jnp.mean(jnp.array(losses[-50:]))
self.trace.losses.append(ma)
tepochs.set_postfix(loss=ma)
tepochs.set_postfix(loss="{:.2e}".format(ma))
self.trace.iteration += 1
# lr_scale = otu.tree_get(self.state, "scale")
# self.trace.lr.append(lr_scale)
Expand Down Expand Up @@ -147,6 +145,8 @@ def fit(self, samples, labels, **kwargs):
self.ndims = samples.shape[-1]
if (not self.state) | restart:
self._init_state(**kwargs)
labels = jnp.array(labels, dtype=int)
samples = jnp.array(samples, dtype=jnp.float32)
self._train(samples, labels, batches_per_epoch, **kwargs)
self._predict_weight = lambda x: self.state.apply_fn(
{
Expand All @@ -157,23 +157,77 @@ def fit(self, samples, labels, **kwargs):
train=False,
)

def predict(self, samples, log=True):
def predict(self, samples):
"""Predict the class (log) - probabilities for the provided samples.
Args:
samples (np.ndarray): Samples to predict on.
log (bool): If True, return the log-probabilities. Defaults to True.
"""
return self._predict_weight(samples)


class ConditionalClassifier(Classifier):
def loss(self, params, batch_stats, batch, labels, rng):
"""Loss function for training the classifier."""

batch = jnp.concatenate([batch, labels], axis=0)
labels = jnp.concatenate(
[jnp.ones(batch.shape[0] // 2), jnp.zeros(batch.shape[0] // 2)]
)

output, updates = self.state.apply_fn(
{"params": params, "batch_stats": batch_stats},
batch,
train=True,
mutable=["batch_stats"],
)
# loss = optax.softmax_cross_entropy_with_integer_labels(
# output.squeeze(), labels
# ).mean()
loss = optax.sigmoid_binary_cross_entropy(output.squeeze(), labels).mean()
return loss, updates

def fit(self, samples_a, samples_b, **kwargs):
"""Fit the classifier on provided samples.
Args:
samples (np.ndarray): Samples to train on.
labels (np.array): integer class labels corresponding to each sample.
Keyword Args:
restart (bool): If True, reinitialise the model before training. Defaults to False.
batch_size (int): Size of the training batches. Defaults to 1024.
epochs (int): Number of training epochs. Defaults to 10.
lr (float): Learning rate. Defaults to 1e-2.
transition_steps (int): Number of steps to transition the learning rate.
Defaults to 100.
"""
restart = kwargs.get("restart", False)
batch_size = kwargs.get("batch_size", 1024)
self.ndims = kwargs.get("ndims", samples_a.shape[-1])
data_size = samples_a.shape[0]
batches_per_epoch = data_size // batch_size
if (not self.state) | restart:
self._init_state(**kwargs)
self._train(samples_a, samples_b, batches_per_epoch, **kwargs)
self._predict_weight = lambda x: self.state.apply_fn(
{
"params": self.state.params,
"batch_stats": self.state.batch_stats,
},
x,
train=False,
)

def predict(self, samples):
"""Predict the class (log) - probabilities for the provided samples.
Args:
samples (np.ndarray): Samples to predict on.
log (bool): If True, return the log-probabilities. Defaults to True.
"""
if log:
return self._predict_weight(samples)
else:
return nn.softmax(self._predict_weight(samples))

def log_score(self, samples):
logits = self._predict_weight(samples)
log_p = jax.nn.log_sigmoid(logits)
log_not_p = jax.nn.log_sigmoid(-logits)
return log_p - (1 - log_not_p)
return self._predict_weight(samples)


class Regressor(Classifier):
Expand Down
4 changes: 2 additions & 2 deletions clax/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,11 @@ class ConditionalNetwork(nn.Module):

@nn.compact
def __call__(self, x, y):
x = nn.concatenate([x, y])
x = nn.concatenate([x, y], axis=-1)
x = nn.Dense(self.n_initial)(x)
x = nn.silu(x)
for i in range(self.n_layers):
x = jnp.concatenate([x, y])
# x = jnp.concatenate([x, y])
x = nn.Dense(self.n_hidden)(x)
x = nn.silu(x)
x = nn.Dense(self.n_out)(x)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "clax"
version = "0.0.2"
version = "0.0.3"
description = "Prebuilt jax classifiers"
authors = ["David Yallup <[email protected]>"]
readme = "README.md"
Expand Down
Empty file added tests/___init__.py
Empty file.
37 changes: 37 additions & 0 deletions tests/test_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import jax.numpy as jnp
import numpy as np
import pytest
from jax import nn

from clax import Classifier, ConditionalClassifier, Regressor

# @pytest.mark.parametrize("n_classes", [2, 10])
# class TestClassifier:
# @pytest.fixture
# def classifier(self, n_classes):
# return Classifier(n_classes)

# def test_fit(self, classifier, n_classes):
# data_x = np.random.rand(100, 10)
# data_y = np.random.randint(0, n_classes, 100)
# classifier.fit(data_x, data_y)


@pytest.mark.parametrize("n_classes", [2, 10])
def test_classifier(n_classes):
classifier = Classifier(n_classes)
data_x = np.random.rand(100, 10)
data_y = np.random.randint(0, n_classes, 100)
classifier.fit(data_x, data_y)
y = classifier.predict(data_x)
assert y.shape == (100, n_classes)
assert np.isclose(nn.softmax(y).sum(axis=-1), 1).all()


def test_conditional_classifier():
classifier = ConditionalClassifier()
data_x = np.random.rand(100, 10)
data_y = np.random.rand(100, 10)
classifier.fit(data_x, data_y)
y = classifier.predict(data_x)
assert y.shape == (100, 1)

0 comments on commit 3be37bf

Please sign in to comment.