Skip to content

Commit

Permalink
修改单测以适配新的 V2 Dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobiichi-Origami committed Nov 4, 2024
1 parent 34dfb20 commit aeab72a
Show file tree
Hide file tree
Showing 5 changed files with 385 additions and 92 deletions.
20 changes: 14 additions & 6 deletions python/qianfan/dataset/data_source/baidu_qianfan.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import json
import os
import re
import uuid
import zipfile
from typing import Any, Dict, Optional, Tuple
Expand Down Expand Up @@ -91,7 +92,13 @@ def _get_transmission_bos_info(
storage_region = sup_storage_region
elif self.storage_type == V2Consts.StorageType.Bos:
assert self.storage_region
assert self.storage_path
storage_region = self.storage_region
match_result = re.search(r"^bos://(.*?)/(.*)/$", self.storage_path)
if match_result is None:
raise ValueError("no bos bucket and path found")
groups = match_result.groups()
storage_id, storage_path = groups[0], groups[1]
elif self.storage_type == V2Consts.StorageType.SysStorage:
err_msg = "don't support upload dataset to dataset which use platform bos"
log_error(err_msg)
Expand Down Expand Up @@ -176,11 +183,12 @@ def save(
V2Consts.DatasetFormat.PromptImageResponse,
]

ak, sk = self._get_console_ak_and_sk()

# 获取存储信息和鉴权信息
storage_id, storage_path, storage_region = self._get_transmission_bos_info(
sup_storage_id, sup_storage_path, sup_storage_region
)
ak, sk = self._get_console_ak_and_sk()

# 构造本地和远端的路径
if not should_save_as_zip_file:
Expand Down Expand Up @@ -471,7 +479,6 @@ def _create_bare_dataset(
name=name,
version=qianfan_resp["versionNumber"],
storage_type=storage_type,
storage_path=qianfan_resp["storagePath"],
info=(
{**qianfan_resp, **addition_info} if addition_info else {**qianfan_resp}
),
Expand Down Expand Up @@ -701,11 +708,12 @@ def get_existed_dataset(

def create_new_version(self) -> "QianfanDataSource":
qianfan_resp = Data.V2.create_dataset_version(self.group_id)
result = qianfan_resp["result"]
dataset = QianfanDataSource(
id=qianfan_resp["versionId"],
group_id=qianfan_resp["datasetId"],
name=qianfan_resp["datasetName"],
version=qianfan_resp["versionNumber"],
id=result["versionId"],
group_id=result["datasetId"],
name=result["datasetName"],
version=result["versionNumber"],
data_format_type=self.data_format_type,
storage_type=self.storage_type,
storage_path=self.storage_path,
Expand Down
76 changes: 28 additions & 48 deletions python/qianfan/tests/dataset/data_source_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,7 @@
QianfanLocalCacheDir,
)
from qianfan.dataset.data_source import FileDataSource, FormatType, QianfanDataSource
from qianfan.resources.console.consts import (
DataProjectType,
DataSetType,
DataStorageType,
DataTemplateType,
)
from qianfan.resources.console.consts import V2 as V2Consts


def _clean_func():
Expand Down Expand Up @@ -159,56 +154,43 @@ def test_save_as_folder():
def test_create_bare_qianfan_data_source():
datasource_1 = QianfanDataSource.create_bare_dataset(
"name",
DataTemplateType.NonSortedConversation,
DataStorageType.PublicBos,
V2Consts.DatasetFormat.PromptResponse,
V2Consts.StorageType.SysStorage,
)

assert datasource_1.template_type == DataTemplateType.NonSortedConversation
assert datasource_1.project_type == DataProjectType.Conversation
assert datasource_1.set_type == DataSetType.TextOnly
assert datasource_1.data_format_type == V2Consts.DatasetFormat.PromptResponse

datasource_2 = QianfanDataSource.create_bare_dataset(
"name",
DataTemplateType.Text2Image,
DataStorageType.PrivateBos,
storage_id="a",
storage_path="b",
V2Consts.DatasetFormat.PromptImage,
V2Consts.StorageType.Bos,
storage_path="bos://a/b/",
)

assert datasource_2.template_type == DataTemplateType.Text2Image
assert datasource_2.project_type == DataProjectType.Text2Image
assert datasource_2.set_type == DataSetType.MultiModel
assert (
datasource_2.storage_path
== "/easydata/_system_/dataset/ds-z07hkq2kyvsmrmdw/texts"
)
assert datasource_2.storage_id == "a"
assert datasource_2.data_format_type == V2Consts.DatasetFormat.PromptImage
assert datasource_2.storage_path == "bos://a/b/"
assert datasource_2.storage_region == "bj"
assert datasource_2.format_type() == FormatType.Text2Image


def test_create_qianfan_data_source_from_existed():
source = QianfanDataSource.get_existed_dataset("12", False)
assert source.id == "12"
assert source.storage_region == "bj"
source = QianfanDataSource.create_bare_dataset(
"empty", V2Consts.DatasetFormat.PromptResponse, V2Consts.StorageType.SysStorage
)
new_source = QianfanDataSource.get_existed_dataset(source.id, False)
assert new_source.id == source.id


def create_an_empty_qianfan_datasource() -> QianfanDataSource:
return QianfanDataSource(
id=1,
group_id=2,
id="1",
group_id="2",
name="test",
set_type=DataSetType.TextOnly,
project_type=DataProjectType.Conversation,
template_type=DataTemplateType.NonSortedConversation,
data_format_type=V2Consts.DatasetFormat.PromptResponse,
version=1,
storage_type=DataStorageType.PrivateBos,
storage_id="123",
storage_path="456",
storage_name="storage_name",
storage_raw_path="/s/",
storage_type=V2Consts.StorageType.Bos,
storage_path="bos://123/456/",
storage_region="bj",
data_format_type=FormatType.Jsonl,
)


Expand All @@ -221,11 +203,13 @@ def test_qianfan_data_source_save(mocker: MockerFixture, *args, **kwargs):
empty_table = Dataset.create_from_pyobj({QianfanDatasetPackColumnName: ["1"]})
ds = create_an_empty_qianfan_datasource()

ds.storage_type = DataStorageType.PublicBos
ds.storage_type = V2Consts.StorageType.SysStorage
with pytest.raises(NotImplementedError):
ds.save(empty_table)

ds = create_an_empty_qianfan_datasource()
ds = QianfanDataSource.create_bare_dataset(
"test", V2Consts.DatasetFormat.PromptResponse
)
config = get_config()

config.ACCESS_KEY = ""
Expand All @@ -240,12 +224,12 @@ def test_qianfan_data_source_save(mocker: MockerFixture, *args, **kwargs):
)

ds.ak = "1"

with pytest.raises(ValueError):
ds.save(empty_table)

ds.sk = "2"
assert ds.save(empty_table)
with pytest.raises(NotImplementedError):
ds.save(empty_table)

config.ACCESS_KEY = "1"
config.SECRET_KEY = "2"
Expand All @@ -256,16 +240,12 @@ def test_qianfan_data_source_save(mocker: MockerFixture, *args, **kwargs):
sup_storage_path="/sdasd/",
sup_storage_region="bj",
)
assert ds.save(
empty_table,
sup_storage_id="1",
sup_storage_path="/sdasd/",
sup_storage_region="bj",
)


def test_qianfan_data_source_load():
ds = create_an_empty_qianfan_datasource()
ds = QianfanDataSource.create_bare_dataset(
"empty", V2Consts.DatasetFormat.PromptResponse, V2Consts.StorageType.SysStorage
)
content = Dataset(inner_table=ds.fetch()).list()
assert content[0][0]["response"] == [["no response"]]

Expand Down
20 changes: 7 additions & 13 deletions python/qianfan/tests/dataset/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
QianfanSortedConversation,
)
from qianfan.dataset.table import Table
from qianfan.resources.console.consts import DataTemplateType
from qianfan.resources.console.consts import V2 as V2Consts
from qianfan.utils.pydantic import BaseModel


Expand Down Expand Up @@ -128,7 +128,7 @@ def test_dataset_create():

def test_dataset_online_process():
qianfan_data_source = QianfanDataSource.create_bare_dataset(
"test", DataTemplateType.GenericText
"test", V2Consts.DatasetFormat.Text
)
dataset = Dataset.load(source=qianfan_data_source)
assert dataset.online_data_process(
Expand Down Expand Up @@ -171,19 +171,13 @@ def test_branch_save(*args, **kwargs):
ds.unpack()
ds.save(fake_data_source)

from qianfan.tests.dataset.data_source_test import (
create_an_empty_qianfan_datasource,
fake_qianfan_data_source = QianfanDataSource.create_bare_dataset(
"test",
V2Consts.DatasetFormat.PromptResponse,
V2Consts.StorageType.Bos,
"bos://are/you/ok/",
)

fake_qianfan_data_source = create_an_empty_qianfan_datasource()
ds = Dataset.create_from_pyobj([{"prompt": "nihao", "response": [["hello"]]}])

ds.save(fake_qianfan_data_source)
ds.save(FakeDataSource(origin_data="", format=FormatType.Json))

fake_qianfan_data_source = create_an_empty_qianfan_datasource()
fake_qianfan_data_source.data_format_type = FormatType.Text
fake_qianfan_data_source.template_type = DataTemplateType.GenericText
fake_qianfan_data_source.project_type = DataTemplateType.GenericText
ds = Dataset.create_from_pyobj({QianfanDatasetPackColumnName: ["wenben"]})
ds.save(fake_qianfan_data_source)
29 changes: 17 additions & 12 deletions python/qianfan/tests/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def dispatch(self, event: Event) -> None:

def test_load_data_action():
qianfan_data_source = QianfanDataSource.create_bare_dataset(
"test", console_consts.DataTemplateType.NonSortedConversation
"test",
console_consts.V2.DatasetFormat.PromptResponse,
)
ds = Dataset.load(source=qianfan_data_source, organize_data_as_group=True)

Expand All @@ -68,7 +69,7 @@ def test_load_data_action():

res = LoadDataSetAction(
preset,
dataset_format_type=console_consts.DataTemplateType.NonSortedConversation,
dataset_format_type=console_consts.V2.DatasetFormat.PromptResponse,
).exec()
assert isinstance(res, dict)
assert "datasets" in res
Expand Down Expand Up @@ -134,7 +135,8 @@ def test_trainer_sft_run():
peft_type=PeftType.ALL,
)
qianfan_data_source = QianfanDataSource.create_bare_dataset(
"test", console_consts.DataTemplateType.NonSortedConversation
"test",
console_consts.V2.DatasetFormat.PromptResponse,
)
ds = Dataset.load(source=qianfan_data_source, organize_data_as_group=True)

Expand Down Expand Up @@ -183,7 +185,7 @@ def test_trainer_sft_with_deploy():
)
deploy_config = DeployConfig(replicas=1, pool_type=1, service_type=ServiceType.Chat)
qianfan_data_source = QianfanDataSource.create_bare_dataset(
"test", console_consts.DataTemplateType.NonSortedConversation
"test", console_consts.V2.DatasetFormat.PromptResponse
)
ds = Dataset.load(source=qianfan_data_source, organize_data_as_group=True)

Expand Down Expand Up @@ -231,7 +233,7 @@ def test_service_exec():

def test_trainer_resume():
qianfan_data_source = QianfanDataSource.create_bare_dataset(
name="test", template_type=console_consts.DataTemplateType.NonSortedConversation
"test", console_consts.V2.DatasetFormat.PromptResponse
)
ds = Dataset.load(source=qianfan_data_source, organize_data_as_group=True)

Expand Down Expand Up @@ -272,7 +274,7 @@ def test_batch_run_on_qianfan():
# 测试_parse_from_input方法
def test__parse_from_input():
qianfan_data_source = QianfanDataSource.create_bare_dataset(
"eval", console_consts.DataTemplateType.NonSortedConversation
"eval", console_consts.V2.DatasetFormat.PromptResponse
)
test_dataset = Dataset.load(source=qianfan_data_source, organize_data_as_group=True)
test_evaluators = [QianfanRuleEvaluator(using_accuracy=True)] # 创建一些评估器
Expand Down Expand Up @@ -300,7 +302,7 @@ def test__parse_from_input():
# 测试eval action exec方法
def test_eval_action_exec():
qianfan_data_source = QianfanDataSource.create_bare_dataset(
"eval", console_consts.DataTemplateType.NonSortedConversation
"eval", console_consts.V2.DatasetFormat.PromptResponse
)
test_dataset = Dataset.load(source=qianfan_data_source, organize_data_as_group=True)
test_evaluators = [QianfanRuleEvaluator(using_similarity=True)] # 创建一些评估器
Expand All @@ -317,7 +319,7 @@ def test_eval_action_exec():
# 测试eval action resume方法
def test_eval_action_resume():
qianfan_data_source = QianfanDataSource.create_bare_dataset(
"eval", console_consts.DataTemplateType.NonSortedConversation
"eval", console_consts.V2.DatasetFormat.PromptResponse
)
test_dataset = Dataset.load(source=qianfan_data_source, organize_data_as_group=True)
test_evaluators = [QianfanRuleEvaluator(using_similarity=True)] # 创建一些评估器
Expand All @@ -338,11 +340,12 @@ def test_trainer_sft_with_eval():
peft_type=PeftType.ALL,
)
qianfan_data_source = QianfanDataSource.create_bare_dataset(
"train", console_consts.DataTemplateType.NonSortedConversation
"train", console_consts.V2.DatasetFormat.PromptResponse
)
ds = Dataset.load(source=qianfan_data_source, organize_data_as_group=True)
qianfan_eval_data_source = QianfanDataSource.create_bare_dataset(
"eval", console_consts.DataTemplateType.NonSortedConversation
"eval",
console_consts.V2.DatasetFormat.PromptResponse,
)
eval_ds = Dataset.load(source=qianfan_eval_data_source, organize_data_as_group=True)
eh = MyEventHandler()
Expand Down Expand Up @@ -487,7 +490,8 @@ def test_failed_sft_run():
peft_type=PeftType.ALL,
)
qianfan_data_source = QianfanDataSource.create_bare_dataset(
"test", console_consts.DataTemplateType.NonSortedConversation
"test",
console_consts.V2.DatasetFormat.PromptResponse,
)
ds = Dataset.load(source=qianfan_data_source, organize_data_as_group=True)

Expand Down Expand Up @@ -527,7 +531,8 @@ def test_persist():
),
)
qianfan_data_source = QianfanDataSource.create_bare_dataset(
"test", console_consts.DataTemplateType.NonSortedConversation
"test",
console_consts.V2.DatasetFormat.PromptResponse,
)
ds = Dataset.load(source=qianfan_data_source, organize_data_as_group=True)

Expand Down
Loading

0 comments on commit aeab72a

Please sign in to comment.