From 976928ecd1391092a0857cc4959052bf5d5837a3 Mon Sep 17 00:00:00 2001 From: cqy123456 <39671710+cqy123456@users.noreply.github.com> Date: Thu, 28 Mar 2024 01:11:10 -0500 Subject: [PATCH] fix: fix fp16/bf16 some code missing and add more fp16/bf16 test (#31612) issue: #31534 Signed-off-by: cqy123456 --- go.mod | 1 + go.sum | 2 + internal/querynodev2/segments/mock_data.go | 40 +++++++++---- internal/querynodev2/segments/utils.go | 26 +++++++++ internal/storage/print_binlog.go | 31 ++++++++++ internal/storage/print_binlog_test.go | 30 ++++++++++ internal/storage/utils.go | 26 +++++++++ internal/storage/utils_test.go | 9 +++ internal/util/typeutil/schema.go | 9 +++ internal/util/typeutil/schema_test.go | 24 ++++++++ internal/util/typeutil/storage.go | 12 ++++ pkg/util/funcutil/func.go | 2 +- pkg/util/typeutil/schema.go | 26 +++++++++ pkg/util/typeutil/schema_test.go | 68 ++++++++++++++++++++++ 14 files changed, 293 insertions(+), 13 deletions(-) diff --git a/go.mod b/go.mod index f1856164d66f0..aaebe721b67b4 100644 --- a/go.mod +++ b/go.mod @@ -200,6 +200,7 @@ require ( github.com/twmb/murmur3 v1.1.3 // indirect github.com/uber/jaeger-client-go v2.30.0+incompatible // indirect github.com/ugorji/go/codec v1.2.11 // indirect + github.com/x448/float16 v0.8.4 // indirect github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 // indirect github.com/yusufpapurcu/wmi v1.2.2 // indirect github.com/zeebo/xxh3 v1.0.2 // indirect diff --git a/go.sum b/go.sum index 935c4a60fe06c..07a53c20241b4 100644 --- a/go.sum +++ b/go.sum @@ -868,6 +868,8 @@ github.com/valyala/fasthttp v1.6.0/go.mod h1:FstJa9V+Pj9vQ7OJie2qMHdwemEDaDiSdBn github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8= github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio= +github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= diff --git a/internal/querynodev2/segments/mock_data.go b/internal/querynodev2/segments/mock_data.go index 41658e67ff636..aa6fcfd6b6a55 100644 --- a/internal/querynodev2/segments/mock_data.go +++ b/internal/querynodev2/segments/mock_data.go @@ -28,6 +28,7 @@ import ( "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" + "github.com/x448/float16" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" @@ -278,6 +279,8 @@ func GenTestCollectionSchema(collectionName string, pkType schemapb.DataType, wi fieldArray := genConstantFieldSchema(simpleArrayField) floatVecFieldSchema := genVectorFieldSchema(simpleFloatVecField) binVecFieldSchema := genVectorFieldSchema(simpleBinVecField) + float16VecFieldSchema := genVectorFieldSchema(simpleFloat16VecField) + bfloat16VecFieldSchema := genVectorFieldSchema(simpleBFloat16VecField) var pkFieldSchema *schemapb.FieldSchema switch pkType { @@ -302,6 +305,8 @@ func GenTestCollectionSchema(collectionName string, pkType schemapb.DataType, wi binVecFieldSchema, pkFieldSchema, fieldArray, + float16VecFieldSchema, + bfloat16VecFieldSchema, }, } @@ -330,7 +335,7 @@ func GenTestIndexInfoList(collectionID int64, schema *schemapb.CollectionSchema) TypeParams: field.GetTypeParams(), } switch field.GetDataType() { - case schemapb.DataType_FloatVector, schemapb.DataType_Float16Vector: + case schemapb.DataType_FloatVector, schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector: { index.IndexParams = []*commonpb.KeyValuePair{ {Key: common.MetricTypeKey, Value: metric.L2}, @@ -500,21 +505,28 @@ func generateBinaryVectors(numRows, dim int) []byte { } func generateFloat16Vectors(numRows, dim int) []byte { - total := numRows * dim * 2 - ret := make([]byte, total) - _, err := rand.Read(ret) - if err != nil { - panic(err) + total := numRows * dim + ret := make([]byte, total*2) + for i := 0; i < total; i++ { + v := float16.Fromfloat32(rand.Float32()).Bits() + binary.LittleEndian.PutUint16(ret[i*2:], v) } return ret } func generateBFloat16Vectors(numRows, dim int) []byte { - total := numRows * dim * 2 - ret := make([]byte, total) - _, err := rand.Read(ret) - if err != nil { - panic(err) + total := numRows * dim + ret16 := make([]uint16, 0, total) + for i := 0; i < total; i++ { + f := rand.Float32() + bits := math.Float32bits(f) + bits >>= 16 + bits &= 0x7FFF + ret16 = append(ret16, uint16(bits)) + } + ret := make([]byte, len(ret16)*2) + for i, value := range ret16 { + binary.LittleEndian.PutUint16(ret[i*2:], value) } return ret } @@ -1009,6 +1021,10 @@ func GenAndSaveIndexV2(collectionID, partitionID, segmentID, buildID int64, dataset = indexcgowrapper.GenBinaryVecDataset(generateBinaryVectors(msgLength, defaultDim)) case schemapb.DataType_FloatVector: dataset = indexcgowrapper.GenFloatVecDataset(generateFloatVectors(msgLength, defaultDim)) + case schemapb.DataType_Float16Vector: + dataset = indexcgowrapper.GenFloat16VecDataset(generateFloat16Vectors(msgLength, defaultDim)) + case schemapb.DataType_BFloat16Vector: + dataset = indexcgowrapper.GenBFloat16VecDataset(generateBFloat16Vectors(msgLength, defaultDim)) case schemapb.DataType_SparseFloatVector: data := testutils.GenerateSparseFloatVectors(msgLength) dataset = indexcgowrapper.GenSparseFloatVecDataset(&storage.SparseFloatVectorFieldData{ @@ -1260,7 +1276,7 @@ func genBruteForceDSL(schema *schemapb.CollectionSchema, topK int64, roundDecima roundDecimalStr := strconv.FormatInt(roundDecimal, 10) var fieldID int64 for _, f := range schema.Fields { - if f.DataType == schemapb.DataType_FloatVector || f.DataType == schemapb.DataType_Float16Vector || f.DataType == schemapb.DataType_BFloat16Vector { + if f.DataType == schemapb.DataType_FloatVector { vecFieldName = f.Name fieldID = f.FieldID for _, p := range f.IndexParams { diff --git a/internal/querynodev2/segments/utils.go b/internal/querynodev2/segments/utils.go index 640d1a4fbcfed..6bd8d58ec21cb 100644 --- a/internal/querynodev2/segments/utils.go +++ b/internal/querynodev2/segments/utils.go @@ -93,6 +93,28 @@ func getPKsFromRowBasedInsertMsg(msg *msgstream.InsertMsg, schema *schemapb.Coll break } } + case schemapb.DataType_Float16Vector: + for _, t := range field.TypeParams { + if t.Key == common.DimKey { + dim, err := strconv.Atoi(t.Value) + if err != nil { + return nil, fmt.Errorf("strconv wrong on get dim, err = %s", err) + } + offset += dim * 2 + break + } + } + case schemapb.DataType_BFloat16Vector: + for _, t := range field.TypeParams { + if t.Key == common.DimKey { + dim, err := strconv.Atoi(t.Value) + if err != nil { + return nil, fmt.Errorf("strconv wrong on get dim, err = %s", err) + } + offset += dim * 2 + break + } + } case schemapb.DataType_SparseFloatVector: return nil, fmt.Errorf("SparseFloatVector not support in row based message") } @@ -280,6 +302,10 @@ func fillFieldData(ctx context.Context, vcm storage.ChunkManager, dataPath strin return fillBinVecFieldData(ctx, vcm, dataPath, fieldData, i, offset, endian) case schemapb.DataType_FloatVector: return fillFloatVecFieldData(ctx, vcm, dataPath, fieldData, i, offset, endian) + case schemapb.DataType_Float16Vector: + return fillFloatVecFieldData(ctx, vcm, dataPath, fieldData, i, offset, endian) + case schemapb.DataType_BFloat16Vector: + return fillFloatVecFieldData(ctx, vcm, dataPath, fieldData, i, offset, endian) case schemapb.DataType_SparseFloatVector: return fillSparseFloatVecFieldData(ctx, vcm, dataPath, fieldData, i, offset, endian) case schemapb.DataType_Bool: diff --git a/internal/storage/print_binlog.go b/internal/storage/print_binlog.go index c27cc4715a213..61cb03cb6d8ef 100644 --- a/internal/storage/print_binlog.go +++ b/internal/storage/print_binlog.go @@ -307,6 +307,37 @@ func printPayloadValues(colType schemapb.DataType, reader PayloadReaderInterface } fmt.Println() } + case schemapb.DataType_Float16Vector: + val, dim, err := reader.GetFloat16VectorFromPayload() + if err != nil { + return err + } + dim = dim * 2 + length := len(val) / dim + for i := 0; i < length; i++ { + fmt.Printf("\t\t%d :", i) + for j := 0; j < dim; j++ { + idx := i*dim + j + fmt.Printf(" %02x", val[idx]) + } + fmt.Println() + } + case schemapb.DataType_BFloat16Vector: + val, dim, err := reader.GetBFloat16VectorFromPayload() + if err != nil { + return err + } + dim = dim * 2 + length := len(val) / dim + for i := 0; i < length; i++ { + fmt.Printf("\t\t%d :", i) + for j := 0; j < dim; j++ { + idx := i*dim + j + fmt.Printf(" %02x", val[idx]) + } + fmt.Println() + } + case schemapb.DataType_FloatVector: val, dim, err := reader.GetFloatVectorFromPayload() if err != nil { diff --git a/internal/storage/print_binlog_test.go b/internal/storage/print_binlog_test.go index 090a90a6fc38c..127b39e403c70 100644 --- a/internal/storage/print_binlog_test.go +++ b/internal/storage/print_binlog_test.go @@ -184,6 +184,20 @@ func TestPrintBinlogFiles(t *testing.T) { Description: "description_12", DataType: schemapb.DataType_JSON, }, + { + FieldID: 111, + Name: "field_bfloat16_vector", + IsPrimaryKey: false, + Description: "description_13", + DataType: schemapb.DataType_BFloat16Vector, + }, + { + FieldID: 112, + Name: "field_float16_vector", + IsPrimaryKey: false, + Description: "description_14", + DataType: schemapb.DataType_Float16Vector, + }, }, }, } @@ -234,6 +248,14 @@ func TestPrintBinlogFiles(t *testing.T) { []byte(`{"key":"hello"}`), }, }, + 111: &BFloat16VectorFieldData{ + Data: []byte("12345678"), + Dim: 4, + }, + 112: &Float16VectorFieldData{ + Data: []byte("12345678"), + Dim: 4, + }, }, } @@ -283,6 +305,14 @@ func TestPrintBinlogFiles(t *testing.T) { []byte(`{"key":"world"}`), }, }, + 111: &BFloat16VectorFieldData{ + Data: []byte("abcdefgh"), + Dim: 4, + }, + 112: &Float16VectorFieldData{ + Data: []byte("abcdefgh"), + Dim: 4, + }, }, } firstBlobs, err := insertCodec.Serialize(1, 1, insertDataFirst) diff --git a/internal/storage/utils.go b/internal/storage/utils.go index 991dae414c062..f3a90a88726e2 100644 --- a/internal/storage/utils.go +++ b/internal/storage/utils.go @@ -1201,6 +1201,32 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert }, }, } + case *Float16VectorFieldData: + fieldData = &schemapb.FieldData{ + Type: schemapb.DataType_Float16Vector, + FieldId: fieldID, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Data: &schemapb.VectorField_Float16Vector{ + Float16Vector: rawData.Data, + }, + Dim: int64(rawData.Dim), + }, + }, + } + case *BFloat16VectorFieldData: + fieldData = &schemapb.FieldData{ + Type: schemapb.DataType_BFloat16Vector, + FieldId: fieldID, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Data: &schemapb.VectorField_Bfloat16Vector{ + Bfloat16Vector: rawData.Data, + }, + Dim: int64(rawData.Dim), + }, + }, + } case *SparseFloatVectorFieldData: fieldData = &schemapb.FieldData{ Type: schemapb.DataType_SparseFloatVector, diff --git a/internal/storage/utils_test.go b/internal/storage/utils_test.go index a08c15bdc25fc..80872d17c2f7b 100644 --- a/internal/storage/utils_test.go +++ b/internal/storage/utils_test.go @@ -993,6 +993,15 @@ func TestRowBasedInsertMsgToInsertData(t *testing.T) { } } +func TestRowBasedTransferInsertMsgToInsertRecord(t *testing.T) { + numRows, fVecDim, bVecDim, f16VecDim, bf16VecDim := 10, 8, 8, 8, 8 + schema, _, _ := genAllFieldsSchema(fVecDim, bVecDim, f16VecDim, bf16VecDim, false) + msg, _, _ := genRowBasedInsertMsg(numRows, fVecDim, bVecDim, f16VecDim, bf16VecDim) + + _, err := TransferInsertMsgToInsertRecord(schema, msg) + assert.NoError(t, err) +} + func TestRowBasedInsertMsgToInsertFloat16VectorDataError(t *testing.T) { msg := &msgstream.InsertMsg{ BaseMsg: msgstream.BaseMsg{ diff --git a/internal/util/typeutil/schema.go b/internal/util/typeutil/schema.go index 3140b140b0548..399818410ab0d 100644 --- a/internal/util/typeutil/schema.go +++ b/internal/util/typeutil/schema.go @@ -89,6 +89,15 @@ func ConvertToArrowSchema(fields []*schemapb.FieldSchema) (*arrow.Schema, error) Name: field.Name, Type: &arrow.FixedSizeBinaryType{ByteWidth: dim * 2}, }) + case schemapb.DataType_BFloat16Vector: + dim, err := storage.GetDimFromParams(field.TypeParams) + if err != nil { + return nil, err + } + arrowFields = append(arrowFields, arrow.Field{ + Name: field.Name, + Type: &arrow.FixedSizeBinaryType{ByteWidth: dim * 2}, + }) default: return nil, merr.WrapErrParameterInvalidMsg("unknown type %v", field.DataType.String()) } diff --git a/internal/util/typeutil/schema_test.go b/internal/util/typeutil/schema_test.go index 93aca85738c22..9450c52b8424f 100644 --- a/internal/util/typeutil/schema_test.go +++ b/internal/util/typeutil/schema_test.go @@ -41,9 +41,33 @@ func TestConvertArrowSchema(t *testing.T) { {FieldID: 12, Name: "field11", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64}, {FieldID: 13, Name: "field12", DataType: schemapb.DataType_JSON}, {FieldID: 14, Name: "field13", DataType: schemapb.DataType_Float16Vector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}}, + {FieldID: 15, Name: "field14", DataType: schemapb.DataType_BFloat16Vector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}}, } schema, err := ConvertToArrowSchema(fieldSchemas) assert.NoError(t, err) assert.Equal(t, len(fieldSchemas), len(schema.Fields())) } + +func TestConvertArrowSchemaWithoutDim(t *testing.T) { + fieldSchemas := []*schemapb.FieldSchema{ + {FieldID: 1, Name: "field0", DataType: schemapb.DataType_Bool}, + {FieldID: 2, Name: "field1", DataType: schemapb.DataType_Int8}, + {FieldID: 3, Name: "field2", DataType: schemapb.DataType_Int16}, + {FieldID: 4, Name: "field3", DataType: schemapb.DataType_Int32}, + {FieldID: 5, Name: "field4", DataType: schemapb.DataType_Int64}, + {FieldID: 6, Name: "field5", DataType: schemapb.DataType_Float}, + {FieldID: 7, Name: "field6", DataType: schemapb.DataType_Double}, + {FieldID: 8, Name: "field7", DataType: schemapb.DataType_String}, + {FieldID: 9, Name: "field8", DataType: schemapb.DataType_VarChar}, + {FieldID: 10, Name: "field9", DataType: schemapb.DataType_BinaryVector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}}, + {FieldID: 11, Name: "field10", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}}, + {FieldID: 12, Name: "field11", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64}, + {FieldID: 13, Name: "field12", DataType: schemapb.DataType_JSON}, + {FieldID: 14, Name: "field13", DataType: schemapb.DataType_Float16Vector, TypeParams: []*commonpb.KeyValuePair{}}, + {FieldID: 15, Name: "field14", DataType: schemapb.DataType_BFloat16Vector, TypeParams: []*commonpb.KeyValuePair{}}, + } + + _, err := ConvertToArrowSchema(fieldSchemas) + assert.Error(t, err) +} diff --git a/internal/util/typeutil/storage.go b/internal/util/typeutil/storage.go index 443b1f93c3291..6e3b44845e20b 100644 --- a/internal/util/typeutil/storage.go +++ b/internal/util/typeutil/storage.go @@ -108,6 +108,18 @@ func BuildRecord(b *array.RecordBuilder, data *storage.InsertData, fields []*sch byteLength := dim * 2 length := len(data) / byteLength + builder.Reserve(length) + for i := 0; i < length; i++ { + builder.Append(data[i*byteLength : (i+1)*byteLength]) + } + case schemapb.DataType_BFloat16Vector: + vecData := data.Data[field.FieldID].(*storage.BFloat16VectorFieldData) + builder := fBuilder.(*array.FixedSizeBinaryBuilder) + dim := vecData.Dim + data := vecData.Data + byteLength := dim * 2 + length := len(data) / byteLength + builder.Reserve(length) for i := 0; i < length; i++ { builder.Append(data[i*byteLength : (i+1)*byteLength]) diff --git a/pkg/util/funcutil/func.go b/pkg/util/funcutil/func.go index 71e67a55c34a4..231f59c18866f 100644 --- a/pkg/util/funcutil/func.go +++ b/pkg/util/funcutil/func.go @@ -146,7 +146,7 @@ func CheckCtxValid(ctx context.Context) bool { func GetVecFieldIDs(schema *schemapb.CollectionSchema) []int64 { var vecFieldIDs []int64 for _, field := range schema.Fields { - if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector || field.DataType == schemapb.DataType_Float16Vector || field.DataType == schemapb.DataType_SparseFloatVector { + if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector || field.DataType == schemapb.DataType_Float16Vector || field.DataType == schemapb.DataType_BFloat16Vector || field.DataType == schemapb.DataType_SparseFloatVector { vecFieldIDs = append(vecFieldIDs, field.FieldID) } } diff --git a/pkg/util/typeutil/schema.go b/pkg/util/typeutil/schema.go index a0937342cd286..c2d3bf05badd3 100644 --- a/pkg/util/typeutil/schema.go +++ b/pkg/util/typeutil/schema.go @@ -241,6 +241,10 @@ func EstimateEntitySize(fieldsData []*schemapb.FieldData, rowOffset int) (int, e res += int(fs.GetVectors().GetDim()) case schemapb.DataType_FloatVector: res += int(fs.GetVectors().GetDim() * 4) + case schemapb.DataType_Float16Vector: + res += int(fs.GetVectors().GetDim() * 2) + case schemapb.DataType_BFloat16Vector: + res += int(fs.GetVectors().GetDim() * 2) case schemapb.DataType_SparseFloatVector: vec := fs.GetVectors().GetSparseFloatVector() // counting only the size of the vector data, ignoring other @@ -527,6 +531,10 @@ func PrepareResultFieldData(sample []*schemapb.FieldData, topK int64) []*schemap vectors.Vectors.Data = &schemapb.VectorField_Float16Vector{ Float16Vector: make([]byte, 0, topK*dim*2), } + case *schemapb.VectorField_Bfloat16Vector: + vectors.Vectors.Data = &schemapb.VectorField_Bfloat16Vector{ + Bfloat16Vector: make([]byte, 0, topK*dim*2), + } case *schemapb.VectorField_BinaryVector: vectors.Vectors.Data = &schemapb.VectorField_BinaryVector{ BinaryVector: make([]byte, 0, topK*dim/8), @@ -957,6 +965,24 @@ func MergeFieldData(dst []*schemapb.FieldData, src []*schemapb.FieldData) error dstBinaryVector := dstVector.Data.(*schemapb.VectorField_BinaryVector) dstBinaryVector.BinaryVector = append(dstBinaryVector.BinaryVector, srcVector.BinaryVector...) } + case *schemapb.VectorField_Float16Vector: + if dstVector.GetFloat16Vector() == nil { + dstVector.Data = &schemapb.VectorField_Float16Vector{ + Float16Vector: srcVector.Float16Vector, + } + } else { + dstFloat16Vector := dstVector.Data.(*schemapb.VectorField_Float16Vector) + dstFloat16Vector.Float16Vector = append(dstFloat16Vector.Float16Vector, srcVector.Float16Vector...) + } + case *schemapb.VectorField_Bfloat16Vector: + if dstVector.GetBfloat16Vector() == nil { + dstVector.Data = &schemapb.VectorField_Bfloat16Vector{ + Bfloat16Vector: srcVector.Bfloat16Vector, + } + } else { + dstBfloat16Vector := dstVector.Data.(*schemapb.VectorField_Bfloat16Vector) + dstBfloat16Vector.Bfloat16Vector = append(dstBfloat16Vector.Bfloat16Vector, srcVector.Bfloat16Vector...) + } case *schemapb.VectorField_FloatVector: if dstVector.GetFloatVector() == nil { dstVector.Data = &schemapb.VectorField_FloatVector{ diff --git a/pkg/util/typeutil/schema_test.go b/pkg/util/typeutil/schema_test.go index 3835fa9317480..cb743d873b14a 100644 --- a/pkg/util/typeutil/schema_test.go +++ b/pkg/util/typeutil/schema_test.go @@ -984,6 +984,36 @@ func TestDeleteFieldData(t *testing.T) { assert.Equal(t, tmpSparseFloatVector, result2[SparseFloatVectorFieldID-common.StartOfUserFieldID].GetVectors().GetSparseFloatVector()) } +func TestEstimateEntitySize(t *testing.T) { + samples := []*schemapb.FieldData{ + { + FieldId: 111, + FieldName: "float16_vector", + Type: schemapb.DataType_Float16Vector, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: 64, + Data: &schemapb.VectorField_Float16Vector{}, + }, + }, + }, + { + FieldId: 112, + FieldName: "bfloat16_vector", + Type: schemapb.DataType_BFloat16Vector, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: 128, + Data: &schemapb.VectorField_Bfloat16Vector{}, + }, + }, + }, + } + size, error := EstimateEntitySize(samples, int(0)) + assert.NoError(t, error) + assert.True(t, size == 384) +} + func TestGetPrimaryFieldSchema(t *testing.T) { int64Field := &schemapb.FieldSchema{ FieldID: 1, @@ -1461,6 +1491,8 @@ func TestMergeFieldData(t *testing.T) { }, FieldId: 106, }, + genFieldData("float16_vector", 111, schemapb.DataType_Float16Vector, []byte("12345678"), 4), + genFieldData("bfloat16_vector", 112, schemapb.DataType_BFloat16Vector, []byte("12345678"), 4), } srcFields := []*schemapb.FieldData{ @@ -1520,6 +1552,8 @@ func TestMergeFieldData(t *testing.T) { }, FieldId: 106, }, + genFieldData("float16_vector", 111, schemapb.DataType_Float16Vector, []byte("abcdefgh"), 4), + genFieldData("bfloat16_vector", 112, schemapb.DataType_BFloat16Vector, []byte("ABCDEFGH"), 4), } err := MergeFieldData(dstFields, srcFields) @@ -1552,6 +1586,8 @@ func TestMergeFieldData(t *testing.T) { Dim: 2301, Contents: sparseFloatRows, }, dstFields[6].GetVectors().GetSparseFloatVector()) + assert.Equal(t, []byte("12345678abcdefgh"), dstFields[7].GetVectors().GetFloat16Vector()) + assert.Equal(t, []byte("12345678ABCDEFGH"), dstFields[8].GetVectors().GetBfloat16Vector()) }) t.Run("merge with nil", func(t *testing.T) { @@ -1584,6 +1620,8 @@ func TestMergeFieldData(t *testing.T) { }, FieldId: 104, }, + genFieldData("float16_vector", 111, schemapb.DataType_Float16Vector, []byte("12345678"), 4), + genFieldData("bfloat16_vector", 112, schemapb.DataType_BFloat16Vector, []byte("12345678"), 4), } dstFields := []*schemapb.FieldData{ @@ -1592,6 +1630,8 @@ func TestMergeFieldData(t *testing.T) { {Type: schemapb.DataType_JSON, FieldName: "json", Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_JsonData{}}}, FieldId: 102}, {Type: schemapb.DataType_Array, FieldName: "array", Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_ArrayData{}}}, FieldId: 103}, {Type: schemapb.DataType_SparseFloatVector, FieldName: "sparseFloat", Field: &schemapb.FieldData_Vectors{Vectors: &schemapb.VectorField{Data: &schemapb.VectorField_SparseFloatVector{}}}, FieldId: 104}, + {Type: schemapb.DataType_Float16Vector, FieldName: "float16_vector", Field: &schemapb.FieldData_Vectors{Vectors: &schemapb.VectorField{Data: &schemapb.VectorField_Float16Vector{}}}, FieldId: 111}, + {Type: schemapb.DataType_BFloat16Vector, FieldName: "bfloat16_vector", Field: &schemapb.FieldData_Vectors{Vectors: &schemapb.VectorField{Data: &schemapb.VectorField_Bfloat16Vector{}}}, FieldId: 112}, } err := MergeFieldData(dstFields, srcFields) @@ -1615,6 +1655,8 @@ func TestMergeFieldData(t *testing.T) { Dim: 521, Contents: sparseFloatRows[:3], }, dstFields[4].GetVectors().GetSparseFloatVector()) + assert.Equal(t, []byte("12345678"), dstFields[5].GetVectors().GetFloat16Vector()) + assert.Equal(t, []byte("12345678"), dstFields[6].GetVectors().GetBfloat16Vector()) }) t.Run("error case", func(t *testing.T) { @@ -1903,6 +1945,32 @@ func (s *FieldDataSuite) TestPrepareFieldData() { s.EqualValues(topK*128*2, cap(field.GetVectors().GetFloat16Vector())) }) + s.Run("bfloat16_vector", func() { + samples := []*schemapb.FieldData{ + { + FieldId: fieldID, + FieldName: fieldName, + Type: schemapb.DataType_BFloat16Vector, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: 128, + Data: &schemapb.VectorField_Bfloat16Vector{}, + }, + }, + }, + } + + fields := PrepareResultFieldData(samples, topK) + s.Require().Len(fields, 1) + field := fields[0] + s.Equal(fieldID, field.GetFieldId()) + s.Equal(fieldName, field.GetFieldName()) + s.Equal(schemapb.DataType_BFloat16Vector, field.GetType()) + + s.EqualValues(128, field.GetVectors().GetDim()) + s.EqualValues(topK*128*2, cap(field.GetVectors().GetBfloat16Vector())) + }) + s.Run("binary_vector", func() { samples := []*schemapb.FieldData{ {