Skip to content

Commit

Permalink
feat!(cockroachdb): simplify connection handling
Browse files Browse the repository at this point in the history
Simplify the connection handling in cockroachdb so that ConnectionString
can be used without the user doing extra work to handle TLS if enabled.

Deprecate TLSConfig which is no longer needed separately.

BREAKING_CHANGE: This now returns a registered connection string so is
no longer compatible with pgx.ParseConfig, use ConnectionConfig for this
use case instead.
  • Loading branch information
stevenh committed Nov 1, 2024
1 parent ff63c4c commit f704010
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 137 deletions.
18 changes: 18 additions & 0 deletions modules/cockroachdb/certs.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cockroachdb

import (
"crypto/tls"
"crypto/x509"
"fmt"
"net"
Expand Down Expand Up @@ -65,3 +66,20 @@ func NewTLSConfig() (*TLSConfig, error) {
ClientKey: clientCert.KeyBytes,
}, nil
}

// tlsConfig returns a [tls.Config] for options.
func (c *TLSConfig) tlsConfig() (*tls.Config, error) {
keyPair, err := tls.X509KeyPair(c.ClientCert, c.ClientKey)
if err != nil {
return nil, fmt.Errorf("x509 key pair: %w", err)
}

certPool := x509.NewCertPool()
certPool.AddCert(c.CACert)

return &tls.Config{
RootCAs: certPool,
Certificates: []tls.Certificate{keyPair},
ServerName: "localhost",
}, nil
}
118 changes: 28 additions & 90 deletions modules/cockroachdb/cockroachdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,19 @@ package cockroachdb
import (
"context"
"crypto/tls"
"crypto/x509"
"database/sql"
"encoding/pem"
"fmt"
"net"
"net/url"
"path/filepath"

"github.com/docker/go-connections/nat"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/stdlib"

"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
)

// ErrTLSNotEnabled is returned when trying to get a TLS config from a container that does not have TLS enabled.
var ErrTLSNotEnabled = fmt.Errorf("tls not enabled")

const (
Expand All @@ -39,7 +36,9 @@ type CockroachDBContainer struct {
opts options
}

// MustConnectionString panics if the address cannot be determined.
// MustConnectionString returns a connection string to open a new connection to CockroachDB
// as described by [CockroachDBContainer.ConnectionString].
// It panics if an error occurs.
func (c *CockroachDBContainer) MustConnectionString(ctx context.Context) string {
addr, err := c.ConnectionString(ctx)
if err != nil {
Expand All @@ -48,27 +47,33 @@ func (c *CockroachDBContainer) MustConnectionString(ctx context.Context) string
return addr
}

// ConnectionString returns the dial address to open a new connection to CockroachDB.
// ConnectionString returns a connection string to open a new connection to CockroachDB.
// The returned string is suitable for use by [sql.Open] but is not be compatible with
// [pgx.ParseConfig], so if you want to call [pgx.ConnectConfig] use the
// [CockroachDBContainer.ConnectionConfig] method instead.
func (c *CockroachDBContainer) ConnectionString(ctx context.Context) (string, error) {
port, err := c.MappedPort(ctx, defaultSQLPort)
if err != nil {
return "", err
}

host, err := c.Host(ctx)
if err != nil {
return "", err
}
return c.opts.containerConnString(ctx, c.Container)
}

return connString(c.opts, host, port), nil
// ConnectionConfig returns a [pgx.ConnConfig] for the CockroachDB container.
// This can be passed to [pgx.ConnectConfig] to open a new connection.
func (c *CockroachDBContainer) ConnectionConfig(ctx context.Context) (*pgx.ConnConfig, error) {
return c.opts.containerConnConfig(ctx, c.Container)
}

// TLSConfig returns config necessary to connect to CockroachDB over TLS.
//
// Deprecated: use [CockroachDBContainer.ConnectionConfig] or
// [CockroachDBContainer.ConnectionConfig] instead.
func (c *CockroachDBContainer) TLSConfig() (*tls.Config, error) {
return connTLS(c.opts)
if c.opts.TLS == nil {
return nil, ErrTLSNotEnabled
}

return c.opts.TLS.tlsConfig()
}

// Deprecated: use Run instead
// Deprecated: use Run instead.
// RunContainer creates an instance of the CockroachDB container type
func RunContainer(ctx context.Context, opts ...testcontainers.ContainerCustomizer) (*CockroachDBContainer, error) {
return Run(ctx, "cockroachdb/cockroach:latest-v23.1", opts...)
Expand Down Expand Up @@ -177,29 +182,12 @@ func addEnvs(req *testcontainers.GenericContainerRequest, opts options) error {
}

func addWaitingFor(req *testcontainers.GenericContainerRequest, opts options) error {
var tlsConfig *tls.Config
if opts.TLS != nil {
cfg, err := connTLS(opts)
if err != nil {
return err
}
tlsConfig = cfg
}

sqlWait := wait.ForSQL(defaultSQLPort, "pgx/v5", func(host string, port nat.Port) string {
connStr := connString(opts, host, port)
if tlsConfig == nil {
return connStr
}

// register TLS config with pgx driver
connCfg, err := pgx.ParseConfig(connStr)
connStr, err := opts.connString(host, port)
if err != nil {
panic(err)
}
connCfg.TLSConfig = tlsConfig

return stdlib.RegisterConnConfig(connCfg)
return connStr
})
defaultStrategy := wait.ForAll(
wait.ForHTTP("/health").WithPort(defaultAdminPort),
Expand Down Expand Up @@ -245,17 +233,12 @@ func runStatements(ctx context.Context, container testcontainers.Container, opts
return nil
}

port, err := container.MappedPort(ctx, defaultSQLPort)
if err != nil {
return fmt.Errorf("mapped port: %w", err)
}

host, err := container.Host(ctx)
connStr, err := opts.containerConnString(ctx, container)
if err != nil {
return fmt.Errorf("host: %w", err)
return fmt.Errorf("connection string: %w", err)
}

db, err := sql.Open("pgx/v5", connString(opts, host, port))
db, err := sql.Open("pgx/v5", connStr)
if err != nil {
return fmt.Errorf("sql.Open: %w", err)
}
Expand All @@ -274,48 +257,3 @@ func runStatements(ctx context.Context, container testcontainers.Container, opts

return nil
}

func connString(opts options, host string, port nat.Port) string {
user := url.User(opts.User)
if opts.Password != "" {
user = url.UserPassword(opts.User, opts.Password)
}

sslMode := "disable"
if opts.TLS != nil {
sslMode = "verify-full"
}
params := url.Values{
"sslmode": []string{sslMode},
}

u := url.URL{
Scheme: "postgres",
User: user,
Host: net.JoinHostPort(host, port.Port()),
Path: opts.Database,
RawQuery: params.Encode(),
}

return u.String()
}

func connTLS(opts options) (*tls.Config, error) {
if opts.TLS == nil {
return nil, ErrTLSNotEnabled
}

keyPair, err := tls.X509KeyPair(opts.TLS.ClientCert, opts.TLS.ClientKey)
if err != nil {
return nil, err
}

certPool := x509.NewCertPool()
certPool.AddCert(opts.TLS.CACert)

return &tls.Config{
RootCAs: certPool,
Certificates: []tls.Certificate{keyPair},
ServerName: "localhost",
}, nil
}
52 changes: 14 additions & 38 deletions modules/cockroachdb/cockroachdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@ package cockroachdb_test

import (
"context"
"errors"
"net/url"
"strings"
"testing"
"time"

Expand All @@ -18,14 +15,11 @@ import (
)

func TestCockroach_Insecure(t *testing.T) {
suite.Run(t, &AuthNSuite{
url: "postgres://root@localhost:xxxxx/defaultdb?sslmode=disable",
})
suite.Run(t, &AuthNSuite{})
}

func TestCockroach_NotRoot(t *testing.T) {
suite.Run(t, &AuthNSuite{
url: "postgres://test@localhost:xxxxx/defaultdb?sslmode=disable",
opts: []testcontainers.ContainerCustomizer{
cockroachdb.WithUser("test"),
// Do not run the default statements as the user used on this test is
Expand All @@ -37,7 +31,6 @@ func TestCockroach_NotRoot(t *testing.T) {

func TestCockroach_Password(t *testing.T) {
suite.Run(t, &AuthNSuite{
url: "postgres://foo:bar@localhost:xxxxx/defaultdb?sslmode=disable",
opts: []testcontainers.ContainerCustomizer{
cockroachdb.WithUser("foo"),
cockroachdb.WithPassword("bar"),
Expand All @@ -53,19 +46,26 @@ func TestCockroach_TLS(t *testing.T) {
require.NoError(t, err)

suite.Run(t, &AuthNSuite{
url: "postgres://root@localhost:xxxxx/defaultdb?sslmode=verify-full",
opts: []testcontainers.ContainerCustomizer{
cockroachdb.WithTLS(tlsCfg),
// Do not run the default statements as the user used on this test is
// lacking the needed MODIFYCLUSTERSETTING privilege to run them.
cockroachdb.WithStatements(),
},
})
}

func TestTLS(t *testing.T) {
tlsCfg, err := cockroachdb.NewTLSConfig()
require.NoError(t, err)

ctx := context.Background()

ctr, err := cockroachdb.Run(ctx, "cockroachdb/cockroach:latest-v23.1", cockroachdb.WithTLS(tlsCfg))
testcontainers.CleanupContainer(t, ctr)
require.NoError(t, err)
require.NotNil(t, ctr)
}

type AuthNSuite struct {
suite.Suite
url string
opts []testcontainers.ContainerCustomizer
}

Expand All @@ -75,11 +75,6 @@ func (suite *AuthNSuite) TestConnectionString() {
ctr, err := cockroachdb.Run(ctx, "cockroachdb/cockroach:latest-v23.1", suite.opts...)
testcontainers.CleanupContainer(suite.T(), ctr)
suite.Require().NoError(err)

connStr, err := removePort(ctr.MustConnectionString(ctx))
suite.Require().NoError(err)

suite.Equal(suite.url, connStr)
}

func (suite *AuthNSuite) TestPing() {
Expand Down Expand Up @@ -203,29 +198,10 @@ func (suite *AuthNSuite) TestWithWaitStrategyAndDeadline() {
}

func conn(ctx context.Context, container *cockroachdb.CockroachDBContainer) (*pgx.Conn, error) {
cfg, err := pgx.ParseConfig(container.MustConnectionString(ctx))
cfg, err := container.ConnectionConfig(ctx)
if err != nil {
return nil, err
}

tlsCfg, err := container.TLSConfig()
switch {
case err != nil:
if !errors.Is(err, cockroachdb.ErrTLSNotEnabled) {
return nil, err
}
default:
// apply TLS config
cfg.TLSConfig = tlsCfg
}

return pgx.ConnectConfig(ctx, cfg)
}

func removePort(s string) (string, error) {
u, err := url.Parse(s)
if err != nil {
return "", err
}
return strings.Replace(s, ":"+u.Port(), ":xxxxx", 1), nil
}
25 changes: 17 additions & 8 deletions modules/cockroachdb/examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import (
"database/sql"
"fmt"
"log"
"net/url"

"github.com/jackc/pgx/v5"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/cockroachdb"
)
Expand Down Expand Up @@ -34,25 +34,34 @@ func ExampleRun() {
}
fmt.Println(state.Running)

addr, err := cockroachdbContainer.ConnectionString(ctx)
cfg, err := cockroachdbContainer.ConnectionConfig(ctx)
if err != nil {
log.Printf("failed to get connection string: %s", err)
return
}
u, err := url.Parse(addr)

conn, err := pgx.ConnectConfig(ctx, cfg)
if err != nil {
log.Printf("failed to parse connection string: %s", err)
log.Printf("failed to connect: %s", err)
return
}

defer func() {
if err := conn.Close(ctx); err != nil {
log.Printf("failed to close connection: %s", err)
}
}()

if err = conn.Ping(ctx); err != nil {
log.Printf("failed to ping: %s", err)
return
}
u.Host = fmt.Sprintf("%s:%s", u.Hostname(), "xxx")
fmt.Println(u.String())

// Output:
// true
// postgres://root@localhost:xxx/defaultdb?sslmode=disable
}

func ExampleRun_withRecommendedSettings() {
func ExampleRun_withCustomStatements() {
ctx := context.Background()

cockroachdbContainer, err := cockroachdb.Run(ctx, "cockroachdb/cockroach:latest-v23.1", cockroachdb.WithStatements(cockroachdb.DefaultStatements...))
Expand Down
Loading

0 comments on commit f704010

Please sign in to comment.