From 44696fed63d110890bbcadc3e63ca9039d87020b Mon Sep 17 00:00:00 2001 From: Hampus Linander Date: Thu, 21 Dec 2023 19:52:40 +0100 Subject: [PATCH] Add clifford ablation Only do artifacts when they don't already exist in db --- experiments/clifford/clifford.py | 163 +++++++++++++++++-------------- lib/render_psql.py | 24 +++++ 2 files changed, 116 insertions(+), 71 deletions(-) diff --git a/experiments/clifford/clifford.py b/experiments/clifford/clifford.py index 19dadab..b6b3d9d 100644 --- a/experiments/clifford/clifford.py +++ b/experiments/clifford/clifford.py @@ -6,6 +6,7 @@ import json from typing import List import math +import matplotlib.pyplot as plt from lib.train_dataclasses import TrainConfig from lib.train_dataclasses import TrainRun @@ -15,22 +16,30 @@ from lib.regression_metrics import create_regression_metrics from lib.ddp import ddp_setup -from lib.ensemble import create_ensemble_config -from lib.ensemble import create_ensemble -from lib.ensemble import request_ensemble -from lib.ensemble import symlink_checkpoint_files + +# from lib.ensemble import create_ensemble_config +# from lib.ensemble import create_ensemble + +# from lib.ensemble import request_ensemble +# from lib.ensemble import symlink_checkpoint_files from lib.files import prepare_results -from lib.render_psql import add_artifact, add_parameter +from lib.render_psql import add_artifact, add_parameter, has_artifact from lib.serialization import serialize_human +from lib.generic_ablation import generic_ablation # from lib.data_factory import register_dataset, get_factory import lib.data_factory as data_factory import lib.model_factory as model_factory -from lib.models.mlp import MLPConfig + +# from lib.models.mlp import MLPConfig from dataclasses import dataclass from lib.dataspec import DataSpec from lib.data_utils import create_sample_legacy from lib.distributed_trainer import distributed_train +from lib.serialization import ( + deserialize_model, + DeserializeConfig, +) class CliffordLinear(torch.nn.Module): @@ -217,15 +226,15 @@ def __len__(self): return 256 * 10 -def create_config(ensemble_id): +def create_config(clifford_width, clifford_depth, ensemble_id): loss = torch.nn.L1Loss() def reg_loss(outputs, batch): return loss(outputs["logits"], batch["target"]) train_config = TrainConfig( - model_config=MLPConfig(widths=[512, 512, 512]), - # model_config=CliffordModelConfig(widths=[256, 256]), + # model_config=MLPConfig(widths=[512, 512, 512]), + model_config=CliffordModelConfig(widths=[clifford_width] * clifford_depth), train_data_config=DataCableConfig(npoints=100), val_data_config=DataCableConfig(npoints=100), loss=reg_loss, @@ -242,9 +251,10 @@ def reg_loss(outputs, batch): compute_config=ComputeConfig(distributed=False, num_workers=10), train_config=train_config, train_eval=train_eval, - epochs=200, + epochs=50, save_nth_epoch=1, validate_nth_epoch=5, + visualize_terminal=False, ) return train_run @@ -270,67 +280,78 @@ def reg_loss(outputs, batch): mf = model_factory.get_factory() mf.register(CliffordModelConfig, CliffordModel) - ensemble_config = create_ensemble_config(create_config, 1) - request_ensemble(ensemble_config) - distributed_train(ensemble_config.members) + configs = generic_ablation( + create_config, + dict( + clifford_depth=[1, 2, 3], + clifford_width=[32, 64, 128, 256, 512], + ensemble_id=[0], + ), + ) + distributed_train(configs) + # ensemble_config = create_ensemble_config(create_config, 1) + # request_ensemble(ensemble_config) + # distributed_train(ensemble_config.members) device_id = ddp_setup() - ensemble = create_ensemble(ensemble_config, device_id) - n_params = sum(p.numel() for p in ensemble.members[0].parameters()) - add_parameter(ensemble_config.members[0], "parameters", f"{n_params}") - - ds = data_factory.get_factory().create(DataCableConfig(npoints=100)) - dl = torch.utils.data.DataLoader( - ds, - batch_size=9, - shuffle=False, - drop_last=False, - ) + # ensemble = create_ensemble(ensemble_config, device_id) + for config in configs: + if has_artifact(config, "clifford_rope_test"): + continue + deserialized_model = deserialize_model(DeserializeConfig(config, device_id)) + model = deserialized_model.model + n_params = sum(p.numel() for p in model.parameters()) + add_parameter(config, "parameters", f"{n_params}") + + ds = data_factory.get_factory().create(DataCableConfig(npoints=100)) + dl = torch.utils.data.DataLoader( + ds, + batch_size=9, + shuffle=False, + drop_last=False, + ) - result_path = prepare_results( - f"{Path(__file__).stem}", - ensemble_config, - ) - symlink_checkpoint_files(ensemble, result_path) - - import matplotlib.pyplot as plt - - np.random.seed(42) - for batch in tqdm.tqdm(dl): - ys = batch["target"] - xs = batch["input"] - - batch = {k: v.to(device_id) for k, v in batch.items()} - output = ensemble.members[0](batch)["logits"].cpu().detach().numpy() - fig, axs = plt.subplots(3, 3, figsize=(10, 10)) - for idx, (start_deltas, delta, target) in enumerate( - zip(xs.cpu().numpy(), output, ys.numpy()) - ): - start = np.zeros_like(delta) - # breakpoint() - for i in range(start_deltas.shape[0]): - start[i + 1] = start[i] + start_deltas[i] - # delta = np.concatenate([[[0, 0]], delta], axis=0) - ax = axs[idx // 3, idx % 3] - ax.plot(start[:, 0], start[:, 1], "g--", label="initial", alpha=0.2) - ax.plot( - start[:, 0] + delta[:, 0], - start[:, 1] + delta[:, 1], - "r-", - label=ensemble_config.members[ - 0 - ].train_config.model_config.__class__.__name__, - ) - ax.plot( - start[:, 0] + target[:, 0], - start[:, 1] + target[:, 1], - "b-", - label="target", - ) - ax.legend() - ax.set_title(f"{length_cable(start + target)}, {99 * 0.05}") - fig.suptitle("80 iteration constraint resolve") - path = result_path / "clifford_test.png" - fig.savefig(path) - add_artifact(ensemble_config.members[0], "clifford_rope_test", path) - raise Exception("exit") + result_path = prepare_results( + f"{Path(__file__).stem}", + config, + ) + # symlink_checkpoint_files(ensemble, result_path) + + np.random.seed(42) + for batch in tqdm.tqdm(dl): + ys = batch["target"] + xs = batch["input"] + + batch = {k: v.to(device_id) for k, v in batch.items()} + output = model(batch)["logits"].cpu().detach().numpy() + fig, axs = plt.subplots(3, 3, figsize=(10, 10)) + for idx, (start_deltas, delta, target) in enumerate( + zip(xs.cpu().numpy(), output, ys.numpy()) + ): + start = np.zeros_like(delta) + # breakpoint() + for i in range(start_deltas.shape[0]): + start[i + 1] = start[i] + start_deltas[i] + # delta = np.concatenate([[[0, 0]], delta], axis=0) + ax = axs[idx // 3, idx % 3] + ax.plot(start[:, 0], start[:, 1], "g--", label="initial", alpha=0.2) + ax.plot( + start[:, 0] + delta[:, 0], + start[:, 1] + delta[:, 1], + "r-", + label=config.train_config.model_config.__class__.__name__, + ) + ax.plot( + start[:, 0] + target[:, 0], + start[:, 1] + target[:, 1], + "b-", + label="target", + ) + ax.legend() + ax.set_title(f"{length_cable(start + target)}, {99 * 0.05}") + fig.suptitle("80 iteration constraint resolve") + path = result_path / "clifford_test.png" + fig.savefig(path) + add_artifact(config, "clifford_rope_test", path) + break + # raise Exception("exit") diff --git a/lib/render_psql.py b/lib/render_psql.py index 90f784e..b5d8293 100644 --- a/lib/render_psql.py +++ b/lib/render_psql.py @@ -443,6 +443,30 @@ def add_artifact(train_run: TrainRun, name: str, path: Union[str, Path]): print(f"[Database] Added artifact {name}: {path}") +def has_artifact(train_run: TrainRun, name: str): + # train_run_dict = train_run.serialize_human() + try: + setup_psql() + except psycopg.errors.OperationalError as e: + print("[Database] Could not connect to database, artifact not added.") + return (False, str(e)) + + train_run_dict = train_run.serialize_human() + with psycopg.connect( + "dbname=equiv user=postgres password=postgres", + host=os.getenv("EP_POSTGRES", env().postgres_host), + port=int(os.getenv("EP_POSTGRES_PORT", env().postgres_port)), + autocommit=False, + ) as conn: + rows = conn.execute( + """ + SELECT * FROM artifacts WHERE train_id=%(train_id)s AND name=%(name)s + """, + dict(train_id=train_run_dict["train_id"], name=name), + ) + return rows.fetchone() is not None + + def add_ensemble_artifact( ensemble_config: EnsembleConfig, name: str, path: Union[str, Path] ):