Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exact Seek #78

Merged
merged 36 commits into from
Jul 13, 2020
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions include/decord/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ namespace runtime {
*/
class NDArray {
public:
// pts of the frame
int pts=0;
// internal container type
struct Container;
/*! \brief default constructor */
Expand All @@ -56,7 +58,7 @@ class NDArray {
* \param other The value to be moved
*/
NDArray(NDArray&& other) // NOLINT(*)
: data_(other.data_) {
: pts(other.pts), data_(other.data_) {
other.data_ = nullptr;
}
/*! \brief destructor */
Expand All @@ -69,6 +71,7 @@ class NDArray {
*/
void swap(NDArray& other) { // NOLINT(*)
std::swap(data_, other.data_);
std::swap(pts, other.pts);
}
/*!
* \brief copy assignmemt
Expand Down Expand Up @@ -305,7 +308,7 @@ inline NDArray::NDArray(Container* data)
}

inline NDArray::NDArray(const NDArray& other)
: data_(other.data_) {
: pts(other.pts), data_(other.data_) {
if (data_ != nullptr) {
data_->IncRef();
}
Expand Down Expand Up @@ -368,6 +371,7 @@ inline void NDArray::CopyFrom(DLTensor* other) {
inline void NDArray::CopyFrom(const NDArray& other) {
CHECK(data_ != nullptr);
CHECK(other.data_ != nullptr);
pts = other.pts;
CopyFromTo(&(other.data_->dl_tensor), &(data_->dl_tensor));
}

Expand All @@ -379,7 +383,7 @@ inline void NDArray::CopyFrom(std::vector<T>& other, std::vector<int64_t>& shape
size *= s;
}
CHECK(other.size() == size) << "other: " << other.size() << " this: " << size;
DLTensor dlt = CreateDLTensorView(other, shape);
DLTensor dlt = CreateDLTensorView(other, shape);
CopyFromTo(&dlt, &(data_->dl_tensor));
}

Expand All @@ -391,6 +395,7 @@ inline void NDArray::CopyTo(DLTensor* other) const {
inline void NDArray::CopyTo(const NDArray& other) const {
CHECK(data_ != nullptr);
CHECK(other.data_ != nullptr);
// no copy of pts
CopyFromTo(&(data_->dl_tensor), &(other.data_->dl_tensor));
}

Expand Down
6 changes: 5 additions & 1 deletion src/video/ffmpeg/threaded_decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ void FFMPEGThreadedDecoder::ProcessFrame(AVFramePtr frame, NDArray out_buf) {
}
if (skip) {
// skip resize/filtering
frame_queue_->Push(NDArray::Empty({1}, kUInt8, kCPU));
NDArray empty = NDArray::Empty({1}, kUInt8, kCPU);
empty.pts = frame->pts;
frame_queue_->Push(empty);
++frame_count_;
return;
}
Expand Down Expand Up @@ -260,6 +262,7 @@ NDArray FFMPEGThreadedDecoder::CopyToNDArray(AVFramePtr p) {
to_ptr, i * linesize,
linesize, ctx, ctx, kUInt8, nullptr);
}
arr.pts = p->pts;
return arr;
}

Expand All @@ -279,6 +282,7 @@ NDArray FFMPEGThreadedDecoder::AsNDArray(AVFramePtr p) {
ToDLTensor(p, manager->dl_tensor, av_manager->shape);
manager->deleter = AVFrameManagerDeleter;
NDArray arr = NDArray::FromDLPack(manager);
arr.pts = p->pts;
return arr;
}

Expand Down
134 changes: 107 additions & 27 deletions src/video/video_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ static const int AVIO_BUFFER_SIZE = 40960;


VideoReader::VideoReader(std::string fn, DLContext ctx, int width, int height, int nb_thread, int io_type)
: ctx_(ctx), key_indices_(), frame_ts_(), codecs_(), actv_stm_idx_(-1), fmt_ctx_(nullptr), decoder_(nullptr), curr_frame_(0),
: ctx_(ctx), key_indices_(), pts_frame_map_(), tmp_key_frame_(), overrun_(false), frame_ts_(), codecs_(), actv_stm_idx_(-1), fmt_ctx_(nullptr), decoder_(nullptr), curr_frame_(0),
nb_thread_decoding_(nb_thread), width_(width), height_(height), eof_(false), io_ctx_() {
// av_register_all deprecated in latest versions
#if ( LIBAVFORMAT_VERSION_INT < AV_VERSION_INT(58,9,100) )
Expand Down Expand Up @@ -75,7 +75,7 @@ VideoReader::VideoReader(std::string fn, DLContext ctx, int width, int height, i
}
return;
}

fmt_ctx_.reset(fmt_ctx);

// find stream info
Expand Down Expand Up @@ -249,17 +249,15 @@ int64_t VideoReader::GetCurrentPosition() const {
}

int64_t VideoReader::FrameToPTS(int64_t pos) {
int64_t ts = pos * fmt_ctx_->streams[actv_stm_idx_]->duration / GetFrameCount();
int64_t ts = frame_ts_[pos].pts;
return ts;
}

std::vector<int64_t> VideoReader::FramesToPTS(const std::vector<int64_t>& positions) {
auto nframe = GetFrameCount();
auto duration = fmt_ctx_->streams[actv_stm_idx_]->duration;
std::vector<int64_t> ret;
ret.reserve(positions.size());
for (auto pos : positions) {
ret.emplace_back(pos * duration / nframe);
ret.emplace_back(frame_ts_[pos].pts);
}
return ret;
}
Expand All @@ -271,7 +269,14 @@ bool VideoReader::Seek(int64_t pos) {
eof_ = false;

int64_t ts = FrameToPTS(pos);
int ret = av_seek_frame(fmt_ctx_.get(), actv_stm_idx_, ts, AVSEEK_FLAG_BACKWARD);
int flag = curr_frame_ > pos ? AVSEEK_FLAG_BACKWARD : 0;

// std::cout << "Seek " << pos << " at pts " << ts << ", flag " << flag << std::endl;
int ret = av_seek_frame(fmt_ctx_.get(), actv_stm_idx_, ts, flag);
if (flag != AVSEEK_FLAG_BACKWARD && ret < 0){
// std::cout << "seek wrong, retry with flag " << AVSEEK_FLAG_BACKWARD << std::endl;
ret = av_seek_frame(fmt_ctx_.get(), actv_stm_idx_, ts, AVSEEK_FLAG_BACKWARD);
}
if (ret < 0) LOG(WARNING) << "Failed to seek file to position: " << pos;
// LOG(INFO) << "seek return: " << ret;
decoder_->Start();
Expand All @@ -281,6 +286,20 @@ bool VideoReader::Seek(int64_t pos) {
return ret >= 0;
}

bool VideoReader::SeekStart() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we merge this into a special case of Seek(0) in Seek?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion! I have substituted the SeekStart() by Seek(0) and the decoding speed remains similar.

if (!fmt_ctx_) return false;
if (curr_frame_ == 0) return true;
decoder_->Clear();
eof_ = false;
int64_t ts = FrameToPTS(0);
int ret = av_seek_frame(fmt_ctx_.get(), actv_stm_idx_, ts, AVSEEK_FLAG_BACKWARD);
if (ret < 0) LOG(WARNING) << "Failed to seek file to position: 0";
if (ret >= 0) {
curr_frame_ = 0;
}
return ret >= 0;
}

int64_t VideoReader::LocateKeyframe(int64_t pos) {
if (key_indices_.size() < 1) return 0;
if (pos <= key_indices_[0]) return 0;
Expand All @@ -294,19 +313,37 @@ bool VideoReader::SeekAccurate(int64_t pos) {
if (curr_frame_ == pos) return true;
int64_t key_pos = LocateKeyframe(pos);
int64_t curr_key_pos = LocateKeyframe(curr_frame_);
if (key_pos != curr_key_pos) {
overrun_ = false;
// std::cout << "seek " << pos << "(" << frame_ts_[pos].pts << "), nearest key " << key_pos << "(" << frame_ts_[key_pos].pts << "), current pos "
// << curr_frame_ << "(" << frame_ts_[curr_frame_].pts << "), current key " << curr_key_pos << "(" << frame_ts_[curr_key_pos].pts << ")" << std:: endl;
if (key_pos != curr_key_pos || pos < curr_frame_) {
// need to seek to keyframes first
bool ret = Seek(key_pos);
// std::cout << "need to seek to keyframe " << key_pos << " first " << std::endl;
// first rewind to 0, in order to increase seek accuracy
bool ret = SeekStart();
if (!ret) return false;
SkipFrames(pos - key_pos);
} else if (pos < curr_frame_) {
// need seek backwards to the nearest keyframe
bool ret = Seek(key_pos);
ret = Seek(key_pos);
if (!ret) return false;
SkipFrames(pos - key_pos);
// double check if keyframe was jumpped correctly
if(CheckKeyFrame()){
if(pos - key_pos > 0){
SkipFramesImpl(pos - curr_frame_);
} else if(pos - key_pos == 0){
overrun_ = true;
}
} else {
if(curr_frame_ < pos){
SkipFramesImpl(pos - curr_frame_);
} else {
key_pos = LocateKeyframe(pos);
// since curr_frame_ is larger, Seek will use AVSEEK_FLAG_BACKWARD
Seek(key_pos);
SkipFramesImpl(pos - key_pos);
}
}
} else {
// no need to seek to keyframe, since both current and seek position belong to same keyframe
SkipFrames(pos - curr_frame_);
SkipFramesImpl(pos - curr_frame_);
}
return true;
}
Expand Down Expand Up @@ -351,11 +388,17 @@ void VideoReader::PushNext() {
}

NDArray VideoReader::NextFrameImpl() {
if (overrun_)
{
overrun_ = false;
return tmp_key_frame_;
}
NDArray frame;
decoder_->Start();
bool ret = false;
int rewind_offset = 0;
while (!ret) {
// std::cout << "!!" << std::endl;
PushNext();
if (curr_frame_ >= GetFrameCount()) {
return NDArray::Empty({}, kUInt8, ctx_);
Expand All @@ -377,7 +420,7 @@ NDArray VideoReader::NextFrameImpl() {

NDArray VideoReader::NextFrame() {
if (!fmt_ctx_) return NDArray();
return NextFrameImpl();
return NextFrameImpl();
}

void VideoReader::IndexKeyframes() {
Expand Down Expand Up @@ -409,6 +452,7 @@ void VideoReader::IndexKeyframes() {
auto start_pts = (packet->pts - start_sec) * ts_factor;
auto stop_pts = (packet->pts + packet->duration - start_sec) * ts_factor;
frame_ts_.emplace_back(AVFrameTime(packet->pts, packet->dts, start_pts, stop_pts));
// std::cout << ((packet->flags & AV_PKT_FLAG_KEY) ? "*" : "") << cnt << ": pts " << packet->pts << ", dts " << packet->dts << ", start pts " << start_pts << ", stop pts " << stop_pts << std::endl;
if (packet->flags & AV_PKT_FLAG_KEY) {
key_indices_.emplace_back(cnt);
}
Expand All @@ -419,8 +463,13 @@ void VideoReader::IndexKeyframes() {
std::sort(std::begin(frame_ts_), std::end(frame_ts_),
[](const AVFrameTime& a, const AVFrameTime& b) -> bool
{return a.pts < b.pts;});

for (size_t i = 0; i < frame_ts_.size(); ++i){
pts_frame_map_.insert(std::pair<int64_t, int64_t>(frame_ts_[i].pts, i));
// std::cout << i << ": pts " << frame_ts_[i].pts << ", dts " << frame_ts_[i].dts << ", start pts " << frame_ts_[i].start << ", stop pts " << frame_ts_[i].stop << std::endl;
}
curr_frame_ = GetFrameCount();
ret = Seek(0);
ret = SeekStart();
}

runtime::NDArray VideoReader::GetKeyIndices() {
Expand Down Expand Up @@ -497,26 +546,63 @@ void VideoReader::SkipFrames(int64_t num) {
// LOG(INFO) << "current: " << curr_frame_ << ", adjust skip from " << num << " to " << num + old_frame - *it2;
num += old_frame - *it2;
}

SkipFramesImpl(num);
}

bool VideoReader::CheckKeyFrame()
{
// check curr_frame_ is correct or not, by decoding the current frame
NDArray frame;
decoder_->Start();
bool ret = false;
int64_t cf = curr_frame_;
while (!ret)
{
PushNext();
ret = decoder_->Pop(&frame);
}

// find the real current frame after decoding
auto iter = pts_frame_map_.find(frame.pts);
if (iter != pts_frame_map_.end())
cf = iter->second;
if (curr_frame_ != cf)
{
curr_frame_ = cf + 1;
return false;
} else{
++curr_frame_;
tmp_key_frame_ = frame;
return true;
}

}

void VideoReader::SkipFramesImpl(int64_t num)
{
if (!fmt_ctx_)
return;
num = std::min(GetFrameCount() - curr_frame_, num);
if (num < 1) return;

// LOG(INFO) << "started skipping with: " << num;
NDArray frame;
decoder_->Start();
bool ret = false;
std::vector<int64_t> frame_pos(num);
std::iota(frame_pos.begin(), frame_pos.end(), curr_frame_);
auto pts = FramesToPTS(frame_pos);
decoder_->SuggestDiscardPTS(pts);
curr_frame_ += num;

while (num > 0) {
PushNext();
ret = decoder_->Pop(&frame);
if (!ret) continue;
++curr_frame_;
// LOG(INFO) << "skip: " << num;
--num;
}
decoder_->ClearDiscardPTS();
// LOG(INFO) << " stopped skipframes: " << curr_frame_;
}

NDArray VideoReader::GetBatch(std::vector<int64_t> indices, NDArray buf) {
Expand Down Expand Up @@ -554,13 +640,7 @@ NDArray VideoReader::GetBatch(std::vector<int64_t> indices, NDArray buf) {
else {
CHECK_LT(pos, frame_count);
CHECK_GE(pos, 0);
if (curr_frame_ == pos) {
// no need to seek
} else if (pos > curr_frame_) {
// skip positive number of frames
SkipFrames(pos - curr_frame_);
} else {
// seek no matter what
if (curr_frame_ != pos) {
SeekAccurate(pos);
}
NDArray frame = NextFrameImpl();
Expand Down
10 changes: 8 additions & 2 deletions src/video/video_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,19 @@ class VideoReader : public VideoReaderInterface {
private:
void IndexKeyframes();
void PushNext();
bool SeekStart();
int64_t LocateKeyframe(int64_t pos);
void SkipFramesImpl(int64_t num = 1);
bool CheckKeyFrame();
NDArray NextFrameImpl();
int64_t FrameToPTS(int64_t pos);
std::vector<int64_t> FramesToPTS(const std::vector<int64_t>& positions);

DLContext ctx_;
std::vector<int64_t> key_indices_;
std::map<int64_t, int64_t> pts_frame_map_;
NDArray tmp_key_frame_;
bool overrun_;
/*! \brief a lookup table for per frame pts/dts */
std::vector<AVFrameTime> frame_ts_;
/*! \brief Video Streams Codecs in original videos */
Expand All @@ -79,7 +85,7 @@ class VideoReader : public VideoReaderInterface {
bool eof_; // end of file indicator
NDArrayPool ndarray_pool_;
std::unique_ptr<ffmpeg::AVIOBytesContext> io_ctx_; // avio context for raw memory access

}; // class VideoReader
} // namespace decord
#endif // DECORD_VIDEO_VIDEO_READER_H_