From f2d3424935f7ffadd7fda510833033692217cb59 Mon Sep 17 00:00:00 2001 From: Daniel Francesconi Date: Fri, 15 Nov 2024 14:13:04 +0100 Subject: [PATCH 1/2] feat: add oidc middleware --- go.mod | 11 ++++- go.sum | 27 +++++++++- middlewarex/oidc.go | 104 +++++++++++++++++++++++++++++++++++++++ middlewarex/oidc_test.go | 81 ++++++++++++++++++++++++++++++ 4 files changed, 221 insertions(+), 2 deletions(-) create mode 100644 middlewarex/oidc.go create mode 100644 middlewarex/oidc_test.go diff --git a/go.mod b/go.mod index 73ce4e3..d3f132c 100644 --- a/go.mod +++ b/go.mod @@ -2,10 +2,19 @@ module github.com/HGV/x go 1.22 -require github.com/stretchr/testify v1.9.0 +require ( + github.com/coreos/go-oidc/v3 v3.11.0 + github.com/stretchr/testify v1.9.0 +) require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-jose/go-jose/v4 v4.0.2 // indirect + github.com/kr/pretty v0.3.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rogpeppe/go-internal v1.13.1 // indirect + golang.org/x/crypto v0.27.0 // indirect + golang.org/x/oauth2 v0.21.0 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 60ce688..ff3bf54 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,35 @@ +github.com/coreos/go-oidc/v3 v3.11.0 h1:Ia3MxdwpSw702YW0xgfmP1GVCMA9aEFWu12XUZ3/OtI= +github.com/coreos/go-oidc/v3 v3.11.0/go.mod h1:gE3LgjOgFoHi9a4ce4/tJczr0Ai2/BoDhf0r5lltWI0= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-jose/go-jose/v4 v4.0.2 h1:R3l3kkBds16bO7ZFAEEcofK0MkrAJt3jlJznWZG0nvk= +github.com/go-jose/go-jose/v4 v4.0.2/go.mod h1:WVf9LFMHh/QVrmqrOfqun0C45tMe3RoiKJMPvgWwLfY= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= +golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= +golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs= +golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/middlewarex/oidc.go b/middlewarex/oidc.go new file mode 100644 index 0000000..80547e7 --- /dev/null +++ b/middlewarex/oidc.go @@ -0,0 +1,104 @@ +package middlewarex + +import ( + "context" + "errors" + "net/http" + "strings" + + "github.com/coreos/go-oidc/v3/oidc" +) + +type ( + OIDCMiddlewareOption func(*oidcMiddleware) + + oidcMiddleware struct { + authFailedHandler func(error) http.HandlerFunc + config *oidc.Config + verifier *oidc.IDTokenVerifier + } + oidcContextKey int +) + +const ( + idTokenContextKey oidcContextKey = iota +) + +func OIDC(ctx context.Context, issuer string, opts ...OIDCMiddlewareOption) func(next http.Handler) http.Handler { + provider, err := oidc.NewProvider(ctx, issuer) + if err != nil { + panic(err) + } + + mw := oidcMiddleware{ + authFailedHandler: oidcAuthFailed, + } + + for _, opt := range opts { + opt(&mw) + } + + if mw.config == nil { + mw.verifier = provider.Verifier(&oidc.Config{}) + } else { + mw.verifier = provider.Verifier(mw.config) + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + bearerToken, ok := validateAuthHeader(authHeader, "Bearer ") + if !ok { + mw.authFailedHandler(errors.New("bearer token is missing or invalid")).ServeHTTP(w, r) + return + } + + idToken, err := mw.verifier.Verify(r.Context(), bearerToken) + if err != nil { + mw.authFailedHandler(err).ServeHTTP(w, r) + return + } + + ctx := contextWithIDToken(r.Context(), idToken) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +func contextWithIDToken(ctx context.Context, idToken *oidc.IDToken) context.Context { + return context.WithValue(ctx, idTokenContextKey, idToken) +} + +func IDTokenFromContext(ctx context.Context) *oidc.IDToken { + if idToken, ok := ctx.Value(idTokenContextKey).(*oidc.IDToken); ok { + return idToken + } + return nil +} + +func validateAuthHeader(s, scheme string) (string, bool) { + if len(s) >= len(scheme) && strings.EqualFold(s[0:len(scheme)], scheme) { + return s[len(scheme):], true + } + return s, false +} + +func oidcAuthFailed(err error) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + } +} + +func WithAuthFailedHandler(h func(error) http.HandlerFunc) OIDCMiddlewareOption { + return func(opt *oidcMiddleware) { + if h != nil { + opt.authFailedHandler = h + } + } +} + +func WithOIDCConfig(c oidc.Config) OIDCMiddlewareOption { + return func(opt *oidcMiddleware) { + opt.config = &c + } +} diff --git a/middlewarex/oidc_test.go b/middlewarex/oidc_test.go new file mode 100644 index 0000000..03b7e10 --- /dev/null +++ b/middlewarex/oidc_test.go @@ -0,0 +1,81 @@ +package middlewarex + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/stretchr/testify/assert" +) + +func TestValidateAuthHeader(t *testing.T) { + tests := []struct { + authHeader string + scheme string + expectedToken string + expectedOk bool + }{ + { + authHeader: "", scheme: "bearer", + expectedToken: "", expectedOk: false, + }, + { + authHeader: "bearer token", scheme: "bearer ", + expectedToken: "token", expectedOk: true, + }, + { + authHeader: "BEARER token", scheme: "bearer ", + expectedToken: "token", expectedOk: true, + }, + } + + for _, tt := range tests { + t.Run("", func(t *testing.T) { + token, ok := validateAuthHeader(tt.authHeader, tt.scheme) + assert.Equal(t, tt.expectedOk, ok) + assert.Equal(t, tt.expectedToken, token) + }) + } +} + +func TestHandler(t *testing.T) { + issuer := "https://api.accounts.hgv.it" + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.NotNil(t, IDTokenFromContext(r.Context())) + w.WriteHeader(http.StatusTeapot) + }) + + t.Run("unauthorized", func(t *testing.T) { + h := OIDC(context.Background(), issuer)(next) + r := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + assert.Equal(t, http.StatusUnauthorized, w.Code) + }) + + t.Run("overwrite default error handler", func(t *testing.T) { + h := OIDC(context.Background(), issuer, WithAuthFailedHandler(func(err error) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + } + }))(next) + r := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + assert.Equal(t, http.StatusForbidden, w.Code) + }) + + t.Run("valid expired token", func(t *testing.T) { + h := OIDC(context.Background(), issuer, WithOIDCConfig(oidc.Config{ + SkipClientIDCheck: true, + SkipExpiryCheck: true, + }))(next) + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Add("Authorization", "Bearer eyJhbGciOiJSUzI1NiIsImtpZCI6IjJlNTc0NjE3LTJlYzYtNGNhNy1hYTE2LThiYTYyMWRlMGI3YSIsInR5cCI6IkpXVCJ9.eyJhdWQiOltdLCJjbGllbnRfaWQiOiJjMDI2NTZiZC00NzZkLTQ1MGYtOWMwZC0zN2ZiMDhiYTI3MjEiLCJleHAiOjE3MzE2Njk3NDMsImV4dCI6e30sImlhdCI6MTczMTY2NjE0MywiaXNzIjoiaHR0cHM6Ly9hcGkuYWNjb3VudHMuaGd2Lml0IiwianRpIjoiZjk3YWE1ODAtZjZmNC00ZGQ3LTlkMDgtMjM1YTM5ZGU4ZWZlIiwibmJmIjoxNzMxNjY2MTQzLCJzY3AiOltdLCJzdWIiOiJjMDI2NTZiZC00NzZkLTQ1MGYtOWMwZC0zN2ZiMDhiYTI3MjEifQ.IeIc2EWCYjH8EaYClYpaTpYz-DDRbpu4vRuzirmBXZy28r7OazSrJdRSEa2a_G9Yq0UzmJXeBtPAouvsQdwmHX1PdBFzwwqLPT4kXcxMmlX6RvnTy-95wVfXnJJP-cGU5U4sMKKFGnsecAQotesEsYk19Dxylr5RMA-DsgwwpN8GQuf4KdLJk4IDJx8Z-FlfAG4XMODGM2S3sqGCwc6b5nQUXa_cUTIMqJCyUdb3Kd3OcQHKEK0o0esG1CBgqj3RrRE98BejeEjR5LOYiQpY1aAklmxa_3UOtEi9Bej1PRyybRxV7QbNE8_K0WVdj3CCedbtpK7DB0mNGCtas2bjiFxsr9MBHUtDcU3taXEoEkSqye7vIbLgd66SFm5gq78-PeJEvbwYqpt4LB7b7F-ZpyhCU-3T3SNkMPHY-q7hIBPauRbJbtWdK3w_xjjjCJdgjspk-CEyOUfhogjKmavxcuuXOGBphOeJ7WCRMTlmv9ira0DZqwBCQTGitkGGT98l4guaIYoB27Zsl-wdgxK2F0AwjvHFTYNUsG3Nf9NJ4ULjPMusBBA9hHBoO1UrlNWgXEpJWvr5YV_vt0Omlqvv-ci7M3Rx1-MjRyBYTQRxVRLhtDtGK4TbW4jCEIE38_k5IDqH6WxaUsgxTxFu8rx5xWhpRlKuIQRrDyWA1ylMo_U") + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + assert.Equal(t, http.StatusTeapot, w.Code) + }) +} From 692f5f91d8adbb7bfb92e7ac6905162ecb3947f7 Mon Sep 17 00:00:00 2001 From: Daniel Francesconi Date: Wed, 20 Nov 2024 11:31:13 +0100 Subject: [PATCH 2/2] refactor: change return type of helper function --- middlewarex/oidc.go | 12 ++++-------- middlewarex/oidc_test.go | 14 ++++++++------ 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/middlewarex/oidc.go b/middlewarex/oidc.go index 80547e7..15a6bde 100644 --- a/middlewarex/oidc.go +++ b/middlewarex/oidc.go @@ -20,9 +20,7 @@ type ( oidcContextKey int ) -const ( - idTokenContextKey oidcContextKey = iota -) +const idTokenContextKey oidcContextKey = iota func OIDC(ctx context.Context, issuer string, opts ...OIDCMiddlewareOption) func(next http.Handler) http.Handler { provider, err := oidc.NewProvider(ctx, issuer) @@ -69,11 +67,9 @@ func contextWithIDToken(ctx context.Context, idToken *oidc.IDToken) context.Cont return context.WithValue(ctx, idTokenContextKey, idToken) } -func IDTokenFromContext(ctx context.Context) *oidc.IDToken { - if idToken, ok := ctx.Value(idTokenContextKey).(*oidc.IDToken); ok { - return idToken - } - return nil +func IDTokenFromContext(ctx context.Context) (idToken *oidc.IDToken, ok bool) { + idToken, ok = ctx.Value(idTokenContextKey).(*oidc.IDToken) + return } func validateAuthHeader(s, scheme string) (string, bool) { diff --git a/middlewarex/oidc_test.go b/middlewarex/oidc_test.go index 03b7e10..2e07844 100644 --- a/middlewarex/oidc_test.go +++ b/middlewarex/oidc_test.go @@ -15,26 +15,26 @@ func TestValidateAuthHeader(t *testing.T) { authHeader string scheme string expectedToken string - expectedOk bool + expectedOK bool }{ { authHeader: "", scheme: "bearer", - expectedToken: "", expectedOk: false, + expectedToken: "", expectedOK: false, }, { authHeader: "bearer token", scheme: "bearer ", - expectedToken: "token", expectedOk: true, + expectedToken: "token", expectedOK: true, }, { authHeader: "BEARER token", scheme: "bearer ", - expectedToken: "token", expectedOk: true, + expectedToken: "token", expectedOK: true, }, } for _, tt := range tests { t.Run("", func(t *testing.T) { token, ok := validateAuthHeader(tt.authHeader, tt.scheme) - assert.Equal(t, tt.expectedOk, ok) + assert.Equal(t, tt.expectedOK, ok) assert.Equal(t, tt.expectedToken, token) }) } @@ -43,7 +43,9 @@ func TestValidateAuthHeader(t *testing.T) { func TestHandler(t *testing.T) { issuer := "https://api.accounts.hgv.it" next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.NotNil(t, IDTokenFromContext(r.Context())) + idToken, ok := IDTokenFromContext(r.Context()) + assert.True(t, ok) + assert.NotNil(t, idToken) w.WriteHeader(http.StatusTeapot) })