Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(tests): use testcontainers-go's Ollama module for running tests with Ollama #1079

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ require (
github.com/testcontainers/testcontainers-go/modules/milvus v0.31.0
github.com/testcontainers/testcontainers-go/modules/mongodb v0.31.0
github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0
github.com/testcontainers/testcontainers-go/modules/ollama v0.31.0
github.com/testcontainers/testcontainers-go/modules/opensearch v0.31.0
github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0
github.com/testcontainers/testcontainers-go/modules/qdrant v0.31.0
Expand Down Expand Up @@ -158,7 +159,6 @@ require (
gitlab.com/golang-commonmark/linkify v0.0.0-20191026162114-a0c2df6c8f82 // indirect
gitlab.com/golang-commonmark/mdurl v0.0.0-20191124015652-932350d1cb84 // indirect
gitlab.com/golang-commonmark/puny v0.0.0-20191124015043-9f83538fa04f // indirect
go.mongodb.org/mongo-driver/v2 v2.0.0-beta1 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.51.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
Expand Down Expand Up @@ -220,6 +220,7 @@ require (
github.com/weaviate/weaviate-go-client/v4 v4.13.1
gitlab.com/golang-commonmark/markdown v0.0.0-20211110145824-bf3e522c626a
go.mongodb.org/mongo-driver v1.14.0
go.mongodb.org/mongo-driver/v2 v2.0.0-beta1
go.starlark.net v0.0.0-20230302034142-4b1e35fe2254
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1
golang.org/x/tools v0.14.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,8 @@ github.com/testcontainers/testcontainers-go/modules/mongodb v0.31.0 h1:0ZAEX50NN
github.com/testcontainers/testcontainers-go/modules/mongodb v0.31.0/go.mod h1:n5KbYAdzD8xJrNVGdPvSacJtwZ4D0Q/byTMI5vR/dk8=
github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0 h1:790+S8ewZYCbG+o8IiFlZ8ZZ33XbNO6zV9qhU6xhlRk=
github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0/go.mod h1:REFmO+lSG9S6uSBEwIMZCxeI36uhScjTwChYADeO3JA=
github.com/testcontainers/testcontainers-go/modules/ollama v0.31.0 h1:oI6DUYdoS/HtoSWxRAYFmLLQi5LbfF9AYVWOF4fHhOQ=
github.com/testcontainers/testcontainers-go/modules/ollama v0.31.0/go.mod h1:4nkgxv1tsEr624OpwgnRshIDCbbtqOvfou9VkztZuLo=
github.com/testcontainers/testcontainers-go/modules/opensearch v0.31.0 h1:sgo2PJb8oCK7ogJjRxAkidXmt+gPzwtyhZpaxSI5wDo=
github.com/testcontainers/testcontainers-go/modules/opensearch v0.31.0/go.mod h1:l4Z7QqGpdk4wTTQk8J8CZ75pfqAz1dizm+LECOLuNVw=
github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0 h1:isAwFS3KNKRbJMbWv+wolWqOFUECmjYZ+sIRZCIBc/E=
Expand Down
82 changes: 68 additions & 14 deletions llms/ollama/ollama_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,61 @@ import (
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go"
tcollama "github.com/testcontainers/testcontainers-go/modules/ollama"
"github.com/tmc/langchaingo/llms"
)

const (
ollamaVersion string = "0.3.13"
llamaModel string = "llama3.2"
llamaModelTag string = "1b" // the 1b model is the smallest model, that fits in CPUs instead of GPUs.
llamaModelName string = llamaModel + ":" + llamaModelTag

// ollamaImage is the Docker image to use for the test container.
// See https://hub.docker.com/r/mdelapenya/llama3.2/tags
ollamaImage string = "mdelapenya/" + llamaModel + ":" + ollamaVersion + "-" + llamaModelTag
)

func runOllama(t *testing.T) (string, error) {
t.Helper()

ctx := context.Background()

// the Ollama container is reused across tests, that's why it defines a fixed container name and reuses it.
ollamaContainer, err := tcollama.RunContainer(
ctx,
testcontainers.WithImage(ollamaImage),
testcontainers.CustomizeRequest(testcontainers.GenericContainerRequest{
ContainerRequest: testcontainers.ContainerRequest{
Name: "ollama-model",
},
Reuse: true,
},
))
if err != nil {
return "", err
}

url, err := ollamaContainer.ConnectionString(ctx)
if err != nil {
return "", err
}
return url, nil
}

func newTestClient(t *testing.T, opts ...Option) *LLM {
t.Helper()
var ollamaModel string
if ollamaModel = os.Getenv("OLLAMA_TEST_MODEL"); ollamaModel == "" {
t.Skip("OLLAMA_TEST_MODEL not set")
return nil
address, err := runOllama(t)
if err != nil {
t.Skip("OLLAMA_TEST_MODEL not set")
return nil
}
ollamaModel = llamaModelName
opts = append(opts, WithServerURL(address))
}

opts = append([]Option{WithModel(ollamaModel)}, opts...)
Expand All @@ -44,17 +88,23 @@ func TestGenerateContent(t *testing.T) {
rsp, err := llm.GenerateContent(context.Background(), content)
require.NoError(t, err)

assert.NotEmpty(t, rsp.Choices)
require.NotEmpty(t, rsp.Choices)
c1 := rsp.Choices[0]
assert.Regexp(t, "feet", strings.ToLower(c1.Content))
require.Regexp(t, "feet", strings.ToLower(c1.Content))
}

func TestWithFormat(t *testing.T) {
t.Parallel()
llm := newTestClient(t, WithFormat("json"))

parts := []llms.ContentPart{
llms.TextContent{Text: "How many feet are in a nautical mile?"},
llms.TextContent{Text: `Please respond with a JSON object in this format:
{
"feet": <number>,
"explanation": "<string explaining the conversion>"
}

How many feet are in a nautical mile?`},
}
content := []llms.MessageContent{
{
Expand All @@ -66,14 +116,18 @@ func TestWithFormat(t *testing.T) {
rsp, err := llm.GenerateContent(context.Background(), content)
require.NoError(t, err)

assert.NotEmpty(t, rsp.Choices)
require.NotEmpty(t, rsp.Choices)
c1 := rsp.Choices[0]
assert.Regexp(t, "feet", strings.ToLower(c1.Content))
require.Regexp(t, "feet", strings.ToLower(c1.Content))

// check whether we got *any* kind of JSON object.
var result map[string]any
err = json.Unmarshal([]byte(c1.Content), &result)
require.NoError(t, err)

// Verify the response contains the expected fields
require.Contains(t, result, "feet", "Response should contain 'feet' field")
require.Contains(t, result, "explanation", "Response should contain 'explanation' field")
}

func TestWithStreaming(t *testing.T) {
Expand All @@ -98,10 +152,10 @@ func TestWithStreaming(t *testing.T) {
}))
require.NoError(t, err)

assert.NotEmpty(t, rsp.Choices)
require.NotEmpty(t, rsp.Choices)
c1 := rsp.Choices[0]
assert.Regexp(t, "feet", strings.ToLower(c1.Content))
assert.Regexp(t, "feet", strings.ToLower(sb.String()))
require.Regexp(t, "feet", strings.ToLower(c1.Content))
require.Regexp(t, "feet", strings.ToLower(sb.String()))
}

func TestWithKeepAlive(t *testing.T) {
Expand All @@ -121,11 +175,11 @@ func TestWithKeepAlive(t *testing.T) {
resp, err := llm.GenerateContent(context.Background(), content)
require.NoError(t, err)

assert.NotEmpty(t, resp.Choices)
require.NotEmpty(t, resp.Choices)
c1 := resp.Choices[0]
assert.Regexp(t, "feet", strings.ToLower(c1.Content))
require.Regexp(t, "feet", strings.ToLower(c1.Content))

vector, err := llm.CreateEmbedding(context.Background(), []string{"test embedding with keep_alive"})
require.NoError(t, err)
assert.NotEmpty(t, vector)
require.NotEmpty(t, vector)
}
49 changes: 47 additions & 2 deletions vectorstores/milvus/milvus_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,52 @@ import (
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go"
tcmilvus "github.com/testcontainers/testcontainers-go/modules/milvus"
tcollama "github.com/testcontainers/testcontainers-go/modules/ollama"
"github.com/tmc/langchaingo/embeddings"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/ollama"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/vectorstores"
)

const (
ollamaVersion string = "0.3.13"
llamaModel string = "llama3.2"
llamaModelTag string = "1b" // the 1b model is the smallest model, that fits in CPUs instead of GPUs.
llamaModelName string = llamaModel + ":" + llamaModelTag

// ollamaImage is the Docker image to use for the test container.
// See https://hub.docker.com/r/mdelapenya/llama3.2/tags
ollamaImage string = "mdelapenya/" + llamaModel + ":" + ollamaVersion + "-" + llamaModelTag
)

func runOllama(t *testing.T) (string, error) {
t.Helper()

ctx := context.Background()

// the Ollama container is reused across tests, that's why it defines a fixed container name and reuses it.
ollamaContainer, err := tcollama.RunContainer(
ctx,
testcontainers.WithImage(ollamaImage),
testcontainers.CustomizeRequest(testcontainers.GenericContainerRequest{
ContainerRequest: testcontainers.ContainerRequest{
Name: "ollama-model-milvus",
},
Reuse: true,
},
))
if err != nil {
return "", err
}

url, err := ollamaContainer.ConnectionString(ctx)
if err != nil {
return "", err
}
return url, nil
}

func getEmbedding(model string, connectionStr ...string) (llms.Model, *embeddings.EmbedderImpl) {
opts := []ollama.Option{ollama.WithModel(model)}
if len(connectionStr) > 0 {
Expand All @@ -40,9 +79,15 @@ func getNewStore(t *testing.T, opts ...Option) (Store, error) {
t.Helper()
ollamaURL := os.Getenv("OLLAMA_HOST")
if ollamaURL == "" {
t.Skip("OLLAMA_HOST not set")
address, err := runOllama(t)
if err != nil {
t.Skip("OLLAMA_HOST not set")
return Store{}, err
}

ollamaURL = address
}
_, e := getEmbedding("gemma:2b")
_, e := getEmbedding(llamaModelName, ollamaURL)

url := os.Getenv("MILVUS_URL")
if url == "" {
Expand Down
Loading
Loading