Skip to content

Commit

Permalink
remove scipy dependency for sparse while still supporting it
Browse files Browse the repository at this point in the history
Signed-off-by: Buqian Zheng <[email protected]>
  • Loading branch information
zhengbuqian committed May 8, 2024
1 parent b771a87 commit 8389df2
Show file tree
Hide file tree
Showing 12 changed files with 157 additions and 114 deletions.
17 changes: 12 additions & 5 deletions examples/hello_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import time

import numpy as np
from scipy.sparse import rand
import random
from pymilvus import (
connections,
utility,
Expand All @@ -20,7 +20,9 @@

fmt = "=== {:30} ==="
search_latency_fmt = "search latency = {:.4f}s"
num_entities, dim, density = 1000, 3000, 0.005
num_entities, dim = 1000, 3000
# non zero count of randomly generated sparse vectors
nnz = 30

def log(msg):
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + " " + msg)
Expand Down Expand Up @@ -54,11 +56,16 @@ def log(msg):
# insert
log(fmt.format("Start creating entities to insert"))
rng = np.random.default_rng(seed=19530)
# this step is so damn slow
matrix_csr = rand(num_entities, dim, density=density, format='csr')

def generate_sparse_vector(dimension: int, non_zero_count: int) -> dict:
indices = random.sample(range(dimension), non_zero_count)
values = [random.random() for _ in range(non_zero_count)]
sparse_vector = {index: value for index, value in zip(indices, values)}
return sparse_vector

entities = [
rng.random(num_entities).tolist(),
matrix_csr,
[generate_sparse_vector(dim, nnz) for _ in range(num_entities)],
]

log(fmt.format("Start inserting entities"))
Expand Down
4 changes: 2 additions & 2 deletions pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pymilvus.grpc_gen import schema_pb2
from pymilvus.settings import Config

from . import entity_helper
from . import entity_helper, utils
from .constants import DEFAULT_CONSISTENCY_LEVEL, RANKER_TYPE_RRF, RANKER_TYPE_WEIGHTED
from .types import DataType

Expand Down Expand Up @@ -327,7 +327,7 @@ def dict(self):
class AnnSearchRequest:
def __init__(
self,
data: Union[List, entity_helper.SparseMatrixInputType],
data: Union[List, utils.SparseMatrixInputType],
anns_field: str,
param: Dict,
limit: int,
Expand Down
111 changes: 26 additions & 85 deletions pymilvus/client/entity_helper.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import math
import struct
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional

import numpy as np
import ujson
from scipy import sparse

from pymilvus.exceptions import (
DataNotMatchException,
Expand All @@ -16,67 +15,13 @@
from pymilvus.settings import Config

from .types import DataType
from .utils import SciPyHelper, SparseMatrixInputType, SparseRowOutputType

CHECK_STR_ARRAY = True

# in search results, if output fields includes a sparse float vector field, we
# will return a SparseRowOutputType for each entity. Using Dict for readability.
# TODO(SPARSE): to allow the user to specify output format.
SparseRowOutputType = Dict[int, float]

# we accept the following types as input for sparse matrix in user facing APIs
# such as insert, search, etc.:
# - scipy sparse array/matrix family: csr, csc, coo, bsr, dia, dok, lil
# - iterable of iterables, each element(iterable) is a sparse vector with index
# as key and value as float.
# dict example: [{2: 0.33, 98: 0.72, ...}, {4: 0.45, 198: 0.52, ...}, ...]
# list of tuple example: [[(2, 0.33), (98, 0.72), ...], [(4, 0.45), ...], ...]
# both index/value can be str numbers: {'2': '3.1'}
SparseMatrixInputType = Union[
Iterable[
Union[
SparseRowOutputType,
Iterable[Tuple[int, float]], # only type hint, we accept int/float like types
]
],
sparse.csc_array,
sparse.coo_array,
sparse.bsr_array,
sparse.dia_array,
sparse.dok_array,
sparse.lil_array,
sparse.csr_array,
sparse.spmatrix,
]


def sparse_is_scipy_matrix(data: Any):
return isinstance(data, sparse.spmatrix)


def sparse_is_scipy_array(data: Any):
# sparse.sparray, the common superclass of sparse.*_array, is introduced in
# scipy 1.11.0, which requires python 3.9, higher than pymilvus's current requirement.
return isinstance(
data,
(
sparse.bsr_array,
sparse.coo_array,
sparse.csc_array,
sparse.csr_array,
sparse.dia_array,
sparse.dok_array,
sparse.lil_array,
),
)


def sparse_is_scipy_format(data: Any):
return sparse_is_scipy_matrix(data) or sparse_is_scipy_array(data)


def entity_is_sparse_matrix(entity: Any):
if sparse_is_scipy_format(entity):
if SciPyHelper.is_scipy_sparse(entity):
return True
try:

Expand Down Expand Up @@ -143,34 +88,30 @@ def sparse_float_row_to_bytes(indices: Iterable[int], values: Iterable[float]):
data += struct.pack("f", v)
return data

def unify_sparse_input(data: SparseMatrixInputType) -> sparse.csr_array:
if isinstance(data, sparse.csr_array):
return data
if sparse_is_scipy_array(data):
return data.tocsr()
if sparse_is_scipy_matrix(data):
return sparse.csr_array(data.tocsr())
row_indices = []
col_indices = []
values = []
for row_id, row_data in enumerate(data):
row = row_data.items() if isinstance(row_data, dict) else row_data
row_indices.extend([row_id] * len(row))
col_indices.extend(
[int(col_id) if isinstance(col_id, str) else col_id for col_id, _ in row]
)
values.extend([float(value) if isinstance(value, str) else value for _, value in row])
return sparse.csr_array((values, (row_indices, col_indices)))

if not entity_is_sparse_matrix(data):
raise ParamError(message="input must be a sparse matrix in supported format")
csr = unify_sparse_input(data)

result = schema_types.SparseFloatArray()
result.dim = csr.shape[1]
for start, end in zip(csr.indptr[:-1], csr.indptr[1:]):
result.contents.append(
sparse_float_row_to_bytes(csr.indices[start:end], csr.data[start:end])
)

if SciPyHelper.is_scipy_sparse(data):
csr = data.tocsr()
result.dim = csr.shape[1]
for start, end in zip(csr.indptr[:-1], csr.indptr[1:]):
result.contents.append(
sparse_float_row_to_bytes(csr.indices[start:end], csr.data[start:end])
)
else:
dim = 0
for _, row_data in enumerate(data):
indices = []
values = []
row = row_data.items() if isinstance(row_data, dict) else row_data
for index, value in row:
indices.append(index)
values.append(value)
result.contents.append(sparse_float_row_to_bytes(indices, values))
dim = max(dim, indices[-1] + 1)
result.dim = dim
return result


Expand All @@ -186,7 +127,7 @@ def sparse_proto_to_rows(


def get_input_num_rows(entity: Any) -> int:
if sparse_is_scipy_format(entity):
if SciPyHelper.is_scipy_sparse(entity):
return entity.shape[0]
return len(entity)

Expand Down Expand Up @@ -354,7 +295,7 @@ def pack_field_value_to_field_data(
field_data.vectors.bfloat16_vector += v_bytes
elif field_type == DataType.SPARSE_FLOAT_VECTOR:
# field_value is a single row of sparse float vector in user provided format
if not sparse_is_scipy_format(field_value):
if not SciPyHelper.is_scipy_sparse(field_value):
field_value = [field_value]
elif field_value.shape[0] != 1:
raise ParamError(message="invalid input for sparse float vector: expect 1 row")
Expand Down
4 changes: 2 additions & 2 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pymilvus.grpc_gen import milvus_pb2 as milvus_types
from pymilvus.settings import Config

from . import entity_helper, interceptor, ts_utils
from . import entity_helper, interceptor, ts_utils, utils
from .abstract import AnnSearchRequest, BaseRanker, CollectionSchema, MutationResult, SearchResult
from .asynch import (
CreateIndexFuture,
Expand Down Expand Up @@ -761,7 +761,7 @@ def _execute_hybrid_search(
def search(
self,
collection_name: str,
data: Union[List[List[float]], entity_helper.SparseMatrixInputType],
data: Union[List[List[float]], utils.SparseMatrixInputType],
anns_field: str,
param: Dict,
limit: int,
Expand Down
5 changes: 2 additions & 3 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@

import numpy as np

from pymilvus.client import __version__, entity_helper
from pymilvus.exceptions import DataNotMatchException, ExceptionsMessage, ParamError
from pymilvus.grpc_gen import common_pb2 as common_types
from pymilvus.grpc_gen import milvus_pb2 as milvus_types
from pymilvus.grpc_gen import schema_pb2 as schema_types
from pymilvus.orm.schema import CollectionSchema

from . import blob, ts_utils, utils
from . import __version__, blob, entity_helper, ts_utils, utils
from .check import check_pass_param, is_legal_collection_properties
from .constants import (
DEFAULT_CONSISTENCY_LEVEL,
Expand Down Expand Up @@ -626,7 +625,7 @@ def _prepare_placeholder_str(cls, data: Any):
def search_requests_with_expr(
cls,
collection_name: str,
data: Union[List, entity_helper.SparseMatrixInputType],
data: Union[List, utils.SparseMatrixInputType],
anns_field: str,
param: Dict,
limit: int,
Expand Down
100 changes: 99 additions & 1 deletion pymilvus/client/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import importlib.util
from datetime import timedelta
from typing import Any, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union

import ujson

Expand Down Expand Up @@ -270,3 +271,100 @@ def get_server_type(host: str):

def dumps(v: Union[dict, str]) -> str:
return ujson.dumps(v) if isinstance(v, dict) else str(v)


class SciPyHelper:
_checked = False

# whether scipy.sparse.*_matrix classes exists
_matrix_available = False
# whether scipy.sparse.*_array classes exists
_array_available = False

@classmethod
def _init(cls):
if cls._checked:
return
scipy_spec = importlib.util.find_spec("scipy")
if scipy_spec is not None:
# when scipy is not installed, find_spec("scipy.sparse") directly
# throws exception instead of returning None.
sparse_spec = importlib.util.find_spec("scipy.sparse")
if sparse_spec is not None:
scipy_sparse = importlib.util.module_from_spec(sparse_spec)
sparse_spec.loader.exec_module(scipy_sparse)
# all scipy.sparse.*_matrix classes are introduced in the same scipy
# version, so we only need to check one of them.
cls._matrix_available = hasattr(scipy_sparse, "csr_matrix")
# all scipy.sparse.*_array classes are introduced in the same scipy
# version, so we only need to check one of them.
cls._array_available = hasattr(scipy_sparse, "csr_array")

cls._checked = True

@classmethod
def is_spmatrix(cls, data: Any):
cls._init()
if not cls._matrix_available:
return False
from scipy.sparse import isspmatrix

return isspmatrix(data)

@classmethod
def is_sparray(cls, data: Any):
cls._init()
if not cls._array_available:
return False
from scipy.sparse import issparse, isspmatrix

return issparse(data) and not isspmatrix(data)

@classmethod
def is_scipy_sparse(cls, data: Any):
return cls.is_spmatrix(data) or cls.is_sparray(data)


# in search results, if output fields includes a sparse float vector field, we
# will return a SparseRowOutputType for each entity. Using Dict for readability.
# TODO(SPARSE): to allow the user to specify output format.
SparseRowOutputType = Dict[int, float]


# this import will be called only during static type checking
if TYPE_CHECKING:
from scipy.sparse import (
bsr_array,
coo_array,
csc_array,
csr_array,
dia_array,
dok_array,
lil_array,
spmatrix,
)

# we accept the following types as input for sparse matrix in user facing APIs
# such as insert, search, etc.:
# - scipy sparse array/matrix family: csr, csc, coo, bsr, dia, dok, lil
# - iterable of iterables, each element(iterable) is a sparse vector with index
# as key and value as float.
# dict example: [{2: 0.33, 98: 0.72, ...}, {4: 0.45, 198: 0.52, ...}, ...]
# list of tuple example: [[(2, 0.33), (98, 0.72), ...], [(4, 0.45), ...], ...]
# both index/value can be str numbers: {'2': '3.1'}
SparseMatrixInputType = Union[
Iterable[
Union[
SparseRowOutputType,
Iterable[Tuple[int, float]], # only type hint, we accept int/float like types
]
],
"csc_array",
"coo_array",
"bsr_array",
"dia_array",
"dok_array",
"lil_array",
"csr_array",
"spmatrix",
]
Loading

0 comments on commit 8389df2

Please sign in to comment.