Skip to content

Commit

Permalink
torch.empty() for speed improvements (ultralytics#9025)
Browse files Browse the repository at this point in the history
`torch.empty()` for speed improvement

Signed-off-by: Glenn Jocher <[email protected]>
  • Loading branch information
glenn-jocher authored and Clay Januhowski committed Sep 8, 2022
1 parent 4c21590 commit 07c2b4e
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def warmup(self, imgsz=(1, 3, 640, 640)):
# Warmup model by running inference once
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb
if any(warmup_types) and self.device.type != 'cpu':
im = torch.zeros(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
for _ in range(2 if self.jit else 1): #
self.forward(im) # warmup

Expand Down Expand Up @@ -600,7 +600,7 @@ def forward(self, ims, size=640, augment=False, profile=False):

dt = (Profile(), Profile(), Profile())
with dt[0]:
p = next(self.model.parameters()) if self.pt else torch.zeros(1, device=self.model.device) # param
p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
if isinstance(ims, torch.Tensor): # torch
with amp.autocast(autocast):
Expand Down
6 changes: 3 additions & 3 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
self.no = nc + 5 # number of outputs per anchor
self.nl = len(anchors) # number of detection layers
self.na = len(anchors[0]) // 2 # number of anchors
self.grid = [torch.zeros(1)] * self.nl # init grid
self.anchor_grid = [torch.zeros(1)] * self.nl # init anchor grid
self.grid = [torch.empty(1)] * self.nl # init grid
self.anchor_grid = [torch.empty(1)] * self.nl # init anchor grid
self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
self.inplace = inplace # use inplace ops (e.g. slice assignment)
Expand Down Expand Up @@ -175,7 +175,7 @@ def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, i
if isinstance(m, Detect):
s = 256 # 2x min stride
m.inplace = self.inplace
m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.empty(1, ch, s, s))]) # forward
check_anchor_order(m) # must be in pixel-space (not grid-space)
m.anchors /= m.stride.view(-1, 1, 1)
self.stride = m.stride
Expand Down
2 changes: 1 addition & 1 deletion utils/autobatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def autobatch(model, imgsz=640, fraction=0.9, batch_size=16):
# Profile batch sizes
batch_sizes = [1, 2, 4, 8, 16]
try:
img = [torch.zeros(b, 3, imgsz, imgsz) for b in batch_sizes]
img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
results = profile(img, model, n=3, device=device)
except Exception as e:
LOGGER.warning(f'{prefix}{e}')
Expand Down
2 changes: 1 addition & 1 deletion utils/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def log_tensorboard_graph(tb, model, imgsz=(640, 640)):
try:
p = next(model.parameters()) # for device, type
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz # expand
im = torch.zeros((1, 3, *imgsz)).to(p.device).type_as(p) # input image
im = torch.empty((1, 3, *imgsz)).to(p.device).type_as(p) # input image
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress jit trace warning
tb.add_graph(torch.jit.trace(de_parallel(model), im, strict=False), [])
Expand Down
2 changes: 1 addition & 1 deletion utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def model_info(model, verbose=False, imgsz=640):
try: # FLOPs
p = next(model.parameters())
stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride
im = torch.zeros((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
flops = thop.profile(deepcopy(model), inputs=(im,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
fs = f', {flops * imgsz[0] / stride * imgsz[1] / stride:.1f} GFLOPs' # 640x640 GFLOPs
Expand Down

0 comments on commit 07c2b4e

Please sign in to comment.