diff --git a/chromadb/api/client.py b/chromadb/api/client.py index 0a97918d6b3..b0e7d690192 100644 --- a/chromadb/api/client.py +++ b/chromadb/api/client.py @@ -1,7 +1,4 @@ -import json -import logging from typing import ClassVar, Dict, Optional, Sequence -from urllib import request from uuid import UUID import uuid @@ -31,9 +28,7 @@ from chromadb.telemetry.product.events import ClientStartEvent from chromadb.types import Database, Tenant, Where, WhereDocument import chromadb.utils.embedding_functions as ef -from chromadb.utils.client_utils import compare_versions - -logger = logging.getLogger(__name__) +from chromadb.utils.client_utils import _upgrade_check class SharedSystemClient: @@ -134,7 +129,6 @@ class Client(SharedSystemClient, ClientAPI): _server: ServerAPI # An internal admin client for verifying that databases and tenants exist _admin_client: AdminAPI - _upgrade_check_url: str = "https://pypi.org/pypi/chromadb/json" # region Initialization def __init__( @@ -144,7 +138,7 @@ def __init__( settings: Settings = Settings(), ) -> None: super().__init__(settings=settings) - self._upgrade_check() + _upgrade_check() self.tenant = tenant self.database = database # Create an admin client for verifying that databases and tenants exist @@ -172,25 +166,6 @@ def from_system( # endregion - def _upgrade_check(self) -> None: - """Check pypi index for new version if possible.""" - try: - data = json.load( - request.urlopen(request.Request(self._upgrade_check_url), timeout=5) - ) - from chromadb import __version__ as local_chroma_version - - latest_version = data["info"]["version"] - if compare_versions(latest_version, local_chroma_version) > 0: - logger.info( - f"\033[38;5;069m[notice]\033[0m A new release of chromadb is available: " - f"\033[38;5;196m{local_chroma_version}!\033[0m -> " - f"\033[38;5;082m{latest_version}\033[0m\n" - "\033[38;5;069m[notice]\033[0m To upgrade, run `pip install --upgrade chromadb`." - ) - except Exception: - pass - # region BaseAPI Methods # Note - we could do this in less verbose ways, but they break type checking @override diff --git a/chromadb/app.py b/chromadb/app.py index 420bc2fce42..35d3c3b6d9e 100644 --- a/chromadb/app.py +++ b/chromadb/app.py @@ -1,7 +1,8 @@ -import chromadb import chromadb.config from chromadb.server.fastapi import FastAPI +from chromadb.utils.client_utils import _upgrade_check +_upgrade_check() settings = chromadb.config.Settings() server = FastAPI(settings) app = server.app() diff --git a/chromadb/cli/cli.py b/chromadb/cli/cli.py index bdfedd99da1..bf452aaf327 100644 --- a/chromadb/cli/cli.py +++ b/chromadb/cli/cli.py @@ -5,6 +5,8 @@ import os import webbrowser +from chromadb.utils.client_utils import _upgrade_check + app = typer.Typer() _logo = """ @@ -50,6 +52,11 @@ def run( "\033[1mGetting started guide\033[0m: https://docs.trychroma.com/getting-started\n\n" ) + upgrade_message = _upgrade_check() + if upgrade_message: + for m in upgrade_message: + typer.echo(m) + # set ENV variable for PERSIST_DIRECTORY to path os.environ["IS_PERSISTENT"] = "True" os.environ["PERSIST_DIRECTORY"] = path @@ -69,7 +76,6 @@ def run( "log_config": f"{chromadb_path}/log_config.yml", "timeout_keep_alive": 30, } - if test: return diff --git a/chromadb/test/client/test_client_upgrade.py b/chromadb/test/client/test_client_upgrade.py index 808bbf1698f..172e2e6b662 100644 --- a/chromadb/test/client/test_client_upgrade.py +++ b/chromadb/test/client/test_client_upgrade.py @@ -9,7 +9,7 @@ def test_new_release_available(caplog: pytest.LogCaptureFixture) -> None: with patch( - "chromadb.api.client.Client._upgrade_check_url", + "chromadb.utils.client_utils._upgrade_check_url", new="http://localhost:8008/pypi/chromadb/json", ): with HTTPServer(port=8008) as httpserver: @@ -39,7 +39,7 @@ def test_on_latest_release(caplog: pytest.LogCaptureFixture) -> None: def test_local_version_newer_than_latest(caplog: pytest.LogCaptureFixture) -> None: with patch( - "chromadb.api.client.Client._upgrade_check_url", + "chromadb.utils.client_utils._upgrade_check_url", new="http://localhost:8008/pypi/chromadb/json", ): with HTTPServer(port=8008) as httpserver: @@ -56,7 +56,7 @@ def test_local_version_newer_than_latest(caplog: pytest.LogCaptureFixture) -> No def test_pypi_unavailable(caplog: pytest.LogCaptureFixture) -> None: with patch( - "chromadb.api.client.Client._upgrade_check_url", + "chromadb.utils.client_utils._upgrade_check_url", new="http://localhost:8008/pypi/chromadb/json", ): with HTTPServer(port=8009) as httpserver: diff --git a/chromadb/test/test_cli.py b/chromadb/test/test_cli.py index 231877341f5..7d878584bf9 100644 --- a/chromadb/test/test_cli.py +++ b/chromadb/test/test_cli.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + from typer.testing import CliRunner from chromadb.cli.cli import app @@ -19,3 +21,24 @@ def test_app() -> None: ) assert "chroma_test_data" in result.stdout assert "8001" in result.stdout + + +def test_app_version_upgrade() -> None: + with patch( + "chromadb.__version__", + new="0.0.1", + ): + result = runner.invoke( + app, + [ + "run", + "--path", + "chroma_test_data", + "--port", + "8001", + "--test", + ], + ) + assert "A new release of chromadb is available" in result.stdout + assert "chroma_test_data" in result.stdout + assert "8001" in result.stdout diff --git a/chromadb/utils/client_utils.py b/chromadb/utils/client_utils.py index bfbb1d93613..1eba2041571 100644 --- a/chromadb/utils/client_utils.py +++ b/chromadb/utils/client_utils.py @@ -1,3 +1,11 @@ +import json +import logging +from typing import List +from urllib import request + +logger = logging.getLogger(__name__) + + def compare_versions(version1: str, version2: str) -> int: """Compares two versions of the format X.Y.Z and returns 1 if version1 is greater than version2, -1 if version1 is less than version2, and 0 if version1 is equal to version2. @@ -17,3 +25,38 @@ def compare_versions(version1: str, version2: str) -> int: return -1 return 0 + + +_upgrade_check_url: str = "https://pypi.org/pypi/chromadb/json" +_check_performed: bool = False + + +def _upgrade_check() -> List[str]: + """Check pypi index for new version if possible.""" + global _check_performed + upgrade_messages: List[str] = [] + # this is to prevent cli from double printing + if _check_performed: + return upgrade_messages + try: + data = json.load( + request.urlopen(request.Request(_upgrade_check_url), timeout=5) + ) + from chromadb import __version__ as local_chroma_version + + latest_version = data["info"]["version"] + if compare_versions(latest_version, local_chroma_version) > 0: + upgrade_messages.append( + f"\033[38;5;069m[notice]\033[0m A new release of chromadb is available: " + f"\033[38;5;196m{local_chroma_version}\033[0m -> " + f"\033[38;5;082m{latest_version}\033[0m" + ) + upgrade_messages.append( + "\033[38;5;069m[notice]\033[0m To upgrade, run `pip install --upgrade chromadb`." + ) + except Exception: + pass + _check_performed = True + for m in upgrade_messages: + logger.info(m) + return upgrade_messages