Skip to content

Commit

Permalink
Merge branch 'master' into gb_feature_fetcher_refactor_read_async
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Jul 22, 2024
2 parents 51f5ce5 + 8ef4a44 commit 239f313
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 12 deletions.
1 change: 1 addition & 0 deletions graphbolt/src/cnumpy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ torch::Tensor OnDiskNpyArray::IndexSelectIOUringImpl(torch::Tensor index) {
// Indicator for index error.
std::atomic<int> error_flag{};
std::atomic<int64_t> work_queue{};
std::lock_guard lock(mtx_);
torch::parallel_for(0, num_thread_, 1, [&](int64_t begin, int64_t end) {
if (begin >= end) return;
const auto thread_id = begin;
Expand Down
2 changes: 2 additions & 0 deletions graphbolt/src/cnumpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <fstream>
#include <iostream>
#include <memory>
#include <mutex>
#include <string>
#include <vector>

Expand Down Expand Up @@ -113,6 +114,7 @@ class OnDiskNpyArray : public torch::CustomClassHolder {
int64_t aligned_length_; // Aligned feature_size.
int num_thread_; // Default thread number.
torch::Tensor read_tensor_; // Provides temporary read buffer.
std::mutex mtx_;

#ifdef HAVE_LIBRARY_LIBURING
std::unique_ptr<io_uring[]> io_uring_queue_; // io_uring queue.
Expand Down
39 changes: 27 additions & 12 deletions graphbolt/src/partitioned_cache_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ PartitionedCachePolicy::Partition(torch::Tensor keys) {

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
PartitionedCachePolicy::Query(torch::Tensor keys) {
if (policies_.size() == 1) return policies_[0]->Query(keys);
if (policies_.size() == 1) {
std::lock_guard lock(mtx_);
return policies_[0]->Query(keys);
};
torch::Tensor offsets, indices, permuted_keys;
std::tie(offsets, indices, permuted_keys) = Partition(keys);
auto offsets_ptr = offsets.data_ptr<int64_t>();
Expand All @@ -127,16 +130,22 @@ PartitionedCachePolicy::Query(torch::Tensor keys) {
torch::Tensor result_offsets_tensor =
torch::empty(policies_.size() * 2 + 1, offsets.options());
auto result_offsets = result_offsets_tensor.data_ptr<int64_t>();
torch::parallel_for(0, policies_.size(), 1, [&](int64_t begin, int64_t end) {
if (begin == end) return;
TORCH_CHECK(end - begin == 1);
const auto tid = begin;
begin = offsets_ptr[tid];
end = offsets_ptr[tid + 1];
results[tid] = policies_.at(tid)->Query(permuted_keys.slice(0, begin, end));
result_offsets[tid] = std::get<0>(results[tid]).size(0);
result_offsets[tid + policies_.size()] = std::get<2>(results[tid]).size(0);
});
{
std::lock_guard lock(mtx_);
torch::parallel_for(
0, policies_.size(), 1, [&](int64_t begin, int64_t end) {
if (begin == end) return;
TORCH_CHECK(end - begin == 1);
const auto tid = begin;
begin = offsets_ptr[tid];
end = offsets_ptr[tid + 1];
results[tid] =
policies_.at(tid)->Query(permuted_keys.slice(0, begin, end));
result_offsets[tid] = std::get<0>(results[tid]).size(0);
result_offsets[tid + policies_.size()] =
std::get<2>(results[tid]).size(0);
});
}
std::exclusive_scan(
result_offsets, result_offsets + result_offsets_tensor.size(0),
result_offsets, 0);
Expand Down Expand Up @@ -198,7 +207,10 @@ PartitionedCachePolicy::QueryAsync(torch::Tensor keys) {
}

torch::Tensor PartitionedCachePolicy::Replace(torch::Tensor keys) {
if (policies_.size() == 1) return policies_[0]->Replace(keys);
if (policies_.size() == 1) {
std::lock_guard lock(mtx_);
return policies_[0]->Replace(keys);
}
torch::Tensor offsets, indices, permuted_keys;
std::tie(offsets, indices, permuted_keys) = Partition(keys);
auto output_positions = torch::empty_like(
Expand All @@ -208,6 +220,7 @@ torch::Tensor PartitionedCachePolicy::Replace(torch::Tensor keys) {
auto offsets_ptr = offsets.data_ptr<int64_t>();
auto indices_ptr = indices.data_ptr<int64_t>();
auto output_positions_ptr = output_positions.data_ptr<int64_t>();
std::lock_guard lock(mtx_);
torch::parallel_for(0, policies_.size(), 1, [&](int64_t begin, int64_t end) {
if (begin == end) return;
const auto tid = begin;
Expand All @@ -231,12 +244,14 @@ c10::intrusive_ptr<Future<torch::Tensor>> PartitionedCachePolicy::ReplaceAsync(

void PartitionedCachePolicy::ReadingCompleted(torch::Tensor keys) {
if (policies_.size() == 1) {
std::lock_guard lock(mtx_);
policies_[0]->ReadingCompleted(keys);
return;
}
torch::Tensor offsets, indices, permuted_keys;
std::tie(offsets, indices, permuted_keys) = Partition(keys);
auto offsets_ptr = offsets.data_ptr<int64_t>();
std::lock_guard lock(mtx_);
torch::parallel_for(0, policies_.size(), 1, [&](int64_t begin, int64_t end) {
if (begin == end) return;
const auto tid = begin;
Expand Down
2 changes: 2 additions & 0 deletions graphbolt/src/partitioned_cache_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <torch/custom_class.h>
#include <torch/torch.h>

#include <mutex>
#include <pcg_random.hpp>
#include <random>
#include <type_traits>
Expand Down Expand Up @@ -118,6 +119,7 @@ class PartitionedCachePolicy : public BaseCachePolicy,

int64_t capacity_;
std::vector<std::unique_ptr<BaseCachePolicy>> policies_;
std::mutex mtx_;
};

} // namespace storage
Expand Down

0 comments on commit 239f313

Please sign in to comment.