diff --git a/pyproject.toml b/pyproject.toml index f8685ff..e8f3d15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -148,7 +148,7 @@ ban-relative-imports = "parents" [tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports -"tests/**/*" = ["PLR2004", "S101", "TID252"] +"tests/**/*" = ["PLR2004", "S101", "TID252", "E501"] [tool.coverage.run] diff --git a/src/banks/extensions/chat.py b/src/banks/extensions/chat.py index db6cb15..ea2df08 100644 --- a/src/banks/extensions/chat.py +++ b/src/banks/extensions/chat.py @@ -1,51 +1,15 @@ # SPDX-FileCopyrightText: 2023-present Massimiliano Pippi # # SPDX-License-Identifier: MIT -from html.parser import HTMLParser +import re from jinja2 import TemplateSyntaxError, nodes from jinja2.ext import Extension -from banks.types import ChatMessage, ChatMessageContent, ContentBlock, ContentBlockType +from banks.types import ChatMessage, ContentBlock, ContentBlockType SUPPORTED_TYPES = ("system", "user") - - -class _ContentBlockParser(HTMLParser): - """A parser used to extract text surrounded by `` and `` tags.""" - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self._parse_block_content = False - self._content_blocks: list[ContentBlock] = [] - - @property - def content(self) -> ChatMessageContent: - """Returns ChatMessageContent data that can be directly assigned to ChatMessage.content. - - If only one block is present, this block is of type text and has no cache control set, we just - return it as plain text for simplicity. - """ - if len(self._content_blocks) == 1: - block = self._content_blocks[0] - if block.type == "text" and block.cache_control is None: - return block.text or "" - - return self._content_blocks - - def handle_starttag(self, tag, attrs): # noqa - if tag == "content_block": - self._parse_block_content = True - - def handle_endtag(self, tag): - if tag == "content_block": - self._parse_block_content = False - - def handle_data(self, data): - if self._parse_block_content: - self._content_blocks.append(ContentBlock.model_validate_json(data)) - else: - self._content_blocks.append(ContentBlock(type=ContentBlockType.text, text=data)) +CONTENT_BLOCK_REGEX = re.compile(r"((?s:.)*)<\/content_block>") class ChatExtension(Extension): @@ -105,7 +69,19 @@ def _store_chat_messages(self, role, caller): """ Helper callback. """ - parser = _ContentBlockParser() - parser.feed(caller()) - cm = ChatMessage(role=role, content=parser.content) + content_blocks: list[ContentBlock] = [] + result = CONTENT_BLOCK_REGEX.match(caller()) + if result is not None: + for g in result.groups(): + content_blocks.append(ContentBlock.model_validate_json(g)) + else: + content_blocks.append(ContentBlock(type=ContentBlockType.text, text=caller())) + + content = content_blocks + if len(content_blocks) == 1: + block = content_blocks[0] + if block.type == "text" and block.cache_control is None: + content = block.text or "" + + cm = ChatMessage(role=role, content=content) return cm.model_dump_json(exclude_none=True) + "\n" diff --git a/tests/templates/chat.jinja b/tests/templates/chat.jinja index 95d7980..ababafd 100644 --- a/tests/templates/chat.jinja +++ b/tests/templates/chat.jinja @@ -3,7 +3,7 @@ You are a helpful assistant. {% endchat %} {% chat role="user" %} -Hello, how are you? +{{ "Hello, how are you?" | cache_control("ephemeral") }} {% endchat %} {% chat role="system" %} @@ -14,4 +14,4 @@ I'm doing well, thank you! How can I assist you today? Can you explain quantum computing? {% endchat %} -Some random text. \ No newline at end of file +Some random text. diff --git a/tests/test_chat.py b/tests/test_chat.py index 38a90c0..2941d75 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -2,8 +2,6 @@ from jinja2 import TemplateSyntaxError from banks import Prompt -from banks.extensions.chat import _ContentBlockParser -from banks.types import CacheControl, ContentBlock, ContentBlockType def test_wrong_tag(): @@ -19,43 +17,3 @@ def test_wrong_tag_params(): def test_wrong_role_type(): with pytest.raises(TemplateSyntaxError): Prompt('{% chat role="does not exist" %}{% endchat %}') - - -def test_content_block_parser_init(): - p = _ContentBlockParser() - assert p._parse_block_content is False - assert p._content_blocks == [] - - -def test_content_block_parser_single_with_cache_control(): - p = _ContentBlockParser() - p.feed( - '{"type":"text","cache_control":{"type":"ephemeral"},"text":"foo","source":null}' - ) - assert p.content == [ - ContentBlock(type=ContentBlockType.text, cache_control=CacheControl(type="ephemeral"), text="foo", source=None) - ] - - -def test_content_block_parser_single_no_cache_control(): - p = _ContentBlockParser() - p.feed('{"type":"text","cache_control":null,"text":"foo","source":null}') - assert p.content == "foo" - - -def test_content_block_parser_multiple(): - p = _ContentBlockParser() - p.feed( - '{"type":"text","cache_control":null,"text":"foo","source":null}' - '{"type":"text","cache_control":null,"text":"bar","source":null}' - ) - assert p.content == [ - ContentBlock(type=ContentBlockType.text, cache_control=None, text="foo", source=None), - ContentBlock(type=ContentBlockType.text, cache_control=None, text="bar", source=None), - ] - - -def test_content_block_parser_other_tags(): - p = _ContentBlockParser() - p.feed("FOO") - assert p.content == "FOO" diff --git a/tests/test_prompt.py b/tests/test_prompt.py index ace025d..9d9b554 100644 --- a/tests/test_prompt.py +++ b/tests/test_prompt.py @@ -8,6 +8,7 @@ from banks import AsyncPrompt, ChatMessage, Prompt from banks.cache import DefaultCache from banks.errors import AsyncError +from banks.types import CacheControl, ContentBlock, ContentBlockType def test_canary_word_generation(): @@ -90,7 +91,7 @@ def test_chat_messages(): == """ {"role":"system","content":"You are a helpful assistant.\\n"} -{"role":"user","content":"Hello, how are you?\\n"} +{"role":"user","content":[{"type":"text","cache_control":{"type":"ephemeral"},"text":"Hello, how are you?"}]} {"role":"system","content":"I'm doing well, thank you! How can I assist you today?\\n"} @@ -102,7 +103,19 @@ def test_chat_messages(): assert p.chat_messages() == [ ChatMessage(role="system", content="You are a helpful assistant.\n"), - ChatMessage(role="user", content="Hello, how are you?\n"), + ChatMessage( + role="user", + content=[ + ContentBlock( + type=ContentBlockType.text, + cache_control=CacheControl(type="ephemeral"), + text="Hello, how are you?", + image_url=None, + ) + ], + tool_call_id=None, + name=None, + ), ChatMessage(role="system", content="I'm doing well, thank you! How can I assist you today?\n"), ChatMessage(role="user", content="Can you explain quantum computing?\n"), ]