Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(datastore): use flat structure to avoid too many sub directories #2809

Merged
merged 3 commits into from
Oct 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 65 additions & 44 deletions client/starwhale/api/_impl/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import contextlib
from abc import ABCMeta, abstractmethod
from http import HTTPStatus
from uuid import uuid4
from typing import (
Any,
cast,
Expand Down Expand Up @@ -896,19 +897,6 @@ def __eq__(self, other: Any) -> bool:
)


def _get_table_path(root_path: str | Path, table_name: str) -> Path:
"""
get table path from table name, return the matched file path if there is only one file match the table name
"""
expect_prefix = Path(root_path) / (table_name.strip("/") + datastore_table_file_ext)
paths = list(expect_prefix.parent.glob(f"{expect_prefix.name}*"))
if len(paths) > 1:
raise RuntimeError(f"can not find table {table_name}, get files {paths}")
if len(paths) == 1:
return paths[0]
return expect_prefix


def _merge_scan(
iters: List[Iterator[Dict[str, Any]]], keep_none: bool
) -> Iterator[dict]:
Expand Down Expand Up @@ -1983,6 +1971,17 @@ def to_dict(self) -> Dict[str, Any]:
return ret


class LocalTableDesc(SwBaseModel):
name: str
dir: str
created_at: int # timestamp in milliseconds


class LocalDataStoreManifest(SwBaseModel):
version: str = "0.1"
tables: List[LocalTableDesc]


class LocalDataStore:
_instance = None
_lock = threading.Lock()
Expand Down Expand Up @@ -2012,24 +2011,12 @@ def list_tables(
self,
prefixes: List[str],
) -> List[str]:
table_names = set()

for prefix in prefixes:
prefix_path = Path(self.root_path) / prefix.strip("/")
for fpath in prefix_path.rglob(f"*{datastore_table_file_ext}*"):
if not fpath.is_file():
continue

table_name = str(fpath.relative_to(self.root_path)).split(
datastore_table_file_ext
)[0]
table_names.add(table_name)

for table in self.tables:
if table.startswith(prefix):
table_names.add(table)

return list(table_names)
manifest = self._load_manifest()
tianweidut marked this conversation as resolved.
Show resolved Hide resolved
return [
table.name
for table in manifest.tables
if any(table.name.startswith(prefix) for prefix in prefixes)
]

def update_table(
self,
Expand Down Expand Up @@ -2074,25 +2061,59 @@ def update_table(
# revision will never be None or empty (len(records) > 0), makes mypy happy
return revision or ""

def _load_manifest(self) -> LocalDataStoreManifest:
manifest_file = Path(self.root_path) / datastore_manifest_file_name
if manifest_file.exists():
with manifest_file.open() as f:
return LocalDataStoreManifest.parse_raw(f.read())
return LocalDataStoreManifest(tables=[])

def _dump_manifest(self, manifest: LocalDataStoreManifest) -> None:
manifest_file = Path(self.root_path) / datastore_manifest_file_name
with filelock.FileLock(str(Path(self.root_path) / ".lock")):
with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp:
tmp.write(manifest.json(indent=2))
tmp.flush()
shutil.move(tmp.name, manifest_file)

def _get_table(
self, table_name: str, key_column: ColumnSchema | None, create: bool = True
) -> LocalTable | None:
with self.lock:
table = self.tables.get(table_name, None)
if table is None:
try:
table = LocalTable(
table_name=table_name,
root_path=Path(self.root_path) / table_name,
key_column=key_column and key_column.name or None,
create_if_missing=create,
)
self.tables[table_name] = table
except TableEmptyException:
if create:
raise
if table is not None:
return table
# open or create
manifest = self._load_manifest()
table_root: Path | None = None
for t in manifest.tables:
if t.name == table_name:
table_root = Path(self.root_path) / t.dir
break
if table_root is None:
if not create:
return None
return table
uuid = uuid4().hex
table_root = Path(self.root_path) / uuid[:2] / uuid[2:-1]

table_root.mkdir(parents=True, exist_ok=True)
# no try, let it raise
table = LocalTable(
table_name=table_name,
root_path=table_root,
key_column=key_column and key_column.name or None,
create_if_missing=create,
)
manifest.tables.append(
LocalTableDesc(
name=table_name,
dir=str(table_root.relative_to(self.root_path)),
created_at=int(time.time() * 1000),
)
)
self._dump_manifest(manifest)
self.tables[table_name] = table
return table

def scan_tables(
self,
Expand Down
14 changes: 5 additions & 9 deletions client/tests/core/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
import json
import typing as t
from pathlib import Path
from unittest import TestCase
from unittest.mock import patch, MagicMock

import yaml
from click.testing import CliRunner
from requests_mock import Mocker
from pyfakefs.fake_filesystem_unittest import TestCase

from tests import ROOT_DIR
from tests import ROOT_DIR, BaseTestCase
from starwhale.utils import config as sw_config
from starwhale.utils import load_yaml
from starwhale.consts import (
Expand Down Expand Up @@ -45,11 +45,7 @@
_dataset_yaml = open(f"{_dataset_data_dir}/dataset.yaml").read()


class StandaloneDatasetTestCase(TestCase):
def setUp(self) -> None:
self.setUpPyfakefs()
sw_config._config = {}

class StandaloneDatasetTestCase(BaseTestCase):
@patch("starwhale.base.uri.resource.Resource._refine_local_rc_info")
@patch("starwhale.api._impl.dataset.model.Dataset.commit")
@patch("starwhale.api._impl.dataset.model.Dataset.__setitem__")
Expand Down Expand Up @@ -382,7 +378,7 @@ def __iter__(self) -> t.Generator:

sw = SWCliConfigMixed()

workdir = "/home/starwhale/myproject"
workdir = self.local_storage
name = "mnist"

ensure_dir(os.path.join(workdir, "data"))
Expand Down Expand Up @@ -419,7 +415,7 @@ def __iter__(self) -> t.Generator:
assert isinstance(_info, LocalDatasetInfo)
assert _info.version == build_version
assert _info.name == name
assert _info.path == str(snapshot_workdir.resolve())
assert str(Path(_info.path).resolve()) == str(snapshot_workdir.resolve())

tags = sd.tag.list()
assert set(tags) == {"t0", "t1", "latest", "v0"}
Expand Down
36 changes: 7 additions & 29 deletions client/tests/sdk/test_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from tests import BaseTestCase
from starwhale.consts import HTTPMethod
from starwhale.utils.fs import ensure_dir, ensure_file
from starwhale.api._impl import data_store
from starwhale.api._impl.data_store import (
INT32,
Expand All @@ -32,20 +31,6 @@


class TestBasicFunctions(BaseTestCase):
def test_get_table_path(self) -> None:
self.assertEqual(
Path("a") / "b.sw-datastore",
data_store._get_table_path("a", "b"),
)
self.assertEqual(
Path("a") / "b" / "c.sw-datastore",
data_store._get_table_path("a", "b/c"),
)
self.assertEqual(
Path("a") / "b" / "c" / "d.sw-datastore",
data_store._get_table_path("a", "b/c/d"),
)

def test_merge_scan(self) -> None:
self.assertEqual([], list(data_store._merge_scan([], False)), "no iter")
self.assertEqual(
Expand Down Expand Up @@ -1160,25 +1145,19 @@ def ds_update(index: int) -> bool:
def test_list_tables(self) -> None:
ds = data_store.LocalDataStore(self.datastore_root)

prefix = "project/self/eval/test-0/"
prefix = "project/self/eval/test-0"

tables = ds.list_tables([prefix])
assert tables == []

root = Path(self.datastore_root) / prefix.strip("/")
schema = data_store.TableSchema(
"k", [data_store.ColumnSchema("k", data_store.INT64)]
)
ds.update_table(f"{prefix}/labels", schema, [{"k": 0}])
ds.update_table(f"{prefix}/results", schema, [{"k": 0}])

ensure_file(root / "labels.sw-datastore.zip", "abc", parents=True)
ensure_file(root / "results.sw-datastore.zip", "abc", parents=True)
for i in range(0, 3):
ensure_file(root / "roc" / f"{i}.sw-datastore.zip", "abc", parents=True)

ensure_dir(root / "mock-dir.sw-datastore.zip")
ensure_file(root / "dummy.file", "abc", parents=True)

m_table_name = f"{prefix}memory-test-table"
ds.tables[m_table_name] = data_store.LocalTable(
m_table_name, self.datastore_root, "k"
)
ds.update_table(f"{prefix}/roc/{i}", schema, [{"k": 0}])

tables = ds.list_tables([prefix])
assert set(tables) == {
Expand All @@ -1189,7 +1168,6 @@ def test_list_tables(self) -> None:
"roc/0",
"roc/1",
"roc/2",
"memory-test-table",
}
}

Expand Down
18 changes: 9 additions & 9 deletions client/tests/sdk/test_track.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from starwhale.utils.process import check_call
from starwhale.base.data_type import Link, Audio, Image
from starwhale.base.uri.project import Project
from starwhale.api._impl.data_store import TableDesc, TableWriter
from starwhale.api._impl.data_store import TableDesc, TableWriter, LocalDataStore
from starwhale.api._impl.track.base import (
_TrackType,
TrackRecord,
Expand Down Expand Up @@ -456,9 +456,9 @@ def test_handle_metrics(self) -> None:
assert isinstance(h._table_writers["metrics/user"], TableWriter)

h.flush()
datastore_file_path = workdir / "metrics" / "user" / "manifest.json"
assert datastore_file_path.exists()
assert datastore_file_path.is_file()

ds = LocalDataStore(str(workdir))
assert ds.list_tables(["m"]) == ["metrics/user"]

records = list(h._data_store.scan_tables([TableDesc("metrics/user")]))
assert len(records) == 2
Expand Down Expand Up @@ -512,9 +512,8 @@ def test_handle_artifacts(self) -> None:

h.flush()

datastore_file_path = workdir / "artifacts" / "user" / "manifest.json"
assert datastore_file_path.exists()
assert datastore_file_path.is_file()
ds = LocalDataStore(str(workdir))
assert ds.list_tables(["a"]) == ["artifacts/user"]

files_dir = workdir / "artifacts" / "_files"
assert files_dir.exists()
Expand Down Expand Up @@ -566,8 +565,9 @@ def test_run(self) -> None:
assert "metrics/_system" in h._table_writers
assert "artifacts/user" in h._table_writers

assert (workdir / "metrics" / "user" / "manifest.json").exists()
assert (workdir / "metrics" / "_system" / "manifest.json").exists()
ds = LocalDataStore(str(workdir))
assert ds.list_tables(["m"]) == ["metrics/user", "metrics/_system"]

assert (workdir / "artifacts" / "_files").exists()
assert len(list((workdir / "artifacts" / "_files").iterdir())) != 0
assert (workdir / "params" / "user.json").exists()
Expand Down
Loading