Skip to content

Commit

Permalink
[stripe] Set reportId on invoices after updating credits
Browse files Browse the repository at this point in the history
  • Loading branch information
easyCZ authored and roboquat committed Aug 30, 2022
1 parent 9df045e commit 33c613c
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 19 deletions.
11 changes: 7 additions & 4 deletions components/usage/pkg/apiv1/billing.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (s *BillingService) UpdateInvoices(ctx context.Context, in *v1.UpdateInvoic
return nil, status.Errorf(codes.Internal, "Failed to download usage report with ID: %s", in.GetReportId())
}

credits, err := s.creditSummaryForTeams(report)
credits, err := s.creditSummaryForTeams(report, in.GetReportId())
if err != nil {
log.Log.WithError(err).Errorf("Failed to compute credit summary.")
return nil, status.Errorf(codes.InvalidArgument, "failed to compute credit summary")
Expand Down Expand Up @@ -100,7 +100,7 @@ func (s *BillingService) GetUpcomingInvoice(ctx context.Context, in *v1.GetUpcom
}, nil
}

func (s *BillingService) creditSummaryForTeams(sessions db.UsageReport) (map[string]int64, error) {
func (s *BillingService) creditSummaryForTeams(sessions db.UsageReport, reportID string) (map[string]stripe.CreditSummary, error) {
creditsPerTeamID := map[string]float64{}

for _, session := range sessions {
Expand All @@ -120,9 +120,12 @@ func (s *BillingService) creditSummaryForTeams(sessions db.UsageReport) (map[str
creditsPerTeamID[id] += session.CreditsUsed
}

rounded := map[string]int64{}
rounded := map[string]stripe.CreditSummary{}
for teamID, credits := range creditsPerTeamID {
rounded[teamID] = int64(math.Ceil(credits))
rounded[teamID] = stripe.CreditSummary{
Credits: int64(math.Ceil(credits)),
ReportID: reportID,
}
}

return rounded, nil
Expand Down
35 changes: 24 additions & 11 deletions components/usage/pkg/apiv1/billing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,19 @@ import (
func TestCreditSummaryForTeams(t *testing.T) {
teamID_A, teamID_B := uuid.New().String(), uuid.New().String()
teamAttributionID_A, teamAttributionID_B := db.NewTeamAttributionID(teamID_A), db.NewTeamAttributionID(teamID_B)
reportID := "report_id_1"

scenarios := []struct {
Name string
Sessions db.UsageReport
BillSessionsAfter time.Time
Expected map[string]int64
Expected map[string]stripe.CreditSummary
}{
{
Name: "no instances in report, no summary",
BillSessionsAfter: time.Time{},
Sessions: nil,
Expected: map[string]int64{},
Expected: map[string]stripe.CreditSummary{},
},
{
Name: "skips user attributions",
Expand All @@ -39,7 +40,7 @@ func TestCreditSummaryForTeams(t *testing.T) {
AttributionID: db.NewUserAttributionID(uuid.New().String()),
},
},
Expected: map[string]int64{},
Expected: map[string]stripe.CreditSummary{},
},
{
Name: "two workspace instances",
Expand All @@ -56,9 +57,12 @@ func TestCreditSummaryForTeams(t *testing.T) {
CreditsUsed: 10,
},
},
Expected: map[string]int64{
Expected: map[string]stripe.CreditSummary{
// total of 2 days runtime, at 10 credits per hour, that's 480 credits
teamID_A: 480,
teamID_A: {
Credits: 480,
ReportID: reportID,
},
},
},
{
Expand All @@ -76,10 +80,16 @@ func TestCreditSummaryForTeams(t *testing.T) {
CreditsUsed: (24) * 10,
},
},
Expected: map[string]int64{
Expected: map[string]stripe.CreditSummary{
// total of 2 days runtime, at 10 credits per hour, that's 480 credits
teamID_A: 120,
teamID_B: 240,
teamID_A: {
Credits: 120,
ReportID: reportID,
},
teamID_B: {
Credits: 240,
ReportID: reportID,
},
},
},
{
Expand All @@ -99,16 +109,19 @@ func TestCreditSummaryForTeams(t *testing.T) {
StartedAt: time.Now().AddDate(0, 0, -3),
},
},
Expected: map[string]int64{
teamID_A: 120,
Expected: map[string]stripe.CreditSummary{
teamID_A: {
Credits: 120,
ReportID: reportID,
},
},
},
}

for _, s := range scenarios {
t.Run(s.Name, func(t *testing.T) {
svc := NewBillingService(&stripe.Client{}, s.BillSessionsAfter, &gorm.DB{})
actual, err := svc.creditSummaryForTeams(s.Sessions)
actual, err := svc.creditSummaryForTeams(s.Sessions, reportID)
require.NoError(t, err)
require.Equal(t, s.Expected, actual)
})
Expand Down
43 changes: 39 additions & 4 deletions components/usage/pkg/stripe/stripe.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ import (
"github.com/stripe/stripe-go/v72/client"
)

const (
reportIDMetadataKey = "reportId"
)

type Client struct {
sc *client.API
}
Expand Down Expand Up @@ -58,9 +62,14 @@ type Invoice struct {
Credits int64
}

type CreditSummary struct {
Credits int64
ReportID string
}

// UpdateUsage updates teams' Stripe subscriptions with usage data
// `usageForTeam` is a map from team name to total workspace seconds used within a billing period.
func (c *Client) UpdateUsage(ctx context.Context, creditsPerTeam map[string]int64) error {
func (c *Client) UpdateUsage(ctx context.Context, creditsPerTeam map[string]CreditSummary) error {
teamIds := make([]string, 0, len(creditsPerTeam))
for k := range creditsPerTeam {
teamIds = append(teamIds, k)
Expand Down Expand Up @@ -117,7 +126,7 @@ func (c *Client) findCustomers(ctx context.Context, query string) ([]*stripe.Cus
return customers, nil
}

func (c *Client) updateUsageForCustomer(ctx context.Context, customer *stripe.Customer, credits int64) (*UsageRecord, error) {
func (c *Client) updateUsageForCustomer(ctx context.Context, customer *stripe.Customer, summary CreditSummary) (*UsageRecord, error) {
subscriptions := customer.Subscriptions.Data
if len(subscriptions) != 1 {
return nil, fmt.Errorf("customer has an unexpected number of subscriptions %v (expected 1, got %d)", subscriptions, len(subscriptions))
Expand All @@ -136,15 +145,27 @@ func (c *Client) updateUsageForCustomer(ctx context.Context, customer *stripe.Cu
Context: ctx,
},
SubscriptionItem: stripe.String(subscriptionItemId),
Quantity: stripe.Int64(credits),
Quantity: stripe.Int64(summary.Credits),
})
if err != nil {
return nil, fmt.Errorf("failed to register usage for customer %q on subscription item %s", customer.Name, subscriptionItemId)
}

invoice, err := c.GetUpcomingInvoice(ctx, customer.ID)
if err != nil {
return nil, fmt.Errorf("failed to find upcoming invoice for customer %s: %w", customer.ID, err)
}

_, err = c.UpdateInvoiceMetadata(ctx, invoice.ID, map[string]string{
reportIDMetadataKey: summary.ReportID,
})
if err != nil {
return nil, fmt.Errorf("failed to udpate invoice %s metadata with report ID: %w", invoice.ID, err)
}

return &UsageRecord{
SubscriptionItemID: subscriptionItemId,
Quantity: credits,
Quantity: summary.Credits,
}, nil
}

Expand Down Expand Up @@ -205,6 +226,20 @@ func (c *Client) GetUpcomingInvoice(ctx context.Context, customerID string) (*In
}, nil
}

func (c *Client) UpdateInvoiceMetadata(ctx context.Context, invoiceID string, metadata map[string]string) (*stripe.Invoice, error) {
invoice, err := c.sc.Invoices.Update(invoiceID, &stripe.InvoiceParams{
Params: stripe.Params{
Context: ctx,
Metadata: metadata,
},
})
if err != nil {
return nil, fmt.Errorf("failed to update invoice %s metadata: %w", invoiceID, err)
}

return invoice, nil
}

// queriesForCustomersWithTeamIds constructs Stripe query strings to find the Stripe Customer for each teamId
// It returns multiple queries, each being a big disjunction of subclauses so that we can process multiple teamIds in one query.
// `clausesPerQuery` is a limit enforced by the Stripe API.
Expand Down

0 comments on commit 33c613c

Please sign in to comment.