Skip to content

Commit

Permalink
feat: add oidc middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
francesconi committed Nov 15, 2024
1 parent 6d3f689 commit f2d3424
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 2 deletions.
11 changes: 10 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,19 @@ module github.com/HGV/x

go 1.22

require github.com/stretchr/testify v1.9.0
require (
github.com/coreos/go-oidc/v3 v3.11.0
github.com/stretchr/testify v1.9.0
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-jose/go-jose/v4 v4.0.2 // indirect
github.com/kr/pretty v0.3.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.13.1 // indirect
golang.org/x/crypto v0.27.0 // indirect
golang.org/x/oauth2 v0.21.0 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
27 changes: 26 additions & 1 deletion go.sum
Original file line number Diff line number Diff line change
@@ -1,10 +1,35 @@
github.com/coreos/go-oidc/v3 v3.11.0 h1:Ia3MxdwpSw702YW0xgfmP1GVCMA9aEFWu12XUZ3/OtI=
github.com/coreos/go-oidc/v3 v3.11.0/go.mod h1:gE3LgjOgFoHi9a4ce4/tJczr0Ai2/BoDhf0r5lltWI0=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-jose/go-jose/v4 v4.0.2 h1:R3l3kkBds16bO7ZFAEEcofK0MkrAJt3jlJznWZG0nvk=
github.com/go-jose/go-jose/v4 v4.0.2/go.mod h1:WVf9LFMHh/QVrmqrOfqun0C45tMe3RoiKJMPvgWwLfY=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A=
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70=
golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs=
golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
104 changes: 104 additions & 0 deletions middlewarex/oidc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package middlewarex

import (
"context"
"errors"
"net/http"
"strings"

"github.com/coreos/go-oidc/v3/oidc"
)

type (
OIDCMiddlewareOption func(*oidcMiddleware)

oidcMiddleware struct {
authFailedHandler func(error) http.HandlerFunc
config *oidc.Config
verifier *oidc.IDTokenVerifier
}
oidcContextKey int
)

const (
idTokenContextKey oidcContextKey = iota
)

func OIDC(ctx context.Context, issuer string, opts ...OIDCMiddlewareOption) func(next http.Handler) http.Handler {
provider, err := oidc.NewProvider(ctx, issuer)
if err != nil {
panic(err)
}

mw := oidcMiddleware{
authFailedHandler: oidcAuthFailed,
}

for _, opt := range opts {
opt(&mw)
}

if mw.config == nil {
mw.verifier = provider.Verifier(&oidc.Config{})
} else {
mw.verifier = provider.Verifier(mw.config)
}

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
bearerToken, ok := validateAuthHeader(authHeader, "Bearer ")
if !ok {
mw.authFailedHandler(errors.New("bearer token is missing or invalid")).ServeHTTP(w, r)
return
}

idToken, err := mw.verifier.Verify(r.Context(), bearerToken)
if err != nil {
mw.authFailedHandler(err).ServeHTTP(w, r)
return
}

ctx := contextWithIDToken(r.Context(), idToken)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}

func contextWithIDToken(ctx context.Context, idToken *oidc.IDToken) context.Context {
return context.WithValue(ctx, idTokenContextKey, idToken)
}

func IDTokenFromContext(ctx context.Context) *oidc.IDToken {
if idToken, ok := ctx.Value(idTokenContextKey).(*oidc.IDToken); ok {
return idToken
}
return nil
}

func validateAuthHeader(s, scheme string) (string, bool) {
if len(s) >= len(scheme) && strings.EqualFold(s[0:len(scheme)], scheme) {
return s[len(scheme):], true
}
return s, false
}

func oidcAuthFailed(err error) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
}
}

func WithAuthFailedHandler(h func(error) http.HandlerFunc) OIDCMiddlewareOption {
return func(opt *oidcMiddleware) {
if h != nil {
opt.authFailedHandler = h
}
}
}

func WithOIDCConfig(c oidc.Config) OIDCMiddlewareOption {
return func(opt *oidcMiddleware) {
opt.config = &c
}
}
81 changes: 81 additions & 0 deletions middlewarex/oidc_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package middlewarex

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/coreos/go-oidc/v3/oidc"
"github.com/stretchr/testify/assert"
)

func TestValidateAuthHeader(t *testing.T) {
tests := []struct {
authHeader string
scheme string
expectedToken string
expectedOk bool
}{
{
authHeader: "", scheme: "bearer",
expectedToken: "", expectedOk: false,
},
{
authHeader: "bearer token", scheme: "bearer ",
expectedToken: "token", expectedOk: true,
},
{
authHeader: "BEARER token", scheme: "bearer ",
expectedToken: "token", expectedOk: true,
},
}

for _, tt := range tests {
t.Run("", func(t *testing.T) {
token, ok := validateAuthHeader(tt.authHeader, tt.scheme)
assert.Equal(t, tt.expectedOk, ok)
assert.Equal(t, tt.expectedToken, token)
})
}
}

func TestHandler(t *testing.T) {
issuer := "https://api.accounts.hgv.it"
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.NotNil(t, IDTokenFromContext(r.Context()))
w.WriteHeader(http.StatusTeapot)
})

t.Run("unauthorized", func(t *testing.T) {
h := OIDC(context.Background(), issuer)(next)
r := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
assert.Equal(t, http.StatusUnauthorized, w.Code)
})

t.Run("overwrite default error handler", func(t *testing.T) {
h := OIDC(context.Background(), issuer, WithAuthFailedHandler(func(err error) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusForbidden)
}
}))(next)
r := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
assert.Equal(t, http.StatusForbidden, w.Code)
})

t.Run("valid expired token", func(t *testing.T) {
h := OIDC(context.Background(), issuer, WithOIDCConfig(oidc.Config{
SkipClientIDCheck: true,
SkipExpiryCheck: true,
}))(next)
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Add("Authorization", "Bearer eyJhbGciOiJSUzI1NiIsImtpZCI6IjJlNTc0NjE3LTJlYzYtNGNhNy1hYTE2LThiYTYyMWRlMGI3YSIsInR5cCI6IkpXVCJ9.eyJhdWQiOltdLCJjbGllbnRfaWQiOiJjMDI2NTZiZC00NzZkLTQ1MGYtOWMwZC0zN2ZiMDhiYTI3MjEiLCJleHAiOjE3MzE2Njk3NDMsImV4dCI6e30sImlhdCI6MTczMTY2NjE0MywiaXNzIjoiaHR0cHM6Ly9hcGkuYWNjb3VudHMuaGd2Lml0IiwianRpIjoiZjk3YWE1ODAtZjZmNC00ZGQ3LTlkMDgtMjM1YTM5ZGU4ZWZlIiwibmJmIjoxNzMxNjY2MTQzLCJzY3AiOltdLCJzdWIiOiJjMDI2NTZiZC00NzZkLTQ1MGYtOWMwZC0zN2ZiMDhiYTI3MjEifQ.IeIc2EWCYjH8EaYClYpaTpYz-DDRbpu4vRuzirmBXZy28r7OazSrJdRSEa2a_G9Yq0UzmJXeBtPAouvsQdwmHX1PdBFzwwqLPT4kXcxMmlX6RvnTy-95wVfXnJJP-cGU5U4sMKKFGnsecAQotesEsYk19Dxylr5RMA-DsgwwpN8GQuf4KdLJk4IDJx8Z-FlfAG4XMODGM2S3sqGCwc6b5nQUXa_cUTIMqJCyUdb3Kd3OcQHKEK0o0esG1CBgqj3RrRE98BejeEjR5LOYiQpY1aAklmxa_3UOtEi9Bej1PRyybRxV7QbNE8_K0WVdj3CCedbtpK7DB0mNGCtas2bjiFxsr9MBHUtDcU3taXEoEkSqye7vIbLgd66SFm5gq78-PeJEvbwYqpt4LB7b7F-ZpyhCU-3T3SNkMPHY-q7hIBPauRbJbtWdK3w_xjjjCJdgjspk-CEyOUfhogjKmavxcuuXOGBphOeJ7WCRMTlmv9ira0DZqwBCQTGitkGGT98l4guaIYoB27Zsl-wdgxK2F0AwjvHFTYNUsG3Nf9NJ4ULjPMusBBA9hHBoO1UrlNWgXEpJWvr5YV_vt0Omlqvv-ci7M3Rx1-MjRyBYTQRxVRLhtDtGK4TbW4jCEIE38_k5IDqH6WxaUsgxTxFu8rx5xWhpRlKuIQRrDyWA1ylMo_U")
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
assert.Equal(t, http.StatusTeapot, w.Code)
})
}

0 comments on commit f2d3424

Please sign in to comment.