Skip to content

Commit

Permalink
Merge pull request #340 from chaen/git_sh
Browse files Browse the repository at this point in the history
use sh.git instead of git module
  • Loading branch information
chaen authored Dec 5, 2024
2 parents 52460ef + 5ef5b17 commit 2d1bb99
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 16 deletions.
1 change: 1 addition & 0 deletions diracx-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies = [
"pydantic >=2.10",
"pydantic-settings",
"pyyaml",
"sh",
]
dynamic = ["version"]

Expand Down
71 changes: 55 additions & 16 deletions diracx-core/src/diracx/core/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

__all__ = ("Config", "ConfigSource", "LocalGitConfigSource", "RemoteGitConfigSource")

import asyncio
import logging
import os
from abc import ABCMeta, abstractmethod
Expand All @@ -15,7 +16,7 @@
from tempfile import TemporaryDirectory
from typing import Annotated

import git
import sh
import yaml
from cachetools import Cache, LRUCache, TTLCache, cachedmethod
from pydantic import AnyUrl, BeforeValidator, TypeAdapter, UrlConstraints
Expand All @@ -34,6 +35,14 @@
logger = logging.getLogger(__name__)


def is_running_in_async_context():
try:
asyncio.get_running_loop()
return True
except RuntimeError:
return False


def _apply_default_scheme(value: str) -> str:
"""Applies the default git+file:// scheme if not present."""
if isinstance(value, str) and "://" not in value:
Expand Down Expand Up @@ -117,10 +126,9 @@ class BaseGitConfigSource(ConfigSource):
The caching is based on 2 caches:
* TTL to find the latest commit hashes
* LRU to keep in memory the last few versions.
"""

repo: git.Repo
repo_location: Path

# Needed because of the ConfigSource.__init_subclass__
scheme = "basegit"
Expand All @@ -134,22 +142,41 @@ def __init__(self, *, backend_url: ConfigSourceUrl) -> None:
@cachedmethod(lambda self: self._latest_revision_cache)
def latest_revision(self) -> tuple[str, datetime]:
try:
rev = self.repo.rev_parse(DEFAULT_GIT_BRANCH)
except git.exc.ODBError as e: # type: ignore
rev = sh.git(
"rev-parse",
DEFAULT_GIT_BRANCH,
_cwd=self.repo_location,
_tty_out=False,
_async=is_running_in_async_context(),
).strip()
commit_info = sh.git.show(
"-s",
"--format=%ct",
rev,
_cwd=self.repo_location,
_tty_out=False,
_async=is_running_in_async_context(),
).strip()
modified = datetime.fromtimestamp(int(commit_info), tz=timezone.utc)
except sh.ErrorReturnCode as e:
raise BadConfigurationVersion(f"Error parsing latest revision: {e}") from e
modified = rev.committed_datetime.astimezone(timezone.utc)
logger.debug(
"Latest revision for %s is %s with mtime %s", self, rev.hexsha, modified
)
return rev.hexsha, modified
logger.debug("Latest revision for %s is %s with mtime %s", self, rev, modified)
return rev, modified

@cachedmethod(lambda self: self._read_raw_cache)
def read_raw(self, hexsha: str, modified: datetime) -> Config:
""":param: hexsha commit hash"""
logger.debug("Reading %s for %s with mtime %s", self, hexsha, modified)
rev = self.repo.rev_parse(hexsha)
blob = rev.tree / DEFAULT_CONFIG_FILE
raw_obj = yaml.safe_load(blob.data_stream.read().decode())
try:
blob = sh.git.show(
f"{hexsha}:{DEFAULT_CONFIG_FILE}",
_cwd=self.repo_location,
_tty_out=False,
_async=False,
)
raw_obj = yaml.safe_load(blob)
except sh.ErrorReturnCode as e:
raise BadConfigurationVersion(f"Error reading configuration: {e}") from e

config_class: Config = select_from_extension(group="diracx", name="config")[
0
Expand Down Expand Up @@ -177,7 +204,19 @@ def __init__(self, *, backend_url: ConfigSourceUrl) -> None:
raise ValueError("Empty path for LocalGitConfigSource")

self.repo_location = Path(backend_url.path)
self.repo = git.Repo(self.repo_location)
# Check if it's a valid git repository
try:
sh.git(
"rev-parse",
"--git-dir",
_cwd=self.repo_location,
_tty_out=False,
_async=False,
)
except sh.ErrorReturnCode as e:
raise ValueError(
f"{self.repo_location} is not a valid git repository"
) from e

def __hash__(self):
return hash(self.repo_location)
Expand All @@ -197,7 +236,7 @@ def __init__(self, *, backend_url: ConfigSourceUrl) -> None:
self.remote_url = str(backend_url).replace("git+", "")
self._temp_dir = TemporaryDirectory()
self.repo_location = Path(self._temp_dir.name)
self.repo = git.Repo.clone_from(self.remote_url, self.repo_location)
sh.git.clone(self.remote_url, self.repo_location, _async=False)
self._pull_cache: Cache = TTLCache(
MAX_PULL_CACHED_VERSIONS, DEFAULT_PULL_CACHE_TTL
)
Expand All @@ -212,7 +251,7 @@ def __hash__(self):
@cachedmethod(lambda self: self._pull_cache)
def _pull(self):
"""Git pull from remote repo."""
self.repo.remotes.origin.pull()
sh.git.pull(_cwd=self.repo_location, _async=False)

def latest_revision(self) -> tuple[str, datetime]:
self._pull()
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ ignore_missing_imports = true
module = 'authlib.*'
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = 'sh.*'
ignore_missing_imports = true

[tool.pytest.ini_options]
testpaths = [
"diracx-api/tests",
Expand Down

0 comments on commit 2d1bb99

Please sign in to comment.