-
-
Notifications
You must be signed in to change notification settings - Fork 16.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add EarlyStopping feature * Add comment * Cleanup * Cleanup2 * debug * debug2 * debug3 * debug3 * debug4 * debug5 * debug6 * debug7 * debug8 * debug9 * debug10 * debug11 * debug12 * Cleanup * Add TODO for known DDP issue
- Loading branch information
1 parent
8b18b66
commit 93cc015
Showing
2 changed files
with
35 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,7 +40,8 @@ | |
from utils.downloads import attempt_download | ||
from utils.loss import ComputeLoss | ||
from utils.plots import plot_labels, plot_evolve | ||
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel | ||
from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, intersect_dicts, select_device, \ | ||
torch_distributed_zero_first | ||
from utils.loggers.wandb.wandb_utils import check_wandb_resume | ||
from utils.metrics import fitness | ||
from utils.loggers import Loggers | ||
|
@@ -255,6 +256,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |
results = (0, 0, 0, 0, 0, 0, 0) # P, R, [email protected], [email protected], val_loss(box, obj, cls) | ||
scheduler.last_epoch = start_epoch - 1 # do not move | ||
scaler = amp.GradScaler(enabled=cuda) | ||
stopper = EarlyStopping(patience=opt.patience) | ||
compute_loss = ComputeLoss(model) # init loss class | ||
LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n' | ||
f'Using {train_loader.num_workers} dataloader workers\n' | ||
|
@@ -389,6 +391,20 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |
del ckpt | ||
callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi) | ||
|
||
# Stop Single-GPU | ||
if stopper(epoch=epoch, fitness=fi): | ||
break | ||
|
||
# Stop DDP TODO: known issues shttps://github.com/ultralytics/yolov5/pull/4576 | ||
# stop = stopper(epoch=epoch, fitness=fi) | ||
# if RANK == 0: | ||
# dist.broadcast_object_list([stop], 0) # broadcast 'stop' to all ranks | ||
|
||
# Stop DPP | ||
# with torch_distributed_zero_first(RANK): | ||
# if stop: | ||
# break # must break all DDP ranks | ||
|
||
# end epoch ---------------------------------------------------------------------------------------------------- | ||
# end training ----------------------------------------------------------------------------------------------------- | ||
if RANK in [-1, 0]: | ||
|
@@ -454,6 +470,7 @@ def parse_opt(known=False): | |
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used') | ||
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify') | ||
parser.add_argument('--freeze', type=int, default=0, help='Number of layers to freeze. backbone=10, all=24') | ||
parser.add_argument('--patience', type=int, default=30, help='EarlyStopping patience (epochs)') | ||
opt = parser.parse_known_args()[0] if known else parser.parse_args() | ||
return opt | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters