From 9b9fa568fad861c376181092fb07b4cb7456efee Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Wed, 16 Feb 2022 20:27:52 +0100 Subject: [PATCH] Basic implementation of `client_credentials` grant (#2) --- .codecov.yml | 5 ++++ .gitignore | 1 + go.mod | 5 +++- go.sum | 2 ++ server.go | 75 ++++++++++++++++++++++++++++++++++++++++++++++++-- server_test.go | 18 +++++++++--- 6 files changed, 99 insertions(+), 7 deletions(-) create mode 100644 .codecov.yml diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 0000000..3a3bdee --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,5 @@ +coverage: + status: + project: + default: + threshold: 0.5% diff --git a/.gitignore b/.gitignore index ec83cb6..ee9e54c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ server .DS_Store +coverage.cov \ No newline at end of file diff --git a/go.mod b/go.mod index 3182824..7ad733d 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,7 @@ module github.com/oxisto/oauth2go go 1.16 -require golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 +require ( + github.com/golang-jwt/jwt/v4 v4.3.0 + golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 +) diff --git a/go.sum b/go.sum index 18ee05a..26c46a0 100644 --- a/go.sum +++ b/go.sum @@ -47,6 +47,8 @@ github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7 github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/golang-jwt/jwt/v4 v4.3.0 h1:kHL1vqdqWNfATmA0FNMdmZNMyZI1U6O31X4rlIPoBog= +github.com/golang-jwt/jwt/v4 v4.3.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= diff --git a/server.go b/server.go index b538149..bae81b7 100644 --- a/server.go +++ b/server.go @@ -6,9 +6,13 @@ import ( "crypto/rand" "encoding/base64" "encoding/json" + "errors" "fmt" "net/http" + "strings" + "time" + "github.com/golang-jwt/jwt/v4" "golang.org/x/oauth2" ) @@ -114,10 +118,37 @@ func (srv *server) handleToken(w http.ResponseWriter, r *http.Request) { } } +// doClientCredentialsFlow implements the Client Credentials Grant +// flow (see https://datatracker.ietf.org/doc/html/rfc6749#section-4.4). func (srv *server) doClientCredentialsFlow(w http.ResponseWriter, r *http.Request) { - var token oauth2.Token + var ( + err error + token oauth2.Token + client *client + expiry time.Time + ) + + // Retrieve the client + if client, err = srv.retrieveClient(r); err != nil { + w.Header().Set("WWW-Authenticate", "Basic") + writeError(w, 401, "invalid_client") + return + } + + expiry = time.Now().Add(time.Hour * 24) + + // Create a new JWT + t := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.RegisteredClaims{ + Subject: client.clientID, + ExpiresAt: jwt.NewNumericDate(expiry), + }) + t.Header["kid"] = 1 - token.AccessToken = "some access token" + token.TokenType = "Bearer" + token.Expiry = expiry + if token.AccessToken, err = t.SignedString(srv.signingKey); err != nil { + writeError(w, 500, "error while creating JWT") + } writeJSON(w, &token) } @@ -144,6 +175,46 @@ func (srv *server) handleJWKS(w http.ResponseWriter, r *http.Request) { writeJSON(w, &keySet) } +func (srv *server) retrieveClient(r *http.Request) (*client, error) { + var ( + idx int + b []byte + authorization string + basic string + clientID string + clientSecret string + ) + + authorization = r.Header.Get("authorization") + idx = strings.Index(authorization, "Basic ") + if idx == -1 { + return nil, errors.New("invalid authentication scheme") + } + + b, err := base64.StdEncoding.DecodeString(authorization[idx+6:]) + if err != nil { + return nil, fmt.Errorf("could not decode basic authentication: %w", err) + } + + basic = string(b) + idx = strings.Index(basic, ":") + if idx == -1 { + return nil, errors.New("misformed basic authentication") + } + + clientID = basic[0:idx] + clientSecret = basic[idx+1:] + + // Look for a matching client + for _, c := range srv.clients { + if c.clientID == clientID && c.clientSecret == clientSecret { + return c, nil + } + } + + return nil, errors.New("no matching client") +} + func writeError(w http.ResponseWriter, statusCode int, error string) { w.Header().Set("Content-Type", "application/json") diff --git a/server_test.go b/server_test.go index d609f74..f2c2d34 100644 --- a/server_test.go +++ b/server_test.go @@ -9,6 +9,7 @@ import ( "net/http" "testing" + "github.com/golang-jwt/jwt/v4" "golang.org/x/oauth2/clientcredentials" ) @@ -44,10 +45,10 @@ func Test_server_handleToken(t *testing.T) { } func TestIntegration(t *testing.T) { - srv := NewServer(":0", WithClient("admin", "admin"), WithClient("client", "secret")) + srv := NewServer(":0", WithUser("admin", "admin"), WithClient("client", "secret")) ln, err := net.Listen("tcp", srv.Addr) if err != nil { - t.Errorf("[TestIntegration] Error while listening key: %v", err) + t.Errorf("Error while listening key: %v", err) } go srv.Serve(ln) @@ -55,14 +56,23 @@ func TestIntegration(t *testing.T) { config := clientcredentials.Config{ ClientID: "client", - ClientSecret: "client", + ClientSecret: "secret", TokenURL: fmt.Sprintf("http://localhost:%d/token", ln.Addr().(*net.TCPAddr).Port), } token, err := config.Token(context.Background()) if err != nil { - t.Errorf("[TestIntegration] Error while retrieving a token: %v", err) + t.Errorf("Error while retrieving a token: %v", err) } log.Printf("Token: %s", token.AccessToken) + + jwtoken, err := jwt.ParseWithClaims(token.AccessToken, &jwt.RegisteredClaims{}, func(t *jwt.Token) (interface{}, error) { + return &srv.signingKey.PublicKey, nil + }) + if err != nil { + t.Errorf("Error while retrieving a token: %v", err) + } + + log.Printf("JWT: %+v", jwtoken) }