Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom json and base64 encoders for Token and Parser #301

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions encoder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package jwt

import "io"

// Base64Encoding represents an object that can encode and decode base64. A
// common example is [encoding/base64.Encoding].
type Base64Encoding interface {
EncodeToString(src []byte) string
DecodeString(s string) ([]byte, error)
}

type StrictFunc[T Base64Encoding] func() T
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this type really needs to be exported at all


type Stricter[T Base64Encoding] interface {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This at least needs a comment if we want to keep It exported, it could probably also be private I guess

Strict() T
}

func DoStrict[S Base64Encoding, T Stricter[S]](x T) Base64Encoding {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double-check if this is actually still needed, which I think it is not?

return x.Strict()
}

// JSONMarshalFunc is a function type that allows to implement custom JSON
// encoding algorithms.
type JSONMarshalFunc func(v any) ([]byte, error)

// JSONUnmarshalFunc is a function type that allows to implement custom JSON
// unmarshal algorithms.
type JSONUnmarshalFunc func(data []byte, v any) error

type JSONDecoder interface {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs a comment because its a public interface

UseNumber()
Decode(v any) error
}

type JSONNewDecoderFunc[T JSONDecoder] func(r io.Reader) T
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs a comment

1 change: 1 addition & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ var (
ErrTokenInvalidId = errors.New("token has invalid id")
ErrTokenInvalidClaims = errors.New("token has invalid claims")
ErrInvalidType = errors.New("invalid type for claim")
ErrUnsupported = errors.New("operation is unsupported")
)

// joinedError is an error type that works similar to what [errors.Join]
Expand Down
46 changes: 46 additions & 0 deletions example_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package jwt_test

import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log"
Expand All @@ -9,6 +11,21 @@ import (
"github.com/golang-jwt/jwt/v5"
)

// Example creating a token by passing jwt.WithJSONEncoder or jwt.WithBase64Encoder to
// options to specify the custom encoders when sign the token to string.
// You can try other encoders when you get tired of the standard library.
func ExampleNew_customEncoder() {
mySigningKey := []byte("AllYourBase")

customJSONEncoderFunc := json.Marshal
customBase64Encoder := base64.RawURLEncoding
token := jwt.New(jwt.SigningMethodHS256, jwt.WithJSONEncoder(customJSONEncoderFunc), jwt.WithBase64Encoder(customBase64Encoder))

ss, err := token.SignedString(mySigningKey)
fmt.Println(ss, err)
// Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.E9f4bo8SFbMyEfLEOEXEO2RGcO9cQhznYfSKqTjWwrM <nil>
}

// Example (atypical) using the RegisteredClaims type by itself to parse a token.
// The RegisteredClaims type is designed to be embedded into your custom types
// to provide standard validation features. You can use it alone, but there's
Expand Down Expand Up @@ -161,6 +178,35 @@ func ExampleParseWithClaims_customValidation() {
// Output: bar test
}

// Example parsing a string to a token with using a custom decoders.
// It's convenient to use the jwt.WithJSONDecoder or jwt.WithBase64Decoder options when create a parser
// to parse string to token by using your favorite JSON or Base64 decoders.
func ExampleParseWithClaims_customDecoder() {
tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiYXVkIjoic2luZ2xlIn0.QAWg1vGvnqRuCFTMcPkjZljXHh8U3L_qUjszOtQbeaA"

customJSONUnmarshalFunc := json.Unmarshal
customNewJSONDecoderFunc := json.NewDecoder

customBase64RawUrlEncoder := base64.RawURLEncoding
customBase64UrlEncoder := base64.URLEncoding

jwtParser := jwt.NewParser(jwt.WithJSONDecoder(customJSONUnmarshalFunc, customNewJSONDecoderFunc), jwt.WithBase64Decoder(customBase64RawUrlEncoder, customBase64UrlEncoder))

token, err := jwtParser.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
return []byte("AllYourBase"), nil
})
if err != nil {
log.Fatal(err)
}
if !token.Valid {
log.Fatal("invalid")
} else {
fmt.Println("valid")
}

// Output: valid
}

// An example of parsing the error types using errors.Is.
func ExampleParse_errorChecking() {
// Token from another example. This token is expired
Expand Down
96 changes: 71 additions & 25 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,26 @@ type Parser struct {
// If populated, only these methods will be considered valid.
validMethods []string

// Use JSON Number format in JSON decoder.
useJSONNumber bool

// Skip claims validation during token parsing.
skipClaimsValidation bool

validator *Validator

decodeStrict bool
decoding
}

type decoding struct {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its not private, but still we probably could use some comments here, especially for this outer struct

jsonUnmarshal JSONUnmarshalFunc
jsonNewDecoder JSONNewDecoderFunc[JSONDecoder]

rawUrlBase64Encoding Base64Encoding
urlBase64Encoding Base64Encoding
strict StrictFunc[Base64Encoding]

// Use JSON Number format in JSON decoder.
useJSONNumber bool

decodeStrict bool
decodePaddingAllowed bool
}

Expand Down Expand Up @@ -148,7 +158,18 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
if headerBytes, err = p.DecodeSegment(parts[0]); err != nil {
return token, parts, newError("could not base64 decode header", ErrTokenMalformed, err)
}
if err = json.Unmarshal(headerBytes, &token.Header); err != nil {

// Choose our JSON decoder. If no custom function is supplied, we use the standard library.
var unmarshal JSONUnmarshalFunc
if p.jsonUnmarshal != nil {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, why are we determining whether a custom JSON marshaller is present in the Parse() method instead of during the Parser creation? The same applies to other places. Parser.jsonUnmarshal should be initialized during Parser construction, as we're not planning to replace JSON marshallers or base64 encoders on the fly.

unmarshal = p.jsonUnmarshal
} else {
unmarshal = json.Unmarshal
}

// JSON Unmarshal the header
err = unmarshal(headerBytes, &token.Header)
if err != nil {
return token, parts, newError("could not JSON decode header", ErrTokenMalformed, err)
}

Expand All @@ -160,25 +181,31 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
return token, parts, newError("could not base64 decode claim", ErrTokenMalformed, err)
}

// If `useJSONNumber` is enabled then we must use *json.Decoder to decode
// the claims. However, this comes with a performance penalty so only use
// it if we must and, otherwise, simple use json.Unmarshal.
if !p.useJSONNumber {
// JSON Unmarshal. Special case for map type to avoid weird pointer behavior.
if c, ok := token.Claims.(MapClaims); ok {
err = json.Unmarshal(claimBytes, &c)
} else {
err = json.Unmarshal(claimBytes, &claims)
// If `useJSONNumber` is enabled, then we must use a dedicated JSONDecoder
// to decode the claims. However, this comes with a performance penalty so
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small nit: we only know definitly that the performance penalty occurs when using json.NewDecoder. Third-party libraries might not have that penalty, so we might at least state this like ... a performance penalty (when using the standard library decoder).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any evidence?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. #303

// only use it if we must and, otherwise, simple use our existing unmarshal
// function.
if p.useJSONNumber {
unmarshal = func(data []byte, v any) error {
buffer := bytes.NewBuffer(claimBytes)

var decoder JSONDecoder
if p.jsonNewDecoder != nil {
decoder = p.jsonNewDecoder(buffer)
} else {
decoder = json.NewDecoder(buffer)
}
decoder.UseNumber()
return decoder.Decode(v)
}
}

// JSON Unmarshal the claims. Special case for map type to avoid weird
// pointer behavior.
if c, ok := token.Claims.(MapClaims); ok {
err = unmarshal(claimBytes, &c)
} else {
dec := json.NewDecoder(bytes.NewBuffer(claimBytes))
dec.UseNumber()
// JSON Decode. Special case for map type to avoid weird pointer behavior.
if c, ok := token.Claims.(MapClaims); ok {
err = dec.Decode(&c)
} else {
err = dec.Decode(&claims)
}
err = unmarshal(claimBytes, &claims)
}
if err != nil {
return token, parts, newError("could not JSON decode claim", ErrTokenMalformed, err)
Expand All @@ -200,18 +227,37 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
// take into account whether the [Parser] is configured with additional options,
// such as [WithStrictDecoding] or [WithPaddingAllowed].
func (p *Parser) DecodeSegment(seg string) ([]byte, error) {
encoding := base64.RawURLEncoding
var encoding Base64Encoding
if p.rawUrlBase64Encoding != nil {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same comment about stage where custom base64 encoder function set. We can set p.rawUrlBase64Encoding to base64.RawURLEncoding or custom function on Parser creation

Copy link
Collaborator

@oxisto oxisto Oct 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this as a safe guard if people ever came across the crazy idea to directly create a Parser struct, which is unfortunately possible because we did not hide it behind an interface. But yes, if we assume that in the case you are doing that you REALLY know what you are doing we can do it in the init

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that sounds reasonable.

encoding = p.rawUrlBase64Encoding
} else {
encoding = base64.RawURLEncoding
}

if p.decodePaddingAllowed {
if l := len(seg) % 4; l > 0 {
seg += strings.Repeat("=", 4-l)
}
encoding = base64.URLEncoding

if p.urlBase64Encoding != nil {
encoding = p.urlBase64Encoding
} else {
encoding = base64.URLEncoding
}
}

if p.decodeStrict {
encoding = encoding.Strict()
if p.strict != nil {
encoding = p.strict()
} else {
stricter, ok := encoding.(Stricter[*base64.Encoding])
if !ok {
return nil, newError("WithStrictDecoding() was enabled but supplied base64 encoder does not support strict mode", ErrUnsupported)
}
encoding = stricter.Strict()
}
}

return encoding.DecodeString(seg)
}

Expand Down
77 changes: 76 additions & 1 deletion parser_option.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package jwt

import "time"
import (
"io"
"time"
)

// ParserOption is used to implement functional-style options that modify the
// behavior of the parser. To add new options, just create a function (ideally
Expand Down Expand Up @@ -121,8 +124,80 @@ func WithPaddingAllowed() ParserOption {
// WithStrictDecoding will switch the codec used for decoding JWTs into strict
// mode. In this mode, the decoder requires that trailing padding bits are zero,
// as described in RFC 4648 section 3.5.
//
// Note: This is only supported when using [encoding/base64.Encoding], but not
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this note is not true anymore, since we are checking whether the library supports a Strict() function. We need to check this.

// by any other decoder specified with [WithBase64Decoder].
func WithStrictDecoding() ParserOption {
return func(p *Parser) {
p.decodeStrict = true
}
}

// WithJSONDecoder supports a custom JSON decoder to use in parsing the JWT.
// There are two functions that can be supplied:
// - jsonUnmarshal is a [JSONUnmarshalFunc] that is used for the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: double "the"

// un-marshalling the header and claims when no other options are specified
// - jsonNewDecoder is a [JSONNewDecoderFunc] that is used to create an object
// satisfying the [JSONDecoder] interface.
//
// The latter is used when the [WithJSONNumber] option is used.
//
// If any of the supplied functions is set to nil, the defaults from the Go
// standard library, [encoding/json.Unmarshal] and [encoding/json.NewDecoder]
// are used.
//
// Example using the https://github.com/bytedance/sonic library.
//
// import (
// "github.com/bytedance/sonic"
// )
//
// var parser = jwt.NewParser(jwt.WithJSONDecoder(sonic.Unmarshal, sonic.ConfigDefault.NewDecoder))
func WithJSONDecoder[T JSONDecoder](jsonUnmarshal JSONUnmarshalFunc, jsonNewDecoder JSONNewDecoderFunc[T]) ParserOption {
return func(p *Parser) {
p.jsonUnmarshal = jsonUnmarshal
// This seems to be necessary, since we don't want to store the specific
// JSONDecoder type in our parser, but need it in the function
// interface.
p.jsonNewDecoder = func(r io.Reader) JSONDecoder {
return jsonNewDecoder(r)
}
}
}

// WithBase64Decoder supports a custom Base64 when decoding a base64 encoded
Copy link
Collaborator

@oxisto oxisto Jul 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grammar: "a custom Base64 [something missing here]"

// token. Two encoding can be specified:
// - rawURL needs to contain a [Base64Encoding] that is based on base64url
// without padding. This is used for parsing tokens with the default
// options.
// - url needs to contain a [Base64Encoding] based on base64url with padding.
// The sole use of this to decode tokens when [WithPaddingAllowed] is
// enabled.
//
// If any of the supplied encodings are set to nil, the defaults from the Go
// standard library, [encoding/base64.RawURLEncoding] and
// [encoding/base64.URLEncoding] are used.
//
// Example using the https://github.com/segmentio/asm library.
//
// import (
// asmbase64 "github.com/segmentio/asm/base64"
// )
//
// var parser = jwt.NewParser(jwt.WithBase64Decoder(asmbase64.RawURLEncoding, asmbase64.URLEncoding))
func WithBase64Decoder[T Base64Encoding](rawURL Base64Encoding, url T) ParserOption {
return func(p *Parser) {
p.rawUrlBase64Encoding = rawURL
p.urlBase64Encoding = url

// Check, whether the library supports the Strict() function
stricter, ok := rawURL.(Stricter[T])
if ok {
// We need to get rid of the type parameter T, so we need to wrap it
// here
p.strict = func() Base64Encoding {
return stricter.Strict()
}
}
}
}
34 changes: 34 additions & 0 deletions parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package jwt_test
import (
"crypto"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -423,6 +424,39 @@ var jwtTestData = []struct {
jwt.NewParser(jwt.WithLeeway(2 * time.Minute)),
jwt.SigningMethodRS256,
},
{
"custom json encoder",
"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg",
defaultKeyFunc,
jwt.MapClaims{"foo": "bar"},
true,
nil,
jwt.NewParser(jwt.WithJSONDecoder(json.Unmarshal, json.NewDecoder)),
jwt.SigningMethodRS256,
},
{
"custom json encoder - use numbers",
"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg",
defaultKeyFunc,
jwt.MapClaims{"foo": "bar"},
true,
nil,
jwt.NewParser(
jwt.WithJSONDecoder(json.Unmarshal, json.NewDecoder),
jwt.WithJSONNumber(),
),
jwt.SigningMethodRS256,
},
{
"custom base64 encoder",
"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg",
defaultKeyFunc,
jwt.MapClaims{"foo": "bar"},
true,
nil,
jwt.NewParser(jwt.WithBase64Decoder(base64.RawURLEncoding, base64.URLEncoding)),
jwt.SigningMethodRS256,
},
{
"rejects if exp is required but missing",
"", // autogen
Expand Down
Loading