-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Pydantic models from neon-data-models
#9
Changes from all commits
11b788f
f7e5f1c
6b2dbfa
5e5cccd
c6358dd
55e5f3e
01fdf2e
45437b7
68d06a6
8ad38c1
4166d09
33c6ce7
16fa24a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,21 +24,28 @@ | |
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | ||
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
|
||
from typing import List | ||
from typing import List, Optional | ||
NeonKirill marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from chatbot_core.v2 import ChatBot | ||
from neon_data_models.models.api.mq import (LLMProposeRequest, | ||
LLMDiscussRequest, LLMVoteRequest, LLMProposeResponse, LLMDiscussResponse, | ||
LLMVoteResponse) | ||
from neon_mq_connector.utils.client_utils import send_mq_request | ||
from neon_utils.logger import LOG | ||
from neon_data_models.models.api.llm import LLMPersona | ||
|
||
from neon_llm_core.utils.config import LLMMQConfig | ||
from neon_llm_core.utils.constants import DEFAULT_RESPONSE, DEFAULT_VOTE | ||
|
||
|
||
class LLMBot(ChatBot): | ||
|
||
def __init__(self, *args, **kwargs): | ||
ChatBot.__init__(self, *args, **kwargs) | ||
self.bot_type = "submind" | ||
self.base_llm = kwargs.get("llm_name") # chatgpt, fastchat, etc. | ||
self.persona = kwargs.get("persona") | ||
self.base_llm = kwargs["llm_name"] # chatgpt, fastchat, etc. | ||
self.persona = kwargs["persona"] | ||
self.persona = LLMPersona(**self.persona) if \ | ||
isinstance(self.persona, dict) else self.persona | ||
self.mq_queue_config = self.get_llm_mq_config(self.base_llm) | ||
LOG.info(f'Initialised config for llm={self.base_llm}|' | ||
f'persona={self._bot_id}') | ||
|
@@ -57,13 +64,12 @@ def ask_chatbot(self, user: str, shout: str, timestamp: str, | |
:param timestamp: formatted timestamp of shout | ||
:param context: message context | ||
""" | ||
prompt_id = context.get('prompt_id') | ||
prompt_id = context.get('prompt_id') if context else None | ||
NeonKirill marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if prompt_id: | ||
self.prompt_id_to_shout[prompt_id] = shout | ||
LOG.debug(f"Getting response to {shout}") | ||
response = self._get_llm_api_response( | ||
shout=shout).get("response", "I have nothing to say here...") | ||
return response | ||
response = self._get_llm_api_response(shout=shout) | ||
return response.response if response else DEFAULT_RESPONSE | ||
|
||
def ask_discusser(self, options: dict, context: dict = None) -> str: | ||
""" | ||
|
@@ -73,84 +79,115 @@ def ask_discusser(self, options: dict, context: dict = None) -> str: | |
:param context: message context | ||
""" | ||
options = {k: v for k, v in options.items() if k != self.service_name} | ||
prompt_sentence = self.prompt_id_to_shout.get(context['prompt_id'], '') | ||
prompt_id = context.get('prompt_id') if context else None | ||
prompt_sentence = None | ||
if prompt_id: | ||
prompt_sentence = self.prompt_id_to_shout.get(prompt_id) | ||
LOG.info(f'prompt_sentence={prompt_sentence}, options={options}') | ||
opinion = self._get_llm_api_opinion(prompt=prompt_sentence, | ||
options=options).get('opinion', '') | ||
return opinion | ||
options=options) | ||
return opinion.opinion if opinion else DEFAULT_RESPONSE | ||
|
||
def ask_appraiser(self, options: dict, context: dict = None) -> str: | ||
""" | ||
Selects one of the responses to a prompt and casts a vote in the conversation. | ||
:param options: proposed responses (botname: response) | ||
:param context: message context | ||
""" | ||
# Determine the relevant prompt | ||
prompt_id = context.get('prompt_id') if context else None | ||
prompt_sentence = None | ||
if prompt_id: | ||
prompt_sentence = self.prompt_id_to_shout.get(prompt_id) | ||
|
||
# Remove self answer from available options | ||
options = {k: v for k, v in options.items() | ||
if k != self.service_name} | ||
|
||
if options: | ||
options = {k: v for k, v in options.items() | ||
if k != self.service_name} | ||
bots = list(options) | ||
bot_responses = list(options.values()) | ||
LOG.info(f'bots={bots}, answers={bot_responses}') | ||
prompt = self.prompt_id_to_shout.pop(context['prompt_id'], '') | ||
answer_data = self._get_llm_api_choice(prompt=prompt, | ||
answer_data = self._get_llm_api_choice(prompt=prompt_sentence, | ||
responses=bot_responses) | ||
LOG.info(f'Received answer_data={answer_data}') | ||
sorted_answer_indexes = answer_data.get('sorted_answer_indexes') | ||
if sorted_answer_indexes: | ||
return bots[sorted_answer_indexes[0]] | ||
return "abstain" | ||
if answer_data and answer_data.sorted_answer_indexes: | ||
return bots[answer_data.sorted_answer_indexes[0]] | ||
return DEFAULT_VOTE | ||
|
||
def _get_llm_api_response(self, shout: str) -> dict: | ||
def _get_llm_api_response(self, shout: str) -> Optional[LLMProposeResponse]: | ||
""" | ||
Requests LLM API for response on provided shout | ||
:param shout: provided should string | ||
:returns response string from LLM API | ||
:param shout: Input prompt to respond to | ||
:returns response from LLM API | ||
""" | ||
queue = self.mq_queue_config.ask_response_queue | ||
LOG.info(f"Sending to {self.mq_queue_config.vhost}/{queue}") | ||
try: | ||
return send_mq_request(vhost=self.mq_queue_config.vhost, | ||
request_data={"query": shout, | ||
"history": [], | ||
"persona": self.persona}, | ||
target_queue=queue, | ||
response_queue=f"{queue}.response") | ||
request_data = LLMProposeRequest(model=self.base_llm, | ||
persona=self.persona, | ||
query=shout, | ||
history=[], | ||
message_id="") | ||
resp_data = send_mq_request(vhost=self.mq_queue_config.vhost, | ||
request_data=request_data.model_dump(), | ||
target_queue=queue, | ||
response_queue=f"{queue}.response") | ||
return LLMProposeResponse(**resp_data) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would omit kwargs-like argument passing here, to make it more flexible to any potential updates to the callback body of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should handle any mis-match of keys as part of the Pydantic validation. My intent with the Pydantic models here is that anything moving across MQ (and HTTP and Messagebus) is defined in |
||
except Exception as e: | ||
LOG.exception(f"Failed to get response on " | ||
f"{self.mq_queue_config.vhost}/" | ||
f"{self.mq_queue_config.ask_response_queue}: " | ||
f"{e}") | ||
return dict() | ||
f"{self.mq_queue_config.vhost}/{queue}: {e}") | ||
|
||
def _get_llm_api_opinion(self, prompt: str, options: dict) -> dict: | ||
def _get_llm_api_opinion(self, prompt: str, | ||
options: dict) -> Optional[LLMDiscussResponse]: | ||
""" | ||
Requests LLM API for opinion on provided submind responses | ||
Requests LLM API for discussion of provided submind responses | ||
:param prompt: incoming prompt text | ||
:param options: proposed responses (botname: response) | ||
:returns response data from LLM API | ||
""" | ||
queue = self.mq_queue_config.ask_discusser_queue | ||
return send_mq_request(vhost=self.mq_queue_config.vhost, | ||
request_data={"query": prompt, | ||
"options": options, | ||
"persona": self.persona}, | ||
target_queue=queue, | ||
response_queue=f"{queue}.response") | ||
|
||
def _get_llm_api_choice(self, prompt: str, responses: List[str]) -> dict: | ||
try: | ||
request_data = LLMDiscussRequest(model=self.base_llm, | ||
persona=self.persona, | ||
query=prompt, | ||
options=options, | ||
history=[], | ||
message_id="") | ||
resp_data = send_mq_request(vhost=self.mq_queue_config.vhost, | ||
request_data=request_data.model_dump(), | ||
target_queue=queue, | ||
response_queue=f"{queue}.response") | ||
return LLMDiscussResponse(**resp_data) | ||
NeonKirill marked this conversation as resolved.
Show resolved
Hide resolved
|
||
except Exception as e: | ||
LOG.exception(f"Failed to get response on " | ||
f"{self.mq_queue_config.vhost}/{queue}: {e}") | ||
|
||
def _get_llm_api_choice(self, prompt: str, | ||
responses: List[str]) -> Optional[LLMVoteResponse]: | ||
""" | ||
Requests LLM API for choice among provided message list | ||
:param prompt: incoming prompt text | ||
:param responses: list of answers to select from | ||
:returns response data from LLM API | ||
""" | ||
queue = self.mq_queue_config.ask_appraiser_queue | ||
return send_mq_request(vhost=self.mq_queue_config.vhost, | ||
request_data={"query": prompt, | ||
"responses": responses, | ||
"persona": self.persona}, | ||
target_queue=queue, | ||
response_queue=f"{queue}.response") | ||
|
||
try: | ||
request_data = LLMVoteRequest(model=self.base_llm, | ||
persona=self.persona, | ||
query=prompt, | ||
responses=responses, | ||
history=[], | ||
message_id="") | ||
resp_data = send_mq_request(vhost=self.mq_queue_config.vhost, | ||
request_data=request_data.model_dump(), | ||
target_queue=queue, | ||
response_queue=f"{queue}.response") | ||
return LLMVoteResponse(**resp_data) | ||
NeonKirill marked this conversation as resolved.
Show resolved
Hide resolved
|
||
except Exception as e: | ||
LOG.exception(f"Failed to get response on " | ||
f"{self.mq_queue_config.vhost}/{queue}: {e}") | ||
|
||
@staticmethod | ||
def get_llm_mq_config(llm_name: str) -> LLMMQConfig: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we do not intend to support Python 3.8 in our system anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've been dropping it from tests as it recently reached EoL