Skip to content

Commit

Permalink
Merge pull request #10 from BlackyDrum/test-embedding-functions
Browse files Browse the repository at this point in the history
Add tests for embedding function
  • Loading branch information
BlackyDrum authored Nov 5, 2024
2 parents d8fd47c + 477106c commit 8949e1c
Show file tree
Hide file tree
Showing 2 changed files with 223 additions and 1 deletion.
2 changes: 2 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ on:
jobs:
build:
runs-on: ubuntu-latest
env:
JINA_API_KEY: ${{ secrets.JINA_API_KEY }}

services:
chroma-no-auth:
Expand Down
222 changes: 221 additions & 1 deletion tests/test_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,23 @@

#include "ChromaDB/ChromaDB.h"

#include <cstdlib>

#pragma warning(disable: 4996)

using namespace chromadb;

class ClientTest : public ::testing::Test {
protected:
Client* client;

std::string jinaApiKey;
protected:
void SetUp() override
{
client = new Client("http", "localhost", "8080", "test_database", "test_tenant", "");

jinaApiKey = GetEnvVar("JINA_API_KEY");
}

void TearDown() override
Expand All @@ -22,6 +30,12 @@ class ClientTest : public ::testing::Test {

delete client;
}
private:
std::string GetEnvVar(const std::string& name)
{
const char* value = std::getenv(name.c_str());
return value == nullptr ? "" : std::string(value);
}
};

TEST_F(ClientTest, ConstructorInitializesCorrectly)
Expand Down Expand Up @@ -1508,4 +1522,210 @@ TEST_F(ClientTest, ThrowsIfEmptyIdForAdd)
std::vector<std::unordered_map<std::string, std::string>> metadatas = { { {"key1", "value1"} }, { {"key1", "value2"} }, { {"key1", "value3"} } };

EXPECT_THROW(client->AddEmbeddings(collection, ids, embeddings, metadatas, documents), ChromaInvalidArgumentException);
}
}

TEST_F(ClientTest, CanAddEmbeddingsUsingEmbeddingFunctionWithManualGeneration)
{
std::shared_ptr<EmbeddingFunction> embeddingFunction = std::make_shared<JinaEmbeddingFunction>(jinaApiKey, "jina-embeddings-v3");

chromadb::Collection collection = client->CreateCollection("test_collection", {}, embeddingFunction);

std::vector<std::string> ids = { "ID1", "ID2", "ID3" };
std::vector<std::string> documents = { "document1", "document2", "document3" };
auto embeddings = embeddingFunction->Generate(documents);

EXPECT_EQ(embeddings.size(), 3);
EXPECT_EQ(embeddings[0].size(), 1024);
EXPECT_EQ(embeddings[1].size(), 1024);
EXPECT_EQ(embeddings[2].size(), 1024);

client->AddEmbeddings(collection, ids, embeddings, {}, documents);

auto queryResponse = client->GetEmbeddings(collection, ids, { "embeddings", "documents" }, {}, {});
EXPECT_EQ(queryResponse.size(), 3);

EXPECT_EQ(queryResponse[0].id, "ID1");
EXPECT_EQ(queryResponse[0].embeddings->size(), 1024);
EXPECT_EQ(*queryResponse[0].document, "document1");

EXPECT_EQ(queryResponse[1].id, "ID2");
EXPECT_EQ(queryResponse[1].embeddings->size(), 1024);
EXPECT_EQ(*queryResponse[1].document, "document2");

EXPECT_EQ(queryResponse[2].id, "ID3");
EXPECT_EQ(queryResponse[2].embeddings->size(), 1024);
EXPECT_EQ(*queryResponse[2].document, "document3");
}

TEST_F(ClientTest, CanAddEmbeddingsUsingEmbeddingFunctionWithAutoGeneration)
{
std::shared_ptr<EmbeddingFunction> embeddingFunction = std::make_shared<JinaEmbeddingFunction>(jinaApiKey, "jina-embeddings-v3");

chromadb::Collection collection = client->CreateCollection("test_collection", {}, embeddingFunction);

std::vector<std::string> ids = { "ID1", "ID2", "ID3" };
std::vector<std::string> documents = { "document1", "document2", "document3" };

client->AddEmbeddings(collection, ids, {}, {}, documents);

auto queryResponse = client->GetEmbeddings(collection, ids, { "embeddings", "documents" }, {}, {});
EXPECT_EQ(queryResponse.size(), 3);

EXPECT_EQ(queryResponse[0].id, "ID1");
EXPECT_EQ(queryResponse[0].embeddings->size(), 1024);
EXPECT_EQ(*queryResponse[0].document, "document1");

EXPECT_EQ(queryResponse[1].id, "ID2");
EXPECT_EQ(queryResponse[1].embeddings->size(), 1024);
EXPECT_EQ(*queryResponse[1].document, "document2");

EXPECT_EQ(queryResponse[2].id, "ID3");
EXPECT_EQ(queryResponse[2].embeddings->size(), 1024);
EXPECT_EQ(*queryResponse[2].document, "document3");
}

TEST_F(ClientTest, CanAddEmbeddingsUsingEmbeddingFunctionWithManuallySettingEmbeddingFunction)
{
std::shared_ptr<EmbeddingFunction> embeddingFunction = std::make_shared<JinaEmbeddingFunction>(jinaApiKey, "jina-embeddings-v3");

chromadb::Collection collection = client->CreateCollection("test_collection", {});
collection.SetEmbeddingFunction(embeddingFunction);

std::vector<std::string> ids = { "ID1", "ID2", "ID3" };
std::vector<std::string> documents = { "document1", "document2", "document3" };

client->AddEmbeddings(collection, ids, {}, {}, documents);

auto queryResponse = client->GetEmbeddings(collection, ids, { "embeddings", "documents" }, {}, {});
EXPECT_EQ(queryResponse.size(), 3);

EXPECT_EQ(queryResponse[0].id, "ID1");
EXPECT_EQ(queryResponse[0].embeddings->size(), 1024);
EXPECT_EQ(*queryResponse[0].document, "document1");

EXPECT_EQ(queryResponse[1].id, "ID2");
EXPECT_EQ(queryResponse[1].embeddings->size(), 1024);
EXPECT_EQ(*queryResponse[1].document, "document2");

EXPECT_EQ(queryResponse[2].id, "ID3");
EXPECT_EQ(queryResponse[2].embeddings->size(), 1024);
EXPECT_EQ(*queryResponse[2].document, "document3");
}

TEST_F(ClientTest, CanAddEmbeddingsUsingEmbeddingFunctionWithAutoGenerationAndMetadata)
{
std::shared_ptr<EmbeddingFunction> embeddingFunction = std::make_shared<JinaEmbeddingFunction>(jinaApiKey, "jina-embeddings-v3");

chromadb::Collection collection = client->CreateCollection("test_collection", {}, embeddingFunction);

std::vector<std::string> ids = { "ID1", "ID2", "ID3" };
std::vector<std::string> documents = { "document1", "document2", "document3" };
std::vector<std::unordered_map<std::string, std::string>> metadatas = { { {"key1", "value1"} }, { {"key1", "value2"} }, { {"key1", "value3"} } };

client->AddEmbeddings(collection, ids, {}, metadatas, documents);

auto queryResponse = client->GetEmbeddings(collection, ids, { "embeddings", "documents", "metadatas" }, {}, {});
EXPECT_EQ(queryResponse.size(), 3);

EXPECT_EQ(queryResponse[0].id, "ID1");
EXPECT_EQ(queryResponse[0].embeddings->size(), 1024);
EXPECT_EQ(*queryResponse[0].document, "document1");
EXPECT_EQ(queryResponse[0].metadata->size(), 1);
EXPECT_EQ(queryResponse[0].metadata->at("key1"), "value1");

EXPECT_EQ(queryResponse[1].id, "ID2");
EXPECT_EQ(queryResponse[1].embeddings->size(), 1024);
EXPECT_EQ(*queryResponse[1].document, "document2");
EXPECT_EQ(queryResponse[1].metadata->size(), 1);
EXPECT_EQ(queryResponse[1].metadata->at("key1"), "value2");

EXPECT_EQ(queryResponse[2].id, "ID3");
EXPECT_EQ(queryResponse[2].embeddings->size(), 1024);
EXPECT_EQ(*queryResponse[2].document, "document3");
EXPECT_EQ(queryResponse[2].metadata->size(), 1);
EXPECT_EQ(queryResponse[2].metadata->at("key1"), "value3");
}

TEST_F(ClientTest, CanAddEmbeddingsUsingEmbeddingFunctionWithManualGenerationAndMetadata)
{
std::shared_ptr<EmbeddingFunction> embeddingFunction = std::make_shared<JinaEmbeddingFunction>(jinaApiKey, "jina-embeddings-v3");

chromadb::Collection collection = client->CreateCollection("test_collection", {}, embeddingFunction);

std::vector<std::string> ids = { "ID1", "ID2", "ID3" };
std::vector<std::string> documents = { "document1", "document2", "document3" };
auto embeddings = embeddingFunction->Generate(documents);
std::vector<std::unordered_map<std::string, std::string>> metadatas = { { {"key1", "value1"} }, { {"key1", "value2"} }, { {"key1", "value3"} } };

client->AddEmbeddings(collection, ids, embeddings, metadatas, documents);

auto queryResponse = client->GetEmbeddings(collection, ids, { "embeddings", "documents", "metadatas" }, {}, {});
EXPECT_EQ(queryResponse.size(), 3);

EXPECT_EQ(queryResponse[0].id, "ID1");
EXPECT_EQ(queryResponse[0].embeddings->size(), 1024);
EXPECT_EQ(*queryResponse[0].document, "document1");
EXPECT_EQ(queryResponse[0].metadata->size(), 1);
EXPECT_EQ(queryResponse[0].metadata->at("key1"), "value1");

EXPECT_EQ(queryResponse[1].id, "ID2");
EXPECT_EQ(queryResponse[1].embeddings->size(), 1024);
EXPECT_EQ(*queryResponse[1].document, "document2");
EXPECT_EQ(queryResponse[1].metadata->size(), 1);
EXPECT_EQ(queryResponse[1].metadata->at("key1"), "value2");

EXPECT_EQ(queryResponse[2].id, "ID3");
EXPECT_EQ(queryResponse[2].embeddings->size(), 1024);
EXPECT_EQ(*queryResponse[2].document, "document3");
EXPECT_EQ(queryResponse[2].metadata->size(), 1);
EXPECT_EQ(queryResponse[2].metadata->at("key1"), "value3");
}

TEST_F(ClientTest, ThrowsIfWrongApiKeyForEmbeddingFunction)
{
std::shared_ptr<EmbeddingFunction> embeddingFunction = std::make_shared<JinaEmbeddingFunction>("wrong-api-key", "jina-embeddings-v3");

chromadb::Collection collection = client->CreateCollection("test_collection", {}, embeddingFunction);

std::vector<std::string> ids = { "ID1", "ID2", "ID3" };
std::vector<std::string> documents = { "document1", "document2", "document3" };

EXPECT_THROW(client->AddEmbeddings(collection, ids, {}, {}, documents), EmbeddingProviderRequestException);
}

TEST_F(ClientTest, ThrowsIfWrongEmbeddingModelNameForEmbeddingFunction)
{
std::shared_ptr<EmbeddingFunction> embeddingFunction = std::make_shared<JinaEmbeddingFunction>(jinaApiKey, "wrong-embedding-model");

chromadb::Collection collection = client->CreateCollection("test_collection", {}, embeddingFunction);

std::vector<std::string> ids = { "ID1", "ID2", "ID3" };
std::vector<std::string> documents = { "document1", "document2", "document3" };

EXPECT_THROW(client->AddEmbeddings(collection, ids, {}, {}, documents), EmbeddingProviderRequestException);
}

TEST_F(ClientTest, ThrowsIfWrongHostForEmbeddingFunction)
{
std::shared_ptr<EmbeddingFunction> embeddingFunction = std::make_shared<JinaEmbeddingFunction>(jinaApiKey, "jina-embeddings-v3", "wrong-host");

chromadb::Collection collection = client->CreateCollection("test_collection", {}, embeddingFunction);

std::vector<std::string> ids = { "ID1", "ID2", "ID3" };
std::vector<std::string> documents = { "document1", "document2", "document3" };

EXPECT_THROW(client->AddEmbeddings(collection, ids, {}, {}, documents), EmbeddingProviderConnectionException);
}

TEST_F(ClientTest, CanGetRequestMetadataFromEmbeddingFunction)
{
std::shared_ptr<EmbeddingFunction> embeddingFunction = std::make_shared<JinaEmbeddingFunction>(jinaApiKey, "jina-embeddings-v3");
auto embeddings = embeddingFunction->Generate({ "document1", "document2", "document3" });

auto requestMetadata = embeddingFunction->GetRequestMetadata();

EXPECT_EQ(requestMetadata["model"], "jina-embeddings-v3");
EXPECT_EQ(requestMetadata["object"], "list");
EXPECT_GT(requestMetadata["usage"]["prompt_tokens"], 0);
EXPECT_GT(requestMetadata["usage"]["total_tokens"], 0);
}

0 comments on commit 8949e1c

Please sign in to comment.