From 961ade4589df179cc1ba3e2691eed7cb21fad52a Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Wed, 12 Jun 2024 12:53:32 -0400 Subject: [PATCH 1/4] refactor _preprocess_msg into private func --- .../chat_models.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py index 51a9be59..8f0fb2a4 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py @@ -251,7 +251,19 @@ def _set_callback_out( def _custom_preprocess( # todo: remove self, msg_list: Sequence[BaseMessage] ) -> List[Dict[str, str]]: - return [self._preprocess_msg(m) for m in msg_list] + def _preprocess_msg(msg: BaseMessage) -> Dict[str, str]: + if isinstance(msg, BaseMessage): + role_convert = {"ai": "assistant", "human": "user"} + if isinstance(msg, ChatMessage): + role = msg.role + else: + role = msg.type + role = role_convert.get(role, role) + content = self._process_content(msg.content) + return {"role": role, "content": content} + raise ValueError(f"Invalid message: {repr(msg)} of type {type(msg)}") + + return [_preprocess_msg(m) for m in msg_list] def _process_content(self, content: Union[str, List[Union[dict, str]]]) -> str: if isinstance(content, str): @@ -284,18 +296,6 @@ def _process_content(self, content: Union[str, List[Union[dict, str]]]) -> str: raise ValueError(f"Unrecognized message part format: {part}") return "".join(string_array) - def _preprocess_msg(self, msg: BaseMessage) -> Dict[str, str]: # todo: remove - if isinstance(msg, BaseMessage): - role_convert = {"ai": "assistant", "human": "user"} - if isinstance(msg, ChatMessage): - role = msg.role - else: - role = msg.type - role = role_convert.get(role, role) - content = self._process_content(msg.content) - return {"role": role, "content": content} - raise ValueError(f"Invalid message: {repr(msg)} of type {type(msg)}") - def _custom_postprocess(self, msg: dict) -> dict: # todo: remove kw_left = msg.copy() out_dict = { From 17dc40f408e01dd246a2e2d62a60fe3fd978076e Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Wed, 12 Jun 2024 13:01:50 -0400 Subject: [PATCH 2/4] refactor _get_filled_chunk inline --- .../langchain_nvidia_ai_endpoints/chat_models.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py index 8f0fb2a4..4aeb007f 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py @@ -217,10 +217,6 @@ def _generate( generation = ChatGeneration(message=message) return ChatResult(generations=[generation], llm_output=responses) - def _get_filled_chunk(self, **kwargs: Any) -> ChatGenerationChunk: - """Fill the generation chunk.""" - return ChatGenerationChunk(message=ChatMessageChunk(**kwargs)) - def _stream( self, messages: List[BaseMessage], @@ -232,7 +228,9 @@ def _stream( inputs = self._custom_preprocess(messages) for response in self._get_stream(inputs=inputs, stop=stop, **kwargs): self._set_callback_out(response, run_manager) - chunk = self._get_filled_chunk(**self._custom_postprocess(response)) + chunk = ChatGenerationChunk( + message=ChatMessageChunk(**self._custom_postprocess(response)) + ) if run_manager: run_manager.on_llm_new_token(chunk.text, chunk=chunk) yield chunk From 7f51f0bfac17941583454a99c45eedf60aa39c3e Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Mon, 17 Jun 2024 10:55:32 -0400 Subject: [PATCH 3/4] refactor inline _prep_payload --- .../chat_models.py | 39 +++++++------------ 1 file changed, 15 insertions(+), 24 deletions(-) diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py index 4aeb007f..04d6260b 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py @@ -356,32 +356,23 @@ def _get_payload( "seed": self.seed, "stop": self.stop, } - # if model_name := self._get_binding_model(): - # attr_kwargs["model"] = model_name attr_kwargs = {k: v for k, v in attr_kwargs.items() if v is not None} new_kwargs = {**attr_kwargs, **kwargs} - return self._prep_payload(inputs=inputs, **new_kwargs) - - def _prep_payload( - self, inputs: Sequence[Dict], **kwargs: Any - ) -> dict: # todo: remove - """Prepares a message or list of messages for the payload""" - messages = [self._prep_msg(m) for m in inputs] - if kwargs.get("stop") is None: - kwargs.pop("stop") - return {"messages": messages, **kwargs} - - def _prep_msg(self, msg: Union[str, dict, BaseMessage]) -> dict: # todo: remove - """Helper Method: Ensures a message is a dictionary with a role and content.""" - if isinstance(msg, str): - # (WFH) this shouldn't ever be reached but leaving this here bcs - # it's a Chesterton's fence I'm unwilling to touch - return dict(role="user", content=msg) - if isinstance(msg, dict): - if msg.get("content", None) is None: - raise ValueError(f"Message {msg} has no content") - return msg - raise ValueError(f"Unknown message received: {msg} of type {type(msg)}") + messages: List[Dict[str, Any]] = [] + for msg in inputs: + if isinstance(msg, str): + # (WFH) this shouldn't ever be reached but leaving this here bcs + # it's a Chesterton's fence I'm unwilling to touch + messages.append(dict(role="user", content=msg)) + elif isinstance(msg, dict): + if msg.get("content", None) is None: + raise ValueError(f"Message {msg} has no content") + messages.append(msg) + else: + raise ValueError(f"Unknown message received: {msg} of type {type(msg)}") + if new_kwargs.get("stop") is None: + new_kwargs.pop("stop") + return {"messages": messages, **new_kwargs} def bind_tools( self, From 4a00aa0fb6bde3500af6b2f5eb3cbe5c56c4319f Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Mon, 17 Jun 2024 11:27:20 -0400 Subject: [PATCH 4/4] remove unused _get_astream --- .../langchain_nvidia_ai_endpoints/chat_models.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py index 04d6260b..e102aad3 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py @@ -10,7 +10,6 @@ import urllib.parse from typing import ( Any, - AsyncIterator, Callable, Dict, Iterator, @@ -334,16 +333,6 @@ def _get_stream( # todo: remove payload = self._get_payload(inputs=inputs, stream=True, **kwargs) return self._client.client.get_req_stream(payload=payload) - def _get_astream( # todo: remove - self, - inputs: Sequence[Dict], - **kwargs: Any, - ) -> AsyncIterator: - """Call to client astream methods with call scope""" - kwargs["stop"] = kwargs.get("stop") or self.stop - payload = self._get_payload(inputs=inputs, stream=True, **kwargs) - return self._client.client.get_req_astream(payload=payload) - def _get_payload( self, inputs: Sequence[Dict], **kwargs: Any ) -> dict: # todo: remove