Skip to content

Commit

Permalink
Fix functon conversions
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinyhZou committed Mar 18, 2024
1 parent f93189e commit 775647d
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
* limitations under the License.
*/

#include <Functions/SparkFunctionUnixTimestamp.h>
#include <Functions/SparkFunctionDateToUnixTimestamp.h>

namespace local_engine
{

REGISTER_FUNCTION(SparkFunctionUnixTimestamp)
REGISTER_FUNCTION(SparkFunctionDateToUnixTimestamp)
{
factory.registerFunction<local_eingine::SparkFunctionUnixTimestamp>();
factory.registerFunction<local_eingine::SparkFunctionDateToUnixTimestamp>();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@
* limitations under the License.
*/

#include <Common/DateLUT.h>
#include <Common/DateLUTImpl.h>
#include <Functions/FunctionsConversion.h>
#include <Common/LocalDateTime.h>
#include <Columns/ColumnVector.h>
#include <DataTypes/IDataType.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>

namespace DB
Expand All @@ -32,34 +37,37 @@ using namespace DB;
namespace local_eingine
{

class SparkFunctionUnixTimestamp : public FunctionToUnixTimestamp
class SparkFunctionDateToUnixTimestamp : public IFunction
{
public:
static constexpr auto name = "sparkToUnixTimestamp";
static FunctionPtr create(ContextPtr) { return std::make_shared<SparkFunctionUnixTimestamp>(); }
SparkFunctionUnixTimestamp()
static constexpr auto name = "sparkDateToUnixTimestamp";
static FunctionPtr create(ContextPtr) { return std::make_shared<SparkFunctionDateToUnixTimestamp>(); }
SparkFunctionDateToUnixTimestamp()
{
const DateLUTImpl * date_lut = &DateLUT::instance("UTC");
UInt32 utc_timestamp = static_cast<UInt32>(0);
LocalDateTime date_time(utc_timestamp, *date_lut);
UInt32 unix_timestamp = date_time.to_time_t();
delta_timestamp_from_utc = unix_timestamp - utc_timestamp;
}
~SparkFunctionUnixTimestamp() override = default;
~SparkFunctionDateToUnixTimestamp() override = default;
String getName() const override { return name; }
bool isSuitableForShortCircuitArgumentsExecution(const DB::DataTypesWithConstInfo &) const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
bool isVariadic() const override { return true; }
bool useDefaultImplementationForConstants() const override { return true; }
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &) const override
{
return std::make_shared<DataTypeUInt32>();
}

ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows) const override
{
if (arguments.size() != 1 && arguments.size() != 2)
if (arguments.size() != 1 && arguments.size() != 2)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} argument size must be 1 or 2", name);

ColumnWithTypeAndName first_arg = arguments[0];

if (!isDateOrDate32(first_arg.type))
{
return FunctionToUnixTimestamp::executeImpl(arguments, result_type, input_rows);
}
else if (isDate(first_arg.type))
ColumnWithTypeAndName first_arg = arguments[0];
if (isDate(first_arg.type))
return executeInternal<UInt16>(first_arg.column, input_rows);
else
return executeInternal<Int32>(first_arg.column, input_rows);
Expand Down
19 changes: 17 additions & 2 deletions cpp-ch/local-engine/Functions/SparkFunctionToDate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@
#include <Common/LocalDate.h>
#include <Common/DateLUT.h>
#include <Common/DateLUTImpl.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnVector.h>
#include <Columns/ColumnsNumber.h>
#include <Columns/ColumnNullable.h>
#include <DataTypes/IDataType.h>
#include <DataTypes/DataTypeDate32.h>
#include <Functions/FunctionsConversion.h>
#include <DataTypes/DataTypeNullable.h>
#include <Functions/FunctionFactory.h>
#include <IO/ReadBufferFromMemory.h>
#include <IO/ReadHelpers.h>
Expand All @@ -35,14 +40,18 @@ namespace ErrorCodes

namespace local_engine
{
class SparkFunctionConvertToDate : public DB::FunctionToDate32OrNull
class SparkFunctionConvertToDate : public DB::IFunction
{
public:
static constexpr auto name = "sparkToDate";
static DB::FunctionPtr create(DB::ContextPtr) { return std::make_shared<SparkFunctionConvertToDate>(); }
SparkFunctionConvertToDate() = default;
~SparkFunctionConvertToDate() override = default;
bool isSuitableForShortCircuitArgumentsExecution(const DB::DataTypesWithConstInfo &) const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
String getName() const override { return name; }
bool isVariadic() const override { return true; }
bool useDefaultImplementationForConstants() const override { return true; }

bool checkAndGetDate32(DB::ReadBuffer & buf, DB::DataTypeDate32::FieldType &x, const DateLUTImpl & date_lut) const
{
Expand Down Expand Up @@ -99,6 +108,12 @@ class SparkFunctionConvertToDate : public DB::FunctionToDate32OrNull
}
}

DB::DataTypePtr getReturnTypeImpl(const DB::ColumnsWithTypeAndName &) const override
{
DB::DataTypePtr date32_type = std::make_shared<DB::DataTypeDate32>();
return makeNullable(date32_type);
}

DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & result_type, size_t) const override
{
if (arguments.size() != 1)
Expand Down
28 changes: 25 additions & 3 deletions cpp-ch/local-engine/Functions/SparkFunctionToDateTime.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
#include <Common/DateLUT.h>
#include <Common/DateLUTImpl.h>
#include <Columns/ColumnsDateTime.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypeDateTime64.h>
#include <Functions/FunctionsConversion.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/extractTimeZoneFromFunctionArguments.h>
#include <IO/ReadBufferFromMemory.h>
#include <IO/parseDateTimeBestEffort.h>
#include <IO/ReadHelpers.h>
Expand All @@ -40,13 +43,17 @@ namespace ErrorCodes

namespace local_engine
{
class SparkFunctionConvertToDateTime : public DB::FunctionToDateTime64OrNull
class SparkFunctionConvertToDateTime : public IFunction
{
public:
static constexpr auto name = "sparkToDateTime";
static DB::FunctionPtr create(DB::ContextPtr) { return std::make_shared<SparkFunctionConvertToDateTime>(); }
SparkFunctionConvertToDateTime() = default;
~SparkFunctionConvertToDateTime() override = default;
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
bool isVariadic() const override { return true; }
bool useDefaultImplementationForConstants() const override { return true; }
String getName() const override { return name; }

bool checkDateTimeFormat(DB::ReadBuffer & buf, size_t buf_size, UInt8 & can_be_parsed) const
Expand Down Expand Up @@ -109,11 +116,26 @@ class SparkFunctionConvertToDateTime : public DB::FunctionToDateTime64OrNull
return true;
}

inline UInt32 extractDecimalScale(const ColumnWithTypeAndName & named_column) const
{
const auto * arg_type = named_column.type.get();
bool ok = checkAndGetDataType<DataTypeUInt64>(arg_type)
|| checkAndGetDataType<DataTypeUInt32>(arg_type)
|| checkAndGetDataType<DataTypeUInt16>(arg_type)
|| checkAndGetDataType<DataTypeUInt8>(arg_type);
if (!ok)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type of toDecimal() scale {}", named_column.type->getName());

Field field;
named_column.column->get(0, field);
return static_cast<UInt32>(field.get<UInt32>());
}

DB::DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
UInt32 scale = 6;
if (arguments.size() > 1)
scale = extractToDecimalScale(arguments[1]);
scale = extractDecimalScale(arguments[1]);
const auto timezone = extractTimeZoneNameFromFunctionArguments(arguments, 2, 0, false);
return makeNullable(std::make_shared<DataTypeDateTime64>(scale, timezone));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class FunctionParserUnixTimestamp : public FunctionParser
if (isString(expr_type))
result_node = toFunctionNode(actions_dag, "parseDateTimeInJodaSyntaxOrNull", {expr_arg, fmt_arg, time_zone_node});
else if (isDateOrDate32(expr_type))
result_node = toFunctionNode(actions_dag, "sparkToUnixTimestamp", {expr_arg, time_zone_node});
result_node = toFunctionNode(actions_dag, "sparkDateToUnixTimestamp", {expr_arg, time_zone_node});
else if (isDateTime(expr_type) || isDateTime64(expr_type))
result_node = toFunctionNode(actions_dag, "toUnixTimestamp", {expr_arg, time_zone_node});
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ static void BM_SparkUnixTimestamp_For_Date32(benchmark::State & state)
{
using namespace DB;
auto & factory = FunctionFactory::instance();
auto function = factory.get("sparkToUnixTimestamp", local_engine::SerializedPlanParser::global_context);
auto function = factory.get("sparkDateToUnixTimestamp", local_engine::SerializedPlanParser::global_context);
Block block = createDataBlock("Date32", 30000000);
auto executable = function->build(block.getColumnsWithTypeAndName());
for (auto _ : state)[[maybe_unused]]
Expand All @@ -85,7 +85,7 @@ static void BM_SparkUnixTimestamp_For_Date(benchmark::State & state)
{
using namespace DB;
auto & factory = FunctionFactory::instance();
auto function = factory.get("sparkToUnixTimestamp", local_engine::SerializedPlanParser::global_context);
auto function = factory.get("sparkDateToUnixTimestamp", local_engine::SerializedPlanParser::global_context);
Block block = createDataBlock("Date", 30000000);
auto executable = function->build(block.getColumnsWithTypeAndName());
for (auto _ : state)[[maybe_unused]]
Expand Down

0 comments on commit 775647d

Please sign in to comment.