From 60baaed83227feef5943a9579d7289759056cf6c Mon Sep 17 00:00:00 2001 From: congqixia Date: Fri, 7 Jun 2024 11:31:53 +0800 Subject: [PATCH] fix: Make fp16&bf16 column use correct byte length (#33684) See also milvus-io/milvus-sdk-go#756 #31293 Signed-off-by: Congqi Xia Signed-off-by: Congqi Xia --- client/column/columns.go | 10 +++++----- client/column/columns_test.go | 13 ++++++------- client/column/vector_gen.go | 4 ++-- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/client/column/columns.go b/client/column/columns.go index a30b064e15235..0a9eb1a556f6f 100644 --- a/client/column/columns.go +++ b/client/column/columns.go @@ -281,12 +281,12 @@ func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) { data := x.Float16Vector dim := int(vectors.GetDim()) if end < 0 { - end = len(data) / dim + end = len(data) / dim / 2 } vector := make([][]byte, 0, end-begin) for i := begin; i < end; i++ { - v := make([]byte, dim) - copy(v, data[i*dim:(i+1)*dim]) + v := make([]byte, dim*2) + copy(v, data[i*dim*2:(i+1)*dim*2]) vector = append(vector, v) } return NewColumnFloat16Vector(fd.GetFieldName(), dim, vector), nil @@ -300,12 +300,12 @@ func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) { data := x.Bfloat16Vector dim := int(vectors.GetDim()) if end < 0 { - end = len(data) / dim + end = len(data) / dim / 2 } vector := make([][]byte, 0, end-begin) // shall not have remanunt for i := begin; i < end; i++ { v := make([]byte, dim) - copy(v, data[i*dim:(i+1)*dim]) + copy(v, data[i*dim*2:(i+1)*dim*2]) vector = append(vector, v) } return NewColumnBFloat16Vector(fd.GetFieldName(), dim, vector), nil diff --git a/client/column/columns_test.go b/client/column/columns_test.go index 38547f384ca45..1a4b3f1605bf5 100644 --- a/client/column/columns_test.go +++ b/client/column/columns_test.go @@ -39,14 +39,13 @@ func TestIDColumns(t *testing.T) { ) t.Run("nil id", func(t *testing.T) { - col, err := IDColumns(intPKCol, nil, 0, -1) - assert.NoError(t, err) - assert.EqualValues(t, 0, col.Len()) - col, err = IDColumns(strPKCol, nil, 0, -1) - assert.NoError(t, err) - assert.EqualValues(t, 0, col.Len()) + _, err := IDColumns(intPKCol, nil, 0, -1) + assert.Error(t, err) + _, err = IDColumns(strPKCol, nil, 0, -1) + assert.Error(t, err) + idField := &schemapb.IDs{} - col, err = IDColumns(intPKCol, idField, 0, -1) + col, err := IDColumns(intPKCol, idField, 0, -1) assert.NoError(t, err) assert.EqualValues(t, 0, col.Len()) col, err = IDColumns(strPKCol, idField, 0, -1) diff --git a/client/column/vector_gen.go b/client/column/vector_gen.go index dca82eba23583..e2ab3e3f872ea 100644 --- a/client/column/vector_gen.go +++ b/client/column/vector_gen.go @@ -244,7 +244,7 @@ func (c *ColumnFloat16Vector) FieldData() *schemapb.FieldData { FieldName: c.name, } - data := make([]byte, 0, len(c.values)*c.dim) + data := make([]byte, 0, len(c.values)*c.dim*2) for _, vector := range c.values { data = append(data, vector...) @@ -330,7 +330,7 @@ func (c *ColumnBFloat16Vector) FieldData() *schemapb.FieldData { FieldName: c.name, } - data := make([]byte, 0, len(c.values)*c.dim) + data := make([]byte, 0, len(c.values)*c.dim*2) for _, vector := range c.values { data = append(data, vector...)