From f738d4b0a08fa75e91a00d1636c7bc09d71fe1b0 Mon Sep 17 00:00:00 2001 From: David Ankin Date: Thu, 27 Jun 2024 06:34:43 -0400 Subject: [PATCH] fix: improve ollama docs, s/ollama_dir/ollama_home/g --- .../ollama/testcontainers/ollama/__init__.py | 50 +++++++++++++++++-- modules/ollama/tests/test_ollama.py | 4 +- 2 files changed, 47 insertions(+), 7 deletions(-) diff --git a/modules/ollama/testcontainers/ollama/__init__.py b/modules/ollama/testcontainers/ollama/__init__.py index 286aabd5..ea089f14 100644 --- a/modules/ollama/testcontainers/ollama/__init__.py +++ b/modules/ollama/testcontainers/ollama/__init__.py @@ -34,7 +34,13 @@ class OllamaContainer(DockerContainer): """ Ollama Container - Example: + :param: image - the ollama image to use (default: :code:`ollama/ollama:0.1.44`) + :param: ollama_home - the directory to mount for model data (default: None) + + you may pass :code:`pathlib.Path.home() / ".ollama"` to re-use models + that have already been pulled with ollama running on this host outside the container. + + Examples: .. doctest:: @@ -42,6 +48,40 @@ class OllamaContainer(DockerContainer): >>> with OllamaContainer() as ollama: ... ollama.list_models() [] + + .. code-block:: python + + >>> from json import loads + >>> from pathlib import Path + >>> from requests import post + >>> from testcontainers.ollama import OllamaContainer + >>> def split_by_line(generator): + ... data = b'' + ... for each_item in generator: + ... for line in each_item.splitlines(True): + ... data += line + ... if data.endswith((b'\\r\\r', b'\\n\\n', b'\\r\\n\\r\\n', b'\\n')): + ... yield from data.splitlines() + ... data = b'' + ... if data: + ... yield from data.splitlines() + + >>> with OllamaContainer(ollama_home=Path.home() / ".ollama") as ollama: + ... if "llama3:latest" not in [e["name"] for e in ollama.list_models()]: + ... print("did not find 'llama3:latest', pulling") + ... ollama.pull_model("llama3:latest") + ... endpoint = ollama.get_endpoint() + ... for chunk in split_by_line( + ... post(url=f"{endpoint}/api/chat", stream=True, json={ + ... "model": "llama3:latest", + ... "messages": [{ + ... "role": "user", + ... "content": "what color is the sky? MAX ONE WORD" + ... }] + ... }) + ... ): + ... print(loads(chunk)["message"]["content"], end="") + Blue. """ OLLAMA_PORT = 11434 @@ -49,12 +89,12 @@ class OllamaContainer(DockerContainer): def __init__( self, image: str = "ollama/ollama:0.1.44", - ollama_dir: Optional[Union[str, PathLike]] = None, + ollama_home: Optional[Union[str, PathLike]] = None, **kwargs, # ): super().__init__(image=image, **kwargs) - self.ollama_dir = ollama_dir + self.ollama_home = ollama_home self.with_exposed_ports(OllamaContainer.OLLAMA_PORT) self._check_and_add_gpu_capabilities() @@ -67,8 +107,8 @@ def start(self) -> "OllamaContainer": """ Start the Ollama server """ - if self.ollama_dir: - self.with_volume_mapping(self.ollama_dir, "/root/.ollama", "rw") + if self.ollama_home: + self.with_volume_mapping(self.ollama_home, "/root/.ollama", "rw") super().start() wait_for_logs(self, "Listening on ", timeout=30) diff --git a/modules/ollama/tests/test_ollama.py b/modules/ollama/tests/test_ollama.py index 80b22a46..980dac00 100644 --- a/modules/ollama/tests/test_ollama.py +++ b/modules/ollama/tests/test_ollama.py @@ -49,12 +49,12 @@ def test_download_model_and_commit_to_image(): def test_models_saved_in_folder(tmp_path: Path): - with OllamaContainer("ollama/ollama:0.1.26", ollama_dir=tmp_path) as ollama: + with OllamaContainer("ollama/ollama:0.1.26", ollama_home=tmp_path) as ollama: assert len(ollama.list_models()) == 0 ollama.pull_model("all-minilm") assert len(ollama.list_models()) == 1 assert "all-minilm" in ollama.list_models()[0].get("name") - with OllamaContainer("ollama/ollama:0.1.26", ollama_dir=tmp_path) as ollama: + with OllamaContainer("ollama/ollama:0.1.26", ollama_home=tmp_path) as ollama: assert len(ollama.list_models()) == 1 assert "all-minilm" in ollama.list_models()[0].get("name")