Skip to content

Commit

Permalink
Merge pull request #1 from BlackyDrum/embedding-functions
Browse files Browse the repository at this point in the history
Embedding functions
  • Loading branch information
BlackyDrum authored Jun 16, 2024
2 parents 8ef5668 + 6fb9671 commit be3a6d4
Show file tree
Hide file tree
Showing 15 changed files with 317 additions and 29 deletions.
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ set(SOURCES
src/ChromaDB/Embeddings/EmbeddingFunction.cpp
src/ChromaDB/Embeddings/JinaEmbeddingFunction.cpp
src/ChromaDB/Embeddings/OpenAIEmbeddingFunction.cpp
src/ChromaDB/Embeddings/CohereEmbeddingFunction.cpp
src/ChromaDB/Embeddings/VoyageAIEmbeddingFunction.cpp
src/ChromaDB/Embeddings/TogetherAIEmbeddingFunction.cpp
src/ChromaDB/Exceptions/ChromaException.cpp
)

Expand Down
104 changes: 86 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ To retrieve an existing collection in ChromaDB, use the `GetCollection` method.

int main()
{
Collection collection = client.GetCollection("test_collection");
chromadb::Collection collection = client.GetCollection("test_collection");
std::cout << "Collection name: " << collection.GetName() << std::endl;
std::cout << "Collection id: " << collection.GetId() << std::endl;

Expand All @@ -306,9 +306,9 @@ To retrieve all existing collections in ChromaDB, use the `GetCollections` metho

int main()
{
std::vector<Collection> collections = client.GetCollections();
std::vector<chromadb::Collection> collections = client.GetCollections();

for (Collection& collection : collections)
for (chromadb::Collection& collection : collections)
std::cout << "Collection name: " << collection.GetName() << std::endl;
}
```
Expand Down Expand Up @@ -338,7 +338,7 @@ int main()
{
std::string newName = "test_collection_updated";
std::unordered_map<std::string, std::string> newMetadata = { {"key3", "value3"}, {"key4", "value4"} };
Collection updatedCollection = client.UpdateCollection("test_collection", newName, newMetadata);
chromadb::Collection updatedCollection = client.UpdateCollection("test_collection", newName, newMetadata);

std::cout << updatedCollection.GetName() << std::endl; // "test_collection_updated"
}
Expand Down Expand Up @@ -411,7 +411,7 @@ int main()
{
std::shared_ptr<chromadb::OpenAIEmbeddingFunction> embeddingFunction = std::make_shared<chromadb::OpenAIEmbeddingFunction>("openai-api-key");

Collection collection = client.GetCollection("test_collection", embeddingFunction);
chromadb::Collection collection = client.GetCollection("test_collection", embeddingFunction);

std::vector<std::string> ids = { "ID1", "ID2", "ID3" };
std::vector<std::string> documents = { "document1", "document2", "document3" };
Expand All @@ -420,25 +420,93 @@ int main()
}
```

We currently supports `JinaEmbeddingFunction` and `OpenAIEmbeddingFunction` for this purpose.
We currently supports `JinaEmbeddingFunction`, `OpenAIEmbeddingFunction`, `CohereEmbeddingFunction`, `VoyageAIEmbeddingFunction` and `TogetherAIEmbeddingFunction` for this purpose.

**JinaEmbeddingFunction**
```cpp
#include "ChromaDB/ChromaDB.h"

int main()
{
std::shared_ptr<chromadb::OpenAIEmbeddingFunction> openaiEmbeddingFunction = std::make_shared<chromadb::OpenAIEmbeddingFunction>("openai-api-key");
std::shared_ptr<chromadb::JinaEmbeddingFunction> jinaEmbeddingFunction = std::make_shared<chromadb::JinaEmbeddingFunction>("jina-api-key");
}
```

**Parameters**
- **apiKey**: The API key to access the API.
- **model**: (Optional) The model to use for generating embeddings. Defaults to "text-embedding-3-small" or "jina-embeddings-v2-base-en".
- **baseUrl**: (Optional) The base URL of the API server. Defaults to "api.openai.com" or "api.jina.ai".
- **path**: (Optional) The path of the endpoint for generating embeddings. Defaults to "/v1/embeddings".
- **model**: (Optional) The model to use for generating embeddings. Defaults to `jina-embeddings-v2-base-en`.
- **baseUrl**: (Optional) The base URL of the API server. Defaults to `api.jina.ai`.
- **path**: (Optional) The path of the endpoint for generating embeddings. Defaults to `/v1/embeddings`.

> Note: You can get started immediately by obtaining a free Jina API Key [here](https://jina.ai/embeddings/#apiform)
> Note: You can get started immediately by obtaining a free Jina API Key with 1M Tokens [here](https://jina.ai/embeddings/#apiform)
**OpenAIEmbeddingFunction**
```cpp
#include "ChromaDB/ChromaDB.h"

int main()
{
std::shared_ptr<chromadb::OpenAIEmbeddingFunction> openAIEmbeddingFunction = std::make_shared<chromadb::OpenAIEmbeddingFunction>("openai-api-key");
}
```

**Parameters**
- **apiKey**: The API key to access the API.
- **model**: (Optional) The model to use for generating embeddings. Defaults to `text-embedding-3-small`.
- **dimensions**: (Optional) The number of dimensions of the embeddings. Defaults to `1536`.
- **baseUrl**: (Optional) The base URL of the API server. Defaults to `api.openai.com`.
- **path**: (Optional) The path of the endpoint for generating embeddings. Defaults to `/v1/embeddings`.

**CohereEmbeddingFunction**
```cpp
#include "ChromaDB/ChromaDB.h"

int main()
{
std::shared_ptr<chromadb::CohereEmbeddingFunction> cohereEmbeddingFunction = std::make_shared<chromadb::CohereEmbeddingFunction>("cohere-api-key");
}
```

**Parameters**
- **apiKey**: The API key to access the API.
- **model**: (Optional) The model to use for generating embeddings. Defaults to `embed-english-v3.0`.
- **inputType**: (Optional) The input type passed to the model. Defaults to `classification`.
- **baseUrl**: (Optional) The base URL of the API server. Defaults to `api.cohere.com`.
- **path**: (Optional) The path of the endpoint for generating embeddings. Defaults to `/v1/embed`.


**VoyageAIEmbeddingFunction**
```cpp
#include "ChromaDB/ChromaDB.h"

int main()
{
std::shared_ptr<chromadb::VoyageAIEmbeddingFunction> voyageAIEmbeddingFunction = std::make_shared<chromadb::VoyageAIEmbeddingFunction>("voyageai-api-key");
}
```

**Parameters**
- **apiKey**: The API key to access the API.
- **model**: (Optional) The model to use for generating embeddings. Defaults to `voyage-2`.
- **inputType**: (Optional) The input type passed to the model. Defaults to `document`.
- **baseUrl**: (Optional) The base URL of the API server. Defaults to `api.voyageai.com`.
- **path**: (Optional) The path of the endpoint for generating embeddings. Defaults to `/v1/embeddings`.

**TogetherAIEmbeddingFunction**
```cpp
#include "ChromaDB/ChromaDB.h"

int main()
{
std::shared_ptr<chromadb::TogetherAIEmbeddingFunction> togetherAIEmbeddingFunction = std::make_shared<chromadb::TogetherAIEmbeddingFunction>("togetherai-api-key");
}
```

**Parameters**
- **apiKey**: The API key to access the API.
- **model**: (Optional) The model to use for generating embeddings. Defaults to `togethercomputer/m2-bert-80M-8k-retrieval`.
- **baseUrl**: (Optional) The base URL of the API server. Defaults to `api.together.xyz`.
- **path**: (Optional) The path of the endpoint for generating embeddings. Defaults to `/v1/embeddings`.

### Get Embeddings from a Collection
To retrieve embeddings from an existing collection in ChromaDB, use the `GetEmbeddings` method. This method allows you to specify the collection, optional IDs of the embeddings, and optional filters and fields to include in the result.
Expand All @@ -448,7 +516,7 @@ To retrieve embeddings from an existing collection in ChromaDB, use the `GetEmbe

int main()
{
Collection collection = client.GetCollection("test_collection");
chromadb::Collection collection = client.GetCollection("test_collection");

std::vector<std::string> ids = { "ID1", "ID2", "ID3" };
std::vector<std::vector<double>> embeddings = { { 1.0, 2.0, 3.0 }, { 4.0, 5.0, 6.0 }, { 7.0, 8.0, 9.0 } };
Expand Down Expand Up @@ -543,7 +611,7 @@ The `where_document` filter works similarly to the `where` filter but for filter

int main()
{
Collection collection = client.CreateCollection("test_collection");
chromadb::Collection collection = client.CreateCollection("test_collection");

std::vector<std::string> ids = { "ID1", "ID2", "ID3", "ID4" };
std::vector<std::vector<double>> embeddings = { { 1.0, 2.0, 3.0 }, { 4.0, 5.0, 6.0 }, { 7.0, 8.0, 9.0 }, { 10.0, 11.0, 12.0 } };
Expand Down Expand Up @@ -575,7 +643,7 @@ To retrieve the count of embeddings from an existing collection in ChromaDB, use

int main()
{
Collection collection = client.CreateCollection("test_collection");
chromadb::Collection collection = client.CreateCollection("test_collection");

std::cout << client.GetEmbeddingCount(collection) << std::endl;
}
Expand All @@ -592,7 +660,7 @@ To update embeddings in an existing collection in ChromaDB, use the `UpdateEmbed

int main()
{
Collection collection = client.GetCollection("test_collection");
chromadb::Collection collection = client.GetCollection("test_collection");

std::vector<std::string> ids = { "ID1", "ID2" };
std::vector<std::string> new_documents = { "NewDocument1", "NewDocument2" };
Expand All @@ -617,7 +685,7 @@ To delete embeddings from an existing collection in ChromaDB, use the `DeleteEmb

int main()
{
Collection collection = client.GetCollection("test_collection");
chromadb::Collection collection = client.GetCollection("test_collection");

client.DeleteEmbeddings(collection, { "ID1", "ID3" });
}
Expand All @@ -639,9 +707,9 @@ To query an existing collection in ChromaDB, use the `Query` method. This method

int main()
{
std::shared_ptr<JinaEmbeddingFunction> embeddingFunction = std::make_shared<JinaEmbeddingFunction>("jina-api-key");
std::shared_ptr<chromadb::JinaEmbeddingFunction> embeddingFunction = std::make_shared<chromadb::JinaEmbeddingFunction>("jina-api-key");

Collection collection = client.GetCollection("test_collection", embeddingFunction); // or collection.SetEmbeddingFunction(embeddingFunction);
chromadb::Collection collection = client.GetCollection("test_collection", embeddingFunction); // or collection.SetEmbeddingFunction(embeddingFunction);

auto queryResponse = client.Query(collection, { "This is a query document" }, {}, 3, { "metadatas", "documents", "embeddings", "distances" });

Expand Down
7 changes: 6 additions & 1 deletion include/ChromaDB/ChromaDB.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
#include "ChromaDB/Exceptions/ChromaTypeException.h"
#include "ChromaDB/Exceptions/ChromaUniqueConstraintException.h"
#include "ChromaDB/Exceptions/ChromaValueException.h"
#include "ChromaDB/Exceptions/External/EmbeddingProviderConnectionException.h"
#include "ChromaDB/Exceptions/External/EmbeddingProviderRequestException.h"

#include "ChromaDB/Embeddings/JinaEmbeddingFunction.h"
#include "ChromaDB/Embeddings/OpenAIEmbeddingFunction.h"
#include "ChromaDB/Embeddings/OpenAIEmbeddingFunction.h"
#include "ChromaDB/Embeddings/CohereEmbeddingFunction.h"
#include "ChromaDB/Embeddings/VoyageAIEmbeddingFunction.h"
#include "ChromaDB/Embeddings/TogetherAIEmbeddingFunction.h"
36 changes: 36 additions & 0 deletions include/ChromaDB/Embeddings/CohereEmbeddingFunction.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#pragma once

#include "ChromaDB/Embeddings/EmbeddingFunction.h"

namespace chromadb {

class CohereEmbeddingFunction : public EmbeddingFunction
{
public:
/*
* @brief Construct a new CohereEmbeddingFunction object
*
* @param apiKey The API key
* @param model The model to use (optional)
* @param inputType The input type (optional)
* @param baseUrl The base URL of the server (optional)
* @param path The path of the endpoint (optional)
*/
CohereEmbeddingFunction(const std::string& apiKey, const std::string& model = "embed-english-v3.0", const std::string& inputType = "classification", const std::string & baseUrl = "api.cohere.com", const std::string & path = "/v1/embed");

/*
* @brief Generate the embeddings of the documents
*
* @param documents The documents to generate the embeddings
*
* @return std::vector<std::vector<double>> The embeddings of the documents
*
* @throw ChromaException if something goes wrong
*/
std::vector<std::vector<double>> Generate(const std::vector<std::string>& documents);

private:
std::string m_InputType;
};

} // namespace chromadb
3 changes: 2 additions & 1 deletion include/ChromaDB/Embeddings/EmbeddingFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

#define CPPHTTPLIB_OPENSSL_SUPPORT

#include "ChromaDB/Exceptions/ChromaException.h"
#include "ChromaDB/Exceptions/External/EmbeddingProviderConnectionException.h"
#include "ChromaDB/Exceptions/External/EmbeddingProviderRequestException.h"

#include "Json/json.h"
#include "Http/httplib.h"
Expand Down
6 changes: 5 additions & 1 deletion include/ChromaDB/Embeddings/OpenAIEmbeddingFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ namespace chromadb {
*
* @param apiKey The API key
* @param model The model to use (optional)
* @param dimensions The number of dimensions of the embeddings (optional)
* @param baseUrl The base URL of the server (optional)
* @param path The path of the endpoint (optional)
*/
OpenAIEmbeddingFunction(const std::string& apiKey, const std::string& model = "text-embedding-3-small", const std::string& baseUrl = "api.openai.com", const std::string& path = "/v1/embeddings");
OpenAIEmbeddingFunction(const std::string& apiKey, const std::string& model = "text-embedding-3-small", size_t dimensions = 1536, const std::string& baseUrl = "api.openai.com", const std::string& path = "/v1/embeddings");

/*
* @brief Generate the embeddings of the documents
Expand All @@ -27,6 +28,9 @@ namespace chromadb {
* @throw ChromaException if something goes wrong
*/
std::vector<std::vector<double>> Generate(const std::vector<std::string>& documents);

private:
size_t m_Dimensions;
};

} // namespace chromadb
32 changes: 32 additions & 0 deletions include/ChromaDB/Embeddings/TogetherAIEmbeddingFunction.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#pragma once

#include "ChromaDB/Embeddings/EmbeddingFunction.h"

namespace chromadb {

class TogetherAIEmbeddingFunction : public EmbeddingFunction
{
public:
/*
* @brief Construct a new TogetherAIEmbeddingFunction object
*
* @param apiKey The API key
* @param model The model to use (optional)
* @param baseUrl The base URL of the server (optional)
* @param path The path of the endpoint (optional)
*/
TogetherAIEmbeddingFunction(const std::string& apiKey, const std::string& model = "togethercomputer/m2-bert-80M-8k-retrieval", const std::string& baseUrl = "api.together.xyz", const std::string& path = "/v1/embeddings");

/*
* @brief Generate the embeddings of the documents
*
* @param documents The documents to generate the embeddings
*
* @return std::vector<std::vector<double>> The embeddings of the documents
*
* @throw ChromaException if something goes wrong
*/
std::vector<std::vector<double>> Generate(const std::vector<std::string>& documents);
};

} // namespace chromadb
36 changes: 36 additions & 0 deletions include/ChromaDB/Embeddings/VoyageAIEmbeddingFunction.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#pragma once

#include "ChromaDB/Embeddings/EmbeddingFunction.h"

namespace chromadb {

class VoyageAIEmbeddingFunction : public EmbeddingFunction
{
public:
/*
* @brief Construct a new VoyageAIEmbeddingFunction object
*
* @param apiKey The API key
* @param model The model to use (optional)
* @param inputType The input type (optional)
* @param baseUrl The base URL of the server (optional)
* @param path The path of the endpoint (optional)
*/
VoyageAIEmbeddingFunction(const std::string& apiKey, const std::string& model = "voyage-2", const std::string& inputType = "document", const std::string& baseUrl = "api.voyageai.com", const std::string& path = "/v1/embeddings");

/*
* @brief Generate the embeddings of the documents
*
* @param documents The documents to generate the embeddings
*
* @return std::vector<std::vector<double>> The embeddings of the documents
*
* @throw ChromaException if something goes wrong
*/
std::vector<std::vector<double>> Generate(const std::vector<std::string>& documents);

private:
std::string m_InputType;
};

} // namespace chromadb
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

#include "ChromaDB/Exceptions/ChromaException.h"

namespace chromadb {

class EmbeddingProviderConnectionException : public ChromaException
{
using ChromaException::ChromaException;
};

} // namespace chromadb
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

#include "ChromaDB/Exceptions/ChromaException.h"

namespace chromadb {

class EmbeddingProviderRequestException : public ChromaException
{
using ChromaException::ChromaException;
};

} // namespace chromadb
Loading

0 comments on commit be3a6d4

Please sign in to comment.