Skip to content

Commit

Permalink
Start work on S3ScanIterator
Browse files Browse the repository at this point in the history
  • Loading branch information
shaunrd0 committed Nov 14, 2023
1 parent b7d1c4c commit 7b10129
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 61 deletions.
38 changes: 0 additions & 38 deletions tiledb/sm/filesystem/ls_callback.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ template <class F>
concept DirectoryPredicate = true;

namespace tiledb::sm {

using FileFilter = std::function<bool(const std::string_view&, uint64_t)>;
// TODO: rename or remove
[[maybe_unused]] static bool no_file_filter(const std::string_view&, uint64_t) {
Expand Down Expand Up @@ -110,43 +109,6 @@ class LsScanner {
LsObjects results_;
};

/**
* Typedef for the callback function invoked on each object collected by ls.
*
* @param path[int] The path of a visited object for the relative filesystem.
* @param path_len[in] The length of the path string.
* @param object_size[in] The size of the object at the path.
* @param data[in] Cast to user defined struct to store paths and offsets.
* @return `1` if the walk should continue to the next object, `0` if the walk
* should stop, and `-1` on error.
*/
using LsCallbackCAPI =
std::function<int32_t(const char*, size_t, uint64_t, void*)>;

/**
* Wrapper for the C API ls callback function and it's associated data.
*/
class LsCallbackWrapperCAPI {
public:
/** Constructor */
LsCallbackWrapperCAPI(LsCallbackCAPI cb, void* data)
: cb_(cb)
, data_(data) {
}

/** Operator for invoking C API callback via C++ interface */
bool operator()(const std::string_view& path, uint64_t size) {
int rc = cb_(path.data(), path.size(), size, data_);
if (rc == -1) {
throw std::runtime_error("Error in ls callback");
}
return rc == 1;
}

private:
LsCallbackCAPI cb_;
void* data_;
};
} // namespace tiledb::sm

#endif // TILEDB_LS_CALLBACK_H
92 changes: 72 additions & 20 deletions tiledb/sm/filesystem/s3.h
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,19 @@ class S3Scanner : public LsScanner<F, D> {
}

void fetch_results() {
if (list_objects_outcome_.GetResult().GetIsTruncated()) {
Aws::String next_marker =
list_objects_outcome_.GetResult().GetNextContinuationToken();
if (next_marker.empty()) {
throw S3Exception(
"Failed to retrieve next continuation token for ListObjectsV2 "
"request.");
}
list_objects_request_.SetContinuationToken(std::move(next_marker));
}

list_objects_outcome_ = client_->ListObjectsV2(list_objects_request_);
it_ = list_objects_outcome_.GetResult().GetContents().begin();

if (!list_objects_outcome_.IsSuccess()) {
// TODO: Use outcome_error_message
throw S3Exception(
Expand All @@ -372,16 +382,56 @@ class S3Scanner : public LsScanner<F, D> {
}
}

class S3ScanIterator {
public:
using value_type = Aws::S3::Model::Object;
using difference_type = ptrdiff_t;
using pointer = Aws::S3::Model::Object*;
using reference = Aws::S3::Model::Object&;
using iterator_category = std::forward_iterator_tag;

S3ScanIterator() = default;
S3ScanIterator(pointer begin, pointer end)
: ptr_(begin)
, begin_(begin)
, end_(end) {
}

reference operator*() {
return *ptr_;
}

S3ScanIterator& operator++() {
if (ptr_ == end_) {
throw std::out_of_range("S3ScanIterator out of range");
}
++ptr_;
return *this;
}

inline S3ScanIterator begin() {
return begin_;
}
inline S3ScanIterator end() {
return end_;
}

private:
pointer ptr_, begin_, end_;
};
friend class S3ScanIterator;

private:
shared_ptr<TileDBS3Client> client_;
std::string delimiter_;
Aws::S3::Model::ListObjectsV2Request list_objects_request_;

/** The current request outcome being scanned. */
Aws::S3::Model::ListObjectsV2Outcome list_objects_outcome_;
// std::vector<Aws::S3::Model::Object>& objects_;

std::vector<Aws::S3::Model::Object>::const_iterator it_;

bool found_;
};

/**
Expand Down Expand Up @@ -547,7 +597,7 @@ class S3 {
* @param recursive Whether to recursively list subdirectories.
*/
template <FilePredicate F, DirectoryPredicate D>
LsObjects ls_filtered(
void ls_filtered(
const URI& parent,
F f,
D d = tiledb::sm::no_filter,
Expand All @@ -557,7 +607,6 @@ class S3 {
while (!s3_scanner.end()) {
s3_scanner.next();
}
return std::move(s3_scanner.results());
}

/**
Expand Down Expand Up @@ -1340,7 +1389,8 @@ S3Scanner<F, D>::S3Scanner(
bool recursive)
: LsScanner<F, D>(prefix, file_filter, dir_filter, recursive)
, client_(client)
, delimiter_(this->is_recursive_ ? "" : "/") {
, delimiter_(this->is_recursive_ ? "" : "/")
, found_(false) {
const auto prefix_dir = prefix.add_trailing_slash();
auto prefix_str = prefix_dir.to_string();
if (!prefix_dir.is_s3()) {
Expand All @@ -1362,7 +1412,17 @@ S3Scanner<F, D>::S3Scanner(

template <FilePredicate F, DirectoryPredicate D>
void S3Scanner<F, D>::next() {
static uint64_t c = 0;
// Increment the iterator if we found a result on the last call.
if (found_) {
it_++;
found_ = false;
if (end() && list_objects_outcome_.GetResult().GetIsTruncated()) {
fetch_results();
} else if (end()) {
std::cout << "Collected " << this->results_.size() << " total results."
<< std::endl;
}
}
while (!end()) {
auto object = *it_;
uint64_t size = object.GetSize();
Expand All @@ -1372,29 +1432,21 @@ void S3Scanner<F, D>::next() {
it_++;
} else {
// TODO: Remove print debugs, results_ member.
// If the file filter predicate is true, add the file to the results.
this->results_.emplace_back(path, size);
std::cout << path << std::endl;
c++;

// iterator is at the next object within results accepted by the filters.
found_ = true;
return;
}
// TODO: Add support for directory pruning.
}

if (end() && list_objects_outcome_.GetResult().GetIsTruncated()) {
Aws::String next_marker =
list_objects_outcome_.GetResult().GetNextContinuationToken();
if (next_marker.empty()) {
throw S3Exception(
"Failed to retrieve next continuation token for ListObjectsV2 "
"request.");
if (end() && list_objects_outcome_.GetResult().GetIsTruncated()) {
fetch_results();
} else if (end()) {
std::cout << "Collected " << this->results_.size() << " total results."
<< std::endl;
}
list_objects_request_.SetContinuationToken(std::move(next_marker));
fetch_results();
} else if (end()) {
std::cout << "Collected " << c << " total results." << std::endl;
}
}

Expand Down
5 changes: 2 additions & 3 deletions tiledb/sm/filesystem/vfs.h
Original file line number Diff line number Diff line change
Expand Up @@ -467,11 +467,10 @@ class VFS : private VFSBase, protected S3_within_VFS {
* @param cb The callback to invoke on each object collected.
*/
template <FilePredicate F, DirectoryPredicate D = DirectoryFilter>
const LsObjects& ls_recursive(
const URI& parent, F f, D d = tiledb::sm::no_filter) const {
void ls_recursive(const URI& parent, F f, D d = tiledb::sm::no_filter) const {
if (parent.is_s3()) {
#ifdef HAVE_S3
return s3().ls_filtered(parent, f, d, true);
s3().ls_filtered(parent, f, d, true);
#else
throw VFSException("TileDB was built without S3 support");
#endif
Expand Down

0 comments on commit 7b10129

Please sign in to comment.