Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor payload preprocessing #61

Merged
merged 4 commits into from
Jul 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 31 additions & 53 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import urllib.parse
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
Expand Down Expand Up @@ -217,10 +216,6 @@ def _generate(
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation], llm_output=responses)

def _get_filled_chunk(self, **kwargs: Any) -> ChatGenerationChunk:
"""Fill the generation chunk."""
return ChatGenerationChunk(message=ChatMessageChunk(**kwargs))

def _stream(
self,
messages: List[BaseMessage],
Expand All @@ -232,7 +227,9 @@ def _stream(
inputs = self._custom_preprocess(messages)
for response in self._get_stream(inputs=inputs, stop=stop, **kwargs):
self._set_callback_out(response, run_manager)
chunk = self._get_filled_chunk(**self._custom_postprocess(response))
chunk = ChatGenerationChunk(
message=ChatMessageChunk(**self._custom_postprocess(response))
)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
Expand All @@ -251,7 +248,19 @@ def _set_callback_out(
def _custom_preprocess( # todo: remove
self, msg_list: Sequence[BaseMessage]
) -> List[Dict[str, str]]:
return [self._preprocess_msg(m) for m in msg_list]
def _preprocess_msg(msg: BaseMessage) -> Dict[str, str]:
if isinstance(msg, BaseMessage):
role_convert = {"ai": "assistant", "human": "user"}
if isinstance(msg, ChatMessage):
role = msg.role
else:
role = msg.type
role = role_convert.get(role, role)
content = self._process_content(msg.content)
return {"role": role, "content": content}
raise ValueError(f"Invalid message: {repr(msg)} of type {type(msg)}")

return [_preprocess_msg(m) for m in msg_list]

def _process_content(self, content: Union[str, List[Union[dict, str]]]) -> str:
if isinstance(content, str):
Expand Down Expand Up @@ -284,18 +293,6 @@ def _process_content(self, content: Union[str, List[Union[dict, str]]]) -> str:
raise ValueError(f"Unrecognized message part format: {part}")
return "".join(string_array)

def _preprocess_msg(self, msg: BaseMessage) -> Dict[str, str]: # todo: remove
if isinstance(msg, BaseMessage):
role_convert = {"ai": "assistant", "human": "user"}
if isinstance(msg, ChatMessage):
role = msg.role
else:
role = msg.type
role = role_convert.get(role, role)
content = self._process_content(msg.content)
return {"role": role, "content": content}
raise ValueError(f"Invalid message: {repr(msg)} of type {type(msg)}")

def _custom_postprocess(self, msg: dict) -> dict: # todo: remove
kw_left = msg.copy()
out_dict = {
Expand Down Expand Up @@ -336,16 +333,6 @@ def _get_stream( # todo: remove
payload = self._get_payload(inputs=inputs, stream=True, **kwargs)
return self._client.client.get_req_stream(payload=payload)

def _get_astream( # todo: remove
self,
inputs: Sequence[Dict],
**kwargs: Any,
) -> AsyncIterator:
"""Call to client astream methods with call scope"""
kwargs["stop"] = kwargs.get("stop") or self.stop
payload = self._get_payload(inputs=inputs, stream=True, **kwargs)
return self._client.client.get_req_astream(payload=payload)

def _get_payload(
self, inputs: Sequence[Dict], **kwargs: Any
) -> dict: # todo: remove
Expand All @@ -358,32 +345,23 @@ def _get_payload(
"seed": self.seed,
"stop": self.stop,
}
# if model_name := self._get_binding_model():
# attr_kwargs["model"] = model_name
attr_kwargs = {k: v for k, v in attr_kwargs.items() if v is not None}
new_kwargs = {**attr_kwargs, **kwargs}
return self._prep_payload(inputs=inputs, **new_kwargs)

def _prep_payload(
self, inputs: Sequence[Dict], **kwargs: Any
) -> dict: # todo: remove
"""Prepares a message or list of messages for the payload"""
messages = [self._prep_msg(m) for m in inputs]
if kwargs.get("stop") is None:
kwargs.pop("stop")
return {"messages": messages, **kwargs}

def _prep_msg(self, msg: Union[str, dict, BaseMessage]) -> dict: # todo: remove
"""Helper Method: Ensures a message is a dictionary with a role and content."""
if isinstance(msg, str):
# (WFH) this shouldn't ever be reached but leaving this here bcs
# it's a Chesterton's fence I'm unwilling to touch
return dict(role="user", content=msg)
if isinstance(msg, dict):
if msg.get("content", None) is None:
raise ValueError(f"Message {msg} has no content")
return msg
raise ValueError(f"Unknown message received: {msg} of type {type(msg)}")
messages: List[Dict[str, Any]] = []
for msg in inputs:
if isinstance(msg, str):
# (WFH) this shouldn't ever be reached but leaving this here bcs
# it's a Chesterton's fence I'm unwilling to touch
messages.append(dict(role="user", content=msg))
elif isinstance(msg, dict):
if msg.get("content", None) is None:
raise ValueError(f"Message {msg} has no content")
messages.append(msg)
else:
raise ValueError(f"Unknown message received: {msg} of type {type(msg)}")
if new_kwargs.get("stop") is None:
new_kwargs.pop("stop")
mattf marked this conversation as resolved.
Show resolved Hide resolved
return {"messages": messages, **new_kwargs}

def bind_tools(
self,
Expand Down
Loading