From f3a2b9651e036d70949624d2ae33bbd8e3e76d97 Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Tue, 14 Mar 2023 14:55:36 +0100 Subject: [PATCH] Adding CORS header everywhere (#58) Previously, it was only added to some calls, making some things not work. --- jwks.go | 2 +- metadata.go | 2 +- server.go | 24 ++++++++++++++---------- server_test.go | 6 ++++-- 4 files changed, 20 insertions(+), 14 deletions(-) diff --git a/jwks.go b/jwks.go index 490e205..eb27fb6 100644 --- a/jwks.go +++ b/jwks.go @@ -48,5 +48,5 @@ func (srv *AuthorizationServer) handleJWKS(w http.ResponseWriter, r *http.Reques }) } - writeJSON(w, keySet) + srv.writeJSON(w, keySet) } diff --git a/metadata.go b/metadata.go index 1bad0cd..6c46105 100644 --- a/metadata.go +++ b/metadata.go @@ -37,5 +37,5 @@ func (srv *AuthorizationServer) handleMetadata(w http.ResponseWriter, r *http.Re return } - writeJSON(w, srv.metadata) + srv.writeJSON(w, srv.metadata) } diff --git a/server.go b/server.go index 7c2e377..0f0ca02 100644 --- a/server.go +++ b/server.go @@ -199,7 +199,7 @@ func (srv *AuthorizationServer) doClientCredentialsFlow(w http.ResponseWriter, r return } - writeToken(w, token) + srv.writeToken(w, token) } // doAuthorizationCodeFlow implements the Authorization Code Grant @@ -213,10 +213,6 @@ func (srv *AuthorizationServer) doAuthorizationCodeFlow(w http.ResponseWriter, r client *Client ) - if srv.allowedOrigin != "" { - w.Header().Add("Access-Control-Allow-Origin", srv.allowedOrigin) - } - // Retrieve the client client, err = srv.retrieveClient(r, true) if err != nil { @@ -245,7 +241,7 @@ func (srv *AuthorizationServer) doAuthorizationCodeFlow(w http.ResponseWriter, r return } - writeToken(w, token) + srv.writeToken(w, token) } // doRefreshTokenFlow implements refreshing an access token. @@ -304,7 +300,7 @@ issue: return } - writeToken(w, token) + srv.writeToken(w, token) } // GetClient returns the client for the given ID or ErrClientNotFound. @@ -444,6 +440,12 @@ func (srv *AuthorizationServer) GenerateToken(clientID string, signingKeyID int, return } +func (srv *AuthorizationServer) cors(w http.ResponseWriter) { + if srv.allowedOrigin != "" { + w.Header().Add("Access-Control-Allow-Origin", srv.allowedOrigin) + } +} + func Error(w http.ResponseWriter, error string, statusCode int) { w.Header().Set("Content-Type", "application/json") @@ -463,7 +465,7 @@ func RedirectError(w http.ResponseWriter, http.Redirect(w, r, fmt.Sprintf("%s?%s", redirectURI, params.Encode()), http.StatusFound) } -func writeToken(w http.ResponseWriter, token *oauth2.Token) { +func (srv *AuthorizationServer) writeToken(w http.ResponseWriter, token *oauth2.Token) { // We need to transform this into our own struct, otherwise // the expiry will be translated into a string representation, // while it should be represented as seconds. @@ -479,12 +481,14 @@ func writeToken(w http.ResponseWriter, token *oauth2.Token) { Expiry: int(time.Until(token.Expiry).Seconds()), } - writeJSON(w, s) + srv.writeJSON(w, s) } -func writeJSON(w http.ResponseWriter, value interface{}) { +func (srv *AuthorizationServer) writeJSON(w http.ResponseWriter, value interface{}) { w.Header().Set("Content-Type", "application/json") + srv.cors(w) + if err := json.NewEncoder(w).Encode(value); err != nil { Error(w, "could not encode JSON", http.StatusInternalServerError) return diff --git a/server_test.go b/server_test.go index a78709c..3ab6206 100644 --- a/server_test.go +++ b/server_test.go @@ -215,7 +215,8 @@ func TestAuthorizationServer_retrieveClient(t *testing.T) { } } -func Test_writeJSON(t *testing.T) { +func TestAuthorizationServer_writeJSON(t *testing.T) { + type fields struct{} type args struct { w http.ResponseWriter value interface{} @@ -239,7 +240,8 @@ func Test_writeJSON(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - writeJSON(tt.args.w, tt.args.value) + srv := &AuthorizationServer{} + srv.writeJSON(tt.args.w, tt.args.value) var rr *httptest.ResponseRecorder switch v := tt.args.w.(type) {