Skip to content

Commit

Permalink
update Funcion interface to use input/output_field_names
Browse files Browse the repository at this point in the history
Signed-off-by: Buqian Zheng <[email protected]>
  • Loading branch information
zhengbuqian committed Sep 19, 2024
1 parent 26cc8dd commit f70c29b
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 48 deletions.
14 changes: 6 additions & 8 deletions examples/hello_bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,8 @@
bm25_function = Function(
name="bm25",
function_type=FunctionType.BM25,
inputs=["document"],
outputs=["sparse"],
params={"bm25_k1": 1.2, "bm25_b": 0.75},
input_field_names=["document"],
output_field_names="sparse",
)

schema = CollectionSchema(fields, "hello_bm25 demo")
Expand Down Expand Up @@ -95,13 +94,12 @@

################################################################################
# 4. create index
# We are going to create an SPARSE_INVERTED_INDEX index for hello_bm25 collection.
# create_index() can only be applied to `FloatVector` and `BinaryVector` fields.
print(fmt.format("Start Creating index SPARSE_INVERTED_INDEX"))
# We are going to create an index for hello_bm25 collection, here we simply
# uses AUTOINDEX so Milvus can use the default parameters.
print(fmt.format("Start Creating index AUTOINDEX"))
index = {
"index_type": "SPARSE_INVERTED_INDEX",
"index_type": "AUTOINDEX",
"metric_type": "BM25",
'params': {"bm25_k1": 1.2, "bm25_b": 0.75},
}

hello_bm25.create_index("sparse", index)
Expand Down
5 changes: 2 additions & 3 deletions examples/hello_hybrid_bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,8 @@ def random_embedding(texts):
Function(
name="bm25",
function_type=FunctionType.BM25,
inputs=["text"],
outputs=["sparse_vector"],
params={"bm25_k1": 1.2, "bm25_b": 0.75},
input_field_names=["text"],
output_field_names="sparse_vector",
)
]
schema = CollectionSchema(fields, "", functions=functions)
Expand Down
5 changes: 2 additions & 3 deletions examples/milvus_client/bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@

bm25_function = Function(
name="bm25_fn",
inputs=["document_content"],
outputs=["sparse_vector"],
input_field_names=["document_content"],
output_field_names="sparse_vector",
function_type=FunctionType.BM25,
params={"bm25_k1": 1.2, "bm25_b": 0.75},
)
schema.add_function(bm25_function)

Expand Down
1 change: 0 additions & 1 deletion pymilvus/orm/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# or implied. See the License for the specific language governing permissions and limitations under
# the License.

import copy
from typing import List, Tuple, Union

import numpy as np
Expand Down
69 changes: 37 additions & 32 deletions pymilvus/orm/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pandas as pd
from pandas.api.types import is_list_like, is_scalar

from pymilvus.client.types import FunctionType
from pymilvus.exceptions import (
AutoIDException,
CannotInferSchemaException,
Expand All @@ -24,19 +25,18 @@
DataTypeNotSupportException,
ExceptionsMessage,
FieldsTypeException,
FunctionsTypeException,
FieldTypeException,
FunctionsTypeException,
ParamError,
PartitionKeyException,
PrimaryKeyException,
SchemaNotReadyException,
)
from pymilvus.grpc_gen import schema_pb2 as schema_types

from .constants import COMMON_TYPE_PARAMS, BM25_k1, BM25_b
from .constants import COMMON_TYPE_PARAMS
from .types import (
DataType,
FunctionType,
infer_dtype_by_scalar_data,
infer_dtype_bydata,
map_numpy_dtype_to_datatype,
Expand Down Expand Up @@ -87,7 +87,9 @@ def validate_clustering_key(clustering_key_field_name: Any, clustering_key_field


class CollectionSchema:
def __init__(self, fields: List, description: str = "", functions: List = [], **kwargs):
def __init__(
self, fields: List, description: str = "", functions: Optional[List] = None, **kwargs
):
self._kwargs = copy.deepcopy(kwargs)
self._fields = []
self._description = description
Expand All @@ -97,6 +99,9 @@ def __init__(self, fields: List, description: str = "", functions: List = [], **
self._partition_key_field = None
self._clustering_key_field = None

if functions is None:
functions = []

if not isinstance(functions, list):
raise FunctionsTypeException(message=ExceptionsMessage.FunctionsType)
for function in functions:
Expand Down Expand Up @@ -352,7 +357,7 @@ def add_field(self, field_name: str, datatype: DataType, **kwargs):
self._mark_output_fields()
return self

def add_function(self, function):
def add_function(self, function: "Function"):
if not isinstance(function, Function):
raise ParamError(message=ExceptionsMessage.FunctionIncorrectType)
self._functions.append(function)
Expand Down Expand Up @@ -459,7 +464,6 @@ def to_dict(self):
"name": self.name,
"description": self._description,
"type": self.dtype,
"is_function_output": self.is_function_output,
}
if self._type_params:
_dict["params"] = copy.deepcopy(self.params)
Expand All @@ -480,6 +484,8 @@ def to_dict(self):
_dict["element_type"] = self.element_type
if self.is_clustering_key:
_dict["is_clustering_key"] = True
if self.is_function_output:
_dict["is_function_output"] = True
return _dict

def __getattr__(self, item: str):
Expand Down Expand Up @@ -537,36 +543,38 @@ def __init__(
self,
name: str,
function_type: FunctionType,
inputs: List[str],
outputs: List[str],
input_field_names: Union[str, List[str]],
output_field_names: Union[str, List[str]],
description: str = "",
params: Dict = {},
params: Optional[Dict] = None,
):
self._name = name
self._description = description
input_field_names = (
[input_field_names] if isinstance(input_field_names, str) else input_field_names
)
output_field_names = (
[output_field_names] if isinstance(output_field_names, str) else output_field_names
)
try:
self._type = FunctionType(function_type)
except ValueError:
raise ParamError(message=ExceptionsMessage.UnknownFunctionType)
except ValueError as err:
raise ParamError(message=ExceptionsMessage.UnknownFunctionType) from err

for field_name in list(inputs) + list(outputs):
for field_name in list(input_field_names) + list(output_field_names):
if not isinstance(field_name, str):
raise ParamError(message=ExceptionsMessage.FunctionIncorrectInputOutputType)
if len(inputs) != len(set(inputs)):
if len(input_field_names) != len(set(input_field_names)):
raise ParamError(message=ExceptionsMessage.FunctionDuplicateInputs)
if len(outputs) != len(set(outputs)):
if len(output_field_names) != len(set(output_field_names)):
raise ParamError(message=ExceptionsMessage.FunctionDuplicateOutputs)

if set(inputs) & set(outputs):
if set(input_field_names) & set(output_field_names):
raise ParamError(message=ExceptionsMessage.FunctionCommonInputOutput)

self._input_field_names = inputs
self._output_field_names = outputs
if BM25_k1 in params:
params[BM25_k1] = str(params[BM25_k1])
if BM25_b in params:
params[BM25_b] = str(params[BM25_b])
self._params = params
self._input_field_names = input_field_names
self._output_field_names = output_field_names
self._params = params if params is not None else {}

@property
def name(self):
Expand Down Expand Up @@ -598,16 +606,13 @@ def verify(self, schema: CollectionSchema):
raise ParamError(message=ExceptionsMessage.BM25FunctionIncorrectInputOutputCount)

for field in schema.fields:
if field.name == self._input_field_names[0]:
if field.dtype != DataType.VARCHAR:
raise ParamError(
message=ExceptionsMessage.BM25FunctionIncorrectInputFieldType
)
if field.name == self._output_field_names[0]:
if field.dtype != DataType.SPARSE_FLOAT_VECTOR:
raise ParamError(
message=ExceptionsMessage.BM25FunctionIncorrectOutputFieldType
)
if field.name == self._input_field_names[0] and field.dtype != DataType.VARCHAR:
raise ParamError(message=ExceptionsMessage.BM25FunctionIncorrectInputFieldType)
if (
field.name == self._output_field_names[0]
and field.dtype != DataType.SPARSE_FLOAT_VECTOR
):
raise ParamError(message=ExceptionsMessage.BM25FunctionIncorrectOutputFieldType)

elif self._type == FunctionType.UNKNOWN:
raise ParamError(message=ExceptionsMessage.UnknownFunctionType)
Expand Down
2 changes: 1 addition & 1 deletion pymilvus/orm/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
is_scalar,
)

from pymilvus.client.types import DataType, FunctionType
from pymilvus.client.types import DataType

dtype_str_map = {
"string": DataType.VARCHAR,
Expand Down

0 comments on commit f70c29b

Please sign in to comment.