Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(datasets): BI-5860 check title while finding existing data sources #653

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions lib/dl_api_lib/dl_api_lib/dataset/base_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,14 @@ def get_backend_type(self) -> SourceBackendType:
return self._capabilities.get_backend_type(role=self.resolve_role())

def _reload_sources(self) -> None:
self._has_sources = bool(self._ds.get_single_data_source_id())
# resolve database characteristics
# never go to database from here --> only_cache=True
if self._has_sources:
if (source_id := self._ds.get_single_data_source_id()) is not None:
self._has_sources = True
role = self.resolve_role()
try:
backend_type = self._capabilities.get_backend_type(role=role)
dialect_name = resolve_dialect_name(backend_type=backend_type)
source_id = self._ds.get_single_data_source_id()
dsrc = self._get_data_source_strict(source_id=source_id, role=role)
db_info = dsrc.get_cached_db_info()
db_version = db_info.version
Expand All @@ -281,6 +280,8 @@ def _reload_sources(self) -> None:
)
except ReferencedUSEntryNotFound:
self.dialect = D.DUMMY
else:
self._has_sources = False

if self._query_spec is not None:
self.load_exbuilders()
Expand Down
2 changes: 1 addition & 1 deletion lib/dl_api_lib/dl_api_lib/dataset/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,7 +1126,7 @@ def validate_source_already_exists(self, source_id: str, source_data: dict) -> b
connection_id=source_data.get("connection_id"),
created_from=source_data["source_type"],
parameters=source_data.get("parameters"),
# Currently ignoring: `title`
title=source_data.get("title"),
)
if existing_id:
# such source already exists, don't add its copy to the dataset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def test_get_param_hash(
dataset = saved_dataset
service_registry = conn_default_service_registry
source_id = dataset.get_single_data_source_id()
assert source_id is not None
dsrc_coll = dataset_wrapper.get_data_source_coll_strict(source_id=source_id)
hash_from_dataset = dsrc_coll.get_param_hash()

Expand Down
3 changes: 2 additions & 1 deletion lib/dl_core/dl_core/components/editor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from collections import defaultdict
from copy import deepcopy
import logging
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -156,7 +157,7 @@ def add_data_source(
raw_schema = None

connection_ref = connection_ref_from_id(connection_id=connection_id)
parameters = parameters or {}
parameters = deepcopy(parameters or {})
parameters["connection_ref"] = connection_ref
parameters["raw_schema"] = raw_schema
parameters["index_info_set"] = index_info_set
Expand Down
3 changes: 0 additions & 3 deletions lib/dl_core/dl_core/us_connection_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,6 @@ def update_data_source(
) -> None:
raise NotImplementedError(self._dsrc_error)

def get_single_data_source_id(self) -> str:
raise NotImplementedError(self._dsrc_error)

def get_data_source_template_templates(self, localizer: Localizer) -> list[DataSourceTemplate]:
"""
For user-input parametrized sources such as subselects.
Expand Down
8 changes: 5 additions & 3 deletions lib/dl_core/dl_core/us_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,15 @@ def rls(self) -> RLS:
def error_registry(self) -> ComponentErrorRegistry:
return self.data.component_errors

def get_single_data_source_id(self, ignore_source_ids: Optional[Collection[str]] = None) -> str: # type: ignore # TODO: fix
def get_single_data_source_id(self, ignore_source_ids: Optional[Collection[str]] = None) -> Optional[str]:
# FIXME: remove in the future
ignore_source_ids = ignore_source_ids or ()
for dsrc_coll_spec in self.data.source_collections or ():
managed_by = dsrc_coll_spec.managed_by or ManagedBy.user
if dsrc_coll_spec.id in ignore_source_ids or managed_by != ManagedBy.user:
continue
return dsrc_coll_spec.id
return None

def get_own_materialized_tables(self, source_id: Optional[str] = None) -> Generator[str, None, None]:
for dsrc_coll_spec in self.data.source_collections or ():
Expand All @@ -111,7 +112,7 @@ def get_own_materialized_tables(self, source_id: Optional[str] = None) -> Genera
if dsrc_spec.table_name is not None:
yield dsrc_spec.table_name

def find_data_source_configuration( # type: ignore # TODO: fix
def find_data_source_configuration(
self,
connection_id: Optional[str],
created_from: Optional[DataSourceType] = None,
Expand All @@ -132,7 +133,7 @@ def find_data_source_configuration( # type: ignore # TODO: fix

def spec_matches_parameters(existing_spec: DataSourceSpec) -> bool:
# FIXME: Refactor
for key, value in parameters.items(): # type: ignore # TODO: fix
for key, value in parameters.items():
if getattr(existing_spec, key, None) != value:
return False
return True
Expand All @@ -150,6 +151,7 @@ def spec_matches_parameters(existing_spec: DataSourceSpec) -> bool:
):
# reference matches params
return dsrc_coll_spec.id
return None

def create_result_schema_field(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class TestComponentErrors(DefaultCoreTestClass):
def test_component_errors(self, sync_us_manager, saved_dataset):
dataset = saved_dataset
source_id = dataset.get_single_data_source_id()
assert source_id is not None
erreg = dataset.error_registry
erreg.add_error(id=source_id, type=ComponentType.data_source, message="This is an error", code=["ERR", "1"])
assert len(erreg.get_pack(id=source_id).errors) == 1
Expand Down
124 changes: 124 additions & 0 deletions lib/dl_core/dl_core_tests/db/test_data_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from copy import deepcopy

from dl_constants.enums import (
DataSourceType,
RawSQLLevel,
)
from dl_core_testing.dataset import add_dataset_source
from dl_core_tests.db.base import DefaultCoreTestClass

from dl_connector_clickhouse.core.clickhouse.constants import (
SOURCE_TYPE_CH_SUBSELECT,
SOURCE_TYPE_CH_TABLE,
)


FAKE_CREATED_FROM = DataSourceType.declare("FAKE_SOURCE")


class TestDataSource(DefaultCoreTestClass):
raw_sql_level = RawSQLLevel.subselect

def test_find_data_source_configuration(
self,
saved_connection,
saved_dataset,
sync_us_manager,
editable_dataset_wrapper,
dsrc_params,
):
# table source
params = deepcopy(dsrc_params)
assert (
saved_dataset.find_data_source_configuration(
connection_id=saved_connection.uuid,
created_from=SOURCE_TYPE_CH_TABLE,
parameters=params,
)
is not None
)

# wrong created_from
assert (
saved_dataset.find_data_source_configuration(
connection_id=saved_connection.uuid,
created_from=FAKE_CREATED_FROM,
parameters=params,
)
is None
)

# wrong table name
wrong_params = params | dict(table_name="fake_table_name")
assert (
saved_dataset.find_data_source_configuration(
connection_id=saved_connection.uuid,
created_from=SOURCE_TYPE_CH_TABLE,
parameters=wrong_params,
)
is None
)

# create a subsql source with title
subsql_params = dict(subsql="SELECT 1 AS A")
title = "My SQL"
add_dataset_source(
sync_usm=sync_us_manager,
connection=saved_connection,
dataset=saved_dataset,
editable_dataset_wrapper=editable_dataset_wrapper,
created_from=SOURCE_TYPE_CH_SUBSELECT,
dsrc_params=subsql_params,
title=title,
)
sync_us_manager.save(saved_dataset)
source_id = saved_dataset.find_data_source_configuration(
connection_id=saved_connection.uuid,
created_from=SOURCE_TYPE_CH_SUBSELECT,
parameters=subsql_params,
title=title,
)
assert source_id is not None

# find without a title
assert (
saved_dataset.find_data_source_configuration(
connection_id=saved_connection.uuid,
created_from=SOURCE_TYPE_CH_SUBSELECT,
parameters=subsql_params,
)
== source_id
)

# wrong created_from
assert (
saved_dataset.find_data_source_configuration(
connection_id=saved_connection.uuid,
created_from=SOURCE_TYPE_CH_TABLE,
parameters=params,
title=title,
)
is None
)

# wrong query
assert (
saved_dataset.find_data_source_configuration(
connection_id=saved_connection.uuid,
created_from=SOURCE_TYPE_CH_SUBSELECT,
parameters=dict(subsql="SELECT 2 AS A"),
title=title,
)
is None
)

# wrong title
assert (
saved_dataset.find_data_source_configuration(
connection_id=saved_connection.uuid,
created_from=SOURCE_TYPE_CH_SUBSELECT,
parameters=subsql_params,
title="Not my SQL",
)
is None
)
58 changes: 38 additions & 20 deletions lib/dl_core_testing/dl_core_testing/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,38 @@ def make_ds_key(*args: str) -> str:
return "/".join(["tests", *args])


def add_dataset_source(
sync_usm: SyncUSManager,
connection: ConnectionBase,
dataset: Dataset,
editable_dataset_wrapper: EditableDatasetTestWrapper,
created_from: DataSourceType,
dsrc_params: dict,
title: Optional[str] = None,
) -> None:
def conn_executor_factory() -> SyncConnExecutorBase:
return sync_usm.get_services_registry().get_conn_executor_factory().get_sync_conn_executor(conn=connection)

source_id = str(uuid.uuid4())
editable_dataset_wrapper.add_data_source(
source_id=source_id,
role=DataSourceRole.origin,
connection_id=connection.uuid,
created_from=created_from,
parameters=dsrc_params,
title=title,
)
sync_usm.load_dependencies(dataset)
dsrc = editable_dataset_wrapper.get_data_source_strict(source_id=source_id, role=DataSourceRole.origin)
editable_dataset_wrapper.update_data_source(
source_id=source_id,
role=DataSourceRole.origin,
raw_schema=dsrc.get_schema_info(conn_executor_factory=conn_executor_factory).schema,
)

editable_dataset_wrapper.add_avatar(avatar_id=str(uuid.uuid4()), source_id=source_id, title="Main Avatar")


def make_dataset(
sync_usm: SyncUSManager,
connection: Optional[ConnectionBase] = None,
Expand All @@ -45,7 +77,6 @@ def make_dataset(
dsrc_params: Optional[dict] = None,
created_via: Optional[DataSourceCreatedVia] = None,
) -> Dataset:
service_registry = sync_usm.get_services_registry()
ds_info = dict(ds_info or {})
db = db_table.db if db_table else None
table_name = table_name or (db_table.name if db_table else None)
Expand Down Expand Up @@ -76,28 +107,15 @@ def make_dataset(
**(dsrc_params or {}),
}
dsrc_params = {key: value for key, value in dsrc_params.items() if value is not None}
source_id = str(uuid.uuid4())

def conn_executor_factory() -> SyncConnExecutorBase:
return service_registry.get_conn_executor_factory().get_sync_conn_executor(conn=connection)

assert created_from is not None
ds_wrapper.add_data_source(
source_id=source_id,
role=DataSourceRole.origin,
connection_id=connection.uuid,
add_dataset_source(
sync_usm=sync_usm,
connection=connection,
dataset=dataset,
editable_dataset_wrapper=ds_wrapper,
created_from=created_from,
parameters=dsrc_params,
dsrc_params=dsrc_params,
)
sync_usm.load_dependencies(dataset)
dsrc = ds_wrapper.get_data_source_strict(source_id=source_id, role=DataSourceRole.origin)
ds_wrapper.update_data_source(
source_id=source_id,
role=DataSourceRole.origin,
raw_schema=dsrc.get_schema_info(conn_executor_factory=conn_executor_factory).schema,
)

ds_wrapper.add_avatar(avatar_id=str(uuid.uuid4()), source_id=source_id, title="Main Avatar")

return dataset

Expand Down
4 changes: 3 additions & 1 deletion lib/dl_core_testing/dl_core_testing/dataset_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,9 @@ def get_new_raw_schema(
)

def quote(self, value: str, role: DataSourceRole) -> str:
sql_dsrc = self.get_sql_data_source_strict(source_id=self._dataset.get_single_data_source_id(), role=role)
source_id = self._dataset.get_single_data_source_id()
assert source_id is not None
sql_dsrc = self.get_sql_data_source_strict(source_id=source_id, role=role)
dialect = sql_dsrc.get_dialect()
return dialect.identifier_preparer.quote(value)

Expand Down
1 change: 1 addition & 0 deletions lib/dl_core_testing/dl_core_testing/testcases/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def test_get_param_hash(
dataset = saved_dataset
service_registry = conn_default_service_registry
source_id = dataset.get_single_data_source_id()
assert source_id is not None
dsrc_coll = dataset_wrapper.get_data_source_coll_strict(source_id=source_id)
hash_from_dataset = dsrc_coll.get_param_hash()

Expand Down
Loading