Skip to content

Commit

Permalink
Pass DAG tests after merging master (#199)
Browse files Browse the repository at this point in the history
* Enhance dbg invoke and add dag as schemaful function

* Add basic sql parse to dag

* Column id starts from 1

* Fix value to ref

* Add basic dag test

* Fix dag bugs and pass 1st mock test

* Make dag go normal routine and add mock dag

* Add todo

* Add comment

* Fix gcc compile error

* Enhance dag test

* Address comments

* Enhance mock sql -> dag compiler and add project test

* Mock sql dag compiler support more expression types and add filter test

* Add topn and limit test

* Add agg for sql -> dag parser and agg test

* Add dag specific codec

* type

* Update codec accordingly

* Remove cop-test

* Pass tests after merging master
  • Loading branch information
zanmato1984 authored Aug 24, 2019
1 parent 960cc56 commit 08bacd7
Show file tree
Hide file tree
Showing 16 changed files with 396 additions and 256 deletions.
13 changes: 7 additions & 6 deletions dbms/src/Debug/dbgFuncCoprocessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ BlockInputStreamPtr dbgFuncDAG(Context & context, const ASTs & args)
region_id = safeGet<RegionID>(typeid_cast<const ASTLiteral &>(*args[1]).value);
Timestamp start_ts = context.getTMTContext().getPDClient()->getTS();

auto [table_id, schema, dag_request] = compileQuery(
context, query,
auto [table_id, schema, dag_request] = compileQuery(context, query,
[&](const String & database_name, const String & table_name) {
auto storage = context.getTable(database_name, table_name);
auto mmt = std::dynamic_pointer_cast<StorageMergeTree>(storage);
Expand Down Expand Up @@ -92,8 +91,7 @@ BlockInputStreamPtr dbgFuncMockDAG(Context & context, const ASTs & args)
if (start_ts == 0)
start_ts = context.getTMTContext().getPDClient()->getTS();

auto [table_id, schema, dag_request] = compileQuery(
context, query,
auto [table_id, schema, dag_request] = compileQuery(context, query,
[&](const String & database_name, const String & table_name) {
return MockTiDB::instance().getTableByName(database_name, table_name)->table_info;
},
Expand Down Expand Up @@ -210,9 +208,12 @@ void compileExpr(const DAGSchema & input, ASTPtr ast, tipb::Expr * expr, std::un
expr->set_tp(tipb::Float64);
encodeDAGFloat64(lit->value.get<Float64>(), ss);
break;
case Field::Types::Which::Decimal:
case Field::Types::Which::Decimal32:
case Field::Types::Which::Decimal64:
case Field::Types::Which::Decimal128:
case Field::Types::Which::Decimal256:
expr->set_tp(tipb::MysqlDecimal);
encodeDAGDecimal(lit->value.get<Decimal>(), ss);
encodeDAGDecimal(lit->value, ss);
break;
case Field::Types::Which::String:
expr->set_tp(tipb::String);
Expand Down
5 changes: 3 additions & 2 deletions dbms/src/Debug/dbgFuncRegion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include <Interpreters/executeQuery.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTLiteral.h>
#include <Storages/MutableSupport.h>
#include <Storages/StorageMergeTree.h>
#include <Storages/Transaction/KVStore.h>
#include <Storages/Transaction/Region.h>
Expand Down Expand Up @@ -140,7 +139,9 @@ void dbgFuncRegionSnapshotWithData(Context & context, const ASTs & args, DBGInvo
}

TiKVKey key = RecordKVFormat::genKey(table_id, handle_id);
TiKVValue value = RecordKVFormat::EncodeRow(table->table_info, fields);
std::stringstream ss;
RegionBench::encodeRow(table->table_info, fields, ss);
TiKVValue value(ss.str());
UInt64 commit_ts = tso;
UInt64 prewrite_ts = tso;
TiKVValue commit_value;
Expand Down
130 changes: 129 additions & 1 deletion dbms/src/Debug/dbgTools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <Raft/RaftContext.h>
#include <Storages/Transaction/Codec.h>
#include <Storages/Transaction/KVStore.h>
#include <Storages/Transaction/MyTimeParser.h>
#include <Storages/Transaction/Region.h>
#include <Storages/Transaction/TMTContext.h>
#include <Storages/Transaction/TiKVRange.h>
Expand All @@ -23,6 +24,9 @@ extern const int LOGICAL_ERROR;
namespace RegionBench
{

using TiDB::ColumnInfo;
using TiDB::TableInfo;

RegionPtr createRegion(TableID table_id, RegionID region_id, const HandleID & start, const HandleID & end)
{
enginepb::SnapshotRequest request;
Expand Down Expand Up @@ -121,6 +125,128 @@ void addRequestsToRaftCmd(enginepb::CommandRequest * cmd, RegionID region_id, co
}
}

template <typename T>
T convertNumber(const Field & field)
{
switch (field.getType())
{
case Field::Types::Int64:
return static_cast<T>(field.get<Int64>());
case Field::Types::UInt64:
return static_cast<T>(field.get<UInt64>());
case Field::Types::Float64:
return static_cast<T>(field.get<Float64>());
case Field::Types::Decimal32:
return static_cast<T>(field.get<DecimalField<Decimal32>>());
case Field::Types::Decimal64:
return static_cast<T>(field.get<DecimalField<Decimal64>>());
case Field::Types::Decimal128:
return static_cast<T>(field.get<DecimalField<Decimal128>>());
case Field::Types::Decimal256:
return static_cast<T>(field.get<DecimalField<Decimal256>>());
default:
throw Exception(String("Unable to convert field type ") + field.getTypeName() + " to number", ErrorCodes::LOGICAL_ERROR);
}
}

Field convertDecimal(UInt32 scale, const Field & field)
{
switch (field.getType())
{
case Field::Types::Int64:
return DecimalField(ToDecimal<Int64, Decimal64>(field.get<Int64>(), scale), scale);
case Field::Types::UInt64:
return DecimalField(ToDecimal<Int64, Decimal64>(field.get<UInt64>(), scale), scale);
case Field::Types::Float64:
return DecimalField(ToDecimal<Float64, Decimal64>(field.get<Float64>(), scale), scale);
case Field::Types::Decimal32:
case Field::Types::Decimal64:
case Field::Types::Decimal128:
case Field::Types::Decimal256:
return field;
default:
throw Exception(String("Unable to convert field type ") + field.getTypeName() + " to number", ErrorCodes::LOGICAL_ERROR);
}
}

Field convertEnum(const ColumnInfo & column_info, const Field & field)
{
switch (field.getType())
{
case Field::Types::Int64:
case Field::Types::UInt64:
return convertNumber<UInt64>(field);
case Field::Types::String:
return static_cast<UInt64>(column_info.getEnumIndex(field.get<String>()));
default:
throw Exception(String("Unable to convert field type ") + field.getTypeName() + " to Enum", ErrorCodes::LOGICAL_ERROR);
}
}

Field convertField(const ColumnInfo & column_info, const Field & field)
{
if (field.isNull())
return field;

switch (column_info.tp)
{
case TiDB::TypeTiny:
case TiDB::TypeShort:
case TiDB::TypeLong:
case TiDB::TypeLongLong:
case TiDB::TypeInt24:
case TiDB::TypeBit:
if (column_info.hasUnsignedFlag())
return convertNumber<UInt64>(field);
else
return convertNumber<Int64>(field);
case TiDB::TypeFloat:
case TiDB::TypeDouble:
return convertNumber<Float64>(field);
case TiDB::TypeDate:
case TiDB::TypeDatetime:
case TiDB::TypeTimestamp:
return DB::parseMyDatetime(field.get<String>());
case TiDB::TypeVarchar:
case TiDB::TypeTinyBlob:
case TiDB::TypeMediumBlob:
case TiDB::TypeLongBlob:
case TiDB::TypeBlob:
case TiDB::TypeVarString:
case TiDB::TypeString:
return field;
case TiDB::TypeEnum:
return convertEnum(column_info, field);
case TiDB::TypeNull:
return Field();
case TiDB::TypeDecimal:
case TiDB::TypeNewDecimal:
return convertDecimal(column_info.decimal, field);
case TiDB::TypeTime:
throw Exception(String("Unable to convert field type ") + field.getTypeName() + " to Time", ErrorCodes::LOGICAL_ERROR);
case TiDB::TypeYear:
throw Exception(String("Unable to convert field type ") + field.getTypeName() + " to Year", ErrorCodes::LOGICAL_ERROR);
case TiDB::TypeSet:
throw Exception(String("Unable to convert field type ") + field.getTypeName() + " to Set", ErrorCodes::LOGICAL_ERROR);
default:
return Field();
}
}

void encodeRow(const TiDB::TableInfo & table_info, const std::vector<Field> & fields, std::stringstream & ss)
{
if (table_info.columns.size() != fields.size())
throw Exception("Encoding row has different sizes between columns and values", ErrorCodes::LOGICAL_ERROR);
for (size_t i = 0; i < fields.size(); i++)
{
const TiDB::ColumnInfo & column_info = table_info.columns[i];
EncodeDatum(Field(column_info.id), TiDB::CodecFlagInt, ss);
Field field = convertField(column_info, fields[i]);
TiDB::DatumBumpy datum = TiDB::DatumBumpy(field, column_info.tp);
EncodeDatum(datum.field(), column_info.getCodecFlag(), ss);
}
}

void insert(const TiDB::TableInfo & table_info, RegionID region_id, HandleID handle_id, ASTs::const_iterator begin,
ASTs::const_iterator end, Context & context, const std::optional<std::tuple<Timestamp, UInt8>> & tso_del)
{
Expand All @@ -142,7 +268,9 @@ void insert(const TiDB::TableInfo & table_info, RegionID region_id, HandleID han
TableID table_id = RecordKVFormat::getTableId(region->getRange().first);

TiKVKey key = RecordKVFormat::genKey(table_id, handle_id);
TiKVValue value = RecordKVFormat::EncodeRow(table_info, fields);
std::stringstream ss;
encodeRow(table_info, fields, ss);
TiKVValue value(ss.str());

UInt64 prewrite_ts = pd_client->getTS();
UInt64 commit_ts = pd_client->getTS();
Expand Down
2 changes: 2 additions & 0 deletions dbms/src/Debug/dbgTools.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ RegionPtr createRegion(TableID table_id, RegionID region_id, const HandleID & st

Regions createRegions(TableID table_id, size_t region_num, size_t key_num_each_region, HandleID handle_begin, RegionID new_region_id_begin);

void encodeRow(const TiDB::TableInfo & table_info, const std::vector<Field> & fields, std::stringstream & ss);

void insert(const TiDB::TableInfo & table_info, RegionID region_id, HandleID handle_id, ASTs::const_iterator begin,
ASTs::const_iterator end, Context & context, const std::optional<std::tuple<Timestamp, UInt8>> & tso_del = {});

Expand Down
9 changes: 7 additions & 2 deletions dbms/src/Flash/Coprocessor/DAGBlockOutputStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <DataTypes/DataTypeNullable.h>
#include <Storages/Transaction/Codec.h>
#include <Storages/Transaction/Datum.h>
#include <Storages/Transaction/TypeMapping.h>

namespace DB
Expand All @@ -13,6 +14,9 @@ extern const int UNSUPPORTED_PARAMETER;
extern const int LOGICAL_ERROR;
} // namespace ErrorCodes

using TiDB::DatumBumpy;
using TiDB::TP;

DAGBlockOutputStream::DAGBlockOutputStream(tipb::SelectResponse & dag_response_, Int64 records_per_chunk_, tipb::EncodeType encodeType_,
std::vector<tipb::FieldType> && result_field_types_, Block header_)
: dag_response(dag_response_),
Expand Down Expand Up @@ -71,8 +75,9 @@ void DAGBlockOutputStream::write(const Block & block)
}
for (size_t j = 0; j < block.columns(); j++)
{
auto field = (*block.getByPosition(j).column.get())[i];
EncodeDatum(field, getCodecFlagByFieldType(result_field_types[j]), current_ss);
const auto & field = (*block.getByPosition(j).column.get())[i];
DatumBumpy datum(field, static_cast<TP>(result_field_types[j].tp()));
EncodeDatum(datum.field(), getCodecFlagByFieldType(result_field_types[j]), current_ss);
}
// Encode current row
records_per_chunk++;
Expand Down
4 changes: 2 additions & 2 deletions dbms/src/Flash/Coprocessor/DAGCodec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void encodeDAGString(const String & s, std::stringstream & ss) { ss << s; }

void encodeDAGBytes(const String & bytes, std::stringstream & ss) { ss << bytes; }

void encodeDAGDecimal(const Decimal & d, std::stringstream & ss) { EncodeDecimal(d, ss); }
void encodeDAGDecimal(const Field & field, std::stringstream & ss) { EncodeDecimal(field, ss); }

Int64 decodeDAGInt64(const String & s)
{
Expand Down Expand Up @@ -56,7 +56,7 @@ String decodeDAGString(const String & s) { return s; }

String decodeDAGBytes(const String & s) { return s; }

Decimal decodeDAGDecimal(const String & s)
Field decodeDAGDecimal(const String & s)
{
size_t cursor = 0;
return DecodeDecimal(cursor, s);
Expand Down
4 changes: 2 additions & 2 deletions dbms/src/Flash/Coprocessor/DAGCodec.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ void encodeDAGFloat32(Float32, std::stringstream &);
void encodeDAGFloat64(Float64, std::stringstream &);
void encodeDAGString(const String &, std::stringstream &);
void encodeDAGBytes(const String &, std::stringstream &);
void encodeDAGDecimal(const Decimal &, std::stringstream &);
void encodeDAGDecimal(const Field &, std::stringstream &);

Int64 decodeDAGInt64(const String &);
UInt64 decodeDAGUInt64(const String &);
Float32 decodeDAGFloat32(const String &);
Float64 decodeDAGFloat64(const String &);
String decodeDAGString(const String &);
String decodeDAGBytes(const String &);
Decimal decodeDAGDecimal(const String &);
Field decodeDAGDecimal(const String &);

} // namespace DB
14 changes: 13 additions & 1 deletion dbms/src/Flash/Coprocessor/DAGUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,19 @@ String exprToString(const tipb::Expr & expr, const NamesAndTypesList & input_col
case tipb::ExprType::Bytes:
return decodeDAGBytes(expr.val());
case tipb::ExprType::MysqlDecimal:
return decodeDAGDecimal(expr.val()).toString();
{
auto field = decodeDAGDecimal(expr.val());
if (field.getType() == Field::Types::Decimal32)
return field.get<DecimalField<Decimal32>>().toString();
else if (field.getType() == Field::Types::Decimal64)
return field.get<DecimalField<Decimal32>>().toString();
else if (field.getType() == Field::Types::Decimal128)
return field.get<DecimalField<Decimal32>>().toString();
else if (field.getType() == Field::Types::Decimal256)
return field.get<DecimalField<Decimal32>>().toString();
else
throw Exception("Not decimal literal" + expr.DebugString(), ErrorCodes::COP_BAD_DAG_REQUEST);
}
case tipb::ExprType::ColumnRef:
column_id = decodeDAGInt64(expr.val());
if (column_id < 0 || column_id >= (ColumnID)input_col.size())
Expand Down
Loading

0 comments on commit 08bacd7

Please sign in to comment.