-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add meta and version to the Prompt class (#12)
* feat: add meta and version to the Prompt class * move away from positional args
- Loading branch information
Showing
3 changed files
with
82 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]> | ||
# | ||
# SPDX-License-Identifier: MIT | ||
from typing import Optional | ||
from typing import Any | ||
|
||
from .cache import DefaultCache, RenderCache | ||
from .config import config | ||
|
@@ -12,7 +12,13 @@ | |
|
||
class BasePrompt: | ||
def __init__( | ||
self, text: str, canary_word: Optional[str] = None, render_cache: Optional[RenderCache] = None | ||
self, | ||
text: str, | ||
*, | ||
version: str | None = None, | ||
metadata: dict[str, Any] | None = None, | ||
canary_word: str | None = None, | ||
render_cache: RenderCache | None = None, | ||
) -> None: | ||
""" | ||
Prompt constructor. | ||
|
@@ -24,30 +30,34 @@ def __init__( | |
render_cache: The caching backend to store rendered prompts. If `None`, the default in-memory backend will | ||
be used. | ||
""" | ||
self._render_cache = render_cache or DefaultCache() | ||
self._metadata = metadata or {} | ||
self._raw: str = text | ||
self._render_cache = render_cache or DefaultCache() | ||
self._template = env.from_string(text) | ||
self.defaults = {"canary_word": canary_word or generate_canary_word()} | ||
self._version = version | ||
|
||
def _cache_get(self, data: dict) -> Optional[str]: | ||
return self._render_cache.get(data) | ||
|
||
def _cache_set(self, data: dict, text: str) -> None: | ||
self._render_cache.set(data, text) | ||
self.defaults = {"canary_word": canary_word or generate_canary_word()} | ||
|
||
def _get_context(self, data: Optional[dict]) -> dict: | ||
def _get_context(self, data: dict | None) -> dict: | ||
if data is None: | ||
return self.defaults | ||
return data | self.defaults | ||
|
||
@property | ||
def metadata(self) -> dict[str, Any]: | ||
return self._metadata | ||
|
||
@property | ||
def raw(self) -> str: | ||
"""Returns the raw text of the prompt.""" | ||
return self._raw | ||
|
||
@property | ||
def version(self) -> str | None: | ||
return self._version | ||
|
||
def canary_leaked(self, text: str) -> bool: | ||
""" | ||
Returns whether the canary word is present in `text`, signalling the prompt might have leaked. | ||
""" | ||
"""Returns whether the canary word is present in `text`, signalling the prompt might have leaked.""" | ||
return self.defaults["canary_word"] in text | ||
|
||
|
||
|
@@ -68,20 +78,20 @@ class Prompt(BasePrompt): | |
``` | ||
""" | ||
|
||
def text(self, data: Optional[dict] = None) -> str: | ||
def text(self, data: dict[str, Any] | None = None) -> str: | ||
""" | ||
Render the prompt using variables present in `data` | ||
Parameters: | ||
data: A dictionary containing the context variables. | ||
""" | ||
data = self._get_context(data) | ||
cached = self._cache_get(data) | ||
cached = self._render_cache.get(data) | ||
if cached: | ||
return cached | ||
|
||
rendered: str = self._template.render(data) | ||
self._cache_set(data, rendered) | ||
self._render_cache.set(data, rendered) | ||
return rendered | ||
|
||
|
||
|
@@ -146,19 +156,25 @@ async def main(): | |
``` | ||
""" | ||
|
||
def __init__(self, text: str) -> None: | ||
super().__init__(text) | ||
def __init__(self, *args: Any, **kwargs: Any) -> None: | ||
super().__init__(*args, **kwargs) | ||
|
||
if not config.ASYNC_ENABLED: | ||
msg = "Async is not enabled. Please set the environment variable 'BANKS_ASYNC_ENABLED=on' and try again." | ||
raise AsyncError(msg) | ||
|
||
async def text(self, data: Optional[dict] = None) -> str: | ||
async def text(self, data: dict[str, Any] | None = None) -> str: | ||
""" | ||
Render the prompt using variables present in `data` | ||
Parameters: | ||
data: A dictionary containing the context variables. | ||
""" | ||
data = self._get_context(data) | ||
cached = self._cache_get(data) | ||
cached = self._render_cache.get(data) | ||
if cached: | ||
return cached | ||
|
||
rendered: str = await self._template.render_async(data) | ||
self._cache_set(data, rendered) | ||
self._render_cache.set(data, rendered) | ||
return rendered |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters