diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index dca5a772f45d81..838dea19819194 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -197,10 +197,10 @@ async fn create_billing_subscription( .await? .ok_or_else(|| anyhow!("user not found"))?; - let Some((stripe_client, stripe_price_id)) = app + let Some((stripe_client, stripe_access_price_id)) = app .stripe_client .clone() - .zip(app.config.stripe_llm_usage_price_id.clone()) + .zip(app.config.stripe_llm_access_price_id.clone()) else { log::error!("failed to retrieve Stripe client or price ID"); Err(Error::http( @@ -232,8 +232,8 @@ async fn create_billing_subscription( params.customer = Some(customer_id); params.client_reference_id = Some(user.github_login.as_str()); params.line_items = Some(vec![CreateCheckoutSessionLineItems { - price: Some(stripe_price_id.to_string()), - quantity: Some(0), + price: Some(stripe_access_price_id.to_string()), + quantity: Some(1), ..Default::default() }]); let success_url = format!("{}/account", app.config.zed_dot_dev_url()); @@ -787,22 +787,33 @@ async fn update_stripe_subscription( monthly_spending.saturating_sub(FREE_TIER_MONTHLY_SPENDING_LIMIT); let new_quantity = (monthly_spending_over_free_tier.0 as f32 / 100.).ceil(); - Subscription::update( - stripe_client, - &subscription_id, - stripe::UpdateSubscription { - items: Some(vec![stripe::UpdateSubscriptionItems { - // TODO: Do we need to send up the `id` if a subscription item - // with this price already exists, or will Stripe take care of - // it? - id: None, - price: Some(stripe_llm_usage_price_id.to_string()), - quantity: Some(new_quantity as u64), - ..Default::default() - }]), + let current_subscription = Subscription::retrieve(stripe_client, &subscription_id, &[]).await?; + + let mut update_params = stripe::UpdateSubscription { + proration_behavior: Some( + stripe::generated::billing::subscription::SubscriptionProrationBehavior::None, + ), + ..Default::default() + }; + + if let Some(existing_item) = current_subscription.items.data.iter().find(|item| { + item.price.as_ref().map_or(false, |price| { + price.id == stripe_llm_usage_price_id.as_ref() + }) + }) { + update_params.items = Some(vec![stripe::UpdateSubscriptionItems { + id: Some(existing_item.id.to_string()), + quantity: Some(new_quantity as u64), ..Default::default() - }, - ) - .await?; + }]); + } else { + update_params.items = Some(vec![stripe::UpdateSubscriptionItems { + price: Some(stripe_llm_usage_price_id.to_string()), + quantity: Some(new_quantity as u64), + ..Default::default() + }]); + } + + Subscription::update(stripe_client, &subscription_id, update_params).await?; Ok(()) } diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index a6141abb888730..3896926f4372fe 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -176,6 +176,7 @@ pub struct Config { pub slack_panics_webhook: Option, pub auto_join_channel_id: Option, pub stripe_api_key: Option, + pub stripe_llm_access_price_id: Option>, pub stripe_llm_usage_price_id: Option>, pub supermaven_admin_api_key: Option>, pub user_backfiller_github_access_token: Option>, @@ -237,6 +238,7 @@ impl Config { migrations_path: None, seed_path: None, stripe_api_key: None, + stripe_llm_access_price_id: None, stripe_llm_usage_price_id: None, supermaven_admin_api_key: None, user_backfiller_github_access_token: None, diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index bbbd4e562cb901..bd227f17c70404 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -155,7 +155,8 @@ async fn main() -> Result<()> { .await .trace_err(); - if let Some(llm_db) = llm_db { + if let Some(mut llm_db) = llm_db { + llm_db.initialize().await?; sync_llm_usage_with_stripe_periodically(state.clone(), llm_db); } diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 55bc279c8eaf6e..683a53a2f56618 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -677,6 +677,7 @@ impl TestServer { migrations_path: None, seed_path: None, stripe_api_key: None, + stripe_llm_access_price_id: None, stripe_llm_usage_price_id: None, supermaven_admin_api_key: None, user_backfiller_github_access_token: None,