Skip to content

Commit

Permalink
mixedbread support
Browse files Browse the repository at this point in the history
  • Loading branch information
asg017 committed Jun 4, 2024
1 parent b6d9a64 commit 977d374
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 11 deletions.
25 changes: 14 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@ A SQLite extension for generating text embedding from remote sources (llamafile,

Work in progress!

| Client | API Reference |
| --------- | -------------------------------------------------------------------------- |
| Ollama | https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings |
| Nomic | https://docs.nomic.ai/reference/endpoints/nomic-embed-text |
| Cohere | https://docs.cohere.com/reference/embed |
| OpenAI | https://platform.openai.com/docs/guides/embeddings |
| llamafile | https://github.com/Mozilla-Ocho/llamafile |
| Jina | https://api.jina.ai/redoc#tag/embeddings |
| Client | API Reference |
| ---------- | -------------------------------------------------------------------------- |
| OpenAI | https://platform.openai.com/docs/guides/embeddings |
| Nomic | https://docs.nomic.ai/reference/endpoints/nomic-embed-text |
| Cohere | https://docs.cohere.com/reference/embed |
| Jina | https://api.jina.ai/redoc#tag/embeddings |
| MixedBread | https://www.mixedbread.ai/api-reference#quick-start-guide |
| llamafile | https://github.com/Mozilla-Ocho/llamafile |
| Ollama | https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings |

TODO

- Google AI API https://ai.google.dev/api/rest/v1beta/models/embedText
- text-embeddings-inference https://github.com/huggingface/text-embeddings-inference
- mixedbread https://www.mixedbread.ai/api-reference#whats-inside-
- [ ] Support Google AI API https://ai.google.dev/api/rest/v1beta/models/embedText
- [ ] Support text-embeddings-inference https://github.com/huggingface/text-embeddings-inference
- [ ] image embeddings support
- [ ] batch support
- [ ] extra params (X-Client-Name headers, truncation_strategy, input_type, etc.)
82 changes: 82 additions & 0 deletions src/clients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,87 @@ impl JinaClient {
})
}
}
#[derive(Clone)]
pub struct MixedbreadClient {
url: String,
model: String,
key: String,
}
const DEFAULT_MIXEDBREAD_URL: &str = "https://api.mixedbread.ai/v1/embeddings/";
const DEFAULT_MIXEDBREAD_API_KEY_ENV: &str = "MIXEDBREAD_API_KEY";

impl MixedbreadClient {
pub fn new<S: Into<String>>(
model: S,
url: Option<String>,
key: Option<String>,
) -> Result<Self> {
Ok(Self {
model: model.into(),
url: url.unwrap_or(DEFAULT_MIXEDBREAD_URL.to_owned()),
key: match key {
Some(key) => key,
None => try_env_var(DEFAULT_MIXEDBREAD_API_KEY_ENV)?,
},
})
}

pub fn infer_single(&self, input: &str) -> Result<Vec<f32>> {
let mut body = serde_json::Map::new();
body.insert("input".to_owned(), vec![input.to_owned()].into());
body.insert("model".to_owned(), self.model.to_owned().into());

let data: serde_json::Value = ureq::post(&self.url)
.set("Content-Type", "application/json")
.set("Accept", "application/json")
.set("Authorization", format!("Bearer {}", self.key).as_str())
.send_bytes(
serde_json::to_vec(&body)
.map_err(|error| {
Error::new_message(format!("Error serializing body to JSON: {error}"))
})?
.as_ref(),
)
.map_err(|error| Error::new_message(format!("Error sending HTTP request: {error}")))?
.into_json()
.map_err(|error| {
Error::new_message(format!("Error parsing HTTP response as JSON: {error}"))
})?;
JinaClient::parse_single_response(data)
}
pub fn parse_single_response(value: serde_json::Value) -> Result<Vec<f32>> {
value
.get("data")
.ok_or_else(|| Error::new_message("expected 'data' key in response body"))
.and_then(|v| {
v.get(0)
.ok_or_else(|| Error::new_message("expected 'data.0' path in response body"))
})
.and_then(|v| {
v.get("embedding").ok_or_else(|| {
Error::new_message("expected 'data.0.embedding' path in response body")
})
})
.and_then(|v| {
v.as_array().ok_or_else(|| {
Error::new_message("expected 'data.0.embedding' path to be an array")
})
})
.and_then(|arr| {
arr.iter()
.map(|v| {
v.as_f64()
.ok_or_else(|| {
Error::new_message(
"expected 'data.0.embedding' array to contain floats",
)
})
.map(|f| f as f32)
})
.collect()
})
}
}

#[derive(Clone)]
pub struct OllamaClient {
Expand Down Expand Up @@ -431,4 +512,5 @@ pub enum Client {
Ollama(OllamaClient),
Llamafile(LlamafileClient),
Jina(JinaClient),
Mixedbread(MixedbreadClient),
}
4 changes: 4 additions & 0 deletions src/clients_vtab.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use sqlite_loadable::{
};
use std::{cell::RefCell, collections::HashMap, marker::PhantomData, mem, os::raw::c_int, rc::Rc};

use crate::clients::MixedbreadClient;
use crate::{
clients::{
Client, CohereClient, JinaClient, LlamafileClient, NomicClient, OllamaClient, OpenAiClient,
Expand Down Expand Up @@ -90,6 +91,9 @@ impl<'vtab, 'a> VTabWriteable<'vtab> for ClientsTable {
let client = match api::value_type(&values[1]) {
ValueType::Text => match api::value_text(&values[1])? {
"openai" => Client::OpenAI(OpenAiClient::new(name, None, None)?),
"mixedbread" => {
Client::Mixedbread(MixedbreadClient::new(name, None, None)?)
}
"jina" => Client::Jina(JinaClient::new(name, None, None)?),
"nomic" => Client::Nomic(NomicClient::new(name, None, None)?),
"cohere" => Client::Cohere(CohereClient::new(name, None, None)?),
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ pub fn rembed(
let embedding = match client {
Client::OpenAI(client) => client.infer_single(input)?,
Client::Jina(client) => client.infer_single(input)?,
Client::Mixedbread(client) => client.infer_single(input)?,
Client::Ollama(client) => client.infer_single(input)?,
Client::Llamafile(client) => client.infer_single(input)?,
Client::Nomic(client) => {
Expand Down
3 changes: 3 additions & 0 deletions test.sql
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
INSERT INTO temp.rembed_clients(name, options) VALUES
('text-embedding-3-small','openai'),
('jina-embeddings-v2-base-en','jina'),
('mixedbread-ai/mxbai-embed-large-v1','mixedbread'),
('nomic-embed-text-v1.5', 'nomic'),
('embed-english-v3.0', 'cohere'),
('snowflake-arctic-embed:s', 'ollama'),
Expand All @@ -21,6 +22,8 @@ INSERT INTO temp.rembed_clients(name, options) VALUES
)
);

select length(rembed('mixedbread-ai/mxbai-embed-large-v1', 'obama the person'));
.exit
select length(rembed('jina-embeddings-v2-base-en', 'obama the person'));

.exit
Expand Down

0 comments on commit 977d374

Please sign in to comment.