Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507) #9512

Open
wants to merge 1 commit into
base: release-6.5
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
295 changes: 265 additions & 30 deletions dbms/src/Functions/FunctionsString.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1673,6 +1673,7 @@ class FunctionSubstringUTF8 : public IFunction

bool implicit_length = (arguments.size() == 2);

<<<<<<< HEAD
bool is_start_type_valid = getNumberType(block.getByPosition(arguments[1]).type, [&](const auto & start_type, bool) {
using StartType = std::decay_t<decltype(start_type)>;
// Int64 / UInt64
Expand All @@ -1692,6 +1693,48 @@ class FunctionSubstringUTF8 : public IFunction
length = getValueFromLengthField<LengthFieldType>((*block.getByPosition(arguments[2]).column)[0]);
return true;
});
=======
bool is_start_type_valid
= getNumberType(block.getByPosition(arguments[1]).type, [&](const auto & start_type, bool) {
using StartType = std::decay_t<decltype(start_type)>;
using StartFieldType = typename StartType::FieldType;
const ColumnVector<StartFieldType> * column_vector_start
= getInnerColumnVector<StartFieldType>(column_start);
if unlikely (!column_vector_start)
throw Exception(
fmt::format(
"Illegal type {} of argument 2 of function {}",
block.getByPosition(arguments[1]).type->getName(),
getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

// vector const const
if (!column_string->isColumnConst() && column_start->isColumnConst()
&& (implicit_length || block.getByPosition(arguments[2]).column->isColumnConst()))
{
auto [is_positive, start_abs] = getValueFromStartColumn<StartFieldType>(*column_vector_start, 0);
UInt64 length = 0;
if (!implicit_length)
{
bool is_length_type_valid = getNumberType(
block.getByPosition(arguments[2]).type,
[&](const auto & length_type, bool) {
using LengthType = std::decay_t<decltype(length_type)>;
using LengthFieldType = typename LengthType::FieldType;
const ColumnVector<LengthFieldType> * column_vector_length
= getInnerColumnVector<LengthFieldType>(block.getByPosition(arguments[2]).column);
if unlikely (!column_vector_length)
throw Exception(
fmt::format(
"Illegal type {} of argument 3 of function {}",
block.getByPosition(arguments[2]).type->getName(),
getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

length = getValueFromLengthColumn<LengthFieldType>(*column_vector_length, 0);
return true;
});
>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507))

if (!is_length_type_valid)
throw Exception(fmt::format("3nd argument of function {} must have UInt/Int type.", getName()));
Expand All @@ -1704,6 +1747,7 @@ class FunctionSubstringUTF8 : public IFunction
return true;
}

<<<<<<< HEAD
const auto * col = checkAndGetColumn<ColumnString>(column_string.get());
assert(col);
auto col_res = ColumnString::create();
Expand Down Expand Up @@ -1757,6 +1801,80 @@ class FunctionSubstringUTF8 : public IFunction
if (!is_length_type_valid)
throw Exception(fmt::format("3nd argument of function {} must have UInt/Int type.", getName()));
}
=======
const auto * col = checkAndGetColumn<ColumnString>(column_string.get());
assert(col);
auto col_res = ColumnString::create();
getVectorConstConstFunc(implicit_length, is_positive)(
col->getChars(),
col->getOffsets(),
start_abs,
length,
col_res->getChars(),
col_res->getOffsets());
block.getByPosition(result).column = std::move(col_res);
}
else // all other cases are converted to vector vector vector
{
std::function<std::pair<bool, size_t>(size_t)> get_start_func;
if (column_start->isColumnConst())
{
// func always return const value
auto start_const = getValueFromStartColumn<StartFieldType>(*column_vector_start, 0);
get_start_func = [start_const](size_t) {
return start_const;
};
}
else
{
get_start_func = [column_vector_start](size_t i) {
return getValueFromStartColumn<StartFieldType>(*column_vector_start, i);
};
}

// if implicit_length, get_length_func be nil is ok.
std::function<size_t(size_t)> get_length_func;
if (!implicit_length)
{
const ColumnPtr & column_length = block.getByPosition(arguments[2]).column;
bool is_length_type_valid = getNumberType(
block.getByPosition(arguments[2]).type,
[&](const auto & length_type, bool) {
using LengthType = std::decay_t<decltype(length_type)>;
using LengthFieldType = typename LengthType::FieldType;
const ColumnVector<LengthFieldType> * column_vector_length
= getInnerColumnVector<LengthFieldType>(column_length);
if unlikely (!column_vector_length)
throw Exception(
fmt::format(
"Illegal type {} of argument 3 of function {}",
block.getByPosition(arguments[2]).type->getName(),
getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

if (column_length->isColumnConst())
{
// func always return const value
auto length_const
= getValueFromLengthColumn<LengthFieldType>(*column_vector_length, 0);
get_length_func = [length_const](size_t) {
return length_const;
};
}
else
{
get_length_func = [column_vector_length](size_t i) {
return getValueFromLengthColumn<LengthFieldType>(*column_vector_length, i);
};
}
return true;
});

if unlikely (!is_length_type_valid)
throw Exception(
fmt::format("3nd argument of function {} must have UInt/Int type.", getName()));
}
>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507))

// convert to vector if string is const.
ColumnPtr full_column_string = column_string->isColumnConst() ? column_string->convertToFullColumnIfConst() : column_string;
Expand All @@ -1777,10 +1895,38 @@ class FunctionSubstringUTF8 : public IFunction
return true;
});

if (!is_start_type_valid)
if unlikely (!is_start_type_valid)
throw Exception(fmt::format("2nd argument of function {} must have UInt/Int type.", getName()));
}

template <typename Integer>
static const ColumnVector<Integer> * getInnerColumnVector(const ColumnPtr & column)
{
if (column->isColumnConst())
return checkAndGetColumn<ColumnVector<Integer>>(
checkAndGetColumn<ColumnConst>(column.get())->getDataColumnPtr().get());
return checkAndGetColumn<ColumnVector<Integer>>(column.get());
}

template <typename Integer>
static size_t getValueFromLengthColumn(const ColumnVector<Integer> & column, size_t index)
{
Integer val = column.getElement(index);
if constexpr (
std::is_same_v<Integer, Int8> || std::is_same_v<Integer, Int16> || std::is_same_v<Integer, Int32>
|| std::is_same_v<Integer, Int64>)
{
return val < 0 ? 0 : val;
}
else
{
static_assert(
std::is_same_v<Integer, UInt8> || std::is_same_v<Integer, UInt16> || std::is_same_v<Integer, UInt32>
|| std::is_same_v<Integer, UInt64>);
return val;
}
}

private:
using VectorConstConstFunc = std::function<void(
const ColumnString::Chars_t &,
Expand All @@ -1802,51 +1948,45 @@ class FunctionSubstringUTF8 : public IFunction
}
}

template <typename Integer>
static size_t getValueFromLengthField(const Field & length_field)
{
if constexpr (std::is_same_v<Integer, Int64>)
{
Int64 signed_length = length_field.get<Int64>();
return signed_length < 0 ? 0 : signed_length;
}
else
{
static_assert(std::is_same_v<Integer, UInt64>);
return length_field.get<UInt64>();
}
}

// return {is_positive, abs}
template <typename Integer>
static std::pair<bool, size_t> getValueFromStartField(const Field & start_field)
static std::pair<bool, size_t> getValueFromStartColumn(const ColumnVector<Integer> & column, size_t index)
{
if constexpr (std::is_same_v<Integer, Int64>)
Integer val = column.getElement(index);
if constexpr (
std::is_same_v<Integer, Int8> || std::is_same_v<Integer, Int16> || std::is_same_v<Integer, Int32>
|| std::is_same_v<Integer, Int64>)
{
Int64 signed_length = start_field.get<Int64>();

if (signed_length < 0)
{
return {false, static_cast<size_t>(-signed_length)};
}
else
{
return {true, static_cast<size_t>(signed_length)};
}
if (val < 0)
return {false, static_cast<size_t>(-val)};
return {true, static_cast<size_t>(val)};
}
else
{
static_assert(std::is_same_v<Integer, UInt64>);
return {true, start_field.get<UInt64>()};
static_assert(
std::is_same_v<Integer, UInt8> || std::is_same_v<Integer, UInt16> || std::is_same_v<Integer, UInt32>
|| std::is_same_v<Integer, UInt64>);
return {true, val};
}
}

template <typename F>
static bool getNumberType(DataTypePtr type, F && f)
{
return castTypeToEither<
<<<<<<< HEAD
DataTypeInt64,
DataTypeUInt64>(type.get(), std::forward<F>(f));
=======
DataTypeUInt8,
DataTypeUInt16,
DataTypeUInt32,
DataTypeUInt64,
DataTypeInt8,
DataTypeInt16,
DataTypeInt32,
DataTypeInt64>(type.get(), std::forward<F>(f));
>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507))
}
};

Expand Down Expand Up @@ -1891,6 +2031,7 @@ class FunctionRightUTF8 : public IFunction
const ColumnPtr column_string = block.getByPosition(arguments[0]).column;
const ColumnPtr column_length = block.getByPosition(arguments[1]).column;

<<<<<<< HEAD
bool is_length_type_valid = getLengthType(block.getByPosition(arguments[1]).type, [&](const auto & length_type, bool) {
using LengthType = std::decay_t<decltype(length_type)>;
// Int64 / UInt64
Expand All @@ -1903,6 +2044,33 @@ class FunctionRightUTF8 : public IFunction
{
// vector const
size_t length = getValueFromLengthField<LengthFieldType>((*column_length)[0]);
=======
bool is_length_type_valid
= getLengthType(block.getByPosition(arguments[1]).type, [&](const auto & length_type, bool) {
using LengthType = std::decay_t<decltype(length_type)>;
using LengthFieldType = typename LengthType::FieldType;

const ColumnVector<LengthFieldType> * column_vector_length
= FunctionSubstringUTF8::getInnerColumnVector<LengthFieldType>(column_length);
if unlikely (!column_vector_length)
throw Exception(
fmt::format(
"Illegal type {} of argument 2 of function {}",
block.getByPosition(arguments[1]).type->getName(),
getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);


auto col_res = ColumnString::create();
if (const auto * col_string = checkAndGetColumn<ColumnString>(column_string.get()))
{
if (column_length->isColumnConst())
{
// vector const
size_t length = FunctionSubstringUTF8::getValueFromLengthColumn<LengthFieldType>(
*column_vector_length,
0);
>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507))

// for const 0, return const blank string.
if (0 == length)
Expand All @@ -1911,6 +2079,7 @@ class FunctionRightUTF8 : public IFunction
return true;
}

<<<<<<< HEAD
RightUTF8Impl::vectorConst(col_string->getChars(), col_string->getOffsets(), length, col_res->getChars(), col_res->getOffsets());
}
else
Expand Down Expand Up @@ -1942,6 +2111,61 @@ class FunctionRightUTF8 : public IFunction
block.getByPosition(result).column = std::move(col_res);
return true;
});
=======
RightUTF8Impl::vectorConst(
col_string->getChars(),
col_string->getOffsets(),
length,
col_res->getChars(),
col_res->getOffsets());
}
else
{
// vector vector
auto get_length_func = [column_vector_length](size_t i) {
return FunctionSubstringUTF8::getValueFromLengthColumn<LengthFieldType>(
*column_vector_length,
i);
};
RightUTF8Impl::vectorVector(
col_string->getChars(),
col_string->getOffsets(),
get_length_func,
col_res->getChars(),
col_res->getOffsets());
}
}
else if (
const ColumnConst * col_const_string = checkAndGetColumnConst<ColumnString>(column_string.get()))
{
// const vector
const auto * col_string_from_const
= checkAndGetColumn<ColumnString>(col_const_string->getDataColumnPtr().get());
assert(col_string_from_const);
// When useDefaultImplementationForConstants is true, string and length are not both constants
assert(!column_length->isColumnConst());
auto get_length_func = [column_vector_length](size_t i) {
return FunctionSubstringUTF8::getValueFromLengthColumn<LengthFieldType>(
*column_vector_length,
i);
};
RightUTF8Impl::constVector(
column_length->size(),
col_string_from_const->getChars(),
col_string_from_const->getOffsets(),
get_length_func,
col_res->getChars(),
col_res->getOffsets());
}
else
{
// Impossible to reach here
return false;
}
block.getByPosition(result).column = std::move(col_res);
return true;
});
>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507))

if (!is_length_type_valid)
throw Exception(fmt::format("2nd argument of function {} must have UInt/Int type.", getName()));
Expand All @@ -1953,6 +2177,7 @@ class FunctionRightUTF8 : public IFunction
getLengthType(DataTypePtr type, F && f)
{
return castTypeToEither<
<<<<<<< HEAD
DataTypeInt64,
DataTypeUInt64>(type.get(), std::forward<F>(f));
}
Expand All @@ -1970,6 +2195,16 @@ class FunctionRightUTF8 : public IFunction
static_assert(std::is_same_v<Integer, UInt64>);
return length_field.get<UInt64>();
}
=======
DataTypeUInt8,
DataTypeUInt16,
DataTypeUInt32,
DataTypeUInt64,
DataTypeInt8,
DataTypeInt16,
DataTypeInt32,
DataTypeInt64>(type.get(), std::forward<F>(f));
>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507))
}
};

Expand Down
Loading