From 8a6ad071277feca0c94368e1ac09a46faea65401 Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Mon, 25 Sep 2023 09:25:39 -0700 Subject: [PATCH] [CHORE] Add support for pydantic v2 (#1174) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description of changes Closes #893 *Summarize the changes made by this PR.* - Improvements & Bug fixes - Adds support for pydantic v2.0 by changing how Collection model inits - this simple change fixes pydantic v2 - Fixes the cross version tests to handle pydantic specifically - Conditionally imports pydantic-settings based on what is available. In v2 BaseSettings was moved to a new package. - New functionality - N/A ## Test plan Existing tests were run with the following configs 1. Fastapi < 0.100, Pydantic >= 2.0 - Unsupported as the fastapi dependencies will not allow it. They likely should, as pydantic.v1 imports would support this, but this is a downstream issue. 2. Fastapi >= 0.100, Pydantic >= 2.0, Supported via normal imports ✅ (Tested with fastapi==0.103.1, pydantic==2.3.0) 3. Fastapi < 0.100 Pydantic < 2.0, Supported via normal imports ✅ (Tested with fastapi==0.95.2, pydantic==1.9.2) 4. Fastapi >= 0.100, Pydantic < 2.0, Supported via normal imports ✅ (Tested with latest fastapi, pydantic==1.9.2) - [x] Tests pass locally with `pytest` for python, `yarn test` for js ## Documentation Changes None required. --- chromadb/api/models/Collection.py | 3 ++- chromadb/auth/providers.py | 1 - chromadb/config.py | 13 ++++++++++++- .../test/property/test_cross_version_persist.py | 12 +++++++++--- pyproject.toml | 4 ++-- requirements.txt | 6 +++--- 6 files changed, 28 insertions(+), 11 deletions(-) diff --git a/chromadb/api/models/Collection.py b/chromadb/api/models/Collection.py index 6b4f7f18bd9..c11a04b1fa4 100644 --- a/chromadb/api/models/Collection.py +++ b/chromadb/api/models/Collection.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING, Optional, Tuple, cast, List from pydantic import BaseModel, PrivateAttr + from uuid import UUID import chromadb.utils.embedding_functions as ef @@ -50,9 +51,9 @@ def __init__( embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), metadata: Optional[CollectionMetadata] = None, ): + super().__init__(name=name, metadata=metadata, id=id) self._client = client self._embedding_function = embedding_function - super().__init__(name=name, metadata=metadata, id=id) def __repr__(self) -> str: return f"Collection(name={self.name})" diff --git a/chromadb/auth/providers.py b/chromadb/auth/providers.py index a3bb23616e2..eceee3bc2ab 100644 --- a/chromadb/auth/providers.py +++ b/chromadb/auth/providers.py @@ -5,7 +5,6 @@ import requests from overrides import override from pydantic import SecretStr - from chromadb.auth import ( ServerAuthCredentialsProvider, AbstractCredentials, diff --git a/chromadb/config.py b/chromadb/config.py index 6167193acd2..0e9c87e5572 100644 --- a/chromadb/config.py +++ b/chromadb/config.py @@ -9,9 +9,20 @@ from overrides import EnforceOverrides from overrides import override -from pydantic import BaseSettings, validator from typing_extensions import Literal + +in_pydantic_v2 = False +try: + from pydantic import BaseSettings +except ImportError: + in_pydantic_v2 = True + from pydantic.v1 import BaseSettings + from pydantic.v1 import validator + +if not in_pydantic_v2: + from pydantic import validator # type: ignore # noqa + # The thin client will have a flag to control which implementations to use is_thin_client = False try: diff --git a/chromadb/test/property/test_cross_version_persist.py b/chromadb/test/property/test_cross_version_persist.py index d1785b84140..529fe02dda7 100644 --- a/chromadb/test/property/test_cross_version_persist.py +++ b/chromadb/test/property/test_cross_version_persist.py @@ -24,6 +24,9 @@ MINIMUM_VERSION = "0.4.1" version_re = re.compile(r"^[0-9]+\.[0-9]+\.[0-9]+$") +# Some modules do not work across versions, since we upgrade our support for them, and should be explicitly reimported in the subprocess +VERSIONED_MODULES = ["pydantic"] + def versions() -> List[str]: """Returns the pinned minimum version and the latest version of chromadb.""" @@ -49,7 +52,7 @@ def _patch_boolean_metadata( # boolean value metadata to int collection_metadata = collection.metadata if collection_metadata is not None: - _bool_to_int(collection_metadata) + _bool_to_int(collection_metadata) # type: ignore if embeddings["metadatas"] is not None: if isinstance(embeddings["metadatas"], list): @@ -162,7 +165,10 @@ def switch_to_version(version: str) -> ModuleType: old_modules = { n: m for n, m in sys.modules.items() - if n == module_name or (n.startswith(module_name + ".")) + if n == module_name + or (n.startswith(module_name + ".")) + or n in VERSIONED_MODULES + or (any(n.startswith(m + ".") for m in VERSIONED_MODULES)) } for n in old_modules: del sys.modules[n] @@ -197,7 +203,7 @@ def persist_generated_data_with_old_version( api.reset() coll = api.create_collection( name=collection_strategy.name, - metadata=collection_strategy.metadata, + metadata=collection_strategy.metadata, # type: ignore # In order to test old versions, we can't rely on the not_implemented function embedding_function=not_implemented_ef(), ) diff --git a/pyproject.toml b/pyproject.toml index 8fc60673607..7db0fe821ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,9 +16,9 @@ classifiers = [ ] dependencies = [ 'requests >= 2.28', - 'pydantic>=1.9,<2.0', + 'pydantic >= 1.9', 'chroma-hnswlib==0.7.3', - 'fastapi>=0.95.2, <0.100.0', + 'fastapi >= 0.95.2', 'uvicorn[standard] >= 0.18.3', 'numpy == 1.21.6; python_version < "3.8"', 'numpy >= 1.22.5; python_version >= "3.8"', diff --git a/requirements.txt b/requirements.txt index 9a9fdcc295c..80f4d9be904 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ bcrypt==4.0.1 chroma-hnswlib==0.7.3 -fastapi>=0.95.2, <0.100.0 +fastapi>=0.95.2 graphlib_backport==1.0.3; python_version < '3.9' importlib-resources numpy==1.21.6; python_version < '3.8' @@ -9,11 +9,11 @@ onnxruntime==1.14.1 overrides==7.3.1 posthog==2.4.0 pulsar-client==3.1.0 -pydantic>=1.9,<2.0 +pydantic>=1.9 pypika==0.48.9 requests==2.28.1 tokenizers==0.13.2 tqdm==4.65.0 typer>=0.9.0 -typing_extensions==4.5.0 +typing_extensions>=4.5.0 uvicorn[standard]==0.18.3