Skip to content

Commit

Permalink
feat: add meta and version to the Prompt class (#12)
Browse files Browse the repository at this point in the history
* feat: add meta and version to the Prompt class

* move away from positional args
  • Loading branch information
masci authored Sep 28, 2024
1 parent 6210a8d commit b9d42f4
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 23 deletions.
58 changes: 37 additions & 21 deletions src/banks/prompt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
from typing import Optional
from typing import Any

from .cache import DefaultCache, RenderCache
from .config import config
Expand All @@ -12,7 +12,13 @@

class BasePrompt:
def __init__(
self, text: str, canary_word: Optional[str] = None, render_cache: Optional[RenderCache] = None
self,
text: str,
*,
version: str | None = None,
metadata: dict[str, Any] | None = None,
canary_word: str | None = None,
render_cache: RenderCache | None = None,
) -> None:
"""
Prompt constructor.
Expand All @@ -24,30 +30,34 @@ def __init__(
render_cache: The caching backend to store rendered prompts. If `None`, the default in-memory backend will
be used.
"""
self._render_cache = render_cache or DefaultCache()
self._metadata = metadata or {}
self._raw: str = text
self._render_cache = render_cache or DefaultCache()
self._template = env.from_string(text)
self.defaults = {"canary_word": canary_word or generate_canary_word()}
self._version = version

def _cache_get(self, data: dict) -> Optional[str]:
return self._render_cache.get(data)

def _cache_set(self, data: dict, text: str) -> None:
self._render_cache.set(data, text)
self.defaults = {"canary_word": canary_word or generate_canary_word()}

def _get_context(self, data: Optional[dict]) -> dict:
def _get_context(self, data: dict | None) -> dict:
if data is None:
return self.defaults
return data | self.defaults

@property
def metadata(self) -> dict[str, Any]:
return self._metadata

@property
def raw(self) -> str:
"""Returns the raw text of the prompt."""
return self._raw

@property
def version(self) -> str | None:
return self._version

def canary_leaked(self, text: str) -> bool:
"""
Returns whether the canary word is present in `text`, signalling the prompt might have leaked.
"""
"""Returns whether the canary word is present in `text`, signalling the prompt might have leaked."""
return self.defaults["canary_word"] in text


Expand All @@ -68,20 +78,20 @@ class Prompt(BasePrompt):
```
"""

def text(self, data: Optional[dict] = None) -> str:
def text(self, data: dict[str, Any] | None = None) -> str:
"""
Render the prompt using variables present in `data`
Parameters:
data: A dictionary containing the context variables.
"""
data = self._get_context(data)
cached = self._cache_get(data)
cached = self._render_cache.get(data)
if cached:
return cached

rendered: str = self._template.render(data)
self._cache_set(data, rendered)
self._render_cache.set(data, rendered)
return rendered


Expand Down Expand Up @@ -146,19 +156,25 @@ async def main():
```
"""

def __init__(self, text: str) -> None:
super().__init__(text)
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

if not config.ASYNC_ENABLED:
msg = "Async is not enabled. Please set the environment variable 'BANKS_ASYNC_ENABLED=on' and try again."
raise AsyncError(msg)

async def text(self, data: Optional[dict] = None) -> str:
async def text(self, data: dict[str, Any] | None = None) -> str:
"""
Render the prompt using variables present in `data`
Parameters:
data: A dictionary containing the context variables.
"""
data = self._get_context(data)
cached = self._cache_get(data)
cached = self._render_cache.get(data)
if cached:
return cached

rendered: str = await self._template.render_async(data)
self._cache_set(data, rendered)
self._render_cache.set(data, rendered)
return rendered
1 change: 0 additions & 1 deletion src/banks/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def set(
*,
name: str,
prompt: Prompt,
meta: dict | None = None,
version: str | None = None,
overwrite: bool = False,
): ...
46 changes: 45 additions & 1 deletion tests/test_prompt.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from unittest import mock

import pytest
import regex as re
from jinja2 import Environment

from banks import Prompt
from banks import AsyncPrompt, Prompt
from banks.cache import DefaultCache
from banks.errors import AsyncError


def test_canary_word_generation():
Expand All @@ -25,3 +28,44 @@ def test_prompt_cache():
p = Prompt("This is my prompt", render_cache=mock_cache)
p.text()
mock_cache.set.assert_called_once()


def test_ctor():
p = Prompt(
text="This is raw text",
version="1.0",
metadata={"LLM": "GPT-3.5"},
canary_word="FOO",
render_cache=DefaultCache(),
)
assert p.raw == "This is raw text"
assert p.version == "1.0"
assert p.metadata == {"LLM": "GPT-3.5"}
assert p.canary_leaked("The message is FOO")
assert p.text() == "This is raw text"
assert p.text() == "This is raw text"


def test_ctor_async_disabled():
with mock.patch("banks.prompt.config", ASYNC_ENABLED=False):
with pytest.raises(AsyncError):
AsyncPrompt(text="This is raw text")


@pytest.mark.asyncio
async def test_ctor_async():
with mock.patch("banks.prompt.config", ASYNC_ENABLED=True):
p = AsyncPrompt(
text="This is raw text",
version="1.0",
metadata={"LLM": "GPT-3.5"},
canary_word="FOO",
render_cache=DefaultCache(),
)
p._template = Environment(autoescape=True, enable_async=True).from_string(p.raw)
assert p.raw == "This is raw text"
assert p.version == "1.0"
assert p.metadata == {"LLM": "GPT-3.5"}
assert p.canary_leaked("The message is FOO")
assert await p.text() == "This is raw text"
assert await p.text() == "This is raw text"

0 comments on commit b9d42f4

Please sign in to comment.