Skip to content

Commit

Permalink
context propagation: pkg/database/machines (#3248)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmetc authored Sep 20, 2024
1 parent e2196bd commit fee3deb
Show file tree
Hide file tree
Showing 11 changed files with 109 additions and 94 deletions.
56 changes: 30 additions & 26 deletions cmd/crowdsec-cli/climachine/machines.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package climachine

import (
"context"
"encoding/csv"
"encoding/json"
"errors"
Expand Down Expand Up @@ -210,11 +211,11 @@ func (cli *cliMachines) listCSV(out io.Writer, machines ent.Machines) error {
return nil
}

func (cli *cliMachines) List(out io.Writer, db *database.Client) error {
func (cli *cliMachines) List(ctx context.Context, out io.Writer, db *database.Client) error {
// XXX: must use the provided db object, the one in the struct might be nil
// (calling List directly skips the PersistentPreRunE)

machines, err := db.ListMachines()
machines, err := db.ListMachines(ctx)
if err != nil {
return fmt.Errorf("unable to list machines: %w", err)
}
Expand Down Expand Up @@ -251,8 +252,8 @@ func (cli *cliMachines) newListCmd() *cobra.Command {
Example: `cscli machines list`,
Args: cobra.NoArgs,
DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, _ []string) error {
return cli.List(color.Output, cli.db)
RunE: func(cmd *cobra.Command, _ []string) error {
return cli.List(cmd.Context(), color.Output, cli.db)
},
}

Expand All @@ -278,8 +279,8 @@ func (cli *cliMachines) newAddCmd() *cobra.Command {
cscli machines add MyTestMachine --auto
cscli machines add MyTestMachine --password MyPassword
cscli machines add -f- --auto > /tmp/mycreds.yaml`,
RunE: func(_ *cobra.Command, args []string) error {
return cli.add(args, string(password), dumpFile, apiURL, interactive, autoAdd, force)
RunE: func(cmd *cobra.Command, args []string) error {
return cli.add(cmd.Context(), args, string(password), dumpFile, apiURL, interactive, autoAdd, force)
},
}

Expand All @@ -294,7 +295,7 @@ cscli machines add -f- --auto > /tmp/mycreds.yaml`,
return cmd
}

func (cli *cliMachines) add(args []string, machinePassword string, dumpFile string, apiURL string, interactive bool, autoAdd bool, force bool) error {
func (cli *cliMachines) add(ctx context.Context, args []string, machinePassword string, dumpFile string, apiURL string, interactive bool, autoAdd bool, force bool) error {
var (
err error
machineID string
Expand Down Expand Up @@ -353,7 +354,7 @@ func (cli *cliMachines) add(args []string, machinePassword string, dumpFile stri

password := strfmt.Password(machinePassword)

_, err = cli.db.CreateMachine(&machineID, &password, "", true, force, types.PasswordAuthType)
_, err = cli.db.CreateMachine(ctx, &machineID, &password, "", true, force, types.PasswordAuthType)
if err != nil {
return fmt.Errorf("unable to create machine: %w", err)
}
Expand Down Expand Up @@ -399,6 +400,7 @@ func (cli *cliMachines) validMachineID(cmd *cobra.Command, args []string, toComp
var err error

cfg := cli.cfg()
ctx := cmd.Context()

// need to load config and db because PersistentPreRunE is not called for completions

Expand All @@ -407,13 +409,13 @@ func (cli *cliMachines) validMachineID(cmd *cobra.Command, args []string, toComp
return nil, cobra.ShellCompDirectiveNoFileComp
}

cli.db, err = require.DBClient(cmd.Context(), cfg.DbConfig)
cli.db, err = require.DBClient(ctx, cfg.DbConfig)
if err != nil {
cobra.CompError("unable to list machines " + err.Error())
return nil, cobra.ShellCompDirectiveNoFileComp
}

machines, err := cli.db.ListMachines()
machines, err := cli.db.ListMachines(ctx)
if err != nil {
cobra.CompError("unable to list machines " + err.Error())
return nil, cobra.ShellCompDirectiveNoFileComp
Expand All @@ -430,9 +432,9 @@ func (cli *cliMachines) validMachineID(cmd *cobra.Command, args []string, toComp
return ret, cobra.ShellCompDirectiveNoFileComp
}

func (cli *cliMachines) delete(machines []string, ignoreMissing bool) error {
func (cli *cliMachines) delete(ctx context.Context, machines []string, ignoreMissing bool) error {
for _, machineID := range machines {
if err := cli.db.DeleteWatcher(machineID); err != nil {
if err := cli.db.DeleteWatcher(ctx, machineID); err != nil {
var notFoundErr *database.MachineNotFoundError
if ignoreMissing && errors.As(err, &notFoundErr) {
return nil
Expand Down Expand Up @@ -460,8 +462,8 @@ func (cli *cliMachines) newDeleteCmd() *cobra.Command {
Aliases: []string{"remove"},
DisableAutoGenTag: true,
ValidArgsFunction: cli.validMachineID,
RunE: func(_ *cobra.Command, args []string) error {
return cli.delete(args, ignoreMissing)
RunE: func(cmd *cobra.Command, args []string) error {
return cli.delete(cmd.Context(), args, ignoreMissing)
},
}

Expand All @@ -471,7 +473,7 @@ func (cli *cliMachines) newDeleteCmd() *cobra.Command {
return cmd
}

func (cli *cliMachines) prune(duration time.Duration, notValidOnly bool, force bool) error {
func (cli *cliMachines) prune(ctx context.Context, duration time.Duration, notValidOnly bool, force bool) error {
if duration < 2*time.Minute && !notValidOnly {
if yes, err := ask.YesNo(
"The duration you provided is less than 2 minutes. "+
Expand All @@ -484,12 +486,12 @@ func (cli *cliMachines) prune(duration time.Duration, notValidOnly bool, force b
}

machines := []*ent.Machine{}
if pending, err := cli.db.QueryPendingMachine(); err == nil {
if pending, err := cli.db.QueryPendingMachine(ctx); err == nil {
machines = append(machines, pending...)
}

if !notValidOnly {
if pending, err := cli.db.QueryMachinesInactiveSince(time.Now().UTC().Add(-duration)); err == nil {
if pending, err := cli.db.QueryMachinesInactiveSince(ctx, time.Now().UTC().Add(-duration)); err == nil {
machines = append(machines, pending...)
}
}
Expand All @@ -512,7 +514,7 @@ func (cli *cliMachines) prune(duration time.Duration, notValidOnly bool, force b
}
}

deleted, err := cli.db.BulkDeleteWatchers(machines)
deleted, err := cli.db.BulkDeleteWatchers(ctx, machines)
if err != nil {
return fmt.Errorf("unable to prune machines: %w", err)
}
Expand Down Expand Up @@ -540,8 +542,8 @@ cscli machines prune --duration 1h
cscli machines prune --not-validated-only --force`,
Args: cobra.NoArgs,
DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, _ []string) error {
return cli.prune(duration, notValidOnly, force)
RunE: func(cmd *cobra.Command, _ []string) error {
return cli.prune(cmd.Context(), duration, notValidOnly, force)
},
}

Expand All @@ -553,8 +555,8 @@ cscli machines prune --not-validated-only --force`,
return cmd
}

func (cli *cliMachines) validate(machineID string) error {
if err := cli.db.ValidateMachine(machineID); err != nil {
func (cli *cliMachines) validate(ctx context.Context, machineID string) error {
if err := cli.db.ValidateMachine(ctx, machineID); err != nil {
return fmt.Errorf("unable to validate machine '%s': %w", machineID, err)
}

Expand All @@ -571,8 +573,8 @@ func (cli *cliMachines) newValidateCmd() *cobra.Command {
Example: `cscli machines validate "machine_name"`,
Args: cobra.ExactArgs(1),
DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, args []string) error {
return cli.validate(args[0])
RunE: func(cmd *cobra.Command, args []string) error {
return cli.validate(cmd.Context(), args[0])
},
}

Expand Down Expand Up @@ -690,9 +692,11 @@ func (cli *cliMachines) newInspectCmd() *cobra.Command {
Args: cobra.ExactArgs(1),
DisableAutoGenTag: true,
ValidArgsFunction: cli.validMachineID,
RunE: func(_ *cobra.Command, args []string) error {
RunE: func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
machineID := args[0]
machine, err := cli.db.QueryMachineByID(machineID)

machine, err := cli.db.QueryMachineByID(ctx, machineID)
if err != nil {
return fmt.Errorf("unable to read machine data '%s': %w", machineID, err)
}
Expand Down
6 changes: 3 additions & 3 deletions cmd/crowdsec-cli/clisupport/support.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ func (cli *cliSupport) dumpBouncers(ctx context.Context, zw *zip.Writer, db *dat
return nil
}

func (cli *cliSupport) dumpAgents(zw *zip.Writer, db *database.Client) error {
func (cli *cliSupport) dumpAgents(ctx context.Context, zw *zip.Writer, db *database.Client) error {
log.Info("Collecting agents")

if db == nil {
Expand All @@ -220,7 +220,7 @@ func (cli *cliSupport) dumpAgents(zw *zip.Writer, db *database.Client) error {
out := new(bytes.Buffer)
cm := climachine.New(cli.cfg)

if err := cm.List(out, db); err != nil {
if err := cm.List(ctx, out, db); err != nil {
return err
}

Expand Down Expand Up @@ -529,7 +529,7 @@ func (cli *cliSupport) dump(ctx context.Context, outFile string) error {
log.Warnf("could not collect bouncers information: %s", err)
}

if err = cli.dumpAgents(zipWriter, db); err != nil {
if err = cli.dumpAgents(ctx, zipWriter, db); err != nil {
log.Warnf("could not collect agents information: %s", err)
}

Expand Down
4 changes: 3 additions & 1 deletion pkg/apiserver/apic.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ func randomDuration(d time.Duration, delta time.Duration) time.Duration {
func (a *apic) FetchScenariosListFromDB() ([]string, error) {
scenarios := make([]string, 0)

machines, err := a.dbClient.ListMachines()
ctx := context.TODO()

machines, err := a.dbClient.ListMachines(ctx)
if err != nil {
return nil, fmt.Errorf("while listing machines: %w", err)
}
Expand Down
10 changes: 5 additions & 5 deletions pkg/apiserver/apic_metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (a *apic) GetUsageMetrics(ctx context.Context) (*models.AllMetrics, []int,
allMetrics := &models.AllMetrics{}
metricsIds := make([]int, 0)

lps, err := a.dbClient.ListMachines()
lps, err := a.dbClient.ListMachines(ctx)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -186,7 +186,7 @@ func (a *apic) MarkUsageMetricsAsSent(ctx context.Context, ids []int) error {
}

func (a *apic) GetMetrics(ctx context.Context) (*models.Metrics, error) {
machines, err := a.dbClient.ListMachines()
machines, err := a.dbClient.ListMachines(ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -230,8 +230,8 @@ func (a *apic) GetMetrics(ctx context.Context) (*models.Metrics, error) {
}, nil
}

func (a *apic) fetchMachineIDs() ([]string, error) {
machines, err := a.dbClient.ListMachines()
func (a *apic) fetchMachineIDs(ctx context.Context) ([]string, error) {
machines, err := a.dbClient.ListMachines(ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -277,7 +277,7 @@ func (a *apic) SendMetrics(stop chan (bool)) {
machineIDs := []string{}

reloadMachineIDs := func() {
ids, err := a.fetchMachineIDs()
ids, err := a.fetchMachineIDs(ctx)
if err != nil {
log.Debugf("unable to get machines (%s), will retry", err)

Expand Down
10 changes: 5 additions & 5 deletions pkg/apiserver/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,12 @@ func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config) {
}

func ValidateMachine(t *testing.T, machineID string, config *csconfig.DatabaseCfg) {
ctx := context.Background()
ctx := context.TODO()

dbClient, err := database.NewClient(ctx, config)
require.NoError(t, err)

err = dbClient.ValidateMachine(machineID)
err = dbClient.ValidateMachine(ctx, machineID)
require.NoError(t, err)
}

Expand All @@ -197,7 +197,7 @@ func GetMachineIP(t *testing.T, machineID string, config *csconfig.DatabaseCfg)
dbClient, err := database.NewClient(ctx, config)
require.NoError(t, err)

machines, err := dbClient.ListMachines()
machines, err := dbClient.ListMachines(ctx)
require.NoError(t, err)

for _, machine := range machines {
Expand Down Expand Up @@ -332,7 +332,7 @@ func TestUnknownPath(t *testing.T) {
req.Header.Set("User-Agent", UserAgent)
router.ServeHTTP(w, req)

assert.Equal(t, 404, w.Code)
assert.Equal(t, http.StatusNotFound, w.Code)
}

/*
Expand Down Expand Up @@ -390,7 +390,7 @@ func TestLoggingDebugToFileConfig(t *testing.T) {
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test42", nil)
req.Header.Set("User-Agent", UserAgent)
api.router.ServeHTTP(w, req)
assert.Equal(t, 404, w.Code)
assert.Equal(t, http.StatusNotFound, w.Code)
// wait for the request to happen
time.Sleep(500 * time.Millisecond)

Expand Down
4 changes: 3 additions & 1 deletion pkg/apiserver/controllers/v1/heartbeat.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import (
func (c *Controller) HeartBeat(gctx *gin.Context) {
machineID, _ := getMachineIDFromContext(gctx)

if err := c.DBClient.UpdateMachineLastHeartBeat(machineID); err != nil {
ctx := gctx.Request.Context()

if err := c.DBClient.UpdateMachineLastHeartBeat(ctx, machineID); err != nil {
c.HandleDBErrors(gctx, err)
return
}
Expand Down
4 changes: 3 additions & 1 deletion pkg/apiserver/controllers/v1/machines.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ func (c *Controller) shouldAutoRegister(token string, gctx *gin.Context) (bool,
}

func (c *Controller) CreateMachine(gctx *gin.Context) {
ctx := gctx.Request.Context()

var input models.WatcherRegistrationRequest

if err := gctx.ShouldBindJSON(&input); err != nil {
Expand All @@ -66,7 +68,7 @@ func (c *Controller) CreateMachine(gctx *gin.Context) {
return
}

if _, err := c.DBClient.CreateMachine(input.MachineID, input.Password, gctx.ClientIP(), autoRegister, false, types.PasswordAuthType); err != nil {
if _, err := c.DBClient.CreateMachine(ctx, input.MachineID, input.Password, gctx.ClientIP(), autoRegister, false, types.PasswordAuthType); err != nil {
c.HandleDBErrors(gctx, err)
return
}
Expand Down
Loading

0 comments on commit fee3deb

Please sign in to comment.