From 1b5a45508b9c39f98cc39d8c108f9cab2dec5519 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Wed, 15 Mar 2023 12:35:02 +0545 Subject: [PATCH] dvcfs: support caching remote streams; use that in plots --- dvc/fs/dvc.py | 7 +++-- dvc/repo/plots/__init__.py | 47 ++++++++++++++++++--------------- dvc/utils/serialize/__init__.py | 7 +++++ 3 files changed, 38 insertions(+), 23 deletions(-) diff --git a/dvc/fs/dvc.py b/dvc/fs/dvc.py index 7a78ec87868..423cea3db37 100644 --- a/dvc/fs/dvc.py +++ b/dvc/fs/dvc.py @@ -251,12 +251,15 @@ def _open( try: return self.repo.fs.open(fs_path, mode=mode) except FileNotFoundError: - _, dvc_fs, subkey = self._get_subrepo_info(key) + repo, dvc_fs, subkey = self._get_subrepo_info(key) if not dvc_fs: raise dvc_path = _get_dvc_path(dvc_fs, subkey) - return dvc_fs.open(dvc_path, mode=mode) + cache_odb = None + if kwargs.get("cache_remote_stream", False): + cache_odb = repo.cache.local + return dvc_fs.open(dvc_path, mode=mode, cache_odb=cache_odb) def isdvc(self, path, **kwargs) -> bool: """Is this entry dvc-tracked?""" diff --git a/dvc/repo/plots/__init__.py b/dvc/repo/plots/__init__.py index cf36d36c3cd..9f08c7fb8fb 100644 --- a/dvc/repo/plots/__init__.py +++ b/dvc/repo/plots/__init__.py @@ -20,13 +20,14 @@ import dpath import dpath.options -from funcy import distinct, first, project +from funcy import distinct, first, project, reraise from dvc.exceptions import DvcException from dvc.utils import error_handler, errored_revisions, onerror_collect from dvc.utils.objects import cached_property -from dvc.utils.serialize import LOADERS +from dvc.utils.serialize import PARSERS, EncodingError from dvc.utils.threadpool import ThreadPoolExecutor +from dvc_render.image import ImageRenderer if TYPE_CHECKING: from dvc.fs import FileSystem @@ -39,6 +40,9 @@ logger = logging.getLogger(__name__) +SUPPORTED_IMAGE_EXTENSIONS = ImageRenderer.EXTENSIONS + + class PlotMetricTypeError(DvcException): def __init__(self, file): super().__init__( @@ -196,7 +200,7 @@ def show( onerror=onerror, props=props, ): - _resolve_data_sources(data) + _resolve_data_sources(data, cache_remote_stream=True) result.update(data) errored = errored_revisions(result) @@ -265,7 +269,7 @@ def _is_plot(out: "Output") -> bool: return bool(out.plot) -def _resolve_data_sources(plots_data: Dict): +def _resolve_data_sources(plots_data: Dict, cache_remote_stream: bool = False): values = list(plots_data.values()) to_resolve = [] while values: @@ -278,7 +282,7 @@ def _resolve_data_sources(plots_data: Dict): def resolve(value): data_source = value.pop("data_source") assert callable(data_source) - value.update(data_source()) + value.update(data_source(cache_remote_stream=cache_remote_stream)) executor = ThreadPoolExecutor( max_workers=4 * cpu_count(), @@ -524,20 +528,25 @@ def unpack_if_dir(fs, path, props: Dict[str, str], onerror: Optional[Callable] = @error_handler -def parse(fs, path, props=None, **kwargs): +def parse(fs, path, props=None, **fs_kwargs): props = props or {} _, extension = os.path.splitext(path) - if extension in (".tsv", ".csv"): - header = props.get("header", True) - if extension == ".csv": - return _load_sv(path=path, fs=fs, delimiter=",", header=header) - return _load_sv(path=path, fs=fs, delimiter="\t", header=header) - if extension in LOADERS or extension in (".yml", ".yaml"): - return LOADERS[extension](path=path, fs=fs) - if extension in (".jpeg", ".jpg", ".gif", ".png", ".svg"): - with fs.open(path, "rb") as fd: + if extension in SUPPORTED_IMAGE_EXTENSIONS: + with fs.open(path, mode="rb", **fs_kwargs) as fd: return fd.read() - raise PlotMetricTypeError(path) + + if extension not in PARSERS.keys() | {".yml", ".yaml", ".csv", ".tsv"}: + raise PlotMetricTypeError(path) + + with reraise(UnicodeDecodeError, EncodingError(path, "utf8")): + with fs.open(path, mode="r", encoding="utf8", **fs_kwargs) as fd: + contents = fd.read() + + if extension in (".csv", ".tsv"): + header = props.get("header", True) + delim = "\t" if extension == ".tsv" else "," + return _load_sv(contents, delimiter=delim, header=header) + return PARSERS[extension](contents, path) def _plot_props(out: "Output") -> Dict: @@ -553,10 +562,7 @@ def _plot_props(out: "Output") -> Dict: return project(out.plot, PLOT_PROPS) -def _load_sv(path, fs, delimiter=",", header=True): - with fs.open(path, "r") as fd: - content = fd.read() - +def _load_sv(content, delimiter=",", header=True): if header: reader = csv.DictReader(io.StringIO(content), delimiter=delimiter) else: @@ -566,5 +572,4 @@ def _load_sv(path, fs, delimiter=",", header=True): delimiter=delimiter, fieldnames=[str(i) for i in range(len(first_row))], ) - return list(reader) diff --git a/dvc/utils/serialize/__init__.py b/dvc/utils/serialize/__init__.py index bd8bfa75a64..c5f7f9759b6 100644 --- a/dvc/utils/serialize/__init__.py +++ b/dvc/utils/serialize/__init__.py @@ -12,6 +12,13 @@ ) LOADERS.update({".toml": load_toml, ".json": load_json, ".py": load_py}) # noqa: F405 +PARSERS: DefaultDict[str, ParserFn] = defaultdict( # noqa: F405 + lambda: parse_yaml # noqa: F405 +) +PARSERS.update( + {".toml": parse_toml, ".json": parse_json, ".py": parse_py} # noqa: F405 +) + def load_path(fs_path, fs): suffix = fs.path.suffix(fs_path).lower()