Skip to content

Commit

Permalink
feat: add google provider (#489)
Browse files Browse the repository at this point in the history
  • Loading branch information
lifeizhou-ap authored Dec 18, 2024
1 parent 1327fc4 commit 7154da7
Show file tree
Hide file tree
Showing 12 changed files with 789 additions and 9 deletions.
3 changes: 3 additions & 0 deletions crates/goose-cli/src/commands/configure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ pub async fn handle_configure(
("databricks", "Databricks", "Models on AI Gateway"),
("ollama", "Ollama", "Local open source models"),
("anthropic", "Anthropic", "Claude models"),
("google", "Google Gemini", "Gemini models"),
])
.interact()?
.to_string()
Expand Down Expand Up @@ -157,6 +158,7 @@ pub fn get_recommended_model(provider_name: &str) -> &str {
"databricks" => "claude-3-5-sonnet-2",
"ollama" => OLLAMA_MODEL,
"anthropic" => "claude-3-5-sonnet-2",
"google" => "gemini-1.5-flash",
_ => panic!("Invalid provider name"),
}
}
Expand All @@ -167,6 +169,7 @@ pub fn get_required_keys(provider_name: &str) -> Vec<&'static str> {
"databricks" => vec!["DATABRICKS_HOST"],
"ollama" => vec!["OLLAMA_HOST"],
"anthropic" => vec!["ANTHROPIC_API_KEY"], // Removed ANTHROPIC_HOST since we use a fixed endpoint
"google" => vec!["GOOGLE_API_KEY"],
_ => panic!("Invalid provider name"),
}
}
14 changes: 12 additions & 2 deletions crates/goose-cli/src/profile.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use anyhow::Result;
use goose::key_manager::{get_keyring_secret, KeyRetrievalStrategy};
use goose::providers::configs::{
AnthropicProviderConfig, DatabricksAuth, DatabricksProviderConfig, ModelConfig,
OllamaProviderConfig, OpenAiProviderConfig, ProviderConfig,
AnthropicProviderConfig, DatabricksAuth, DatabricksProviderConfig, GoogleProviderConfig,
ModelConfig, OllamaProviderConfig, OpenAiProviderConfig, ProviderConfig,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
Expand Down Expand Up @@ -125,6 +125,16 @@ pub fn get_provider_config(provider_name: &str, profile: Profile) -> ProviderCon
model: model_config,
})
}
"google" => {
let api_key = get_keyring_secret("GOOGLE_API_KEY", KeyRetrievalStrategy::Both)
.expect("GOOGLE_API_KEY not available in env or the keychain\nSet an env var or rerun `goose configure`");

ProviderConfig::Google(GoogleProviderConfig {
host: "https://generativelanguage.googleapis.com".to_string(), // Default Anthropic API endpoint
api_key,
model: model_config,
})
}
_ => panic!("Invalid provider name"),
}
}
Expand Down
36 changes: 35 additions & 1 deletion crates/goose-server/src/configuration.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use crate::error::{to_env_var, ConfigError};
use config::{Config, Environment};
use goose::providers::configs::GoogleProviderConfig;
use goose::providers::{
configs::{
DatabricksAuth, DatabricksProviderConfig, ModelConfig, OllamaProviderConfig,
OpenAiProviderConfig, ProviderConfig,
},
factory::ProviderType,
ollama,
google, ollama,
utils::ImageFormat,
};
use serde::Deserialize;
Expand Down Expand Up @@ -76,6 +77,17 @@ pub enum ProviderSettings {
#[serde(default)]
estimate_factor: Option<f32>,
},
Google {
#[serde(default = "default_google_host")]
host: String,
api_key: String,
#[serde(default = "default_google_model")]
model: String,
#[serde(default)]
temperature: Option<f32>,
#[serde(default)]
max_tokens: Option<i32>,
},
}

impl ProviderSettings {
Expand All @@ -86,6 +98,7 @@ impl ProviderSettings {
ProviderSettings::OpenAi { .. } => ProviderType::OpenAi,
ProviderSettings::Databricks { .. } => ProviderType::Databricks,
ProviderSettings::Ollama { .. } => ProviderType::Ollama,
ProviderSettings::Google { .. } => ProviderType::Google,
}
}

Expand Down Expand Up @@ -142,6 +155,19 @@ impl ProviderSettings {
.with_context_limit(context_limit)
.with_estimate_factor(estimate_factor),
}),
ProviderSettings::Google {
host,
api_key,
model,
temperature,
max_tokens,
} => ProviderConfig::Google(GoogleProviderConfig {
host,
api_key,
model: ModelConfig::new(model)
.with_temperature(temperature)
.with_max_tokens(max_tokens),
}),
}
}
}
Expand Down Expand Up @@ -233,6 +259,14 @@ fn default_ollama_model() -> String {
ollama::OLLAMA_MODEL.to_string()
}

fn default_google_host() -> String {
google::GOOGLE_API_HOST.to_string()
}

fn default_google_model() -> String {
google::GOOGLE_DEFAULT_MODEL.to_string()
}

fn default_image_format() -> ImageFormat {
ImageFormat::Anthropic
}
Expand Down
7 changes: 7 additions & 0 deletions crates/goose-server/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ impl Clone for AppState {
model: config.model.clone(),
})
}
ProviderConfig::Google(config) => {
ProviderConfig::Google(goose::providers::configs::GoogleProviderConfig {
host: config.host.clone(),
api_key: config.api_key.clone(),
model: config.model.clone(),
})
}
},
agent: self.agent.clone(),
secret_key: self.secret_key.clone(),
Expand Down
1 change: 1 addition & 0 deletions crates/goose/src/providers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ pub mod ollama;
pub mod openai;
pub mod utils;

pub mod google;
#[cfg(test)]
pub mod mock;
14 changes: 14 additions & 0 deletions crates/goose/src/providers/configs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub enum ProviderConfig {
Databricks(DatabricksProviderConfig),
Ollama(OllamaProviderConfig),
Anthropic(AnthropicProviderConfig),
Google(GoogleProviderConfig),
}

/// Configuration for model-specific settings and limits
Expand Down Expand Up @@ -208,6 +209,19 @@ impl ProviderModelConfig for OpenAiProviderConfig {
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GoogleProviderConfig {
pub host: String,
pub api_key: String,
pub model: ModelConfig,
}

impl ProviderModelConfig for GoogleProviderConfig {
fn model_config(&self) -> &ModelConfig {
&self.model
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaProviderConfig {
pub host: String,
Expand Down
5 changes: 4 additions & 1 deletion crates/goose/src/providers/factory.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::{
anthropic::AnthropicProvider, base::Provider, configs::ProviderConfig,
databricks::DatabricksProvider, ollama::OllamaProvider, openai::OpenAiProvider,
databricks::DatabricksProvider, google::GoogleProvider, ollama::OllamaProvider,
openai::OpenAiProvider,
};
use anyhow::Result;
use strum_macros::EnumIter;
Expand All @@ -11,6 +12,7 @@ pub enum ProviderType {
Databricks,
Ollama,
Anthropic,
Google,
}

pub fn get_provider(config: ProviderConfig) -> Result<Box<dyn Provider + Send + Sync>> {
Expand All @@ -23,5 +25,6 @@ pub fn get_provider(config: ProviderConfig) -> Result<Box<dyn Provider + Send +
ProviderConfig::Anthropic(anthropic_config) => {
Ok(Box::new(AnthropicProvider::new(anthropic_config)?))
}
ProviderConfig::Google(google_config) => Ok(Box::new(GoogleProvider::new(google_config)?)),
}
}
Loading

0 comments on commit 7154da7

Please sign in to comment.