Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add google provider #489

Merged
merged 19 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions crates/goose-cli/src/commands/configure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ pub async fn handle_configure(
("databricks", "Databricks", "Models on AI Gateway"),
("ollama", "Ollama", "Local open source models"),
("anthropic", "Anthropic", "Claude models"),
("google", "Google Gemini", "Gemini models"),
])
.interact()?
.to_string()
Expand Down Expand Up @@ -157,6 +158,7 @@ pub fn get_recommended_model(provider_name: &str) -> &str {
"databricks" => "claude-3-5-sonnet-2",
"ollama" => OLLAMA_MODEL,
"anthropic" => "claude-3-5-sonnet-2",
"google" => "gemini-1.5-flash",
_ => panic!("Invalid provider name"),
}
}
Expand All @@ -167,6 +169,7 @@ pub fn get_required_keys(provider_name: &str) -> Vec<&'static str> {
"databricks" => vec!["DATABRICKS_HOST"],
"ollama" => vec!["OLLAMA_HOST"],
"anthropic" => vec!["ANTHROPIC_API_KEY"], // Removed ANTHROPIC_HOST since we use a fixed endpoint
"google" => vec!["GOOGLE_API_KEY"],
_ => panic!("Invalid provider name"),
}
}
14 changes: 12 additions & 2 deletions crates/goose-cli/src/profile.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use anyhow::Result;
use goose::key_manager::{get_keyring_secret, KeyRetrievalStrategy};
use goose::providers::configs::{
AnthropicProviderConfig, DatabricksAuth, DatabricksProviderConfig, ModelConfig,
OllamaProviderConfig, OpenAiProviderConfig, ProviderConfig,
AnthropicProviderConfig, DatabricksAuth, DatabricksProviderConfig, GoogleProviderConfig,
ModelConfig, OllamaProviderConfig, OpenAiProviderConfig, ProviderConfig,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
Expand Down Expand Up @@ -125,6 +125,16 @@ pub fn get_provider_config(provider_name: &str, profile: Profile) -> ProviderCon
model: model_config,
})
}
"google" => {
let api_key = get_keyring_secret("GOOGLE_API_KEY", KeyRetrievalStrategy::Both)
.expect("GOOGLE_API_KEY not available in env or the keychain\nSet an env var or rerun `goose configure`");

ProviderConfig::Google(GoogleProviderConfig {
host: "https://generativelanguage.googleapis.com".to_string(), // Default Anthropic API endpoint
api_key,
model: model_config,
})
}
_ => panic!("Invalid provider name"),
}
}
Expand Down
36 changes: 35 additions & 1 deletion crates/goose-server/src/configuration.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use crate::error::{to_env_var, ConfigError};
use config::{Config, Environment};
use goose::providers::configs::GoogleProviderConfig;
use goose::providers::{
configs::{
DatabricksAuth, DatabricksProviderConfig, ModelConfig, OllamaProviderConfig,
OpenAiProviderConfig, ProviderConfig,
},
factory::ProviderType,
ollama,
google, ollama,
utils::ImageFormat,
};
use serde::Deserialize;
Expand Down Expand Up @@ -76,6 +77,17 @@ pub enum ProviderSettings {
#[serde(default)]
estimate_factor: Option<f32>,
},
Google {
#[serde(default = "default_google_host")]
host: String,
api_key: String,
#[serde(default = "default_google_model")]
model: String,
#[serde(default)]
temperature: Option<f32>,
#[serde(default)]
max_tokens: Option<i32>,
},
}

impl ProviderSettings {
Expand All @@ -86,6 +98,7 @@ impl ProviderSettings {
ProviderSettings::OpenAi { .. } => ProviderType::OpenAi,
ProviderSettings::Databricks { .. } => ProviderType::Databricks,
ProviderSettings::Ollama { .. } => ProviderType::Ollama,
ProviderSettings::Google { .. } => ProviderType::Google,
}
}

Expand Down Expand Up @@ -142,6 +155,19 @@ impl ProviderSettings {
.with_context_limit(context_limit)
.with_estimate_factor(estimate_factor),
}),
ProviderSettings::Google {
host,
api_key,
model,
temperature,
max_tokens,
} => ProviderConfig::Google(GoogleProviderConfig {
host,
api_key,
model: ModelConfig::new(model)
.with_temperature(temperature)
.with_max_tokens(max_tokens),
}),
}
}
}
Expand Down Expand Up @@ -233,6 +259,14 @@ fn default_ollama_model() -> String {
ollama::OLLAMA_MODEL.to_string()
}

fn default_google_host() -> String {
google::GOOGLE_API_HOST.to_string()
}

fn default_google_model() -> String {
google::GOOGLE_DEFAULT_MODEL.to_string()
}

fn default_image_format() -> ImageFormat {
ImageFormat::Anthropic
}
Expand Down
7 changes: 7 additions & 0 deletions crates/goose-server/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ impl Clone for AppState {
model: config.model.clone(),
})
}
ProviderConfig::Google(config) => {
ProviderConfig::Google(goose::providers::configs::GoogleProviderConfig {
host: config.host.clone(),
api_key: config.api_key.clone(),
model: config.model.clone(),
})
}
},
agent: self.agent.clone(),
secret_key: self.secret_key.clone(),
Expand Down
1 change: 1 addition & 0 deletions crates/goose/src/providers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ pub mod ollama;
pub mod openai;
pub mod utils;

pub mod google;
#[cfg(test)]
pub mod mock;
14 changes: 14 additions & 0 deletions crates/goose/src/providers/configs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub enum ProviderConfig {
Databricks(DatabricksProviderConfig),
Ollama(OllamaProviderConfig),
Anthropic(AnthropicProviderConfig),
Google(GoogleProviderConfig),
}

/// Configuration for model-specific settings and limits
Expand Down Expand Up @@ -208,6 +209,19 @@ impl ProviderModelConfig for OpenAiProviderConfig {
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GoogleProviderConfig {
pub host: String,
pub api_key: String,
pub model: ModelConfig,
}

impl ProviderModelConfig for GoogleProviderConfig {
fn model_config(&self) -> &ModelConfig {
&self.model
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaProviderConfig {
pub host: String,
Expand Down
3 changes: 3 additions & 0 deletions crates/goose/src/providers/factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::{
anthropic::AnthropicProvider, base::Provider, configs::ProviderConfig,
databricks::DatabricksProvider, ollama::OllamaProvider, openai::OpenAiProvider,
};
use crate::providers::google::GoogleProvider;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this up into the super for consistency?

use anyhow::Result;
use strum_macros::EnumIter;

Expand All @@ -11,6 +12,7 @@ pub enum ProviderType {
Databricks,
Ollama,
Anthropic,
Google,
}

pub fn get_provider(config: ProviderConfig) -> Result<Box<dyn Provider + Send + Sync>> {
Expand All @@ -23,5 +25,6 @@ pub fn get_provider(config: ProviderConfig) -> Result<Box<dyn Provider + Send +
ProviderConfig::Anthropic(anthropic_config) => {
Ok(Box::new(AnthropicProvider::new(anthropic_config)?))
}
ProviderConfig::Google(google_config) => Ok(Box::new(GoogleProvider::new(google_config)?)),
}
}
Loading
Loading