diff --git a/cmd/client/grpc_client.go b/cmd/client/grpc_client.go index b7d2f3b18..9c8dcb379 100644 --- a/cmd/client/grpc_client.go +++ b/cmd/client/grpc_client.go @@ -151,5 +151,5 @@ func RegisterRemoteURLFlags(flags *pflag.FlagSet) { flags.String(FlagAuthority, "", "Set the authority header for the remote gRPC server.") flags.Bool(FlagInsecureNoTransportSecurity, false, "Disables transport security. Do not use this in production.") flags.Bool(FlagInsecureSkipHostVerification, false, "Disables hostname verification. Do not use this in production.") - flags.Bool(FlagBlock, false, "Block until all migrations have been applied") + flags.Bool(FlagBlock, false, "Block until the connection is up.") } diff --git a/cmd/migrate/migrate_test.go b/cmd/migrate/migrate_test.go index c5ba46a9c..189ead920 100644 --- a/cmd/migrate/migrate_test.go +++ b/cmd/migrate/migrate_test.go @@ -8,6 +8,7 @@ import ( "context" "strings" "testing" + "time" "github.com/ory/x/dbal" @@ -98,6 +99,24 @@ func TestMigrate(t *testing.T) { cmd := newCmd(ctx, "-c", cf) + t.Run("case=shows status", func(t *testing.T) { + stdOut := cmd.ExecNoErr(t, "status") + assert.Contains(t, stdOut, "Pending") + assert.NotContains(t, stdOut, "Applied") + }) + + t.Run("case=status blocks until all are applied", func(t *testing.T) { + cmd := *cmd + ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) + defer cancel() + cmd.Ctx = ctx + + stdOut, stdErr, err := cmd.Exec(nil, "status", "--block") + require.ErrorIs(t, err, cmdx.ErrNoPrintButFail) + assert.Contains(t, stdOut, "Waiting for migrations to finish...") + assert.Contains(t, stdErr, "Context was canceled, exiting...", stdOut) + }) + t.Run("case=aborts on no", func(t *testing.T) { stdOut, stdErr, err := cmd.Exec(bytes.NewBufferString("n\n"), "up") require.NoError(t, err, "%s %s", stdOut, stdErr) diff --git a/cmd/migrate/status.go b/cmd/migrate/status.go index 1b7b3aeaa..cbeb8b339 100644 --- a/cmd/migrate/status.go +++ b/cmd/migrate/status.go @@ -12,12 +12,12 @@ import ( "github.com/ory/x/cmdx" "github.com/spf13/cobra" - "github.com/ory/keto/cmd/client" "github.com/ory/keto/internal/driver" "github.com/ory/keto/ketoctx" ) func newStatusCmd(opts []ketoctx.Option) *cobra.Command { + block := false cmd := &cobra.Command{ Use: "status", Short: "Get the current migration status", @@ -26,11 +26,6 @@ func newStatusCmd(opts []ketoctx.Option) *cobra.Command { RunE: func(cmd *cobra.Command, _ []string) error { ctx := cmd.Context() - block, err := cmd.Flags().GetBool(client.FlagBlock) - if err != nil { - return err - } - reg, err := driver.NewDefaultRegistry(ctx, cmd.Flags(), true, opts) if err != nil { return err @@ -54,7 +49,12 @@ func newStatusCmd(opts []ketoctx.Option) *cobra.Command { _, _ = fmt.Fprintf(cmd.OutOrStdout(), " - %s\n", m.Name) } } - time.Sleep(time.Second) + select { + case <-ctx.Done(): + _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Context was canceled, exiting...") + return cmdx.FailSilently(cmd) + case <-time.After(time.Second): + } s, err = mb.Status(ctx) if err != nil { _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Could not get migration status: %+v\n", err) @@ -68,6 +68,7 @@ func newStatusCmd(opts []ketoctx.Option) *cobra.Command { } cmdx.RegisterFormatFlags(cmd.Flags()) + cmd.Flags().BoolVar(&block, "block", false, "Block until all migrations have been applied") return cmd }