Skip to content

Commit

Permalink
Support custom Ollama Host (#2044)
Browse files Browse the repository at this point in the history
* Add ollama config test

Signed-off-by: Marcel Coetzee <[email protected]>

* Merge host and port in single var

Signed-off-by: Marcel Coetzee <[email protected]>

* Remove redundant mypy ingore

Signed-off-by: Marcel Coetzee <[email protected]>

* [test](lancedb): add embedding model env var in test_model_providers

Signed-off-by: Marcel Coetzee <[email protected]>

* [fix](test): remove redundant LanceDB Ollama test case

Signed-off-by: Marcel Coetzee <[email protected]>

* Format

Signed-off-by: Marcel Coetzee <[email protected]>

* [docs](lancedb): update embedding model provider and add custom endpoint support

Signed-off-by: Marcel Coetzee <[email protected]>

---------

Signed-off-by: Marcel Coetzee <[email protected]>
  • Loading branch information
Pipboyguy authored Nov 25, 2024
1 parent bc25a60 commit f13e3f1
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 3 deletions.
2 changes: 2 additions & 0 deletions dlt/destinations/impl/lancedb/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class LanceDBClientConfiguration(DestinationClientDwhConfiguration):
"""Embedding provider used for generating embeddings. Default is "cohere". You can find the full list of
providers at https://github.com/lancedb/lancedb/tree/main/python/python/lancedb/embeddings as well as
https://lancedb.github.io/lancedb/embeddings/default_embedding_functions/."""
embedding_model_provider_host: Optional[str] = None
"""Full host URL with protocol and port (e.g. 'http://localhost:11434'). Uses LanceDB's default if not specified, assuming the provider accepts this parameter."""
embedding_model: str = "embed-english-v3.0"
"""The model used by the embedding provider for generating embeddings.
Check with the embedding provider which options are available.
Expand Down
4 changes: 3 additions & 1 deletion dlt/destinations/impl/lancedb/lancedb_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def __init__(
self.dataset_name = self.config.normalize_dataset_name(self.schema)

embedding_model_provider = self.config.embedding_model_provider
embedding_model_host = self.config.embedding_model_provider_host

# LanceDB doesn't provide a standardized way to set API keys across providers.
# Some use ENV variables and others allow passing api key as an argument.
Expand All @@ -259,12 +260,13 @@ def __init__(
embedding_model_provider,
self.config.credentials.embedding_model_provider_api_key,
)

self.model_func = self.registry.get(embedding_model_provider).create(
name=self.config.embedding_model,
max_retries=self.config.options.max_retries,
api_key=self.config.credentials.api_key,
**({"host": embedding_model_host} if embedding_model_host else {}),
)

self.vector_field_name = self.config.vector_field_name

@property
Expand Down
8 changes: 6 additions & 2 deletions docs/website/docs/dlt-ecosystem/destinations/lancedb.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ Configure the destination in the dlt secrets file located at `~/.dlt/secrets.tom

```toml
[destination.lancedb]
embedding_model_provider = "cohere"
embedding_model = "embed-english-v3.0"
embedding_model_provider = "ollama"
embedding_model = "mxbai-embed-large"
embedding_model_provider_host = "http://localhost:11434" # Optional: custom endpoint for providers that support it

[destination.lancedb.credentials]
uri = ".lancedb"
api_key = "api_key" # API key to connect to LanceDB Cloud. Leave out if you are using LanceDB OSS.
Expand All @@ -47,6 +49,7 @@ embedding_model_provider_api_key = "embedding_model_provider_api_key" # Not need
- The `embedding_model` specifies the model used by the embedding provider for generating embeddings.
Check with the embedding provider which options are available.
Reference https://lancedb.github.io/lancedb/embeddings/default_embedding_functions/.
- The `embedding_model_provider_host` specifies the full host URL with protocol and port for providers that support custom endpoints (like Ollama). If not specified, the provider's default endpoint will be used.
- The `embedding_model_provider_api_key` is the API key for the embedding model provider used to generate embeddings. If you're using a provider that doesn't need authentication, such as Ollama, you don't need to supply this key.

:::info Available model providers
Expand All @@ -61,6 +64,7 @@ embedding_model_provider_api_key = "embedding_model_provider_api_key" # Not need
- "sentence-transformers"
- "huggingface"
- "colbert"
- "ollama"
:::

### Define your data source
Expand Down
44 changes: 44 additions & 0 deletions tests/load/lancedb/test_model_providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
Test intricacies and configuration related to each provider.
"""

import os
from typing import Iterator, Any, Generator

import pytest
from lancedb import DBConnection # type: ignore
from lancedb.embeddings import EmbeddingFunctionRegistry # type: ignore
from lancedb.table import Table # type: ignore

import dlt
from dlt.common.configuration import resolve_configuration
from dlt.common.typing import DictStrStr
from dlt.common.utils import uniq_id
from dlt.destinations.impl.lancedb import lancedb_adapter
from dlt.destinations.impl.lancedb.configuration import LanceDBClientConfiguration
from dlt.destinations.impl.lancedb.lancedb_client import LanceDBClient
from tests.load.utils import drop_active_pipeline_data, sequence_generator
from tests.pipeline.utils import assert_load_info

# Mark all tests as essential, don't remove.
pytestmark = pytest.mark.essential


@pytest.fixture(autouse=True)
def drop_lancedb_data() -> Iterator[Any]:
yield
drop_active_pipeline_data()


def test_lancedb_ollama_endpoint_configuration() -> None:
os.environ["DESTINATION__LANCEDB__EMBEDDING_MODEL_PROVIDER"] = "ollama"
os.environ["DESTINATION__LANCEDB__EMBEDDING_MODEL"] = "nomic-embed-text"
os.environ["DESTINATION__LANCEDB__EMBEDDING_MODEL_PROVIDER_HOST"] = "http://198.163.194.3:24233"

config = resolve_configuration(
LanceDBClientConfiguration()._bind_dataset_name(dataset_name="dataset"),
sections=("destination", "lancedb"),
)
assert config.embedding_model_provider == "ollama"
assert config.embedding_model == "nomic-embed-text"
assert config.embedding_model_provider_host == "http://198.163.194.3:24233"

0 comments on commit f13e3f1

Please sign in to comment.