diff --git a/modules/cockroachdb/certs.go b/modules/cockroachdb/certs.go index afa12fcd1a..61280b4db9 100644 --- a/modules/cockroachdb/certs.go +++ b/modules/cockroachdb/certs.go @@ -1,8 +1,10 @@ package cockroachdb import ( + "crypto/tls" "crypto/x509" "errors" + "fmt" "net" "time" @@ -65,3 +67,20 @@ func NewTLSConfig() (*TLSConfig, error) { ClientKey: clientCert.KeyBytes, }, nil } + +// tlsConfig returns a [tls.Config] for options. +func (c *TLSConfig) tlsConfig() (*tls.Config, error) { + keyPair, err := tls.X509KeyPair(c.ClientCert, c.ClientKey) + if err != nil { + return nil, fmt.Errorf("x509 key pair: %w", err) + } + + certPool := x509.NewCertPool() + certPool.AddCert(c.CACert) + + return &tls.Config{ + RootCAs: certPool, + Certificates: []tls.Certificate{keyPair}, + ServerName: "localhost", + }, nil +} diff --git a/modules/cockroachdb/cockroachdb.go b/modules/cockroachdb/cockroachdb.go index 092efa4e2a..98c5ebf149 100644 --- a/modules/cockroachdb/cockroachdb.go +++ b/modules/cockroachdb/cockroachdb.go @@ -3,22 +3,20 @@ package cockroachdb import ( "context" "crypto/tls" - "crypto/x509" + "database/sql" "encoding/pem" "errors" "fmt" - "net" - "net/url" "path/filepath" "github.com/docker/go-connections/nat" "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/stdlib" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/wait" ) +// ErrTLSNotEnabled is returned when trying to get a TLS config from a container that does not have TLS enabled. var ErrTLSNotEnabled = errors.New("tls not enabled") const ( @@ -39,7 +37,9 @@ type CockroachDBContainer struct { opts options } -// MustConnectionString panics if the address cannot be determined. +// MustConnectionString returns a connection string to open a new connection to CockroachDB +// as described by [CockroachDBContainer.ConnectionString]. +// It panics if an error occurs. func (c *CockroachDBContainer) MustConnectionString(ctx context.Context) string { addr, err := c.ConnectionString(ctx) if err != nil { @@ -48,27 +48,33 @@ func (c *CockroachDBContainer) MustConnectionString(ctx context.Context) string return addr } -// ConnectionString returns the dial address to open a new connection to CockroachDB. +// ConnectionString returns a connection string to open a new connection to CockroachDB. +// The returned string is suitable for use by [sql.Open] but is not be compatible with +// [pgx.ParseConfig], so if you want to call [pgx.ConnectConfig] use the +// [CockroachDBContainer.ConnectionConfig] method instead. func (c *CockroachDBContainer) ConnectionString(ctx context.Context) (string, error) { - port, err := c.MappedPort(ctx, defaultSQLPort) - if err != nil { - return "", err - } - - host, err := c.Host(ctx) - if err != nil { - return "", err - } + return c.opts.containerConnString(ctx, c.Container) +} - return connString(c.opts, host, port), nil +// ConnectionConfig returns a [pgx.ConnConfig] for the CockroachDB container. +// This can be passed to [pgx.ConnectConfig] to open a new connection. +func (c *CockroachDBContainer) ConnectionConfig(ctx context.Context) (*pgx.ConnConfig, error) { + return c.opts.containerConnConfig(ctx, c.Container) } // TLSConfig returns config necessary to connect to CockroachDB over TLS. +// +// Deprecated: use [CockroachDBContainer.ConnectionConfig] or +// [CockroachDBContainer.ConnectionConfig] instead. func (c *CockroachDBContainer) TLSConfig() (*tls.Config, error) { - return connTLS(c.opts) + if c.opts.TLS == nil { + return nil, ErrTLSNotEnabled + } + + return c.opts.TLS.tlsConfig() } -// Deprecated: use Run instead +// Deprecated: use Run instead. // RunContainer creates an instance of the CockroachDB container type func RunContainer(ctx context.Context, opts ...testcontainers.ContainerCustomizer) (*CockroachDBContainer, error) { return Run(ctx, "cockroachdb/cockroach:latest-v23.1", opts...) @@ -91,6 +97,11 @@ func Run(ctx context.Context, img string, opts ...testcontainers.ContainerCustom return addTLS(ctx, container, o) }, }, + PostReadies: []testcontainers.ContainerHook{ + func(ctx context.Context, container testcontainers.Container) error { + return runStatements(ctx, container, o) + }, + }, }, }, }, @@ -172,29 +183,12 @@ func addEnvs(req *testcontainers.GenericContainerRequest, opts options) error { } func addWaitingFor(req *testcontainers.GenericContainerRequest, opts options) error { - var tlsConfig *tls.Config - if opts.TLS != nil { - cfg, err := connTLS(opts) - if err != nil { - return err - } - tlsConfig = cfg - } - sqlWait := wait.ForSQL(defaultSQLPort, "pgx/v5", func(host string, port nat.Port) string { - connStr := connString(opts, host, port) - if tlsConfig == nil { - return connStr - } - - // register TLS config with pgx driver - connCfg, err := pgx.ParseConfig(connStr) + connStr, err := opts.connString(host, port) if err != nil { panic(err) } - connCfg.TLSConfig = tlsConfig - - return stdlib.RegisterConnConfig(connCfg) + return connStr }) defaultStrategy := wait.ForAll( wait.ForHTTP("/health").WithPort(defaultAdminPort), @@ -234,47 +228,33 @@ func addTLS(ctx context.Context, container testcontainers.Container, opts option return nil } -func connString(opts options, host string, port nat.Port) string { - user := url.User(opts.User) - if opts.Password != "" { - user = url.UserPassword(opts.User, opts.Password) - } - - sslMode := "disable" - if opts.TLS != nil { - sslMode = "verify-full" - } - params := url.Values{ - "sslmode": []string{sslMode}, - } - - u := url.URL{ - Scheme: "postgres", - User: user, - Host: net.JoinHostPort(host, port.Port()), - Path: opts.Database, - RawQuery: params.Encode(), +// runStatements runs the configured statements against the CockroachDB container. +func runStatements(ctx context.Context, container testcontainers.Container, opts options) (err error) { + if len(opts.Statements) == 0 { + return nil } - return u.String() -} - -func connTLS(opts options) (*tls.Config, error) { - if opts.TLS == nil { - return nil, ErrTLSNotEnabled + connStr, err := opts.containerConnString(ctx, container) + if err != nil { + return fmt.Errorf("connection string: %w", err) } - keyPair, err := tls.X509KeyPair(opts.TLS.ClientCert, opts.TLS.ClientKey) + db, err := sql.Open("pgx/v5", connStr) if err != nil { - return nil, err + return fmt.Errorf("sql.Open: %w", err) } + defer func() { + cerr := db.Close() + if err == nil { + err = cerr + } + }() - certPool := x509.NewCertPool() - certPool.AddCert(opts.TLS.CACert) + for _, stmt := range opts.Statements { + if _, err = db.Exec(stmt); err != nil { + return fmt.Errorf("db.Exec: %w", err) + } + } - return &tls.Config{ - RootCAs: certPool, - Certificates: []tls.Certificate{keyPair}, - ServerName: "localhost", - }, nil + return nil } diff --git a/modules/cockroachdb/cockroachdb_test.go b/modules/cockroachdb/cockroachdb_test.go index cc355e9168..d3d05d1d86 100644 --- a/modules/cockroachdb/cockroachdb_test.go +++ b/modules/cockroachdb/cockroachdb_test.go @@ -2,9 +2,6 @@ package cockroachdb_test import ( "context" - "errors" - "net/url" - "strings" "testing" "time" @@ -18,26 +15,28 @@ import ( ) func TestCockroach_Insecure(t *testing.T) { - suite.Run(t, &AuthNSuite{ - url: "postgres://root@localhost:xxxxx/defaultdb?sslmode=disable", - }) + suite.Run(t, &AuthNSuite{}) } func TestCockroach_NotRoot(t *testing.T) { suite.Run(t, &AuthNSuite{ - url: "postgres://test@localhost:xxxxx/defaultdb?sslmode=disable", opts: []testcontainers.ContainerCustomizer{ cockroachdb.WithUser("test"), + // Do not run the default statements as the user used on this test is + // lacking the needed MODIFYCLUSTERSETTING privilege to run them. + cockroachdb.WithStatements(), }, }) } func TestCockroach_Password(t *testing.T) { suite.Run(t, &AuthNSuite{ - url: "postgres://foo:bar@localhost:xxxxx/defaultdb?sslmode=disable", opts: []testcontainers.ContainerCustomizer{ cockroachdb.WithUser("foo"), cockroachdb.WithPassword("bar"), + // Do not run the default statements as the user used on this test is + // lacking the needed MODIFYCLUSTERSETTING privilege to run them. + cockroachdb.WithStatements(), }, }) } @@ -47,16 +46,26 @@ func TestCockroach_TLS(t *testing.T) { require.NoError(t, err) suite.Run(t, &AuthNSuite{ - url: "postgres://root@localhost:xxxxx/defaultdb?sslmode=verify-full", opts: []testcontainers.ContainerCustomizer{ cockroachdb.WithTLS(tlsCfg), }, }) } +func TestTLS(t *testing.T) { + tlsCfg, err := cockroachdb.NewTLSConfig() + require.NoError(t, err) + + ctx := context.Background() + + ctr, err := cockroachdb.Run(ctx, "cockroachdb/cockroach:latest-v23.1", cockroachdb.WithTLS(tlsCfg)) + testcontainers.CleanupContainer(t, ctr) + require.NoError(t, err) + require.NotNil(t, ctr) +} + type AuthNSuite struct { suite.Suite - url string opts []testcontainers.ContainerCustomizer } @@ -66,11 +75,6 @@ func (suite *AuthNSuite) TestConnectionString() { ctr, err := cockroachdb.Run(ctx, "cockroachdb/cockroach:latest-v23.1", suite.opts...) testcontainers.CleanupContainer(suite.T(), ctr) suite.Require().NoError(err) - - connStr, err := removePort(ctr.MustConnectionString(ctx)) - suite.Require().NoError(err) - - suite.Equal(suite.url, connStr) } func (suite *AuthNSuite) TestPing() { @@ -194,29 +198,10 @@ func (suite *AuthNSuite) TestWithWaitStrategyAndDeadline() { } func conn(ctx context.Context, container *cockroachdb.CockroachDBContainer) (*pgx.Conn, error) { - cfg, err := pgx.ParseConfig(container.MustConnectionString(ctx)) + cfg, err := container.ConnectionConfig(ctx) if err != nil { return nil, err } - tlsCfg, err := container.TLSConfig() - switch { - case err != nil: - if !errors.Is(err, cockroachdb.ErrTLSNotEnabled) { - return nil, err - } - default: - // apply TLS config - cfg.TLSConfig = tlsCfg - } - return pgx.ConnectConfig(ctx, cfg) } - -func removePort(s string) (string, error) { - u, err := url.Parse(s) - if err != nil { - return "", err - } - return strings.Replace(s, ":"+u.Port(), ":xxxxx", 1), nil -} diff --git a/modules/cockroachdb/examples_test.go b/modules/cockroachdb/examples_test.go index c06c97596b..4cc14e8b7b 100644 --- a/modules/cockroachdb/examples_test.go +++ b/modules/cockroachdb/examples_test.go @@ -2,9 +2,11 @@ package cockroachdb_test import ( "context" + "database/sql" "fmt" "log" - "net/url" + + "github.com/jackc/pgx/v5" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/modules/cockroachdb" @@ -33,20 +35,87 @@ func ExampleRun() { } fmt.Println(state.Running) + cfg, err := cockroachdbContainer.ConnectionConfig(ctx) + if err != nil { + log.Printf("failed to get connection string: %s", err) + return + } + + conn, err := pgx.ConnectConfig(ctx, cfg) + if err != nil { + log.Printf("failed to connect: %s", err) + return + } + + defer func() { + if err := conn.Close(ctx); err != nil { + log.Printf("failed to close connection: %s", err) + } + }() + + if err = conn.Ping(ctx); err != nil { + log.Printf("failed to ping: %s", err) + return + } + + // Output: + // true +} + +func ExampleRun_withCustomStatements() { + ctx := context.Background() + + cockroachdbContainer, err := cockroachdb.Run(ctx, "cockroachdb/cockroach:latest-v23.1", cockroachdb.WithStatements(cockroachdb.DefaultStatements...)) + defer func() { + if err := testcontainers.TerminateContainer(cockroachdbContainer); err != nil { + log.Printf("failed to terminate container: %s", err) + } + }() + if err != nil { + log.Printf("failed to start container: %s", err) + return + } + + state, err := cockroachdbContainer.State(ctx) + if err != nil { + log.Printf("failed to get container state: %s", err) + return + } + fmt.Println(state.Running) + addr, err := cockroachdbContainer.ConnectionString(ctx) if err != nil { log.Printf("failed to get connection string: %s", err) return } - u, err := url.Parse(addr) + + db, err := sql.Open("pgx/v5", addr) if err != nil { - log.Printf("failed to parse connection string: %s", err) + log.Printf("failed to open connection: %s", err) + return + } + defer func() { + if err := db.Close(); err != nil { + log.Printf("failed to close connection: %s", err) + } + }() + + var queueInterval string + if err := db.QueryRow("SHOW CLUSTER SETTING kv.range_merge.queue_interval").Scan(&queueInterval); err != nil { + log.Printf("failed to scan row: %s", err) + return + } + fmt.Println(queueInterval) + + var statsCollectionEnabled bool + if err := db.QueryRow("SHOW CLUSTER SETTING sql.stats.automatic_collection.enabled").Scan(&statsCollectionEnabled); err != nil { + log.Printf("failed to scan row: %s", err) return } - u.Host = fmt.Sprintf("%s:%s", u.Hostname(), "xxx") - fmt.Println(u.String()) + fmt.Println(statsCollectionEnabled) // Output: // true - // postgres://root@localhost:xxx/defaultdb?sslmode=disable + // 00:00:00.05 + // false } diff --git a/modules/cockroachdb/options.go b/modules/cockroachdb/options.go index a2211d77e7..81a1843f71 100644 --- a/modules/cockroachdb/options.go +++ b/modules/cockroachdb/options.go @@ -1,21 +1,109 @@ package cockroachdb -import "github.com/testcontainers/testcontainers-go" +import ( + "context" + "fmt" + "net" + "net/url" + + "github.com/docker/go-connections/nat" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/stdlib" + + "github.com/testcontainers/testcontainers-go" +) type options struct { - Database string - User string - Password string - StoreSize string - TLS *TLSConfig + Database string + User string + Password string + StoreSize string + TLS *TLSConfig + Statements []string +} + +// containerConnConfig returns the [pgx.ConnConfig] for the given container and options. +func (opts options) containerConnConfig(ctx context.Context, container testcontainers.Container) (*pgx.ConnConfig, error) { + port, err := container.MappedPort(ctx, defaultSQLPort) + if err != nil { + return nil, fmt.Errorf("mapped port: %w", err) + } + + host, err := container.Host(ctx) + if err != nil { + return nil, fmt.Errorf("host: %w", err) + } + + return opts.connConfig(host, port) +} + +// containerConnString returns the connection string for the given container and options. +func (opts options) containerConnString(ctx context.Context, container testcontainers.Container) (string, error) { + cfg, err := opts.containerConnConfig(ctx, container) + if err != nil { + return "", fmt.Errorf("container connection config: %w", err) + } + + return stdlib.RegisterConnConfig(cfg), nil +} + +// connString returns a connection string for the given host, port and options. +func (opts options) connString(host string, port nat.Port) (string, error) { + cfg, err := opts.connConfig(host, port) + if err != nil { + return "", fmt.Errorf("connection config: %w", err) + } + + return stdlib.RegisterConnConfig(cfg), nil +} + +// connConfig returns a [pgx.ConnConfig] for the given host, port and options. +func (opts options) connConfig(host string, port nat.Port) (*pgx.ConnConfig, error) { + user := url.User(opts.User) + if opts.Password != "" { + user = url.UserPassword(opts.User, opts.Password) + } + + sslMode := "disable" + if opts.TLS != nil { + sslMode = "require" // We can't use "verify-full" as it might be a self signed cert. + } + params := url.Values{ + "sslmode": []string{sslMode}, + } + + u := url.URL{ + Scheme: "postgres", + User: user, + Host: net.JoinHostPort(host, port.Port()), + Path: opts.Database, + RawQuery: params.Encode(), + } + + cfg, err := pgx.ParseConfig(u.String()) + if err != nil { + return nil, fmt.Errorf("parse config: %w", err) + } + + if opts.TLS != nil { + tlsCfg, err := opts.TLS.tlsConfig() + if err != nil { + return nil, fmt.Errorf("tls config: %w", err) + } + + cfg.TLSConfig = tlsCfg + } + + return cfg, nil } func defaultOptions() options { return options{ - User: defaultUser, - Password: defaultPassword, - Database: defaultDatabase, - StoreSize: defaultStoreSize, + User: defaultUser, + Password: defaultPassword, + Database: defaultDatabase, + StoreSize: defaultStoreSize, + Statements: DefaultStatements, } } @@ -67,3 +155,26 @@ func WithTLS(cfg *TLSConfig) Option { o.TLS = cfg } } + +// DefaultStatements are the settings recommended by Cockroach Labs for testing clusters. +// Note that to use these defaults the user needs to have MODIFYCLUSTERSETTING privilege. +// See https://www.cockroachlabs.com/docs/stable/local-testing for more information. +var DefaultStatements = []string{ + "SET CLUSTER SETTING kv.range_merge.queue_interval = '50ms'", + "SET CLUSTER SETTING jobs.registry.interval.gc = '30s'", + "SET CLUSTER SETTING jobs.registry.interval.cancel = '180s'", + "SET CLUSTER SETTING jobs.retention_time = '15s'", + "SET CLUSTER SETTING sql.stats.automatic_collection.enabled = false", + "SET CLUSTER SETTING kv.range_split.by_load_merge_delay = '5s'", + `ALTER RANGE default CONFIGURE ZONE USING "gc.ttlseconds" = 600`, + `ALTER DATABASE system CONFIGURE ZONE USING "gc.ttlseconds" = 600`, +} + +// WithStatements sets the statements to run on the CockroachDB cluster once the container is ready. +// By default, the container will run the statements in [DefaultStatements] as recommended by +// Cockroach Labs however that is not always possible due to the user not having the required privileges. +func WithStatements(statements ...string) Option { + return func(o *options) { + o.Statements = statements + } +}