forked from ultralytics/yolov5
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add EarlyStopping feature (ultralytics#4576)
* Add EarlyStopping feature * Add comment * Cleanup * Cleanup2 * debug * debug2 * debug3 * debug3 * debug4 * debug5 * debug6 * debug7 * debug8 * debug9 * debug10 * debug11 * debug12 * Cleanup * Add TODO for known DDP issue
- Loading branch information
1 parent
ce234ff
commit b2c8831
Showing
2 changed files
with
35 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,7 +40,8 @@ | |
from utils.downloads import attempt_download | ||
from utils.loss import ComputeLoss | ||
from utils.plots import plot_labels, plot_evolve | ||
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel | ||
from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, intersect_dicts, select_device, \ | ||
torch_distributed_zero_first | ||
from utils.loggers.wandb.wandb_utils import check_wandb_resume | ||
from utils.metrics import fitness | ||
from utils.loggers import Loggers | ||
|
@@ -255,6 +256,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |
results = (0, 0, 0, 0, 0, 0, 0) # P, R, [email protected], [email protected], val_loss(box, obj, cls) | ||
scheduler.last_epoch = start_epoch - 1 # do not move | ||
scaler = amp.GradScaler(enabled=cuda) | ||
stopper = EarlyStopping(patience=opt.patience) | ||
compute_loss = ComputeLoss(model) # init loss class | ||
LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n' | ||
f'Using {train_loader.num_workers} dataloader workers\n' | ||
|
@@ -389,6 +391,20 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |
del ckpt | ||
callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi) | ||
|
||
# Stop Single-GPU | ||
if stopper(epoch=epoch, fitness=fi): | ||
break | ||
|
||
# Stop DDP TODO: known issues shttps://github.com/ultralytics/yolov5/pull/4576 | ||
# stop = stopper(epoch=epoch, fitness=fi) | ||
# if RANK == 0: | ||
# dist.broadcast_object_list([stop], 0) # broadcast 'stop' to all ranks | ||
|
||
# Stop DPP | ||
# with torch_distributed_zero_first(RANK): | ||
# if stop: | ||
# break # must break all DDP ranks | ||
|
||
# end epoch ---------------------------------------------------------------------------------------------------- | ||
# end training ----------------------------------------------------------------------------------------------------- | ||
if RANK in [-1, 0]: | ||
|
@@ -454,6 +470,7 @@ def parse_opt(known=False): | |
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used') | ||
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify') | ||
parser.add_argument('--freeze', type=int, default=0, help='Number of layers to freeze. backbone=10, all=24') | ||
parser.add_argument('--patience', type=int, default=30, help='EarlyStopping patience (epochs)') | ||
opt = parser.parse_known_args()[0] if known else parser.parse_args() | ||
return opt | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters