diff --git a/dvc/repo/experiments/push.py b/dvc/repo/experiments/push.py index 9cc56784e3..aeeb56b817 100644 --- a/dvc/repo/experiments/push.py +++ b/dvc/repo/experiments/push.py @@ -1,12 +1,12 @@ import logging -from typing import Iterable, List, Mapping, Optional, Set, Union +from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, Optional, Set, Union from funcy import compact, group_by from scmrepo.git.backend.base import SyncStatus from dvc.repo import locked from dvc.repo.scm_context import scm_context -from dvc.scm import TqdmGit, iter_revs +from dvc.scm import Git, TqdmGit, iter_revs from dvc.ui import ui from dvc.utils import env2bool @@ -14,10 +14,14 @@ from .refs import ExpRefInfo from .utils import exp_commits, exp_refs, exp_refs_by_baseline, resolve_name +if TYPE_CHECKING: + from dvc.repo import Repo + logger = logging.getLogger(__name__) -def notify_refs_to_studio(scm, config, git_remote: str, **refs: List[str]): +def notify_refs_to_studio(repo: "Repo", git_remote: str, **refs: List[str]) -> None: + config = repo.config["feature"] refs = compact(refs) if not refs or env2bool("DVC_TEST"): return @@ -40,7 +44,7 @@ def notify_refs_to_studio(scm, config, git_remote: str, **refs: List[str]): from dvc.utils import studio - _, repo_url = get_remote_repo(scm.dulwich.repo, git_remote) + _, repo_url = get_remote_repo(repo.scm.dulwich.repo, git_remote) studio_url = config.get("studio_url") studio.notify_refs( repo_url, @@ -53,17 +57,18 @@ def notify_refs_to_studio(scm, config, git_remote: str, **refs: List[str]): @locked @scm_context def push( # noqa: C901 - repo, + repo: "Repo", git_remote: str, exp_names: Union[Iterable[str], str], - all_commits=False, + all_commits: bool = False, rev: Optional[str] = None, - num=1, + num: int = 1, force: bool = False, push_cache: bool = False, - **kwargs, + **kwargs: Any, ) -> Iterable[str]: exp_ref_set: Set["ExpRefInfo"] = set() + assert isinstance(repo.scm, Git) if all_commits: exp_ref_set.update(exp_refs(repo.scm)) @@ -106,14 +111,12 @@ def push( # noqa: C901 refs = push_result[SyncStatus.SUCCESS] pushed_refs = [str(r) for r in refs] - notify_refs_to_studio( - repo.scm, repo.config["feature"], git_remote, pushed=pushed_refs - ) + notify_refs_to_studio(repo, git_remote, pushed=pushed_refs) return [ref.name for ref in refs] def _push( - repo, + repo: "Repo", git_remote: str, refs: Iterable["ExpRefInfo"], force: bool, @@ -145,14 +148,15 @@ def group_result(refspec): def _push_cache( - repo, + repo: "Repo", refs: Union[ExpRefInfo, Iterable["ExpRefInfo"]], - dvc_remote=None, - jobs=None, - run_cache=False, + dvc_remote: Optional[str] = None, + jobs: Optional[int] = None, + run_cache: bool = False, ): if isinstance(refs, ExpRefInfo): refs = [refs] + assert isinstance(repo.scm, Git) revs = list(exp_commits(repo.scm, refs)) logger.debug("dvc push experiment '%s'", refs) repo.push(jobs=jobs, remote=dvc_remote, run_cache=run_cache, revs=revs) diff --git a/dvc/repo/experiments/remove.py b/dvc/repo/experiments/remove.py index 74cb64ae06..ec5b8f6261 100644 --- a/dvc/repo/experiments/remove.py +++ b/dvc/repo/experiments/remove.py @@ -79,18 +79,18 @@ def remove( # noqa: C901, PLR0912 from .push import notify_refs_to_studio removed_refs = [str(r) for r in exp_ref_list] - notify_refs_to_studio( - repo.scm, repo.config["feature"], git_remote, removed=removed_refs - ) + notify_refs_to_studio(repo, git_remote, removed=removed_refs) return removed def _resolve_exp_by_baseline( - repo, + repo: "Repo", rev: str, num: int, git_remote: Optional[str] = None, ) -> Dict[str, "ExpRefInfo"]: + assert isinstance(repo.scm, Git) + commit_ref_dict: Dict[str, "ExpRefInfo"] = {} rev_dict = iter_revs(repo.scm, [rev], num) rev_set = set(rev_dict.keys())