Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CH-313] support functions position/locate #314

Merged
merged 2 commits into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
296 changes: 296 additions & 0 deletions utils/local-engine/Functions/positionUTF8Spark.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
#include <string>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionsStringSearch.h>
#include <Functions/PositionImpl.h>

namespace DB
{

namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
}

}

namespace local_engine
{

using namespace DB;

// Spark-specific version of PositionImpl
template <typename Name, typename Impl>
struct PositionSparkImpl
{
static constexpr bool use_default_implementation_for_constants = false;
static constexpr bool supports_start_pos = true;
static constexpr auto name = Name::name;

static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {};}

using ResultType = UInt64;

/// Find one substring in many strings.
static void vectorConstant(
const ColumnString::Chars & data,
const ColumnString::Offsets & offsets,
const std::string & needle,
const ColumnPtr & start_pos,
PaddedPODArray<UInt64> & res,
[[maybe_unused]] ColumnUInt8 * res_null)
{

/// `res_null` serves as an output parameter for implementing an XYZOrNull variant.
assert(!res_null);

const UInt8 * begin = data.data();
const UInt8 * pos = begin;
const UInt8 * end = pos + data.size();

/// Current index in the array of strings.
size_t i = 0;

typename Impl::SearcherInBigHaystack searcher = Impl::createSearcherInBigHaystack(needle.data(), needle.size(), end - pos);

/// We will search for the next occurrence in all strings at once.
while (pos < end && end != (pos = searcher.search(pos, end - pos)))
{
/// Determine which index it refers to.
while (begin + offsets[i] <= pos)
{
res[i] = 0;
++i;
}
auto start = start_pos != nullptr ? start_pos->getUInt(i) : 0;

/// We check that the entry does not pass through the boundaries of strings.
// The result is 0 if start_pos is 0, in compliance with Spark semantics
if (start != 0 && pos + needle.size() < begin + offsets[i])
{
auto res_pos = 1 + Impl::countChars(reinterpret_cast<const char *>(begin + offsets[i - 1]), reinterpret_cast<const char *>(pos));
if (res_pos < start)
{
pos = reinterpret_cast<const UInt8 *>(Impl::advancePos(
reinterpret_cast<const char *>(pos),
reinterpret_cast<const char *>(begin + offsets[i]),
start - res_pos));
continue;
}
// The result is 1 if needle is empty, in compliance with Spark semantics
res[i] = needle.empty() ? 1 : res_pos;
}
else
{
res[i] = 0;
}
pos = begin + offsets[i];
++i;
}

if (i < res.size())
memset(&res[i], 0, (res.size() - i) * sizeof(res[0]));
}

/// Search for substring in string.
static void constantConstantScalar(
std::string data,
std::string needle,
UInt64 start_pos,
UInt64 & res)
{
size_t start_byte = Impl::advancePos(data.data(), data.data() + data.size(), start_pos - 1) - data.data();
res = data.find(needle, start_byte);
if (res == std::string::npos)
res = 0;
else
res = 1 + Impl::countChars(data.data(), data.data() + res);
}

/// Search for substring in string starting from different positions.
static void constantConstant(
std::string data,
std::string needle,
const ColumnPtr & start_pos,
PaddedPODArray<UInt64> & res,
[[maybe_unused]] ColumnUInt8 * res_null)
{
/// `res_null` serves as an output parameter for implementing an XYZOrNull variant.
assert(!res_null);

Impl::toLowerIfNeed(data);
Impl::toLowerIfNeed(needle);

if (start_pos == nullptr)
{
res[0] = 0;
return;
}

size_t haystack_size = Impl::countChars(data.data(), data.data() + data.size());

size_t size = start_pos != nullptr ? start_pos->size() : 0;
for (size_t i = 0; i < size; ++i)
{
auto start = start_pos->getUInt(i);

if (start == 0 || start > haystack_size + 1)
{
res[i] = 0;
continue;
}
if (needle.empty())
{
res[0] = 1;
continue;
}
constantConstantScalar(data, needle, start, res[i]);
}
}

/// Search each time for a different single substring inside each time different string.
static void vectorVector(
const ColumnString::Chars & haystack_data,
const ColumnString::Offsets & haystack_offsets,
const ColumnString::Chars & needle_data,
const ColumnString::Offsets & needle_offsets,
const ColumnPtr & start_pos,
PaddedPODArray<UInt64> & res,
[[maybe_unused]] ColumnUInt8 * res_null)
{
/// `res_null` serves as an output parameter for implementing an XYZOrNull variant.
assert(!res_null);

ColumnString::Offset prev_haystack_offset = 0;
ColumnString::Offset prev_needle_offset = 0;

size_t size = haystack_offsets.size();

for (size_t i = 0; i < size; ++i)
{
size_t needle_size = needle_offsets[i] - prev_needle_offset - 1;
size_t haystack_size = haystack_offsets[i] - prev_haystack_offset - 1;

auto start = start_pos != nullptr ? start_pos->getUInt(i) : UInt64(0);

if (start == 0 || start > haystack_size + 1)
{
res[i] = 0;
}
else if (0 == needle_size)
{
/// An empty string is always 1 in compliance with Spark semantics.
res[i] = 1;
}
else
{
/// It is assumed that the StringSearcher is not very difficult to initialize.
typename Impl::SearcherInSmallHaystack searcher = Impl::createSearcherInSmallHaystack(
reinterpret_cast<const char *>(&needle_data[prev_needle_offset]),
needle_offsets[i] - prev_needle_offset - 1); /// zero byte at the end

const char * beg = Impl::advancePos(
reinterpret_cast<const char *>(&haystack_data[prev_haystack_offset]),
reinterpret_cast<const char *>(&haystack_data[haystack_offsets[i] - 1]),
start - 1);
/// searcher returns a pointer to the found substring or to the end of `haystack`.
size_t pos = searcher.search(reinterpret_cast<const UInt8 *>(beg), &haystack_data[haystack_offsets[i] - 1])
- &haystack_data[prev_haystack_offset];

if (pos != haystack_size)
{
res[i] = 1
+ Impl::countChars(
reinterpret_cast<const char *>(&haystack_data[prev_haystack_offset]),
reinterpret_cast<const char *>(&haystack_data[prev_haystack_offset + pos]));
}
else
res[i] = 0;
}

prev_haystack_offset = haystack_offsets[i];
prev_needle_offset = needle_offsets[i];
}
}

/// Find many substrings in single string.
static void constantVector(
const String & haystack,
const ColumnString::Chars & needle_data,
const ColumnString::Offsets & needle_offsets,
const ColumnPtr & start_pos,
PaddedPODArray<UInt64> & res,
[[maybe_unused]] ColumnUInt8 * res_null)
{
/// `res_null` serves as an output parameter for implementing an XYZOrNull variant.
assert(!res_null);

/// NOTE You could use haystack indexing. But this is a rare case.
ColumnString::Offset prev_needle_offset = 0;

size_t size = needle_offsets.size();

for (size_t i = 0; i < size; ++i)
{
size_t needle_size = needle_offsets[i] - prev_needle_offset - 1;

auto start = start_pos != nullptr ? start_pos->getUInt(i) : UInt64(0);

if (start == 0 || start > haystack.size() + 1)
{
res[i] = 0;
}
else if (0 == needle_size)
{
res[i] = 1;
}
else
{
typename Impl::SearcherInSmallHaystack searcher = Impl::createSearcherInSmallHaystack(
reinterpret_cast<const char *>(&needle_data[prev_needle_offset]), needle_offsets[i] - prev_needle_offset - 1);

const char * beg = Impl::advancePos(haystack.data(), haystack.data() + haystack.size(), start - 1);
size_t pos = searcher.search(
reinterpret_cast<const UInt8 *>(beg),
reinterpret_cast<const UInt8 *>(haystack.data()) + haystack.size())
- reinterpret_cast<const UInt8 *>(haystack.data());

if (pos != haystack.size())
{
res[i] = 1 + Impl::countChars(haystack.data(), haystack.data() + pos);
}
else
res[i] = 0;
}

prev_needle_offset = needle_offsets[i];
}
}

template <typename... Args>
static void vectorFixedConstant(Args &&...)
{
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Function '{}' doesn't support FixedString haystack argument", name);
}

template <typename... Args>
static void vectorFixedVector(Args &&...)
{
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Function '{}' doesn't support FixedString haystack argument", name);
}
};

struct NamePositionUTF8Spark
{
static constexpr auto name = "positionUTF8Spark";
};


using FunctionPositionUTF8Spark = FunctionsStringSearch<PositionSparkImpl<NamePositionUTF8Spark, PositionCaseSensitiveUTF8>>;


void registerFunctionPositionUTF8Spark(FunctionFactory & factory)
{
factory.registerFunction<FunctionPositionUTF8Spark>();
}

}
3 changes: 3 additions & 0 deletions utils/local-engine/Functions/registerFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@ namespace local_engine
using namespace DB;
void registerFunctionSparkTrim(FunctionFactory &);
void registerFunctionsHashingExtended(FunctionFactory & factory);
void registerFunctionPositionUTF8Spark(FunctionFactory &);

void registerFunctions(FunctionFactory & factory)
{
registerFunctionSparkTrim(factory);
registerFunctionsHashingExtended(factory);
registerFunctionPositionUTF8Spark(factory);
}

}
18 changes: 18 additions & 0 deletions utils/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1604,6 +1604,24 @@ void SerializedPlanParser::parseFunctionArguments(
DB::ActionsDAG::NodeRawConstPtrs ifnull_func_args = {arg_node, add_column(std::make_shared<DataTypeInt32>(), 0)};
parsed_args.emplace_back(toFunctionNode(actions_dag, "IfNull", ifnull_func_args));
}
else if (function_name == "positionUTF8Spark")
{
if (args.size() >= 2)
{
// In Spark: position(substr, str, Int32)
// In CH: position(str, subtr, UInt32)
parseFunctionArgument(actions_dag, parsed_args, required_columns, function_name, args[1]);
parseFunctionArgument(actions_dag, parsed_args, required_columns, function_name, args[0]);
}
if (args.size() >= 3)
{
// add cast: cast(start_pos as UInt32)
const auto * start_pos_node = parseFunctionArgument(actions_dag, required_columns, function_name, args[2]);
DB::DataTypeNullable target_type(std::make_shared<DB::DataTypeUInt32>());
start_pos_node = ActionsDAGUtil::convertNodeType(actions_dag, start_pos_node, target_type.getName());
parsed_args.emplace_back(start_pos_node);
}
}
else
{
// Default handle
Expand Down
2 changes: 2 additions & 0 deletions utils/local-engine/Parser/SerializedPlanParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ static const std::map<std::string, std::string> SCALAR_FUNCTIONS = {
{"md5","MD5"},
{"translate", "translateUTF8"},
{"repeat","repeat"},
{"position", "positionUTF8Spark"},
{"locate", "positionUTF8Spark"},

/// hash functions
{"hash", "murmurHashSpark3_32"},
Expand Down