From db0e2f7670b3dc376221f5e7dd4f093d1ceecb47 Mon Sep 17 00:00:00 2001 From: yallup Date: Wed, 10 Jul 2024 15:34:06 +0100 Subject: [PATCH] hyperparams for ex --- clax/clax.py | 8 +++--- clax/network.py | 1 - examples/bayes_factors.py | 57 ++++++++++++++++++++++++++------------- 3 files changed, 41 insertions(+), 25 deletions(-) diff --git a/clax/clax.py b/clax/clax.py index aba93e9..41b3378 100644 --- a/clax/clax.py +++ b/clax/clax.py @@ -116,11 +116,9 @@ def _init_state(self, **kwargs): if not optimizer: optimizer = optax.chain( # optax.clip_by_global_norm(1.0), - # optax.adaptive_grad_clip(0.1), - optax.adaptive_grad_clip(1.0), - # optax.adam(lr), - # optax.adamw(self.schedule), - optax.adamw(lr), + optax.adaptive_grad_clip(0.01), + optax.adamw(self.schedule), + # optax.adamw(lr), ) # self.state = train_state.TrainState.create( diff --git a/clax/network.py b/clax/network.py index 6a4504e..141efb7 100644 --- a/clax/network.py +++ b/clax/network.py @@ -33,7 +33,6 @@ class Network(nn.Module): @nn.compact def __call__(self, x, train: bool): x = nn.Dense(self.n_initial)(x) - # nn.BatchNorm(use_running_average=not train)(x) x = nn.BatchNorm(use_running_average=not train)(x) x = nn.silu(x) for i in range(self.n_layers): diff --git a/examples/bayes_factors.py b/examples/bayes_factors.py index 1f2be27..7e899b9 100644 --- a/examples/bayes_factors.py +++ b/examples/bayes_factors.py @@ -7,19 +7,18 @@ import matplotlib.pyplot as plt import numpy as np -import optax +from flax import linen as nn from mpl_toolkits.axes_grid1.inset_locator import mark_inset, zoomed_inset_axes from scipy.stats import multivariate_normal -from sklearn.datasets import make_sparse_spd_matrix from sklearn.model_selection import train_test_split from clax import Classifier # from clax.network import Network -np.random.seed(2025) +np.random.seed(2024) dim = 100 -n_sample = 100000 +n_sample = 500000 c1 = np.random.rand(dim) - 0.5 @@ -31,29 +30,49 @@ midpoint = (m1 + m2) / 2 error = 0.025 -C1 = make_sparse_spd_matrix(dim, norm_diag=True, smallest_coef=0.01, largest_coef=0.25) -C2 = make_sparse_spd_matrix(dim, norm_diag=True, smallest_coef=0.01, largest_coef=0.25) - M_0 = multivariate_normal(mean=m1, cov=np.eye(dim) * error) M_1 = multivariate_normal(mean=m2, cov=np.eye(dim) * error) M_2 = multivariate_normal(mean=midpoint, cov=np.eye(dim) * error) - X = np.concatenate((M_0.rvs(n_sample), M_1.rvs(n_sample))) y = np.concatenate((np.zeros(n_sample), np.ones(n_sample))) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.01) - -# # Arg is the number classes classifier = Classifier() -chain = optax.chain( - optax.adaptive_grad_clip(1.0), - optax.adamw(1e-3), -) +# optionally specify the optimizer manually +# chain = optax.chain( +# optax.adaptive_grad_clip(1.0), +# optax.adamw(1e-3), +# ) + + +class Network(nn.Module): + """A simple MLP classifier.""" -classifier.fit(X_train, y_train, epochs=500, optimizer=chain, batch_size=1000) + n_initial: int = 256 + n_hidden: int = 64 + n_layers: int = 3 + n_out: int = 1 + # act = nn.silu + + @nn.compact + def __call__(self, x, train: bool): + x = nn.Dense(self.n_initial)(x) + # hacky way to make batchnorm have no impact + nn.BatchNorm(use_running_average=not train)(x) + x = nn.silu(x) + for i in range(self.n_layers): + x = nn.Dense(self.n_hidden)(x) + x = nn.silu(x) + x = nn.Dense(self.n_out)(x) + return x + + +lr = 1e-4 +classifier.network = Network(n_out=1, n_initial=1056, n_hidden=128, n_layers=3) +classifier.fit(X_train, y_train, epochs=100, lr=lr, batch_size=10000) true_k = M_1.logpdf(X_test) - M_0.logpdf(X_test) network_k = classifier.predict(X_test).squeeze() @@ -65,12 +84,12 @@ def plot(): - f, a = plt.subplots(1, 1) + f, a = plt.subplots(1, 1, figsize=(6, 4)) a.scatter( true_k_m2, network_k_m2, alpha=0.5, - c="C4", + c="C1", label=r"$M_2$ test", marker=".", rasterized=True, @@ -172,7 +191,7 @@ def plot(): a.set_xlabel(r"True $\ln K$") a.set_ylabel(r"Network $\ln K$") f.tight_layout() - f.savefig("en_metal.pdf") + f.savefig("en.pdf") plot() @@ -180,4 +199,4 @@ def plot(): f, a = plt.subplots() a.plot(classifier.trace.losses) a.set_yscale("log") -f.savefig("losses_metal.pdf") +f.savefig("losses.pdf")