Skip to content

Commit

Permalink
Fixed UUID serde to match Presto Java
Browse files Browse the repository at this point in the history
  • Loading branch information
BryanCutler committed Oct 8, 2024
1 parent 08dd2d4 commit af1deb8
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 10 deletions.
7 changes: 5 additions & 2 deletions velox/functions/prestosql/tests/UuidFunctionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,13 @@ TEST_F(UuidFunctionsTest, castAsVarchar) {
// Verify that CAST results as the same as boost::lexical_cast. We do not use
// boost::lexical_cast to implement CAST because it is too slow.
auto expected = makeFlatVector<std::string>(size, [&](auto row) {
const auto uuid = uuids->valueAt(row);
auto uuid = uuids->valueAt(row);
auto charPtr = reinterpret_cast<const char*>(&uuid);

boost::uuids::uuid u;
memcpy(&u, &uuid, 16);
for (size_t i = 0; i < 16; ++i) {
u.data[15 - i] = charPtr[i];
}

return boost::lexical_cast<std::string>(u);
});
Expand Down
9 changes: 6 additions & 3 deletions velox/functions/prestosql/types/UuidType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ class UuidCastOperator : public exec::CastOperator {

size_t offset = 0;
for (auto i = 0; i < 16; ++i) {
result.data()[offset] = kHexTable[uuidBytes[i] * 2];
result.data()[offset + 1] = kHexTable[uuidBytes[i] * 2 + 1];
result.data()[offset] = kHexTable[uuidBytes[15 - i] * 2];
result.data()[offset + 1] = kHexTable[uuidBytes[15 - i] * 2 + 1];

offset += 2;
if (i == 3 || i == 5 || i == 7 || i == 9) {
Expand All @@ -125,7 +125,10 @@ class UuidCastOperator : public exec::CastOperator {
auto uuid = boost::lexical_cast<boost::uuids::uuid>(uuidString);

int128_t u;
memcpy(&u, &uuid, 16);
auto charPtr = reinterpret_cast<char*>(&u);
for (size_t i = 0; i < 16; ++i) {
charPtr[i] = uuid.data[15 - i];
}

flatResult->set(row, u);
});
Expand Down
92 changes: 87 additions & 5 deletions velox/serializers/PrestoSerializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
*/
#include "velox/serializers/PrestoSerializer.h"

#include <iostream>
#include <optional>

#include <folly/lang/Bits.h>

#include "velox/common/base/Crc.h"
#include "velox/common/base/RawVector.h"
#include "velox/common/memory/ByteStream.h"
#include "velox/functions/prestosql/types/UuidType.h"
#include "velox/vector/BiasVector.h"
#include "velox/vector/ComplexVector.h"
#include "velox/vector/DictionaryVector.h"
Expand Down Expand Up @@ -442,6 +444,42 @@ void readDecimalValues(
}
}

int128_t readUuidValue(ByteInputStream* source) {
// ByteInputStream does not support reading int128_t values.
// UUIDs are serialized as 2 int64 values with msb int64 value first.
auto high = source->read<uint64_t>();
auto low = source->read<uint64_t>();
return HugeInt::build(high, low);
}

void readUuidValues(
ByteInputStream* source,
vector_size_t size,
vector_size_t offset,
const BufferPtr& nulls,
vector_size_t nullCount,
const BufferPtr& values) {
auto rawValues = values->asMutable<int128_t>();
if (nullCount) {
checkValuesSize<int128_t>(values, nulls, size, offset);

int32_t toClear = offset;
bits::forEachSetBit(
nulls->as<uint64_t>(), offset, offset + size, [&](int32_t row) {
// Set the values between the last non-null and this to type default.
for (; toClear < row; ++toClear) {
rawValues[toClear] = 0;
}
rawValues[row] = readUuidValue(source);
toClear = row + 1;
});
} else {
for (int32_t row = 0; row < size; ++row) {
rawValues[offset + row] = readUuidValue(source);
}
}
}

/// When deserializing vectors under row vectors that introduce
/// nulls, the child vector must have a gap at the place where a
/// parent RowVector has a null. So, if there is a parent RowVector
Expand Down Expand Up @@ -565,6 +603,16 @@ void read(
values);
return;
}
if (isUuidType(type)) {
readUuidValues(
source,
numNewValues,
resultOffset,
flatResult->nulls(),
nullCount,
values);
return;
}
readValues<T>(
source,
numNewValues,
Expand Down Expand Up @@ -1364,6 +1412,7 @@ class VectorStream {
useLosslessTimestamp_(opts.useLosslessTimestamp),
nullsFirst_(opts.nullsFirst),
isLongDecimal_(type_->isLongDecimal()),
isUuid_(isUuidType(type_)),
opts_(opts),
encoding_(getEncoding(encoding, vector)),
nulls_(streamArena, true, true),
Expand Down Expand Up @@ -1709,6 +1758,10 @@ class VectorStream {
return isLongDecimal_;
}

bool isUuid() const {
return isUuid_;
}

void clear() {
encoding_ = std::nullopt;
initializeHeader(typeToEncodingName(type_), *streamArena_);
Expand Down Expand Up @@ -1784,6 +1837,7 @@ class VectorStream {
const bool useLosslessTimestamp_;
const bool nullsFirst_;
const bool isLongDecimal_;
const bool isUuid_;
const SerdeOpts opts_;
std::optional<VectorEncoding::Simple> encoding_;
int32_t nonNullCount_{0};
Expand Down Expand Up @@ -1841,13 +1895,24 @@ FOLLY_ALWAYS_INLINE int128_t toJavaDecimalValue(int128_t value) {
return value;
}

FOLLY_ALWAYS_INLINE int128_t toJavaUuidValue(int128_t value) {
// Presto Java UuidType uses java.util.UUID that expects 2 long values
// with most significant bits first, swap upper and lower to adjust.
auto low = HugeInt::upper(value);
auto high = HugeInt::lower(value);
return HugeInt::build(high, low);
}

template <>
void VectorStream::append(folly::Range<const int128_t*> values) {
for (auto& value : values) {
int128_t val = value;
if (isLongDecimal_) {
val = toJavaDecimalValue(value);
}
else if (isUuid_) {
val = toJavaUuidValue(value);
}
values_.append<int128_t>(folly::Range(&val, 1));
}
}
Expand Down Expand Up @@ -2392,14 +2457,22 @@ void copyWords(
const int32_t* indices,
int32_t numIndices,
const T* values,
bool isLongDecimal = false) {
bool isLongDecimal = false,
bool isUuid = false) {
if (std::is_same_v<T, int128_t> && isLongDecimal) {
for (auto i = 0; i < numIndices; ++i) {
reinterpret_cast<int128_t*>(destination)[i] = toJavaDecimalValue(
reinterpret_cast<const int128_t*>(values)[indices[i]]);
}
return;
}
if (std::is_same_v<T, int128_t> && isUuid) {
for (auto i = 0; i < numIndices; ++i) {
reinterpret_cast<int128_t*>(destination)[i] = toJavaUuidValue(
reinterpret_cast<const int128_t*>(values)[indices[i]]);
}
return;
}
for (auto i = 0; i < numIndices; ++i) {
destination[i] = values[indices[i]];
}
Expand All @@ -2412,9 +2485,10 @@ void copyWordsWithRows(
const int32_t* indices,
int32_t numIndices,
const T* values,
bool isLongDecimal = false) {
bool isLongDecimal = false,
bool isUuid = false) {
if (!indices) {
copyWords(destination, rows, numIndices, values, isLongDecimal);
copyWords(destination, rows, numIndices, values, isLongDecimal, isUuid);
return;
}
if (std::is_same_v<T, int128_t> && isLongDecimal) {
Expand All @@ -2424,6 +2498,13 @@ void copyWordsWithRows(
}
return;
}
else if (std::is_same_v<T, int128_t> && isUuid) {
for (auto i = 0; i < numIndices; ++i) {
reinterpret_cast<int128_t*>(destination)[i] = toJavaUuidValue(
reinterpret_cast<const int128_t*>(values)[rows[indices[i]]]);
}
return;
}
for (auto i = 0; i < numIndices; ++i) {
destination[i] = values[rows[indices[i]]];
}
Expand Down Expand Up @@ -2484,7 +2565,8 @@ void appendNonNull(
nonNullIndices,
numNonNull,
values,
stream->isLongDecimal());
stream->isLongDecimal(),
stream->isUuid());
}
}

Expand Down Expand Up @@ -2577,7 +2659,7 @@ void serializeFlatVector(
AppendWindow<T> window(stream->values(), scratch);
T* output = window.get(rows.size());
copyWords(
output, rows.data(), rows.size(), rawValues, stream->isLongDecimal());
output, rows.data(), rows.size(), rawValues, stream->isLongDecimal(), stream->isUuid());
return;
}

Expand Down
18 changes: 18 additions & 0 deletions velox/serializers/tests/PrestoSerializerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
#include "velox/serializers/PrestoSerializer.h"
#include <folly/Random.h>
#include <functions/prestosql/types/UuidType.h>
#include <gtest/gtest.h>
#include <vector>
#include "velox/common/base/tests/GTestUtils.h"
Expand Down Expand Up @@ -1054,6 +1055,23 @@ TEST_P(PrestoSerializerTest, longDecimal) {
testRoundTrip(vector);
}

TEST_P(PrestoSerializerTest, uuid) {
std::vector<int128_t> uuidValues(200);

for (int row = 0; row < uuidValues.size(); row++) {
uuidValues[row] = (int128_t) 0xD1 << row % 120;
}
auto vector = makeFlatVector<int128_t>(uuidValues, UUID());

testRoundTrip(vector);

// Add some nulls.
for (auto i = 0; i < uuidValues.size(); i += 7) {
vector->setNull(i, true);
}
testRoundTrip(vector);
}

// Test that hierarchically encoded columns (rows) have their encodings
// preserved by the PrestoBatchVectorSerializer.
TEST_P(PrestoSerializerTest, encodingsBatchVectorSerializer) {
Expand Down

0 comments on commit af1deb8

Please sign in to comment.