Skip to content

Commit

Permalink
refactor(datastore): use flat structure to avoid too many sub directo…
Browse files Browse the repository at this point in the history
…ries
  • Loading branch information
jialeicui committed Sep 27, 2023
1 parent 3977573 commit 5867faa
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 83 deletions.
110 changes: 66 additions & 44 deletions client/starwhale/api/_impl/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
import contextlib
from abc import ABCMeta, abstractmethod
from http import HTTPStatus
from uuid import uuid4
from typing import Any, cast, Dict, List, Type, Tuple, Union, Iterator, Optional
from pathlib import Path
from collections import UserDict, OrderedDict

import dill
import numpy as np
import pyarrow as pa # type: ignore
import filelock
import requests
import tenacity
import jsonlines
Expand Down Expand Up @@ -883,19 +885,6 @@ def __eq__(self, other: Any) -> bool:
)


def _get_table_path(root_path: str | Path, table_name: str) -> Path:
"""
get table path from table name, return the matched file path if there is only one file match the table name
"""
expect_prefix = Path(root_path) / (table_name.strip("/") + datastore_table_file_ext)
paths = list(expect_prefix.parent.glob(f"{expect_prefix.name}*"))
if len(paths) > 1:
raise RuntimeError(f"can not find table {table_name}, get files {paths}")
if len(paths) == 1:
return paths[0]
return expect_prefix


def _merge_scan(
iters: List[Iterator[Dict[str, Any]]], keep_none: bool
) -> Iterator[dict]:
Expand Down Expand Up @@ -1851,6 +1840,17 @@ def to_dict(self) -> Dict[str, Any]:
return ret


class LocalTableDesc(SwBaseModel):
name: str
dir: str
created_at: int # timestamp in milliseconds


class LocalDataStoreManifest(SwBaseModel):
version: str = "0.1"
tables: List[LocalTableDesc]


class LocalDataStore:
_instance = None
_lock = threading.Lock()
Expand Down Expand Up @@ -1880,24 +1880,12 @@ def list_tables(
self,
prefixes: List[str],
) -> List[str]:
table_names = set()

for prefix in prefixes:
prefix_path = Path(self.root_path) / prefix.strip("/")
for fpath in prefix_path.rglob(f"*{datastore_table_file_ext}*"):
if not fpath.is_file():
continue

table_name = str(fpath.relative_to(self.root_path)).split(
datastore_table_file_ext
)[0]
table_names.add(table_name)

for table in self.tables:
if table.startswith(prefix):
table_names.add(table)

return list(table_names)
manifest = self._load_manifest()
return [
table.name
for table in manifest.tables
if any(table.name.startswith(prefix) for prefix in prefixes)
]

def update_table(
self,
Expand Down Expand Up @@ -1942,25 +1930,59 @@ def update_table(
# revision will never be None or empty (len(records) > 0), makes mypy happy
return revision or ""

def _load_manifest(self) -> LocalDataStoreManifest:
manifest_file = Path(self.root_path) / datastore_manifest_file_name
if manifest_file.exists():
with manifest_file.open() as f:
return LocalDataStoreManifest.parse_raw(f.read())
return LocalDataStoreManifest(tables=[])

def _dump_manifest(self, manifest: LocalDataStoreManifest) -> None:
manifest_file = Path(self.root_path) / datastore_manifest_file_name
with filelock.FileLock(str(Path(self.root_path) / ".lock")):
with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp:
tmp.write(manifest.json(indent=2))
tmp.flush()
shutil.move(tmp.name, manifest_file)

def _get_table(
self, table_name: str, key_column: ColumnSchema | None, create: bool = True
) -> LocalTable | None:
with self.lock:
table = self.tables.get(table_name, None)
if table is None:
try:
table = LocalTable(
table_name=table_name,
root_path=Path(self.root_path) / table_name,
key_column=key_column and key_column.name or None,
create_if_missing=create,
)
self.tables[table_name] = table
except TableEmptyException:
if create:
raise
if table is not None:
return table
# open or create
manifest = self._load_manifest()
table_root: Path | None = None
for t in manifest.tables:
if t.name == table_name:
table_root = Path(self.root_path) / t.dir
break
if table_root is None:
if not create:
return None
return table
uuid = uuid4().hex
table_root = Path(self.root_path) / uuid[:2] / uuid[2:-1]

table_root.mkdir(parents=True, exist_ok=True)
# no try, let it raise
table = LocalTable(
table_name=table_name,
root_path=table_root,
key_column=key_column and key_column.name or None,
create_if_missing=create,
)
manifest.tables.append(
LocalTableDesc(
name=table_name,
dir=str(table_root.relative_to(self.root_path)),
created_at=int(time.time() * 1000),
)
)
self._dump_manifest(manifest)
self.tables[table_name] = table
return table

def scan_tables(
self,
Expand Down
37 changes: 7 additions & 30 deletions client/tests/sdk/test_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import time
import unittest
import concurrent.futures
from pathlib import Path
from unittest.mock import Mock, patch, MagicMock

import numpy as np
Expand All @@ -13,7 +12,6 @@

from tests import BaseTestCase
from starwhale.consts import HTTPMethod
from starwhale.utils.fs import ensure_dir, ensure_file
from starwhale.api._impl import data_store
from starwhale.api._impl.data_store import (
INT32,
Expand All @@ -31,20 +29,6 @@


class TestBasicFunctions(BaseTestCase):
def test_get_table_path(self) -> None:
self.assertEqual(
Path("a") / "b.sw-datastore",
data_store._get_table_path("a", "b"),
)
self.assertEqual(
Path("a") / "b" / "c.sw-datastore",
data_store._get_table_path("a", "b/c"),
)
self.assertEqual(
Path("a") / "b" / "c" / "d.sw-datastore",
data_store._get_table_path("a", "b/c/d"),
)

def test_merge_scan(self) -> None:
self.assertEqual([], list(data_store._merge_scan([], False)), "no iter")
self.assertEqual(
Expand Down Expand Up @@ -1159,25 +1143,19 @@ def ds_update(index: int) -> bool:
def test_list_tables(self) -> None:
ds = data_store.LocalDataStore(self.datastore_root)

prefix = "project/self/eval/test-0/"
prefix = "project/self/eval/test-0"

tables = ds.list_tables([prefix])
assert tables == []

root = Path(self.datastore_root) / prefix.strip("/")
schema = data_store.TableSchema(
"k", [data_store.ColumnSchema("k", data_store.INT64)]
)
ds.update_table(f"{prefix}/labels", schema, [{"k": 0}])
ds.update_table(f"{prefix}/results", schema, [{"k": 0}])

ensure_file(root / "labels.sw-datastore.zip", "abc", parents=True)
ensure_file(root / "results.sw-datastore.zip", "abc", parents=True)
for i in range(0, 3):
ensure_file(root / "roc" / f"{i}.sw-datastore.zip", "abc", parents=True)

ensure_dir(root / "mock-dir.sw-datastore.zip")
ensure_file(root / "dummy.file", "abc", parents=True)

m_table_name = f"{prefix}memory-test-table"
ds.tables[m_table_name] = data_store.LocalTable(
m_table_name, self.datastore_root, "k"
)
ds.update_table(f"{prefix}/roc/{i}", schema, [{"k": 0}])

tables = ds.list_tables([prefix])
assert set(tables) == {
Expand All @@ -1188,7 +1166,6 @@ def test_list_tables(self) -> None:
"roc/0",
"roc/1",
"roc/2",
"memory-test-table",
}
}

Expand Down
18 changes: 9 additions & 9 deletions client/tests/sdk/test_track.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from starwhale.utils.process import check_call
from starwhale.base.data_type import Link, Audio, Image
from starwhale.base.uri.project import Project
from starwhale.api._impl.data_store import TableDesc, TableWriter
from starwhale.api._impl.data_store import TableDesc, TableWriter, LocalDataStore
from starwhale.api._impl.track.base import (
_TrackType,
TrackRecord,
Expand Down Expand Up @@ -451,9 +451,9 @@ def test_handle_metrics(self) -> None:
assert isinstance(h._table_writers["metrics/user"], TableWriter)

h.flush()
datastore_file_path = workdir / "metrics" / "user" / "manifest.json"
assert datastore_file_path.exists()
assert datastore_file_path.is_file()

ds = LocalDataStore(str(workdir))
assert ds.list_tables(["m"]) == ["metrics/user"]

records = list(h._data_store.scan_tables([TableDesc("metrics/user")]))
assert len(records) == 2
Expand Down Expand Up @@ -507,9 +507,8 @@ def test_handle_artifacts(self) -> None:

h.flush()

datastore_file_path = workdir / "artifacts" / "user" / "manifest.json"
assert datastore_file_path.exists()
assert datastore_file_path.is_file()
ds = LocalDataStore(str(workdir))
assert ds.list_tables(["a"]) == ["artifacts/user"]

files_dir = workdir / "artifacts" / "_files"
assert files_dir.exists()
Expand Down Expand Up @@ -561,8 +560,9 @@ def test_run(self) -> None:
assert "metrics/_system" in h._table_writers
assert "artifacts/user" in h._table_writers

assert (workdir / "metrics" / "user" / "manifest.json").exists()
assert (workdir / "metrics" / "_system" / "manifest.json").exists()
ds = LocalDataStore(str(workdir))
assert ds.list_tables(["m"]) == ["metrics/user", "metrics/_system"]

assert (workdir / "artifacts" / "_files").exists()
assert len(list((workdir / "artifacts" / "_files").iterdir())) != 0
assert (workdir / "params" / "user.json").exists()
Expand Down

0 comments on commit 5867faa

Please sign in to comment.