From 775647dbaea3aae436e0f1c9471bbd7aea9b2ae7 Mon Sep 17 00:00:00 2001 From: kevinyhzou Date: Mon, 18 Mar 2024 17:10:01 +0800 Subject: [PATCH] Fix functon conversions --- ...p => SparkFunctionDateToUnixTimestamp.cpp} | 6 ++-- ...p.h => SparkFunctionDateToUnixTimestamp.h} | 36 +++++++++++-------- .../Functions/SparkFunctionToDate.cpp | 19 ++++++++-- .../Functions/SparkFunctionToDateTime.h | 28 +++++++++++++-- .../scalar_function_parser/unixTimestamp.cpp | 2 +- .../benchmark_unix_timestamp_function.cpp | 4 +-- 6 files changed, 70 insertions(+), 25 deletions(-) rename cpp-ch/local-engine/Functions/{SparkFunctionUnixTimestamp.cpp => SparkFunctionDateToUnixTimestamp.cpp} (81%) rename cpp-ch/local-engine/Functions/{SparkFunctionUnixTimestamp.h => SparkFunctionDateToUnixTimestamp.h} (71%) diff --git a/cpp-ch/local-engine/Functions/SparkFunctionUnixTimestamp.cpp b/cpp-ch/local-engine/Functions/SparkFunctionDateToUnixTimestamp.cpp similarity index 81% rename from cpp-ch/local-engine/Functions/SparkFunctionUnixTimestamp.cpp rename to cpp-ch/local-engine/Functions/SparkFunctionDateToUnixTimestamp.cpp index c3040c60d40c..e7ae17feb3cf 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionUnixTimestamp.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionDateToUnixTimestamp.cpp @@ -15,14 +15,14 @@ * limitations under the License. */ -#include +#include namespace local_engine { -REGISTER_FUNCTION(SparkFunctionUnixTimestamp) +REGISTER_FUNCTION(SparkFunctionDateToUnixTimestamp) { - factory.registerFunction(); + factory.registerFunction(); } } \ No newline at end of file diff --git a/cpp-ch/local-engine/Functions/SparkFunctionUnixTimestamp.h b/cpp-ch/local-engine/Functions/SparkFunctionDateToUnixTimestamp.h similarity index 71% rename from cpp-ch/local-engine/Functions/SparkFunctionUnixTimestamp.h rename to cpp-ch/local-engine/Functions/SparkFunctionDateToUnixTimestamp.h index 356519afa531..cdf0460e0e9d 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionUnixTimestamp.h +++ b/cpp-ch/local-engine/Functions/SparkFunctionDateToUnixTimestamp.h @@ -15,8 +15,13 @@ * limitations under the License. */ +#include #include -#include +#include +#include +#include +#include +#include #include namespace DB @@ -32,12 +37,12 @@ using namespace DB; namespace local_eingine { -class SparkFunctionUnixTimestamp : public FunctionToUnixTimestamp +class SparkFunctionDateToUnixTimestamp : public IFunction { public: - static constexpr auto name = "sparkToUnixTimestamp"; - static FunctionPtr create(ContextPtr) { return std::make_shared(); } - SparkFunctionUnixTimestamp() + static constexpr auto name = "sparkDateToUnixTimestamp"; + static FunctionPtr create(ContextPtr) { return std::make_shared(); } + SparkFunctionDateToUnixTimestamp() { const DateLUTImpl * date_lut = &DateLUT::instance("UTC"); UInt32 utc_timestamp = static_cast(0); @@ -45,21 +50,24 @@ class SparkFunctionUnixTimestamp : public FunctionToUnixTimestamp UInt32 unix_timestamp = date_time.to_time_t(); delta_timestamp_from_utc = unix_timestamp - utc_timestamp; } - ~SparkFunctionUnixTimestamp() override = default; + ~SparkFunctionDateToUnixTimestamp() override = default; String getName() const override { return name; } + bool isSuitableForShortCircuitArgumentsExecution(const DB::DataTypesWithConstInfo &) const override { return true; } + size_t getNumberOfArguments() const override { return 0; } + bool isVariadic() const override { return true; } + bool useDefaultImplementationForConstants() const override { return true; } + DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &) const override + { + return std::make_shared(); + } ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows) const override { - if (arguments.size() != 1 && arguments.size() != 2) + if (arguments.size() != 1 && arguments.size() != 2) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} argument size must be 1 or 2", name); - ColumnWithTypeAndName first_arg = arguments[0]; - - if (!isDateOrDate32(first_arg.type)) - { - return FunctionToUnixTimestamp::executeImpl(arguments, result_type, input_rows); - } - else if (isDate(first_arg.type)) + ColumnWithTypeAndName first_arg = arguments[0]; + if (isDate(first_arg.type)) return executeInternal(first_arg.column, input_rows); else return executeInternal(first_arg.column, input_rows); diff --git a/cpp-ch/local-engine/Functions/SparkFunctionToDate.cpp b/cpp-ch/local-engine/Functions/SparkFunctionToDate.cpp index 0b963e769efd..3a25e383d7a0 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionToDate.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionToDate.cpp @@ -17,8 +17,13 @@ #include #include #include +#include +#include +#include +#include +#include #include -#include +#include #include #include #include @@ -35,14 +40,18 @@ namespace ErrorCodes namespace local_engine { -class SparkFunctionConvertToDate : public DB::FunctionToDate32OrNull +class SparkFunctionConvertToDate : public DB::IFunction { public: static constexpr auto name = "sparkToDate"; static DB::FunctionPtr create(DB::ContextPtr) { return std::make_shared(); } SparkFunctionConvertToDate() = default; ~SparkFunctionConvertToDate() override = default; + bool isSuitableForShortCircuitArgumentsExecution(const DB::DataTypesWithConstInfo &) const override { return true; } + size_t getNumberOfArguments() const override { return 0; } String getName() const override { return name; } + bool isVariadic() const override { return true; } + bool useDefaultImplementationForConstants() const override { return true; } bool checkAndGetDate32(DB::ReadBuffer & buf, DB::DataTypeDate32::FieldType &x, const DateLUTImpl & date_lut) const { @@ -99,6 +108,12 @@ class SparkFunctionConvertToDate : public DB::FunctionToDate32OrNull } } + DB::DataTypePtr getReturnTypeImpl(const DB::ColumnsWithTypeAndName &) const override + { + DB::DataTypePtr date32_type = std::make_shared(); + return makeNullable(date32_type); + } + DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & result_type, size_t) const override { if (arguments.size() != 1) diff --git a/cpp-ch/local-engine/Functions/SparkFunctionToDateTime.h b/cpp-ch/local-engine/Functions/SparkFunctionToDateTime.h index d185b850fa1f..ae9ebc6e72f2 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionToDateTime.h +++ b/cpp-ch/local-engine/Functions/SparkFunctionToDateTime.h @@ -18,9 +18,12 @@ #include #include #include +#include +#include #include -#include #include +#include +#include #include #include #include @@ -40,13 +43,17 @@ namespace ErrorCodes namespace local_engine { -class SparkFunctionConvertToDateTime : public DB::FunctionToDateTime64OrNull +class SparkFunctionConvertToDateTime : public IFunction { public: static constexpr auto name = "sparkToDateTime"; static DB::FunctionPtr create(DB::ContextPtr) { return std::make_shared(); } SparkFunctionConvertToDateTime() = default; ~SparkFunctionConvertToDateTime() override = default; + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; } + size_t getNumberOfArguments() const override { return 0; } + bool isVariadic() const override { return true; } + bool useDefaultImplementationForConstants() const override { return true; } String getName() const override { return name; } bool checkDateTimeFormat(DB::ReadBuffer & buf, size_t buf_size, UInt8 & can_be_parsed) const @@ -109,11 +116,26 @@ class SparkFunctionConvertToDateTime : public DB::FunctionToDateTime64OrNull return true; } + inline UInt32 extractDecimalScale(const ColumnWithTypeAndName & named_column) const + { + const auto * arg_type = named_column.type.get(); + bool ok = checkAndGetDataType(arg_type) + || checkAndGetDataType(arg_type) + || checkAndGetDataType(arg_type) + || checkAndGetDataType(arg_type); + if (!ok) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type of toDecimal() scale {}", named_column.type->getName()); + + Field field; + named_column.column->get(0, field); + return static_cast(field.get()); + } + DB::DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override { UInt32 scale = 6; if (arguments.size() > 1) - scale = extractToDecimalScale(arguments[1]); + scale = extractDecimalScale(arguments[1]); const auto timezone = extractTimeZoneNameFromFunctionArguments(arguments, 2, 0, false); return makeNullable(std::make_shared(scale, timezone)); } diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/unixTimestamp.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/unixTimestamp.cpp index 41b0864d580d..84a7d394e1d4 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/unixTimestamp.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/unixTimestamp.cpp @@ -67,7 +67,7 @@ class FunctionParserUnixTimestamp : public FunctionParser if (isString(expr_type)) result_node = toFunctionNode(actions_dag, "parseDateTimeInJodaSyntaxOrNull", {expr_arg, fmt_arg, time_zone_node}); else if (isDateOrDate32(expr_type)) - result_node = toFunctionNode(actions_dag, "sparkToUnixTimestamp", {expr_arg, time_zone_node}); + result_node = toFunctionNode(actions_dag, "sparkDateToUnixTimestamp", {expr_arg, time_zone_node}); else if (isDateTime(expr_type) || isDateTime64(expr_type)) result_node = toFunctionNode(actions_dag, "toUnixTimestamp", {expr_arg, time_zone_node}); else diff --git a/cpp-ch/local-engine/tests/benchmark_unix_timestamp_function.cpp b/cpp-ch/local-engine/tests/benchmark_unix_timestamp_function.cpp index bad233d27053..dcfa7f69d931 100644 --- a/cpp-ch/local-engine/tests/benchmark_unix_timestamp_function.cpp +++ b/cpp-ch/local-engine/tests/benchmark_unix_timestamp_function.cpp @@ -74,7 +74,7 @@ static void BM_SparkUnixTimestamp_For_Date32(benchmark::State & state) { using namespace DB; auto & factory = FunctionFactory::instance(); - auto function = factory.get("sparkToUnixTimestamp", local_engine::SerializedPlanParser::global_context); + auto function = factory.get("sparkDateToUnixTimestamp", local_engine::SerializedPlanParser::global_context); Block block = createDataBlock("Date32", 30000000); auto executable = function->build(block.getColumnsWithTypeAndName()); for (auto _ : state)[[maybe_unused]] @@ -85,7 +85,7 @@ static void BM_SparkUnixTimestamp_For_Date(benchmark::State & state) { using namespace DB; auto & factory = FunctionFactory::instance(); - auto function = factory.get("sparkToUnixTimestamp", local_engine::SerializedPlanParser::global_context); + auto function = factory.get("sparkDateToUnixTimestamp", local_engine::SerializedPlanParser::global_context); Block block = createDataBlock("Date", 30000000); auto executable = function->build(block.getColumnsWithTypeAndName()); for (auto _ : state)[[maybe_unused]]