From f982128d7c7cca9c774468f3b4ae30827bb5b1a8 Mon Sep 17 00:00:00 2001 From: Hampus Linander Date: Sat, 6 Jan 2024 13:59:38 +0100 Subject: [PATCH] Fix error for new trainings in LORA --- lib/distributed_trainer.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/lib/distributed_trainer.py b/lib/distributed_trainer.py index dad9140..9e9976e 100644 --- a/lib/distributed_trainer.py +++ b/lib/distributed_trainer.py @@ -66,14 +66,15 @@ def distributed_train(requested_configs: List[TrainRun] = None): if distributed_train_run is not None: try: last_aquired_training = time.time() - if ( - get_serialization_epoch( - DeserializeConfig( - train_run=distributed_train_run.train_run, - device_id=device_id, - ) + serialized_epoch = get_serialization_epoch( + DeserializeConfig( + train_run=distributed_train_run.train_run, + device_id=device_id, ) - < distributed_train_run.train_run.epochs + ) + if ( + serialized_epoch is None + or serialized_epoch < distributed_train_run.train_run.epochs ): do_train_run(distributed_train_run, device_id) else: