Skip to content

Commit

Permalink
refactor(datastore): use flat structure to avoid too many sub directo…
Browse files Browse the repository at this point in the history
…ries (#2809)
  • Loading branch information
jialeicui authored Oct 8, 2023
1 parent 824ae66 commit 731773c
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 91 deletions.
109 changes: 65 additions & 44 deletions client/starwhale/api/_impl/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import contextlib
from abc import ABCMeta, abstractmethod
from http import HTTPStatus
from uuid import uuid4
from typing import (
Any,
cast,
Expand Down Expand Up @@ -896,19 +897,6 @@ def __eq__(self, other: Any) -> bool:
)


def _get_table_path(root_path: str | Path, table_name: str) -> Path:
"""
get table path from table name, return the matched file path if there is only one file match the table name
"""
expect_prefix = Path(root_path) / (table_name.strip("/") + datastore_table_file_ext)
paths = list(expect_prefix.parent.glob(f"{expect_prefix.name}*"))
if len(paths) > 1:
raise RuntimeError(f"can not find table {table_name}, get files {paths}")
if len(paths) == 1:
return paths[0]
return expect_prefix


def _merge_scan(
iters: List[Iterator[Dict[str, Any]]], keep_none: bool
) -> Iterator[dict]:
Expand Down Expand Up @@ -1983,6 +1971,17 @@ def to_dict(self) -> Dict[str, Any]:
return ret


class LocalTableDesc(SwBaseModel):
name: str
dir: str
created_at: int # timestamp in milliseconds


class LocalDataStoreManifest(SwBaseModel):
version: str = "0.1"
tables: List[LocalTableDesc]


class LocalDataStore:
_instance = None
_lock = threading.Lock()
Expand Down Expand Up @@ -2012,24 +2011,12 @@ def list_tables(
self,
prefixes: List[str],
) -> List[str]:
table_names = set()

for prefix in prefixes:
prefix_path = Path(self.root_path) / prefix.strip("/")
for fpath in prefix_path.rglob(f"*{datastore_table_file_ext}*"):
if not fpath.is_file():
continue

table_name = str(fpath.relative_to(self.root_path)).split(
datastore_table_file_ext
)[0]
table_names.add(table_name)

for table in self.tables:
if table.startswith(prefix):
table_names.add(table)

return list(table_names)
manifest = self._load_manifest()
return [
table.name
for table in manifest.tables
if any(table.name.startswith(prefix) for prefix in prefixes)
]

def update_table(
self,
Expand Down Expand Up @@ -2074,25 +2061,59 @@ def update_table(
# revision will never be None or empty (len(records) > 0), makes mypy happy
return revision or ""

def _load_manifest(self) -> LocalDataStoreManifest:
manifest_file = Path(self.root_path) / datastore_manifest_file_name
if manifest_file.exists():
with manifest_file.open() as f:
return LocalDataStoreManifest.parse_raw(f.read())
return LocalDataStoreManifest(tables=[])

def _dump_manifest(self, manifest: LocalDataStoreManifest) -> None:
manifest_file = Path(self.root_path) / datastore_manifest_file_name
with filelock.FileLock(str(Path(self.root_path) / ".lock")):
with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp:
tmp.write(manifest.json(indent=2))
tmp.flush()
shutil.move(tmp.name, manifest_file)

def _get_table(
self, table_name: str, key_column: ColumnSchema | None, create: bool = True
) -> LocalTable | None:
with self.lock:
table = self.tables.get(table_name, None)
if table is None:
try:
table = LocalTable(
table_name=table_name,
root_path=Path(self.root_path) / table_name,
key_column=key_column and key_column.name or None,
create_if_missing=create,
)
self.tables[table_name] = table
except TableEmptyException:
if create:
raise
if table is not None:
return table
# open or create
manifest = self._load_manifest()
table_root: Path | None = None
for t in manifest.tables:
if t.name == table_name:
table_root = Path(self.root_path) / t.dir
break
if table_root is None:
if not create:
return None
return table
uuid = uuid4().hex
table_root = Path(self.root_path) / uuid[:2] / uuid[2:-1]

table_root.mkdir(parents=True, exist_ok=True)
# no try, let it raise
table = LocalTable(
table_name=table_name,
root_path=table_root,
key_column=key_column and key_column.name or None,
create_if_missing=create,
)
manifest.tables.append(
LocalTableDesc(
name=table_name,
dir=str(table_root.relative_to(self.root_path)),
created_at=int(time.time() * 1000),
)
)
self._dump_manifest(manifest)
self.tables[table_name] = table
return table

def scan_tables(
self,
Expand Down
14 changes: 5 additions & 9 deletions client/tests/core/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
import json
import typing as t
from pathlib import Path
from unittest import TestCase
from unittest.mock import patch, MagicMock

import yaml
from click.testing import CliRunner
from requests_mock import Mocker
from pyfakefs.fake_filesystem_unittest import TestCase

from tests import ROOT_DIR
from tests import ROOT_DIR, BaseTestCase
from starwhale.utils import config as sw_config
from starwhale.utils import load_yaml
from starwhale.consts import (
Expand Down Expand Up @@ -45,11 +45,7 @@
_dataset_yaml = open(f"{_dataset_data_dir}/dataset.yaml").read()


class StandaloneDatasetTestCase(TestCase):
def setUp(self) -> None:
self.setUpPyfakefs()
sw_config._config = {}

class StandaloneDatasetTestCase(BaseTestCase):
@patch("starwhale.base.uri.resource.Resource._refine_local_rc_info")
@patch("starwhale.api._impl.dataset.model.Dataset.commit")
@patch("starwhale.api._impl.dataset.model.Dataset.__setitem__")
Expand Down Expand Up @@ -382,7 +378,7 @@ def __iter__(self) -> t.Generator:

sw = SWCliConfigMixed()

workdir = "/home/starwhale/myproject"
workdir = self.local_storage
name = "mnist"

ensure_dir(os.path.join(workdir, "data"))
Expand Down Expand Up @@ -419,7 +415,7 @@ def __iter__(self) -> t.Generator:
assert isinstance(_info, LocalDatasetInfo)
assert _info.version == build_version
assert _info.name == name
assert _info.path == str(snapshot_workdir.resolve())
assert str(Path(_info.path).resolve()) == str(snapshot_workdir.resolve())

tags = sd.tag.list()
assert set(tags) == {"t0", "t1", "latest", "v0"}
Expand Down
36 changes: 7 additions & 29 deletions client/tests/sdk/test_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from tests import BaseTestCase
from starwhale.consts import HTTPMethod
from starwhale.utils.fs import ensure_dir, ensure_file
from starwhale.api._impl import data_store
from starwhale.api._impl.data_store import (
INT32,
Expand All @@ -32,20 +31,6 @@


class TestBasicFunctions(BaseTestCase):
def test_get_table_path(self) -> None:
self.assertEqual(
Path("a") / "b.sw-datastore",
data_store._get_table_path("a", "b"),
)
self.assertEqual(
Path("a") / "b" / "c.sw-datastore",
data_store._get_table_path("a", "b/c"),
)
self.assertEqual(
Path("a") / "b" / "c" / "d.sw-datastore",
data_store._get_table_path("a", "b/c/d"),
)

def test_merge_scan(self) -> None:
self.assertEqual([], list(data_store._merge_scan([], False)), "no iter")
self.assertEqual(
Expand Down Expand Up @@ -1160,25 +1145,19 @@ def ds_update(index: int) -> bool:
def test_list_tables(self) -> None:
ds = data_store.LocalDataStore(self.datastore_root)

prefix = "project/self/eval/test-0/"
prefix = "project/self/eval/test-0"

tables = ds.list_tables([prefix])
assert tables == []

root = Path(self.datastore_root) / prefix.strip("/")
schema = data_store.TableSchema(
"k", [data_store.ColumnSchema("k", data_store.INT64)]
)
ds.update_table(f"{prefix}/labels", schema, [{"k": 0}])
ds.update_table(f"{prefix}/results", schema, [{"k": 0}])

ensure_file(root / "labels.sw-datastore.zip", "abc", parents=True)
ensure_file(root / "results.sw-datastore.zip", "abc", parents=True)
for i in range(0, 3):
ensure_file(root / "roc" / f"{i}.sw-datastore.zip", "abc", parents=True)

ensure_dir(root / "mock-dir.sw-datastore.zip")
ensure_file(root / "dummy.file", "abc", parents=True)

m_table_name = f"{prefix}memory-test-table"
ds.tables[m_table_name] = data_store.LocalTable(
m_table_name, self.datastore_root, "k"
)
ds.update_table(f"{prefix}/roc/{i}", schema, [{"k": 0}])

tables = ds.list_tables([prefix])
assert set(tables) == {
Expand All @@ -1189,7 +1168,6 @@ def test_list_tables(self) -> None:
"roc/0",
"roc/1",
"roc/2",
"memory-test-table",
}
}

Expand Down
18 changes: 9 additions & 9 deletions client/tests/sdk/test_track.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from starwhale.utils.process import check_call
from starwhale.base.data_type import Link, Audio, Image
from starwhale.base.uri.project import Project
from starwhale.api._impl.data_store import TableDesc, TableWriter
from starwhale.api._impl.data_store import TableDesc, TableWriter, LocalDataStore
from starwhale.api._impl.track.base import (
_TrackType,
TrackRecord,
Expand Down Expand Up @@ -456,9 +456,9 @@ def test_handle_metrics(self) -> None:
assert isinstance(h._table_writers["metrics/user"], TableWriter)

h.flush()
datastore_file_path = workdir / "metrics" / "user" / "manifest.json"
assert datastore_file_path.exists()
assert datastore_file_path.is_file()

ds = LocalDataStore(str(workdir))
assert ds.list_tables(["m"]) == ["metrics/user"]

records = list(h._data_store.scan_tables([TableDesc("metrics/user")]))
assert len(records) == 2
Expand Down Expand Up @@ -512,9 +512,8 @@ def test_handle_artifacts(self) -> None:

h.flush()

datastore_file_path = workdir / "artifacts" / "user" / "manifest.json"
assert datastore_file_path.exists()
assert datastore_file_path.is_file()
ds = LocalDataStore(str(workdir))
assert ds.list_tables(["a"]) == ["artifacts/user"]

files_dir = workdir / "artifacts" / "_files"
assert files_dir.exists()
Expand Down Expand Up @@ -566,8 +565,9 @@ def test_run(self) -> None:
assert "metrics/_system" in h._table_writers
assert "artifacts/user" in h._table_writers

assert (workdir / "metrics" / "user" / "manifest.json").exists()
assert (workdir / "metrics" / "_system" / "manifest.json").exists()
ds = LocalDataStore(str(workdir))
assert ds.list_tables(["m"]) == ["metrics/user", "metrics/_system"]

assert (workdir / "artifacts" / "_files").exists()
assert len(list((workdir / "artifacts" / "_files").iterdir())) != 0
assert (workdir / "params" / "user.json").exists()
Expand Down

0 comments on commit 731773c

Please sign in to comment.