Skip to content

Commit

Permalink
Implement Pydantic models from neon-data-models (#9)
Browse files Browse the repository at this point in the history
* WIP implementation of `neon_data_models` for MQ message validation

* Add note to deprecate refactored import
Outline unit tests with automation
Add test extra dependencies

* Add chatbot shutdown to unit tests to allow tests to finish

* Outline `test_llm` TestCase
Add tests of MQ request/response handling in `chatbot` module
Update `chatbot` module to use Pydantic models in place of `dict` objects for MQ message validation

* Add test coverage for `ask_chatbot`, `ask_discusser`, and `ask_appraiser`
Refactor `chatbot` methods to safely handle missing context

* Refactor order-dependent test case

* Define methodology for RMQ unit tests with basic initialization test case

* Troubleshoot permissions error in GHA tests

* Allow test fixture to select the RMQ port to troubleshoot conflicts in GHA runs

* Troubleshoot permissions handling in GHA

* Add RMQ unit test coverage
Include `routing_key` in LLM responses to associate inputs/responses

* Refactor static strings into `constants` module
Update tests to reference constant strings for more specific testing

* Update `neon-data-models` dependency spec
  • Loading branch information
NeonDaniel authored Jan 6, 2025
1 parent dc55b1c commit fd6ef5d
Show file tree
Hide file tree
Showing 11 changed files with 665 additions and 97 deletions.
25 changes: 23 additions & 2 deletions .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,26 @@ on:
jobs:
py_build_tests:
uses: neongeckocom/.github/.github/workflows/python_build_tests.yml@master
with:
python_version: "3.8"
unit_tests:
strategy:
matrix:
python-version: [ 3.9, "3.10", "3.11", "3.12" ]
timeout-minutes: 15
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install apt dependencies
run: |
sudo apt update
sudo apt install -y rabbitmq-server
- name: Install package
run: |
python -m pip install --upgrade pip
pip install .[chatbots,test]
- name: Run Tests
run: |
pytest tests
131 changes: 84 additions & 47 deletions neon_llm_core/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,28 @@
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from typing import List
from typing import List, Optional
from chatbot_core.v2 import ChatBot
from neon_data_models.models.api.mq import (LLMProposeRequest,
LLMDiscussRequest, LLMVoteRequest, LLMProposeResponse, LLMDiscussResponse,
LLMVoteResponse)
from neon_mq_connector.utils.client_utils import send_mq_request
from neon_utils.logger import LOG
from neon_data_models.models.api.llm import LLMPersona

from neon_llm_core.utils.config import LLMMQConfig
from neon_llm_core.utils.constants import DEFAULT_RESPONSE, DEFAULT_VOTE


class LLMBot(ChatBot):

def __init__(self, *args, **kwargs):
ChatBot.__init__(self, *args, **kwargs)
self.bot_type = "submind"
self.base_llm = kwargs.get("llm_name") # chatgpt, fastchat, etc.
self.persona = kwargs.get("persona")
self.base_llm = kwargs["llm_name"] # chatgpt, fastchat, etc.
self.persona = kwargs["persona"]
self.persona = LLMPersona(**self.persona) if \
isinstance(self.persona, dict) else self.persona
self.mq_queue_config = self.get_llm_mq_config(self.base_llm)
LOG.info(f'Initialised config for llm={self.base_llm}|'
f'persona={self._bot_id}')
Expand All @@ -57,13 +64,12 @@ def ask_chatbot(self, user: str, shout: str, timestamp: str,
:param timestamp: formatted timestamp of shout
:param context: message context
"""
prompt_id = context.get('prompt_id')
prompt_id = context.get('prompt_id') if context else None
if prompt_id:
self.prompt_id_to_shout[prompt_id] = shout
LOG.debug(f"Getting response to {shout}")
response = self._get_llm_api_response(
shout=shout).get("response", "I have nothing to say here...")
return response
response = self._get_llm_api_response(shout=shout)
return response.response if response else DEFAULT_RESPONSE

def ask_discusser(self, options: dict, context: dict = None) -> str:
"""
Expand All @@ -73,84 +79,115 @@ def ask_discusser(self, options: dict, context: dict = None) -> str:
:param context: message context
"""
options = {k: v for k, v in options.items() if k != self.service_name}
prompt_sentence = self.prompt_id_to_shout.get(context['prompt_id'], '')
prompt_id = context.get('prompt_id') if context else None
prompt_sentence = None
if prompt_id:
prompt_sentence = self.prompt_id_to_shout.get(prompt_id)
LOG.info(f'prompt_sentence={prompt_sentence}, options={options}')
opinion = self._get_llm_api_opinion(prompt=prompt_sentence,
options=options).get('opinion', '')
return opinion
options=options)
return opinion.opinion if opinion else DEFAULT_RESPONSE

def ask_appraiser(self, options: dict, context: dict = None) -> str:
"""
Selects one of the responses to a prompt and casts a vote in the conversation.
:param options: proposed responses (botname: response)
:param context: message context
"""
# Determine the relevant prompt
prompt_id = context.get('prompt_id') if context else None
prompt_sentence = None
if prompt_id:
prompt_sentence = self.prompt_id_to_shout.get(prompt_id)

# Remove self answer from available options
options = {k: v for k, v in options.items()
if k != self.service_name}

if options:
options = {k: v for k, v in options.items()
if k != self.service_name}
bots = list(options)
bot_responses = list(options.values())
LOG.info(f'bots={bots}, answers={bot_responses}')
prompt = self.prompt_id_to_shout.pop(context['prompt_id'], '')
answer_data = self._get_llm_api_choice(prompt=prompt,
answer_data = self._get_llm_api_choice(prompt=prompt_sentence,
responses=bot_responses)
LOG.info(f'Received answer_data={answer_data}')
sorted_answer_indexes = answer_data.get('sorted_answer_indexes')
if sorted_answer_indexes:
return bots[sorted_answer_indexes[0]]
return "abstain"
if answer_data and answer_data.sorted_answer_indexes:
return bots[answer_data.sorted_answer_indexes[0]]
return DEFAULT_VOTE

def _get_llm_api_response(self, shout: str) -> dict:
def _get_llm_api_response(self, shout: str) -> Optional[LLMProposeResponse]:
"""
Requests LLM API for response on provided shout
:param shout: provided should string
:returns response string from LLM API
:param shout: Input prompt to respond to
:returns response from LLM API
"""
queue = self.mq_queue_config.ask_response_queue
LOG.info(f"Sending to {self.mq_queue_config.vhost}/{queue}")
try:
return send_mq_request(vhost=self.mq_queue_config.vhost,
request_data={"query": shout,
"history": [],
"persona": self.persona},
target_queue=queue,
response_queue=f"{queue}.response")
request_data = LLMProposeRequest(model=self.base_llm,
persona=self.persona,
query=shout,
history=[],
message_id="")
resp_data = send_mq_request(vhost=self.mq_queue_config.vhost,
request_data=request_data.model_dump(),
target_queue=queue,
response_queue=f"{queue}.response")
return LLMProposeResponse(**resp_data)
except Exception as e:
LOG.exception(f"Failed to get response on "
f"{self.mq_queue_config.vhost}/"
f"{self.mq_queue_config.ask_response_queue}: "
f"{e}")
return dict()
f"{self.mq_queue_config.vhost}/{queue}: {e}")

def _get_llm_api_opinion(self, prompt: str, options: dict) -> dict:
def _get_llm_api_opinion(self, prompt: str,
options: dict) -> Optional[LLMDiscussResponse]:
"""
Requests LLM API for opinion on provided submind responses
Requests LLM API for discussion of provided submind responses
:param prompt: incoming prompt text
:param options: proposed responses (botname: response)
:returns response data from LLM API
"""
queue = self.mq_queue_config.ask_discusser_queue
return send_mq_request(vhost=self.mq_queue_config.vhost,
request_data={"query": prompt,
"options": options,
"persona": self.persona},
target_queue=queue,
response_queue=f"{queue}.response")

def _get_llm_api_choice(self, prompt: str, responses: List[str]) -> dict:
try:
request_data = LLMDiscussRequest(model=self.base_llm,
persona=self.persona,
query=prompt,
options=options,
history=[],
message_id="")
resp_data = send_mq_request(vhost=self.mq_queue_config.vhost,
request_data=request_data.model_dump(),
target_queue=queue,
response_queue=f"{queue}.response")
return LLMDiscussResponse(**resp_data)
except Exception as e:
LOG.exception(f"Failed to get response on "
f"{self.mq_queue_config.vhost}/{queue}: {e}")

def _get_llm_api_choice(self, prompt: str,
responses: List[str]) -> Optional[LLMVoteResponse]:
"""
Requests LLM API for choice among provided message list
:param prompt: incoming prompt text
:param responses: list of answers to select from
:returns response data from LLM API
"""
queue = self.mq_queue_config.ask_appraiser_queue
return send_mq_request(vhost=self.mq_queue_config.vhost,
request_data={"query": prompt,
"responses": responses,
"persona": self.persona},
target_queue=queue,
response_queue=f"{queue}.response")

try:
request_data = LLMVoteRequest(model=self.base_llm,
persona=self.persona,
query=prompt,
responses=responses,
history=[],
message_id="")
resp_data = send_mq_request(vhost=self.mq_queue_config.vhost,
request_data=request_data.model_dump(),
target_queue=queue,
response_queue=f"{queue}.response")
return LLMVoteResponse(**resp_data)
except Exception as e:
LOG.exception(f"Failed to get response on "
f"{self.mq_queue_config.vhost}/{queue}: {e}")

@staticmethod
def get_llm_mq_config(llm_name: str) -> LLMMQConfig:
Expand Down
62 changes: 33 additions & 29 deletions neon_llm_core/rmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@

from abc import abstractmethod, ABC
from threading import Thread
from typing import Optional

from neon_mq_connector.connector import MQConnector
from neon_mq_connector.utils.rabbit_utils import create_mq_callback
from neon_utils.logger import LOG

from neon_data_models.models.api.mq import (LLMProposeResponse,
LLMDiscussResponse, LLMVoteResponse)

from neon_llm_core.utils.config import load_config
from neon_llm_core.llm import NeonLLM
from neon_llm_core.utils.constants import LLM_VHOST
Expand All @@ -44,10 +48,10 @@ class NeonLLMMQConnector(MQConnector, ABC):

async_consumers_enabled = True

def __init__(self):
def __init__(self, config: Optional[dict] = None):
self.service_name = f'neon_llm_{self.name}'

self.ovos_config = load_config()
self.ovos_config = config or load_config()
mq_config = self.ovos_config.get("MQ", dict())
super().__init__(config=mq_config, service_name=self.service_name)
self.vhost = LLM_VHOST
Expand Down Expand Up @@ -106,15 +110,17 @@ def model(self) -> NeonLLM:
pass

@create_mq_callback()
def handle_request(self, body: dict):
def handle_request(self, body: dict) -> Thread:
"""
Handles ask requests from MQ to LLM
Handles ask requests (response to prompt) from MQ to LLM
:param body: request body (dict)
"""
# Handle this asynchronously so multiple subminds can be handled
# concurrently
Thread(target=self._handle_request_async, args=(body,),
daemon=True).start()
t = Thread(target=self._handle_request_async, args=(body,),
daemon=True)
t.start()
return t

def _handle_request_async(self, request: dict):
message_id = request["message_id"]
Expand All @@ -131,20 +137,20 @@ def _handle_request_async(self, request: dict):
LOG.error(f'ValueError={err}')
response = ('Sorry, but I cannot respond to your message at the '
'moment, please try again later')
api_response = {
"message_id": message_id,
"response": response
}
api_response = LLMProposeResponse(message_id=message_id,
response=response,
routing_key=routing_key)
LOG.info(f"Sending response: {response}")
self.send_message(request_data=api_response,
self.send_message(request_data=api_response.model_dump(),
queue=routing_key)
LOG.info(f"Handled ask request for message_id={message_id}")

# TODO: Refactor score and opinion to work async like request
@create_mq_callback()
def handle_score_request(self, body: dict):
"""
Handles score requests from MQ to LLM
:param body: request body (dict)
Handles score requests (vote) from MQ to LLM
:param body: request body (dict)
"""
message_id = body["message_id"]
routing_key = body["routing_key"]
Expand All @@ -154,27 +160,27 @@ def handle_score_request(self, body: dict):
persona = body.get("persona", {})

if not responses:
sorted_answer_indexes = []
sorted_answer_idx = []
else:
try:
sorted_answer_indexes = self.model.get_sorted_answer_indexes(
sorted_answer_idx = self.model.get_sorted_answer_indexes(
question=query, answers=responses, persona=persona)
except ValueError as err:
LOG.error(f'ValueError={err}')
sorted_answer_indexes = []
api_response = {
"message_id": message_id,
"sorted_answer_indexes": sorted_answer_indexes
}
self.send_message(request_data=api_response,
sorted_answer_idx = []

api_response = LLMVoteResponse(message_id=message_id,
routing_key=routing_key,
sorted_answer_indexes=sorted_answer_idx)
self.send_message(request_data=api_response.model_dump(),
queue=routing_key)
LOG.info(f"Handled score request for message_id={message_id}")

@create_mq_callback()
def handle_opinion_request(self, body: dict):
"""
Handles opinion requests from MQ to LLM
:param body: request body (dict)
Handles opinion requests (discuss) from MQ to LLM
:param body: request body (dict)
"""
message_id = body["message_id"]
routing_key = body["routing_key"]
Expand All @@ -200,12 +206,10 @@ def handle_opinion_request(self, body: dict):
opinion = ("Sorry, but I experienced an issue trying to form "
"an opinion on this topic")

api_response = {
"message_id": message_id,
"opinion": opinion
}

self.send_message(request_data=api_response,
api_response = LLMDiscussResponse(message_id=message_id,
routing_key=routing_key,
opinion=opinion)
self.send_message(request_data=api_response.model_dump(),
queue=routing_key)
LOG.info(f"Handled ask request for message_id={message_id}")

Expand Down
2 changes: 2 additions & 0 deletions neon_llm_core/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

LLM_VHOST = '/llm'
DEFAULT_RESPONSE = "I have nothing to say here..."
DEFAULT_VOTE = "abstain"
19 changes: 2 additions & 17 deletions neon_llm_core/utils/personas/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,6 @@
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from typing import Optional

from pydantic import BaseModel, computed_field


class PersonaModel(BaseModel):
name: str
description: str
enabled: bool = True
user_id: Optional[str] = None

@computed_field
@property
def id(self) -> str:
persona_id = self.name
if self.user_id:
persona_id += f"_{self.user_id}"
return persona_id
from neon_data_models.models.api.llm import LLMPersona as PersonaModel
# TODO: Mark for deprecation
Loading

0 comments on commit fd6ef5d

Please sign in to comment.