diff --git a/diracx-core/pyproject.toml b/diracx-core/pyproject.toml index c9cb8f9e..5b058978 100644 --- a/diracx-core/pyproject.toml +++ b/diracx-core/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "pydantic >=2.10", "pydantic-settings", "pyyaml", + "sh", ] dynamic = ["version"] diff --git a/diracx-core/src/diracx/core/config/__init__.py b/diracx-core/src/diracx/core/config/__init__.py index 60a435ee..4431d2f4 100644 --- a/diracx-core/src/diracx/core/config/__init__.py +++ b/diracx-core/src/diracx/core/config/__init__.py @@ -7,6 +7,7 @@ __all__ = ("Config", "ConfigSource", "LocalGitConfigSource", "RemoteGitConfigSource") +import asyncio import logging import os from abc import ABCMeta, abstractmethod @@ -15,7 +16,7 @@ from tempfile import TemporaryDirectory from typing import Annotated -import git +import sh import yaml from cachetools import Cache, LRUCache, TTLCache, cachedmethod from pydantic import AnyUrl, BeforeValidator, TypeAdapter, UrlConstraints @@ -34,6 +35,14 @@ logger = logging.getLogger(__name__) +def is_running_in_async_context(): + try: + asyncio.get_running_loop() + return True + except RuntimeError: + return False + + def _apply_default_scheme(value: str) -> str: """Applies the default git+file:// scheme if not present.""" if isinstance(value, str) and "://" not in value: @@ -117,10 +126,9 @@ class BaseGitConfigSource(ConfigSource): The caching is based on 2 caches: * TTL to find the latest commit hashes * LRU to keep in memory the last few versions. - """ - repo: git.Repo + repo_location: Path # Needed because of the ConfigSource.__init_subclass__ scheme = "basegit" @@ -134,22 +142,41 @@ def __init__(self, *, backend_url: ConfigSourceUrl) -> None: @cachedmethod(lambda self: self._latest_revision_cache) def latest_revision(self) -> tuple[str, datetime]: try: - rev = self.repo.rev_parse(DEFAULT_GIT_BRANCH) - except git.exc.ODBError as e: # type: ignore + rev = sh.git( + "rev-parse", + DEFAULT_GIT_BRANCH, + _cwd=self.repo_location, + _tty_out=False, + _async=is_running_in_async_context(), + ).strip() + commit_info = sh.git.show( + "-s", + "--format=%ct", + rev, + _cwd=self.repo_location, + _tty_out=False, + _async=is_running_in_async_context(), + ).strip() + modified = datetime.fromtimestamp(int(commit_info), tz=timezone.utc) + except sh.ErrorReturnCode as e: raise BadConfigurationVersion(f"Error parsing latest revision: {e}") from e - modified = rev.committed_datetime.astimezone(timezone.utc) - logger.debug( - "Latest revision for %s is %s with mtime %s", self, rev.hexsha, modified - ) - return rev.hexsha, modified + logger.debug("Latest revision for %s is %s with mtime %s", self, rev, modified) + return rev, modified @cachedmethod(lambda self: self._read_raw_cache) def read_raw(self, hexsha: str, modified: datetime) -> Config: """:param: hexsha commit hash""" logger.debug("Reading %s for %s with mtime %s", self, hexsha, modified) - rev = self.repo.rev_parse(hexsha) - blob = rev.tree / DEFAULT_CONFIG_FILE - raw_obj = yaml.safe_load(blob.data_stream.read().decode()) + try: + blob = sh.git.show( + f"{hexsha}:{DEFAULT_CONFIG_FILE}", + _cwd=self.repo_location, + _tty_out=False, + _async=False, + ) + raw_obj = yaml.safe_load(blob) + except sh.ErrorReturnCode as e: + raise BadConfigurationVersion(f"Error reading configuration: {e}") from e config_class: Config = select_from_extension(group="diracx", name="config")[ 0 @@ -177,7 +204,19 @@ def __init__(self, *, backend_url: ConfigSourceUrl) -> None: raise ValueError("Empty path for LocalGitConfigSource") self.repo_location = Path(backend_url.path) - self.repo = git.Repo(self.repo_location) + # Check if it's a valid git repository + try: + sh.git( + "rev-parse", + "--git-dir", + _cwd=self.repo_location, + _tty_out=False, + _async=False, + ) + except sh.ErrorReturnCode as e: + raise ValueError( + f"{self.repo_location} is not a valid git repository" + ) from e def __hash__(self): return hash(self.repo_location) @@ -197,7 +236,7 @@ def __init__(self, *, backend_url: ConfigSourceUrl) -> None: self.remote_url = str(backend_url).replace("git+", "") self._temp_dir = TemporaryDirectory() self.repo_location = Path(self._temp_dir.name) - self.repo = git.Repo.clone_from(self.remote_url, self.repo_location) + sh.git.clone(self.remote_url, self.repo_location, _async=False) self._pull_cache: Cache = TTLCache( MAX_PULL_CACHED_VERSIONS, DEFAULT_PULL_CACHE_TTL ) @@ -212,7 +251,7 @@ def __hash__(self): @cachedmethod(lambda self: self._pull_cache) def _pull(self): """Git pull from remote repo.""" - self.repo.remotes.origin.pull() + sh.git.pull(_cwd=self.repo_location, _async=False) def latest_revision(self) -> tuple[str, datetime]: self._pull() diff --git a/pyproject.toml b/pyproject.toml index 4109f0cf..06b78e3b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,6 +120,10 @@ ignore_missing_imports = true module = 'authlib.*' ignore_missing_imports = true +[[tool.mypy.overrides]] +module = 'sh.*' +ignore_missing_imports = true + [tool.pytest.ini_options] testpaths = [ "diracx-api/tests",