Skip to content

Commit

Permalink
make C extension lazy-import (#971)
Browse files Browse the repository at this point in the history
* make C extension lazy-import

* add lazy loading to roi_pool
  • Loading branch information
soumith authored May 30, 2019
1 parent 579eebe commit 220b69b
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 31 deletions.
28 changes: 0 additions & 28 deletions torchvision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,31 +33,3 @@ def get_image_backend():
Gets the name of the package used to load images
"""
return _image_backend


def _check_cuda_matches():
"""
Make sure that CUDA versions match between the pytorch install and torchvision install
"""
import torch
from torchvision import _C
if hasattr(_C, "CUDA_VERSION") and torch.version.cuda is not None:
tv_version = str(_C.CUDA_VERSION)
if int(tv_version) < 10000:
tv_major = int(tv_version[0])
tv_minor = int(tv_version[2])
else:
tv_major = int(tv_version[0:2])
tv_minor = int(tv_version[3])
t_version = torch.version.cuda
t_version = t_version.split('.')
t_major = int(t_version[0])
t_minor = int(t_version[1])
if t_major != tv_major or t_minor != tv_minor:
raise RuntimeError("Detected that PyTorch and torchvision were compiled with different CUDA versions. "
"PyTorch has CUDA Version={}.{} and torchvision has CUDA Version={}.{}. "
"Please reinstall the torchvision that matches your PyTorch install."
.format(t_major, t_minor, tv_major, tv_minor))


_check_cuda_matches()
31 changes: 31 additions & 0 deletions torchvision/extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
_C = None


def _lazy_import():
"""
Make sure that CUDA versions match between the pytorch install and torchvision install
"""
global _C
if _C is not None:
return _C
import torch
from torchvision import _C as C
_C = C
if hasattr(_C, "CUDA_VERSION") and torch.version.cuda is not None:
tv_version = str(_C.CUDA_VERSION)
if int(tv_version) < 10000:
tv_major = int(tv_version[0])
tv_minor = int(tv_version[2])
else:
tv_major = int(tv_version[0:2])
tv_minor = int(tv_version[3])
t_version = torch.version.cuda
t_version = t_version.split('.')
t_major = int(t_version[0])
t_minor = int(t_version[1])
if t_major != tv_major or t_minor != tv_minor:
raise RuntimeError("Detected that PyTorch and torchvision were compiled with different CUDA versions. "
"PyTorch has CUDA Version={}.{} and torchvision has CUDA Version={}.{}. "
"Please reinstall the torchvision that matches your PyTorch install."
.format(t_major, t_minor, tv_major, tv_minor))
return _C
3 changes: 2 additions & 1 deletion torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from torchvision import _C
from torchvision.extension import _lazy_import


def nms(boxes, scores, iou_threshold):
Expand All @@ -22,6 +22,7 @@ def nms(boxes, scores, iou_threshold):
of the elements that have been kept
by NMS, sorted in decreasing order of scores
"""
_C = _lazy_import()
return _C.nms(boxes, scores, iou_threshold)


Expand Down
4 changes: 3 additions & 1 deletion torchvision/ops/roi_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from torch.nn.modules.utils import _pair

from torchvision import _C
from torchvision.extension import _lazy_import
from ._utils import convert_boxes_to_roi_format


Expand All @@ -18,6 +18,7 @@ def forward(ctx, input, roi, output_size, spatial_scale, sampling_ratio):
ctx.spatial_scale = spatial_scale
ctx.sampling_ratio = sampling_ratio
ctx.input_shape = input.size()
_C = _lazy_import()
output = _C.roi_align_forward(
input, roi, spatial_scale,
output_size[0], output_size[1], sampling_ratio)
Expand All @@ -31,6 +32,7 @@ def backward(ctx, grad_output):
spatial_scale = ctx.spatial_scale
sampling_ratio = ctx.sampling_ratio
bs, ch, h, w = ctx.input_shape
_C = _lazy_import()
grad_input = _C.roi_align_backward(
grad_output, rois, spatial_scale,
output_size[0], output_size[1], bs, ch, h, w, sampling_ratio)
Expand Down
4 changes: 3 additions & 1 deletion torchvision/ops/roi_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from torch.nn.modules.utils import _pair

from torchvision import _C
from torchvision.extension import _lazy_import
from ._utils import convert_boxes_to_roi_format


Expand All @@ -16,6 +16,7 @@ def forward(ctx, input, rois, output_size, spatial_scale):
ctx.output_size = _pair(output_size)
ctx.spatial_scale = spatial_scale
ctx.input_shape = input.size()
_C = _lazy_import()
output, argmax = _C.roi_pool_forward(
input, rois, spatial_scale,
output_size[0], output_size[1])
Expand All @@ -29,6 +30,7 @@ def backward(ctx, grad_output):
output_size = ctx.output_size
spatial_scale = ctx.spatial_scale
bs, ch, h, w = ctx.input_shape
_C = _lazy_import()
grad_input = _C.roi_pool_backward(
grad_output, rois, argmax, spatial_scale,
output_size[0], output_size[1], bs, ch, h, w)
Expand Down

0 comments on commit 220b69b

Please sign in to comment.