diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 8ff62599..a0fccde9 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -10,7 +10,7 @@ from dbally.audit.events import RequestEnd, RequestStart from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ExecutionResult -from dbally.context.context import CustomContext +from dbally.context.context import BaseCallerContext from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.nl_responder.nl_responder import NLResponder @@ -157,7 +157,7 @@ async def ask( dry_run: bool = False, return_natural_response: bool = False, llm_options: Optional[LLMOptions] = None, - contexts: Optional[Iterable[CustomContext]] = None, + contexts: Optional[Iterable[BaseCallerContext]] = None, ) -> ExecutionResult: """ Ask question in a text form and retrieve the answer based on the available views. diff --git a/src/dbally/context/context.py b/src/dbally/context/context.py index b01cad2e..1f5a32d6 100644 --- a/src/dbally/context/context.py +++ b/src/dbally/context/context.py @@ -2,11 +2,9 @@ from abc import ABC from typing import ClassVar, Iterable -from typing_extensions import Self, TypeAlias +from typing_extensions import Self -from dbally.context.exceptions import ContextNotAvailableError - -CustomContext: TypeAlias = "BaseCallerContext" +from dbally.context.exceptions import BaseContextError class BaseCallerContext(ABC): @@ -23,7 +21,7 @@ class BaseCallerContext(ABC): alias: ClassVar[str] = "AskerContext" @classmethod - def select_context(cls, contexts: Iterable[CustomContext]) -> Self: + def select_context(cls, contexts: Iterable["BaseCallerContext"]) -> Self: """ Typically called from a subclass of BaseCallerContext, selects a member object from `contexts` being an instance of the same class. Effectively provides a type dispatch mechanism, substituting the context @@ -36,17 +34,17 @@ class by its right instance. An instance of the same BaseCallerContext subclass this method is caller from. Raises: - ContextNotAvailableError: If the sequence of context objects passed as argument is empty. + BaseContextError: If no element in `contexts` matches `cls` class. """ - if not contexts: - raise ContextNotAvailableError( - "The LLM detected that the context is required to execute the query" - "and the filter signature allows contextualization while the context was not provided." - ) + try: + selected_context = next(filter(lambda obj: isinstance(obj, cls), contexts)) + except StopIteration as e: + # this custom exception provides more clear message what have just gone wrong + raise BaseContextError() from e # TODO confirm whether it is possible to design a correct type hints here and skipping `type: ignore` - return next(filter(lambda obj: isinstance(obj, cls), contexts)) # type: ignore + return selected_context # type: ignore @classmethod def is_context_call(cls, node: ast.expr) -> bool: diff --git a/src/dbally/context/exceptions.py b/src/dbally/context/exceptions.py index 15c1d303..c538ee18 100644 --- a/src/dbally/context/exceptions.py +++ b/src/dbally/context/exceptions.py @@ -1,5 +1,23 @@ -class ContextNotAvailableError(Exception): +class BaseContextError(Exception): """ - An exception inheriting from BaseContextException pointining that no sufficient context information - was provided by the user while calling view.ask(). + A base error for context handling logic. """ + + +class SuitableContextNotProvidedError(BaseContextError): + """ + Raised when method argument type hint points that a contextualization is available + but not suitable context was provided. + """ + + def __init__(self, filter_fun_signature: str, context_class_name: str) -> None: + # this syntax 'or BaseCallerContext' is just to prevent type checkers + # from raising a warning, as filter_.context_class can be None. It's essenially a fallback that should never + # be reached, unless somebody will use this Exception against its purpose. + # TODO consider raising a warning/error when this happens. + + message = ( + f"No context of class {context_class_name} was provided" + f"while the filter {filter_fun_signature} requires it." + ) + super().__init__(message) diff --git a/src/dbally/iql/_processor.py b/src/dbally/iql/_processor.py index 5e18a480..37c08ee0 100644 --- a/src/dbally/iql/_processor.py +++ b/src/dbally/iql/_processor.py @@ -3,7 +3,7 @@ from dbally.audit.event_tracker import EventTracker from dbally.context._utils import _does_arg_allow_context -from dbally.context.context import BaseCallerContext, CustomContext +from dbally.context.context import BaseCallerContext from dbally.iql import syntax from dbally.iql._exceptions import ( IQLArgumentParsingError, @@ -23,21 +23,17 @@ class IQLProcessor: Attributes: source: Raw LLM response containing IQL filter calls. - allowed_functions: A mapping (typically a dict) of all filters implemented for a certain View. - contexts: A sequence (typically a list) of context objects, each being an instance of - a subclass of BaseCallerContext. May contain contexts irrelevant for the currently processed query. + allowed_functions: A mapping (typically a dict) of all filters implemented for a certain View.= """ source: str allowed_functions: Mapping[str, "ExposedFunction"] - contexts: Iterable[CustomContext] _event_tracker: EventTracker def __init__( self, source: str, allowed_functions: Iterable[ExposedFunction], - contexts: Optional[Iterable[CustomContext]] = None, event_tracker: Optional[EventTracker] = None, ) -> None: """ @@ -46,14 +42,11 @@ def __init__( Args: source: Raw LLM response containing IQL filter calls. allowed_functions: An interable (typically a list) of all filters implemented for a certain View. - contexts: An iterable (typically a list) of context objects, each being an instance of - a subclass of BaseCallerContext. even_tracker: An EvenTracker instance. """ self.source = source self.allowed_functions = {func.name: func for func in allowed_functions} - self.contexts = contexts or [] self._event_tracker = event_tracker or EventTracker() async def process(self) -> syntax.Node: @@ -148,7 +141,7 @@ def _parse_arg( if not _does_arg_allow_context(arg_spec): raise IQLContextNotAllowedError(arg, self.source, arg_name=arg_spec.name) - return parent_func_def.context_class.select_context(self.contexts) + return parent_func_def.context if not isinstance(arg, ast.Constant): raise IQLArgumentParsingError(arg, self.source) diff --git a/src/dbally/iql/_query.py b/src/dbally/iql/_query.py index cc090ad6..a9080a49 100644 --- a/src/dbally/iql/_query.py +++ b/src/dbally/iql/_query.py @@ -1,9 +1,7 @@ -from typing import TYPE_CHECKING, Iterable, List, Optional +from typing import TYPE_CHECKING, List, Optional from typing_extensions import Self -from dbally.context.context import CustomContext - from ..audit.event_tracker import EventTracker from . import syntax from ._processor import IQLProcessor @@ -28,11 +26,7 @@ def __str__(self) -> str: @classmethod async def parse( - cls, - source: str, - allowed_functions: List["ExposedFunction"], - event_tracker: Optional[EventTracker] = None, - contexts: Optional[Iterable[CustomContext]] = None, + cls, source: str, allowed_functions: List["ExposedFunction"], event_tracker: Optional[EventTracker] = None ) -> Self: """ Parse IQL string to IQLQuery object. @@ -41,11 +35,10 @@ async def parse( source: IQL string that needs to be parsed allowed_functions: list of IQL functions that are allowed for this query event_tracker: EventTracker object to track events - contexts: An iterable (typically a list) of context objects, each being - an instance of a subclass of BaseCallerContext. + Returns: IQLQuery object """ - root = await IQLProcessor(source, allowed_functions, contexts, event_tracker).process() + root = await IQLProcessor(source, allowed_functions, event_tracker).process() return cls(root=root, source=source) diff --git a/src/dbally/iql/_type_validators.py b/src/dbally/iql/_type_validators.py index 7b993ef5..b06f8305 100644 --- a/src/dbally/iql/_type_validators.py +++ b/src/dbally/iql/_type_validators.py @@ -70,7 +70,7 @@ def validate_arg_type(required_type: Union[Type, _GenericAlias], value: Any) -> actual_type = type_ext.get_origin(required_type) if isinstance(required_type, _GenericAlias) else required_type # typing.Union is an instance of _GenericAlias if actual_type is None: - # workaround to prevent type warning in line `if isisntanc(value, actual_type):`, TODO check whether necessary + # workaround to prevent type warning in line `if isisntance(value, actual_type):`, TODO check whether necessary actual_type = required_type.__origin__ if actual_type is Union: diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py index 8018f6e1..c6aeec31 100644 --- a/src/dbally/iql_generator/iql_generator.py +++ b/src/dbally/iql_generator/iql_generator.py @@ -1,7 +1,6 @@ -from typing import Iterable, List, Optional +from typing import List, Optional from dbally.audit.event_tracker import EventTracker -from dbally.context.context import CustomContext from dbally.iql import IQLError, IQLQuery from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat from dbally.llms.base import LLM @@ -43,7 +42,6 @@ async def generate_iql( examples: Optional[List[FewShotExample]] = None, llm_options: Optional[LLMOptions] = None, n_retries: int = 3, - contexts: Optional[Iterable[CustomContext]] = None, ) -> IQLQuery: """ Generates IQL in text form using LLM. @@ -55,8 +53,6 @@ async def generate_iql( examples: List of examples to be injected into the conversation. llm_options: Options to use for the LLM client. n_retries: Number of retries to regenerate IQL in case of errors. - contexts: An iterable (typically a list) of context objects, each being - an instance of a subclass of BaseCallerContext. Returns: Generated IQL query. @@ -78,9 +74,7 @@ async def generate_iql( # TODO: Move response parsing to llm generate_text method iql = formatted_prompt.response_parser(response) # TODO: Move IQL query parsing to prompt response parser - return await IQLQuery.parse( - source=iql, allowed_functions=filters, event_tracker=event_tracker, contexts=contexts - ) + return await IQLQuery.parse(source=iql, allowed_functions=filters, event_tracker=event_tracker) except IQLError as exc: # TODO handle the possibility of variable `response` being not initialized # while runnning the following line diff --git a/src/dbally/views/base.py b/src/dbally/views/base.py index e2292b56..a83b961d 100644 --- a/src/dbally/views/base.py +++ b/src/dbally/views/base.py @@ -5,7 +5,7 @@ from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult -from dbally.context.context import CustomContext +from dbally.context.context import BaseCallerContext from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.prompt.elements import FewShotExample @@ -29,7 +29,7 @@ async def ask( n_retries: int = 3, dry_run: bool = False, llm_options: Optional[LLMOptions] = None, - contexts: Optional[Iterable[CustomContext]] = None, + contexts: Optional[Iterable[BaseCallerContext]] = None, ) -> ViewExecutionResult: """ Executes the query and returns the result. @@ -59,9 +59,9 @@ def list_similarity_indexes(self) -> Dict[AbstractSimilarityIndex, List[IndexLoc def list_few_shots(self) -> List[FewShotExample]: """ - List all examples to be injected into few-shot prompt. + Lists all examples to be injected into few-shot prompt. Returns: - List of few-shot examples + List of few-shot examples. """ return [] diff --git a/src/dbally/views/exposed_functions.py b/src/dbally/views/exposed_functions.py index 9ee307ec..07b88005 100644 --- a/src/dbally/views/exposed_functions.py +++ b/src/dbally/views/exposed_functions.py @@ -1,11 +1,12 @@ from dataclasses import dataclass from inspect import isclass from typing import _GenericAlias # type: ignore -from typing import Generator, Optional, Sequence, Type, Union +from typing import Generator, Iterable, Optional, Sequence, Type, Union import typing_extensions as type_ext from dbally.context.context import BaseCallerContext +from dbally.context.exceptions import BaseContextError, SuitableContextNotProvidedError from dbally.similarity import AbstractSimilarityIndex @@ -127,6 +128,7 @@ class ExposedFunction: description: str parameters: Sequence[MethodParamWithTyping] context_class: Optional[Type[BaseCallerContext]] = None + context: Optional[BaseCallerContext] = None def __str__(self) -> str: base_str = f"{self.name}({', '.join(str(param) for param in self.parameters)})" @@ -135,3 +137,22 @@ def __str__(self) -> str: return f"{base_str} - {self.description}" return base_str + + def inject_context(self, contexts: Iterable[BaseCallerContext]) -> None: + """ + Inserts reference to the member of `contexts` of the proper class in self.context. + + Args: + contexts: An iterable of user-provided context objects. + + Raises: + SuitableContextNotProvidedError: Ff no element in `contexts` matches `self.context_class`. + """ + + if self.context_class is None: + return + + try: + self.context = self.context_class.select_context(contexts) + except BaseContextError as e: + raise SuitableContextNotProvidedError(str(self), self.context_class.__name__) from e diff --git a/src/dbally/views/freeform/text2sql/view.py b/src/dbally/views/freeform/text2sql/view.py index 27596d7e..31f4c041 100644 --- a/src/dbally/views/freeform/text2sql/view.py +++ b/src/dbally/views/freeform/text2sql/view.py @@ -8,7 +8,7 @@ from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult -from dbally.context.context import CustomContext +from dbally.context.context import BaseCallerContext from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.prompt.template import PromptTemplate @@ -104,7 +104,7 @@ async def ask( n_retries: int = 3, dry_run: bool = False, llm_options: Optional[LLMOptions] = None, - contexts: Optional[Iterable[CustomContext]] = None, + contexts: Optional[Iterable[BaseCallerContext]] = None, ) -> ViewExecutionResult: """ Executes the query and returns the result. It generates the SQL query from the natural language query and diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index cfe2d6ba..0b5e1f27 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -4,7 +4,7 @@ from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult -from dbally.context.context import CustomContext +from dbally.context.context import BaseCallerContext from dbally.iql import IQLQuery from dbally.iql_generator.iql_generator import IQLGenerator from dbally.llms.base import LLM @@ -33,6 +33,22 @@ def get_iql_generator(self, llm: LLM) -> IQLGenerator: """ return IQLGenerator(llm=llm) + @classmethod + def contextualize_filters( + cls, filters: Iterable[ExposedFunction], contexts: Optional[Iterable[BaseCallerContext]] + ) -> None: + """ + Updates a list of filters packed as ExposedFunction's by ingesting the matching context objects. + + Args: + filters: An iterable of filters. + contexts: An iterable of context objects. + """ + + contexts = contexts or [] + for filter_ in filters: + filter_.inject_context(contexts) + async def ask( self, query: str, @@ -41,7 +57,7 @@ async def ask( n_retries: int = 3, dry_run: bool = False, llm_options: Optional[LLMOptions] = None, - contexts: Optional[Iterable[CustomContext]] = None, + contexts: Optional[Iterable[BaseCallerContext]] = None, ) -> ViewExecutionResult: """ Executes the query and returns the result. It generates the IQL query from the natural language query\ @@ -65,6 +81,8 @@ async def ask( filters = self.list_filters() examples = self.list_few_shots() + self.contextualize_filters(filters, contexts) + iql = await iql_generator.generate_iql( question=query, filters=filters, @@ -72,7 +90,6 @@ async def ask( event_tracker=event_tracker, llm_options=llm_options, n_retries=n_retries, - contexts=contexts, ) await self.apply_filters(iql) diff --git a/tests/unit/iql/test_iql_parser.py b/tests/unit/iql/test_iql_parser.py index d28604e5..691ab39d 100644 --- a/tests/unit/iql/test_iql_parser.py +++ b/tests/unit/iql/test_iql_parser.py @@ -16,14 +16,8 @@ class TestCustomContext(BaseCallerContext): city: str -@dataclass -class AnotherTestCustomContext(BaseCallerContext): - some_field: str - - async def test_iql_parser(): custom_context = TestCustomContext(city="cracow") - custom_context2 = AnotherTestCustomContext(some_field="aaa") parsed = await IQLQuery.parse( "not (filter_by_name(['John', 'Anne']) and filter_by_city(AskerContext()) and filter_by_company('deepsense.ai'))", @@ -36,12 +30,12 @@ async def test_iql_parser(): description="", parameters=[MethodParamWithTyping(name="city", type=Union[str, TestCustomContext])], context_class=TestCustomContext, + context=custom_context, ), ExposedFunction( name="filter_by_company", description="", parameters=[MethodParamWithTyping(name="company", type=str)] ), ], - contexts=[custom_context, custom_context2], ) not_op = parsed.root diff --git a/tests/unit/test_iql_generator.py b/tests/unit/test_iql_generator.py index 7fb0a379..08189943 100644 --- a/tests/unit/test_iql_generator.py +++ b/tests/unit/test_iql_generator.py @@ -77,7 +77,7 @@ async def test_iql_generation(iql_generator: IQLGenerator, event_tracker: EventT options=None, ) mock_parse.assert_called_once_with( - source="filter_by_id(1)", allowed_functions=filters, event_tracker=event_tracker, contexts=None + source="filter_by_id(1)", allowed_functions=filters, event_tracker=event_tracker ) diff --git a/tests/unit/views/test_sqlalchemy_base.py b/tests/unit/views/test_sqlalchemy_base.py index 20f11e72..cd115410 100644 --- a/tests/unit/views/test_sqlalchemy_base.py +++ b/tests/unit/views/test_sqlalchemy_base.py @@ -59,10 +59,11 @@ async def test_filter_sql_generation() -> None: mock_connection = sqlalchemy.create_mock_engine("postgresql://", executor=None) mock_view = MockSqlAlchemyView(mock_connection.engine) + filters = mock_view.list_filters() + mock_view.contextualize_filters(filters, [SomeTestContext(age=69)]) + query = await IQLQuery.parse( - 'method_foo(1) and method_bar("London", 2020) and method_baz(AskerContext())', - allowed_functions=mock_view.list_filters(), - contexts=[SomeTestContext(age=69)], + 'method_foo(1) and method_bar("London", 2020) and method_baz(AskerContext())', allowed_functions=filters ) await mock_view.apply_filters(query) sql = normalize_whitespace(mock_view.execute(dry_run=True).metadata["sql"])