Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(client): make job info --web works #3052

Merged
merged 1 commit into from
Dec 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 2 additions & 15 deletions client/starwhale/web/data_store.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import glob
import typing as t
import os.path

from fastapi import APIRouter
from pydantic import Field, BaseModel

from starwhale.utils.config import SWCliConfigMixed
from starwhale.web.response import success, SuccessResp
from starwhale.base.uri.project import Project
from starwhale.api._impl.data_store import SwType, _get_type, TableDesc, LocalDataStore
Expand Down Expand Up @@ -38,18 +35,8 @@ class Config:

@router.post("/listTables")
def list_tables(request: ListTablesRequest) -> SuccessResp:
# TODO: use datastore builtin function
root = str(SWCliConfigMixed().datastore_dir)
path = os.path.join(root, request.prefix)
files = glob.glob(f"{path}**", recursive=True)
tables = []
for f in files:
if not os.path.isfile(f):
continue
p, file = os.path.split(f)
p = p[len(root) :].lstrip("/")
table_name = file.split(".sw-datastore.zip")[0]
tables.append(f"{p}/{table_name}")
ds = LocalDataStore.get_instance()
tables = ds.list_tables([request.prefix])
return success({"tables": tables})


Expand Down
2 changes: 1 addition & 1 deletion client/starwhale/web/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ async def proxy(request: Request) -> Response:
)
resp = await client.send(req, stream=True)
return StreamingResponse(
resp.aiter_bytes(),
resp.aiter_raw(), # use raw to support gzipped response
status_code=resp.status_code,
headers=resp.headers,
background=BackgroundTask(resp.aclose),
Expand Down
28 changes: 23 additions & 5 deletions client/tests/web/test_server.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import io
import os
import sys
import gzip
from typing import AsyncIterable
from asyncio import Future
from pathlib import Path
from unittest.mock import patch, MagicMock, PropertyMock

import httpx
from httpx._content import AsyncIteratorByteStream
from fastapi.testclient import TestClient

from starwhale import Link
from starwhale.utils.fs import ensure_file
from starwhale.web.server import Server
from starwhale.base.uri.instance import Instance
from starwhale.api._impl.data_store import TableWriter, LocalDataStore


def test_static_faked_response():
Expand Down Expand Up @@ -61,12 +65,14 @@ def test_datastore_list_tables(root: MagicMock, tmpdir: Path):
assert resp.status_code == 200
assert resp.json()["data"]["tables"] == []

ensure_file(tmpdir / "a" / "foo.sw-datastore.zip", b"", parents=True)
ensure_file(tmpdir / "b" / "c" / "foo.sw-datastore.zip", b"", parents=True)
ds = LocalDataStore.get_instance()
tw = TableWriter("foo", data_store=ds)
tw.insert({"id": 1})
tw.close()

resp = client.post("/api/v1/datastore/listTables", content='{"prefix": ""}')
assert resp.status_code == 200
assert set(resp.json()["data"]["tables"]) == {"a/foo", "b/c/foo"}
assert set(resp.json()["data"]["tables"]) == {"foo"}


@patch("starwhale.api._impl.data_store.LocalDataStore.scan_tables")
Expand Down Expand Up @@ -275,7 +281,19 @@ def test_proxy(m_send: MagicMock, m_ac: MagicMock, m_sw_config: MagicMock):
},
}

resp = httpx.Response(200, json={"data": "ok"})
compressed = io.BytesIO()
with gzip.GzipFile(fileobj=compressed, mode="wb") as f:
f.write(b'{ "data": "ok" }')

async def async_generator(data: bytes) -> AsyncIterable[bytes]:
for byte in data:
yield byte.to_bytes(1, "little")

resp = httpx.Response(
200,
stream=AsyncIteratorByteStream(async_generator(compressed.getvalue())),
headers={"Content-Type": "application/json", "Content-Encoding": "gzip"},
)

if sys.version_info >= (3, 8):
m_send.return_value = resp
Expand Down
Loading