diff --git a/modules/kafka/kafka.go b/modules/kafka/kafka.go index b2b0e831b5..c0c02890d4 100644 --- a/modules/kafka/kafka.go +++ b/modules/kafka/kafka.go @@ -71,31 +71,16 @@ func Run(ctx context.Context, img string, opts ...testcontainers.ContainerCustom LifecycleHooks: []testcontainers.ContainerLifecycleHooks{ { PostStarts: []testcontainers.ContainerHook{ - // 1. copy the starter script into the container + // Use a single hook to copy the starter script and wait for + // the Kafka server to be ready. This prevents the wait running + // if the starter script fails to copy. func(ctx context.Context, c testcontainers.Container) error { - host, err := c.Host(ctx) - if err != nil { - return err + // 1. copy the starter script into the container + if err := copyStarterScript(ctx, c); err != nil { + return fmt.Errorf("copy starter script: %w", err) } - inspect, err := c.Inspect(ctx) - if err != nil { - return err - } - - hostname := inspect.Config.Hostname - - port, err := c.MappedPort(ctx, publicPort) - if err != nil { - return err - } - - scriptContent := fmt.Sprintf(starterScriptContent, host, port.Int(), hostname) - - return c.CopyToContainer(ctx, []byte(scriptContent), starterScript, 0o755) - }, - // 2. wait for the Kafka server to be ready - func(ctx context.Context, c testcontainers.Container) error { + // 2. wait for the Kafka server to be ready return wait.ForLog(".*Transitioning from RECOVERY to RUNNING.*").AsRegexp().WaitUntilReady(ctx, c) }, }, @@ -131,6 +116,40 @@ func Run(ctx context.Context, img string, opts ...testcontainers.ContainerCustom return &KafkaContainer{Container: container, ClusterID: clusterID}, nil } +// copyStarterScript copies the starter script into the container. +func copyStarterScript(ctx context.Context, c testcontainers.Container) error { + if err := wait.ForListeningPort(publicPort). + SkipInternalCheck(). + WaitUntilReady(ctx, c); err != nil { + return fmt.Errorf("wait for exposed port: %w", err) + } + + host, err := c.Host(ctx) + if err != nil { + return fmt.Errorf("host: %w", err) + } + + inspect, err := c.Inspect(ctx) + if err != nil { + return fmt.Errorf("inspect: %w", err) + } + + hostname := inspect.Config.Hostname + + port, err := c.MappedPort(ctx, publicPort) + if err != nil { + return fmt.Errorf("mapped port: %w", err) + } + + scriptContent := fmt.Sprintf(starterScriptContent, host, port.Int(), hostname) + + if err := c.CopyToContainer(ctx, []byte(scriptContent), starterScript, 0o755); err != nil { + return fmt.Errorf("copy to container: %w", err) + } + + return nil +} + func WithClusterID(clusterID string) testcontainers.CustomizeRequestOption { return func(req *testcontainers.GenericContainerRequest) error { req.Env["CLUSTER_ID"] = clusterID