diff --git a/detect.py b/detect.py index 8741e7f7fd6..c699a749a09 100644 --- a/detect.py +++ b/detect.py @@ -44,10 +44,10 @@ from utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2, increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh) from utils.plots import Annotator, colors, save_one_box -from utils.torch_utils import select_device, time_sync +from utils.torch_utils import select_device, smart_inference_mode, time_sync -@torch.no_grad() +@smart_inference_mode() def run( weights=ROOT / 'yolov5s.pt', # model.pt path(s) source=ROOT / 'data/images', # file/dir/URL/glob, 0 for webcam diff --git a/export.py b/export.py index 6d70724daa2..c88a5107470 100644 --- a/export.py +++ b/export.py @@ -69,7 +69,7 @@ from utils.dataloaders import LoadImages from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_version, check_yaml, colorstr, file_size, print_args, url2file) -from utils.torch_utils import select_device +from utils.torch_utils import select_device, smart_inference_mode def export_formats(): @@ -455,7 +455,7 @@ def export_tfjs(file, prefix=colorstr('TensorFlow.js:')): LOGGER.info(f'\n{prefix} export failure: {e}') -@torch.no_grad() +@smart_inference_mode() def run( data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' weights=ROOT / 'yolov5s.pt', # weights path diff --git a/models/common.py b/models/common.py index a1269c5f337..afb9323ce49 100644 --- a/models/common.py +++ b/models/common.py @@ -25,7 +25,7 @@ from utils.general import (LOGGER, check_requirements, check_suffix, check_version, colorstr, increment_path, make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh) from utils.plots import Annotator, colors, save_one_box -from utils.torch_utils import copy_attr, time_sync +from utils.torch_utils import copy_attr, smart_inference_mode, time_sync def autopad(k, p=None): # kernel, padding @@ -578,7 +578,7 @@ def _apply(self, fn): m.anchor_grid = list(map(fn, m.anchor_grid)) return self - @torch.no_grad() + @smart_inference_mode() def forward(self, imgs, size=640, augment=False, profile=False): # Inference from various sources. For height=640, width=1280, RGB images example inputs are: # file: imgs = 'data/images/zidane.jpg' # str or PosixPath diff --git a/models/yolo.py b/models/yolo.py index bc1893ccbc4..307b74844ca 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -76,12 +76,12 @@ def forward(self, x): return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x) - def _make_grid(self, nx=20, ny=20, i=0): + def _make_grid(self, nx=20, ny=20, i=0, torch_1_10=check_version(torch.__version__, '1.10.0')): d = self.anchors[i].device t = self.anchors[i].dtype shape = 1, self.na, ny, nx, 2 # grid shape y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t) - if check_version(torch.__version__, '1.10.0'): # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility + if torch_1_10: # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility yv, xv = torch.meshgrid(y, x, indexing='ij') else: yv, xv = torch.meshgrid(y, x) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index beb81442912..1ceb0aa346e 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -34,6 +34,14 @@ warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling') +def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')): + # Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator + def decorate(fn): + return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn) + + return decorate + + def smart_DDP(model): # Model DDP creation with checks assert not check_version(torch.__version__, '1.12.0', pinned=True), \ @@ -364,17 +372,17 @@ def __init__(self, model, decay=0.9999, tau=2000, updates=0): for p in self.ema.parameters(): p.requires_grad_(False) + @smart_inference_mode() def update(self, model): # Update EMA parameters - with torch.no_grad(): - self.updates += 1 - d = self.decay(self.updates) - - msd = de_parallel(model).state_dict() # model state_dict - for k, v in self.ema.state_dict().items(): - if v.dtype.is_floating_point: - v *= d - v += (1 - d) * msd[k].detach() + self.updates += 1 + d = self.decay(self.updates) + + msd = de_parallel(model).state_dict() # model state_dict + for k, v in self.ema.state_dict().items(): + if v.dtype.is_floating_point: + v *= d + v += (1 - d) * msd[k].detach() def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): # Update EMA attributes diff --git a/val.py b/val.py index 4265145b065..13049623346 100644 --- a/val.py +++ b/val.py @@ -42,7 +42,7 @@ scale_coords, xywh2xyxy, xyxy2xywh) from utils.metrics import ConfusionMatrix, ap_per_class, box_iou from utils.plots import output_to_target, plot_images, plot_val_study -from utils.torch_utils import select_device, time_sync +from utils.torch_utils import select_device, smart_inference_mode, time_sync def save_one_txt(predn, save_conf, shape, file): @@ -93,7 +93,7 @@ def process_batch(detections, labels, iouv): return torch.tensor(correct, dtype=torch.bool, device=iouv.device) -@torch.no_grad() +@smart_inference_mode() def run( data, weights=None, # model.pt path(s)