Skip to content

Commit

Permalink
enhance: [GoSDK] Support nullable generic column (milvus-io#38076)
Browse files Browse the repository at this point in the history
Related to milvus-io#31728 milvus-io#31293

---------

Signed-off-by: Congqi Xia <[email protected]>
  • Loading branch information
congqixia authored and wayblink committed Nov 29, 2024
1 parent c6dcef7 commit f166cd4
Show file tree
Hide file tree
Showing 7 changed files with 618 additions and 185 deletions.
284 changes: 101 additions & 183 deletions client/column/columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@ import (
"fmt"

"github.com/cockroachdb/errors"
"github.com/samber/lo"

"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/client/v2/entity"
)

//go:generate go run gen/gen.go

// Column interface field type for column-based data frame
type Column interface {
Name() string
Expand All @@ -40,6 +39,10 @@ type Column interface {
GetAsString(int) (string, error)
GetAsDouble(int) (float64, error)
GetAsBool(int) (bool, error)
// nullable related API
AppendNull() error
IsNull(int) (bool, error)
Nullable() bool
}

var errFieldDataTypeNotMatch = errors.New("FieldData type not matched")
Expand Down Expand Up @@ -79,135 +82,128 @@ func IDColumns(schema *entity.Schema, ids *schemapb.IDs, begin, end int) (Column
return idColumn, nil
}

func parseScalarData[T any, COL Column, NCOL Column](
name string,
data []T,
start, end int,
validData []bool,
creator func(string, []T) COL,
nullableCreator func(string, []T, []bool) (NCOL, error),
) (Column, error) {
if end < 0 {
end = len(data)
}
data = data[start:end]
if len(validData) > 0 {
ncol, err := nullableCreator(name, data, validData)
return ncol, err
}

return creator(name, data), nil
}

func parseArrayData(fieldName string, elementType schemapb.DataType, fieldDataList []*schemapb.ScalarField, validData []bool, begin, end int) (Column, error) {
switch elementType {
case schemapb.DataType_Bool:
data := lo.Map(fieldDataList, func(fd *schemapb.ScalarField, _ int) []bool {
return fd.GetBoolData().GetData()
})
return parseScalarData(fieldName, data, begin, end, validData, NewColumnBoolArray, NewNullableColumnBoolArray)

case schemapb.DataType_Int8:
data := lo.Map(fieldDataList, func(fd *schemapb.ScalarField, _ int) []int8 {
return int32ToType[int8](fd.GetIntData().GetData())
})
return parseScalarData(fieldName, data, begin, end, validData, NewColumnInt8Array, NewNullableColumnInt8Array)

case schemapb.DataType_Int16:
data := lo.Map(fieldDataList, func(fd *schemapb.ScalarField, _ int) []int16 {
return int32ToType[int16](fd.GetIntData().GetData())
})
return parseScalarData(fieldName, data, begin, end, validData, NewColumnInt16Array, NewNullableColumnInt16Array)

case schemapb.DataType_Int32:
data := lo.Map(fieldDataList, func(fd *schemapb.ScalarField, _ int) []int32 {
return fd.GetIntData().GetData()
})
return parseScalarData(fieldName, data, begin, end, validData, NewColumnInt32Array, NewNullableColumnInt32Array)

case schemapb.DataType_Int64:
data := lo.Map(fieldDataList, func(fd *schemapb.ScalarField, _ int) []int64 {
return fd.GetLongData().GetData()
})
return parseScalarData(fieldName, data, begin, end, validData, NewColumnInt64Array, NewNullableColumnInt64Array)

case schemapb.DataType_Float:
data := lo.Map(fieldDataList, func(fd *schemapb.ScalarField, _ int) []float32 {
return fd.GetFloatData().GetData()
})
return parseScalarData(fieldName, data, begin, end, validData, NewColumnFloatArray, NewNullableColumnFloatArray)

case schemapb.DataType_Double:
data := lo.Map(fieldDataList, func(fd *schemapb.ScalarField, _ int) []float64 {
return fd.GetDoubleData().GetData()
})
return parseScalarData(fieldName, data, begin, end, validData, NewColumnDoubleArray, NewNullableColumnDoubleArray)

case schemapb.DataType_VarChar, schemapb.DataType_String:
data := lo.Map(fieldDataList, func(fd *schemapb.ScalarField, _ int) []string {
return fd.GetStringData().GetData()
})
return parseScalarData(fieldName, data, begin, end, validData, NewColumnVarCharArray, NewNullableColumnVarCharArray)

default:
return nil, fmt.Errorf("unsupported element type %s", elementType)
}
}

func int32ToType[T ~int8 | int16](data []int32) []T {
return lo.Map(data, func(i32 int32, _ int) T {
return T(i32)
})
}

// FieldDataColumn converts schemapb.FieldData to Column, used int search result conversion logic
// begin, end specifies the start and end positions
func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) {
validData := fd.GetValidData()

switch fd.GetType() {
case schemapb.DataType_Bool:
data, ok := fd.GetScalars().GetData().(*schemapb.ScalarField_BoolData)
if !ok {
return nil, errFieldDataTypeNotMatch
}
if end < 0 {
return NewColumnBool(fd.GetFieldName(), data.BoolData.GetData()[begin:]), nil
}
return NewColumnBool(fd.GetFieldName(), data.BoolData.GetData()[begin:end]), nil
return parseScalarData(fd.GetFieldName(), fd.GetScalars().GetBoolData().GetData(), begin, end, validData, NewColumnBool, NewNullableColumnBool)

case schemapb.DataType_Int8:
data, ok := getIntData(fd)
if !ok {
return nil, errFieldDataTypeNotMatch
}
values := make([]int8, 0, len(data.IntData.GetData()))
for _, v := range data.IntData.GetData() {
values = append(values, int8(v))
}

if end < 0 {
return NewColumnInt8(fd.GetFieldName(), values[begin:]), nil
}

return NewColumnInt8(fd.GetFieldName(), values[begin:end]), nil
data := int32ToType[int8](fd.GetScalars().GetIntData().GetData())
return parseScalarData(fd.GetFieldName(), data, begin, end, validData, NewColumnInt8, NewNullableColumnInt8)

case schemapb.DataType_Int16:
data, ok := getIntData(fd)
if !ok {
return nil, errFieldDataTypeNotMatch
}
values := make([]int16, 0, len(data.IntData.GetData()))
for _, v := range data.IntData.GetData() {
values = append(values, int16(v))
}
if end < 0 {
return NewColumnInt16(fd.GetFieldName(), values[begin:]), nil
}

return NewColumnInt16(fd.GetFieldName(), values[begin:end]), nil
data := int32ToType[int16](fd.GetScalars().GetIntData().GetData())
return parseScalarData(fd.GetFieldName(), data, begin, end, validData, NewColumnInt16, NewNullableColumnInt16)

case schemapb.DataType_Int32:
data, ok := getIntData(fd)
if !ok {
return nil, errFieldDataTypeNotMatch
}
if end < 0 {
return NewColumnInt32(fd.GetFieldName(), data.IntData.GetData()[begin:]), nil
}
return NewColumnInt32(fd.GetFieldName(), data.IntData.GetData()[begin:end]), nil
return parseScalarData(fd.GetFieldName(), fd.GetScalars().GetIntData().GetData(), begin, end, validData, NewColumnInt32, NewNullableColumnInt32)

case schemapb.DataType_Int64:
data, ok := fd.GetScalars().GetData().(*schemapb.ScalarField_LongData)
if !ok {
return nil, errFieldDataTypeNotMatch
}
if end < 0 {
return NewColumnInt64(fd.GetFieldName(), data.LongData.GetData()[begin:]), nil
}
return NewColumnInt64(fd.GetFieldName(), data.LongData.GetData()[begin:end]), nil
return parseScalarData(fd.GetFieldName(), fd.GetScalars().GetLongData().GetData(), begin, end, validData, NewColumnInt64, NewNullableColumnInt64)

case schemapb.DataType_Float:
data, ok := fd.GetScalars().GetData().(*schemapb.ScalarField_FloatData)
if !ok {
return nil, errFieldDataTypeNotMatch
}
if end < 0 {
return NewColumnFloat(fd.GetFieldName(), data.FloatData.GetData()[begin:]), nil
}
return NewColumnFloat(fd.GetFieldName(), data.FloatData.GetData()[begin:end]), nil
return parseScalarData(fd.GetFieldName(), fd.GetScalars().GetFloatData().GetData(), begin, end, validData, NewColumnFloat, NewNullableColumnFloat)

case schemapb.DataType_Double:
data, ok := fd.GetScalars().GetData().(*schemapb.ScalarField_DoubleData)
if !ok {
return nil, errFieldDataTypeNotMatch
}
if end < 0 {
return NewColumnDouble(fd.GetFieldName(), data.DoubleData.GetData()[begin:]), nil
}
return NewColumnDouble(fd.GetFieldName(), data.DoubleData.GetData()[begin:end]), nil
return parseScalarData(fd.GetFieldName(), fd.GetScalars().GetDoubleData().GetData(), begin, end, validData, NewColumnDouble, NewNullableColumnDouble)

case schemapb.DataType_String:
data, ok := fd.GetScalars().GetData().(*schemapb.ScalarField_StringData)
if !ok {
return nil, errFieldDataTypeNotMatch
}
if end < 0 {
return NewColumnString(fd.GetFieldName(), data.StringData.GetData()[begin:]), nil
}
return NewColumnString(fd.GetFieldName(), data.StringData.GetData()[begin:end]), nil
return parseScalarData(fd.GetFieldName(), fd.GetScalars().GetStringData().GetData(), begin, end, validData, NewColumnString, NewNullableColumnString)

case schemapb.DataType_VarChar:
data, ok := fd.GetScalars().GetData().(*schemapb.ScalarField_StringData)
if !ok {
return nil, errFieldDataTypeNotMatch
}
if end < 0 {
return NewColumnVarChar(fd.GetFieldName(), data.StringData.GetData()[begin:]), nil
}
return NewColumnVarChar(fd.GetFieldName(), data.StringData.GetData()[begin:end]), nil
return parseScalarData(fd.GetFieldName(), fd.GetScalars().GetStringData().GetData(), begin, end, validData, NewColumnVarChar, NewNullableColumnVarChar)

case schemapb.DataType_Array:
data := fd.GetScalars().GetArrayData()
if data == nil {
return nil, errFieldDataTypeNotMatch
}
var arrayData []*schemapb.ScalarField
if end < 0 {
arrayData = data.GetData()[begin:]
} else {
arrayData = data.GetData()[begin:end]
}

return parseArrayData(fd.GetFieldName(), data.GetElementType(), arrayData)
return parseArrayData(fd.GetFieldName(), data.GetElementType(), data.GetData(), validData, begin, end)

case schemapb.DataType_JSON:
data, ok := fd.GetScalars().GetData().(*schemapb.ScalarField_JsonData)
isDynamic := fd.GetIsDynamic()
if !ok {
return nil, errFieldDataTypeNotMatch
}
if end < 0 {
return NewColumnJSONBytes(fd.GetFieldName(), data.JsonData.GetData()[begin:]).WithIsDynamic(isDynamic), nil
}
return NewColumnJSONBytes(fd.GetFieldName(), data.JsonData.GetData()[begin:end]).WithIsDynamic(isDynamic), nil
return parseScalarData(fd.GetFieldName(), fd.GetScalars().GetJsonData().GetData(), begin, end, validData, NewColumnJSONBytes, NewNullableColumnJSONBytes)

case schemapb.DataType_FloatVector:
vectors := fd.GetVectors()
Expand Down Expand Up @@ -312,84 +308,6 @@ func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) {
}
}

func parseArrayData(fieldName string, elementType schemapb.DataType, fieldDataList []*schemapb.ScalarField) (Column, error) {
switch elementType {
case schemapb.DataType_Bool:
var data [][]bool
for _, fd := range fieldDataList {
data = append(data, fd.GetBoolData().GetData())
}
return NewColumnBoolArray(fieldName, data), nil

case schemapb.DataType_Int8:
var data [][]int8
for _, fd := range fieldDataList {
raw := fd.GetIntData().GetData()
row := make([]int8, 0, len(raw))
for _, item := range raw {
row = append(row, int8(item))
}
data = append(data, row)
}
return NewColumnInt8Array(fieldName, data), nil

case schemapb.DataType_Int16:
var data [][]int16
for _, fd := range fieldDataList {
raw := fd.GetIntData().GetData()
row := make([]int16, 0, len(raw))
for _, item := range raw {
row = append(row, int16(item))
}
data = append(data, row)
}
return NewColumnInt16Array(fieldName, data), nil

case schemapb.DataType_Int32:
var data [][]int32
for _, fd := range fieldDataList {
data = append(data, fd.GetIntData().GetData())
}
return NewColumnInt32Array(fieldName, data), nil

case schemapb.DataType_Int64:
var data [][]int64
for _, fd := range fieldDataList {
data = append(data, fd.GetLongData().GetData())
}
return NewColumnInt64Array(fieldName, data), nil

case schemapb.DataType_Float:
var data [][]float32
for _, fd := range fieldDataList {
data = append(data, fd.GetFloatData().GetData())
}
return NewColumnFloatArray(fieldName, data), nil

case schemapb.DataType_Double:
var data [][]float64
for _, fd := range fieldDataList {
data = append(data, fd.GetDoubleData().GetData())
}
return NewColumnDoubleArray(fieldName, data), nil

case schemapb.DataType_VarChar, schemapb.DataType_String:
var data [][]string
for _, fd := range fieldDataList {
strs := fd.GetStringData().GetData()
bytesData := make([]string, 0, len(strs))
bytesData = append(bytesData, strs...)

data = append(data, bytesData)
}

return NewColumnVarCharArray(fieldName, data), nil

default:
return nil, fmt.Errorf("unsupported element type %s", elementType)
}
}

// getIntData get int32 slice from result field data
// also handles LongData bug (see also https://github.com/milvus-io/milvus/issues/23850)
func getIntData(fd *schemapb.FieldData) (*schemapb.ScalarField_IntData, bool) {
Expand Down
Loading

0 comments on commit f166cd4

Please sign in to comment.