diff --git a/environment_dev.yml b/environment_dev.yml index dcb2c0b7fd..3b73689e5c 100644 --- a/environment_dev.yml +++ b/environment_dev.yml @@ -43,9 +43,9 @@ dependencies: - pgmpy - plotly>=4.1.0 - snorkel>=0.9.7 - - spacy==3.1.0 - - https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.1.0/en_core_web_sm-3.1.0.tar.gz + - spacy==3.5.0 + - https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.5.0/en_core_web_sm-3.5.0.tar.gz - transformers[torch]~=4.18.0 - - loguru + - rich==13.0.1 # install Argilla in editable mode - -e .[server,listeners] diff --git a/pyproject.toml b/pyproject.toml index 42133cb2c7..bf7123e24f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,11 +38,13 @@ dependencies = [ "wrapt >= 1.13,< 1.15", # weaksupervision "numpy < 1.24.0", + # for progressbars "tqdm >= 4.27.0", # monitor background consumers "backoff", - "monotonic" - + "monotonic", + # for logging, tracebacks, printing, progressbars + "rich <= 13.0.1" ] dynamic = ["version"] diff --git a/src/argilla/__init__.py b/src/argilla/__init__.py index 1dbd654c9f..81ed00085a 100644 --- a/src/argilla/__init__.py +++ b/src/argilla/__init__.py @@ -26,6 +26,14 @@ from . import _version from .utils import LazyargillaModule as _LazyargillaModule +try: + from rich.traceback import install as _install_rich + + # Rely on `rich` for tracebacks + _install_rich() +except ModuleNotFoundError: + pass + __version__ = _version.version if _TYPE_CHECKING: diff --git a/src/argilla/client/client.py b/src/argilla/client/client.py index ef36f35800..70f9938509 100644 --- a/src/argilla/client/client.py +++ b/src/argilla/client/client.py @@ -20,7 +20,8 @@ from asyncio import Future from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union -from tqdm.auto import tqdm +from rich import print as rprint +from rich.progress import Progress from argilla._constants import ( _OLD_WORKSPACE_HEADER_NAME, @@ -363,25 +364,26 @@ async def log_async( raise InputValueError(f"Unknown record type {record_type}. Available values are" f" {Record.__args__}") processed, failed = 0, 0 - progress_bar = tqdm(total=len(records), disable=not verbose) - for i in range(0, len(records), chunk_size): - chunk = records[i : i + chunk_size] - - response = await async_bulk( - client=self._client, - name=name, - json_body=bulk_class( - tags=tags, - metadata=metadata, - records=[creation_class.from_client(r) for r in chunk], - ), - ) + with Progress() as progress_bar: + task = progress_bar.add_task("Logging...", total=len(records), visible=verbose) + + for i in range(0, len(records), chunk_size): + chunk = records[i : i + chunk_size] + + response = await async_bulk( + client=self._client, + name=name, + json_body=bulk_class( + tags=tags, + metadata=metadata, + records=[creation_class.from_client(r) for r in chunk], + ), + ) - processed += response.parsed.processed - failed += response.parsed.failed + processed += response.parsed.processed + failed += response.parsed.failed - progress_bar.update(len(chunk)) - progress_bar.close() + progress_bar.update(task, advance=len(chunk)) # TODO: improve logging policy in library if verbose: @@ -389,7 +391,7 @@ async def log_async( workspace = self.get_workspace() if not workspace: # Just for backward comp. with datasets with no workspaces workspace = "-" - print(f"{processed} records logged to" f" {self._client.base_url}/datasets/{workspace}/{name}") + rprint(f"{processed} records logged to {self._client.base_url}/datasets/{workspace}/{name}") # Creating a composite BulkResponse with the total processed and failed return BulkResponse(dataset=name, processed=processed, failed=failed) diff --git a/src/argilla/logging.py b/src/argilla/logging.py index 94b1cc249e..51e3c56b49 100644 --- a/src/argilla/logging.py +++ b/src/argilla/logging.py @@ -18,13 +18,13 @@ """ import logging -from logging import Logger +from logging import Logger, StreamHandler from typing import Type try: - from loguru import logger + from rich.logging import RichHandler as ArgillaHandler except ModuleNotFoundError: - logger = None + ArgillaHandler = StreamHandler def full_qualified_class_name(_class: Type) -> str: @@ -60,64 +60,10 @@ def logger(self) -> logging.Logger: return self.__logger__ -class LoguruLoggerHandler(logging.Handler): - """This logging handler enables an easy way to use loguru fo all built-in logger traces""" - - __LOGLEVEL_MAPPING__ = { - 50: "CRITICAL", - 40: "ERROR", - 30: "WARNING", - 20: "INFO", - 10: "DEBUG", - 0: "NOTSET", - } - - @property - def is_available(self) -> bool: - """Return True if handler can tackle log records. False otherwise""" - return logger is not None - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - if not self.is_available: - self.emit = lambda record: None - - def emit(self, record: logging.LogRecord): - try: - level = logger.level(record.levelname).name - except AttributeError: - level = self.__LOGLEVEL_MAPPING__[record.levelno] - - frame, depth = logging.currentframe(), 2 - while frame.f_code.co_filename == logging.__file__: - frame = frame.f_back - depth += 1 - - log = logger.bind(request_id="argilla") - log.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage()) - - def configure_logging(): """Normalizes logging configuration for argilla and its dependencies""" - intercept_handler = LoguruLoggerHandler() - if not intercept_handler.is_available: - return - - logging.basicConfig(handlers=[intercept_handler], level=logging.WARNING) - for name in logging.root.manager.loggerDict: - logger_ = logging.getLogger(name) - logger_.handlers = [] - - for name in [ - "uvicorn", - "uvicorn.lifespan", - "uvicorn.error", - "uvicorn.access", - "fastapi", - "argilla", - "argilla.server", - ]: - logger_ = logging.getLogger(name) - logger_.propagate = False - logger_.handlers = [intercept_handler] + handler = ArgillaHandler() + + # See the note here: https://docs.python.org/3/library/logging.html#logging.Logger.propagate + # We only attach our handler to the root logger and let propagation take care of the rest + logging.basicConfig(handlers=[handler], level=logging.WARNING) diff --git a/tests/conftest.py b/tests/conftest.py index 33738a6db6..f8495bff2b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,15 +15,10 @@ import httpx import pytest from _pytest.logging import LogCaptureFixture -from argilla.client.sdk.users import api as users_api -from argilla.server.commons import telemetry - -try: - from loguru import logger -except ModuleNotFoundError: - logger = None from argilla import app from argilla.client.api import active_api +from argilla.client.sdk.users import api as users_api +from argilla.server.commons import telemetry from starlette.testclient import TestClient from .helpers import SecuredClient @@ -68,13 +63,3 @@ def whoami_mocked(client): monkeypatch.setattr(rb_api._client, "__httpx__", client_) yield client_ - - -@pytest.fixture -def caplog(caplog: LogCaptureFixture): - if not logger: - yield caplog - else: - handler_id = logger.add(caplog.handler, format="{message}") - yield caplog - logger.remove(handler_id) diff --git a/tests/labeling/text_classification/test_label_errors.py b/tests/labeling/text_classification/test_label_errors.py index b07780ec71..0e46394bf0 100644 --- a/tests/labeling/text_classification/test_label_errors.py +++ b/tests/labeling/text_classification/test_label_errors.py @@ -17,6 +17,7 @@ import argilla as ar import cleanlab import pytest +from _pytest.logging import LogCaptureFixture from argilla.labeling.text_classification import find_label_errors from argilla.labeling.text_classification.label_errors import ( MissingPredictionError, @@ -70,7 +71,7 @@ def test_no_records(): find_label_errors(records) -def test_multi_label_warning(caplog): +def test_multi_label_warning(caplog: LogCaptureFixture): record = ar.TextClassificationRecord( text="test", prediction=[("mock", 0.0), ("mock2", 0.0)], diff --git a/tests/test_init.py b/tests/test_init.py index 9a3a37b103..753e230216 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -16,7 +16,7 @@ import logging import sys -from argilla.logging import LoguruLoggerHandler +from argilla.logging import ArgillaHandler from argilla.utils import LazyargillaModule @@ -25,4 +25,7 @@ def test_lazy_module(): def test_configure_logging_call(): - assert isinstance(logging.getLogger("argilla").handlers[0], LoguruLoggerHandler) + # Ensure that the root logger uses the ArgillaHandler (RichHandler if rich is installed), + # whereas the other loggers do not have handlers + assert isinstance(logging.getLogger().handlers[0], ArgillaHandler) + assert len(logging.getLogger("argilla").handlers) == 0 diff --git a/tests/test_logging.py b/tests/test_logging.py index 054deb7740..5f7f9f5be4 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from argilla.logging import LoggingMixin, LoguruLoggerHandler +from argilla.logging import ArgillaHandler, LoggingMixin class LoggingForTest(LoggingMixin): @@ -50,8 +50,8 @@ def test_logging_mixin_without_breaking_constructors(): def test_logging_handler(mocker): - mocker.patch.object(LoguruLoggerHandler, "emit", autospec=True) - handler = LoguruLoggerHandler() + mocker.patch.object(ArgillaHandler, "emit", autospec=True) + handler = ArgillaHandler() logger = logging.getLogger(__name__) logger.handlers = [handler]