Skip to content

Commit

Permalink
disallow get raw vector data of a BM25 Function output field
Browse files Browse the repository at this point in the history
Signed-off-by: Buqian Zheng <[email protected]>
  • Loading branch information
zhengbuqian committed Nov 19, 2024
1 parent 5a23c80 commit d60ccd6
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 14 deletions.
4 changes: 4 additions & 0 deletions internal/proxy/meta_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,10 @@ func (s *schemaInfo) IsFieldLoaded(fieldID int64) bool {
return s.schemaHelper.IsFieldLoaded(fieldID)
}

func (s *schemaInfo) CanRetrieveRawFieldData(field *schemapb.FieldSchema) bool {
return s.schemaHelper.CanRetrieveRawFieldData(field)
}

// partitionInfos contains the cached collection partition informations.
type partitionInfos struct {
partitionInfos []*partitionInfo
Expand Down
33 changes: 26 additions & 7 deletions internal/proxy/task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -460,12 +460,13 @@ func constructSearchRequest(

func TestTranslateOutputFields(t *testing.T) {
const (
idFieldName = "id"
tsFieldName = "timestamp"
floatVectorFieldName = "float_vector"
binaryVectorFieldName = "binary_vector"
float16VectorFieldName = "float16_vector"
bfloat16VectorFieldName = "bfloat16_vector"
idFieldName = "id"
tsFieldName = "timestamp"
floatVectorFieldName = "float_vector"
binaryVectorFieldName = "binary_vector"
float16VectorFieldName = "float16_vector"
bfloat16VectorFieldName = "bfloat16_vector"
sparseFloatVectorFieldName = "sparse_float_vector"
)
var outputFields []string
var userOutputFields []string
Expand All @@ -483,6 +484,15 @@ func TestTranslateOutputFields(t *testing.T) {
{Name: binaryVectorFieldName, FieldID: 101, DataType: schemapb.DataType_BinaryVector},
{Name: float16VectorFieldName, FieldID: 102, DataType: schemapb.DataType_Float16Vector},
{Name: bfloat16VectorFieldName, FieldID: 103, DataType: schemapb.DataType_BFloat16Vector},
{Name: sparseFloatVectorFieldName, FieldID: 104, DataType: schemapb.DataType_SparseFloatVector, IsFunctionOutput: true},
},
Functions: []*schemapb.FunctionSchema{
{
Name: "bm25",
Type: schemapb.FunctionType_BM25,
OutputFieldNames: []string{sparseFloatVectorFieldName},
// omit other fields for brevity
},
},
}
schema := newSchemaInfo(collSchema)
Expand Down Expand Up @@ -511,6 +521,7 @@ func TestTranslateOutputFields(t *testing.T) {
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName}, userOutputFields)
assert.ElementsMatch(t, []string{}, userDynamicFields)

// sparse_float_vector is a BM25 function output field, so it should not be included in the output fields
outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{"*"}, schema, false)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName, bfloat16VectorFieldName}, outputFields)
Expand All @@ -535,6 +546,14 @@ func TestTranslateOutputFields(t *testing.T) {
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName, bfloat16VectorFieldName}, userOutputFields)
assert.ElementsMatch(t, []string{}, userDynamicFields)

// sparse_float_vector is a BM25 function output field, so it should not be included in the output fields
_, _, _, err = translateOutputFields([]string{"*", sparseFloatVectorFieldName}, schema, false)
assert.Error(t, err)
_, _, _, err = translateOutputFields([]string{sparseFloatVectorFieldName}, schema, false)
assert.Error(t, err)
_, _, _, err = translateOutputFields([]string{sparseFloatVectorFieldName}, schema, true)
assert.Error(t, err)

//=========================================================================
outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{}, schema, true)
assert.Equal(t, nil, err)
Expand Down Expand Up @@ -578,7 +597,7 @@ func TestTranslateOutputFields(t *testing.T) {
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName, bfloat16VectorFieldName}, userOutputFields)
assert.ElementsMatch(t, []string{}, userDynamicFields)

outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{"A"}, schema, true)
_, _, _, err = translateOutputFields([]string{"A"}, schema, true)
assert.Error(t, err)

t.Run("enable dynamic schema", func(t *testing.T) {
Expand Down
17 changes: 10 additions & 7 deletions internal/proxy/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -1204,7 +1204,7 @@ func translatePkOutputFields(schema *schemapb.CollectionSchema) ([]string, []int
func translateOutputFields(outputFields []string, schema *schemaInfo, addPrimary bool) ([]string, []string, []string, error) {
var primaryFieldName string
var dynamicField *schemapb.FieldSchema
allFieldNameMap := make(map[string]int64)
allFieldNameMap := make(map[string]*schemapb.FieldSchema)
resultFieldNameMap := make(map[string]bool)
resultFieldNames := make([]string, 0)
userOutputFieldsMap := make(map[string]bool)
Expand All @@ -1219,23 +1219,26 @@ func translateOutputFields(outputFields []string, schema *schemaInfo, addPrimary
if field.IsDynamic {
dynamicField = field
}
allFieldNameMap[field.Name] = field.GetFieldID()
allFieldNameMap[field.Name] = field
}

for _, outputFieldName := range outputFields {
outputFieldName = strings.TrimSpace(outputFieldName)
if outputFieldName == "*" {
for fieldName, fieldID := range allFieldNameMap {
// skip Cold field
if schema.IsFieldLoaded(fieldID) {
for fieldName, field := range allFieldNameMap {
// skip Cold field and fields that can't be output
if schema.IsFieldLoaded(field.GetFieldID()) && schema.CanRetrieveRawFieldData(field) {
resultFieldNameMap[fieldName] = true
userOutputFieldsMap[fieldName] = true
}
}
useAllDyncamicFields = true
} else {
if fieldID, ok := allFieldNameMap[outputFieldName]; ok {
if schema.IsFieldLoaded(fieldID) {
if field, ok := allFieldNameMap[outputFieldName]; ok {
if !schema.CanRetrieveRawFieldData(field) {
return nil, nil, nil, fmt.Errorf("not allowed to retrieve raw data of field %s", outputFieldName)
}
if schema.IsFieldLoaded(field.GetFieldID()) {
resultFieldNameMap[outputFieldName] = true
userOutputFieldsMap[outputFieldName] = true
} else {
Expand Down
14 changes: 14 additions & 0 deletions pkg/util/typeutil/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,20 @@ func (helper *SchemaHelper) GetFunctionByOutputField(field *schemapb.FieldSchema
return nil, fmt.Errorf("function not exist")
}

// As of now, only BM25 function output field is not supported to retrieve raw field data
func (helper *SchemaHelper) CanRetrieveRawFieldData(field *schemapb.FieldSchema) bool {
if !field.IsFunctionOutput {
return true
}

f, err := helper.GetFunctionByOutputField(field)
if err != nil {
return false
}

return f.GetType() != schemapb.FunctionType_BM25

Check warning on line 458 in pkg/util/typeutil/schema.go

View check run for this annotation

Codecov / codecov/patch

pkg/util/typeutil/schema.go#L458

Added line #L458 was not covered by tests
}

func (helper *SchemaHelper) GetCollectionName() string {
return helper.schema.Name
}
Expand Down

0 comments on commit d60ccd6

Please sign in to comment.