From c58631bf667c03061ec5a10745fd9048655f9e7a Mon Sep 17 00:00:00 2001 From: Steven Hartland Date: Fri, 1 Nov 2024 14:37:34 +0000 Subject: [PATCH] feat!(cockroachdb): simplify connection handling Simplify the connection handling in cockroachdb so that ConnectionString can be used without the user doing extra work to handle TLS if enabled. Deprecate TLSConfig which is no longer needed separately. BREAKING_CHANGE: This now returns a registered connection string so is no longer compatible with pgx.ParseConfig, use ConnectionConfig for this use case instead. --- modules/cockroachdb/certs.go | 19 ++++ modules/cockroachdb/cockroachdb.go | 118 ++++++------------------ modules/cockroachdb/cockroachdb_test.go | 52 +++-------- modules/cockroachdb/examples_test.go | 26 ++++-- modules/cockroachdb/options.go | 92 +++++++++++++++++- 5 files changed, 168 insertions(+), 139 deletions(-) diff --git a/modules/cockroachdb/certs.go b/modules/cockroachdb/certs.go index afa12fcd1a..61280b4db9 100644 --- a/modules/cockroachdb/certs.go +++ b/modules/cockroachdb/certs.go @@ -1,8 +1,10 @@ package cockroachdb import ( + "crypto/tls" "crypto/x509" "errors" + "fmt" "net" "time" @@ -65,3 +67,20 @@ func NewTLSConfig() (*TLSConfig, error) { ClientKey: clientCert.KeyBytes, }, nil } + +// tlsConfig returns a [tls.Config] for options. +func (c *TLSConfig) tlsConfig() (*tls.Config, error) { + keyPair, err := tls.X509KeyPair(c.ClientCert, c.ClientKey) + if err != nil { + return nil, fmt.Errorf("x509 key pair: %w", err) + } + + certPool := x509.NewCertPool() + certPool.AddCert(c.CACert) + + return &tls.Config{ + RootCAs: certPool, + Certificates: []tls.Certificate{keyPair}, + ServerName: "localhost", + }, nil +} diff --git a/modules/cockroachdb/cockroachdb.go b/modules/cockroachdb/cockroachdb.go index 884d8f076f..98c5ebf149 100644 --- a/modules/cockroachdb/cockroachdb.go +++ b/modules/cockroachdb/cockroachdb.go @@ -3,23 +3,20 @@ package cockroachdb import ( "context" "crypto/tls" - "crypto/x509" "database/sql" "encoding/pem" "errors" "fmt" - "net" - "net/url" "path/filepath" "github.com/docker/go-connections/nat" "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/stdlib" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/wait" ) +// ErrTLSNotEnabled is returned when trying to get a TLS config from a container that does not have TLS enabled. var ErrTLSNotEnabled = errors.New("tls not enabled") const ( @@ -40,7 +37,9 @@ type CockroachDBContainer struct { opts options } -// MustConnectionString panics if the address cannot be determined. +// MustConnectionString returns a connection string to open a new connection to CockroachDB +// as described by [CockroachDBContainer.ConnectionString]. +// It panics if an error occurs. func (c *CockroachDBContainer) MustConnectionString(ctx context.Context) string { addr, err := c.ConnectionString(ctx) if err != nil { @@ -49,27 +48,33 @@ func (c *CockroachDBContainer) MustConnectionString(ctx context.Context) string return addr } -// ConnectionString returns the dial address to open a new connection to CockroachDB. +// ConnectionString returns a connection string to open a new connection to CockroachDB. +// The returned string is suitable for use by [sql.Open] but is not be compatible with +// [pgx.ParseConfig], so if you want to call [pgx.ConnectConfig] use the +// [CockroachDBContainer.ConnectionConfig] method instead. func (c *CockroachDBContainer) ConnectionString(ctx context.Context) (string, error) { - port, err := c.MappedPort(ctx, defaultSQLPort) - if err != nil { - return "", err - } - - host, err := c.Host(ctx) - if err != nil { - return "", err - } + return c.opts.containerConnString(ctx, c.Container) +} - return connString(c.opts, host, port), nil +// ConnectionConfig returns a [pgx.ConnConfig] for the CockroachDB container. +// This can be passed to [pgx.ConnectConfig] to open a new connection. +func (c *CockroachDBContainer) ConnectionConfig(ctx context.Context) (*pgx.ConnConfig, error) { + return c.opts.containerConnConfig(ctx, c.Container) } // TLSConfig returns config necessary to connect to CockroachDB over TLS. +// +// Deprecated: use [CockroachDBContainer.ConnectionConfig] or +// [CockroachDBContainer.ConnectionConfig] instead. func (c *CockroachDBContainer) TLSConfig() (*tls.Config, error) { - return connTLS(c.opts) + if c.opts.TLS == nil { + return nil, ErrTLSNotEnabled + } + + return c.opts.TLS.tlsConfig() } -// Deprecated: use Run instead +// Deprecated: use Run instead. // RunContainer creates an instance of the CockroachDB container type func RunContainer(ctx context.Context, opts ...testcontainers.ContainerCustomizer) (*CockroachDBContainer, error) { return Run(ctx, "cockroachdb/cockroach:latest-v23.1", opts...) @@ -178,29 +183,12 @@ func addEnvs(req *testcontainers.GenericContainerRequest, opts options) error { } func addWaitingFor(req *testcontainers.GenericContainerRequest, opts options) error { - var tlsConfig *tls.Config - if opts.TLS != nil { - cfg, err := connTLS(opts) - if err != nil { - return err - } - tlsConfig = cfg - } - sqlWait := wait.ForSQL(defaultSQLPort, "pgx/v5", func(host string, port nat.Port) string { - connStr := connString(opts, host, port) - if tlsConfig == nil { - return connStr - } - - // register TLS config with pgx driver - connCfg, err := pgx.ParseConfig(connStr) + connStr, err := opts.connString(host, port) if err != nil { panic(err) } - connCfg.TLSConfig = tlsConfig - - return stdlib.RegisterConnConfig(connCfg) + return connStr }) defaultStrategy := wait.ForAll( wait.ForHTTP("/health").WithPort(defaultAdminPort), @@ -246,17 +234,12 @@ func runStatements(ctx context.Context, container testcontainers.Container, opts return nil } - port, err := container.MappedPort(ctx, defaultSQLPort) - if err != nil { - return fmt.Errorf("mapped port: %w", err) - } - - host, err := container.Host(ctx) + connStr, err := opts.containerConnString(ctx, container) if err != nil { - return fmt.Errorf("host: %w", err) + return fmt.Errorf("connection string: %w", err) } - db, err := sql.Open("pgx/v5", connString(opts, host, port)) + db, err := sql.Open("pgx/v5", connStr) if err != nil { return fmt.Errorf("sql.Open: %w", err) } @@ -275,48 +258,3 @@ func runStatements(ctx context.Context, container testcontainers.Container, opts return nil } - -func connString(opts options, host string, port nat.Port) string { - user := url.User(opts.User) - if opts.Password != "" { - user = url.UserPassword(opts.User, opts.Password) - } - - sslMode := "disable" - if opts.TLS != nil { - sslMode = "verify-full" - } - params := url.Values{ - "sslmode": []string{sslMode}, - } - - u := url.URL{ - Scheme: "postgres", - User: user, - Host: net.JoinHostPort(host, port.Port()), - Path: opts.Database, - RawQuery: params.Encode(), - } - - return u.String() -} - -func connTLS(opts options) (*tls.Config, error) { - if opts.TLS == nil { - return nil, ErrTLSNotEnabled - } - - keyPair, err := tls.X509KeyPair(opts.TLS.ClientCert, opts.TLS.ClientKey) - if err != nil { - return nil, err - } - - certPool := x509.NewCertPool() - certPool.AddCert(opts.TLS.CACert) - - return &tls.Config{ - RootCAs: certPool, - Certificates: []tls.Certificate{keyPair}, - ServerName: "localhost", - }, nil -} diff --git a/modules/cockroachdb/cockroachdb_test.go b/modules/cockroachdb/cockroachdb_test.go index 45df7909bb..d3d05d1d86 100644 --- a/modules/cockroachdb/cockroachdb_test.go +++ b/modules/cockroachdb/cockroachdb_test.go @@ -2,9 +2,6 @@ package cockroachdb_test import ( "context" - "errors" - "net/url" - "strings" "testing" "time" @@ -18,14 +15,11 @@ import ( ) func TestCockroach_Insecure(t *testing.T) { - suite.Run(t, &AuthNSuite{ - url: "postgres://root@localhost:xxxxx/defaultdb?sslmode=disable", - }) + suite.Run(t, &AuthNSuite{}) } func TestCockroach_NotRoot(t *testing.T) { suite.Run(t, &AuthNSuite{ - url: "postgres://test@localhost:xxxxx/defaultdb?sslmode=disable", opts: []testcontainers.ContainerCustomizer{ cockroachdb.WithUser("test"), // Do not run the default statements as the user used on this test is @@ -37,7 +31,6 @@ func TestCockroach_NotRoot(t *testing.T) { func TestCockroach_Password(t *testing.T) { suite.Run(t, &AuthNSuite{ - url: "postgres://foo:bar@localhost:xxxxx/defaultdb?sslmode=disable", opts: []testcontainers.ContainerCustomizer{ cockroachdb.WithUser("foo"), cockroachdb.WithPassword("bar"), @@ -53,19 +46,26 @@ func TestCockroach_TLS(t *testing.T) { require.NoError(t, err) suite.Run(t, &AuthNSuite{ - url: "postgres://root@localhost:xxxxx/defaultdb?sslmode=verify-full", opts: []testcontainers.ContainerCustomizer{ cockroachdb.WithTLS(tlsCfg), - // Do not run the default statements as the user used on this test is - // lacking the needed MODIFYCLUSTERSETTING privilege to run them. - cockroachdb.WithStatements(), }, }) } +func TestTLS(t *testing.T) { + tlsCfg, err := cockroachdb.NewTLSConfig() + require.NoError(t, err) + + ctx := context.Background() + + ctr, err := cockroachdb.Run(ctx, "cockroachdb/cockroach:latest-v23.1", cockroachdb.WithTLS(tlsCfg)) + testcontainers.CleanupContainer(t, ctr) + require.NoError(t, err) + require.NotNil(t, ctr) +} + type AuthNSuite struct { suite.Suite - url string opts []testcontainers.ContainerCustomizer } @@ -75,11 +75,6 @@ func (suite *AuthNSuite) TestConnectionString() { ctr, err := cockroachdb.Run(ctx, "cockroachdb/cockroach:latest-v23.1", suite.opts...) testcontainers.CleanupContainer(suite.T(), ctr) suite.Require().NoError(err) - - connStr, err := removePort(ctr.MustConnectionString(ctx)) - suite.Require().NoError(err) - - suite.Equal(suite.url, connStr) } func (suite *AuthNSuite) TestPing() { @@ -203,29 +198,10 @@ func (suite *AuthNSuite) TestWithWaitStrategyAndDeadline() { } func conn(ctx context.Context, container *cockroachdb.CockroachDBContainer) (*pgx.Conn, error) { - cfg, err := pgx.ParseConfig(container.MustConnectionString(ctx)) + cfg, err := container.ConnectionConfig(ctx) if err != nil { return nil, err } - tlsCfg, err := container.TLSConfig() - switch { - case err != nil: - if !errors.Is(err, cockroachdb.ErrTLSNotEnabled) { - return nil, err - } - default: - // apply TLS config - cfg.TLSConfig = tlsCfg - } - return pgx.ConnectConfig(ctx, cfg) } - -func removePort(s string) (string, error) { - u, err := url.Parse(s) - if err != nil { - return "", err - } - return strings.Replace(s, ":"+u.Port(), ":xxxxx", 1), nil -} diff --git a/modules/cockroachdb/examples_test.go b/modules/cockroachdb/examples_test.go index 9a8fb12881..4cc14e8b7b 100644 --- a/modules/cockroachdb/examples_test.go +++ b/modules/cockroachdb/examples_test.go @@ -5,7 +5,8 @@ import ( "database/sql" "fmt" "log" - "net/url" + + "github.com/jackc/pgx/v5" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/modules/cockroachdb" @@ -34,25 +35,34 @@ func ExampleRun() { } fmt.Println(state.Running) - addr, err := cockroachdbContainer.ConnectionString(ctx) + cfg, err := cockroachdbContainer.ConnectionConfig(ctx) if err != nil { log.Printf("failed to get connection string: %s", err) return } - u, err := url.Parse(addr) + + conn, err := pgx.ConnectConfig(ctx, cfg) if err != nil { - log.Printf("failed to parse connection string: %s", err) + log.Printf("failed to connect: %s", err) + return + } + + defer func() { + if err := conn.Close(ctx); err != nil { + log.Printf("failed to close connection: %s", err) + } + }() + + if err = conn.Ping(ctx); err != nil { + log.Printf("failed to ping: %s", err) return } - u.Host = fmt.Sprintf("%s:%s", u.Hostname(), "xxx") - fmt.Println(u.String()) // Output: // true - // postgres://root@localhost:xxx/defaultdb?sslmode=disable } -func ExampleRun_withRecommendedSettings() { +func ExampleRun_withCustomStatements() { ctx := context.Background() cockroachdbContainer, err := cockroachdb.Run(ctx, "cockroachdb/cockroach:latest-v23.1", cockroachdb.WithStatements(cockroachdb.DefaultStatements...)) diff --git a/modules/cockroachdb/options.go b/modules/cockroachdb/options.go index eba101834e..81a1843f71 100644 --- a/modules/cockroachdb/options.go +++ b/modules/cockroachdb/options.go @@ -1,6 +1,17 @@ package cockroachdb -import "github.com/testcontainers/testcontainers-go" +import ( + "context" + "fmt" + "net" + "net/url" + + "github.com/docker/go-connections/nat" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/stdlib" + + "github.com/testcontainers/testcontainers-go" +) type options struct { Database string @@ -11,6 +22,81 @@ type options struct { Statements []string } +// containerConnConfig returns the [pgx.ConnConfig] for the given container and options. +func (opts options) containerConnConfig(ctx context.Context, container testcontainers.Container) (*pgx.ConnConfig, error) { + port, err := container.MappedPort(ctx, defaultSQLPort) + if err != nil { + return nil, fmt.Errorf("mapped port: %w", err) + } + + host, err := container.Host(ctx) + if err != nil { + return nil, fmt.Errorf("host: %w", err) + } + + return opts.connConfig(host, port) +} + +// containerConnString returns the connection string for the given container and options. +func (opts options) containerConnString(ctx context.Context, container testcontainers.Container) (string, error) { + cfg, err := opts.containerConnConfig(ctx, container) + if err != nil { + return "", fmt.Errorf("container connection config: %w", err) + } + + return stdlib.RegisterConnConfig(cfg), nil +} + +// connString returns a connection string for the given host, port and options. +func (opts options) connString(host string, port nat.Port) (string, error) { + cfg, err := opts.connConfig(host, port) + if err != nil { + return "", fmt.Errorf("connection config: %w", err) + } + + return stdlib.RegisterConnConfig(cfg), nil +} + +// connConfig returns a [pgx.ConnConfig] for the given host, port and options. +func (opts options) connConfig(host string, port nat.Port) (*pgx.ConnConfig, error) { + user := url.User(opts.User) + if opts.Password != "" { + user = url.UserPassword(opts.User, opts.Password) + } + + sslMode := "disable" + if opts.TLS != nil { + sslMode = "require" // We can't use "verify-full" as it might be a self signed cert. + } + params := url.Values{ + "sslmode": []string{sslMode}, + } + + u := url.URL{ + Scheme: "postgres", + User: user, + Host: net.JoinHostPort(host, port.Port()), + Path: opts.Database, + RawQuery: params.Encode(), + } + + cfg, err := pgx.ParseConfig(u.String()) + if err != nil { + return nil, fmt.Errorf("parse config: %w", err) + } + + if opts.TLS != nil { + tlsCfg, err := opts.TLS.tlsConfig() + if err != nil { + return nil, fmt.Errorf("tls config: %w", err) + } + + cfg.TLSConfig = tlsCfg + } + + return cfg, nil +} + func defaultOptions() options { return options{ User: defaultUser, @@ -85,8 +171,8 @@ var DefaultStatements = []string{ } // WithStatements sets the statements to run on the CockroachDB cluster once the container is ready. -// This, in combination with DefaultStatements, can be used to configure the cluster with the settings -// recommended by Cockroach Labs. +// By default, the container will run the statements in [DefaultStatements] as recommended by +// Cockroach Labs however that is not always possible due to the user not having the required privileges. func WithStatements(statements ...string) Option { return func(o *options) { o.Statements = statements