diff --git a/.github/workflows/ci-test-go.yml b/.github/workflows/ci-test-go.yml index df544fdc61..54a3a3d152 100644 --- a/.github/workflows/ci-test-go.yml +++ b/.github/workflows/ci-test-go.yml @@ -65,7 +65,7 @@ jobs: uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4 - name: Set up Go - uses: actions/setup-go@0c52d547c9bc32b1aa3301fd7a9cb496313a4491 # v5 + uses: actions/setup-go@41dfa10bad2bb2ae585af6ee5bb4d7d973ad74ed # v5 with: go-version: '${{ inputs.go-version }}' cache-dependency-path: '${{ inputs.project-directory }}/go.sum' @@ -141,7 +141,7 @@ jobs: ./scripts/check_environment.sh - name: Test Summary - uses: test-summary/action@032c8a9cec6aaa3c20228112cae6ca10a3b29336 # v2.3 + uses: test-summary/action@31493c76ec9e7aa675f1585d3ed6f1da69269a86 # v2.4 with: paths: "**/${{ inputs.project-directory }}/TEST-unit*.xml" if: always() diff --git a/.github/workflows/ci-windows.yml b/.github/workflows/ci-windows.yml index fce4331c94..65a8cf573d 100644 --- a/.github/workflows/ci-windows.yml +++ b/.github/workflows/ci-windows.yml @@ -31,7 +31,7 @@ jobs: ref: ${{ github.event.client_payload.pull_request.head.ref }} - name: Set up Go - uses: actions/setup-go@0c52d547c9bc32b1aa3301fd7a9cb496313a4491 # v5 + uses: actions/setup-go@41dfa10bad2bb2ae585af6ee5bb4d7d973ad74ed # v5 with: go-version-file: go.mod id: go diff --git a/.github/workflows/docker-moby-latest.yml b/.github/workflows/docker-moby-latest.yml index bebb968652..957d2eb3ff 100644 --- a/.github/workflows/docker-moby-latest.yml +++ b/.github/workflows/docker-moby-latest.yml @@ -25,7 +25,7 @@ jobs: uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4 - name: Set up Go - uses: actions/setup-go@0c52d547c9bc32b1aa3301fd7a9cb496313a4491 # v5 + uses: actions/setup-go@41dfa10bad2bb2ae585af6ee5bb4d7d973ad74ed # v5 with: go-version-file: 'go.mod' cache-dependency-path: 'go.sum' diff --git a/.github/workflows/scorecards.yml b/.github/workflows/scorecards.yml index 58de84c72e..51b8a535c5 100644 --- a/.github/workflows/scorecards.yml +++ b/.github/workflows/scorecards.yml @@ -25,7 +25,7 @@ jobs: persist-credentials: false - name: "Run analysis" - uses: ossf/scorecard-action@0864cf19026789058feabb7e87baa5f140aac736 # v2.3.1 + uses: ossf/scorecard-action@62b2cac7ed8198b15735ed49ab1e5cf35480ba46 # v2.4.0 with: results_file: results.sarif results_format: sarif diff --git a/container.go b/container.go index 90a7ba91cb..5ee0aac881 100644 --- a/container.go +++ b/container.go @@ -532,7 +532,7 @@ func (c *ContainerRequest) validateMounts() error { if len(hostConfig.Binds) > 0 { for _, bind := range hostConfig.Binds { parts := strings.Split(bind, ":") - if len(parts) != 2 { + if len(parts) != 2 && len(parts) != 3 { return fmt.Errorf("%w: %s", ErrInvalidBindMount, bind) } targetPath := parts[1] diff --git a/container_test.go b/container_test.go index c3581a379e..742a97436e 100644 --- a/container_test.go +++ b/container_test.go @@ -73,11 +73,33 @@ func Test_ContainerValidation(t *testing.T) { }, { Name: "Invalid bind mount", - ExpectedError: "invalid bind mount: /data:/data:/data", + ExpectedError: "invalid bind mount: /data:/data:a:b", ContainerRequest: testcontainers.ContainerRequest{ Image: "redis:latest", HostConfigModifier: func(hc *container.HostConfig) { - hc.Binds = []string{"/data:/data:/data"} + hc.Binds = []string{"/data:/data:a:b"} + }, + }, + }, + { + Name: "bind-options/provided", + ContainerRequest: testcontainers.ContainerRequest{ + Image: "redis:latest", + HostConfigModifier: func(hc *container.HostConfig) { + hc.Binds = []string{ + "/a:/a:nocopy", + "/b:/b:ro", + "/c:/c:rw", + "/d:/d:z", + "/e:/e:Z", + "/f:/f:shared", + "/g:/g:rshared", + "/h:/h:slave", + "/i:/i:rslave", + "/j:/j:private", + "/k:/k:rprivate", + "/l:/l:ro,z,shared", + } }, }, }, diff --git a/docs/features/tls.md b/docs/features/tls.md index fd8b95266d..130f789b5f 100644 --- a/docs/features/tls.md +++ b/docs/features/tls.md @@ -12,6 +12,6 @@ The example will also create a client that will connect to the server using the demonstrating how to use the generated certificate to communicate with a service. -[Create a self-signed certificate](../../modules/cockroachdb/certs.go) inside_block:exampleSelfSignedCert -[Sign a self-signed certificate](../../modules/cockroachdb/certs.go) inside_block:exampleSignSelfSignedCert +[Create a self-signed certificate](../../modules/rabbitmq/examples_test.go) inside_block:exampleSelfSignedCert +[Sign a self-signed certificate](../../modules/rabbitmq/examples_test.go) inside_block:exampleSignSelfSignedCert diff --git a/docs/features/wait/introduction.md b/docs/features/wait/introduction.md index 87adabc3ed..feef9dc939 100644 --- a/docs/features/wait/introduction.md +++ b/docs/features/wait/introduction.md @@ -15,6 +15,7 @@ Below you can find a list of the available wait strategies that you can use: - [Log](./log.md) - [Multi](./multi.md) - [SQL](./sql.md) +- [TLS](./tls.md) ## Startup timeout and Poll interval @@ -25,3 +26,8 @@ If the default 60s timeout is not sufficient, it can be updated with the `WithSt Besides that, it's possible to define a poll interval, which will actually stop 100 milliseconds the test execution. If the default 100 milliseconds poll interval is not sufficient, it can be updated with the `WithPollInterval(pollInterval time.Duration)` function. + +## Modifying request strategies + +It's possible for options to modify `ContainerRequest.WaitingFor` using +[Walk](walk.md). diff --git a/docs/features/wait/tls.md b/docs/features/wait/tls.md new file mode 100644 index 0000000000..a98f78d84c --- /dev/null +++ b/docs/features/wait/tls.md @@ -0,0 +1,31 @@ +# TLS Strategy + +TLS Strategy waits for one or more files to exist in the container and uses them +and other details to construct a `tls.Config` which can be used to create secure +connections. + +It supports: + +- x509 PEM Certificate loaded from a certificate / key file pair. +- Root Certificate Authorities aka RootCAs loaded from PEM encoded files. +- Server name. +- Startup timeout to be used in seconds, default is 60 seconds. +- Poll interval to be used in milliseconds, default is 100 milliseconds. + +## Waiting for certificate pair + +The following snippets show how to configure a request to wait for certificate +pair to exist once started and then read the +[tls.Config](https://pkg.go.dev/crypto/tls#Config), alongside how to copy a test +certificate pair into a container image using a `Dockerfile`. + +It should be noted that copying certificate pairs into an images is only an +example which might be useful for testing with testcontainers-go and should not +be done with production images as that could expose your certificates if your +images become public. + + +[Wait for certificate](../../../wait/tls_test.go) inside_block:waitForTLSCert +[Read TLS Config](../../../wait/tls_test.go) inside_block:waitTLSConfig +[Dockerfile with certificate](../../../wait/testdata/http/Dockerfile) + diff --git a/docs/features/wait/walk.md b/docs/features/wait/walk.md new file mode 100644 index 0000000000..f8db724cc0 --- /dev/null +++ b/docs/features/wait/walk.md @@ -0,0 +1,19 @@ +# Walk + +Walk walks the strategies tree and calls the visit function for each node. + +This allows modules to easily amend default wait strategies, updating or +removing specific strategies based on requirements of functional options. + +For example removing a TLS strategy if a functional option enabled insecure mode +or changing the location of the certificate based on the configured user. + +If visit function returns `wait.VisitStop`, the walk stops. +If visit function returns `wait.VisitRemove`, the current node is removed. + +## Walk removing entries + +The following example shows how to remove a strategy based on its type. + +[Remove FileStrategy entries](../../../wait/walk_test.go) inside_block:walkRemoveFileStrategy + diff --git a/docs/modules/cockroachdb.md b/docs/modules/cockroachdb.md index 6bbdba0792..39956d5417 100644 --- a/docs/modules/cockroachdb.md +++ b/docs/modules/cockroachdb.md @@ -10,7 +10,7 @@ The Testcontainers module for CockroachDB. Please run the following command to add the CockroachDB module to your Go dependencies: -``` +```shell go get github.com/testcontainers/testcontainers-go/modules/cockroachdb ``` @@ -54,9 +54,11 @@ E.g. `Run(context.Background(), "cockroachdb/cockroach:latest-v23.1")`. Set the database that is created & dialled with `cockroachdb.WithDatabase`. -#### Password authentication +#### User and Password + +You can configured the container to create a user with a password by setting `cockroachdb.WithUser` and `cockroachdb.WithPassword`. -Disable insecure mode and connect with password authentication by setting `cockroachdb.WithUser` and `cockroachdb.WithPassword`. +`cockroachdb.WithPassword` is incompatible with `cockroachdb.WithInsecure`. #### Store size @@ -64,13 +66,21 @@ Control the maximum amount of memory used for storage, by default this is 100% b #### TLS authentication -`cockroachdb.WithTLS` lets you provide the CA certificate along with the certicate and key for the node & clients to connect with. -Internally CockroachDB requires a client certificate for the user to connect with. +`cockroachdb.WithInsecure` lets you disable the use of TLS on connections. + +`cockroachdb.WithInsecure` is incompatible with `cockroachdb.WithPassword`. + +#### Initialization Scripts + +`cockroachdb.WithInitScripts` adds the given scripts to those automatically run when the container starts. +These will be ignored if data exists in the `/cockroach/cockroach-data` directory within the container. -A helper `cockroachdb.NewTLSConfig` exists to generate all of this for you. +`cockroachdb.WithNoClusterDefaults` disables the default cluster settings script. -!!!warning - When TLS is enabled there's a very small, unlikely chance that the underlying driver can panic when registering the driver as part of waiting for CockroachDB to be ready to accept connections. If this is repeatedly happening please open an issue. +Without this option Cockroach containers run `data/cluster-defaults.sql` on startup +which configures the settings recommended by Cockroach Labs for +[local testing clusters](https://www.cockroachlabs.com/docs/stable/local-testing) +unless data exists in the `/cockroach/cockroach-data` directory within the container. ### Container Methods @@ -87,3 +97,10 @@ Same as `ConnectionString` but any error to generate the address will raise a pa #### TLSConfig Returns `*tls.Config` setup to allow you to dial your client over TLS, if enabled, else this will error with `cockroachdb.ErrTLSNotEnabled`. + +!!!info + The `TLSConfig()` function is deprecated and will be removed in the next major release of _Testcontainers for Go_. + +#### ConnectionConfig + +Returns `*pgx.ConnConfig` which can be passed to `pgx.ConnectConfig` to open a new connection. diff --git a/mkdocs.yml b/mkdocs.yml index 55083ad2f6..ae90f9a2b2 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -64,6 +64,8 @@ nav: - Log: features/wait/log.md - Multi: features/wait/multi.md - SQL: features/wait/sql.md + - TLS: features/wait/tls.md + - Walk: features/wait/walk.md - Modules: - modules/index.md - modules/artemis.md diff --git a/modules/cockroachdb/certs.go b/modules/cockroachdb/certs.go deleted file mode 100644 index afa12fcd1a..0000000000 --- a/modules/cockroachdb/certs.go +++ /dev/null @@ -1,67 +0,0 @@ -package cockroachdb - -import ( - "crypto/x509" - "errors" - "net" - "time" - - "github.com/mdelapenya/tlscert" -) - -type TLSConfig struct { - CACert *x509.Certificate - NodeCert []byte - NodeKey []byte - ClientCert []byte - ClientKey []byte -} - -// NewTLSConfig creates a new TLSConfig capable of running CockroachDB & connecting over TLS. -func NewTLSConfig() (*TLSConfig, error) { - // exampleSelfSignedCert { - caCert := tlscert.SelfSignedFromRequest(tlscert.Request{ - Name: "ca", - SubjectCommonName: "Cockroach Test CA", - Host: "localhost,127.0.0.1", - IsCA: true, - ValidFor: time.Hour, - }) - if caCert == nil { - return nil, errors.New("failed to generate CA certificate") - } - // } - - // exampleSignSelfSignedCert { - nodeCert := tlscert.SelfSignedFromRequest(tlscert.Request{ - Name: "node", - SubjectCommonName: "node", - Host: "localhost,127.0.0.1", - IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback}, - ValidFor: time.Hour, - Parent: caCert, // using the CA certificate as parent - }) - if nodeCert == nil { - return nil, errors.New("failed to generate node certificate") - } - // } - - clientCert := tlscert.SelfSignedFromRequest(tlscert.Request{ - Name: "client", - SubjectCommonName: defaultUser, - Host: "localhost,127.0.0.1", - ValidFor: time.Hour, - Parent: caCert, // using the CA certificate as parent - }) - if clientCert == nil { - return nil, errors.New("failed to generate client certificate") - } - - return &TLSConfig{ - CACert: caCert.Cert, - NodeCert: nodeCert.Bytes, - NodeKey: nodeCert.KeyBytes, - ClientCert: clientCert.Bytes, - ClientKey: clientCert.KeyBytes, - }, nil -} diff --git a/modules/cockroachdb/cockroachdb.go b/modules/cockroachdb/cockroachdb.go index 092efa4e2a..40da90fcd1 100644 --- a/modules/cockroachdb/cockroachdb.go +++ b/modules/cockroachdb/cockroachdb.go @@ -1,15 +1,14 @@ package cockroachdb import ( + "bytes" "context" "crypto/tls" - "crypto/x509" - "encoding/pem" + _ "embed" "errors" "fmt" "net" "net/url" - "path/filepath" "github.com/docker/go-connections/nat" "github.com/jackc/pgx/v5" @@ -19,11 +18,10 @@ import ( "github.com/testcontainers/testcontainers-go/wait" ) +// ErrTLSNotEnabled is returned when trying to get a TLS config from a container that does not have TLS enabled. var ErrTLSNotEnabled = errors.New("tls not enabled") const ( - certsDir = "/tmp" - defaultSQLPort = "26257/tcp" defaultAdminPort = "8080/tcp" @@ -31,15 +29,63 @@ const ( defaultPassword = "" defaultDatabase = "defaultdb" defaultStoreSize = "100%" + + // initDBPath is the path where the init scripts are placed in the container. + initDBPath = "/docker-entrypoint-initdb.d" + + // cockroachDir is the path where the CockroachDB files are placed in the container. + cockroachDir = "/cockroach" + + // clusterDefaultsContainerFile is the path to the default cluster settings script in the container. + clusterDefaultsContainerFile = initDBPath + "/__cluster_defaults.sql" + + // memStorageFlag is the flag to use in the start command to use an in-memory store. + memStorageFlag = "--store=type=mem,size=" + + // insecureFlag is the flag to use in the start command to disable TLS. + insecureFlag = "--insecure" + + // env vars. + envUser = "COCKROACH_USER" + envPassword = "COCKROACH_PASSWORD" + envDatabase = "COCKROACH_DATABASE" + + // cert files. + certsDir = cockroachDir + "/certs" + fileCACert = certsDir + "/ca.crt" ) +//go:embed data/cluster_defaults.sql +var clusterDefaults []byte + +// defaultsReader is a reader for the default settings scripts +// so that they can be identified and removed from the request. +type defaultsReader struct { + *bytes.Reader +} + +// newDefaultsReader creates a new reader for the default cluster settings script. +func newDefaultsReader(data []byte) *defaultsReader { + return &defaultsReader{Reader: bytes.NewReader(data)} +} + // CockroachDBContainer represents the CockroachDB container type used in the module type CockroachDBContainer struct { testcontainers.Container - opts options + options +} + +// options represents the options for the CockroachDBContainer type. +type options struct { + database string + user string + password string + tlsStrategy *wait.TLSStrategy } -// MustConnectionString panics if the address cannot be determined. +// MustConnectionString returns a connection string to open a new connection to CockroachDB +// as described by [CockroachDBContainer.ConnectionString]. +// It panics if an error occurs. func (c *CockroachDBContainer) MustConnectionString(ctx context.Context) string { addr, err := c.ConnectionString(ctx) if err != nil { @@ -48,35 +94,86 @@ func (c *CockroachDBContainer) MustConnectionString(ctx context.Context) string return addr } -// ConnectionString returns the dial address to open a new connection to CockroachDB. +// ConnectionString returns a connection string to open a new connection to CockroachDB. +// The returned string is suitable for use by [sql.Open] but is not be compatible with +// [pgx.ParseConfig], so if you want to call [pgx.ConnectConfig] use the +// [CockroachDBContainer.ConnectionConfig] method instead. func (c *CockroachDBContainer) ConnectionString(ctx context.Context) (string, error) { + cfg, err := c.ConnectionConfig(ctx) + if err != nil { + return "", fmt.Errorf("connection config: %w", err) + } + + return stdlib.RegisterConnConfig(cfg), nil +} + +// ConnectionConfig returns a [pgx.ConnConfig] for the CockroachDB container. +// This can be passed to [pgx.ConnectConfig] to open a new connection. +func (c *CockroachDBContainer) ConnectionConfig(ctx context.Context) (*pgx.ConnConfig, error) { port, err := c.MappedPort(ctx, defaultSQLPort) if err != nil { - return "", err + return nil, fmt.Errorf("mapped port: %w", err) } host, err := c.Host(ctx) if err != nil { - return "", err + return nil, fmt.Errorf("host: %w", err) } - return connString(c.opts, host, port), nil + return c.connConfig(host, port) } // TLSConfig returns config necessary to connect to CockroachDB over TLS. +// Returns [ErrTLSNotEnabled] if TLS is not enabled. +// +// Deprecated: use [CockroachDBContainer.ConnectionString] or +// [CockroachDBContainer.ConnectionConfig] instead. func (c *CockroachDBContainer) TLSConfig() (*tls.Config, error) { - return connTLS(c.opts) + if cfg := c.tlsStrategy.TLSConfig(); cfg != nil { + return cfg, nil + } + + return nil, ErrTLSNotEnabled } -// Deprecated: use Run instead +// Deprecated: use Run instead. // RunContainer creates an instance of the CockroachDB container type func RunContainer(ctx context.Context, opts ...testcontainers.ContainerCustomizer) (*CockroachDBContainer, error) { return Run(ctx, "cockroachdb/cockroach:latest-v23.1", opts...) } -// Run creates an instance of the CockroachDB container type +// Run start an instance of the CockroachDB container type using the given image and options. +// +// By default, the container will configured with: +// - Cluster: Single node +// - Storage: 100% in-memory +// - User: root +// - Password: "" +// - Database: defaultdb +// - Exposed ports: 26257/tcp (SQL), 8080/tcp (Admin UI) +// - Init Scripts: `data/cluster_defaults.sql` +// +// This supports CockroachDB images v22.2.0 and later, earlier versions will only work with +// customised options, such as disabling TLS and removing the wait for `init_success` using +// a [testcontainers.ContainerCustomizer]. +// +// The init script `data/cluster_defaults.sql` configures the settings recommended +// by Cockroach Labs for [local testing clusters] unless data exists in the +// `/cockroach/cockroach-data` directory within the container. Use [WithNoClusterDefaults] +// to disable this behaviour and provide your own settings using [WithInitScripts]. +// +// For more information see starting a [local cluster in docker]. +// +// [local cluster in docker]: https://www.cockroachlabs.com/docs/stable/start-a-local-cluster-in-docker-linux +// [local testing clusters]: https://www.cockroachlabs.com/docs/stable/local-testing func Run(ctx context.Context, img string, opts ...testcontainers.ContainerCustomizer) (*CockroachDBContainer, error) { - o := defaultOptions() + ctr := &CockroachDBContainer{ + options: options{ + database: defaultDatabase, + user: defaultUser, + password: defaultPassword, + }, + } req := testcontainers.GenericContainerRequest{ ContainerRequest: testcontainers.ContainerRequest{ Image: img, @@ -84,164 +181,80 @@ func Run(ctx context.Context, img string, opts ...testcontainers.ContainerCustom defaultSQLPort, defaultAdminPort, }, - LifecycleHooks: []testcontainers.ContainerLifecycleHooks{ - { - PreStarts: []testcontainers.ContainerHook{ - func(ctx context.Context, container testcontainers.Container) error { - return addTLS(ctx, container, o) - }, - }, - }, + Env: map[string]string{ + "COCKROACH_DATABASE": defaultDatabase, + "COCKROACH_USER": defaultUser, + "COCKROACH_PASSWORD": defaultPassword, + }, + Files: []testcontainers.ContainerFile{{ + Reader: newDefaultsReader(clusterDefaults), + ContainerFilePath: clusterDefaultsContainerFile, + FileMode: 0o644, + }}, + Cmd: []string{ + "start-single-node", + memStorageFlag + defaultStoreSize, }, + WaitingFor: wait.ForAll( + wait.ForFile(cockroachDir+"/init_success"), + wait.ForHTTP("/health").WithPort(defaultAdminPort), + wait.ForTLSCert( + certsDir+"/client."+defaultUser+".crt", + certsDir+"/client."+defaultUser+".key", + ).WithRootCAs(fileCACert).WithServerName("127.0.0.1"), + wait.ForSQL(defaultSQLPort, "pgx/v5", func(host string, port nat.Port) string { + connStr, err := ctr.connString(host, port) + if err != nil { + panic(err) + } + return connStr + }), + ), }, Started: true, } - // apply options for _, opt := range opts { - if apply, ok := opt.(Option); ok { - apply(&o) - } if err := opt.Customize(&req); err != nil { - return nil, err + return nil, fmt.Errorf("customize request: %w", err) } } - // modify request - for _, fn := range []modiferFunc{ - addEnvs, - addCmd, - addWaitingFor, - } { - if err := fn(&req, o); err != nil { - return nil, err - } - } - - container, err := testcontainers.GenericContainer(ctx, req) - var c *CockroachDBContainer - if container != nil { - c = &CockroachDBContainer{Container: container, opts: o} + if err := ctr.configure(&req); err != nil { + return nil, fmt.Errorf("set options: %w", err) } + var err error + ctr.Container, err = testcontainers.GenericContainer(ctx, req) if err != nil { - return c, fmt.Errorf("generic container: %w", err) + return ctr, fmt.Errorf("generic container: %w", err) } - return c, nil + return ctr, nil } -type modiferFunc func(*testcontainers.GenericContainerRequest, options) error - -func addCmd(req *testcontainers.GenericContainerRequest, opts options) error { - req.Cmd = []string{ - "start-single-node", - "--store=type=mem,size=" + opts.StoreSize, - } - - // authN - if opts.TLS != nil { - if opts.User != defaultUser { - return fmt.Errorf("unsupported user %s with TLS, use %s", opts.User, defaultUser) - } - if opts.Password != "" { - return errors.New("cannot use password authentication with TLS") - } - } - - switch { - case opts.TLS != nil: - req.Cmd = append(req.Cmd, "--certs-dir="+certsDir) - case opts.Password != "": - req.Cmd = append(req.Cmd, "--accept-sql-without-tls") - default: - req.Cmd = append(req.Cmd, "--insecure") - } - return nil -} - -func addEnvs(req *testcontainers.GenericContainerRequest, opts options) error { - if req.Env == nil { - req.Env = make(map[string]string) +// connString returns a connection string for the given host, port and options. +func (c *CockroachDBContainer) connString(host string, port nat.Port) (string, error) { + cfg, err := c.connConfig(host, port) + if err != nil { + return "", fmt.Errorf("connection config: %w", err) } - req.Env["COCKROACH_DATABASE"] = opts.Database - req.Env["COCKROACH_USER"] = opts.User - req.Env["COCKROACH_PASSWORD"] = opts.Password - return nil + return stdlib.RegisterConnConfig(cfg), nil } -func addWaitingFor(req *testcontainers.GenericContainerRequest, opts options) error { - var tlsConfig *tls.Config - if opts.TLS != nil { - cfg, err := connTLS(opts) - if err != nil { - return err - } - tlsConfig = cfg - } - - sqlWait := wait.ForSQL(defaultSQLPort, "pgx/v5", func(host string, port nat.Port) string { - connStr := connString(opts, host, port) - if tlsConfig == nil { - return connStr - } - - // register TLS config with pgx driver - connCfg, err := pgx.ParseConfig(connStr) - if err != nil { - panic(err) - } - connCfg.TLSConfig = tlsConfig - - return stdlib.RegisterConnConfig(connCfg) - }) - defaultStrategy := wait.ForAll( - wait.ForHTTP("/health").WithPort(defaultAdminPort), - sqlWait, - ) - - if req.WaitingFor == nil { - req.WaitingFor = defaultStrategy +// connConfig returns a [pgx.ConnConfig] for the given host, port and options. +func (c *CockroachDBContainer) connConfig(host string, port nat.Port) (*pgx.ConnConfig, error) { + var user *url.Userinfo + if c.password != "" { + user = url.UserPassword(c.user, c.password) } else { - req.WaitingFor = wait.ForAll(req.WaitingFor, defaultStrategy) - } - - return nil -} - -func addTLS(ctx context.Context, container testcontainers.Container, opts options) error { - if opts.TLS == nil { - return nil - } - - caBytes := pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: opts.TLS.CACert.Raw, - }) - files := map[string][]byte{ - "ca.crt": caBytes, - "node.crt": opts.TLS.NodeCert, - "node.key": opts.TLS.NodeKey, - "client.root.crt": opts.TLS.ClientCert, - "client.root.key": opts.TLS.ClientKey, - } - for filename, contents := range files { - if err := container.CopyToContainer(ctx, contents, filepath.Join(certsDir, filename), 0o600); err != nil { - return err - } - } - return nil -} - -func connString(opts options, host string, port nat.Port) string { - user := url.User(opts.User) - if opts.Password != "" { - user = url.UserPassword(opts.User, opts.Password) + user = url.User(c.user) } sslMode := "disable" - if opts.TLS != nil { + tlsConfig := c.tlsStrategy.TLSConfig() + if tlsConfig != nil { sslMode = "verify-full" } params := url.Values{ @@ -252,29 +265,57 @@ func connString(opts options, host string, port nat.Port) string { Scheme: "postgres", User: user, Host: net.JoinHostPort(host, port.Port()), - Path: opts.Database, + Path: c.database, RawQuery: params.Encode(), } - return u.String() + cfg, err := pgx.ParseConfig(u.String()) + if err != nil { + return nil, fmt.Errorf("parse config: %w", err) + } + + cfg.TLSConfig = tlsConfig + + return cfg, nil } -func connTLS(opts options) (*tls.Config, error) { - if opts.TLS == nil { - return nil, ErrTLSNotEnabled +// configure sets the CockroachDBContainer options from the given request and updates the request +// wait strategies to match the options. +func (c *CockroachDBContainer) configure(req *testcontainers.GenericContainerRequest) error { + c.database = req.Env[envDatabase] + c.user = req.Env[envUser] + c.password = req.Env[envPassword] + + var insecure bool + for _, arg := range req.Cmd { + if arg == insecureFlag { + insecure = true + break + } } - keyPair, err := tls.X509KeyPair(opts.TLS.ClientCert, opts.TLS.ClientKey) - if err != nil { - return nil, err - } + // Walk the wait strategies to find the TLS strategy and either remove it or + // update the client certificate files to match the user and configure the + // container to use the TLS strategy. + if err := wait.Walk(&req.WaitingFor, func(strategy wait.Strategy) error { + if cert, ok := strategy.(*wait.TLSStrategy); ok { + if insecure { + // If insecure mode is enabled, the certificate strategy is removed. + return errors.Join(wait.VisitRemove, wait.VisitStop) + } - certPool := x509.NewCertPool() - certPool.AddCert(opts.TLS.CACert) + // Update the client certificate files to match the user which may have changed. + cert.WithCert(certsDir+"/client."+c.user+".crt", certsDir+"/client."+c.user+".key") - return &tls.Config{ - RootCAs: certPool, - Certificates: []tls.Certificate{keyPair}, - ServerName: "localhost", - }, nil + c.tlsStrategy = cert + + // Stop the walk as the certificate strategy has been found. + return wait.VisitStop + } + return nil + }); err != nil { + return fmt.Errorf("walk strategies: %w", err) + } + + return nil } diff --git a/modules/cockroachdb/cockroachdb_test.go b/modules/cockroachdb/cockroachdb_test.go index cc355e9168..e3a7bb1f12 100644 --- a/modules/cockroachdb/cockroachdb_test.go +++ b/modules/cockroachdb/cockroachdb_test.go @@ -2,221 +2,94 @@ package cockroachdb_test import ( "context" - "errors" - "net/url" - "strings" + "database/sql" "testing" - "time" "github.com/jackc/pgx/v5" "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/modules/cockroachdb" - "github.com/testcontainers/testcontainers-go/wait" ) -func TestCockroach_Insecure(t *testing.T) { - suite.Run(t, &AuthNSuite{ - url: "postgres://root@localhost:xxxxx/defaultdb?sslmode=disable", - }) -} +const testImage = "cockroachdb/cockroach:latest-v23.1" -func TestCockroach_NotRoot(t *testing.T) { - suite.Run(t, &AuthNSuite{ - url: "postgres://test@localhost:xxxxx/defaultdb?sslmode=disable", - opts: []testcontainers.ContainerCustomizer{ - cockroachdb.WithUser("test"), - }, - }) +func TestRun(t *testing.T) { + testContainer(t) } -func TestCockroach_Password(t *testing.T) { - suite.Run(t, &AuthNSuite{ - url: "postgres://foo:bar@localhost:xxxxx/defaultdb?sslmode=disable", - opts: []testcontainers.ContainerCustomizer{ - cockroachdb.WithUser("foo"), - cockroachdb.WithPassword("bar"), - }, - }) +func TestRun_WithAllOptions(t *testing.T) { + testContainer(t, + cockroachdb.WithDatabase("testDatabase"), + cockroachdb.WithStoreSize("50%"), + cockroachdb.WithUser("testUser"), + cockroachdb.WithPassword("testPassword"), + cockroachdb.WithNoClusterDefaults(), + cockroachdb.WithInitScripts("testdata/__init.sql"), + // WithInsecure is not present as it is incompatible with WithPassword. + ) } -func TestCockroach_TLS(t *testing.T) { - tlsCfg, err := cockroachdb.NewTLSConfig() - require.NoError(t, err) - - suite.Run(t, &AuthNSuite{ - url: "postgres://root@localhost:xxxxx/defaultdb?sslmode=verify-full", - opts: []testcontainers.ContainerCustomizer{ - cockroachdb.WithTLS(tlsCfg), - }, +func TestRun_WithInsecure(t *testing.T) { + t.Run("valid", func(t *testing.T) { + testContainer(t, cockroachdb.WithInsecure()) }) -} - -type AuthNSuite struct { - suite.Suite - url string - opts []testcontainers.ContainerCustomizer -} - -func (suite *AuthNSuite) TestConnectionString() { - ctx := context.Background() - - ctr, err := cockroachdb.Run(ctx, "cockroachdb/cockroach:latest-v23.1", suite.opts...) - testcontainers.CleanupContainer(suite.T(), ctr) - suite.Require().NoError(err) - - connStr, err := removePort(ctr.MustConnectionString(ctx)) - suite.Require().NoError(err) - - suite.Equal(suite.url, connStr) -} - -func (suite *AuthNSuite) TestPing() { - ctx := context.Background() - - inputs := []struct { - name string - opts []testcontainers.ContainerCustomizer - }{ - { - name: "defaults", - // opts: suite.opts - }, - { - name: "database", - opts: []testcontainers.ContainerCustomizer{ - cockroachdb.WithDatabase("test"), - }, - }, - } - - for _, input := range inputs { - suite.Run(input.name, func() { - opts := suite.opts - opts = append(opts, input.opts...) - - ctr, err := cockroachdb.Run(ctx, "cockroachdb/cockroach:latest-v23.1", opts...) - testcontainers.CleanupContainer(suite.T(), ctr) - suite.Require().NoError(err) - - conn, err := conn(ctx, ctr) - suite.Require().NoError(err) - defer conn.Close(ctx) - - err = conn.Ping(ctx) - suite.Require().NoError(err) - }) - } -} - -func (suite *AuthNSuite) TestQuery() { - ctx := context.Background() - - ctr, err := cockroachdb.Run(ctx, "cockroachdb/cockroach:latest-v23.1", suite.opts...) - testcontainers.CleanupContainer(suite.T(), ctr) - suite.Require().NoError(err) - - conn, err := conn(ctx, ctr) - suite.Require().NoError(err) - defer conn.Close(ctx) - _, err = conn.Exec(ctx, "CREATE TABLE test (id INT PRIMARY KEY)") - suite.Require().NoError(err) - - _, err = conn.Exec(ctx, "INSERT INTO test (id) VALUES (523123)") - suite.Require().NoError(err) - - var id int - err = conn.QueryRow(ctx, "SELECT id FROM test").Scan(&id) - suite.Require().NoError(err) - suite.Equal(523123, id) -} - -// TestWithWaitStrategyAndDeadline covers a previous regression, container creation needs to fail to cover that path. -func (suite *AuthNSuite) TestWithWaitStrategyAndDeadline() { - nodeStartUpCompleted := "node startup completed" - - suite.Run("Expected Failure To Run", func() { - ctx := context.Background() - - // This will never match a log statement - suite.opts = append(suite.opts, testcontainers.WithWaitStrategyAndDeadline(time.Millisecond*250, wait.ForLog("Won't Exist In Logs"))) - ctr, err := cockroachdb.Run(ctx, "cockroachdb/cockroach:latest-v23.1", suite.opts...) - testcontainers.CleanupContainer(suite.T(), ctr) - suite.Require().ErrorIs(err, context.DeadlineExceeded) + t.Run("invalid-password-insecure", func(t *testing.T) { + _, err := cockroachdb.Run(context.Background(), testImage, + cockroachdb.WithPassword("testPassword"), + cockroachdb.WithInsecure(), + ) + require.Error(t, err) }) - suite.Run("Expected Failure To Run But Would Succeed ", func() { - ctx := context.Background() - - // This will timeout as we didn't give enough time for intialization, but would have succeeded otherwise - suite.opts = append(suite.opts, testcontainers.WithWaitStrategyAndDeadline(time.Millisecond*20, wait.ForLog(nodeStartUpCompleted))) - ctr, err := cockroachdb.Run(ctx, "cockroachdb/cockroach:latest-v23.1", suite.opts...) - testcontainers.CleanupContainer(suite.T(), ctr) - suite.Require().ErrorIs(err, context.DeadlineExceeded) + t.Run("invalid-insecure-password", func(t *testing.T) { + _, err := cockroachdb.Run(context.Background(), testImage, + cockroachdb.WithInsecure(), + cockroachdb.WithPassword("testPassword"), + ) + require.Error(t, err) }) +} - suite.Run("Succeeds And Executes Commands", func() { - ctx := context.Background() +// testContainer runs a CockroachDB container and validates its functionality. +func testContainer(t *testing.T, opts ...testcontainers.ContainerCustomizer) { + t.Helper() - // This will succeed - suite.opts = append(suite.opts, testcontainers.WithWaitStrategyAndDeadline(time.Second*60, wait.ForLog(nodeStartUpCompleted))) - ctr, err := cockroachdb.Run(ctx, "cockroachdb/cockroach:latest-v23.1", suite.opts...) - testcontainers.CleanupContainer(suite.T(), ctr) - suite.Require().NoError(err) + ctx := context.Background() + ctr, err := cockroachdb.Run(ctx, testImage, opts...) + testcontainers.CleanupContainer(t, ctr) + require.NoError(t, err) + require.NotNil(t, ctr) - conn, err := conn(ctx, ctr) - suite.Require().NoError(err) - defer conn.Close(ctx) + // Check a raw connection with a ping. + cfg, err := ctr.ConnectionConfig(ctx) + require.NoError(t, err) - _, err = conn.Exec(ctx, "CREATE TABLE test (id INT PRIMARY KEY)") - suite.Require().NoError(err) + conn, err := pgx.ConnectConfig(ctx, cfg) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, conn.Close(ctx)) }) - suite.Run("Succeeds And Executes Commands Waiting on HTTP Endpoint", func() { - ctx := context.Background() + err = conn.Ping(ctx) + require.NoError(t, err) - // This will succeed - suite.opts = append(suite.opts, testcontainers.WithWaitStrategyAndDeadline(time.Second*60, wait.ForHTTP("/health").WithPort("8080/tcp"))) - ctr, err := cockroachdb.Run(ctx, "cockroachdb/cockroach:latest-v23.1", suite.opts...) - testcontainers.CleanupContainer(suite.T(), ctr) - suite.Require().NoError(err) + // Check an SQL connection with a queries. + addr, err := ctr.ConnectionString(ctx) + require.NoError(t, err) - conn, err := conn(ctx, ctr) - suite.Require().NoError(err) - defer conn.Close(ctx) + db, err := sql.Open("pgx/v5", addr) + require.NoError(t, err) - _, err = conn.Exec(ctx, "CREATE TABLE test (id INT PRIMARY KEY)") - suite.Require().NoError(err) - }) -} + _, err = db.ExecContext(ctx, "CREATE TABLE test (id INT PRIMARY KEY)") + require.NoError(t, err) -func conn(ctx context.Context, container *cockroachdb.CockroachDBContainer) (*pgx.Conn, error) { - cfg, err := pgx.ParseConfig(container.MustConnectionString(ctx)) - if err != nil { - return nil, err - } - - tlsCfg, err := container.TLSConfig() - switch { - case err != nil: - if !errors.Is(err, cockroachdb.ErrTLSNotEnabled) { - return nil, err - } - default: - // apply TLS config - cfg.TLSConfig = tlsCfg - } - - return pgx.ConnectConfig(ctx, cfg) -} + _, err = db.ExecContext(ctx, "INSERT INTO test (id) VALUES (523123)") + require.NoError(t, err) -func removePort(s string) (string, error) { - u, err := url.Parse(s) - if err != nil { - return "", err - } - return strings.Replace(s, ":"+u.Port(), ":xxxxx", 1), nil + var id int + err = db.QueryRowContext(ctx, "SELECT id FROM test").Scan(&id) + require.NoError(t, err) + require.Equal(t, 523123, id) } diff --git a/modules/cockroachdb/data/cluster_defaults.sql b/modules/cockroachdb/data/cluster_defaults.sql new file mode 100644 index 0000000000..78502d115e --- /dev/null +++ b/modules/cockroachdb/data/cluster_defaults.sql @@ -0,0 +1,8 @@ +SET CLUSTER SETTING kv.range_merge.queue_interval = '50ms'; +SET CLUSTER SETTING jobs.registry.interval.gc = '30s'; +SET CLUSTER SETTING jobs.registry.interval.cancel = '180s'; +SET CLUSTER SETTING jobs.retention_time = '15s'; +SET CLUSTER SETTING sql.stats.automatic_collection.enabled = false; +SET CLUSTER SETTING kv.range_split.by_load_merge_delay = '5s'; +ALTER RANGE default CONFIGURE ZONE USING "gc.ttlseconds" = 600; +ALTER DATABASE system CONFIGURE ZONE USING "gc.ttlseconds" = 600; diff --git a/modules/cockroachdb/examples_test.go b/modules/cockroachdb/examples_test.go index c06c97596b..a1259c218b 100644 --- a/modules/cockroachdb/examples_test.go +++ b/modules/cockroachdb/examples_test.go @@ -2,9 +2,11 @@ package cockroachdb_test import ( "context" + "database/sql" "fmt" "log" - "net/url" + + "github.com/jackc/pgx/v5" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/modules/cockroachdb" @@ -33,20 +35,97 @@ func ExampleRun() { } fmt.Println(state.Running) + cfg, err := cockroachdbContainer.ConnectionConfig(ctx) + if err != nil { + log.Printf("failed to get connection string: %s", err) + return + } + + conn, err := pgx.ConnectConfig(ctx, cfg) + if err != nil { + log.Printf("failed to connect: %s", err) + return + } + + defer func() { + if err := conn.Close(ctx); err != nil { + log.Printf("failed to close connection: %s", err) + } + }() + + if err = conn.Ping(ctx); err != nil { + log.Printf("failed to ping: %s", err) + return + } + + // Output: + // true +} + +func ExampleRun_withInitOptions() { + ctx := context.Background() + + cockroachdbContainer, err := cockroachdb.Run(ctx, "cockroachdb/cockroach:latest-v23.1", + cockroachdb.WithNoClusterDefaults(), + cockroachdb.WithInitScripts("testdata/__init.sql"), + ) + defer func() { + if err := testcontainers.TerminateContainer(cockroachdbContainer); err != nil { + log.Printf("failed to terminate container: %s", err) + } + }() + if err != nil { + log.Printf("failed to start container: %s", err) + return + } + + state, err := cockroachdbContainer.State(ctx) + if err != nil { + log.Printf("failed to get container state: %s", err) + return + } + fmt.Println(state.Running) + addr, err := cockroachdbContainer.ConnectionString(ctx) if err != nil { log.Printf("failed to get connection string: %s", err) return } - u, err := url.Parse(addr) + + db, err := sql.Open("pgx/v5", addr) if err != nil { - log.Printf("failed to parse connection string: %s", err) + log.Printf("failed to open connection: %s", err) return } - u.Host = fmt.Sprintf("%s:%s", u.Hostname(), "xxx") - fmt.Println(u.String()) + defer func() { + if err := db.Close(); err != nil { + log.Printf("failed to close connection: %s", err) + } + }() + + var interval string + if err := db.QueryRow("SHOW CLUSTER SETTING kv.range_merge.queue_interval").Scan(&interval); err != nil { + log.Printf("failed to scan row: %s", err) + return + } + fmt.Println(interval) + + if err := db.QueryRow("SHOW CLUSTER SETTING jobs.registry.interval.gc").Scan(&interval); err != nil { + log.Printf("failed to scan row: %s", err) + return + } + fmt.Println(interval) + + var statsCollectionEnabled bool + if err := db.QueryRow("SHOW CLUSTER SETTING sql.stats.automatic_collection.enabled").Scan(&statsCollectionEnabled); err != nil { + log.Printf("failed to scan row: %s", err) + return + } + fmt.Println(statsCollectionEnabled) // Output: // true - // postgres://root@localhost:xxx/defaultdb?sslmode=disable + // 00:00:05 + // 00:00:50 + // true } diff --git a/modules/cockroachdb/go.mod b/modules/cockroachdb/go.mod index fbc0fd6f7a..cf31a35616 100644 --- a/modules/cockroachdb/go.mod +++ b/modules/cockroachdb/go.mod @@ -41,7 +41,6 @@ require ( github.com/kr/text v0.2.0 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/magiconair/properties v1.8.7 // indirect - github.com/mdelapenya/tlscert v0.1.0 github.com/moby/patternmatcher v0.6.0 // indirect github.com/moby/sys/sequential v0.5.0 // indirect github.com/moby/sys/user v0.1.0 // indirect diff --git a/modules/cockroachdb/go.sum b/modules/cockroachdb/go.sum index e8661eb69a..3877e20a9a 100644 --- a/modules/cockroachdb/go.sum +++ b/modules/cockroachdb/go.sum @@ -69,8 +69,6 @@ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= -github.com/mdelapenya/tlscert v0.1.0 h1:YTpF579PYUX475eOL+6zyEO3ngLTOUWck78NBuJVXaM= -github.com/mdelapenya/tlscert v0.1.0/go.mod h1:wrbyM/DwbFCeCeqdPX/8c6hNOqQgbf0rUDErE1uD+64= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk= diff --git a/modules/cockroachdb/options.go b/modules/cockroachdb/options.go index a2211d77e7..9efac532c6 100644 --- a/modules/cockroachdb/options.go +++ b/modules/cockroachdb/options.go @@ -1,69 +1,119 @@ package cockroachdb -import "github.com/testcontainers/testcontainers-go" - -type options struct { - Database string - User string - Password string - StoreSize string - TLS *TLSConfig -} +import ( + "errors" + "path/filepath" + "strings" + + "github.com/testcontainers/testcontainers-go" +) + +// errInsecureWithPassword is returned when trying to use insecure mode with a password. +var errInsecureWithPassword = errors.New("insecure mode cannot be used with a password") -func defaultOptions() options { - return options{ - User: defaultUser, - Password: defaultPassword, - Database: defaultDatabase, - StoreSize: defaultStoreSize, +// WithDatabase sets the name of the database to create and use. +// This will be converted to lowercase as CockroachDB forces the database to be lowercase. +// The database creation will be skipped if data exists in the `/cockroach/cockroach-data` directory within the container. +func WithDatabase(database string) testcontainers.CustomizeRequestOption { + return func(req *testcontainers.GenericContainerRequest) error { + req.Env[envDatabase] = strings.ToLower(database) + return nil } } -// Compiler check to ensure that Option implements the testcontainers.ContainerCustomizer interface. -var _ testcontainers.ContainerCustomizer = (*Option)(nil) +// WithUser sets the name of the user to create and connect as. +// This will be converted to lowercase as CockroachDB forces the user to be lowercase. +// The user creation will be skipped if data exists in the `/cockroach/cockroach-data` directory within the container. +func WithUser(user string) testcontainers.CustomizeRequestOption { + return func(req *testcontainers.GenericContainerRequest) error { + req.Env[envUser] = strings.ToLower(user) + return nil + } +} -// Option is an option for the CockroachDB container. -type Option func(*options) +// WithPassword sets the password of the user to create and connect as. +// The user creation will be skipped if data exists in the `/cockroach/cockroach-data` directory within the container. +// This will error if insecure mode is enabled. +func WithPassword(password string) testcontainers.CustomizeRequestOption { + return func(req *testcontainers.GenericContainerRequest) error { + for _, arg := range req.Cmd { + if arg == insecureFlag { + return errInsecureWithPassword + } + } -// Customize is a NOOP. It's defined to satisfy the testcontainers.ContainerCustomizer interface. -func (o Option) Customize(*testcontainers.GenericContainerRequest) error { - // NOOP to satisfy interface. - return nil -} + req.Env[envPassword] = password -// WithDatabase sets the name of the database to use. -func WithDatabase(database string) Option { - return func(o *options) { - o.Database = database + return nil } } -// WithUser creates & sets the user to connect as. -func WithUser(user string) Option { - return func(o *options) { - o.User = user +// WithStoreSize sets the amount of available [in-memory storage]. +// +// [in-memory storage]: https://www.cockroachlabs.com/docs/stable/cockroach-start#store +func WithStoreSize(size string) testcontainers.CustomizeRequestOption { + return func(req *testcontainers.GenericContainerRequest) error { + for i, cmd := range req.Cmd { + if strings.HasPrefix(cmd, memStorageFlag) { + req.Cmd[i] = memStorageFlag + size + return nil + } + } + + // Wasn't found, add it. + req.Cmd = append(req.Cmd, memStorageFlag+size) + + return nil } } -// WithPassword sets the password when using password authentication. -func WithPassword(password string) Option { - return func(o *options) { - o.Password = password +// WithNoClusterDefaults disables the default cluster settings script. +// +// Without this option Cockroach containers run `data/cluster-defaults.sql` on startup +// which configures the settings recommended by Cockroach Labs for [local testing clusters] +// unless data exists in the `/cockroach/cockroach-data` directory within the container. +// +// [local testing clusters]: https://www.cockroachlabs.com/docs/stable/local-testing +func WithNoClusterDefaults() testcontainers.CustomizeRequestOption { + return func(req *testcontainers.GenericContainerRequest) error { + for i, file := range req.Files { + if _, ok := file.Reader.(*defaultsReader); ok && file.ContainerFilePath == clusterDefaultsContainerFile { + req.Files = append(req.Files[:i], req.Files[i+1:]...) + return nil + } + } + + return nil } } -// WithStoreSize sets the amount of available in-memory storage. -// See https://www.cockroachlabs.com/docs/stable/cockroach-start#store -func WithStoreSize(size string) Option { - return func(o *options) { - o.StoreSize = size +// WithInitScripts adds the given scripts to those automatically run when the container starts. +// These will be ignored if data exists in the `/cockroach/cockroach-data` directory within the container. +func WithInitScripts(scripts ...string) testcontainers.CustomizeRequestOption { + return func(req *testcontainers.GenericContainerRequest) error { + files := make([]testcontainers.ContainerFile, len(scripts)) + for i, script := range scripts { + files[i] = testcontainers.ContainerFile{ + HostFilePath: script, + ContainerFilePath: initDBPath + "/" + filepath.Base(script), + FileMode: 0o644, + } + } + req.Files = append(req.Files, files...) + + return nil } } -// WithTLS enables TLS on the CockroachDB container. -// Cert and key must be PEM-encoded. -func WithTLS(cfg *TLSConfig) Option { - return func(o *options) { - o.TLS = cfg +// WithInsecure enables insecure mode which disables TLS. +func WithInsecure() testcontainers.CustomizeRequestOption { + return func(req *testcontainers.GenericContainerRequest) error { + if req.Env[envPassword] != "" { + return errInsecureWithPassword + } + + req.Cmd = append(req.Cmd, insecureFlag) + + return nil } } diff --git a/modules/cockroachdb/testdata/__init.sql b/modules/cockroachdb/testdata/__init.sql new file mode 100644 index 0000000000..c2c82dd48a --- /dev/null +++ b/modules/cockroachdb/testdata/__init.sql @@ -0,0 +1 @@ +SET CLUSTER SETTING jobs.registry.interval.gc = '50s'; diff --git a/modules/rabbitmq/examples_test.go b/modules/rabbitmq/examples_test.go index bc6a849456..b9c4e9fdf2 100644 --- a/modules/rabbitmq/examples_test.go +++ b/modules/rabbitmq/examples_test.go @@ -102,6 +102,7 @@ func ExampleRun_withSSL() { defer os.RemoveAll(certDirs) // generates the CA certificate and the certificate + // exampleSelfSignedCert { caCert := tlscert.SelfSignedFromRequest(tlscert.Request{ Name: "ca", Host: "localhost,127.0.0.1", @@ -112,7 +113,9 @@ func ExampleRun_withSSL() { log.Print("failed to generate CA certificate") return } + // } + // exampleSignSelfSignedCert { cert := tlscert.SelfSignedFromRequest(tlscert.Request{ Name: "client", Host: "localhost,127.0.0.1", @@ -124,6 +127,7 @@ func ExampleRun_withSSL() { log.Print("failed to generate certificate") return } + // } sslSettings := rabbitmq.SSLSettings{ CACertFile: caCert.CertPath, diff --git a/wait/file_test.go b/wait/file_test.go index 22133ba349..20bcc13a01 100644 --- a/wait/file_test.go +++ b/wait/file_test.go @@ -20,7 +20,7 @@ import ( const testFilename = "/tmp/file" -var anyContext = mock.AnythingOfType("*context.timerCtx") +var anyContext = mock.MatchedBy(func(_ context.Context) bool { return true }) // newRunningTarget creates a new mockStrategyTarget that is running. func newRunningTarget() *mockStrategyTarget { diff --git a/wait/http_test.go b/wait/http_test.go index 32479bddd4..73e32d44d7 100644 --- a/wait/http_test.go +++ b/wait/http_test.go @@ -5,6 +5,7 @@ import ( "context" "crypto/tls" "crypto/x509" + _ "embed" "fmt" "io" "log" @@ -23,6 +24,9 @@ import ( "github.com/testcontainers/testcontainers-go/wait" ) +//go:embed testdata/root.pem +var caBytes []byte + // https://github.com/testcontainers/testcontainers-go/issues/183 func ExampleHTTPStrategy() { // waitForHTTPWithDefaultPort { @@ -80,7 +84,7 @@ func ExampleHTTPStrategy_WithHeaders() { tlsconfig := &tls.Config{RootCAs: certpool, ServerName: "testcontainer.go.test"} req := testcontainers.ContainerRequest{ FromDockerfile: testcontainers.FromDockerfile{ - Context: "testdata", + Context: "testdata/http", }, ExposedPorts: []string{"6443/tcp"}, WaitingFor: wait.ForHTTP("/headers"). @@ -227,20 +231,13 @@ func ExampleHTTPStrategy_WithBasicAuth() { } func TestHTTPStrategyWaitUntilReady(t *testing.T) { - workdir, err := os.Getwd() - require.NoError(t, err) - - capath := filepath.Join(workdir, "testdata", "root.pem") - cafile, err := os.ReadFile(capath) - require.NoError(t, err) - certpool := x509.NewCertPool() - require.Truef(t, certpool.AppendCertsFromPEM(cafile), "the ca file isn't valid") + require.Truef(t, certpool.AppendCertsFromPEM(caBytes), "the ca file isn't valid") tlsconfig := &tls.Config{RootCAs: certpool, ServerName: "testcontainer.go.test"} dockerReq := testcontainers.ContainerRequest{ FromDockerfile: testcontainers.FromDockerfile{ - Context: filepath.Join(workdir, "testdata"), + Context: "testdata/http", }, ExposedPorts: []string{"6443/tcp"}, WaitingFor: wait.NewHTTPStrategy("/auth-ping").WithTLS(true, tlsconfig). @@ -288,20 +285,13 @@ func TestHTTPStrategyWaitUntilReady(t *testing.T) { } func TestHTTPStrategyWaitUntilReadyWithQueryString(t *testing.T) { - workdir, err := os.Getwd() - require.NoError(t, err) - - capath := filepath.Join(workdir, "testdata", "root.pem") - cafile, err := os.ReadFile(capath) - require.NoError(t, err) - certpool := x509.NewCertPool() - require.Truef(t, certpool.AppendCertsFromPEM(cafile), "the ca file isn't valid") + require.Truef(t, certpool.AppendCertsFromPEM(caBytes), "the ca file isn't valid") tlsconfig := &tls.Config{RootCAs: certpool, ServerName: "testcontainer.go.test"} dockerReq := testcontainers.ContainerRequest{ FromDockerfile: testcontainers.FromDockerfile{ - Context: filepath.Join(workdir, "testdata"), + Context: "testdata/http", }, ExposedPorts: []string{"6443/tcp"}, @@ -348,22 +338,15 @@ func TestHTTPStrategyWaitUntilReadyWithQueryString(t *testing.T) { } func TestHTTPStrategyWaitUntilReadyNoBasicAuth(t *testing.T) { - workdir, err := os.Getwd() - require.NoError(t, err) - - capath := filepath.Join(workdir, "testdata", "root.pem") - cafile, err := os.ReadFile(capath) - require.NoError(t, err) - certpool := x509.NewCertPool() - require.Truef(t, certpool.AppendCertsFromPEM(cafile), "the ca file isn't valid") + require.Truef(t, certpool.AppendCertsFromPEM(caBytes), "the ca file isn't valid") // waitForHTTPStatusCode { tlsconfig := &tls.Config{RootCAs: certpool, ServerName: "testcontainer.go.test"} var i int dockerReq := testcontainers.ContainerRequest{ FromDockerfile: testcontainers.FromDockerfile{ - Context: filepath.Join(workdir, "testdata"), + Context: "testdata/http", }, ExposedPorts: []string{"6443/tcp"}, WaitingFor: wait.NewHTTPStrategy("/ping").WithTLS(true, tlsconfig). diff --git a/wait/testdata/tls.pem b/wait/testdata/cert.crt similarity index 100% rename from wait/testdata/tls.pem rename to wait/testdata/cert.crt diff --git a/wait/testdata/tls-key.pem b/wait/testdata/cert.key similarity index 100% rename from wait/testdata/tls-key.pem rename to wait/testdata/cert.key diff --git a/wait/testdata/Dockerfile b/wait/testdata/http/Dockerfile similarity index 100% rename from wait/testdata/Dockerfile rename to wait/testdata/http/Dockerfile diff --git a/wait/testdata/go.mod b/wait/testdata/http/go.mod similarity index 100% rename from wait/testdata/go.mod rename to wait/testdata/http/go.mod diff --git a/wait/testdata/main.go b/wait/testdata/http/main.go similarity index 100% rename from wait/testdata/main.go rename to wait/testdata/http/main.go diff --git a/wait/testdata/http/tls-key.pem b/wait/testdata/http/tls-key.pem new file mode 100644 index 0000000000..00789d2371 --- /dev/null +++ b/wait/testdata/http/tls-key.pem @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIM8HuDwcZyVqBBy2C6db6zNb/dAJ69bq5ejAEz7qGOIQoAoGCCqGSM49 +AwEHoUQDQgAEBL2ioRmfTc70WT0vyx+amSQOGbMeoMRAfF2qaPzpzOqpKTk0aLOG +0735iy9Fz16PX4vqnLMiM/ZupugAhB//yA== +-----END EC PRIVATE KEY----- diff --git a/wait/testdata/http/tls.pem b/wait/testdata/http/tls.pem new file mode 100644 index 0000000000..46348b7900 --- /dev/null +++ b/wait/testdata/http/tls.pem @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE----- +MIIBxTCCAWugAwIBAgIUWBLNpiF1o4r+5ZXwawzPOfBM1F8wCgYIKoZIzj0EAwIw +ADAeFw0yMDA4MTkxMzM4MDBaFw0zMDA4MTcxMzM4MDBaMBkxFzAVBgNVBAMTDnRl +c3Rjb250YWluZXJzMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEBL2ioRmfTc70 +WT0vyx+amSQOGbMeoMRAfF2qaPzpzOqpKTk0aLOG0735iy9Fz16PX4vqnLMiM/Zu +pugAhB//yKOBqTCBpjAOBgNVHQ8BAf8EBAMCBaAwEwYDVR0lBAwwCgYIKwYBBQUH +AwEwDAYDVR0TAQH/BAIwADAdBgNVHQ4EFgQUTMdz5PIZ+Gix4jYUzRIHfByrW+Yw +HwYDVR0jBBgwFoAUFdfV6PSYUlHs+lSQNouRwSfR2ZgwMQYDVR0RBCowKIIVdGVz +dGNvbnRhaW5lci5nby50ZXN0gglsb2NhbGhvc3SHBH8AAAEwCgYIKoZIzj0EAwID +SAAwRQIhAJznPNumi2Plf0GsP9DpC+8WukT/jUhnhcDWCfZ6Ini2AiBLhnhFebZX +XWfSsdSNxIo20OWvy6z3wqdybZtRUfdU+g== +-----END CERTIFICATE----- diff --git a/wait/tls.go b/wait/tls.go new file mode 100644 index 0000000000..ab904b271e --- /dev/null +++ b/wait/tls.go @@ -0,0 +1,167 @@ +package wait + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "io" + "time" +) + +// Validate we implement interface. +var _ Strategy = (*TLSStrategy)(nil) + +// TLSStrategy is a strategy for handling TLS. +type TLSStrategy struct { + // General Settings. + timeout *time.Duration + pollInterval time.Duration + + // Custom Settings. + certFiles *x509KeyPair + rootFiles []string + + // State. + tlsConfig *tls.Config +} + +// x509KeyPair is a pair of certificate and key files. +type x509KeyPair struct { + certPEMFile string + keyPEMFile string +} + +// ForTLSCert returns a CertStrategy that will add a Certificate to the [tls.Config] +// constructed from PEM formatted certificate key file pair in the container. +func ForTLSCert(certPEMFile, keyPEMFile string) *TLSStrategy { + return &TLSStrategy{ + certFiles: &x509KeyPair{ + certPEMFile: certPEMFile, + keyPEMFile: keyPEMFile, + }, + tlsConfig: &tls.Config{}, + pollInterval: defaultPollInterval(), + } +} + +// ForTLSRootCAs returns a CertStrategy that sets the root CAs for the [tls.Config] +// using the given PEM formatted files from the container. +func ForTLSRootCAs(pemFiles ...string) *TLSStrategy { + return &TLSStrategy{ + rootFiles: pemFiles, + tlsConfig: &tls.Config{}, + pollInterval: defaultPollInterval(), + } +} + +// WithRootCAs sets the root CAs for the [tls.Config] using the given files from +// the container. +func (ws *TLSStrategy) WithRootCAs(files ...string) *TLSStrategy { + ws.rootFiles = files + return ws +} + +// WithCert sets the [tls.Config] Certificates using the given files from the container. +func (ws *TLSStrategy) WithCert(certPEMFile, keyPEMFile string) *TLSStrategy { + ws.certFiles = &x509KeyPair{ + certPEMFile: certPEMFile, + keyPEMFile: keyPEMFile, + } + return ws +} + +// WithServerName sets the server for the [tls.Config]. +func (ws *TLSStrategy) WithServerName(serverName string) *TLSStrategy { + ws.tlsConfig.ServerName = serverName + return ws +} + +// WithStartupTimeout can be used to change the default startup timeout. +func (ws *TLSStrategy) WithStartupTimeout(startupTimeout time.Duration) *TLSStrategy { + ws.timeout = &startupTimeout + return ws +} + +// WithPollInterval can be used to override the default polling interval of 100 milliseconds. +func (ws *TLSStrategy) WithPollInterval(pollInterval time.Duration) *TLSStrategy { + ws.pollInterval = pollInterval + return ws +} + +// TLSConfig returns the TLS config once the strategy is ready. +// If the strategy is nil, it returns nil. +func (ws *TLSStrategy) TLSConfig() *tls.Config { + if ws == nil { + return nil + } + + return ws.tlsConfig +} + +// WaitUntilReady implements the [Strategy] interface. +// It waits for the CA, client cert and key files to be available in the container and +// uses them to setup the TLS config. +func (ws *TLSStrategy) WaitUntilReady(ctx context.Context, target StrategyTarget) error { + size := len(ws.rootFiles) + if ws.certFiles != nil { + size += 2 + } + strategies := make([]Strategy, 0, size) + for _, file := range ws.rootFiles { + strategies = append(strategies, + ForFile(file).WithMatcher(func(r io.Reader) error { + buf, err := io.ReadAll(r) + if err != nil { + return fmt.Errorf("read CA cert file %q: %w", file, err) + } + + if ws.tlsConfig.RootCAs == nil { + ws.tlsConfig.RootCAs = x509.NewCertPool() + } + + if !ws.tlsConfig.RootCAs.AppendCertsFromPEM(buf) { + return fmt.Errorf("invalid CA cert file %q", file) + } + + return nil + }).WithPollInterval(ws.pollInterval), + ) + } + + if ws.certFiles != nil { + var certPEMBlock []byte + strategies = append(strategies, + ForFile(ws.certFiles.certPEMFile).WithMatcher(func(r io.Reader) error { + var err error + if certPEMBlock, err = io.ReadAll(r); err != nil { + return fmt.Errorf("read certificate cert %q: %w", ws.certFiles.certPEMFile, err) + } + + return nil + }).WithPollInterval(ws.pollInterval), + ForFile(ws.certFiles.keyPEMFile).WithMatcher(func(r io.Reader) error { + keyPEMBlock, err := io.ReadAll(r) + if err != nil { + return fmt.Errorf("read certificate key %q: %w", ws.certFiles.keyPEMFile, err) + } + + cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock) + if err != nil { + return fmt.Errorf("x509 key pair %q %q: %w", ws.certFiles.certPEMFile, ws.certFiles.keyPEMFile, err) + } + + ws.tlsConfig.Certificates = []tls.Certificate{cert} + + return nil + }).WithPollInterval(ws.pollInterval), + ) + } + + strategy := ForAll(strategies...) + if ws.timeout != nil { + strategy.WithStartupTimeout(*ws.timeout) + } + + return strategy.WaitUntilReady(ctx, target) +} diff --git a/wait/tls_test.go b/wait/tls_test.go new file mode 100644 index 0000000000..babc17b3d0 --- /dev/null +++ b/wait/tls_test.go @@ -0,0 +1,150 @@ +package wait_test + +import ( + "bytes" + "context" + "crypto/tls" + "crypto/x509" + _ "embed" + "errors" + "fmt" + "io" + "log" + "testing" + "time" + + "github.com/docker/docker/errdefs" + "github.com/stretchr/testify/require" + + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" +) + +const ( + serverName = "127.0.0.1" + caFilename = "/tmp/ca.pem" + clientCertFilename = "/tmp/cert.crt" + clientKeyFilename = "/tmp/cert.key" +) + +var ( + //go:embed testdata/cert.crt + certBytes []byte + + //go:embed testdata/cert.key + keyBytes []byte +) + +// testForTLSCert creates a new CertStrategy for testing. +func testForTLSCert() *wait.TLSStrategy { + return wait.ForTLSCert(clientCertFilename, clientKeyFilename). + WithRootCAs(caFilename). + WithServerName(serverName). + WithStartupTimeout(time.Millisecond * 50). + WithPollInterval(time.Millisecond) +} + +func TestForCert(t *testing.T) { + errNotFound := errdefs.NotFound(errors.New("file not found")) + ctx := context.Background() + + t.Run("ca-not-found", func(t *testing.T) { + target := newRunningTarget() + target.EXPECT().CopyFileFromContainer(anyContext, caFilename).Return(nil, errNotFound) + err := testForTLSCert().WaitUntilReady(ctx, target) + require.EqualError(t, err, context.DeadlineExceeded.Error()) + }) + + t.Run("cert-not-found", func(t *testing.T) { + target := newRunningTarget() + caFile := io.NopCloser(bytes.NewBuffer(caBytes)) + target.EXPECT().CopyFileFromContainer(anyContext, caFilename).Return(caFile, nil) + target.EXPECT().CopyFileFromContainer(anyContext, clientCertFilename).Return(nil, errNotFound) + err := testForTLSCert().WaitUntilReady(ctx, target) + require.EqualError(t, err, context.DeadlineExceeded.Error()) + }) + + t.Run("key-not-found", func(t *testing.T) { + target := newRunningTarget() + caFile := io.NopCloser(bytes.NewBuffer(caBytes)) + certFile := io.NopCloser(bytes.NewBuffer(certBytes)) + target.EXPECT().CopyFileFromContainer(anyContext, caFilename).Return(caFile, nil) + target.EXPECT().CopyFileFromContainer(anyContext, clientCertFilename).Return(certFile, nil) + target.EXPECT().CopyFileFromContainer(anyContext, clientKeyFilename).Return(nil, errNotFound) + err := testForTLSCert().WaitUntilReady(ctx, target) + require.EqualError(t, err, context.DeadlineExceeded.Error()) + }) + + t.Run("valid", func(t *testing.T) { + target := newRunningTarget() + caFile := io.NopCloser(bytes.NewBuffer(caBytes)) + certFile := io.NopCloser(bytes.NewBuffer(certBytes)) + keyFile := io.NopCloser(bytes.NewBuffer(keyBytes)) + target.EXPECT().CopyFileFromContainer(anyContext, caFilename).Return(caFile, nil) + target.EXPECT().CopyFileFromContainer(anyContext, clientCertFilename).Return(certFile, nil) + target.EXPECT().CopyFileFromContainer(anyContext, clientKeyFilename).Return(keyFile, nil) + + certStrategy := testForTLSCert() + err := certStrategy.WaitUntilReady(ctx, target) + require.NoError(t, err) + + pool := x509.NewCertPool() + require.True(t, pool.AppendCertsFromPEM(caBytes)) + cert, err := tls.X509KeyPair(certBytes, keyBytes) + require.NoError(t, err) + got := certStrategy.TLSConfig() + require.Equal(t, serverName, got.ServerName) + require.Equal(t, []tls.Certificate{cert}, got.Certificates) + require.True(t, pool.Equal(got.RootCAs)) + }) +} + +func ExampleForTLSCert() { + ctx := context.Background() + + // waitForTLSCert { + // The file names passed to ForTLSCert are the paths where the files will + // be copied to in the container as detailed by the Dockerfile. + forCert := wait.ForTLSCert("/app/tls.pem", "/app/tls-key.pem"). + WithServerName("testcontainer.go.test") + req := testcontainers.ContainerRequest{ + FromDockerfile: testcontainers.FromDockerfile{ + Context: "testdata/http", + }, + WaitingFor: forCert, + } + // } + + c, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }) + defer func() { + if err := testcontainers.TerminateContainer(c); err != nil { + log.Printf("failed to terminate container: %s", err) + } + }() + if err != nil { + log.Printf("failed to start container: %s", err) + return + } + + state, err := c.State(ctx) + if err != nil { + log.Printf("failed to get container state: %s", err) + return + } + + fmt.Println(state.Running) + + // waitTLSConfig { + config := forCert.TLSConfig() + // } + fmt.Println(config.ServerName) + fmt.Println(len(config.Certificates)) + + // Output: + // true + // testcontainer.go.test + // 1 +} diff --git a/wait/walk.go b/wait/walk.go new file mode 100644 index 0000000000..4685e50088 --- /dev/null +++ b/wait/walk.go @@ -0,0 +1,74 @@ +package wait + +import ( + "errors" +) + +var ( + // VisitStop is used as a return value from [VisitFunc] to stop the walk. + // It is not returned as an error by any function. + VisitStop = errors.New("stop the walk") + + // VisitRemove is used as a return value from [VisitFunc] to have the current node removed. + // It is not returned as an error by any function. + VisitRemove = errors.New("remove this strategy") +) + +// VisitFunc is a function that visits a strategy node. +// If it returns [VisitStop], the walk stops. +// If it returns [VisitRemove], the current node is removed. +type VisitFunc func(root Strategy) error + +// Walk walks the strategies tree and calls the visit function for each node. +func Walk(root *Strategy, visit VisitFunc) error { + if root == nil { + return errors.New("root strategy is nil") + } + + if err := walk(root, visit); err != nil { + if errors.Is(err, VisitRemove) || errors.Is(err, VisitStop) { + return nil + } + return err + } + + return nil +} + +// walk walks the strategies tree and calls the visit function for each node. +// It returns an error if the visit function returns an error. +func walk(root *Strategy, visit VisitFunc) error { + if *root == nil { + // No strategy. + return nil + } + + // Allow the visit function to customize the behaviour of the walk before visiting the children. + if err := visit(*root); err != nil { + if errors.Is(err, VisitRemove) { + *root = nil + } + + return err + } + + if s, ok := (*root).(*MultiStrategy); ok { + var i int + for range s.Strategies { + if err := walk(&s.Strategies[i], visit); err != nil { + if errors.Is(err, VisitRemove) { + s.Strategies = append(s.Strategies[:i], s.Strategies[i+1:]...) + if errors.Is(err, VisitStop) { + return VisitStop + } + continue + } + + return err + } + i++ + } + } + + return nil +} diff --git a/wait/walk_test.go b/wait/walk_test.go new file mode 100644 index 0000000000..e8f8df2f2b --- /dev/null +++ b/wait/walk_test.go @@ -0,0 +1,127 @@ +package wait_test + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" +) + +func TestWalk(t *testing.T) { + req := testcontainers.ContainerRequest{ + WaitingFor: wait.ForAll( + wait.ForFile("/tmp/file"), + wait.ForHTTP("/health"), + wait.ForAll( + wait.ForFile("/tmp/other"), + ), + ), + } + + t.Run("walk", func(t *testing.T) { + var count int + err := wait.Walk(&req.WaitingFor, func(_ wait.Strategy) error { + count++ + return nil + }) + require.NoError(t, err) + require.Equal(t, 5, count) + }) + + t.Run("stop", func(t *testing.T) { + var count int + err := wait.Walk(&req.WaitingFor, func(_ wait.Strategy) error { + count++ + return wait.VisitStop + }) + require.NoError(t, err) + require.Equal(t, 1, count) + }) + + t.Run("remove", func(t *testing.T) { + // walkRemoveFileStrategy { + var count, matched int + err := wait.Walk(&req.WaitingFor, func(s wait.Strategy) error { + count++ + if _, ok := s.(*wait.FileStrategy); ok { + matched++ + return wait.VisitRemove + } + + return nil + }) + // } + require.NoError(t, err) + require.Equal(t, 5, count) + require.Equal(t, 2, matched) + + count = 0 + matched = 0 + err = wait.Walk(&req.WaitingFor, func(s wait.Strategy) error { + count++ + if _, ok := s.(*wait.FileStrategy); ok { + matched++ + } + return nil + }) + require.NoError(t, err) + require.Equal(t, 3, count) + require.Zero(t, matched) + }) + + t.Run("remove-stop", func(t *testing.T) { + req := testcontainers.ContainerRequest{ + WaitingFor: wait.ForAll( + wait.ForFile("/tmp/file"), + wait.ForHTTP("/health"), + ), + } + var count int + err := wait.Walk(&req.WaitingFor, func(_ wait.Strategy) error { + count++ + return errors.Join(wait.VisitRemove, wait.VisitStop) + }) + require.NoError(t, err) + require.Equal(t, 1, count) + require.Nil(t, req.WaitingFor) + }) + + t.Run("nil-root", func(t *testing.T) { + err := wait.Walk(nil, func(_ wait.Strategy) error { + return nil + }) + require.EqualError(t, err, "root strategy is nil") + }) + + t.Run("direct-single", func(t *testing.T) { + req := testcontainers.ContainerRequest{ + WaitingFor: wait.ForFile("/tmp/file"), + } + requireVisits(t, req, 1) + }) + + t.Run("for-all-single", func(t *testing.T) { + req := testcontainers.ContainerRequest{ + WaitingFor: wait.ForAll( + wait.ForFile("/tmp/file"), + ), + } + requireVisits(t, req, 2) + }) +} + +// requireVisits validates the number of visits for a given request. +func requireVisits(t *testing.T, req testcontainers.ContainerRequest, expected int) { + t.Helper() + + var count int + err := wait.Walk(&req.WaitingFor, func(_ wait.Strategy) error { + count++ + return nil + }) + require.NoError(t, err) + require.Equal(t, expected, count) +}