Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ExternalSource refactoring and fixing #5690

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 38 additions & 15 deletions dali/core/access_order.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -25,6 +25,13 @@ AccessOrder::AccessOrder(cudaStream_t stream) : stream_(stream) {
device_id_ = DeviceFromStream(stream);
}

constexpr bool is_ambiguous_handle(cudaStream_t stream) {
return
stream == 0 ||
stream == cudaStreamPerThread ||
stream == cudaStreamLegacy;
}

void AccessOrder::wait(const AccessOrder &other) const {
if (*this == other)
return;
Expand All @@ -33,44 +40,60 @@ void AccessOrder::wait(const AccessOrder &other) const {
// always considered up-to-date.
if (!has_value() || !other.is_device())
return;

auto current_dev = []() {
int dev;
CUDA_CALL(cudaGetDevice(&dev));
return dev;
};

auto need_device_switch = [&]() {
return is_ambiguous_handle(other.stream_) && other.device_id() != current_dev();
};

if (is_device()) {
auto &pool = CUDAEventPool::instance();
int other_dev = other.device_id();
auto event = pool.Get(other_dev);
// Record an event in the preceding stream

auto current_dev = []() {
int dev;
CUDA_CALL(cudaGetDevice(&dev));
return dev;
};

// If the stream handle has a special value, we can't refer to it directly - it is
// inherently associated with the concept of "current device" and it must be switched
if (other_dev != device_id_ ||
((other.stream_ == 0 ||
other.stream_ == cudaStreamPerThread ||
other.stream_ == cudaStreamLegacy) &&
other_dev != current_dev())) {
if (need_device_switch()) {
DeviceGuard dg(other.device_id_);
CUDA_CALL(cudaEventRecord(event, other.stream()));
} else {
CUDA_CALL(cudaEventRecord(event, other.stream()));
}
// and wait for it in this stream
CUDA_CALL(cudaStreamWaitEvent(stream(), event, 0));
if (is_ambiguous_handle(stream())) {
DeviceGuard dg(device_id_);
CUDA_CALL(cudaStreamWaitEvent(stream(), event, 0));
} else {
CUDA_CALL(cudaStreamWaitEvent(stream(), event, 0));
}
pool.Put(std::move(event), other_dev);
} else {
// host order - wait for the preceding stream on host
CUDA_CALL(cudaStreamSynchronize(other.stream()));
if (need_device_switch()) {
DeviceGuard dg(device_id_);
CUDA_CALL(cudaStreamSynchronize(other.stream()));
} else {
CUDA_CALL(cudaStreamSynchronize(other.stream()));
}
}
}

void AccessOrder::wait(cudaEvent_t event) const {
if (!has_value())
throw std::logic_error("A null AccessOrder cannot wait for an event.");
if (is_device()) {
CUDA_DTOR_CALL(cudaStreamWaitEvent(stream(), event, 0));
if (is_ambiguous_handle(stream())) {
DeviceGuard dg(device_id_);
CUDA_DTOR_CALL(cudaStreamWaitEvent(stream(), event, 0));
} else {
CUDA_DTOR_CALL(cudaStreamWaitEvent(stream(), event, 0));
}
} else {
CUDA_DTOR_CALL(cudaEventSynchronize(event));
}
Expand Down
84 changes: 40 additions & 44 deletions dali/pipeline/operator/builtin/caching_list.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -15,11 +15,13 @@
#ifndef DALI_PIPELINE_OPERATOR_BUILTIN_CACHING_LIST_H_
#define DALI_PIPELINE_OPERATOR_BUILTIN_CACHING_LIST_H_

#include <stdexcept>
#include <list>
#include <memory>
#include <utility>
#include <stdexcept>

namespace dali {

/**
* CachingList differs from std::List by the ability to recycle empty elements. When allocating
* memory is expensive it is better to store already allocated but no longer needed element in the
Expand Down Expand Up @@ -47,6 +49,19 @@ class CachingList {
public:
CachingList() : prophet_(full_data_.end()) {}

class Item {
public:
Item() = default;
T &operator*() const & noexcept { return l_.front(); }
T &&operator*() && noexcept { return l_.front(); }

T *operator->() const & noexcept { return &l_.front(); }
private:
explicit Item(std::list<T> &&l) : l_(std::move(l)) {}
mutable std::list<T> l_;
friend class CachingList<T>;
};


bool IsEmpty() const {
return full_data_.empty();
Expand All @@ -58,50 +73,43 @@ class CachingList {
}


std::list<T> PopFront() {
assert(!full_data_.empty()); // Can't pop from an empty list
Item PopFront() {
if (full_data_.empty())
throw std::out_of_range("Cannot pop an item from an empty list");
std::list<T> tmp;
tmp.splice(tmp.begin(), full_data_, full_data_.begin());
if (tmp.begin() == prophet_)
prophet_ = full_data_.begin();
return tmp;
assert(tmp.size() == 1u);
return Item(std::move(tmp));
}


void Recycle(std::list<T> &elm) {
empty_data_.splice(empty_data_.end(), elm, elm.begin());
void Recycle(Item &&elm) {
empty_data_.splice(empty_data_.end(), elm.l_, elm.l_.begin(), elm.l_.end());
}


std::list<T> GetEmpty() {
Item GetEmpty() {
std::list<T> tmp;
if (empty_data_.empty()) {
tmp.emplace_back(std::make_unique<typename T::element_type>());
tmp.emplace_back();
} else {
tmp.splice(tmp.begin(), empty_data_, empty_data_.begin());
}
return tmp;
return Item(std::move(tmp));
}


void PushBack(std::list<T> &elm) {
full_data_.splice(full_data_.end(), elm, elm.begin());
/*
* When the prophet is dead and needs to be resurrected,
* he shall be resurrected by the apprentice.
* In the special scenario, when prophet is dead and the data list is empty
* (hence the apprentice is dead too), the prophet will be resurrected
* from scratch, by assigning him to the element that was just added to the data list.
* Sic mundus creatus est.
*/
if (resurrect_prophet_) {
if (full_data_.size() == 1) {
prophet_ = full_data_.begin();
} else {
prophet_ = std::next(apprentice_);
}
resurrect_prophet_ = false;
}
void PushBack(Item &&elm) {
if (elm.l_.empty())
throw std::logic_error("The element is empty - has it been moved out?");

// If the "prophet" is at the end of the list, we'll need to restore it to point to the
// beginning of the newly appended item.
if (prophet_ == full_data_.end() || full_data_.empty())
prophet_ = elm.l_.begin();
full_data_.splice(full_data_.end(), elm.l_, elm.l_.begin(), elm.l_.end());
}


Expand All @@ -119,8 +127,7 @@ class CachingList {
throw std::out_of_range(
"Attempted to move to the data batch that doesn't exist. Add more elements to"
" the DALI input operator.");
apprentice_ = prophet_++;
resurrect_prophet_ = prophet_ == full_data_.end();
++prophet_;
}


Expand All @@ -132,20 +139,9 @@ class CachingList {
std::list<T> full_data_;
std::list<T> empty_data_;

/**
* Prophet dies when he hits the end() iterator of the list with the data.
* Prophet can be resurrected, iff there is a data record for him, i.e.
* when user calls PushBack and therefore inserts the data at the end
* of the CachingList
*/
bool resurrect_prophet_ = true;

/**
* The apprentice follows the prophet and is always one step behind him.
* Apprentice is used to resurrect the prophet, so that the prophet might
* again point to the last actual element of the list.
*/
typename std::list<T>::iterator prophet_, apprentice_;
// The "prophet" is a separate lookahead pointer into the list, used for peeking into
// future items without altering the contents of the list.
typename std::list<T>::iterator prophet_;
};

} // namespace dali
Expand Down
40 changes: 20 additions & 20 deletions dali/pipeline/operator/builtin/caching_list_test.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -14,6 +14,7 @@

#include "dali/pipeline/operator/builtin/caching_list.h"
#include <gtest/gtest.h>
#include <utility>

namespace dali::test {

Expand All @@ -33,50 +34,49 @@ struct TestType {


TEST(CachingListTest, ProphetTest) {
CachingList<std::unique_ptr<TestType<int>>> cl;
CachingList<TestType<int>> cl;

auto push = [&](int val) {
auto elem = cl.GetEmpty();
elem.emplace_back(std::make_unique<TestType<int>>());
elem.front()->val = val;
cl.PushBack(elem);
elem->val = val;
cl.PushBack(std::move(elem));
};

ASSERT_THROW(cl.PeekProphet(), std::out_of_range);
push(6);
EXPECT_EQ(*cl.PeekProphet(), 6);
EXPECT_EQ(cl.PeekProphet(), 6);
push(9);
EXPECT_EQ(*cl.PeekProphet(), 6);
EXPECT_EQ(cl.PeekProphet(), 6);
cl.AdvanceProphet();
EXPECT_EQ(*cl.PeekProphet(), 9);
EXPECT_EQ(cl.PeekProphet(), 9);
push(13);
EXPECT_EQ(*cl.PeekProphet(), 9);
EXPECT_EQ(cl.PeekProphet(), 9);
cl.AdvanceProphet();
EXPECT_EQ(*cl.PeekProphet(), 13);
EXPECT_EQ(cl.PeekProphet(), 13);
push(42);
EXPECT_EQ(*cl.PeekProphet(), 13);
EXPECT_EQ(cl.PeekProphet(), 13);
push(69);
EXPECT_EQ(*cl.PeekProphet(), 13);
EXPECT_EQ(cl.PeekProphet(), 13);
cl.AdvanceProphet();
EXPECT_EQ(*cl.PeekProphet(), 42);
EXPECT_EQ(cl.PeekProphet(), 42);
cl.AdvanceProphet();
EXPECT_EQ(*cl.PeekProphet(), 69);
EXPECT_EQ(cl.PeekProphet(), 69);
cl.AdvanceProphet();
ASSERT_THROW(cl.PeekProphet(), std::out_of_range);
push(666);
EXPECT_EQ(*cl.PeekProphet(), 666);
EXPECT_EQ(cl.PeekProphet(), 666);
push(1337);
EXPECT_EQ(*cl.PeekProphet(), 666);
EXPECT_EQ(cl.PeekProphet(), 666);
cl.AdvanceProphet();
EXPECT_EQ(*cl.PeekProphet(), 1337);
EXPECT_EQ(cl.PeekProphet(), 1337);
cl.AdvanceProphet();
ASSERT_THROW(cl.PeekProphet(), std::out_of_range);
push(1234);
EXPECT_EQ(*cl.PeekProphet(), 1234);
EXPECT_EQ(cl.PeekProphet(), 1234);
push(4321);
EXPECT_EQ(*cl.PeekProphet(), 1234);
EXPECT_EQ(cl.PeekProphet(), 1234);
cl.AdvanceProphet();
EXPECT_EQ(*cl.PeekProphet(), 4321);
EXPECT_EQ(cl.PeekProphet(), 4321);
cl.AdvanceProphet();
ASSERT_THROW(cl.PeekProphet(), std::out_of_range);
ASSERT_THROW(cl.AdvanceProphet(), std::out_of_range);
Expand Down
Loading
Loading