Skip to content

Commit

Permalink
Merge pull request #61 from langchain-ai/mattf/payload-refector-0
Browse files Browse the repository at this point in the history
refactor payload preprocessing
  • Loading branch information
mattf authored Jul 1, 2024
2 parents ec3c54c + 4a00aa0 commit 0338598
Showing 1 changed file with 31 additions and 53 deletions.
84 changes: 31 additions & 53 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import urllib.parse
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
Expand Down Expand Up @@ -217,10 +216,6 @@ def _generate(
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation], llm_output=responses)

def _get_filled_chunk(self, **kwargs: Any) -> ChatGenerationChunk:
"""Fill the generation chunk."""
return ChatGenerationChunk(message=ChatMessageChunk(**kwargs))

def _stream(
self,
messages: List[BaseMessage],
Expand All @@ -232,7 +227,9 @@ def _stream(
inputs = self._custom_preprocess(messages)
for response in self._get_stream(inputs=inputs, stop=stop, **kwargs):
self._set_callback_out(response, run_manager)
chunk = self._get_filled_chunk(**self._custom_postprocess(response))
chunk = ChatGenerationChunk(
message=ChatMessageChunk(**self._custom_postprocess(response))
)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
Expand All @@ -251,7 +248,19 @@ def _set_callback_out(
def _custom_preprocess( # todo: remove
self, msg_list: Sequence[BaseMessage]
) -> List[Dict[str, str]]:
return [self._preprocess_msg(m) for m in msg_list]
def _preprocess_msg(msg: BaseMessage) -> Dict[str, str]:
if isinstance(msg, BaseMessage):
role_convert = {"ai": "assistant", "human": "user"}
if isinstance(msg, ChatMessage):
role = msg.role
else:
role = msg.type
role = role_convert.get(role, role)
content = self._process_content(msg.content)
return {"role": role, "content": content}
raise ValueError(f"Invalid message: {repr(msg)} of type {type(msg)}")

return [_preprocess_msg(m) for m in msg_list]

def _process_content(self, content: Union[str, List[Union[dict, str]]]) -> str:
if isinstance(content, str):
Expand Down Expand Up @@ -284,18 +293,6 @@ def _process_content(self, content: Union[str, List[Union[dict, str]]]) -> str:
raise ValueError(f"Unrecognized message part format: {part}")
return "".join(string_array)

def _preprocess_msg(self, msg: BaseMessage) -> Dict[str, str]: # todo: remove
if isinstance(msg, BaseMessage):
role_convert = {"ai": "assistant", "human": "user"}
if isinstance(msg, ChatMessage):
role = msg.role
else:
role = msg.type
role = role_convert.get(role, role)
content = self._process_content(msg.content)
return {"role": role, "content": content}
raise ValueError(f"Invalid message: {repr(msg)} of type {type(msg)}")

def _custom_postprocess(self, msg: dict) -> dict: # todo: remove
kw_left = msg.copy()
out_dict = {
Expand Down Expand Up @@ -336,16 +333,6 @@ def _get_stream( # todo: remove
payload = self._get_payload(inputs=inputs, stream=True, **kwargs)
return self._client.client.get_req_stream(payload=payload)

def _get_astream( # todo: remove
self,
inputs: Sequence[Dict],
**kwargs: Any,
) -> AsyncIterator:
"""Call to client astream methods with call scope"""
kwargs["stop"] = kwargs.get("stop") or self.stop
payload = self._get_payload(inputs=inputs, stream=True, **kwargs)
return self._client.client.get_req_astream(payload=payload)

def _get_payload(
self, inputs: Sequence[Dict], **kwargs: Any
) -> dict: # todo: remove
Expand All @@ -358,32 +345,23 @@ def _get_payload(
"seed": self.seed,
"stop": self.stop,
}
# if model_name := self._get_binding_model():
# attr_kwargs["model"] = model_name
attr_kwargs = {k: v for k, v in attr_kwargs.items() if v is not None}
new_kwargs = {**attr_kwargs, **kwargs}
return self._prep_payload(inputs=inputs, **new_kwargs)

def _prep_payload(
self, inputs: Sequence[Dict], **kwargs: Any
) -> dict: # todo: remove
"""Prepares a message or list of messages for the payload"""
messages = [self._prep_msg(m) for m in inputs]
if kwargs.get("stop") is None:
kwargs.pop("stop")
return {"messages": messages, **kwargs}

def _prep_msg(self, msg: Union[str, dict, BaseMessage]) -> dict: # todo: remove
"""Helper Method: Ensures a message is a dictionary with a role and content."""
if isinstance(msg, str):
# (WFH) this shouldn't ever be reached but leaving this here bcs
# it's a Chesterton's fence I'm unwilling to touch
return dict(role="user", content=msg)
if isinstance(msg, dict):
if msg.get("content", None) is None:
raise ValueError(f"Message {msg} has no content")
return msg
raise ValueError(f"Unknown message received: {msg} of type {type(msg)}")
messages: List[Dict[str, Any]] = []
for msg in inputs:
if isinstance(msg, str):
# (WFH) this shouldn't ever be reached but leaving this here bcs
# it's a Chesterton's fence I'm unwilling to touch
messages.append(dict(role="user", content=msg))
elif isinstance(msg, dict):
if msg.get("content", None) is None:
raise ValueError(f"Message {msg} has no content")
messages.append(msg)
else:
raise ValueError(f"Unknown message received: {msg} of type {type(msg)}")
if new_kwargs.get("stop") is None:
new_kwargs.pop("stop")
return {"messages": messages, **new_kwargs}

def bind_tools(
self,
Expand Down

0 comments on commit 0338598

Please sign in to comment.