Skip to content

Commit

Permalink
other code
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Nov 14, 2024
1 parent a5ab95c commit c7e6659
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 69 deletions.
14 changes: 13 additions & 1 deletion docs/examples/chat-app.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ Demonstrates:

* reusing chat history
* serializing messages
* streaming responses

This demonstrates storing chat history between requests and using it to give the model context for new responses.

Most of the complex logic here is in `chat_app.html` which includes the page layout and JavaScript to handle the chat.
Most of the complex logic here is between `chat_app.py` which streams the response to the browser,
and `chat_app.ts` which renders messages in the browser.

## Running the Example

Expand All @@ -27,10 +29,20 @@ TODO screenshot.

## Example Code

Python code that runs the chat app:

```py title="chat_app.py"
#! pydantic_ai_examples/chat_app.py
```

Simple HTML page to render the app:

```html title="chat_app.html"
#! pydantic_ai_examples/chat_app.html
```

TypeScript to handle rendering the messages, to keep this simple (and at the risk of offending frontend developers) the typescript code is passed to the browser as plain text and transpiled in the browser.

```ts title="chat_app.ts"
#! pydantic_ai_examples/chat_app.ts
```
5 changes: 5 additions & 0 deletions pydantic_ai/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from dataclasses import dataclass
from datetime import datetime
from typing import Generic, TypeVar, cast

import logfire_api
Expand Down Expand Up @@ -273,6 +274,10 @@ def cost(self) -> Cost:
"""
return self.cost_so_far + self._stream_response.cost()

def timestamp(self) -> datetime:
"""Get the timestamp of the response."""
return self._stream_response.timestamp()

async def validate_structured_result(
self, message: messages.ModelStructuredResponse, *, allow_partial: bool = False
) -> ResultData:
Expand Down
77 changes: 15 additions & 62 deletions pydantic_ai_examples/chat_app.html
Original file line number Diff line number Diff line change
Expand Up @@ -58,71 +58,24 @@ <h1>Chat App</h1>
</main>
</body>
</html>
<script src="https://cdnjs.cloudflare.com/ajax/libs/typescript/5.6.3/typescript.min.js" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
<script type="module">
import { marked } from 'https://cdn.jsdelivr.net/npm/marked/lib/marked.esm.js';

function addMessages(lines) {
const messages = lines.filter(line => line.length > 1).map((line) => JSON.parse(line))
const parent = document.getElementById('conversation');
for (const message of messages) {
let msgDiv = document.createElement('div');
msgDiv.classList.add('border-top', 'pt-2', message.role);
msgDiv.innerHTML = marked.parse(message.content);
parent.appendChild(msgDiv);
}
// to let me write TypeScript, without adding the burden of npm we do a dirty, non-production-ready hack
// and transpile the TypeScript code in the browser
// this is (arguably) A neat demo trick, but not suitable for production!
async function loadTs() {
const response = await fetch('/chat_app.ts');
const tsCode = await response.text();
const jsCode = window.ts.transpile(tsCode, { target: "es2015" });
let script = document.createElement('script');
script.type = 'module';
script.text = jsCode;
document.body.appendChild(script);
}

function onError(error) {
console.error(error);
loadTs().catch((e) => {
console.error(e);
document.getElementById('error').classList.remove('d-none');
document.getElementById('spinner').classList.remove('active');
}

async function fetchResponse(response) {
let text = '';
if (response.ok) {
const reader = response.body.getReader();
while (true) {
const {done, value} = await reader.read();
if (done) {
break;
}
text += new TextDecoder().decode(value);
const lines = text.split('\n');
if (lines.length > 1) {
addMessages(lines.slice(0, -1));
text = lines[lines.length - 1];
}
}
addMessages(text.split('\n'));
let input = document.getElementById('prompt-input')
input.disabled = false;
input.focus();
} else {
const text = await response.text();
console.error(`Unexpected response: ${response.status}`, {response, text});
throw new Error(`Unexpected response: ${response.status}`);
}
}

async function onSubmit(e) {
e.preventDefault();
const spinner = document.getElementById('spinner');
spinner.classList.add('active');
const body = new FormData(e.target);

let input = document.getElementById('prompt-input')
input.value = '';
input.disabled = true;

const response = await fetch('/chat/', {method: 'POST', body});
await fetchResponse(response);
spinner.classList.remove('active');
}

// call onSubmit when form is submitted (e.g. user clicks the send button or hits Enter)
document.querySelector('form').addEventListener('submit', (e) => onSubmit(e).catch(onError));

// load messages on page load
fetch('/chat/').then(fetchResponse).catch(onError);
});
</script>
27 changes: 21 additions & 6 deletions pydantic_ai_examples/chat_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
from pydantic import Field, TypeAdapter

from pydantic_ai import Agent
from pydantic_ai.messages import Message, MessagesTypeAdapter, UserPrompt
from pydantic_ai.messages import (
Message,
MessagesTypeAdapter,
ModelTextResponse,
UserPrompt,
)

# 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured
logfire.configure(send_to_logfire='if-token-present')
Expand All @@ -32,6 +37,12 @@ async def index() -> HTMLResponse:
return HTMLResponse((THIS_DIR / 'chat_app.html').read_bytes())


@app.get('/chat_app.ts')
async def main_ts() -> Response:
"""Get the raw typescript code, it's compiled in the browser, forgive me."""
return Response((THIS_DIR / 'chat_app.ts').read_bytes(), media_type='text/plain')


@app.get('/chat/')
async def get_chat() -> Response:
msgs = database.get_messages()
Expand All @@ -49,12 +60,16 @@ async def stream_messages():
yield MessageTypeAdapter.dump_json(UserPrompt(content=prompt)) + b'\n'
# get the chat history so far to pass as context to the agent
messages = list(database.get_messages())
response = await agent.run(prompt, message_history=messages)
# run the agent with the user prompt and the chat history
async with agent.run_stream(prompt, message_history=messages) as result:
async for text in result.stream(debounce_by=0.01):
# text here is a `str` and the frontend wants
# JSON encoded ModelTextResponse, so we create one
m = ModelTextResponse(content=text, timestamp=result.timestamp())
yield MessageTypeAdapter.dump_json(m) + b'\n'

# add new messages (e.g. the user prompt and the agent response in this case) to the database
database.add_messages(response.new_messages_json())
# stream the last message which will be the agent response, we can't just yield `new_messages_json()`
# since we already stream the user prompt
yield MessageTypeAdapter.dump_json(response.all_messages()[-1]) + b'\n'
database.add_messages(result.new_messages_json())

return StreamingResponse(stream_messages(), media_type='text/plain')

Expand Down
3 changes: 3 additions & 0 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel
from pydantic_ai.models.test import TestModel
from pydantic_ai.result import Cost
from tests.conftest import IsNow

pytestmark = pytest.mark.anyio
Expand Down Expand Up @@ -51,6 +52,8 @@ async def ret_a(x: str) -> str:
response = await result.get_data()
assert response == snapshot('{"ret_a":"a-apple"}')
assert result.is_complete
assert result.cost() == snapshot(Cost())
assert result.timestamp() == IsNow(tz=timezone.utc)
assert result.all_messages() == snapshot(
[
UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
Expand Down

0 comments on commit c7e6659

Please sign in to comment.