diff --git a/.travis.yml b/.travis.yml index 44fd9ae3c16..789c355765f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -38,7 +38,6 @@ before_install: fi - conda install av -c conda-forge - install: # Using pip instead of setup.py ensures we install a non-compressed version of the package # (as opposed to an egg), which is necessary to collect coverage. @@ -55,7 +54,7 @@ install: cd - script: - - pytest --cov-config .coveragerc --cov torchvision --cov $TV_INSTALL_PATH -k 'not TestVideoReader and not TestVideoTransforms and not TestIO' test --ignore=test/test_datasets_download.py + - pytest --cov-config .coveragerc --cov torchvision --cov $TV_INSTALL_PATH -k 'not TestVideo and not TestVideoReader and not TestVideoTransforms and not TestIO' test --ignore=test/test_datasets_download.py - pytest test/test_hub.py after_success: diff --git a/setup.py b/setup.py index 4e927923fcb..d6674465405 100644 --- a/setup.py +++ b/setup.py @@ -347,10 +347,13 @@ def get_extensions(): base_decoder_src_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'cpu', 'decoder') base_decoder_src = glob.glob( os.path.join(base_decoder_src_dir, "*.cpp")) + # Torchvision video API + videoapi_src_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'cpu', 'video') + videoapi_src = glob.glob(os.path.join(videoapi_src_dir, "*.cpp")) # exclude tests base_decoder_src = [x for x in base_decoder_src if '_test.cpp' not in x] - combined_src = video_reader_src + base_decoder_src + combined_src = video_reader_src + base_decoder_src + videoapi_src ext_modules.append( CppExtension( @@ -359,6 +362,7 @@ def get_extensions(): include_dirs=[ base_decoder_src_dir, video_reader_src_dir, + videoapi_src_dir, ffmpeg_include_dir, extensions_dir, ], diff --git a/test/test_video.py b/test/test_video.py new file mode 100644 index 00000000000..63434fa9c1f --- /dev/null +++ b/test/test_video.py @@ -0,0 +1,407 @@ +import os +import collections +import contextlib +import tempfile +import unittest +import random + + +import numpy as np + +import torch +import torchvision +from torchvision.io import _HAS_VIDEO_OPT + +try: + import av + + # Do a version test too + torchvision.io.video._check_av_available() +except ImportError: + av = None + + +VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos") + +CheckerConfig = [ + "duration", + "video_fps", + "audio_sample_rate", + # We find for some videos (e.g. HMDB51 videos), the decoded audio frames and pts are + # slightly different between TorchVision decoder and PyAv decoder. So omit it during check + "check_aframes", + "check_aframe_pts", +] +GroundTruth = collections.namedtuple("GroundTruth", " ".join(CheckerConfig)) + +all_check_config = GroundTruth( + duration=0, + video_fps=0, + audio_sample_rate=0, + check_aframes=True, + check_aframe_pts=True, +) + +test_videos = { + "RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth( + duration=2.0, + video_fps=30.0, + audio_sample_rate=None, + check_aframes=True, + check_aframe_pts=True, + ), + "SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi": GroundTruth( + duration=2.0, + video_fps=30.0, + audio_sample_rate=None, + check_aframes=True, + check_aframe_pts=True, + ), + "TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth( + duration=2.0, + video_fps=30.0, + audio_sample_rate=None, + check_aframes=True, + check_aframe_pts=True, + ), + "v_SoccerJuggling_g23_c01.avi": GroundTruth( + duration=8.0, + video_fps=29.97, + audio_sample_rate=None, + check_aframes=True, + check_aframe_pts=True, + ), + "v_SoccerJuggling_g24_c01.avi": GroundTruth( + duration=8.0, + video_fps=29.97, + audio_sample_rate=None, + check_aframes=True, + check_aframe_pts=True, + ), + # Last three test segfault on video reader (see issues) + "R6llTwEh07w.mp4": GroundTruth( + duration=10.0, + video_fps=30.0, + audio_sample_rate=44100, + # PyAv miss one audio frame at the beginning (pts=0) + check_aframes=False, + check_aframe_pts=False, + ), + "SOX5yA1l24A.mp4": GroundTruth( + duration=11.0, + video_fps=29.97, + audio_sample_rate=48000, + # PyAv miss one audio frame at the beginning (pts=0) + check_aframes=False, + check_aframe_pts=False, + ), + "WUzgd7C1pWA.mp4": GroundTruth( + duration=11.0, + video_fps=29.97, + audio_sample_rate=48000, + # PyAv miss one audio frame at the beginning (pts=0) + check_aframes=False, + check_aframe_pts=False, + ), +} + +DecoderResult = collections.namedtuple( + "DecoderResult", "vframes vframe_pts vtimebase aframes aframe_pts atimebase" +) + + +def _read_from_stream( + container, start_pts, end_pts, stream, stream_name, buffer_size=4 +): + """ + Args: + container: pyav container + start_pts/end_pts: the starting/ending Presentation TimeStamp where + frames are read + stream: pyav stream + stream_name: a dictionary of streams. For example, {"video": 0} means + video stream at stream index 0 + buffer_size: pts of frames decoded by PyAv is not guaranteed to be in + ascending order. We need to decode more frames even when we meet end + pts + """ + # seeking in the stream is imprecise. Thus, seek to an ealier PTS by a margin + margin = 1 + seek_offset = max(start_pts - margin, 0) + + container.seek(seek_offset, any_frame=False, backward=True, stream=stream) + frames = {} + buffer_count = 0 + for frame in container.decode(**stream_name): + if frame.pts < start_pts: + continue + if frame.pts <= end_pts: + frames[frame.pts] = frame + else: + buffer_count += 1 + if buffer_count >= buffer_size: + break + result = [frames[pts] for pts in sorted(frames)] + + return result + + +def _fraction_to_tensor(fraction): + ret = torch.zeros([2], dtype=torch.int32) + ret[0] = fraction.numerator + ret[1] = fraction.denominator + return ret + + +def _decode_frames_by_av_module( + full_path, + video_start_pts=0, + video_end_pts=None, + audio_start_pts=0, + audio_end_pts=None, +): + """ + Use PyAv to decode video frames. This provides a reference for our decoder + to compare the decoding results. + Input arguments: + full_path: video file path + video_start_pts/video_end_pts: the starting/ending Presentation TimeStamp where + frames are read + """ + if video_end_pts is None: + video_end_pts = float("inf") + if audio_end_pts is None: + audio_end_pts = float("inf") + container = av.open(full_path) + + video_frames = [] + vtimebase = torch.zeros([0], dtype=torch.int32) + if container.streams.video: + video_frames = _read_from_stream( + container, + video_start_pts, + video_end_pts, + container.streams.video[0], + {"video": 0}, + ) + # container.streams.video[0].average_rate is not a reliable estimator of + # frame rate. It can be wrong for certain codec, such as VP80 + # So we do not return video fps here + vtimebase = _fraction_to_tensor(container.streams.video[0].time_base) + + audio_frames = [] + atimebase = torch.zeros([0], dtype=torch.int32) + if container.streams.audio: + audio_frames = _read_from_stream( + container, + audio_start_pts, + audio_end_pts, + container.streams.audio[0], + {"audio": 0}, + ) + atimebase = _fraction_to_tensor(container.streams.audio[0].time_base) + + container.close() + vframes = [frame.to_rgb().to_ndarray() for frame in video_frames] + vframes = torch.as_tensor(np.stack(vframes)) + + vframe_pts = torch.tensor([frame.pts for frame in video_frames], dtype=torch.int64) + + aframes = [frame.to_ndarray() for frame in audio_frames] + if aframes: + aframes = np.transpose(np.concatenate(aframes, axis=1)) + aframes = torch.as_tensor(aframes) + else: + aframes = torch.empty((1, 0), dtype=torch.float32) + + aframe_pts = torch.tensor( + [audio_frame.pts for audio_frame in audio_frames], dtype=torch.int64 + ) + + return DecoderResult( + vframes=vframes.permute(0, 3, 1, 2), + vframe_pts=vframe_pts, + vtimebase=vtimebase, + aframes=aframes, + aframe_pts=aframe_pts, + atimebase=atimebase, + ) + + +def _template_read_video(video_object, s=0, e=None): + + if e is None: + e = float("inf") + if e < s: + raise ValueError( + "end time should be larger than start time, got " + "start time={} and end time={}".format(s, e) + ) + video_object.set_current_stream("video") + video_object.seek(s) + video_frames = torch.empty(0) + frames = [] + video_pts = [] + t, pts = video_object.next() + while t.numel() > 0 and (pts >= s and pts <= e): + frames.append(t) + video_pts.append(pts) + t, pts = video_object.next() + if len(frames) > 0: + video_frames = torch.stack(frames, 0) + + video_object.set_current_stream("audio") + video_object.seek(s) + audio_frames = torch.empty(0) + frames = [] + audio_pts = [] + t, pts = video_object.next() + while t.numel() > 0 and (pts > s and pts <= e): + frames.append(t) + audio_pts.append(pts) + t, pts = video_object.next() + if len(frames) > 0: + audio_frames = torch.stack(frames, 0) + + return DecoderResult( + vframes=video_frames, + vframe_pts=video_pts, + vtimebase=None, + aframes=audio_frames, + aframe_pts=audio_pts, + atimebase=None, + ) + return video_frames, audio_frames, video_object.get_metadata() + + +@unittest.skipIf(_HAS_VIDEO_OPT is False, "Didn't compile with ffmpeg") +class TestVideo(unittest.TestCase): + @unittest.skipIf(av is None, "PyAV unavailable") + def test_read_video_tensor(self): + """ + Check if reading the video using the `next` based API yields the + same sized tensors as the pyav alternative. + """ + torchvision.set_video_backend("pyav") + for test_video, config in test_videos.items(): + full_path = os.path.join(VIDEO_DIR, test_video) + # pass 1: decode all frames using existing TV decoder + tv_result, _, _ = torchvision.io.read_video(full_path, pts_unit="sec") + tv_result = tv_result.permute(0, 3, 1, 2) + # pass 2: decode all frames using new api + reader = torch.classes.torchvision.Video(full_path, "video") + frames = [] + t, _ = reader.next() + while t.numel() > 0: + frames.append(t) + t, _ = reader.next() + new_api = torch.stack(frames, 0) + self.assertEqual(tv_result.size(), new_api.size()) + + # def test_partial_video_reading_fn(self): + # torchvision.set_video_backend("video_reader") + # for test_video, config in test_videos.items(): + # full_path = os.path.join(VIDEO_DIR, test_video) + + # # select two random points between 0 and duration + # r = [] + # r.append(random.uniform(0, config.duration)) + # r.append(random.uniform(0, config.duration)) + # s = min(r) + # e = max(r) + + # reader = torch.classes.torchvision.Video(full_path, "video") + # results = _template_read_video(reader, s, e) + # tv_video, tv_audio, info = torchvision.io.read_video( + # full_path, start_pts=s, end_pts=e, pts_unit="sec" + # ) + # self.assertAlmostEqual(tv_video.size(0), results.vframes.size(0), delta=2.0) + + # def test_pts(self): + # """ + # Check if every frame read from + # """ + # torchvision.set_video_backend("video_reader") + # for test_video, config in test_videos.items(): + # full_path = os.path.join(VIDEO_DIR, test_video) + + # tv_timestamps, _ = torchvision.io.read_video_timestamps( + # full_path, pts_unit="sec" + # ) + # # pass 2: decode all frames using new api + # reader = torch.classes.torchvision.Video(full_path, "video") + # pts = [] + # t, p = reader.next() + # while t.numel() > 0: + # pts.append(p) + # t, p = reader.next() + + # tv_timestamps = [float(p) for p in tv_timestamps] + # napi_pts = [float(p) for p in pts] + # for i in range(len(napi_pts)): + # self.assertAlmostEqual(napi_pts[i], tv_timestamps[i], delta=0.001) + # # check if pts of video frames are sorted in ascending order + # for i in range(len(napi_pts) - 1): + # self.assertEqual(napi_pts[i] < napi_pts[i + 1], True) + + @unittest.skipIf(av is None, "PyAV unavailable") + def test_metadata(self): + """ + Test that the metadata returned via pyav corresponds to the one returned + by the new video decoder API + """ + torchvision.set_video_backend("pyav") + for test_video, config in test_videos.items(): + full_path = os.path.join(VIDEO_DIR, test_video) + reader = torch.classes.torchvision.Video(full_path, "video") + reader_md = reader.get_metadata() + self.assertAlmostEqual( + config.video_fps, reader_md["video"]["fps"][0], delta=0.0001 + ) + self.assertAlmostEqual( + config.duration, reader_md["video"]["duration"][0], delta=0.5 + ) + + @unittest.skipIf(av is None, "PyAV unavailable") + def test_video_reading_fn(self): + """ + Test that the outputs of the pyav and ffmpeg outputs are mostly the same + """ + for test_video, config in test_videos.items(): + full_path = os.path.join(VIDEO_DIR, test_video) + + ref_result = _decode_frames_by_av_module(full_path) + + reader = torch.classes.torchvision.Video(full_path, "video") + newapi_result = _template_read_video(reader) + + # First we check if the frames are approximately the same + # (note that every codec context has signature artefacts which + # make a direct comparison not feasible) + if newapi_result.vframes.numel() > 0 and ref_result.vframes.numel() > 0: + mean_delta = torch.mean( + torch.abs( + newapi_result.vframes.float() - ref_result.vframes.float() + ) + ) + self.assertAlmostEqual(mean_delta, 0, delta=8.0) + + # Just a sanity check: are the two of the correct size? + self.assertEqual(newapi_result.vframes.size(), ref_result.vframes.size()) + + # Lastly, we compare the resulting audio streams + if ( + config.check_aframes + and newapi_result.aframes.numel() > 0 + and ref_result.aframes.numel() > 0 + ): + """Audio stream is available and audio frame is required to return + from decoder""" + is_same = torch.all( + torch.eq(newapi_result.aframes, ref_result.aframes) + ).item() + self.assertEqual(is_same, True) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchvision/csrc/cpu/video/Video.cpp b/torchvision/csrc/cpu/video/Video.cpp new file mode 100644 index 00000000000..f3c55fd6dea --- /dev/null +++ b/torchvision/csrc/cpu/video/Video.cpp @@ -0,0 +1,337 @@ + +#include "Video.h" +#include +#include +#include "defs.h" +#include "memory_buffer.h" +#include "sync_decoder.h" + +using namespace std; +using namespace ffmpeg; + +// If we are in a Windows environment, we need to define +// initialization functions for the _custom_ops extension +// #ifdef _WIN32 +// #if PY_MAJOR_VERSION < 3 +// PyMODINIT_FUNC init_video_reader(void) { +// // No need to do anything. +// return NULL; +// } +// #else +// PyMODINIT_FUNC PyInit_video_reader(void) { +// // No need to do anything. +// return NULL; +// } +// #endif +// #endif + +const size_t decoderTimeoutMs = 600000; +const AVPixelFormat defaultVideoPixelFormat = AV_PIX_FMT_RGB24; +const AVSampleFormat defaultAudioSampleFormat = AV_SAMPLE_FMT_FLT; + +// returns number of written bytes +template +size_t fillTensorList(DecoderOutputMessage& msgs, torch::Tensor& frame) { + const auto& msg = msgs; + T* frameData = frame.numel() > 0 ? frame.data_ptr() : nullptr; + if (frameData) { + auto sizeInBytes = msg.payload->length(); + memcpy(frameData, msg.payload->data(), sizeInBytes); + } + return sizeof(T); +} + +size_t fillVideoTensor(DecoderOutputMessage& msgs, torch::Tensor& videoFrame) { + return fillTensorList(msgs, videoFrame); +} + +size_t fillAudioTensor(DecoderOutputMessage& msgs, torch::Tensor& audioFrame) { + return fillTensorList(msgs, audioFrame); +} + +std::pair const* _parse_type( + const std::string& stream_string) { + static const std::array, 4> types = {{ + {"video", TYPE_VIDEO}, + {"audio", TYPE_AUDIO}, + {"subtitle", TYPE_SUBTITLE}, + {"cc", TYPE_CC}, + }}; + auto device = std::find_if( + types.begin(), + types.end(), + [stream_string](const std::pair& p) { + return p.first == stream_string; + }); + if (device != types.end()) { + return device; + } + AT_ERROR("Expected one of [audio, video, subtitle, cc] ", stream_string); +} + +std::string parse_type_to_string(const std::string& stream_string) { + auto device = _parse_type(stream_string); + return device->first; +} + +MediaType parse_type_to_mt(const std::string& stream_string) { + auto device = _parse_type(stream_string); + return device->second; +} + +std::tuple _parseStream(const std::string& streamString) { + TORCH_CHECK(!streamString.empty(), "Stream string must not be empty"); + static const std::regex regex("([a-zA-Z_]+)(?::([1-9]\\d*|0))?"); + std::smatch match; + + TORCH_CHECK( + std::regex_match(streamString, match, regex), + "Invalid stream string: '", + streamString, + "'"); + + std::string type_ = "video"; + type_ = parse_type_to_string(match[1].str()); + long index_ = -1; + if (match[2].matched) { + try { + index_ = c10::stoi(match[2].str()); + } catch (const std::exception&) { + AT_ERROR( + "Could not parse device index '", + match[2].str(), + "' in device string '", + streamString, + "'"); + } + } + return std::make_tuple(type_, index_); +} + +void Video::_getDecoderParams( + double videoStartS, + int64_t getPtsOnly, + std::string stream, + long stream_id = -1, + bool all_streams = false, + double seekFrameMarginUs = 10) { + int64_t videoStartUs = int64_t(videoStartS * 1e6); + + params.timeoutMs = decoderTimeoutMs; + params.startOffset = videoStartUs; + params.seekAccuracy = seekFrameMarginUs; + params.headerOnly = false; + + params.preventStaleness = false; // not sure what this is about + + if (all_streams == true) { + MediaFormat format; + format.stream = -2; + format.type = TYPE_AUDIO; + params.formats.insert(format); + + format.type = TYPE_VIDEO; + format.stream = -2; + format.format.video.width = 0; + format.format.video.height = 0; + format.format.video.cropImage = 0; + format.format.video.format = defaultVideoPixelFormat; + params.formats.insert(format); + + format.type = TYPE_SUBTITLE; + format.stream = -2; + params.formats.insert(format); + + format.type = TYPE_CC; + format.stream = -2; + params.formats.insert(format); + } else { + // parse stream type + MediaType stream_type = parse_type_to_mt(stream); + + // TODO: reset params.formats + std::set formats; + params.formats = formats; + // Define new format + MediaFormat format; + format.type = stream_type; + format.stream = stream_id; + if (stream_type == TYPE_VIDEO) { + format.format.video.width = 0; + format.format.video.height = 0; + format.format.video.cropImage = 0; + format.format.video.format = defaultVideoPixelFormat; + } + params.formats.insert(format); + } + +} // _get decoder params + +Video::Video(std::string videoPath, std::string stream) { + // parse stream information + current_stream = _parseStream(stream); + // note that in the initial call we want to get all streams + Video::_getDecoderParams( + 0, // video start + 0, // headerOnly + get<0>(current_stream), // stream info - remove that + long(-1), // stream_id parsed from info above change to -2 + true // read all streams + ); + + std::string logMessage, logType; + + // TODO: add read from memory option + params.uri = videoPath; + logType = "file"; + logMessage = videoPath; + + // locals + std::vector audioFPS, videoFPS, ccFPS, subsFPS; + std::vector audioDuration, videoDuration, ccDuration, subsDuration; + std::vector audioTB, videoTB, ccTB, subsTB; + c10::Dict> audioMetadata; + c10::Dict> videoMetadata; + + // calback and metadata defined in struct + succeeded = decoder.init(params, std::move(callback), &metadata); + if (succeeded) { + for (const auto& header : metadata) { + double fps = double(header.fps); + double duration = double(header.duration) * 1e-6; // * timeBase; + + if (header.format.type == TYPE_VIDEO) { + videoFPS.push_back(fps); + videoDuration.push_back(duration); + } else if (header.format.type == TYPE_AUDIO) { + audioFPS.push_back(fps); + audioDuration.push_back(duration); + } else if (header.format.type == TYPE_CC) { + ccFPS.push_back(fps); + ccDuration.push_back(duration); + } else if (header.format.type == TYPE_SUBTITLE) { + subsFPS.push_back(fps); + subsDuration.push_back(duration); + }; + } + } + audioMetadata.insert("duration", audioDuration); + audioMetadata.insert("framerate", audioFPS); + videoMetadata.insert("duration", videoDuration); + videoMetadata.insert("fps", videoFPS); + streamsMetadata.insert("video", videoMetadata); + streamsMetadata.insert("audio", audioMetadata); + + succeeded = Video::setCurrentStream(stream); + LOG(INFO) << "\nDecoder inited with: " << succeeded << "\n"; + if (get<1>(current_stream) != -1) { + LOG(INFO) + << "Stream index set to " << get<1>(current_stream) + << ". If you encounter trouble, consider switching it to automatic stream discovery. \n"; + } +} // video + +bool Video::setCurrentStream(std::string stream = "video") { + if ((!stream.empty()) && (_parseStream(stream) != current_stream)) { + current_stream = _parseStream(stream); + } + + double ts = 0; + if (seekTS > 0) { + ts = seekTS; + } + + _getDecoderParams( + ts, // video start + 0, // headerOnly + get<0>(current_stream), // stream + long(get<1>( + current_stream)), // stream_id parsed from info above change to -2 + false // read all streams + ); + + // calback and metadata defined in Video.h + return (decoder.init(params, std::move(callback), &metadata)); +} + +std::tuple Video::getCurrentStream() const { + return current_stream; +} + +c10::Dict>> Video:: + getStreamMetadata() const { + return streamsMetadata; +} + +void Video::Seek(double ts) { + // initialize the class variables used for seeking and retrurn + _getDecoderParams( + ts, // video start + 0, // headerOnly + get<0>(current_stream), // stream + long(get<1>( + current_stream)), // stream_id parsed from info above change to -2 + false // read all streams + ); + + // calback and metadata defined in Video.h + succeeded = decoder.init(params, std::move(callback), &metadata); + LOG(INFO) << "Decoder init at seek " << succeeded << "\n"; +} + +std::tuple Video::Next() { + // if failing to decode simply return a null tensor (note, should we + // raise an exeption?) + double frame_pts_s; + torch::Tensor outFrame = torch::zeros({0}, torch::kByte); + + // decode single frame + DecoderOutputMessage out; + int64_t res = decoder.decode(&out, decoderTimeoutMs); + // if successfull + if (res == 0) { + frame_pts_s = double(double(out.header.pts) * 1e-6); + + auto header = out.header; + const auto& format = header.format; + + // initialize the output variables based on type + + if (format.type == TYPE_VIDEO) { + // note: this can potentially be optimized + // by having the global tensor that we fill at decode time + // (would avoid allocations) + int outHeight = format.format.video.height; + int outWidth = format.format.video.width; + int numChannels = 3; + outFrame = torch::zeros({outHeight, outWidth, numChannels}, torch::kByte); + auto numberWrittenBytes = fillVideoTensor(out, outFrame); + outFrame = outFrame.permute({2, 0, 1}); + + } else if (format.type == TYPE_AUDIO) { + int outAudioChannels = format.format.audio.channels; + int bytesPerSample = av_get_bytes_per_sample( + static_cast(format.format.audio.format)); + int frameSizeTotal = out.payload->length(); + + CHECK_EQ(frameSizeTotal % (outAudioChannels * bytesPerSample), 0); + int numAudioSamples = + frameSizeTotal / (outAudioChannels * bytesPerSample); + + outFrame = + torch::zeros({numAudioSamples, outAudioChannels}, torch::kFloat); + + auto numberWrittenBytes = fillAudioTensor(out, outFrame); + } + // currently not supporting other formats (will do soon) + + out.payload.reset(); + } else if (res == 61) { + LOG(INFO) << "Decoder ran out of frames (error 61)\n"; + } else { + LOG(ERROR) << "Decoder failed with ERROR_CODE " << res; + } + + std::tuple result = {outFrame, frame_pts_s}; + return result; +} diff --git a/torchvision/csrc/cpu/video/Video.h b/torchvision/csrc/cpu/video/Video.h new file mode 100644 index 00000000000..8060adfcfce --- /dev/null +++ b/torchvision/csrc/cpu/video/Video.h @@ -0,0 +1,61 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include "defs.h" +#include "memory_buffer.h" +#include "sync_decoder.h" + +using namespace ffmpeg; + +struct Video : torch::CustomClassHolder { + std::tuple current_stream; // stream type, id + // global video metadata + c10::Dict>> + streamsMetadata; + + public: + Video(std::string videoPath, std::string stream); + std::tuple getCurrentStream() const; + c10::Dict>> + getStreamMetadata() const; + void Seek(double ts); + bool setCurrentStream(std::string stream); + std::tuple Next(); + + private: + bool video_any_frame = false; // add this to input parameters? + bool succeeded = false; // decoder init flag + // seekTS and doSeek act as a flag - if it's not set, next function simply + // retruns the next frame. If it's set, we look at the global seek + // time in comination with any_frame settings + double seekTS = -1; + bool doSeek = false; + + void _getDecoderParams( + double videoStartS, + int64_t getPtsOnly, + std::string stream, + long stream_id, + bool all_streams, + double seekFrameMarginUs); // this needs to be improved + + std::map> streamTimeBase; // not used + + DecoderInCallback callback = nullptr; + std::vector metadata; + + protected: + SyncDecoder decoder; + DecoderParameters params; + +}; // struct Video diff --git a/torchvision/csrc/cpu/video/register.cpp b/torchvision/csrc/cpu/video/register.cpp new file mode 100644 index 00000000000..a88615987bf --- /dev/null +++ b/torchvision/csrc/cpu/video/register.cpp @@ -0,0 +1,18 @@ +#ifndef REGISTER_H +#define REGISTER_H + +#include "Video.h" + +namespace { + +static auto registerVideo = + torch::class_