-
Notifications
You must be signed in to change notification settings - Fork 2
/
callback.py
30 lines (23 loc) · 1.04 KB
/
callback.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
"""Callback handlers used for the chat API"""
from typing import Any, Dict, List
from langchain.callbacks.base import AsyncCallbackHandler
from schemas import ChatResponse
class StreamingLLMCallbackHandler(AsyncCallbackHandler):
"""Callback handler for streaming LLM responses."""
def __init__(self, websocket):
self.websocket = websocket
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
resp = ChatResponse(sender="bot", message=token, type="stream")
await self.websocket.send_json(resp.dict())
class QuestionGenCallbackHandler(AsyncCallbackHandler):
"""Callback handler for question generation."""
def __init__(self, websocket):
self.websocket = websocket
async def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
resp = ChatResponse(
sender="bot", message="Synthesizing question...", type="info"
)
await self.websocket.send_json(resp.dict())