From dd5ef981085fe24278aca50002619c94ea394676 Mon Sep 17 00:00:00 2001 From: yangxuan Date: Tue, 27 Jun 2023 14:52:10 +0800 Subject: [PATCH] Change pylint to ruff and black Signed-off-by: yangxuan --- .github/mergify.yml | 9 +- .github/workflows/check_milvus_proto.yml | 2 +- .github/workflows/code_checker.yml | 2 +- .github/workflows/nightly_ci.yml | 2 +- .github/workflows/pull_request.yml | 2 +- Makefile | 9 +- README.md | 22 +- pylint.conf | 612 -------------- pymilvus/__init__.py | 156 ++-- pymilvus/client/__init__.py | 41 +- pymilvus/client/abstract.py | 559 +++---------- pymilvus/client/asynch.py | 99 ++- pymilvus/client/blob.py | 35 +- pymilvus/client/check.py | 101 ++- pymilvus/client/constants.py | 2 +- pymilvus/client/entity_helper.py | 202 ++--- pymilvus/client/grpc_handler.py | 969 ++++++++++++++++------- pymilvus/client/interceptor.py | 81 +- pymilvus/client/prepare.py | 845 ++++++++++++-------- pymilvus/client/singleton_utils.py | 3 +- pymilvus/client/stub.py | 222 ++++-- pymilvus/client/ts_utils.py | 34 +- pymilvus/client/types.py | 215 +++-- pymilvus/client/utils.py | 130 +-- pymilvus/decorators.py | 120 ++- pymilvus/exceptions.py | 72 +- pymilvus/grpc_gen/__init__.py | 12 +- pymilvus/milvus_client/milvus_client.py | 266 ++++--- pymilvus/orm/collection.py | 694 ++++++++++------ pymilvus/orm/connections.py | 141 ++-- pymilvus/orm/constants.py | 2 +- pymilvus/orm/db.py | 22 +- pymilvus/orm/future.py | 24 +- pymilvus/orm/index.py | 166 ++-- pymilvus/orm/iterator.py | 201 +++-- pymilvus/orm/mutation.py | 6 +- pymilvus/orm/partition.py | 248 ++++-- pymilvus/orm/prepare.py | 95 ++- pymilvus/orm/role.py | 73 +- pymilvus/orm/schema.py | 266 +++---- pymilvus/orm/search.py | 50 +- pymilvus/orm/types.py | 34 +- pymilvus/orm/utility.py | 465 +++++++---- pymilvus/settings.py | 75 +- pyproject.toml | 121 +++ requirements.txt | 3 +- test_requirements.txt | 3 +- tests/test_schema.py | 2 +- tests/test_types.py | 3 +- 49 files changed, 4004 insertions(+), 3514 deletions(-) delete mode 100644 pylint.conf diff --git a/.github/mergify.yml b/.github/mergify.yml index 3028c252b..4fc75cbb6 100644 --- a/.github/mergify.yml +++ b/.github/mergify.yml @@ -4,9 +4,12 @@ pull_request_rules: - or: - base=master - base~=2\.\d - - "status-success=Run Python Tests (3.8)" - - "status-success=Run Check Proto (3.8)" - - "status-success=Code lint check (3.8)" + - "status-success=Run Python Tests (3.7)" + - "status-success=Run Check Proto (3.7)" + - "status-success=Code lint check (3.7)" + - "status-success=Run Python Tests (3.11)" + - "status-success=Run Check Proto (3.11)" + - "status-success=Code lint check (3.11)" actions: label: add: diff --git a/.github/workflows/check_milvus_proto.yml b/.github/workflows/check_milvus_proto.yml index 74c3b11d3..870af92dd 100644 --- a/.github/workflows/check_milvus_proto.yml +++ b/.github/workflows/check_milvus_proto.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8] + python-version: [3.7, 3.11] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/code_checker.yml b/.github/workflows/code_checker.yml index 96e38c3b7..ca6aff38e 100644 --- a/.github/workflows/code_checker.yml +++ b/.github/workflows/code_checker.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8, 3.11] + python-version: [3.7, 3.11] steps: - name: Checkout code uses: actions/checkout@v2 diff --git a/.github/workflows/nightly_ci.yml b/.github/workflows/nightly_ci.yml index 9498f33d9..ed0d2d42e 100644 --- a/.github/workflows/nightly_ci.yml +++ b/.github/workflows/nightly_ci.yml @@ -20,7 +20,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.6] + python-version: [3.7] env: IMAGE_REPO: "milvusdb" TAG_PREFIX: "master-" diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index ac6a31017..d9ce774ed 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8, 3.11] + python-version: [3.7, 3.11] steps: - name: Checkout code diff --git a/Makefile b/Makefile index 332c476b2..8fcea583c 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,12 @@ unittest: PYTHONPATH=`pwd` python3 -m pytest tests --cov=pymilvus -v lint: - PYTHONPATH=`pwd` pylint --rcfile=pylint.conf pymilvus + PYTHONPATH=`pwd` black pymilvus --check + PYTHONPATH=`pwd` ruff check pymilvus + +format: + PYTHONPATH=`pwd` black pymilvus + PYTHONPATH=`pwd` ruff check pymilvus --fix codecov: PYTHONPATH=`pwd` pytest --cov=pymilvus --cov-report=xml tests -x -v -rxXs @@ -24,7 +29,7 @@ gen_proto: check_proto_product: gen_proto ./check_proto_product.sh - + version: python -m setuptools_scm diff --git a/README.md b/README.md index cc18c8054..18bd5271a 100644 --- a/README.md +++ b/README.md @@ -21,12 +21,12 @@ The following collection shows Milvus versions and recommended PyMilvus versions | 1.1.\* | 1.1.2 | | 2.0.\* | 2.0.2 | | 2.1.\* | 2.1.3 | -| 2.2.\* | 2.2.0 | +| 2.2.\* | 2.2.13 | ## Installation -You can install PyMilvus via `pip` or `pip3` for Python 3.6+: +You can install PyMilvus via `pip` or `pip3` for Python 3.7+: ```shell $ pip3 install pymilvus @@ -35,7 +35,7 @@ $ pip3 install pymilvus You can install a specific version of PyMilvus by: ```shell -$ pip3 install pymilvus==2.2.0 +$ pip3 install pymilvus==2.2.13 ``` You can upgrade PyMilvus to the latest version by: @@ -66,7 +66,21 @@ Q3. How to use the local PyMilvus repository for Milvus server? A3. ```shell -$ python setup.py install +$ make install +``` + +Q4. How to check coding styles? + +A4. +```shell +make lint +``` + +Q5. How to fix the coding styles? + +Q5 +```shell +make format ``` diff --git a/pylint.conf b/pylint.conf deleted file mode 100644 index 771fe3c20..000000000 --- a/pylint.conf +++ /dev/null @@ -1,612 +0,0 @@ -[MASTER] - -# A comma-separated list of package or module names from where C extensions may -# be loaded. Extensions are loading into the active Python interpreter and may -# run arbitrary code. -extension-pkg-whitelist=grpc,ujson - -# Add files or directories to the blacklist. They should be base names, not -# paths. -ignore=CVS,grpc_gen - -# Add files or directories matching the regex patterns to the blacklist. The -# regex matches against base names, not paths. -ignore-patterns= - -# Python code to execute, usually for sys.path manipulation such as -# pygtk.require(). -init-hook="import sys;sys.path.append('..');sys.path.append('.')" - -# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the -# number of processors available to use. -jobs=1 - -# Control the amount of potential inferred values when inferring a single -# object. This can help the performance when dealing with large functions or -# complex, nested conditions. -limit-inference-results=100 - -# List of plugins (as comma separated values of python module names) to load, -# usually to register additional checkers. -load-plugins= - -# Pickle collected data for later comparisons. -persistent=no - -# Specify a configuration file. -#rcfile= - -# When enabled, pylint would attempt to guess common misconfiguration and emit -# user-friendly hints instead of false-positive error messages. -suggestion-mode=yes - -# Allow loading of arbitrary C extensions. Extensions are imported into the -# active Python interpreter and may run arbitrary code. -unsafe-load-any-extension=no - - -[MESSAGES CONTROL] - -# Only show warnings with the listed confidence levels. Leave empty to show -# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. -confidence= - -# Disable the message, report, category or checker with the given id(s). You -# can either give multiple identifiers separated by comma (,) or put this -# option multiple times (only on the command line, not in the configuration -# file where it should appear only once). You can also use "--disable=all" to -# disable everything first and then reenable specific checks. For example, if -# you want to run only the similarities checker, you can use "--disable=all -# --enable=similarities". If you want to run only the classes checker, but have -# no Warning level messages displayed, use "--disable=all --enable=classes -# --disable=W". -disable=print-statement, - parameter-unpacking, - unpacking-in-except, - old-raise-syntax, - backtick, - long-suffix, - old-ne-operator, - old-octal-literal, - import-star-module-level, - non-ascii-bytes-literal, - raw-checker-failed, - bad-inline-option, - locally-disabled, - file-ignored, - suppressed-message, - useless-suppression, - deprecated-pragma, - use-symbolic-message-instead, - apply-builtin, - basestring-builtin, - buffer-builtin, - cmp-builtin, - coerce-builtin, - execfile-builtin, - file-builtin, - long-builtin, - raw_input-builtin, - reduce-builtin, - standarderror-builtin, - unicode-builtin, - xrange-builtin, - coerce-method, - delslice-method, - getslice-method, - setslice-method, - no-absolute-import, - old-division, - dict-iter-method, - dict-view-method, - next-method-called, - metaclass-assignment, - indexing-exception, - raising-string, - reload-builtin, - oct-method, - hex-method, - nonzero-method, - cmp-method, - input-builtin, - round-builtin, - intern-builtin, - unichr-builtin, - map-builtin-not-iterating, - zip-builtin-not-iterating, - range-builtin-not-iterating, - filter-builtin-not-iterating, - using-cmp-argument, - eq-without-hash, - div-method, - idiv-method, - rdiv-method, - exception-message-attribute, - invalid-str-codec, - sys-max-int, - bad-python3-import, - deprecated-string-function, - deprecated-str-translate-call, - deprecated-itertools-function, - deprecated-types-field, - next-method-defined, - dict-items-not-iterating, - dict-keys-not-iterating, - dict-values-not-iterating, - deprecated-operator-function, - deprecated-urllib-function, - xreadlines-attribute, - deprecated-sys-function, - exception-escape, - comprehension-escape, - c-extension-no-member, - protected-access, - C0415, # Import outside toplevel (collection) (import-outside-toplevel) - C0103, # invalid-name - C0114, # missing-module-docstring - C0115, # missing-class-docstring - C0116, # missing-function-docstring - C0330, # Wrong hanging indentation before block (add 4 spaces) - R0201, # no-self-use - R0903, # too-few-public-methods - R0904, # too-many-public-methods - R0912, # too-many-branches - R0913, # too-many-arguments - R0914, # many local variables (too-many-locals) - # R1721, unnecessary-comprehension - W0511, # (fixme) - W0107, # Unnecessary pass statement (unnecessary-pass) - W0613, # unused-argument - E1101, # no-member - W1202, # logging-format-interpolation - W1203, # Use lazy % or .format() formatting in logging functions (logging-fstring-interpolation) - W0102, # TODO, dagerous-default-value - W0201, # TODO, Attribute 'compaction_id' defined outside __init__ (attribute-defined-outside-init) - W0622, # TODO, Redefining built-in 'object' (redefined-builtin) - W0703, # Catching too general exception Exception (broad-except) - R1710, # TODO, Either all return statements in a function should return an expression, or none of them should. (inconsistent-return-statements) - R0801, # TODO, Similar lines in 2 files - R0401, # TODO, Cyclic import (pymilvus.orm.collection -> pymilvus.orm.partition) (cyclic-import) - - - -# Enable the message, report, category or checker with the given id(s). You can -# either give multiple identifier separated by comma (,) or put this option -# multiple time (only on the command line, not in the configuration file where -# it should appear only once). See also the "--disable" option for examples. -enable=c-extension-no-member - - -[REPORTS] - -# Python expression which should return a score less than or equal to 10. You -# have access to the variables 'error', 'warning', 'refactor', and 'convention' -# which contain the number of messages in each category, as well as 'statement' -# which is the total number of statements analyzed. This score is used by the -# global evaluation report (RP0004). -evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) - -# Template used to display messages. This is a python new-style format string -# used to format the message information. See doc for all details. -#msg-template= - -# Set the output format. Available formats are text, parseable, colorized, json -# and msvs (visual studio). You can also give a reporter class, e.g. -# mypackage.mymodule.MyReporterClass. -output-format=text - -# Tells whether to display a full report or only the messages. -reports=no - -# Activate the evaluation score. -score=yes - - -[REFACTORING] - -# Maximum number of nested blocks for function / method body -max-nested-blocks=5 - -# Complete name of functions that never returns. When checking for -# inconsistent-return-statements if a never returning function is called then -# it will be considered as an explicit return statement and no message will be -# printed. -never-returning-functions=sys.exit - - -[STRING] - -# This flag controls whether the implicit-str-concat-in-sequence should -# generate a warning on implicit string concatenation in sequences defined over -# several lines. -check-str-concat-over-line-jumps=no - - -[MISCELLANEOUS] - -# List of note tags to take in consideration, separated by a comma. -notes=FIXME, - XXX, - TODO - - -[BASIC] - -# Naming style matching correct argument names. -argument-naming-style=snake_case - -# Regular expression matching correct argument names. Overrides argument- -# naming-style. -#argument-rgx= - -# Naming style matching correct attribute names. -attr-naming-style=snake_case - -# Regular expression matching correct attribute names. Overrides attr-naming- -# style. -#attr-rgx= - -# Bad variable names which should always be refused, separated by a comma. -bad-names=foo, - bar, - baz, - toto, - tutu, - tata - -# Naming style matching correct class attribute names. -class-attribute-naming-style=any - -# Regular expression matching correct class attribute names. Overrides class- -# attribute-naming-style. -#class-attribute-rgx= - -# Naming style matching correct class names. -class-naming-style=PascalCase - -# Regular expression matching correct class names. Overrides class-naming- -# style. -#class-rgx= - -# Naming style matching correct constant names. -const-naming-style=UPPER_CASE - -# Regular expression matching correct constant names. Overrides const-naming- -# style. -#const-rgx= - -# Minimum line length for functions/classes that require docstrings, shorter -# ones are exempt. -docstring-min-length=-1 - -# Naming style matching correct function names. -function-naming-style=snake_case - -# Regular expression matching correct function names. Overrides function- -# naming-style. -#function-rgx= - -# Good variable names which should always be accepted, separated by a comma. -good-names=i, - j, - k, - ex, - Run, - _ - -# Include a hint for the correct naming format with invalid-name. -include-naming-hint=no - -# Naming style matching correct inline iteration names. -inlinevar-naming-style=any - -# Regular expression matching correct inline iteration names. Overrides -# inlinevar-naming-style. -#inlinevar-rgx= - -# Naming style matching correct method names. -method-naming-style=snake_case - -# Regular expression matching correct method names. Overrides method-naming- -# style. -#method-rgx= - -# Naming style matching correct module names. -module-naming-style=snake_case - -# Regular expression matching correct module names. Overrides module-naming- -# style. -#module-rgx= - -# Colon-delimited sets of names that determine each other's naming style when -# the name regexes allow several styles. -name-group= - -# Regular expression which should only match function or class names that do -# not require a docstring. -no-docstring-rgx=^_ - -# List of decorators that produce properties, such as abc.abstractproperty. Add -# to this list to register other decorators that produce valid properties. -# These decorators are taken in consideration only for invalid-name. -property-classes=abc.abstractproperty - -# Naming style matching correct variable names. -variable-naming-style=snake_case - -# Regular expression matching correct variable names. Overrides variable- -# naming-style. -#variable-rgx= - - -[LOGGING] - -# Format style used to check logging format string. `old` means using % -# formatting, `new` is for `{}` formatting. -logging-format-style=new - -# Logging modules to check that the string format arguments are in logging -# function parameter format. -logging-modules=logging - - -[SIMILARITIES] - -# Ignore comments when computing similarities. -ignore-comments=yes - -# Ignore docstrings when computing similarities. -ignore-docstrings=yes - -# Ignore imports when computing similarities. -ignore-imports=no - -# Minimum lines number of a similarity. -min-similarity-lines=4 - - -[VARIABLES] - -# List of additional names supposed to be defined in builtins. Remember that -# you should avoid defining new builtins when possible. -additional-builtins= - -# Tells whether unused global variables should be treated as a violation. -allow-global-unused-variables=yes - -# List of strings which can identify a callback function by name. A callback -# name must start or end with one of those strings. -callbacks=cb_, - _cb - -# A regular expression matching the name of dummy variables (i.e. expected to -# not be used). -dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ - -# Argument names that match this expression will be ignored. Default to name -# with leading underscore. -ignored-argument-names=_.*|^ignored_|^unused_ - -# Tells whether we should check for unused import in __init__ files. -init-import=no - -# List of qualified module names which can have objects that can redefine -# builtins. -redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io - - -[FORMAT] - -# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. -expected-line-ending-format= - -# Regexp for a line that is allowed to be longer than the limit. -ignore-long-lines=(^\s*def\s.*)|(f\".*\") # ignore string formating and functions -# ^\s*(# )??$, - -# Number of spaces of indent required inside a hanging or continued line. -indent-after-paren=4 - -# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 -# tab). -indent-string=' ' - -# Maximum number of characters on a single line. -max-line-length=120 - -# Maximum number of lines in a module. -max-module-lines=1500 # TODO GOOSE - -# List of optional constructs for which whitespace checking is disabled. `dict- -# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. -# `trailing-comma` allows a space between comma and closing bracket: (a, ). -# `empty-line` allows space-only lines. -no-space-check=trailing-comma, - dict-separator - -# Allow the body of a class to be on the same line as the declaration if body -# contains single statement. -single-line-class-stmt=no - -# Allow the body of an if to be on the same line as the test if there is no -# else. -single-line-if-stmt=no - - -[SPELLING] - -# Limits count of emitted suggestions for spelling mistakes. -max-spelling-suggestions=4 - -# Spelling dictionary name. Available dictionaries: none. To make it work, -# install the python-enchant package. -spelling-dict= - -# List of comma separated words that should not be checked. -spelling-ignore-words= - -# A path to a file that contains the private dictionary; one word per line. -spelling-private-dict-file= - -# Tells whether to store unknown words to the private dictionary (see the -# --spelling-private-dict-file option) instead of raising a message. -spelling-store-unknown-words=no - - -[TYPECHECK] - -# List of decorators that produce context managers, such as -# contextlib.contextmanager. Add to this list to register other decorators that -# produce valid context managers. -contextmanager-decorators=contextlib.contextmanager - -# List of members which are set dynamically and missed by pylint inference -# system, and so shouldn't trigger E1101 when accessed. Python regular -# expressions are accepted. -generated-members= - -# Tells whether missing members accessed in mixin class should be ignored. A -# mixin class is detected if its name ends with "mixin" (case insensitive). -ignore-mixin-members=yes - -# Tells whether to warn about missing members when the owner of the attribute -# is inferred to be None. -ignore-none=yes - -# This flag controls whether pylint should warn about no-member and similar -# checks whenever an opaque object is returned when inferring. The inference -# can return multiple potential results while evaluating a Python object, but -# some branches might not be evaluated, which results in partial inference. In -# that case, it might be useful to still emit no-member and other checks for -# the rest of the inferred objects. -ignore-on-opaque-inference=yes - -# List of class names for which member attributes should not be checked (useful -# for classes with dynamically set attributes). This supports the use of -# qualified names. -ignored-classes=optparse.Values,thread._local,_thread._local - -# List of module names for which member attributes should not be checked -# (useful for modules/projects where namespaces are manipulated during runtime -# and thus existing member attributes cannot be deduced by static analysis). It -# supports qualified module names, as well as Unix pattern matching. -ignored-modules= - -# Show a hint with possible names when a member name was not found. The aspect -# of finding the hint is based on edit distance. -missing-member-hint=yes - -# The minimum edit distance a name should have in order to be considered a -# similar match for a missing member name. -missing-member-hint-distance=1 - -# The total number of similar names that should be taken in consideration when -# showing a hint for a missing member. -missing-member-max-choices=1 - -# List of decorators that change the signature of a decorated function. -signature-mutators= - - -[CLASSES] - -# List of method names used to declare (i.e. assign) instance attributes. -defining-attr-methods=__init__, - __new__, - setUp, - __post_init__ - -# List of member names, which should be excluded from the protected access -# warning. -exclude-protected=_asdict, - _fields, - _replace, - _source, - _make - -# List of valid names for the first argument in a class method. -valid-classmethod-first-arg=cls - -# List of valid names for the first argument in a metaclass class method. -valid-metaclass-classmethod-first-arg=cls - - -[DESIGN] - -# Maximum number of arguments for function / method. -max-args=5 - -# Maximum number of attributes for a class (see R0902). -max-attributes=15 - -# Maximum number of boolean expressions in an if statement (see R0916). -max-bool-expr=5 - -# Maximum number of branch for function / method body. -max-branches=12 - -# Maximum number of locals for function / method body. -max-locals=15 - -# Maximum number of parents for a class (see R0901). -max-parents=7 - -# Maximum number of public methods for a class (see R0904). -max-public-methods=20 - -# Maximum number of return / yield for function / method body. -max-returns=10 - -# Maximum number of statements in function / method body. -max-statements=50 - -# Minimum number of public methods for a class (see R0903). -min-public-methods=2 - - -[IMPORTS] - -# List of modules that can be imported at any level, not just the top level -# one. -allow-any-import-level= - -# Allow wildcard imports from modules that define __all__. -allow-wildcard-with-all=no - -# Analyse import fallback blocks. This can be used to support both Python 2 and -# 3 compatible code, which means that the block might have code that exists -# only in one or another interpreter, leading to false positives when analysed. -analyse-fallback-blocks=no - -# Deprecated modules which should not be used, separated by a comma. -deprecated-modules=optparse,tkinter.tix - -# Create a graph of external dependencies in the given file (report RP0402 must -# not be disabled). -ext-import-graph= - -# Create a graph of every (i.e. internal and external) dependencies in the -# given file (report RP0402 must not be disabled). -import-graph= - -# Create a graph of internal dependencies in the given file (report RP0402 must -# not be disabled). -int-import-graph= - -# Force import order to recognize a module as part of the standard -# compatibility libraries. -known-standard-library= - -# Force import order to recognize a module as part of a third party library. -known-third-party=enchant - -# Couples of modules and preferred modules, separated by a comma. -preferred-modules= - - -[EXCEPTIONS] - -# Exceptions that will emit a warning when being caught. Defaults to -# "BaseException, Exception". -overgeneral-exceptions=BaseException, - Exception diff --git a/pymilvus/__init__.py b/pymilvus/__init__.py index d45af9051..cc3a19597 100644 --- a/pymilvus/__init__.py +++ b/pymilvus/__init__.py @@ -10,90 +10,118 @@ # or implied. See the License for the specific language governing permissions and limitations under # the License. -from .client.stub import Milvus +from .client import __version__ from .client.prepare import Prepare +from .client.stub import Milvus from .client.types import ( - Status, + BulkInsertState, DataType, - RangeType, + Group, IndexType, Replica, - Group, + ResourceGroupInfo, Shard, - BulkInsertState, - ResourceGroupInfo + Status, ) from .exceptions import ( - ParamError, + ExceptionsMessage, MilvusException, MilvusUnavailableException, - ExceptionsMessage ) -from .client import __version__ - -from .settings import ( - DEBUG_LOG_LEVEL, - INFO_LOG_LEVEL, - WARN_LOG_LEVEL, - ERROR_LOG_LEVEL, -) -# Compatiable -from .settings import Config as DefaultConfig - -from .client.constants import DEFAULT_RESOURCE_GROUP - +from .milvus_client.milvus_client import MilvusClient +from .orm import db, utility from .orm.collection import Collection -from .orm.connections import connections, Connections - +from .orm.connections import Connections, connections +from .orm.future import MutationFuture, SearchFuture from .orm.index import Index from .orm.partition import Partition +from .orm.role import Role +from .orm.schema import CollectionSchema, FieldSchema +from .orm.search import Hit, Hits, SearchResult from .orm.utility import ( - loading_progress, - index_building_progress, - wait_for_loading_complete, - wait_for_index_building_complete, + create_resource_group, + create_user, + delete_user, + describe_resource_group, + drop_collection, + drop_resource_group, has_collection, has_partition, + hybridts_to_datetime, + hybridts_to_unixtime, + index_building_progress, list_collections, - drop_collection, - get_query_segment_info, - load_balance, - mkts_from_hybridts, mkts_from_unixtime, mkts_from_datetime, - hybridts_to_unixtime, hybridts_to_datetime, - do_bulk_insert, get_bulk_insert_state, list_bulk_insert_tasks, - reset_password, create_user, update_password, delete_user, list_usernames, - create_resource_group, drop_resource_group, describe_resource_group, - list_resource_groups, transfer_node, transfer_replica + list_resource_groups, + list_usernames, + loading_progress, + mkts_from_datetime, + mkts_from_hybridts, + mkts_from_unixtime, + reset_password, + transfer_node, + transfer_replica, + update_password, + wait_for_index_building_complete, + wait_for_loading_complete, ) -from .orm import utility, db - -from .orm.search import SearchResult, Hits, Hit -from .orm.schema import FieldSchema, CollectionSchema -from .orm.future import SearchFuture, MutationFuture -from .orm.role import Role - -from .milvus_client.milvus_client import MilvusClient +# Compatiable +from .settings import Config as DefaultConfig __all__ = [ - 'Collection', 'Index', 'Partition', - 'connections', - 'loading_progress', 'index_building_progress', 'wait_for_loading_complete', 'has_collection', 'has_partition', - 'list_collections', 'wait_for_loading_complete', 'wait_for_index_building_complete', 'drop_collection', - 'mkts_from_hybridts', 'mkts_from_unixtime', 'mkts_from_datetime', - 'hybridts_to_unixtime', 'hybridts_to_datetime', - 'reset_password', 'create_user', 'update_password', 'delete_user', 'list_usernames', - 'SearchResult', 'Hits', 'Hit', 'Replica', 'Group', 'Shard', - 'FieldSchema', 'CollectionSchema', - 'SearchFuture', 'MutationFuture', - 'utility', 'db', 'DefaultConfig', 'ExceptionsMessage', 'MilvusUnavailableException', 'BulkInsertState', - 'Role', - 'create_resource_group', 'drop_resource_group', 'describe_resource_group', - 'list_resource_groups', 'transfer_node', 'transfer_replica', - - 'Milvus', 'Prepare', 'Status', 'DataType', - 'MilvusException', - '__version__', - - 'MilvusClient' + "Collection", + "Index", + "Partition", + "connections", + "loading_progress", + "index_building_progress", + "wait_for_index_building_complete", + "drop_collection", + "has_collection", + "list_collections", + "wait_for_loading_complete", + "has_partition", + "mkts_from_hybridts", + "mkts_from_unixtime", + "mkts_from_datetime", + "hybridts_to_unixtime", + "hybridts_to_datetime", + "reset_password", + "create_user", + "update_password", + "delete_user", + "list_usernames", + "SearchResult", + "Hits", + "Hit", + "Replica", + "Group", + "Shard", + "FieldSchema", + "CollectionSchema", + "SearchFuture", + "MutationFuture", + "utility", + "db", + "DefaultConfig", + "Role", + "ExceptionsMessage", + "MilvusUnavailableException", + "BulkInsertState", + "create_resource_group", + "drop_resource_group", + "describe_resource_group", + "list_resource_groups", + "transfer_node", + "transfer_replica", + "Milvus", + "Prepare", + "Status", + "DataType", + "MilvusException", + "__version__", + "MilvusClient", + "ResourceGroupInfo", + "Connections", + "IndexType", ] diff --git a/pymilvus/client/__init__.py b/pymilvus/client/__init__.py index c7405010d..800d36b25 100644 --- a/pymilvus/client/__init__.py +++ b/pymilvus/client/__init__.py @@ -1,22 +1,25 @@ -import subprocess +import logging import re -from pkg_resources import get_distribution, DistributionNotFound +import subprocess +from contextlib import suppress + +from pkg_resources import DistributionNotFound, get_distribution + +log = logging.getLogger(__name__) + + +__version__ = "0.0.0.dev" -__version__ = '0.0.0.dev' -try: - __version__ = get_distribution('pymilvus').version -except DistributionNotFound: - # package is not installed - pass +with suppress(DistributionNotFound): + __version__ = get_distribution("pymilvus").version -def get_commit(version="", short=True) -> str: - """get commit return the commit for a specific version like `xxxxxx.dev12` """ +def get_commit(version: str = "", short: bool = True) -> str: + """get commit return the commit for a specific version like `xxxxxx.dev12`""" - version_info = r'((\d+)\.(\d+)\.(\d+))((rc)(\d+))?(\.dev(\d+))?' + version_info = r"((\d+)\.(\d+)\.(\d+))((rc)(\d+))?(\.dev(\d+))?" # 2.0.0rc9.dev12 - # ('2.0.0', '2', '0', '0', 'rc9', 'rc', '9', '.dev12', '12') p = re.compile(version_info) target_v = __version__ if version == "" else version @@ -27,23 +30,23 @@ def get_commit(version="", short=True) -> str: if match_version[7] is not None: if match_version[4] is not None: v = str(int(match_version[6]) - 1) - target_tag = 'v' + match_version[0] + match_version[5] + v + target_tag = "v" + match_version[0] + match_version[5] + v else: - target_tag = 'v' + ".".join(str(int("".join(match_version[1:4])) - 1).split("")) + target_tag = "v" + ".".join(str(int("".join(match_version[1:4])) - 1).split("")) target_num = int(match_version[-1]) elif match_version[4] is not None: - target_tag = 'v' + match_version[0] + match_version[4] + target_tag = "v" + match_version[0] + match_version[4] target_num = 0 else: - target_tag = 'v' + match_version[0] + target_tag = "v" + match_version[0] target_num = 0 else: return f"Version: {target_v} isn't the right form" try: - cmd = ['git', 'rev-list', '--reverse', '--ancestry-path', f'{target_tag}^..HEAD'] - print(f"git cmd: {' '.join(cmd)}") - result = subprocess.check_output(cmd).decode('ascii').strip().split('\n') + cmd = ["git", "rev-list", "--reverse", "--ancestry-path", f"{target_tag}^..HEAD"] + log.info(f"git cmd: {' '.join(cmd)}") + result = subprocess.check_output(cmd).decode("ascii").strip().split("\n") length = 7 if short else 40 return result[target_num][:length] diff --git a/pymilvus/client/abstract.py b/pymilvus/client/abstract.py index d0e1e573b..419848fa4 100644 --- a/pymilvus/client/abstract.py +++ b/pymilvus/client/abstract.py @@ -1,11 +1,13 @@ import abc +from typing import Any, Dict, List + +from pymilvus.exceptions import MilvusException +from pymilvus.grpc_gen import schema_pb2 +from pymilvus.settings import Config -from ..settings import Config -from .types import DataType -from .constants import DEFAULT_CONSISTENCY_LEVEL -from ..grpc_gen import schema_pb2 -from ..exceptions import MilvusException from . import entity_helper +from .constants import DEFAULT_CONSISTENCY_LEVEL +from .types import DataType class LoopBase: @@ -15,17 +17,17 @@ def __init__(self): def __iter__(self): return self - def __getitem__(self, item): + def __getitem__(self, item: Any): if isinstance(item, slice): _start = item.start or 0 _end = min(item.stop, self.__len__()) if item.stop else self.__len__() _step = item.step or 1 - elements = [self.get__item(i) for i in range(_start, _end, _step)] - return elements + return [self.get__item(i) for i in range(_start, _end, _step)] if item >= self.__len__(): - raise IndexError("Index out of range") + msg = "Index out of range" + raise IndexError(msg) return self.get__item(item) @@ -36,27 +38,27 @@ def __next__(self): # iterate stop, raise Exception self.__index = 0 - raise StopIteration() + raise StopIteration def __str__(self): return str(list(map(str, self.__getitem__(slice(0, 10))))) @abc.abstractmethod - def get__item(self, item): - raise NotImplementedError() + def get__item(self, item: Any): + raise NotImplementedError class LoopCache: def __init__(self): self._array = [] - def fill(self, index, obj): + def fill(self, index: int, obj: Any): if len(self._array) + 1 < index: pass class FieldSchema: - def __init__(self, raw): + def __init__(self, raw: Any): self._raw = raw # self.field_id = 0 @@ -73,7 +75,7 @@ def __init__(self, raw): ## self.__pack(self._raw) - def __pack(self, raw): + def __pack(self, raw: Any): self.field_id = raw.fieldID self.name = raw.name self.is_primary = raw.is_primary_key @@ -82,19 +84,17 @@ def __pack(self, raw): self.type = raw.data_type self.is_partition_key = raw.is_partition_key self.default_value = raw.default_value - if raw.default_value is not None: - if raw.default_value.WhichOneof("data") is None: - self.default_value = None + if raw.default_value is not None and raw.default_value.WhichOneof("data") is None: + self.default_value = None try: self.is_dynamic = raw.is_dynamic except Exception: self.is_dynamic = False - # self.type = DataType(int(raw.type)) - for type_param in raw.type_params: if type_param.key == "params": import json + self.params[type_param.key] = json.loads(type_param.value) else: self.params[type_param.key] = type_param.value @@ -106,6 +106,7 @@ def __pack(self, raw): for index_param in raw.index_params: if index_param.key == "params": import json + index_dict[index_param.key] = json.loads(index_param.value) else: index_dict[index_param.key] = index_param.value @@ -113,9 +114,8 @@ def __pack(self, raw): self.indexes.extend([index_dict]) def dict(self): - if self.default_value is not None: - if self.default_value.WhichOneof("data") is None: - self.default_value = None + if self.default_value is not None and self.default_value.WhichOneof("data") is None: + self.default_value = None _dict = { "field_id": self.field_id, "name": self.name, @@ -137,7 +137,7 @@ def dict(self): class CollectionSchema: - def __init__(self, raw): + def __init__(self, raw: Any): self._raw = raw # @@ -159,7 +159,7 @@ def __init__(self, raw): if self._raw: self.__pack(self._raw) - def __pack(self, raw): + def __pack(self, raw: Any): self.collection_name = raw.schema.name self.description = raw.schema.description self.aliases = raw.aliases @@ -178,22 +178,17 @@ def __pack(self, raw): except Exception: self.enable_dynamic_field = False - # self.params = dict() # TODO: extra_params here # for kv in raw.extra_params: - # par = ujson.loads(kv.value) - # self.params.update(par) - # # self.params[kv.key] = kv.value self.fields = [FieldSchema(f) for f in raw.schema.fields] # for s in raw.statistics: - # self.statistics[s.key] = s.value self.properties = raw.properties @classmethod - def _rewrite_schema_dict(cls, schema_dict): + def _rewrite_schema_dict(cls, schema_dict: Dict): fields = schema_dict.get("fields", []) if not fields: return @@ -228,16 +223,16 @@ def __str__(self): class Entity: - def __init__(self, entity_id, entity_row_data, entity_score): + def __init__(self, entity_id: int, entity_row_data: Any, entity_score: float): self._id = entity_id self._row_data = entity_row_data self._score = entity_score self._distance = entity_score def __str__(self): - return f'id: {self._id}, distance: {self._distance}, entity: {self._row_data}' + return f"id: {self._id}, distance: {self._distance}, entity: {self._row_data}" - def __getattr__(self, item): + def __getattr__(self, item: Any): return self.value_of_field(item) @property @@ -246,26 +241,26 @@ def id(self): @property def fields(self): - fields = [k for k, v in self._row_data.items()] - return fields + return [k for k, v in self._row_data.items()] - def get(self, field): + def get(self, field: Any): return self.value_of_field(field) - def value_of_field(self, field): + def value_of_field(self, field: Any): if field not in self._row_data: raise MilvusException(message=f"Field {field} is not in return entity") return self._row_data[field] - def type_of_field(self, field): - raise NotImplementedError('TODO: support field in Hits') + def type_of_field(self, field: Any): + msg = "TODO: support field in Hits" + raise NotImplementedError(msg) def to_dict(self): return {"id": self._id, "distance": self._distance, "entity": self._row_data} class Hit: - def __init__(self, entity_id, entity_row_data, entity_score): + def __init__(self, entity_id: int, entity_row_data: Any, entity_score: float): self._id = entity_id self._row_data = entity_row_data self._score = entity_score @@ -297,7 +292,7 @@ def to_dict(self): class Hits(LoopBase): - def __init__(self, raw, round_decimal=-1): + def __init__(self, raw: Any, round_decimal: int = -1): super().__init__() self._raw = raw if round_decimal != -1: @@ -307,7 +302,10 @@ def __init__(self, raw, round_decimal=-1): self._dynamic_field_name = None self._dynamic_fields = set() - self._dynamic_field_name, self._dynamic_fields = entity_helper.extract_dynamic_field_from_result(self._raw) + ( + self._dynamic_field_name, + self._dynamic_fields, + ) = entity_helper.extract_dynamic_field_from_result(self._raw) def __len__(self): if self._raw.ids.HasField("int_id"): @@ -316,7 +314,7 @@ def __len__(self): return len(self._raw.ids.str_id.data) return 0 - def get__item(self, item): + def get__item(self, item: Any): if self._raw.ids.HasField("int_id"): entity_id = self._raw.ids.int_id.data[item] elif self._raw.ids.HasField("str_id"): @@ -324,8 +322,9 @@ def get__item(self, item): else: raise MilvusException(message="Unsupported ids type") - entity_row_data = entity_helper.extract_row_data_from_fields_data(self._raw.fields_data, item, - self._dynamic_fields) + entity_row_data = entity_helper.extract_row_data_from_fields_data( + self._raw.fields_data, item, self._dynamic_fields + ) entity_score = self._distances[item] return Hit(entity_id, entity_row_data, entity_score) @@ -343,7 +342,7 @@ def distances(self): class MutationResult: - def __init__(self, raw): + def __init__(self, raw: Any): self._raw = raw self._primary_keys = [] self._insert_cnt = 0 @@ -392,8 +391,10 @@ def err_index(self): return self._err_index def __str__(self): - return f"(insert count: {self._insert_cnt}, delete count: {self._delete_cnt}, upsert count: {self._upsert_cnt}, " \ - f"timestamp: {self._timestamp}, success count: {self.succ_count}, err count: {self.err_count})" + return ( + f"(insert count: {self._insert_cnt}, delete count: {self._delete_cnt}, upsert count: {self._upsert_cnt}, " + f"timestamp: {self._timestamp}, success count: {self.succ_count}, err count: {self.err_count})" + ) __repr__ = __str__ @@ -404,8 +405,7 @@ def __str__(self): # def error_reason(self): # pass - def _pack(self, raw): - # self._primary_keys = getattr(raw.IDs, raw.IDs.WhichOneof('id_field')).value.data + def _pack(self, raw: Any): which = raw.IDs.WhichOneof("id_field") if which == "int_id": self._primary_keys = raw.IDs.int_id.data @@ -421,7 +421,7 @@ def _pack(self, raw): class QueryResult(LoopBase): - def __init__(self, raw): + def __init__(self, raw: Any): super().__init__() self._raw = raw self._pack(raw.hits) @@ -429,7 +429,7 @@ def __init__(self, raw): def __len__(self): return self._nq - def _pack(self, raw): + def _pack(self, raw: Any): self._nq = raw.results.num_queries self._topk = raw.results.top_k self._hits = [] @@ -438,52 +438,69 @@ def _pack(self, raw): hit = schema_pb2.SearchResultData() start_pos = offset end_pos = offset + raw.results.topks[i] - hit.scores.append(raw.results.scores[start_pos: end_pos]) + hit.scores.append(raw.results.scores[start_pos:end_pos]) if raw.results.ids.HasField("int_id"): - hit.ids.append(raw.results.ids.int_id.data[start_pos: end_pos]) + hit.ids.append(raw.results.ids.int_id.data[start_pos:end_pos]) elif raw.results.ids.HasField("str_id"): - hit.ids.append(raw.results.ids.str_id.data[start_pos: end_pos]) + hit.ids.append(raw.results.ids.str_id.data[start_pos:end_pos]) for field_data in raw.result.fields_data: field = schema_pb2.FieldData() field.type = field_data.type field.field_name = field_data.field_name if field_data.type == DataType.BOOL: - field.scalars.bool_data.data.extend(field_data.scalars.bool_data.data[start_pos: end_pos]) + field.scalars.bool_data.data.extend( + field_data.scalars.bool_data.data[start_pos:end_pos] + ) elif field_data.type in (DataType.INT8, DataType.INT16, DataType.INT32): - field.scalars.int_data.data.extend(field_data.scalars.int_data.data[start_pos: end_pos]) + field.scalars.int_data.data.extend( + field_data.scalars.int_data.data[start_pos:end_pos] + ) elif field_data.type == DataType.INT64: - field.scalars.long_data.data.extend(field_data.scalars.long_data.data[start_pos: end_pos]) + field.scalars.long_data.data.extend( + field_data.scalars.long_data.data[start_pos:end_pos] + ) elif field_data.type == DataType.FLOAT: - field.scalars.float_data.data.extend(field_data.scalars.float_data.data[start_pos: end_pos]) + field.scalars.float_data.data.extend( + field_data.scalars.float_data.data[start_pos:end_pos] + ) elif field_data.type == DataType.DOUBLE: - field.scalars.double_data.data.extend(field_data.scalars.double_data.data[start_pos: end_pos]) + field.scalars.double_data.data.extend( + field_data.scalars.double_data.data[start_pos:end_pos] + ) elif field_data.type == DataType.VARCHAR: - field.scalars.string_data.data.extend(field_data.scalars.string_data.data[start_pos: end_pos]) + field.scalars.string_data.data.extend( + field_data.scalars.string_data.data[start_pos:end_pos] + ) elif field_data.type == DataType.STRING: raise MilvusException(message="Not support string yet") - # result[field_data.field_name] = field_data.scalars.string_data.data[index] elif field_data.type == DataType.JSON: - field.scalars.json_data.data.extend(field_data.scalars.json_data.data[start_pos: end_pos]) + field.scalars.json_data.data.extend( + field_data.scalars.json_data.data[start_pos:end_pos] + ) elif field_data.type == DataType.FLOAT_VECTOR: dim = field.vectors.dim field.vectors.dim = dim field.vectors.float_vector.data.extend( - field_data.vectors.float_data.data[start_pos * dim: end_pos * dim]) + field_data.vectors.float_data.data[start_pos * dim : end_pos * dim] + ) elif field_data.type == DataType.BINARY_VECTOR: dim = field_data.vectors.dim field.vectors.dim = dim - field.vectors.binary_vector.data.extend(field_data.vectors.binary_vector.data[ - start_pos * (dim / 8): end_pos * (dim / 8)]) + field.vectors.binary_vector.data.extend( + field_data.vectors.binary_vector.data[ + start_pos * (dim / 8) : end_pos * (dim / 8) + ] + ) hit.fields_data.append(field) self._hits.append(hit) offset += raw.results.topks[i] - def get__item(self, item): + def get__item(self, item: Any): return Hits(self._hits[item]) class ChunkedQueryResult(LoopBase): - def __init__(self, raw_list, round_decimal=-1): + def __init__(self, raw_list: List, round_decimal: int = -1): super().__init__() self._raw_list = raw_list self._nq = 0 @@ -494,7 +511,7 @@ def __init__(self, raw_list, round_decimal=-1): def __len__(self): return self._nq - def _pack(self, raw_list): + def _pack(self, raw_list: List): self._hits = [] for raw in raw_list: nq = raw.results.num_queries @@ -506,11 +523,11 @@ def _pack(self, raw_list): hit = schema_pb2.SearchResultData() start_pos = offset end_pos = offset + raw.results.topks[i] - hit.scores.extend(raw.results.scores[start_pos: end_pos]) + hit.scores.extend(raw.results.scores[start_pos:end_pos]) if raw.results.ids.HasField("int_id"): - hit.ids.int_id.data.extend(raw.results.ids.int_id.data[start_pos: end_pos]) + hit.ids.int_id.data.extend(raw.results.ids.int_id.data[start_pos:end_pos]) elif raw.results.ids.HasField("str_id"): - hit.ids.str_id.data.extend(raw.results.ids.str_id.data[start_pos: end_pos]) + hit.ids.str_id.data.extend(raw.results.ids.str_id.data[start_pos:end_pos]) hit.output_fields.extend(raw.results.output_fields) for field_data in raw.results.fields_data: field = schema_pb2.FieldData() @@ -518,389 +535,57 @@ def _pack(self, raw_list): field.field_name = field_data.field_name field.is_dynamic = field_data.is_dynamic if field_data.type == DataType.BOOL: - field.scalars.bool_data.data.extend(field_data.scalars.bool_data.data[start_pos: end_pos]) + field.scalars.bool_data.data.extend( + field_data.scalars.bool_data.data[start_pos:end_pos] + ) elif field_data.type in (DataType.INT8, DataType.INT16, DataType.INT32): - field.scalars.int_data.data.extend(field_data.scalars.int_data.data[start_pos: end_pos]) + field.scalars.int_data.data.extend( + field_data.scalars.int_data.data[start_pos:end_pos] + ) elif field_data.type == DataType.INT64: - field.scalars.long_data.data.extend(field_data.scalars.long_data.data[start_pos: end_pos]) + field.scalars.long_data.data.extend( + field_data.scalars.long_data.data[start_pos:end_pos] + ) elif field_data.type == DataType.FLOAT: - field.scalars.float_data.data.extend(field_data.scalars.float_data.data[start_pos: end_pos]) + field.scalars.float_data.data.extend( + field_data.scalars.float_data.data[start_pos:end_pos] + ) elif field_data.type == DataType.DOUBLE: - field.scalars.double_data.data.extend(field_data.scalars.double_data.data[start_pos: end_pos]) + field.scalars.double_data.data.extend( + field_data.scalars.double_data.data[start_pos:end_pos] + ) elif field_data.type == DataType.VARCHAR: - field.scalars.string_data.data.extend(field_data.scalars.string_data.data[start_pos: end_pos]) + field.scalars.string_data.data.extend( + field_data.scalars.string_data.data[start_pos:end_pos] + ) elif field_data.type == DataType.STRING: raise MilvusException(message="Not support string yet") - # result[field_data.field_name] = field_data.scalars.string_data.data[index] elif field_data.type == DataType.JSON: - field.scalars.json_data.data.extend(field_data.scalars.json_data.data[start_pos: end_pos]) + field.scalars.json_data.data.extend( + field_data.scalars.json_data.data[start_pos:end_pos] + ) elif field_data.type == DataType.FLOAT_VECTOR: dim = field_data.vectors.dim field.vectors.dim = dim - field.vectors.float_vector.data.extend(field_data.vectors.float_vector.data[ - start_pos * dim: end_pos * dim]) + field.vectors.float_vector.data.extend( + field_data.vectors.float_vector.data[start_pos * dim : end_pos * dim] + ) elif field_data.type == DataType.BINARY_VECTOR: dim = field_data.vectors.dim field.vectors.dim = dim - field.vectors.binary_vector.data.extend(field_data.vectors.binary_vector.data[ - start_pos * (dim / 8): end_pos * (dim / 8)]) + field.vectors.binary_vector.data.extend( + field_data.vectors.binary_vector.data[ + start_pos * (dim / 8) : end_pos * (dim / 8) + ] + ) hit.fields_data.append(field) self._hits.append(hit) offset += raw.results.topks[i] - def get__item(self, item): + def get__item(self, item: Any): return Hits(self._hits[item], self.round_decimal) def _abstract(): - raise NotImplementedError('You need to override this function') - - -class ConnectIntf: - """SDK client abstract class - - Connection is a abstract class - - """ - - def connect(self, host, port, uri, timeout): - """ - Connect method should be called before any operations - Server will be connected after connect return OK - Should be implemented - - :type host: str - :param host: host - - :type port: str - :param port: port - - :type uri: str - :param uri: (Optional) uri - - :type timeout: int - :param timeout: - - :return: Status, indicate if connect is successful - """ - _abstract() - - def connected(self): - """ - connected, connection status - Should be implemented - - :return: Status, indicate if connect is successful - """ - _abstract() - - def disconnect(self): - """ - Disconnect, server will be disconnected after disconnect return SUCCESS - Should be implemented - - :return: Status, indicate if connect is successful - """ - _abstract() - - def create_table(self, param, timeout): - """ - Create table - Should be implemented - - :type param: TableSchema - :param param: provide table information to be created - - :type timeout: int - :param timeout: - - :return: Status, indicate if connect is successful - """ - _abstract() - - def has_table(self, table_name, timeout): - """ - - This method is used to test table existence. - Should be implemented - - :type table_name: str - :param table_name: table name is going to be tested. - - :type timeout: int - :param timeout: - - :return: - has_table: bool, if given table_name exists - - """ - _abstract() - - def delete_table(self, table_name, timeout): - """ - Delete table - Should be implemented - - :type table_name: str - :param table_name: table_name of the deleting table - - :type timeout: int - :param timeout: - - :return: Status, indicate if connect is successful - """ - _abstract() - - def add_vectors(self, table_name, records, ids, timeout, **kwargs): - """ - Add vectors to table - Should be implemented - - :type table_name: str - :param table_name: table name been inserted - - :type records: list[RowRecord] - :param records: list of vectors been inserted - - :type ids: list[int] - :param ids: list of ids - - :type timeout: int - :param timeout: - - :returns: - Status : indicate if vectors inserted successfully - ids :list of id, after inserted every vector is given a id - """ - _abstract() - - def search_vectors(self, table_name, top_k, nprobe, query_records, query_ranges, **kwargs): - """ - Query vectors in a table - Should be implemented - - :type table_name: str - :param table_name: table name been queried - - :type query_records: list[RowRecord] - :param query_records: all vectors going to be queried - - :type query_ranges: list[Range] - :param query_ranges: Optional ranges for conditional search. - If not specified, search whole table - - :type top_k: int - :param top_k: how many similar vectors will be searched - - :returns: - Status: indicate if query is successful - query_results: list[TopKQueryResult] - """ - _abstract() - - def search_vectors_in_files(self, table_name, file_ids, query_records, - top_k, nprobe, query_ranges, **kwargs): - """ - Query vectors in a table, query vector in specified files - Should be implemented - - :type table_name: str - :param table_name: table name been queried - - :type file_ids: list[str] - :param file_ids: Specified files id array - - :type query_records: list[RowRecord] - :param query_records: all vectors going to be queried - - :type query_ranges: list[Range] - :param query_ranges: Optional ranges for conditional search. - If not specified, search whole table - - :type top_k: int - :param top_k: how many similar vectors will be searched - - :returns: - Status: indicate if query is successful - query_results: list[TopKQueryResult] - """ - _abstract() - - def describe_table(self, table_name, timeout): - """ - Show table information - Should be implemented - - :type table_name: str - :param table_name: which table to be shown - - :type timeout: int - :param timeout: - - :returns: - Status: indicate if query is successful - table_schema: TableSchema, given when operation is successful - """ - _abstract() - - def get_table_row_count(self, table_name, timeout): - """ - Get table row count - Should be implemented - - :type table_name, str - :param table_name, target table name. - - :type timeout: int - :param timeout: how many similar vectors will be searched - - :returns: - Status: indicate if operation is successful - count: int, table row count - """ - _abstract() - - def show_tables(self, timeout): - """ - Show all tables in database - should be implemented - - :type timeout: int - :param timeout: how many similar vectors will be searched - - :return: - Status: indicate if this operation is successful - tables: list[str], list of table names - """ - _abstract() - - def create_index(self, table_name, index, timeout): - """ - Create specified index in a table - should be implemented - - :type table_name: str - :param table_name: table name - - :type index: dict - :param index: index information dict - - example: index = { - "index_type": IndexType.FLAT, - "nlist": 18384 - } - - :type timeout: int - :param timeout: how many similar vectors will be searched - - :return: - Status: indicate if this operation is successful - - :rtype: Status - """ - _abstract() - - def server_version(self, timeout): - """ - Provide server version - should be implemented - - :type timeout: int - :param timeout: how many similar vectors will be searched - - :return: - Status: indicate if operation is successful - - str : Server version - - :rtype: (Status, str) - """ - _abstract() - - def server_status(self, timeout): - """ - Provide server status. When cmd !='version', provide 'OK' - should be implemented - - :type timeout: int - :param timeout: how many similar vectors will be searched - - :return: - Status: indicate if operation is successful - - str : Server version - - :rtype: (Status, str) - """ - _abstract() - - def preload_table(self, table_name, timeout): - """ - load table to memory cache in advance - should be implemented - - :param table_name: target table name. - :type table_name: str - - :type timeout: int - :param timeout: how many similar vectors will be searched - - :return: - Status: indicate if operation is successful - - ::rtype: Status - """ - - _abstract() - - def describe_index(self, table_name, timeout): - """ - Show index information - should be implemented - - :param table_name: target table name. - :type table_name: str - - :type timeout: int - :param timeout: how many similar vectors will be searched - - :return: - Status: indicate if operation is successful - - TableSchema: table detail information - - :rtype: (Status, TableSchema) - """ - - _abstract() - - def drop_index(self, table_name, timeout): - """ - Show index information - should be implemented - - :param table_name: target table name. - :type table_name: str - - :type timeout: int - :param timeout: how many similar vectors will be searched - - :return: - Status: indicate if operation is successful - - ::rtype: Status - """ - - _abstract() - - def load_collection(self, collection_name, timeout): - _abstract() - - def release_collection(self, collection_name, timeout): - _abstract() - - def load_partitions(self, collection_name, timeout): - _abstract() - - def release_partitions(self, collection_name, timeout): - _abstract() + msg = "You need to override this function" + raise NotImplementedError(msg) diff --git a/pymilvus/client/asynch.py b/pymilvus/client/asynch.py index 6dbe756e5..544d917eb 100644 --- a/pymilvus/client/asynch.py +++ b/pymilvus/client/asynch.py @@ -1,61 +1,67 @@ import abc import threading +from typing import Any, Callable, List, Optional -from .abstract import QueryResult, ChunkedQueryResult, MutationResult -from ..exceptions import MilvusException +from pymilvus.exceptions import MilvusException + +from .abstract import ChunkedQueryResult, MutationResult, QueryResult from .types import Status # TODO: remove this to a common util -def _parameter_is_empty(func): +def _parameter_is_empty(func: Callable): import inspect + sig = inspect.signature(func) - # params = sig.parameters # todo: add more check to parameter, such as `default parameter`, # `positional-only`, `positional-or-keyword`, `keyword-only`, `var-positional`, `var-keyword` # if len(params) == 0: - # return True # for param in params.values(): # if (param.kind == inspect.Parameter.POSITIONAL_ONLY or - # param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD) and \ # param.default == inspect._empty: - # return False return len(sig.parameters) == 0 class AbstractFuture: @abc.abstractmethod def result(self, **kwargs): - '''Return deserialized result. + """Return deserialized result. It's a synchronous interface. It will wait executing until server respond or timeout occur(if specified). This API is thread-safe. - ''' - raise NotImplementedError() + """ + raise NotImplementedError @abc.abstractmethod def cancel(self): - '''Cancle gRPC future. + """Cancle gRPC future. This API is thread-safe. - ''' - raise NotImplementedError() + """ + raise NotImplementedError @abc.abstractmethod def done(self): - '''Wait for request done. + """Wait for request done. This API is thread-safe. - ''' - raise NotImplementedError() + """ + raise NotImplementedError class Future(AbstractFuture): - def __init__(self, future, done_callback=None, pre_exception=None, **kwargs): + def __init__( + self, + future: Any, + done_callback: Optional[Callable] = None, + pre_exception: Optional[Callable] = None, + **kwargs, + ) -> None: self._future = future - self._done_cb = done_callback # keep compatible (such as Future(future, done_callback)), deprecated later + # keep compatible (such as Future(future, done_callback)), deprecated later + self._done_cb = done_callback self._done_cb_list = [] self.add_callback(done_callback) self._condition = threading.Condition() @@ -67,19 +73,18 @@ def __init__(self, future, done_callback=None, pre_exception=None, **kwargs): self._callback_called = False # callback function should be called only once self._kwargs = kwargs - def add_callback(self, func): + def add_callback(self, func: Callable): self._done_cb_list.append(func) - def __del__(self): + def __del__(self) -> None: self._future = None @abc.abstractmethod - def on_response(self, response): - ''' Parse response from gRPC server and return results. - ''' - raise NotImplementedError() + def on_response(self, response: Callable): + """Parse response from gRPC server and return results.""" + raise NotImplementedError - def _callback(self, **kwargs): + def _callback(self): if not self._callback_called: for cb in self._done_cb_list: if cb: @@ -134,7 +139,6 @@ def is_done(self): return self._done def done(self): - # self.exception() with self._condition: if self._future and self._results is None: try: @@ -156,7 +160,7 @@ def exception(self): class SearchFuture(Future): - def on_response(self, response): + def on_response(self, response: Any): if response.status.error_code == 0: return QueryResult(response) @@ -165,9 +169,15 @@ def on_response(self, response): # TODO: if ChunkedFuture is more common later, consider using ChunkedFuture as Base Class, -# then Future(future, done_cb, pre_exception) equal to ChunkedFuture([future], done_cb, pre_exception) +# then Future(future, done_cb, pre_exception) equal +# to ChunkedFuture([future], done_cb, pre_exception) class ChunkedSearchFuture(Future): - def __init__(self, future_list, done_callback=None, pre_exception=None): + def __init__( + self, + future_list: List, + done_callback: Optional[Callable] = None, + pre_exception: Optional[Callable] = None, + ) -> None: super().__init__(None, done_callback, pre_exception) self._future_list = future_list self._response = [] @@ -193,7 +203,8 @@ def result(self, **kwargs): self.exception() if kwargs.get("raw", False) is True: # just return response object received from gRPC - raise AttributeError("Not supported to return response object received from gRPC") + msg = "Not supported to return response object received from gRPC" + raise AttributeError(msg) if self._results: return self._results @@ -210,7 +221,6 @@ def is_done(self): return self._done def done(self): - # self.exception() with self._condition: if self._results is None: try: @@ -234,7 +244,7 @@ def exception(self): if future: future.exception() - def on_response(self, response): + def on_response(self, response: Any): for raw in response: if raw.status.error_code != 0: raise MilvusException(raw.status.error_code, raw.status.reason) @@ -243,7 +253,7 @@ def on_response(self, response): class MutationFuture(Future): - def on_response(self, response): + def on_response(self, response: Any): status = response.status if status.error_code == 0: return MutationResult(response) @@ -253,7 +263,7 @@ def on_response(self, response): class CreateIndexFuture(Future): - def on_response(self, response): + def on_response(self, response: Any): if response.error_code != 0: raise MilvusException(response.error_code, response.reason) @@ -261,7 +271,12 @@ def on_response(self, response): class CreateFlatIndexFuture(AbstractFuture): - def __init__(self, res, done_callback=None, pre_exception=None): + def __init__( + self, + res: Any, + done_callback: Optional[Callable] = None, + pre_exception: Optional[Callable] = None, + ) -> None: self._results = res self._done_cb = done_callback self._done_cb_list = [] @@ -269,16 +284,16 @@ def __init__(self, res, done_callback=None, pre_exception=None): self._condition = threading.Condition() self._exception = pre_exception - def add_callback(self, func): + def add_callback(self, func: Callable): self._done_cb_list.append(func) - def __del__(self): + def __del__(self) -> None: self._results = None - def on_response(self, response): + def on_response(self, response: Any): pass - def result(self, **kwargs): + def result(self): self.exception() with self._condition: for cb in self._done_cb_list: @@ -311,18 +326,18 @@ def exception(self): class FlushFuture(Future): - def on_response(self, response): + def on_response(self, response: Any): if response.status.error_code != 0: raise MilvusException(response.status.error_code, response.status.reason) class LoadCollectionFuture(Future): - def on_response(self, response): + def on_response(self, response: Any): if response.error_code != 0: raise MilvusException(response.error_code, response.reason) class LoadPartitionsFuture(Future): - def on_response(self, response): + def on_response(self, response: Any): if response.error_code != 0: raise MilvusException(response.error_code, response.reason) diff --git a/pymilvus/client/blob.py b/pymilvus/client/blob.py index f82001e83..56fcbc02c 100644 --- a/pymilvus/client/blob.py +++ b/pymilvus/client/blob.py @@ -1,38 +1,13 @@ import struct +from typing import List # reference: https://docs.python.org/3/library/struct.html#struct.pack -def boolToBytes(b): - return struct.pack("?", b) -def int8ToBytes(i): - return struct.pack("b", i) - -def int16ToBytes(i): - return struct.pack("h", i) - -def int32ToBytes(i): - return struct.pack("i", i) - -def int64ToBytes(i): - return struct.pack("q", i) - -def floatToBytes(f): - return struct.pack("f", f) - -def doubleToBytes(d): - return struct.pack("d", d) - -def stringToBytes(s): - return bytes(s, encoding='utf8') - -def vectorBinaryToBytes(v): +def vector_binary_to_bytes(v: bytes): return bytes(v) -def vectorFloatToBytes(v): - # pack len(v) number of float - bs = struct.pack(f'{len(v)}f', *v) - return bs -def bytesToInt64(v): - return struct.unpack("q", v)[0] +def vector_float_to_bytes(v: List[float]): + # pack len(v) number of float + return struct.pack(f"{len(v)}f", *v) diff --git a/pymilvus/client/check.py b/pymilvus/client/check.py index 913dda5ab..3a4e3fdee 100644 --- a/pymilvus/client/check.py +++ b/pymilvus/client/check.py @@ -1,8 +1,10 @@ -import sys import datetime -from typing import Any, Union -from ..exceptions import ParamError -from ..grpc_gen import milvus_pb2 as milvus_types +import sys +from typing import Any, Callable, Union + +from pymilvus.exceptions import ParamError +from pymilvus.grpc_gen import milvus_pb2 as milvus_types + from .singleton_utils import Singleton @@ -42,22 +44,14 @@ def is_legal_port(port: Any) -> bool: def is_legal_vector(array: Any) -> bool: - if not array or \ - not isinstance(array, list) or \ - len(array) == 0: + if not array or not isinstance(array, list) or len(array) == 0: return False - # for v in array: - # if not isinstance(v, float): - # return False - return True def is_legal_bin_vector(array: Any) -> bool: - if not array or \ - not isinstance(array, bytes) or \ - len(array) == 0: + if not array or not isinstance(array, bytes) or len(array) == 0: return False return True @@ -76,7 +70,7 @@ def int_or_str(item: Union[int, str]) -> str: def is_correct_date_str(param: str) -> bool: try: - datetime.datetime.strptime(param, '%Y-%m-%d') + datetime.datetime.strptime(param, "%Y-%m-%d") except ValueError: return False @@ -140,16 +134,18 @@ def is_legal_cmd(cmd: Any) -> bool: def parser_range_date(date: Union[str, datetime.date]) -> str: if isinstance(date, datetime.date): - return date.strftime('%Y-%m-%d') + return date.strftime("%Y-%m-%d") if isinstance(date, str): if not is_correct_date_str(date): - raise ParamError(message='Date string should be YY-MM-DD format!') + raise ParamError(message="Date string should be YY-MM-DD format!") return date - raise ParamError(message='Date should be YY-MM-DD format string or datetime.date, ' - 'or datetime.datetime object') + raise ParamError( + message="Date should be YY-MM-DD format string or datetime.date, " + "or datetime.datetime object" + ) def is_legal_date_range(start: str, end: str) -> bool: @@ -168,21 +164,18 @@ def is_legal_partition_name(tag: Any) -> bool: def is_legal_limit(limit: Any) -> bool: return isinstance(limit, int) and limit > 0 + def is_legal_anns_field(field: Any) -> bool: return field is None or isinstance(field, str) + def is_legal_search_data(data: Any) -> bool: import numpy as np + if not isinstance(data, (list, np.ndarray)): return False - for vector in data: - # list -> float vector - # bytes -> byte vector - if not isinstance(vector, (list, bytes, np.ndarray)): - return False - - return True + return all(isinstance(vector, (list, bytes, np.ndarray)) for vector in data) def is_legal_output_fields(output_fields: Any) -> bool: @@ -192,11 +185,7 @@ def is_legal_output_fields(output_fields: Any) -> bool: if not isinstance(output_fields, list): return False - for field in output_fields: - if not is_legal_field_name(field): - return False - - return True + return all(is_legal_field_name(field) for field in output_fields) def is_legal_partition_name_array(tag_array: Any) -> bool: @@ -206,26 +195,26 @@ def is_legal_partition_name_array(tag_array: Any) -> bool: if not isinstance(tag_array, list): return False - for tag in tag_array: - if not is_legal_partition_name(tag): - return False + return all(is_legal_partition_name(tag) for tag in tag_array) - return True def is_legal_replica_number(replica_number: int) -> bool: return isinstance(replica_number, int) + # https://milvus.io/cn/docs/v1.0.0/metric.md#floating def is_legal_index_metric_type(index_type: str, metric_type: str) -> bool: - if index_type not in ("GPU_IVF_FLAT", - "GPU_IVF_PQ", - "FLAT", - "IVF_FLAT", - "IVF_SQ8", - "IVF_PQ", - "HNSW", - "AUTOINDEX", - "DISKANN"): + if index_type not in ( + "GPU_IVF_FLAT", + "GPU_IVF_PQ", + "FLAT", + "IVF_FLAT", + "IVF_SQ8", + "IVF_PQ", + "HNSW", + "AUTOINDEX", + "DISKANN", + ): return False if metric_type not in ("L2", "IP", "COSINE"): return False @@ -237,9 +226,8 @@ def is_legal_binary_index_metric_type(index_type: str, metric_type: str) -> bool if index_type == "BIN_FLAT": if metric_type in ("JACCARD", "TANIMOTO", "HAMMING", "SUBSTRUCTURE", "SUPERSTRUCTURE"): return True - elif index_type == "BIN_IVF_FLAT": - if metric_type in ("JACCARD", "TANIMOTO", "HAMMING"): - return True + elif index_type == "BIN_IVF_FLAT" and metric_type in ("JACCARD", "TANIMOTO", "HAMMING"): + return True return False @@ -258,11 +246,12 @@ def is_legal_travel_timestamp(ts: Any) -> bool: def is_legal_guarantee_timestamp(ts: Any) -> bool: return ts is None or isinstance(ts, int) and ts >= 0 -def is_legal_user(user) -> bool: + +def is_legal_user(user: Any) -> bool: return isinstance(user, str) -def is_legal_password(password) -> bool: +def is_legal_password(password: Any) -> bool: return isinstance(password, str) @@ -271,8 +260,10 @@ def is_legal_role_name(role_name: Any) -> bool: def is_legal_operate_user_role_type(operate_user_role_type: Any) -> bool: - return operate_user_role_type in \ - (milvus_types.OperateUserRoleType.AddUserToRole, milvus_types.OperateUserRoleType.RemoveUserFromRole) + return operate_user_role_type in ( + milvus_types.OperateUserRoleType.AddUserToRole, + milvus_types.OperateUserRoleType.RemoveUserFromRole, + ) def is_legal_include_user_info(include_user_info: Any) -> bool: @@ -300,8 +291,10 @@ def is_legal_collection_properties(properties: Any) -> bool: def is_legal_operate_privilege_type(operate_privilege_type: Any) -> bool: - return operate_privilege_type in \ - (milvus_types.OperatePrivilegeType.Grant, milvus_types.OperatePrivilegeType.Revoke) + return operate_privilege_type in ( + milvus_types.OperatePrivilegeType.Grant, + milvus_types.OperatePrivilegeType.Revoke, + ) class ParamChecker(metaclass=Singleton): @@ -341,7 +334,7 @@ def __init__(self) -> None: "resource_group_name": is_legal_table_name, } - def check(self, key, value): + def check(self, key: str, value: Callable): if key in self.check_dict: if not self.check_dict[key](value): _raise_param_error(key, value) diff --git a/pymilvus/client/constants.py b/pymilvus/client/constants.py index b65cc34bf..1ac430170 100644 --- a/pymilvus/client/constants.py +++ b/pymilvus/client/constants.py @@ -1,4 +1,4 @@ -from ..grpc_gen import common_pb2 +from pymilvus.grpc_gen import common_pb2 ConsistencyLevel = common_pb2.ConsistencyLevel diff --git a/pymilvus/client/entity_helper.py b/pymilvus/client/entity_helper.py index 26ccbbba1..f8ae9fa06 100644 --- a/pymilvus/client/entity_helper.py +++ b/pymilvus/client/entity_helper.py @@ -1,13 +1,18 @@ -import ujson +from typing import Any, Dict, List, Optional + import numpy as np +import ujson + +from pymilvus.exceptions import MilvusException, ParamError +from pymilvus.grpc_gen import schema_pb2 as schema_types +from pymilvus.settings import Config -from ..grpc_gen import schema_pb2 as schema_types from .types import DataType -from ..exceptions import ParamError, MilvusException -from ..settings import Config + +CHECK_STR_ARRAY = True -def entity_type_to_dtype(entity_type): +def entity_type_to_dtype(entity_type: Any): if isinstance(entity_type, int): return entity_type if isinstance(entity_type, str): @@ -16,24 +21,26 @@ def entity_type_to_dtype(entity_type): raise ParamError(message=f"invalid entity type: {entity_type}") -def get_max_len_of_var_char(field_info) -> int: +def get_max_len_of_var_char(field_info: Dict) -> int: k = Config.MaxVarCharLengthKey v = Config.MaxVarCharLength return field_info.get("params", {}).get(k, v) -def check_str_arr(str_arr, max_len): +def check_str_arr(str_arr: Any, max_len: int): for s in str_arr: if not isinstance(s, str): raise ParamError(message=f"expect string input, got: {type(s)}") if len(s) > max_len: - raise ParamError(message=f"invalid input, length of string exceeds max length. length: {len(s)}, " - f"max length: {max_len}") + raise ParamError( + message=f"invalid input, length of string exceeds max length. " + f"length: {len(s)}, max length: {max_len}" + ) -def convert_to_str_array(orig_str_arr, field_info, check=True): +def convert_to_str_array(orig_str_arr: Any, field_info: Any, check: bool = True): arr = [] - if Config.EncodeProtocol.lower() != 'utf-8'.lower(): + if Config.EncodeProtocol.lower() != "utf-8".lower(): for s in orig_str_arr: arr.append(s.encode(Config.EncodeProtocol)) else: @@ -44,88 +51,83 @@ def convert_to_str_array(orig_str_arr, field_info, check=True): return arr -def entity_to_str_arr(entity, field_info, check=True): +def entity_to_str_arr(entity: Any, field_info: Any, check: bool = True): return convert_to_str_array(entity.get("values", []), field_info, check=check) -def convert_to_json(obj): +def convert_to_json(obj: object): return ujson.dumps(obj, ensure_ascii=False).encode(Config.EncodeProtocol) -def convert_to_json_arr(objs): +def convert_to_json_arr(objs: List[object]): arr = [] for obj in objs: arr.append(ujson.dumps(obj, ensure_ascii=False).encode(Config.EncodeProtocol)) return arr -def entity_to_json_arr(entity): +def entity_to_json_arr(entity: Dict): return convert_to_json_arr(entity.get("values", [])) -def pack_field_value_to_field_data(field_value, field_data, field_info): +def pack_field_value_to_field_data(field_value: Any, field_data: Any, field_info: Any): field_type = field_data.type - if field_type in (DataType.BOOL,): + if field_type == DataType.BOOL: field_data.scalars.bool_data.data.append(field_value) - elif field_type in (DataType.INT8,): + elif field_type in (DataType.INT8, DataType.INT16, DataType.INT32): field_data.scalars.int_data.data.append(field_value) - elif field_type in (DataType.INT16,): - field_data.scalars.int_data.data.append(field_value) - elif field_type in (DataType.INT32,): - field_data.scalars.int_data.data.append(field_value) - elif field_type in (DataType.INT64,): + elif field_type == DataType.INT64: field_data.scalars.long_data.data.append(field_value) - elif field_type in (DataType.FLOAT,): + elif field_type == DataType.FLOAT: field_data.scalars.float_data.data.append(field_value) - elif field_type in (DataType.DOUBLE,): + elif field_type == DataType.DOUBLE: field_data.scalars.double_data.data.append(field_value) - elif field_type in (DataType.FLOAT_VECTOR,): + elif field_type == DataType.FLOAT_VECTOR: field_data.vectors.dim = len(field_value) field_data.vectors.float_vector.data.extend(field_value) - elif field_type in (DataType.BINARY_VECTOR,): + elif field_type == DataType.BINARY_VECTOR: field_data.vectors.dim = len(field_value) * 8 field_data.vectors.binary_vector += bytes(field_value) - elif field_type in (DataType.VARCHAR,): + elif field_type == DataType.VARCHAR: field_data.scalars.string_data.data.append( - convert_to_str_array(field_value, field_info, True)) - elif field_type in (DataType.JSON,): + convert_to_str_array(field_value, field_info, CHECK_STR_ARRAY) + ) + elif field_type == DataType.JSON: field_data.scalars.json_data.data.append(convert_to_json(field_value)) else: raise ParamError(message=f"UnSupported data type: {field_type}") # TODO: refactor here. -def entity_to_field_data(entity, field_info): +def entity_to_field_data(entity: Any, field_info: Any): field_data = schema_types.FieldData() entity_type = entity.get("type") field_data.field_name = entity.get("name") field_data.type = entity_type_to_dtype(entity_type) - if entity_type in (DataType.BOOL,): + if entity_type == DataType.BOOL: field_data.scalars.bool_data.data.extend(entity.get("values")) - elif entity_type in (DataType.INT8,): - field_data.scalars.int_data.data.extend(entity.get("values")) - elif entity_type in (DataType.INT16,): - field_data.scalars.int_data.data.extend(entity.get("values")) - elif entity_type in (DataType.INT32,): + elif entity_type in (DataType.INT8, DataType.INT16, DataType.INT32): field_data.scalars.int_data.data.extend(entity.get("values")) - elif entity_type in (DataType.INT64,): + elif entity_type == DataType.INT64: field_data.scalars.long_data.data.extend(entity.get("values")) - elif entity_type in (DataType.FLOAT,): + elif entity_type == DataType.FLOAT: field_data.scalars.float_data.data.extend(entity.get("values")) - elif entity_type in (DataType.DOUBLE,): + elif entity_type == DataType.DOUBLE: field_data.scalars.double_data.data.extend(entity.get("values")) - elif entity_type in (DataType.FLOAT_VECTOR,): + elif entity_type == DataType.FLOAT_VECTOR: field_data.vectors.dim = len(entity.get("values")[0]) all_floats = [f for vector in entity.get("values") for f in vector] field_data.vectors.float_vector.data.extend(all_floats) - elif entity_type in (DataType.BINARY_VECTOR,): + elif entity_type == DataType.BINARY_VECTOR: field_data.vectors.dim = len(entity.get("values")[0]) * 8 - field_data.vectors.binary_vector = b''.join(entity.get("values")) - elif entity_type in (DataType.VARCHAR,): - field_data.scalars.string_data.data.extend(entity_to_str_arr(entity, field_info, True)) - elif entity_type in (DataType.JSON,): + field_data.vectors.binary_vector = b"".join(entity.get("values")) + elif entity_type == DataType.VARCHAR: + field_data.scalars.string_data.data.extend( + entity_to_str_arr(entity, field_info, CHECK_STR_ARRAY) + ) + elif entity_type == DataType.JSON: field_data.scalars.json_data.data.extend(entity_to_json_arr(entity)) else: raise ParamError(message=f"UnSupported data type: {entity_type}") @@ -133,7 +135,7 @@ def entity_to_field_data(entity, field_info): return field_data -def extract_dynamic_field_from_result(raw): +def extract_dynamic_field_from_result(raw: Any): dynamic_field_name = None field_names = set() if raw.fields_data: @@ -153,61 +155,81 @@ def extract_dynamic_field_from_result(raw): # pylint: disable=R1702 (too-many-nested-blocks) -def extract_row_data_from_fields_data(fields_data, index, dynamic_output_fields=None): +def extract_row_data_from_fields_data( + fields_data: Any, + index: Any, + dynamic_output_fields: Optional[List] = None, +): if not fields_data: return {} entity_row_data = {} dynamic_fields = dynamic_output_fields or set() - for field_data in fields_data: - if field_data.type == DataType.BOOL: - if len(field_data.scalars.bool_data.data) >= index: - entity_row_data[field_data.field_name] = field_data.scalars.bool_data.data[index] - elif field_data.type in (DataType.INT8, DataType.INT16, DataType.INT32): - if len(field_data.scalars.int_data.data) >= index: - entity_row_data[field_data.field_name] = field_data.scalars.int_data.data[index] - elif field_data.type == DataType.INT64: - if len(field_data.scalars.long_data.data) >= index: - entity_row_data[field_data.field_name] = field_data.scalars.long_data.data[index] - elif field_data.type == DataType.FLOAT: - if len(field_data.scalars.float_data.data) >= index: - entity_row_data[field_data.field_name] = np.single(field_data.scalars.float_data.data[index]) - elif field_data.type == DataType.DOUBLE: - if len(field_data.scalars.double_data.data) >= index: - entity_row_data[field_data.field_name] = field_data.scalars.double_data.data[index] - elif field_data.type == DataType.VARCHAR: - if len(field_data.scalars.string_data.data) >= index: - entity_row_data[field_data.field_name] = field_data.scalars.string_data.data[index] - elif field_data.type == DataType.STRING: + + def check_append(field_data: Any): + if field_data.type == DataType.STRING: raise MilvusException(message="Not support string yet") - # result[field_data.field_name] = field_data.scalars.string_data.data[index] - elif field_data.type == DataType.JSON: - if len(field_data.scalars.json_data.data) >= index: - json_value = field_data.scalars.json_data.data[index] - json_dict = ujson.loads(json_value) - if field_data.is_dynamic: - if dynamic_fields: - for key in json_dict: - if key in dynamic_fields: - entity_row_data[key] = json_dict[key] - else: - entity_row_data.update(json_dict) - continue + + if field_data.type == DataType.BOOL and len(field_data.scalars.bool_data.data) >= index: + entity_row_data[field_data.field_name] = field_data.scalars.bool_data.data[index] + return + + if ( + field_data.type in (DataType.INT8, DataType.INT16, DataType.INT32) + and len(field_data.scalars.int_data.data) >= index + ): + entity_row_data[field_data.field_name] = field_data.scalars.int_data.data[index] + return + + if field_data.type == DataType.INT64 and len(field_data.scalars.long_data.data) >= index: + entity_row_data[field_data.field_name] = field_data.scalars.long_data.data[index] + return + + if field_data.type == DataType.FLOAT and len(field_data.scalars.float_data.data) >= index: + entity_row_data[field_data.field_name] = np.single( + field_data.scalars.float_data.data[index] + ) + return + + if field_data.type == DataType.DOUBLE and len(field_data.scalars.double_data.data) >= index: + entity_row_data[field_data.field_name] = field_data.scalars.double_data.data[index] + return + + if ( + field_data.type == DataType.VARCHAR + and len(field_data.scalars.string_data.data) >= index + ): + entity_row_data[field_data.field_name] = field_data.scalars.string_data.data[index] + return + + if field_data.type == DataType.JSON and len(field_data.scalars.json_data.data) >= index: + json_value = field_data.scalars.json_data.data[index] + json_dict = ujson.loads(json_value) + + if not field_data.is_dynamic: entity_row_data[field_data.field_name] = json_dict - elif field_data.type == DataType.FLOAT_VECTOR: + return + + tmp_dict = {k: v for k, v in json_dict.items() if k in dynamic_fields} + entity_row_data.update(tmp_dict) + return + + if field_data.type == DataType.FLOAT_VECTOR: dim = field_data.vectors.dim if len(field_data.vectors.float_vector.data) >= index * dim: - start_pos = index * dim - end_pos = index * dim + dim - entity_row_data[field_data.field_name] = [np.single(x) for x in - field_data.vectors.float_vector.data[ - start_pos:end_pos]] + start_pos, end_pos = index * dim, (index + 1) * dim + entity_row_data[field_data.field_name] = [ + np.single(x) for x in field_data.vectors.float_vector.data[start_pos:end_pos] + ] elif field_data.type == DataType.BINARY_VECTOR: dim = field_data.vectors.dim if len(field_data.vectors.binary_vector) >= index * (dim // 8): - start_pos = index * (dim // 8) - end_pos = (index + 1) * (dim // 8) + start_pos, end_pos = index * (dim // 8), (index + 1) * (dim // 8) entity_row_data[field_data.field_name] = [ - field_data.vectors.binary_vector[start_pos:end_pos]] + field_data.vectors.binary_vector[start_pos:end_pos] + ] + + for field_data in fields_data: + check_append(field_data) return entity_row_data diff --git a/pymilvus/client/grpc_handler.py b/pymilvus/client/grpc_handler.py index c3a774234..02088fbce 100644 --- a/pymilvus/client/grpc_handler.py +++ b/pymilvus/client/grpc_handler.py @@ -1,75 +1,79 @@ -import time -import json -import copy import base64 -from urllib import parse +import copy +import json import socket +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Union +from urllib import parse import grpc from grpc._cython import cygrpc -from ..grpc_gen import milvus_pb2_grpc -from ..grpc_gen import milvus_pb2 as milvus_types -from ..grpc_gen import common_pb2 +from pymilvus.decorators import retry_on_rpc_failure, upgrade_reminder +from pymilvus.exceptions import ( + AmbiguousIndexName, + DescribeCollectionException, + ExceptionsMessage, + MilvusException, + ParamError, +) +from pymilvus.grpc_gen import common_pb2, milvus_pb2_grpc +from pymilvus.grpc_gen import milvus_pb2 as milvus_types +from pymilvus.settings import Config -from .abstract import CollectionSchema, ChunkedQueryResult, MutationResult +from . import entity_helper, interceptor, ts_utils +from .abstract import ChunkedQueryResult, CollectionSchema, MutationResult +from .asynch import ( + ChunkedSearchFuture, + CreateIndexFuture, + FlushFuture, + LoadPartitionsFuture, + MutationFuture, + SearchFuture, +) from .check import ( + check_pass_param, is_legal_host, is_legal_port, - check_pass_param, ) from .prepare import Prepare from .types import ( - Status, + BulkInsertState, + CompactionPlans, + CompactionState, + DataType, + GrantInfo, + Group, IndexState, LoadState, - DataType, - CompactionState, - State, - CompactionPlans, Plan, - Replica, Shard, Group, - GrantInfo, UserInfo, RoleInfo, - BulkInsertState, + Replica, ResourceGroupInfo, + RoleInfo, + Shard, + State, + Status, + UserInfo, ) - from .utils import ( check_invalid_binary_vector, - len_of, get_server_type, + len_of, ) -from ..settings import Config -from . import ts_utils -from . import interceptor - -from .asynch import ( - SearchFuture, - MutationFuture, - CreateIndexFuture, - FlushFuture, - LoadPartitionsFuture, - ChunkedSearchFuture -) - -from ..exceptions import ( - ExceptionsMessage, - ParamError, - DescribeCollectionException, - MilvusException, - AmbiguousIndexName, -) - -from ..decorators import retry_on_rpc_failure, upgrade_reminder -from . import entity_helper - class GrpcHandler: - # pylint: disable=too-many-instance-attributes - def __init__(self, uri=Config.GRPC_URI, host="", port="", channel=None, **kwargs): + def __init__( + self, + uri: str = Config.GRPC_URI, + host: str = "", + port: str = "", + channel: Optional[grpc.Channel] = None, + **kwargs, + ) -> None: self._stub = None self._channel = channel @@ -88,7 +92,7 @@ def __get_address(self, uri: str, host: str, port: str) -> str: try: parsed_uri = parse.urlparse(uri) - except (Exception) as e: + except Exception as e: raise ParamError(message=f"Illegal uri: [{uri}], {e}") from e return parsed_uri.netloc @@ -104,51 +108,57 @@ def _set_authorization(self, **kwargs): self._server_name = kwargs.get("server_name", "") self._authorization_interceptor = None - self._setup_authorization_interceptor(kwargs.get("user", None), - kwargs.get("password", None), - kwargs.get("token", None)) + self._setup_authorization_interceptor( + kwargs.get("user", None), + kwargs.get("password", None), + kwargs.get("token", None), + ) def __enter__(self): return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): pass - def _wait_for_channel_ready(self, timeout=10): - if self._channel is not None: - try: - grpc.channel_ready_future(self._channel).result(timeout=timeout) - self._setup_identifier_interceptor(self._user) - return - except (grpc.FutureTimeoutError, MilvusException) as e: - raise MilvusException(Status.CONNECT_FAILED, - f'Fail connecting to server on {self._address}. Timeout') from e + def _wait_for_channel_ready(self, timeout: Union[int, float] = 10): + if self._channel is None: + raise MilvusException( + code=Status.CONNECT_FAILED, + message="No channel in handler, please setup grpc channel first", + ) - raise MilvusException(Status.CONNECT_FAILED, 'No channel in handler, please setup grpc channel first') + try: + grpc.channel_ready_future(self._channel).result(timeout=timeout) + self._setup_identifier_interceptor(self._user) + except (grpc.FutureTimeoutError, MilvusException) as e: + raise MilvusException( + code=Status.CONNECT_FAILED, + message=f"Fail connecting to server on {self._address}. Timeout", + ) from e def close(self): self._channel.close() - def reset_db_name(self, db_name): + def reset_db_name(self, db_name: str): self._setup_db_interceptor(db_name) self._setup_grpc_channel() self._setup_identifier_interceptor(self._user) - def _setup_authorization_interceptor(self, user, password, token): + def _setup_authorization_interceptor(self, user: str, password: str, token: str): keys = [] values = [] if token: - authorization = base64.b64encode(f"{token}".encode('utf-8')) + authorization = base64.b64encode(f"{token}".encode()) keys.append("authorization") values.append(authorization) elif user and password: - authorization = base64.b64encode(f"{user}:{password}".encode('utf-8')) + authorization = base64.b64encode(f"{user}:{password}".encode()) keys.append("authorization") values.append(authorization) if len(keys) > 0 and len(values) > 0: self._authorization_interceptor = interceptor.header_adder_interceptor(keys, values) - def _setup_db_interceptor(self, db_name): + def _setup_db_interceptor(self, db_name: str): if db_name is None: self._db_interceptor = None else: @@ -156,82 +166,108 @@ def _setup_db_interceptor(self, db_name): self._db_interceptor = interceptor.header_adder_interceptor(["dbname"], [db_name]) def _setup_grpc_channel(self): - """ Create a ddl grpc channel """ + """Create a ddl grpc channel""" if self._channel is None: - opts = [(cygrpc.ChannelArgKey.max_send_message_length, -1), - (cygrpc.ChannelArgKey.max_receive_message_length, -1), - ('grpc.enable_retries', 1), - ('grpc.keepalive_time_ms', 55000), - ] + opts = [ + (cygrpc.ChannelArgKey.max_send_message_length, -1), + (cygrpc.ChannelArgKey.max_receive_message_length, -1), + ("grpc.enable_retries", 1), + ("grpc.keepalive_time_ms", 55000), + ] if not self._secure: self._channel = grpc.insecure_channel( self._address, options=opts, ) else: - if self._client_pem_path != "" and self._client_key_path != "" and self._ca_pem_path != "" \ - and self._server_name != "": - opts.append(('grpc.ssl_target_name_override', self._server_name,), ) - with open(self._client_pem_path, 'rb') as f: + if ( + self._client_pem_path != "" + and self._client_key_path != "" + and self._ca_pem_path != "" + and self._server_name != "" + ): + opts.append(("grpc.ssl_target_name_override", self._server_name)) + with Path(self._client_pem_path).open("rb") as f: certificate_chain = f.read() - with open(self._client_key_path, 'rb') as f: + with Path(self._client_key_path).open("rb") as f: private_key = f.read() - with open(self._ca_pem_path, 'rb') as f: + with Path(self._ca_pem_path).open("rb") as f: root_certificates = f.read() - creds = grpc.ssl_channel_credentials(root_certificates, private_key, certificate_chain) + creds = grpc.ssl_channel_credentials( + root_certificates, private_key, certificate_chain + ) elif self._server_pem_path != "" and self._server_name != "": - opts.append(('grpc.ssl_target_name_override', self._server_name,), ) - with open(self._server_pem_path, 'rb') as f: + opts.append(("grpc.ssl_target_name_override", self._server_name)) + with Path(self._server_pem_path).open("rb") as f: server_pem = f.read() creds = grpc.ssl_channel_credentials(root_certificates=server_pem) else: - creds = grpc.ssl_channel_credentials(root_certificates=None, private_key=None, - certificate_chain=None) + creds = grpc.ssl_channel_credentials( + root_certificates=None, private_key=None, certificate_chain=None + ) self._channel = grpc.secure_channel( self._address, creds, - options=opts + options=opts, ) # avoid to add duplicate headers. self._final_channel = self._channel if self._authorization_interceptor: - self._final_channel = grpc.intercept_channel(self._final_channel, self._authorization_interceptor) + self._final_channel = grpc.intercept_channel( + self._final_channel, self._authorization_interceptor + ) if self._db_interceptor: self._final_channel = grpc.intercept_channel(self._final_channel, self._db_interceptor) if self._log_level: - log_level_interceptor = interceptor.header_adder_interceptor(["log_level"], [self._log_level]) + log_level_interceptor = interceptor.header_adder_interceptor( + ["log_level"], [self._log_level] + ) self._final_channel = grpc.intercept_channel(self._final_channel, log_level_interceptor) self._log_level = None if self._request_id: - request_id_interceptor = interceptor.header_adder_interceptor(["client_request_id"], [self._request_id]) - self._final_channel = grpc.intercept_channel(self._final_channel, request_id_interceptor) + request_id_interceptor = interceptor.header_adder_interceptor( + ["client_request_id"], [self._request_id] + ) + self._final_channel = grpc.intercept_channel( + self._final_channel, request_id_interceptor + ) self._request_id = None self._stub = milvus_pb2_grpc.MilvusServiceStub(self._final_channel) - def set_onetime_loglevel(self, log_level): + def set_onetime_loglevel(self, log_level: str): self._log_level = log_level self._setup_grpc_channel() - def set_onetime_request_id(self, req_id): + def set_onetime_request_id(self, req_id: int): self._request_id = req_id self._setup_grpc_channel() - def _setup_identifier_interceptor(self, user): + def _setup_identifier_interceptor(self, user: str): host = socket.gethostname() self._identifier = self.__internal_register(user, host) - self._identifier_interceptor = interceptor.header_adder_interceptor(["identifier"], [str(self._identifier)]) - self._final_channel = grpc.intercept_channel(self._final_channel, self._identifier_interceptor) + self._identifier_interceptor = interceptor.header_adder_interceptor( + ["identifier"], [str(self._identifier)] + ) + self._final_channel = grpc.intercept_channel( + self._final_channel, self._identifier_interceptor + ) self._stub = milvus_pb2_grpc.MilvusServiceStub(self._final_channel) @property def server_address(self): - """ Server network address """ + """Server network address""" return self._address def get_server_type(self): - return get_server_type(self.server_address.split(':')[0]) - - def reset_password(self, user, old_password, new_password, timeout=None): + return get_server_type(self.server_address.split(":")[0]) + + def reset_password( + self, + user: str, + old_password: str, + new_password: str, + timeout: Optional[float] = None, + ): """ reset password and then setup the grpc channel. """ @@ -240,7 +276,9 @@ def reset_password(self, user, old_password, new_password, timeout=None): self._setup_grpc_channel() @retry_on_rpc_failure() - def create_collection(self, collection_name, fields, timeout=None, **kwargs): + def create_collection( + self, collection_name: str, fields: List, timeout: Optional[float] = None, **kwargs + ): check_pass_param(collection_name=collection_name) request = Prepare.create_collection_request(collection_name, fields, **kwargs) @@ -250,9 +288,10 @@ def create_collection(self, collection_name, fields, timeout=None, **kwargs): status = rf.result() if status.error_code != 0: raise MilvusException(status.error_code, status.reason) + return None @retry_on_rpc_failure() - def drop_collection(self, collection_name, timeout=None): + def drop_collection(self, collection_name: str, timeout: Optional[float] = None): check_pass_param(collection_name=collection_name) request = Prepare.drop_collection_request(collection_name) @@ -262,7 +301,9 @@ def drop_collection(self, collection_name, timeout=None): raise MilvusException(status.error_code, status.reason) @retry_on_rpc_failure() - def alter_collection(self, collection_name, properties, timeout=None, **kwargs): + def alter_collection( + self, collection_name: str, properties: List, timeout: Optional[float] = None, **kwargs + ): check_pass_param(collection_name=collection_name, properties=properties) request = Prepare.alter_collection_request(collection_name, properties) rf = self._stub.AlterCollection.future(request, timeout=timeout) @@ -271,7 +312,7 @@ def alter_collection(self, collection_name, properties, timeout=None, **kwargs): raise MilvusException(status.error_code, status.reason) @retry_on_rpc_failure() - def has_collection(self, collection_name, timeout=None, **kwargs): + def has_collection(self, collection_name: str, timeout: Optional[float] = None, **kwargs): check_pass_param(collection_name=collection_name) request = Prepare.describe_collection_request(collection_name) rf = self._stub.DescribeCollection.future(request, timeout=timeout) @@ -281,13 +322,16 @@ def has_collection(self, collection_name, timeout=None, **kwargs): return True # TODO: Workaround for unreasonable describe collection results and error_code - if reply.status.error_code == common_pb2.UnexpectedError and "can\'t find collection" in reply.status.reason: + if ( + reply.status.error_code == common_pb2.UnexpectedError + and "can't find collection" in reply.status.reason + ): return False raise MilvusException(reply.status.error_code, reply.status.reason) @retry_on_rpc_failure() - def describe_collection(self, collection_name, timeout=None, **kwargs): + def describe_collection(self, collection_name: str, timeout: Optional[float] = None, **kwargs): check_pass_param(collection_name=collection_name) request = Prepare.describe_collection_request(collection_name) rf = self._stub.DescribeCollection.future(request, timeout=timeout) @@ -300,7 +344,7 @@ def describe_collection(self, collection_name, timeout=None, **kwargs): raise DescribeCollectionException(status.error_code, status.reason) @retry_on_rpc_failure() - def list_collections(self, timeout=None): + def list_collections(self, timeout: Optional[float] = None): request = Prepare.show_collections_request() rf = self._stub.ShowCollections.future(request, timeout=timeout) response = rf.result() @@ -311,10 +355,10 @@ def list_collections(self, timeout=None): raise MilvusException(status.error_code, status.reason) @retry_on_rpc_failure() - def rename_collections(self, old_name=None, new_name=None, timeout=None): + def rename_collections(self, old_name: str, new_name: str, timeout: Optional[float] = None): check_pass_param(collection_name=new_name) check_pass_param(collection_name=old_name) - request = Prepare().rename_collections_request(old_name, new_name) + request = Prepare.rename_collections_request(old_name, new_name) rf = self._stub.RenameCollection.future(request, timeout=timeout) response = rf.result() @@ -322,7 +366,9 @@ def rename_collections(self, old_name=None, new_name=None, timeout=None): raise MilvusException(response.error_code, response.reason) @retry_on_rpc_failure() - def create_partition(self, collection_name, partition_name, timeout=None): + def create_partition( + self, collection_name: str, partition_name: str, timeout: Optional[float] = None + ): check_pass_param(collection_name=collection_name, partition_name=partition_name) request = Prepare.create_partition_request(collection_name, partition_name) rf = self._stub.CreatePartition.future(request, timeout=timeout) @@ -331,7 +377,9 @@ def create_partition(self, collection_name, partition_name, timeout=None): raise MilvusException(response.error_code, response.reason) @retry_on_rpc_failure() - def drop_partition(self, collection_name, partition_name, timeout=None): + def drop_partition( + self, collection_name: str, partition_name: str, timeout: Optional[float] = None + ): check_pass_param(collection_name=collection_name, partition_name=partition_name) request = Prepare.drop_partition_request(collection_name, partition_name) @@ -342,7 +390,9 @@ def drop_partition(self, collection_name, partition_name, timeout=None): raise MilvusException(response.error_code, response.reason) @retry_on_rpc_failure() - def has_partition(self, collection_name, partition_name, timeout=None, **kwargs): + def has_partition( + self, collection_name: str, partition_name: str, timeout: Optional[float] = None, **kwargs + ): check_pass_param(collection_name=collection_name, partition_name=partition_name) request = Prepare.has_partition_request(collection_name, partition_name) rf = self._stub.HasPartition.future(request, timeout=timeout) @@ -355,7 +405,9 @@ def has_partition(self, collection_name, partition_name, timeout=None, **kwargs) # TODO: this is not inuse @retry_on_rpc_failure() - def get_partition_info(self, collection_name, partition_name, timeout=None): + def get_partition_info( + self, collection_name: str, partition_name: str, timeout: Optional[float] = None + ): request = Prepare.partition_stats_request(collection_name, partition_name) rf = self._stub.DescribePartition.future(request, timeout=timeout) response = rf.result() @@ -369,7 +421,7 @@ def get_partition_info(self, collection_name, partition_name, timeout=None): raise MilvusException(status.error_code, status.reason) @retry_on_rpc_failure() - def list_partitions(self, collection_name, timeout=None): + def list_partitions(self, collection_name: str, timeout: Optional[float] = None): check_pass_param(collection_name=collection_name) request = Prepare.show_partitions_request(collection_name) @@ -382,7 +434,9 @@ def list_partitions(self, collection_name, timeout=None): raise MilvusException(status.error_code, status.reason) @retry_on_rpc_failure() - def get_partition_stats(self, collection_name, partition_name, timeout=None, **kwargs): + def get_partition_stats( + self, collection_name: str, partition_name: str, timeout: Optional[float] = None, **kwargs + ): check_pass_param(collection_name=collection_name) req = Prepare.get_partition_stats_request(collection_name, partition_name) future = self._stub.GetPartitionStatistics.future(req, timeout=timeout) @@ -393,28 +447,46 @@ def get_partition_stats(self, collection_name, partition_name, timeout=None, **k raise MilvusException(status.error_code, status.reason) - def _prepare_row_insert_or_upsert_request(self, collection_name, rows, partition_name=None, timeout=None, - is_insert=True, **kwargs): + def _prepare_row_insert_or_upsert_request( + self, + collection_name: str, + rows: List, + partition_name: Optional[str] = None, + timeout: Optional[float] = None, + is_insert: bool = True, + **kwargs, + ): if not isinstance(rows, list): raise ParamError(message="None rows, please provide valid row data.") collection_schema = kwargs.get("schema", None) if not collection_schema: - collection_schema = self.describe_collection( - collection_name, timeout=timeout, **kwargs) + collection_schema = self.describe_collection(collection_name, timeout=timeout, **kwargs) fields_info = collection_schema["fields"] enable_dynamic = collection_schema.get("enable_dynamic_field", False) - request = Prepare.row_insert_or_upsert_param(collection_name, rows, partition_name, fields_info, is_insert, - enable_dynamic=enable_dynamic) - return request + return Prepare.row_insert_or_upsert_param( + collection_name, + rows, + partition_name, + fields_info, + is_insert, + enable_dynamic=enable_dynamic, + ) - def _prepare_batch_insert_or_upsert_request(self, collection_name, entities, partition_name=None, timeout=None, - is_insert=True, **kwargs): - param = kwargs.get('insert_param', None) + def _prepare_batch_insert_or_upsert_request( + self, + collection_name: str, + entities: List, + partition_name: Optional[str] = None, + timeout: Optional[float] = None, + is_insert: bool = True, + **kwargs, + ): + param = kwargs.get("insert_param", None) if not is_insert: - param = kwargs.get('upsert_param', None) + param = kwargs.get("upsert_param", None) if param and not isinstance(param, milvus_types.RowBatch): if is_insert: @@ -425,42 +497,57 @@ def _prepare_batch_insert_or_upsert_request(self, collection_name, entities, par collection_schema = kwargs.get("schema", None) if not collection_schema: - collection_schema = self.describe_collection( - collection_name, timeout=timeout, **kwargs) + collection_schema = self.describe_collection(collection_name, timeout=timeout, **kwargs) fields_info = collection_schema["fields"] - request = param if param \ - else Prepare.batch_insert_or_upsert_param(collection_name, entities, partition_name, fields_info, is_insert) - - return request + return ( + param + if param + else Prepare.batch_insert_or_upsert_param( + collection_name, entities, partition_name, fields_info, is_insert + ) + ) @retry_on_rpc_failure() - def insert_rows(self, collection_name, entities, partition_name=None, timeout=None, **kwargs): + def insert_rows( + self, + collection_name: str, + entities: List, + partition_name: Optional[str] = None, + timeout: Optional[float] = None, + **kwargs, + ): if isinstance(entities, dict): entities = [entities] - try: - request = self._prepare_row_insert_or_upsert_request( - collection_name, entities, partition_name, timeout, **kwargs) - rf = self._stub.Insert.future(request, timeout=timeout) - response = rf.result() - if response.status.error_code == 0: - m = MutationResult(response) - ts_utils.update_collection_ts(collection_name, m.timestamp) - return m + request = self._prepare_row_insert_or_upsert_request( + collection_name, entities, partition_name, timeout, **kwargs + ) + rf = self._stub.Insert.future(request, timeout=timeout) + response = rf.result() + if response.status.error_code == 0: + m = MutationResult(response) + ts_utils.update_collection_ts(collection_name, m.timestamp) + return m - raise MilvusException(response.status.error_code, response.status.reason) - except Exception as err: - raise err + raise MilvusException(response.status.error_code, response.status.reason) @retry_on_rpc_failure() - def batch_insert(self, collection_name, entities, partition_name=None, timeout=None, **kwargs): + def batch_insert( + self, + collection_name: str, + entities: List, + partition_name: Optional[str] = None, + timeout: Optional[float] = None, + **kwargs, + ): if not check_invalid_binary_vector(entities): raise ParamError(message="Invalid binary vector data exists") try: request = self._prepare_batch_insert_or_upsert_request( - collection_name, entities, partition_name, timeout, **kwargs) + collection_name, entities, partition_name, timeout, **kwargs + ) rf = self._stub.Insert.future(request, timeout=timeout) if kwargs.get("_async", False) is True: cb = kwargs.get("_callback", None) @@ -478,10 +565,17 @@ def batch_insert(self, collection_name, entities, partition_name=None, timeout=N except Exception as err: if kwargs.get("_async", False): return MutationFuture(None, None, err) - raise err + raise err from err @retry_on_rpc_failure() - def delete(self, collection_name, expression, partition_name=None, timeout=None, **kwargs): + def delete( + self, + collection_name: str, + expression: str, + partition_name: Optional[str] = None, + timeout: Optional[float] = None, + **kwargs, + ): check_pass_param(collection_name=collection_name) try: req = Prepare.delete_request(collection_name, partition_name, expression) @@ -503,16 +597,24 @@ def delete(self, collection_name, expression, partition_name=None, timeout=None, except Exception as err: if kwargs.get("_async", False): return MutationFuture(None, None, err) - raise err + raise err from err @retry_on_rpc_failure() - def upsert(self, collection_name, entities, partition_name=None, timeout=None, **kwargs): + def upsert( + self, + collection_name: str, + entities: List, + partition_name: Optional[str] = None, + timeout: Optional[float] = None, + **kwargs, + ): if not check_invalid_binary_vector(entities): raise ParamError(message="Invalid binary vector data exists") try: request = self._prepare_batch_insert_or_upsert_request( - collection_name, entities, partition_name, timeout, False, **kwargs) + collection_name, entities, partition_name, timeout, False, **kwargs + ) rf = self._stub.Upsert.future(request, timeout=timeout) if kwargs.get("_async", False) is True: cb = kwargs.get("_callback", None) @@ -526,33 +628,36 @@ def upsert(self, collection_name, entities, partition_name=None, timeout=None, * ts_utils.update_collection_ts(collection_name, m.timestamp) return m - raise MilvusException( - response.status.error_code, response.status.reason) + raise MilvusException(response.status.error_code, response.status.reason) except Exception as err: if kwargs.get("_async", False): return MutationFuture(None, None, err) - raise err + raise err from err @retry_on_rpc_failure() - def upsert_rows(self, collection_name, entities, partition_name=None, timeout=None, **kwargs): + def upsert_rows( + self, + collection_name: str, + entities: List, + partition_name: Optional[str] = None, + timeout: Optional[float] = None, + **kwargs, + ): if isinstance(entities, dict): entities = [entities] - try: - request = self._prepare_row_insert_or_upsert_request( - collection_name, entities, partition_name, timeout, False, **kwargs) - rf = self._stub.Upsert.future(request, timeout=timeout) - response = rf.result() - if response.status.error_code == 0: - m = MutationResult(response) - ts_utils.update_collection_ts(collection_name, m.timestamp) - return m + request = self._prepare_row_insert_or_upsert_request( + collection_name, entities, partition_name, timeout, False, **kwargs + ) + rf = self._stub.Upsert.future(request, timeout=timeout) + response = rf.result() + if response.status.error_code == 0: + m = MutationResult(response) + ts_utils.update_collection_ts(collection_name, m.timestamp) + return m - raise MilvusException( - response.status.error_code, response.status.reason) - except Exception as err: - raise err + raise MilvusException(response.status.error_code, response.status.reason) - def _execute_search_requests(self, requests, timeout=None, **kwargs): + def _execute_search_requests(self, requests: Any, timeout: Optional[float] = None, **kwargs): try: if kwargs.get("_async", False): futures = [] @@ -576,12 +681,23 @@ def _execute_search_requests(self, requests, timeout=None, **kwargs): except Exception as pre_err: if kwargs.get("_async", False): return SearchFuture(None, None, pre_err) - raise pre_err - - @retry_on_rpc_failure() - def search(self, collection_name, data, anns_field, param, limit, - expression=None, partition_names=None, output_fields=None, - round_decimal=-1, timeout=None, **kwargs): + raise pre_err from pre_err + + @retry_on_rpc_failure() + def search( + self, + collection_name: str, + data: List[List[float]], + anns_field: str, + param: Dict, + limit: int, + expression: Optional[str] = None, + partition_names: Optional[List[str]] = None, + output_fields: Optional[List[str]] = None, + round_decimal: int = -1, + timeout: Optional[float] = None, + **kwargs, + ): check_pass_param( limit=limit, round_decimal=round_decimal, @@ -590,16 +706,27 @@ def search(self, collection_name, data, anns_field, param, limit, partition_name_array=partition_names, output_fields=output_fields, travel_timestamp=kwargs.get("travel_timestamp", 0), - guarantee_timestamp=kwargs.get("guarantee_timestamp", None) + guarantee_timestamp=kwargs.get("guarantee_timestamp", None), ) - requests = Prepare.search_requests_with_expr(collection_name, data, anns_field, param, limit, - expression, partition_names, output_fields, round_decimal, - **kwargs) - return self._execute_search_requests(requests, timeout, round_decimal=round_decimal, **kwargs) + requests = Prepare.search_requests_with_expr( + collection_name, + data, + anns_field, + param, + limit, + expression, + partition_names, + output_fields, + round_decimal, + **kwargs, + ) + return self._execute_search_requests( + requests, timeout, round_decimal=round_decimal, **kwargs + ) @retry_on_rpc_failure() - def get_query_segment_info(self, collection_name, timeout=30, **kwargs): + def get_query_segment_info(self, collection_name: str, timeout: float = 30, **kwargs): req = Prepare.get_query_segment_info_request(collection_name) future = self._stub.GetQuerySegmentInfo.future(req, timeout=timeout) response = future.result() @@ -609,7 +736,9 @@ def get_query_segment_info(self, collection_name, timeout=30, **kwargs): raise MilvusException(status.error_code, status.reason) @retry_on_rpc_failure() - def create_alias(self, collection_name, alias, timeout=None, **kwargs): + def create_alias( + self, collection_name: str, alias: str, timeout: Optional[float] = None, **kwargs + ): check_pass_param(collection_name=collection_name) request = Prepare.create_alias_request(collection_name, alias) rf = self._stub.CreateAlias.future(request, timeout=timeout) @@ -618,7 +747,7 @@ def create_alias(self, collection_name, alias, timeout=None, **kwargs): raise MilvusException(response.error_code, response.reason) @retry_on_rpc_failure() - def drop_alias(self, alias, timeout=None, **kwargs): + def drop_alias(self, alias: str, timeout: Optional[float] = None, **kwargs): request = Prepare.drop_alias_request(alias) rf = self._stub.DropAlias.future(request, timeout=timeout) response = rf.result() @@ -626,7 +755,9 @@ def drop_alias(self, alias, timeout=None, **kwargs): raise MilvusException(response.error_code, response.reason) @retry_on_rpc_failure() - def alter_alias(self, collection_name, alias, timeout=None, **kwargs): + def alter_alias( + self, collection_name: str, alias: str, timeout: Optional[float] = None, **kwargs + ): check_pass_param(collection_name=collection_name) request = Prepare.alter_alias_request(collection_name, alias) rf = self._stub.AlterAlias.future(request, timeout=timeout) @@ -635,7 +766,14 @@ def alter_alias(self, collection_name, alias, timeout=None, **kwargs): raise MilvusException(response.error_code, response.reason) @retry_on_rpc_failure() - def create_index(self, collection_name, field_name, params, timeout=None, **kwargs): + def create_index( + self, + collection_name: str, + field_name: str, + params: Dict, + timeout: Optional[float] = None, + **kwargs, + ): # for historical reason, index_name contained in kwargs. index_name = kwargs.pop("index_name", Config.IndexName) copy_kwargs = copy.deepcopy(kwargs) @@ -657,15 +795,21 @@ def create_index(self, collection_name, field_name, params, timeout=None, **kwar _async = kwargs.get("_async", False) kwargs["_async"] = False - index_param = Prepare.create_index_request(collection_name, field_name, params, index_name=index_name) + index_param = Prepare.create_index_request( + collection_name, field_name, params, index_name=index_name + ) future = self._stub.CreateIndex.future(index_param, timeout=timeout) if _async: + def _check(): if kwargs.get("sync", True): - index_success, fail_reason = self.wait_for_creating_index(collection_name=collection_name, - index_name=index_name, - timeout=timeout, field_name=field_name) + index_success, fail_reason = self.wait_for_creating_index( + collection_name=collection_name, + index_name=index_name, + timeout=timeout, + field_name=field_name, + ) if not index_success: raise MilvusException(message=fail_reason) @@ -682,16 +826,19 @@ def _check(): raise MilvusException(status.error_code, status.reason) if kwargs.get("sync", True): - index_success, fail_reason = self.wait_for_creating_index(collection_name=collection_name, - index_name=index_name, - timeout=timeout, field_name=field_name) + index_success, fail_reason = self.wait_for_creating_index( + collection_name=collection_name, + index_name=index_name, + timeout=timeout, + field_name=field_name, + ) if not index_success: raise MilvusException(message=fail_reason) return Status(status.error_code, status.reason) @retry_on_rpc_failure() - def list_indexes(self, collection_name, timeout=None, **kwargs): + def list_indexes(self, collection_name: str, timeout: Optional[float] = None, **kwargs): check_pass_param(collection_name=collection_name) request = Prepare.describe_index_request(collection_name, "") @@ -705,7 +852,9 @@ def list_indexes(self, collection_name, timeout=None, **kwargs): raise MilvusException(status.error_code, status.reason) @retry_on_rpc_failure() - def describe_index(self, collection_name, index_name, timeout=None, **kwargs): + def describe_index( + self, collection_name: str, index_name: str, timeout: Optional[float] = None, **kwargs + ): check_pass_param(collection_name=collection_name) request = Prepare.describe_index_request(collection_name, index_name) @@ -718,8 +867,8 @@ def describe_index(self, collection_name, index_name, timeout=None, **kwargs): raise MilvusException(status.error_code, status.reason) if len(response.index_descriptions) == 1: info_dict = {kv.key: kv.value for kv in response.index_descriptions[0].params} - info_dict['field_name'] = response.index_descriptions[0].field_name - info_dict['index_name'] = response.index_descriptions[0].index_name + info_dict["field_name"] = response.index_descriptions[0].field_name + info_dict["index_name"] = response.index_descriptions[0].index_name if info_dict.get("params", None): info_dict["params"] = json.loads(info_dict["params"]) return info_dict @@ -727,7 +876,9 @@ def describe_index(self, collection_name, index_name, timeout=None, **kwargs): raise AmbiguousIndexName(message=ExceptionsMessage.AmbiguousIndexName) @retry_on_rpc_failure() - def get_index_build_progress(self, collection_name, index_name, timeout=None): + def get_index_build_progress( + self, collection_name: str, index_name: str, timeout: Optional[float] = None + ): request = Prepare.describe_index_request(collection_name, index_name) rf = self._stub.DescribeIndex.future(request, timeout=timeout) response = rf.result() @@ -735,13 +886,18 @@ def get_index_build_progress(self, collection_name, index_name, timeout=None): if status.error_code == 0: if len(response.index_descriptions) == 1: index_desc = response.index_descriptions[0] - return {'total_rows': index_desc.total_rows, 'indexed_rows': index_desc.indexed_rows, - "pending_index_rows": index_desc.pending_index_rows} + return { + "total_rows": index_desc.total_rows, + "indexed_rows": index_desc.indexed_rows, + "pending_index_rows": index_desc.pending_index_rows, + } raise AmbiguousIndexName(message=ExceptionsMessage.AmbiguousIndexName) raise MilvusException(status.error_code, status.reason) @retry_on_rpc_failure() - def get_index_state(self, collection_name: str, index_name: str, timeout=None, **kwargs): + def get_index_state( + self, collection_name: str, index_name: str, timeout: Optional[float] = None, **kwargs + ): request = Prepare.describe_index_request(collection_name, index_name) rf = self._stub.DescribeIndex.future(request, timeout=timeout) response = rf.result() @@ -762,58 +918,79 @@ def get_index_state(self, collection_name: str, index_name: str, timeout=None, * raise AmbiguousIndexName(message=ExceptionsMessage.AmbiguousIndexName) @retry_on_rpc_failure() - def wait_for_creating_index(self, collection_name, index_name, timeout=None, **kwargs): + def wait_for_creating_index( + self, collection_name: str, index_name: str, timeout: Optional[float] = None, **kwargs + ): start = time.time() while True: time.sleep(0.5) - state, fail_reason = self.get_index_state(collection_name, index_name, timeout=timeout, **kwargs) + state, fail_reason = self.get_index_state( + collection_name, index_name, timeout=timeout, **kwargs + ) if state == IndexState.Finished: return True, fail_reason if state == IndexState.Failed: return False, fail_reason end = time.time() if isinstance(timeout, int) and end - start > timeout: - raise MilvusException( - message=f"collection {collection_name} create index {index_name} timeout in {timeout}s") + msg = ( + f"collection {collection_name} create index {index_name} " + f"timeout in {timeout}s" + ) + raise MilvusException(message=msg) @retry_on_rpc_failure() - def load_collection(self, collection_name, replica_number=1, timeout=None, **kwargs): + def load_collection( + self, + collection_name: str, + replica_number: int = 1, + timeout: Optional[float] = None, + **kwargs, + ): check_pass_param(collection_name=collection_name, replica_number=replica_number) _refresh = kwargs.get("_refresh", False) _resource_groups = kwargs.get("_resource_groups") - request = Prepare.load_collection("", collection_name, replica_number, _refresh, _resource_groups) + request = Prepare.load_collection( + "", collection_name, replica_number, _refresh, _resource_groups + ) rf = self._stub.LoadCollection.future(request, timeout=timeout) response = rf.result() if response.error_code != 0: raise MilvusException(response.error_code, response.reason) _async = kwargs.get("_async", False) if not _async: - self.wait_for_loading_collection(collection_name, timeout, isRefresh=_refresh) + self.wait_for_loading_collection(collection_name, timeout, is_refresh=_refresh) @retry_on_rpc_failure() - def load_collection_progress(self, collection_name, timeout=None): - """ Return loading progress of collection """ + def load_collection_progress(self, collection_name: str, timeout: Optional[float] = None): + """Return loading progress of collection""" progress = self.get_loading_progress(collection_name, timeout=timeout) return { "loading_progress": f"{progress:.0f}%", } @retry_on_rpc_failure() - def wait_for_loading_collection(self, collection_name, timeout=None, isRefresh=False): + def wait_for_loading_collection( + self, collection_name: str, timeout: Optional[float] = None, is_refresh: bool = False + ): start = time.time() - def can_loop(t) -> bool: + def can_loop(t: int) -> bool: return True if timeout is None else t <= (start + timeout) while can_loop(time.time()): - progress = self.get_loading_progress(collection_name, timeout=timeout, isRefresh=isRefresh) + progress = self.get_loading_progress( + collection_name, timeout=timeout, is_refresh=is_refresh + ) if progress >= 100: return time.sleep(Config.WaitTimeDurationWhenLoad) - raise MilvusException(message=f"wait for loading collection timeout, collection: {collection_name}") + raise MilvusException( + message=f"wait for loading collection timeout, collection: {collection_name}" + ) @retry_on_rpc_failure() - def release_collection(self, collection_name, timeout=None): + def release_collection(self, collection_name: str, timeout: Optional[float] = None): check_pass_param(collection_name=collection_name) request = Prepare.release_collection("", collection_name) rf = self._stub.ReleaseCollection.future(request, timeout=timeout) @@ -822,21 +999,33 @@ def release_collection(self, collection_name, timeout=None): raise MilvusException(response.error_code, response.reason) @retry_on_rpc_failure() - def load_partitions(self, collection_name, partition_names, replica_number=1, timeout=None, **kwargs): + def load_partitions( + self, + collection_name: str, + partition_names: List[str], + replica_number: int = 1, + timeout: Optional[float] = None, + **kwargs, + ): check_pass_param( collection_name=collection_name, partition_name_array=partition_names, - replica_number=replica_number) + replica_number=replica_number, + ) _refresh = kwargs.get("_refresh", False) _resource_groups = kwargs.get("_resource_groups") - request = Prepare.load_partitions("", collection_name, partition_names, replica_number, _refresh, - _resource_groups) + request = Prepare.load_partitions( + "", collection_name, partition_names, replica_number, _refresh, _resource_groups + ) future = self._stub.LoadPartitions.future(request, timeout=timeout) if kwargs.get("_async", False): + def _check(): if kwargs.get("sync", True): - self.wait_for_loading_partitions(collection_name, partition_names, isRefresh=_refresh) + self.wait_for_loading_partitions( + collection_name, partition_names, is_refresh=_refresh + ) load_partitions_future = LoadPartitionsFuture(future) load_partitions_future.add_callback(_check) @@ -852,50 +1041,67 @@ def _check(): raise MilvusException(response.error_code, response.reason) sync = kwargs.get("sync", True) if sync: - self.wait_for_loading_partitions(collection_name, partition_names, isRefresh=_refresh) + self.wait_for_loading_partitions(collection_name, partition_names, is_refresh=_refresh) + return None + return None @retry_on_rpc_failure() - def wait_for_loading_partitions(self, collection_name, partition_names, timeout=None, isRefresh=False): + def wait_for_loading_partitions( + self, + collection_name: str, + partition_names: List[str], + timeout: Optional[float] = None, + is_refresh: bool = False, + ): start = time.time() - def can_loop(t) -> bool: + def can_loop(t: int) -> bool: return True if timeout is None else t <= (start + timeout) while can_loop(time.time()): - progress = self.get_loading_progress(collection_name, partition_names, timeout=timeout, isRefresh=isRefresh) + progress = self.get_loading_progress( + collection_name, partition_names, timeout=timeout, is_refresh=is_refresh + ) if progress >= 100: return time.sleep(Config.WaitTimeDurationWhenLoad) raise MilvusException( - message=f"wait for loading partition timeout, collection: {collection_name}, partitions: {partition_names}") + message=f"wait for loading partition timeout, collection: {collection_name}, partitions: {partition_names}" + ) @retry_on_rpc_failure() - def get_loading_progress(self, collection_name, partition_names=None, timeout=None, isRefresh=False): + def get_loading_progress( + self, + collection_name: str, + partition_names: Optional[List[str]] = None, + timeout: Optional[float] = None, + is_refresh: bool = False, + ): request = Prepare.get_loading_progress(collection_name, partition_names) response = self._stub.GetLoadingProgress.future(request, timeout=timeout).result() if response.status.error_code != 0: raise MilvusException(response.status.error_code, response.status.reason) - if isRefresh: + if is_refresh: return response.refresh_progress return response.progress @retry_on_rpc_failure() - def create_database(self, db_name, timeout=None): + def create_database(self, db_name: str, timeout: Optional[float] = None): request = Prepare.create_database_req(db_name) status = self._stub.CreateDatabase(request, timeout=timeout) if status.error_code != 0: raise MilvusException(status.error_code, status.reason) @retry_on_rpc_failure() - def drop_database(self, db_name, timeout=None): + def drop_database(self, db_name: str, timeout: Optional[float] = None): request = Prepare.drop_database_req(db_name) status = self._stub.DropDatabase(request, timeout=timeout) if status.error_code != 0: raise MilvusException(status.error_code, status.reason) @retry_on_rpc_failure() - def list_database(self, timeout=None): + def list_database(self, timeout: Optional[float] = None): request = Prepare.list_database_req() response = self._stub.ListDatabases(request, timeout=timeout) if response.status.error_code != 0: @@ -903,7 +1109,12 @@ def list_database(self, timeout=None): return list(response.db_names) @retry_on_rpc_failure() - def get_load_state(self, collection_name, partition_names=None, timeout=None): + def get_load_state( + self, + collection_name: str, + partition_names: Optional[List[str]] = None, + timeout: Optional[float] = None, + ): request = Prepare.get_load_state(collection_name, partition_names) response = self._stub.GetLoadState.future(request, timeout=timeout).result() if response.status.error_code != 0: @@ -911,15 +1122,19 @@ def get_load_state(self, collection_name, partition_names=None, timeout=None): return LoadState(response.state) @retry_on_rpc_failure() - def load_partitions_progress(self, collection_name, partition_names, timeout=None): - """ Return loading progress of partitions """ + def load_partitions_progress( + self, collection_name: str, partition_names: List[str], timeout: Optional[float] = None + ): + """Return loading progress of partitions""" progress = self.get_loading_progress(collection_name, partition_names, timeout) return { "loading_progress": f"{progress:.0f}%", } @retry_on_rpc_failure() - def release_partitions(self, collection_name, partition_names, timeout=None): + def release_partitions( + self, collection_name: str, partition_names: List[str], timeout: Optional[float] = None + ): check_pass_param(collection_name=collection_name, partition_name_array=partition_names) request = Prepare.release_partitions("", collection_name, partition_names) rf = self._stub.ReleasePartitions.future(request, timeout=timeout) @@ -928,7 +1143,7 @@ def release_partitions(self, collection_name, partition_names, timeout=None): raise MilvusException(response.error_code, response.reason) @retry_on_rpc_failure() - def get_collection_stats(self, collection_name, timeout=None, **kwargs): + def get_collection_stats(self, collection_name: str, timeout: Optional[float] = None, **kwargs): check_pass_param(collection_name=collection_name) index_param = Prepare.get_collection_stats_request(collection_name) future = self._stub.GetCollectionStatistics.future(index_param, timeout=timeout) @@ -940,7 +1155,7 @@ def get_collection_stats(self, collection_name, timeout=None, **kwargs): raise MilvusException(status.error_code, status.reason) @retry_on_rpc_failure() - def get_flush_state(self, segment_ids, timeout=None, **kwargs): + def get_flush_state(self, segment_ids: List[int], timeout: Optional[float] = None, **kwargs): req = Prepare.get_flush_state_request(segment_ids) future = self._stub.GetFlushState.future(req, timeout=timeout) response = future.result() @@ -951,7 +1166,9 @@ def get_flush_state(self, segment_ids, timeout=None, **kwargs): # TODO seem not in use @retry_on_rpc_failure() - def get_persistent_segment_infos(self, collection_name, timeout=None, **kwargs): + def get_persistent_segment_infos( + self, collection_name: str, timeout: Optional[float] = None, **kwargs + ): req = Prepare.get_persistent_segment_info_request(collection_name) future = self._stub.GetPersistentSegmentInfo.future(req, timeout=timeout) response = future.result() @@ -960,21 +1177,20 @@ def get_persistent_segment_infos(self, collection_name, timeout=None, **kwargs): return response.infos # todo: A wrapper class of PersistentSegmentInfo raise MilvusException(status.error_code, status.reason) - def _wait_for_flushed(self, segment_ids, timeout=None, **kwargs): + def _wait_for_flushed(self, segment_ids: List[int], timeout: Optional[float] = None, **kwargs): flush_ret = False start = time.time() while not flush_ret: flush_ret = self.get_flush_state(segment_ids, timeout, **kwargs) end = time.time() - if timeout is not None: - if end - start > timeout: - raise MilvusException(message=f"wait for flush timeout, segment ids: {segment_ids}") + if timeout is not None and end - start > timeout: + raise MilvusException(message=f"wait for flush timeout, segment ids: {segment_ids}") if not flush_ret: time.sleep(0.5) @retry_on_rpc_failure() - def flush(self, collection_names: list, timeout=None, **kwargs): + def flush(self, collection_names: list, timeout: Optional[float] = None, **kwargs): if collection_names in (None, []) or not isinstance(collection_names, list): raise ParamError(message="Collection name list can not be None or empty") @@ -1003,9 +1219,17 @@ def _check(): return flush_future _check() + return None @retry_on_rpc_failure() - def drop_index(self, collection_name, field_name, index_name, timeout=None, **kwargs): + def drop_index( + self, + collection_name: str, + field_name: str, + index_name: str, + timeout: Optional[float] = None, + **kwargs, + ): check_pass_param(collection_name=collection_name, field_name=field_name) request = Prepare.drop_index_request(collection_name, field_name, index_name) future = self._stub.DropIndex.future(request, timeout=timeout) @@ -1014,31 +1238,48 @@ def drop_index(self, collection_name, field_name, index_name, timeout=None, **kw raise MilvusException(response.error_code, response.reason) @retry_on_rpc_failure() - def dummy(self, request_type, timeout=None, **kwargs): + def dummy(self, request_type: Any, timeout: Optional[float] = None, **kwargs): request = Prepare.dummy_request(request_type) future = self._stub.Dummy.future(request, timeout=timeout) return future.result() # TODO seems not in use @retry_on_rpc_failure() - def fake_register_link(self, timeout=None): + def fake_register_link(self, timeout: Optional[float] = None): request = Prepare.register_link_request() future = self._stub.RegisterLink.future(request, timeout=timeout) return future.result().status # TODO seems not in use @retry_on_rpc_failure() - def get(self, collection_name, ids, output_fields=None, partition_names=None, timeout=None): + def get( + self, + collection_name: str, + ids: List[int], + output_fields: Optional[List[str]] = None, + partition_names: Optional[List[str]] = None, + timeout: Optional[float] = None, + ): # TODO: some check request = Prepare.retrieve_request(collection_name, ids, output_fields, partition_names) future = self._stub.Retrieve.future(request, timeout=timeout) return future.result() @retry_on_rpc_failure() - def query(self, collection_name, expr, output_fields=None, partition_names=None, timeout=None, **kwargs): + def query( + self, + collection_name: str, + expr: str, + output_fields: Optional[List[str]] = None, + partition_names: Optional[List[str]] = None, + timeout: Optional[float] = None, + **kwargs, + ): if output_fields is not None and not isinstance(output_fields, (list,)): raise ParamError(message="Invalid query format. 'output_fields' must be a list") - request = Prepare.query_request(collection_name, expr, output_fields, partition_names, **kwargs) + request = Prepare.query_request( + collection_name, expr, output_fields, partition_names, **kwargs + ) future = self._stub.Query.future(request, timeout=timeout) response = future.result() @@ -1062,21 +1303,32 @@ def query(self, collection_name, expr, output_fields=None, partition_names=None, results = [] for index in range(0, num_entities): - entity_row_data = entity_helper.extract_row_data_from_fields_data(response.fields_data, index, - dynamic_fields) + entity_row_data = entity_helper.extract_row_data_from_fields_data( + response.fields_data, index, dynamic_fields + ) results.append(entity_row_data) return results @retry_on_rpc_failure() - def load_balance(self, collection_name: str, src_node_id, dst_node_ids, sealed_segment_ids, timeout=None, **kwargs): - req = Prepare.load_balance_request(collection_name, src_node_id, dst_node_ids, sealed_segment_ids) + def load_balance( + self, + collection_name: str, + src_node_id: int, + dst_node_ids: List[int], + sealed_segment_ids: List[int], + timeout: Optional[float] = None, + **kwargs, + ): + req = Prepare.load_balance_request( + collection_name, src_node_id, dst_node_ids, sealed_segment_ids + ) future = self._stub.LoadBalance.future(req, timeout=timeout) status = future.result() if status.error_code != 0: raise MilvusException(status.error_code, status.reason) @retry_on_rpc_failure() - def compact(self, collection_name, timeout=None, **kwargs) -> int: + def compact(self, collection_name: str, timeout: Optional[float] = None, **kwargs) -> int: request = Prepare.describe_collection_request(collection_name) rf = self._stub.DescribeCollection.future(request, timeout=timeout) response = rf.result() @@ -1092,7 +1344,9 @@ def compact(self, collection_name, timeout=None, **kwargs) -> int: return response.compactionID @retry_on_rpc_failure() - def get_compaction_state(self, compaction_id, timeout=None, **kwargs) -> CompactionState: + def get_compaction_state( + self, compaction_id: int, timeout: Optional[float] = None, **kwargs + ) -> CompactionState: req = Prepare.get_compaction_state(compaction_id) future = self._stub.GetCompactionState.future(req, timeout=timeout) @@ -1105,11 +1359,13 @@ def get_compaction_state(self, compaction_id, timeout=None, **kwargs) -> Compact State.new(response.state), response.executingPlanNo, response.timeoutPlanNo, - response.completedPlanNo + response.completedPlanNo, ) @retry_on_rpc_failure() - def wait_for_compaction_completed(self, compaction_id, timeout=None, **kwargs): + def wait_for_compaction_completed( + self, compaction_id: int, timeout: Optional[float] = None, **kwargs + ): start = time.time() while True: time.sleep(0.5) @@ -1119,12 +1375,15 @@ def wait_for_compaction_completed(self, compaction_id, timeout=None, **kwargs): if compaction_state == State.UndefiedState: return False end = time.time() - if timeout is not None: - if end - start > timeout: - raise MilvusException(message=f"get compaction state timeout, compaction id: {compaction_id}") + if timeout is not None and end - start > timeout: + raise MilvusException( + message=f"get compaction state timeout, compaction id: {compaction_id}" + ) @retry_on_rpc_failure() - def get_compaction_plans(self, compaction_id, timeout=None, **kwargs) -> CompactionPlans: + def get_compaction_plans( + self, compaction_id: int, timeout: Optional[float] = None, **kwargs + ) -> CompactionPlans: req = Prepare.get_compaction_state_with_plans(compaction_id) future = self._stub.GetCompactionStateWithPlans.future(req, timeout=timeout) @@ -1139,8 +1398,12 @@ def get_compaction_plans(self, compaction_id, timeout=None, **kwargs) -> Compact return cp @retry_on_rpc_failure() - def get_replicas(self, collection_name, timeout=None, **kwargs) -> Replica: - collection_id = self.describe_collection(collection_name, timeout, **kwargs)["collection_id"] + def get_replicas( + self, collection_name: str, timeout: Optional[float] = None, **kwargs + ) -> Replica: + collection_id = self.describe_collection(collection_name, timeout, **kwargs)[ + "collection_id" + ] req = Prepare.get_replicas(collection_id) future = self._stub.GetReplicas.future(req, timeout=timeout) @@ -1150,14 +1413,30 @@ def get_replicas(self, collection_name, timeout=None, **kwargs) -> Replica: groups = [] for replica in response.replicas: - shards = [Shard(s.dm_channel_name, s.node_ids, s.leaderID) for s in replica.shard_replicas] - groups.append(Group(replica.replicaID, shards, replica.node_ids, replica.resource_group_name, - replica.num_outbound_node)) + shards = [ + Shard(s.dm_channel_name, s.node_ids, s.leaderID) for s in replica.shard_replicas + ] + groups.append( + Group( + replica.replicaID, + shards, + replica.node_ids, + replica.resource_group_name, + replica.num_outbound_node, + ) + ) return Replica(groups) @retry_on_rpc_failure() - def do_bulk_insert(self, collection_name, partition_name, files: list, timeout=None, **kwargs) -> int: + def do_bulk_insert( + self, + collection_name: str, + partition_name: str, + files: List[str], + timeout: Optional[float] = None, + **kwargs, + ) -> int: req = Prepare.do_bulk_insert(collection_name, partition_name, files, **kwargs) future = self._stub.Import.future(req, timeout=timeout) response = future.result() @@ -1168,29 +1447,35 @@ def do_bulk_insert(self, collection_name, partition_name, files: list, timeout=N return response.tasks[0] @retry_on_rpc_failure() - def get_bulk_insert_state(self, task_id, timeout=None, **kwargs) -> BulkInsertState: + def get_bulk_insert_state( + self, task_id: int, timeout: Optional[float] = None, **kwargs + ) -> BulkInsertState: req = Prepare.get_bulk_insert_state(task_id) future = self._stub.GetImportState.future(req, timeout=timeout) resp = future.result() if resp.status.error_code != 0: raise MilvusException(resp.status.error_code, resp.status.reason) - state = BulkInsertState(task_id, resp.state, resp.row_count, resp.id_list, resp.infos, resp.create_ts) - return state + return BulkInsertState( + task_id, resp.state, resp.row_count, resp.id_list, resp.infos, resp.create_ts + ) @retry_on_rpc_failure() - def list_bulk_insert_tasks(self, limit, collection_name, timeout=None, **kwargs) -> list: + def list_bulk_insert_tasks( + self, limit: int, collection_name: str, timeout: Optional[float] = None, **kwargs + ) -> list: req = Prepare.list_bulk_insert_tasks(limit, collection_name) future = self._stub.ListImportTasks.future(req, timeout=timeout) resp = future.result() if resp.status.error_code != 0: raise MilvusException(resp.status.error_code, resp.status.reason) - tasks = [BulkInsertState(t.id, t.state, t.row_count, t.id_list, t.infos, t.create_ts) - for t in resp.tasks] - return tasks + return [ + BulkInsertState(t.id, t.state, t.row_count, t.id_list, t.infos, t.create_ts) + for t in resp.tasks + ] @retry_on_rpc_failure() - def create_user(self, user, password, timeout=None, **kwargs): + def create_user(self, user: str, password: str, timeout: Optional[float] = None, **kwargs): check_pass_param(user=user, password=password) req = Prepare.create_user_request(user, password) resp = self._stub.CreateCredential(req, timeout=timeout) @@ -1198,21 +1483,28 @@ def create_user(self, user, password, timeout=None, **kwargs): raise MilvusException(resp.error_code, resp.reason) @retry_on_rpc_failure() - def update_password(self, user, old_password, new_password, timeout=None, **kwargs): + def update_password( + self, + user: str, + old_password: str, + new_password: str, + timeout: Optional[float] = None, + **kwargs, + ): req = Prepare.update_password_request(user, old_password, new_password) resp = self._stub.UpdateCredential(req, timeout=timeout) if resp.error_code != 0: raise MilvusException(resp.error_code, resp.reason) @retry_on_rpc_failure() - def delete_user(self, user, timeout=None, **kwargs): + def delete_user(self, user: str, timeout: Optional[float] = None, **kwargs): req = Prepare.delete_user_request(user) resp = self._stub.DeleteCredential(req, timeout=timeout) if resp.error_code != 0: raise MilvusException(resp.error_code, resp.reason) @retry_on_rpc_failure() - def list_usernames(self, timeout=None, **kwargs): + def list_usernames(self, timeout: Optional[float] = None, **kwargs): req = Prepare.list_usernames_request() resp = self._stub.ListCredUsers(req, timeout=timeout) if resp.status.error_code != 0: @@ -1220,36 +1512,45 @@ def list_usernames(self, timeout=None, **kwargs): return resp.usernames @retry_on_rpc_failure() - def create_role(self, role_name, timeout=None, **kwargs): + def create_role(self, role_name: str, timeout: Optional[float] = None, **kwargs): req = Prepare.create_role_request(role_name) resp = self._stub.CreateRole(req, wait_for_ready=True, timeout=timeout) if resp.error_code != 0: raise MilvusException(resp.error_code, resp.reason) @retry_on_rpc_failure() - def drop_role(self, role_name, timeout=None, **kwargs): + def drop_role(self, role_name: str, timeout: Optional[float] = None, **kwargs): req = Prepare.drop_role_request(role_name) resp = self._stub.DropRole(req, wait_for_ready=True, timeout=timeout) if resp.error_code != 0: raise MilvusException(resp.error_code, resp.reason) @retry_on_rpc_failure() - def add_user_to_role(self, username, role_name, timeout=None, **kwargs): - req = Prepare.operate_user_role_request(username, role_name, milvus_types.OperateUserRoleType.AddUserToRole) + def add_user_to_role( + self, username: str, role_name: str, timeout: Optional[float] = None, **kwargs + ): + req = Prepare.operate_user_role_request( + username, role_name, milvus_types.OperateUserRoleType.AddUserToRole + ) resp = self._stub.OperateUserRole(req, wait_for_ready=True, timeout=timeout) if resp.error_code != 0: raise MilvusException(resp.error_code, resp.reason) @retry_on_rpc_failure() - def remove_user_from_role(self, username, role_name, timeout=None, **kwargs): - req = Prepare.operate_user_role_request(username, role_name, - milvus_types.OperateUserRoleType.RemoveUserFromRole) + def remove_user_from_role( + self, username: str, role_name: str, timeout: Optional[float] = None, **kwargs + ): + req = Prepare.operate_user_role_request( + username, role_name, milvus_types.OperateUserRoleType.RemoveUserFromRole + ) resp = self._stub.OperateUserRole(req, wait_for_ready=True, timeout=timeout) if resp.error_code != 0: raise MilvusException(resp.error_code, resp.reason) @retry_on_rpc_failure() - def select_one_role(self, role_name, include_user_info, timeout=None, **kwargs): + def select_one_role( + self, role_name: str, include_user_info: bool, timeout: Optional[float] = None, **kwargs + ): req = Prepare.select_role_request(role_name, include_user_info) resp = self._stub.SelectRole(req, wait_for_ready=True, timeout=timeout) if resp.status.error_code != 0: @@ -1257,7 +1558,7 @@ def select_one_role(self, role_name, include_user_info, timeout=None, **kwargs): return RoleInfo(resp.results) @retry_on_rpc_failure() - def select_all_role(self, include_user_info, timeout=None, **kwargs): + def select_all_role(self, include_user_info: bool, timeout: Optional[float] = None, **kwargs): req = Prepare.select_role_request(None, include_user_info) resp = self._stub.SelectRole(req, wait_for_ready=True, timeout=timeout) if resp.status.error_code != 0: @@ -1265,7 +1566,9 @@ def select_all_role(self, include_user_info, timeout=None, **kwargs): return RoleInfo(resp.results) @retry_on_rpc_failure() - def select_one_user(self, username, include_role_info, timeout=None, **kwargs): + def select_one_user( + self, username: str, include_role_info: bool, timeout: Optional[float] = None, **kwargs + ): req = Prepare.select_user_request(username, include_role_info) resp = self._stub.SelectUser(req, wait_for_ready=True, timeout=timeout) if resp.status.error_code != 0: @@ -1273,7 +1576,7 @@ def select_one_user(self, username, include_role_info, timeout=None, **kwargs): return UserInfo(resp.results) @retry_on_rpc_failure() - def select_all_user(self, include_role_info, timeout=None, **kwargs): + def select_all_user(self, include_role_info: bool, timeout: Optional[float] = None, **kwargs): req = Prepare.select_user_request(None, include_role_info) resp = self._stub.SelectUser(req, wait_for_ready=True, timeout=timeout) if resp.status.error_code != 0: @@ -1281,23 +1584,55 @@ def select_all_user(self, include_role_info, timeout=None, **kwargs): return UserInfo(resp.results) @retry_on_rpc_failure() - def grant_privilege(self, role_name, object, object_name, privilege, db_name, timeout=None, **kwargs): - req = Prepare.operate_privilege_request(role_name, object, object_name, privilege, db_name, - milvus_types.OperatePrivilegeType.Grant) + def grant_privilege( + self, + role_name: str, + object: str, + object_name: str, + privilege: str, + db_name: str, + timeout: Optional[float] = None, + **kwargs, + ): + req = Prepare.operate_privilege_request( + role_name, + object, + object_name, + privilege, + db_name, + milvus_types.OperatePrivilegeType.Grant, + ) resp = self._stub.OperatePrivilege(req, wait_for_ready=True, timeout=timeout) if resp.error_code != 0: raise MilvusException(resp.error_code, resp.reason) @retry_on_rpc_failure() - def revoke_privilege(self, role_name, object, object_name, privilege, db_name, timeout=None, **kwargs): - req = Prepare.operate_privilege_request(role_name, object, object_name, privilege, db_name, - milvus_types.OperatePrivilegeType.Revoke) + def revoke_privilege( + self, + role_name: str, + object: str, + object_name: str, + privilege: str, + db_name: str, + timeout: Optional[float] = None, + **kwargs, + ): + req = Prepare.operate_privilege_request( + role_name, + object, + object_name, + privilege, + db_name, + milvus_types.OperatePrivilegeType.Revoke, + ) resp = self._stub.OperatePrivilege(req, wait_for_ready=True, timeout=timeout) if resp.error_code != 0: raise MilvusException(resp.error_code, resp.reason) @retry_on_rpc_failure() - def select_grant_for_one_role(self, role_name, db_name, timeout=None, **kwargs): + def select_grant_for_one_role( + self, role_name: str, db_name: str, timeout: Optional[float] = None, **kwargs + ): req = Prepare.select_grant_request(role_name, None, None, db_name) resp = self._stub.SelectGrant(req, wait_for_ready=True, timeout=timeout) if resp.status.error_code != 0: @@ -1306,7 +1641,15 @@ def select_grant_for_one_role(self, role_name, db_name, timeout=None, **kwargs): return GrantInfo(resp.entities) @retry_on_rpc_failure() - def select_grant_for_role_and_object(self, role_name, object, object_name, db_name, timeout=None, **kwargs): + def select_grant_for_role_and_object( + self, + role_name: str, + object: str, + object_name: str, + db_name: str, + timeout: Optional[float] = None, + **kwargs, + ): req = Prepare.select_grant_request(role_name, object, object_name, db_name) resp = self._stub.SelectGrant(req, wait_for_ready=True, timeout=timeout) if resp.status.error_code != 0: @@ -1315,7 +1658,7 @@ def select_grant_for_role_and_object(self, role_name, object, object_name, db_na return GrantInfo(resp.entities) @retry_on_rpc_failure() - def get_server_version(self, timeout=None, **kwargs) -> str: + def get_server_version(self, timeout: Optional[float] = None, **kwargs) -> str: req = Prepare.get_server_version() resp = self._stub.GetVersion(req, timeout=timeout) if resp.status.error_code != 0: @@ -1324,21 +1667,21 @@ def get_server_version(self, timeout=None, **kwargs) -> str: return resp.version @retry_on_rpc_failure() - def create_resource_group(self, name, timeout=None, **kwargs): + def create_resource_group(self, name: str, timeout: Optional[float] = None, **kwargs): req = Prepare.create_resource_group(name) resp = self._stub.CreateResourceGroup(req, wait_for_ready=True, timeout=timeout) if resp.error_code != 0: raise MilvusException(resp.error_code, resp.reason) @retry_on_rpc_failure() - def drop_resource_group(self, name, timeout=None, **kwargs): + def drop_resource_group(self, name: str, timeout: Optional[float] = None, **kwargs): req = Prepare.drop_resource_group(name) resp = self._stub.DropResourceGroup(req, wait_for_ready=True, timeout=timeout) if resp.error_code != 0: raise MilvusException(resp.error_code, resp.reason) @retry_on_rpc_failure() - def list_resource_groups(self, timeout=None, **kwargs): + def list_resource_groups(self, timeout: Optional[float] = None, **kwargs): req = Prepare.list_resource_groups() resp = self._stub.ListResourceGroups(req, wait_for_ready=True, timeout=timeout) if resp.status.error_code != 0: @@ -1346,7 +1689,9 @@ def list_resource_groups(self, timeout=None, **kwargs): return list(resp.resource_groups) @retry_on_rpc_failure() - def describe_resource_group(self, name, timeout=None, **kwargs) -> ResourceGroupInfo: + def describe_resource_group( + self, name: str, timeout: Optional[float] = None, **kwargs + ) -> ResourceGroupInfo: req = Prepare.describe_resource_group(name) resp = self._stub.DescribeResourceGroup(req, wait_for_ready=True, timeout=timeout) if resp.status.error_code != 0: @@ -1354,21 +1699,31 @@ def describe_resource_group(self, name, timeout=None, **kwargs) -> ResourceGroup return ResourceGroupInfo(resp.resource_group) @retry_on_rpc_failure() - def transfer_node(self, source, target, num_node, timeout=None, **kwargs): + def transfer_node( + self, source: str, target: str, num_node: int, timeout: Optional[float] = None, **kwargs + ): req = Prepare.transfer_node(source, target, num_node) resp = self._stub.TransferNode(req, wait_for_ready=True, timeout=timeout) if resp.error_code != 0: raise MilvusException(resp.error_code, resp.reason) @retry_on_rpc_failure() - def transfer_replica(self, source, target, collection_name, num_replica, timeout=None, **kwargs): + def transfer_replica( + self, + source: str, + target: str, + collection_name: str, + num_replica: int, + timeout: Optional[float] = None, + **kwargs, + ): req = Prepare.transfer_replica(source, target, collection_name, num_replica) resp = self._stub.TransferReplica(req, wait_for_ready=True, timeout=timeout) if resp.error_code != 0: raise MilvusException(resp.error_code, resp.reason) @retry_on_rpc_failure() - def get_flush_all_state(self, flush_all_ts, timeout=None, **kwargs): + def get_flush_all_state(self, flush_all_ts: int, timeout: Optional[float] = None, **kwargs): req = Prepare.get_flush_all_state_request(flush_all_ts) response = self._stub.GetFlushAllState(req, timeout=timeout) status = response.status @@ -1376,21 +1731,22 @@ def get_flush_all_state(self, flush_all_ts, timeout=None, **kwargs): return response.flushed raise MilvusException(status.error_code, status.reason) - def _wait_for_flush_all(self, flush_all_ts, timeout=None, **kwargs): + def _wait_for_flush_all(self, flush_all_ts: int, timeout: Optional[float] = None, **kwargs): flush_ret = False start = time.time() while not flush_ret: flush_ret = self.get_flush_all_state(flush_all_ts, timeout, **kwargs) end = time.time() - if timeout is not None: - if end - start > timeout: - raise MilvusException(message=f"wait for flush all timeout, flush_all_ts: {flush_all_ts}") + if timeout is not None and end - start > timeout: + raise MilvusException( + message=f"wait for flush all timeout, flush_all_ts: {flush_all_ts}" + ) if not flush_ret: time.sleep(5) @retry_on_rpc_failure() - def flush_all(self, timeout=None, **kwargs): + def flush_all(self, timeout: Optional[float] = None, **kwargs): request = Prepare.flush_all_request() future = self._stub.FlushAll.future(request, timeout=timeout) response = future.result() @@ -1411,10 +1767,11 @@ def _check(): return flush_future _check() + return None @retry_on_rpc_failure() @upgrade_reminder - def __internal_register(self, user, host) -> int: + def __internal_register(self, user: str, host: str) -> int: req = Prepare.register_request(user, host) response = self._stub.Connect(request=req) if response.status.error_code != common_pb2.Success: diff --git a/pymilvus/client/interceptor.py b/pymilvus/client/interceptor.py index f4cff2bd3..29a9fb24d 100644 --- a/pymilvus/client/interceptor.py +++ b/pymilvus/client/interceptor.py @@ -14,69 +14,88 @@ """Base class for interceptors that operate on all RPC types.""" import collections +from typing import Any, Callable, List import grpc -class _GenericClientInterceptor(grpc.UnaryUnaryClientInterceptor, - grpc.UnaryStreamClientInterceptor, - grpc.StreamUnaryClientInterceptor, - grpc.StreamStreamClientInterceptor): - - def __init__(self, interceptor_function): +class _GenericClientInterceptor( + grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, + grpc.StreamStreamClientInterceptor, +): + def __init__(self, interceptor_function: Callable) -> None: super().__init__() self._fn = interceptor_function - def intercept_unary_unary(self, continuation, client_call_details, request): + def intercept_unary_unary(self, continuation: Callable, client_call_details: Any, request: Any): new_details, new_request_iterator, postprocess = self._fn( - client_call_details, iter((request,)), False, False) + client_call_details, iter((request,)) + ) response = continuation(new_details, next(new_request_iterator)) return postprocess(response) if postprocess else response - def intercept_unary_stream(self, continuation, client_call_details, - request): + def intercept_unary_stream( + self, + continuation: Callable, + client_call_details: Any, + request: Any, + ): new_details, new_request_iterator, postprocess = self._fn( - client_call_details, iter((request,)), False, True) + client_call_details, iter((request,)) + ) response_it = continuation(new_details, next(new_request_iterator)) return postprocess(response_it) if postprocess else response_it - def intercept_stream_unary(self, continuation, client_call_details, - request_iterator): + def intercept_stream_unary( + self, + continuation: Callable, + client_call_details: Any, + request_iterator: Any, + ): new_details, new_request_iterator, postprocess = self._fn( - client_call_details, request_iterator, True, False) + client_call_details, request_iterator + ) response = continuation(new_details, new_request_iterator) return postprocess(response) if postprocess else response - def intercept_stream_stream(self, continuation, client_call_details, - request_iterator): + def intercept_stream_stream( + self, + continuation: Callable, + client_call_details: Any, + request_iterator: Any, + ): new_details, new_request_iterator, postprocess = self._fn( - client_call_details, request_iterator, True, True) + client_call_details, request_iterator + ) response_it = continuation(new_details, new_request_iterator) return postprocess(response_it) if postprocess else response_it -def create(intercept_call): - return _GenericClientInterceptor(intercept_call) - - class _ClientCallDetails( - collections.namedtuple( - '_ClientCallDetails', - ('method', 'timeout', 'metadata', 'credentials')), - grpc.ClientCallDetails): + collections.namedtuple("_ClientCallDetails", ("method", "timeout", "metadata", "credentials")), + grpc.ClientCallDetails, +): pass -def header_adder_interceptor(headers, values): - def intercept_call(client_call_details, request_iterator, request_streaming, - response_streaming): + +def header_adder_interceptor(headers: List, values: List): + def intercept_call( + client_call_details: Any, + request_iterator: Any, + ): metadata = [] if client_call_details.metadata is not None: metadata = list(client_call_details.metadata) for item in zip(headers, values): metadata.append(item) client_call_details = _ClientCallDetails( - client_call_details.method, client_call_details.timeout, metadata, - client_call_details.credentials) + client_call_details.method, + client_call_details.timeout, + metadata, + client_call_details.credentials, + ) return client_call_details, request_iterator, None - return create(intercept_call) + return _GenericClientInterceptor(intercept_call) diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py index c540c9f69..1b473fad4 100644 --- a/pymilvus/client/prepare.py +++ b/pymilvus/client/prepare.py @@ -1,30 +1,31 @@ import base64 -from typing import Dict, Iterable, Union - import datetime +from typing import Any, Dict, Iterable, List, Optional, Union + import ujson -from . import blob -from . import entity_helper -from . import ts_utils +from pymilvus.client import __version__ +from pymilvus.exceptions import DataNotMatchException, ExceptionsMessage, ParamError +from pymilvus.grpc_gen import common_pb2 as common_types +from pymilvus.grpc_gen import milvus_pb2 as milvus_types +from pymilvus.grpc_gen import schema_pb2 as schema_types +from pymilvus.orm.schema import CollectionSchema + +from . import blob, entity_helper, ts_utils from .check import check_pass_param, is_legal_collection_properties +from .constants import DEFAULT_CONSISTENCY_LEVEL, ITERATION_EXTENSION_REDUCE_RATE from .types import DataType, PlaceholderType, get_consistency_level from .utils import traverse_info, traverse_rows_info -from .constants import DEFAULT_CONSISTENCY_LEVEL, ITERATION_EXTENSION_REDUCE_RATE -from ..exceptions import ParamError, DataNotMatchException, ExceptionsMessage -from ..orm.schema import CollectionSchema - -from ..client import __version__ - -from ..grpc_gen import common_pb2 as common_types -from ..grpc_gen import schema_pb2 as schema_types -from ..grpc_gen import milvus_pb2 as milvus_types class Prepare: @classmethod - def create_collection_request(cls, collection_name: str, fields: Union[Dict[str, Iterable], CollectionSchema], - **kwargs) -> milvus_types.CreateCollectionRequest: + def create_collection_request( + cls, + collection_name: str, + fields: Union[Dict[str, Iterable], CollectionSchema], + **kwargs, + ) -> milvus_types.CreateCollectionRequest: """ Args: fields (Union(Dict[str, Iterable], CollectionSchema)). @@ -41,63 +42,82 @@ def create_collection_request(cls, collection_name: str, fields: Union[Dict[str, """ if isinstance(fields, CollectionSchema): - schema = cls.get_schema_from_collection_schema(collection_name, fields, **kwargs) + schema = cls.get_schema_from_collection_schema(collection_name, fields) else: schema = cls.get_schema(collection_name, fields, **kwargs) - consistency_level = get_consistency_level(kwargs.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL)) + consistency_level = get_consistency_level( + kwargs.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL) + ) - req = milvus_types.CreateCollectionRequest(collection_name=collection_name, - schema=bytes(schema.SerializeToString()), - consistency_level=consistency_level) + req = milvus_types.CreateCollectionRequest( + collection_name=collection_name, + schema=bytes(schema.SerializeToString()), + consistency_level=consistency_level, + ) properties = kwargs.get("properties") if is_legal_collection_properties(properties): - properties = [common_types.KeyValuePair(key=str(k), value=str(v)) for k, v in properties.items()] + properties = [ + common_types.KeyValuePair(key=str(k), value=str(v)) for k, v in properties.items() + ] req.properties.extend(properties) same_key = set(kwargs.keys()).intersection({"num_shards", "shards_num"}) if len(same_key) > 0: if len(same_key) > 1: - raise ParamError(message="got both num_shards and shards_num in kwargs, expected only one of them") + msg = "got both num_shards and shards_num in kwargs, expected only one of them" + raise ParamError(message=msg) num_shards = kwargs[list(same_key)[0]] if not isinstance(num_shards, int): - raise ParamError(message=f"invalid num_shards type, got {type(num_shards)}, expected int") + msg = f"invalid num_shards type, got {type(num_shards)}, expected int" + raise ParamError(message=msg) req.shards_num = num_shards num_partitions = kwargs.get("num_partitions", None) if num_partitions is not None: if not isinstance(num_partitions, int) or isinstance(num_partitions, bool): - raise ParamError(message=f"invalid num_partitions type, got {type(num_partitions)}, expected int") + msg = f"invalid num_partitions type, got {type(num_partitions)}, expected int" + raise ParamError(message=msg) if num_partitions < 1: - raise ParamError( - message=f"The specified num_partitions should be greater than or equal to 1, got {num_partitions}") + msg = f"The specified num_partitions should be greater than or equal to 1, got {num_partitions}" + raise ParamError(message=msg) req.num_partitions = num_partitions return req @classmethod - def get_schema_from_collection_schema(cls, collection_name: str, fields: CollectionSchema, - **kwargs) -> schema_types.CollectionSchema: + def get_schema_from_collection_schema( + cls, + collection_name: str, + fields: CollectionSchema, + ) -> schema_types.CollectionSchema: coll_description = fields.description if not isinstance(coll_description, (str, bytes)): - raise ParamError( - message=f"description [{coll_description}] has type {type(coll_description).__name__}, but expected one of: bytes, str") - - schema = schema_types.CollectionSchema(name=collection_name, - autoID=fields.auto_id, - description=coll_description, - enable_dynamic_field=fields.enable_dynamic_field) + msg = ( + f"description [{coll_description}] has type {type(coll_description).__name__}, " + "but expected one of: bytes, str" + ) + raise ParamError(message=msg) + + schema = schema_types.CollectionSchema( + name=collection_name, + autoID=fields.auto_id, + description=coll_description, + enable_dynamic_field=fields.enable_dynamic_field, + ) for f in fields.fields: - field_schema = schema_types.FieldSchema(name=f.name, - data_type=f.dtype, - description=f.description, - is_primary_key=f.is_primary, - autoID=f.auto_id, - is_partition_key=f.is_partition_key, - default_value=f.default_value, - is_dynamic=f.is_dynamic) + field_schema = schema_types.FieldSchema( + name=f.name, + data_type=f.dtype, + description=f.description, + is_primary_key=f.is_primary, + autoID=f.auto_id, + is_partition_key=f.is_partition_key, + default_value=f.default_value, + is_dynamic=f.is_dynamic, + ) for k, v in f.params.items(): kv_pair = common_types.KeyValuePair(key=str(k), value=str(v)) field_schema.type_params.append(kv_pair) @@ -105,8 +125,69 @@ def get_schema_from_collection_schema(cls, collection_name: str, fields: Collect schema.fields.append(field_schema) return schema + @staticmethod + def get_field_schema( + field: Dict, + primary_field: Any, + auto_id_field: Any, + ) -> (schema_types.FieldSchema, Any, Any): + field_name = field.get("name") + if field_name is None: + raise ParamError(message="You should specify the name of field!") + + data_type = field.get("type") + if data_type is None: + raise ParamError(message="You should specify the data type of field!") + if not isinstance(data_type, (int, DataType)): + raise ParamError(message="Field type must be of DataType!") + + is_primary = field.get("is_primary", False) + if not isinstance(is_primary, bool): + raise ParamError(message="is_primary must be boolean") + if is_primary: + if primary_field is not None: + raise ParamError(message="A collection should only have one primary field") + if DataType(data_type) not in [DataType.INT64, DataType.VARCHAR]: + msg = "int64 and varChar are the only supported types of primary key" + raise ParamError(message=msg) + primary_field = field_name + + auto_id = field.get("auto_id", False) + if not isinstance(auto_id, bool): + raise ParamError(message="auto_id must be boolean") + if auto_id: + if auto_id_field is not None: + raise ParamError(message="A collection should only have one autoID field") + if DataType(data_type) != DataType.INT64: + msg = "int64 is the only supported type of automatic generated id" + raise ParamError(message=msg) + auto_id_field = field_name + + field_schema = schema_types.FieldSchema( + name=field_name, + data_type=data_type, + default_value=field.get("default_value", None), + description=field.get("description", ""), + is_primary_key=is_primary, + autoID=auto_id, + is_partition_key=field.get("is_partition_key", False), + ) + + type_params = field.get("params", {}) + if not isinstance(type_params, dict): + raise ParamError(message="params should be dictionary type") + kvs = [common_types.KeyValuePair(key=str(k), value=str(v)) for k, v in type_params.items()] + field_schema.type_params.extend(kvs) + + return field_schema, primary_field, auto_id_field + @classmethod - def get_schema(cls, collection_name: str, fields: Dict[str, Iterable], **kwargs) -> schema_types.CollectionSchema: + def get_schema( + cls, + collection_name: str, + fields: Dict[str, Iterable], + **kwargs, + ) -> schema_types.CollectionSchema: if not isinstance(fields, dict): raise ParamError(message="Param fields must be a dict") @@ -120,93 +201,54 @@ def get_schema(cls, collection_name: str, fields: Dict[str, Iterable], **kwargs) if "enable_dynamic_field" in fields: enable_dynamic_field = fields["enable_dynamic_field"] - schema = schema_types.CollectionSchema(name=collection_name, - autoID=False, - description=fields.get('description', ''), - enable_dynamic_field=enable_dynamic_field) + schema = schema_types.CollectionSchema( + name=collection_name, + autoID=False, + description=fields.get("description", ""), + enable_dynamic_field=enable_dynamic_field, + ) - primary_field = None - auto_id_field = None + primary_field, auto_id_field = None, None for field in all_fields: - field_name = field.get('name') - if field_name is None: - raise ParamError(message="You should specify the name of field!") - - data_type = field.get('type') - if data_type is None: - raise ParamError(message="You should specify the data type of field!") - if not isinstance(data_type, (int, DataType)): - raise ParamError(message="Field type must be of DataType!") - - is_primary = field.get("is_primary", False) - if not isinstance(is_primary, bool): - raise ParamError(message="is_primary must be boolean") - if is_primary: - if primary_field is not None: - raise ParamError(message="A collection should only have one primary field") - if DataType(data_type) not in [DataType.INT64, DataType.VARCHAR]: - raise ParamError(message="int64 and varChar are the only supported types of primary key") - primary_field = field_name - - auto_id = field.get('auto_id', False) - if not isinstance(auto_id, bool): - raise ParamError(message="auto_id must be boolean") - if auto_id: - if auto_id_field is not None: - raise ParamError(message="A collection should only have one autoID field") - if DataType(data_type) != DataType.INT64: - raise ParamError(message="int64 is the only supported type of automatic generated id") - auto_id_field = field_name - - field_schema = schema_types.FieldSchema(name=field_name, - data_type=data_type, - default_value=field.get("default_value", None), - description=field.get('description', ''), - is_primary_key=is_primary, - autoID=auto_id, - is_partition_key=field.get("is_partition_key", False)) - - type_params = field.get('params', {}) - if not isinstance(type_params, dict): - raise ParamError(message="params should be dictionary type") - kvs = [common_types.KeyValuePair(key=str(k), value=str(v)) for k, v in type_params.items()] - field_schema.type_params.extend(kvs) + (field_schema, primary_field, auto_id_field) = cls.get_field_schema( + field, primary_field, auto_id_field + ) schema.fields.append(field_schema) return schema @classmethod - def drop_collection_request(cls, collection_name): + def drop_collection_request(cls, collection_name: str) -> milvus_types.DropCollectionRequest: return milvus_types.DropCollectionRequest(collection_name=collection_name) @classmethod - # TODO remove - def has_collection_request(cls, collection_name): - return milvus_types.HasCollectionRequest(collection_name=collection_name) - - @classmethod - def describe_collection_request(cls, collection_name): + def describe_collection_request( + cls, + collection_name: str, + ) -> milvus_types.DescribeCollectionRequest: return milvus_types.DescribeCollectionRequest(collection_name=collection_name) @classmethod - def alter_collection_request(cls, collection_name, properties): - kvs = [] - for k in properties: - kv = common_types.KeyValuePair(key=k, value=str(properties[k])) - kvs.append(kv) + def alter_collection_request( + cls, + collection_name: str, + properties: Dict, + ) -> milvus_types.AlterCollectionRequest: + kvs = [common_types.KeyDataPair(key=k, value=str(v)) for k, v in properties.items()] return milvus_types.AlterCollectionRequest(collection_name=collection_name, properties=kvs) @classmethod - def collection_stats_request(cls, collection_name): + def collection_stats_request(cls, collection_name: str): return milvus_types.CollectionStatsRequest(collection_name=collection_name) @classmethod - def show_collections_request(cls, collection_names=None): + def show_collections_request(cls, collection_names: Optional[List[str]] = None): req = milvus_types.ShowCollectionsRequest() if collection_names: if not isinstance(collection_names, (list,)): - raise ParamError(message=f"collection_names must be a list of strings, but got: {collection_names}") + msg = f"collection_names must be a list of strings, but got: {collection_names}" + raise ParamError(message=msg) for collection_name in collection_names: check_pass_param(collection_name=collection_name) req.collection_names.extend(collection_names) @@ -214,32 +256,46 @@ def show_collections_request(cls, collection_names=None): return req @classmethod - def rename_collections_request(cls, old_name=None, new_name=None): + def rename_collections_request(cls, old_name: str, new_name: str): return milvus_types.RenameCollectionRequest(oldName=old_name, newName=new_name) @classmethod - def create_partition_request(cls, collection_name, partition_name): - return milvus_types.CreatePartitionRequest(collection_name=collection_name, partition_name=partition_name) + def create_partition_request(cls, collection_name: str, partition_name: str): + return milvus_types.CreatePartitionRequest( + collection_name=collection_name, partition_name=partition_name + ) @classmethod - def drop_partition_request(cls, collection_name, partition_name): - return milvus_types.DropPartitionRequest(collection_name=collection_name, partition_name=partition_name) + def drop_partition_request(cls, collection_name: str, partition_name: str): + return milvus_types.DropPartitionRequest( + collection_name=collection_name, partition_name=partition_name + ) @classmethod - def has_partition_request(cls, collection_name, partition_name): - return milvus_types.HasPartitionRequest(collection_name=collection_name, partition_name=partition_name) + def has_partition_request(cls, collection_name: str, partition_name: str): + return milvus_types.HasPartitionRequest( + collection_name=collection_name, partition_name=partition_name + ) @classmethod - def partition_stats_request(cls, collection_name, partition_name): - return milvus_types.PartitionStatsRequest(collection_name=collection_name, partition_name=partition_name) + def partition_stats_request(cls, collection_name: str, partition_name: str): + return milvus_types.PartitionStatsRequest( + collection_name=collection_name, partition_name=partition_name + ) @classmethod - def show_partitions_request(cls, collection_name, partition_names=None, type_in_memory=False): + def show_partitions_request( + cls, + collection_name: str, + partition_names: Optional[List[str]] = None, + type_in_memory: bool = False, + ): check_pass_param(collection_name=collection_name, partition_name_array=partition_names) req = milvus_types.ShowPartitionsRequest(collection_name=collection_name) if partition_names: if not isinstance(partition_names, (list,)): - raise ParamError(message=f"partition_names must be a list of strings, but got: {partition_names}") + msg = f"partition_names must be a list of strings, but got: {partition_names}" + raise ParamError(msg) for partition_name in partition_names: check_pass_param(partition_name=partition_name) req.partition_names.extend(partition_names) @@ -250,7 +306,9 @@ def show_partitions_request(cls, collection_name, partition_names=None, type_in_ return req @classmethod - def get_loading_progress(cls, collection_name, partition_names=None): + def get_loading_progress( + cls, collection_name: str, partition_names: Optional[List[str]] = None + ): check_pass_param(collection_name=collection_name, partition_name_array=partition_names) req = milvus_types.GetLoadingProgressRequest(collection_name=collection_name) if partition_names: @@ -258,7 +316,7 @@ def get_loading_progress(cls, collection_name, partition_names=None): return req @classmethod - def get_load_state(cls, collection_name, partition_names=None): + def get_load_state(cls, collection_name: str, partition_names: Optional[List[str]] = None): check_pass_param(collection_name=collection_name, partition_name_array=partition_names) req = milvus_types.GetLoadStateRequest(collection_name=collection_name) if partition_names: @@ -267,111 +325,112 @@ def get_load_state(cls, collection_name, partition_names=None): @classmethod def empty(cls): - raise DeprecationWarning("no empty request later") - # return common_types.Empty() + msg = "no empty request later" + raise DeprecationWarning(msg) @classmethod def register_link_request(cls): return milvus_types.RegisterLinkRequest() @classmethod - def partition_name(cls, collection_name, partition_name): + def partition_name(cls, collection_name: str, partition_name: str): if not isinstance(collection_name, str): raise ParamError(message="collection_name must be of str type") if not isinstance(partition_name, str): raise ParamError(message="partition_name must be of str type") - return milvus_types.PartitionName(collection_name=collection_name, - tag=partition_name) + return milvus_types.PartitionName(collection_name=collection_name, tag=partition_name) # pylint: disable=too-many-statements @classmethod - def row_insert_or_upsert_param(cls, collection_name, entities, partition_name, fields_info=None, is_insert=True, - enable_dynamic=False, **kwargs): - # insert_request.hash_keys and upsert_request.hash_keys won't be filled in client. It will be filled in proxy. - - tag = partition_name if isinstance(partition_name, str) else "" - request = milvus_types.InsertRequest(collection_name=collection_name, partition_name=tag) - - if not is_insert: - request = milvus_types.UpsertRequest(collection_name=collection_name, partition_name=tag) - + def row_insert_or_upsert_param( + cls, + collection_name: str, + entities: List, + partition_name: str, + fields_info: Any, + is_insert: bool = True, + enable_dynamic: bool = False, + ): if not fields_info: raise ParamError(message="Missing collection meta to validate entities") - _, _, auto_id_loc = traverse_rows_info(fields_info, entities) - - meta_field = schema_types.FieldData() - fields_data, field_info_map = {}, {} - for field in fields_info: - if field.get("auto_id", False): - continue - field_name, field_type = field["name"], field["type"] - field_info_map[field_name] = field - field_data = schema_types.FieldData() - field_data.field_name = field_name - field_data.type = field_type - fields_data[field_name] = field_data + # insert_request.hash_keys and upsert_request.hash_keys won't be filled in client. + tag = partition_name if isinstance(partition_name, str) else "" + if is_insert: + request = milvus_types.InsertRequest( + collection_name=collection_name, partition_name=tag, num_rows=len(entities) + ) + else: + request = milvus_types.UpsertRequest( + collection_name=collection_name, partition_name=tag, num_rows=len(entities) + ) + + fields_data = { + field["name"]: schema_types.FieldData(field_name=field["name"], type=field["type"]) + for field in fields_info + if not field.get("auto_id", False) + } + field_info_map = { + field["name"]: field for field in fields_info if not field.get("auto_id", False) + } - if enable_dynamic: - meta_field.is_dynamic, meta_field.type = True, DataType.JSON + meta_field = ( + schema_types.FieldData(is_dynamic=True, type=DataType.JSON) if enable_dynamic else None + ) + if meta_field is not None: field_info_map[meta_field.field_name] = meta_field fields_data[meta_field.field_name] = meta_field try: for entity in entities: - json_dict = {} - for key in entity: - if key in fields_data: - field_info, field_data = field_info_map[key], fields_data[key] - entity_helper.pack_field_value_to_field_data(entity[key], field_data, field_info) - elif enable_dynamic: - json_dict[key] = entity[key] - else: + for k, v in entity.items(): + if k not in fields_data and not enable_dynamic: raise DataNotMatchException(message=ExceptionsMessage.InsertUnexpectedField) - if enable_dynamic: + if k in fields_data: + field_info, field_data = field_info_map[k], fields_data[k] + entity_helper.pack_field_value_to_field_data(v, field_data, field_info) + + json_dict = { + k: v for k, v in entity.items() if k not in fields_data and enable_dynamic + } + + if enable_dynamic and len(json_dict) > 0: json_value = entity_helper.convert_to_json(json_dict) meta_field.scalars.json_data.data.append(json_value) except (TypeError, ValueError) as e: raise DataNotMatchException(message=ExceptionsMessage.DataTypeInconsistent) from e - request.num_rows = len(entities) - for field in fields_info: - if not field.get("auto_id", False): - field_name = field["name"] - field_data = fields_data[field_name] - request.fields_data.append(field_data) + request.fields_data.extend( + [fields_data[field["name"]] for field in fields_info if not field.get("auto_id", False)] + ) if enable_dynamic: request.fields_data.append(meta_field) + _, _, auto_id_loc = traverse_rows_info(fields_info, entities) if auto_id_loc is not None: - if enable_dynamic: - # len(fields_data) = len(fields_info) - 1(auto_ID) + 1 (dynamic_field) - if len(fields_data) != len(fields_info): - raise ParamError(ExceptionsMessage.FieldsNumInconsistent) - # len(fields_data) = len(fields_info) - 1(auto_ID) - elif len(fields_data) + 1 != len(fields_info): - raise ParamError(ExceptionsMessage.FieldsNumInconsistent) - elif enable_dynamic: - if len(fields_data) != len(fields_info) + 1: + if (enable_dynamic and len(fields_data) != len(fields_info)) or ( + not enable_dynamic and len(fields_data) + 1 != len(fields_info) + ): raise ParamError(ExceptionsMessage.FieldsNumInconsistent) + elif enable_dynamic and len(fields_data) != len(fields_info) + 1: + raise ParamError(ExceptionsMessage.FieldsNumInconsistent) return request @classmethod - def batch_insert_or_upsert_param(cls, collection_name, entities, partition_name, fields_info=None, is_insert=True, - **kwargs): - # insert_request.hash_keys and upsert_request.hash_keys won't be filled in client. It will be filled in proxy. - - tag = partition_name if isinstance(partition_name, str) else "" - request = milvus_types.InsertRequest(collection_name=collection_name, partition_name=tag) - - if not is_insert: - request = milvus_types.UpsertRequest(collection_name=collection_name, partition_name=tag) - + def batch_insert_or_upsert_param( + cls, + collection_name: str, + entities: List, + partition_name: str, + fields_info: Any, + is_insert: bool = True, + ): + # insert_request.hash_keys and upsert_request.hash_keys won't be filled in client. for entity in entities: - if not entity.get("name", None) or not entity.get("type", None): + if not entity.get("name") or not entity.get("type"): raise ParamError(message="Missing param in entities, a field must have type, name") if not fields_info: raise ParamError(message="Missing collection meta to validate entities") @@ -383,10 +442,22 @@ def batch_insert_or_upsert_param(cls, collection_name, entities, partition_name, raise ParamError(message="primary key not found") if auto_id_loc is None and len(entities) != len(fields_info): - raise ParamError(message=f"number of fields: {len(fields_info)}, number of entities: {len(entities)}") + msg = f"number of fields: {len(fields_info)}, number of entities: {len(entities)}" + raise ParamError(msg) if auto_id_loc is not None and len(entities) + 1 != len(fields_info): - raise ParamError(message=f"number of fields: {len(fields_info)}, number of entities: {len(entities)}") + msg = f"number of fields: {len(fields_info)}, number of entities: {len(entities)}" + raise ParamError(msg) + + tag = partition_name if isinstance(partition_name, str) else "" + if is_insert: + request = milvus_types.InsertRequest( + collection_name=collection_name, partition_name=tag + ) + else: + request = milvus_types.UpsertRequest( + collection_name=collection_name, partition_name=tag + ) row_num = 0 try: @@ -395,9 +466,12 @@ def batch_insert_or_upsert_param(cls, collection_name, entities, partition_name, # if current length is zero, consider use default value if current != 0: if row_num not in (0, current): - raise ParamError(message="row num misaligned current[{current}]!= previous[{row_num}]") + msg = f"row num misaligned current[{current}]!= previous[{row_num}]" + raise ParamError(msg) row_num = current - field_data = entity_helper.entity_to_field_data(entity, fields_info[location[entity.get("name")]]) + field_data = entity_helper.entity_to_field_data( + entity, fields_info[location[entity.get("name")]] + ) request.fields_data.append(field_data) except (TypeError, ValueError) as e: raise DataNotMatchException(message=ExceptionsMessage.DataTypeInconsistent) from e @@ -409,13 +483,13 @@ def batch_insert_or_upsert_param(cls, collection_name, entities, partition_name, return request @classmethod - def delete_request(cls, collection_name, partition_name, expr): - def check_str(instr, prefix): + def delete_request(cls, collection_name: str, partition_name: str, expr: str): + def check_str(instr: str, prefix: str): if instr is None: raise ParamError(message=f"{prefix} cannot be None") if not isinstance(instr, str): raise ParamError(message=f"{prefix} value {instr} is illegal") - if instr == "": + if len(instr) == 0: raise ParamError(message=f"{prefix} cannot be empty") check_str(collection_name, "collection_name") @@ -423,23 +497,35 @@ def check_str(instr, prefix): check_str(partition_name, "partition_name") check_str(expr, "expr") - request = milvus_types.DeleteRequest(collection_name=collection_name, expr=expr, partition_name=partition_name) - return request + return milvus_types.DeleteRequest( + collection_name=collection_name, partition_name=partition_name, expr=expr + ) @classmethod - def _prepare_placeholders(cls, vectors, nq, tag, pl_type, is_binary): + def _prepare_placeholders(cls, vectors: Any, nq: int, tag: Any, pl_type: Any, is_binary: bool): pl = common_types.PlaceholderValue(tag=tag) pl.type = pl_type for i in range(0, nq): if is_binary: - pl.values.append(blob.vectorBinaryToBytes(vectors[i])) + pl.values.append(blob.vector_binary_to_bytes(vectors[i])) else: - pl.values.append(blob.vectorFloatToBytes(vectors[i])) + pl.values.append(blob.vector_float_to_bytes(vectors[i])) return pl @classmethod - def search_requests_with_expr(cls, collection_name, data, anns_field, param, limit, expr=None, - partition_names=None, output_fields=None, round_decimal=-1, **kwargs): + def search_requests_with_expr( + cls, + collection_name: str, + data: List, + anns_field: str, + param: Dict, + limit: int, + expr: Optional[str] = None, + partition_names: Optional[List[str]] = None, + output_fields: Optional[List[str]] = None, + round_decimal: int = -1, + **kwargs, + ): requests = [] if len(data) <= 0: return requests @@ -471,7 +557,7 @@ def search_requests_with_expr(cls, collection_name, data, anns_field, param, lim if anns_field: search_params["anns_field"] = anns_field - def dump(v): + def dump(v: Dict): if isinstance(v, dict): return ujson.dumps(v) return str(v) @@ -497,139 +583,193 @@ def dump(v): request.dsl_type = common_types.DslType.BoolExprV1 if expr is not None: request.dsl = expr - request.search_params.extend([common_types.KeyValuePair(key=str(key), value=dump(value)) - for key, value in search_params.items()]) + request.search_params.extend( + [ + common_types.KeyValuePair(key=str(key), value=dump(value)) + for key, value in search_params.items() + ] + ) requests.append(request) return requests @classmethod - def create_alias_request(cls, collection_name, alias): + def create_alias_request(cls, collection_name: str, alias: str): return milvus_types.CreateAliasRequest(collection_name=collection_name, alias=alias) @classmethod - def drop_alias_request(cls, alias): + def drop_alias_request(cls, alias: str): return milvus_types.DropAliasRequest(alias=alias) @classmethod - def alter_alias_request(cls, collection_name, alias): + def alter_alias_request(cls, collection_name: str, alias: str): return milvus_types.AlterAliasRequest(collection_name=collection_name, alias=alias) @classmethod - def create_index_request(cls, collection_name, field_name, params, **kwargs): - index_params = milvus_types.CreateIndexRequest(collection_name=collection_name, field_name=field_name, - index_name=kwargs.get("index_name", "")) - - # index_params.collection_name = collection_name - # index_params.field_name = field_name + def create_index_request(cls, collection_name: str, field_name: str, params: Dict, **kwargs): + index_params = milvus_types.CreateIndexRequest( + collection_name=collection_name, + field_name=field_name, + index_name=kwargs.get("index_name", ""), + ) - def dump(tv): + def dump(tv: Dict): if isinstance(tv, dict): return ujson.dumps(tv) return str(tv) if isinstance(params, dict): for tk, tv in params.items(): - if tk == "dim": - if not tv or not isinstance(tv, int): - raise ParamError(message="dim must be of int!") + if tk == "dim" and (not tv or not isinstance(tv, int)): + raise ParamError(message="dim must be of int!") kv_pair = common_types.KeyValuePair(key=str(tk), value=dump(tv)) index_params.extra_params.append(kv_pair) return index_params @classmethod - def describe_index_request(cls, collection_name, index_name): - return milvus_types.DescribeIndexRequest(collection_name=collection_name, index_name=index_name) + def describe_index_request(cls, collection_name: str, index_name: str): + return milvus_types.DescribeIndexRequest( + collection_name=collection_name, index_name=index_name + ) @classmethod def get_index_build_progress(cls, collection_name: str, index_name: str): - return milvus_types.GetIndexBuildProgressRequest(collection_name=collection_name, index_name=index_name) + return milvus_types.GetIndexBuildProgressRequest( + collection_name=collection_name, index_name=index_name + ) @classmethod def get_index_state_request(cls, collection_name: str, index_name: str): - return milvus_types.GetIndexStateRequest(collection_name=collection_name, index_name=index_name) + return milvus_types.GetIndexStateRequest( + collection_name=collection_name, index_name=index_name + ) @classmethod - def load_collection(cls, db_name, collection_name, replica_number, refresh, resource_groups): - return milvus_types.LoadCollectionRequest(db_name=db_name, collection_name=collection_name, - replica_number=replica_number, refresh=refresh, - resource_groups=resource_groups) + def load_collection( + cls, + db_name: str, + collection_name: str, + replica_number: int, + refresh: bool, + resource_groups: List[str], + ): + return milvus_types.LoadCollectionRequest( + db_name=db_name, + collection_name=collection_name, + replica_number=replica_number, + refresh=refresh, + resource_groups=resource_groups, + ) @classmethod - def release_collection(cls, db_name, collection_name): - return milvus_types.ReleaseCollectionRequest(db_name=db_name, collection_name=collection_name) + def release_collection(cls, db_name: str, collection_name: str): + return milvus_types.ReleaseCollectionRequest( + db_name=db_name, collection_name=collection_name + ) @classmethod - def load_partitions(cls, db_name, collection_name, partition_names, replica_number, refresh, resource_groups): - return milvus_types.LoadPartitionsRequest(db_name=db_name, collection_name=collection_name, - partition_names=partition_names, - replica_number=replica_number, - refresh=refresh, - resource_groups=resource_groups) + def load_partitions( + cls, + db_name: str, + collection_name: str, + partition_names: List[str], + replica_number: int, + refresh: bool, + resource_groups: List[str], + ): + return milvus_types.LoadPartitionsRequest( + db_name=db_name, + collection_name=collection_name, + partition_names=partition_names, + replica_number=replica_number, + refresh=refresh, + resource_groups=resource_groups, + ) @classmethod - def release_partitions(cls, db_name, collection_name, partition_names): - return milvus_types.ReleasePartitionsRequest(db_name=db_name, collection_name=collection_name, - partition_names=partition_names) + def release_partitions(cls, db_name: str, collection_name: str, partition_names: List[str]): + return milvus_types.ReleasePartitionsRequest( + db_name=db_name, collection_name=collection_name, partition_names=partition_names + ) @classmethod - def get_collection_stats_request(cls, collection_name): + def get_collection_stats_request(cls, collection_name: str): return milvus_types.GetCollectionStatisticsRequest(collection_name=collection_name) @classmethod - def get_persistent_segment_info_request(cls, collection_name): + def get_persistent_segment_info_request(cls, collection_name: str): return milvus_types.GetPersistentSegmentInfoRequest(collectionName=collection_name) @classmethod - def get_flush_state_request(cls, segment_ids): + def get_flush_state_request(cls, segment_ids: List[int]): return milvus_types.GetFlushStateRequest(segmentIDs=segment_ids) @classmethod - def get_query_segment_info_request(cls, collection_name): + def get_query_segment_info_request(cls, collection_name: str): return milvus_types.GetQuerySegmentInfoRequest(collectionName=collection_name) @classmethod - def flush_param(cls, collection_names): + def flush_param(cls, collection_names: List[str]): return milvus_types.FlushRequest(collection_names=collection_names) @classmethod - def drop_index_request(cls, collection_name, field_name, index_name): - return milvus_types.DropIndexRequest(db_name="", collection_name=collection_name, field_name=field_name, - index_name=index_name) + def drop_index_request(cls, collection_name: str, field_name: str, index_name: str): + return milvus_types.DropIndexRequest( + db_name="", + collection_name=collection_name, + field_name=field_name, + index_name=index_name, + ) @classmethod - def get_partition_stats_request(cls, collection_name, partition_name): - return milvus_types.GetPartitionStatisticsRequest(db_name="", collection_name=collection_name, - partition_name=partition_name) + def get_partition_stats_request(cls, collection_name: str, partition_name: str): + return milvus_types.GetPartitionStatisticsRequest( + db_name="", collection_name=collection_name, partition_name=partition_name + ) @classmethod - def dummy_request(cls, request_type): + def dummy_request(cls, request_type: Any): return milvus_types.DummyRequest(request_type=request_type) @classmethod - def retrieve_request(cls, collection_name, ids, output_fields, partition_names): + def retrieve_request( + cls, + collection_name: str, + ids: List[str], + output_fields: List[str], + partition_names: List[str], + ): ids = schema_types.IDs(int_id=schema_types.LongArray(data=ids)) - return milvus_types.RetrieveRequest(db_name="", - collection_name=collection_name, - ids=ids, - output_fields=output_fields, - partition_names=partition_names) + return milvus_types.RetrieveRequest( + db_name="", + collection_name=collection_name, + ids=ids, + output_fields=output_fields, + partition_names=partition_names, + ) @classmethod - def query_request(cls, collection_name, expr, output_fields, partition_names, **kwargs): - + def query_request( + cls, + collection_name: str, + expr: str, + output_fields: List[str], + partition_names: List[str], + **kwargs, + ): use_default_consistency = ts_utils.construct_guarantee_ts(collection_name, kwargs) - req = milvus_types.QueryRequest(db_name="", - collection_name=collection_name, - expr=expr, - output_fields=output_fields, - partition_names=partition_names, - guarantee_timestamp=kwargs.get("guarantee_timestamp", 0), - travel_timestamp=kwargs.get("travel_timestamp", 0), - use_default_consistency=use_default_consistency, - consistency_level=kwargs.get("consistency_level", 0) - ) + req = milvus_types.QueryRequest( + db_name="", + collection_name=collection_name, + expr=expr, + output_fields=output_fields, + partition_names=partition_names, + guarantee_timestamp=kwargs.get("guarantee_timestamp", 0), + travel_timestamp=kwargs.get("travel_timestamp", 0), + use_default_consistency=use_default_consistency, + consistency_level=kwargs.get("consistency_level", 0), + ) limit = kwargs.get("limit", None) if limit is not None: @@ -640,25 +780,35 @@ def query_request(cls, collection_name, expr, output_fields, partition_names, ** req.query_params.append(common_types.KeyValuePair(key="offset", value=str(offset))) ignore_growing = kwargs.get("ignore_growing", False) - req.query_params.append(common_types.KeyValuePair(key="ignore_growing", value=str(ignore_growing))) + req.query_params.append( + common_types.KeyValuePair(key="ignore_growing", value=str(ignore_growing)) + ) use_iteration_extension_reduce_rate = kwargs.get(ITERATION_EXTENSION_REDUCE_RATE, 0) - req.query_params.append(common_types.KeyValuePair(key=ITERATION_EXTENSION_REDUCE_RATE, - value=str(use_iteration_extension_reduce_rate))) + req.query_params.append( + common_types.KeyValuePair( + key=ITERATION_EXTENSION_REDUCE_RATE, value=str(use_iteration_extension_reduce_rate) + ) + ) return req @classmethod - def load_balance_request(cls, collection_name, src_node_id, dst_node_ids, sealed_segment_ids): - request = milvus_types.LoadBalanceRequest( + def load_balance_request( + cls, + collection_name: str, + src_node_id: int, + dst_node_ids: List[int], + sealed_segment_ids: List[int], + ): + return milvus_types.LoadBalanceRequest( collectionName=collection_name, src_nodeID=src_node_id, dst_nodeIDs=dst_node_ids, sealed_segmentIDs=sealed_segment_ids, ) - return request @classmethod - def manual_compaction(cls, collection_id, timetravel): + def manual_compaction(cls, collection_id: int, timetravel: int): if collection_id is None or not isinstance(collection_id, int): raise ParamError(message=f"collection_id value {collection_id} is illegal") @@ -694,11 +844,10 @@ def get_replicas(cls, collection_id: int): if collection_id is None or not isinstance(collection_id, int): raise ParamError(message=f"collection_id value {collection_id} is illegal") - request = milvus_types.GetReplicasRequest( + return milvus_types.GetReplicasRequest( collectionID=collection_id, with_shard_nodes=True, ) - return request @classmethod def do_bulk_insert(cls, collection_name: str, partition_name: str, files: list, **kwargs): @@ -719,41 +868,44 @@ def do_bulk_insert(cls, collection_name: str, partition_name: str, files: list, return req @classmethod - def get_bulk_insert_state(cls, task_id): + def get_bulk_insert_state(cls, task_id: int): if task_id is None or not isinstance(task_id, int): - raise ParamError(f"task_id value {task_id} is not an integer") + msg = f"task_id value {task_id} is not an integer" + raise ParamError(msg) - req = milvus_types.GetImportStateRequest(task=task_id) - return req + return milvus_types.GetImportStateRequest(task=task_id) @classmethod - def list_bulk_insert_tasks(cls, limit, collection_name): + def list_bulk_insert_tasks(cls, limit: int, collection_name: str): if limit is None or not isinstance(limit, int): - raise ParamError(f"limit value {limit} is not an integer") + msg = f"limit value {limit} is not an integer" + raise ParamError(msg) - request = milvus_types.ListImportTasksRequest( + return milvus_types.ListImportTasksRequest( collection_name=collection_name, limit=limit, ) - return request @classmethod - def create_user_request(cls, user, password): + def create_user_request(cls, user: str, password: str): check_pass_param(user=user, password=password) - return milvus_types.CreateCredentialRequest(username=user, password=base64.b64encode(password.encode('utf-8'))) + return milvus_types.CreateCredentialRequest( + username=user, password=base64.b64encode(password.encode("utf-8")) + ) @classmethod - def update_password_request(cls, user, old_password, new_password): + def update_password_request(cls, user: str, old_password: str, new_password: str): check_pass_param(user=user) check_pass_param(password=old_password) check_pass_param(password=new_password) - return milvus_types.UpdateCredentialRequest(username=user, - oldPassword=base64.b64encode(old_password.encode('utf-8')), - newPassword=base64.b64encode(new_password.encode('utf-8')), - ) + return milvus_types.UpdateCredentialRequest( + username=user, + oldPassword=base64.b64encode(old_password.encode("utf-8")), + newPassword=base64.b64encode(new_password.encode("utf-8")), + ) @classmethod - def delete_user_request(cls, user): + def delete_user_request(cls, user: str): if not isinstance(user, str): raise ParamError(message=f"invalid user {user}") return milvus_types.DeleteCredentialRequest(username=user) @@ -763,79 +915,99 @@ def list_usernames_request(cls): return milvus_types.ListCredUsersRequest() @classmethod - def create_role_request(cls, role_name): + def create_role_request(cls, role_name: str): check_pass_param(role_name=role_name) return milvus_types.CreateRoleRequest(entity=milvus_types.RoleEntity(name=role_name)) @classmethod - def drop_role_request(cls, role_name): + def drop_role_request(cls, role_name: str): check_pass_param(role_name=role_name) return milvus_types.DropRoleRequest(role_name=role_name) @classmethod - def operate_user_role_request(cls, username, role_name, operate_user_role_type): + def operate_user_role_request(cls, username: str, role_name: str, operate_user_role_type: Any): check_pass_param(user=username) check_pass_param(role_name=role_name) check_pass_param(operate_user_role_type=operate_user_role_type) - return milvus_types.OperateUserRoleRequest(username=username, role_name=role_name, type=operate_user_role_type) + return milvus_types.OperateUserRoleRequest( + username=username, role_name=role_name, type=operate_user_role_type + ) @classmethod - def select_role_request(cls, role_name, include_user_info): + def select_role_request(cls, role_name: str, include_user_info: bool): if role_name: check_pass_param(role_name=role_name) check_pass_param(include_user_info=include_user_info) - return milvus_types.SelectRoleRequest(role=milvus_types.RoleEntity(name=role_name) if role_name else None, - include_user_info=include_user_info) + return milvus_types.SelectRoleRequest( + role=milvus_types.RoleEntity(name=role_name) if role_name else None, + include_user_info=include_user_info, + ) @classmethod - def select_user_request(cls, username, include_role_info): + def select_user_request(cls, username: str, include_role_info: bool): if username: check_pass_param(user=username) check_pass_param(include_role_info=include_role_info) - return milvus_types.SelectUserRequest(user=milvus_types.UserEntity(name=username) if username else None, - include_role_info=include_role_info) + return milvus_types.SelectUserRequest( + user=milvus_types.UserEntity(name=username) if username else None, + include_role_info=include_role_info, + ) @classmethod - def operate_privilege_request(cls, role_name, object, object_name, privilege, db_name, operate_privilege_type): + def operate_privilege_request( + cls, + role_name: str, + object: Any, + object_name: str, + privilege: str, + db_name: str, + operate_privilege_type: Any, + ): check_pass_param(role_name=role_name) check_pass_param(object=object) check_pass_param(object_name=object_name) check_pass_param(privilege=privilege) check_pass_param(operate_privilege_type=operate_privilege_type) return milvus_types.OperatePrivilegeRequest( - entity=milvus_types.GrantEntity(role=milvus_types.RoleEntity(name=role_name), - object=milvus_types.ObjectEntity(name=object), - object_name=object_name, - db_name=db_name, - grantor=milvus_types.GrantorEntity( - privilege=milvus_types.PrivilegeEntity(name=privilege))), - type=operate_privilege_type) + entity=milvus_types.GrantEntity( + role=milvus_types.RoleEntity(name=role_name), + object=milvus_types.ObjectEntity(name=object), + object_name=object_name, + db_name=db_name, + grantor=milvus_types.GrantorEntity( + privilege=milvus_types.PrivilegeEntity(name=privilege) + ), + ), + type=operate_privilege_type, + ) @classmethod - def select_grant_request(cls, role_name, object, object_name, db_name): + def select_grant_request(cls, role_name: str, object: str, object_name: str, db_name: str): check_pass_param(role_name=role_name) if object: check_pass_param(object=object) if object_name: check_pass_param(object_name=object_name) return milvus_types.SelectGrantRequest( - entity=milvus_types.GrantEntity(role=milvus_types.RoleEntity(name=role_name), - object=milvus_types.ObjectEntity(name=object) if object else None, - object_name=object_name if object_name else None, - db_name=db_name, - )) + entity=milvus_types.GrantEntity( + role=milvus_types.RoleEntity(name=role_name), + object=milvus_types.ObjectEntity(name=object) if object else None, + object_name=object_name if object_name else None, + db_name=db_name, + ), + ) @classmethod def get_server_version(cls): return milvus_types.GetVersionRequest() @classmethod - def create_resource_group(cls, name): + def create_resource_group(cls, name: str): check_pass_param(resource_group_name=name) return milvus_types.CreateResourceGroupRequest(resource_group=name) @classmethod - def drop_resource_group(cls, name): + def drop_resource_group(cls, name: str): check_pass_param(resource_group_name=name) return milvus_types.DropResourceGroupRequest(resource_group=name) @@ -844,37 +1016,39 @@ def list_resource_groups(cls): return milvus_types.ListResourceGroupsRequest() @classmethod - def describe_resource_group(cls, name): + def describe_resource_group(cls, name: str): check_pass_param(resource_group_name=name) return milvus_types.DescribeResourceGroupRequest(resource_group=name) @classmethod - def transfer_node(cls, source, target, num_node): + def transfer_node(cls, source: str, target: str, num_node: int): check_pass_param(resource_group_name=source) check_pass_param(resource_group_name=target) - return milvus_types.TransferNodeRequest(source_resource_group=source, - target_resource_group=target, - num_node=num_node) + return milvus_types.TransferNodeRequest( + source_resource_group=source, target_resource_group=target, num_node=num_node + ) @classmethod - def transfer_replica(cls, source, target, collection_name, num_replica): + def transfer_replica(cls, source: str, target: str, collection_name: str, num_replica: int): check_pass_param(resource_group_name=source) check_pass_param(resource_group_name=target) - return milvus_types.TransferReplicaRequest(source_resource_group=source, - target_resource_group=target, - collection_name=collection_name, - num_replica=num_replica) + return milvus_types.TransferReplicaRequest( + source_resource_group=source, + target_resource_group=target, + collection_name=collection_name, + num_replica=num_replica, + ) @classmethod def flush_all_request(cls): return milvus_types.FlushAllRequest() @classmethod - def get_flush_all_state_request(cls, flush_all_ts): + def get_flush_all_state_request(cls, flush_all_ts: int): return milvus_types.GetFlushAllStateRequest(flush_all_ts=flush_all_ts) @classmethod - def register_request(cls, user, host, **kwargs): + def register_request(cls, user: str, host: str, **kwargs): reserved = {} for k, v in kwargs.items(): reserved[k] = v @@ -894,18 +1068,15 @@ def register_request(cls, user, host, **kwargs): ) @classmethod - def create_database_req(cls, db_name): + def create_database_req(cls, db_name: str): check_pass_param(db_name=db_name) - req = milvus_types.CreateDatabaseRequest(db_name=db_name) - return req + return milvus_types.CreateDatabaseRequest(db_name=db_name) @classmethod - def drop_database_req(cls, db_name): + def drop_database_req(cls, db_name: str): check_pass_param(db_name=db_name) - req = milvus_types.DropDatabaseRequest(db_name=db_name) - return req + return milvus_types.DropDatabaseRequest(db_name=db_name) @classmethod def list_database_req(cls): - req = milvus_types.ListDatabasesRequest() - return req + return milvus_types.ListDatabasesRequest() diff --git a/pymilvus/client/singleton_utils.py b/pymilvus/client/singleton_utils.py index 0e62b097d..a4597573b 100644 --- a/pymilvus/client/singleton_utils.py +++ b/pymilvus/client/singleton_utils.py @@ -1,8 +1,9 @@ import threading +from typing import ClassVar, Dict class Singleton(type): - _ins = {} + _ins: ClassVar[Dict] = {} _lock = threading.Lock() def __call__(cls, *args, **kwargs): diff --git a/pymilvus/client/stub.py b/pymilvus/client/stub.py index 3f5faefc9..3f327e3f2 100644 --- a/pymilvus/client/stub.py +++ b/pymilvus/client/stub.py @@ -1,17 +1,25 @@ from urllib import parse -from .grpc_handler import GrpcHandler -from ..exceptions import MilvusException, ParamError -from .types import CompactionState, CompactionPlans, Replica, BulkInsertState, ResourceGroupInfo -from ..settings import Config -from ..decorators import deprecated +from pymilvus.decorators import deprecated +from pymilvus.exceptions import MilvusException, ParamError +from pymilvus.settings import Config from .check import is_legal_host, is_legal_port +from .grpc_handler import GrpcHandler +from .types import ( + BulkInsertState, + CompactionPlans, + CompactionState, + Replica, + ResourceGroupInfo, +) class Milvus: @deprecated - def __init__(self, host=None, port=Config.GRPC_PORT, uri=Config.GRPC_URI, channel=None, **kwargs): + def __init__( + self, host=None, port=Config.GRPC_PORT, uri=Config.GRPC_URI, channel=None, **kwargs + ) -> None: self.address = self.__get_address(host, port, uri) self._handler = GrpcHandler(address=self.address, channel=channel, **kwargs) @@ -20,12 +28,12 @@ def __init__(self, host=None, port=Config.GRPC_PORT, uri=Config.GRPC_URI, channe def __get_address(self, host=None, port=Config.GRPC_PORT, uri=Config.GRPC_URI): if host is None and uri is None: - raise ParamError(message='Host and uri cannot both be None') + raise ParamError(message="Host and uri cannot both be None") if host is None: try: parsed_uri = parse.urlparse(uri, "tcp") - except (Exception) as e: + except Exception as e: raise ParamError(message=f"Illegal uri [{uri}]: {e}") from e host, port = parsed_uri.hostname, parsed_uri.port @@ -60,7 +68,7 @@ def close(self): self._handler = None def create_collection(self, collection_name, fields, timeout=None, **kwargs): - """ Creates a collection. + """Creates a collection. :param collection_name: The name of the collection. A collection name can only include numbers, letters, and underscores, and must not begin with a number. @@ -193,7 +201,9 @@ def load_collection(self, collection_name, replica_number=1, timeout=None, **kwa :raises MilvusException: If the return result from server is not ok """ with self._connection() as handler: - return handler.load_collection(collection_name, replica_number, timeout=timeout, **kwargs) + return handler.load_collection( + collection_name, replica_number, timeout=timeout, **kwargs + ) def release_collection(self, collection_name, timeout=None): """ @@ -360,9 +370,12 @@ def load_partitions(self, collection_name, partition_names, replica_number=1, ti :raises MilvusException: If the return result from server is not ok """ with self._connection() as handler: - return handler.load_partitions(collection_name=collection_name, - partition_names=partition_names, - replica_number=replica_number, timeout=timeout) + return handler.load_partitions( + collection_name=collection_name, + partition_names=partition_names, + replica_number=replica_number, + timeout=timeout, + ) def release_partitions(self, collection_name, partition_names, timeout=None): """ @@ -386,8 +399,9 @@ def release_partitions(self, collection_name, partition_names, timeout=None): :raises MilvusException: If the return result from server is not ok """ with self._connection() as handler: - return handler.release_partitions(collection_name=collection_name, - partition_names=partition_names, timeout=timeout) + return handler.release_partitions( + collection_name=collection_name, partition_names=partition_names, timeout=timeout + ) def list_partitions(self, collection_name, timeout=None): """ @@ -433,7 +447,9 @@ def get_partition_stats(self, collection_name, partition_name, timeout=None, **k :raises MilvusException: If the return result from server is not ok """ with self._connection() as handler: - stats = handler.get_partition_stats(collection_name, partition_name, timeout=timeout, **kwargs) + stats = handler.get_partition_stats( + collection_name, partition_name, timeout=timeout, **kwargs + ) result = {stat.key: stat.value for stat in stats} result["row_count"] = int(result["row_count"]) return result @@ -632,7 +648,9 @@ def create_index(self, collection_name, field_name, params, timeout=None, **kwar :raises MilvusException: If the return result from server is not ok """ with self._connection() as handler: - return handler.create_index(collection_name, field_name, params, timeout=timeout, **kwargs) + return handler.create_index( + collection_name, field_name, params, timeout=timeout, **kwargs + ) def drop_index(self, collection_name, field_name, timeout=None): """ @@ -656,8 +674,12 @@ def drop_index(self, collection_name, field_name, timeout=None): :raises MilvusException: If the return result from server is not ok """ with self._connection() as handler: - return handler.drop_index(collection_name=collection_name, - field_name=field_name, index_name="", timeout=timeout) + return handler.drop_index( + collection_name=collection_name, + field_name=field_name, + index_name="", + timeout=timeout, + ) def describe_index(self, collection_name, index_name="", timeout=None): """ @@ -718,7 +740,9 @@ def insert(self, collection_name, entities, partition_name=None, timeout=None, * :raises MilvusException: If the return result from server is not ok """ with self._connection() as handler: - return handler.batch_insert(collection_name, entities, partition_name, timeout=timeout, **kwargs) + return handler.batch_insert( + collection_name, entities, partition_name, timeout=timeout, **kwargs + ) def delete(self, collection_name, expr, partition_name=None, timeout=None, **kwargs): """ @@ -783,8 +807,20 @@ def flush(self, collection_names=None, timeout=None, **kwargs): with self._connection() as handler: return handler.flush(collection_names, timeout=timeout, **kwargs) - def search(self, collection_name, data, anns_field, param, limit, expression=None, partition_names=None, - output_fields=None, timeout=None, round_decimal=-1, **kwargs): + def search( + self, + collection_name, + data, + anns_field, + param, + limit, + expression=None, + partition_names=None, + output_fields=None, + timeout=None, + round_decimal=-1, + **kwargs, + ): """ Searches a collection based on the given expression and returns query results. @@ -842,8 +878,19 @@ def search(self, collection_name, data, anns_field, param, limit, expression=Non :raises MilvusException: If the return result from server is not ok """ with self._connection() as handler: - return handler.search(collection_name, data, anns_field, param, limit, expression, partition_names, - output_fields, round_decimal=round_decimal, timeout=timeout, **kwargs) + return handler.search( + collection_name, + data, + anns_field, + param, + limit, + expression, + partition_names, + output_fields, + round_decimal=round_decimal, + timeout=timeout, + **kwargs, + ) def get_query_segment_info(self, collection_name, timeout=None, **kwargs): """ @@ -862,7 +909,7 @@ def get_query_segment_info(self, collection_name, timeout=None, **kwargs): return handler.get_query_segment_info(collection_name, timeout=timeout, **kwargs) def load_collection_progress(self, collection_name, timeout=None): - """ { + """{ 'loading_progress': '100%', 'num_loaded_partitions': 3, 'not_loaded_partitions': [], @@ -872,14 +919,16 @@ def load_collection_progress(self, collection_name, timeout=None): return handler.load_collection_progress(collection_name, timeout=timeout) def load_partitions_progress(self, collection_name, partition_names, timeout=None): - """ { + """{ 'loading_progress': '100%', 'num_loaded_partitions': 3, 'not_loaded_partitions': [], } """ with self._connection() as handler: - return handler.load_partitions_progress(collection_name, partition_names, timeout=timeout) + return handler.load_partitions_progress( + collection_name, partition_names, timeout=timeout + ) def wait_for_loading_collection_complete(self, collection_name, timeout=None): with self._connection() as handler: @@ -887,7 +936,9 @@ def wait_for_loading_collection_complete(self, collection_name, timeout=None): def wait_for_loading_partitions_complete(self, collection_name, partition_names, timeout=None): with self._connection() as handler: - return handler.wait_for_loading_partitions(collection_name, partition_names, timeout=timeout) + return handler.wait_for_loading_partitions( + collection_name, partition_names, timeout=timeout + ) def get_index_build_progress(self, collection_name, index_name, timeout=None): with self._connection() as handler: @@ -901,7 +952,15 @@ def dummy(self, request_type, timeout=None): with self._connection() as handler: return handler.dummy(request_type, timeout=timeout) - def query(self, collection_name, expr, output_fields=None, partition_names=None, timeout=None, **kwargs): + def query( + self, + collection_name, + expr, + output_fields=None, + partition_names=None, + timeout=None, + **kwargs, + ): """ Query with a set of criteria, and results in a list of records that match the query exactly. @@ -946,9 +1005,19 @@ def query(self, collection_name, expr, output_fields=None, partition_names=None, :raises MilvusException: If the return result from server is not ok """ with self._connection() as handler: - return handler.query(collection_name, expr, output_fields, partition_names, timeout=timeout, **kwargs) + return handler.query( + collection_name, expr, output_fields, partition_names, timeout=timeout, **kwargs + ) - def load_balance(self, collection_name: str, src_node_id, dst_node_ids, sealed_segment_ids, timeout=None, **kwargs): + def load_balance( + self, + collection_name: str, + src_node_id, + dst_node_ids, + sealed_segment_ids, + timeout=None, + **kwargs, + ): """ Do load balancing operation from source query node to destination query node. :param collection_name: The collection to balance. @@ -970,8 +1039,14 @@ def load_balance(self, collection_name: str, src_node_id, dst_node_ids, sealed_s :raises MilvusException: If sealed segments not exist. """ with self._connection() as handler: - return handler.load_balance(collection_name, src_node_id, dst_node_ids, sealed_segment_ids, - timeout=timeout, **kwargs) + return handler.load_balance( + collection_name, + src_node_id, + dst_node_ids, + sealed_segment_ids, + timeout=timeout, + **kwargs, + ) def compact(self, collection_name, timeout=None, **kwargs) -> int: """ @@ -1010,7 +1085,9 @@ def get_compaction_state(self, compaction_id: int, timeout=None, **kwargs) -> Co with self._connection() as handler: return handler.get_compaction_state(compaction_id, timeout=timeout, **kwargs) - def wait_for_compaction_completed(self, compaction_id: int, timeout=None, **kwargs) -> CompactionState: + def wait_for_compaction_completed( + self, compaction_id: int, timeout=None, **kwargs + ) -> CompactionState: with self._connection() as handler: return handler.wait_for_compaction_completed(compaction_id, timeout=timeout, **kwargs) @@ -1033,7 +1110,7 @@ def get_compaction_plans(self, compaction_id: int, timeout=None, **kwargs) -> Co return handler.get_compaction_plans(compaction_id, timeout=timeout, **kwargs) def get_replicas(self, collection_name: str, timeout=None, **kwargs) -> Replica: - """ Get replica infos of a collection + """Get replica infos of a collection :param collection_name: the name of the collection :type collection_name: str @@ -1049,8 +1126,10 @@ def get_replicas(self, collection_name: str, timeout=None, **kwargs) -> Replica: with self._connection() as handler: return handler.get_replicas(collection_name, timeout=timeout, **kwargs) - def do_bulk_insert(self, collection_name: str, partition_name: str, files: list, timeout=None, **kwargs) -> int: - """ do_bulk_insert inserts entities through files, currently supports row-based json file. + def do_bulk_insert( + self, collection_name: str, partition_name: str, files: list, timeout=None, **kwargs + ) -> int: + """do_bulk_insert inserts entities through files, currently supports row-based json file. User need to create the json file with a specified json format which is described in the official user guide. Let's say a collection has two fields: "id" and "vec"(dimension=8), the row-based json format is: {"rows": [ @@ -1086,7 +1165,9 @@ def do_bulk_insert(self, collection_name: str, partition_name: str, files: list, :raises BaseException: If the files input is illegal. """ with self._connection() as handler: - return handler.do_bulk_insert(collection_name, partition_name, files, timeout=timeout, **kwargs) + return handler.do_bulk_insert( + collection_name, partition_name, files, timeout=timeout, **kwargs + ) def get_bulk_insert_state(self, task_id, timeout=None, **kwargs) -> BulkInsertState: """get_bulk_insert_state returns state of a certain task_id @@ -1117,7 +1198,7 @@ def list_bulk_insert_tasks(self, timeout=None, **kwargs) -> list: return handler.list_bulk_insert_tasks(timeout=timeout, **kwargs) def create_user(self, user, password, timeout=None, **kwargs): - """ Create a user using the given user and password. + """Create a user using the given user and password. :param user: the user name. :type user: str :param password: the password. @@ -1146,7 +1227,7 @@ def update_password(self, user, old_password, new_password, timeout=None, **kwar handler.update_password(user, old_password, new_password, timeout=timeout, **kwargs) def delete_user(self, user, timeout=None, **kwargs): - """ Delete user corresponding to the username. + """Delete user corresponding to the username. :param user: the user name. :type user: str :param timeout: The timeout for this method, unit: second @@ -1156,7 +1237,7 @@ def delete_user(self, user, timeout=None, **kwargs): handler.delete_user(user, timeout=timeout, **kwargs) def list_usernames(self, timeout=None, **kwargs): - """ List all usernames. + """List all usernames. :param timeout: The timeout for this method, unit: second :type timeout: int :return list of str: @@ -1166,7 +1247,7 @@ def list_usernames(self, timeout=None, **kwargs): return handler.list_usernames(timeout=timeout, **kwargs) def create_role(self, role_name, timeout=None, **kwargs): - """ Create Role + """Create Role :param role_name: the role name. :type role_name: str """ @@ -1174,7 +1255,7 @@ def create_role(self, role_name, timeout=None, **kwargs): handler.create_role(role_name, timeout=timeout, **kwargs) def drop_role(self, role_name, timeout=None, **kwargs): - """ Drop Role + """Drop Role :param role_name: role name. :type role_name: str """ @@ -1182,7 +1263,7 @@ def drop_role(self, role_name, timeout=None, **kwargs): handler.drop_role(role_name, timeout=timeout, **kwargs) def add_user_to_role(self, username, role_name, timeout=None, **kwargs): - """ Add User To Role + """Add User To Role :param username: user name. :type username: str :param role_name: role name. @@ -1192,7 +1273,7 @@ def add_user_to_role(self, username, role_name, timeout=None, **kwargs): handler.add_user_to_role(username, role_name, timeout=timeout, **kwargs) def remove_user_from_role(self, username, role_name, timeout=None, **kwargs): - """ Remove User From Role + """Remove User From Role :param username: user name. :type username: str :param role_name: role name. @@ -1202,7 +1283,7 @@ def remove_user_from_role(self, username, role_name, timeout=None, **kwargs): handler.remove_user_from_role(username, role_name, timeout=timeout, **kwargs) def select_one_role(self, role_name, include_user_info, timeout=None, **kwargs): - """ Select One Role Info + """Select One Role Info :param role_name: role name. :type role_name: str :param include_user_info: whether to obtain the user information associated with the role @@ -1212,7 +1293,7 @@ def select_one_role(self, role_name, include_user_info, timeout=None, **kwargs): handler.select_one_role(role_name, include_user_info, timeout=timeout, **kwargs) def select_all_role(self, include_user_info, timeout=None, **kwargs): - """ Select All Role Info + """Select All Role Info :param include_user_info: whether to obtain the user information associated with roles :type include_user_info: bool """ @@ -1220,7 +1301,7 @@ def select_all_role(self, include_user_info, timeout=None, **kwargs): handler.select_all_role(include_user_info, timeout=timeout, **kwargs) def select_one_user(self, username, include_role_info, timeout=None, **kwargs): - """ Select One User Info + """Select One User Info :param username: user name. :type username: str :param include_role_info: whether to obtain the role information associated with the user @@ -1230,16 +1311,15 @@ def select_one_user(self, username, include_role_info, timeout=None, **kwargs): handler.select_one_user(username, include_role_info, timeout=timeout, **kwargs) def select_all_user(self, include_role_info, timeout=None, **kwargs): - """ Select All User Info + """Select All User Info :param include_role_info: whether to obtain the role information associated with users :type include_role_info: bool """ with self._connection() as handler: handler.select_all_role(include_role_info, timeout=timeout, **kwargs) - def grant_privilege(self, role_name, object, object_name, privilege, - timeout=None, **kwargs): - """ Grant Privilege + def grant_privilege(self, role_name, object, object_name, privilege, timeout=None, **kwargs): + """Grant Privilege :param role_name: role name. :type role_name: str :param object: object that will be granted the privilege. @@ -1250,12 +1330,12 @@ def grant_privilege(self, role_name, object, object_name, privilege, :type privilege: str """ with self._connection() as handler: - handler.grant_privilege(role_name, object, object_name, privilege, - timeout=timeout, **kwargs) + handler.grant_privilege( + role_name, object, object_name, privilege, timeout=timeout, **kwargs + ) - def revoke_privilege(self, role_name, object, object_name, privilege, - timeout=None, **kwargs): - """ Revoke Privilege + def revoke_privilege(self, role_name, object, object_name, privilege, timeout=None, **kwargs): + """Revoke Privilege :param role_name: role name. :type role_name: str :param object: object that will be granted the privilege. @@ -1266,20 +1346,22 @@ def revoke_privilege(self, role_name, object, object_name, privilege, :type privilege: str """ with self._connection() as handler: - handler.revoke_privilege(role_name, object, object_name, privilege, - timeout=timeout, **kwargs) + handler.revoke_privilege( + role_name, object, object_name, privilege, timeout=timeout, **kwargs + ) def select_grant_for_one_role(self, role_name, timeout=None, **kwargs): - """ Select the grant info about the role + """Select the grant info about the role :param role_name: role name. :type role_name: str """ with self._connection() as handler: handler.select_grant_for_one_role(role_name, timeout=timeout, **kwargs) - def select_grant_for_role_and_object(self, role_name, object, object_name, - timeout=None, **kwargs): - """ Select the grant info about the role and specific object + def select_grant_for_role_and_object( + self, role_name, object, object_name, timeout=None, **kwargs + ): + """Select the grant info about the role and specific object :param role_name: role name. :type role_name: str :param object: object that will be selected the privilege info. @@ -1288,7 +1370,9 @@ def select_grant_for_role_and_object(self, role_name, object, object_name, :type object_name: str """ with self._connection() as handler: - handler.select_grant_for_role_and_object(role_name, object, object_name, timeout=timeout, **kwargs) + handler.select_grant_for_role_and_object( + role_name, object, object_name, timeout=timeout, **kwargs + ) def get_version(self, timeout=None, **kwargs): with self._connection() as handler: @@ -1313,8 +1397,7 @@ def drop_resource_group(self, name, timeout=None, **kwargs): handler.drop_resource_group(name, timeout=timeout, **kwargs) def list_resource_groups(self, timeout=None, **kwargs): - """list all resource group names - """ + """list all resource group names""" with self._connection() as handler: handler.list_resource_groups(timeout=timeout, **kwargs) @@ -1342,7 +1425,9 @@ def transfer_node(self, source, target, num_node, timeout=None, **kwargs): with self._connection() as handler: handler.transfer_node(source, target, num_node, timeout=timeout, **kwargs) - def transfer_replica(self, source, target, collection_name, num_replica, timeout=None, **kwargs): + def transfer_replica( + self, source, target, collection_name, num_replica, timeout=None, **kwargs + ): """transfer num_replica from source resource group to target resource group :param source: source resource group name @@ -1356,4 +1441,5 @@ def transfer_replica(self, source, target, collection_name, num_replica, timeout """ with self._connection() as handler: handler.transfer_replica( - source, target, collection_name, num_replica, timeout=timeout, **kwargs) + source, target, collection_name, num_replica, timeout=timeout, **kwargs + ) diff --git a/pymilvus/client/ts_utils.py b/pymilvus/client/ts_utils.py index 071f08849..46d72e365 100644 --- a/pymilvus/client/ts_utils.py +++ b/pymilvus/client/ts_utils.py @@ -1,31 +1,33 @@ -import threading import datetime +import threading +from typing import Any, Dict, Optional +from pymilvus.grpc_gen import common_pb2 + +from .constants import BOUNDED_TS, EVENTUALLY_TS from .singleton_utils import Singleton -from .utils import hybridts_to_unixtime from .types import get_consistency_level -from .constants import EVENTUALLY_TS, BOUNDED_TS - -from ..grpc_gen import common_pb2 +from .utils import hybridts_to_unixtime ConsistencyLevel = common_pb2.ConsistencyLevel + class GTsDict(metaclass=Singleton): - def __init__(self): + def __init__(self) -> None: # collection id -> last write ts self._last_write_ts_dict = {} self._last_write_ts_dict_lock = threading.Lock() - def __repr__(self): + def __repr__(self) -> str: return self._last_write_ts_dict.__repr__() - def update(self, collection, ts): + def update(self, collection: int, ts: int): # use lru later if necessary with self._last_write_ts_dict_lock: if ts > self._last_write_ts_dict.get(collection, 0): self._last_write_ts_dict[collection] = ts - def get(self, collection): + def get(self, collection: int): return self._last_write_ts_dict.get(collection, 0) @@ -35,31 +37,31 @@ def _get_gts_dict(): # Update the last write ts of collection. -def update_collection_ts(collection, ts): +def update_collection_ts(collection: int, ts: int): _get_gts_dict().update(collection, ts) # Return a callback corresponding to the collection. -def update_ts_on_mutation(collection): - def _update(mutation_result): +def update_ts_on_mutation(collection: int): + def _update(mutation_result: Any): update_collection_ts(collection, mutation_result.timestamp) return _update # Get the last write ts of collection. -def get_collection_ts(collection): +def get_collection_ts(collection: int): return _get_gts_dict().get(collection) # Get the last write timestamp of collection. -def get_collection_timestamp(collection): +def get_collection_timestamp(collection: int): ts = _get_gts_dict().get(collection) return hybridts_to_unixtime(ts) # Get the last write datetime of collection. -def get_collection_datetime(collection, tz=None): +def get_collection_datetime(collection: int, tz: Optional[datetime.timezone] = None): timestamp = get_collection_timestamp(collection) return datetime.datetime.fromtimestamp(timestamp, tz=tz) @@ -72,7 +74,7 @@ def get_bounded_ts(): return BOUNDED_TS -def construct_guarantee_ts(collection_name, kwargs): +def construct_guarantee_ts(collection_name: str, kwargs: Dict): consistency_level = kwargs.get("consistency_level", None) use_default = consistency_level is None if use_default: diff --git a/pymilvus/client/types.py b/pymilvus/client/types.py index ef74a42af..49dd6c2c7 100644 --- a/pymilvus/client/types.py +++ b/pymilvus/client/types.py @@ -1,16 +1,19 @@ import time from enum import IntEnum -from ..grpc_gen import common_pb2 -from ..exceptions import ( +from typing import Any, ClassVar, Dict, List, TypeVar, Union + +from pymilvus.exceptions import ( AutoIDException, ExceptionsMessage, InvalidConsistencyLevel, ) -from ..grpc_gen import milvus_pb2 as milvus_types - +from pymilvus.grpc_gen import common_pb2 +from pymilvus.grpc_gen import milvus_pb2 as milvus_types +Status = TypeVar("Status") ConsistencyLevel = common_pb2.ConsistencyLevel + class Status: """ :attribute code: int (optional) default as ok @@ -46,24 +49,21 @@ class Status: INDEX_NOT_EXIST = 25 EMPTY_COLLECTION = 26 - def __init__(self, code=SUCCESS, message="Success"): + def __init__(self, code: int = SUCCESS, message: str = "Success") -> None: self.code = code self.message = message - def __repr__(self): - attr_list = [f'{key}={value}' for key, value in self.__dict__.items()] + def __repr__(self) -> str: + attr_list = [f"{key}={value}" for key, value in self.__dict__.items()] return f"{self.__class__.__name__}({', '.join(attr_list)})" - def __eq__(self, other): - """ Make Status comparable with self by code """ + def __eq__(self, other: Union[int, Status]): + """Make Status comparable with self by code""" if isinstance(other, int): return self.code == other return isinstance(other, self.__class__) and self.code == other.code - def __ne__(self, other): - return self != other - def OK(self): return self.code == Status.SUCCESS @@ -113,10 +113,10 @@ class IndexType(IntEnum): IVF_FLAT = IVFLAT IVF_SQ8_H = IVF_SQ8H - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__}: {self._name_}>" - def __str__(self): + def __str__(self) -> str: return self._name_ @@ -132,10 +132,10 @@ class MetricType(IntEnum): SUBSTRUCTURE = 6 SUPERSTRUCTURE = 7 - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__}: {self._name_}>" - def __str__(self): + def __str__(self) -> str: return self._name_ @@ -174,19 +174,19 @@ def new(s: int): return State.Completed return State.UndefiedState - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__}: {self._name_}>" - def __str__(self): + def __str__(self) -> str: return self._name_ class LoadState(IntEnum): """ - NotExist: collection or partition isn't existed - NotLoad: collection or partition isn't loaded - Loading: collection or partition is loading - Loaded: collection or partition is loaded + NotExist: collection or partition isn't existed + NotLoad: collection or partition isn't loaded + Loading: collection or partition is loading + Loaded: collection or partition is loaded """ NotExist = 0 @@ -194,27 +194,35 @@ class LoadState(IntEnum): Loading = 2 Loaded = 3 - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__}: {self._name_}>" - def __str__(self): + def __str__(self) -> str: return self._name_ + class CompactionState: """ - in_executing: number of plans in executing - in_timeout: number of plans failed of timeout - completed: number of plans successfully completed + in_executing: number of plans in executing + in_timeout: number of plans failed of timeout + completed: number of plans successfully completed """ - def __init__(self, compaction_id: int, state: State, in_executing: int, in_timeout: int, completed: int): + def __init__( + self, + compaction_id: int, + state: State, + in_executing: int, + in_timeout: int, + completed: int, + ) -> None: self.compaction_id = compaction_id self.state = state self.in_executing = in_executing self.in_timeout = in_timeout self.completed = completed - def __repr__(self): + def __repr__(self) -> str: return f""" CompactionState - compaction id: {self.compaction_id} @@ -226,11 +234,11 @@ def __repr__(self): class Plan: - def __init__(self, sources: list, target: int): + def __init__(self, sources: list, target: int) -> None: self.sources = sources self.target = target - def __repr__(self): + def __repr__(self) -> str: return f""" Plan: - sources: {self.sources} @@ -239,12 +247,12 @@ def __repr__(self): class CompactionPlans: - def __init__(self, compaction_id: int, state: int): + def __init__(self, compaction_id: int, state: int) -> None: self.compaction_id = compaction_id self.state = State.new(state) self.plans = [] - def __repr__(self): + def __repr__(self) -> str: return f""" Compaction Plans: - compaction id: {self.compaction_id} @@ -253,7 +261,7 @@ def __repr__(self): """ -def cmp_consistency_level(l1, l2): +def cmp_consistency_level(l1: Union[str, int], l2: Union[str, int]): if isinstance(l1, str): try: l1 = ConsistencyLevel.Value(l1) @@ -266,18 +274,16 @@ def cmp_consistency_level(l1, l2): except ValueError: return False - if isinstance(l1, int): - if l1 not in ConsistencyLevel.values(): - return False + if isinstance(l1, int) and l1 not in ConsistencyLevel.values(): + return False - if isinstance(l2, int): - if l2 not in ConsistencyLevel.values(): - return False + if isinstance(l2, int) and l2 not in ConsistencyLevel.values(): + return False return l1 == l2 -def get_consistency_level(consistency_level): +def get_consistency_level(consistency_level: Union[str, int]): if isinstance(consistency_level, int): if consistency_level in ConsistencyLevel.values(): return consistency_level @@ -286,18 +292,23 @@ def get_consistency_level(consistency_level): try: return ConsistencyLevel.Value(consistency_level) except ValueError as e: - raise InvalidConsistencyLevel(message=f"invalid consistency level: {consistency_level}") from e + raise InvalidConsistencyLevel( + message=f"invalid consistency level: {consistency_level}" + ) from e raise InvalidConsistencyLevel(message="invalid consistency level") class Shard: - def __init__(self, channel_name: str, shard_nodes: list, shard_leader: int): + def __init__(self, channel_name: str, shard_nodes: list, shard_leader: int) -> None: self._channel_name = channel_name self._shard_nodes = set(shard_nodes) self._shard_leader = shard_leader - def __repr__(self): - return f"""Shard: , , """ + def __repr__(self) -> str: + return ( + f"Shard: , " + f", " + ) @property def channel_name(self) -> str: @@ -313,7 +324,14 @@ def shard_leader(self) -> int: class Group: - def __init__(self, group_id: int, shards: list, group_nodes: list, resource_group: str, num_outbound_node: dict): + def __init__( + self, + group_id: int, + shards: List[str], + group_nodes: List[tuple], + resource_group: str, + num_outbound_node: dict, + ) -> None: self._id = group_id self._shards = shards self._group_nodes = tuple(group_nodes) @@ -321,8 +339,11 @@ def __init__(self, group_id: int, shards: list, group_nodes: list, resource_grou self._num_outbound_node = num_outbound_node def __repr__(self) -> str: - s = f"Group: , , , , " - return s + return ( + f"Group: , , " + f", , " + f"" + ) @property def id(self): @@ -344,16 +365,24 @@ def resource_group(self): def num_outbound_node(self): return self._num_outbound_node + class Replica: """ Replica groups: - Group: , , - , , , ]> + , + , + , + ]> - Group: , , - , , , ]> + , + , + , + ]> """ - def __init__(self, groups: list): + def __init__(self, groups: list) -> None: self._groups = groups def __repr__(self) -> str: @@ -369,6 +398,7 @@ def groups(self): class BulkInsertState: """enum states of bulk insert task""" + ImportPending = 0 ImportFailed = 1 ImportStarted = 2 @@ -386,15 +416,18 @@ class BulkInsertState: """ Bulk insert state example: - - taskID : 44353845454358, - - state : "BulkLoadPersisted", - - row_count : 1000, - - infos : {"files": "rows.json", "collection": "c1", "partition": "", "failed_reason": ""}, - - id_list : [44353845455401, 44353845456401] - - create_ts : 1661398759, + - taskID : 44353845454358, + - state : "BulkLoadPersisted", + - row_count : 1000, + - infos : {"files": "rows.json", + "collection": "c1", + "partition": "", + "failed_reason": ""}, + - id_list : [44353845455401, 44353845456401] + - create_ts : 1661398759, """ - state_2_state = { + state_2_state: ClassVar[Dict] = { common_pb2.ImportPending: ImportPending, common_pb2.ImportFailed: ImportFailed, common_pb2.ImportStarted: ImportStarted, @@ -403,7 +436,7 @@ class BulkInsertState: common_pb2.ImportFailedAndCleaned: ImportFailedAndCleaned, } - state_2_name = { + state_2_name: ClassVar[Dict] = { ImportPending: "Pending", ImportFailed: "Failed", ImportStarted: "Started", @@ -413,7 +446,15 @@ class BulkInsertState: ImportUnknownState: "Unknown", } - def __init__(self, task_id, state, row_count: int, id_ranges: list, infos, create_ts: int): + def __init__( + self, + task_id: int, + state: State, + row_count: int, + id_ranges: list, + infos: Dict, + create_ts: int, + ): self._task_id = task_id self._state = state self._row_count = row_count @@ -431,8 +472,14 @@ def __repr__(self) -> str: - id_ranges : {}, - create_ts : {} >""" - return fmt.format(self._task_id, self.state_name, self.row_count, self.infos, - self.id_ranges, self.create_time_str) + return fmt.format( + self._task_id, + self.state_name, + self.row_count, + self.infos, + self.id_ranges, + self.create_time_str, + ) @property def task_id(self): @@ -534,7 +581,7 @@ def progress(self): class GrantItem: - def __init__(self, entity): + def __init__(self, entity: Any) -> None: self._object = entity.object.name self._object_name = entity.object_name self._db_name = entity.db_name @@ -543,11 +590,12 @@ def __init__(self, entity): self._privilege = entity.grantor.privilege.name def __repr__(self) -> str: - s = f"GrantItem: , , " \ - f", " \ - f", , " \ + return ( + f"GrantItem: , , " + f", " + f", , " f"" - return s + ) @property def object(self): @@ -577,11 +625,13 @@ def privilege(self): class GrantInfo: """ GrantInfo groups: - - GrantItem: , , , , - - GrantItem: , , , , + - GrantItem: , , , + , + - GrantItem: , , , + , """ - def __init__(self, entities): + def __init__(self, entities: List[milvus_types.RoleEntity]) -> None: groups = [] for entity in entities: if isinstance(entity, milvus_types.GrantEntity): @@ -601,7 +651,7 @@ def groups(self): class UserItem: - def __init__(self, username, entities): + def __init__(self, username: str, entities: List[milvus_types.RoleEntity]) -> None: self._username = username roles = [] for entity in entities: @@ -610,8 +660,7 @@ def __init__(self, username, entities): self._roles = tuple(roles) def __repr__(self) -> str: - s = f"UserItem: , " - return s + return f"UserItem: , " @property def username(self): @@ -628,7 +677,7 @@ class UserInfo: - UserItem: , """ - def __init__(self, results): + def __init__(self, results: List[milvus_types.UserResult]): groups = [] for result in results: if isinstance(result, milvus_types.UserResult): @@ -648,7 +697,7 @@ def groups(self): class RoleItem: - def __init__(self, role_name, entities): + def __init__(self, role_name: str, entities: List[milvus_types.UserEntity]): self._role_name = role_name users = [] for entity in entities: @@ -657,8 +706,7 @@ def __init__(self, role_name, entities): self._users = tuple(users) def __repr__(self) -> str: - s = f"RoleItem: , " - return s + return f"RoleItem: , " @property def role_name(self): @@ -675,7 +723,7 @@ class RoleInfo: - UserItem: , """ - def __init__(self, results): + def __init__(self, results: List[milvus_types.RoleResult]) -> None: groups = [] for result in results: if isinstance(result, milvus_types.RoleResult): @@ -695,7 +743,7 @@ def groups(self): class ResourceGroupInfo: - def __init__(self, resource_group): + def __init__(self, resource_group: Any) -> None: self._name = resource_group.name self._capacity = resource_group.capacity self._num_available_node = resource_group.num_available_node @@ -704,41 +752,34 @@ def __init__(self, resource_group): self._num_incoming_node = resource_group.num_incoming_node def __repr__(self) -> str: - s = f"""ResourceGroupInfo: + return f"""ResourceGroupInfo: , , , , , """ - return s - @property def name(self): return self._name - @property def capacity(self): return self._capacity - @property def num_available_node(self): return self._num_available_node - @property def num_loaded_replica(self): return self._num_loaded_replica - @property def num_outgoing_node(self): return self._num_outgoing_node - @property def num_incoming_node(self): return self._num_incoming_node diff --git a/pymilvus/client/utils.py b/pymilvus/client/utils.py index 399072784..aba526665 100644 --- a/pymilvus/client/utils.py +++ b/pymilvus/client/utils.py @@ -1,8 +1,11 @@ import datetime +from datetime import timedelta +from typing import Any, List, Optional, Union + +from pymilvus.exceptions import MilvusException, ParamError -from .types import DataType from .constants import LOGICAL_BITS, LOGICAL_BITS_MASK -from ..exceptions import ParamError, MilvusException +from .types import DataType MILVUS = "milvus" ZILLIZ = "zilliz" @@ -18,7 +21,7 @@ "BIN_FLAT", "BIN_IVF_FLAT", "DISKANN", - "AUTOINDEX" + "AUTOINDEX", ] valid_index_params_keys = [ @@ -28,12 +31,12 @@ "M", "efConstruction", "PQM", - "n_trees" + "n_trees", ] valid_binary_index_types = [ "BIN_FLAT", - "BIN_IVF_FLAT" + "BIN_IVF_FLAT", ] valid_binary_metric_types = [ @@ -41,21 +44,25 @@ "HAMMING", "TANIMOTO", "SUBSTRUCTURE", - "SUPERSTRUCTURE" + "SUPERSTRUCTURE", ] -def hybridts_to_unixtime(ts): +def hybridts_to_unixtime(ts: int): physical = ts >> LOGICAL_BITS return physical / 1000.0 -def mkts_from_hybridts(hybridts, milliseconds=0., delta=None): +def mkts_from_hybridts( + hybridts: int, + milliseconds: Union[int, float] = 0.0, + delta: Optional[timedelta] = None, +) -> int: if not isinstance(milliseconds, (int, float)): raise MilvusException(message="parameter milliseconds should be type of int or float") if isinstance(delta, datetime.timedelta): - milliseconds += (delta.microseconds / 1000.0) + milliseconds += delta.microseconds / 1000.0 elif delta is not None: raise MilvusException(message="parameter delta should be type of datetime.timedelta") @@ -65,11 +72,14 @@ def mkts_from_hybridts(hybridts, milliseconds=0., delta=None): logical = hybridts & LOGICAL_BITS_MASK physical = hybridts >> LOGICAL_BITS - new_ts = int((int((physical + milliseconds)) << LOGICAL_BITS) + logical) - return new_ts + return int((int(physical + milliseconds) << LOGICAL_BITS) + logical) -def mkts_from_unixtime(epoch, milliseconds=0., delta=None): +def mkts_from_unixtime( + epoch: Union[int, float], + milliseconds: Union[int, float] = 0.0, + delta: Optional[timedelta] = None, +) -> int: if not isinstance(epoch, (int, float)): raise MilvusException(message="parameter epoch should be type of int or float") @@ -77,33 +87,37 @@ def mkts_from_unixtime(epoch, milliseconds=0., delta=None): raise MilvusException(message="parameter milliseconds should be type of int or float") if isinstance(delta, datetime.timedelta): - milliseconds += (delta.microseconds / 1000.0) + milliseconds += delta.microseconds / 1000.0 elif delta is not None: raise MilvusException(message="parameter delta should be type of datetime.timedelta") - epoch += (milliseconds / 1000.0) + epoch += milliseconds / 1000.0 int_msecs = int(epoch * 1000 // 1) return int(int_msecs << LOGICAL_BITS) -def mkts_from_datetime(d_time, milliseconds=0., delta=None): +def mkts_from_datetime( + d_time: datetime.datetime, + milliseconds: Union[int, float] = 0.0, + delta: Optional[timedelta] = None, +) -> int: if not isinstance(d_time, datetime.datetime): raise MilvusException(message="parameter d_time should be type of datetime.datetime") return mkts_from_unixtime(d_time.timestamp(), milliseconds=milliseconds, delta=delta) -def check_invalid_binary_vector(entities) -> bool: +def check_invalid_binary_vector(entities: List) -> bool: for entity in entities: - if entity['type'] == DataType.BINARY_VECTOR: - if not isinstance(entity['values'], list) and len(entity['values']) == 0: + if entity["type"] == DataType.BINARY_VECTOR: + if not isinstance(entity["values"], list) and len(entity["values"]) == 0: return False - dim = len(entity['values'][0]) * 8 + dim = len(entity["values"][0]) * 8 if dim == 0: return False - for values in entity['values']: + for values in entity["values"]: if len(values) * 8 != dim: return False if not isinstance(values, bytes): @@ -111,7 +125,7 @@ def check_invalid_binary_vector(entities) -> bool: return True -def len_of(field_data) -> int: +def len_of(field_data: Any) -> int: if field_data.HasField("scalars"): if field_data.scalars.HasField("bool_data"): return len(field_data.scalars.bool_data.data) @@ -144,7 +158,9 @@ def len_of(field_data) -> int: if field_data.vectors.HasField("float_vector"): total_len = len(field_data.vectors.float_vector.data) if total_len % dim != 0: - raise MilvusException(message=f"Invalid vector length: total_len={total_len}, dim={dim}") + raise MilvusException( + message=f"Invalid vector length: total_len={total_len}, dim={dim}" + ) return int(total_len / dim) total_len = len(field_data.vectors.binary_vector) @@ -153,7 +169,7 @@ def len_of(field_data) -> int: raise MilvusException(message="Unknown data type") -def traverse_rows_info(fields_info, entities): +def traverse_rows_info(fields_info: Any, entities: List): location, primary_key_loc, auto_id_loc = {}, None, None for i, field in enumerate(fields_info): @@ -178,29 +194,28 @@ def traverse_rows_info(fields_info, entities): if is_auto_id: if field_name in entity: raise ParamError( - message=f"auto id enabled, {field_name} shouldn't in entities[{j}]") + message=f"auto id enabled, {field_name} shouldn't in entities[{j}]" + ) continue - if is_dynamic: - if field_name in entity: - raise ParamError( - message=f"dynamic field enabled, {field_name} shouldn't in entities[{j}]") + if is_dynamic and field_name in entity: + raise ParamError( + message=f"dynamic field enabled, {field_name} shouldn't in entities[{j}]" + ) value = entity.get(field_name, None) if value is None: - raise ParamError( - message=f"Field {field_name} don't match in entities[{j}]") + raise ParamError(message=f"Field {field_name} don't match in entities[{j}]") if field_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]: field_dim = field["params"]["dim"] - if field_type == DataType.FLOAT_VECTOR: - entity_dim = len(value) - else: - entity_dim = len(value) * 8 + entity_dim = len(value) if field_type == DataType.FLOAT_VECTOR else len(value) * 8 if entity_dim != field_dim: - raise ParamError(message=f"Collection field dim is {field_dim}" - f", but entities field dim is {entity_dim}") + raise ParamError( + message=f"Collection field dim is {field_dim}" + f", but entities field dim is {entity_dim}" + ) # though impossible from sdk if primary_key_loc is None: @@ -209,7 +224,7 @@ def traverse_rows_info(fields_info, entities): return location, primary_key_loc, auto_id_loc -def traverse_info(fields_info, entities): +def traverse_info(fields_info: Any, entities: List): location, primary_key_loc, auto_id_loc = {}, None, None for i, field in enumerate(fields_info): if field.get("is_primary", False): @@ -228,41 +243,50 @@ def traverse_info(fields_info, entities): if field_name == entity_name: if field_type != entity_type: - raise ParamError(message=f"Collection field type is {field_type}" - f", but entities field type is {entity_type}") + raise ParamError( + message=f"Collection field type is {field_type}" + f", but entities field type is {entity_type}" + ) entity_dim, field_dim = 0, 0 if entity_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]: field_dim = field["params"]["dim"] entity_dim = len(entity["values"][0]) - if entity_type in [DataType.FLOAT_VECTOR, ] and entity_dim != field_dim: - raise ParamError(message=f"Collection field dim is {field_dim}" - f", but entities field dim is {entity_dim}") + if entity_type in [DataType.FLOAT_VECTOR] and entity_dim != field_dim: + raise ParamError( + message=f"Collection field dim is {field_dim}" + f", but entities field dim is {entity_dim}" + ) - if entity_type in [DataType.BINARY_VECTOR, ] and entity_dim * 8 != field_dim: - raise ParamError(message=f"Collection field dim is {field_dim}" - f", but entities field dim is {entity_dim * 8}") + if entity_type in [DataType.BINARY_VECTOR] and entity_dim * 8 != field_dim: + raise ParamError( + message=f"Collection field dim is {field_dim}" + f", but entities field dim is {entity_dim * 8}" + ) location[field["name"]] = i match_flag = True break if not match_flag: - raise ParamError( - message=f"Field {field['name']} don't match in entities") + raise ParamError(message=f"Field {field['name']} don't match in entities") return location, primary_key_loc, auto_id_loc -def get_server_type(host): +def get_server_type(host: str): if host is None or not isinstance(host, str): return MILVUS - splits = host.split('.') + splits = host.split(".") len_of_splits = len(splits) - if len_of_splits >= 2 and \ - (splits[len_of_splits - 2].lower() == "zilliz" or - splits[len_of_splits - 2].lower() == "zillizcloud") and \ - splits[len_of_splits - 1].lower() == "com": + if ( + len_of_splits >= 2 + and ( + splits[len_of_splits - 2].lower() == "zilliz" + or splits[len_of_splits - 2].lower() == "zillizcloud" + ) + and splits[len_of_splits - 1].lower() == "com" + ): return ZILLIZ return MILVUS diff --git a/pymilvus/decorators.py b/pymilvus/decorators.py index 5aa2f181b..f588127e1 100644 --- a/pymilvus/decorators.py +++ b/pymilvus/decorators.py @@ -2,6 +2,7 @@ import functools import logging import time +from typing import Any, Callable, Optional import grpc @@ -12,23 +13,34 @@ WARNING_COLOR = "\033[93m{}\033[0m" -def deprecated(func): +def deprecated(func: Any): @functools.wraps(func) def inner(*args, **kwargs): - dup_msg = "[WARNING] PyMilvus: class Milvus will be deprecated soon, please use Collection/utility instead" - LOGGER.warning(WARNING_COLOR.format(dup_msg)) + LOGGER.warning( + WARNING_COLOR.format( + "[WARNING] PyMilvus: ", + "class Milvus will be deprecated soon, please use Collection/utility instead", + ) + ) return func(*args, **kwargs) + return inner -def retry_on_rpc_failure(retry_times=10, initial_back_off=0.01, max_back_off=60, back_off_multiplier=3): - # the default 7 retry_times will cost about 26s - def wrapper(func): +def retry_on_rpc_failure( + *, + retry_times: int = 10, + initial_back_off: float = 0.01, + max_back_off: float = 60, + back_off_multiplier: int = 3, +): + def wrapper(func: Any): @functools.wraps(func) @error_handler(func_name=func.__name__) @tracing_request() - def handler(self, *args, **kwargs): - # This has to make sure every timeout parameter is passing throught kwargs form as `timeout=10` + def handler(*args, **kwargs): + # This has to make sure every timeout parameter is passing + # throught kwargs form as `timeout=10` _timeout = kwargs.get("timeout", None) _retry_on_rate_limit = kwargs.get("retry_on_rate_limit", True) @@ -37,54 +49,62 @@ def handler(self, *args, **kwargs): back_off = initial_back_off start_time = time.time() - def timeout(start_time) -> bool: - """ If timeout is valid, use timeout as the retry limits, - If timeout is None, use retry_times as the retry limits. + def timeout(start_time: Optional[float] = None) -> bool: + """If timeout is valid, use timeout as the retry limits, + If timeout is None, use retry_times as the retry limits. """ if retry_timeout is not None: return time.time() - start_time >= retry_timeout return counter > retry_times + to_msg = ( + f"Retry timeout: {retry_timeout}s" + if retry_timeout is not None + else f"Retry run out of {retry_times} retry times" + ) + while True: try: - return func(self, *args, **kwargs) + return func(*args, **kwargs) except grpc.RpcError as e: # Reference: https://grpc.github.io/grpc/python/grpc.html#grpc-status-code if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED: raise MilvusException(message=str(e)) from e if timeout(start_time): - timeout_msg = f"Retry timeout: {retry_timeout}s" if retry_timeout is not None \ - else f"Retry run out of {retry_times} retry times" - raise MilvusException(e.code, f"{timeout_msg}, message={e.details()}") from e + raise MilvusException(e.code, f"{to_msg}, message={e.details()}") from e if counter > 3: - retry_msg = f"[{func.__name__}] retry:{counter}, cost: {back_off:.2f}s, reason: <{e.__class__.__name__}: {e.code()}, {e.details()}>" + retry_msg = ( + f"[{func.__name__}] retry:{counter}, cost: {back_off:.2f}s, " + f"reason: <{e.__class__.__name__}: {e.code()}, {e.details()}>" + ) LOGGER.warning(WARNING_COLOR.format(retry_msg)) time.sleep(back_off) back_off = min(back_off * back_off_multiplier, max_back_off) except MilvusException as e: if timeout(start_time): - timeout_msg = f"Retry timeout: {retry_timeout}s" if retry_timeout is not None \ - else f"Retry run out of {retry_times} retry times" - LOGGER.warning(WARNING_COLOR.format(timeout_msg)) - raise MilvusException(e.code, f"{timeout_msg}, message={e.message}") from e + LOGGER.warning(WARNING_COLOR.format(to_msg)) + raise MilvusException( + code=e.code, message=f"{to_msg}, message={e.message}" + ) from e if _retry_on_rate_limit and e.code == common_pb2.RateLimit: time.sleep(back_off) back_off = min(back_off * back_off_multiplier, max_back_off) else: - raise e + raise e from e except Exception as e: - raise e + raise e from e finally: counter += 1 return handler + return wrapper -def error_handler(func_name=""): - def wrapper(func): +def error_handler(func_name: str = ""): + def wrapper(func: Callable): @functools.wraps(func) def handler(*args, **kwargs): inner_name = func_name @@ -97,41 +117,50 @@ def handler(*args, **kwargs): except MilvusException as e: record_dict["RPC error"] = str(datetime.datetime.now()) LOGGER.error(f"RPC error: [{inner_name}], {e}, ") - raise e + raise e from e except grpc.FutureTimeoutError as e: record_dict["gRPC timeout"] = str(datetime.datetime.now()) - LOGGER.error(f"grpc Timeout: [{inner_name}], <{e.__class__.__name__}: {e.code()}, {e.details()}>, ") - raise e + LOGGER.error( + f"grpc Timeout: [{inner_name}], <{e.__class__.__name__}: " + f"{e.code()}, {e.details()}>, " + ) + raise e from e except grpc.RpcError as e: record_dict["gRPC error"] = str(datetime.datetime.now()) - LOGGER.error(f"grpc RpcError: [{inner_name}], <{e.__class__.__name__}: {e.code()}, {e.details()}>, ") - raise e + LOGGER.error( + f"grpc RpcError: [{inner_name}], <{e.__class__.__name__}: " + f"{e.code()}, {e.details()}>, " + ) + raise e from e except Exception as e: record_dict["Exception"] = str(datetime.datetime.now()) LOGGER.error(f"Unexpected error: [{inner_name}], {e}, ") - raise MilvusException(message=f"Unexpected error, message=<{str(e)}>") from e + raise MilvusException(message=f"Unexpected error, message=<{e!s}>") from e + return handler + return wrapper def tracing_request(): - def wrapper(func): + def wrapper(func: Callable): @functools.wraps(func) - def handler(self, *args, **kwargs): + def handler(self: Callable, *args, **kwargs): level = kwargs.get("log_level", None) req_id = kwargs.get("client_request_id", None) if level: self.set_onetime_loglevel(level) if req_id: self.set_onetime_request_id(req_id) - ret = func(self, *args, **kwargs) - return ret + return func(self, *args, **kwargs) + return handler + return wrapper -def ignore_unimplemented(default_return_value): - def wrapper(func): +def ignore_unimplemented(default_return_value: Any): + def wrapper(func: Callable): @functools.wraps(func) def handler(*args, **kwargs): try: @@ -140,24 +169,29 @@ def handler(*args, **kwargs): if e.code() == grpc.StatusCode.UNIMPLEMENTED: LOGGER.warning(f"{func.__name__} unimplemented, ignore it") return default_return_value - raise e + raise e from e except Exception as e: - raise e + raise e from e + return handler + return wrapper -def upgrade_reminder(func): +def upgrade_reminder(func: Callable): @functools.wraps(func) def handler(*args, **kwargs): try: return func(*args, **kwargs) except grpc.RpcError as e: if e.code() == grpc.StatusCode.UNIMPLEMENTED: - msg = "this version of sdk is incompatible with server, please downgrade your sdk or upgrade your " \ - "server " + msg = ( + "this version of sdk is incompatible with server, " + "please downgrade your sdk or upgrade your server" + ) raise MilvusException(message=msg) from e - raise e + raise e from e except Exception as e: - raise e + raise e from e + return handler diff --git a/pymilvus/exceptions.py b/pymilvus/exceptions.py index e1f3ee794..dd83d0666 100644 --- a/pymilvus/exceptions.py +++ b/pymilvus/exceptions.py @@ -19,7 +19,7 @@ class ErrorCode(IntEnum): class MilvusException(Exception): - def __init__(self, code: int = ErrorCode.UNEXPECTED_ERROR, message: str = ""): + def __init__(self, code: int = ErrorCode.UNEXPECTED_ERROR, message: str = "") -> None: super().__init__() self._code = code self._message = message @@ -32,40 +32,40 @@ def code(self): def message(self): return self._message - def __str__(self): + def __str__(self) -> str: return f"<{type(self).__name__}: (code={self.code}, message={self.message})>" class ParamError(MilvusException): - """ Raise when params are incorrect """ + """Raise when params are incorrect""" class ConnectError(MilvusException): - """ Connect server fail """ + """Connect server fail""" class MilvusUnavailableException(MilvusException): - """ Raise when server's Unavaliable""" + """Raise when server's Unavaliable""" class CollectionNotExistException(MilvusException): - """ Raise when collections doesn't exist """ + """Raise when collections doesn't exist""" class DescribeCollectionException(MilvusException): - """ Raise when fail to describe collection """ + """Raise when fail to describe collection""" class PartitionNotExistException(MilvusException): - """ Raise when partition doesn't exist """ + """Raise when partition doesn't exist""" class PartitionAlreadyExistException(MilvusException): - """ Raise when create an exsiting partition """ + """Raise when create an exsiting partition""" class IndexNotExistException(MilvusException): - """ Raise when index doesn't exist """ + """Raise when index doesn't exist""" class AmbiguousIndexName(MilvusException): @@ -73,70 +73,73 @@ class AmbiguousIndexName(MilvusException): class CannotInferSchemaException(MilvusException): - """ Raise when cannot trasfer dataframe to schema """ + """Raise when cannot trasfer dataframe to schema""" class SchemaNotReadyException(MilvusException): - """ Raise when schema is wrong """ + """Raise when schema is wrong""" class DataTypeNotMatchException(MilvusException): - """ Raise when datatype dosen't match """ + """Raise when datatype dosen't match""" class DataTypeNotSupportException(MilvusException): - """ Raise when datatype isn't supported """ + """Raise when datatype isn't supported""" class DataNotMatchException(MilvusException): - """ Raise when insert data isn't match with schema """ + """Raise when insert data isn't match with schema""" class ConnectionNotExistException(MilvusException): - """ Raise when connections doesn't exist """ + """Raise when connections doesn't exist""" class ConnectionConfigException(MilvusException): - """ Raise when configs of connection are invalid """ + """Raise when configs of connection are invalid""" class PrimaryKeyException(MilvusException): - """ Raise when primarykey are invalid """ + """Raise when primarykey are invalid""" class PartitionKeyException(MilvusException): - """ Raise when partitionkey are invalid """ + """Raise when partitionkey are invalid""" class DefaultValueException(MilvusException): - """ Raise when DefaultValue are invalid """ + """Raise when DefaultValue are invalid""" class FieldsTypeException(MilvusException): - """ Raise when fields is invalid """ + """Raise when fields is invalid""" class FieldTypeException(MilvusException): - """ Raise when one field is invalid """ + """Raise when one field is invalid""" class AutoIDException(MilvusException): - """ Raise when autoID is invalid """ + """Raise when autoID is invalid""" class InvalidConsistencyLevel(MilvusException): - """ Raise when consistency level is invalid """ + """Raise when consistency level is invalid""" class UpsertAutoIDTrueException(MilvusException): - """ Raise when upsert autoID is true """ + """Raise when upsert autoID is true""" class ExceptionsMessage: NoHostPort = "connection configuration must contain 'host' and 'port'." HostType = "Type of 'host' must be str." PortType = "Type of 'port' must be str or int." - ConnDiffConf = "Alias of %r already creating connections, but the configure is not the same as passed in." + ConnDiffConf = ( + "Alias of %r already creating connections, " + "but the configure is not the same as passed in." + ) AliasType = "Alias should be string, but %r is given." ConnLackConf = "You need to pass in the configuration of the connection named %r ." ConnectFirst = "should create connect first." @@ -144,13 +147,20 @@ class ExceptionsMessage: NoSchema = "Should be passed into the schema." EmptySchema = "The field of the schema cannot be empty." SchemaType = "Schema type must be schema.CollectionSchema." - SchemaInconsistent = "The collection already exist, but the schema is not the same as the schema passed in." + SchemaInconsistent = ( + "The collection already exist, but the schema is not the same as the schema passed in." + ) AutoIDWithData = "Auto_id is True, primary field should not have data." AutoIDType = "Param auto_id must be bool type." NumPartitionsType = "Param num_partitions must be int type." - AutoIDInconsistent = "The auto_id of the collection is inconsistent with the auto_id of the primary key field." + AutoIDInconsistent = ( + "The auto_id of the collection is inconsistent " + "with the auto_id of the primary key field." + ) AutoIDIllegalRanges = "The auto-generated id ranges should be pairs." - ConsistencyLevelInconsistent = "The parameter consistency_level is inconsistent with that of existed collection." + ConsistencyLevelInconsistent = ( + "The parameter consistency_level is inconsistent with that of existed collection." + ) AutoIDOnlyOnPK = "The auto_id can only be specified on the primary key field" AutoIDFieldType = "The auto_id can only be specified on field with DataType.INT64" DefaultValueTypeNotSupport = "default_value only support scalars except array and json for now." @@ -187,5 +197,7 @@ class ExceptionsMessage: ExprType = "The type of expr must be string ,but %r is given." EnvConfigErr = "Environment variable %s has a wrong format, please check it: %s" AmbiguousIndexName = "There are multiple indexes, please specify the index_name." - InsertUnexpectedField = "Attempt to insert an unexpected field to collection without enabling dynamic field" + InsertUnexpectedField = ( + "Attempt to insert an unexpected field to collection without enabling dynamic field" + ) UpsertAutoIDTrue = "Upsert don't support autoid == true" diff --git a/pymilvus/grpc_gen/__init__.py b/pymilvus/grpc_gen/__init__.py index ca9e40522..1243d0194 100644 --- a/pymilvus/grpc_gen/__init__.py +++ b/pymilvus/grpc_gen/__init__.py @@ -1,9 +1,9 @@ -# -*- coding: utf-8 -*- +from . import schema_pb2, milvus_pb2, common_pb2, feder_pb2, milvus_pb2_grpc __all__ = [ - 'milvus_pb2', - 'common_pb2', - 'feder_pb2', - 'schema_pb2', - 'milvus_pb2_grpc', + "milvus_pb2", + "common_pb2", + "feder_pb2", + "schema_pb2", + "milvus_pb2_grpc", ] diff --git a/pymilvus/milvus_client/milvus_client.py b/pymilvus/milvus_client/milvus_client.py index ed0e504f8..9ea615e4f 100644 --- a/pymilvus/milvus_client/milvus_client.py +++ b/pymilvus/milvus_client/milvus_client.py @@ -1,15 +1,20 @@ """MilvusClient for dealing with simple workflows.""" import logging -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union from uuid import uuid4 -from pymilvus.exceptions import MilvusException, DataTypeNotMatchException, AutoIDException, PrimaryKeyException +from pymilvus.client.constants import DEFAULT_CONSISTENCY_LEVEL +from pymilvus.client.types import ExceptionsMessage +from pymilvus.exceptions import ( + AutoIDException, + DataTypeNotMatchException, + MilvusException, + PrimaryKeyException, +) from pymilvus.orm import utility from pymilvus.orm.collection import CollectionSchema from pymilvus.orm.connections import connections from pymilvus.orm.types import DataType -from pymilvus.client.types import ExceptionsMessage -from pymilvus.client.constants import DEFAULT_CONSISTENCY_LEVEL logger = logging.getLogger() logger.setLevel(logging.DEBUG) @@ -20,13 +25,16 @@ class MilvusClient: # pylint: disable=logging-too-many-args, too-many-instance-attributes, import-outside-toplevel - def __init__(self, - uri: str = "http://localhost:19530", - user: str = "", - password: str = "", - db_name: str = "", - token: str = "", - timeout: float = None, **kwargs): + def __init__( + self, + uri: str = "http://localhost:19530", + user: str = "", + password: str = "", + db_name: str = "", + token: str = "", + timeout: Optional[float] = None, + **kwargs, + ) -> None: """A client for the common Milvus use case. This client attempts to hide away the complexity of using Pymilvus. In a lot ofcases what @@ -42,10 +50,11 @@ def __init__(self, # Optionial TQDM import try: import tqdm + self.tqdm = tqdm.tqdm except ImportError: logger.debug("tqdm not found") - self.tqdm = (lambda x, disable: x) + self.tqdm = lambda x, disable: x self.uri = uri self.timeout = timeout @@ -55,20 +64,21 @@ def __init__(self, self._using = self._create_connection(uri, user, password, db_name, token, **kwargs) self.is_self_hosted = bool( - utility.get_server_type(using=self._using) == "milvus" + utility.get_server_type(using=self._using) == "milvus", ) - def create_collection(self, - collection_name: str, - dimension: int, - primary_field_name: str = "id", # default is "id" - id_type: str = "int", # or "string", - vector_field_name: str = "vector", # default is "vector" - metric_type: str = "IP", - auto_id=False, - timeout: float = None, - **kwargs): - + def create_collection( + self, + collection_name: str, + dimension: int, + primary_field_name: str = "id", # default is "id" + id_type: str = "int", # or "string", + vector_field_name: str = "vector", # default is "vector" + metric_type: str = "IP", + auto_id: bool = False, + timeout: Optional[float] = None, + **kwargs, + ): if "enable_dynamic_field" not in kwargs: kwargs["enable_dynamic_field"] = True @@ -81,9 +91,8 @@ def create_collection(self, else: raise PrimaryKeyException(message=ExceptionsMessage.PrimaryFieldType) - if pk_data_type == DataType.VARCHAR: - if auto_id: - raise AutoIDException(message=ExceptionsMessage.AutoIDFieldType) + if pk_data_type == DataType.VARCHAR and auto_id: + raise AutoIDException(message=ExceptionsMessage.AutoIDFieldType) pk_args = {} if "max_length" in kwargs and pk_data_type == DataType.VARCHAR: @@ -102,7 +111,7 @@ def create_collection(self, logger.debug("Successfully created collection: %s", collection_name) except Exception as ex: logger.error("Failed to create collection: %s", collection_name) - raise ex + raise ex from ex index_params = { "metric_type": metric_type, "params": {}, @@ -110,7 +119,13 @@ def create_collection(self, self._create_index(collection_name, vector_field_name, index_params, timeout=timeout) self._load(collection_name, timeout=timeout) - def _create_index(self, collection_name, vec_field_name, index_params, timeout=None) -> None: + def _create_index( + self, + collection_name: str, + vec_field_name: str, + index_params: Dict, + timeout: Optional[float] = None, + ) -> None: """Create a index on the collection""" conn = self._get_connection() try: @@ -126,18 +141,19 @@ def _create_index(self, collection_name, vec_field_name, index_params, timeout=N ) except Exception as ex: logger.error( - "Failed to create an index on collection: %s", collection_name + "Failed to create an index on collection: %s", + collection_name, ) - raise ex + raise ex from ex def insert( - self, - collection_name: str, - data: Union[Dict, List[Dict]], - batch_size: int = 0, - progress_bar: bool = False, - timeout=None, - **kwargs, + self, + collection_name: str, + data: Union[Dict, List[Dict]], + batch_size: int = 0, + progress_bar: bool = False, + timeout: Optional[float] = None, + **kwargs, ) -> List[Union[str, int]]: """Insert data into the collection. @@ -170,7 +186,8 @@ def insert( if batch_size < 0: logger.error("Invalid batch size provided for insert.") - raise ValueError("Invalid batch size provided for insert.") + msg = "Invalid batch size provided for insert." + raise ValueError(msg) if batch_size == 0: batch_size = len(data) @@ -179,7 +196,7 @@ def insert( pks = [] for i in self.tqdm(range(0, len(data), batch_size), disable=not progress_bar): # Convert dict to list of lists batch for insertion - insert_batch = data[i:i + batch_size] + insert_batch = data[i : i + batch_size] # Insert into the collection. try: res = conn.insert_rows(collection_name, insert_batch, timeout=timeout) @@ -190,20 +207,21 @@ def insert( str(i), str(len(data)), ) - raise ex + raise ex from ex return pks - def search(self, - collection_name: str, - data: Union[List[list], list], - filter: str = "", - limit: int = 10, - output_fields: List[str] = None, - search_params: dict = None, - timeout: float = None, - **kwargs, - ) -> List[dict]: + def search( + self, + collection_name: str, + data: Union[List[list], list], + filter: str = "", + limit: int = 10, + output_fields: Optional[List[str]] = None, + search_params: Optional[dict] = None, + timeout: Optional[float] = None, + **kwargs, + ) -> List[dict]: """Search for a query vector/vectors. In order for the search to process, a collection needs to have been either provided @@ -242,7 +260,7 @@ def search(self, ) except Exception as ex: logger.error("Failed to search collection: %s", collection_name) - raise ex + raise ex from ex ret = [] for hits in res: @@ -254,12 +272,12 @@ def search(self, return ret def query( - self, - collection_name: str, - filter: str, - output_fields: List[str] = None, - timeout: float = None, - **kwargs + self, + collection_name: str, + filter: str, + output_fields: Optional[List[str]] = None, + timeout: Optional[float] = None, + **kwargs, ) -> List[dict]: """Query for entries in the Collection. @@ -282,12 +300,10 @@ def query( conn = self._get_connection() try: - schema_dict = conn.describe_collection( - collection_name, - timeout=timeout, **kwargs) + schema_dict = conn.describe_collection(collection_name, timeout=timeout, **kwargs) except Exception as ex: logger.error("Failed to describe collection: %s", collection_name) - raise ex + raise ex from ex if not output_fields: output_fields = ["*"] @@ -297,23 +313,21 @@ def query( try: res = conn.query( - collection_name, - expr=filter, - output_fields=output_fields, - timeout=timeout, **kwargs) + collection_name, expr=filter, output_fields=output_fields, timeout=timeout, **kwargs + ) except Exception as ex: logger.error("Failed to query collection: %s", collection_name) - raise ex + raise ex from ex return res def get( - self, - collection_name: str, - ids: Union[list, str, int], - output_fields: List[str] = None, - timeout: float = None, - **kwargs, + self, + collection_name: str, + ids: Union[list, str, int], + output_fields: Optional[List[str]] = None, + timeout: Optional[float] = None, + **kwargs, ) -> List[List[float]]: """Grab the inserted vectors using the primary key from the Collection. @@ -339,12 +353,10 @@ def get( conn = self._get_connection() try: - schema_dict = conn.describe_collection( - collection_name, - timeout=timeout, **kwargs) + schema_dict = conn.describe_collection(collection_name, timeout=timeout, **kwargs) except Exception as ex: logger.error("Failed to describe collection: %s", collection_name) - raise ex + raise ex from ex if not output_fields: output_fields = ["*"] @@ -355,22 +367,20 @@ def get( expr = self._pack_pks_expr(schema_dict, ids) try: res = conn.query( - collection_name, - expr=expr, - output_fields=output_fields, - timeout=timeout, **kwargs) + collection_name, expr=expr, output_fields=output_fields, timeout=timeout, **kwargs + ) except Exception as ex: logger.error("Failed to get collection: %s", collection_name) - raise ex + raise ex from ex return res def delete( - self, - collection_name: str, - pks: Union[list, str, int], - timeout: float = None, - **kwargs + self, + collection_name: str, + pks: Union[list, str, int], + timeout: Optional[float] = None, + **kwargs, ) -> List[Union[str, int]]: """Delete entries in the collection by their pk. @@ -392,27 +402,22 @@ def delete( conn = self._get_connection() try: - schema_dict = conn.describe_collection( - collection_name, - timeout=timeout, **kwargs) + schema_dict = conn.describe_collection(collection_name, timeout=timeout, **kwargs) except Exception as ex: logger.error("Failed to describe collection: %s", collection_name) - raise ex + raise ex from ex expr = self._pack_pks_expr(schema_dict, pks) ret_pks = [] try: - res = conn.delete( - collection_name, - expr, - timeout=timeout, **kwargs) + res = conn.delete(collection_name, expr, timeout=timeout, **kwargs) ret_pks.extend(res.primary_keys) except Exception as ex: logger.error("Failed to delete primary keys in collection: %s", collection_name) - raise ex + raise ex from ex return ret_pks - def num_entities(self, collection_name: str, timeout=None) -> int: + def num_entities(self, collection_name: str, timeout: Optional[float] = None) -> int: """return the number of rows in the collection. Returns: @@ -424,13 +429,14 @@ def num_entities(self, collection_name: str, timeout=None) -> int: result["row_count"] = int(result["row_count"]) return result["row_count"] - def flush(self, collection_name, timeout=None, **kwargs): - """ Seal all segments in the collection. Inserts after flushing will be written into + def flush(self, collection_name: str, timeout: Optional[float] = None, **kwargs): + """Seal all segments in the collection. Inserts after flushing will be written into new segments. Only sealed segments can be indexed. Args: timeout (float): an optional duration of time in seconds to allow for the RPCs. - If timeout is not set, the client keeps waiting until the server responds or an error occurs. + If timeout is not set, the client keeps waiting until the server responds + or an error occurs. """ conn = self._get_connection() conn.flush([collection_name], timeout=timeout, **kwargs) @@ -438,11 +444,10 @@ def flush(self, collection_name, timeout=None, **kwargs): def describe_collection(self, collection_name: str, **kwargs): conn = self._get_connection() try: - schema_dict = conn.describe_collection( - collection_name, **kwargs) + schema_dict = conn.describe_collection(collection_name, **kwargs) except Exception as ex: logger.error("Failed to describe collection: %s", collection_name) - raise ex + raise ex from ex return schema_dict def list_collections(self, **kwargs): @@ -451,7 +456,7 @@ def list_collections(self, **kwargs): collection_names = conn.list_collections(**kwargs) except Exception as ex: logger.error("Failed to list collections") - raise ex + raise ex from ex return collection_names def drop_collection(self, collection_name: str): @@ -465,8 +470,15 @@ def create_schema(cls, **kwargs): return CollectionSchema([], **kwargs) @classmethod - def prepare_index_params(cls, field_name, index_type=None, metric_type=None, index_name="", params=None, - **kwargs): + def prepare_index_params( + cls, + field_name: str, + index_type: Optional[str] = None, + metric_type: Optional[str] = None, + index_name: str = "", + params: Optional[Dict] = None, + **kwargs, + ): index_params = {"field_name": field_name} if index_type is not None: index_params["index_type"] = index_type @@ -481,7 +493,14 @@ def prepare_index_params(cls, field_name, index_type=None, metric_type=None, ind return index_params - def create_collection_with_schema(self, collection_name, schema, index_param, timeout=None, **kwargs): + def create_collection_with_schema( + self, + collection_name: str, + schema: CollectionSchema, + index_param: Dict, + timeout: Optional[float] = None, + **kwargs, + ): schema.verify() if kwargs.get("auto_id", True): schema.auto_id = True @@ -503,7 +522,7 @@ def create_collection_with_schema(self, collection_name, schema, index_param, ti logger.debug("Successfully created collection: %s", collection_name) except Exception as ex: logger.error("Failed to create collection: %s", collection_name) - raise ex + raise ex from ex self._create_index(collection_name, vector_field_name, index_param, timeout=timeout) self._load(collection_name, timeout=timeout) @@ -514,19 +533,28 @@ def close(self): def _get_connection(self): return connections._fetch_handler(self._using) - def _create_connection(self, uri, user="", password="", db_name="", token="", **kwargs) -> str: + def _create_connection( + self, + uri: str, + user: str = "", + password: str = "", + db_name: str = "", + token: str = "", + **kwargs, + ) -> str: """Create the connection to the Milvus server.""" # TODO: Implement reuse with new uri style using = uuid4().hex try: connections.connect(using, user, password, db_name, token, uri=uri, **kwargs) - logger.debug("Created new connection using: %s", using) - return using except Exception as ex: logger.error("Failed to create new connection using: %s", using) - raise ex + raise ex from ex + else: + logger.debug("Created new connection using: %s", using) + return using - def _extract_primary_field(self, schema_dict) -> dict: + def _extract_primary_field(self, schema_dict: Dict) -> dict: fields = schema_dict.get("fields", []) if not fields: return {} @@ -537,7 +565,7 @@ def _extract_primary_field(self, schema_dict) -> dict: return {} - def _get_vector_field_name(self, schema_dict): + def _get_vector_field_name(self, schema_dict: Dict): fields = schema_dict.get("fields", []) if not fields: return {} @@ -547,13 +575,13 @@ def _get_vector_field_name(self, schema_dict): return field_dict.get("name", "") return "" - def _pack_pks_expr(self, schema_dict, pks) -> str: + def _pack_pks_expr(self, schema_dict: Dict, pks: List) -> str: primary_field = self._extract_primary_field(schema_dict) pk_field_name = primary_field["name"] - dataType = primary_field["type"] + data_type = primary_field["type"] # Varchar pks need double quotes around the values - if dataType == DataType.VARCHAR: + if data_type == DataType.VARCHAR: ids = ["'" + str(entry) + "'" for entry in pks] expr = f"""{pk_field_name} in [{','.join(ids)}]""" else: @@ -561,7 +589,7 @@ def _pack_pks_expr(self, schema_dict, pks) -> str: expr = f"{pk_field_name} in [{','.join(ids)}]" return expr - def _load(self, collection_name, timeout=None): + def _load(self, collection_name: str, timeout: Optional[float] = None): """Loads the collection.""" conn = self._get_connection() try: @@ -571,4 +599,4 @@ def _load(self, collection_name, timeout=None): "Failed to load collection: %s", collection_name, ) - raise ex + raise ex from ex diff --git a/pymilvus/orm/collection.py b/pymilvus/orm/collection.py index c973bba11..e462aeaf4 100644 --- a/pymilvus/orm/collection.py +++ b/pymilvus/orm/collection.py @@ -12,44 +12,59 @@ import copy import json -from typing import List, Union, Dict -import pandas +from typing import Dict, List, Optional, Union + +import pandas as pd + +from pymilvus.client.constants import DEFAULT_CONSISTENCY_LEVEL +from pymilvus.client.types import ( + CompactionPlans, + CompactionState, + Replica, + cmp_consistency_level, + get_consistency_level, +) +from pymilvus.exceptions import ( + AutoIDException, + DataTypeNotMatchException, + ExceptionsMessage, + IndexNotExistException, + PartitionAlreadyExistException, + PartitionNotExistException, + SchemaNotReadyException, +) +from pymilvus.settings import Config from .connections import connections +from .future import MutationFuture, SearchFuture +from .index import Index +from .iterator import QueryIterator, SearchIterator +from .mutation import MutationResult +from .partition import Partition +from .prepare import Prepare from .schema import ( CollectionSchema, FieldSchema, - construct_fields_from_dataframe, - check_insert_or_upsert_data_schema, - check_insert_or_upsert_is_row_based, + check_insert_schema, + check_is_row_based, check_schema, + check_upset_schema, + construct_fields_from_dataframe, ) -from .prepare import Prepare -from .partition import Partition -from .index import Index from .search import SearchResult -from .mutation import MutationResult from .types import DataType -from ..exceptions import ( - SchemaNotReadyException, - DataTypeNotMatchException, - PartitionAlreadyExistException, - PartitionNotExistException, - IndexNotExistException, - AutoIDException, - ExceptionsMessage, -) -from .future import SearchFuture, MutationFuture from .utility import _get_connection -from ..settings import Config -from ..client.types import CompactionState, CompactionPlans, Replica, get_consistency_level, cmp_consistency_level -from ..client.constants import DEFAULT_CONSISTENCY_LEVEL -from .iterator import QueryIterator, SearchIterator class Collection: - def __init__(self, name: str, schema: CollectionSchema = None, using: str = "default", **kwargs): - """ Constructs a collection by name, schema and other parameters. + def __init__( + self, + name: str, + schema: Optional[CollectionSchema] = None, + using: str = "default", + **kwargs, + ) -> None: + """Constructs a collection by name, schema and other parameters. Args: name (``str``): the name of collection @@ -58,34 +73,35 @@ def __init__(self, name: str, schema: CollectionSchema = None, using: str = "def **kwargs (``dict``): * *num_shards (``int``, optional): how many shards will the insert data be divided. - * *shards_num (``int``, optional, deprecated): how many shards will the insert data be divided. + * *shards_num (``int``, optional, deprecated): + how many shards will the insert data be divided. * *consistency_level* (``int/ str``) Which consistency level to use when searching in the collection. Options of consistency level: Strong, Bounded, Eventually, Session, Customized. - Note: this parameter can be overwritten by the same parameter specified in search. + Note: can be overwritten by the same parameter specified in search. * *properties* (``dict``, optional) Collection properties. * *timeout* (``float``) An optional duration of time in seconds to allow for the RPCs. - If timeout is not set, the client keeps waiting until the server responds or an error occurs. + If timeout is not set, the client keeps waiting until the server + responds or an error occurs. Raises: SchemaNotReadyException: if the schema is wrong. Examples: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType - >>> connections.connect() + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType >>> fields = [ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=128) ... ] >>> schema = CollectionSchema(fields=fields) - >>> properties = {"collection.ttl.seconds": 1800} - >>> collection = Collection(name="test_collection_init", schema=schema, properties=properties) + >>> prop = {"collection.ttl.seconds": 1800} + >>> collection = Collection(name="test_collection_init", schema=schema, properties=prop) >>> collection.name 'test_collection_init' """ @@ -101,7 +117,9 @@ def __init__(self, name: str, schema: CollectionSchema = None, using: str = "def s_consistency_level = resp.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL) arg_consistency_level = kwargs.get("consistency_level", s_consistency_level) if not cmp_consistency_level(s_consistency_level, arg_consistency_level): - raise SchemaNotReadyException(message=ExceptionsMessage.ConsistencyLevelInconsistent) + raise SchemaNotReadyException( + message=ExceptionsMessage.ConsistencyLevelInconsistent + ) server_schema = CollectionSchema.construct_from_dict(resp) self._consistency_level = s_consistency_level if schema is None: @@ -115,11 +133,15 @@ def __init__(self, name: str, schema: CollectionSchema = None, using: str = "def else: if schema is None: - raise SchemaNotReadyException(message=ExceptionsMessage.CollectionNotExistNoSchema % name) + raise SchemaNotReadyException( + message=ExceptionsMessage.CollectionNotExistNoSchema % name + ) if isinstance(schema, CollectionSchema): schema.verify() check_schema(schema) - consistency_level = get_consistency_level(kwargs.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL)) + consistency_level = get_consistency_level( + kwargs.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL) + ) conn.create_collection(self._name, schema, **kwargs) self._schema = schema @@ -130,11 +152,11 @@ def __init__(self, name: str, schema: CollectionSchema = None, using: str = "def self._schema_dict = self._schema.to_dict() self._schema_dict["consistency_level"] = self._consistency_level - def __repr__(self): + def __repr__(self) -> str: _dict = { - 'name': self.name, - 'description': self.description, - 'schema': self._schema, + "name": self.name, + "description": self.description, + "schema": self._schema, } r = [":\n-------------\n"] s = "<{}>: {}\n" @@ -146,10 +168,8 @@ def _get_connection(self): return connections._fetch_handler(self._using) @classmethod - def construct_from_dataframe(cls, name, dataframe, **kwargs): - if dataframe is None: - raise SchemaNotReadyException(message=ExceptionsMessage.NoneDataFrame) - if not isinstance(dataframe, pandas.DataFrame): + def construct_from_dataframe(cls, name: str, dataframe: pd.DataFrame, **kwargs): + if not isinstance(dataframe, pd.DataFrame): raise SchemaNotReadyException(message=ExceptionsMessage.DataFrameType) primary_field = kwargs.pop("primary_field", None) if primary_field is None: @@ -160,9 +180,8 @@ def construct_from_dataframe(cls, name, dataframe, **kwargs): pk_index = i if pk_index == -1: raise SchemaNotReadyException(message=ExceptionsMessage.PrimaryKeyNotExist) - if "auto_id" in kwargs: - if not isinstance(kwargs.get("auto_id", None), bool): - raise AutoIDException(message=ExceptionsMessage.AutoIDType) + if "auto_id" in kwargs and not isinstance(kwargs.get("auto_id", None), bool): + raise AutoIDException(message=ExceptionsMessage.AutoIDType) auto_id = kwargs.pop("auto_id", False) if auto_id: if dataframe[primary_field].isnull().all(): @@ -179,10 +198,16 @@ def construct_from_dataframe(cls, name, dataframe, **kwargs): else: fields_schema = construct_fields_from_dataframe(dataframe) if auto_id: - fields_schema.insert(pk_index, - FieldSchema(name=primary_field, dtype=DataType.INT64, is_primary=True, - auto_id=True, - **kwargs)) + fields_schema.insert( + pk_index, + FieldSchema( + name=primary_field, + dtype=DataType.INT64, + is_primary=True, + auto_id=True, + **kwargs, + ), + ) for field in fields_schema: if auto_id is False and field.name == primary_field: @@ -199,25 +224,24 @@ def construct_from_dataframe(cls, name, dataframe, **kwargs): @property def schema(self) -> CollectionSchema: - """CollectionSchema: schema of the collection. """ + """CollectionSchema: schema of the collection.""" return self._schema @property def aliases(self, **kwargs) -> list: - """List[str]: all the aliases of the collection. """ + """List[str]: all the aliases of the collection.""" conn = self._get_connection() resp = conn.describe_collection(self._name, **kwargs) - aliases = resp["aliases"] - return aliases + return resp["aliases"] @property def description(self) -> str: - """str: a text description of the collection. """ + """str: a text description of the collection.""" return self._schema.description @property def name(self) -> str: - """str: the name of the collection. """ + """str: the name of the collection.""" return self._name @property @@ -229,7 +253,7 @@ def is_empty(self) -> bool: def num_shards(self, **kwargs) -> int: """int: number of shards used by the collection.""" if self._num_shards is None: - self._num_shards = self.describe().get("num_shards") + self._num_shards = self.describe(timeout=kwargs.get("timeout")).get("num_shards") return self._num_shards @property @@ -237,8 +261,7 @@ def num_entities(self, **kwargs) -> int: """int: The number of entities in the collection, not real time. Examples: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType - >>> connections.connect() + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) @@ -264,17 +287,17 @@ def primary_field(self) -> FieldSchema: """FieldSchema: the primary field of the collection.""" return self._schema.primary_field - def flush(self, timeout=None, **kwargs): - """ Seal all segments in the collection. Inserts after flushing will be written into + def flush(self, timeout: Optional[float] = None, **kwargs): + """Seal all segments in the collection. Inserts after flushing will be written into new segments. Only sealed segments can be indexed. Args: timeout (float): an optional duration of time in seconds to allow for the RPCs. - If timeout is not set, the client keeps waiting until the server responds or an error occurs. + If timeout is not set, the client keeps waiting until the server + responds or an error occurs. Examples: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType - >>> connections.connect() + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType >>> fields = [ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=128) @@ -289,16 +312,16 @@ def flush(self, timeout=None, **kwargs): conn = self._get_connection() conn.flush([self.name], timeout=timeout, **kwargs) - def drop(self, timeout=None, **kwargs): - """ Drops the collection. The same as `utility.drop_collection()` + def drop(self, timeout: Optional[float] = None, **kwargs): + """Drops the collection. The same as `utility.drop_collection()` Args: - timeout (float, optional): an optional duration of time in seconds to allow for the RPCs. - If timeout is not set, the client keeps waiting until the server responds or an error occurs. + timeout (float, optional): an optional duration of time in seconds to allow + for the RPCs. If timeout is not set, the client keeps waiting until the + server responds or an error occurs. Examples: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType - >>> connections.connect() + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) @@ -313,18 +336,18 @@ def drop(self, timeout=None, **kwargs): conn = self._get_connection() conn.drop_collection(self._name, timeout=timeout, **kwargs) - def set_properties(self, properties, timeout=None, **kwargs): - """ Set properties for the collection + def set_properties(self, properties: dict, timeout: Optional[float] = None, **kwargs): + """Set properties for the collection Args: properties (``dict``): collection properties. only support collection TTL with key `collection.ttl.seconds` - timeout (``float``, optional): an optional duration of time in seconds to allow for the RPCs. - If timeout is not set, the client keeps waiting until the server responds or an error occurs. + timeout (float, optional): an optional duration of time in seconds to allow + for the RPCs. If timeout is not set, the client keeps waiting until the + server responds or an error occurs. Examples: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType - >>> connections.connect() + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType >>> fields = [ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=128) @@ -334,23 +357,35 @@ def set_properties(self, properties, timeout=None, **kwargs): >>> collection.set_properties({"collection.ttl.seconds": 60}) """ conn = self._get_connection() - conn.alter_collection(self.name, properties=properties, timeout=timeout) - - def load(self, partition_names=None, replica_number=1, timeout=None, **kwargs): - """ Load the data into memory. + conn.alter_collection( + self.name, + properties=properties, + timeout=timeout, + **kwargs, + ) + + def load( + self, + partition_names: Optional[list] = None, + replica_number: int = 1, + timeout: Optional[float] = None, + **kwargs, + ): + """Load the data into memory. Args: partition_names (``List[str]``): The specified partitions to load. replica_number (``int``, optional): The replica number to load, defaults to 1. - timeout (``float``, optional): an optional duration of time in seconds to allow for the RPCs. - If timeout is not set, the client keeps waiting until the server responds or an error occurs. + timeout (float, optional): an optional duration of time in seconds to allow + for the RPCs. If timeout is not set, the client keeps waiting until the + server responds or an error occurs. **kwargs (``dict``, optional): * *_async*(``bool``) Indicate if invoke asynchronously. * *_refresh*(``bool``) - Whether to enable refresh mode(renew the segment list of this collection before loading). + Whether to renew the segment list of this collection before loading * *_resource_groups(``List[str]``) Specify resource groups which can be used during loading. @@ -358,7 +393,7 @@ def load(self, partition_names=None, replica_number=1, timeout=None, **kwargs): MilvusException: If anything goes wrong. Examples: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType >>> connections.connect() >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), @@ -366,48 +401,68 @@ def load(self, partition_names=None, replica_number=1, timeout=None, **kwargs): ... ]) >>> collection = Collection("test_collection_load", schema) >>> collection.insert([[1, 2], [[1.0, 2.0], [3.0, 4.0]]]) - >>> collection.create_index("films", {"index_type": "FLAT", "metric_type": "L2", "params": {}}) + >>> index_param = {"index_type": "FLAT", "metric_type": "L2", "params": {}} + >>> collection.create_index("films", index_param) >>> collection.load() """ conn = self._get_connection() if partition_names is not None: - conn.load_partitions(self._name, partition_names, replica_number=replica_number, timeout=timeout, **kwargs) + conn.load_partitions( + collection_name=self._name, + partition_names=partition_names, + replica_number=replica_number, + timeout=timeout, + **kwargs, + ) else: - conn.load_collection(self._name, replica_number=replica_number, timeout=timeout, **kwargs) + conn.load_collection( + collection_name=self._name, + replica_number=replica_number, + timeout=timeout, + **kwargs, + ) - def release(self, timeout=None, **kwargs): - """ Releases the collection data from memory. + def release(self, timeout: Optional[float] = None, **kwargs): + """Releases the collection data from memory. Args: - timeout (``float``, optional): an optional duration of time in seconds to allow for the RPCs. - If timeout is not set, the client keeps waiting until the server responds or an error occurs. + timeout (float, optional): an optional duration of time in seconds to allow + for the RPCs. If timeout is not set, the client keeps waiting until the + server responds or an error occurs. Examples: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType - >>> connections.connect() + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) ... ]) >>> collection = Collection("test_collection_release", schema) >>> collection.insert([[1, 2], [[1.0, 2.0], [3.0, 4.0]]]) - >>> collection.create_index("films", {"index_type": "FLAT", "metric_type": "L2", "params": {}}) + >>> index_param = {"index_type": "FLAT", "metric_type": "L2", "params": {}} + >>> collection.create_index("films", index_param) >>> collection.load() >>> collection.release() """ conn = self._get_connection() conn.release_collection(self._name, timeout=timeout, **kwargs) - def insert(self, data: Union[List, pandas.DataFrame, Dict], partition_name: str = None, timeout=None, - **kwargs) -> MutationResult: - """ Insert data into the collection. + def insert( + self, + data: Union[List, pd.DataFrame, Dict], + partition_name: Optional[str] = None, + timeout: Optional[float] = None, + **kwargs, + ) -> MutationResult: + """Insert data into the collection. Args: data (``list/tuple/pandas.DataFrame``): The specified data to insert partition_name (``str``): The partition name which the data will be inserted to, - if partition name is not passed, then the data will be inserted to "_default" partition - timeout (``float``, optional): A duration of time in seconds to allow for the RPC. Defaults to None. - If timeout is set to None, the client keeps waiting until the server responds or an error occurs. + if partition name is not passed, then the data will be inserted + to default partition + timeout (float, optional): an optional duration of time in seconds to allow + for the RPCs. If timeout is not set, the client keeps waiting until the + server responds or an error occurs. Returns: MutationResult: contains 2 properties `insert_count`, and, `primary_keys` `insert_count`: how may entites have been inserted into Milvus, @@ -416,9 +471,8 @@ def insert(self, data: Union[List, pandas.DataFrame, Dict], partition_name: str MilvusException: If anything goes wrong. Examples: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType >>> import random - >>> connections.connect() >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) @@ -435,39 +489,58 @@ def insert(self, data: Union[List, pandas.DataFrame, Dict], partition_name: str if data is None: return MutationResult(data) - row_based = check_insert_or_upsert_is_row_based(data) + row_based = check_is_row_based(data) conn = self._get_connection() if not row_based: - check_insert_or_upsert_data_schema(self._schema, data) - entities = Prepare.prepare_insert_or_upsert_data(data, self._schema) - res = conn.batch_insert(self._name, entities, partition_name, - timeout=timeout, schema=self._schema_dict, **kwargs) + check_insert_schema(self._schema, data) + entities = Prepare.prepare_insert_data(data, self._schema) + res = conn.batch_insert( + self._name, + entities, + partition_name, + timeout=timeout, + schema=self._schema_dict, + **kwargs, + ) if kwargs.get("_async", False): return MutationFuture(res) else: - res = conn.insert_rows(self._name, data, partition_name, - timeout=timeout, schema=self._schema_dict, **kwargs) + res = conn.insert_rows( + collection_name=self._name, + entities=data, + partition_name=partition_name, + timeout=timeout, + schema=self._schema_dict, + **kwargs, + ) return MutationResult(res) - def delete(self, expr, partition_name=None, timeout=None, **kwargs): - """ Delete entities with an expression condition. + def delete( + self, + expr: str, + partition_name: Optional[str] = None, + timeout: Optional[float] = None, + **kwargs, + ): + """Delete entities with an expression condition. Args: expr (``str``): The specified data to insert. partition_names (``List[str]``): Name of partitions to delete entities. - timeout (``float``, optional): A duration of time in seconds to allow for the RPC. Defaults to None. - If timeout is set to None, the client keeps waiting until the server responds or an error occurs. + timeout (float, optional): an optional duration of time in seconds to allow + for the RPCs. If timeout is not set, the client keeps waiting until the + server responds or an error occurs. Returns: - MutationResult: contains `delete_count` properties represents how many entities might be deleted. + MutationResult: + contains `delete_count` properties represents how many entities might be deleted. Raises: MilvusException: If anything goes wrong. Examples: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType >>> import random - >>> connections.connect() >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("film_date", DataType.INT64), @@ -492,15 +565,23 @@ def delete(self, expr, partition_name=None, timeout=None, **kwargs): return MutationFuture(res) return MutationResult(res) - def upsert(self, data: Union[List, pandas.DataFrame, Dict], partition_name: str=None, timeout=None, **kwargs) -> MutationResult: - """ Upsert data into the collection. + def upsert( + self, + data: Union[List, pd.DataFrame, Dict], + partition_name: Optional[str] = None, + timeout: Optional[float] = None, + **kwargs, + ) -> MutationResult: + """Upsert data into the collection. Args: data (``list/tuple/pandas.DataFrame``): The specified data to upsert partition_name (``str``): The partition name which the data will be upserted at, - if partition name is not passed, then the data will be upserted in "_default" partition - timeout (``float``, optional): A duration of time in seconds to allow for the RPC. Defaults to None. - If timeout is set to None, the client keeps waiting until the server responds or an error occurs. + if partition name is not passed, then the data will be upserted + in default partition + timeout (float, optional): an optional duration of time in seconds to allow + for the RPCs. If timeout is not set, the client keeps waiting until the + server responds or an error occurs. Returns: MutationResult: contains 2 properties `upsert_count`, and, `primary_keys` `upsert_count`: how may entites have been upserted at Milvus, @@ -509,9 +590,8 @@ def upsert(self, data: Union[List, pandas.DataFrame, Dict], partition_name: str= MilvusException: If anything goes wrong. Examples: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType >>> import random - >>> connections.connect() >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) @@ -528,48 +608,65 @@ def upsert(self, data: Union[List, pandas.DataFrame, Dict], partition_name: str= if data is None: return MutationResult(data) - row_based = check_insert_or_upsert_is_row_based(data) + row_based = check_is_row_based(data) conn = self._get_connection() if not row_based: - check_insert_or_upsert_data_schema(self._schema, data, False) - entities = Prepare.prepare_insert_or_upsert_data(data, self._schema, False) - - res = conn.upsert(self._name, entities, partition_name, - timeout=timeout, schema=self._schema_dict, **kwargs) + check_upset_schema(self._schema, data) + entities = Prepare.prepare_upsert_data(data, self._schema) + + res = conn.upsert( + self._name, + entities, + partition_name, + timeout=timeout, + schema=self._schema_dict, + **kwargs, + ) if kwargs.get("_async", False): return MutationFuture(res) else: - res = conn.upsert_rows(self._name, data, partition_name, - timeout=timeout, schema=self._schema_dict, **kwargs) + res = conn.upsert_rows( + self._name, + data, + partition_name, + timeout=timeout, + schema=self._schema_dict, + **kwargs, + ) return MutationResult(res) - def search(self, data, anns_field, param, limit, expr=None, partition_names=None, - output_fields=None, timeout=None, round_decimal=-1, **kwargs): - """ Conducts a vector similarity search with an optional boolean expression as filter. + def search( + self, + data: List, + anns_field: str, + param: Dict, + limit: int, + expr: Optional[str] = None, + partition_names: Optional[List[str]] = None, + output_fields: Optional[List[str]] = None, + timeout: Optional[float] = None, + round_decimal: int = -1, + **kwargs, + ): + """Conducts a vector similarity search with an optional boolean expression as filter. Args: data (``List[List[float]]``): The vectors of search data. - the length of data is number of query (nq), and the dim of every vector in data must be equal to - the vector field's of collection. + the length of data is number of query (nq), + and the dim of every vector in data must be equal to the vector field of collection. anns_field (``str``): The name of the vector field used to search of collection. param (``dict[str, Any]``): - The parameters of search. The followings are valid keys of param. - * *nprobe*, *ef*, *search_k*, etc Corresponding search params for a certain index. - * *metric_type* (``str``) similar metricy types, the value must be of type str. - * *offset* (``int``, optional) offset for pagination. - * *limit* (``int``, optional) limit for the search results and pagination. - example for param:: { @@ -586,13 +683,15 @@ def search(self, data, anns_field, param, limit, expr=None, partition_names=None "id_field >= 0", "id_field in [1, 2, 3, 4]" - partition_names (``List[str]``, optional): The names of partitions to search on. Default to None. + partition_names (``List[str]``, optional): The names of partitions to search on. output_fields (``List[str]``, optional): The name of fields to return in the search result. Can only get scalar fields. - round_decimal (``int``, optional): The specified number of decimal places of returned distance. + round_decimal (``int``, optional): + The specified number of decimal places of returned distance. Defaults to -1 means no round to returned distance. - timeout (``float``, optional): A duration of time in seconds to allow for the RPC. Defaults to None. - If timeout is set to None, the client keeps waiting until the server responds or an error occurs. + timeout (``float``, optional): A duration of time in seconds to allow for the RPC. + If timeout is set to None, the client keeps waiting until the server + responds or an error occurs. **kwargs (``dict``): Optional search params * *_async* (``bool``, optional) @@ -608,9 +707,9 @@ def search(self, data, anns_field, param, limit, expr=None, partition_names=None Options of consistency level: Strong, Bounded, Eventually, Session, Customized. - Note: this parameter will overwrite the same parameter specified when user created the collection, - if no consistency level was specified, search will use the consistency level when you create the - collection. + Note: this parameter overwrites the same one specified when creating collection, + if no consistency level was specified, search will use the + consistency level when you create the collection. * *guarantee_timestamp* (``int``, optional) Instructs Milvus to see all operations performed before this timestamp. @@ -642,9 +741,8 @@ def search(self, data, anns_field, param, limit, expr=None, partition_names=None MilvusException: If anything goes wrong Examples: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType >>> import random - >>> connections.connect() >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) @@ -656,7 +754,8 @@ def search(self, data, anns_field, param, limit, expr=None, partition_names=None ... [[random.random() for _ in range(2)] for _ in range(10)], ... ] >>> collection.insert(data) - >>> collection.create_index("films", {"index_type": "FLAT", "metric_type": "L2", "params": {}}) + >>> index_param = {"index_type": "FLAT", "metric_type": "L2", "params": {}} + >>> collection.create_index("films", index_param) >>> collection.load() >>> # search >>> search_param = { @@ -672,38 +771,79 @@ def search(self, data, anns_field, param, limit, expr=None, partition_names=None >>> assert len(hits) == 2 >>> print(f"- Total hits: {len(hits)}, hits ids: {hits.ids} ") - Total hits: 2, hits ids: [8, 5] - >>> print(f"- Top1 hit id: {hits[0].id}, distance: {hits[0].distance}, score: {hits[0].score} ") - - Top1 hit id: 8, distance: 0.10143111646175385, score: 0.10143111646175385 + >>> print(f"- Top1 hit id: {hits[0].id}, score: {hits[0].score} ") + - Top1 hit id: 8, score: 0.10143111646175385 """ if expr is not None and not isinstance(expr, str): raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr)) conn = self._get_connection() - res = conn.search(self._name, data, anns_field, param, limit, expr, - partition_names, output_fields, round_decimal, timeout=timeout, - schema=self._schema_dict, **kwargs) + res = conn.search( + self._name, + data, + anns_field, + param, + limit, + expr, + partition_names, + output_fields, + round_decimal, + timeout=timeout, + schema=self._schema_dict, + **kwargs, + ) if kwargs.get("_async", False): return SearchFuture(res) return SearchResult(res) - def search_iterator(self, data, anns_field, param, limit, expr=None, partition_names=None, - output_fields=None, timeout=None, round_decimal=-1, **kwargs): + def search_iterator( + self, + data: List, + anns_field: str, + param: Dict, + limit: int, + expr: Optional[str] = None, + partition_names: Optional[List[str]] = None, + output_fields: Optional[List[str]] = None, + timeout: Optional[float] = None, + round_decimal: int = -1, + **kwargs, + ): if expr is not None and not isinstance(expr, str): raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr)) - conn = self._get_connection() - iterator = SearchIterator(conn, self._name, data, anns_field, param, limit, expr, partition_names, - output_fields, timeout, round_decimal, schema=self._schema_dict, **kwargs) - return iterator - - def query(self, expr, output_fields=None, partition_names=None, timeout=None, **kwargs): - """ Query with expressions + return SearchIterator( + connections=self._get_connection(), + collection_name=self._name, + data=data, + ann_field=anns_field, + param=param, + limit=limit, + expr=expr, + partition_names=partition_names, + output_fields=output_fields, + timeout=timeout, + round_decimal=round_decimal, + schema=self._schema_dict, + **kwargs, + ) + + def query( + self, + expr: str, + output_fields: Optional[List[str]] = None, + partition_names: Optional[List[str]] = None, + timeout: Optional[float] = None, + **kwargs, + ): + """Query with expressions Args: expr (``str``): The query expression. output_fields(``List[str]``): A list of field names to return. Defaults to None. - partition_names: (``List[str]``, optional): A list of partition names to query in. Defaults to None. - timeout (``float``, optional): A duration of time in seconds to allow for the RPC. Defaults to None. - If timeout is set to None, the client keeps waiting until the server responds or an error occurs. + partition_names: (``List[str]``, optional): A list of partition names to query in. + timeout (``float``, optional): A duration of time in seconds to allow for the RPC. + If timeout is set to None, the client keeps waiting until the server + responds or an error occurs. **kwargs (``dict``, optional): * *consistency_level* (``str/int``, optional) @@ -711,9 +851,10 @@ def query(self, expr, output_fields=None, partition_names=None, timeout=None, ** Options of consistency level: Strong, Bounded, Eventually, Session, Customized. - Note: this parameter will overwrite the same parameter specified when user created the collection, - if no consistency level was specified, search will use the consistency level when you create the - collection. + Note: this parameter overwrites the same one specified when creating collection, + if no consistency level was specified, search will use the + consistency level when you create the collection. + * *guarantee_timestamp* (``int``, optional) Instructs Milvus to see all operations performed before this timestamp. @@ -743,9 +884,8 @@ def query(self, expr, output_fields=None, partition_names=None, timeout=None, ** MilvusException: If anything goes wrong Examples: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType >>> import random - >>> connections.connect() >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("film_date", DataType.INT64), @@ -759,7 +899,8 @@ def query(self, expr, output_fields=None, partition_names=None, timeout=None, ** ... [[random.random() for _ in range(2)] for _ in range(10)], ... ] >>> collection.insert(data) - >>> collection.create_index("films", {"index_type": "FLAT", "metric_type": "L2", "params": {}}) + >>> index_param = {"index_type": "FLAT", "metric_type": "L2", "params": {}} + >>> collection.create_index("films", index_param) >>> collection.load() >>> # query >>> expr = "film_id <= 1" @@ -772,28 +913,47 @@ def query(self, expr, output_fields=None, partition_names=None, timeout=None, ** raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr)) conn = self._get_connection() - res = conn.query(self._name, expr, output_fields, partition_names, - timeout=timeout, schema=self._schema_dict, **kwargs) - return res - - def query_iterator(self, expr, output_fields=None, partition_names=None, timeout=None, **kwargs): + return conn.query( + self._name, + expr, + output_fields, + partition_names, + timeout=timeout, + schema=self._schema_dict, + **kwargs, + ) + + def query_iterator( + self, + expr: str, + output_fields: Optional[List[str]] = None, + partition_names: Optional[List[str]] = None, + timeout: Optional[float] = None, + **kwargs, + ): if not isinstance(expr, str): raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr)) conn = self._get_connection() - iterator = QueryIterator(conn, self._name, expr, output_fields, partition_names, - timeout=timeout, schema=self._schema_dict, **kwargs) - return iterator + return QueryIterator( + connection=conn, + collection_name=self._name, + expr=expr, + output_fields=output_fields, + partition_names=partition_names, + schema=self._schema_dict, + timeout=timeout, + **kwargs, + ) @property def partitions(self, **kwargs) -> List[Partition]: - """ List[Partition]: List of Partition object. + """List[Partition]: List of Partition object. Raises: MilvusException: If anything goes wrong. Examples: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType - >>> connections.connect() + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) @@ -809,8 +969,8 @@ def partitions(self, **kwargs) -> List[Partition]: partitions.append(Partition(self, partition, construct_only=True)) return partitions - def partition(self, partition_name, **kwargs) -> Partition: - """ Get the existing partition object according to name. Return None if not existed. + def partition(self, partition_name: str, **kwargs) -> Partition: + """Get the existing partition object according to name. Return None if not existed. Args: partition_name (``str``): The name of the partition to get. @@ -822,8 +982,7 @@ def partition(self, partition_name, **kwargs) -> Partition: MilvusException: If anything goes wrong. Examples: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType - >>> connections.connect() + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) @@ -832,12 +991,12 @@ def partition(self, partition_name, **kwargs) -> Partition: >>> collection.partition("_default") {"name": "_default", "description": "", "num_entities": 0} """ - if self.has_partition(partition_name, **kwargs) is False: + if not self.has_partition(partition_name, **kwargs): return None return Partition(self, partition_name, construct_only=True, **kwargs) - def create_partition(self, partition_name, description="", **kwargs) -> Partition: - """ Create a new partition corresponding to name if not existed. + def create_partition(self, partition_name: str, description: str = "", **kwargs) -> Partition: + """Create a new partition corresponding to name if not existed. Args: partition_name (``str``): The name of the partition to create. @@ -850,29 +1009,29 @@ def create_partition(self, partition_name, description="", **kwargs) -> Partitio MilvusException: If anything goes wrong. Examples: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType - >>> connections.connect() + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) ... ]) - >>> collection = Collection("test_collection_create_partition", schema) + >>> collection = Collection("test_create_partition", schema) >>> collection.create_partition("comedy", description="comedy films") - {"name": "comedy", "collection_name": "test_collection_create_partition", "description": ""} + {"name": "comedy", "collection_name": "test_create_partition", "description": ""} >>> collection.partition("comedy") - {"name": "comedy", "collection_name": "test_collection_create_partition", "description": ""} + {"name": "comedy", "collection_name": "test_create_partition", "description": ""} """ if self.has_partition(partition_name, **kwargs) is True: raise PartitionAlreadyExistException(message=ExceptionsMessage.PartitionAlreadyExist) return Partition(self, partition_name, description=description, **kwargs) - def has_partition(self, partition_name, timeout=None, **kwargs) -> bool: - """ Checks if a specified partition exists. + def has_partition(self, partition_name: str, timeout: Optional[float] = None, **kwargs) -> bool: + """Checks if a specified partition exists. Args: partition_name (``str``): The name of the partition to check. - timeout (``float``, optional): An optional duration of time in seconds to allow for the RPC. When timeout - is set to None, client waits until server response or error occur. + timeout (``float``, optional): An optional duration of time in seconds to allow for + the RPC. When timeout is set to None, client waits until server + response or error occur. Returns: bool: True if exists, otherwise false. @@ -881,8 +1040,7 @@ def has_partition(self, partition_name, timeout=None, **kwargs) -> bool: MilvusException: If anything goes wrong. Examples: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType - >>> connections.connect() + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) @@ -898,21 +1056,21 @@ def has_partition(self, partition_name, timeout=None, **kwargs) -> bool: conn = self._get_connection() return conn.has_partition(self._name, partition_name, timeout=timeout, **kwargs) - def drop_partition(self, partition_name, timeout=None, **kwargs): - """ Drop the partition in this collection. + def drop_partition(self, partition_name: str, timeout: Optional[float] = None, **kwargs): + """Drop the partition in this collection. Args: partition_name (``str``): The name of the partition to drop. - timeout (``float``, optional): An optional duration of time in seconds to allow for the RPC. When timeout - is set to None, client waits until server response or error occur. + timeout (``float``, optional): An optional duration of time in seconds to allow for + the RPC. When timeout is set to None, client waits until server response + or error occur. Raises: PartitionNotExistException: If the partition doesn't exists. MilvusException: If anything goes wrong. Examples: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType - >>> connections.connect() + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) @@ -936,8 +1094,7 @@ def indexes(self, **kwargs) -> List[Index]: """List[Index]: list of indexes of this collection. Examples: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType - >>> connections.connect() + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) @@ -955,7 +1112,13 @@ def indexes(self, **kwargs) -> List[Index]: if info_dict.get("params", None): info_dict["params"] = json.loads(info_dict["params"]) - index_info = Index(self, index.field_name, info_dict, index_name=index.index_name, construct_only=True) + index_info = Index( + collection=self, + field_name=index.field_name, + index_params=info_dict, + index_name=index.index_name, + construct_only=True, + ) indexes.append(index_info) return indexes @@ -974,8 +1137,7 @@ def index(self, **kwargs) -> Index: IndexNotExistException: If the index doesn't exists. Examples: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType - >>> connections.connect() + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) @@ -999,12 +1161,18 @@ def index(self, **kwargs) -> Index: return Index(self, field_name, tmp_index, construct_only=True, index_name=index_name) raise IndexNotExistException(message=ExceptionsMessage.IndexNotExist) - def create_index(self, field_name, index_params={}, timeout=None, **kwargs): + def create_index( + self, + field_name: str, + index_params: Optional[Dict] = None, + timeout: Optional[float] = None, + **kwargs, + ): """Creates index for a specified field, with a index name. Args: field_name (``str``): The name of the field to create index - index_params (``dict``): The parameters to index + index_params (``dict``, optional): The parameters to index * *index_type* (``str``) "index_type" as the key, example values: "FLAT", "IVF_FLAT", etc. @@ -1014,8 +1182,9 @@ def create_index(self, field_name, index_params={}, timeout=None, **kwargs): * *params* (``dict``) "params" as the key, corresponding index params. - timeout (``float``, optional): An optional duration of time in seconds to allow for the RPC. When timeout - is set to None, client waits until server response or error occur. + timeout (``float``, optional): An optional duration of time in seconds to allow + for the RPC. When timeout is set to None, client waits until server + response or error occur. index_name (``str``): The name of index which will be created, must be unique. If no index name is specified, the default index name will be used. @@ -1023,26 +1192,29 @@ def create_index(self, field_name, index_params={}, timeout=None, **kwargs): MilvusException: If anything goes wrong. Examples: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType - >>> connections.connect() + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) ... ]) >>> collection = Collection("test_collection_create_index", schema) - >>> index_params = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"} + >>> index_params = { + ... "index_type": "IVF_FLAT", + ... "params": {"nlist": 128}, + ... "metric_type": "L2"} >>> collection.create_index("films", index_params, index_name="idx") Status(code=0, message='') """ conn = self._get_connection() return conn.create_index(self._name, field_name, index_params, timeout=timeout, **kwargs) - def has_index(self, timeout=None, **kwargs) -> bool: - """ Check whether a specified index exists. + def has_index(self, timeout: Optional[float] = None, **kwargs) -> bool: + """Check whether a specified index exists. Args: - timeout (``float``, optional): An optional duration of time in seconds to allow for the RPC. When timeout - is set to None, client waits until server response or error occur. + timeout (``float``, optional): An optional duration of time in seconds to allow + for the RPC. When timeout is set to None, client waits until server response + or error occur. **kwargs (``dict``): * *index_name* (``str``) @@ -1052,8 +1224,7 @@ def has_index(self, timeout=None, **kwargs) -> bool: bool: Whether the specified index exists. Examples: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType - >>> connections.connect() + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) @@ -1071,11 +1242,12 @@ def has_index(self, timeout=None, **kwargs) -> bool: return False return True - def drop_index(self, timeout=None, **kwargs): - """ Drop index and its corresponding index files. + def drop_index(self, timeout: Optional[float] = None, **kwargs): + """Drop index and its corresponding index files. Args: - timeout (``float``, optional): An optional duration of time in seconds to allow for the RPC. When timeout - is set to None, client waits until server response or error occur. + timeout (``float``, optional): An optional duration of time in seconds to allow + for the RPC. When timeout is set to None, client waits until server response + or error occur. **kwargs (``dict``): * *index_name* (``str``) @@ -1085,8 +1257,7 @@ def drop_index(self, timeout=None, **kwargs): MilvusException: If anything goes wrong. Examples: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType - >>> connections.connect() + >>> from pymilvus Collection, FieldSchema, CollectionSchema, DataType >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) @@ -1105,15 +1276,22 @@ def drop_index(self, timeout=None, **kwargs): conn = self._get_connection() tmp_index = conn.describe_index(self._name, index_name, timeout=timeout, **copy_kwargs) if tmp_index is not None: - index = Index(self, tmp_index['field_name'], tmp_index, construct_only=True, index_name=index_name) + index = Index( + collection=self, + field_name=tmp_index["field_name"], + index_params=tmp_index, + construct_only=True, + index_name=index_name, + ) index.drop(timeout=timeout, **kwargs) - def compact(self, timeout=None, **kwargs): - """ Compact merge the small segments in a collection + def compact(self, timeout: Optional[float] = None, **kwargs): + """Compact merge the small segments in a collection Args: - timeout (``float``, optional): An optional duration of time in seconds to allow for the RPC. When timeout - is set to None, client waits until server response or error occur. + timeout (``float``, optional): An optional duration of time in seconds to allow + for the RPC. When timeout is set to None, client waits until server response + or error occur. Raises: MilvusException: If anything goes wrong. @@ -1121,12 +1299,13 @@ def compact(self, timeout=None, **kwargs): conn = self._get_connection() self.compaction_id = conn.compact(self._name, timeout=timeout, **kwargs) - def get_compaction_state(self, timeout=None, **kwargs) -> CompactionState: - """ Get the current compaction state + def get_compaction_state(self, timeout: Optional[float] = None, **kwargs) -> CompactionState: + """Get the current compaction state Args: - timeout (``float``, optional): An optional duration of time in seconds to allow for the RPC. When timeout - is set to None, client waits until server response or error occur. + timeout (``float``, optional): An optional duration of time in seconds to allow + for the RPC. When timeout is set to None, client waits until server response + or error occur. Raises: MilvusException: If anything goes wrong. @@ -1134,12 +1313,17 @@ def get_compaction_state(self, timeout=None, **kwargs) -> CompactionState: conn = self._get_connection() return conn.get_compaction_state(self.compaction_id, timeout=timeout, **kwargs) - def wait_for_compaction_completed(self, timeout=None, **kwargs) -> CompactionState: - """ Block until the current collection's compaction completed + def wait_for_compaction_completed( + self, + timeout: Optional[float] = None, + **kwargs, + ) -> CompactionState: + """Block until the current collection's compaction completed Args: - timeout (``float``, optional): An optional duration of time in seconds to allow for the RPC. When timeout - is set to None, client waits until server response or error occur. + timeout (``float``, optional): An optional duration of time in seconds to allow + for the RPC. When timeout is set to None, client waits until server response + or error occur. Raises: MilvusException: If anything goes wrong. @@ -1147,30 +1331,32 @@ def wait_for_compaction_completed(self, timeout=None, **kwargs) -> CompactionSta conn = self._get_connection() return conn.wait_for_compaction_completed(self.compaction_id, timeout=timeout, **kwargs) - def get_compaction_plans(self, timeout=None, **kwargs) -> CompactionPlans: + def get_compaction_plans(self, timeout: Optional[float] = None, **kwargs) -> CompactionPlans: """Get the current compaction plans Args: - timeout (``float``, optional): An optional duration of time in seconds to allow for the RPC. When timeout - is set to None, client waits until server response or error occur. + timeout (``float``, optional): An optional duration of time in seconds to allow + for the RPC. When timeout is set to None, client waits until server response + or error occur. Returns: CompactionPlans: All the plans' states of this compaction. """ conn = self._get_connection() return conn.get_compaction_plans(self.compaction_id, timeout=timeout, **kwargs) - def get_replicas(self, timeout=None, **kwargs) -> Replica: + def get_replicas(self, timeout: Optional[float] = None, **kwargs) -> Replica: """Get the current loaded replica information Args: - timeout (``float``, optional): An optional duration of time in seconds to allow for the RPC. When timeout - is set to None, client waits until server response or error occur. + timeout (``float``, optional): An optional duration of time in seconds to allow + for the RPC. When timeout is set to None, client waits until server response + or error occur. Returns: Replica: All the replica information. """ conn = self._get_connection() return conn.get_replicas(self.name, timeout=timeout, **kwargs) - def describe(self, timeout=None): + def describe(self, timeout: Optional[float] = None): conn = self._get_connection() return conn.describe_collection(self.name, timeout=timeout) diff --git a/pymilvus/orm/connections.py b/pymilvus/orm/connections.py index fa5728f9f..b06dff343 100644 --- a/pymilvus/orm/connections.py +++ b/pymilvus/orm/connections.py @@ -12,19 +12,23 @@ import copy import threading +from typing import Callable, Tuple, Union from urllib import parse -from typing import Tuple -from ..client.check import is_legal_host, is_legal_port, is_legal_address -from ..client.grpc_handler import GrpcHandler -from ..client.utils import get_server_type, ZILLIZ - -from ..settings import Config -from ..exceptions import ExceptionsMessage, ConnectionConfigException, ConnectionNotExistException +from pymilvus.client.check import is_legal_address, is_legal_host, is_legal_port +from pymilvus.client.grpc_handler import GrpcHandler +from pymilvus.client.utils import ZILLIZ, get_server_type +from pymilvus.exceptions import ( + ConnectionConfigException, + ConnectionNotExistException, + ExceptionsMessage, +) +from pymilvus.settings import Config VIRTUAL_PORT = 443 -def synchronized(func): + +def synchronized(func: Callable): """ Decorator in order to achieve thread-safe singleton class. """ @@ -40,7 +44,7 @@ def lock_func(*args, **kwargs): class SingleInstanceMetaClass(type): instance = None - def __init__(cls, *args, **kwargs): + def __init__(cls, *args, **kwargs) -> None: super().__init__(*args, **kwargs) def __call__(cls, *args, **kwargs): @@ -56,10 +60,10 @@ def __new__(cls, *args, **kwargs): class Connections(metaclass=SingleInstanceMetaClass): - """ Class for managing all connections of milvus. Used as a singleton in this module. """ + """Class for managing all connections of milvus. Used as a singleton in this module.""" - def __init__(self): - """ Constructs a default milvus alias config + def __init__(self) -> None: + """Constructs a default milvus alias config default config will be read from env: MILVUS_URI and MILVUS_CONN_ALIAS with default value: default="localhost:19530" @@ -98,21 +102,25 @@ def __init__(self): self.add_connection(**{Config.MILVUS_CONN_ALIAS: default_conn_config}) - def __verify_host_port(self, host, port): + def __verify_host_port(self, host: str, port: Union[int, str]): if not is_legal_host(host): raise ConnectionConfigException(message=ExceptionsMessage.HostType) if not is_legal_port(port): raise ConnectionConfigException(message=ExceptionsMessage.PortType) if not 0 <= int(port) < 65535: - raise ConnectionConfigException(message=f"port number {port} out of range, valid range [0, 65535)") + msg = f"port number {port} out of range, valid range [0, 65535)" + raise ConnectionConfigException(message=msg) def __parse_address_from_uri(self, uri: str) -> (str, parse.ParseResult): - illegal_uri_msg = "Illegal uri: [{}], expected form 'http[s]://[user:password@]example.com:12345'" + illegal_uri_msg = ( + "Illegal uri: [{}], expected form 'http[s]://[user:password@]example.com:12345'" + ) try: parsed_uri = parse.urlparse(uri) - except (Exception) as e: + except Exception as e: raise ConnectionConfigException( - message=f"{illegal_uri_msg.format(uri)}: <{type(e).__name__}, {e}>") from None + message=f"{illegal_uri_msg.format(uri)}: <{type(e).__name__}, {e}>" + ) from None if len(parsed_uri.netloc) == 0: raise ConnectionConfigException(message=f"{illegal_uri_msg.format(uri)}") from None @@ -129,7 +137,7 @@ def __parse_address_from_uri(self, uri: str) -> (str, parse.ParseResult): return addr, parsed_uri def add_connection(self, **kwargs): - """ Configures a milvus connection. + """Configures a milvus connection. Addresses priority in kwargs: address, uri, host and port @@ -141,7 +149,8 @@ def add_connection(self, **kwargs): Example uri: "http://localhost:19530", "tcp:localhost:19530", "https://ok.s3.south.com:19530". * *host* (``str``) -- Optional. The host of Milvus instance. - Default at "localhost", PyMilvus will fill in the default host if only port is provided. + Default at "localhost", PyMilvus will fill in the default host + if only port is provided. * *port* (``str/int``) -- Optional. The port of Milvus instance. Default at 19530, PyMilvus will fill in the default port if only host is provided. @@ -163,11 +172,11 @@ def add_connection(self, **kwargs): config.get("address", ""), config.get("uri", ""), config.get("host", ""), - config.get("port", "")) + config.get("port", ""), + ) - if alias in self._connected_alias: - if self._alias[alias].get("address") != addr: - raise ConnectionConfigException(message=ExceptionsMessage.ConnDiffConf % alias) + if alias in self._connected_alias and self._alias[alias].get("address") != addr: + raise ConnectionConfigException(message=ExceptionsMessage.ConnDiffConf % alias) alias_config = { "address": addr, @@ -176,12 +185,18 @@ def add_connection(self, **kwargs): self._alias[alias] = alias_config - def __get_full_address(self, address: str = "", uri: str = "", host: str = "", port: str = "") -> ( - str, parse.ParseResult): + def __get_full_address( + self, + address: str = "", + uri: str = "", + host: str = "", + port: str = "", + ) -> (str, parse.ParseResult): if address != "": if not is_legal_address(address): raise ConnectionConfigException( - message=f"Illegal address: {address}, should be in form 'localhost:19530'") + message=f"Illegal address: {address}, should be in form 'localhost:19530'" + ) return address, None if uri != "": @@ -195,7 +210,7 @@ def __get_full_address(self, address: str = "", uri: str = "", host: str = "", p return f"{host}:{port}", None def disconnect(self, alias: str): - """ Disconnects connection from the registry. + """Disconnects connection from the registry. :param alias: The name of milvus connection :type alias: str @@ -207,7 +222,7 @@ def disconnect(self, alias: str): self._connected_alias.pop(alias).close() def remove_connection(self, alias: str): - """ Removes connection from the registry. + """Removes connection from the registry. :param alias: The name of milvus connection :type alias: str @@ -218,8 +233,15 @@ def remove_connection(self, alias: str): self.disconnect(alias) self._alias.pop(alias, None) - # pylint: disable=too-many-statements - def connect(self, alias=Config.MILVUS_CONN_ALIAS, user="", password="", db_name="", token="", **kwargs): + def connect( + self, + alias: str = Config.MILVUS_CONN_ALIAS, + user: str = "", + password: str = "", + db_name: str = "", + token: str = "", + **kwargs, + ) -> None: """ Constructs a milvus connection and register it under given alias. @@ -229,16 +251,13 @@ def connect(self, alias=Config.MILVUS_CONN_ALIAS, user="", password="", db_name= :param kwargs: * *address* (``str``) -- Optional. The actual address of Milvus instance. Example address: "localhost:19530" - * *uri* (``str``) -- Optional. The uri of Milvus instance. Example uri: "http://localhost:19530", "tcp:localhost:19530", "https://ok.s3.south.com:19530". - * *host* (``str``) -- Optional. The host of Milvus instance. - Default at "localhost", PyMilvus will fill in the default host if only port is provided. - + Default at "localhost", PyMilvus will fill in the default host + if only port is provided. * *port* (``str/int``) -- Optional. The port of Milvus instance. Default at 19530, PyMilvus will fill in the default port if only host is provided. - * *secure* (``bool``) -- Optional. Default is false. If set to true, tls will be enabled. * *user* (``str``) -- @@ -249,7 +268,8 @@ def connect(self, alias=Config.MILVUS_CONN_ALIAS, user="", password="", db_name= the user. * *token* (``str``) -- Optional. Serving as the key for identification and authentication purposes. - Whenever a token is furnished, we shall supplement the corresponding header to each RPC call. + Whenever a token is furnished, we shall supplement the corresponding header + to each RPC call. * *db_name* (``str``) -- Optional. default database name of this connection * *client_key_path* (``str``) -- @@ -280,21 +300,17 @@ def connect_milvus(**kwargs): timeout = t if isinstance(t, (int, float)) else Config.MILVUS_CONN_TIMEOUT gh._wait_for_channel_ready(timeout=timeout) - kwargs.pop('password') + kwargs.pop("password") kwargs.pop("token", None) - kwargs.pop('db_name', None) - kwargs.pop('secure', None) + kwargs.pop("db_name", None) + kwargs.pop("secure", None) kwargs.pop("db_name", "") self._connected_alias[alias] = gh self._alias[alias] = copy.deepcopy(kwargs) def with_config(config: Tuple) -> bool: - for c in config: - if c != "": - return True - - return False + return any(c != "" for c in config) if not isinstance(alias, str): raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) @@ -304,48 +320,38 @@ def with_config(config: Tuple) -> bool: if uri is not None: server_type = get_server_type(uri) if server_type == ZILLIZ and ":" not in token: - kwargs["uri"] = uri+":"+str(VIRTUAL_PORT) + kwargs["uri"] = uri + ":" + str(VIRTUAL_PORT) config = ( kwargs.pop("address", ""), kwargs.pop("uri", ""), kwargs.pop("host", ""), - kwargs.pop("port", "") + kwargs.pop("port", ""), ) # Make sure passed in None doesnt break - user = user or "" - password = password or "" - token = token or "" - # Make sure passed in are Strings - user = str(user) - password = str(password) - token = str(token) + user, password, token = str(user) or "", str(password) or "", str(token) or "" # 1st Priority: connection from params if with_config(config): in_addr, parsed_uri = self.__get_full_address(*config) kwargs["address"] = in_addr - if self.has_connection(alias): - if self._alias[alias].get("address") != in_addr: - raise ConnectionConfigException(message=ExceptionsMessage.ConnDiffConf % alias) + if self.has_connection(alias) and self._alias[alias].get("address") != in_addr: + raise ConnectionConfigException(message=ExceptionsMessage.ConnDiffConf % alias) # uri might take extra info if parsed_uri is not None: - user = parsed_uri.username if parsed_uri.username is not None else user - password = parsed_uri.password if parsed_uri.password is not None else password + user = parsed_uri.username or user + password = parsed_uri.password or password group = parsed_uri.path.split("/") - db_name = "default" - if len(group) > 1: - db_name = group[1] + db_name = group[1] if len(group) > 1 else "default" # Set secure=True if https scheme if parsed_uri.scheme == "https": kwargs["secure"] = True - connect_milvus(**kwargs, user=user, password=password, token=token, db_name=db_name) return @@ -375,7 +381,7 @@ def with_config(config: Tuple) -> bool: raise ConnectionConfigException(message=ExceptionsMessage.ConnLackConf % alias) def list_connections(self) -> list: - """ List names of all connections. + """List names of all connections. :return list: Names of all connections. @@ -384,7 +390,6 @@ def list_connections(self) -> list: >>> from pymilvus import connections >>> connections.connect("test", host="localhost", port="19530") >>> connections.list_connections() - // TODO [('default', None), ('test', )] """ return [(k, self._connected_alias.get(k, None)) for k in self._alias] @@ -403,7 +408,6 @@ def get_connection_addr(self, alias: str): >>> from pymilvus import connections >>> connections.connect("test", host="localhost", port="19530") >>> connections.list_connections() - [('default', None), ('test', )] >>> connections.get_connection_addr('test') {'host': 'localhost', 'port': '19530'} """ @@ -413,7 +417,7 @@ def get_connection_addr(self, alias: str): return self._alias.get(alias, {}) def has_connection(self, alias: str) -> bool: - """ Check if connection named alias exists. + """Check if connection named alias exists. :param alias: The name of milvus connection :type alias: str @@ -425,7 +429,6 @@ def has_connection(self, alias: str) -> bool: >>> from pymilvus import connections >>> connections.connect("test", host="localhost", port="19530") >>> connections.list_connections() - [('default', None), ('test', )] >>> connections.get_connection_addr('test') {'host': 'localhost', 'port': '19530'} """ @@ -433,8 +436,8 @@ def has_connection(self, alias: str) -> bool: raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) return alias in self._connected_alias - def _fetch_handler(self, alias=Config.MILVUS_CONN_ALIAS) -> GrpcHandler: - """ Retrieves a GrpcHandler by alias. """ + def _fetch_handler(self, alias: str = Config.MILVUS_CONN_ALIAS) -> GrpcHandler: + """Retrieves a GrpcHandler by alias.""" if not isinstance(alias, str): raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) diff --git a/pymilvus/orm/constants.py b/pymilvus/orm/constants.py index 29935d4d5..56a894e56 100644 --- a/pymilvus/orm/constants.py +++ b/pymilvus/orm/constants.py @@ -1,4 +1,4 @@ - # Copyright (C) 2019-2021 Zilliz. All rights reserved. +# Copyright (C) 2019-2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at diff --git a/pymilvus/orm/db.py b/pymilvus/orm/db.py index 866fecc60..8753ab63c 100644 --- a/pymilvus/orm/db.py +++ b/pymilvus/orm/db.py @@ -1,12 +1,14 @@ -from pymilvus import connections +from typing import Optional +from . import connections -def _get_connection(alias): + +def _get_connection(alias: str): return connections._fetch_handler(alias) -def using_database(db_name, using="default"): - """ Using a database as a default database name within this connection +def using_database(db_name: str, using: str = "default"): + """Using a database as a default database name within this connection :param db_name: Database name :type db_name: str @@ -15,8 +17,8 @@ def using_database(db_name, using="default"): _get_connection(using).reset_db_name(db_name) -def create_database(db_name, using="default", timeout=None): - """ Create a database using provided database name +def create_database(db_name: str, using: str = "default", timeout: Optional[float] = None): + """Create a database using provided database name :param db_name: Database name :type db_name: str @@ -25,8 +27,8 @@ def create_database(db_name, using="default", timeout=None): _get_connection(using).create_database(db_name, timeout=timeout) -def drop_database(db_name, using="default", timeout=None): - """ Drop a database using provided database name +def drop_database(db_name: str, using: str = "default", timeout: Optional[float] = None): + """Drop a database using provided database name :param db_name: Database name :type db_name: str @@ -35,8 +37,8 @@ def drop_database(db_name, using="default", timeout=None): _get_connection(using).drop_database(db_name, timeout=timeout) -def list_database(using="default", timeout=None) -> list: - """ List databases +def list_database(using: str = "default", timeout: Optional[float] = None) -> list: + """List databases :return list[str]: List of database names, return when operation is successful diff --git a/pymilvus/orm/future.py b/pymilvus/orm/future.py index 8fef4df6b..09f43b76c 100644 --- a/pymilvus/orm/future.py +++ b/pymilvus/orm/future.py @@ -10,46 +10,42 @@ # or implied. See the License for the specific language governing permissions and limitations under # the License. +from typing import Any -from .search import SearchResult from .mutation import MutationResult +from .search import SearchResult # TODO(dragondriver): how could we inherit the docstring elegantly? class BaseFuture: - def __init__(self, future): + def __init__(self, future: Any) -> None: self._f = future - def result(self, **kwargs): - """ - Return the result from future object. + def result(self) -> Any: + """Return the result from future object. It's a synchronous interface. It will wait executing until server respond or timeout occur(if specified). """ return self.on_response(self._f.result()) - def on_response(self, res): + def on_response(self, res: Any): return res def cancel(self): - """ - Cancel the request. - """ + """Cancel the request.""" return self._f.cancel() def done(self): - """ - Wait for request done. - """ + """Wait for request done.""" return self._f.done() class SearchFuture(BaseFuture): - def on_response(self, res): + def on_response(self, res: Any): return SearchResult(res) class MutationFuture(BaseFuture): - def on_response(self, res): + def on_response(self, res: Any): return MutationResult(res) diff --git a/pymilvus/orm/index.py b/pymilvus/orm/index.py index 2bbd39b62..df151ad67 100644 --- a/pymilvus/orm/index.py +++ b/pymilvus/orm/index.py @@ -11,54 +11,63 @@ # the License. import copy +from typing import Dict, Optional, TypeVar -from ..exceptions import CollectionNotExistException, ExceptionsMessage -from ..settings import Config +from pymilvus.exceptions import CollectionNotExistException, ExceptionsMessage +from pymilvus.settings import Config + +Index = TypeVar("Index") +Collection = TypeVar("Collection") class Index: - def __init__(self, collection, field_name, index_params, **kwargs): - """ - Creates index on a specified field according to the index parameters. - - :param collection: The collection in which the index is created - :type collection: Collection - - :param field_name: The name of the field to create an index for. - :type field_name: str - - :param index_params: Indexing parameters. - :type index_params: dict - - :param kwargs: - * *index_name* (``str``) -- - The name of index which will be created. Then you can use the index name to check the state of index. - If no index name is specified, default index name is used. - - :raises ParamError: If parameters are invalid. - - :example: - >>> from pymilvus import * - >>> from pymilvus.schema import * - >>> from pymilvus.types import DataType - >>> connections.connect() - - >>> field1 = FieldSchema("int64", DataType.INT64, is_primary=True) - >>> field2 = FieldSchema("fvec", DataType.FLOAT_VECTOR, is_primary=False, dim=128) - >>> schema = CollectionSchema(fields=[field1, field2], description="collection description") - >>> collection = Collection(name='test_collection', schema=schema) - >>> # insert some data - >>> index_params = {"index_type": "IVF_FLAT", "metric_type": "L2", "params": {"nlist": 128}} - >>> index = Index(collection, "fvec", index_params) - >>> print(index.params) - {'index_type': 'IVF_FLAT', 'metric_type': 'L2', 'params': {'nlist': 128}} - >>> print(index.collection_name) - test_collection - >>> print(index.field_name) - fvec - >>> index.drop() + def __init__( + self, + collection: Collection, + field_name: str, + index_params: Dict, + **kwargs, + ) -> Index: + """Creates index on a specified field according to the index parameters. + + Args: + collection(Collection): The collection in which the index is created + field_name(str): The name of the field to create an index for. + index_params(dict): Indexing parameters. + kwargs: + * *index_name* (``str``) -- + The name of index which will be created. If no index name is specified, + default index name will be used. + + Raises: + MilvusException: If anything goes wrong. + + Examples: + >>> from pymilvus import * + >>> from pymilvus.schema import * + >>> from pymilvus.types import DataType + >>> connections.connect() + + >>> field1 = FieldSchema("int64", DataType.INT64, is_primary=True) + >>> field2 = FieldSchema("fvec", DataType.FLOAT_VECTOR, is_primary=False, dim=128) + >>> schema = CollectionSchema(fields=[field1, field2]) + >>> collection = Collection(name='test_collection', schema=schema) + >>> # insert some data + >>> index_params = { + ... "index_type": "IVF_FLAT", + ... "metric_type": "L2", + ... "params": {"nlist": 128}} + >>> index = Index(collection, "fvec", index_params) + >>> index.params + {'index_type': 'IVF_FLAT', 'metric_type': 'L2', 'params': {'nlist': 128}} + >>> index.collection_name + test_collection + >>> index.field_name + fvec + >>> index.drop() """ from .collection import Collection + if not isinstance(collection, Collection): raise CollectionNotExistException(message=ExceptionsMessage.CollectionType) self._collection = collection @@ -81,80 +90,57 @@ def __init__(self, collection, field_name, index_params, **kwargs): def _get_connection(self): return self._collection._get_connection() - # read-only @property def params(self) -> dict: - """ - Returns the index parameters. - - :return dict: - The index parameters - """ + """dict: The index parameters""" return copy.deepcopy(self._index_params) - # read-only @property def collection_name(self) -> str: - """ - Returns the corresponding collection name. - - :return str: - The corresponding collection name - """ + """str: The corresponding collection name""" return self._collection.name @property def field_name(self) -> str: - """ - Returns the corresponding field name. - - :return str: - The corresponding field name. - """ + """str: The corresponding field name.""" return self._field_name @property def index_name(self) -> str: - """ - Returns the corresponding index name. - - :return str: - The corresponding index name. - """ + """str: The corresponding index name.""" return self._index_name - def __eq__(self, other) -> bool: - """ - The order of the fields of index must be consistent. - """ + def __eq__(self, other: Index) -> bool: + """The order of the fields of index must be consistent.""" return self.to_dict() == other.to_dict() def to_dict(self): - """ - Put collection name, field name and index params into dict. - """ - _dict = { + """Put collection name, field name and index params into dict.""" + return { "collection": self._collection._name, "field": self._field_name, "index_name": self._index_name, - "index_param": self.params + "index_param": self.params, } - return _dict - - def drop(self, timeout=None, **kwargs): - """ - Drop an index and its corresponding index files. - - :param timeout: An optional duration of time in seconds to allow for the RPC. When timeout - is set to None, client waits until server response or error occur - :type timeout: float - :param kwargs: - * *index_name* (``str``) -- - The name of index. If no index is specified, the default index name is used. + def drop(self, timeout: Optional[float] = None, **kwargs): + """Drop an index and its corresponding index files. + Args: + timeout(float, optional): An optional duration of time in seconds to allow + for the RPC. When timeout is set to None, client waits until server response + or error occur + kwargs: + * *index_name* (``str``) -- + The name of index. If no index is specified, the default index name is used. """ copy_kwargs = copy.deepcopy(kwargs) index_name = copy_kwargs.pop("index_name", Config.IndexName) conn = self._get_connection() - conn.drop_index(self._collection.name, self.field_name, index_name, timeout=timeout, **copy_kwargs) + conn.drop_index( + collection_name=self._collection.name, + field_name=self.field_name, + index_name=index_name, + timeout=timeout, + **copy_kwargs, + ) diff --git a/pymilvus/orm/iterator.py b/pymilvus/orm/iterator.py index 42729c476..69c068d3b 100644 --- a/pymilvus/orm/iterator.py +++ b/pymilvus/orm/iterator.py @@ -1,19 +1,50 @@ -from .constants import CALC_DIST_JACCARD, CALC_DIST_COSINE, OFFSET, LIMIT, METRIC_TYPE,\ - FIELDS, PARAMS, RADIUS, RANGE_FILTER, DEFAULT_MAX_L2_DISTANCE, DEFAULT_MIN_IP_DISTANCE, \ - DEFAULT_MAX_HAMMING_DISTANCE, DEFAULT_MAX_TANIMOTO_DISTANCE, DEFAULT_MAX_JACCARD_DISTANCE, \ - DEFAULT_MIN_COSINE_DISTANCE, MAX_FILTERED_IDS_COUNT_ITERATION, ITERATION_EXTENSION_REDUCE_RATE, \ - CALC_DIST_L2, CALC_DIST_IP, CALC_DIST_HAMMING, CALC_DIST_TANIMOTO - -from .types import DataType -from ..exceptions import ( - MilvusException, +from typing import Any, Dict, List, Optional, TypeVar + +from pymilvus.exceptions import MilvusException + +from .connections import Connections +from .constants import ( + CALC_DIST_COSINE, + CALC_DIST_HAMMING, + CALC_DIST_IP, + CALC_DIST_JACCARD, + CALC_DIST_L2, + CALC_DIST_TANIMOTO, + DEFAULT_MAX_HAMMING_DISTANCE, + DEFAULT_MAX_JACCARD_DISTANCE, + DEFAULT_MAX_L2_DISTANCE, + DEFAULT_MAX_TANIMOTO_DISTANCE, + DEFAULT_MIN_COSINE_DISTANCE, + DEFAULT_MIN_IP_DISTANCE, + FIELDS, + ITERATION_EXTENSION_REDUCE_RATE, + LIMIT, + MAX_FILTERED_IDS_COUNT_ITERATION, + METRIC_TYPE, + OFFSET, + PARAMS, + RADIUS, + RANGE_FILTER, ) +from .schema import CollectionSchema +from .types import DataType +QueryIterator = TypeVar("QueryIterator") +SearchIterator = TypeVar("SearchIterator") -class QueryIterator: - def __init__(self, connection, collection_name, expr, output_fields=None, partition_names=None, schema=None, - timeout=None, **kwargs): +class QueryIterator: + def __init__( + self, + connection: Connections, + collection_name: str, + expr: str, + output_fields: Optional[List[str]] = None, + partition_names: Optional[List[str]] = None, + schema: Optional[CollectionSchema] = None, + timeout: Optional[float] = None, + **kwargs, + ) -> QueryIterator: self._conn = connection self._collection_name = collection_name self._expr = expr @@ -38,76 +69,88 @@ def __seek(self): first_cursor_kwargs[LIMIT] = self._kwargs[OFFSET] first_cursor_kwargs[ITERATION_EXTENSION_REDUCE_RATE] = 0 - res = self._conn.query(self._collection_name, self._expr, self._output_fields, self._partition_names, - timeout=self._timeout, **first_cursor_kwargs) + res = self._conn.query( + collection_name=self._collection_name, + expr=self._expr, + output_field=self._output_fields, + partition_name=self._partition_names, + timeout=self._timeout, + **first_cursor_kwargs, + ) self.__update_cursor(res) self._kwargs[OFFSET] = 0 - def __maybe_cache(self, result): + def __maybe_cache(self, result: List): if len(result) < 2 * self._kwargs[LIMIT]: return start = self._kwargs[LIMIT] cache_result = result[start:] - cache_id = iteratorCache.cache(cache_result, NO_CACHE_ID) + cache_id = iterator_cache.cache(cache_result, NO_CACHE_ID) self._cache_id_in_use = cache_id - def __is_res_sufficient(self, res): + def __is_res_sufficient(self, res: List): return res is not None and len(res) >= self._kwargs[LIMIT] def next(self): - cached_res = iteratorCache.fetch_cache(self._cache_id_in_use) + cached_res = iterator_cache.fetch_cache(self._cache_id_in_use) ret = None if self.__is_res_sufficient(cached_res): - ret = cached_res[0:self._kwargs[LIMIT]] - res_to_cache = cached_res[self._kwargs[LIMIT]:] - iteratorCache.cache(res_to_cache, self._cache_id_in_use) + ret = cached_res[0 : self._kwargs[LIMIT]] + res_to_cache = cached_res[self._kwargs[LIMIT] :] + iterator_cache.cache(res_to_cache, self._cache_id_in_use) else: - iteratorCache.release_cache(self._cache_id_in_use) + iterator_cache.release_cache(self._cache_id_in_use) current_expr = self.__setup_next_expr() - res = self._conn.query(self._collection_name, current_expr, self._output_fields, self._partition_names, - timeout=self._timeout, **self._kwargs) + res = self._conn.query( + collection_name=self._collection_name, + expr=current_expr, + output_fields=self._output_fields, + partition_names=self._partition_names, + timeout=self._timeout, + **self._kwargs, + ) self.__maybe_cache(res) - ret = res[0:min(self._kwargs[LIMIT], len(res))] + ret = res[0 : min(self._kwargs[LIMIT], len(res))] self.__update_cursor(ret) return ret def __setup__pk_prop(self): fields = self._schema[FIELDS] for field in fields: - if field['is_primary']: - if field['type'] == DataType.VARCHAR: + if field["is_primary"]: + if field["type"] == DataType.VARCHAR: self._pk_str = True else: self._pk_str = False - self._pk_field_name = field['name'] + self._pk_field_name = field["name"] break if self._pk_field_name is None or self._pk_field_name == "": raise MilvusException(message="schema must contain pk field, broke") - def __setup_next_expr(self): + def __setup_next_expr(self) -> None: current_expr = self._expr if self._next_id is None: return current_expr filtered_pk_str = "" if self._pk_str: - filtered_pk_str = f"{self._pk_field_name} > \"{self._next_id}\"" + filtered_pk_str = f'{self._pk_field_name} > "{self._next_id}"' else: filtered_pk_str = f"{self._pk_field_name} > {self._next_id}" if current_expr is None or len(current_expr) == 0: return filtered_pk_str return current_expr + " and " + filtered_pk_str - def __update_cursor(self, res): + def __update_cursor(self, res: List) -> None: if len(res) == 0: return self._next_id = res[-1][self._pk_field_name] - def close(self): + def close(self) -> None: # release cache in use - iteratorCache.release_cache(self._cache_id_in_use) + iterator_cache.release_cache(self._cache_id_in_use) -def default_radius(metrics): +def default_radius(metrics: str): if metrics is CALC_DIST_L2: return DEFAULT_MAX_L2_DISTANCE if metrics is CALC_DIST_IP: @@ -124,16 +167,35 @@ def default_radius(metrics): class SearchIterator: - - def __init__(self, connection, collection_name, data, ann_field, param, limit, expr=None, partition_names=None, - output_fields=None, timeout=None, round_decimal=-1, schema=None, **kwargs): + def __init__( + self, + connection: Connections, + collection_name: str, + data: List, + ann_field: str, + param: Dict, + limit: int, + expr: Optional[str] = None, + partition_names: Optional[List[str]] = None, + output_fields: Optional[List[str]] = None, + timeout: Optional[float] = None, + round_decimal: int = -1, + schema: Optional[CollectionSchema] = None, + **kwargs, + ) -> SearchIterator: if len(data) > 1: raise MilvusException(message="Not support multiple vector iterator at present") self._conn = connection - self._iterator_params = {'collection_name': collection_name, "data": data, - "ann_field": ann_field, "limit": limit, - "output_fields": output_fields, "partition_names": partition_names, - "timeout": timeout, "round_decimal": round_decimal} + self._iterator_params = { + "collection_name": collection_name, + "data": data, + "ann_field": ann_field, + "limit": limit, + "output_fields": output_fields, + "partition_names": partition_names, + "timeout": timeout, + "round_decimal": round_decimal, + } self._expr = expr self._param = param self._kwargs = kwargs @@ -149,12 +211,12 @@ def __init__(self, connection, collection_name, data, ann_field, param, limit, e def __setup__pk_prop(self): fields = self._schema[FIELDS] for field in fields: - if field['is_primary']: - if field['type'] == DataType.VARCHAR: + if field["is_primary"]: + if field["type"] == DataType.VARCHAR: self._pk_str = True else: self._pk_str = False - self._pk_field_name = field['name'] + self._pk_field_name = field["name"] break if self._pk_field_name is None or self._pk_field_name == "": raise MilvusException(message="schema must contain pk field, broke") @@ -173,7 +235,7 @@ def __seek(self): if self._kwargs.get(OFFSET, 0) != 0: raise MilvusException(message="Not support offset when searching iteration") - def __update_cursor(self, res): + def __update_cursor(self, res: Any): if len(res[0]) == 0: return last_hit = res[0][-1] @@ -187,39 +249,43 @@ def __update_cursor(self, res): if hit.distance == last_hit.distance: self._filtered_ids.append(hit.id) if len(self._filtered_ids) > MAX_FILTERED_IDS_COUNT_ITERATION: - raise MilvusException(message=f"filtered ids length has accumulated to more than " - f"{str(MAX_FILTERED_IDS_COUNT_ITERATION)}, " - f"there is a danger of overly memory consumption") - + raise MilvusException( + message=f"filtered ids length has accumulated to more than " + f"{MAX_FILTERED_IDS_COUNT_ITERATION!s}, " + f"there is a danger of overly memory consumption" + ) def next(self): next_params = self.__next_params() next_expr = self.__filtered_duplicated_result_expr(self._expr) - res = self._conn.search(self._iterator_params['collection_name'], - self._iterator_params['data'], - self._iterator_params['ann_field'], - next_params, - self._iterator_params['limit'], - next_expr, - self._iterator_params['partition_names'], - self._iterator_params['output_fields'], - self._iterator_params['round_decimal'], - timeout=self._iterator_params['timeout'], - schema=self._schema, **self._kwargs) + res = self._conn.search( + self._iterator_params["collection_name"], + self._iterator_params["data"], + self._iterator_params["ann_field"], + next_params, + self._iterator_params["limit"], + next_expr, + self._iterator_params["partition_names"], + self._iterator_params["output_fields"], + self._iterator_params["round_decimal"], + timeout=self._iterator_params["timeout"], + schema=self._schema, + **self._kwargs, + ) self.__update_cursor(res) return res # at present, the range_filter parameter means 'larger/less and equal', # so there would be vectors with same distances returned multiple times in different pages # we need to refine and remove these results before returning - def __filtered_duplicated_result_expr(self, expr): + def __filtered_duplicated_result_expr(self, expr: str): if len(self._filtered_ids) == 0: return expr filtered_ids_str = "" for filtered_id in self._filtered_ids: if self._pk_str: - filtered_ids_str += f"\"{filtered_id}\"," + filtered_ids_str += f'"{filtered_id}",' else: filtered_ids_str += f"{filtered_id}," filtered_ids_str = filtered_ids_str[0:-1] @@ -242,26 +308,25 @@ def close(self): class IteratorCache: - - def __init__(self): + def __init__(self) -> None: self._cache_id = 0 self._cache_map = {} - def cache(self, result, cache_id): + def cache(self, result: Any, cache_id: int): if cache_id == NO_CACHE_ID: self._cache_id += 1 cache_id = self._cache_id self._cache_map[cache_id] = result return cache_id - def fetch_cache(self, cache_id): + def fetch_cache(self, cache_id: int): return self._cache_map.get(cache_id, None) - def release_cache(self, cache_id): + def release_cache(self, cache_id: int): if self._cache_map.get(cache_id, None) is not None: self._cache_map.pop(cache_id) NO_CACHE_ID = -1 # Singleton Mode in Python -iteratorCache = IteratorCache() +iterator_cache = IteratorCache() diff --git a/pymilvus/orm/mutation.py b/pymilvus/orm/mutation.py index a66e0c905..b11ca9276 100644 --- a/pymilvus/orm/mutation.py +++ b/pymilvus/orm/mutation.py @@ -10,9 +10,11 @@ # or implied. See the License for the specific language governing permissions and limitations under # the License. +from typing import Any + class MutationResult: - def __init__(self, mr): + def __init__(self, mr: Any) -> None: self._mr = mr @property @@ -51,7 +53,7 @@ def succ_index(self): def err_index(self): return self._mr.err_index if self._mr else [] - def __str__(self): + def __str__(self) -> str: """ Return the information of mutation result diff --git a/pymilvus/orm/partition.py b/pymilvus/orm/partition.py index 1f3e226a1..f62e0e0e9 100644 --- a/pymilvus/orm/partition.py +++ b/pymilvus/orm/partition.py @@ -10,31 +10,42 @@ # or implied. See the License for the specific language governing permissions and limitations under # the License. -import json -from typing import Union, List +from typing import Dict, List, Optional, TypeVar, Union -import pandas +import pandas as pd +import ujson -from ..exceptions import ( - PartitionNotExistException, +from pymilvus.client.types import Replica +from pymilvus.exceptions import ( ExceptionsMessage, - MilvusException + MilvusException, + PartitionNotExistException, ) -from .search import SearchResult from .mutation import MutationResult -from ..client.types import Replica +from .search import SearchResult + +CollectionT = TypeVar("Collection") +PartitionT = TypeVar("Partition") class Partition: - def __init__(self, collection, name, description="", **kwargs): + def __init__( + self, + collection: Union[CollectionT, str], + name: str, + description: str = "", + **kwargs, + ) -> PartitionT: from .collection import Collection + if isinstance(collection, Collection): self._collection = collection elif isinstance(collection, str): self._collection = Collection(collection) else: - raise MilvusException(message="Collection must be of type pymilvus.Collection or String") + msg = "Collection must be of type pymilvus.Collection or String" + raise MilvusException(message=msg) self._name = name self._description = description @@ -46,12 +57,14 @@ def __init__(self, collection, name, description="", **kwargs): conn = self._get_connection() conn.create_partition(self._collection.name, self.name, **kwargs) - def __repr__(self): - return json.dumps({ - 'name': self.name, - 'collection_name': self._collection.name, - 'description': self.description, - }) + def __repr__(self) -> str: + return ujson.dumps( + { + "name": self.name, + "collection_name": self._collection.name, + "description": self.description, + } + ) def _get_connection(self): return self._collection._get_connection() @@ -103,8 +116,9 @@ def num_entities(self, **kwargs) -> int: """int: number of entities in the partition Examples: - >>> from pymilvus import connections, Collection, Partition, FieldSchema, CollectionSchema, DataType + >>> from pymilvus import connections >>> connections.connect() + >>> from pymilvus import Collection, Partition, FieldSchema, CollectionSchema, DataType >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) @@ -120,28 +134,33 @@ def num_entities(self, **kwargs) -> int: 10 """ conn = self._get_connection() - stats = conn.get_partition_stats(collection_name=self._collection.name, partition_name=self.name, **kwargs) + stats = conn.get_partition_stats( + collection_name=self._collection.name, partition_name=self.name, **kwargs + ) result = {stat.key: stat.value for stat in stats} result["row_count"] = int(result["row_count"]) return result["row_count"] - def flush(self, timeout=None, **kwargs): - """ Seal all segment in the collection of this partition. Inserts after flushing will be written into - new segments. Only sealed segments can be indexed. + def flush(self, timeout: Optional[float] = None, **kwargs): + """Seal all segment in the collection of this partition. + Inserts after flushing will be written into new segments. + Only sealed segments can be indexed. Args: - timeout (float): an optional duration of time in seconds to allow for the RPCs. - If timeout is not set, the client keeps waiting until the server responds or an error occurs. + timeout (float, optional): an optional duration of time in seconds to allow + for the RPCs. If timeout is not set, the client keeps waiting until the server + responds or an error occurs. """ conn = self._get_connection() conn.flush([self._collection.name], timeout=timeout, **kwargs) - def drop(self, timeout=None, **kwargs): - """ Drop the partition, the same as Collection.drop_partition + def drop(self, timeout: Optional[float] = None, **kwargs): + """Drop the partition, the same as Collection.drop_partition Args: - timeout (``float``, optional): an optional duration of time in seconds to allow for the RPCs. - If timeout is not set, the client keeps waiting until the server responds or an error occurs. + timeout (``float``, optional): an optional duration of time in seconds to allow + for the RPCs. If timeout is not set, the client keeps waiting until the server + responds or an error occurs. Raises: PartitionNotExistException: If the partitoin doesn't exist @@ -158,20 +177,22 @@ def drop(self, timeout=None, **kwargs): raise PartitionNotExistException(message=ExceptionsMessage.PartitionNotExist) return conn.drop_partition(self._collection.name, self.name, timeout=timeout, **kwargs) - def load(self, replica_number: int=1, timeout=None, **kwargs): - """ Load the partition data into memory. + def load(self, replica_number: int = 1, timeout: Optional[float] = None, **kwargs): + """Load the partition data into memory. Args: replica_number (``int``, optional): The replica number to load, defaults to 1. - timeout (``float``, optional): an optional duration of time in seconds to allow for the RPCs. - If timeout is not set, the client keeps waiting until the server responds or an error occurs. + timeout (``float``, optional): an optional duration of time in seconds to allow + for the RPCs. If timeout is not set, the client keeps waiting until the + server responds or an error occurs. Raises: MilvusException: If anything goes wrong Examples: - >>> from pymilvus import connections, Collection, Partition, FieldSchema, CollectionSchema, DataType + >>> from pymilvus import connections >>> connections.connect() + >>> from pymilvus import Collection, Partition, FieldSchema, CollectionSchema, DataType >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) @@ -182,22 +203,30 @@ def load(self, replica_number: int=1, timeout=None, **kwargs): """ conn = self._get_connection() if conn.has_partition(self._collection.name, self.name, **kwargs): - return conn.load_partitions(self._collection.name, [self.name], replica_number, timeout=timeout, **kwargs) + return conn.load_partitions( + collection_name=self._collection.name, + partition_names=[self.name], + replica_number=replica_number, + timeout=timeout, + **kwargs, + ) raise PartitionNotExistException(message=ExceptionsMessage.PartitionNotExist) - def release(self, timeout=None, **kwargs): - """ Release the partition data from memory. + def release(self, timeout: Optional[float] = None, **kwargs): + """Release the partition data from memory. Args: - timeout (``float``, optional): an optional duration of time in seconds to allow for the RPCs. - If timeout is not set, the client keeps waiting until the server responds or an error occurs. + timeout (``float``, optional): an optional duration of time in seconds to allow + for the RPCs. If timeout is not set, the client keeps waiting until the + server responds or an error occurs. Raises: MilvusException: If anything goes wrong Examples: - >>> from pymilvus import connections, Collection, Partition, FieldSchema, CollectionSchema, DataType + >>> from pymilvus import connections >>> connections.connect() + >>> from pymilvus import Collection, Partition, FieldSchema, CollectionSchema, DataType >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) @@ -209,18 +238,29 @@ def release(self, timeout=None, **kwargs): """ conn = self._get_connection() if conn.has_partition(self._collection.name, self._name, **kwargs): - return conn.release_partitions(self._collection.name, [self.name], timeout=timeout, **kwargs) + return conn.release_partitions( + collection_name=self._collection.name, + partition_names=[self.name], + timeout=timeout, + **kwargs, + ) raise PartitionNotExistException(message=ExceptionsMessage.PartitionNotExist) - def insert(self, data: Union[List, pandas.DataFrame], timeout=None, **kwargs) -> MutationResult: - """ Insert data into the partition, the same as Collection.insert(data, [partition]) + def insert( + self, + data: Union[List, pd.DataFrame], + timeout: Optional[float] = None, + **kwargs, + ) -> MutationResult: + """Insert data into the partition, the same as Collection.insert(data, [partition]) Args: data (``list/tuple/pandas.DataFrame``): The specified data to insert partition_name (``str``): The partition name which the data will be inserted to, - if partition name is not passed, then the data will be inserted to "_default" partition - timeout (``float``, optional): A duration of time in seconds to allow for the RPC. Defaults to None. - If timeout is set to None, the client keeps waiting until the server responds or an error occurs. + if partition name is not passed, then the data will be inserted to default partition + timeout (``float``, optional): A duration of time in seconds to allow for the RPC + If timeout is set to None, the client keeps waiting until the server + responds or an error occurs. Returns: MutationResult: contains 2 properties `insert_count`, and, `primary_keys` `insert_count`: how may entites have been inserted into Milvus, @@ -229,8 +269,9 @@ def insert(self, data: Union[List, pandas.DataFrame], timeout=None, **kwargs) -> MilvusException: If anything goes wrong. Examples: - >>> from pymilvus import connections, Collection, Partition, FieldSchema, CollectionSchema, DataType + >>> from pymilvus import connections >>> connections.connect() + >>> from pymilvus import Collection, Partition, FieldSchema, CollectionSchema, DataType >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) @@ -250,24 +291,25 @@ def insert(self, data: Union[List, pandas.DataFrame], timeout=None, **kwargs) -> raise PartitionNotExistException(message=ExceptionsMessage.PartitionNotExist) return self._collection.insert(data, self.name, timeout=timeout, **kwargs) - def delete(self, expr, timeout=None, **kwargs): - """ Delete entities with an expression condition. + def delete(self, expr: str, timeout: Optional[float] = None, **kwargs): + """Delete entities with an expression condition. Args: expr (``str``): The specified data to insert. partition_names (``List[str]``): Name of partitions to delete entities. - timeout (``float``, optional): A duration of time in seconds to allow for the RPC. Defaults to None. - If timeout is set to None, the client keeps waiting until the server responds or an error occurs. + timeout (``float``, optional): A duration of time in seconds to allow for the RPC. + If timeout is set to None, the client keeps waiting until the server responds + or an error occurs. Returns: - MutationResult: contains `delete_count` properties represents how many entities might be deleted. + MutationResult: contains `delete_count` properties represents + how many entities might be deleted. Raises: MilvusException: If anything goes wrong. Examples: - >>> from pymilvus import connections, Collection, Partition, FieldSchema, CollectionSchema, DataType - >>> connections.connect() + >>> from pymilvus import Collection, Partition, FieldSchema, CollectionSchema, DataType >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", DataType.FLOAT_VECTOR, dim=2) @@ -285,15 +327,21 @@ def delete(self, expr, timeout=None, **kwargs): """ return self._collection.delete(expr, self.name, timeout=timeout, **kwargs) - def upsert(self, data: Union[List, pandas.DataFrame], timeout=None, **kwargs) -> MutationResult: - """ Upsert data into the collection. + def upsert( + self, + data: Union[List, pd.DataFrame], + timeout: Optional[float] = None, + **kwargs, + ) -> MutationResult: + """Upsert data into the collection. Args: data (``list/tuple/pandas.DataFrame``): The specified data to upsert partition_name (``str``): The partition name which the data will be upserted at, - if partition name is not passed, then the data will be upserted in "_default" partition - timeout (``float``, optional): A duration of time in seconds to allow for the RPC. Defaults to None. - If timeout is set to None, the client keeps waiting until the server responds or an error occurs. + if partition name is not passed, then the data will be upserted in default partition + timeout (``float``, optional): A duration of time in seconds to allow for the RPC. + If timeout is set to None, the client keeps waiting until the server responds + or an error occurs. Returns: MutationResult: contains 2 properties `upsert_count`, and, `primary_keys` `upsert_count`: how may entites have been upserted at Milvus, @@ -302,8 +350,7 @@ def upsert(self, data: Union[List, pandas.DataFrame], timeout=None, **kwargs) -> MilvusException: If anything goes wrong. Examples: - >>> from pymilvus import connections, Collection, Partition, FieldSchema, CollectionSchema, DataType - >>> connections.connect() + >>> from pymilvus import Collection, Partition, FieldSchema, CollectionSchema, DataType >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) @@ -324,14 +371,24 @@ def upsert(self, data: Union[List, pandas.DataFrame], timeout=None, **kwargs) -> return self._collection.upsert(data, self.name, timeout=timeout, **kwargs) - def search(self, data, anns_field, param, limit, - expr=None, output_fields=None, timeout=None, round_decimal=-1, **kwargs) -> SearchResult: - """ Conducts a vector similarity search with an optional boolean expression as filter. + def search( + self, + data: List, + anns_field: str, + param: Dict, + limit: int, + expr: Optional[str] = None, + output_fields: Optional[List[str]] = None, + timeout: Optional[float] = None, + round_decimal: int = -1, + **kwargs, + ) -> SearchResult: + """Conducts a vector similarity search with an optional boolean expression as filter. Args: data (``List[List[float]]``): The vectors of search data. - the length of data is number of query (nq), and the dim of every vector in data must be equal to - the vector field's of collection. + the length of data is number of query (nq), + and the dim of every vector in data must be equal to the vector field of collection. anns_field (``str``): The name of the vector field used to search of collection. param (``dict[str, Any]``): @@ -367,8 +424,8 @@ def search(self, data, anns_field, param, limit, output_fields (``List[str]``, optional): The name of fields to return in the search result. Can only get scalar fields. - round_decimal (``int``, optional): The specified number of decimal places of returned distance - Defaults to -1 means no round to returned distance. + round_decimal (``int``, optional): The specified number of decimal places of + returned distance......... Defaults to -1 means no round to returned distance. **kwargs (``dict``): Optional search params * *_async* (``bool``, optional) @@ -384,9 +441,9 @@ def search(self, data, anns_field, param, limit, Options of consistency level: Strong, Bounded, Eventually, Session, Customized. - Note: this parameter will overwrite the same parameter specified when user created the collection, - if no consistency level was specified, search will use the consistency level when you create the - collection. + Note: this parameter overwrites the same one specified when creating collection, + if no consistency level was specified, search will use the + consistency level when you create the collection. * *guarantee_timestamp* (``int``, optional) Instructs Milvus to see all operations performed before this timestamp. @@ -418,15 +475,16 @@ def search(self, data, anns_field, param, limit, MilvusException: If anything goes wrong Examples: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType >>> import random - >>> connections.connect() >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) ... ]) >>> collection = Collection("test_collection_search", schema) - >>> collection.create_index("films", {"index_type": "FLAT", "metric_type": "L2", "params": {}}) + >>> collection.create_index( + ... "films", + ... {"index_type": "FLAT", "metric_type": "L2", "params": {}}) >>> partition = Partition(collection, "comedy", "comedy films") >>> # insert >>> data = [ @@ -449,8 +507,8 @@ def search(self, data, anns_field, param, limit, >>> assert len(hits) == 2 >>> print(f"- Total hits: {len(hits)}, hits ids: {hits.ids} ") - Total hits: 2, hits ids: [8, 5] - >>> print(f"- Top1 hit id: {hits[0].id}, distance: {hits[0].distance}, score: {hits[0].score} ") - - Top1 hit id: 8, distance: 0.10143111646175385, score: 0.10143111646175385 + >>> print(f"- Top1 hit id: {hits[0].id}, score: {hits[0].score} ") + - Top1 hit id: 8, score: 0.10143111646175385 """ return self._collection.search( @@ -466,14 +524,21 @@ def search(self, data, anns_field, param, limit, **kwargs, ) - def query(self, expr, output_fields=None, timeout=None, **kwargs): - """ Query with expressions + def query( + self, + expr: str, + output_fields: Optional[List[str]] = None, + timeout: Optional[float] = None, + **kwargs, + ): + """Query with expressions Args: expr (``str``): The query expression. output_fields(``List[str]``): A list of field names to return. Defaults to None. - timeout (``float``, optional): A duration of time in seconds to allow for the RPC. Defaults to None. - If timeout is set to None, the client keeps waiting until the server responds or an error occurs. + timeout (``float``, optional): A duration of time in seconds to allow for the RPC. + If timeout is set to None, the client keeps waiting until the server responds + or an error occurs. **kwargs (``dict``, optional): * *consistency_level* (``str/int``, optional) @@ -481,9 +546,9 @@ def query(self, expr, output_fields=None, timeout=None, **kwargs): Options of consistency level: Strong, Bounded, Eventually, Session, Customized. - Note: this parameter will overwrite the same parameter specified when user created the collection, - if no consistency level was specified, search will use the consistency level when you create the - collection. + Note: this parameter overwrites the same one specified when creating collection, + if no consistency level was specified, search will use the + consistency level when you create the collection. * *guarantee_timestamp* (``int``, optional) Instructs Milvus to see all operations performed before this timestamp. @@ -513,16 +578,17 @@ def query(self, expr, output_fields=None, timeout=None, **kwargs): MilvusException: If anything goes wrong Examples: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType >>> import random - >>> connections.connect() >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("film_date", DataType.INT64), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) ... ]) >>> collection = Collection("test_collection_query", schema) - >>> collection.create_index("films", {"index_type": "FLAT", "metric_type": "L2", "params": {}}) + >>> collection.create_index( + ... "films", + ... {"index_type": "FLAT", "metric_type": "L2", "params": {}}) >>> partition = Partition(collection, "comedy", "comedy films") >>> # insert >>> data = [ @@ -540,14 +606,20 @@ def query(self, expr, output_fields=None, timeout=None, **kwargs): - Query results: [{'film_id': 0, 'film_date': 2000}, {'film_id': 1, 'film_date': 2001}] """ - return self._collection.query(expr, output_fields, partition_names=[self.name], timeout=timeout, **kwargs) + return self._collection.query( + expr=expr, + output_fields=output_fields, + partition_names=[self.name], + timeout=timeout, + **kwargs, + ) - def get_replicas(self, timeout=None, **kwargs) -> Replica: + def get_replicas(self, timeout: Optional[float] = None, **kwargs) -> Replica: """Get the current loaded replica information Args: - timeout (``float``, optional): An optional duration of time in seconds to allow for the RPC. When timeout - is set to None, client waits until server response or error occur. + timeout (``float``, optional): An optional duration of time in seconds to allow for + the RPC. When timeout is set to None, client waits until server response or error occur. Returns: Replica: All the replica information. """ diff --git a/pymilvus/orm/prepare.py b/pymilvus/orm/prepare.py index fe7632e40..d2823a10b 100644 --- a/pymilvus/orm/prepare.py +++ b/pymilvus/orm/prepare.py @@ -9,73 +9,80 @@ # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. + import copy +from typing import List, Tuple, Union -import numpy -import pandas +import numpy as np +import pandas as pd -from ..exceptions import ( +from pymilvus.exceptions import ( DataNotMatchException, DataTypeNotSupportException, ExceptionsMessage, UpsertAutoIDTrueException, ) +from .schema import CollectionSchema + class Prepare: @classmethod - def prepare_insert_or_upsert_data(cls, data, schema, is_insert=True): - if not isinstance(data, (list, tuple, pandas.DataFrame)): + def prepare_insert_data( + cls, + data: Union[List, Tuple, pd.DataFrame], + schema: CollectionSchema, + ) -> List: + if not isinstance(data, (list, tuple, pd.DataFrame)): raise DataTypeNotSupportException(message=ExceptionsMessage.DataTypeNotSupport) fields = schema.fields entities = [] # Entities - if isinstance(data, pandas.DataFrame): - if schema.auto_id: - if is_insert is False: - raise UpsertAutoIDTrueException(message=ExceptionsMessage.UpsertAutoIDTrue) - if schema.primary_field.name in data: - if not data[schema.primary_field.name].isnull().all(): - raise DataNotMatchException(message=ExceptionsMessage.AutoIDWithData) - for i, field in enumerate(fields): + if isinstance(data, pd.DataFrame): + if ( + schema.auto_id + and schema.primary_field.name in data + and not data[schema.primary_field.name].isnull().all() + ): + raise DataNotMatchException(message=ExceptionsMessage.AutoIDWithData) + for field in fields: if field.is_primary and field.auto_id: continue values = [] if field.name in list(data.columns): values = list(data[field.name]) - entities.append({"name": field.name, - "type": field.dtype, - "values": values}) - else: - if schema.auto_id: - if is_insert is False: - raise UpsertAutoIDTrueException(message=ExceptionsMessage.UpsertAutoIDTrue) + entities.append({"name": field.name, "type": field.dtype, "values": values}) + return entities - tmp_fields = copy.deepcopy(fields) - for i, field in enumerate(tmp_fields): - # TODO Goose: Checking auto_id and is_primary only, maybe different than - # schema.is_primary, schema.auto_id, need to check why and how schema is built. - if field.is_primary and field.auto_id: - tmp_fields.pop(i) + tmp_fields = copy.deepcopy(fields) + for i, field in enumerate(tmp_fields): + # TODO Goose: Checking auto_id and is_primary only, maybe different than + # schema.is_primary, schema.auto_id, need to check why and how schema is built. + if field.is_primary and field.auto_id: + tmp_fields.pop(i) + + for i, field in enumerate(tmp_fields): + try: + if isinstance(data[i], np.ndarray): + d = data[i].tolist() + else: + d = data[i] if data[i] is not None else [] - for i, field in enumerate(tmp_fields): - try: - d = data[i] - # if pass in None, considering to be passed in order according to the schema - if d is None: - d = [] - if isinstance(data[i], numpy.ndarray): - d = data[i].tolist() - entities.append({ - "name": field.name, - "type": field.dtype, - "values": d}) - # the last missing part of data is also completed in order according to the schema - except IndexError: - entities.append({ - "name": field.name, - "type": field.dtype, - "values": []}) + entities.append({"name": field.name, "type": field.dtype, "values": d}) + # the last missing part of data is also completed in order according to the schema + except IndexError: + entities.append({"name": field.name, "type": field.dtype, "values": []}) return entities + + @classmethod + def prepare_upsert_data( + cls, + data: Union[List, Tuple, pd.DataFrame], + schema: CollectionSchema, + ) -> List: + if schema.auto_id: + raise UpsertAutoIDTrueException(message=ExceptionsMessage.UpsertAutoIDTrue) + + return cls.prepare_insert_data(data, schema) diff --git a/pymilvus/orm/role.py b/pymilvus/orm/role.py index b8c49d36b..eedb733cb 100644 --- a/pymilvus/orm/role.py +++ b/pymilvus/orm/role.py @@ -9,15 +9,19 @@ # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. + from .connections import connections +INCLUDE_USER_INFO, NOT_INCLUDE_USER_INFO = True, False + + class Role: - """ Role, can be granted privileges which are allowed to execute some objects' apis. """ + """Role, can be granted privileges which are allowed to execute some objects' apis.""" - def __init__(self, name: str, using="default", **kwargs): - """ Constructs a role by name - :param name: role name. - :type name: str + def __init__(self, name: str, using: str = "default", **kwargs) -> None: + """Constructs a role by name + :param name: role name. + :type name: str """ self._name = name self._using = using @@ -31,7 +35,7 @@ def name(self): return self._name def create(self): - """ Create a role + """Create a role It will success if the role isn't existed, otherwise fail. :example: @@ -46,7 +50,7 @@ def create(self): return self._get_connection().create_role(self._name) def drop(self): - """ Drop a role + """Drop a role It will success if the role is existed, otherwise fail. :example: @@ -61,7 +65,7 @@ def drop(self): return self._get_connection().drop_role(self._name) def add_user(self, username: str): - """ Add user to role + """Add user to role The user will get permissions that the role are allowed to perform operations. :param username: user name. :type username: str @@ -78,7 +82,7 @@ def add_user(self, username: str): return self._get_connection().add_user_to_role(username, self._name) def remove_user(self, username: str): - """ Remove user from role + """Remove user from role The user will remove permissions that the role are allowed to perform operations. :param username: user name. :type username: str @@ -95,7 +99,7 @@ def remove_user(self, username: str): return self._get_connection().remove_user_from_role(username, self._name) def get_users(self): - """ Get all users who are added to the role. + """Get all users who are added to the role. :return a RoleInfo object which contains a RoleItem group According to the RoleItem, you can get a list of usernames. @@ -110,13 +114,13 @@ def get_users(self): >>> users = role.get_users() >>> print(f"users added to the role: {users}") """ - roles = self._get_connection().select_one_role(self._name, True) + roles = self._get_connection().select_one_role(self._name, INCLUDE_USER_INFO) if len(roles.groups) == 0: return [] return roles.groups[0].users def is_exist(self): - """ Check whether the role is existed. + """Check whether the role is existed. :return a bool value It will be True if the role is existed, otherwise False. @@ -128,11 +132,11 @@ def is_exist(self): >>> is_exist = role.is_exist() >>> print(f"the role: {is_exist}") """ - roles = self._get_connection().select_one_role(self._name, False) + roles = self._get_connection().select_one_role(self._name, NOT_INCLUDE_USER_INFO) return len(roles.groups) != 0 def grant(self, object: str, object_name: str, privilege: str, db_name: str = "default"): - """ Grant a privilege for the role + """Grant a privilege for the role :param object: object type. :type object: str :param object_name: identifies a specific object name. @@ -149,30 +153,31 @@ def grant(self, object: str, object_name: str, privilege: str, db_name: str = "d >>> role = Role(role_name) >>> role.grant("Collection", collection_name, "Insert") """ - return self._get_connection().grant_privilege(self._name, object, object_name, privilege, db_name) + return self._get_connection().grant_privilege( + self._name, object, object_name, privilege, db_name + ) def revoke(self, object: str, object_name: str, privilege: str, db_name: str = "default"): - """ Revoke a privilege for the role - :param object: object type. - :type object: str - :param object_name: identifies a specific object name. - :type object_name: str - :param privilege: privilege name. - :type privilege: str - :param db_name: db name. - :type db_name: str - - :example: + """Revoke a privilege for the role + Args: + object(str): object type. + object_name(str): identifies a specific object name. + privilege(str): privilege name. + db_name(str): db name. + + Examples: >>> from pymilvus import connections >>> from pymilvus.orm.role import Role >>> connections.connect() >>> role = Role(role_name) >>> role.revoke("Collection", collection_name, "Insert") """ - return self._get_connection().revoke_privilege(self._name, object, object_name, privilege, db_name) + return self._get_connection().revoke_privilege( + self._name, object, object_name, privilege, db_name + ) def list_grant(self, object: str, object_name: str, db_name: str = "default"): - """ List a grant info for the role and the specific object + """List a grant info for the role and the specific object :param object: object type. :type object: str :param object_name: identifies a specific object name. @@ -183,7 +188,8 @@ def list_grant(self, object: str, object_name: str, db_name: str = "default"): :rtype GrantInfo GrantInfo groups: - - GrantItem: , , , , + - GrantItem: , , , + , :example: >>> from pymilvus import connections @@ -192,17 +198,20 @@ def list_grant(self, object: str, object_name: str, db_name: str = "default"): >>> role = Role(role_name) >>> role.list_grant("Collection", collection_name) """ - return self._get_connection().select_grant_for_role_and_object(self._name, object, object_name, db_name) + return self._get_connection().select_grant_for_role_and_object( + self._name, object, object_name, db_name + ) def list_grants(self, db_name: str = "default"): - """ List a grant info for the role + """List a grant info for the role :param db_name: db name. :type db_name: str :return a GrantInfo object :rtype GrantInfo GrantInfo groups: - - GrantItem: , , , , + - GrantItem: , , , + , :example: >>> from pymilvus import connections diff --git a/pymilvus/orm/schema.py b/pymilvus/orm/schema.py index 9b57678f2..67c70efed 100644 --- a/pymilvus/orm/schema.py +++ b/pymilvus/orm/schema.py @@ -11,30 +11,37 @@ # the License. import copy -from typing import List, Union, Dict -import pandas +from typing import Any, Dict, List, Optional, Union + +import pandas as pd from pandas.api.types import is_list_like, is_scalar -from ..grpc_gen import schema_pb2 as schema_types -from .constants import COMMON_TYPE_PARAMS -from .types import DataType, map_numpy_dtype_to_datatype, infer_dtype_bydata, infer_dtype_by_scaladata -from ..exceptions import ( - ParamError, +from pymilvus.exceptions import ( + AutoIDException, CannotInferSchemaException, + DataNotMatchException, DataTypeNotSupportException, - PrimaryKeyException, - PartitionKeyException, + ExceptionsMessage, FieldsTypeException, FieldTypeException, - AutoIDException, - ExceptionsMessage, - DataNotMatchException, + ParamError, + PartitionKeyException, + PrimaryKeyException, SchemaNotReadyException, - UpsertAutoIDTrueException + UpsertAutoIDTrueException, ) +from pymilvus.grpc_gen import schema_pb2 as schema_types +from .constants import COMMON_TYPE_PARAMS +from .types import ( + DataType, + infer_dtype_by_scaladata, + infer_dtype_bydata, + map_numpy_dtype_to_datatype, +) -def validate_primary_key(primary_field): + +def validate_primary_key(primary_field: Any): if primary_field is None: raise PrimaryKeyException(message=ExceptionsMessage.PrimaryKeyNotExist) @@ -42,24 +49,24 @@ def validate_primary_key(primary_field): raise PrimaryKeyException(message=ExceptionsMessage.PrimaryKeyType) -def validate_partition_key(partition_key_field_name, partition_key_field, primary_field_name): +def validate_partition_key( + partition_key_field_name: Any, partition_key_field: Any, primary_field_name: Any +): # not allow partition_key field is primary key field if partition_key_field is not None: if partition_key_field.name == primary_field_name: - PartitionKeyException( - message=ExceptionsMessage.PartitionKeyNotPrimary) + PartitionKeyException(message=ExceptionsMessage.PartitionKeyNotPrimary) if partition_key_field.dtype not in [DataType.INT64, DataType.VARCHAR]: - raise PartitionKeyException( - message=ExceptionsMessage.PartitionKeyType) - else: - if partition_key_field_name is not None: - raise PartitionKeyException( - message=ExceptionsMessage.PartitionKeyFieldNotExist % partition_key_field_name) + raise PartitionKeyException(message=ExceptionsMessage.PartitionKeyType) + elif partition_key_field_name is not None: + raise PartitionKeyException( + message=ExceptionsMessage.PartitionKeyFieldNotExist % partition_key_field_name + ) class CollectionSchema: - def __init__(self, fields, description="", **kwargs): + def __init__(self, fields: List, description: str = "", **kwargs): self._kwargs = copy.deepcopy(kwargs) self._fields = [] self._description = description @@ -75,24 +82,20 @@ def __init__(self, fields, description="", **kwargs): if kwargs.get("check_fields", True): self._check_fields() - def _check_kwargs(self): primary_field_name = self._kwargs.get("primary_field", None) partition_key_field_name = self._kwargs.get("partition_key_field", None) if primary_field_name is not None and not isinstance(primary_field_name, str): - raise PrimaryKeyException( - message=ExceptionsMessage.PrimaryFieldType) + raise PrimaryKeyException(message=ExceptionsMessage.PrimaryFieldType) if partition_key_field_name is not None and not isinstance(partition_key_field_name, str): - raise PartitionKeyException( - message=ExceptionsMessage.PartitionKeyFieldType) + raise PartitionKeyException(message=ExceptionsMessage.PartitionKeyFieldType) for field in self._fields: if not isinstance(field, FieldSchema): raise FieldTypeException(message=ExceptionsMessage.FieldType) - if "auto_id" in self._kwargs: - if not isinstance(self._kwargs["auto_id"], bool): - raise AutoIDException(0, ExceptionsMessage.AutoIDType) + if "auto_id" in self._kwargs and not isinstance(self._kwargs["auto_id"], bool): + raise AutoIDException(0, ExceptionsMessage.AutoIDType) def _check_fields(self): primary_field_name = self._kwargs.get("primary_field", None) @@ -105,21 +108,25 @@ def _check_fields(self): if field.is_primary: if primary_field_name is not None and primary_field_name != field.name: - raise PrimaryKeyException( - message=ExceptionsMessage.PrimaryKeyOnlyOne % (primary_field_name, field.name)) + msg = ExceptionsMessage.PrimaryKeyOnlyOne % (primary_field_name, field.name) + raise PrimaryKeyException(message=msg) self._primary_field = field primary_field_name = field.name if field.is_partition_key: if partition_key_field_name is not None and partition_key_field_name != field.name: - raise PartitionKeyException( - message=ExceptionsMessage.PartitionKeyOnlyOne % (partition_key_field_name, field.name)) + msg = ExceptionsMessage.PartitionKeyOnlyOne % ( + partition_key_field_name, + field.name, + ) + raise PartitionKeyException(message=msg) self._partition_key_field = field partition_key_field_name = field.name validate_primary_key(self._primary_field) - validate_partition_key(partition_key_field_name, - self._partition_key_field, self._primary_field.name) + validate_partition_key( + partition_key_field_name, self._partition_key_field, self._primary_field.name + ) auto_id = self._kwargs.get("auto_id", False) if auto_id: @@ -132,22 +139,23 @@ def _check(self): self._check_kwargs() self._check_fields() - def __repr__(self): + def __repr__(self) -> str: return str(self.to_dict()) - def __len__(self): + def __len__(self) -> int: return len(self.fields) - def __eq__(self, other): - """ The order of the fields of schema must be consistent.""" + def __eq__(self, other: Any): + """The order of the fields of schema must be consistent.""" return self.to_dict() == other.to_dict() @classmethod - def construct_from_dict(cls, raw): - fields = [FieldSchema.construct_from_dict( - field_raw) for field_raw in raw['fields']] + def construct_from_dict(cls, raw: Dict): + fields = [FieldSchema.construct_from_dict(field_raw) for field_raw in raw["fields"]] enable_dynamic_field = raw.get("enable_dynamic_field", False) - return CollectionSchema(fields, raw.get('description', ""), enable_dynamic_field=enable_dynamic_field) + return CollectionSchema( + fields, raw.get("description", ""), enable_dynamic_field=enable_dynamic_field + ) @property def primary_field(self): @@ -210,7 +218,7 @@ def auto_id(self): return self.primary_field.auto_id @auto_id.setter - def auto_id(self, value): + def auto_id(self, value: bool): if self.primary_field: self.primary_field.auto_id = bool(value) @@ -219,7 +227,7 @@ def enable_dynamic_field(self): return self._enable_dynamic_field @enable_dynamic_field.setter - def enable_dynamic_field(self, value): + def enable_dynamic_field(self, value: bool): self._enable_dynamic_field = bool(value) def to_dict(self): @@ -236,23 +244,21 @@ def verify(self): # final check, detect obvious problems self._check() - def add_field(self, field_name, datatype, **kwargs): + def add_field(self, field_name: str, datatype: DataType, **kwargs): field = FieldSchema(field_name, datatype, **kwargs) self._fields.append(field) return self class FieldSchema: - def __init__(self, name: str, dtype: DataType, description="", **kwargs): + def __init__(self, name: str, dtype: DataType, description: str = "", **kwargs) -> None: self.name = name try: dtype = DataType(dtype) except ValueError: - raise DataTypeNotSupportException( - message=ExceptionsMessage.FieldDtype) from None + raise DataTypeNotSupportException(message=ExceptionsMessage.FieldDtype) from None if dtype == DataType.UNKNOWN: - raise DataTypeNotSupportException( - message=ExceptionsMessage.FieldDtype) + raise DataTypeNotSupportException(message=ExceptionsMessage.FieldDtype) self._dtype = dtype self._description = description self._type_params = {} @@ -266,33 +272,30 @@ def __init__(self, name: str, dtype: DataType, description="", **kwargs): if not isinstance(self.auto_id, bool): raise AutoIDException(message=ExceptionsMessage.AutoIDType) if not self.is_primary and self.auto_id: - raise PrimaryKeyException( - message=ExceptionsMessage.AutoIDOnlyOnPK) + raise PrimaryKeyException(message=ExceptionsMessage.AutoIDOnlyOnPK) if not isinstance(kwargs.get("is_partition_key", False), bool): - raise PartitionKeyException( - message=ExceptionsMessage.IsPartitionKeyType) + raise PartitionKeyException(message=ExceptionsMessage.IsPartitionKeyType) self.is_partition_key = kwargs.get("is_partition_key", False) self.default_value = kwargs.get("default_value", None) if isinstance(self.default_value, schema_types.ValueField): if self.default_value.WhichOneof("data") is None: self.default_value = None else: - self.default_value = infer_default_value_bydata( - kwargs.get("default_value", None)) + self.default_value = infer_default_value_bydata(kwargs.get("default_value", None)) self._parse_type_params() - def __repr__(self): + def __repr__(self) -> str: return str(self.to_dict()) - def __deepcopy__(self, memodict=None): + def __deepcopy__(self, memodict: Optional[Dict] = None): if memodict is None: memodict = {} return self.construct_from_dict(self.to_dict()) def _parse_type_params(self): # update self._type_params according to self._kwargs - if self._dtype not in (DataType.BINARY_VECTOR, DataType.FLOAT_VECTOR, DataType.VARCHAR,): + if self._dtype not in (DataType.BINARY_VECTOR, DataType.FLOAT_VECTOR, DataType.VARCHAR): return if not self._kwargs: return @@ -305,16 +308,16 @@ def _parse_type_params(self): self._type_params[k] = self._kwargs[k] @classmethod - def construct_from_dict(cls, raw): + def construct_from_dict(cls, raw: Dict): kwargs = {} kwargs.update(raw.get("params", {})) - kwargs['is_primary'] = raw.get("is_primary", False) + kwargs["is_primary"] = raw.get("is_primary", False) if raw.get("auto_id", None) is not None: - kwargs['auto_id'] = raw.get("auto_id", None) - kwargs['is_partition_key'] = raw.get("is_partition_key", False) - kwargs['default_value'] = raw.get("default_value", None) - kwargs['is_dynamic'] = raw.get("is_dynamic", False) - return FieldSchema(raw['name'], raw['type'], raw.get("description", ""), **kwargs) + kwargs["auto_id"] = raw.get("auto_id", None) + kwargs["is_partition_key"] = raw.get("is_partition_key", False) + kwargs["default_value"] = raw.get("default_value", None) + kwargs["is_dynamic"] = raw.get("is_dynamic", False) + return FieldSchema(raw["name"], raw["type"], raw.get("description", ""), **kwargs) def to_dict(self): _dict = { @@ -337,12 +340,12 @@ def to_dict(self): _dict["is_dynamic"] = self.is_dynamic return _dict - def __getattr__(self, item): + def __getattr__(self, item: str): if self._type_params and item in self._type_params: return self._type_params[item] return None - def __eq__(self, other): + def __eq__(self, other: Any): if not isinstance(other, FieldSchema): return False return self.to_dict() == other.to_dict() @@ -387,12 +390,13 @@ def dtype(self) -> DataType: return self._dtype -def check_insert_or_upsert_is_row_based(data: Union[List[List], List[Dict], Dict, pandas.DataFrame]) -> bool: - if not isinstance(data, (pandas.DataFrame, list, dict)): +def check_is_row_based(data: Union[List[List], List[Dict], Dict, pd.DataFrame]) -> bool: + if not isinstance(data, (pd.DataFrame, list, dict)): raise DataTypeNotSupportException( - message="The type of data should be list or pandas.DataFrame or dict") + message="The type of data should be list or pandas.DataFrame or dict" + ) - if isinstance(data, pandas.DataFrame): + if isinstance(data, pd.DataFrame): return False if isinstance(data, dict): @@ -407,44 +411,36 @@ def check_insert_or_upsert_is_row_based(data: Union[List[List], List[Dict], Dict return False -def check_insert_or_upsert_data_schema(schema: CollectionSchema, data: Union[List[List], pandas.DataFrame], - is_insert=True) -> None: - """ check if the insert or upsert data is consist with the collection schema +def check_insert_schema(schema: CollectionSchema, data: Union[List[List], pd.DataFrame]): + if schema is None: + raise SchemaNotReadyException(message="Schema shouldn't be None") + if schema.auto_id and isinstance(data, pd.DataFrame) and schema.primary_field.name in data: + if not data[schema.primary_field.name].isnull().all(): + msg = f"Expect no data for auto_id primary field: {schema.primary_field.name}" + raise DataNotMatchException(message=msg) + data = data.drop(schema.primary_field.name, axis=1) + + infer_fields, tmp_fields, is_data_frame = parse_fields_from_data(schema, data) + check_infer_fields_valid(infer_fields, tmp_fields, is_data_frame) - Args: - schema (CollectionSchema): the schema of the collection - data (List[List], pandas.DataFrame): the data to be inserted or upserted - Raise: - SchemaNotReadyException: if the schema is None - UpsertAutoIDTrueException: if autoid option is true - DataNotMatchException: if the data is in consist with the schema - """ +def check_upset_schema(schema: CollectionSchema, data: Union[List[List], pd.DataFrame]): if schema is None: raise SchemaNotReadyException(message="Schema shouldn't be None") if schema.auto_id: - if is_insert: - if isinstance(data, pandas.DataFrame): - if schema.primary_field.name in data: - if not data[schema.primary_field.name].isnull().all(): - raise DataNotMatchException( - message=f"Please don't provide data for auto_id primary field: {schema.primary_field.name}") - data = data.drop(schema.primary_field.name, axis=1) - else: - raise UpsertAutoIDTrueException( - message=ExceptionsMessage.UpsertAutoIDTrue) + raise UpsertAutoIDTrueException(ExceptionsMessage.UpsertAutoIDTrue) - infer_fields, tmp_fields, is_data_frame = parse_fields_from_data( - schema, data) + infer_fields, tmp_fields, is_data_frame = parse_fields_from_data(schema, data) check_infer_fields_valid(infer_fields, tmp_fields, is_data_frame) -def parse_fields_from_data(schema: CollectionSchema, data: Union[List[List], pandas.DataFrame]): - if not isinstance(data, (pandas.DataFrame, list)): +def parse_fields_from_data(schema: CollectionSchema, data: Union[List[List], pd.DataFrame]): + if not isinstance(data, (pd.DataFrame, list)): raise DataTypeNotSupportException( - message="The type of data should be list or pandas.DataFrame") + message="The type of data should be list or pandas.DataFrame" + ) - if isinstance(data, pandas.DataFrame): + if isinstance(data, pd.DataFrame): return parse_fields_from_dataframe(schema, data) tmp_fields = copy.deepcopy(schema.fields) @@ -461,8 +457,7 @@ def parse_fields_from_data(schema: CollectionSchema, data: Union[List[List], pan infer_fields.append(FieldSchema("", field.dtype)) continue if not is_list_like(d): - raise DataTypeNotSupportException( - message="data should be a list of list") + raise DataTypeNotSupportException(message="data should be a list of list") try: elem = d[0] infer_fields.append(FieldSchema("", infer_dtype_bydata(elem))) @@ -482,15 +477,14 @@ def parse_fields_from_data(schema: CollectionSchema, data: Union[List[List], pan return infer_fields, tmp_fields, False -def parse_fields_from_dataframe(schema: CollectionSchema, df: pandas.DataFrame): - col_names, data_types, column_params_map = prepare_fields_from_dataframe( - df) +def parse_fields_from_dataframe(schema: CollectionSchema, df: pd.DataFrame): + col_names, data_types, column_params_map = prepare_fields_from_dataframe(df) tmp_fields = copy.deepcopy(schema.fields) for i, field in enumerate(schema.fields): if field.is_primary and field.auto_id: tmp_fields.pop(i) infer_fields = [] - for i, field in enumerate(tmp_fields): + for field in tmp_fields: # if no data pass in, considering to be passed in order according to the schema if field.name not in col_names: field_schema = FieldSchema(field.name, field.dtype) @@ -500,7 +494,8 @@ def parse_fields_from_dataframe(schema: CollectionSchema, df: pandas.DataFrame): else: type_params = column_params_map.get(field.name, {}) field_schema = FieldSchema( - field.name, data_types[col_names.index(field.name)], **type_params) + field.name, data_types[col_names.index(field.name)], **type_params + ) infer_fields.append(field_schema) infer_name = [f.name for f in infer_fields] @@ -513,9 +508,8 @@ def parse_fields_from_dataframe(schema: CollectionSchema, df: pandas.DataFrame): return infer_fields, tmp_fields, True -def construct_fields_from_dataframe(df: pandas.DataFrame) -> List[FieldSchema]: - col_names, data_types, column_params_map = prepare_fields_from_dataframe( - df) +def construct_fields_from_dataframe(df: pd.DataFrame) -> List[FieldSchema]: + col_names, data_types, column_params_map = prepare_fields_from_dataframe(df) fields = [] for name, dtype in zip(col_names, data_types): type_params = column_params_map.get(name, {}) @@ -525,7 +519,7 @@ def construct_fields_from_dataframe(df: pandas.DataFrame) -> List[FieldSchema]: return fields -def prepare_fields_from_dataframe(df: pandas.DataFrame): +def prepare_fields_from_dataframe(df: pd.DataFrame): d_types = list(df.dtypes) data_types = list(map(map_numpy_dtype_to_datatype, d_types)) col_names = list(df.columns) @@ -534,45 +528,52 @@ def prepare_fields_from_dataframe(df: pandas.DataFrame): if DataType.UNKNOWN in data_types: if len(df) == 0: - raise CannotInferSchemaException( - message=ExceptionsMessage.DataFrameInvalid) - values = df.head(1).values[0] + raise CannotInferSchemaException(message=ExceptionsMessage.DataFrameInvalid) + values = df.head(1).to_numpy()[0] for i, dtype in enumerate(data_types): if dtype == DataType.UNKNOWN: new_dtype = infer_dtype_bydata(values[i]) if new_dtype in (DataType.BINARY_VECTOR, DataType.FLOAT_VECTOR): vector_type_params = {} if new_dtype == DataType.BINARY_VECTOR: - vector_type_params['dim'] = len(values[i]) * 8 + vector_type_params["dim"] = len(values[i]) * 8 else: - vector_type_params['dim'] = len(values[i]) + vector_type_params["dim"] = len(values[i]) column_params_map[col_names[i]] = vector_type_params data_types[i] = new_dtype if DataType.UNKNOWN in data_types: - raise CannotInferSchemaException( - message=ExceptionsMessage.DataFrameInvalid) + raise CannotInferSchemaException(message=ExceptionsMessage.DataFrameInvalid) return col_names, data_types, column_params_map -def check_infer_fields_valid(infer_fields: List[FieldSchema], tmp_fields: list, is_data_frame: bool): +def check_infer_fields_valid( + infer_fields: List[FieldSchema], + tmp_fields: List, + is_data_frame: bool, +): if len(infer_fields) != len(tmp_fields): i_name = [f.name for f in infer_fields] t_name = [f.name for f in tmp_fields] raise DataNotMatchException( - message=f"The fields don't match with schema fields, expected: {t_name}, got {i_name}") + message=f"The fields don't match with schema fields, expected: {t_name}, got {i_name}" + ) for x, y in zip(infer_fields, tmp_fields): if x.dtype != y.dtype: - raise DataNotMatchException( - message=f"The data type of field {y.name} doesn't match, expected: {y.dtype.name}, got {x.dtype.name}") + msg = ( + f"The data type of field {y.name} doesn't match, " + f"expected: {y.dtype.name}, got {x.dtype.name}" + ) + raise DataNotMatchException(message=msg) if is_data_frame and x.name != y.name: raise DataNotMatchException( - message=f"The name of field don't match, expected: {y.name}, got {x.name}") + message=f"The name of field don't match, expected: {y.name}, got {x.name}" + ) -def check_schema(schema): +def check_schema(schema: CollectionSchema): if schema is None: raise SchemaNotReadyException(message=ExceptionsMessage.NoSchema) if len(schema.fields) < 1: @@ -585,7 +586,7 @@ def check_schema(schema): raise SchemaNotReadyException(message=ExceptionsMessage.NoVector) -def infer_default_value_bydata(data): +def infer_default_value_bydata(data: Any): if data is None: return None default_data = schema_types.ValueField() @@ -594,11 +595,7 @@ def infer_default_value_bydata(data): d_type = infer_dtype_by_scaladata(data) if d_type is DataType.BOOL: default_data.bool_data = data - elif d_type is DataType.INT8: - default_data.int_data = data - elif d_type is DataType.INT16: - default_data.int_data = data - elif d_type is DataType.INT32: + elif d_type in (DataType.INT8, DataType.INT16, DataType.INT32): default_data.int_data = data elif d_type is DataType.INT64: default_data.long_data = data @@ -609,6 +606,5 @@ def infer_default_value_bydata(data): elif d_type is DataType.VARCHAR: default_data.string_data = data else: - raise ParamError( - message=f"Default value unsupported data type: {d_type}") + raise ParamError(message=f"Default value unsupported data type: {d_type}") return default_data diff --git a/pymilvus/orm/search.py b/pymilvus/orm/search.py index fafff122b..5c07ca46a 100644 --- a/pymilvus/orm/search.py +++ b/pymilvus/orm/search.py @@ -11,11 +11,13 @@ # the License. import abc -from ..client.abstract import Entity +from typing import Any, Iterable + +from pymilvus.client.abstract import Entity class _IterableWrapper: - def __init__(self, iterable_obj): + def __init__(self, iterable_obj: Iterable) -> None: self._iterable = iterable_obj def __iter__(self): @@ -24,7 +26,7 @@ def __iter__(self): def __next__(self): return self.on_result(self._iterable.__next__()) - def __getitem__(self, item): + def __getitem__(self, item: str): s = self._iterable.__getitem__(item) if isinstance(item, slice): _start = item.start or 0 @@ -37,31 +39,25 @@ def __getitem__(self, item): return elements return s - def __len__(self): + def __len__(self) -> int: return self._iterable.__len__() @abc.abstractmethod - def on_result(self, res): + def on_result(self, res: Any): raise NotImplementedError # TODO: how to add docstring to method of subclass and don't change the implementation? # for example like below: # class Hits(_IterableWrapper): -# __init__.__doc__ = """doc of __init__""" -# __iter__.__doc__ = """doc of __iter__""" -# __next__.__doc__ = """doc of __next__""" -# __getitem__.__doc__ = """doc of __getitem__""" -# __len__.__doc__ = """doc of __len__""" # # def on_result(self, res): -# return Hit(res) class DocstringMeta(type): - def __new__(cls, name, bases, attrs): + def __new__(cls, name: str, bases: Any, attrs: Any): doc_meta = attrs.pop("docstring", None) - new_cls = super(DocstringMeta, cls).__new__(cls, name, bases, attrs) + new_cls = super().__new__(cls, name, bases, attrs) if doc_meta: for member_name, member in attrs.items(): if member_name in doc_meta: @@ -71,20 +67,12 @@ def __new__(cls, name, bases, attrs): # for example: # class Hits(_IterableWrapper, metaclass=DocstringMeta): -# docstring = { -# "__init__": """doc of __init__""", -# "__iter__": """doc of __iter__""", -# "__next__": """doc of __next__""", -# "__getitem__": """doc of __getitem__""", -# "__len__": """doc of __len__""", -# } # # def on_result(self, res): -# return Hit(res) class Hit: - def __init__(self, hit): + def __init__(self, hit: Any) -> None: """ Construct a Hit object from response. A hit represent a record corresponding to the query. """ @@ -130,7 +118,7 @@ def score(self) -> float: """ return self._hit.score - def __str__(self): + def __str__(self) -> str: """ Return the information of hit record. @@ -146,7 +134,7 @@ def to_dict(self): class Hits: - def __init__(self, hits): + def __init__(self, hits: Any) -> None: """ Construct a Hits object from response. """ @@ -166,7 +154,7 @@ def __next__(self): """ return Hit(self._hits.__next__()) - def __getitem__(self, item): + def __getitem__(self, item: str): """ Return the kth Hit corresponding to the query. @@ -194,10 +182,10 @@ def __len__(self) -> int: """ return self._hits.__len__() - def __str__(self): + def __str__(self) -> str: return str(list(map(str, self.__getitem__(slice(0, 10))))) - def on_result(self, res): + def on_result(self, res: Any): return Hit(res) @property @@ -222,7 +210,7 @@ def distances(self) -> list: class SearchResult: - def __init__(self, query_result=None): + def __init__(self, query_result: Any = None) -> None: """ Construct a search result from response. """ @@ -240,7 +228,7 @@ def __next__(self): """ return self.on_result(self._qs.__next__()) - def __getitem__(self, item): + def __getitem__(self, item: Any): """ Return the Hits corresponding to the nth query. @@ -268,8 +256,8 @@ def __len__(self) -> int: """ return self._qs.__len__() - def __str__(self): + def __str__(self) -> str: return str(list(map(str, self.__getitem__(slice(0, 10))))) - def on_result(self, res): + def on_result(self, res: Any): return Hits(res) diff --git a/pymilvus/orm/types.py b/pymilvus/orm/types.py index 5e61e66de..1bb4eb715 100644 --- a/pymilvus/orm/types.py +++ b/pymilvus/orm/types.py @@ -10,10 +10,18 @@ # or implied. See the License for the specific language governing permissions and limitations under # the License. -from pandas.api.types import infer_dtype, is_list_like, is_scalar, is_float, is_array_like +from typing import Any + import numpy as np +from pandas.api.types import ( + infer_dtype, + is_array_like, + is_float, + is_list_like, + is_scalar, +) -from ..client.types import DataType +from pymilvus.client.types import DataType dtype_str_map = { "string": DataType.VARCHAR, @@ -51,20 +59,20 @@ } -def is_integer_datatype(data_type): +def is_integer_datatype(data_type: DataType): return data_type in (DataType.INT8, DataType.INT16, DataType.INT32, DataType.INT64) -def is_float_datatype(data_type): +def is_float_datatype(data_type: DataType): return data_type in (DataType.FLOAT,) -def is_numeric_datatype(data_type): +def is_numeric_datatype(data_type: DataType): return is_float_datatype(data_type) or is_integer_datatype(data_type) # pylint: disable=too-many-return-statements -def infer_dtype_by_scaladata(data): +def infer_dtype_by_scaladata(data: Any): if isinstance(data, float): return DataType.DOUBLE if isinstance(data, bool): @@ -97,11 +105,10 @@ def infer_dtype_by_scaladata(data): return DataType.UNKNOWN -def infer_dtype_bydata(data): +def infer_dtype_bydata(data: Any): d_type = DataType.UNKNOWN if is_scalar(data): - d_type = infer_dtype_by_scaladata(data) - return d_type + return infer_dtype_by_scaladata(data) if isinstance(data, dict): return DataType.JSON @@ -114,12 +121,7 @@ def infer_dtype_bydata(data): failed = True if not failed: d_type = dtype_str_map.get(type_str, DataType.UNKNOWN) - if is_numeric_datatype(d_type): - d_type = DataType.FLOAT_VECTOR - else: - d_type = DataType.UNKNOWN - - return d_type + return DataType.FLOAT_VECTOR if is_numeric_datatype(d_type) else DataType.UNKNOWN if d_type == DataType.UNKNOWN: try: @@ -139,7 +141,7 @@ def infer_dtype_bydata(data): return d_type -def map_numpy_dtype_to_datatype(d_type): +def map_numpy_dtype_to_datatype(d_type: DataType): d_type_str = str(d_type) return numpy_dtype_str_map.get(d_type_str, DataType.UNKNOWN) diff --git a/pymilvus/orm/utility.py b/pymilvus/orm/utility.py index 03d29745e..91eec09eb 100644 --- a/pymilvus/orm/utility.py +++ b/pymilvus/orm/utility.py @@ -10,18 +10,22 @@ # or implied. See the License for the specific language governing permissions and limitations under # the License. -from .connections import connections +from datetime import datetime, timedelta, timezone +from typing import List, Optional + +from pymilvus.client.types import BulkInsertState +from pymilvus.client.utils import hybridts_to_unixtime as _hybridts_to_unixtime +from pymilvus.client.utils import mkts_from_datetime as _mkts_from_datetime +from pymilvus.client.utils import mkts_from_hybridts as _mkts_from_hybridts +from pymilvus.client.utils import mkts_from_unixtime as _mkts_from_unixtime +from pymilvus.exceptions import MilvusException -from ..client.utils import mkts_from_hybridts as _mkts_from_hybridts -from ..client.utils import mkts_from_unixtime as _mkts_from_unixtime -from ..client.utils import mkts_from_datetime as _mkts_from_datetime -from ..client.utils import hybridts_to_unixtime as _hybridts_to_unixtime -from ..client.types import BulkInsertState +from .connections import connections -def mkts_from_hybridts(hybridts, milliseconds=0., delta=None): - """ - Generate a hybrid timestamp based on an existing hybrid timestamp, timedelta and incremental time internval. +def mkts_from_hybridts(hybridts: int, milliseconds: float = 0.0, delta: Optional[timedelta] = None): + """Generate a hybrid timestamp based on an existing hybrid timestamp, + timedelta and incremental time internval. :param hybridts: The original hybrid timestamp used to generate a new hybrid timestamp. Non-negative interger range from 0 to 18446744073709551615. @@ -38,12 +42,11 @@ def mkts_from_hybridts(hybridts, milliseconds=0., delta=None): Hybrid timetamp is a non-negative interger range from 0 to 18446744073709551615. :example: - >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, connections, utility - >>> connections.connect(alias="default") + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, utility >>> _DIM = 128 >>> field_int64 = FieldSchema("int64", DataType.INT64, description="int64", is_primary=True) - >>> field_vector = FieldSchema("float_vector", DataType.FLOAT_VECTOR, is_primary=False, dim=_DIM) - >>> schema = CollectionSchema(fields=[field_int64, field_vector], description="get collection entities num") + >>> field_vector = FieldSchema("float_vector", DataType.FLOAT_VECTOR, dim=_DIM) + >>> schema = CollectionSchema(fields=[field_int64, field_vector]) >>> collection = Collection(name="test_collection", schema=schema) >>> import pandas as pd >>> int64_series = pd.Series(data=list(range(10, 20)), index=list(range(10)))i @@ -55,7 +58,7 @@ def mkts_from_hybridts(hybridts, milliseconds=0., delta=None): return _mkts_from_hybridts(hybridts, milliseconds=milliseconds, delta=delta) -def mkts_from_unixtime(epoch, milliseconds=0., delta=None): +def mkts_from_unixtime(epoch: float, milliseconds: float = 0.0, delta: Optional[timedelta] = None): """ Generate a hybrid timestamp based on Unix Epoch time, timedelta and incremental time internval. @@ -83,7 +86,9 @@ def mkts_from_unixtime(epoch, milliseconds=0., delta=None): return _mkts_from_unixtime(epoch, milliseconds=milliseconds, delta=delta) -def mkts_from_datetime(d_time, milliseconds=0., delta=None): +def mkts_from_datetime( + d_time: datetime, milliseconds: float = 0.0, delta: Optional[timedelta] = None +): """ Generate a hybrid timestamp based on datetime, timedelta and incremental time internval. @@ -109,7 +114,7 @@ def mkts_from_datetime(d_time, milliseconds=0., delta=None): return _mkts_from_datetime(d_time, milliseconds=milliseconds, delta=delta) -def hybridts_to_datetime(hybridts, tz=None): +def hybridts_to_datetime(hybridts: int, tz: Optional[timezone] = None): """ Convert a hybrid timestamp to the datetime according to timezone. @@ -117,7 +122,7 @@ def hybridts_to_datetime(hybridts, tz=None): Non-negative interger range from 0 to 18446744073709551615. :type hybridts: int :param tz: Timezone defined by a fixed offset from UTC. If argument tz is None or not specified, - the hybridts is converted to the platform’s local date and time. + the hybridts is converted to the platform`s local date and time. :type tz: datetime.timezone :return datetime: @@ -133,13 +138,15 @@ def hybridts_to_datetime(hybridts, tz=None): >>> d = utility.hybridts_to_datetime(ts) """ import datetime + if tz is not None and not isinstance(tz, datetime.timezone): - raise Exception("parameter tz should be type of datetime.timezone") + msg = "parameter tz should be type of datetime.timezone" + raise MilvusException(message=msg) epoch = _hybridts_to_unixtime(hybridts) return datetime.datetime.fromtimestamp(epoch, tz=tz) -def hybridts_to_unixtime(hybridts): +def hybridts_to_unixtime(hybridts: int): """ Convert a hybrid timestamp to UNIX Epoch time ignoring the logic part. @@ -148,7 +155,8 @@ def hybridts_to_unixtime(hybridts): :type hybridts: int :return float: - The Unix Epoch time is the number of seconds that have elapsed since January 1, 1970 (midnight UTC/GMT). + The Unix Epoch time is the number of seconds that have elapsed since + January 1, 1970 (midnight UTC/GMT). :example: >>> import time @@ -161,12 +169,17 @@ def hybridts_to_unixtime(hybridts): return _hybridts_to_unixtime(hybridts) -def _get_connection(alias): +def _get_connection(alias: str): return connections._fetch_handler(alias) -def loading_progress(collection_name, partition_names=None, using="default", timeout=None): - """ Show loading progress of sealed segments in percentage. +def loading_progress( + collection_name: str, + partition_names: Optional[List[str]] = None, + using: str = "default", + timeout: Optional[float] = None, +): + """Show loading progress of sealed segments in percentage. :param collection_name: The name of collection is loading :type collection_name: str @@ -178,10 +191,9 @@ def loading_progress(collection_name, partition_names=None, using="default", tim {'loading_progress': '100%'} :raises PartitionNotExistException: If partition doesn't exist. :example: - >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, connections, utility + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, utility >>> import pandas as pd >>> import random - >>> connections.connect() >>> fields = [ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", DataType.FLOAT_VECTOR, dim=8), @@ -193,19 +205,28 @@ def loading_progress(collection_name, partition_names=None, using="default", tim ... "films": [[random.random() for _ in range(8)] for _ in range (10)], ... }) >>> collection.insert(data) - >>> collection.create_index("films", {"index_type": "IVF_FLAT", "params": {"nlist": 8}, "metric_type": "L2"}) + >>> collection.create_index( + ... "films", + ... {"index_type": "IVF_FLAT", "params": {"nlist": 8}, "metric_type": "L2"}) >>> collection.load(_async=True) >>> utility.loading_progress("test_loading_progress") {'loading_progress': '100%'} """ - progress = _get_connection(using).get_loading_progress(collection_name, partition_names, timeout=timeout) + progress = _get_connection(using).get_loading_progress( + collection_name, partition_names, timeout=timeout + ) return { "loading_progress": f"{progress:.0f}%", } -def load_state(collection_name, partition_names=None, using="default", timeout=None): - """ Show load state of collection or partitions. +def load_state( + collection_name: str, + partition_names: Optional[float] = None, + using: str = "default", + timeout: Optional[float] = None, +): + """Show load state of collection or partitions. :param collection_name: The name of collection is loading :type collection_name: str @@ -216,11 +237,10 @@ def load_state(collection_name, partition_names=None, using="default", timeout=N The current state of collection or partitions. :example: - >>> from pymilvus import Collection, connections, FieldSchema, CollectionSchema, DataType, utility + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, utility >>> from pymilvus.client.types import LoadState >>> import pandas as pd >>> import random - >>> connections.connect() >>> assert utility.load_state("test_load_state") == LoadState.NotExist >>> fields = [ ... FieldSchema("film_id", DataType.INT64, is_primary=True), @@ -234,14 +254,21 @@ def load_state(collection_name, partition_names=None, using="default", timeout=N ... "films": [[random.random() for _ in range(8)] for _ in range (10)], ... }) >>> collection.insert(data) - >>> collection.create_index("films", {"index_type": "IVF_FLAT", "params": {"nlist": 8}, "metric_type": "L2"}) + >>> collection.create_index( + ... "films", + ... {"index_type": "IVF_FLAT", "params": {"nlist": 8}, "metric_type": "L2"}) >>> collection.load(_async=True) >>> assert utility.load_state("test_load_state") == LoadState.Loaded """ return _get_connection(using).get_load_state(collection_name, partition_names, timeout=timeout) -def wait_for_loading_complete(collection_name, partition_names=None, timeout=None, using="default"): +def wait_for_loading_complete( + collection_name: str, + partition_names: Optional[List[str]] = None, + timeout: Optional[float] = None, + using: str = "default", +): """ Block until loading is done or Raise Exception after timeout. @@ -258,12 +285,11 @@ def wait_for_loading_complete(collection_name, partition_names=None, timeout=Non :raises PartitionNotExistException: If partition doesn't exist. :example: - >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, connections, utility - >>> connections.connect(alias="default") + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, utility >>> _DIM = 128 >>> field_int64 = FieldSchema("int64", DataType.INT64, description="int64", is_primary=True) - >>> field_fvec = FieldSchema("float_vector", DataType.FLOAT_VECTOR, is_primary=False, dim=_DIM) - >>> schema = CollectionSchema(fields=[field_int64, field_fvec], description="get collection entities num") + >>> field_fvec = FieldSchema("float_vector", DataType.FLOAT_VECTOR, dim=_DIM) + >>> schema = CollectionSchema(fields=[field_int64, field_fvec]) >>> collection = Collection(name="test_collection", schema=schema) >>> import pandas as pd >>> int64_series = pd.Series(data=list(range(10, 20)), index=list(range(10)))i @@ -275,10 +301,17 @@ def wait_for_loading_complete(collection_name, partition_names=None, timeout=Non """ if not partition_names or len(partition_names) == 0: return _get_connection(using).wait_for_loading_collection(collection_name, timeout=timeout) - return _get_connection(using).wait_for_loading_partitions(collection_name, partition_names, timeout=timeout) + return _get_connection(using).wait_for_loading_partitions( + collection_name, partition_names, timeout=timeout + ) -def index_building_progress(collection_name, index_name="", using="default", timeout=None): +def index_building_progress( + collection_name: str, + index_name: str = "", + using: str = "default", + timeout: Optional[float] = None, +): """ Show # indexed entities vs. # total entities. @@ -298,9 +331,7 @@ def index_building_progress(collection_name, index_name="", using="default", tim :raises CollectionNotExistException: If collection doesn't exist. :raises IndexNotExistException: If index doesn't exist. :example: - >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, connections, utility - >>> connections.connect() - >>> + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, utility >>> fields = [ ... FieldSchema("int64", DataType.INT64, is_primary=True, auto_id=True), ... FieldSchema("float_vector", DataType.FLOAT_VECTOR, dim=128), @@ -317,14 +348,23 @@ def index_building_progress(collection_name, index_name="", using="default", tim ... "index_type": "IVF_FLAT", ... "params": {"nlist": 1024} ... } - >>> index = c.create_index(field_name="float_vector", index_params=index_params, index_name="ivf_flat") + >>> index = c.create_index( + ... field_name="float_vector", + ... index_params=index_params, + ... index_name="ivf_flat") >>> utility.index_building_progress("test_collection", c.name) """ return _get_connection(using).get_index_build_progress( - collection_name=collection_name, index_name=index_name, timeout=timeout) + collection_name=collection_name, index_name=index_name, timeout=timeout + ) -def wait_for_index_building_complete(collection_name, index_name="", timeout=None, using="default"): +def wait_for_index_building_complete( + collection_name: str, + index_name: str = "", + timeout: Optional[float] = None, + using: str = "default", +): """ Block until building is done or Raise Exception after timeout. @@ -341,11 +381,10 @@ def wait_for_index_building_complete(collection_name, index_name="", timeout=Non :raises IndexNotExistException: If index doesn't exist. :example: - >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, connections, utility - >>> connections.connect(alias="default") + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, utility >>> _DIM = 128 >>> field_int64 = FieldSchema("int64", DataType.INT64, description="int64", is_primary=True) - >>> field_vector = FieldSchema("float_vector", DataType.FLOAT_VECTOR, is_primary=False, dim=_DIM) + >>> field_vector = FieldSchema("float_vector", DataType.FLOAT_VECTOR, dim=_DIM) >>> schema = CollectionSchema(fields=[field_int64, field_vector], description="test") >>> collection = Collection(name="test_collection", schema=schema) >>> import random @@ -367,10 +406,12 @@ def wait_for_index_building_complete(collection_name, index_name="", timeout=Non >>> utility.loading_progress("test_collection") """ - return _get_connection(using).wait_for_creating_index(collection_name, index_name, timeout=timeout)[0] + return _get_connection(using).wait_for_creating_index( + collection_name, index_name, timeout=timeout + )[0] -def has_collection(collection_name, using="default", timeout=None): +def has_collection(collection_name: str, using: str = "default", timeout: Optional[float] = None): """ Checks whether a specified collection exists. @@ -381,11 +422,10 @@ def has_collection(collection_name, using="default", timeout=None): Whether the collection exists. :example: - >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, connections, utility - >>> connections.connect(alias="default") + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, utility >>> _DIM = 128 >>> field_int64 = FieldSchema("int64", DataType.INT64, description="int64", is_primary=True) - >>> field_vector = FieldSchema("float_vector", DataType.FLOAT_VECTOR, is_primary=False, dim=_DIM) + >>> field_vector = FieldSchema("float_vector", DataType.FLOAT_VECTOR, dim=_DIM) >>> schema = CollectionSchema(fields=[field_int64, field_vector], description="test") >>> collection = Collection(name="test_collection", schema=schema) >>> utility.has_collection("test_collection") @@ -393,7 +433,12 @@ def has_collection(collection_name, using="default", timeout=None): return _get_connection(using).has_collection(collection_name, timeout=timeout) -def has_partition(collection_name, partition_name, using="default", timeout=None): +def has_partition( + collection_name: str, + partition_name: str, + using: str = "default", + timeout: Optional[float] = None, +) -> bool: """ Checks if a specified partition exists in a collection. @@ -407,11 +452,10 @@ def has_partition(collection_name, partition_name, using="default", timeout=None Whether the partition exist. :example: - >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, connections, utility - >>> connections.connect(alias="default") + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, utility >>> _DIM = 128 >>> field_int64 = FieldSchema("int64", DataType.INT64, description="int64", is_primary=True) - >>> field_vector = FieldSchema("float_vector", DataType.FLOAT_VECTOR, is_primary=False, dim=_DIM) + >>> field_vector = FieldSchema("float_vector", DataType.FLOAT_VECTOR, dim=_DIM) >>> schema = CollectionSchema(fields=[field_int64, field_vector], description="test") >>> collection = Collection(name="test_collection", schema=schema) >>> utility.has_partition("_default") @@ -419,7 +463,7 @@ def has_partition(collection_name, partition_name, using="default", timeout=None return _get_connection(using).has_partition(collection_name, partition_name, timeout=timeout) -def drop_collection(collection_name, timeout=None, using="default"): +def drop_collection(collection_name: str, timeout: Optional[float] = None, using: str = "default"): """ Drop a collection by name @@ -430,8 +474,7 @@ def drop_collection(collection_name, timeout=None, using="default"): :type timeout: float :example: - >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, connections, utility - >>> connections.connect(alias="default") + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, utility >>> schema = CollectionSchema(fields=[ ... FieldSchema("int64", DataType.INT64, description="int64", is_primary=True), ... FieldSchema("float_vector", DataType.FLOAT_VECTOR, is_primary=False, dim=128), @@ -446,7 +489,12 @@ def drop_collection(collection_name, timeout=None, using="default"): return _get_connection(using).drop_collection(collection_name, timeout=timeout) -def rename_collection(old_collection_name, new_collection_name, timeout=None, using="default"): +def rename_collection( + old_collection_name: str, + new_collection_name: str, + timeout: Optional[float] = None, + using: str = "default", +): """ Rename a collection to new collection name @@ -461,8 +509,7 @@ def rename_collection(old_collection_name, new_collection_name, timeout=None, us :type timeout: float :example: - >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, connections, utility - >>> connections.connect(alias="default") + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, utility >>> schema = CollectionSchema(fields=[ ... FieldSchema("int64", DataType.INT64, description="int64", is_primary=True), ... FieldSchema("float_vector", DataType.FLOAT_VECTOR, is_primary=False, dim=128), @@ -474,10 +521,12 @@ def rename_collection(old_collection_name, new_collection_name, timeout=None, us >>> utility.has_collection("new_collection") >>> False """ - return _get_connection(using).rename_collections(old_collection_name, new_collection_name, timeout=timeout) + return _get_connection(using).rename_collections( + old_collection_name, new_collection_name, timeout=timeout + ) -def list_collections(timeout=None, using="default") -> list: +def list_collections(timeout: Optional[float] = None, using: str = "default") -> list: """ Returns a list of all collection names. @@ -489,11 +538,10 @@ def list_collections(timeout=None, using="default") -> list: List of collection names, return when operation is successful :example: - >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, connections, utility - >>> connections.connect(alias="default") + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, utility >>> _DIM = 128 >>> field_int64 = FieldSchema("int64", DataType.INT64, description="int64", is_primary=True) - >>> field_vector = FieldSchema("float_vector", DataType.FLOAT_VECTOR, is_primary=False, dim=_DIM) + >>> field_vector = FieldSchema("float_vector", DataType.FLOAT_VECTOR, dim=_DIM) >>> schema = CollectionSchema(fields=[field_int64, field_vector], description="test") >>> collection = Collection(name="test_collection", schema=schema) >>> utility.list_collections() @@ -501,8 +549,15 @@ def list_collections(timeout=None, using="default") -> list: return _get_connection(using).list_collections(timeout=timeout) -def load_balance(collection_name: str, src_node_id, dst_node_ids=None, sealed_segment_ids=None, timeout=None, using="default"): - """ Do load balancing operation from source query node to destination query node. +def load_balance( + collection_name: str, + src_node_id: int, + dst_node_ids: Optional[List[int]] = None, + sealed_segment_ids: Optional[List[int]] = None, + timeout: Optional[float] = None, + using: str = "default", +): + """Do load balancing operation from source query node to destination query node. :param collection_name: The collection to balance. :type collection_name: str @@ -530,17 +585,22 @@ def load_balance(collection_name: str, src_node_id, dst_node_ids=None, sealed_se >>> src_node_id = 0 >>> dst_node_ids = [1] >>> sealed_segment_ids = [] - >>> res = utility.load_balance("test_collection", src_node_id, dst_node_ids, sealed_segment_ids) + >>> res = utility.load_balance("test", src_node_id, dst_node_ids, sealed_segment_ids) """ if dst_node_ids is None: dst_node_ids = [] if sealed_segment_ids is None: sealed_segment_ids = [] - return _get_connection(using).\ - load_balance(collection_name, src_node_id, dst_node_ids, sealed_segment_ids, timeout=timeout) + return _get_connection(using).load_balance( + collection_name, src_node_id, dst_node_ids, sealed_segment_ids, timeout=timeout + ) -def get_query_segment_info(collection_name, timeout=None, using="default"): +def get_query_segment_info( + collection_name: str, + timeout: Optional[float] = None, + using: str = "default", +): """ Notifies Proxy to return segments information from query nodes. @@ -554,12 +614,11 @@ def get_query_segment_info(collection_name, timeout=None, using="default"): :rtype: QuerySegmentInfo :example: - >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, connections, utility - >>> connections.connect(alias="default") + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, utility >>> _DIM = 128 >>> field_int64 = FieldSchema("int64", DataType.INT64, description="int64", is_primary=True) - >>> field_vector = FieldSchema("float_vector", DataType.FLOAT_VECTOR, is_primary=False, dim=_DIM) - >>> schema = CollectionSchema(fields=[field_int64, field_vector], description="get collection entities num") + >>> field_vector = FieldSchema("float_vector", DataType.FLOAT_VECTOR, dim=_DIM) + >>> schema = CollectionSchema(fields=[field_int64, field_vector]) >>> collection = Collection(name="test_get_segment_info", schema=schema) >>> import pandas as pd >>> int64_series = pd.Series(data=list(range(10, 20)), index=list(range(10)))i @@ -572,8 +631,13 @@ def get_query_segment_info(collection_name, timeout=None, using="default"): return _get_connection(using).get_query_segment_info(collection_name, timeout=timeout) -def create_alias(collection_name: str, alias: str, timeout=None, using="default"): - """ Specify alias for a collection. +def create_alias( + collection_name: str, + alias: str, + timeout: Optional[float] = None, + using: str = "default", +): + """Specify alias for a collection. Alias cannot be duplicated, you can't assign the same alias to different collections. But you can specify multiple aliases for a collection, for example: before create_alias("collection_1", "bob"): @@ -592,8 +656,7 @@ def create_alias(collection_name: str, alias: str, timeout=None, using="default" :raises BaseException: If the alias failed to create. :example: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility - >>> connections.connect() + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, utility >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) @@ -605,8 +668,8 @@ def create_alias(collection_name: str, alias: str, timeout=None, using="default" return _get_connection(using).create_alias(collection_name, alias, timeout=timeout) -def drop_alias(alias: str, timeout=None, using="default"): - """ Delete the alias. +def drop_alias(alias: str, timeout: Optional[float] = None, using: str = "default"): + """Delete the alias. No need to provide collection name because an alias can only be assigned to one collection and the server knows which collection it belongs. For example: @@ -626,8 +689,7 @@ def drop_alias(alias: str, timeout=None, using="default"): :raises BaseException: If the alias doesn't exist. :example: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility - >>> connections.connect() + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, utility >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) @@ -640,8 +702,13 @@ def drop_alias(alias: str, timeout=None, using="default"): return _get_connection(using).drop_alias(alias, timeout=timeout) -def alter_alias(collection_name: str, alias: str, timeout=None, using="default"): - """ Change the alias of a collection to another collection. +def alter_alias( + collection_name: str, + alias: str, + timeout: Optional[float] = None, + using: str = "default", +): + """Change the alias of a collection to another collection. Raise error if the alias doesn't exist. Alias cannot be duplicated, you can't assign same alias to different collections. This api can change alias owner collection, for example: @@ -666,8 +733,7 @@ def alter_alias(collection_name: str, alias: str, timeout=None, using="default") :raises BaseException: If the alias failed to alter. :example: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility - >>> connections.connect() + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, utility >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) @@ -680,15 +746,14 @@ def alter_alias(collection_name: str, alias: str, timeout=None, using="default") return _get_connection(using).alter_alias(collection_name, alias, timeout=timeout) -def list_aliases(collection_name: str, timeout=None, using="default"): - """ Returns alias list of the collection. +def list_aliases(collection_name: str, timeout: Optional[float] = None, using: str = "default"): + """Returns alias list of the collection. :return list of str: The collection aliases, returned when the operation succeeds. :example: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility - >>> connections.connect() + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, utility >>> fields = [ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=128) @@ -701,24 +766,37 @@ def list_aliases(collection_name: str, timeout=None, using="default"): """ conn = _get_connection(using) resp = conn.describe_collection(collection_name, timeout=timeout) - aliases = resp["aliases"] - return aliases + return resp["aliases"] + +def do_bulk_insert( + collection_name: str, + files: List[str], + partition_name: Optional[str] = None, + timeout: Optional[float] = None, + using: str = "default", + **kwargs, +) -> int: + """do_bulk_insert inserts entities through files, currently supports row-based json file. + User need to create the json file with a specified json format which is described in + the official user guide. + + Let's say a collection has two fields: "id" and "vec"(dimension=8), + the row-based json format is: -def do_bulk_insert(collection_name: str, files: list, partition_name=None, timeout=None, using="default", **kwargs) -> int: - """ do_bulk_insert inserts entities through files, currently supports row-based json file. - User need to create the json file with a specified json format which is described in the official user guide. - Let's say a collection has two fields: "id" and "vec"(dimension=8), the row-based json format is: {"rows": [ {"id": "0", "vec": [0.190, 0.046, 0.143, 0.972, 0.592, 0.238, 0.266, 0.995]}, {"id": "1", "vec": [0.149, 0.586, 0.012, 0.673, 0.588, 0.917, 0.949, 0.944]}, ...... ] } - The json file must be uploaded to root path of MinIO/S3 storage which is accessed by milvus server. - For example: - the milvus.yml specify the MinIO/S3 storage bucketName as "a-bucket", user can upload his json file - to a-bucket/xxx.json, then call do_bulk_insert(files=["a-bucket/xxx.json"]) + + The json file must be uploaded to root path of MinIO/S3 storage which is + accessed by milvus server. For example: + + the milvus.yml specify the MinIO/S3 storage bucketName as "a-bucket", + user can upload his json file to a-bucket/xxx.json, + then call do_bulk_insert(files=["a-bucket/xxx.json"]) :param collection_name: the name of the collection :type collection_name: str @@ -726,7 +804,8 @@ def do_bulk_insert(collection_name: str, files: list, partition_name=None, timeo :param partition_name: the name of the partition :type partition_name: str - :param files: related path of the file to be imported, for row-based json file, only allow one file each invocation. + :param files: related path of the file to be imported, for row-based json file, + only allow one file each invocation. :type files: list[str] :param timeout: The timeout for this method, unit: second @@ -741,8 +820,7 @@ def do_bulk_insert(collection_name: str, files: list, partition_name=None, timeo :raises BaseException: If the files input is illegal. :example: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility - >>> connections.connect() + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, utility >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) @@ -751,10 +829,17 @@ def do_bulk_insert(collection_name: str, files: list, partition_name=None, timeo >>> task_id = utility.do_bulk_insert(collection_name=collection.name, files=['data.json']) >>> print(task_id) """ - return _get_connection(using).do_bulk_insert(collection_name, partition_name, files, timeout=timeout, **kwargs) + return _get_connection(using).do_bulk_insert( + collection_name, partition_name, files, timeout=timeout, **kwargs + ) -def get_bulk_insert_state(task_id, timeout=None, using="default", **kwargs) -> BulkInsertState: +def get_bulk_insert_state( + task_id: int, + timeout: Optional[float] = None, + using: str = "default", + **kwargs, +) -> BulkInsertState: """get_bulk_insert_state returns state of a certain task_id :param task_id: the task id returned by bulk_insert @@ -766,17 +851,26 @@ def get_bulk_insert_state(task_id, timeout=None, using="default", **kwargs) -> B :example: >>> from pymilvus import connections, utility, BulkInsertState >>> connections.connect() - >>> state = utility.get_bulk_insert_state(task_id=id) # the id is returned by do_bulk_insert() - >>> if state.state == BulkInsertState.ImportFailed or state.state == BulkInsertState.ImportFailedAndCleaned: - >>> print("task id:", state.task_id, "failed, reason:", state.failed_reason) + >>> # the id is returned by do_bulk_insert() + >>> state = utility.get_bulk_insert_state(task_id=id) + >>> if state.state == BulkInsertState.ImportFailed or \ + ... state.state == BulkInsertState.ImportFailedAndCleaned: + >>> print("task id:", state.task_id, "failed, reason:", state.failed_reason) """ return _get_connection(using).get_bulk_insert_state(task_id, timeout=timeout, **kwargs) -def list_bulk_insert_tasks(limit=0, collection_name=None, timeout=None, using="default", **kwargs) -> list: +def list_bulk_insert_tasks( + limit: int = 0, + collection_name: Optional[str] = None, + timeout: Optional[float] = None, + using: str = "default", + **kwargs, +) -> list: """list_bulk_insert_tasks lists all bulk load tasks - :param limit: maximum number of tasks returned, list all tasks if the value is 0, else return the latest tasks + :param limit: maximum number of tasks returned, list all tasks if the value is 0, + else return the latest tasks :type limit: int :param collection_name: target collection name, list all tasks if the name is empty @@ -791,10 +885,18 @@ def list_bulk_insert_tasks(limit=0, collection_name=None, timeout=None, using="d >>> tasks = utility.list_bulk_insert_tasks(collection_name=collection_name) >>> print(tasks) """ - return _get_connection(using).list_bulk_insert_tasks(limit, collection_name, timeout=timeout, **kwargs) - - -def reset_password(user: str, old_password: str, new_password: str, using="default", timeout=None): + return _get_connection(using).list_bulk_insert_tasks( + limit, collection_name, timeout=timeout, **kwargs + ) + + +def reset_password( + user: str, + old_password: str, + new_password: str, + using: str = "default", + timeout: Optional[float] = None, +): """ Reset the user & password of the connection. You must provide the original password to check if the operation is valid. @@ -817,8 +919,13 @@ def reset_password(user: str, old_password: str, new_password: str, using="defau return _get_connection(using).reset_password(user, old_password, new_password, timeout=timeout) -def create_user(user: str, password: str, using="default", timeout=None): - """ Create User using the given user and password. +def create_user( + user: str, + password: str, + using: str = "default", + timeout: Optional[float] = None, +): + """Create User using the given user and password. :param user: the user name. :type user: str :param password: the password. @@ -835,7 +942,13 @@ def create_user(user: str, password: str, using="default", timeout=None): return _get_connection(using).create_user(user, password, timeout=timeout) -def update_password(user: str, old_password, new_password: str, using="default", timeout=None): +def update_password( + user: str, + old_password: str, + new_password: str, + using: str = "default", + timeout: Optional[float] = None, +): """ Update user password using the given user and password. You must provide the original password to check if the operation is valid. @@ -860,8 +973,8 @@ def update_password(user: str, old_password, new_password: str, using="default", return _get_connection(using).update_password(user, old_password, new_password, timeout=timeout) -def delete_user(user: str, using="default", timeout=None): - """ Delete User corresponding to the username. +def delete_user(user: str, using: str = "default", timeout: Optional[float] = None): + """Delete User corresponding to the username. :param user: the user name. :type user: str @@ -875,8 +988,8 @@ def delete_user(user: str, using="default", timeout=None): return _get_connection(using).delete_user(user, timeout=timeout) -def list_usernames(using="default", timeout=None): - """ List all usernames. +def list_usernames(using: str = "default", timeout: Optional[float] = None): + """List all usernames. :return list of str: The usernames in Milvus instances. @@ -889,8 +1002,8 @@ def list_usernames(using="default", timeout=None): return _get_connection(using).list_usernames(timeout=timeout) -def list_roles(include_user_info: bool, using="default", timeout=None): - """ List All Role Info +def list_roles(include_user_info: bool, using: str = "default", timeout: Optional[float] = None): + """List All Role Info :param include_user_info: whether to obtain the user information associated with roles :type include_user_info: bool :return RoleInfo @@ -904,8 +1017,13 @@ def list_roles(include_user_info: bool, using="default", timeout=None): return _get_connection(using).select_all_role(include_user_info, timeout=timeout) -def list_user(username: str, include_role_info: bool, using="default", timeout=None): - """ List One User Info +def list_user( + username: str, + include_role_info: bool, + using: str = "default", + timeout: Optional[float] = None, +): + """List One User Info :param username: user name. :type username: str :param include_role_info: whether to obtain the role information associated with the user @@ -921,8 +1039,8 @@ def list_user(username: str, include_role_info: bool, using="default", timeout=N return _get_connection(using).select_one_user(username, include_role_info, timeout=timeout) -def list_users(include_role_info: bool, using="default", timeout=None): - """ List All User Info +def list_users(include_role_info: bool, using: str = "default", timeout: Optional[float] = None): + """List All User Info :param include_role_info: whether to obtain the role information associated with users :type include_role_info: bool :return UserInfo @@ -935,8 +1053,9 @@ def list_users(include_role_info: bool, using="default", timeout=None): """ return _get_connection(using).select_all_user(include_role_info, timeout=timeout) -def get_server_version(using="default", timeout=None) -> str: - """ get the running server's version + +def get_server_version(using: str = "default", timeout: Optional[float] = None) -> str: + """get the running server's version :returns: server's version :rtype: str @@ -949,8 +1068,9 @@ def get_server_version(using="default", timeout=None) -> str: """ return _get_connection(using).get_server_version(timeout=timeout) -def create_resource_group(name, using="default", timeout=None): - """ Create a resource group + +def create_resource_group(name: str, using: str = "default", timeout: Optional[float] = None): + """Create a resource group It will success whether or not the resource group exists. :example: @@ -962,8 +1082,9 @@ def create_resource_group(name, using="default", timeout=None): """ return _get_connection(using).create_resource_group(name, timeout) -def drop_resource_group(name, using="default", timeout=None): - """ Drop a resource group + +def drop_resource_group(name: str, using: str = "default", timeout: Optional[float] = None): + """Drop a resource group It will success if the resource group is existed and empty, otherwise fail. :example: @@ -975,8 +1096,9 @@ def drop_resource_group(name, using="default", timeout=None): """ return _get_connection(using).drop_resource_group(name, timeout) -def describe_resource_group(name, using="default", timeout=None): - """ Drop a resource group + +def describe_resource_group(name: str, using: str = "default", timeout: Optional[float] = None): + """Drop a resource group It will success if the resource group is existed and empty, otherwise fail. :example: @@ -987,7 +1109,8 @@ def describe_resource_group(name, using="default", timeout=None): """ return _get_connection(using).describe_resource_group(name, timeout) -def list_resource_groups(using="default", timeout=None): + +def list_resource_groups(using: str = "default", timeout: Optional[float] = None): """list all resource group names :return: all resource group names @@ -1001,7 +1124,13 @@ def list_resource_groups(using="default", timeout=None): return _get_connection(using).list_resource_groups(timeout) -def transfer_node(source_group, target_group, num_nodes, using="default", timeout=None): +def transfer_node( + source_group: str, + target_group: str, + num_nodes: int, + using: str = "default", + timeout: Optional[float] = None, +): """transfer num_node from source resource group to target resource_group :param source_group: source resource group name @@ -1019,7 +1148,14 @@ def transfer_node(source_group, target_group, num_nodes, using="default", timeou return _get_connection(using).transfer_node(source_group, target_group, num_nodes, timeout) -def transfer_replica(source_group, target_group, collection_name, num_replicas, using="default", timeout=None): +def transfer_replica( + source_group: str, + target_group: str, + collection_name: str, + num_replicas: int, + using: str = "default", + timeout: Optional[float] = None, +): """transfer num_replica from source resource group to target resource group :param source_group: source resource group name @@ -1036,23 +1172,25 @@ def transfer_replica(source_group, target_group, collection_name, num_replicas, >>> connections.connect() >>> rgs = utility.transfer_replica(source, target, collection_name, num_replica) """ - return _get_connection(using).transfer_replica(source_group, target_group, collection_name, num_replicas, timeout) + return _get_connection(using).transfer_replica( + source_group, target_group, collection_name, num_replicas, timeout + ) -def flush_all(using="default", timeout=None, **kwargs): - """ Flush all collections. All insertions, deletions, and upserts before `flush_all` will be synced. +def flush_all(using: str = "default", timeout: Optional[float] = None, **kwargs): + """Flush all collections. All insertions, deletions, and upserts before + `flush_all` will be synced. Args: timeout (float): an optional duration of time in seconds to allow for the RPCs. - If timeout is not set, the client keeps waiting until the server responds or an error occurs. - **kwargs (``dict``, optional): - - * *_async*(``bool``) - Indicate if invoke asynchronously. Default `False`. + If timeout is not set, the client keeps waiting until the server responds or + an error occurs. + **kwargs (``dict``, optional): + * *_async*(``bool``) + Indicate if invoke asynchronously. Default `False`. Examples: - >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility - >>> connections.connect() + >>> from pymilvus import Collection, FieldSchema, CollectionSchema, DataType, utility >>> fields = [ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=128) @@ -1068,9 +1206,9 @@ def flush_all(using="default", timeout=None, **kwargs): return _get_connection(using).flush_all(timeout=timeout, **kwargs) -def get_server_type(using="default"): - """ Get the server type. Now, it will return "zilliz" if the connection related to an instance on the zilliz cloud, - otherwise "milvus" will be returned. +def get_server_type(using: str = "default"): + """Get the server type. Now, it will return "zilliz" if the connection related to + an instance on the zilliz cloud, otherwise "milvus" will be returned. :param using: Alias to the connection. Default connection is used if this is not specified. :type using: str @@ -1081,9 +1219,15 @@ def get_server_type(using="default"): return _get_connection(using).get_server_type() -def list_indexes(collection_name, using="default", timeout=None, **kwargs): - """ List all indexes of collection. If `field_name` is not specified, return all the indexes of this collection, - otherwise this interface will return all indexes on this field of the collection. +def list_indexes( + collection_name: str, + using: str = "default", + timeout: Optional[float] = None, + **kwargs, +): + """List all indexes of collection. If `field_name` is not specified, + return all the indexes of this collection, otherwise this interface will return + all indexes on this field of the collection. :param collection_name: The name of collection. :type collection_name: str @@ -1096,8 +1240,9 @@ def list_indexes(collection_name, using="default", timeout=None, **kwargs): :type timeout: float/int :param kwargs: - * *field_name* (``str``) - The name of field. If no field name is specified, all indexes of this collection will be returned. + * *field_name* (``str``) + The name of field. If no field name is specified, all indexes + of this collection will be returned. :type kwargs: dict :return: The name list of all indexes. diff --git a/pymilvus/settings.py b/pymilvus/settings.py index d54826731..b1f89d428 100644 --- a/pymilvus/settings.py +++ b/pymilvus/settings.py @@ -1,11 +1,13 @@ +import contextlib import logging.config + import environs + env = environs.Env() -try: +with contextlib.suppress(Exception): env.read_env(".env") -except Exception: - pass + class Config: # legacy env MILVUS_DEFAULT_CONNECTION, not recommended @@ -30,70 +32,69 @@ class Config: WaitTimeDurationWhenLoad = 0.5 # in seconds MaxVarCharLengthKey = "max_length" MaxVarCharLength = 65535 - EncodeProtocol = 'utf-8' + EncodeProtocol = "utf-8" IndexName = "" # logging COLORS = { - 'HEADER': '\033[95m', - 'INFO': '\033[92m', - 'DEBUG': '\033[94m', - 'WARNING': '\033[93m', - 'ERROR': '\033[95m', - 'CRITICAL': '\033[91m', - 'ENDC': '\033[0m', + "HEADER": "\033[95m", + "INFO": "\033[92m", + "DEBUG": "\033[94m", + "WARNING": "\033[93m", + "ERROR": "\033[95m", + "CRITICAL": "\033[91m", + "ENDC": "\033[0m", } class ColorFulFormatColMixin: - def format_col(self, message_str, level_name): + def format_col(self, message_str: str, level_name: str): if level_name in COLORS: - message_str = COLORS.get(level_name) + message_str + COLORS.get('ENDC') + message_str = COLORS.get(level_name) + message_str + COLORS.get("ENDC") return message_str class ColorfulFormatter(logging.Formatter, ColorFulFormatColMixin): - def format(self, record): + def format(self, record: str): message_str = super().format(record) return self.format_col(message_str, level_name=record.levelname) -LOG_LEVEL = 'WARNING' -# LOG_LEVEL = 'DEBUG' +LOG_LEVEL = "WARNING" LOGGING = { - 'version': 1, - 'disable_existing_loggers': False, - 'handlers': { - 'console': { - 'class': 'logging.StreamHandler', - 'level': LOG_LEVEL, + "version": 1, + "disable_existing_loggers": False, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "level": LOG_LEVEL, }, }, - 'loggers': { - 'milvus': { - 'handlers': ['console'], - 'level': LOG_LEVEL, + "loggers": { + "milvus": { + "handlers": ["console"], + "level": LOG_LEVEL, }, }, } -if LOG_LEVEL == 'DEBUG': - LOGGING['formatters'] = { - 'colorful_console': { - 'format': '[%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s)', - '()': ColorfulFormatter, +if LOG_LEVEL == "DEBUG": + LOGGING["formatters"] = { + "colorful_console": { + "format": "[%(asctime)s-%(levelname)s-%(name)s]: %(message)s (%(filename)s:%(lineno)s)", + "()": ColorfulFormatter, }, } - LOGGING['handlers']['milvus_console'] = { - 'class': 'logging.StreamHandler', - 'formatter': 'colorful_console', + LOGGING["handlers"]["milvus_console"] = { + "class": "logging.StreamHandler", + "formatter": "colorful_console", } - LOGGING['loggers']['milvus'] = { - 'handlers': ['milvus_console'], - 'level': LOG_LEVEL, + LOGGING["loggers"]["milvus"] = { + "handlers": ["milvus_console"], + "level": LOG_LEVEL, } logging.config.dictConfig(LOGGING) diff --git a/pyproject.toml b/pyproject.toml index 54c21304f..c3863b94b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,3 +38,124 @@ dynamic = ["version"] [tool.setuptools_scm] 'local_scheme'= 'no-local-version' 'version_scheme'= 'release-branch-semver' + +[tool.black] +line-length = 100 +target-version = ['py37'] +include = '\.pyi?$' +extend-ignore = ["E203", "E501"] +# 'extend-exclude' excludes files or directories in addition to the defaults +extend-exclude = ''' +# A regex preceded with ^/ will apply only to files and directories +# in the root of the project. +( + ^/foo.py # exclude a file named foo.py in the root of the project + | .*/grpc_gen/ +) +''' + +[tool.ruff] +select = [ + "E", + "F", + "C90", + "I", + "N", + "B", "C", "G", + "A", + "ANN001", + "S", "T", "W", "ARG", "BLE", "COM", "DJ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT" +] +ignore = [ + "N818", + "DTZ", # datatime related + "BLE", # blind-except (BLE001) + "SLF", # SLF001 Private member accessed: `_fetch_handler` [E] + "PD003", + "TRY003", # [ruff] TRY003 Avoid specifying long messages outside the exception class [E] TODO + "PLR2004", # Magic value used in comparison, consider replacing 65535 with a constant variable [E] TODO + "TRY301", #[ruff] TRY301 Abstract `raise` to an inner function [E] + "FBT001", #[ruff] FBT001 Boolean positional arg in function definition [E] TODO + "FBT002", # [ruff] FBT002 Boolean default value in function definition [E] TODO + "PLR0911", # Too many return statements (15 > 6) [E] + "G004", # [ruff] G004 Logging statement uses f-string [E] + "S603", # [ruff] S603 `subprocess` call: check for execution of untrusted input [E] + "N802", #[ruff] N802 Function name `OK` should be lowercase [E] TODO + "PD011", # [ruff] PD011 Use `.to_numpy()` instead of `.values` [E] + "COM812", + "FBT003", # [ruff] FBT003 Boolean positional value in function call [E] TODO + "ARG002", + "E501", # black takes care of it + "ARG005", # [ruff] ARG005 Unused lambda argument: `disable` [E] + "TRY400", +] + +# Allow autofix for all enabled rules (when `--fix`) is provided. +fixable = [ + "A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", + "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", + "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", + "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", + "YTT", +] +unfixable = [] + +show-fixes = true + +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", + "grpc_gen", + "__pycache__", + "pymilvus/client/stub.py" +] + +# Same as Black. +line-length = 100 + +# Allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +# Assume Python 3.7 +target-version = "py37" + +[tool.ruff.mccabe] +# Unlike Flake8, default to a complexity level of 10. +max-complexity = 17 + +[tool.ruff.pycodestyle] +max-doc-length = 100 + +[tool.ruff.pylint] +max-args = 20 +max-branches = 15 + +[tool.ruff.flake8-builtins] +builtins-ignorelist = [ + "format", + "next", + "object", # TODO + "id", + "dict", # TODO + "filter", +] diff --git a/requirements.txt b/requirements.txt index 3831cb63c..b2a2b3ca1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,5 +31,6 @@ tqdm==4.65.0 pytest>=5.3.4 pytest-cov==2.8.1 pytest-timeout==1.3.4 -pylint==2.13.9 pandas>=1.1.5 +ruff +black diff --git a/test_requirements.txt b/test_requirements.txt index 4e6ebad8d..ea4da2c0d 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -1,6 +1,7 @@ pytest>=5.3.4 pytest-cov==2.8.1 pytest-timeout==1.3.4 -pylint==2.13.9 grpcio-testing sklearn==0.0 +ruff +black diff --git a/tests/test_schema.py b/tests/test_schema.py index fa440a771..15469797d 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -166,4 +166,4 @@ def test_check_insert_data_schema_issue1324(self): ] with pytest.raises(MilvusException): - s.check_insert_or_upsert_data_schema(schema, data) + s.check_insert_schema(schema, data) diff --git a/tests/test_types.py b/tests/test_types.py index 77e45a511..51f389597 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -9,7 +9,8 @@ # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under the License. -from pymilvus import DataType, DEFAULT_RESOURCE_GROUP +from pymilvus import DataType +from pymilvus.client.constants import DEFAULT_RESOURCE_GROUP from pymilvus.exceptions import InvalidConsistencyLevel from pymilvus.client.types import ( get_consistency_level, Shard, Group, Replica, ConsistencyLevel