Skip to content

Commit

Permalink
feat: Add databricks oauth (#313)
Browse files Browse the repository at this point in the history
  • Loading branch information
baxen authored Nov 23, 2024
1 parent 7601cdc commit 339f5f4
Show file tree
Hide file tree
Showing 11 changed files with 560 additions and 48 deletions.
28 changes: 15 additions & 13 deletions crates/goose-cli/src/profile/provider_helper.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::inputs::inputs::get_env_value_or_input;
use goose::providers::configs::{DatabricksProviderConfig, OllamaProviderConfig, OpenAiProviderConfig, ProviderConfig};
use goose::providers::configs::{
DatabricksAuth, DatabricksProviderConfig, OpenAiProviderConfig, OllamaProviderConfig, ProviderConfig
};
use goose::providers::factory::ProviderType;
use goose::providers::ollama::OLLAMA_HOST;
use strum::IntoEnumIterator;
Expand Down Expand Up @@ -35,21 +37,21 @@ pub fn set_provider_config(provider_name: &str, model: String) -> ProviderConfig
temperature: None,
max_tokens: None,
}),
PROVIDER_DATABRICKS => ProviderConfig::Databricks(DatabricksProviderConfig {
host: get_env_value_or_input(
PROVIDER_DATABRICKS => {
let host = get_env_value_or_input(
"DATABRICKS_HOST",
"Please enter your Databricks host:",
false,
),
token: get_env_value_or_input(
"DATABRICKS_TOKEN",
"Please enter your Databricks token:",
true,
),
model,
temperature: None,
max_tokens: None,
}),
);
ProviderConfig::Databricks(DatabricksProviderConfig {
host: host.clone(),
// TODO revisit configuration
auth: DatabricksAuth::oauth(host),
model,
temperature: None,
max_tokens: None,
})
}
PROVIDER_OLLAMA => ProviderConfig::Ollama(OllamaProviderConfig {
host: std::env::var("OLLAMA_HOST")
.unwrap_or_else(|_| String::from(OLLAMA_HOST)),
Expand Down
14 changes: 4 additions & 10 deletions crates/goose-server/src/configuration.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::error::{to_env_var, ConfigError};
use config::{Config, Environment};
use goose::providers::{
configs::{DatabricksProviderConfig, OllamaProviderConfig, OpenAiProviderConfig, ProviderConfig},
configs::{DatabricksAuth, DatabricksProviderConfig, OllamaProviderConfig, OpenAiProviderConfig, ProviderConfig},
factory::ProviderType,
ollama,
};
Expand Down Expand Up @@ -41,7 +41,6 @@ pub enum ProviderSettings {
Databricks {
#[serde(default = "default_databricks_host")]
host: String,
token: String,
#[serde(default = "default_model")]
model: String,
#[serde(default)]
Expand Down Expand Up @@ -90,13 +89,12 @@ impl ProviderSettings {
}),
ProviderSettings::Databricks {
host,
token,
model,
temperature,
max_tokens,
} => ProviderConfig::Databricks(DatabricksProviderConfig {
host,
token,
host: host.clone(),
auth: DatabricksAuth::oauth(host),
model,
temperature,
max_tokens,
Expand Down Expand Up @@ -257,7 +255,6 @@ mod tests {
fn test_databricks_settings() {
clean_env();
env::set_var("GOOSE_PROVIDER__TYPE", "databricks");
env::set_var("GOOSE_PROVIDER__TOKEN", "test-token");
env::set_var("GOOSE_PROVIDER__HOST", "https://custom.databricks.com");
env::set_var("GOOSE_PROVIDER__MODEL", "llama-2-70b");
env::set_var("GOOSE_PROVIDER__TEMPERATURE", "0.7");
Expand All @@ -266,14 +263,12 @@ mod tests {
let settings = Settings::new().unwrap();
if let ProviderSettings::Databricks {
host,
token,
model,
temperature,
max_tokens,
} = settings.provider
{
assert_eq!(host, "https://custom.databricks.com");
assert_eq!(token, "test-token");
assert_eq!(model, "llama-2-70b");
assert_eq!(temperature, Some(0.7));
assert_eq!(max_tokens, Some(2000));
Expand All @@ -283,7 +278,6 @@ mod tests {

// Clean up
env::remove_var("GOOSE_PROVIDER__TYPE");
env::remove_var("GOOSE_PROVIDER__TOKEN");
env::remove_var("GOOSE_PROVIDER__HOST");
env::remove_var("GOOSE_PROVIDER__MODEL");
env::remove_var("GOOSE_PROVIDER__TEMPERATURE");
Expand Down Expand Up @@ -372,4 +366,4 @@ mod tests {
let addr = server_settings.socket_addr();
assert_eq!(addr.to_string(), "127.0.0.1:3000");
}
}
}
2 changes: 1 addition & 1 deletion crates/goose-server/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl Clone for AppState {
ProviderConfig::Databricks(config) => ProviderConfig::Databricks(
goose::providers::configs::DatabricksProviderConfig {
host: config.host.clone(),
token: config.token.clone(),
auth: config.auth.clone(),
model: config.model.clone(),
temperature: config.temperature,
max_tokens: config.max_tokens,
Expand Down
16 changes: 14 additions & 2 deletions crates/goose/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ reqwest = { version = "0.11", features = ["json"] }
tokio = { version = "1.0", features = ["full"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
serde_urlencoded = "0.7"
uuid = { version = "1.0", features = ["v4"] }
regex = "1.11.1"
async-trait = "0.1"
Expand All @@ -25,11 +26,22 @@ strum_macros = "0.26"
tera = "1.20.0"
tokenizers = "0.20.3"
include_dir = "0.7.4"
chrono = "0.4.38"
chrono = { version = "0.4.38", features = ["serde"] }
indoc = "2.0.5"
nanoid = "0.4"
sha2 = "0.10"
base64 = "0.21"
url = "2.5"
axum = "0.7"
tower-http = { version = "0.5", features = ["cors"] }
webbrowser = "0.8"
dotenv = "0.15"

[dev-dependencies]
wiremock = "0.6.0"
mockito = "1.2"
tempfile = "3.8"
dotenv = "0.15"

[[example]]
name = "databricks_oauth"
path = "examples/databricks_oauth.rs"
3 changes: 3 additions & 0 deletions crates/goose/examples/.env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Databricks OAuth Configuration
DATABRICKS_HOST=https://your-workspace.cloud.databricks.com
DATABRICKS_MODEL=your-model-name
49 changes: 49 additions & 0 deletions crates/goose/examples/databricks_oauth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use anyhow::Result;
use dotenv::dotenv;
use goose::{
models::message::Message,
providers::{
configs::{DatabricksProviderConfig, ProviderConfig},
factory::get_provider,
},
};

#[tokio::main]
async fn main() -> Result<()> {
// Load environment variables from .env file
dotenv().ok();

// Get required environment variables
let host =
std::env::var("DATABRICKS_HOST").expect("DATABRICKS_HOST environment variable is required");
let model = std::env::var("DATABRICKS_MODEL")
.expect("DATABRICKS_MODEL environment variable is required");

// Create the Databricks provider configuration with OAuth
let config = ProviderConfig::Databricks(DatabricksProviderConfig::with_oauth(host, model));

// Create the provider
let provider = get_provider(config)?;

// Create a simple message
let message = Message::user().with_text("Tell me a short joke about programming.");

// Get a response
let (response, usage) = provider
.complete("You are a helpful assistant.", &[message], &[])
.await?;

// Print the response and usage statistics
println!("\nResponse from AI:");
println!("---------------");
for content in response.content {
dbg!(content);
}
println!("\nToken Usage:");
println!("------------");
println!("Input tokens: {:?}", usage.input_tokens);
println!("Output tokens: {:?}", usage.output_tokens);
println!("Total tokens: {:?}", usage.total_tokens);

Ok(())
}
3 changes: 2 additions & 1 deletion crates/goose/src/providers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ pub mod base;
pub mod configs;
pub mod databricks;
pub mod factory;
pub mod oauth;
pub mod ollama;
pub mod openai;
pub mod utils;

#[cfg(test)]
pub mod mock;
pub mod mock;
67 changes: 61 additions & 6 deletions crates/goose/src/providers/configs.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,82 @@
// Unified enum to wrap different provider configurations
use serde::{Deserialize, Serialize};

const DEFAULT_CLIENT_ID: &str = "databricks-cli";
const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020";
const DEFAULT_SCOPES: &[&str] = &["all-apis"];

#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ProviderConfig {
OpenAi(OpenAiProviderConfig),
Databricks(DatabricksProviderConfig),
Ollama(OllamaProviderConfig),
}

// Define specific config structs for each provider
pub struct OpenAiProviderConfig {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum DatabricksAuth {
Token(String),
OAuth {
host: String,
client_id: String,
redirect_url: String,
scopes: Vec<String>,
},
}

impl DatabricksAuth {
/// Create a new OAuth configuration with default values
pub fn oauth(host: String) -> Self {
Self::OAuth {
host,
client_id: DEFAULT_CLIENT_ID.to_string(),
redirect_url: DEFAULT_REDIRECT_URL.to_string(),
scopes: DEFAULT_SCOPES.iter().map(|s| s.to_string()).collect(),
}
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabricksProviderConfig {
pub host: String,
pub api_key: String,
pub model: String,
pub auth: DatabricksAuth,
pub temperature: Option<f32>,
pub max_tokens: Option<i32>,
}

pub struct DatabricksProviderConfig {
impl DatabricksProviderConfig {
/// Create a new configuration with token authentication
pub fn with_token(host: String, model: String, token: String) -> Self {
Self {
host,
model,
auth: DatabricksAuth::Token(token),
temperature: None,
max_tokens: None,
}
}

/// Create a new configuration with OAuth authentication using default settings
pub fn with_oauth(host: String, model: String) -> Self {
Self {
host: host.clone(),
model,
auth: DatabricksAuth::oauth(host),
temperature: None,
max_tokens: None,
}
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAiProviderConfig {
pub host: String,
pub token: String,
pub api_key: String,
pub model: String,
pub temperature: Option<f32>,
pub max_tokens: Option<i32>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaProviderConfig {
pub host: String,
pub model: String,
Expand Down
39 changes: 28 additions & 11 deletions crates/goose/src/providers/databricks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use serde_json::{json, Value};
use std::time::Duration;

use super::base::{Provider, Usage};
use super::configs::DatabricksProviderConfig;
use super::configs::{DatabricksAuth, DatabricksProviderConfig};
use super::oauth;
use super::utils::{
check_openai_context_length_error, messages_to_openai_spec, openai_response_to_message,
tools_to_openai_spec,
Expand All @@ -22,16 +23,27 @@ impl DatabricksProvider {
pub fn new(config: DatabricksProviderConfig) -> Result<Self> {
let client = Client::builder()
.timeout(Duration::from_secs(600)) // 10 minutes timeout
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("Authorization", format!("Bearer {}", config.token).parse()?);
headers
})
.build()?;

Ok(Self { client, config })
}

async fn ensure_auth_header(&self) -> Result<String> {
match &self.config.auth {
DatabricksAuth::Token(token) => Ok(format!("Bearer {}", token)),
DatabricksAuth::OAuth {
host,
client_id,
redirect_url,
scopes,
} => {
let token =
oauth::get_oauth_token_async(host, client_id, redirect_url, scopes).await?;
Ok(format!("Bearer {}", token))
}
}
}

fn get_usage(data: &Value) -> Result<Usage> {
let usage = data
.get("usage")
Expand Down Expand Up @@ -66,7 +78,14 @@ impl DatabricksProvider {
self.config.model
);

let response = self.client.post(&url).json(&payload).send().await?;
let auth_header = self.ensure_auth_header().await?;
let response = self
.client
.post(&url)
.header("Authorization", auth_header)
.json(&payload)
.send()
.await?;

match response.status() {
StatusCode::OK => Ok(response.json().await?),
Expand Down Expand Up @@ -152,13 +171,11 @@ impl Provider for DatabricksProvider {
mod tests {
use super::*;
use crate::models::message::MessageContent;
use anyhow::Result;
use serde_json::json;
use wiremock::matchers::{body_json, header, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};

#[tokio::test]
async fn test_databricks_completion() -> Result<()> {
async fn test_databricks_completion_with_token() -> Result<()> {
// Start a mock server
let mock_server = MockServer::start().await;

Expand Down Expand Up @@ -199,8 +216,8 @@ mod tests {
// Create the DatabricksProvider with the mock server's URL as the host
let config = DatabricksProviderConfig {
host: mock_server.uri(),
token: "test_token".to_string(),
model: "my-databricks-model".to_string(),
auth: DatabricksAuth::Token("test_token".to_string()),
temperature: None,
max_tokens: None,
};
Expand Down
Loading

0 comments on commit 339f5f4

Please sign in to comment.