From 62073571dfcbbf9dfa041426814a49a9f26d57e6 Mon Sep 17 00:00:00 2001 From: "zhouyiheng.go" Date: Sun, 2 Apr 2023 23:54:36 +0800 Subject: [PATCH] feat: custom json and base64 encoders for Token and Parser Co-Authored-By: Christian Banse --- encoder.go | 17 ++++++++++++++ parser.go | 29 +++++++++++++++++++---- parser_option.go | 14 +++++++++++ parser_test.go | 60 +++++++++++++++++++++++++++++++++++++++++++++++- token.go | 42 ++++++++++++++++++++++++--------- token_option.go | 12 ++++++++++ token_test.go | 23 +++++++++++++++++++ 7 files changed, 180 insertions(+), 17 deletions(-) create mode 100644 encoder.go diff --git a/encoder.go b/encoder.go new file mode 100644 index 00000000..1b48152d --- /dev/null +++ b/encoder.go @@ -0,0 +1,17 @@ +package jwt + +// Base64Encoder is an interface that allows to implement custom Base64 encoding +// algorithms. +type Base64EncodeFunc func(src []byte) string + +// Base64Decoder is an interface that allows to implement custom Base64 decoding +// algorithms. +type Base64DecodeFunc func(s string) ([]byte, error) + +// JSONEncoder is an interface that allows to implement custom JSON encoding +// algorithms. +type JSONMarshalFunc func(v any) ([]byte, error) + +// JSONUnmarshal is an interface that allows to implement custom JSON unmarshal +// algorithms. +type JSONUnmarshalFunc func(data []byte, v any) error diff --git a/parser.go b/parser.go index 1ed2e4e4..68aa2212 100644 --- a/parser.go +++ b/parser.go @@ -12,7 +12,7 @@ type Parser struct { // If populated, only these methods will be considered valid. validMethods []string - // Use JSON Number format in JSON decoder. + // Use JSON Number format in JSON decoder. This field is disabled when using a custom json encoder. useJSONNumber bool // Skip claims validation during token parsing. @@ -20,9 +20,14 @@ type Parser struct { validator *validator + // This field is disabled when using a custom base64 encoder. decodeStrict bool + // This field is disabled when using a custom base64 encoder. decodePaddingAllowed bool + + unmarshalFunc JSONUnmarshalFunc + base64DecodeFunc Base64DecodeFunc } // NewParser creates a new Parser with the specified options @@ -148,7 +153,17 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke if headerBytes, err = p.DecodeSegment(parts[0]); err != nil { return token, parts, newError("could not base64 decode header", ErrTokenMalformed, err) } - if err = json.Unmarshal(headerBytes, &token.Header); err != nil { + + // Choose our JSON decoder. If no custom function is supplied, we use the standard library. + var unmarshal JSONUnmarshalFunc + if p.unmarshalFunc != nil { + unmarshal = p.unmarshalFunc + } else { + unmarshal = json.Unmarshal + } + + err = unmarshal(headerBytes, &token.Header) + if err != nil { return token, parts, newError("could not JSON decode header", ErrTokenMalformed, err) } @@ -162,13 +177,13 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke // If `useJSONNumber` is enabled then we must use *json.Decoder to decode // the claims. However, this comes with a performance penalty so only use - // it if we must and, otherwise, simple use json.Unmarshal. + // it if we must and, otherwise, simple use our decode function. if !p.useJSONNumber { // JSON Unmarshal. Special case for map type to avoid weird pointer behavior. if c, ok := token.Claims.(MapClaims); ok { - err = json.Unmarshal(claimBytes, &c) + err = unmarshal(claimBytes, &c) } else { - err = json.Unmarshal(claimBytes, &claims) + err = unmarshal(claimBytes, &claims) } } else { dec := json.NewDecoder(bytes.NewBuffer(claimBytes)) @@ -200,6 +215,10 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke // take into account whether the [Parser] is configured with additional options, // such as [WithStrictDecoding] or [WithPaddingAllowed]. func (p *Parser) DecodeSegment(seg string) ([]byte, error) { + if p.base64DecodeFunc != nil { + return p.base64DecodeFunc(seg) + } + encoding := base64.RawURLEncoding if p.decodePaddingAllowed { diff --git a/parser_option.go b/parser_option.go index 1b5af970..662a03ff 100644 --- a/parser_option.go +++ b/parser_option.go @@ -118,3 +118,17 @@ func WithStrictDecoding() ParserOption { p.decodeStrict = true } } + +// WithJSONUnmarshal supports a custom [JSONUnmarshal] to use in parsing the JWT. +func WithJSONUnmarshal(f JSONUnmarshalFunc) ParserOption { + return func(p *Parser) { + p.unmarshalFunc = f + } +} + +// WithBase64Decoder supports a custom [Base64Decoder] to use in parsing the JWT. +func WithBase64Decoder(f Base64DecodeFunc) ParserOption { + return func(p *Parser) { + p.base64DecodeFunc = f + } +} diff --git a/parser_test.go b/parser_test.go index 1825dfc3..05c58d31 100644 --- a/parser_test.go +++ b/parser_test.go @@ -3,6 +3,7 @@ package jwt_test import ( "crypto" "crypto/rsa" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -72,6 +73,7 @@ var jwtTestData = []struct { err []error parser *jwt.Parser signingMethod jwt.SigningMethod // The method to sign the JWT token for test purpose + options []jwt.ParserOption }{ { "invalid JWT", @@ -82,6 +84,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenMalformed}, nil, jwt.SigningMethodRS256, + nil, }, { "invalid JSON claim", @@ -92,6 +95,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenMalformed}, nil, jwt.SigningMethodRS256, + nil, }, { "bearer in JWT", @@ -102,6 +106,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenMalformed}, nil, jwt.SigningMethodRS256, + nil, }, { "basic", @@ -112,6 +117,7 @@ var jwtTestData = []struct { nil, nil, jwt.SigningMethodRS256, + nil, }, { "multiple keys, last matches", @@ -122,6 +128,7 @@ var jwtTestData = []struct { nil, nil, jwt.SigningMethodRS256, + nil, }, { "multiple keys not []interface{} type, all match", @@ -132,6 +139,7 @@ var jwtTestData = []struct { nil, nil, jwt.SigningMethodRS256, + nil, }, { "multiple keys, first matches", @@ -142,6 +150,7 @@ var jwtTestData = []struct { nil, nil, jwt.SigningMethodRS256, + nil, }, { "public keys slice, not allowed", @@ -152,6 +161,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenSignatureInvalid}, nil, jwt.SigningMethodRS256, + nil, }, { "basic expired", @@ -162,6 +172,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenExpired}, nil, jwt.SigningMethodRS256, + nil, }, { "basic nbf", @@ -172,6 +183,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenNotValidYet}, nil, jwt.SigningMethodRS256, + nil, }, { "expired and nbf", @@ -182,6 +194,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenNotValidYet, jwt.ErrTokenExpired}, nil, jwt.SigningMethodRS256, + nil, }, { "basic invalid", @@ -192,6 +205,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenSignatureInvalid, rsa.ErrVerification}, nil, jwt.SigningMethodRS256, + nil, }, { "basic nokeyfunc", @@ -202,6 +216,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenUnverifiable}, nil, jwt.SigningMethodRS256, + nil, }, { "basic nokey", @@ -212,6 +227,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenSignatureInvalid}, nil, jwt.SigningMethodRS256, + nil, }, { "multiple nokey", @@ -222,6 +238,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenSignatureInvalid}, nil, jwt.SigningMethodRS256, + nil, }, { "empty verification key set", @@ -232,6 +249,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenUnverifiable}, nil, jwt.SigningMethodRS256, + nil, }, { "zero length key list", @@ -242,6 +260,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenSignatureInvalid}, nil, jwt.SigningMethodRS256, + nil, }, { "basic errorkey", @@ -252,6 +271,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenUnverifiable, errKeyFuncError}, nil, jwt.SigningMethodRS256, + nil, }, { "invalid signing method", @@ -262,6 +282,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenSignatureInvalid}, jwt.NewParser(jwt.WithValidMethods([]string{"HS256"})), jwt.SigningMethodRS256, + nil, }, { "valid RSA signing method", @@ -272,6 +293,7 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithValidMethods([]string{"RS256", "HS256"})), jwt.SigningMethodRS256, + nil, }, { "ECDSA signing method not accepted", @@ -282,6 +304,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenSignatureInvalid}, jwt.NewParser(jwt.WithValidMethods([]string{"RS256", "HS256"})), jwt.SigningMethodES256, + nil, }, { "valid ECDSA signing method", @@ -292,6 +315,7 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithValidMethods([]string{"HS256", "ES256"})), jwt.SigningMethodES256, + nil, }, { "JSON Number", @@ -302,6 +326,7 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, + nil, }, { "JSON Number - basic expired", @@ -312,6 +337,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenExpired}, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, + nil, }, { "JSON Number - basic nbf", @@ -322,6 +348,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenNotValidYet}, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, + nil, }, { "JSON Number - expired and nbf", @@ -332,6 +359,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenNotValidYet, jwt.ErrTokenExpired}, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, + nil, }, { "SkipClaimsValidation during token parsing", @@ -342,6 +370,7 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithJSONNumber(), jwt.WithoutClaimsValidation()), jwt.SigningMethodRS256, + nil, }, { "RFC7519 Claims", @@ -354,6 +383,7 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, + nil, }, { "RFC7519 Claims - single aud", @@ -366,6 +396,7 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, + nil, }, { "RFC7519 Claims - multiple aud", @@ -378,6 +409,7 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, + nil, }, { "RFC7519 Claims - single aud with wrong type", @@ -390,6 +422,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenMalformed}, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, + nil, }, { "RFC7519 Claims - multiple aud with wrong types", @@ -402,6 +435,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenMalformed}, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, + nil, }, { "RFC7519 Claims - nbf with 60s skew", @@ -412,6 +446,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenNotValidYet}, jwt.NewParser(jwt.WithLeeway(time.Minute)), jwt.SigningMethodRS256, + nil, }, { "RFC7519 Claims - nbf with 120s skew", @@ -422,6 +457,29 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithLeeway(2 * time.Minute)), jwt.SigningMethodRS256, + nil, + }, + { + "custom json encoder", + "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + defaultKeyFunc, + jwt.MapClaims{"foo": "bar"}, + true, + nil, + nil, + jwt.SigningMethodRS256, + []jwt.ParserOption{jwt.WithJSONUnmarshal(json.Unmarshal)}, + }, + { + "custom base64 encoder", + "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + defaultKeyFunc, + jwt.MapClaims{"foo": "bar"}, + true, + nil, + nil, + jwt.SigningMethodRS256, + []jwt.ParserOption{jwt.WithBase64Decoder(base64.RawURLEncoding.DecodeString)}, }, } @@ -453,7 +511,7 @@ func TestParser_Parse(t *testing.T) { var err error var parser = data.parser if parser == nil { - parser = jwt.NewParser() + parser = jwt.NewParser(data.options...) } // Figure out correct claims type switch data.claims.(type) { diff --git a/token.go b/token.go index c8c9d4db..c4511440 100644 --- a/token.go +++ b/token.go @@ -28,12 +28,14 @@ type VerificationKeySet struct { // Token represents a JWT Token. Different fields will be used depending on // whether you're creating or parsing/verifying a token. type Token struct { - Raw string // Raw contains the raw token. Populated when you [Parse] a token - Method SigningMethod // Method is the signing method used or to be used - Header map[string]interface{} // Header is the first segment of the token in decoded form - Claims Claims // Claims is the second segment of the token in decoded form - Signature []byte // Signature is the third segment of the token in decoded form. Populated when you Parse a token - Valid bool // Valid specifies if the token is valid. Populated when you Parse/Verify a token + Raw string // Raw contains the raw token. Populated when you [Parse] a token + Method SigningMethod // Method is the signing method used or to be used + Header map[string]interface{} // Header is the first segment of the token in decoded form + Claims Claims // Claims is the second segment of the token in decoded form + Signature []byte // Signature is the third segment of the token in decoded form. Populated when you Parse a token + Valid bool // Valid specifies if the token is valid. Populated when you Parse/Verify a token + jsonEncoder JSONMarshalFunc // jsonEncoder is the custom json encoder/decoder + base64Encoder Base64EncodeFunc // base64Encoder is the custom base64 encoder/decoder } // New creates a new [Token] with the specified signing method and an empty map @@ -45,7 +47,7 @@ func New(method SigningMethod, opts ...TokenOption) *Token { // NewWithClaims creates a new [Token] with the specified signing method and // claims. Additional options can be specified, but are currently unused. func NewWithClaims(method SigningMethod, claims Claims, opts ...TokenOption) *Token { - return &Token{ + t := &Token{ Header: map[string]interface{}{ "typ": "JWT", "alg": method.Alg(), @@ -53,6 +55,10 @@ func NewWithClaims(method SigningMethod, claims Claims, opts ...TokenOption) *To Claims: claims, Method: method, } + for _, opt := range opts { + opt(t) + } + return t } // SignedString creates and returns a complete, signed JWT. The token is signed @@ -78,12 +84,19 @@ func (t *Token) SignedString(key interface{}) (string, error) { // of the whole deal. Unless you need this for something special, just go // straight for the SignedString. func (t *Token) SigningString() (string, error) { - h, err := json.Marshal(t.Header) + var marshal JSONMarshalFunc + if t.jsonEncoder != nil { + marshal = t.jsonEncoder + } else { + marshal = json.Marshal + } + + h, err := marshal(t.Header) if err != nil { return "", err } - c, err := json.Marshal(t.Claims) + c, err := marshal(t.Claims) if err != nil { return "", err } @@ -95,6 +108,13 @@ func (t *Token) SigningString() (string, error) { // stripped. In the future, this function might take into account a // [TokenOption]. Therefore, this function exists as a method of [Token], rather // than a global function. -func (*Token) EncodeSegment(seg []byte) string { - return base64.RawURLEncoding.EncodeToString(seg) +func (t *Token) EncodeSegment(seg []byte) string { + var enc Base64EncodeFunc + if t.base64Encoder != nil { + enc = t.base64Encoder + } else { + enc = base64.RawURLEncoding.EncodeToString + } + + return enc(seg) } diff --git a/token_option.go b/token_option.go index b4ae3bad..3e4c20b6 100644 --- a/token_option.go +++ b/token_option.go @@ -3,3 +3,15 @@ package jwt // TokenOption is a reserved type, which provides some forward compatibility, // if we ever want to introduce token creation-related options. type TokenOption func(*Token) + +func WithJSONEncoder(f JSONMarshalFunc) TokenOption { + return func(token *Token) { + token.jsonEncoder = f + } +} + +func WithBase64Encoder(f Base64EncodeFunc) TokenOption { + return func(token *Token) { + token.base64Encoder = f + } +} diff --git a/token_test.go b/token_test.go index f18329e0..7c76fade 100644 --- a/token_test.go +++ b/token_test.go @@ -1,6 +1,8 @@ package jwt_test import ( + "encoding/base64" + "encoding/json" "testing" "github.com/golang-jwt/jwt/v5" @@ -14,6 +16,7 @@ func TestToken_SigningString(t1 *testing.T) { Claims jwt.Claims Signature []byte Valid bool + Options []jwt.TokenOption } tests := []struct { name string @@ -21,6 +24,22 @@ func TestToken_SigningString(t1 *testing.T) { want string wantErr bool }{ + { + name: "", + fields: fields{ + Raw: "", + Method: jwt.SigningMethodHS256, + Header: map[string]interface{}{ + "typ": "JWT", + "alg": jwt.SigningMethodHS256.Alg(), + }, + Claims: jwt.RegisteredClaims{}, + Valid: false, + Options: nil, + }, + want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30", + wantErr: false, + }, { name: "", fields: fields{ @@ -32,6 +51,10 @@ func TestToken_SigningString(t1 *testing.T) { }, Claims: jwt.RegisteredClaims{}, Valid: false, + Options: []jwt.TokenOption{ + jwt.WithJSONEncoder(json.Marshal), + jwt.WithBase64Encoder(base64.StdEncoding.EncodeToString), + }, }, want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30", wantErr: false,