diff --git a/service/go.mod b/service/go.mod index 295a21df8..51d945bca 100644 --- a/service/go.mod +++ b/service/go.mod @@ -28,7 +28,6 @@ require ( github.com/spf13/viper v1.18.2 github.com/stretchr/testify v1.9.0 github.com/testcontainers/testcontainers-go v0.32.0 - github.com/valyala/fasthttp v1.52.0 golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 google.golang.org/grpc v1.66.0 google.golang.org/protobuf v1.34.2 @@ -39,6 +38,7 @@ require ( github.com/Microsoft/hcsshim v0.12.0 // indirect github.com/OneOfOne/xxhash v1.2.8 // indirect github.com/agnivade/levenshtein v1.1.1 // indirect + github.com/andybalholm/brotli v1.1.0 // indirect github.com/antlr4-go/antlr/v4 v4.13.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/casbin/govaluate v1.1.0 // indirect diff --git a/service/go.sum b/service/go.sum index 9fd67e525..98234e600 100644 --- a/service/go.sum +++ b/service/go.sum @@ -367,10 +367,6 @@ github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+F github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= github.com/tursodatabase/libsql-client-go v0.0.0-20231216154754-8383a53d618f h1:teZ0Pj1Wp3Wk0JObKBiKZqgxhYwLeJhVAyj6DRgmQtY= github.com/tursodatabase/libsql-client-go v0.0.0-20231216154754-8383a53d618f/go.mod h1:UMde0InJz9I0Le/1YIR4xsB0E2vb01MrDY6k/eNdfkg= -github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= -github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasthttp v1.52.0 h1:wqBQpxH71XW0e2g+Og4dzQM8pk34aFYlA1Ga8db7gU0= -github.com/valyala/fasthttp v1.52.0/go.mod h1:hf5C4QnVMkNXMspnsUlfM3WitlgYflyhHYoKol/szxQ= github.com/vertica/vertica-sql-go v1.3.3 h1:fL+FKEAEy5ONmsvya2WH5T8bhkvY27y/Ik3ReR2T+Qw= github.com/vertica/vertica-sql-go v1.3.3/go.mod h1:jnn2GFuv+O2Jcjktb7zyc4Utlbu9YVqpHH/lx63+1M4= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= diff --git a/service/internal/server/memhttp/listener.go b/service/internal/server/memhttp/listener.go new file mode 100644 index 000000000..1ec303ddb --- /dev/null +++ b/service/internal/server/memhttp/listener.go @@ -0,0 +1,59 @@ +package memhttp + +import ( + "context" + "errors" + "net" + "sync" +) + +type memoryListener struct { + conns chan net.Conn + once sync.Once + closed chan struct{} +} + +// Accept implements net.Listener. +func (l *memoryListener) Accept() (net.Conn, error) { + select { + case conn := <-l.conns: + return conn, nil + case <-l.closed: + return nil, errors.New("listener closed") + } +} + +// Close implements net.Listener. +func (l *memoryListener) Close() error { + l.once.Do(func() { + close(l.closed) + }) + return nil +} + +// Addr implements net.Listener. +func (l *memoryListener) Addr() net.Addr { + return &memoryAddr{} +} + +// DialContext is the type expected by http.Transport.DialContext. +func (l *memoryListener) DialContext(ctx context.Context, _, _ string) (net.Conn, error) { + server, client := net.Pipe() + select { + case <-l.closed: + return nil, errors.New("listener closed") + case l.conns <- server: + return client, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +type memoryAddr struct{} + +// Network implements net.Addr. +func (*memoryAddr) Network() string { return "memory" } + +// String implements io.Stringer, returning a value that matches the +// certificates used by net/http/httptest. +func (*memoryAddr) String() string { return "opentdf.io" } diff --git a/service/internal/server/memhttp/memhttp.go b/service/internal/server/memhttp/memhttp.go new file mode 100644 index 000000000..5f053e27c --- /dev/null +++ b/service/internal/server/memhttp/memhttp.go @@ -0,0 +1,138 @@ +// Package memhttp provides an in-memory HTTP server and client. For +// testing-specific adapters, see the memhttptest subpackage. +package memhttp + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net" + "net/http" + "time" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" +) + +// Server is a net/http server that uses in-memory pipes instead of TCP. By +// default, it has TLS enabled and supports HTTP/2. It otherwise uses the same +// configuration as the zero value of [http.Server]. +type Server struct { + server *http.Server + Listener *memoryListener + url string + serveErr chan error + cleanupContext func() (context.Context, context.CancelFunc) +} + +// New constructs and starts a Server. +func New(handler http.Handler, opts ...Option) *Server { + var cfg config + WithCleanupTimeout(5 * time.Second).apply(&cfg) //nolint:mnd // Specific to cleanup timeout. + for _, opt := range opts { + opt.apply(&cfg) + } + mlis := &memoryListener{ + conns: make(chan net.Conn), + closed: make(chan struct{}), + } + + var lis net.Listener = mlis + + http2Server := &http2.Server{} + + handler = h2c.NewHandler(handler, http2Server) + + server := &http.Server{ + Handler: handler, + ReadHeaderTimeout: 5 * time.Second, //nolint:mnd // Specific to read header timeout. + } + + serveErr := make(chan error, 1) + go func() { + serveErr <- server.Serve(lis) + }() + + return &Server{ + server: server, + Listener: mlis, + url: fmt.Sprintf("http://%s", mlis.Addr().String()), + serveErr: serveErr, + cleanupContext: cfg.CleanupContext, + } +} + +// Transport returns an [http2.Transport] configured to use in-memory pipes +// rather than TCP, disable automatic compression, trust the server's TLS +// certificate (if any), and use HTTP/2 (if the server supports it). +// +// Callers may reconfigure the returned Transport without affecting other +// transports or clients. +func (s *Server) Transport() *http2.Transport { + transport := &http2.Transport{ + DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) { + return s.Listener.DialContext(ctx, network, addr) + }, + AllowHTTP: true, + } + + return transport +} + +// Client returns an [http.Client] configured to use in-memory pipes rather +// than TCP, disable automatic compression, trust the server's TLS certificate +// (if any), and use HTTP/2 (if the server supports it). +// +// Callers may reconfigure the returned client without affecting other clients. +func (s *Server) Client() *http.Client { + return &http.Client{Transport: s.Transport()} +} + +// URL returns the server's URL. +func (s *Server) URL() string { + return s.url +} + +// Close immediately shuts down the server. To shut down the server without +// interrupting in-flight requests, use Shutdown. +func (s *Server) Close() error { + if err := s.server.Close(); err != nil { + return err + } + return s.listenErr() +} + +// Shutdown gracefully shuts down the server, without interrupting any active +// connections. See [http.Server.Shutdown] for details. +func (s *Server) Shutdown(ctx context.Context) error { + if err := s.server.Shutdown(ctx); err != nil { + return err + } + return s.listenErr() +} + +// Cleanup calls Shutdown with a five second timeout. To customize the timeout, +// use WithCleanupTimeout. +// +// Cleanup is primarily intended for use in tests. If you find yourself using +// it, you may want to use the memhttptest package instead. +func (s *Server) Cleanup() error { + ctx, cancel := s.cleanupContext() + defer cancel() + return s.Shutdown(ctx) +} + +// RegisterOnShutdown registers a function to call on Shutdown. It's often used +// to cleanly shut down connections that have been hijacked. See +// [http.Server.RegisterOnShutdown] for details. +func (s *Server) RegisterOnShutdown(f func()) { + s.server.RegisterOnShutdown(f) +} + +func (s *Server) listenErr() error { + if err := <-s.serveErr; err != nil && !errors.Is(err, http.ErrServerClosed) { + return err + } + return nil +} diff --git a/service/internal/server/memhttp/option.go b/service/internal/server/memhttp/option.go new file mode 100644 index 000000000..6b9666efa --- /dev/null +++ b/service/internal/server/memhttp/option.go @@ -0,0 +1,47 @@ +package memhttp + +import ( + "context" + "log" + "time" +) + +type config struct { + CleanupContext func() (context.Context, context.CancelFunc) + ErrorLog *log.Logger +} + +// An Option configures a Server. +type Option interface { + apply(*config) +} + +type optionFunc func(*config) + +func (f optionFunc) apply(cfg *config) { f(cfg) } + +// WithOptions composes multiple Options into one. +func WithOptions(opts ...Option) Option { + return optionFunc(func(cfg *config) { + for _, opt := range opts { + opt.apply(cfg) + } + }) +} + +// WithCleanupTimeout customizes the default five-second timeout for the +// server's Cleanup method. It's most useful with the memhttptest subpackage. +func WithCleanupTimeout(d time.Duration) Option { + return optionFunc(func(cfg *config) { + cfg.CleanupContext = func() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), d) + } + }) +} + +// WithErrorLog sets [http.Server.ErrorLog]. +func WithErrorLog(l *log.Logger) Option { + return optionFunc(func(cfg *config) { + cfg.ErrorLog = l + }) +} diff --git a/service/internal/server/server.go b/service/internal/server/server.go index d66461b6a..f923730e7 100644 --- a/service/internal/server/server.go +++ b/service/internal/server/server.go @@ -21,9 +21,9 @@ import ( sdkAudit "github.com/opentdf/platform/sdk/audit" "github.com/opentdf/platform/service/internal/auth" "github.com/opentdf/platform/service/internal/security" + "github.com/opentdf/platform/service/internal/server/memhttp" "github.com/opentdf/platform/service/logger" "github.com/opentdf/platform/service/logger/audit" - "github.com/valyala/fasthttp/fasthttputil" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" "google.golang.org/grpc" @@ -122,7 +122,7 @@ https://github.com/heroku/x/blob/master/grpc/grpcserver/inprocess.go https://github.com/valyala/fasthttp/blob/master/fasthttputil/inmemory_listener.go */ type inProcessServer struct { - ln *fasthttputil.InmemoryListener + ln *memhttp.Server srv *grpc.Server maxCallRecvMsgSize int @@ -157,9 +157,12 @@ func NewOpenTDFServer(config Config, logger *logger.Logger) (*OpenTDFServer, err if err != nil { return nil, fmt.Errorf("failed to create grpc server: %w", err) } + + grpcInProcessServer := newGrpcInProcessServer() + grpcIPCServer := &inProcessServer{ - ln: fasthttputil.NewInmemoryListener(), - srv: newGrpcInProcessServer(), + ln: memhttp.New(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { grpcInProcessServer.ServeHTTP(w, r) })), + srv: grpcInProcessServer, maxCallRecvMsgSize: config.GRPC.MaxCallRecvMsgSizeBytes, maxCallSendMsgSize: config.GRPC.MaxCallSendMsgSizeBytes, } @@ -407,8 +410,8 @@ func (s inProcessServer) Conn() *grpc.ClientConn { grpc.MaxCallRecvMsgSize(s.maxCallRecvMsgSize), grpc.MaxCallSendMsgSize(s.maxCallSendMsgSize), ), - grpc.WithContextDialer(func(_ context.Context, _ string) (net.Conn, error) { - conn, err := s.ln.Dial() + grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { + conn, err := s.ln.Listener.DialContext(ctx, "inprocess", "") if err != nil { return nil, fmt.Errorf("failed to dial in process grpc server: %w", err) } @@ -424,7 +427,7 @@ func (s inProcessServer) Conn() *grpc.ClientConn { func (s OpenTDFServer) startInProcessGrpcServer() { s.logger.Info("starting in process grpc server") - if err := s.GRPCInProcess.srv.Serve(s.GRPCInProcess.ln); err != nil { + if err := s.GRPCInProcess.srv.Serve(s.GRPCInProcess.ln.Listener); err != nil { s.logger.Error("failed to serve in process grpc", slog.String("error", err.Error())) panic(err) }