Skip to content

Commit

Permalink
feat: connectrpc realip interceptor (#1728)
Browse files Browse the repository at this point in the history
Depends On: #1715 

This is needed because we can't use the
[realip](https://github.com/grpc-ecosystem/go-grpc-middleware/tree/main/interceptors/realip)
interceptor with connectrpc.

---------

Co-authored-by: Jake Van Vorhis <83739412+jakedoublev@users.noreply.github.com>
  • Loading branch information
strantalis and jakedoublev authored Nov 7, 2024
1 parent 3cdd1b2 commit 292fca0
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 0 deletions.
62 changes: 62 additions & 0 deletions service/internal/server/realip/realip.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package realip

import (
"context"
"net"
"net/http"
"net/netip"
"strings"

"connectrpc.com/connect"
)

const (
XRealIP = "X-Real-IP"
XForwardedFor = "X-Forwarded-For"
TrueClientIP = "True-Client-Ip"
)

type clientIP struct{}

func ConnectRealIPUnaryInterceptor() connect.UnaryInterceptorFunc {
interceptor := func(next connect.UnaryFunc) connect.UnaryFunc {
return connect.UnaryFunc(func(
ctx context.Context,
req connect.AnyRequest,
) (connect.AnyResponse, error) {
ip := getIP(ctx, req.Peer(), req.Header())

ctx = context.WithValue(ctx, clientIP{}, ip)

return next(ctx, req)
})
}
return connect.UnaryInterceptorFunc(interceptor)
}

func getIP(_ context.Context, peer connect.Peer, headers http.Header) net.IP {
for _, header := range []string{XRealIP, XForwardedFor, TrueClientIP} {
if ip := headers.Get(header); ip != "" {
ips := strings.Split(ip, ",")
if ips[0] == "" || net.ParseIP(ips[0]) == nil {
continue
}
return net.ParseIP(ips[0])
}
}

ip, err := netip.ParseAddrPort(peer.Addr)
if err != nil {
return net.IP{}
}

return net.IP(ip.Addr().AsSlice())
}

func FromContext(ctx context.Context) net.IP {
ip, ok := ctx.Value(clientIP{}).(net.IP)
if !ok {
return net.IP{}
}
return ip
}
57 changes: 57 additions & 0 deletions service/internal/server/realip/realip_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package realip

import (
"context"
"net/http"
"testing"

"connectrpc.com/connect"
"github.com/stretchr/testify/suite"
)

type RealIPTestSuite struct {
suite.Suite
}

func TestRealIPSuite(t *testing.T) {
suite.Run(t, new(RealIPTestSuite))
}

func (s *RealIPTestSuite) Test_getIP_from_x_real_ip_header() {
ip := "1.1.1.1"
peer := connect.Peer{}

headers := http.Header{}
headers.Add(XRealIP, ip)
foundIP := getIP(context.Background(), peer, headers)
s.Equal(ip, foundIP.String())
}

func (s *RealIPTestSuite) Test_getIP_from_x_forwarded_for_header() {
ip := "1.1.1.1"
peer := connect.Peer{}

headers := http.Header{}
headers.Add(XForwardedFor, ip)
foundIP := getIP(context.Background(), peer, headers)
s.Equal(ip, foundIP.String())
}

func (s *RealIPTestSuite) Test_getIP_from_true_client_ip_header() {
ip := "1.1.1.1"
peer := connect.Peer{}

headers := http.Header{}
headers.Add(TrueClientIP, ip)
foundIP := getIP(context.Background(), peer, headers)
s.Equal(ip, foundIP.String())
}

func (s *RealIPTestSuite) Test_getIP_from_peer() {
ip := "1.1.1.1"
peer := connect.Peer{Addr: ip + ":1234"}

headers := http.Header{}
foundIP := getIP(context.Background(), peer, headers)
s.Equal(ip, foundIP.String())
}

0 comments on commit 292fca0

Please sign in to comment.