Skip to content

Commit

Permalink
Merge branch 'main' into dev-v0.3
Browse files Browse the repository at this point in the history
  • Loading branch information
mattf committed Sep 19, 2024
2 parents 1eeb30f + ac7408e commit 6325c6a
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 11 deletions.
23 changes: 23 additions & 0 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import BaseChatModel, LanguageModelInput
from langchain_core.language_models.chat_models import LangSmithParams
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
Expand Down Expand Up @@ -342,6 +343,28 @@ def _llm_type(self) -> str:
"""Return type of NVIDIA AI Foundation Model Interface."""
return "chat-nvidia-ai-playground"

def _get_ls_params(
self,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> LangSmithParams:
"""Get standard LangSmith parameters for tracing."""
params = self._get_invocation_params(stop=stop, **kwargs)
return LangSmithParams(
ls_provider="NVIDIA",
# error: Incompatible types (expression has type "Optional[str]",
# TypedDict item "ls_model_name" has type "str") [typeddict-item]
ls_model_name=self.model or "UNKNOWN",
ls_model_type="chat",
ls_temperature=params.get("temperature", self.temperature),
ls_max_tokens=params.get("max_tokens", self.max_tokens),
# mypy error: Extra keys ("ls_top_p", "ls_seed")
# for TypedDict "LangSmithParams" [typeddict-item]
# ls_top_p=params.get("top_p", self.top_p),
# ls_seed=params.get("seed", self.seed),
ls_stop=params.get("stop", self.stop),
)

def _generate(
self,
messages: List[BaseMessage],
Expand Down
44 changes: 33 additions & 11 deletions libs/ai-endpoints/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions libs/ai-endpoints/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
langchain-core = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core" }
requests-mock = "^1.11.0"
langchain-standard-tests = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/standard-tests" }
faker = "^24.4.0"

[tool.poetry.group.codespell]
Expand Down
23 changes: 23 additions & 0 deletions libs/ai-endpoints/tests/integration_tests/test_standard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Standard LangChain interface tests"""

from typing import Type

import pytest
from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests

from langchain_nvidia_ai_endpoints import ChatNVIDIA


class TestNVIDIAStandard(ChatModelIntegrationTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatNVIDIA

@property
def chat_model_params(self) -> dict:
return {"model": "meta/llama-3.1-8b-instruct"}

@pytest.mark.xfail(reason="anthropic-style list content not supported")
def test_tool_message_histories_list_content(self, model: BaseChatModel) -> None:
return super().test_tool_message_histories_list_content(model)
18 changes: 18 additions & 0 deletions libs/ai-endpoints/tests/unit_tests/test_standard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Standard LangChain interface tests"""

from typing import Type

from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.unit_tests import ChatModelUnitTests

from langchain_nvidia_ai_endpoints import ChatNVIDIA


class TestNVIDIAStandard(ChatModelUnitTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatNVIDIA

@property
def chat_model_params(self) -> dict:
return {"model": "meta/llama-3.1-8b-instruct"}

0 comments on commit 6325c6a

Please sign in to comment.