Skip to content

Commit

Permalink
Merge pull request #102 from langchain-ai/mattf/remove-deprecated-mod…
Browse files Browse the repository at this point in the history
…el-type-param

remove deprecated model_type param (embedding)
  • Loading branch information
mattf authored Sep 21, 2024
2 parents 6325c6a + 0a45ce5 commit c73b9b6
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 36 deletions.
22 changes: 2 additions & 20 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
ConfigDict,
Field,
PrivateAttr,
field_validator,
)

from langchain_nvidia_ai_endpoints._common import _NVIDIAClient
Expand Down Expand Up @@ -50,9 +49,6 @@ class NVIDIAEmbeddings(BaseModel, Embeddings):
),
)
max_batch_size: int = Field(default=_DEFAULT_BATCH_SIZE)
model_type: Optional[Literal["passage", "query"]] = Field(
None, description="(DEPRECATED) The type of text to be embedded."
)

def __init__(self, **kwargs: Any):
"""
Expand Down Expand Up @@ -111,18 +107,6 @@ def __init__(self, **kwargs: Any):
)
self.truncate = "END"

@field_validator("model_type")
def _validate_model_type(
cls, v: Optional[Literal["passage", "query"]]
) -> Optional[Literal["passage", "query"]]:
if v:
warnings.warn(
"Warning: `model_type` is deprecated and will be removed "
"in a future release. Please use `embed_query` or "
"`embed_documents` appropriately."
)
return v

@property
def available_models(self) -> List[Model]:
"""
Expand Down Expand Up @@ -175,7 +159,7 @@ def _embed(

def embed_query(self, text: str) -> List[float]:
"""Input pathway for query embeddings."""
return self._embed([text], model_type=self.model_type or "query")[0]
return self._embed([text], model_type="query")[0]

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Input pathway for document embeddings."""
Expand All @@ -187,9 +171,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
all_embeddings = []
for i in range(0, len(texts), self.max_batch_size):
batch = texts[i : i + self.max_batch_size]
all_embeddings.extend(
self._embed(batch, model_type=self.model_type or "passage")
)
all_embeddings.extend(self._embed(batch, model_type="passage"))
return all_embeddings

def _invoke_callback_vars(self, response: dict) -> None:
Expand Down
2 changes: 0 additions & 2 deletions libs/ai-endpoints/tests/integration_tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,5 @@ def test_embed_nvolveqa_40k_compat(nvolveqa_40k: str, mode: dict) -> None:
assert len(output) > 3


# todo: test model_type ("passage" and embed_query,
# "query" and embed_documents; compare results)
# todo: test max_length > max length accepted by the model
# todo: test max_batch_size > max batch size accepted by the model
15 changes: 1 addition & 14 deletions libs/ai-endpoints/tests/unit_tests/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Generator, Literal
from typing import Any, Generator

import pytest
from requests_mock import Mocker
Expand Down Expand Up @@ -86,17 +86,4 @@ def test_embed_query_truncate_invalid(truncate: Any) -> None:
NVIDIAEmbeddings(truncate=truncate)


@pytest.mark.parametrize("model_type", ["query", "passage"])
def test_embed_model_type_deprecated(model_type: Literal["query", "passage"]) -> None:
with pytest.warns(UserWarning) as record:
NVIDIAEmbeddings(api_key="BOGUS", model_type=model_type)
assert len(record) == 1
assert "`model_type` is deprecated" in str(record[0].message)
x = NVIDIAEmbeddings(api_key="BOGUS")
with pytest.warns(UserWarning) as record:
x.model_type = model_type
assert len(record) == 1
assert "`model_type` is deprecated" in str(record[0].message)


# todo: test max_batch_size (-50, 0, 1, 50)

0 comments on commit c73b9b6

Please sign in to comment.