Skip to content

Commit

Permalink
fix: fix fp16/bf16 some code missing and add more fp16/bf16 test (#31612
Browse files Browse the repository at this point in the history
)

issue: #31534

Signed-off-by: cqy123456 <[email protected]>
  • Loading branch information
cqy123456 authored Mar 28, 2024
1 parent 3d66670 commit 976928e
Show file tree
Hide file tree
Showing 14 changed files with 293 additions and 13 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ require (
github.com/twmb/murmur3 v1.1.3 // indirect
github.com/uber/jaeger-client-go v2.30.0+incompatible // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
github.com/x448/float16 v0.8.4 // indirect
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 // indirect
github.com/yusufpapurcu/wmi v1.2.2 // indirect
github.com/zeebo/xxh3 v1.0.2 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,8 @@ github.com/valyala/fasthttp v1.6.0/go.mod h1:FstJa9V+Pj9vQ7OJie2qMHdwemEDaDiSdBn
github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8=
github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU=
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ=
github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y=
Expand Down
40 changes: 28 additions & 12 deletions internal/querynodev2/segments/mock_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (

"github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto"
"github.com/x448/float16"
"go.uber.org/zap"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
Expand Down Expand Up @@ -278,6 +279,8 @@ func GenTestCollectionSchema(collectionName string, pkType schemapb.DataType, wi
fieldArray := genConstantFieldSchema(simpleArrayField)
floatVecFieldSchema := genVectorFieldSchema(simpleFloatVecField)
binVecFieldSchema := genVectorFieldSchema(simpleBinVecField)
float16VecFieldSchema := genVectorFieldSchema(simpleFloat16VecField)
bfloat16VecFieldSchema := genVectorFieldSchema(simpleBFloat16VecField)
var pkFieldSchema *schemapb.FieldSchema

switch pkType {
Expand All @@ -302,6 +305,8 @@ func GenTestCollectionSchema(collectionName string, pkType schemapb.DataType, wi
binVecFieldSchema,
pkFieldSchema,
fieldArray,
float16VecFieldSchema,
bfloat16VecFieldSchema,
},
}

Expand Down Expand Up @@ -330,7 +335,7 @@ func GenTestIndexInfoList(collectionID int64, schema *schemapb.CollectionSchema)
TypeParams: field.GetTypeParams(),
}
switch field.GetDataType() {
case schemapb.DataType_FloatVector, schemapb.DataType_Float16Vector:
case schemapb.DataType_FloatVector, schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector:
{
index.IndexParams = []*commonpb.KeyValuePair{
{Key: common.MetricTypeKey, Value: metric.L2},
Expand Down Expand Up @@ -500,21 +505,28 @@ func generateBinaryVectors(numRows, dim int) []byte {
}

func generateFloat16Vectors(numRows, dim int) []byte {
total := numRows * dim * 2
ret := make([]byte, total)
_, err := rand.Read(ret)
if err != nil {
panic(err)
total := numRows * dim
ret := make([]byte, total*2)
for i := 0; i < total; i++ {
v := float16.Fromfloat32(rand.Float32()).Bits()
binary.LittleEndian.PutUint16(ret[i*2:], v)
}
return ret
}

func generateBFloat16Vectors(numRows, dim int) []byte {
total := numRows * dim * 2
ret := make([]byte, total)
_, err := rand.Read(ret)
if err != nil {
panic(err)
total := numRows * dim
ret16 := make([]uint16, 0, total)
for i := 0; i < total; i++ {
f := rand.Float32()
bits := math.Float32bits(f)
bits >>= 16
bits &= 0x7FFF
ret16 = append(ret16, uint16(bits))
}
ret := make([]byte, len(ret16)*2)
for i, value := range ret16 {
binary.LittleEndian.PutUint16(ret[i*2:], value)
}
return ret
}
Expand Down Expand Up @@ -1009,6 +1021,10 @@ func GenAndSaveIndexV2(collectionID, partitionID, segmentID, buildID int64,
dataset = indexcgowrapper.GenBinaryVecDataset(generateBinaryVectors(msgLength, defaultDim))
case schemapb.DataType_FloatVector:
dataset = indexcgowrapper.GenFloatVecDataset(generateFloatVectors(msgLength, defaultDim))
case schemapb.DataType_Float16Vector:
dataset = indexcgowrapper.GenFloat16VecDataset(generateFloat16Vectors(msgLength, defaultDim))
case schemapb.DataType_BFloat16Vector:
dataset = indexcgowrapper.GenBFloat16VecDataset(generateBFloat16Vectors(msgLength, defaultDim))
case schemapb.DataType_SparseFloatVector:
data := testutils.GenerateSparseFloatVectors(msgLength)
dataset = indexcgowrapper.GenSparseFloatVecDataset(&storage.SparseFloatVectorFieldData{
Expand Down Expand Up @@ -1260,7 +1276,7 @@ func genBruteForceDSL(schema *schemapb.CollectionSchema, topK int64, roundDecima
roundDecimalStr := strconv.FormatInt(roundDecimal, 10)
var fieldID int64
for _, f := range schema.Fields {
if f.DataType == schemapb.DataType_FloatVector || f.DataType == schemapb.DataType_Float16Vector || f.DataType == schemapb.DataType_BFloat16Vector {
if f.DataType == schemapb.DataType_FloatVector {
vecFieldName = f.Name
fieldID = f.FieldID
for _, p := range f.IndexParams {
Expand Down
26 changes: 26 additions & 0 deletions internal/querynodev2/segments/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,28 @@ func getPKsFromRowBasedInsertMsg(msg *msgstream.InsertMsg, schema *schemapb.Coll
break
}
}
case schemapb.DataType_Float16Vector:
for _, t := range field.TypeParams {
if t.Key == common.DimKey {
dim, err := strconv.Atoi(t.Value)
if err != nil {
return nil, fmt.Errorf("strconv wrong on get dim, err = %s", err)
}
offset += dim * 2
break
}
}
case schemapb.DataType_BFloat16Vector:
for _, t := range field.TypeParams {
if t.Key == common.DimKey {
dim, err := strconv.Atoi(t.Value)
if err != nil {
return nil, fmt.Errorf("strconv wrong on get dim, err = %s", err)
}
offset += dim * 2
break
}
}
case schemapb.DataType_SparseFloatVector:
return nil, fmt.Errorf("SparseFloatVector not support in row based message")
}
Expand Down Expand Up @@ -280,6 +302,10 @@ func fillFieldData(ctx context.Context, vcm storage.ChunkManager, dataPath strin
return fillBinVecFieldData(ctx, vcm, dataPath, fieldData, i, offset, endian)
case schemapb.DataType_FloatVector:
return fillFloatVecFieldData(ctx, vcm, dataPath, fieldData, i, offset, endian)
case schemapb.DataType_Float16Vector:
return fillFloatVecFieldData(ctx, vcm, dataPath, fieldData, i, offset, endian)
case schemapb.DataType_BFloat16Vector:
return fillFloatVecFieldData(ctx, vcm, dataPath, fieldData, i, offset, endian)
case schemapb.DataType_SparseFloatVector:
return fillSparseFloatVecFieldData(ctx, vcm, dataPath, fieldData, i, offset, endian)
case schemapb.DataType_Bool:
Expand Down
31 changes: 31 additions & 0 deletions internal/storage/print_binlog.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,37 @@ func printPayloadValues(colType schemapb.DataType, reader PayloadReaderInterface
}
fmt.Println()
}
case schemapb.DataType_Float16Vector:
val, dim, err := reader.GetFloat16VectorFromPayload()
if err != nil {
return err
}
dim = dim * 2
length := len(val) / dim
for i := 0; i < length; i++ {
fmt.Printf("\t\t%d :", i)
for j := 0; j < dim; j++ {
idx := i*dim + j
fmt.Printf(" %02x", val[idx])
}
fmt.Println()
}
case schemapb.DataType_BFloat16Vector:
val, dim, err := reader.GetBFloat16VectorFromPayload()
if err != nil {
return err
}
dim = dim * 2
length := len(val) / dim
for i := 0; i < length; i++ {
fmt.Printf("\t\t%d :", i)
for j := 0; j < dim; j++ {
idx := i*dim + j
fmt.Printf(" %02x", val[idx])
}
fmt.Println()
}

case schemapb.DataType_FloatVector:
val, dim, err := reader.GetFloatVectorFromPayload()
if err != nil {
Expand Down
30 changes: 30 additions & 0 deletions internal/storage/print_binlog_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,20 @@ func TestPrintBinlogFiles(t *testing.T) {
Description: "description_12",
DataType: schemapb.DataType_JSON,
},
{
FieldID: 111,
Name: "field_bfloat16_vector",
IsPrimaryKey: false,
Description: "description_13",
DataType: schemapb.DataType_BFloat16Vector,
},
{
FieldID: 112,
Name: "field_float16_vector",
IsPrimaryKey: false,
Description: "description_14",
DataType: schemapb.DataType_Float16Vector,
},
},
},
}
Expand Down Expand Up @@ -234,6 +248,14 @@ func TestPrintBinlogFiles(t *testing.T) {
[]byte(`{"key":"hello"}`),
},
},
111: &BFloat16VectorFieldData{
Data: []byte("12345678"),
Dim: 4,
},
112: &Float16VectorFieldData{
Data: []byte("12345678"),
Dim: 4,
},
},
}

Expand Down Expand Up @@ -283,6 +305,14 @@ func TestPrintBinlogFiles(t *testing.T) {
[]byte(`{"key":"world"}`),
},
},
111: &BFloat16VectorFieldData{
Data: []byte("abcdefgh"),
Dim: 4,
},
112: &Float16VectorFieldData{
Data: []byte("abcdefgh"),
Dim: 4,
},
},
}
firstBlobs, err := insertCodec.Serialize(1, 1, insertDataFirst)
Expand Down
26 changes: 26 additions & 0 deletions internal/storage/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -1201,6 +1201,32 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert
},
},
}
case *Float16VectorFieldData:
fieldData = &schemapb.FieldData{
Type: schemapb.DataType_Float16Vector,
FieldId: fieldID,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_Float16Vector{
Float16Vector: rawData.Data,
},
Dim: int64(rawData.Dim),
},
},
}
case *BFloat16VectorFieldData:
fieldData = &schemapb.FieldData{
Type: schemapb.DataType_BFloat16Vector,
FieldId: fieldID,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_Bfloat16Vector{
Bfloat16Vector: rawData.Data,
},
Dim: int64(rawData.Dim),
},
},
}
case *SparseFloatVectorFieldData:
fieldData = &schemapb.FieldData{
Type: schemapb.DataType_SparseFloatVector,
Expand Down
9 changes: 9 additions & 0 deletions internal/storage/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,15 @@ func TestRowBasedInsertMsgToInsertData(t *testing.T) {
}
}

func TestRowBasedTransferInsertMsgToInsertRecord(t *testing.T) {
numRows, fVecDim, bVecDim, f16VecDim, bf16VecDim := 10, 8, 8, 8, 8
schema, _, _ := genAllFieldsSchema(fVecDim, bVecDim, f16VecDim, bf16VecDim, false)
msg, _, _ := genRowBasedInsertMsg(numRows, fVecDim, bVecDim, f16VecDim, bf16VecDim)

_, err := TransferInsertMsgToInsertRecord(schema, msg)
assert.NoError(t, err)
}

func TestRowBasedInsertMsgToInsertFloat16VectorDataError(t *testing.T) {
msg := &msgstream.InsertMsg{
BaseMsg: msgstream.BaseMsg{
Expand Down
9 changes: 9 additions & 0 deletions internal/util/typeutil/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,15 @@ func ConvertToArrowSchema(fields []*schemapb.FieldSchema) (*arrow.Schema, error)
Name: field.Name,
Type: &arrow.FixedSizeBinaryType{ByteWidth: dim * 2},
})
case schemapb.DataType_BFloat16Vector:
dim, err := storage.GetDimFromParams(field.TypeParams)
if err != nil {
return nil, err
}
arrowFields = append(arrowFields, arrow.Field{
Name: field.Name,
Type: &arrow.FixedSizeBinaryType{ByteWidth: dim * 2},
})
default:
return nil, merr.WrapErrParameterInvalidMsg("unknown type %v", field.DataType.String())
}
Expand Down
24 changes: 24 additions & 0 deletions internal/util/typeutil/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,33 @@ func TestConvertArrowSchema(t *testing.T) {
{FieldID: 12, Name: "field11", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64},
{FieldID: 13, Name: "field12", DataType: schemapb.DataType_JSON},
{FieldID: 14, Name: "field13", DataType: schemapb.DataType_Float16Vector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
{FieldID: 15, Name: "field14", DataType: schemapb.DataType_BFloat16Vector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
}

schema, err := ConvertToArrowSchema(fieldSchemas)
assert.NoError(t, err)
assert.Equal(t, len(fieldSchemas), len(schema.Fields()))
}

func TestConvertArrowSchemaWithoutDim(t *testing.T) {
fieldSchemas := []*schemapb.FieldSchema{
{FieldID: 1, Name: "field0", DataType: schemapb.DataType_Bool},
{FieldID: 2, Name: "field1", DataType: schemapb.DataType_Int8},
{FieldID: 3, Name: "field2", DataType: schemapb.DataType_Int16},
{FieldID: 4, Name: "field3", DataType: schemapb.DataType_Int32},
{FieldID: 5, Name: "field4", DataType: schemapb.DataType_Int64},
{FieldID: 6, Name: "field5", DataType: schemapb.DataType_Float},
{FieldID: 7, Name: "field6", DataType: schemapb.DataType_Double},
{FieldID: 8, Name: "field7", DataType: schemapb.DataType_String},
{FieldID: 9, Name: "field8", DataType: schemapb.DataType_VarChar},
{FieldID: 10, Name: "field9", DataType: schemapb.DataType_BinaryVector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
{FieldID: 11, Name: "field10", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
{FieldID: 12, Name: "field11", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64},
{FieldID: 13, Name: "field12", DataType: schemapb.DataType_JSON},
{FieldID: 14, Name: "field13", DataType: schemapb.DataType_Float16Vector, TypeParams: []*commonpb.KeyValuePair{}},
{FieldID: 15, Name: "field14", DataType: schemapb.DataType_BFloat16Vector, TypeParams: []*commonpb.KeyValuePair{}},
}

_, err := ConvertToArrowSchema(fieldSchemas)
assert.Error(t, err)
}
12 changes: 12 additions & 0 deletions internal/util/typeutil/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,18 @@ func BuildRecord(b *array.RecordBuilder, data *storage.InsertData, fields []*sch
byteLength := dim * 2
length := len(data) / byteLength

builder.Reserve(length)
for i := 0; i < length; i++ {
builder.Append(data[i*byteLength : (i+1)*byteLength])
}
case schemapb.DataType_BFloat16Vector:
vecData := data.Data[field.FieldID].(*storage.BFloat16VectorFieldData)
builder := fBuilder.(*array.FixedSizeBinaryBuilder)
dim := vecData.Dim
data := vecData.Data
byteLength := dim * 2
length := len(data) / byteLength

builder.Reserve(length)
for i := 0; i < length; i++ {
builder.Append(data[i*byteLength : (i+1)*byteLength])
Expand Down
2 changes: 1 addition & 1 deletion pkg/util/funcutil/func.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func CheckCtxValid(ctx context.Context) bool {
func GetVecFieldIDs(schema *schemapb.CollectionSchema) []int64 {
var vecFieldIDs []int64
for _, field := range schema.Fields {
if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector || field.DataType == schemapb.DataType_Float16Vector || field.DataType == schemapb.DataType_SparseFloatVector {
if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector || field.DataType == schemapb.DataType_Float16Vector || field.DataType == schemapb.DataType_BFloat16Vector || field.DataType == schemapb.DataType_SparseFloatVector {
vecFieldIDs = append(vecFieldIDs, field.FieldID)
}
}
Expand Down
Loading

0 comments on commit 976928e

Please sign in to comment.