diff --git a/command/agent.go b/command/agent.go index 4d7bc201a60b..111345b4d865 100644 --- a/command/agent.go +++ b/command/agent.go @@ -61,8 +61,9 @@ type AgentCommand struct { startedCh chan (struct{}) // for tests - flagConfigs []string - flagLogLevel string + flagConfigs []string + flagLogLevel string + flagExitAfterAuth bool flagTestVerifyOnly bool flagCombineLogs bool @@ -115,6 +116,15 @@ func (c *AgentCommand) Flags() *FlagSets { "\"trace\", \"debug\", \"info\", \"warn\", and \"err\".", }) + f.BoolVar(&BoolVar{ + Name: "exit-after-auth", + Target: &c.flagExitAfterAuth, + Default: false, + Usage: "If set to true, the agent will exit with code 0 after a single " + + "successful auth, where success means that a token was retrieved and " + + "all sinks successfully wrote it", + }) + // Internal-only flags to follow. // // Why hello there little source code reader! Welcome to the Vault source @@ -223,6 +233,13 @@ func (c *AgentCommand) Run(args []string) int { config.Vault = new(agentConfig.Vault) } + exitAfterAuth := config.ExitAfterAuth + f.Visit(func(fl *flag.Flag) { + if fl.Name == "exit-after-auth" { + exitAfterAuth = c.flagExitAfterAuth + } + }) + c.setStringFlag(f, config.Vault.Address, &StringVar{ Name: flagNameAddress, Target: &c.flagAddress, @@ -524,7 +541,7 @@ func (c *AgentCommand) Run(args []string) int { ss := sink.NewSinkServer(&sink.SinkServerConfig{ Logger: c.logger.Named("sink.server"), Client: client, - ExitAfterAuth: config.ExitAfterAuth, + ExitAfterAuth: exitAfterAuth, }) ssDoneCh = ss.DoneCh @@ -534,7 +551,7 @@ func (c *AgentCommand) Run(args []string) int { LogWriter: c.logWriter, VaultConf: config.Vault, Namespace: namespace, - ExitAfterAuth: config.ExitAfterAuth, + ExitAfterAuth: exitAfterAuth, }) tsDoneCh = ts.DoneCh diff --git a/command/agent_test.go b/command/agent_test.go index d8dbd5cf26e9..8b8a3de27874 100644 --- a/command/agent_test.go +++ b/command/agent_test.go @@ -225,6 +225,16 @@ cache { */ func TestAgent_ExitAfterAuth(t *testing.T) { + t.Run("via_config", func(t *testing.T) { + testAgentExitAfterAuth(t, false) + }) + + t.Run("via_flag", func(t *testing.T) { + testAgentExitAfterAuth(t, true) + }) +} + +func testAgentExitAfterAuth(t *testing.T, viaFlag bool) { logger := logging.NewVaultLogger(hclog.Trace) coreConfig := &vault.CoreConfig{ Logger: logger, @@ -313,8 +323,13 @@ func TestAgent_ExitAfterAuth(t *testing.T) { logger.Trace("wrote test jwt", "path", in) } + exitAfterAuthTemplText := "exit_after_auth = true" + if viaFlag { + exitAfterAuthTemplText = "" + } + config := ` -exit_after_auth = true +%s auto_auth { method { @@ -340,23 +355,37 @@ auto_auth { } ` - config = fmt.Sprintf(config, in, sink1, sink2) + config = fmt.Sprintf(config, exitAfterAuthTemplText, in, sink1, sink2) if err := ioutil.WriteFile(conf, []byte(config), 0600); err != nil { t.Fatal(err) } else { logger.Trace("wrote test config", "path", conf) } - // If this hangs forever until the test times out, exit-after-auth isn't - // working - ui, cmd := testAgentCommand(t, logger) - cmd.client = client + doneCh := make(chan struct{}) + go func() { + ui, cmd := testAgentCommand(t, logger) + cmd.client = client - code := cmd.Run([]string{"-config", conf}) - if code != 0 { - t.Errorf("expected %d to be %d", code, 0) - t.Logf("output from agent:\n%s", ui.OutputWriter.String()) - t.Logf("error from agent:\n%s", ui.ErrorWriter.String()) + args := []string{"-config", conf} + if viaFlag { + args = append(args, "-exit-after-auth") + } + + code := cmd.Run(args) + if code != 0 { + t.Errorf("expected %d to be %d", code, 0) + t.Logf("output from agent:\n%s", ui.OutputWriter.String()) + t.Logf("error from agent:\n%s", ui.ErrorWriter.String()) + } + close(doneCh) + }() + + select { + case <-doneCh: + break + case <-time.After(1 * time.Minute): + t.Fatal("timeout reached while waiting for agent to exit") } sink1Bytes, err := ioutil.ReadFile(sink1)