diff --git a/include/decord/runtime/ndarray.h b/include/decord/runtime/ndarray.h index 52da737..7dece08 100644 --- a/include/decord/runtime/ndarray.h +++ b/include/decord/runtime/ndarray.h @@ -34,6 +34,8 @@ namespace runtime { */ class NDArray { public: + // pts of the frame + int pts=0; // internal container type struct Container; /*! \brief default constructor */ @@ -56,7 +58,7 @@ class NDArray { * \param other The value to be moved */ NDArray(NDArray&& other) // NOLINT(*) - : data_(other.data_) { + : pts(other.pts), data_(other.data_) { other.data_ = nullptr; } /*! \brief destructor */ @@ -69,6 +71,7 @@ class NDArray { */ void swap(NDArray& other) { // NOLINT(*) std::swap(data_, other.data_); + std::swap(pts, other.pts); } /*! * \brief copy assignmemt @@ -305,7 +308,7 @@ inline NDArray::NDArray(Container* data) } inline NDArray::NDArray(const NDArray& other) - : data_(other.data_) { + : pts(other.pts), data_(other.data_) { if (data_ != nullptr) { data_->IncRef(); } @@ -368,6 +371,7 @@ inline void NDArray::CopyFrom(DLTensor* other) { inline void NDArray::CopyFrom(const NDArray& other) { CHECK(data_ != nullptr); CHECK(other.data_ != nullptr); + pts = other.pts; CopyFromTo(&(other.data_->dl_tensor), &(data_->dl_tensor)); } @@ -379,7 +383,7 @@ inline void NDArray::CopyFrom(std::vector& other, std::vector& shape size *= s; } CHECK(other.size() == size) << "other: " << other.size() << " this: " << size; - DLTensor dlt = CreateDLTensorView(other, shape); + DLTensor dlt = CreateDLTensorView(other, shape); CopyFromTo(&dlt, &(data_->dl_tensor)); } @@ -391,6 +395,7 @@ inline void NDArray::CopyTo(DLTensor* other) const { inline void NDArray::CopyTo(const NDArray& other) const { CHECK(data_ != nullptr); CHECK(other.data_ != nullptr); + // no copy of pts CopyFromTo(&(data_->dl_tensor), &(other.data_->dl_tensor)); } diff --git a/src/video/ffmpeg/threaded_decoder.cc b/src/video/ffmpeg/threaded_decoder.cc index 5a59d28..774afa8 100644 --- a/src/video/ffmpeg/threaded_decoder.cc +++ b/src/video/ffmpeg/threaded_decoder.cc @@ -146,7 +146,9 @@ void FFMPEGThreadedDecoder::ProcessFrame(AVFramePtr frame, NDArray out_buf) { } if (skip) { // skip resize/filtering - frame_queue_->Push(NDArray::Empty({1}, kUInt8, kCPU)); + NDArray empty = NDArray::Empty({1}, kUInt8, kCPU); + empty.pts = frame->pts; + frame_queue_->Push(empty); ++frame_count_; return; } @@ -260,6 +262,7 @@ NDArray FFMPEGThreadedDecoder::CopyToNDArray(AVFramePtr p) { to_ptr, i * linesize, linesize, ctx, ctx, kUInt8, nullptr); } + arr.pts = p->pts; return arr; } @@ -279,6 +282,7 @@ NDArray FFMPEGThreadedDecoder::AsNDArray(AVFramePtr p) { ToDLTensor(p, manager->dl_tensor, av_manager->shape); manager->deleter = AVFrameManagerDeleter; NDArray arr = NDArray::FromDLPack(manager); + arr.pts = p->pts; return arr; } diff --git a/src/video/video_reader.cc b/src/video/video_reader.cc index 706130e..15d2657 100644 --- a/src/video/video_reader.cc +++ b/src/video/video_reader.cc @@ -25,7 +25,7 @@ static const int AVIO_BUFFER_SIZE = 40960; VideoReader::VideoReader(std::string fn, DLContext ctx, int width, int height, int nb_thread, int io_type) - : ctx_(ctx), key_indices_(), frame_ts_(), codecs_(), actv_stm_idx_(-1), fmt_ctx_(nullptr), decoder_(nullptr), curr_frame_(0), + : ctx_(ctx), key_indices_(), pts_frame_map_(), tmp_key_frame_(), overrun_(false), frame_ts_(), codecs_(), actv_stm_idx_(-1), fmt_ctx_(nullptr), decoder_(nullptr), curr_frame_(0), nb_thread_decoding_(nb_thread), width_(width), height_(height), eof_(false), io_ctx_() { // av_register_all deprecated in latest versions #if ( LIBAVFORMAT_VERSION_INT < AV_VERSION_INT(58,9,100) ) @@ -75,7 +75,7 @@ VideoReader::VideoReader(std::string fn, DLContext ctx, int width, int height, i } return; } - + fmt_ctx_.reset(fmt_ctx); // find stream info @@ -249,17 +249,15 @@ int64_t VideoReader::GetCurrentPosition() const { } int64_t VideoReader::FrameToPTS(int64_t pos) { - int64_t ts = pos * fmt_ctx_->streams[actv_stm_idx_]->duration / GetFrameCount(); + int64_t ts = frame_ts_[pos].pts; return ts; } std::vector VideoReader::FramesToPTS(const std::vector& positions) { - auto nframe = GetFrameCount(); - auto duration = fmt_ctx_->streams[actv_stm_idx_]->duration; std::vector ret; ret.reserve(positions.size()); for (auto pos : positions) { - ret.emplace_back(pos * duration / nframe); + ret.emplace_back(frame_ts_[pos].pts); } return ret; } @@ -271,7 +269,14 @@ bool VideoReader::Seek(int64_t pos) { eof_ = false; int64_t ts = FrameToPTS(pos); - int ret = av_seek_frame(fmt_ctx_.get(), actv_stm_idx_, ts, AVSEEK_FLAG_BACKWARD); + int flag = curr_frame_ > pos ? AVSEEK_FLAG_BACKWARD : 0; + + // std::cout << "Seek " << pos << " at pts " << ts << ", flag " << flag << std::endl; + int ret = av_seek_frame(fmt_ctx_.get(), actv_stm_idx_, ts, flag); + if (flag != AVSEEK_FLAG_BACKWARD && ret < 0){ + // std::cout << "seek wrong, retry with flag " << AVSEEK_FLAG_BACKWARD << std::endl; + ret = av_seek_frame(fmt_ctx_.get(), actv_stm_idx_, ts, AVSEEK_FLAG_BACKWARD); + } if (ret < 0) LOG(WARNING) << "Failed to seek file to position: " << pos; // LOG(INFO) << "seek return: " << ret; decoder_->Start(); @@ -294,19 +299,37 @@ bool VideoReader::SeekAccurate(int64_t pos) { if (curr_frame_ == pos) return true; int64_t key_pos = LocateKeyframe(pos); int64_t curr_key_pos = LocateKeyframe(curr_frame_); - if (key_pos != curr_key_pos) { + overrun_ = false; + // std::cout << "seek " << pos << "(" << frame_ts_[pos].pts << "), nearest key " << key_pos << "(" << frame_ts_[key_pos].pts << "), current pos " + // << curr_frame_ << "(" << frame_ts_[curr_frame_].pts << "), current key " << curr_key_pos << "(" << frame_ts_[curr_key_pos].pts << ")" << std:: endl; + if (key_pos != curr_key_pos || pos < curr_frame_) { // need to seek to keyframes first - bool ret = Seek(key_pos); + // std::cout << "need to seek to keyframe " << key_pos << " first " << std::endl; + // first rewind to 0, in order to increase seek accuracy + bool ret = Seek(0); if (!ret) return false; - SkipFrames(pos - key_pos); - } else if (pos < curr_frame_) { - // need seek backwards to the nearest keyframe - bool ret = Seek(key_pos); + ret = Seek(key_pos); if (!ret) return false; - SkipFrames(pos - key_pos); + // double check if keyframe was jumpped correctly + if(CheckKeyFrame()){ + if(pos - key_pos > 0){ + SkipFramesImpl(pos - curr_frame_); + } else if(pos - key_pos == 0){ + overrun_ = true; + } + } else { + if(curr_frame_ < pos){ + SkipFramesImpl(pos - curr_frame_); + } else { + key_pos = LocateKeyframe(pos); + // since curr_frame_ is larger, Seek will use AVSEEK_FLAG_BACKWARD + Seek(key_pos); + SkipFramesImpl(pos - key_pos); + } + } } else { // no need to seek to keyframe, since both current and seek position belong to same keyframe - SkipFrames(pos - curr_frame_); + SkipFramesImpl(pos - curr_frame_); } return true; } @@ -351,11 +374,17 @@ void VideoReader::PushNext() { } NDArray VideoReader::NextFrameImpl() { + if (overrun_) + { + overrun_ = false; + return tmp_key_frame_; + } NDArray frame; decoder_->Start(); bool ret = false; int rewind_offset = 0; while (!ret) { + // std::cout << "!!" << std::endl; PushNext(); if (curr_frame_ >= GetFrameCount()) { return NDArray::Empty({}, kUInt8, ctx_); @@ -377,7 +406,7 @@ NDArray VideoReader::NextFrameImpl() { NDArray VideoReader::NextFrame() { if (!fmt_ctx_) return NDArray(); - return NextFrameImpl(); + return NextFrameImpl(); } void VideoReader::IndexKeyframes() { @@ -409,6 +438,7 @@ void VideoReader::IndexKeyframes() { auto start_pts = (packet->pts - start_sec) * ts_factor; auto stop_pts = (packet->pts + packet->duration - start_sec) * ts_factor; frame_ts_.emplace_back(AVFrameTime(packet->pts, packet->dts, start_pts, stop_pts)); + // std::cout << ((packet->flags & AV_PKT_FLAG_KEY) ? "*" : "") << cnt << ": pts " << packet->pts << ", dts " << packet->dts << ", start pts " << start_pts << ", stop pts " << stop_pts << std::endl; if (packet->flags & AV_PKT_FLAG_KEY) { key_indices_.emplace_back(cnt); } @@ -419,6 +449,11 @@ void VideoReader::IndexKeyframes() { std::sort(std::begin(frame_ts_), std::end(frame_ts_), [](const AVFrameTime& a, const AVFrameTime& b) -> bool {return a.pts < b.pts;}); + + for (size_t i = 0; i < frame_ts_.size(); ++i){ + pts_frame_map_.insert(std::pair(frame_ts_[i].pts, i)); + // std::cout << i << ": pts " << frame_ts_[i].pts << ", dts " << frame_ts_[i].dts << ", start pts " << frame_ts_[i].start << ", stop pts " << frame_ts_[i].stop << std::endl; + } curr_frame_ = GetFrameCount(); ret = Seek(0); } @@ -497,9 +532,46 @@ void VideoReader::SkipFrames(int64_t num) { // LOG(INFO) << "current: " << curr_frame_ << ", adjust skip from " << num << " to " << num + old_frame - *it2; num += old_frame - *it2; } + + SkipFramesImpl(num); +} + +bool VideoReader::CheckKeyFrame() +{ + // check curr_frame_ is correct or not, by decoding the current frame + NDArray frame; + decoder_->Start(); + bool ret = false; + int64_t cf = curr_frame_; + while (!ret) + { + PushNext(); + ret = decoder_->Pop(&frame); + } + + // find the real current frame after decoding + auto iter = pts_frame_map_.find(frame.pts); + if (iter != pts_frame_map_.end()) + cf = iter->second; + if (curr_frame_ != cf) + { + curr_frame_ = cf + 1; + return false; + } else{ + ++curr_frame_; + tmp_key_frame_ = frame; + return true; + } + +} + +void VideoReader::SkipFramesImpl(int64_t num) +{ + if (!fmt_ctx_) + return; + num = std::min(GetFrameCount() - curr_frame_, num); if (num < 1) return; - // LOG(INFO) << "started skipping with: " << num; NDArray frame; decoder_->Start(); bool ret = false; @@ -507,16 +579,16 @@ void VideoReader::SkipFrames(int64_t num) { std::iota(frame_pos.begin(), frame_pos.end(), curr_frame_); auto pts = FramesToPTS(frame_pos); decoder_->SuggestDiscardPTS(pts); - curr_frame_ += num; + while (num > 0) { PushNext(); ret = decoder_->Pop(&frame); if (!ret) continue; + ++curr_frame_; // LOG(INFO) << "skip: " << num; --num; } decoder_->ClearDiscardPTS(); - // LOG(INFO) << " stopped skipframes: " << curr_frame_; } NDArray VideoReader::GetBatch(std::vector indices, NDArray buf) { @@ -554,13 +626,7 @@ NDArray VideoReader::GetBatch(std::vector indices, NDArray buf) { else { CHECK_LT(pos, frame_count); CHECK_GE(pos, 0); - if (curr_frame_ == pos) { - // no need to seek - } else if (pos > curr_frame_) { - // skip positive number of frames - SkipFrames(pos - curr_frame_); - } else { - // seek no matter what + if (curr_frame_ != pos) { SeekAccurate(pos); } NDArray frame = NextFrameImpl(); diff --git a/src/video/video_reader.h b/src/video/video_reader.h index 7bf8bb5..c220aad 100644 --- a/src/video/video_reader.h +++ b/src/video/video_reader.h @@ -57,12 +57,17 @@ class VideoReader : public VideoReaderInterface { void IndexKeyframes(); void PushNext(); int64_t LocateKeyframe(int64_t pos); + void SkipFramesImpl(int64_t num = 1); + bool CheckKeyFrame(); NDArray NextFrameImpl(); int64_t FrameToPTS(int64_t pos); std::vector FramesToPTS(const std::vector& positions); - + DLContext ctx_; std::vector key_indices_; + std::map pts_frame_map_; + NDArray tmp_key_frame_; + bool overrun_; /*! \brief a lookup table for per frame pts/dts */ std::vector frame_ts_; /*! \brief Video Streams Codecs in original videos */ @@ -79,7 +84,7 @@ class VideoReader : public VideoReaderInterface { bool eof_; // end of file indicator NDArrayPool ndarray_pool_; std::unique_ptr io_ctx_; // avio context for raw memory access - + }; // class VideoReader } // namespace decord #endif // DECORD_VIDEO_VIDEO_READER_H_