Skip to content

Commit

Permalink
enhance: disallow get raw vector data of a BM25 Function output field (
Browse files Browse the repository at this point in the history
…#37800)

issue: #35853

Signed-off-by: Buqian Zheng <[email protected]>
  • Loading branch information
zhengbuqian authored Nov 20, 2024
1 parent 7ba8550 commit 511edd2
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 26 deletions.
4 changes: 4 additions & 0 deletions internal/proxy/meta_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ func (s *schemaInfo) IsFieldLoaded(fieldID int64) bool {
return s.schemaHelper.IsFieldLoaded(fieldID)
}

func (s *schemaInfo) CanRetrieveRawFieldData(field *schemapb.FieldSchema) bool {
return s.schemaHelper.CanRetrieveRawFieldData(field)
}

// partitionInfos contains the cached collection partition informations.
type partitionInfos struct {
partitionInfos []*partitionInfo
Expand Down
33 changes: 26 additions & 7 deletions internal/proxy/task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -460,12 +460,13 @@ func constructSearchRequest(

func TestTranslateOutputFields(t *testing.T) {
const (
idFieldName = "id"
tsFieldName = "timestamp"
floatVectorFieldName = "float_vector"
binaryVectorFieldName = "binary_vector"
float16VectorFieldName = "float16_vector"
bfloat16VectorFieldName = "bfloat16_vector"
idFieldName = "id"
tsFieldName = "timestamp"
floatVectorFieldName = "float_vector"
binaryVectorFieldName = "binary_vector"
float16VectorFieldName = "float16_vector"
bfloat16VectorFieldName = "bfloat16_vector"
sparseFloatVectorFieldName = "sparse_float_vector"
)
var outputFields []string
var userOutputFields []string
Expand All @@ -483,6 +484,15 @@ func TestTranslateOutputFields(t *testing.T) {
{Name: binaryVectorFieldName, FieldID: 101, DataType: schemapb.DataType_BinaryVector},
{Name: float16VectorFieldName, FieldID: 102, DataType: schemapb.DataType_Float16Vector},
{Name: bfloat16VectorFieldName, FieldID: 103, DataType: schemapb.DataType_BFloat16Vector},
{Name: sparseFloatVectorFieldName, FieldID: 104, DataType: schemapb.DataType_SparseFloatVector, IsFunctionOutput: true},
},
Functions: []*schemapb.FunctionSchema{
{
Name: "bm25",
Type: schemapb.FunctionType_BM25,
OutputFieldNames: []string{sparseFloatVectorFieldName},
// omit other fields for brevity
},
},
}
schema := newSchemaInfo(collSchema)
Expand Down Expand Up @@ -511,6 +521,7 @@ func TestTranslateOutputFields(t *testing.T) {
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName}, userOutputFields)
assert.ElementsMatch(t, []string{}, userDynamicFields)

// sparse_float_vector is a BM25 function output field, so it should not be included in the output fields
outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{"*"}, schema, false)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName, bfloat16VectorFieldName}, outputFields)
Expand All @@ -535,6 +546,14 @@ func TestTranslateOutputFields(t *testing.T) {
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName, bfloat16VectorFieldName}, userOutputFields)
assert.ElementsMatch(t, []string{}, userDynamicFields)

// sparse_float_vector is a BM25 function output field, so it should not be included in the output fields
_, _, _, err = translateOutputFields([]string{"*", sparseFloatVectorFieldName}, schema, false)
assert.Error(t, err)
_, _, _, err = translateOutputFields([]string{sparseFloatVectorFieldName}, schema, false)
assert.Error(t, err)
_, _, _, err = translateOutputFields([]string{sparseFloatVectorFieldName}, schema, true)
assert.Error(t, err)

//=========================================================================
outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{}, schema, true)
assert.Equal(t, nil, err)
Expand Down Expand Up @@ -578,7 +597,7 @@ func TestTranslateOutputFields(t *testing.T) {
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName, bfloat16VectorFieldName}, userOutputFields)
assert.ElementsMatch(t, []string{}, userDynamicFields)

outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{"A"}, schema, true)
_, _, _, err = translateOutputFields([]string{"A"}, schema, true)
assert.Error(t, err)

t.Run("enable dynamic schema", func(t *testing.T) {
Expand Down
17 changes: 10 additions & 7 deletions internal/proxy/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -1204,7 +1204,7 @@ func translatePkOutputFields(schema *schemapb.CollectionSchema) ([]string, []int
func translateOutputFields(outputFields []string, schema *schemaInfo, addPrimary bool) ([]string, []string, []string, error) {
var primaryFieldName string
var dynamicField *schemapb.FieldSchema
allFieldNameMap := make(map[string]int64)
allFieldNameMap := make(map[string]*schemapb.FieldSchema)
resultFieldNameMap := make(map[string]bool)
resultFieldNames := make([]string, 0)
userOutputFieldsMap := make(map[string]bool)
Expand All @@ -1219,23 +1219,26 @@ func translateOutputFields(outputFields []string, schema *schemaInfo, addPrimary
if field.IsDynamic {
dynamicField = field
}
allFieldNameMap[field.Name] = field.GetFieldID()
allFieldNameMap[field.Name] = field
}

for _, outputFieldName := range outputFields {
outputFieldName = strings.TrimSpace(outputFieldName)
if outputFieldName == "*" {
for fieldName, fieldID := range allFieldNameMap {
// skip Cold field
if schema.IsFieldLoaded(fieldID) {
for fieldName, field := range allFieldNameMap {
// skip Cold field and fields that can't be output
if schema.IsFieldLoaded(field.GetFieldID()) && schema.CanRetrieveRawFieldData(field) {
resultFieldNameMap[fieldName] = true
userOutputFieldsMap[fieldName] = true
}
}
useAllDyncamicFields = true
} else {
if fieldID, ok := allFieldNameMap[outputFieldName]; ok {
if schema.IsFieldLoaded(fieldID) {
if field, ok := allFieldNameMap[outputFieldName]; ok {
if !schema.CanRetrieveRawFieldData(field) {
return nil, nil, nil, fmt.Errorf("not allowed to retrieve raw data of field %s", outputFieldName)
}
if schema.IsFieldLoaded(field.GetFieldID()) {
resultFieldNameMap[outputFieldName] = true
userOutputFieldsMap[outputFieldName] = true
} else {
Expand Down
14 changes: 14 additions & 0 deletions pkg/util/typeutil/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,20 @@ func (helper *SchemaHelper) GetFunctionByOutputField(field *schemapb.FieldSchema
return nil, fmt.Errorf("function not exist")
}

// As of now, only BM25 function output field is not supported to retrieve raw field data
func (helper *SchemaHelper) CanRetrieveRawFieldData(field *schemapb.FieldSchema) bool {
if !field.IsFunctionOutput {
return true
}

f, err := helper.GetFunctionByOutputField(field)
if err != nil {
return false
}

return f.GetType() != schemapb.FunctionType_BM25
}

func (helper *SchemaHelper) GetCollectionName() string {
return helper.schema.Name
}
Expand Down
24 changes: 12 additions & 12 deletions tests/python_client/testcases/test_full_text_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,7 @@ def test_insert_for_full_text_search_with_part_of_empty_string(self, tokenizer):
# query with expr
res, _ = collection_w.query(
expr="id >= 0",
output_fields=["text_sparse_emb", "text"]
output_fields=["text"]
)
assert len(res) == len(data)

Expand All @@ -965,7 +965,7 @@ def test_insert_for_full_text_search_with_part_of_empty_string(self, tokenizer):
anns_field="text_sparse_emb",
param={},
limit=limit,
output_fields=["id", "text", "text_sparse_emb"])
output_fields=["id", "text"])
assert len(res_list) == nq
for i in range(nq):
assert len(res_list[i]) == limit
Expand Down Expand Up @@ -1536,7 +1536,7 @@ def test_delete_for_full_text_search(self, tokenizer):
anns_field="text_sparse_emb",
param={},
limit=100,
output_fields=["id", "text", "text_sparse_emb"])
output_fields=["id", "text"])
for i in range(len(res_list)):
query_text = search_data[i]
result_texts = [r.text for r in res_list[i]]
Expand Down Expand Up @@ -2262,7 +2262,7 @@ def test_full_text_search_default(
param={},
limit=limit + offset,
offset=0,
output_fields=["id", "text", "text_sparse_emb"])
output_fields=["id", "text"])
full_res_id_list = []
for i in range(nq):
res = full_res_list[i]
Expand All @@ -2278,7 +2278,7 @@ def test_full_text_search_default(
param={},
limit=limit,
offset=offset,
output_fields=["id", "text", "text_sparse_emb"])
output_fields=["id", "text"])

# verify correctness
for i in range(nq):
Expand Down Expand Up @@ -2462,7 +2462,7 @@ def test_full_text_search_with_jieba_tokenizer(
param={},
limit=limit + offset,
offset=0,
output_fields=["id", "text", "text_sparse_emb"])
output_fields=["id", "text"])
full_res_id_list = []
for i in range(nq):
res = full_res_list[i]
Expand All @@ -2478,7 +2478,7 @@ def test_full_text_search_with_jieba_tokenizer(
param={},
limit=limit,
offset=offset,
output_fields=["id", "text", "text_sparse_emb"])
output_fields=["id", "text"])

# verify correctness
for i in range(nq):
Expand Down Expand Up @@ -2637,7 +2637,7 @@ def test_full_text_search_with_range_search(
param={
},
limit=limit, # get a wider range of search result
output_fields=["id", "text", "text_sparse_emb"])
output_fields=["id", "text"])

distance_list = []
for i in range(nq):
Expand All @@ -2660,7 +2660,7 @@ def test_full_text_search_with_range_search(
}
},
limit=limit,
output_fields=["id", "text", "text_sparse_emb"])
output_fields=["id", "text"])
# verify correctness
for i in range(nq):
log.info(f"res: {len(res_list[i])}")
Expand Down Expand Up @@ -2804,7 +2804,7 @@ def test_full_text_search_with_search_iterator(
param={
"metric_type": "BM25",
},
output_fields=["id", "text", "text_sparse_emb"],
output_fields=["id", "text"],
limit=limit
)
iter_result = []
Expand Down Expand Up @@ -2948,7 +2948,7 @@ def test_search_for_full_text_search_with_empty_string_search_data(
anns_field="text_sparse_emb",
param={},
limit=limit,
output_fields=["id", "text", "text_sparse_emb"],
output_fields=["id", "text"],
)
assert len(res) == nq
for r in res:
Expand Down Expand Up @@ -3089,7 +3089,7 @@ def test_search_for_full_text_search_with_invalid_search_data(
anns_field="text_sparse_emb",
param={},
limit=limit,
output_fields=["id", "text", "text_sparse_emb"],
output_fields=["id", "text"],
check_task=CheckTasks.err_res,
check_items=error
)
Expand Down

0 comments on commit 511edd2

Please sign in to comment.