From 48009415e080ac27db7f70174e6d0f6e6b281de4 Mon Sep 17 00:00:00 2001 From: BlackyDrum Date: Tue, 18 Jun 2024 19:27:00 +0200 Subject: [PATCH] Add GetOrCreateCollection method --- README.md | 16 +++++ include/ChromaDB/Client/Client.h | 13 ++++ src/ChromaDB/Client/Client.cpp | 15 ++++- tests/test_client.cpp | 109 +++++++++++++++++++++++++++++++ 4 files changed, 152 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 4784766..494ecb8 100644 --- a/README.md +++ b/README.md @@ -296,6 +296,22 @@ int main() - **name**: The name of the collection to retrieve. - **embeddingFunction**: (Optional) A shared pointer to an embedding function for the collection. +### Get or Create a Collection +To retrieve an existing collection or create a new one in ChromaDB, use the `GetOrCreateCollection` method. This method allows you to specify the name of the collection, optional metadata, and an optional embedding function. + +```cpp +#include "ChromaDB/ChromaDB.h" + +int main() +{ + chromadb::Collection collection = client.GetOrCreateCollection("test_collection"); +} +``` +**Parameters** +- **name**: The name of the collection to retrieve. +- **metadata**: (Optional) A map of metadata key-value pairs to associate with the collection. +- **embeddingFunction**: (Optional) A shared pointer to an embedding function for the collection. + ### Get all Collections To retrieve all existing collections in ChromaDB, use the `GetCollections` method. This method allows you to specify an optional embedding function that applies to all collections. diff --git a/include/ChromaDB/Client/Client.h b/include/ChromaDB/Client/Client.h index 8a6247b..6f86b5c 100644 --- a/include/ChromaDB/Client/Client.h +++ b/include/ChromaDB/Client/Client.h @@ -103,6 +103,19 @@ namespace chromadb { * @throw ChromaException if something goes wrong */ Collection GetCollection(const std::string& name, std::shared_ptr embeddingFunction = nullptr); + + /* + * @brief Get or create a collection + * + * @param name The name of the collection + * @param metadata The metadata of the collection (optional) + * @param embeddingFunction The embedding function of the collection (optional) + * + * @return Collection The collection + * + * @throw ChromaException if something goes wrong + */ + Collection GetOrCreateCollection(const std::string& name, const std::unordered_map& metadata = {}, std::shared_ptr embeddingFunction = nullptr); /* * @brief Get all collections diff --git a/src/ChromaDB/Client/Client.cpp b/src/ChromaDB/Client/Client.cpp index af26a19..e2562eb 100644 --- a/src/ChromaDB/Client/Client.cpp +++ b/src/ChromaDB/Client/Client.cpp @@ -131,6 +131,18 @@ namespace chromadb { } } + Collection Client::GetOrCreateCollection(const std::string& name, const std::unordered_map& metadata, std::shared_ptr embeddingFunction) + { + try + { + return this->GetCollection(name, embeddingFunction); + } + catch (ChromaException) + { + return this->CreateCollection(name, metadata, embeddingFunction); + } + } + std::vector Client::GetCollections(std::shared_ptr embeddingFunction) { try @@ -523,7 +535,8 @@ namespace chromadb { return { validatedIds, finalEmbeddings, metadata, documents }; } - void Client::handleChromaApiException(const ChromaException& e) { + void Client::handleChromaApiException(const ChromaException& e) + { const auto* connectException = dynamic_cast(&e); if (connectException) throw ChromaConnectionException(connectException->what()); diff --git a/tests/test_client.cpp b/tests/test_client.cpp index 520bfcc..d97236a 100644 --- a/tests/test_client.cpp +++ b/tests/test_client.cpp @@ -231,6 +231,115 @@ TEST_F(ClientTest, GetCollectionThrowsExceptionIfCollectionDoesNotExist) EXPECT_THROW(client->GetCollection("test_collection2"), ChromaValueException); } +TEST_F(ClientTest, CanGetOrCreateCollectionWithoutMetadata) +{ + Collection collection = client->GetOrCreateCollection("test_collection"); + EXPECT_EQ(collection.GetName(), "test_collection"); + EXPECT_EQ(collection.GetMetadata().size(), 0); + EXPECT_EQ(collection.GetId().empty(), false); + EXPECT_EQ(collection.GetEmbeddingFunction(), nullptr); + + Collection collection2 = client->GetCollection("test_collection"); + EXPECT_EQ(collection2.GetName(), "test_collection"); + EXPECT_EQ(collection2.GetMetadata().size(), 0); + EXPECT_EQ(collection2.GetId().empty(), false); + EXPECT_EQ(collection2.GetEmbeddingFunction(), nullptr); +} + +TEST_F(ClientTest, CanGetOrCreateCollectionWithMetadata) +{ + std::unordered_map metadata = { {"key1", "value1"}, {"key2", "value2"} }; + Collection collection = client->GetOrCreateCollection("test_collection", metadata); + EXPECT_EQ(collection.GetName(), "test_collection"); + EXPECT_EQ(collection.GetMetadata().size(), 2); + EXPECT_EQ(collection.GetMetadata().at("key1"), "value1"); + EXPECT_EQ(collection.GetMetadata().at("key2"), "value2"); + EXPECT_EQ(collection.GetId().empty(), false); + EXPECT_EQ(collection.GetEmbeddingFunction(), nullptr); + + Collection collection2 = client->GetCollection("test_collection"); + EXPECT_EQ(collection2.GetName(), "test_collection"); + EXPECT_EQ(collection2.GetMetadata().size(), 2); + EXPECT_EQ(collection2.GetMetadata().at("key1"), "value1"); + EXPECT_EQ(collection2.GetMetadata().at("key2"), "value2"); + EXPECT_EQ(collection2.GetId().empty(), false); + EXPECT_EQ(collection2.GetEmbeddingFunction(), nullptr); +} + +TEST_F(ClientTest, CanGetOrCreateCollectionWithEmbeddingFunction) +{ + std::shared_ptr embeddingFunction = std::make_shared("jina-api-key"); + Collection collection = client->GetOrCreateCollection("test_collection", {}, embeddingFunction); + EXPECT_EQ(collection.GetName(), "test_collection"); + EXPECT_EQ(collection.GetMetadata().size(), 0); + EXPECT_EQ(collection.GetId().empty(), false); + EXPECT_EQ(collection.GetEmbeddingFunction(), embeddingFunction); + + Collection collection2 = client->GetCollection("test_collection"); + EXPECT_EQ(collection2.GetName(), "test_collection"); + EXPECT_EQ(collection2.GetMetadata().size(), 0); + EXPECT_EQ(collection2.GetId().empty(), false); + EXPECT_EQ(collection2.GetEmbeddingFunction(), nullptr); + collection2.SetEmbeddingFunction(embeddingFunction); + EXPECT_EQ(collection2.GetEmbeddingFunction(), embeddingFunction); +} + +TEST_F(ClientTest, CanGetOrCreateCollectionWithMetadataAndEmbeddingFunction) +{ + std::unordered_map metadata = { {"key1", "value1"}, {"key2", "value2"} }; + std::shared_ptr embeddingFunction = std::make_shared("jina-api-key"); + Collection collection = client->GetOrCreateCollection("test_collection", metadata, embeddingFunction); + EXPECT_EQ(collection.GetName(), "test_collection"); + EXPECT_EQ(collection.GetMetadata().size(), 2); + EXPECT_EQ(collection.GetMetadata().at("key1"), "value1"); + EXPECT_EQ(collection.GetMetadata().at("key2"), "value2"); + EXPECT_EQ(collection.GetId().empty(), false); + EXPECT_EQ(collection.GetEmbeddingFunction(), embeddingFunction); + + Collection collection2 = client->GetCollection("test_collection", embeddingFunction); + EXPECT_EQ(collection2.GetName(), "test_collection"); + EXPECT_EQ(collection2.GetMetadata().size(), 2); + EXPECT_EQ(collection2.GetMetadata().at("key1"), "value1"); + EXPECT_EQ(collection2.GetMetadata().at("key2"), "value2"); + EXPECT_EQ(collection2.GetId().empty(), false); + EXPECT_EQ(collection2.GetEmbeddingFunction(), embeddingFunction); +} + +TEST_F(ClientTest, GetOrCreateCollectionDoesNotThrowExceptionIfCollectionAlreadyExists) +{ + Collection collection = client->CreateCollection("test_collection"); + + EXPECT_NO_THROW(client->GetOrCreateCollection("test_collection")); +} + +TEST_F(ClientTest, GetOrCreateCollectionThrowsExceptionIfInvalidNameProvided) +{ + EXPECT_THROW(client->GetOrCreateCollection("te"), ChromaValueException); +} + +TEST_F(ClientTest, CanGetOrCreateCollectionWithMetadataAndEmbeddingFunctionIfCollectionAlreadyExists) +{ + std::unordered_map metadata = { {"key1", "value1"}, {"key2", "value2"} }; + std::shared_ptr embeddingFunction = std::make_shared("jina-api-key"); + client->CreateCollection("test_collection", metadata); + + Collection collection = client->GetOrCreateCollection("test_collection", metadata, embeddingFunction); + EXPECT_EQ(collection.GetName(), "test_collection"); + EXPECT_EQ(collection.GetMetadata().size(), 2); + EXPECT_EQ(collection.GetMetadata().at("key1"), "value1"); + EXPECT_EQ(collection.GetMetadata().at("key2"), "value2"); + EXPECT_EQ(collection.GetId().empty(), false); + EXPECT_EQ(collection.GetEmbeddingFunction(), embeddingFunction); + + Collection collection2 = client->GetCollection("test_collection", embeddingFunction); + EXPECT_EQ(collection2.GetName(), "test_collection"); + EXPECT_EQ(collection2.GetMetadata().size(), 2); + EXPECT_EQ(collection2.GetMetadata().at("key1"), "value1"); + EXPECT_EQ(collection2.GetMetadata().at("key2"), "value2"); + EXPECT_EQ(collection2.GetId().empty(), false); + EXPECT_EQ(collection2.GetEmbeddingFunction(), embeddingFunction); +} + TEST_F(ClientTest, CanGetCollections) { std::vector collections = client->GetCollections();