Skip to content

Commit

Permalink
EXPERIMENTAL: reworked context injection such it is handled immediate…
Browse files Browse the repository at this point in the history
…ly in 'structured_view.ask()' and than stored in 'ExposedFunction' instances
  • Loading branch information
ds-jakub-cierocki committed Jul 15, 2024
1 parent a154577 commit 623effd
Show file tree
Hide file tree
Showing 14 changed files with 97 additions and 68 deletions.
4 changes: 2 additions & 2 deletions src/dbally/collection/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dbally.audit.events import RequestEnd, RequestStart
from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError
from dbally.collection.results import ExecutionResult
from dbally.context.context import CustomContext
from dbally.context.context import BaseCallerContext
from dbally.llms.base import LLM
from dbally.llms.clients.base import LLMOptions
from dbally.nl_responder.nl_responder import NLResponder
Expand Down Expand Up @@ -157,7 +157,7 @@ async def ask(
dry_run: bool = False,
return_natural_response: bool = False,
llm_options: Optional[LLMOptions] = None,
contexts: Optional[Iterable[CustomContext]] = None,
contexts: Optional[Iterable[BaseCallerContext]] = None,
) -> ExecutionResult:
"""
Ask question in a text form and retrieve the answer based on the available views.
Expand Down
22 changes: 10 additions & 12 deletions src/dbally/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
from abc import ABC
from typing import ClassVar, Iterable

from typing_extensions import Self, TypeAlias
from typing_extensions import Self

from dbally.context.exceptions import ContextNotAvailableError

CustomContext: TypeAlias = "BaseCallerContext"
from dbally.context.exceptions import BaseContextError


class BaseCallerContext(ABC):
Expand All @@ -23,7 +21,7 @@ class BaseCallerContext(ABC):
alias: ClassVar[str] = "AskerContext"

@classmethod
def select_context(cls, contexts: Iterable[CustomContext]) -> Self:
def select_context(cls, contexts: Iterable["BaseCallerContext"]) -> Self:
"""
Typically called from a subclass of BaseCallerContext, selects a member object from `contexts` being
an instance of the same class. Effectively provides a type dispatch mechanism, substituting the context
Expand All @@ -36,17 +34,17 @@ class by its right instance.
An instance of the same BaseCallerContext subclass this method is caller from.
Raises:
ContextNotAvailableError: If the sequence of context objects passed as argument is empty.
BaseContextError: If no element in `contexts` matches `cls` class.
"""

if not contexts:
raise ContextNotAvailableError(
"The LLM detected that the context is required to execute the query"
"and the filter signature allows contextualization while the context was not provided."
)
try:
selected_context = next(filter(lambda obj: isinstance(obj, cls), contexts))
except StopIteration as e:
# this custom exception provides more clear message what have just gone wrong
raise BaseContextError() from e

# TODO confirm whether it is possible to design a correct type hints here and skipping `type: ignore`
return next(filter(lambda obj: isinstance(obj, cls), contexts)) # type: ignore
return selected_context # type: ignore

@classmethod
def is_context_call(cls, node: ast.expr) -> bool:
Expand Down
24 changes: 21 additions & 3 deletions src/dbally/context/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
class ContextNotAvailableError(Exception):
class BaseContextError(Exception):
"""
An exception inheriting from BaseContextException pointining that no sufficient context information
was provided by the user while calling view.ask().
A base error for context handling logic.
"""


class SuitableContextNotProvidedError(BaseContextError):
"""
Raised when method argument type hint points that a contextualization is available
but not suitable context was provided.
"""

def __init__(self, filter_fun_signature: str, context_class_name: str) -> None:
# this syntax 'or BaseCallerContext' is just to prevent type checkers
# from raising a warning, as filter_.context_class can be None. It's essenially a fallback that should never
# be reached, unless somebody will use this Exception against its purpose.
# TODO consider raising a warning/error when this happens.

message = (
f"No context of class {context_class_name} was provided"
f"while the filter {filter_fun_signature} requires it."
)
super().__init__(message)
13 changes: 3 additions & 10 deletions src/dbally/iql/_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from dbally.audit.event_tracker import EventTracker
from dbally.context._utils import _does_arg_allow_context
from dbally.context.context import BaseCallerContext, CustomContext
from dbally.context.context import BaseCallerContext
from dbally.iql import syntax
from dbally.iql._exceptions import (
IQLArgumentParsingError,
Expand All @@ -23,21 +23,17 @@ class IQLProcessor:
Attributes:
source: Raw LLM response containing IQL filter calls.
allowed_functions: A mapping (typically a dict) of all filters implemented for a certain View.
contexts: A sequence (typically a list) of context objects, each being an instance of
a subclass of BaseCallerContext. May contain contexts irrelevant for the currently processed query.
allowed_functions: A mapping (typically a dict) of all filters implemented for a certain View.=
"""

source: str
allowed_functions: Mapping[str, "ExposedFunction"]
contexts: Iterable[CustomContext]
_event_tracker: EventTracker

def __init__(
self,
source: str,
allowed_functions: Iterable[ExposedFunction],
contexts: Optional[Iterable[CustomContext]] = None,
event_tracker: Optional[EventTracker] = None,
) -> None:
"""
Expand All @@ -46,14 +42,11 @@ def __init__(
Args:
source: Raw LLM response containing IQL filter calls.
allowed_functions: An interable (typically a list) of all filters implemented for a certain View.
contexts: An iterable (typically a list) of context objects, each being an instance of
a subclass of BaseCallerContext.
even_tracker: An EvenTracker instance.
"""

self.source = source
self.allowed_functions = {func.name: func for func in allowed_functions}
self.contexts = contexts or []
self._event_tracker = event_tracker or EventTracker()

async def process(self) -> syntax.Node:
Expand Down Expand Up @@ -148,7 +141,7 @@ def _parse_arg(
if not _does_arg_allow_context(arg_spec):
raise IQLContextNotAllowedError(arg, self.source, arg_name=arg_spec.name)

return parent_func_def.context_class.select_context(self.contexts)
return parent_func_def.context

if not isinstance(arg, ast.Constant):
raise IQLArgumentParsingError(arg, self.source)
Expand Down
15 changes: 4 additions & 11 deletions src/dbally/iql/_query.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from typing import TYPE_CHECKING, Iterable, List, Optional
from typing import TYPE_CHECKING, List, Optional

from typing_extensions import Self

from dbally.context.context import CustomContext

from ..audit.event_tracker import EventTracker
from . import syntax
from ._processor import IQLProcessor
Expand All @@ -28,11 +26,7 @@ def __str__(self) -> str:

@classmethod
async def parse(
cls,
source: str,
allowed_functions: List["ExposedFunction"],
event_tracker: Optional[EventTracker] = None,
contexts: Optional[Iterable[CustomContext]] = None,
cls, source: str, allowed_functions: List["ExposedFunction"], event_tracker: Optional[EventTracker] = None
) -> Self:
"""
Parse IQL string to IQLQuery object.
Expand All @@ -41,11 +35,10 @@ async def parse(
source: IQL string that needs to be parsed
allowed_functions: list of IQL functions that are allowed for this query
event_tracker: EventTracker object to track events
contexts: An iterable (typically a list) of context objects, each being
an instance of a subclass of BaseCallerContext.
Returns:
IQLQuery object
"""

root = await IQLProcessor(source, allowed_functions, contexts, event_tracker).process()
root = await IQLProcessor(source, allowed_functions, event_tracker).process()
return cls(root=root, source=source)
2 changes: 1 addition & 1 deletion src/dbally/iql/_type_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def validate_arg_type(required_type: Union[Type, _GenericAlias], value: Any) ->
actual_type = type_ext.get_origin(required_type) if isinstance(required_type, _GenericAlias) else required_type
# typing.Union is an instance of _GenericAlias
if actual_type is None:
# workaround to prevent type warning in line `if isisntanc(value, actual_type):`, TODO check whether necessary
# workaround to prevent type warning in line `if isisntance(value, actual_type):`, TODO check whether necessary
actual_type = required_type.__origin__

if actual_type is Union:
Expand Down
10 changes: 2 additions & 8 deletions src/dbally/iql_generator/iql_generator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Iterable, List, Optional
from typing import List, Optional

from dbally.audit.event_tracker import EventTracker
from dbally.context.context import CustomContext
from dbally.iql import IQLError, IQLQuery
from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat
from dbally.llms.base import LLM
Expand Down Expand Up @@ -43,7 +42,6 @@ async def generate_iql(
examples: Optional[List[FewShotExample]] = None,
llm_options: Optional[LLMOptions] = None,
n_retries: int = 3,
contexts: Optional[Iterable[CustomContext]] = None,
) -> IQLQuery:
"""
Generates IQL in text form using LLM.
Expand All @@ -55,8 +53,6 @@ async def generate_iql(
examples: List of examples to be injected into the conversation.
llm_options: Options to use for the LLM client.
n_retries: Number of retries to regenerate IQL in case of errors.
contexts: An iterable (typically a list) of context objects, each being
an instance of a subclass of BaseCallerContext.
Returns:
Generated IQL query.
Expand All @@ -78,9 +74,7 @@ async def generate_iql(
# TODO: Move response parsing to llm generate_text method
iql = formatted_prompt.response_parser(response)
# TODO: Move IQL query parsing to prompt response parser
return await IQLQuery.parse(
source=iql, allowed_functions=filters, event_tracker=event_tracker, contexts=contexts
)
return await IQLQuery.parse(source=iql, allowed_functions=filters, event_tracker=event_tracker)
except IQLError as exc:
# TODO handle the possibility of variable `response` being not initialized
# while runnning the following line
Expand Down
8 changes: 4 additions & 4 deletions src/dbally/views/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from dbally.audit.event_tracker import EventTracker
from dbally.collection.results import ViewExecutionResult
from dbally.context.context import CustomContext
from dbally.context.context import BaseCallerContext
from dbally.llms.base import LLM
from dbally.llms.clients.base import LLMOptions
from dbally.prompt.elements import FewShotExample
Expand All @@ -29,7 +29,7 @@ async def ask(
n_retries: int = 3,
dry_run: bool = False,
llm_options: Optional[LLMOptions] = None,
contexts: Optional[Iterable[CustomContext]] = None,
contexts: Optional[Iterable[BaseCallerContext]] = None,
) -> ViewExecutionResult:
"""
Executes the query and returns the result.
Expand Down Expand Up @@ -59,9 +59,9 @@ def list_similarity_indexes(self) -> Dict[AbstractSimilarityIndex, List[IndexLoc

def list_few_shots(self) -> List[FewShotExample]:
"""
List all examples to be injected into few-shot prompt.
Lists all examples to be injected into few-shot prompt.
Returns:
List of few-shot examples
List of few-shot examples.
"""
return []
23 changes: 22 additions & 1 deletion src/dbally/views/exposed_functions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from dataclasses import dataclass
from inspect import isclass
from typing import _GenericAlias # type: ignore
from typing import Generator, Optional, Sequence, Type, Union
from typing import Generator, Iterable, Optional, Sequence, Type, Union

import typing_extensions as type_ext

from dbally.context.context import BaseCallerContext
from dbally.context.exceptions import BaseContextError, SuitableContextNotProvidedError
from dbally.similarity import AbstractSimilarityIndex


Expand Down Expand Up @@ -127,6 +128,7 @@ class ExposedFunction:
description: str
parameters: Sequence[MethodParamWithTyping]
context_class: Optional[Type[BaseCallerContext]] = None
context: Optional[BaseCallerContext] = None

def __str__(self) -> str:
base_str = f"{self.name}({', '.join(str(param) for param in self.parameters)})"
Expand All @@ -135,3 +137,22 @@ def __str__(self) -> str:
return f"{base_str} - {self.description}"

return base_str

def inject_context(self, contexts: Iterable[BaseCallerContext]) -> None:
"""
Inserts reference to the member of `contexts` of the proper class in self.context.
Args:
contexts: An iterable of user-provided context objects.
Raises:
SuitableContextNotProvidedError: Ff no element in `contexts` matches `self.context_class`.
"""

if self.context_class is None:
return

try:
self.context = self.context_class.select_context(contexts)
except BaseContextError as e:
raise SuitableContextNotProvidedError(str(self), self.context_class.__name__) from e
4 changes: 2 additions & 2 deletions src/dbally/views/freeform/text2sql/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from dbally.audit.event_tracker import EventTracker
from dbally.collection.results import ViewExecutionResult
from dbally.context.context import CustomContext
from dbally.context.context import BaseCallerContext
from dbally.llms.base import LLM
from dbally.llms.clients.base import LLMOptions
from dbally.prompt.template import PromptTemplate
Expand Down Expand Up @@ -104,7 +104,7 @@ async def ask(
n_retries: int = 3,
dry_run: bool = False,
llm_options: Optional[LLMOptions] = None,
contexts: Optional[Iterable[CustomContext]] = None,
contexts: Optional[Iterable[BaseCallerContext]] = None,
) -> ViewExecutionResult:
"""
Executes the query and returns the result. It generates the SQL query from the natural language query and
Expand Down
23 changes: 20 additions & 3 deletions src/dbally/views/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from dbally.audit.event_tracker import EventTracker
from dbally.collection.results import ViewExecutionResult
from dbally.context.context import CustomContext
from dbally.context.context import BaseCallerContext
from dbally.iql import IQLQuery
from dbally.iql_generator.iql_generator import IQLGenerator
from dbally.llms.base import LLM
Expand Down Expand Up @@ -33,6 +33,22 @@ def get_iql_generator(self, llm: LLM) -> IQLGenerator:
"""
return IQLGenerator(llm=llm)

@classmethod
def contextualize_filters(
cls, filters: Iterable[ExposedFunction], contexts: Optional[Iterable[BaseCallerContext]]
) -> None:
"""
Updates a list of filters packed as ExposedFunction's by ingesting the matching context objects.
Args:
filters: An iterable of filters.
contexts: An iterable of context objects.
"""

contexts = contexts or []
for filter_ in filters:
filter_.inject_context(contexts)

async def ask(
self,
query: str,
Expand All @@ -41,7 +57,7 @@ async def ask(
n_retries: int = 3,
dry_run: bool = False,
llm_options: Optional[LLMOptions] = None,
contexts: Optional[Iterable[CustomContext]] = None,
contexts: Optional[Iterable[BaseCallerContext]] = None,
) -> ViewExecutionResult:
"""
Executes the query and returns the result. It generates the IQL query from the natural language query\
Expand All @@ -65,14 +81,15 @@ async def ask(
filters = self.list_filters()
examples = self.list_few_shots()

self.contextualize_filters(filters, contexts)

iql = await iql_generator.generate_iql(
question=query,
filters=filters,
examples=examples,
event_tracker=event_tracker,
llm_options=llm_options,
n_retries=n_retries,
contexts=contexts,
)

await self.apply_filters(iql)
Expand Down
Loading

0 comments on commit 623effd

Please sign in to comment.