Skip to content

Commit

Permalink
Implement lazy union
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinTF committed Oct 15, 2024
1 parent aa7b723 commit 52c76bc
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 8 deletions.
90 changes: 84 additions & 6 deletions src/engine/Union.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,20 @@ size_t Union::getCostEstimate() {
getSizeEstimateBeforeLimit();
}

ProtoResult Union::computeResult([[maybe_unused]] bool requestLaziness) {
ProtoResult Union::computeResult(bool requestLaziness) {
LOG(DEBUG) << "Union result computation..." << std::endl;
std::shared_ptr<const Result> subRes1 = _subtrees[0]->getResult();
std::shared_ptr<const Result> subRes2 = _subtrees[1]->getResult();
std::shared_ptr<const Result> subRes1 =
_subtrees[0]->getResult(requestLaziness);
std::shared_ptr<const Result> subRes2 =
_subtrees[1]->getResult(requestLaziness);

if (requestLaziness) {
auto localVocab = std::make_shared<LocalVocab>();
auto generator =
computeResultLazily(std::move(subRes1), std::move(subRes2), localVocab);
return {std::move(generator), resultSortedOn(), std::move(localVocab)};
}

LOG(DEBUG) << "Union subresult computation done." << std::endl;

IdTable idTable =
Expand All @@ -179,10 +189,10 @@ void Union::copyChunked(auto beg, auto end, auto target) const {
size_t total = end - beg;
for (size_t i = 0; i < total; i += chunkSize) {
checkCancellation();
size_t actualEnd = std::min(i + chunkSize, total);
std::copy(beg + i, beg + actualEnd, target + i);
size_t actualEnd = std::min(i + chunkSize, total);
std::copy(beg + i, beg + actualEnd, target + i);
}
};
}

// _____________________________________________________________________________
void Union::fillChunked(auto beg, auto end, const auto& value) const {
Expand Down Expand Up @@ -229,3 +239,71 @@ IdTable Union::computeUnion(
}
return res;
}

// _____________________________________________________________________________
template <bool left>
std::vector<size_t> Union::computePermutation() const {
size_t startOfUndefColumns = _subtrees[left ? 0 : 1]->getResultWidth();
std::vector<size_t> permutation(_columnOrigins.size());

Check warning on line 247 in src/engine/Union.cpp

View check run for this annotation

Codecov / codecov/patch

src/engine/Union.cpp#L245-L247

Added lines #L245 - L247 were not covered by tests
for (size_t targetColIdx = 0; targetColIdx < _columnOrigins.size();
++targetColIdx) {
size_t originIndex = _columnOrigins.at(targetColIdx)[left ? 0 : 1];

Check warning on line 250 in src/engine/Union.cpp

View check run for this annotation

Codecov / codecov/patch

src/engine/Union.cpp#L249-L250

Added lines #L249 - L250 were not covered by tests
if (originIndex == NO_COLUMN) {
originIndex = startOfUndefColumns;
startOfUndefColumns++;

Check warning on line 253 in src/engine/Union.cpp

View check run for this annotation

Codecov / codecov/patch

src/engine/Union.cpp#L252-L253

Added lines #L252 - L253 were not covered by tests
}
permutation[originIndex] = targetColIdx;

Check warning on line 255 in src/engine/Union.cpp

View check run for this annotation

Codecov / codecov/patch

src/engine/Union.cpp#L255

Added line #L255 was not covered by tests
}
return permutation;

Check warning on line 257 in src/engine/Union.cpp

View check run for this annotation

Codecov / codecov/patch

src/engine/Union.cpp#L257

Added line #L257 was not covered by tests
}

// _____________________________________________________________________________
IdTable Union::transformToCorrectColumnFormat(
IdTable idTable, const std::vector<size_t>& permutation) const {

Check warning on line 262 in src/engine/Union.cpp

View check run for this annotation

Codecov / codecov/patch

src/engine/Union.cpp#L262

Added line #L262 was not covered by tests
while (idTable.numColumns() < getResultWidth()) {
idTable.addEmptyColumn();
std::ranges::fill(idTable.getColumn(idTable.numColumns() - 1),
Id::makeUndefined());

Check warning on line 266 in src/engine/Union.cpp

View check run for this annotation

Codecov / codecov/patch

src/engine/Union.cpp#L264-L266

Added lines #L264 - L266 were not covered by tests
}

// i + 1 because once everything expect the last column is in the correct
// position, the last column is already in the correct position so the last
// iteration would always swap the last column with itself.
for (size_t i = 0; i + 1 < permutation.size(); ++i) {
size_t ind = permutation[i];

Check warning on line 273 in src/engine/Union.cpp

View check run for this annotation

Codecov / codecov/patch

src/engine/Union.cpp#L273

Added line #L273 was not covered by tests
while (ind < i) {
ind = permutation[ind];

Check warning on line 275 in src/engine/Union.cpp

View check run for this annotation

Codecov / codecov/patch

src/engine/Union.cpp#L275

Added line #L275 was not covered by tests
}
idTable.swapColumns(i, ind);

Check warning on line 277 in src/engine/Union.cpp

View check run for this annotation

Codecov / codecov/patch

src/engine/Union.cpp#L277

Added line #L277 was not covered by tests
}
return idTable;

Check warning on line 279 in src/engine/Union.cpp

View check run for this annotation

Codecov / codecov/patch

src/engine/Union.cpp#L279

Added line #L279 was not covered by tests
}

// _____________________________________________________________________________
cppcoro::generator<IdTable> Union::computeResultLazily(
std::shared_ptr<const Result> result1,
std::shared_ptr<const Result> result2,
std::shared_ptr<LocalVocab> localVocab) const {
if (result1->isFullyMaterialized()) {
co_yield computeUnion(result1->idTable(),
IdTable{getResultWidth(), allocator()},
_columnOrigins);
} else {
std::vector<size_t> permutation = computePermutation<true>();

Check warning on line 292 in src/engine/Union.cpp

View check run for this annotation

Codecov / codecov/patch

src/engine/Union.cpp#L292

Added line #L292 was not covered by tests
for (IdTable& idTable : result1->idTables()) {
co_yield transformToCorrectColumnFormat(std::move(idTable), permutation);

Check warning on line 294 in src/engine/Union.cpp

View check run for this annotation

Codecov / codecov/patch

src/engine/Union.cpp#L294

Added line #L294 was not covered by tests
}
}
if (result2->isFullyMaterialized()) {
co_yield computeUnion(IdTable{getResultWidth(), allocator()},
result2->idTable(), _columnOrigins);
} else {
std::vector<size_t> permutation = computePermutation<false>();

Check warning on line 301 in src/engine/Union.cpp

View check run for this annotation

Codecov / codecov/patch

src/engine/Union.cpp#L301

Added line #L301 was not covered by tests
for (IdTable& idTable : result2->idTables()) {
co_yield transformToCorrectColumnFormat(std::move(idTable), permutation);

Check warning on line 303 in src/engine/Union.cpp

View check run for this annotation

Codecov / codecov/patch

src/engine/Union.cpp#L303

Added line #L303 was not covered by tests
}
}
std::array<const LocalVocab*, 2> vocabs{&result1->localVocab(),
&result2->localVocab()};
*localVocab = LocalVocab::merge(vocabs);
}
17 changes: 15 additions & 2 deletions src/engine/Union.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,21 @@ class Union : public Operation {
// A similar timeout-checking replacement for `std::fill`.
void fillChunked(auto beg, auto end, const auto& value) const;

virtual ProtoResult computeResult(
[[maybe_unused]] bool requestLaziness) override;
virtual ProtoResult computeResult(bool requestLaziness) override;

VariableToColumnMap computeVariableToColumnMap() const override;

// Compute the permutation of the `IdTable` being yielded for the left or
// right child depending on `left`. This permutation can then be used to swap
// the columns without any copy operations.
template <bool left>
std::vector<size_t> computePermutation() const;

IdTable transformToCorrectColumnFormat(
IdTable idTable, const std::vector<size_t>& permutation) const;

cppcoro::generator<IdTable> computeResultLazily(
std::shared_ptr<const Result> result1,
std::shared_ptr<const Result> result2,
std::shared_ptr<LocalVocab> localVocab) const;
};
41 changes: 41 additions & 0 deletions test/UnionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,44 @@ TEST(UnionTest, computeUnionLarge) {

ASSERT_EQ(result, makeIdTableFromVector(expected));
}

// _____________________________________________________________________________
TEST(UnionTest, computeUnionLazy) {
auto runTest = [](bool nonLazyChilds,
ad_utility::source_location loc =
ad_utility::source_location::current()) {
auto l = generateLocationTrace(loc);
auto* qec = ad_utility::testing::getQec();
IdTable left = makeIdTableFromVector({{V(1)}, {V(2)}, {V(3)}});
auto leftT = ad_utility::makeExecutionTree<ValuesForTesting>(
qec, std::move(left), Vars{Variable{"?x"}}, false,
std::vector<ColumnIndex>{}, LocalVocab{}, std::nullopt, nonLazyChilds);

IdTable right = makeIdTableFromVector({{V(4), V(5)}, {V(6), V(7)}});
auto rightT = ad_utility::makeExecutionTree<ValuesForTesting>(
qec, std::move(right), Vars{Variable{"?u"}, Variable{"?x"}}, false,
std::vector<ColumnIndex>{}, LocalVocab{}, std::nullopt, nonLazyChilds);

Union u{ad_utility::testing::getQec(), std::move(leftT), std::move(rightT)};
auto resultTable = u.computeResultOnlyForTesting(true);
ASSERT_FALSE(resultTable.isFullyMaterialized());
auto& result = resultTable.idTables();

auto U = Id::makeUndefined();
auto expected1 = makeIdTableFromVector({{V(1), U}, {V(2), U}, {V(3), U}});
auto expected2 = makeIdTableFromVector({{V(5), V(4)}, {V(7), V(6)}});

auto iterator = result.begin();
ASSERT_NE(iterator, result.end());
ASSERT_EQ(*iterator, expected1);

++iterator;
ASSERT_NE(iterator, result.end());
ASSERT_EQ(*iterator, expected2);

ASSERT_EQ(++iterator, result.end());
};

runTest(false);
runTest(true);
}

0 comments on commit 52c76bc

Please sign in to comment.