diff --git a/components/usage/pkg/apiv1/billing.go b/components/usage/pkg/apiv1/billing.go index b5a7e85dc8f3d6..af0be259c34a42 100644 --- a/components/usage/pkg/apiv1/billing.go +++ b/components/usage/pkg/apiv1/billing.go @@ -37,6 +37,47 @@ type BillingService struct { v1.UnimplementedBillingServiceServer } +func (s *BillingService) GetStripeCustomer(ctx context.Context, req *v1.GetStripeCustomerRequest) (*v1.GetStripeCustomerResponse, error) { + switch identifier := req.GetIdentifier().(type) { + case *v1.GetStripeCustomerRequest_AttributionId: + attributionID, err := db.ParseAttributionID(identifier.AttributionId) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "Invalid attribution ID %s", attributionID) + } + + customer, err := s.stripeClient.GetCustomerByAttributionID(ctx, string(attributionID)) + if err != nil { + return nil, err + } + + return &v1.GetStripeCustomerResponse{ + AttributionId: string(attributionID), + Customer: &v1.StripeCustomer{ + Id: customer.ID, + }, + }, nil + case *v1.GetStripeCustomerRequest_StripeCustomerId: + customer, err := s.stripeClient.GetCustomer(ctx, identifier.StripeCustomerId) + if err != nil { + return nil, err + } + + attributionID, err := stripe.GetAttributionID(ctx, customer) + if err != nil { + return nil, status.Errorf(codes.Internal, "Failed to parse attribution ID from Stripe customer %s", customer.ID) + } + + return &v1.GetStripeCustomerResponse{ + AttributionId: string(attributionID), + Customer: &v1.StripeCustomer{ + Id: customer.ID, + }, + }, nil + default: + return nil, status.Errorf(codes.InvalidArgument, "Unknown identifier") + } +} + func (s *BillingService) ReconcileInvoices(ctx context.Context, in *v1.ReconcileInvoicesRequest) (*v1.ReconcileInvoicesResponse, error) { balances, err := db.ListBalance(ctx, s.conn) if err != nil { diff --git a/components/usage/pkg/stripe/stripe.go b/components/usage/pkg/stripe/stripe.go index 2743024276c2bf..bdb9e36cd8e60c 100644 --- a/components/usage/pkg/stripe/stripe.go +++ b/components/usage/pkg/stripe/stripe.go @@ -187,36 +187,40 @@ func (c *Client) updateUsageForCustomer(ctx context.Context, customer *stripe.Cu return nil } -func (c *Client) GetCustomerByTeamID(ctx context.Context, teamID string) (*stripe.Customer, error) { - customers, err := c.findCustomers(ctx, fmt.Sprintf("metadata['teamId']:'%s'", teamID)) +func (c *Client) GetCustomerByAttributionID(ctx context.Context, attributionID string) (*stripe.Customer, error) { + customers, err := c.findCustomers(ctx, fmt.Sprintf("metadata['attributionId']:'%s'", attributionID)) if err != nil { - return nil, fmt.Errorf("failed to find customers: %w", err) + return nil, status.Errorf(codes.Internal, "failed to find customers: %v", err) } if len(customers) == 0 { - return nil, fmt.Errorf("no team customer found for id: %s", teamID) + return nil, status.Errorf(codes.NotFound, "no team customer found for attribution_id: %s", attributionID) } if len(customers) > 1 { - return nil, fmt.Errorf("found multiple team customers for id: %s", teamID) + return nil, status.Errorf(codes.FailedPrecondition, "found multiple customers for attributiuon_id: %s", attributionID) } return customers[0], nil } -func (c *Client) GetCustomerByUserID(ctx context.Context, userID string) (*stripe.Customer, error) { - customers, err := c.findCustomers(ctx, fmt.Sprintf("metadata['userId']:'%s'", userID)) +func (c *Client) GetCustomer(ctx context.Context, customerID string) (*stripe.Customer, error) { + customer, err := c.sc.Customers.Get(customerID, &stripe.CustomerParams{ + Params: stripe.Params{ + Context: ctx, + }, + }) if err != nil { - return nil, fmt.Errorf("failed to find customers: %w", err) - } + if stripeErr, ok := err.(*stripe.Error); ok { + switch stripeErr.Code { + case stripe.ErrorCodeMissing: + return nil, status.Errorf(codes.NotFound, "customer %s does not exist in stripe", customerID) + } + } - if len(customers) == 0 { - return nil, fmt.Errorf("no user customer found for id: %s", userID) - } - if len(customers) > 1 { - return nil, fmt.Errorf("found multiple user customers for id: %s", userID) + return nil, fmt.Errorf("failed to get customer by customer ID %s", customerID) } - return customers[0], nil + return customer, nil } func (c *Client) GetInvoiceWithCustomer(ctx context.Context, invoiceID string) (*stripe.Invoice, error) {