Skip to content

Commit

Permalink
Merge pull request #18 from langchain-ai/erick/ai-endpoints-patch-sta…
Browse files Browse the repository at this point in the history
…ndard-tests

ai-endpoints[patch]: standard tests
  • Loading branch information
mattf authored Sep 19, 2024
2 parents 2f9b86c + f611783 commit fa4018b
Show file tree
Hide file tree
Showing 5 changed files with 453 additions and 242 deletions.
23 changes: 23 additions & 0 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import BaseChatModel, LanguageModelInput
from langchain_core.language_models.chat_models import LangSmithParams
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
Expand Down Expand Up @@ -282,6 +283,28 @@ def _llm_type(self) -> str:
"""Return type of NVIDIA AI Foundation Model Interface."""
return "chat-nvidia-ai-playground"

def _get_ls_params(
self,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> LangSmithParams:
"""Get standard LangSmith parameters for tracing."""
params = self._get_invocation_params(stop=stop, **kwargs)
return LangSmithParams(
ls_provider="NVIDIA",
# error: Incompatible types (expression has type "Optional[str]",
# TypedDict item "ls_model_name" has type "str") [typeddict-item]
ls_model_name=self.model or "UNKNOWN",
ls_model_type="chat",
ls_temperature=params.get("temperature", self.temperature),
ls_max_tokens=params.get("max_tokens", self.max_tokens),
# mypy error: Extra keys ("ls_top_p", "ls_seed")
# for TypedDict "LangSmithParams" [typeddict-item]
# ls_top_p=params.get("top_p", self.top_p),
# ls_seed=params.get("seed", self.seed),
ls_stop=params.get("stop", self.stop),
)

def _generate(
self,
messages: List[BaseMessage],
Expand Down
Loading

0 comments on commit fa4018b

Please sign in to comment.