Skip to content

Commit

Permalink
*: Vector Data types and Functions (#9341)
Browse files Browse the repository at this point in the history
ref #9032

*: Vector Data types and Functions

Support parsing vector data type written by TiDB
Support basic functions for vector data type: CastVectorAsText, VecDims, VecL1Distance, VecL2Distance, VecL2Norm, VecCosineDistance, VecNegativeInnerProduct

Signed-off-by: Lloyd-Pottiger <[email protected]>

Co-authored-by: Lloyd-Pottiger <[email protected]>
Co-authored-by: JaySon-Huang <[email protected]>
  • Loading branch information
JaySon-Huang and Lloyd-Pottiger authored Aug 21, 2024
1 parent cf5a204 commit 0d0463a
Show file tree
Hide file tree
Showing 32 changed files with 1,562 additions and 29 deletions.
71 changes: 69 additions & 2 deletions dbms/src/Columns/ColumnArray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,13 @@
#include <Common/SipHash.h>
#include <Common/typeid_cast.h>
#include <DataStreams/ColumnGathererStream.h>
#include <Functions/FunctionHelpers.h>
#include <IO/Endian.h>
#include <IO/WriteHelpers.h>
#include <string.h> // memcpy

#include <memory>

namespace DB
{
namespace ErrorCodes
Expand Down Expand Up @@ -798,10 +803,44 @@ void ColumnArray::getPermutation(bool reverse, size_t limit, int nan_direction_h
}
}

ColumnPtr ColumnArray::replicateRange(size_t /*start_row*/, size_t /*end_row*/, const IColumn::Offsets & /*offsets*/)
ColumnPtr ColumnArray::replicateRange(size_t start_row, size_t end_row, const IColumn::Offsets & replicate_offsets)
const
{
throw Exception("not implement.", ErrorCodes::NOT_IMPLEMENTED);
size_t col_size = size();
if (col_size != replicate_offsets.size())
throw Exception("Size of offsets doesn't match size of column.", ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH);

// We only support replicate to full column.
RUNTIME_CHECK(start_row == 0, start_row);
RUNTIME_CHECK(end_row == replicate_offsets.size(), end_row, replicate_offsets.size());

if (typeid_cast<const ColumnUInt8 *>(data.get()))
return replicateNumber<UInt8>(replicate_offsets);
if (typeid_cast<const ColumnUInt16 *>(data.get()))
return replicateNumber<UInt16>(replicate_offsets);
if (typeid_cast<const ColumnUInt32 *>(data.get()))
return replicateNumber<UInt32>(replicate_offsets);
if (typeid_cast<const ColumnUInt64 *>(data.get()))
return replicateNumber<UInt64>(replicate_offsets);
if (typeid_cast<const ColumnUInt128 *>(data.get()))
return replicateNumber<UInt128>(replicate_offsets);
if (typeid_cast<const ColumnInt8 *>(data.get()))
return replicateNumber<Int8>(replicate_offsets);
if (typeid_cast<const ColumnInt16 *>(data.get()))
return replicateNumber<Int16>(replicate_offsets);
if (typeid_cast<const ColumnInt32 *>(data.get()))
return replicateNumber<Int32>(replicate_offsets);
if (typeid_cast<const ColumnInt64 *>(data.get()))
return replicateNumber<Int64>(replicate_offsets);
if (typeid_cast<const ColumnFloat32 *>(data.get()))
return replicateNumber<Float32>(replicate_offsets);
if (typeid_cast<const ColumnFloat64 *>(data.get()))
return replicateNumber<Float64>(replicate_offsets);
if (typeid_cast<const ColumnConst *>(data.get()))
return replicateConst(replicate_offsets);
if (typeid_cast<const ColumnNullable *>(data.get()))
return replicateNullable(replicate_offsets);
return replicateGeneric(replicate_offsets);
}


Expand Down Expand Up @@ -1048,4 +1087,32 @@ void ColumnArray::gather(ColumnGathererStream & gatherer)
gatherer.gather(*this);
}

bool ColumnArray::decodeTiDBRowV2Datum(size_t cursor, const String & raw_value, size_t length, bool /* force_decode */)
{
RUNTIME_CHECK(raw_value.size() >= cursor + length);
insertFromDatumData(raw_value.c_str() + cursor, length);
return true;
}

void ColumnArray::insertFromDatumData(const char * data, size_t length)
{
RUNTIME_CHECK(boost::endian::order::native == boost::endian::order::little);

RUNTIME_CHECK(checkAndGetColumn<ColumnVector<Float32>>(&getData()));
RUNTIME_CHECK(getData().isFixedAndContiguous());

RUNTIME_CHECK(length >= sizeof(UInt32), length);
auto n = readLittleEndian<UInt32>(data);
data += sizeof(UInt32);

auto precise_data_size = n * sizeof(Float32);
RUNTIME_CHECK(length >= sizeof(UInt32) + precise_data_size, n, length);
insertData(data, precise_data_size);
}

std::pair<UInt32, StringRef> ColumnArray::getElementRef(size_t element_idx) const
{
return {static_cast<UInt32>(sizeAt(element_idx)), getDataAt(element_idx)};
}

} // namespace DB
9 changes: 9 additions & 0 deletions dbms/src/Columns/ColumnArray.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,15 @@ class ColumnArray final : public COWPtrHelper<IColumn, ColumnArray>
callback(data);
}

bool canBeInsideNullable() const override { return true; }

bool decodeTiDBRowV2Datum(size_t cursor, const String & raw_value, size_t /* length */, bool /* force_decode */)
override;

void insertFromDatumData(const char * data, size_t length) override;

std::pair<UInt32, StringRef> getElementRef(size_t element_idx) const;

private:
ColumnPtr data;
ColumnPtr offsets;
Expand Down
29 changes: 25 additions & 4 deletions dbms/src/Columns/ColumnNullable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,14 +204,29 @@ void ColumnNullable::get(size_t n, Field & res) const
getNestedColumn().get(n, res);
}

StringRef ColumnNullable::getDataAt(size_t /*n*/) const
StringRef ColumnNullable::getDataAt(size_t n) const
{
throw Exception(fmt::format("Method getDataAt is not supported for {}", getName()), ErrorCodes::NOT_IMPLEMENTED);
if (likely(!isNullAt(n)))
return getNestedColumn().getDataAt(n);

throw Exception(
ErrorCodes::NOT_IMPLEMENTED,
"Method getDataAt is not supported for {} in case if value is NULL",
getName());
}

void ColumnNullable::insertData(const char * /*pos*/, size_t /*length*/)
void ColumnNullable::insertData(const char * pos, size_t length)
{
throw Exception(fmt::format("Method insertData is not supported for {}", getName()), ErrorCodes::NOT_IMPLEMENTED);
if (pos == nullptr)
{
getNestedColumn().insertDefault();
getNullMapData().push_back(1);
}
else
{
getNestedColumn().insertData(pos, length);
getNullMapData().push_back(0);
}
}

bool ColumnNullable::decodeTiDBRowV2Datum(size_t cursor, const String & raw_value, size_t length, bool force_decode)
Expand All @@ -222,6 +237,12 @@ bool ColumnNullable::decodeTiDBRowV2Datum(size_t cursor, const String & raw_valu
return true;
}

void ColumnNullable::insertFromDatumData(const char * cursor, size_t len)
{
getNestedColumn().insertFromDatumData(cursor, len);
getNullMapData().push_back(0);
}

StringRef ColumnNullable::serializeValueIntoArena(
size_t n,
Arena & arena,
Expand Down
4 changes: 3 additions & 1 deletion dbms/src/Columns/ColumnNullable.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,11 @@ class ColumnNullable final : public COWPtrHelper<IColumn, ColumnNullable>
Field operator[](size_t n) const override;
void get(size_t n, Field & res) const override;
UInt64 get64(size_t n) const override { return nested_column->get64(n); }
StringRef getDataAt(size_t n) const override;
StringRef getDataAt(size_t) const override;
/// Will insert null value if pos=nullptr
void insertData(const char * pos, size_t length) override;
bool decodeTiDBRowV2Datum(size_t cursor, const String & raw_value, size_t length, bool force_decode) override;
void insertFromDatumData(const char *, size_t) override;
StringRef serializeValueIntoArena(
size_t n,
Arena & arena,
Expand Down
5 changes: 5 additions & 0 deletions dbms/src/Columns/IColumn.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ class IColumn : public COWPtr<IColumn>
throw Exception("Method decodeTiDBRowV2Datum is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}

virtual void insertFromDatumData(const char *, size_t)
{
throw Exception("Method insertFromDatumData is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}

/// Like getData, but has special behavior for columns that contain variable-length strings.
/// In this special case inserting data should be zero-ending (i.e. length is 1 byte greater than real string size).
virtual void insertDataWithTerminatingZero(const char * pos, size_t length) { insertData(pos, length); }
Expand Down
4 changes: 2 additions & 2 deletions dbms/src/DataTypes/DataTypeArray.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class DataTypeArray final : public IDataType

const char * getFamilyName() const override { return "Array"; }

bool canBeInsideNullable() const override { return false; }
bool canBeInsideNullable() const override { return true; }

TypeIndex getTypeId() const override { return TypeIndex::Array; }

Expand Down Expand Up @@ -98,7 +98,7 @@ class DataTypeArray final : public IDataType
bool haveSubtypes() const override { return true; }
bool cannotBeStoredInTables() const override { return nested->cannotBeStoredInTables(); }
bool textCanContainOnlyValidUTF8() const override { return nested->textCanContainOnlyValidUTF8(); }
bool isComparable() const override { return nested->isComparable(); };
bool isComparable() const override { return nested->isComparable(); }
bool canBeComparedWithCollation() const override { return nested->canBeComparedWithCollation(); }

bool isValueUnambiguouslyRepresentedInContiguousMemoryRegion() const override
Expand Down
7 changes: 7 additions & 0 deletions dbms/src/Debug/MockExecutor/AstToPB.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,13 @@ void literalFieldToTiPBExpr(const TiDB::ColumnInfo & ci, const Field & val_field
encodeDAGInt64(val, ss);
break;
}
case TiDB::TypeTiDBVectorFloat32:
{
expr->set_tp(tipb::ExprType::TiDBVectorFloat32);
const auto & val = val_field.safeGet<Array>();
encodeDAGVectorFloat32(val, ss);
break;
}
default:
throw Exception(fmt::format(
"Type {} does not support literal in function unit test",
Expand Down
4 changes: 4 additions & 0 deletions dbms/src/Debug/dbgTools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,10 @@ struct BatchCtrl
throw Exception(
"Not implented yet: BatchCtrl::encodeDatum, TiDB::CodecFlagJson",
ErrorCodes::LOGICAL_ERROR);
case TiDB::CodecFlagVectorFloat32:
throw Exception(
"Not implented yet: BatchCtrl::encodeDatum, TiDB::CodecFlagVectorFloat32",
ErrorCodes::LOGICAL_ERROR);
case TiDB::CodecFlagMax:
throw Exception("Not implented yet: BatchCtrl::encodeDatum, TiDB::CodecFlagMax", ErrorCodes::LOGICAL_ERROR);
case TiDB::CodecFlagDuration:
Expand Down
86 changes: 86 additions & 0 deletions dbms/src/Flash/Coprocessor/ArrowColCodec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <Columns/ColumnArray.h>
#include <Columns/ColumnDecimal.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnVector.h>
#include <Common/TiFlashException.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeDecimal.h>
#include <DataTypes/DataTypeEnum.h>
#include <DataTypes/DataTypeMyDate.h>
Expand Down Expand Up @@ -296,6 +298,37 @@ void flashStringColToArrowCol(
}
}

template <bool is_nullable>
void flashArrayFloat32ColToArrowCol(
TiDBColumn & dag_column,
const IColumn * flash_col_untyped,
size_t start_index,
size_t end_index)
{
// We only unwrap the NULLABLE() part.
const IColumn * nested_col = getNestedCol(flash_col_untyped);
const auto * flash_col = checkAndGetColumn<ColumnArray>(nested_col);

RUNTIME_CHECK(checkAndGetColumn<ColumnVector<Float32>>(&flash_col->getData()));
RUNTIME_CHECK(flash_col->getData().isFixedAndContiguous());

for (size_t i = start_index; i < end_index; i++)
{
// todo check if we can convert flash_col to DAG col directly since the internal representation is almost the same
if constexpr (is_nullable)
{
if (flash_col_untyped->isNullAt(i))
{
dag_column.appendNull();
continue;
}
}

auto [num_elems, elem_bytes] = flash_col->getElementRef(i);
dag_column.appendVectorF32(num_elems, elem_bytes);
}
}

template <bool is_nullable>
void flashBitColToArrowCol(
TiDBColumn & dag_column,
Expand Down Expand Up @@ -465,6 +498,20 @@ void flashColToArrowCol(
else
flashStringColToArrowCol<true>(dag_column, col, start_index, end_index);
break;
case TiDB::TypeTiDBVectorFloat32:
{
const auto * data_type = checkAndGetDataType<DataTypeArray>(type);
if (!data_type || data_type->getNestedType()->getTypeId() != TypeIndex::Float32)
throw TiFlashException(
Errors::Coprocessor::Internal,
"Type un-matched during arrow encode, target col type is array<float32> and source column type is {}",
type->getName());
if (tidb_column_info.hasNotNullFlag())
flashArrayFloat32ColToArrowCol<false>(dag_column, col, start_index, end_index);
else
flashArrayFloat32ColToArrowCol<true>(dag_column, col, start_index, end_index);
break;
}
case TiDB::TypeBit:
if (!checkDataType<DataTypeUInt64>(type))
throw TiFlashException(
Expand Down Expand Up @@ -529,6 +576,35 @@ const char * arrowStringColToFlashCol(
return pos + offsets[length];
}

const char * arrowArrayFloat32ColToFlashCol(
const char * pos,
UInt8,
UInt32 null_count,
const std::vector<UInt8> & null_bitmap,
const std::vector<UInt64> & offsets,
const ColumnWithTypeAndName & col,
const ColumnInfo &,
UInt32 length)
{
const auto * data_type = checkAndGetDataType<DataTypeArray>(&*col.type);
if (!data_type || data_type->getNestedType()->getTypeId() != TypeIndex::Float32)
throw TiFlashException(
Errors::Coprocessor::Internal,
"Type un-matched during arrow decode, target col type is array<float32> and source column type is {}",
col.type->getName());

for (UInt32 i = 0; i < length; i++)
{
if (checkNull(i, null_count, null_bitmap, col))
continue;

auto arrow_data_size = offsets[i + 1] - offsets[i];
const auto * base_offset = pos + offsets[i];
col.column->assumeMutable()->insertFromDatumData(base_offset, arrow_data_size);
}
return pos + offsets[length];
}

const char * arrowEnumColToFlashCol(
const char * pos,
UInt8,
Expand Down Expand Up @@ -823,6 +899,16 @@ const char * arrowColToFlashCol(
length);
case TiDB::TypeBit:
return arrowBitColToFlashCol(pos, field_length, null_count, null_bitmap, offsets, flash_col, col_info, length);
case TiDB::TypeTiDBVectorFloat32:
return arrowArrayFloat32ColToFlashCol(
pos,
field_length,
null_count,
null_bitmap,
offsets,
flash_col,
col_info,
length);
case TiDB::TypeEnum:
return arrowEnumColToFlashCol(pos, field_length, null_count, null_bitmap, offsets, flash_col, col_info, length);
default:
Expand Down
11 changes: 11 additions & 0 deletions dbms/src/Flash/Coprocessor/DAGCodec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ void encodeDAGDecimal(const Field & field, WriteBuffer & ss)
EncodeDecimal(field, ss);
}

void encodeDAGVectorFloat32(const Array & v, WriteBuffer & ss)
{
EncodeVectorFloat32(v, ss);
}

Int64 decodeDAGInt64(const String & s)
{
auto u = *(reinterpret_cast<const UInt64 *>(s.data()));
Expand Down Expand Up @@ -93,4 +98,10 @@ Field decodeDAGDecimal(const String & s)
return DecodeDecimal(cursor, s);
}

Field decodeDAGVectorFloat32(const String & s)
{
size_t cursor = 0;
return DecodeVectorFloat32(cursor, s);
}

} // namespace DB
2 changes: 2 additions & 0 deletions dbms/src/Flash/Coprocessor/DAGCodec.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ void encodeDAGFloat64(Float64, WriteBuffer &);
void encodeDAGString(const String &, WriteBuffer &);
void encodeDAGBytes(const String &, WriteBuffer &);
void encodeDAGDecimal(const Field &, WriteBuffer &);
void encodeDAGVectorFloat32(const Array &, WriteBuffer &);

Int64 decodeDAGInt64(const String &);
UInt64 decodeDAGUInt64(const String &);
Expand All @@ -34,5 +35,6 @@ Float64 decodeDAGFloat64(const String &);
String decodeDAGString(const String &);
String decodeDAGBytes(const String &);
Field decodeDAGDecimal(const String &);
Field decodeDAGVectorFloat32(const String &);

} // namespace DB
Loading

0 comments on commit 0d0463a

Please sign in to comment.