Skip to content

Commit

Permalink
Merge branch 'dev-v0.3' into remove-deprecated-model-type-param
Browse files Browse the repository at this point in the history
  • Loading branch information
mattf committed Sep 21, 2024
2 parents 0ae895a + 6325c6a commit 0a45ce5
Show file tree
Hide file tree
Showing 31 changed files with 1,193 additions and 743 deletions.
1 change: 0 additions & 1 deletion libs/ai-endpoints/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test

lint lint_diff lint_package lint_tests:
./scripts/check_pydantic.sh .
./scripts/lint_imports.sh
poetry run ruff .
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
Expand Down
334 changes: 89 additions & 245 deletions libs/ai-endpoints/docs/chat/nvidia_ai_endpoints.ipynb

Large diffs are not rendered by default.

151 changes: 82 additions & 69 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,41 +20,47 @@
from urllib.parse import urlparse, urlunparse

import requests
from langchain_core.pydantic_v1 import (
from pydantic import (
BaseModel,
ConfigDict,
Field,
PrivateAttr,
SecretStr,
root_validator,
validator,
field_validator,
)
from requests.models import Response

from langchain_nvidia_ai_endpoints._statics import MODEL_TABLE, Model, determine_model

logger = logging.getLogger(__name__)

_API_KEY_VAR = "NVIDIA_API_KEY"
_BASE_URL_VAR = "NVIDIA_BASE_URL"


class _NVIDIAClient(BaseModel):
"""
Low level client library interface to NIM endpoints.
"""

default_hosted_model_name: str = Field(..., description="Default model name to use")
model_name: Optional[str] = Field(..., description="Name of the model to invoke")
# "mdl_name" because "model_" is a protected namespace in pydantic
mdl_name: Optional[str] = Field(..., description="Name of the model to invoke")
model: Optional[Model] = Field(None, description="The model to invoke")
is_hosted: bool = Field(True)
cls: str = Field(..., description="Class Name")

# todo: add a validator for requests.Response (last_response attribute) and
# remove arbitrary_types_allowed=True
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)

## Core defaults. These probably should not be changed
_api_key_var = "NVIDIA_API_KEY"
base_url: str = Field(
...,
default_factory=lambda: os.getenv(
_BASE_URL_VAR, "https://integrate.api.nvidia.com/v1"
),
description="Base URL for standard inference",
)
infer_path: str = Field(
Expand All @@ -71,13 +77,28 @@ class Config:
)
get_session_fn: Callable = Field(requests.Session)

api_key: Optional[SecretStr] = Field(description="API Key for service of choice")
api_key: Optional[SecretStr] = Field(
default_factory=lambda: SecretStr(
os.getenv(_API_KEY_VAR, "INTERNAL_LCNVAIE_ERROR")
)
if _API_KEY_VAR in os.environ
else None,
description="API Key for service of choice",
)

## Generation arguments
timeout: float = Field(60, ge=0, description="Timeout for waiting on response (s)")
interval: float = Field(0.02, ge=0, description="Interval for pulling response")
timeout: float = Field(
60,
ge=0,
description="The minimum amount of time (in sec) to poll after a 202 response",
)
interval: float = Field(
0.02,
ge=0,
description="Interval (in sec) between polling attempts after a 202 response",
)
last_inputs: Optional[dict] = Field(
description="Last inputs sent over to the server"
default={}, description="Last inputs sent over to the server"
)
last_response: Response = Field(
None, description="Last response sent from the server"
Expand All @@ -103,47 +124,25 @@ class Config:
###################################################################################
################### Validation and Initialization #################################

@validator("base_url")
@field_validator("base_url")
def _validate_base_url(cls, v: str) -> str:
## Making sure /v1 in added to the url
if v is not None:
result = urlparse(v)
expected_format = "Expected format is 'http://host:port'."
# Ensure scheme and netloc (domain name) are present
if not (result.scheme and result.netloc):
raise ValueError(f"Invalid base_url format. {expected_format} Got: {v}")
return v

@root_validator(pre=True)
def _preprocess_args(cls, values: Dict[str, Any]) -> Dict[str, Any]:
# if api_key is not provided or None,
# try to get it from the environment
# we can't use Field(default_factory=...)
# because construction may happen with api_key=None
if values.get("api_key") is None:
values["api_key"] = os.getenv(cls._api_key_var)

## Making sure /v1 in added to the url, followed by infer_path
if "base_url" in values:
base_url = values["base_url"].strip("/")
parsed = urlparse(base_url)
expected_format = "Expected format is: http://host:port"
parsed = urlparse(v)

# Ensure scheme and netloc (domain name) are present
if not (parsed.scheme and parsed.netloc):
raise ValueError(
f"Invalid base_url format. {expected_format} Got: {base_url}"
)
expected_format = "Expected format is: http://host:port"
raise ValueError(f"Invalid base_url format. {expected_format} Got: {v}")

if base_url.endswith(
if v.strip("/").endswith(
("/embeddings", "/completions", "/rankings", "/reranking")
):
warnings.warn(f"Using {base_url}, ignoring the rest")
warnings.warn(f"Using {v}, ignoring the rest")

values["base_url"] = base_url = urlunparse(
(parsed.scheme, parsed.netloc, "v1", None, None, None)
)
values["infer_path"] = values["infer_path"].format(base_url=base_url)
v = urlunparse((parsed.scheme, parsed.netloc, "v1", None, None, None))

return values
return v

# final validation after model is constructed
# todo: when pydantic v2 is available,
Expand All @@ -165,10 +164,10 @@ def __init__(self, **kwargs: Any):
)

# set default model for hosted endpoint
if not self.model_name:
self.model_name = self.default_hosted_model_name
if not self.mdl_name:
self.mdl_name = self.default_hosted_model_name

if model := determine_model(self.model_name):
if model := determine_model(self.mdl_name):
if not model.client:
warnings.warn(f"Unable to determine validity of {model.id}")
elif model.client != self.cls:
Expand All @@ -186,37 +185,37 @@ def __init__(self, **kwargs: Any):
candidates = [
model
for model in self.available_models
if model.id == self.model_name
if model.id == self.mdl_name
]
assert len(candidates) <= 1, (
f"Multiple candidates for {self.model_name} "
f"Multiple candidates for {self.mdl_name} "
f"in `available_models`: {candidates}"
)
if candidates:
model = candidates[0]
warnings.warn(
f"Found {self.model_name} in available_models, but type is "
f"Found {self.mdl_name} in available_models, but type is "
"unknown and inference may fail."
)
else:
raise ValueError(
f"Model {self.model_name} is unknown, check `available_models`"
f"Model {self.mdl_name} is unknown, check `available_models`"
)
self.model = model
self.model_name = self.model.id # name may change because of aliasing
self.mdl_name = self.model.id # name may change because of aliasing
else:
# set default model
if not self.model_name:
if not self.mdl_name:
valid_models = [
model
for model in self.available_models
if not model.base_model or model.base_model == model.id
]
self.model = next(iter(valid_models), None)
if self.model:
self.model_name = self.model.id
self.mdl_name = self.model.id
warnings.warn(
f"Default model is set as: {self.model_name}. \n"
f"Default model is set as: {self.mdl_name}. \n"
"Set model using model parameter. \n"
"To get available models use available_models property.",
UserWarning,
Expand All @@ -233,15 +232,15 @@ def is_lc_serializable(cls) -> bool:

@property
def lc_secrets(self) -> Dict[str, str]:
return {"api_key": self._api_key_var}
return {"api_key": _API_KEY_VAR}

@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}
attributes["base_url"] = self.base_url

if self.model_name:
attributes["model"] = self.model_name
if self.mdl_name:
attributes["model"] = self.mdl_name

return attributes

Expand Down Expand Up @@ -332,11 +331,15 @@ def _post(
self,
invoke_url: str,
payload: Optional[dict] = {},
extra_headers: dict = {},
) -> Tuple[Response, requests.Session]:
"""Method for posting to the AI Foundation Model Function API."""
self.last_inputs = {
"url": invoke_url,
"headers": self.headers_tmpl["call"],
"headers": {
**self.headers_tmpl["call"],
**extra_headers,
},
"json": payload,
}
session = self.get_session_fn()
Expand Down Expand Up @@ -372,9 +375,7 @@ def _wait(self, response: Response, session: requests.Session) -> Response:
start_time = time.time()
# note: the local NIM does not return a 202 status code
# (per RL 22may2024 circa 24.05)
while (
response.status_code == 202
): # todo: there are no tests that reach this point
while response.status_code == 202:
time.sleep(self.interval)
if (time.time() - start_time) > self.timeout:
raise TimeoutError(
Expand All @@ -385,10 +386,12 @@ def _wait(self, response: Response, session: requests.Session) -> Response:
"NVCF-REQID" in response.headers
), "Received 202 response with no request id to follow"
request_id = response.headers.get("NVCF-REQID")
# todo: this needs testing, missing auth header update
payload = {
"url": self.polling_url_tmpl.format(request_id=request_id),
"headers": self.headers_tmpl["call"],
}
self.last_response = response = session.get(
self.polling_url_tmpl.format(request_id=request_id),
headers=self.headers_tmpl["call"],
**self.__add_authorization(payload)
)
self._try_raise(response)
return response
Expand Down Expand Up @@ -444,9 +447,12 @@ def _try_raise(self, response: Response) -> None:
def get_req(
self,
payload: dict = {},
extra_headers: dict = {},
) -> Response:
"""Post to the API."""
response, session = self._post(self.infer_url, payload)
response, session = self._post(
self.infer_url, payload, extra_headers=extra_headers
)
return self._wait(response, session)

def postprocess(
Expand Down Expand Up @@ -485,7 +491,10 @@ def _aggregate_msgs(self, msg_list: Sequence[dict]) -> Tuple[dict, bool]:
usage_holder = msg.get("usage", {}) ####
if "choices" in msg:
## Tease out ['choices'][0]...['delta'/'message']
msg = msg.get("choices", [{}])[0]
# when streaming w/ usage info, we may get a response
# w/ choices: [] that includes final usage info
choices = msg.get("choices", [{}])
msg = choices[0] if choices else {}
# todo: this meeds to be fixed, the fact we only
# use the first choice breaks the interface
finish_reason_holder = msg.get("finish_reason", None)
Expand Down Expand Up @@ -517,18 +526,22 @@ def _aggregate_msgs(self, msg_list: Sequence[dict]) -> Tuple[dict, bool]:
def get_req_stream(
self,
payload: dict,
extra_headers: dict = {},
) -> Iterator[Dict]:
self.last_inputs = {
"url": self.infer_url,
"headers": self.headers_tmpl["stream"],
"headers": {
**self.headers_tmpl["stream"],
**extra_headers,
},
"json": payload,
}

response = self.get_session_fn().post(
stream=True, **self.__add_authorization(self.last_inputs)
)
self._try_raise(response)
call = self.copy()
call: _NVIDIAClient = self.model_copy()

def out_gen() -> Generator[dict, Any, Any]:
## Good for client, since it allows self.last_inputs
Expand Down
Loading

0 comments on commit 0a45ce5

Please sign in to comment.