Skip to content

Commit

Permalink
Add GetOrCreateCollection method
Browse files Browse the repository at this point in the history
  • Loading branch information
BlackyDrum committed Jun 19, 2024
1 parent 91e90b6 commit 4800941
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 1 deletion.
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,22 @@ int main()
- **name**: The name of the collection to retrieve.
- **embeddingFunction**: (Optional) A shared pointer to an embedding function for the collection.

### Get or Create a Collection
To retrieve an existing collection or create a new one in ChromaDB, use the `GetOrCreateCollection` method. This method allows you to specify the name of the collection, optional metadata, and an optional embedding function.

```cpp
#include "ChromaDB/ChromaDB.h"

int main()
{
chromadb::Collection collection = client.GetOrCreateCollection("test_collection");
}
```
**Parameters**
- **name**: The name of the collection to retrieve.
- **metadata**: (Optional) A map of metadata key-value pairs to associate with the collection.
- **embeddingFunction**: (Optional) A shared pointer to an embedding function for the collection.

### Get all Collections
To retrieve all existing collections in ChromaDB, use the `GetCollections` method. This method allows you to specify an optional embedding function that applies to all collections.

Expand Down
13 changes: 13 additions & 0 deletions include/ChromaDB/Client/Client.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,19 @@ namespace chromadb {
* @throw ChromaException if something goes wrong
*/
Collection GetCollection(const std::string& name, std::shared_ptr<EmbeddingFunction> embeddingFunction = nullptr);

/*
* @brief Get or create a collection
*
* @param name The name of the collection
* @param metadata The metadata of the collection (optional)
* @param embeddingFunction The embedding function of the collection (optional)
*
* @return Collection The collection
*
* @throw ChromaException if something goes wrong
*/
Collection GetOrCreateCollection(const std::string& name, const std::unordered_map<std::string, std::string>& metadata = {}, std::shared_ptr<EmbeddingFunction> embeddingFunction = nullptr);

/*
* @brief Get all collections
Expand Down
15 changes: 14 additions & 1 deletion src/ChromaDB/Client/Client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,18 @@ namespace chromadb {
}
}

Collection Client::GetOrCreateCollection(const std::string& name, const std::unordered_map<std::string, std::string>& metadata, std::shared_ptr<EmbeddingFunction> embeddingFunction)
{
try
{
return this->GetCollection(name, embeddingFunction);
}
catch (ChromaException)
{
return this->CreateCollection(name, metadata, embeddingFunction);
}
}

std::vector<Collection> Client::GetCollections(std::shared_ptr<EmbeddingFunction> embeddingFunction)
{
try
Expand Down Expand Up @@ -523,7 +535,8 @@ namespace chromadb {
return { validatedIds, finalEmbeddings, metadata, documents };
}

void Client::handleChromaApiException(const ChromaException& e) {
void Client::handleChromaApiException(const ChromaException& e)
{
const auto* connectException = dynamic_cast<const ChromaConnectionException*>(&e);
if (connectException)
throw ChromaConnectionException(connectException->what());
Expand Down
109 changes: 109 additions & 0 deletions tests/test_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,115 @@ TEST_F(ClientTest, GetCollectionThrowsExceptionIfCollectionDoesNotExist)
EXPECT_THROW(client->GetCollection("test_collection2"), ChromaValueException);
}

TEST_F(ClientTest, CanGetOrCreateCollectionWithoutMetadata)
{
Collection collection = client->GetOrCreateCollection("test_collection");
EXPECT_EQ(collection.GetName(), "test_collection");
EXPECT_EQ(collection.GetMetadata().size(), 0);
EXPECT_EQ(collection.GetId().empty(), false);
EXPECT_EQ(collection.GetEmbeddingFunction(), nullptr);

Collection collection2 = client->GetCollection("test_collection");
EXPECT_EQ(collection2.GetName(), "test_collection");
EXPECT_EQ(collection2.GetMetadata().size(), 0);
EXPECT_EQ(collection2.GetId().empty(), false);
EXPECT_EQ(collection2.GetEmbeddingFunction(), nullptr);
}

TEST_F(ClientTest, CanGetOrCreateCollectionWithMetadata)
{
std::unordered_map<std::string, std::string> metadata = { {"key1", "value1"}, {"key2", "value2"} };
Collection collection = client->GetOrCreateCollection("test_collection", metadata);
EXPECT_EQ(collection.GetName(), "test_collection");
EXPECT_EQ(collection.GetMetadata().size(), 2);
EXPECT_EQ(collection.GetMetadata().at("key1"), "value1");
EXPECT_EQ(collection.GetMetadata().at("key2"), "value2");
EXPECT_EQ(collection.GetId().empty(), false);
EXPECT_EQ(collection.GetEmbeddingFunction(), nullptr);

Collection collection2 = client->GetCollection("test_collection");
EXPECT_EQ(collection2.GetName(), "test_collection");
EXPECT_EQ(collection2.GetMetadata().size(), 2);
EXPECT_EQ(collection2.GetMetadata().at("key1"), "value1");
EXPECT_EQ(collection2.GetMetadata().at("key2"), "value2");
EXPECT_EQ(collection2.GetId().empty(), false);
EXPECT_EQ(collection2.GetEmbeddingFunction(), nullptr);
}

TEST_F(ClientTest, CanGetOrCreateCollectionWithEmbeddingFunction)
{
std::shared_ptr<EmbeddingFunction> embeddingFunction = std::make_shared<JinaEmbeddingFunction>("jina-api-key");
Collection collection = client->GetOrCreateCollection("test_collection", {}, embeddingFunction);
EXPECT_EQ(collection.GetName(), "test_collection");
EXPECT_EQ(collection.GetMetadata().size(), 0);
EXPECT_EQ(collection.GetId().empty(), false);
EXPECT_EQ(collection.GetEmbeddingFunction(), embeddingFunction);

Collection collection2 = client->GetCollection("test_collection");
EXPECT_EQ(collection2.GetName(), "test_collection");
EXPECT_EQ(collection2.GetMetadata().size(), 0);
EXPECT_EQ(collection2.GetId().empty(), false);
EXPECT_EQ(collection2.GetEmbeddingFunction(), nullptr);
collection2.SetEmbeddingFunction(embeddingFunction);
EXPECT_EQ(collection2.GetEmbeddingFunction(), embeddingFunction);
}

TEST_F(ClientTest, CanGetOrCreateCollectionWithMetadataAndEmbeddingFunction)
{
std::unordered_map<std::string, std::string> metadata = { {"key1", "value1"}, {"key2", "value2"} };
std::shared_ptr<EmbeddingFunction> embeddingFunction = std::make_shared<JinaEmbeddingFunction>("jina-api-key");
Collection collection = client->GetOrCreateCollection("test_collection", metadata, embeddingFunction);
EXPECT_EQ(collection.GetName(), "test_collection");
EXPECT_EQ(collection.GetMetadata().size(), 2);
EXPECT_EQ(collection.GetMetadata().at("key1"), "value1");
EXPECT_EQ(collection.GetMetadata().at("key2"), "value2");
EXPECT_EQ(collection.GetId().empty(), false);
EXPECT_EQ(collection.GetEmbeddingFunction(), embeddingFunction);

Collection collection2 = client->GetCollection("test_collection", embeddingFunction);
EXPECT_EQ(collection2.GetName(), "test_collection");
EXPECT_EQ(collection2.GetMetadata().size(), 2);
EXPECT_EQ(collection2.GetMetadata().at("key1"), "value1");
EXPECT_EQ(collection2.GetMetadata().at("key2"), "value2");
EXPECT_EQ(collection2.GetId().empty(), false);
EXPECT_EQ(collection2.GetEmbeddingFunction(), embeddingFunction);
}

TEST_F(ClientTest, GetOrCreateCollectionDoesNotThrowExceptionIfCollectionAlreadyExists)
{
Collection collection = client->CreateCollection("test_collection");

EXPECT_NO_THROW(client->GetOrCreateCollection("test_collection"));
}

TEST_F(ClientTest, GetOrCreateCollectionThrowsExceptionIfInvalidNameProvided)
{
EXPECT_THROW(client->GetOrCreateCollection("te"), ChromaValueException);
}

TEST_F(ClientTest, CanGetOrCreateCollectionWithMetadataAndEmbeddingFunctionIfCollectionAlreadyExists)
{
std::unordered_map<std::string, std::string> metadata = { {"key1", "value1"}, {"key2", "value2"} };
std::shared_ptr<EmbeddingFunction> embeddingFunction = std::make_shared<JinaEmbeddingFunction>("jina-api-key");
client->CreateCollection("test_collection", metadata);

Collection collection = client->GetOrCreateCollection("test_collection", metadata, embeddingFunction);
EXPECT_EQ(collection.GetName(), "test_collection");
EXPECT_EQ(collection.GetMetadata().size(), 2);
EXPECT_EQ(collection.GetMetadata().at("key1"), "value1");
EXPECT_EQ(collection.GetMetadata().at("key2"), "value2");
EXPECT_EQ(collection.GetId().empty(), false);
EXPECT_EQ(collection.GetEmbeddingFunction(), embeddingFunction);

Collection collection2 = client->GetCollection("test_collection", embeddingFunction);
EXPECT_EQ(collection2.GetName(), "test_collection");
EXPECT_EQ(collection2.GetMetadata().size(), 2);
EXPECT_EQ(collection2.GetMetadata().at("key1"), "value1");
EXPECT_EQ(collection2.GetMetadata().at("key2"), "value2");
EXPECT_EQ(collection2.GetId().empty(), false);
EXPECT_EQ(collection2.GetEmbeddingFunction(), embeddingFunction);
}

TEST_F(ClientTest, CanGetCollections)
{
std::vector<Collection> collections = client->GetCollections();
Expand Down

0 comments on commit 4800941

Please sign in to comment.