From 3a85ad9c907b6469d32d4db77ef4786826b59321 Mon Sep 17 00:00:00 2001 From: Congqi Xia Date: Wed, 22 May 2024 15:21:33 +0800 Subject: [PATCH] enhance: Support Row-based insert for milvusclient See also #31293 Signed-off-by: Congqi Xia --- client/collection.go | 6 +- client/collection_options.go | 4 + client/entity/schema.go | 14 ++ client/entity/sparse.go | 2 +- client/row/data.go | 332 +++++++++++++++++++++++++++++++++++ client/row/data_test.go | 174 ++++++++++++++++++ client/row/schema.go | 185 +++++++++++++++++++ client/row/schema_test.go | 213 ++++++++++++++++++++++ client/write_options.go | 57 +++++- 9 files changed, 977 insertions(+), 10 deletions(-) create mode 100644 client/row/data.go create mode 100644 client/row/data_test.go create mode 100644 client/row/schema.go create mode 100644 client/row/schema_test.go diff --git a/client/collection.go b/client/collection.go index 039ff2460d64c..4031c687d9993 100644 --- a/client/collection.go +++ b/client/collection.go @@ -62,10 +62,6 @@ func (c *Client) CreateCollection(ctx context.Context, option CreateCollectionOp return nil } -type ListCollectionOption interface { - Request() *milvuspb.ShowCollectionsRequest -} - func (c *Client) ListCollections(ctx context.Context, option ListCollectionOption, callOptions ...grpc.CallOption) (collectionNames []string, err error) { req := option.Request() err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error { @@ -82,7 +78,7 @@ func (c *Client) ListCollections(ctx context.Context, option ListCollectionOptio return collectionNames, err } -func (c *Client) DescribeCollection(ctx context.Context, option *describeCollectionOption, callOptions ...grpc.CallOption) (collection *entity.Collection, err error) { +func (c *Client) DescribeCollection(ctx context.Context, option DescribeCollectionOption, callOptions ...grpc.CallOption) (collection *entity.Collection, err error) { req := option.Request() err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error { resp, err := milvusService.DescribeCollection(ctx, req, callOptions...) diff --git a/client/collection_options.go b/client/collection_options.go index adb59e37b5145..696fe702273a2 100644 --- a/client/collection_options.go +++ b/client/collection_options.go @@ -159,6 +159,10 @@ func NewCreateCollectionOption(name string, collectionSchema *entity.Schema) *cr } } +type ListCollectionOption interface { + Request() *milvuspb.ShowCollectionsRequest +} + type listCollectionOption struct{} func (opt *listCollectionOption) Request() *milvuspb.ShowCollectionsRequest { diff --git a/client/entity/schema.go b/client/entity/schema.go index ce30b53f51483..8225ba6c2fd3c 100644 --- a/client/entity/schema.go +++ b/client/entity/schema.go @@ -19,6 +19,8 @@ package entity import ( "strconv" + "github.com/cockroachdb/errors" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) @@ -293,6 +295,18 @@ func (f *Field) WithDim(dim int64) *Field { return f } +func (f *Field) GetDim() (int64, error) { + dimStr, has := f.TypeParams[TypeParamDim] + if !has { + return -1, errors.New("field with no dim") + } + dim, err := strconv.ParseInt(dimStr, 10, 64) + if err != nil { + return -1, errors.Newf("field with bad format dim: %s", err.Error()) + } + return dim, nil +} + func (f *Field) WithMaxLength(maxLen int64) *Field { if f.TypeParams == nil { f.TypeParams = make(map[string]string) diff --git a/client/entity/sparse.go b/client/entity/sparse.go index 2bded8f6e8f2b..56ca5f4dca265 100644 --- a/client/entity/sparse.go +++ b/client/entity/sparse.go @@ -88,7 +88,7 @@ func deserializeSliceSparceEmbedding(bs []byte) (sliceSparseEmbedding, error) { return sliceSparseEmbedding{}, errors.New("not valid sparse embedding bytes") } - length = length / 8 + length /= 8 result := sliceSparseEmbedding{ positions: make([]uint32, length), diff --git a/client/row/data.go b/client/row/data.go new file mode 100644 index 0000000000000..292661ade29be --- /dev/null +++ b/client/row/data.go @@ -0,0 +1,332 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package row + +import ( + "encoding/json" + "fmt" + "reflect" + "strconv" + + "github.com/cockroachdb/errors" + + "github.com/milvus-io/milvus/client/v2/column" + "github.com/milvus-io/milvus/client/v2/entity" +) + +const ( + // MilvusTag struct tag const for milvus row based struct + MilvusTag = `milvus` + + // MilvusSkipTagValue struct tag const for skip this field. + MilvusSkipTagValue = `-` + + // MilvusTagSep struct tag const for attribute separator + MilvusTagSep = `;` + + // MilvusTagName struct tag const for field name + MilvusTagName = `NAME` + + // VectorDimTag struct tag const for vector dimension + VectorDimTag = `DIM` + + // VectorTypeTag struct tag const for binary vector type + VectorTypeTag = `VECTOR_TYPE` + + // MilvusPrimaryKey struct tag const for primary key indicator + MilvusPrimaryKey = `PRIMARY_KEY` + + // MilvusAutoID struct tag const for auto id indicator + MilvusAutoID = `AUTO_ID` + + // DimMax dimension max value + DimMax = 65535 +) + +func AnyToColumns(rows []interface{}, schemas ...*entity.Schema) ([]column.Column, error) { + rowsLen := len(rows) + if rowsLen == 0 { + return []column.Column{}, errors.New("0 length column") + } + + var sch *entity.Schema + var err error + // if schema not provided, try to parse from row + if len(schemas) == 0 { + sch, err = ParseSchema(rows[0]) + if err != nil { + return []column.Column{}, err + } + } else { + // use first schema provided + sch = schemas[0] + } + + isDynamic := sch.EnableDynamicField + var dynamicCol *column.ColumnJSONBytes + + nameColumns := make(map[string]column.Column) + for _, field := range sch.Fields { + // skip auto id pk field + if field.PrimaryKey && field.AutoID { + continue + } + switch field.DataType { + case entity.FieldTypeBool: + data := make([]bool, 0, rowsLen) + col := column.NewColumnBool(field.Name, data) + nameColumns[field.Name] = col + case entity.FieldTypeInt8: + data := make([]int8, 0, rowsLen) + col := column.NewColumnInt8(field.Name, data) + nameColumns[field.Name] = col + case entity.FieldTypeInt16: + data := make([]int16, 0, rowsLen) + col := column.NewColumnInt16(field.Name, data) + nameColumns[field.Name] = col + case entity.FieldTypeInt32: + data := make([]int32, 0, rowsLen) + col := column.NewColumnInt32(field.Name, data) + nameColumns[field.Name] = col + case entity.FieldTypeInt64: + data := make([]int64, 0, rowsLen) + col := column.NewColumnInt64(field.Name, data) + nameColumns[field.Name] = col + case entity.FieldTypeFloat: + data := make([]float32, 0, rowsLen) + col := column.NewColumnFloat(field.Name, data) + nameColumns[field.Name] = col + case entity.FieldTypeDouble: + data := make([]float64, 0, rowsLen) + col := column.NewColumnDouble(field.Name, data) + nameColumns[field.Name] = col + case entity.FieldTypeString, entity.FieldTypeVarChar: + data := make([]string, 0, rowsLen) + col := column.NewColumnString(field.Name, data) + nameColumns[field.Name] = col + case entity.FieldTypeJSON: + data := make([][]byte, 0, rowsLen) + col := column.NewColumnJSONBytes(field.Name, data) + nameColumns[field.Name] = col + case entity.FieldTypeArray: + col := NewArrayColumn(field) + if col == nil { + return nil, errors.Newf("unsupported element type %s for Array", field.ElementType.String()) + } + nameColumns[field.Name] = col + case entity.FieldTypeFloatVector: + data := make([][]float32, 0, rowsLen) + dimStr, has := field.TypeParams[entity.TypeParamDim] + if !has { + return []column.Column{}, errors.New("vector field with no dim") + } + dim, err := strconv.ParseInt(dimStr, 10, 64) + if err != nil { + return []column.Column{}, fmt.Errorf("vector field with bad format dim: %s", err.Error()) + } + col := column.NewColumnFloatVector(field.Name, int(dim), data) + nameColumns[field.Name] = col + case entity.FieldTypeBinaryVector: + data := make([][]byte, 0, rowsLen) + dim, err := field.GetDim() + if err != nil { + return []column.Column{}, err + } + col := column.NewColumnBinaryVector(field.Name, int(dim), data) + nameColumns[field.Name] = col + case entity.FieldTypeFloat16Vector: + data := make([][]byte, 0, rowsLen) + dim, err := field.GetDim() + if err != nil { + return []column.Column{}, err + } + col := column.NewColumnFloat16Vector(field.Name, int(dim), data) + nameColumns[field.Name] = col + case entity.FieldTypeBFloat16Vector: + data := make([][]byte, 0, rowsLen) + dim, err := field.GetDim() + if err != nil { + return []column.Column{}, err + } + col := column.NewColumnBFloat16Vector(field.Name, int(dim), data) + nameColumns[field.Name] = col + case entity.FieldTypeSparseVector: + data := make([]entity.SparseEmbedding, 0, rowsLen) + col := column.NewColumnSparseVectors(field.Name, data) + nameColumns[field.Name] = col + } + } + + if isDynamic { + dynamicCol = column.NewColumnJSONBytes("", make([][]byte, 0, rowsLen)).WithIsDynamic(true) + } + + for _, row := range rows { + // collection schema name need not to be same, since receiver could has other names + v := reflect.ValueOf(row) + set, err := reflectValueCandi(v) + if err != nil { + return nil, err + } + + for idx, field := range sch.Fields { + // skip dynamic field if visible + if isDynamic && field.IsDynamic { + continue + } + // skip auto id pk field + if field.PrimaryKey && field.AutoID { + // remove pk field from candidates set, avoid adding it into dynamic column + delete(set, field.Name) + continue + } + column, ok := nameColumns[field.Name] + if !ok { + return nil, fmt.Errorf("expected unhandled field %s", field.Name) + } + + candi, ok := set[field.Name] + if !ok { + return nil, fmt.Errorf("row %d does not has field %s", idx, field.Name) + } + err := column.AppendValue(candi.v.Interface()) + if err != nil { + return nil, err + } + delete(set, field.Name) + } + + if isDynamic { + m := make(map[string]interface{}) + for name, candi := range set { + m[name] = candi.v.Interface() + } + bs, err := json.Marshal(m) + if err != nil { + return nil, fmt.Errorf("failed to marshal dynamic field %w", err) + } + err = dynamicCol.AppendValue(bs) + if err != nil { + return nil, fmt.Errorf("failed to append value to dynamic field %w", err) + } + } + } + columns := make([]column.Column, 0, len(nameColumns)) + for _, column := range nameColumns { + columns = append(columns, column) + } + if isDynamic { + columns = append(columns, dynamicCol) + } + return columns, nil +} + +func NewArrayColumn(f *entity.Field) column.Column { + switch f.ElementType { + case entity.FieldTypeBool: + return column.NewColumnBoolArray(f.Name, nil) + + case entity.FieldTypeInt8: + return column.NewColumnInt8Array(f.Name, nil) + + case entity.FieldTypeInt16: + return column.NewColumnInt16Array(f.Name, nil) + + case entity.FieldTypeInt32: + return column.NewColumnInt32Array(f.Name, nil) + + case entity.FieldTypeInt64: + return column.NewColumnInt64Array(f.Name, nil) + + case entity.FieldTypeFloat: + return column.NewColumnFloatArray(f.Name, nil) + + case entity.FieldTypeDouble: + return column.NewColumnDoubleArray(f.Name, nil) + + case entity.FieldTypeVarChar: + return column.NewColumnVarCharArray(f.Name, nil) + + default: + return nil + } +} + +type fieldCandi struct { + name string + v reflect.Value + options map[string]string +} + +func reflectValueCandi(v reflect.Value) (map[string]fieldCandi, error) { + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + + result := make(map[string]fieldCandi) + switch v.Kind() { + case reflect.Map: // map[string]any + iter := v.MapRange() + for iter.Next() { + key := iter.Key().String() + result[key] = fieldCandi{ + name: key, + v: iter.Value(), + } + } + return result, nil + case reflect.Struct: + for i := 0; i < v.NumField(); i++ { + ft := v.Type().Field(i) + name := ft.Name + tag, ok := ft.Tag.Lookup(MilvusTag) + + settings := make(map[string]string) + if ok { + if tag == MilvusSkipTagValue { + continue + } + settings = ParseTagSetting(tag, MilvusTagSep) + fn, has := settings[MilvusTagName] + if has { + // overwrite column to tag name + name = fn + } + } + _, ok = result[name] + // duplicated + if ok { + return nil, fmt.Errorf("column has duplicated name: %s when parsing field: %s", name, ft.Name) + } + + v := v.Field(i) + if v.Kind() == reflect.Array { + v = v.Slice(0, v.Len()) + } + + result[name] = fieldCandi{ + name: name, + v: v, + options: settings, + } + } + + return result, nil + default: + return nil, fmt.Errorf("unsupport row type: %s", v.Kind().String()) + } +} diff --git a/client/row/data_test.go b/client/row/data_test.go new file mode 100644 index 0000000000000..9e8b7fb216fbc --- /dev/null +++ b/client/row/data_test.go @@ -0,0 +1,174 @@ +package row + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus/client/v2/entity" +) + +type ValidStruct struct { + ID int64 `milvus:"primary_key"` + Attr1 int8 + Attr2 int16 + Attr3 int32 + Attr4 float32 + Attr5 float64 + Attr6 string + Attr7 bool + Vector []float32 `milvus:"dim:16"` + Vector2 []byte `milvus:"dim:32"` +} + +type ValidStruct2 struct { + ID int64 `milvus:"primary_key"` + Vector [16]float32 + Vector2 [4]byte + Ignored bool `milvus:"-"` +} + +type ValidStructWithNamedTag struct { + ID int64 `milvus:"primary_key;name:id"` + Vector [16]float32 `milvus:"name:vector"` +} + +type RowsSuite struct { + suite.Suite +} + +func (s *RowsSuite) TestRowsToColumns() { + s.Run("valid_cases", func() { + columns, err := AnyToColumns([]any{&ValidStruct{}}) + s.Nil(err) + s.Equal(10, len(columns)) + + columns, err = AnyToColumns([]any{&ValidStruct2{}}) + s.Nil(err) + s.Equal(3, len(columns)) + }) + + s.Run("auto_id_pk", func() { + type AutoPK struct { + ID int64 `milvus:"primary_key;auto_id"` + Vector []float32 `milvus:"dim:32"` + } + columns, err := AnyToColumns([]any{&AutoPK{}}) + s.Nil(err) + s.Require().Equal(1, len(columns)) + s.Equal("Vector", columns[0].Name()) + }) + + s.Run("fp16", func() { + type BF16Struct struct { + ID int64 `milvus:"primary_key;auto_id"` + Vector []byte `milvus:"dim:16;vector_type:bf16"` + } + columns, err := AnyToColumns([]any{&BF16Struct{}}) + s.Nil(err) + s.Require().Equal(1, len(columns)) + s.Equal("Vector", columns[0].Name()) + s.Equal(entity.FieldTypeBFloat16Vector, columns[0].Type()) + }) + + s.Run("fp16", func() { + type FP16Struct struct { + ID int64 `milvus:"primary_key;auto_id"` + Vector []byte `milvus:"dim:16;vector_type:fp16"` + } + columns, err := AnyToColumns([]any{&FP16Struct{}}) + s.Nil(err) + s.Require().Equal(1, len(columns)) + s.Equal("Vector", columns[0].Name()) + s.Equal(entity.FieldTypeFloat16Vector, columns[0].Type()) + }) + + s.Run("invalid_cases", func() { + // empty input + _, err := AnyToColumns([]any{}) + s.NotNil(err) + + // incompatible rows + _, err = AnyToColumns([]any{&ValidStruct{}, &ValidStruct2{}}) + s.NotNil(err) + + // schema & row not compatible + _, err = AnyToColumns([]any{&ValidStruct{}}, &entity.Schema{ + Fields: []*entity.Field{ + { + Name: "int64", + DataType: entity.FieldTypeInt64, + }, + }, + }) + s.NotNil(err) + }) +} + +func (s *RowsSuite) TestDynamicSchema() { + s.Run("all_fallback_dynamic", func() { + columns, err := AnyToColumns([]any{&ValidStruct{}}, + entity.NewSchema().WithDynamicFieldEnabled(true), + ) + s.NoError(err) + s.Equal(1, len(columns)) + }) + + s.Run("dynamic_not_found", func() { + _, err := AnyToColumns([]any{&ValidStruct{}}, + entity.NewSchema().WithField( + entity.NewField().WithName("ID").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true), + ).WithDynamicFieldEnabled(true), + ) + s.NoError(err) + }) +} + +func (s *RowsSuite) TestReflectValueCandi() { + cases := []struct { + tag string + v reflect.Value + expect map[string]fieldCandi + expectErr bool + }{ + { + tag: "MapRow", + v: reflect.ValueOf(map[string]interface{}{ + "A": "abd", "B": int64(8), + }), + expect: map[string]fieldCandi{ + "A": { + name: "A", + v: reflect.ValueOf("abd"), + }, + "B": { + name: "B", + v: reflect.ValueOf(int64(8)), + }, + }, + expectErr: false, + }, + } + + for _, c := range cases { + s.Run(c.tag, func() { + r, err := reflectValueCandi(c.v) + if c.expectErr { + s.Error(err) + return + } + s.NoError(err) + s.Equal(len(c.expect), len(r)) + for k, v := range c.expect { + rv, has := r[k] + s.Require().True(has) + s.Equal(v.name, rv.name) + } + }) + } +} + +func TestRows(t *testing.T) { + suite.Run(t, new(RowsSuite)) +} diff --git a/client/row/schema.go b/client/row/schema.go new file mode 100644 index 0000000000000..6022275653f17 --- /dev/null +++ b/client/row/schema.go @@ -0,0 +1,185 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package row + +import ( + "fmt" + "go/ast" + "reflect" + "strconv" + "strings" + + "github.com/cockroachdb/errors" + + "github.com/milvus-io/milvus/client/v2/entity" +) + +// ParseSchema parses schema from interface{}. +func ParseSchema(r interface{}) (*entity.Schema, error) { + sch := &entity.Schema{} + t := reflect.TypeOf(r) + if t.Kind() == reflect.Array || t.Kind() == reflect.Slice || t.Kind() == reflect.Ptr { + t = t.Elem() + } + + // MapRow is not supported for schema definition + // TODO add PrimaryKey() interface later + if t.Kind() == reflect.Map { + return nil, fmt.Errorf("map row is not supported for schema definition") + } + + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("unsupported data type: %+v", r) + } + + // Collection method not overwrited, try use Row type name + if sch.CollectionName == "" { + sch.CollectionName = t.Name() + if sch.CollectionName == "" { + return nil, errors.New("collection name not provided") + } + } + sch.Fields = make([]*entity.Field, 0, t.NumField()) + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + // ignore anonymous field for now + if f.Anonymous || !ast.IsExported(f.Name) { + continue + } + + field := &entity.Field{ + Name: f.Name, + } + ft := f.Type + if f.Type.Kind() == reflect.Ptr { + ft = ft.Elem() + } + fv := reflect.New(ft) + tag := f.Tag.Get(MilvusTag) + if tag == MilvusSkipTagValue { + continue + } + tagSettings := ParseTagSetting(tag, MilvusTagSep) + if _, has := tagSettings[MilvusPrimaryKey]; has { + field.PrimaryKey = true + } + if _, has := tagSettings[MilvusAutoID]; has { + field.AutoID = true + } + if name, has := tagSettings[MilvusTagName]; has { + field.Name = name + } + switch reflect.Indirect(fv).Kind() { + case reflect.Bool: + field.DataType = entity.FieldTypeBool + case reflect.Int8: + field.DataType = entity.FieldTypeInt8 + case reflect.Int16: + field.DataType = entity.FieldTypeInt16 + case reflect.Int32: + field.DataType = entity.FieldTypeInt32 + case reflect.Int64: + field.DataType = entity.FieldTypeInt64 + case reflect.Float32: + field.DataType = entity.FieldTypeFloat + case reflect.Float64: + field.DataType = entity.FieldTypeDouble + case reflect.String: + field.DataType = entity.FieldTypeString + case reflect.Array: + arrayLen := ft.Len() + elemType := ft.Elem() + switch elemType.Kind() { + case reflect.Uint8: + field.WithDataType(entity.FieldTypeBinaryVector) + field.WithDim(int64(arrayLen) * 8) + case reflect.Float32: + field.WithDataType(entity.FieldTypeFloatVector) + field.WithDim(int64(arrayLen)) + default: + return nil, fmt.Errorf("field %s is array of %v, which is not supported", f.Name, elemType) + } + case reflect.Slice: + dimStr, has := tagSettings[VectorDimTag] + if !has { + return nil, fmt.Errorf("field %s is slice but dim not provided", f.Name) + } + dim, err := strconv.ParseInt(dimStr, 10, 64) + if err != nil { + return nil, fmt.Errorf("dim value %s is not valid", dimStr) + } + if dim < 1 || dim > DimMax { + return nil, fmt.Errorf("dim value %d is out of range", dim) + } + field.WithDim(dim) + + elemType := ft.Elem() + switch elemType.Kind() { + case reflect.Uint8: // []byte, could be BinaryVector, fp16, bf 6 + switch tagSettings[VectorTypeTag] { + case "fp16": + field.DataType = entity.FieldTypeFloat16Vector + case "bf16": + field.DataType = entity.FieldTypeBFloat16Vector + default: + field.DataType = entity.FieldTypeBinaryVector + } + case reflect.Float32: + field.DataType = entity.FieldTypeFloatVector + default: + return nil, fmt.Errorf("field %s is slice of %v, which is not supported", f.Name, elemType) + } + default: + return nil, fmt.Errorf("field %s is %v, which is not supported", field.Name, ft) + } + sch.Fields = append(sch.Fields, field) + } + + return sch, nil +} + +// ParseTagSetting parses struct tag into map settings +func ParseTagSetting(str string, sep string) map[string]string { + settings := map[string]string{} + names := strings.Split(str, sep) + + for i := 0; i < len(names); i++ { + j := i + if len(names[j]) > 0 { + for { + if names[j][len(names[j])-1] == '\\' { + i++ + names[j] = names[j][0:len(names[j])-1] + sep + names[i] + names[i] = "" + } else { + break + } + } + } + + values := strings.Split(names[j], ":") + k := strings.TrimSpace(strings.ToUpper(values[0])) + + if len(values) >= 2 { + settings[k] = strings.Join(values[1:], ":") + } else if k != "" { + settings[k] = k + } + } + + return settings +} diff --git a/client/row/schema_test.go b/client/row/schema_test.go new file mode 100644 index 0000000000000..fbfdc19f27058 --- /dev/null +++ b/client/row/schema_test.go @@ -0,0 +1,213 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package row + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/client/v2/entity" +) + +// ArrayRow test case type +type ArrayRow [16]float32 + +func (ar *ArrayRow) Collection() string { return "" } +func (ar *ArrayRow) Partition() string { return "" } +func (ar *ArrayRow) Description() string { return "" } + +type Uint8Struct struct { + Attr uint8 +} + +type StringArrayStruct struct { + Vector [8]string +} + +type StringSliceStruct struct { + Vector []string `milvus:"dim:8"` +} + +type SliceNoDimStruct struct { + Vector []float32 `milvus:""` +} + +type SliceBadDimStruct struct { + Vector []float32 `milvus:"dim:str"` +} + +type SliceBadDimStruct2 struct { + Vector []float32 `milvus:"dim:0"` +} + +func TestParseSchema(t *testing.T) { + t.Run("invalid cases", func(t *testing.T) { + // anonymous struct with default collection name ("") will cause error + anonymusStruct := struct{}{} + sch, err := ParseSchema(anonymusStruct) + assert.Nil(t, sch) + assert.NotNil(t, err) + + // non struct + arrayRow := ArrayRow([16]float32{}) + sch, err = ParseSchema(&arrayRow) + assert.Nil(t, sch) + assert.NotNil(t, err) + + // uint8 not supported + sch, err = ParseSchema(&Uint8Struct{}) + assert.Nil(t, sch) + assert.NotNil(t, err) + + // string array not supported + sch, err = ParseSchema(&StringArrayStruct{}) + assert.Nil(t, sch) + assert.NotNil(t, err) + + // string slice not supported + sch, err = ParseSchema(&StringSliceStruct{}) + assert.Nil(t, sch) + assert.NotNil(t, err) + + // slice vector with no dim + sch, err = ParseSchema(&SliceNoDimStruct{}) + assert.Nil(t, sch) + assert.NotNil(t, err) + + // slice vector with bad format dim + sch, err = ParseSchema(&SliceBadDimStruct{}) + assert.Nil(t, sch) + assert.NotNil(t, err) + + // slice vector with bad format dim 2 + sch, err = ParseSchema(&SliceBadDimStruct2{}) + assert.Nil(t, sch) + assert.NotNil(t, err) + }) + + t.Run("valid cases", func(t *testing.T) { + getVectorField := func(schema *entity.Schema) *entity.Field { + for _, field := range schema.Fields { + if field.DataType == entity.FieldTypeFloatVector || + field.DataType == entity.FieldTypeBinaryVector || + field.DataType == entity.FieldTypeBFloat16Vector || + field.DataType == entity.FieldTypeFloat16Vector { + return field + } + } + return nil + } + + type ValidStruct struct { + ID int64 `milvus:"primary_key"` + Attr1 int8 + Attr2 int16 + Attr3 int32 + Attr4 float32 + Attr5 float64 + Attr6 string + Vector []float32 `milvus:"dim:128"` + } + vs := &ValidStruct{} + sch, err := ParseSchema(vs) + assert.Nil(t, err) + assert.NotNil(t, sch) + assert.Equal(t, "ValidStruct", sch.CollectionName) + + type ValidFp16Struct struct { + ID int64 `milvus:"primary_key"` + Attr1 int8 + Attr2 int16 + Attr3 int32 + Attr4 float32 + Attr5 float64 + Attr6 string + Vector []byte `milvus:"dim:128;vector_type:fp16"` + } + fp16Vs := &ValidFp16Struct{} + sch, err = ParseSchema(fp16Vs) + assert.Nil(t, err) + assert.NotNil(t, sch) + assert.Equal(t, "ValidFp16Struct", sch.CollectionName) + vectorField := getVectorField(sch) + assert.Equal(t, entity.FieldTypeFloat16Vector, vectorField.DataType) + + type ValidBf16Struct struct { + ID int64 `milvus:"primary_key"` + Attr1 int8 + Attr2 int16 + Attr3 int32 + Attr4 float32 + Attr5 float64 + Attr6 string + Vector []byte `milvus:"dim:128;vector_type:bf16"` + } + bf16Vs := &ValidBf16Struct{} + sch, err = ParseSchema(bf16Vs) + assert.Nil(t, err) + assert.NotNil(t, sch) + assert.Equal(t, "ValidBf16Struct", sch.CollectionName) + vectorField = getVectorField(sch) + assert.Equal(t, entity.FieldTypeBFloat16Vector, vectorField.DataType) + + type ValidByteStruct struct { + ID int64 `milvus:"primary_key"` + Vector []byte `milvus:"dim:128"` + } + vs2 := &ValidByteStruct{} + sch, err = ParseSchema(vs2) + assert.Nil(t, err) + assert.NotNil(t, sch) + + type ValidArrayStruct struct { + ID int64 `milvus:"primary_key"` + Vector [64]float32 + } + vs3 := &ValidArrayStruct{} + sch, err = ParseSchema(vs3) + assert.Nil(t, err) + assert.NotNil(t, sch) + + type ValidArrayStructByte struct { + ID int64 `milvus:"primary_key;auto_id"` + Data *string `milvus:"extra:test\\;false"` + Vector [64]byte + } + vs4 := &ValidArrayStructByte{} + sch, err = ParseSchema(vs4) + assert.Nil(t, err) + assert.NotNil(t, sch) + + vs5 := &ValidStructWithNamedTag{} + sch, err = ParseSchema(vs5) + assert.Nil(t, err) + assert.NotNil(t, sch) + i64f, vecf := false, false + for _, field := range sch.Fields { + if field.Name == "id" { + i64f = true + } + if field.Name == "vector" { + vecf = true + } + } + + assert.True(t, i64f) + assert.True(t, vecf) + }) +} diff --git a/client/write_options.go b/client/write_options.go index 54139ef0b21fa..612cc7fe2d995 100644 --- a/client/write_options.go +++ b/client/write_options.go @@ -28,6 +28,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/client/v2/column" "github.com/milvus-io/milvus/client/v2/entity" + "github.com/milvus-io/milvus/client/v2/row" ) type InsertOption interface { @@ -71,10 +72,8 @@ func (opt *columnBasedDataOption) processInsertColumns(colSchema *entity.Schema, l := col.Len() if rowSize == 0 { rowSize = l - } else { - if rowSize != l { - return nil, 0, errors.New("column size not match") - } + } else if rowSize != l { + return nil, 0, errors.New("column size not match") } field, has := mNameField[col.Name()] if !has { @@ -247,6 +246,56 @@ func NewColumnBasedInsertOption(collName string, columns ...column.Column) *colu } } +type rowBasedDataOption struct { + *columnBasedDataOption + rows []any +} + +func NewRowBasedInsertOption(collName string, rows ...any) *rowBasedDataOption { + return &rowBasedDataOption{ + columnBasedDataOption: &columnBasedDataOption{ + collName: collName, + }, + rows: rows, + } +} + +func (opt *rowBasedDataOption) InsertRequest(coll *entity.Collection) (*milvuspb.InsertRequest, error) { + columns, err := row.AnyToColumns(opt.rows, coll.Schema) + if err != nil { + return nil, err + } + opt.columnBasedDataOption.columns = columns + fieldsData, rowNum, err := opt.processInsertColumns(coll.Schema, opt.columns...) + if err != nil { + return nil, err + } + return &milvuspb.InsertRequest{ + CollectionName: opt.collName, + PartitionName: opt.partitionName, + FieldsData: fieldsData, + NumRows: uint32(rowNum), + }, nil +} + +func (opt *rowBasedDataOption) UpsertRequest(coll *entity.Collection) (*milvuspb.UpsertRequest, error) { + columns, err := row.AnyToColumns(opt.rows, coll.Schema) + if err != nil { + return nil, err + } + opt.columnBasedDataOption.columns = columns + fieldsData, rowNum, err := opt.processInsertColumns(coll.Schema, opt.columns...) + if err != nil { + return nil, err + } + return &milvuspb.UpsertRequest{ + CollectionName: opt.collName, + PartitionName: opt.partitionName, + FieldsData: fieldsData, + NumRows: uint32(rowNum), + }, nil +} + type DeleteOption interface { Request() *milvuspb.DeleteRequest }