Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle ShortDecimal correctly inside importFromArrow #8957

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 66 additions & 14 deletions velox/vector/arrow/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,44 @@ void exportToArrowImpl(
out.release = releaseArrowArray;
}

// Parses the velox decimal format from the given arrow format.
// The input format string should be in the form "d:precision,scale<,bitWidth>".
// bitWidth is not required and must be 128 if provided.
TypePtr parseDecimalFormat(const char* format) {
std::string invalidFormatMsg =
"Unable to convert '{}' ArrowSchema decimal format to Velox decimal";
try {
std::string::size_type sz;
std::string formatStr(format);

auto firstCommaIdx = formatStr.find(',', 2);
auto secondCommaIdx = formatStr.find(',', firstCommaIdx + 1);

if (firstCommaIdx == std::string::npos ||
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add extra check here to avoid ASAN issue, cc @pedroerp

formatStr.size() == firstCommaIdx + 1 ||
(secondCommaIdx != std::string::npos &&
formatStr.size() == secondCommaIdx + 1)) {
VELOX_USER_FAIL(invalidFormatMsg, format);
}

// Parse "d:".
int precision = std::stoi(&format[2], &sz);
int scale = std::stoi(&format[firstCommaIdx + 1], &sz);
// If bitwidth is provided, check if it is equal to 128.
if (secondCommaIdx != std::string::npos) {
int bitWidth = std::stoi(&format[secondCommaIdx + 1], &sz);
VELOX_USER_CHECK_EQ(
bitWidth,
128,
"Conversion failed for '{}'. Velox decimal does not support custom bitwidth.",
format);
}
return DECIMAL(precision, scale);
} catch (std::invalid_argument&) {
VELOX_USER_FAIL(invalidFormatMsg, format);
}
}

TypePtr importFromArrowImpl(
const char* format,
const ArrowSchema& arrowSchema) {
Expand Down Expand Up @@ -1056,20 +1094,9 @@ TypePtr importFromArrowImpl(
}
break;

case 'd': { // decimal types.
try {
std::string::size_type sz;
// Parse "d:".
int precision = std::stoi(&format[2], &sz);
// Parse ",".
int scale = std::stoi(&format[2 + sz + 1], &sz);
return DECIMAL(precision, scale);
} catch (std::invalid_argument&) {
VELOX_USER_FAIL(
"Unable to convert '{}' ArrowSchema decimal format to Velox decimal",
format);
}
}
case 'd':
// decimal types.
return parseDecimalFormat(format);

// Complex types.
case '+': {
Expand Down Expand Up @@ -1612,6 +1639,23 @@ VectorPtr createTimestampVector(
optionalNullCount(nullCount));
}

VectorPtr createShortDecimalVector(
memory::MemoryPool* pool,
const TypePtr& type,
BufferPtr nulls,
const int128_t* input,
size_t length,
int64_t nullCount) {
auto values = AlignedBuffer::allocate<int64_t>(length, pool);
auto rawValues = values->asMutable<int64_t>();
for (size_t i = 0; i < length; ++i) {
memcpy(rawValues + i, input + i, sizeof(int64_t));
}

return createFlatVector<TypeKind::BIGINT>(
pool, type, nulls, length, values, nullCount);
}

bool isREE(const ArrowSchema& arrowSchema) {
return arrowSchema.format[0] == '+' && arrowSchema.format[1] == 'r';
}
Expand Down Expand Up @@ -1691,6 +1735,14 @@ VectorPtr importFromArrowImpl(
static_cast<const int64_t*>(arrowArray.buffers[1]),
arrowArray.length,
arrowArray.null_count);
} else if (type->isShortDecimal()) {
return createShortDecimalVector(
pool,
type,
nulls,
static_cast<const int128_t*>(arrowArray.buffers[1]),
arrowArray.length,
arrowArray.null_count);
} else if (type->isRow()) {
// Row/structs.
return createRowVector(
Expand Down
25 changes: 25 additions & 0 deletions velox/vector/arrow/tests/ArrowBridgeArrayTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,10 @@ class ArrowBridgeArrayImportTest : public ArrowBridgeArrayExportTest {
std::is_same_v<TInput, int64_t> && std::is_same_v<TOutput, Timestamp>) {
assertTimestampVectorContent(
inputValues, output, arrowArray.null_count, format);
} else if constexpr (
std::is_same_v<TInput, int128_t> && std::is_same_v<TOutput, int64_t>) {
assertShortDecimalVectorContent(
inputValues, output, arrowArray.null_count);
} else {
assertVectorContent(inputValues, output, arrowArray.null_count);
}
Expand Down Expand Up @@ -1224,6 +1228,9 @@ class ArrowBridgeArrayImportTest : public ArrowBridgeArrayExportTest {
testArrowImport<Timestamp, int64_t>(
tsString.data(), {0, std::nullopt, Timestamp::kMaxSeconds});
}

testArrowImport<int64_t, int128_t>(
"d:5,2", {1, -1, 0, 12345, -12345, std::nullopt});
}

template <typename TOutput, typename TInput>
Expand Down Expand Up @@ -1304,6 +1311,24 @@ class ArrowBridgeArrayImportTest : public ArrowBridgeArrayExportTest {
}

private:
// Creates short decimals from int128 and asserts the content of actual vector
// with the expected values.
void assertShortDecimalVectorContent(
boneanxs marked this conversation as resolved.
Show resolved Hide resolved
const std::vector<std::optional<int128_t>>& expectedValues,
const VectorPtr& actual,
size_t nullCount) {
std::vector<std::optional<int64_t>> decValues;
decValues.reserve(expectedValues.size());
for (const auto& value : expectedValues) {
if (value) {
decValues.emplace_back(static_cast<int64_t>(*value));
} else {
decValues.emplace_back(std::nullopt);
}
}
assertVectorContent(decValues, actual, nullCount);
}

// Creates timestamp from bigint and asserts the content of actual vector with
// the expected timestamp values.
void assertTimestampVectorContent(
Expand Down
16 changes: 16 additions & 0 deletions velox/vector/arrow/tests/ArrowBridgeSchemaTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,22 @@ TEST_F(ArrowBridgeSchemaImportTest, scalar) {
VELOX_ASSERT_THROW(
*testSchemaImport("d2,15"),
"Unable to convert 'd2,15' ArrowSchema decimal format to Velox decimal");
EXPECT_EQ(*DECIMAL(10, 4), *testSchemaImport("d:10,4,128"));
VELOX_ASSERT_THROW(
*testSchemaImport("d:10,4,256"),
"Conversion failed for 'd:10,4,256'. Velox decimal does not support custom bitwidth.");
VELOX_ASSERT_THROW(
*testSchemaImport("d:10,4,"),
"Unable to convert 'd:10,4,' ArrowSchema decimal format to Velox decimal");
boneanxs marked this conversation as resolved.
Show resolved Hide resolved
VELOX_ASSERT_THROW(
*testSchemaImport("d:10"),
"Unable to convert 'd:10' ArrowSchema decimal format to Velox decimal");
VELOX_ASSERT_THROW(
*testSchemaImport("d:"),
"Unable to convert 'd:' ArrowSchema decimal format to Velox decimal");
VELOX_ASSERT_THROW(
*testSchemaImport("d:10,"),
"Unable to convert 'd:10,' ArrowSchema decimal format to Velox decimal");
}

TEST_F(ArrowBridgeSchemaImportTest, complexTypes) {
Expand Down