Skip to content

Commit

Permalink
Merge pull request #15 from langchain-ai/mattf/add-api-catalog-embedd…
Browse files Browse the repository at this point in the history
…ings

add api catalog embeddings
  • Loading branch information
mattf authored Apr 2, 2024
2 parents e77f0e7 + 4bea6ec commit 3df8119
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 84 deletions.
59 changes: 30 additions & 29 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/_statics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,46 +6,47 @@
class Model(BaseModel):
id: str
model_type: Optional[str] = None
api_type: Optional[str] = None
model_name: Optional[str] = None
client: Optional[str] = None
path: str


MODEL_SPECS = {
"playground_smaug_72b": {"model_type": "chat"},
"playground_kosmos_2": {"model_type": "image_in"},
"playground_llama2_70b": {"model_type": "chat"},
"playground_nvolveqa_40k": {"model_type": "embedding"},
"playground_nemotron_qa_8b": {"model_type": "qa"},
"playground_gemma_7b": {"model_type": "chat"},
"playground_mistral_7b": {"model_type": "chat"},
"playground_mamba_chat": {"model_type": "chat"},
"playground_phi2": {"model_type": "chat"},
"playground_sdxl": {"model_type": "image_out"},
"playground_nv_llama2_rlhf_70b": {"model_type": "chat"},
"playground_neva_22b": {"model_type": "image_in"},
"playground_yi_34b": {"model_type": "chat"},
"playground_nemotron_steerlm_8b": {"model_type": "chat"},
"playground_cuopt": {"model_type": "cuopt"},
"playground_llama_guard": {"model_type": "classifier"},
"playground_starcoder2_15b": {"model_type": "completion"},
"playground_deplot": {"model_type": "image_in"},
"playground_llama2_code_70b": {"model_type": "chat"},
"playground_gemma_2b": {"model_type": "chat"},
"playground_seamless": {"model_type": "translation"},
"playground_mixtral_8x7b": {"model_type": "chat"},
"playground_fuyu_8b": {"model_type": "image_in"},
"playground_llama2_code_34b": {"model_type": "chat"},
"playground_llama2_code_13b": {"model_type": "chat"},
"playground_steerlm_llama_70b": {"model_type": "chat"},
"playground_clip": {"model_type": "similarity"},
"playground_llama2_13b": {"model_type": "chat"},
"playground_smaug_72b": {"model_type": "chat", "api_type": "aifm"},
"playground_kosmos_2": {"model_type": "image_in", "api_type": "aifm"},
"playground_llama2_70b": {"model_type": "chat", "api_type": "aifm"},
"playground_nvolveqa_40k": {"model_type": "embedding", "api_type": "aifm"},
"playground_nemotron_qa_8b": {"model_type": "qa", "api_type": "aifm"},
"playground_gemma_7b": {"model_type": "chat", "api_type": "aifm"},
"playground_mistral_7b": {"model_type": "chat", "api_type": "aifm"},
"playground_mamba_chat": {"model_type": "chat", "api_type": "aifm"},
"playground_phi2": {"model_type": "chat", "api_type": "aifm"},
"playground_sdxl": {"model_type": "image_out", "api_type": "aifm"},
"playground_nv_llama2_rlhf_70b": {"model_type": "chat", "api_type": "aifm"},
"playground_neva_22b": {"model_type": "image_in", "api_type": "aifm"},
"playground_yi_34b": {"model_type": "chat", "api_type": "aifm"},
"playground_nemotron_steerlm_8b": {"model_type": "chat", "api_type": "aifm"},
"playground_cuopt": {"model_type": "cuopt", "api_type": "aifm"},
"playground_llama_guard": {"model_type": "classifier", "api_type": "aifm"},
"playground_starcoder2_15b": {"model_type": "completion", "api_type": "aifm"},
"playground_deplot": {"model_type": "image_in", "api_type": "aifm"},
"playground_llama2_code_70b": {"model_type": "chat", "api_type": "aifm"},
"playground_gemma_2b": {"model_type": "chat", "api_type": "aifm"},
"playground_seamless": {"model_type": "translation", "api_type": "aifm"},
"playground_mixtral_8x7b": {"model_type": "chat", "api_type": "aifm"},
"playground_fuyu_8b": {"model_type": "image_in", "api_type": "aifm"},
"playground_llama2_code_34b": {"model_type": "chat", "api_type": "aifm"},
"playground_llama2_code_13b": {"model_type": "chat", "api_type": "aifm"},
"playground_steerlm_llama_70b": {"model_type": "chat", "api_type": "aifm"},
"playground_clip": {"model_type": "similarity", "api_type": "aifm"},
"playground_llama2_13b": {"model_type": "chat", "api_type": "aifm"},
}

MODEL_SPECS.update(
{
"ai-codellama-70b": {"model_type": "chat", "model_name": "meta/codellama-70b"},
# 'ai-embedding-2b': {'model_type': 'embedding'},
"ai-embed-qa-4": {"model_type": "embedding", "model_name": "NV-Embed-QA"},
"ai-fuyu-8b": {"model_type": "image_in"},
"ai-gemma-7b": {"model_type": "chat", "model_name": "google/gemma-7b"},
"ai-google-deplot": {"model_type": "image_in"},
Expand Down
40 changes: 35 additions & 5 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Embeddings Components Derived from NVEModel/Embeddings"""

from typing import List, Literal, Optional

from langchain_core.embeddings import Embeddings
Expand All @@ -8,6 +9,8 @@
from langchain_nvidia_ai_endpoints._common import _NVIDIAClient
from langchain_nvidia_ai_endpoints.callbacks import usage_callback_var

from ._statics import MODEL_SPECS


class NVIDIAEmbeddings(_NVIDIAClient, Embeddings):
"""NVIDIA's AI Foundation Retriever Question-Answering Asymmetric Model."""
Expand All @@ -25,13 +28,40 @@ def _embed(
self, texts: List[str], model_type: Literal["passage", "query"]
) -> List[List[float]]:
"""Embed a single text entry to either passage or query type"""
response = self.client.get_req(
model_name=self.model,
payload={
# AI Foundation Model API -
# input: str | list[str] -- <= 2048 characters, <= 50 inputs
# model: "query" | "passage" -- type of input text to be embedded
# encoding_format: "float" | "base64"
# API Catalog API -
# input: str | list[str] -- char limit depends on model
# model: str -- model name, e.g. NV-Embed-QA
# encoding_format: "float" | "base64"
# input_type: "query" | "passage"
# user: str -- ignored
# truncate: "NONE" | "START" | "END" -- default "NONE", error raised if
# an input is too long
# todo: remove the playground aliases
model_name = self.model
if model_name not in MODEL_SPECS:
if f"playground_{model_name}" in MODEL_SPECS:
model_name = f"playground_{model_name}"
if MODEL_SPECS.get(model_name, {}).get("api_type", None) == "aifm":
payload = {
"input": texts,
"model": self.get_binding_model() or model_type,
"model": model_type,
"encoding_format": "float",
},
}
else: # default to the API Catalog API
payload = {
"input": texts,
"model": self.get_binding_model() or self.model,
"encoding_format": "float",
"input_type": model_type,
}

response = self.client.get_req(
model_name=self.model,
payload=payload,
endpoint="infer",
)
response.raise_for_status()
Expand Down
66 changes: 16 additions & 50 deletions libs/ai-endpoints/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions libs/ai-endpoints/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ codespell = "^2.2.0"
optional = true

[tool.poetry.group.test_integration.dependencies]
requests-mock = "^1.11.0"

[tool.poetry.group.lint]
optional = true
Expand Down
7 changes: 7 additions & 0 deletions libs/ai-endpoints/tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ def pytest_addoption(parser: pytest.Parser) -> None:
action="store",
help="Run tests for a specific chat model",
)
parser.addoption(
"--embedding-model-id",
action="store",
help="Run tests for a specific embedding model",
)
parser.addoption(
"--all-models",
action="store_true",
Expand Down Expand Up @@ -60,6 +65,8 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:

if "embedding_model" in metafunc.fixturenames:
models = ["nvolveqa_40k"]
if metafunc.config.getoption("embedding_model_id"):
models = [metafunc.config.getoption("embedding_model_id")]
if metafunc.config.getoption("all_models"):
models = [
model.id
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test ChatNVIDIA chat model."""

import warnings

import pytest
Expand Down
50 changes: 50 additions & 0 deletions libs/ai-endpoints/tests/integration_tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
Note: These tests are designed to validate the functionality of NVIDIAEmbeddings.
"""

import pytest
import requests_mock

from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings


Expand Down Expand Up @@ -46,3 +50,49 @@ async def test_nvai_play_embedding_async_documents(embedding_model: str) -> None
output = await embedding.aembed_documents(documents)
assert len(output) == 3
assert all(len(doc) == 1024 for doc in output)


def test_embed_available_models() -> None:
embedding = NVIDIAEmbeddings()
models = embedding.available_models
assert len(models) >= 2 # nvolveqa_40k and ai-embed-qa-4
assert "nvolveqa_40k" in [model.id for model in models]
assert "ai-embed-qa-4" in [model.id for model in models]


def test_embed_available_models_cached() -> None:
"""Test NVIDIA embeddings for available models."""
pytest.skip("There's a bug that needs to be fixed")
with requests_mock.Mocker(real_http=True) as mock:
embedding = NVIDIAEmbeddings()
assert not mock.called
embedding.available_models
assert mock.called
embedding.available_models
embedding.available_models
assert mock.call_count == 1


def test_embed_long_query_text(embedding_model: str) -> None:
embedding = NVIDIAEmbeddings(model=embedding_model)
text = "nvidia " * 2048
with pytest.raises(Exception):
embedding.embed_query(text)


def test_embed_many_texts(embedding_model: str) -> None:
embedding = NVIDIAEmbeddings(model=embedding_model)
texts = ["nvidia " * 32] * 1000
output = embedding.embed_documents(texts)
assert len(output) == 1000
assert all(len(embedding) == 1024 for embedding in output)


def test_embed_mixed_long_texts(embedding_model: str) -> None:
if embedding_model == "nvolveqa_40k":
pytest.skip("AI Foundation Model trucates by default")
embedding = NVIDIAEmbeddings(model=embedding_model)
texts = ["nvidia " * 32] * 50
texts[42] = "nvidia " * 2048
with pytest.raises(Exception):
embedding.embed_documents(texts)

0 comments on commit 3df8119

Please sign in to comment.