Skip to content

Commit

Permalink
Add optional per-epoch serialization
Browse files Browse the repository at this point in the history
Enable with
  train_run.keep_epoch_checkpoints = True
  • Loading branch information
hlinander committed Jan 11, 2024
1 parent 6521c02 commit ad02a31
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 4 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,9 @@ slurm_tmp/
locks/
experiments/lora_ensembles/hf_env.sh
distributed_training_requests/*
LLM
artifacts
rust/vis/target
rust/vis/flamegraph.svg
rust/vis/perf.data
rust/vis/perf.data.old
2 changes: 1 addition & 1 deletion experiments/lora_ensembles/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def print_sample_comparison(input_ids, logits, labels, tokenizer):


def create_inference_config(ensemble_id):
config = create_config(ensemble_id)
config = create_config(ensemble_id, epochs=1)
config.compute_config.distributed = False
config.compute_config.num_gpus = 1
return config
Expand Down
6 changes: 4 additions & 2 deletions experiments/lora_ensembles/lora_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def create_config(
lora_l2=0.1,
regular_l2=0,
target_modules=["q_proj", "v_proj"],
epochs=4,
):
train_config = TrainConfig(
model_config=LLaMA2GenerativeConfig(
Expand Down Expand Up @@ -91,11 +92,12 @@ def create_config(
data_visualizer=None,
)
train_run = TrainRun(
compute_config=ComputeConfig(distributed=True, num_workers=2, num_gpus=2),
compute_config=ComputeConfig(distributed=False, num_workers=2, num_gpus=1),
train_config=train_config,
train_eval=train_eval,
epochs=2,
epochs=epochs,
save_nth_epoch=1,
keep_epoch_checkpoints=True,
validate_nth_epoch=1,
)
return train_run
Expand Down
17 changes: 17 additions & 0 deletions experiments/lora_ensembles/lora_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from lib.serialization import DeserializeConfig, get_checkpoint_path, deserialize_model
from lib.train_dataclasses import TrainRun
from lib.paths import get_model_epoch_checkpoint_path


@dataclass
Expand All @@ -30,6 +31,22 @@ def deserialize_model_state_dict(config: DeserializeConfig):
)

model_epoch = torch.load(checkpoint_path / "epoch")

if model_epoch != config.train_run.epochs:
model_epoch_checkpoint = get_model_epoch_checkpoint_path(
config.train_run.train_config, config.train_run.epochs
)
if not (model_epoch_checkpoint).is_file():
raise Exception(
f"The requested epoch ({config.train_run.epochs}) is not available in {model_epoch_checkpoint}."
)
model_state_dict = torch.load(
model_epoch_checkpoint, map_location=torch.device(config.device_id)
)
print(
f"Loaded earlier epoch {config.train_run.epochs}, the latest epoch is {model_epoch}."
)

except Exception as e:
print(f"Failed to deserialize_model: {e}")
return None
Expand Down
6 changes: 6 additions & 0 deletions lib/paths.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path
from lib.stable_hash import stable_hash
from lib.compute_env import env
from lib.train_dataclasses import TrainConfig


def get_or_create_checkpoint_path(train_config) -> Path:
Expand All @@ -23,3 +24,8 @@ def get_checkpoint_path(train_config) -> Path:
checkpoint_dir.mkdir(exist_ok=True, parents=True)
checkpoint = checkpoint_dir / f"checkpoint_{config_hash}"
return checkpoint


def get_model_epoch_checkpoint_path(train_config: TrainConfig, epoch: int):
checkpoint_path = get_checkpoint_path(train_config)
return checkpoint_path / f"model_epoch_{epoch:04d}"
16 changes: 15 additions & 1 deletion lib/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
import lib.data_factory as data_factory
from lib.metric import Metric
from lib.timing_metric import Timing
from lib.paths import get_checkpoint_path, get_or_create_checkpoint_path
from lib.paths import (
get_checkpoint_path,
get_or_create_checkpoint_path,
get_model_epoch_checkpoint_path,
)
import lib.model_factory as model_factory
from lib.data_utils import get_sampler
from lib.ddp import get_rank
Expand Down Expand Up @@ -81,6 +85,16 @@ def serialize(config: SerializeConfig):
)

checkpoint_path = get_or_create_checkpoint_path(train_config)

if config.train_run.keep_epoch_checkpoints:
model_epoch_checkpoint = get_model_epoch_checkpoint_path(
config.train_run.train_config, train_epoch_state.epoch
)
torch.save(model, f"{model_epoch_checkpoint}_tmp")
shutil.move(
checkpoint_path / f"{model_epoch_checkpoint}_tmp",
checkpoint_path / model_epoch_checkpoint,
)
for key, value in file_data.__dict__.items():
torch.save(value, checkpoint_path / f"{key}_tmp")
shutil.move(checkpoint_path / f"{key}_tmp", checkpoint_path / key)
Expand Down
1 change: 1 addition & 0 deletions lib/train_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ class TrainRun:
epochs: int
save_nth_epoch: int
validate_nth_epoch: int
keep_epoch_checkpoints: bool = False
visualize_terminal: bool = True

def serialize_human(self):
Expand Down

0 comments on commit ad02a31

Please sign in to comment.