Skip to content

Commit

Permalink
api: traversable files API
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Jan 4, 2022
1 parent 0a4958b commit 9546b36
Show file tree
Hide file tree
Showing 3 changed files with 338 additions and 35 deletions.
66 changes: 31 additions & 35 deletions dvc/api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
import os
from contextlib import _GeneratorContextManager as GCM
from contextlib import contextmanager
from typing import ContextManager, Iterator

from funcy import reraise

from dvc.exceptions import OutputNotFoundError, PathMissingError
from dvc.exceptions import NoOutputInExternalRepoError, OutputNotFoundError
from dvc.repo import Repo
from dvc.repo_path import RepoPath


def files(path=os.curdir, repo=None, rev=None) -> ContextManager[RepoPath]:
@contextmanager
def inner() -> Iterator["RepoPath"]:
with Repo.open(
repo, rev=rev, subrepos=True, uninitialized=True
) as root_repo:
yield RepoPath(path, fs=root_repo.repo_fs)

return inner()


def get_url(path, repo=None, rev=None, remote=None):
Expand All @@ -18,17 +30,11 @@ def get_url(path, repo=None, rev=None, remote=None):
NOTE: This function does not check for the actual existence of the file or
directory in the remote storage.
"""
with Repo.open(repo, rev=rev, subrepos=True, uninitialized=True) as _repo:
fs_path = _repo.fs.path.join(_repo.root_dir, path)
with reraise(FileNotFoundError, PathMissingError(path, repo)):
metadata = _repo.repo_fs.metadata(fs_path)

if not metadata.is_dvc:
raise OutputNotFoundError(path, repo)

cloud = metadata.repo.cloud
md5 = metadata.repo.dvcfs.info(fs_path)["md5"]
return cloud.get_url_for(remote, checksum=md5)
try:
with files(path, repo=repo, rev=rev) as path_obj:
return path_obj.url(remote=remote)
except NoOutputInExternalRepoError as exc:
raise OutputNotFoundError(exc.path, repo=repo)


def open( # noqa, pylint: disable=redefined-builtin
Expand All @@ -46,15 +52,15 @@ def open( # noqa, pylint: disable=redefined-builtin
) as fd:
# ... Handle file object fd
"""
args = (path,)
kwargs = {
"repo": repo,
"remote": remote,
"rev": rev,
"mode": mode,
"encoding": encoding,
}
return _OpenContextManager(_open, args, kwargs)

def _open():
with files(path, repo=repo, rev=rev) as path_obj:
with path_obj.open( # pylint: disable=not-context-manager
remote=remote, mode=mode, encoding=encoding
) as fd:
yield fd

return _OpenContextManager(_open, (), {})


class _OpenContextManager(GCM):
Expand All @@ -70,24 +76,14 @@ def __getattr__(self, name):
)


def _open(path, repo=None, rev=None, remote=None, mode="r", encoding=None):
with Repo.open(repo, rev=rev, subrepos=True, uninitialized=True) as _repo:
with _repo.open_by_relpath(
path, remote=remote, mode=mode, encoding=encoding
) as fd:
yield fd


def read(path, repo=None, rev=None, remote=None, mode="r", encoding=None):
"""
Returns the contents of a tracked file (by DVC or Git). For Git repos, HEAD
is used unless a rev argument is supplied. The default remote is tried
unless a remote argument is supplied.
"""
with open(
path, repo=repo, rev=rev, remote=remote, mode=mode, encoding=encoding
) as fd:
return fd.read()
with files(path, repo=repo, rev=rev) as path_obj:
return path_obj.read(remote=remote, mode=mode, encoding=encoding)


def make_checkpoint():
Expand Down
1 change: 1 addition & 0 deletions dvc/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ class NoOutputInExternalRepoError(DvcException):
def __init__(self, path, external_repo_path, external_repo_url):
from dvc.utils import relpath

self.path = path
super().__init__(
"Output '{}' not found in target repository '{}'".format(
relpath(path, external_repo_path), external_repo_url
Expand Down
Loading

0 comments on commit 9546b36

Please sign in to comment.