Skip to content

Commit

Permalink
dvcfs: support caching remote streams; use that in plots
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Mar 15, 2023
1 parent 54a6709 commit 1b5a455
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 23 deletions.
7 changes: 5 additions & 2 deletions dvc/fs/dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,15 @@ def _open(
try:
return self.repo.fs.open(fs_path, mode=mode)
except FileNotFoundError:
_, dvc_fs, subkey = self._get_subrepo_info(key)
repo, dvc_fs, subkey = self._get_subrepo_info(key)
if not dvc_fs:
raise

dvc_path = _get_dvc_path(dvc_fs, subkey)
return dvc_fs.open(dvc_path, mode=mode)
cache_odb = None
if kwargs.get("cache_remote_stream", False):
cache_odb = repo.cache.local
return dvc_fs.open(dvc_path, mode=mode, cache_odb=cache_odb)

def isdvc(self, path, **kwargs) -> bool:
"""Is this entry dvc-tracked?"""
Expand Down
47 changes: 26 additions & 21 deletions dvc/repo/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@

import dpath
import dpath.options
from funcy import distinct, first, project
from funcy import distinct, first, project, reraise

from dvc.exceptions import DvcException
from dvc.utils import error_handler, errored_revisions, onerror_collect
from dvc.utils.objects import cached_property
from dvc.utils.serialize import LOADERS
from dvc.utils.serialize import PARSERS, EncodingError
from dvc.utils.threadpool import ThreadPoolExecutor
from dvc_render.image import ImageRenderer

if TYPE_CHECKING:
from dvc.fs import FileSystem
Expand All @@ -39,6 +40,9 @@
logger = logging.getLogger(__name__)


SUPPORTED_IMAGE_EXTENSIONS = ImageRenderer.EXTENSIONS


class PlotMetricTypeError(DvcException):
def __init__(self, file):
super().__init__(
Expand Down Expand Up @@ -196,7 +200,7 @@ def show(
onerror=onerror,
props=props,
):
_resolve_data_sources(data)
_resolve_data_sources(data, cache_remote_stream=True)
result.update(data)

errored = errored_revisions(result)
Expand Down Expand Up @@ -265,7 +269,7 @@ def _is_plot(out: "Output") -> bool:
return bool(out.plot)


def _resolve_data_sources(plots_data: Dict):
def _resolve_data_sources(plots_data: Dict, cache_remote_stream: bool = False):
values = list(plots_data.values())
to_resolve = []
while values:
Expand All @@ -278,7 +282,7 @@ def _resolve_data_sources(plots_data: Dict):
def resolve(value):
data_source = value.pop("data_source")
assert callable(data_source)
value.update(data_source())
value.update(data_source(cache_remote_stream=cache_remote_stream))

executor = ThreadPoolExecutor(
max_workers=4 * cpu_count(),
Expand Down Expand Up @@ -524,20 +528,25 @@ def unpack_if_dir(fs, path, props: Dict[str, str], onerror: Optional[Callable] =


@error_handler
def parse(fs, path, props=None, **kwargs):
def parse(fs, path, props=None, **fs_kwargs):
props = props or {}
_, extension = os.path.splitext(path)
if extension in (".tsv", ".csv"):
header = props.get("header", True)
if extension == ".csv":
return _load_sv(path=path, fs=fs, delimiter=",", header=header)
return _load_sv(path=path, fs=fs, delimiter="\t", header=header)
if extension in LOADERS or extension in (".yml", ".yaml"):
return LOADERS[extension](path=path, fs=fs)
if extension in (".jpeg", ".jpg", ".gif", ".png", ".svg"):
with fs.open(path, "rb") as fd:
if extension in SUPPORTED_IMAGE_EXTENSIONS:
with fs.open(path, mode="rb", **fs_kwargs) as fd:
return fd.read()
raise PlotMetricTypeError(path)

if extension not in PARSERS.keys() | {".yml", ".yaml", ".csv", ".tsv"}:
raise PlotMetricTypeError(path)

with reraise(UnicodeDecodeError, EncodingError(path, "utf8")):
with fs.open(path, mode="r", encoding="utf8", **fs_kwargs) as fd:
contents = fd.read()

if extension in (".csv", ".tsv"):
header = props.get("header", True)
delim = "\t" if extension == ".tsv" else ","
return _load_sv(contents, delimiter=delim, header=header)
return PARSERS[extension](contents, path)


def _plot_props(out: "Output") -> Dict:
Expand All @@ -553,10 +562,7 @@ def _plot_props(out: "Output") -> Dict:
return project(out.plot, PLOT_PROPS)


def _load_sv(path, fs, delimiter=",", header=True):
with fs.open(path, "r") as fd:
content = fd.read()

def _load_sv(content, delimiter=",", header=True):
if header:
reader = csv.DictReader(io.StringIO(content), delimiter=delimiter)
else:
Expand All @@ -566,5 +572,4 @@ def _load_sv(path, fs, delimiter=",", header=True):
delimiter=delimiter,
fieldnames=[str(i) for i in range(len(first_row))],
)

return list(reader)
7 changes: 7 additions & 0 deletions dvc/utils/serialize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@
)
LOADERS.update({".toml": load_toml, ".json": load_json, ".py": load_py}) # noqa: F405

PARSERS: DefaultDict[str, ParserFn] = defaultdict( # noqa: F405
lambda: parse_yaml # noqa: F405
)
PARSERS.update(
{".toml": parse_toml, ".json": parse_json, ".py": parse_py} # noqa: F405
)


def load_path(fs_path, fs):
suffix = fs.path.suffix(fs_path).lower()
Expand Down

0 comments on commit 1b5a455

Please sign in to comment.