Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new attribute 'citations' to ConversationMessage type #193

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/src/multi_agent_orchestrator/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@ class AgentResponse:


class AgentCallbacks:
def on_llm_new_token(self, token: str) -> None:
def on_llm_new_token(self, message: ConversationMessage) -> None:
# Default implementation
pass

def on_llm_end(self, token: ConversationMessage) -> None:
# Default implementation
pass

@dataclass
class AgentOptions:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,12 @@ async def process_request(
decoded_response = chunk['bytes'].decode('utf-8')

# Trigger callback for each token (useful for real-time updates)
self.callbacks.on_llm_new_token(decoded_response)
self.callbacks.on_llm_new_token(
ConversationMessage(
role=ParticipantRole.ASSISTANT.value,
content=[{'text': decoded_response}]
)
)
completion += decoded_response

elif 'trace' in event:
Expand Down
9 changes: 7 additions & 2 deletions python/src/multi_agent_orchestrator/agents/anthropic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async def process_request(

if self.retriever:
response = await self.retriever.retrieve_and_combine_results(input_text)
context_prompt = f"\nHere is the context to use to answer the user's question:\n{response}"
context_prompt = f"\nHere is the context to use to answer the user's question:\n{response['text']}"
system_prompt += context_prompt

input = {
Expand Down Expand Up @@ -205,7 +205,12 @@ async def handle_streaming_response(self, input) -> Any:
async with self.client.messages.stream(**input) as stream:
async for event in stream:
if event.type == "text":
self.callbacks.on_llm_new_token(event.text)
self.callbacks.on_llm_new_token(
ConversationMessage(
role=ParticipantRole.ASSISTANT.value,
content=[{'text': event.text}]
)
)
elif event.type == "input_json":
message['input'] = json.loads(event.partial_json)
elif event.type == "content_block_stop":
Expand Down
70 changes: 61 additions & 9 deletions python/src/multi_agent_orchestrator/agents/bedrock_llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import os
import boto3
from multi_agent_orchestrator.agents import Agent, AgentOptions
from multi_agent_orchestrator.types import (ConversationMessage,
from multi_agent_orchestrator.types import (ConversationMessage, ConversationMessageMetadata,
ParticipantRole,
BEDROCK_MODEL_ID_CLAUDE_3_HAIKU,
TemplateVariables,
AgentProviderType)
from multi_agent_orchestrator.utils import conversation_to_dict, Logger, AgentTools
from multi_agent_orchestrator.retrievers import Retriever

import traceback

@dataclass
class BedrockLLMAgentOptions(AgentOptions):
Expand Down Expand Up @@ -114,11 +115,13 @@ async def process_request(
self.update_system_prompt()

system_prompt = self.system_prompt
citations = []

if self.retriever:
response = await self.retriever.retrieve_and_combine_results(input_text)
context_prompt = "\nHere is the context to use to answer the user's question:\n" + response
context_prompt = "\nHere is the context to use to answer the user's question:\n" + response['text']
system_prompt += context_prompt
citations = response['sources']

converse_cmd = {
'modelId': self.model_id,
Expand Down Expand Up @@ -151,6 +154,17 @@ async def process_request(
else:
bedrock_response = await self.handle_single_response(converse_cmd)

if citations:
if not converse_message.metadata:
bedrock_response['metadata'] = ConversationMessageMetadata()

bedrock_response.metadata.citations.extend(citations)

if self.streaming:
self.callbacks.on_llm_end(
bedrock_response
)

conversation.append(bedrock_response)

if any('toolUse' in content for content in bedrock_response.content):
Expand All @@ -172,18 +186,36 @@ async def process_request(
return final_message

if self.streaming:
return await self.handle_streaming_response(converse_cmd)
converse_message = await self.handle_streaming_response(converse_cmd)
else:
converse_message = await self.handle_single_response(converse_cmd)

if citations:
if not converse_message.metadata:
converse_message['metadata'] = ConversationMessageMetadata()

return await self.handle_single_response(converse_cmd)
converse_message.metadata.citations.extend(citations)

if self.streaming:
self.callbacks.on_llm_end(
converse_message
)

return converse_message

async def handle_single_response(self, converse_input: dict[str, Any]) -> ConversationMessage:
try:
response = self.client.converse(**converse_input)
if 'output' not in response:
raise ValueError("No output received from Bedrock model")

return ConversationMessage(
role=response['output']['message']['role'],
content=response['output']['message']['content']
role=ParticipantRole.ASSISTANT.value,
content=response['output']['message']['content'],
metadata=ConversationMessageMetadata({
'usage': response['usage'],
'metrics': response['metrics']
})
)
except Exception as error:
Logger.error(f"Error invoking Bedrock model:{str(error)}")
Expand All @@ -196,26 +228,37 @@ async def handle_streaming_response(self, converse_input: dict[str, Any]) -> Con
message = {}
content = []
message['content'] = content
message['metadata'] = None
text = ''
tool_use = {}

#stream the response into a message.
for chunk in response['stream']:

if 'messageStart' in chunk:
message['role'] = chunk['messageStart']['role']

elif 'contentBlockStart' in chunk:
tool = chunk['contentBlockStart']['start']['toolUse']
tool_use['toolUseId'] = tool['toolUseId']
tool_use['name'] = tool['name']

elif 'contentBlockDelta' in chunk:
delta = chunk['contentBlockDelta']['delta']

if 'toolUse' in delta:
if 'input' not in tool_use:
tool_use['input'] = ''
tool_use['input'] += delta['toolUse']['input']

elif 'text' in delta:
text += delta['text']
self.callbacks.on_llm_new_token(delta['text'])
self.callbacks.on_llm_new_token(
ConversationMessage(
role=ParticipantRole.ASSISTANT.value,
content=delta['text']
)
)
elif 'contentBlockStop' in chunk:
if 'input' in tool_use:
tool_use['input'] = json.loads(tool_use['input'])
Expand All @@ -224,12 +267,21 @@ async def handle_streaming_response(self, converse_input: dict[str, Any]) -> Con
else:
content.append({'text': text})
text = ''

elif 'metadata' in chunk:

message['metadata'] = ConversationMessageMetadata(
usage=chunk['metadata']['usage'],
metrics=chunk['metadata']['metrics']
)

print('generate message stream :', message)
return ConversationMessage(
role=ParticipantRole.ASSISTANT.value,
content=message['content']
**message
)

except Exception as error:
print(traceback.print_exc())
Logger.error(f"Error getting stream from Bedrock model: {str(error)}")
raise error

Expand Down
62 changes: 48 additions & 14 deletions python/src/multi_agent_orchestrator/agents/openai_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from multi_agent_orchestrator.agents import Agent, AgentOptions
from multi_agent_orchestrator.types import (
ConversationMessage,
ConversationMessageMetadata,
ParticipantRole,
OPENAI_MODEL_ID_GPT_O_MINI,
TemplateVariables
Expand All @@ -28,15 +29,15 @@ class OpenAIAgentOptions(AgentOptions):
class OpenAIAgent(Agent):
def __init__(self, options: OpenAIAgentOptions):
super().__init__(options)
if not options.api_key:
raise ValueError("OpenAI API key is required")


if options.client:
self.client = options.client
else:
if not options.api_key:
raise ValueError("OpenAI API key is required")

self.client = OpenAI(api_key=options.api_key)


self.model = options.model or OPENAI_MODEL_ID_GPT_O_MINI
self.streaming = options.streaming or False
self.retriever: Optional[Retriever] = options.retriever
Expand Down Expand Up @@ -83,7 +84,7 @@ def __init__(self, options: OpenAIAgentOptions):
options.custom_system_prompt.get('template'),
options.custom_system_prompt.get('variables')
)



def is_streaming_enabled(self) -> bool:
Expand All @@ -102,11 +103,13 @@ async def process_request(
self.update_system_prompt()

system_prompt = self.system_prompt
citations = None

if self.retriever:
response = await self.retriever.retrieve_and_combine_results(input_text)
context_prompt = "\nHere is the context to use to answer the user's question:\n" + response
context_prompt = "\nHere is the context to use to answer the user's question:\n" + response['text']
system_prompt += context_prompt
citations = response['sources']


messages = [
Expand All @@ -118,7 +121,6 @@ async def process_request(
{"role": "user", "content": input_text}
]


request_options = {
"model": self.model,
"messages": messages,
Expand All @@ -128,10 +130,24 @@ async def process_request(
"stop": self.inference_config.get('stopSequences'),
"stream": self.streaming
}

if self.streaming:
return await self.handle_streaming_response(request_options)
converse_message = await self.handle_streaming_response(request_options)
else:
return await self.handle_single_response(request_options)
converse_message = await self.handle_single_response(request_options)

if citations:
if not converse_message.metadata:
converse_message['metadata'] = ConversationMessageMetadata()

converse_message.metadata.citations.extend(citations)

if self.streaming:
self.callbacks.on_llm_end(
converse_message
)

return converse_message

except Exception as error:
Logger.error(f"Error in OpenAI API call: {str(error)}")
Expand All @@ -152,7 +168,11 @@ async def handle_single_response(self, request_options: Dict[str, Any]) -> Conve

return ConversationMessage(
role=ParticipantRole.ASSISTANT.value,
content=[{"text": assistant_message}]
content=[{"text": assistant_message}],
metadata=ConversationMessageMetadata({
'citations': chat_completion.citations,
'usage': chat_completion.usage
})
)

except Exception as error:
Expand All @@ -163,19 +183,33 @@ async def handle_streaming_response(self, request_options: Dict[str, Any]) -> Co
try:
stream = self.client.chat.completions.create(**request_options)
accumulated_message = []

for chunk in stream:
if chunk.choices[0].delta.content:

metadata = {
'citations': chunk.citations,
'usage': chunk.usage
}

chunk_content = chunk.choices[0].delta.content
accumulated_message.append(chunk_content)

if self.callbacks:
self.callbacks.on_llm_new_token(chunk_content)
self.callbacks.on_llm_new_token(
ConversationMessage(
role=ParticipantRole.ASSISTANT.value,
content=chunk_content,
metadata=ConversationMessageMetadata(**metadata)
)
)
#yield chunk_content

# Store the complete message in the instance for later access if needed
return ConversationMessage(
role=ParticipantRole.ASSISTANT.value,
content=[{"text": ''.join(accumulated_message)}]
role=ParticipantRole.ASSISTANT.value,
content=[{"text": ''.join(accumulated_message)}],
metadata=ConversationMessageMetadata(**metadata)
)

except Exception as error:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,21 @@ async def retrieve_and_combine_results(self, text, knowledge_base_id=None, retri

@staticmethod
def combine_retrieval_results(retrieval_results):
return "\n".join(
sources = []

sources.extend(
set(result['metadata']['x-amz-bedrock-kb-source-uri']
for result in retrieval_results
if result and result.get('metadata') and isinstance(result['metadata'].get('x-amz-bedrock-kb-source-uri'), str))
)

text = "\n".join(
result['content']['text']
for result in retrieval_results
if result and result.get('content') and isinstance(result['content'].get('text'), str)
)
)

return {
'text': text,
'sources': sources
}
2 changes: 2 additions & 0 deletions python/src/multi_agent_orchestrator/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Module for importing types."""
from .types import (
ConversationMessage,
ConversationMessageMetadata,
ParticipantRole,
TimestampedMessage,
RequestMetadata,
Expand All @@ -19,6 +20,7 @@

__all__ = [
'ConversationMessage',
'ConversationMessageMetadata'
'ParticipantRole',
'TimestampedMessage',
'RequestMetadata',
Expand Down
Loading