diff --git a/lib/dl_api_lib/dl_api_lib/query/registry.py b/lib/dl_api_lib/dl_api_lib/query/registry.py index 04fe91430..10b97dc13 100644 --- a/lib/dl_api_lib/dl_api_lib/query/registry.py +++ b/lib/dl_api_lib/dl_api_lib/query/registry.py @@ -9,7 +9,6 @@ QueryProcessingMode, SourceBackendType, ) -from dl_core.fields import ResultSchema from dl_formula.core.dialect import ( DialectCombo, DialectName, @@ -64,12 +63,11 @@ def _get_default_mqm_factory_cls() -> Type[MultiQueryMutatorFactoryBase]: return DefaultMultiQueryMutatorFactory -def get_multi_query_mutator_factory( +def get_multi_query_mutator_factory_class( query_proc_mode: QueryProcessingMode, backend_type: SourceBackendType, dialect: DialectCombo, - result_schema: ResultSchema, -) -> MultiQueryMutatorFactoryBase: +) -> Type[MultiQueryMutatorFactoryBase]: prioritized_keys = ( # First try with exact dialect and mode (exact match) MQMFactoryKey(query_proc_mode=query_proc_mode, backend_type=backend_type, dialect=dialect), @@ -96,7 +94,7 @@ def get_multi_query_mutator_factory( factory_cls = _get_default_mqm_factory_cls() assert factory_cls is not None - return factory_cls(result_schema=result_schema) + return factory_cls def register_multi_query_mutator_factory_cls( diff --git a/lib/dl_api_lib/dl_api_lib/service_registry/multi_query_mutator_factory.py b/lib/dl_api_lib/dl_api_lib/service_registry/multi_query_mutator_factory.py index 63c692dc4..d23368589 100644 --- a/lib/dl_api_lib/dl_api_lib/service_registry/multi_query_mutator_factory.py +++ b/lib/dl_api_lib/dl_api_lib/service_registry/multi_query_mutator_factory.py @@ -1,10 +1,12 @@ -import abc import logging -from typing import Sequence +from typing import ( + Sequence, + Type, +) import attr -from dl_api_lib.query.registry import get_multi_query_mutator_factory +from dl_api_lib.query.registry import get_multi_query_mutator_factory_class from dl_constants.enums import ( QueryProcessingMode, SourceBackendType, @@ -18,47 +20,41 @@ LOGGER = logging.getLogger(__name__) -class SRMultiQueryMutatorFactory(abc.ABC): - @abc.abstractmethod - def get_mqm_factory( - self, - backend_type: SourceBackendType, - dialect: DialectCombo, - dataset: Dataset, - ) -> MultiQueryMutatorFactoryBase: - raise NotImplementedError - - def get_multi_query_mutators( - self, - backend_type: SourceBackendType, - dialect: DialectCombo, - dataset: Dataset, - ) -> Sequence[MultiQueryMutatorBase]: - mqm_factory = self.get_mqm_factory(backend_type=backend_type, dialect=dialect, dataset=dataset) - return mqm_factory.get_mutators() - - @attr.s -class DefaultSRMultiQueryMutatorFactory(SRMultiQueryMutatorFactory): +class SRMultiQueryMutatorFactory: _query_proc_mode: QueryProcessingMode = attr.ib(kw_only=True) + _mqm_factory_cls_cache: dict[tuple[SourceBackendType, DialectCombo], Type[MultiQueryMutatorFactoryBase]] = attr.ib( + init=False, factory=dict + ) - def get_mqm_factory( + def get_mqm_factory_cls( self, backend_type: SourceBackendType, dialect: DialectCombo, - dataset: Dataset, - ) -> MultiQueryMutatorFactoryBase: + ) -> Type[MultiQueryMutatorFactoryBase]: # Try to get for the specified query mode - factory = get_multi_query_mutator_factory( + factory_cls = get_multi_query_mutator_factory_class( query_proc_mode=self._query_proc_mode, backend_type=backend_type, dialect=dialect, - result_schema=dataset.result_schema, ) LOGGER.info( f"Resolved MQM factory for backend_type {backend_type.name} " f"and dialect {dialect.common_name_and_version} " f"in {self._query_proc_mode.name} mode " - f"to {type(factory).__name__}" + f"to {factory_cls.__name__}" ) - return factory + return factory_cls + + def get_multi_query_mutators( + self, + backend_type: SourceBackendType, + dialect: DialectCombo, + dataset: Dataset, + ) -> Sequence[MultiQueryMutatorBase]: + cache_key = (backend_type, dialect) + if (mqm_factory_cls := self._mqm_factory_cls_cache.get(cache_key)) is None: + mqm_factory_cls = self.get_mqm_factory_cls(backend_type=backend_type, dialect=dialect) + self._mqm_factory_cls_cache[cache_key] = mqm_factory_cls + mqm_factory = mqm_factory_cls(result_schema=dataset.result_schema) + return mqm_factory.get_mutators() diff --git a/lib/dl_api_lib/dl_api_lib/service_registry/service_registry.py b/lib/dl_api_lib/dl_api_lib/service_registry/service_registry.py index 34d069558..847ff250a 100644 --- a/lib/dl_api_lib/dl_api_lib/service_registry/service_registry.py +++ b/lib/dl_api_lib/dl_api_lib/service_registry/service_registry.py @@ -11,10 +11,7 @@ from dl_api_lib.connector_availability.base import ConnectorAvailabilityConfig from dl_api_lib.service_registry.field_id_generator_factory import FieldIdGeneratorFactory from dl_api_lib.service_registry.formula_parser_factory import FormulaParserFactory -from dl_api_lib.service_registry.multi_query_mutator_factory import ( - DefaultSRMultiQueryMutatorFactory, - SRMultiQueryMutatorFactory, -) +from dl_api_lib.service_registry.multi_query_mutator_factory import SRMultiQueryMutatorFactory from dl_api_lib.service_registry.supported_functions_manager import SupportedFunctionsManager from dl_api_lib.service_registry.typed_query_processor_factory import ( DefaultQueryProcessorFactory, @@ -96,6 +93,14 @@ class DefaultApiServiceRegistry(DefaultServicesRegistry, ApiServiceRegistry): # _pivot_transformer_factory: Optional[PivotTransformerFactory] = attr.ib(kw_only=True, default=None) _typed_query_processor_factory: TypedQueryProcessorFactory = attr.ib(kw_only=True) + _multi_query_mutator_factory_factory: SRMultiQueryMutatorFactory = attr.ib( + init=False, + default=attr.Factory( + lambda self: SRMultiQueryMutatorFactory(query_proc_mode=self._query_proc_mode), + takes_self=True, + ), + ) + @_formula_parser_factory.default # noqa def _default_formula_parser_factory(self) -> FormulaParserFactory: return FormulaParserFactory(default_formula_parser_type=self._default_formula_parser_type) @@ -140,7 +145,7 @@ def get_connector_availability(self) -> ConnectorAvailabilityConfig: return self._connector_availability def get_multi_query_mutator_factory_factory(self) -> SRMultiQueryMutatorFactory: - return DefaultSRMultiQueryMutatorFactory(query_proc_mode=self._query_proc_mode) + return self._multi_query_mutator_factory_factory def get_pivot_transformer_factory(self) -> PivotTransformerFactory: assert self._pivot_transformer_factory is not None diff --git a/lib/dl_api_lib/dl_api_lib_tests/db/control_api/test_avatars.py b/lib/dl_api_lib/dl_api_lib_tests/db/control_api/test_avatars.py new file mode 100644 index 000000000..2317fcd47 --- /dev/null +++ b/lib/dl_api_lib/dl_api_lib_tests/db/control_api/test_avatars.py @@ -0,0 +1,27 @@ +from dl_api_lib_tests.db.base import DefaultApiTestBase + + +class TestAvatars(DefaultApiTestBase): + def test_delete_avatar_with_formula_fields(self, control_api, saved_dataset): + ds = saved_dataset + field_count = len(saved_dataset.result_schema) + fields = {f"{field.title}_copy": ds.field(formula=f"[{field.title}]") for field in saved_dataset.result_schema} + for title, field in fields.items(): + ds.result_schema[title] = field + ds = control_api.apply_updates(dataset=ds).dataset + ds = control_api.save_dataset(dataset=ds).dataset + assert len(ds.result_schema) == 2 * field_count + + resp = control_api.apply_updates( + ds, + updates=[ + ds.source_avatars["avatar_1"].delete(), + ], + fail_ok=True, + ) + assert resp.status_code == 400, resp.json # unknown fields in formulas + + # check that only copies are left + ds = resp.dataset + assert len(ds.result_schema) == field_count + assert all(field.title.endswith("_copy") for field in ds.result_schema) diff --git a/lib/dl_core/dl_core/fields.py b/lib/dl_core/dl_core/fields.py index 654479a53..064510c5f 100644 --- a/lib/dl_core/dl_core/fields.py +++ b/lib/dl_core/dl_core/fields.py @@ -351,18 +351,14 @@ def aggregation_locked(self) -> bool: @staticmethod def rename_in_formula(formula: str, key_map: Dict[str, str]) -> str: - found_keys = FIELD_RE.findall(formula) - - for key in found_keys: - try: - value = key_map[key] - except KeyError: + def replace(match: re.Match) -> str: + key = match.group(1) + if (value := key_map.get(key)) is None: LOGGER.warning("Unknown field: %s", key) - continue - - formula = formula.replace("[{}]".format(key), "[{}]".format(value)) + return key + return value - return formula + return FIELD_RE.sub(replace, formula) def depends_on(self, field: BIField) -> bool: return self.calc_spec.depends_on(field) diff --git a/lib/dl_query_processing/dl_query_processing/compilation/formula_compiler.py b/lib/dl_query_processing/dl_query_processing/compilation/formula_compiler.py index 9df28ad4b..230fa91b8 100644 --- a/lib/dl_query_processing/dl_query_processing/compilation/formula_compiler.py +++ b/lib/dl_query_processing/dl_query_processing/compilation/formula_compiler.py @@ -524,12 +524,6 @@ def update_environments( self._group_by_ids = set(group_by_ids) self._order_by_specs = list(order_by_specs) - def get_expression_hash(self, expr: str) -> str: - expr = expr.strip() - # TODO: replace with formula_obj.extract - expr = BIField.rename_in_formula(formula=expr, key_map=self._fields.titles_to_guids) - return expr - def _try_parse_formula( self, field: BIField, collect_errors: bool = False ) -> Tuple[Optional[formula_nodes.Formula], List[FormulaErrorCtx]]: