From 5ef1daf090f60f8255833091068644f2b28c46e6 Mon Sep 17 00:00:00 2001 From: "min.tian" Date: Wed, 24 Jul 2024 16:03:51 +0800 Subject: [PATCH] add new db_config for better labeling: version, note Signed-off-by: min.tian --- vectordb_bench/backend/clients/api.py | 21 +++++++- .../frontend/components/check_results/data.py | 19 ++++--- .../components/run_test/dbConfigSetting.py | 52 +++++++++++++------ .../frontend/components/run_test/initStyle.py | 4 +- vectordb_bench/models.py | 12 +++-- 5 files changed, 81 insertions(+), 27 deletions(-) diff --git a/vectordb_bench/backend/clients/api.py b/vectordb_bench/backend/clients/api.py index 0024bf600..d9ec5d83b 100644 --- a/vectordb_bench/backend/clients/api.py +++ b/vectordb_bench/backend/clients/api.py @@ -38,6 +38,22 @@ class DBConfig(ABC, BaseModel): """ db_label: str = "" + version: str = "" + note: str = "" + + @staticmethod + def common_short_configs() -> list[str]: + """ + short input, such as `db_label`, `version` + """ + return ["version", "db_label"] + + @staticmethod + def common_long_configs() -> list[str]: + """ + long input, such as `note` + """ + return ["note"] @abstractmethod def to_dict(self) -> dict: @@ -45,7 +61,10 @@ def to_dict(self) -> dict: @validator("*") def not_empty_field(cls, v, field): - if field.name == "db_label": + if ( + field.name in cls.common_short_configs() + or field.name in cls.common_long_configs() + ): return v if not v and isinstance(v, (str, SecretStr)): raise ValueError("Empty string!") diff --git a/vectordb_bench/frontend/components/check_results/data.py b/vectordb_bench/frontend/components/check_results/data.py index 1e6bba00e..b3cac21e1 100644 --- a/vectordb_bench/frontend/components/check_results/data.py +++ b/vectordb_bench/frontend/components/check_results/data.py @@ -24,7 +24,10 @@ def getFilterTasks( task for task in tasks if task.task_config.db_name in dbNames - and task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case).name in caseNames + and task.task_config.case_config.case_id.case_cls( + task.task_config.case_config.custom_case + ).name + in caseNames ] return filterTasks @@ -35,17 +38,20 @@ def mergeTasks(tasks: list[CaseResult]): db_name = task.task_config.db_name db = task.task_config.db.value db_label = task.task_config.db_config.db_label or "" - case = task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case) + version = task.task_config.db_config.version or "" + case = task.task_config.case_config.case_id.case_cls( + task.task_config.case_config.custom_case + ) dbCaseMetricsMap[db_name][case.name] = { "db": db, "db_label": db_label, + "version": version, "metrics": mergeMetrics( dbCaseMetricsMap[db_name][case.name].get("metrics", {}), asdict(task.metrics), ), "label": getBetterLabel( - dbCaseMetricsMap[db_name][case.name].get( - "label", ResultLabel.FAILED), + dbCaseMetricsMap[db_name][case.name].get("label", ResultLabel.FAILED), task.label, ), } @@ -57,6 +63,7 @@ def mergeTasks(tasks: list[CaseResult]): metrics = metricInfo["metrics"] db = metricInfo["db"] db_label = metricInfo["db_label"] + version = metricInfo["version"] label = metricInfo["label"] if label == ResultLabel.NORMAL: mergedTasks.append( @@ -64,6 +71,7 @@ def mergeTasks(tasks: list[CaseResult]): "db_name": db_name, "db": db, "db_label": db_label, + "version": version, "case_name": case_name, "metricsSet": set(metrics.keys()), **metrics, @@ -79,8 +87,7 @@ def mergeMetrics(metrics_1: dict, metrics_2: dict) -> dict: metrics = {**metrics_1} for key, value in metrics_2.items(): metrics[key] = ( - getBetterMetric( - key, value, metrics[key]) if key in metrics else value + getBetterMetric(key, value, metrics[key]) if key in metrics else value ) return metrics diff --git a/vectordb_bench/frontend/components/run_test/dbConfigSetting.py b/vectordb_bench/frontend/components/run_test/dbConfigSetting.py index 8f4f35c93..257608413 100644 --- a/vectordb_bench/frontend/components/run_test/dbConfigSetting.py +++ b/vectordb_bench/frontend/components/run_test/dbConfigSetting.py @@ -1,9 +1,10 @@ from pydantic import ValidationError -from vectordb_bench.frontend.config.styles import * +from vectordb_bench.backend.clients import DB +from vectordb_bench.frontend.config.styles import DB_CONFIG_SETTING_COLUMNS from vectordb_bench.frontend.utils import inputIsPassword -def dbConfigSettings(st, activedDbList): +def dbConfigSettings(st, activedDbList: list[DB]): expander = st.expander("Configurations for the selected databases", True) dbConfigs = {} @@ -27,7 +28,7 @@ def dbConfigSettings(st, activedDbList): return dbConfigs, isAllValid -def dbConfigSettingItem(st, activeDb): +def dbConfigSettingItem(st, activeDb: DB): st.markdown( f"
{activeDb.value}
", unsafe_allow_html=True, @@ -36,20 +37,41 @@ def dbConfigSettingItem(st, activeDb): dbConfigClass = activeDb.config_cls properties = dbConfigClass.schema().get("properties") - propertiesItems = list(properties.items()) - moveDBLabelToLast(propertiesItems) dbConfig = {} - for j, property in enumerate(propertiesItems): - column = columns[j % DB_CONFIG_SETTING_COLUMNS] - key, value = property + idx = 0 + + # db config (unique) + for key, property in properties.items(): + if ( + key not in dbConfigClass.common_short_configs() + and key not in dbConfigClass.common_long_configs() + ): + column = columns[idx % DB_CONFIG_SETTING_COLUMNS] + idx += 1 + dbConfig[key] = column.text_input( + key, + key="%s-%s" % (activeDb.name, key), + value=property.get("default", ""), + type="password" if inputIsPassword(key) else "default", + ) + # db config (common short labels) + for key in dbConfigClass.common_short_configs(): + column = columns[idx % DB_CONFIG_SETTING_COLUMNS] + idx += 1 dbConfig[key] = column.text_input( key, - key="%s-%s" % (activeDb, key), - value=value.get("default", ""), - type="password" if inputIsPassword(key) else "default", + key="%s-%s" % (activeDb.name, key), + value="", + type="default", + placeholder="optional, for labeling results", ) - return dbConfig - -def moveDBLabelToLast(propertiesItems): - propertiesItems.sort(key=lambda x: 1 if x[0] == "db_label" else 0) + # db config (common long text_input) + for key in dbConfigClass.common_long_configs(): + dbConfig[key] = st.text_area( + key, + key="%s-%s" % (activeDb.name, key), + value="", + placeholder="optional", + ) + return dbConfig diff --git a/vectordb_bench/frontend/components/run_test/initStyle.py b/vectordb_bench/frontend/components/run_test/initStyle.py index 59dd438e1..1e6af57ad 100644 --- a/vectordb_bench/frontend/components/run_test/initStyle.py +++ b/vectordb_bench/frontend/components/run_test/initStyle.py @@ -9,6 +9,8 @@ def initStyle(st): div[data-testid='stHorizontalBlock'] {gap: 8px;} /* check box */ .stCheckbox p { color: #000; font-size: 18px; font-weight: 600; } + /* db selector - db_name should not wrap */ + div[data-testid="stVerticalBlockBorderWrapper"] div[data-testid="stCheckbox"] div[data-testid="stWidgetLabel"] p { white-space: nowrap; } """, unsafe_allow_html=True, - ) \ No newline at end of file + ) diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py index 56034796e..41be95b7c 100644 --- a/vectordb_bench/models.py +++ b/vectordb_bench/models.py @@ -2,7 +2,7 @@ import pathlib from datetime import date from enum import Enum, StrEnum, auto -from typing import List, Self, Sequence, Set +from typing import List, Self import ujson @@ -10,7 +10,6 @@ DB, DBConfig, DBCaseConfig, - IndexType, ) from .backend.cases import CaseType from .base import BaseModel @@ -128,9 +127,14 @@ class TaskConfig(BaseModel): @property def db_name(self): - db = self.db.value + db_name = f"{self.db.value}" db_label = self.db_config.db_label - return f"{db}-{db_label}" if db_label else db + if db_label: + db_name += f"-{db_label}" + version = self.db_config.version + if version: + db_name += f"-{version}" + return db_name class ResultLabel(Enum):