diff --git a/components/usage/pkg/apiv1/billing.go b/components/usage/pkg/apiv1/billing.go index 0920dc23802ffb..792fca0e147948 100644 --- a/components/usage/pkg/apiv1/billing.go +++ b/components/usage/pkg/apiv1/billing.go @@ -49,7 +49,7 @@ func (s *BillingService) UpdateInvoices(ctx context.Context, in *v1.UpdateInvoic return nil, status.Errorf(codes.Internal, "Failed to download usage report with ID: %s", in.GetReportId()) } - credits, err := s.creditSummaryForTeams(report) + credits, err := s.creditSummaryForTeams(report, in.GetReportId()) if err != nil { log.Log.WithError(err).Errorf("Failed to compute credit summary.") return nil, status.Errorf(codes.InvalidArgument, "failed to compute credit summary") @@ -100,7 +100,7 @@ func (s *BillingService) GetUpcomingInvoice(ctx context.Context, in *v1.GetUpcom }, nil } -func (s *BillingService) creditSummaryForTeams(sessions db.UsageReport) (map[string]int64, error) { +func (s *BillingService) creditSummaryForTeams(sessions db.UsageReport, reportID string) (map[string]stripe.CreditSummary, error) { creditsPerTeamID := map[string]float64{} for _, session := range sessions { @@ -120,9 +120,12 @@ func (s *BillingService) creditSummaryForTeams(sessions db.UsageReport) (map[str creditsPerTeamID[id] += session.CreditsUsed } - rounded := map[string]int64{} + rounded := map[string]stripe.CreditSummary{} for teamID, credits := range creditsPerTeamID { - rounded[teamID] = int64(math.Ceil(credits)) + rounded[teamID] = stripe.CreditSummary{ + Credits: int64(math.Ceil(credits)), + ReportID: reportID, + } } return rounded, nil diff --git a/components/usage/pkg/apiv1/billing_test.go b/components/usage/pkg/apiv1/billing_test.go index 49e58a126cd6d4..c56328156d846b 100644 --- a/components/usage/pkg/apiv1/billing_test.go +++ b/components/usage/pkg/apiv1/billing_test.go @@ -18,18 +18,19 @@ import ( func TestCreditSummaryForTeams(t *testing.T) { teamID_A, teamID_B := uuid.New().String(), uuid.New().String() teamAttributionID_A, teamAttributionID_B := db.NewTeamAttributionID(teamID_A), db.NewTeamAttributionID(teamID_B) + reportID := "report_id_1" scenarios := []struct { Name string Sessions db.UsageReport BillSessionsAfter time.Time - Expected map[string]int64 + Expected map[string]stripe.CreditSummary }{ { Name: "no instances in report, no summary", BillSessionsAfter: time.Time{}, Sessions: nil, - Expected: map[string]int64{}, + Expected: map[string]stripe.CreditSummary{}, }, { Name: "skips user attributions", @@ -39,7 +40,7 @@ func TestCreditSummaryForTeams(t *testing.T) { AttributionID: db.NewUserAttributionID(uuid.New().String()), }, }, - Expected: map[string]int64{}, + Expected: map[string]stripe.CreditSummary{}, }, { Name: "two workspace instances", @@ -56,9 +57,12 @@ func TestCreditSummaryForTeams(t *testing.T) { CreditsUsed: 10, }, }, - Expected: map[string]int64{ + Expected: map[string]stripe.CreditSummary{ // total of 2 days runtime, at 10 credits per hour, that's 480 credits - teamID_A: 480, + teamID_A: { + Credits: 480, + ReportID: reportID, + }, }, }, { @@ -76,10 +80,16 @@ func TestCreditSummaryForTeams(t *testing.T) { CreditsUsed: (24) * 10, }, }, - Expected: map[string]int64{ + Expected: map[string]stripe.CreditSummary{ // total of 2 days runtime, at 10 credits per hour, that's 480 credits - teamID_A: 120, - teamID_B: 240, + teamID_A: { + Credits: 120, + ReportID: reportID, + }, + teamID_B: { + Credits: 240, + ReportID: reportID, + }, }, }, { @@ -99,8 +109,11 @@ func TestCreditSummaryForTeams(t *testing.T) { StartedAt: time.Now().AddDate(0, 0, -3), }, }, - Expected: map[string]int64{ - teamID_A: 120, + Expected: map[string]stripe.CreditSummary{ + teamID_A: { + Credits: 120, + ReportID: reportID, + }, }, }, } @@ -108,7 +121,7 @@ func TestCreditSummaryForTeams(t *testing.T) { for _, s := range scenarios { t.Run(s.Name, func(t *testing.T) { svc := NewBillingService(&stripe.Client{}, s.BillSessionsAfter, &gorm.DB{}) - actual, err := svc.creditSummaryForTeams(s.Sessions) + actual, err := svc.creditSummaryForTeams(s.Sessions, reportID) require.NoError(t, err) require.Equal(t, s.Expected, actual) }) diff --git a/components/usage/pkg/stripe/stripe.go b/components/usage/pkg/stripe/stripe.go index d8215ea1504d6d..35f421bbd28472 100644 --- a/components/usage/pkg/stripe/stripe.go +++ b/components/usage/pkg/stripe/stripe.go @@ -16,6 +16,10 @@ import ( "github.com/stripe/stripe-go/v72/client" ) +const ( + reportIDMetadataKey = "reportId" +) + type Client struct { sc *client.API } @@ -58,9 +62,14 @@ type Invoice struct { Credits int64 } +type CreditSummary struct { + Credits int64 + ReportID string +} + // UpdateUsage updates teams' Stripe subscriptions with usage data // `usageForTeam` is a map from team name to total workspace seconds used within a billing period. -func (c *Client) UpdateUsage(ctx context.Context, creditsPerTeam map[string]int64) error { +func (c *Client) UpdateUsage(ctx context.Context, creditsPerTeam map[string]CreditSummary) error { teamIds := make([]string, 0, len(creditsPerTeam)) for k := range creditsPerTeam { teamIds = append(teamIds, k) @@ -117,7 +126,7 @@ func (c *Client) findCustomers(ctx context.Context, query string) ([]*stripe.Cus return customers, nil } -func (c *Client) updateUsageForCustomer(ctx context.Context, customer *stripe.Customer, credits int64) (*UsageRecord, error) { +func (c *Client) updateUsageForCustomer(ctx context.Context, customer *stripe.Customer, summary CreditSummary) (*UsageRecord, error) { subscriptions := customer.Subscriptions.Data if len(subscriptions) != 1 { return nil, fmt.Errorf("customer has an unexpected number of subscriptions %v (expected 1, got %d)", subscriptions, len(subscriptions)) @@ -136,15 +145,27 @@ func (c *Client) updateUsageForCustomer(ctx context.Context, customer *stripe.Cu Context: ctx, }, SubscriptionItem: stripe.String(subscriptionItemId), - Quantity: stripe.Int64(credits), + Quantity: stripe.Int64(summary.Credits), }) if err != nil { return nil, fmt.Errorf("failed to register usage for customer %q on subscription item %s", customer.Name, subscriptionItemId) } + invoice, err := c.GetUpcomingInvoice(ctx, customer.ID) + if err != nil { + return nil, fmt.Errorf("failed to find upcoming invoice for customer %s: %w", customer.ID, err) + } + + _, err = c.UpdateInvoiceMetadata(ctx, invoice.ID, map[string]string{ + reportIDMetadataKey: summary.ReportID, + }) + if err != nil { + return nil, fmt.Errorf("failed to udpate invoice %s metadata with report ID: %w", invoice.ID, err) + } + return &UsageRecord{ SubscriptionItemID: subscriptionItemId, - Quantity: credits, + Quantity: summary.Credits, }, nil } @@ -205,6 +226,20 @@ func (c *Client) GetUpcomingInvoice(ctx context.Context, customerID string) (*In }, nil } +func (c *Client) UpdateInvoiceMetadata(ctx context.Context, invoiceID string, metadata map[string]string) (*stripe.Invoice, error) { + invoice, err := c.sc.Invoices.Update(invoiceID, &stripe.InvoiceParams{ + Params: stripe.Params{ + Context: ctx, + Metadata: metadata, + }, + }) + if err != nil { + return nil, fmt.Errorf("failed to update invoice %s metadata: %w", invoiceID, err) + } + + return invoice, nil +} + // queriesForCustomersWithTeamIds constructs Stripe query strings to find the Stripe Customer for each teamId // It returns multiple queries, each being a big disjunction of subclauses so that we can process multiple teamIds in one query. // `clausesPerQuery` is a limit enforced by the Stripe API.