From 9c960d00686a5f537ff5597fcd7dd29254c8c0a4 Mon Sep 17 00:00:00 2001 From: daavoo Date: Wed, 10 May 2023 14:50:53 +0200 Subject: [PATCH] repro: Move pull logic to inside Stage. --- dvc/repo/reproduce.py | 4 ---- dvc/stage/__init__.py | 16 +++++++++++----- tests/func/test_run_cache.py | 2 +- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/dvc/repo/reproduce.py b/dvc/repo/reproduce.py index c4129d49bb..fdf4999b05 100644 --- a/dvc/repo/reproduce.py +++ b/dvc/repo/reproduce.py @@ -194,10 +194,6 @@ def _reproduce_stages( # noqa: C901 ) try: - if kwargs.get("pull") and stage.changed(): - logger.debug("Pulling %s", stage.addressing) - stage.repo.pull(stage.addressing, allow_missing=True) - ret = _reproduce_stage(stage, **kwargs) if len(ret) == 0: diff --git a/dvc/stage/__init__.py b/dvc/stage/__init__.py index d982859494..53e5059032 100644 --- a/dvc/stage/__init__.py +++ b/dvc/stage/__init__.py @@ -561,7 +561,7 @@ def commit(self, allow_missing=False, filter_info=None, **kwargs) -> None: raise CacheLinkError(link_failures) @rwlocked(read=["deps", "outs"]) - def run( + def run( # noqa: C901 self, dry=False, no_commit=False, @@ -573,15 +573,21 @@ def run( if (self.cmd or self.is_import) and not self.frozen and not dry: self.remove_outs(ignore_remove=False, force=False) - if (not self.frozen and self.is_import) or self.is_partial_import: + if ( + self.is_import and (not self.frozen or kwargs.get("pull")) + ) or self.is_partial_import: self._sync_import(dry, force, kwargs.get("jobs", None), no_download) elif not self.frozen and self.cmd: self._run_stage(dry, force, **kwargs) - else: + elif kwargs.get("pull"): + logger.info("Pulling data for %s", self) + for objs in self.get_used_objs().values(): + self.repo.cloud.pull(objs) + self.checkout() + elif not dry: args = ("outputs", "frozen ") if self.frozen else ("data sources", "") logger.info("Verifying %s in %s%s", *args, self) - if not dry: - self._check_missing_outputs() + self._check_missing_outputs() if not dry: if kwargs.get("checkpoint_func", None) or no_download: diff --git a/tests/func/test_run_cache.py b/tests/func/test_run_cache.py index dc907c2f4c..ea12eb147c 100644 --- a/tests/func/test_run_cache.py +++ b/tests/func/test_run_cache.py @@ -182,7 +182,7 @@ def test_restore_pull(tmp_dir, dvc, run_copy, mocker, local_remote): mock_restore.assert_called_once_with(stage, pull=True, dry=False) mock_run.assert_not_called() - assert mock_checkout.call_count == 3 + assert mock_checkout.call_count == 2 assert (tmp_dir / "bar").exists() assert not (tmp_dir / "foo").unlink() assert (tmp_dir / LOCK_FILE).exists()