Skip to content

Commit

Permalink
Merge pull request #74 from langchain-ai/mattf/add-chat-tool-calling
Browse files Browse the repository at this point in the history
add chat tool calling
  • Loading branch information
mattf authored Jul 23, 2024
2 parents bb67c4f + a26088e commit e207e02
Show file tree
Hide file tree
Showing 11 changed files with 1,600 additions and 89 deletions.
69 changes: 69 additions & 0 deletions libs/ai-endpoints/docs/chat/nvidia_ai_endpoints.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,75 @@
"source": [
"conversation.invoke(\"Tell me about yourself.\")[\"response\"]"
]
},
{
"cell_type": "markdown",
"id": "f3cbbba0",
"metadata": {},
"source": [
"## Tool calling\n",
"\n",
"Starting in v0.2, `ChatNVIDIA` supports [bind_tools](https://api.python.langchain.com/en/latest/language_models/langchain_core.language_models.chat_models.BaseChatModel.html#langchain_core.language_models.chat_models.BaseChatModel.bind_tools).\n",
"\n",
"`ChatNVIDIA` provides integration with the variety of models on [build.nvidia.com](https://build.nvidia.com) as well as local NIMs. Not all these models are trained for tool calling. Be sure to select a model that does have tool calling for your experimention and applications."
]
},
{
"cell_type": "markdown",
"id": "6f7b535e",
"metadata": {},
"source": [
"You can get a list of models that are known to support tool calling with,"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e36c8911",
"metadata": {},
"outputs": [],
"source": [
"tool_models = [model for model in ChatNVIDIA.get_available_models() if model.supports_tools]\n",
"tool_models"
]
},
{
"cell_type": "markdown",
"id": "b01d75a7",
"metadata": {},
"source": [
"With a tool capable model,"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bd54f174",
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.pydantic_v1 import Field\n",
"from langchain_core.tools import tool\n",
"\n",
"@tool\n",
"def get_current_weather(\n",
" location: str = Field(..., description=\"The location to get the weather for.\")\n",
"):\n",
" \"\"\"Get the current weather for a location.\"\"\"\n",
" ...\n",
"\n",
"llm = ChatNVIDIA(model=tool_models[0].id).bind_tools(tools=[get_current_weather])\n",
"response = llm.invoke(\"What is the weather in Boston?\")\n",
"response.tool_calls"
]
},
{
"cell_type": "markdown",
"id": "e08df68c",
"metadata": {},
"source": [
"See [How to use chat models to call tools](https://python.langchain.com/v0.2/docs/how_to/tool_calling/) for additional examples."
]
}
],
"metadata": {
Expand Down
8 changes: 7 additions & 1 deletion libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,13 +368,17 @@ def _aggregate_msgs(self, msg_list: Sequence[dict]) -> Tuple[dict, bool]:
content_buffer: Dict[str, Any] = dict()
content_holder: Dict[Any, Any] = dict()
usage_holder: Dict[Any, Any] = dict() ####
finish_reason_holder: Optional[str] = None
is_stopped = False
for msg in msg_list:
usage_holder = msg.get("usage", {}) ####
if "choices" in msg:
## Tease out ['choices'][0]...['delta'/'message']
msg = msg.get("choices", [{}])[0]
is_stopped = msg.get("finish_reason", "") == "stop"
# todo: this meeds to be fixed, the fact we only
# use the first choice breaks the interface
finish_reason_holder = msg.get("finish_reason", None)
is_stopped = finish_reason_holder == "stop"
msg = msg.get("delta", msg.get("message", msg.get("text", "")))
if not isinstance(msg, dict):
msg = {"content": msg}
Expand All @@ -392,6 +396,8 @@ def _aggregate_msgs(self, msg_list: Sequence[dict]) -> Tuple[dict, bool]:
content_holder = {**content_holder, **content_buffer}
if usage_holder:
content_holder.update(token_usage=usage_holder) ####
if finish_reason_holder:
content_holder.update(finish_reason=finish_reason_holder)
return content_holder, is_stopped

####################################################################################
Expand Down
21 changes: 21 additions & 0 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/_statics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import warnings
from typing import Literal, Optional

Expand All @@ -13,6 +14,7 @@ class Model(BaseModel):
client: client name, e.g. ChatNVIDIA, NVIDIAEmbeddings, NVIDIARerank
endpoint: custom endpoint for the model
aliases: list of aliases for the model
supports_tools: whether the model supports tool calling
All aliases are deprecated and will trigger a warning when used.
"""
Expand All @@ -25,6 +27,7 @@ class Model(BaseModel):
client: Optional[Literal["ChatNVIDIA", "NVIDIAEmbeddings", "NVIDIARerank"]] = None
endpoint: Optional[str] = None
aliases: Optional[list] = None
supports_tools: Optional[bool] = False
base_model: Optional[str] = None

def __hash__(self) -> int:
Expand Down Expand Up @@ -286,16 +289,19 @@ def validate_client(cls, client: str, values: dict) -> str:
id="meta/llama-3.1-8b-instruct",
model_type="chat",
client="ChatNVIDIA",
supports_tools=True,
),
"meta/llama-3.1-70b-instruct": Model(
id="meta/llama-3.1-70b-instruct",
model_type="chat",
client="ChatNVIDIA",
supports_tools=True,
),
"meta/llama-3.1-405b-instruct": Model(
id="meta/llama-3.1-405b-instruct",
model_type="chat",
client="ChatNVIDIA",
supports_tools=True,
),
}

Expand Down Expand Up @@ -438,6 +444,18 @@ def validate_client(cls, client: str, values: dict) -> str:
# ),
# }


OPENAI_MODEL_TABLE = {
"gpt-3.5-turbo": Model(
id="gpt-3.5-turbo",
model_type="chat",
client="ChatNVIDIA",
endpoint="https://api.openai.com/v1/chat/completions",
supports_tools=True,
),
}


MODEL_TABLE = {
**CHAT_MODEL_TABLE,
**QA_MODEL_TABLE,
Expand All @@ -446,6 +464,9 @@ def validate_client(cls, client: str, values: dict) -> str:
**RANKING_MODEL_TABLE,
}

if "_INCLUDE_OPENAI" in os.environ:
MODEL_TABLE.update(OPENAI_MODEL_TABLE)


def register_model(model: Model) -> None:
"""
Expand Down
Loading

0 comments on commit e207e02

Please sign in to comment.