Skip to content

Commit

Permalink
perf(datasets): DLHELP-12032 optimize avatar deletion action (#779)
Browse files Browse the repository at this point in the history
  • Loading branch information
MCPN authored Jan 15, 2025
1 parent 230db12 commit 62dd0db
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 57 deletions.
8 changes: 3 additions & 5 deletions lib/dl_api_lib/dl_api_lib/query/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
QueryProcessingMode,
SourceBackendType,
)
from dl_core.fields import ResultSchema
from dl_formula.core.dialect import (
DialectCombo,
DialectName,
Expand Down Expand Up @@ -64,12 +63,11 @@ def _get_default_mqm_factory_cls() -> Type[MultiQueryMutatorFactoryBase]:
return DefaultMultiQueryMutatorFactory


def get_multi_query_mutator_factory(
def get_multi_query_mutator_factory_class(
query_proc_mode: QueryProcessingMode,
backend_type: SourceBackendType,
dialect: DialectCombo,
result_schema: ResultSchema,
) -> MultiQueryMutatorFactoryBase:
) -> Type[MultiQueryMutatorFactoryBase]:
prioritized_keys = (
# First try with exact dialect and mode (exact match)
MQMFactoryKey(query_proc_mode=query_proc_mode, backend_type=backend_type, dialect=dialect),
Expand All @@ -96,7 +94,7 @@ def get_multi_query_mutator_factory(
factory_cls = _get_default_mqm_factory_cls()

assert factory_cls is not None
return factory_cls(result_schema=result_schema)
return factory_cls


def register_multi_query_mutator_factory_cls(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import abc
import logging
from typing import Sequence
from typing import (
Sequence,
Type,
)

import attr

from dl_api_lib.query.registry import get_multi_query_mutator_factory
from dl_api_lib.query.registry import get_multi_query_mutator_factory_class
from dl_constants.enums import (
QueryProcessingMode,
SourceBackendType,
Expand All @@ -18,47 +20,41 @@
LOGGER = logging.getLogger(__name__)


class SRMultiQueryMutatorFactory(abc.ABC):
@abc.abstractmethod
def get_mqm_factory(
self,
backend_type: SourceBackendType,
dialect: DialectCombo,
dataset: Dataset,
) -> MultiQueryMutatorFactoryBase:
raise NotImplementedError

def get_multi_query_mutators(
self,
backend_type: SourceBackendType,
dialect: DialectCombo,
dataset: Dataset,
) -> Sequence[MultiQueryMutatorBase]:
mqm_factory = self.get_mqm_factory(backend_type=backend_type, dialect=dialect, dataset=dataset)
return mqm_factory.get_mutators()


@attr.s
class DefaultSRMultiQueryMutatorFactory(SRMultiQueryMutatorFactory):
class SRMultiQueryMutatorFactory:
_query_proc_mode: QueryProcessingMode = attr.ib(kw_only=True)
_mqm_factory_cls_cache: dict[tuple[SourceBackendType, DialectCombo], Type[MultiQueryMutatorFactoryBase]] = attr.ib(
init=False, factory=dict
)

def get_mqm_factory(
def get_mqm_factory_cls(
self,
backend_type: SourceBackendType,
dialect: DialectCombo,
dataset: Dataset,
) -> MultiQueryMutatorFactoryBase:
) -> Type[MultiQueryMutatorFactoryBase]:
# Try to get for the specified query mode
factory = get_multi_query_mutator_factory(
factory_cls = get_multi_query_mutator_factory_class(
query_proc_mode=self._query_proc_mode,
backend_type=backend_type,
dialect=dialect,
result_schema=dataset.result_schema,
)
LOGGER.info(
f"Resolved MQM factory for backend_type {backend_type.name} "
f"and dialect {dialect.common_name_and_version} "
f"in {self._query_proc_mode.name} mode "
f"to {type(factory).__name__}"
f"to {factory_cls.__name__}"
)
return factory
return factory_cls

def get_multi_query_mutators(
self,
backend_type: SourceBackendType,
dialect: DialectCombo,
dataset: Dataset,
) -> Sequence[MultiQueryMutatorBase]:
cache_key = (backend_type, dialect)
if (mqm_factory_cls := self._mqm_factory_cls_cache.get(cache_key)) is None:
mqm_factory_cls = self.get_mqm_factory_cls(backend_type=backend_type, dialect=dialect)
self._mqm_factory_cls_cache[cache_key] = mqm_factory_cls
mqm_factory = mqm_factory_cls(result_schema=dataset.result_schema)
return mqm_factory.get_mutators()
15 changes: 10 additions & 5 deletions lib/dl_api_lib/dl_api_lib/service_registry/service_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@
from dl_api_lib.connector_availability.base import ConnectorAvailabilityConfig
from dl_api_lib.service_registry.field_id_generator_factory import FieldIdGeneratorFactory
from dl_api_lib.service_registry.formula_parser_factory import FormulaParserFactory
from dl_api_lib.service_registry.multi_query_mutator_factory import (
DefaultSRMultiQueryMutatorFactory,
SRMultiQueryMutatorFactory,
)
from dl_api_lib.service_registry.multi_query_mutator_factory import SRMultiQueryMutatorFactory
from dl_api_lib.service_registry.supported_functions_manager import SupportedFunctionsManager
from dl_api_lib.service_registry.typed_query_processor_factory import (
DefaultQueryProcessorFactory,
Expand Down Expand Up @@ -96,6 +93,14 @@ class DefaultApiServiceRegistry(DefaultServicesRegistry, ApiServiceRegistry): #
_pivot_transformer_factory: Optional[PivotTransformerFactory] = attr.ib(kw_only=True, default=None)
_typed_query_processor_factory: TypedQueryProcessorFactory = attr.ib(kw_only=True)

_multi_query_mutator_factory_factory: SRMultiQueryMutatorFactory = attr.ib(
init=False,
default=attr.Factory(
lambda self: SRMultiQueryMutatorFactory(query_proc_mode=self._query_proc_mode),
takes_self=True,
),
)

@_formula_parser_factory.default # noqa
def _default_formula_parser_factory(self) -> FormulaParserFactory:
return FormulaParserFactory(default_formula_parser_type=self._default_formula_parser_type)
Expand Down Expand Up @@ -140,7 +145,7 @@ def get_connector_availability(self) -> ConnectorAvailabilityConfig:
return self._connector_availability

def get_multi_query_mutator_factory_factory(self) -> SRMultiQueryMutatorFactory:
return DefaultSRMultiQueryMutatorFactory(query_proc_mode=self._query_proc_mode)
return self._multi_query_mutator_factory_factory

def get_pivot_transformer_factory(self) -> PivotTransformerFactory:
assert self._pivot_transformer_factory is not None
Expand Down
27 changes: 27 additions & 0 deletions lib/dl_api_lib/dl_api_lib_tests/db/control_api/test_avatars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from dl_api_lib_tests.db.base import DefaultApiTestBase


class TestAvatars(DefaultApiTestBase):
def test_delete_avatar_with_formula_fields(self, control_api, saved_dataset):
ds = saved_dataset
field_count = len(saved_dataset.result_schema)
fields = {f"{field.title}_copy": ds.field(formula=f"[{field.title}]") for field in saved_dataset.result_schema}
for title, field in fields.items():
ds.result_schema[title] = field
ds = control_api.apply_updates(dataset=ds).dataset
ds = control_api.save_dataset(dataset=ds).dataset
assert len(ds.result_schema) == 2 * field_count

resp = control_api.apply_updates(
ds,
updates=[
ds.source_avatars["avatar_1"].delete(),
],
fail_ok=True,
)
assert resp.status_code == 400, resp.json # unknown fields in formulas

# check that only copies are left
ds = resp.dataset
assert len(ds.result_schema) == field_count
assert all(field.title.endswith("_copy") for field in ds.result_schema)
16 changes: 6 additions & 10 deletions lib/dl_core/dl_core/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,18 +351,14 @@ def aggregation_locked(self) -> bool:

@staticmethod
def rename_in_formula(formula: str, key_map: Dict[str, str]) -> str:
found_keys = FIELD_RE.findall(formula)

for key in found_keys:
try:
value = key_map[key]
except KeyError:
def replace(match: re.Match) -> str:
key = match.group(1)
if (value := key_map.get(key)) is None:
LOGGER.warning("Unknown field: %s", key)
continue

formula = formula.replace("[{}]".format(key), "[{}]".format(value))
return key
return value

return formula
return FIELD_RE.sub(replace, formula)

def depends_on(self, field: BIField) -> bool:
return self.calc_spec.depends_on(field)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -524,12 +524,6 @@ def update_environments(
self._group_by_ids = set(group_by_ids)
self._order_by_specs = list(order_by_specs)

def get_expression_hash(self, expr: str) -> str:
expr = expr.strip()
# TODO: replace with formula_obj.extract
expr = BIField.rename_in_formula(formula=expr, key_map=self._fields.titles_to_guids)
return expr

def _try_parse_formula(
self, field: BIField, collect_errors: bool = False
) -> Tuple[Optional[formula_nodes.Formula], List[FormulaErrorCtx]]:
Expand Down

0 comments on commit 62dd0db

Please sign in to comment.