Skip to content

Commit

Permalink
Add clifford ablation
Browse files Browse the repository at this point in the history
Only do artifacts when they don't already exist in db
  • Loading branch information
hlinander committed Dec 21, 2023
1 parent 6df56ed commit 44696fe
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 71 deletions.
163 changes: 92 additions & 71 deletions experiments/clifford/clifford.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
from typing import List
import math
import matplotlib.pyplot as plt

from lib.train_dataclasses import TrainConfig
from lib.train_dataclasses import TrainRun
Expand All @@ -15,22 +16,30 @@
from lib.regression_metrics import create_regression_metrics

from lib.ddp import ddp_setup
from lib.ensemble import create_ensemble_config
from lib.ensemble import create_ensemble
from lib.ensemble import request_ensemble
from lib.ensemble import symlink_checkpoint_files

# from lib.ensemble import create_ensemble_config
# from lib.ensemble import create_ensemble

# from lib.ensemble import request_ensemble
# from lib.ensemble import symlink_checkpoint_files
from lib.files import prepare_results
from lib.render_psql import add_artifact, add_parameter
from lib.render_psql import add_artifact, add_parameter, has_artifact
from lib.serialization import serialize_human
from lib.generic_ablation import generic_ablation

# from lib.data_factory import register_dataset, get_factory
import lib.data_factory as data_factory
import lib.model_factory as model_factory
from lib.models.mlp import MLPConfig

# from lib.models.mlp import MLPConfig
from dataclasses import dataclass
from lib.dataspec import DataSpec
from lib.data_utils import create_sample_legacy
from lib.distributed_trainer import distributed_train
from lib.serialization import (
deserialize_model,
DeserializeConfig,
)


class CliffordLinear(torch.nn.Module):
Expand Down Expand Up @@ -217,15 +226,15 @@ def __len__(self):
return 256 * 10


def create_config(ensemble_id):
def create_config(clifford_width, clifford_depth, ensemble_id):
loss = torch.nn.L1Loss()

def reg_loss(outputs, batch):
return loss(outputs["logits"], batch["target"])

train_config = TrainConfig(
model_config=MLPConfig(widths=[512, 512, 512]),
# model_config=CliffordModelConfig(widths=[256, 256]),
# model_config=MLPConfig(widths=[512, 512, 512]),
model_config=CliffordModelConfig(widths=[clifford_width] * clifford_depth),
train_data_config=DataCableConfig(npoints=100),
val_data_config=DataCableConfig(npoints=100),
loss=reg_loss,
Expand All @@ -242,9 +251,10 @@ def reg_loss(outputs, batch):
compute_config=ComputeConfig(distributed=False, num_workers=10),
train_config=train_config,
train_eval=train_eval,
epochs=200,
epochs=50,
save_nth_epoch=1,
validate_nth_epoch=5,
visualize_terminal=False,
)
return train_run

Expand All @@ -270,67 +280,78 @@ def reg_loss(outputs, batch):

mf = model_factory.get_factory()
mf.register(CliffordModelConfig, CliffordModel)
ensemble_config = create_ensemble_config(create_config, 1)
request_ensemble(ensemble_config)
distributed_train(ensemble_config.members)
configs = generic_ablation(
create_config,
dict(
clifford_depth=[1, 2, 3],
clifford_width=[32, 64, 128, 256, 512],
ensemble_id=[0],
),
)
distributed_train(configs)
# ensemble_config = create_ensemble_config(create_config, 1)
# request_ensemble(ensemble_config)
# distributed_train(ensemble_config.members)

device_id = ddp_setup()
ensemble = create_ensemble(ensemble_config, device_id)
n_params = sum(p.numel() for p in ensemble.members[0].parameters())
add_parameter(ensemble_config.members[0], "parameters", f"{n_params}")

ds = data_factory.get_factory().create(DataCableConfig(npoints=100))
dl = torch.utils.data.DataLoader(
ds,
batch_size=9,
shuffle=False,
drop_last=False,
)
# ensemble = create_ensemble(ensemble_config, device_id)
for config in configs:
if has_artifact(config, "clifford_rope_test"):
continue
deserialized_model = deserialize_model(DeserializeConfig(config, device_id))
model = deserialized_model.model
n_params = sum(p.numel() for p in model.parameters())
add_parameter(config, "parameters", f"{n_params}")

ds = data_factory.get_factory().create(DataCableConfig(npoints=100))
dl = torch.utils.data.DataLoader(
ds,
batch_size=9,
shuffle=False,
drop_last=False,
)

result_path = prepare_results(
f"{Path(__file__).stem}",
ensemble_config,
)
symlink_checkpoint_files(ensemble, result_path)

import matplotlib.pyplot as plt

np.random.seed(42)
for batch in tqdm.tqdm(dl):
ys = batch["target"]
xs = batch["input"]

batch = {k: v.to(device_id) for k, v in batch.items()}
output = ensemble.members[0](batch)["logits"].cpu().detach().numpy()
fig, axs = plt.subplots(3, 3, figsize=(10, 10))
for idx, (start_deltas, delta, target) in enumerate(
zip(xs.cpu().numpy(), output, ys.numpy())
):
start = np.zeros_like(delta)
# breakpoint()
for i in range(start_deltas.shape[0]):
start[i + 1] = start[i] + start_deltas[i]
# delta = np.concatenate([[[0, 0]], delta], axis=0)
ax = axs[idx // 3, idx % 3]
ax.plot(start[:, 0], start[:, 1], "g--", label="initial", alpha=0.2)
ax.plot(
start[:, 0] + delta[:, 0],
start[:, 1] + delta[:, 1],
"r-",
label=ensemble_config.members[
0
].train_config.model_config.__class__.__name__,
)
ax.plot(
start[:, 0] + target[:, 0],
start[:, 1] + target[:, 1],
"b-",
label="target",
)
ax.legend()
ax.set_title(f"{length_cable(start + target)}, {99 * 0.05}")
fig.suptitle("80 iteration constraint resolve")
path = result_path / "clifford_test.png"
fig.savefig(path)
add_artifact(ensemble_config.members[0], "clifford_rope_test", path)
raise Exception("exit")
result_path = prepare_results(
f"{Path(__file__).stem}",
config,
)
# symlink_checkpoint_files(ensemble, result_path)

np.random.seed(42)
for batch in tqdm.tqdm(dl):
ys = batch["target"]
xs = batch["input"]

batch = {k: v.to(device_id) for k, v in batch.items()}
output = model(batch)["logits"].cpu().detach().numpy()
fig, axs = plt.subplots(3, 3, figsize=(10, 10))
for idx, (start_deltas, delta, target) in enumerate(
zip(xs.cpu().numpy(), output, ys.numpy())
):
start = np.zeros_like(delta)
# breakpoint()
for i in range(start_deltas.shape[0]):
start[i + 1] = start[i] + start_deltas[i]
# delta = np.concatenate([[[0, 0]], delta], axis=0)
ax = axs[idx // 3, idx % 3]
ax.plot(start[:, 0], start[:, 1], "g--", label="initial", alpha=0.2)
ax.plot(
start[:, 0] + delta[:, 0],
start[:, 1] + delta[:, 1],
"r-",
label=config.train_config.model_config.__class__.__name__,
)
ax.plot(
start[:, 0] + target[:, 0],
start[:, 1] + target[:, 1],
"b-",
label="target",
)
ax.legend()
ax.set_title(f"{length_cable(start + target)}, {99 * 0.05}")
fig.suptitle("80 iteration constraint resolve")
path = result_path / "clifford_test.png"
fig.savefig(path)
add_artifact(config, "clifford_rope_test", path)
break
# raise Exception("exit")
24 changes: 24 additions & 0 deletions lib/render_psql.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,30 @@ def add_artifact(train_run: TrainRun, name: str, path: Union[str, Path]):
print(f"[Database] Added artifact {name}: {path}")


def has_artifact(train_run: TrainRun, name: str):
# train_run_dict = train_run.serialize_human()
try:
setup_psql()
except psycopg.errors.OperationalError as e:
print("[Database] Could not connect to database, artifact not added.")
return (False, str(e))

train_run_dict = train_run.serialize_human()
with psycopg.connect(
"dbname=equiv user=postgres password=postgres",
host=os.getenv("EP_POSTGRES", env().postgres_host),
port=int(os.getenv("EP_POSTGRES_PORT", env().postgres_port)),
autocommit=False,
) as conn:
rows = conn.execute(
"""
SELECT * FROM artifacts WHERE train_id=%(train_id)s AND name=%(name)s
""",
dict(train_id=train_run_dict["train_id"], name=name),
)
return rows.fetchone() is not None


def add_ensemble_artifact(
ensemble_config: EnsembleConfig, name: str, path: Union[str, Path]
):
Expand Down

0 comments on commit 44696fe

Please sign in to comment.