Skip to content

Commit

Permalink
Merge branch 'main' into raft_model_arch
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Dec 6, 2021
2 parents 9ae9e38 + 9b57de6 commit f077d7c
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 49 deletions.
54 changes: 19 additions & 35 deletions references/video_classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,13 @@
from torch.utils.data.dataloader import default_collate
from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler

try:
from apex import amp
except ImportError:
amp = None


try:
from torchvision.prototype import models as PM
except ImportError:
PM = None


def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, apex=False):
def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
Expand All @@ -34,16 +28,19 @@ def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, devi
for video, target in metric_logger.log_every(data_loader, print_freq, header):
start_time = time.time()
video, target = video.to(device), target.to(device)
output = model(video)
loss = criterion(output, target)
with torch.cuda.amp.autocast(enabled=scaler is not None):
output = model(video)
loss = criterion(output, target)

optimizer.zero_grad()
if apex:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()

if scaler is not None:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
optimizer.step()

acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
batch_size = video.shape[0]
Expand Down Expand Up @@ -101,11 +98,6 @@ def collate_fn(batch):
def main(args):
if args.weights and PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if args.apex and amp is None:
raise RuntimeError(
"Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
"to enable mixed-precision training."
)

if args.output_dir:
utils.mkdir(args.output_dir)
Expand Down Expand Up @@ -224,9 +216,7 @@ def main(args):

lr = args.lr * args.world_size
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay)

if args.apex:
model, optimizer = amp.initialize(model, optimizer, opt_level=args.apex_opt_level)
scaler = torch.cuda.amp.GradScaler() if args.amp else None

# convert scheduler to be per iteration, not per epoch, for warmup that lasts
# between different epochs
Expand Down Expand Up @@ -267,6 +257,8 @@ def main(args):
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
args.start_epoch = checkpoint["epoch"] + 1
if args.amp:
scaler.load_state_dict(checkpoint["scaler"])

if args.test_only:
evaluate(model, criterion, data_loader_test, device=device)
Expand All @@ -277,9 +269,7 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(
model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, args.apex
)
train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, scaler)
evaluate(model, criterion, data_loader_test, device=device)
if args.output_dir:
checkpoint = {
Expand All @@ -289,6 +279,8 @@ def main(args):
"epoch": epoch,
"args": args,
}
if args.amp:
checkpoint["scaler"] = scaler.state_dict()
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))

Expand Down Expand Up @@ -363,24 +355,16 @@ def parse_args():
action="store_true",
)

# Mixed precision training parameters
parser.add_argument("--apex", action="store_true", help="Use apex for mixed precision training")
parser.add_argument(
"--apex-opt-level",
default="O1",
type=str,
help="For apex mixed precision training"
"O0 for FP32 training, O1 for mixed precision training."
"For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet",
)

# distributed training parameters
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")

# Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

# Mixed precision training parameters
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")

args = parser.parse_args()

return args
Expand Down
1 change: 1 addition & 0 deletions test/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def test_roi_align(self):
model = ops.RoIAlign((5, 5), 1, -1)
self.run_model(model, [(x, single_roi)])

@pytest.mark.skip(reason="ROIAlign with aligned=True is not supported in ONNX, but will be supported in opset 16.")
def test_roi_align_aligned(self):
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32)
Expand Down
2 changes: 1 addition & 1 deletion torchvision/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(self, path: str, stream: str = "video", num_threads: int = 0) -> No
raise RuntimeError(
"Not compiled with video_reader support, "
+ "to enable video_reader support, please install "
+ "ffmpeg (version 4.2 is currently supported) and"
+ "ffmpeg (version 4.2 is currently supported) and "
+ "build torchvision from source."
)
self._c = torch.classes.torchvision.Video(path, stream, num_threads)
Expand Down
36 changes: 23 additions & 13 deletions torchvision/prototype/datasets/utils/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os.path
import pathlib
import pickle
import platform
from typing import BinaryIO
from typing import (
Sequence,
Expand Down Expand Up @@ -260,6 +261,11 @@ def _make_sharded_datapipe(root: str, dataset_size: int) -> IterDataPipe:
return dp


def _read_mutable_buffer_fallback(file: BinaryIO, count: int, item_size: int) -> bytearray:
# A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable
return bytearray(file.read(-1 if count == -1 else count * item_size))


def fromfile(
file: BinaryIO,
*,
Expand Down Expand Up @@ -293,20 +299,24 @@ def fromfile(
item_size = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
np_dtype = byte_order + char + str(item_size)

# PyTorch does not support tensors with underlying read-only memory. In case
# - the file has a .fileno(),
# - the file was opened for updating, i.e. 'r+b' or 'w+b',
# - the file is seekable
# we can avoid copying the data for performance. Otherwise we fall back to simply .read() the data and copy it to
# a mutable location afterwards.
buffer: Union[memoryview, bytearray]
try:
buffer = memoryview(mmap.mmap(file.fileno(), 0))[file.tell() :]
# Reading from the memoryview does not advance the file cursor, so we have to do it manually.
file.seek(*(0, io.SEEK_END) if count == -1 else (count * item_size, io.SEEK_CUR))
except (PermissionError, io.UnsupportedOperation):
# A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable
buffer = bytearray(file.read(-1 if count == -1 else count * item_size))
if platform.system() != "Windows":
# PyTorch does not support tensors with underlying read-only memory. In case
# - the file has a .fileno(),
# - the file was opened for updating, i.e. 'r+b' or 'w+b',
# - the file is seekable
# we can avoid copying the data for performance. Otherwise we fall back to simply .read() the data and copy it
# to a mutable location afterwards.
try:
buffer = memoryview(mmap.mmap(file.fileno(), 0))[file.tell() :]
# Reading from the memoryview does not advance the file cursor, so we have to do it manually.
file.seek(*(0, io.SEEK_END) if count == -1 else (count * item_size, io.SEEK_CUR))
except (PermissionError, io.UnsupportedOperation):
buffer = _read_mutable_buffer_fallback(file, count, item_size)
else:
# On Windows just trying to call mmap.mmap() on a file that does not support it, may corrupt the internal state
# so no data can be read afterwards. Thus, we simply ignore the possible speed-up.
buffer = _read_mutable_buffer_fallback(file, count, item_size)

# We cannot use torch.frombuffer() directly, since it only supports the native byte order of the system. Thus, we
# read the data with np.frombuffer() with the correct byte order and convert it to the native one with the
Expand Down

0 comments on commit f077d7c

Please sign in to comment.