From 90d315ce0ce5b738514a3a953213f84e0d22a0a7 Mon Sep 17 00:00:00 2001 From: "zhouyiheng.go" Date: Sun, 2 Apr 2023 23:54:36 +0800 Subject: [PATCH 1/8] feat: custom json and base64 encoders for Token and Parser Co-Authored-By: Christian Banse --- encoder.go | 17 ++++++++++++++ parser.go | 29 +++++++++++++++++++---- parser_option.go | 14 +++++++++++ parser_test.go | 60 +++++++++++++++++++++++++++++++++++++++++++++++- token.go | 42 ++++++++++++++++++++++++--------- token_option.go | 12 ++++++++++ token_test.go | 23 +++++++++++++++++++ 7 files changed, 180 insertions(+), 17 deletions(-) create mode 100644 encoder.go diff --git a/encoder.go b/encoder.go new file mode 100644 index 00000000..1b48152d --- /dev/null +++ b/encoder.go @@ -0,0 +1,17 @@ +package jwt + +// Base64Encoder is an interface that allows to implement custom Base64 encoding +// algorithms. +type Base64EncodeFunc func(src []byte) string + +// Base64Decoder is an interface that allows to implement custom Base64 decoding +// algorithms. +type Base64DecodeFunc func(s string) ([]byte, error) + +// JSONEncoder is an interface that allows to implement custom JSON encoding +// algorithms. +type JSONMarshalFunc func(v any) ([]byte, error) + +// JSONUnmarshal is an interface that allows to implement custom JSON unmarshal +// algorithms. +type JSONUnmarshalFunc func(data []byte, v any) error diff --git a/parser.go b/parser.go index ecf99af7..07e30ab9 100644 --- a/parser.go +++ b/parser.go @@ -12,7 +12,7 @@ type Parser struct { // If populated, only these methods will be considered valid. validMethods []string - // Use JSON Number format in JSON decoder. + // Use JSON Number format in JSON decoder. This field is disabled when using a custom json encoder. useJSONNumber bool // Skip claims validation during token parsing. @@ -20,9 +20,14 @@ type Parser struct { validator *Validator + // This field is disabled when using a custom base64 encoder. decodeStrict bool + // This field is disabled when using a custom base64 encoder. decodePaddingAllowed bool + + unmarshalFunc JSONUnmarshalFunc + base64DecodeFunc Base64DecodeFunc } // NewParser creates a new Parser with the specified options @@ -148,7 +153,17 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke if headerBytes, err = p.DecodeSegment(parts[0]); err != nil { return token, parts, newError("could not base64 decode header", ErrTokenMalformed, err) } - if err = json.Unmarshal(headerBytes, &token.Header); err != nil { + + // Choose our JSON decoder. If no custom function is supplied, we use the standard library. + var unmarshal JSONUnmarshalFunc + if p.unmarshalFunc != nil { + unmarshal = p.unmarshalFunc + } else { + unmarshal = json.Unmarshal + } + + err = unmarshal(headerBytes, &token.Header) + if err != nil { return token, parts, newError("could not JSON decode header", ErrTokenMalformed, err) } @@ -162,13 +177,13 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke // If `useJSONNumber` is enabled then we must use *json.Decoder to decode // the claims. However, this comes with a performance penalty so only use - // it if we must and, otherwise, simple use json.Unmarshal. + // it if we must and, otherwise, simple use our decode function. if !p.useJSONNumber { // JSON Unmarshal. Special case for map type to avoid weird pointer behavior. if c, ok := token.Claims.(MapClaims); ok { - err = json.Unmarshal(claimBytes, &c) + err = unmarshal(claimBytes, &c) } else { - err = json.Unmarshal(claimBytes, &claims) + err = unmarshal(claimBytes, &claims) } } else { dec := json.NewDecoder(bytes.NewBuffer(claimBytes)) @@ -200,6 +215,10 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke // take into account whether the [Parser] is configured with additional options, // such as [WithStrictDecoding] or [WithPaddingAllowed]. func (p *Parser) DecodeSegment(seg string) ([]byte, error) { + if p.base64DecodeFunc != nil { + return p.base64DecodeFunc(seg) + } + encoding := base64.RawURLEncoding if p.decodePaddingAllowed { diff --git a/parser_option.go b/parser_option.go index 88a780fb..735774eb 100644 --- a/parser_option.go +++ b/parser_option.go @@ -126,3 +126,17 @@ func WithStrictDecoding() ParserOption { p.decodeStrict = true } } + +// WithJSONUnmarshal supports a custom [JSONUnmarshal] to use in parsing the JWT. +func WithJSONUnmarshal(f JSONUnmarshalFunc) ParserOption { + return func(p *Parser) { + p.unmarshalFunc = f + } +} + +// WithBase64Decoder supports a custom [Base64Decoder] to use in parsing the JWT. +func WithBase64Decoder(f Base64DecodeFunc) ParserOption { + return func(p *Parser) { + p.base64DecodeFunc = f + } +} diff --git a/parser_test.go b/parser_test.go index c0f81711..d2b7dd39 100644 --- a/parser_test.go +++ b/parser_test.go @@ -3,6 +3,7 @@ package jwt_test import ( "crypto" "crypto/rsa" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -72,6 +73,7 @@ var jwtTestData = []struct { err []error parser *jwt.Parser signingMethod jwt.SigningMethod // The method to sign the JWT token for test purpose + options []jwt.ParserOption }{ { "invalid JWT", @@ -82,6 +84,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenMalformed}, nil, jwt.SigningMethodRS256, + nil, }, { "invalid JSON claim", @@ -92,6 +95,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenMalformed}, nil, jwt.SigningMethodRS256, + nil, }, { "bearer in JWT", @@ -102,6 +106,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenMalformed}, nil, jwt.SigningMethodRS256, + nil, }, { "basic", @@ -112,6 +117,7 @@ var jwtTestData = []struct { nil, nil, jwt.SigningMethodRS256, + nil, }, { "multiple keys, last matches", @@ -122,6 +128,7 @@ var jwtTestData = []struct { nil, nil, jwt.SigningMethodRS256, + nil, }, { "multiple keys not []interface{} type, all match", @@ -132,6 +139,7 @@ var jwtTestData = []struct { nil, nil, jwt.SigningMethodRS256, + nil, }, { "multiple keys, first matches", @@ -142,6 +150,7 @@ var jwtTestData = []struct { nil, nil, jwt.SigningMethodRS256, + nil, }, { "public keys slice, not allowed", @@ -152,6 +161,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenSignatureInvalid}, nil, jwt.SigningMethodRS256, + nil, }, { "basic expired", @@ -162,6 +172,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenExpired}, nil, jwt.SigningMethodRS256, + nil, }, { "basic nbf", @@ -172,6 +183,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenNotValidYet}, nil, jwt.SigningMethodRS256, + nil, }, { "expired and nbf", @@ -182,6 +194,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenNotValidYet, jwt.ErrTokenExpired}, nil, jwt.SigningMethodRS256, + nil, }, { "basic invalid", @@ -192,6 +205,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenSignatureInvalid, rsa.ErrVerification}, nil, jwt.SigningMethodRS256, + nil, }, { "basic nokeyfunc", @@ -202,6 +216,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenUnverifiable}, nil, jwt.SigningMethodRS256, + nil, }, { "basic nokey", @@ -212,6 +227,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenSignatureInvalid}, nil, jwt.SigningMethodRS256, + nil, }, { "multiple nokey", @@ -222,6 +238,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenSignatureInvalid}, nil, jwt.SigningMethodRS256, + nil, }, { "empty verification key set", @@ -232,6 +249,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenUnverifiable}, nil, jwt.SigningMethodRS256, + nil, }, { "zero length key list", @@ -242,6 +260,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenSignatureInvalid}, nil, jwt.SigningMethodRS256, + nil, }, { "basic errorkey", @@ -252,6 +271,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenUnverifiable, errKeyFuncError}, nil, jwt.SigningMethodRS256, + nil, }, { "invalid signing method", @@ -262,6 +282,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenSignatureInvalid}, jwt.NewParser(jwt.WithValidMethods([]string{"HS256"})), jwt.SigningMethodRS256, + nil, }, { "valid RSA signing method", @@ -272,6 +293,7 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithValidMethods([]string{"RS256", "HS256"})), jwt.SigningMethodRS256, + nil, }, { "ECDSA signing method not accepted", @@ -282,6 +304,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenSignatureInvalid}, jwt.NewParser(jwt.WithValidMethods([]string{"RS256", "HS256"})), jwt.SigningMethodES256, + nil, }, { "valid ECDSA signing method", @@ -292,6 +315,7 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithValidMethods([]string{"HS256", "ES256"})), jwt.SigningMethodES256, + nil, }, { "JSON Number", @@ -302,6 +326,7 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, + nil, }, { "JSON Number - basic expired", @@ -312,6 +337,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenExpired}, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, + nil, }, { "JSON Number - basic nbf", @@ -322,6 +348,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenNotValidYet}, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, + nil, }, { "JSON Number - expired and nbf", @@ -332,6 +359,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenNotValidYet, jwt.ErrTokenExpired}, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, + nil, }, { "SkipClaimsValidation during token parsing", @@ -342,6 +370,7 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithJSONNumber(), jwt.WithoutClaimsValidation()), jwt.SigningMethodRS256, + nil, }, { "RFC7519 Claims", @@ -354,6 +383,7 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, + nil, }, { "RFC7519 Claims - single aud", @@ -366,6 +396,7 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, + nil, }, { "RFC7519 Claims - multiple aud", @@ -378,6 +409,7 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, + nil, }, { "RFC7519 Claims - single aud with wrong type", @@ -390,6 +422,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenMalformed}, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, + nil, }, { "RFC7519 Claims - multiple aud with wrong types", @@ -402,6 +435,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenMalformed}, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, + nil, }, { "RFC7519 Claims - nbf with 60s skew", @@ -412,6 +446,7 @@ var jwtTestData = []struct { []error{jwt.ErrTokenNotValidYet}, jwt.NewParser(jwt.WithLeeway(time.Minute)), jwt.SigningMethodRS256, + nil, }, { "RFC7519 Claims - nbf with 120s skew", @@ -422,6 +457,29 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithLeeway(2 * time.Minute)), jwt.SigningMethodRS256, + nil, + }, + { + "custom json encoder", + "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + defaultKeyFunc, + jwt.MapClaims{"foo": "bar"}, + true, + nil, + nil, + jwt.SigningMethodRS256, + []jwt.ParserOption{jwt.WithJSONUnmarshal(json.Unmarshal)}, + }, + { + "custom base64 encoder", + "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + defaultKeyFunc, + jwt.MapClaims{"foo": "bar"}, + true, + nil, + nil, + jwt.SigningMethodRS256, + []jwt.ParserOption{jwt.WithBase64Decoder(base64.RawURLEncoding.DecodeString)}, }, { "rejects if exp is required but missing", @@ -463,7 +521,7 @@ func TestParser_Parse(t *testing.T) { var err error var parser = data.parser if parser == nil { - parser = jwt.NewParser() + parser = jwt.NewParser(data.options...) } // Figure out correct claims type switch data.claims.(type) { diff --git a/token.go b/token.go index 352873a2..618a1c8e 100644 --- a/token.go +++ b/token.go @@ -28,12 +28,14 @@ type VerificationKeySet struct { // Token represents a JWT Token. Different fields will be used depending on // whether you're creating or parsing/verifying a token. type Token struct { - Raw string // Raw contains the raw token. Populated when you [Parse] a token - Method SigningMethod // Method is the signing method used or to be used - Header map[string]interface{} // Header is the first segment of the token in decoded form - Claims Claims // Claims is the second segment of the token in decoded form - Signature []byte // Signature is the third segment of the token in decoded form. Populated when you Parse a token - Valid bool // Valid specifies if the token is valid. Populated when you Parse/Verify a token + Raw string // Raw contains the raw token. Populated when you [Parse] a token + Method SigningMethod // Method is the signing method used or to be used + Header map[string]interface{} // Header is the first segment of the token in decoded form + Claims Claims // Claims is the second segment of the token in decoded form + Signature []byte // Signature is the third segment of the token in decoded form. Populated when you Parse a token + Valid bool // Valid specifies if the token is valid. Populated when you Parse/Verify a token + jsonEncoder JSONMarshalFunc // jsonEncoder is the custom json encoder/decoder + base64Encoder Base64EncodeFunc // base64Encoder is the custom base64 encoder/decoder } // New creates a new [Token] with the specified signing method and an empty map @@ -45,7 +47,7 @@ func New(method SigningMethod, opts ...TokenOption) *Token { // NewWithClaims creates a new [Token] with the specified signing method and // claims. Additional options can be specified, but are currently unused. func NewWithClaims(method SigningMethod, claims Claims, opts ...TokenOption) *Token { - return &Token{ + t := &Token{ Header: map[string]interface{}{ "typ": "JWT", "alg": method.Alg(), @@ -53,6 +55,10 @@ func NewWithClaims(method SigningMethod, claims Claims, opts ...TokenOption) *To Claims: claims, Method: method, } + for _, opt := range opts { + opt(t) + } + return t } // SignedString creates and returns a complete, signed JWT. The token is signed @@ -78,12 +84,19 @@ func (t *Token) SignedString(key interface{}) (string, error) { // of the whole deal. Unless you need this for something special, just go // straight for the SignedString. func (t *Token) SigningString() (string, error) { - h, err := json.Marshal(t.Header) + var marshal JSONMarshalFunc + if t.jsonEncoder != nil { + marshal = t.jsonEncoder + } else { + marshal = json.Marshal + } + + h, err := marshal(t.Header) if err != nil { return "", err } - c, err := json.Marshal(t.Claims) + c, err := marshal(t.Claims) if err != nil { return "", err } @@ -95,6 +108,13 @@ func (t *Token) SigningString() (string, error) { // stripped. In the future, this function might take into account a // [TokenOption]. Therefore, this function exists as a method of [Token], rather // than a global function. -func (*Token) EncodeSegment(seg []byte) string { - return base64.RawURLEncoding.EncodeToString(seg) +func (t *Token) EncodeSegment(seg []byte) string { + var enc Base64EncodeFunc + if t.base64Encoder != nil { + enc = t.base64Encoder + } else { + enc = base64.RawURLEncoding.EncodeToString + } + + return enc(seg) } diff --git a/token_option.go b/token_option.go index b4ae3bad..3e4c20b6 100644 --- a/token_option.go +++ b/token_option.go @@ -3,3 +3,15 @@ package jwt // TokenOption is a reserved type, which provides some forward compatibility, // if we ever want to introduce token creation-related options. type TokenOption func(*Token) + +func WithJSONEncoder(f JSONMarshalFunc) TokenOption { + return func(token *Token) { + token.jsonEncoder = f + } +} + +func WithBase64Encoder(f Base64EncodeFunc) TokenOption { + return func(token *Token) { + token.base64Encoder = f + } +} diff --git a/token_test.go b/token_test.go index f18329e0..7c76fade 100644 --- a/token_test.go +++ b/token_test.go @@ -1,6 +1,8 @@ package jwt_test import ( + "encoding/base64" + "encoding/json" "testing" "github.com/golang-jwt/jwt/v5" @@ -14,6 +16,7 @@ func TestToken_SigningString(t1 *testing.T) { Claims jwt.Claims Signature []byte Valid bool + Options []jwt.TokenOption } tests := []struct { name string @@ -21,6 +24,22 @@ func TestToken_SigningString(t1 *testing.T) { want string wantErr bool }{ + { + name: "", + fields: fields{ + Raw: "", + Method: jwt.SigningMethodHS256, + Header: map[string]interface{}{ + "typ": "JWT", + "alg": jwt.SigningMethodHS256.Alg(), + }, + Claims: jwt.RegisteredClaims{}, + Valid: false, + Options: nil, + }, + want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30", + wantErr: false, + }, { name: "", fields: fields{ @@ -32,6 +51,10 @@ func TestToken_SigningString(t1 *testing.T) { }, Claims: jwt.RegisteredClaims{}, Valid: false, + Options: []jwt.TokenOption{ + jwt.WithJSONEncoder(json.Marshal), + jwt.WithBase64Encoder(base64.StdEncoding.EncodeToString), + }, }, want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30", wantErr: false, From 8089d9eb7806b2dbdb888958326f68fad5648032 Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Wed, 13 Sep 2023 16:49:08 +0200 Subject: [PATCH 2/8] Slightly better way to handle useNumber --- parser.go | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/parser.go b/parser.go index 07e30ab9..a0a9274e 100644 --- a/parser.go +++ b/parser.go @@ -162,6 +162,7 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke unmarshal = json.Unmarshal } + // JSON Unmarshal the header err = unmarshal(headerBytes, &token.Header) if err != nil { return token, parts, newError("could not JSON decode header", ErrTokenMalformed, err) @@ -175,25 +176,23 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke return token, parts, newError("could not base64 decode claim", ErrTokenMalformed, err) } - // If `useJSONNumber` is enabled then we must use *json.Decoder to decode + // If `useJSONNumber` is enabled, then we must use *json.Decoder to decode // the claims. However, this comes with a performance penalty so only use - // it if we must and, otherwise, simple use our decode function. - if !p.useJSONNumber { - // JSON Unmarshal. Special case for map type to avoid weird pointer behavior. - if c, ok := token.Claims.(MapClaims); ok { - err = unmarshal(claimBytes, &c) - } else { - err = unmarshal(claimBytes, &claims) + // it if we must and, otherwise, simple use our existing unmarshal function. + if p.useJSONNumber { + unmarshal = func(data []byte, v any) error { + decoder := json.NewDecoder(bytes.NewBuffer(claimBytes)) + decoder.UseNumber() + return decoder.Decode(v) } + } + + // JSON Unmarshal the claims. Special case for map type to avoid weird + // pointer behavior. + if c, ok := token.Claims.(MapClaims); ok { + err = unmarshal(claimBytes, &c) } else { - dec := json.NewDecoder(bytes.NewBuffer(claimBytes)) - dec.UseNumber() - // JSON Decode. Special case for map type to avoid weird pointer behavior. - if c, ok := token.Claims.(MapClaims); ok { - err = dec.Decode(&c) - } else { - err = dec.Decode(&claims) - } + err = unmarshal(claimBytes, &claims) } if err != nil { return token, parts, newError("could not JSON decode claim", ErrTokenMalformed, err) From 7a5e5d6efe0720be8ae10ceba8040fac23b4906a Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Wed, 13 Sep 2023 16:53:08 +0200 Subject: [PATCH 3/8] Extract encoders/decoders to extra struct --- parser.go | 18 +++++++++++------- parser_option.go | 4 ++-- token.go | 29 +++++++++++++++++------------ token_option.go | 4 ++-- 4 files changed, 32 insertions(+), 23 deletions(-) diff --git a/parser.go b/parser.go index a0a9274e..e79b414b 100644 --- a/parser.go +++ b/parser.go @@ -20,14 +20,18 @@ type Parser struct { validator *Validator + decoders +} + +type decoders struct { + jsonUnmarshal JSONUnmarshalFunc + base64Decode Base64DecodeFunc + // This field is disabled when using a custom base64 encoder. decodeStrict bool // This field is disabled when using a custom base64 encoder. decodePaddingAllowed bool - - unmarshalFunc JSONUnmarshalFunc - base64DecodeFunc Base64DecodeFunc } // NewParser creates a new Parser with the specified options @@ -156,8 +160,8 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke // Choose our JSON decoder. If no custom function is supplied, we use the standard library. var unmarshal JSONUnmarshalFunc - if p.unmarshalFunc != nil { - unmarshal = p.unmarshalFunc + if p.jsonUnmarshal != nil { + unmarshal = p.jsonUnmarshal } else { unmarshal = json.Unmarshal } @@ -214,8 +218,8 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke // take into account whether the [Parser] is configured with additional options, // such as [WithStrictDecoding] or [WithPaddingAllowed]. func (p *Parser) DecodeSegment(seg string) ([]byte, error) { - if p.base64DecodeFunc != nil { - return p.base64DecodeFunc(seg) + if p.base64Decode != nil { + return p.base64Decode(seg) } encoding := base64.RawURLEncoding diff --git a/parser_option.go b/parser_option.go index 735774eb..b724731d 100644 --- a/parser_option.go +++ b/parser_option.go @@ -130,13 +130,13 @@ func WithStrictDecoding() ParserOption { // WithJSONUnmarshal supports a custom [JSONUnmarshal] to use in parsing the JWT. func WithJSONUnmarshal(f JSONUnmarshalFunc) ParserOption { return func(p *Parser) { - p.unmarshalFunc = f + p.jsonUnmarshal = f } } // WithBase64Decoder supports a custom [Base64Decoder] to use in parsing the JWT. func WithBase64Decoder(f Base64DecodeFunc) ParserOption { return func(p *Parser) { - p.base64DecodeFunc = f + p.base64Decode = f } } diff --git a/token.go b/token.go index 618a1c8e..585770f3 100644 --- a/token.go +++ b/token.go @@ -28,14 +28,19 @@ type VerificationKeySet struct { // Token represents a JWT Token. Different fields will be used depending on // whether you're creating or parsing/verifying a token. type Token struct { - Raw string // Raw contains the raw token. Populated when you [Parse] a token - Method SigningMethod // Method is the signing method used or to be used - Header map[string]interface{} // Header is the first segment of the token in decoded form - Claims Claims // Claims is the second segment of the token in decoded form - Signature []byte // Signature is the third segment of the token in decoded form. Populated when you Parse a token - Valid bool // Valid specifies if the token is valid. Populated when you Parse/Verify a token - jsonEncoder JSONMarshalFunc // jsonEncoder is the custom json encoder/decoder - base64Encoder Base64EncodeFunc // base64Encoder is the custom base64 encoder/decoder + Raw string // Raw contains the raw token. Populated when you [Parse] a token + Method SigningMethod // Method is the signing method used or to be used + Header map[string]interface{} // Header is the first segment of the token in decoded form + Claims Claims // Claims is the second segment of the token in decoded form + Signature []byte // Signature is the third segment of the token in decoded form. Populated when you Parse a token + Valid bool // Valid specifies if the token is valid. Populated when you Parse/Verify a token + + encoders +} + +type encoders struct { + jsonMarshal JSONMarshalFunc // jsonEncoder is the custom json encoder/decoder + base64Encode Base64EncodeFunc // base64Encoder is the custom base64 encoder/decoder } // New creates a new [Token] with the specified signing method and an empty map @@ -85,8 +90,8 @@ func (t *Token) SignedString(key interface{}) (string, error) { // straight for the SignedString. func (t *Token) SigningString() (string, error) { var marshal JSONMarshalFunc - if t.jsonEncoder != nil { - marshal = t.jsonEncoder + if t.jsonMarshal != nil { + marshal = t.jsonMarshal } else { marshal = json.Marshal } @@ -110,8 +115,8 @@ func (t *Token) SigningString() (string, error) { // than a global function. func (t *Token) EncodeSegment(seg []byte) string { var enc Base64EncodeFunc - if t.base64Encoder != nil { - enc = t.base64Encoder + if t.base64Encode != nil { + enc = t.base64Encode } else { enc = base64.RawURLEncoding.EncodeToString } diff --git a/token_option.go b/token_option.go index 3e4c20b6..3a9ca8d9 100644 --- a/token_option.go +++ b/token_option.go @@ -6,12 +6,12 @@ type TokenOption func(*Token) func WithJSONEncoder(f JSONMarshalFunc) TokenOption { return func(token *Token) { - token.jsonEncoder = f + token.jsonMarshal = f } } func WithBase64Encoder(f Base64EncodeFunc) TokenOption { return func(token *Token) { - token.base64Encoder = f + token.base64Encode = f } } From f0fa303116a0300136394404dcf107213a577587 Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Wed, 13 Sep 2023 17:24:50 +0200 Subject: [PATCH 4/8] Also supports UseNumber --- encoder.go | 9 ++++++++ parser.go | 23 ++++++++++++++------- parser_option.go | 14 ++++++++++--- parser_test.go | 54 ++++++++++++++---------------------------------- 4 files changed, 51 insertions(+), 49 deletions(-) diff --git a/encoder.go b/encoder.go index 1b48152d..375cb692 100644 --- a/encoder.go +++ b/encoder.go @@ -1,5 +1,7 @@ package jwt +import "io" + // Base64Encoder is an interface that allows to implement custom Base64 encoding // algorithms. type Base64EncodeFunc func(src []byte) string @@ -15,3 +17,10 @@ type JSONMarshalFunc func(v any) ([]byte, error) // JSONUnmarshal is an interface that allows to implement custom JSON unmarshal // algorithms. type JSONUnmarshalFunc func(data []byte, v any) error + +type JSONDecoder interface { + UseNumber() + Decode(v any) error +} + +type JSONNewDecoderFunc[T JSONDecoder] func(r io.Reader) T diff --git a/parser.go b/parser.go index e79b414b..57fbe706 100644 --- a/parser.go +++ b/parser.go @@ -12,7 +12,7 @@ type Parser struct { // If populated, only these methods will be considered valid. validMethods []string - // Use JSON Number format in JSON decoder. This field is disabled when using a custom json encoder. + // Use JSON Number format in JSON decoder. useJSONNumber bool // Skip claims validation during token parsing. @@ -24,8 +24,9 @@ type Parser struct { } type decoders struct { - jsonUnmarshal JSONUnmarshalFunc - base64Decode Base64DecodeFunc + jsonUnmarshal JSONUnmarshalFunc + jsonNewDecoder JSONNewDecoderFunc[JSONDecoder] + base64Decode Base64DecodeFunc // This field is disabled when using a custom base64 encoder. decodeStrict bool @@ -180,12 +181,20 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke return token, parts, newError("could not base64 decode claim", ErrTokenMalformed, err) } - // If `useJSONNumber` is enabled, then we must use *json.Decoder to decode - // the claims. However, this comes with a performance penalty so only use - // it if we must and, otherwise, simple use our existing unmarshal function. + // If `useJSONNumber` is enabled, then we must use a dedicated JSONDecoder + // to decode the claims. However, this comes with a performance penalty so + // only use it if we must and, otherwise, simple use our existing unmarshal + // function. if p.useJSONNumber { unmarshal = func(data []byte, v any) error { - decoder := json.NewDecoder(bytes.NewBuffer(claimBytes)) + buffer := bytes.NewBuffer(claimBytes) + + var decoder JSONDecoder + if p.jsonNewDecoder != nil { + decoder = p.jsonNewDecoder(buffer) + } else { + decoder = json.NewDecoder(buffer) + } decoder.UseNumber() return decoder.Decode(v) } diff --git a/parser_option.go b/parser_option.go index b724731d..9e30b439 100644 --- a/parser_option.go +++ b/parser_option.go @@ -1,6 +1,9 @@ package jwt -import "time" +import ( + "io" + "time" +) // ParserOption is used to implement functional-style options that modify the // behavior of the parser. To add new options, just create a function (ideally @@ -127,10 +130,15 @@ func WithStrictDecoding() ParserOption { } } -// WithJSONUnmarshal supports a custom [JSONUnmarshal] to use in parsing the JWT. -func WithJSONUnmarshal(f JSONUnmarshalFunc) ParserOption { +// WithJSONDecoder supports a custom [JSONUnmarshal] to use in parsing the JWT. +func WithJSONDecoder[T JSONDecoder](f JSONUnmarshalFunc, f2 JSONNewDecoderFunc[T]) ParserOption { return func(p *Parser) { p.jsonUnmarshal = f + // This seems to be necessary, since we don't want to store the specific + // JSONDecoder type in our parser, but need it in the function interface. + p.jsonNewDecoder = func(r io.Reader) JSONDecoder { + return f2(r) + } } } diff --git a/parser_test.go b/parser_test.go index d2b7dd39..0e7b32fe 100644 --- a/parser_test.go +++ b/parser_test.go @@ -73,7 +73,6 @@ var jwtTestData = []struct { err []error parser *jwt.Parser signingMethod jwt.SigningMethod // The method to sign the JWT token for test purpose - options []jwt.ParserOption }{ { "invalid JWT", @@ -84,7 +83,6 @@ var jwtTestData = []struct { []error{jwt.ErrTokenMalformed}, nil, jwt.SigningMethodRS256, - nil, }, { "invalid JSON claim", @@ -95,7 +93,6 @@ var jwtTestData = []struct { []error{jwt.ErrTokenMalformed}, nil, jwt.SigningMethodRS256, - nil, }, { "bearer in JWT", @@ -106,7 +103,6 @@ var jwtTestData = []struct { []error{jwt.ErrTokenMalformed}, nil, jwt.SigningMethodRS256, - nil, }, { "basic", @@ -117,7 +113,6 @@ var jwtTestData = []struct { nil, nil, jwt.SigningMethodRS256, - nil, }, { "multiple keys, last matches", @@ -128,7 +123,6 @@ var jwtTestData = []struct { nil, nil, jwt.SigningMethodRS256, - nil, }, { "multiple keys not []interface{} type, all match", @@ -139,7 +133,6 @@ var jwtTestData = []struct { nil, nil, jwt.SigningMethodRS256, - nil, }, { "multiple keys, first matches", @@ -150,7 +143,6 @@ var jwtTestData = []struct { nil, nil, jwt.SigningMethodRS256, - nil, }, { "public keys slice, not allowed", @@ -161,7 +153,6 @@ var jwtTestData = []struct { []error{jwt.ErrTokenSignatureInvalid}, nil, jwt.SigningMethodRS256, - nil, }, { "basic expired", @@ -172,7 +163,6 @@ var jwtTestData = []struct { []error{jwt.ErrTokenExpired}, nil, jwt.SigningMethodRS256, - nil, }, { "basic nbf", @@ -183,7 +173,6 @@ var jwtTestData = []struct { []error{jwt.ErrTokenNotValidYet}, nil, jwt.SigningMethodRS256, - nil, }, { "expired and nbf", @@ -194,7 +183,6 @@ var jwtTestData = []struct { []error{jwt.ErrTokenNotValidYet, jwt.ErrTokenExpired}, nil, jwt.SigningMethodRS256, - nil, }, { "basic invalid", @@ -205,7 +193,6 @@ var jwtTestData = []struct { []error{jwt.ErrTokenSignatureInvalid, rsa.ErrVerification}, nil, jwt.SigningMethodRS256, - nil, }, { "basic nokeyfunc", @@ -216,7 +203,6 @@ var jwtTestData = []struct { []error{jwt.ErrTokenUnverifiable}, nil, jwt.SigningMethodRS256, - nil, }, { "basic nokey", @@ -227,7 +213,6 @@ var jwtTestData = []struct { []error{jwt.ErrTokenSignatureInvalid}, nil, jwt.SigningMethodRS256, - nil, }, { "multiple nokey", @@ -238,7 +223,6 @@ var jwtTestData = []struct { []error{jwt.ErrTokenSignatureInvalid}, nil, jwt.SigningMethodRS256, - nil, }, { "empty verification key set", @@ -249,7 +233,6 @@ var jwtTestData = []struct { []error{jwt.ErrTokenUnverifiable}, nil, jwt.SigningMethodRS256, - nil, }, { "zero length key list", @@ -260,7 +243,6 @@ var jwtTestData = []struct { []error{jwt.ErrTokenSignatureInvalid}, nil, jwt.SigningMethodRS256, - nil, }, { "basic errorkey", @@ -271,7 +253,6 @@ var jwtTestData = []struct { []error{jwt.ErrTokenUnverifiable, errKeyFuncError}, nil, jwt.SigningMethodRS256, - nil, }, { "invalid signing method", @@ -282,7 +263,6 @@ var jwtTestData = []struct { []error{jwt.ErrTokenSignatureInvalid}, jwt.NewParser(jwt.WithValidMethods([]string{"HS256"})), jwt.SigningMethodRS256, - nil, }, { "valid RSA signing method", @@ -293,7 +273,6 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithValidMethods([]string{"RS256", "HS256"})), jwt.SigningMethodRS256, - nil, }, { "ECDSA signing method not accepted", @@ -304,7 +283,6 @@ var jwtTestData = []struct { []error{jwt.ErrTokenSignatureInvalid}, jwt.NewParser(jwt.WithValidMethods([]string{"RS256", "HS256"})), jwt.SigningMethodES256, - nil, }, { "valid ECDSA signing method", @@ -315,7 +293,6 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithValidMethods([]string{"HS256", "ES256"})), jwt.SigningMethodES256, - nil, }, { "JSON Number", @@ -326,7 +303,6 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, - nil, }, { "JSON Number - basic expired", @@ -337,7 +313,6 @@ var jwtTestData = []struct { []error{jwt.ErrTokenExpired}, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, - nil, }, { "JSON Number - basic nbf", @@ -348,7 +323,6 @@ var jwtTestData = []struct { []error{jwt.ErrTokenNotValidYet}, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, - nil, }, { "JSON Number - expired and nbf", @@ -359,7 +333,6 @@ var jwtTestData = []struct { []error{jwt.ErrTokenNotValidYet, jwt.ErrTokenExpired}, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, - nil, }, { "SkipClaimsValidation during token parsing", @@ -370,7 +343,6 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithJSONNumber(), jwt.WithoutClaimsValidation()), jwt.SigningMethodRS256, - nil, }, { "RFC7519 Claims", @@ -383,7 +355,6 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, - nil, }, { "RFC7519 Claims - single aud", @@ -396,7 +367,6 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, - nil, }, { "RFC7519 Claims - multiple aud", @@ -409,7 +379,6 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, - nil, }, { "RFC7519 Claims - single aud with wrong type", @@ -422,7 +391,6 @@ var jwtTestData = []struct { []error{jwt.ErrTokenMalformed}, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, - nil, }, { "RFC7519 Claims - multiple aud with wrong types", @@ -435,7 +403,6 @@ var jwtTestData = []struct { []error{jwt.ErrTokenMalformed}, jwt.NewParser(jwt.WithJSONNumber()), jwt.SigningMethodRS256, - nil, }, { "RFC7519 Claims - nbf with 60s skew", @@ -446,7 +413,6 @@ var jwtTestData = []struct { []error{jwt.ErrTokenNotValidYet}, jwt.NewParser(jwt.WithLeeway(time.Minute)), jwt.SigningMethodRS256, - nil, }, { "RFC7519 Claims - nbf with 120s skew", @@ -457,7 +423,6 @@ var jwtTestData = []struct { nil, jwt.NewParser(jwt.WithLeeway(2 * time.Minute)), jwt.SigningMethodRS256, - nil, }, { "custom json encoder", @@ -466,9 +431,21 @@ var jwtTestData = []struct { jwt.MapClaims{"foo": "bar"}, true, nil, + jwt.NewParser(jwt.WithJSONDecoder(json.Unmarshal, json.NewDecoder)), + jwt.SigningMethodRS256, + }, + { + "custom json encoder - use numbers", + "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + defaultKeyFunc, + jwt.MapClaims{"foo": "bar"}, + true, nil, + jwt.NewParser( + jwt.WithJSONDecoder(json.Unmarshal, json.NewDecoder), + jwt.WithJSONNumber(), + ), jwt.SigningMethodRS256, - []jwt.ParserOption{jwt.WithJSONUnmarshal(json.Unmarshal)}, }, { "custom base64 encoder", @@ -477,9 +454,8 @@ var jwtTestData = []struct { jwt.MapClaims{"foo": "bar"}, true, nil, - nil, + jwt.NewParser(jwt.WithBase64Decoder(base64.RawURLEncoding.DecodeString)), jwt.SigningMethodRS256, - []jwt.ParserOption{jwt.WithBase64Decoder(base64.RawURLEncoding.DecodeString)}, }, { "rejects if exp is required but missing", @@ -521,7 +497,7 @@ func TestParser_Parse(t *testing.T) { var err error var parser = data.parser if parser == nil { - parser = jwt.NewParser(data.options...) + parser = jwt.NewParser() } // Figure out correct claims type switch data.claims.(type) { From 7684d3e29a9a4805a5b765fcdf63317a8d1df2e1 Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Thu, 14 Sep 2023 21:25:59 +0200 Subject: [PATCH 5/8] Support almost all base64 options --- encoder.go | 21 +++++++++++---------- errors.go | 1 + parser.go | 32 ++++++++++++++++++++++---------- parser_option.go | 7 ++++--- parser_test.go | 2 +- token.go | 14 +++++++------- token_option.go | 4 ++-- token_test.go | 2 +- 8 files changed, 49 insertions(+), 34 deletions(-) diff --git a/encoder.go b/encoder.go index 375cb692..5f21e808 100644 --- a/encoder.go +++ b/encoder.go @@ -2,20 +2,21 @@ package jwt import "io" -// Base64Encoder is an interface that allows to implement custom Base64 encoding -// algorithms. -type Base64EncodeFunc func(src []byte) string +type Base64Encoding interface { + EncodeToString(src []byte) string + DecodeString(s string) ([]byte, error) +} -// Base64Decoder is an interface that allows to implement custom Base64 decoding -// algorithms. -type Base64DecodeFunc func(s string) ([]byte, error) +type Stricter[T Base64Encoding] interface { + Strict() T +} -// JSONEncoder is an interface that allows to implement custom JSON encoding -// algorithms. +// JSONMarshalFunc is an function type that allows to implement custom JSON +// encoding algorithms. type JSONMarshalFunc func(v any) ([]byte, error) -// JSONUnmarshal is an interface that allows to implement custom JSON unmarshal -// algorithms. +// JSONUnmarshalFunc is an function type that allows to implement custom JSON +// unmarshal algorithms. type JSONUnmarshalFunc func(data []byte, v any) error type JSONDecoder interface { diff --git a/errors.go b/errors.go index 23bb616d..a8fe9be9 100644 --- a/errors.go +++ b/errors.go @@ -22,6 +22,7 @@ var ( ErrTokenInvalidId = errors.New("token has invalid id") ErrTokenInvalidClaims = errors.New("token has invalid claims") ErrInvalidType = errors.New("invalid type for claim") + ErrUnsupported = errors.New("operation is unsupported") ) // joinedError is an error type that works similar to what [errors.Join] diff --git a/parser.go b/parser.go index 57fbe706..65570b47 100644 --- a/parser.go +++ b/parser.go @@ -26,12 +26,11 @@ type Parser struct { type decoders struct { jsonUnmarshal JSONUnmarshalFunc jsonNewDecoder JSONNewDecoderFunc[JSONDecoder] - base64Decode Base64DecodeFunc - // This field is disabled when using a custom base64 encoder. - decodeStrict bool + rawUrlBase64Encoding Base64Encoding + urlBase64Encoding Base64Encoding - // This field is disabled when using a custom base64 encoder. + decodeStrict bool decodePaddingAllowed bool } @@ -227,22 +226,35 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke // take into account whether the [Parser] is configured with additional options, // such as [WithStrictDecoding] or [WithPaddingAllowed]. func (p *Parser) DecodeSegment(seg string) ([]byte, error) { - if p.base64Decode != nil { - return p.base64Decode(seg) + var encoding Base64Encoding + if p.rawUrlBase64Encoding != nil { + encoding = p.rawUrlBase64Encoding + } else { + encoding = base64.RawURLEncoding } - encoding := base64.RawURLEncoding - if p.decodePaddingAllowed { if l := len(seg) % 4; l > 0 { seg += strings.Repeat("=", 4-l) } - encoding = base64.URLEncoding + + if p.urlBase64Encoding != nil { + encoding = p.urlBase64Encoding + } else { + encoding = base64.URLEncoding + } } if p.decodeStrict { - encoding = encoding.Strict() + // For now we can only support the standard library here because of the + // current state of the type parameter system + stricter, ok := encoding.(Stricter[*base64.Encoding]) + if !ok { + return nil, newError("strict mode is only supported in encoding/base64", ErrUnsupported) + } + encoding = stricter.Strict() } + return encoding.DecodeString(seg) } diff --git a/parser_option.go b/parser_option.go index 9e30b439..8d701a0d 100644 --- a/parser_option.go +++ b/parser_option.go @@ -142,9 +142,10 @@ func WithJSONDecoder[T JSONDecoder](f JSONUnmarshalFunc, f2 JSONNewDecoderFunc[T } } -// WithBase64Decoder supports a custom [Base64Decoder] to use in parsing the JWT. -func WithBase64Decoder(f Base64DecodeFunc) ParserOption { +// WithBase64Decoder supports a custom [Base64Encoding] to use in parsing the JWT. +func WithBase64Decoder(rawURL Base64Encoding, url Base64Encoding) ParserOption { return func(p *Parser) { - p.base64Decode = f + p.rawUrlBase64Encoding = rawURL + p.urlBase64Encoding = url } } diff --git a/parser_test.go b/parser_test.go index 0e7b32fe..3319c979 100644 --- a/parser_test.go +++ b/parser_test.go @@ -454,7 +454,7 @@ var jwtTestData = []struct { jwt.MapClaims{"foo": "bar"}, true, nil, - jwt.NewParser(jwt.WithBase64Decoder(base64.RawURLEncoding.DecodeString)), + jwt.NewParser(jwt.WithBase64Decoder(base64.RawURLEncoding, base64.URLEncoding)), jwt.SigningMethodRS256, }, { diff --git a/token.go b/token.go index 585770f3..93b87a36 100644 --- a/token.go +++ b/token.go @@ -39,8 +39,8 @@ type Token struct { } type encoders struct { - jsonMarshal JSONMarshalFunc // jsonEncoder is the custom json encoder/decoder - base64Encode Base64EncodeFunc // base64Encoder is the custom base64 encoder/decoder + jsonMarshal JSONMarshalFunc // jsonEncoder is the custom json encoder/decoder + base64Encoding Base64Encoding // base64Encoder is the custom base64 encoding } // New creates a new [Token] with the specified signing method and an empty map @@ -114,12 +114,12 @@ func (t *Token) SigningString() (string, error) { // [TokenOption]. Therefore, this function exists as a method of [Token], rather // than a global function. func (t *Token) EncodeSegment(seg []byte) string { - var enc Base64EncodeFunc - if t.base64Encode != nil { - enc = t.base64Encode + var enc Base64Encoding + if t.base64Encoding != nil { + enc = t.base64Encoding } else { - enc = base64.RawURLEncoding.EncodeToString + enc = base64.RawURLEncoding } - return enc(seg) + return enc.EncodeToString(seg) } diff --git a/token_option.go b/token_option.go index 3a9ca8d9..0fab6a37 100644 --- a/token_option.go +++ b/token_option.go @@ -10,8 +10,8 @@ func WithJSONEncoder(f JSONMarshalFunc) TokenOption { } } -func WithBase64Encoder(f Base64EncodeFunc) TokenOption { +func WithBase64Encoder(enc Base64Encoding) TokenOption { return func(token *Token) { - token.base64Encode = f + token.base64Encoding = enc } } diff --git a/token_test.go b/token_test.go index 7c76fade..d572339e 100644 --- a/token_test.go +++ b/token_test.go @@ -53,7 +53,7 @@ func TestToken_SigningString(t1 *testing.T) { Valid: false, Options: []jwt.TokenOption{ jwt.WithJSONEncoder(json.Marshal), - jwt.WithBase64Encoder(base64.StdEncoding.EncodeToString), + jwt.WithBase64Encoder(base64.StdEncoding), }, }, want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30", From 3ae2a4a3c8b8e42a0d9c896dd60f05c1c4e6437c Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Fri, 15 Sep 2023 22:33:21 +0200 Subject: [PATCH 6/8] Added more Godoc --- encoder.go | 2 ++ parser_option.go | 54 ++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 50 insertions(+), 6 deletions(-) diff --git a/encoder.go b/encoder.go index 5f21e808..1a989068 100644 --- a/encoder.go +++ b/encoder.go @@ -2,6 +2,8 @@ package jwt import "io" +// Base64Encoding represents an object that can encode and decode base64. A +// common example is [encoding/base64.Encoding]. type Base64Encoding interface { EncodeToString(src []byte) string DecodeString(s string) ([]byte, error) diff --git a/parser_option.go b/parser_option.go index 8d701a0d..9513387b 100644 --- a/parser_option.go +++ b/parser_option.go @@ -124,25 +124,67 @@ func WithPaddingAllowed() ParserOption { // WithStrictDecoding will switch the codec used for decoding JWTs into strict // mode. In this mode, the decoder requires that trailing padding bits are zero, // as described in RFC 4648 section 3.5. +// +// Note: This is only supported when using [encoding/base64.Encoding], but not +// by any other decoder specified with [WithBase64Decoder]. func WithStrictDecoding() ParserOption { return func(p *Parser) { p.decodeStrict = true } } -// WithJSONDecoder supports a custom [JSONUnmarshal] to use in parsing the JWT. -func WithJSONDecoder[T JSONDecoder](f JSONUnmarshalFunc, f2 JSONNewDecoderFunc[T]) ParserOption { +// WithJSONDecoder supports a custom JSON decoder to use in parsing the JWT. +// There are two functions that can be supplied: +// - jsonUnmarshal is a [JSONUnmarshalFunc] that is used for the +// un-marshalling the header and claims when no other options are specified +// - jsonNewDecoder is a [JSONNewDecoderFunc] that is used to create an object +// satisfying the [JSONDecoder] interface. +// +// The latter is used when the [WithJSONNumber] option is used. +// +// If any of the supplied functions is set to nil, the defaults from the Go +// standard library, [encoding/json.Unmarshal] and [encoding/json.NewDecoder] +// are used. +// +// Example using the https://github.com/bytedance/sonic library. +// +// import ( +// "github.com/bytedance/sonic" +// ) +// +// var parser = NewParser(WithJSONDecoder(sonic.Unmarshal, sonic.ConfigDefault.NewDecoder)) +func WithJSONDecoder[T JSONDecoder](jsonUnmarshal JSONUnmarshalFunc, jsonNewDecoder JSONNewDecoderFunc[T]) ParserOption { return func(p *Parser) { - p.jsonUnmarshal = f + p.jsonUnmarshal = jsonUnmarshal // This seems to be necessary, since we don't want to store the specific - // JSONDecoder type in our parser, but need it in the function interface. + // JSONDecoder type in our parser, but need it in the function + // interface. p.jsonNewDecoder = func(r io.Reader) JSONDecoder { - return f2(r) + return jsonNewDecoder(r) } } } -// WithBase64Decoder supports a custom [Base64Encoding] to use in parsing the JWT. +// WithBase64Decoder supports a custom Base64 when decoding a base64 encoded +// token. Two encoding can be specified: +// - rawURL needs to contain a [Base64Encoding] that is based on base64url +// without padding. This is used for parsing tokens with the default +// options. +// - url needs to contain a [Base64Encoding] based on base64url with padding. +// The sole use of this to decode tokens when [WithPaddingAllowed] is +// enabled. +// +// If any of the supplied encodings are set to nil, the defaults from the Go +// standard library, [encoding/base64.RawURLEncoding] and +// [encoding/base64.URLEncoding] are used. +// +// Example using the https://github.com/segmentio/asm library. +// +// import ( +// asmbase64 "github.com/segmentio/asm/base64" +// ) +// +// var parser = NewParser(WithBase64Decoder(asmbase64.RawURLEncoding, asmbase64.URLEncoding)) func WithBase64Decoder(rawURL Base64Encoding, url Base64Encoding) ParserOption { return func(p *Parser) { p.rawUrlBase64Encoding = rawURL From f64f4609f38566ac018a564205c25729e0d07b16 Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Sat, 16 Sep 2023 13:31:27 +0200 Subject: [PATCH 7/8] supporting strict() for all libraries --- encoder.go | 6 ++++++ parser.go | 25 ++++++++++++++----------- parser_option.go | 16 +++++++++++++--- 3 files changed, 33 insertions(+), 14 deletions(-) diff --git a/encoder.go b/encoder.go index 1a989068..8b13411d 100644 --- a/encoder.go +++ b/encoder.go @@ -9,10 +9,16 @@ type Base64Encoding interface { DecodeString(s string) ([]byte, error) } +type StrictFunc[T Base64Encoding] func() T + type Stricter[T Base64Encoding] interface { Strict() T } +func DoStrict[S Base64Encoding, T Stricter[S]](x T) Base64Encoding { + return x.Strict() +} + // JSONMarshalFunc is an function type that allows to implement custom JSON // encoding algorithms. type JSONMarshalFunc func(v any) ([]byte, error) diff --git a/parser.go b/parser.go index 65570b47..5b774e1f 100644 --- a/parser.go +++ b/parser.go @@ -12,23 +12,24 @@ type Parser struct { // If populated, only these methods will be considered valid. validMethods []string - // Use JSON Number format in JSON decoder. - useJSONNumber bool - // Skip claims validation during token parsing. skipClaimsValidation bool validator *Validator - decoders + decoding } -type decoders struct { +type decoding struct { jsonUnmarshal JSONUnmarshalFunc jsonNewDecoder JSONNewDecoderFunc[JSONDecoder] rawUrlBase64Encoding Base64Encoding urlBase64Encoding Base64Encoding + strict StrictFunc[Base64Encoding] + + // Use JSON Number format in JSON decoder. + useJSONNumber bool decodeStrict bool decodePaddingAllowed bool @@ -246,13 +247,15 @@ func (p *Parser) DecodeSegment(seg string) ([]byte, error) { } if p.decodeStrict { - // For now we can only support the standard library here because of the - // current state of the type parameter system - stricter, ok := encoding.(Stricter[*base64.Encoding]) - if !ok { - return nil, newError("strict mode is only supported in encoding/base64", ErrUnsupported) + if p.strict != nil { + encoding = p.strict() + } else { + stricter, ok := encoding.(Stricter[*base64.Encoding]) + if !ok { + return nil, newError("WithStrictDecoding() was enabled but supplied base64 encoder does not support strict mode", ErrUnsupported) + } + encoding = stricter.Strict() } - encoding = stricter.Strict() } return encoding.DecodeString(seg) diff --git a/parser_option.go b/parser_option.go index 9513387b..5cf5cd99 100644 --- a/parser_option.go +++ b/parser_option.go @@ -152,7 +152,7 @@ func WithStrictDecoding() ParserOption { // "github.com/bytedance/sonic" // ) // -// var parser = NewParser(WithJSONDecoder(sonic.Unmarshal, sonic.ConfigDefault.NewDecoder)) +// var parser = jwt.NewParser(jwt.WithJSONDecoder(sonic.Unmarshal, sonic.ConfigDefault.NewDecoder)) func WithJSONDecoder[T JSONDecoder](jsonUnmarshal JSONUnmarshalFunc, jsonNewDecoder JSONNewDecoderFunc[T]) ParserOption { return func(p *Parser) { p.jsonUnmarshal = jsonUnmarshal @@ -184,10 +184,20 @@ func WithJSONDecoder[T JSONDecoder](jsonUnmarshal JSONUnmarshalFunc, jsonNewDeco // asmbase64 "github.com/segmentio/asm/base64" // ) // -// var parser = NewParser(WithBase64Decoder(asmbase64.RawURLEncoding, asmbase64.URLEncoding)) -func WithBase64Decoder(rawURL Base64Encoding, url Base64Encoding) ParserOption { +// var parser = jwt.NewParser(jwt.WithBase64Decoder(asmbase64.RawURLEncoding, asmbase64.URLEncoding)) +func WithBase64Decoder[T Base64Encoding](rawURL Base64Encoding, url T) ParserOption { return func(p *Parser) { p.rawUrlBase64Encoding = rawURL p.urlBase64Encoding = url + + // Check, whether the library supports the Strict() function + stricter, ok := rawURL.(Stricter[T]) + if ok { + // We need to get rid of the type parameter T, so we need to wrap it + // here + p.strict = func() Base64Encoding { + return stricter.Strict() + } + } } } From 5e2ab08478c2ff56a96e8f95f9a03bec0a38c454 Mon Sep 17 00:00:00 2001 From: "zhouyiheng.go" Date: Sat, 6 Jul 2024 13:33:20 +0800 Subject: [PATCH 8/8] test: supply example tests for custom encoder and decoder --- encoder.go | 4 ++-- example_test.go | 46 ++++++++++++++++++++++++++++++++++++++++++++++ token_test.go | 2 +- 3 files changed, 49 insertions(+), 3 deletions(-) diff --git a/encoder.go b/encoder.go index 8b13411d..6f49ec28 100644 --- a/encoder.go +++ b/encoder.go @@ -19,11 +19,11 @@ func DoStrict[S Base64Encoding, T Stricter[S]](x T) Base64Encoding { return x.Strict() } -// JSONMarshalFunc is an function type that allows to implement custom JSON +// JSONMarshalFunc is a function type that allows to implement custom JSON // encoding algorithms. type JSONMarshalFunc func(v any) ([]byte, error) -// JSONUnmarshalFunc is an function type that allows to implement custom JSON +// JSONUnmarshalFunc is a function type that allows to implement custom JSON // unmarshal algorithms. type JSONUnmarshalFunc func(data []byte, v any) error diff --git a/example_test.go b/example_test.go index 651841de..abd9e38d 100644 --- a/example_test.go +++ b/example_test.go @@ -1,6 +1,8 @@ package jwt_test import ( + "encoding/base64" + "encoding/json" "errors" "fmt" "log" @@ -9,6 +11,21 @@ import ( "github.com/golang-jwt/jwt/v5" ) +// Example creating a token by passing jwt.WithJSONEncoder or jwt.WithBase64Encoder to +// options to specify the custom encoders when sign the token to string. +// You can try other encoders when you get tired of the standard library. +func ExampleNew_customEncoder() { + mySigningKey := []byte("AllYourBase") + + customJSONEncoderFunc := json.Marshal + customBase64Encoder := base64.RawURLEncoding + token := jwt.New(jwt.SigningMethodHS256, jwt.WithJSONEncoder(customJSONEncoderFunc), jwt.WithBase64Encoder(customBase64Encoder)) + + ss, err := token.SignedString(mySigningKey) + fmt.Println(ss, err) + // Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.E9f4bo8SFbMyEfLEOEXEO2RGcO9cQhznYfSKqTjWwrM +} + // Example (atypical) using the RegisteredClaims type by itself to parse a token. // The RegisteredClaims type is designed to be embedded into your custom types // to provide standard validation features. You can use it alone, but there's @@ -161,6 +178,35 @@ func ExampleParseWithClaims_customValidation() { // Output: bar test } +// Example parsing a string to a token with using a custom decoders. +// It's convenient to use the jwt.WithJSONDecoder or jwt.WithBase64Decoder options when create a parser +// to parse string to token by using your favorite JSON or Base64 decoders. +func ExampleParseWithClaims_customDecoder() { + tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiYXVkIjoic2luZ2xlIn0.QAWg1vGvnqRuCFTMcPkjZljXHh8U3L_qUjszOtQbeaA" + + customJSONUnmarshalFunc := json.Unmarshal + customNewJSONDecoderFunc := json.NewDecoder + + customBase64RawUrlEncoder := base64.RawURLEncoding + customBase64UrlEncoder := base64.URLEncoding + + jwtParser := jwt.NewParser(jwt.WithJSONDecoder(customJSONUnmarshalFunc, customNewJSONDecoderFunc), jwt.WithBase64Decoder(customBase64RawUrlEncoder, customBase64UrlEncoder)) + + token, err := jwtParser.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + return []byte("AllYourBase"), nil + }) + if err != nil { + log.Fatal(err) + } + if !token.Valid { + log.Fatal("invalid") + } else { + fmt.Println("valid") + } + + // Output: valid +} + // An example of parsing the error types using errors.Is. func ExampleParse_errorChecking() { // Token from another example. This token is expired diff --git a/token_test.go b/token_test.go index d572339e..ff9eefe5 100644 --- a/token_test.go +++ b/token_test.go @@ -41,7 +41,7 @@ func TestToken_SigningString(t1 *testing.T) { wantErr: false, }, { - name: "", + name: "encode with custom json and base64 encoder", fields: fields{ Raw: "", Method: jwt.SigningMethodHS256,