Skip to content

Commit

Permalink
Add option to prefer text unmarshaler when data type is String
Browse files Browse the repository at this point in the history
This library currently always prefers to use encoding.BinaryUnmashaler
when it's implemented by the target type.

This might lead to problems when the type also has a text
representation implemented as encoding.TextUnmarshaler.

Consider netip.Addr from stdlib as example, which implements both.
This library won't be able decode MessagePack containing a string
"192.0.2.1" into *netip.Addr because it will attempt to use
encoding.BinaryUnmashaler which doesn't expect text representation.

Fortunately, MessagePack has distinct string and binary types, so we can
check the source data type before choosing the interface to use.

This commit changes the behaviour of decoder as follows. When

1) target Go data type implements both BinaryUnmashaler and
   TextUnmarshaler
2) source MessagePack data type is a string

TextUnmarshaler will be preferred over BinaryUnmashaler.

This feature is gated behind a Decoder option, because it is potentially
backward-incompatible change.

See vmihailenco#370
  • Loading branch information
WGH- committed Feb 27, 2024
1 parent 19c91df commit d8fe095
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 4 deletions.
15 changes: 15 additions & 0 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const (
disallowUnknownFieldsFlag
usePreallocateValues
disableAllocLimitFlag
preferTextUnmarshalerForString
)

type bufReader interface {
Expand Down Expand Up @@ -184,6 +185,20 @@ func (d *Decoder) DisableAllocLimit(on bool) {
}
}

// PreferTextUnmarshalerForString makes the decoder prefer [encoding.TextUnmarshaler]
// over [encoding.BinaryUnmarshaler] when both are implemented, and source
// MessagePack data is a String (as opposed to Binary).
//
// If this option is not enabled, [encoding.BinaryUnmarshaler] will be preferred
// instead, regardless of MessagePack data type.
func (d *Decoder) PreferTextUnmarshalerForString(on bool) {
if on {
d.flags |= preferTextUnmarshalerForString
} else {
d.flags &= ^preferTextUnmarshalerForString
}
}

// Buffered returns a reader of the data remaining in the Decoder's buffer.
// The reader is valid until the next call to Decode.
func (d *Decoder) Buffered() io.Reader {
Expand Down
39 changes: 35 additions & 4 deletions decode_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"errors"
"fmt"
"reflect"

"github.com/vmihailenco/msgpack/v5/msgpcode"
)

var (
Expand Down Expand Up @@ -70,10 +72,16 @@ func _getDecoder(typ reflect.Type) decoderFunc {
if typ.Implements(unmarshalerType) {
return nilAwareDecoder(typ, unmarshalValue)
}
if typ.Implements(binaryUnmarshalerType) {

implementsBinaryUnmarshaler := typ.Implements(binaryUnmarshalerType)
implementsTextUnmarshaler := typ.Implements(textUnmarshalerType)
if implementsBinaryUnmarshaler && implementsTextUnmarshaler {
return nilAwareDecoder(typ, unmarshalBinaryOrTextValue)
}
if implementsBinaryUnmarshaler {
return nilAwareDecoder(typ, unmarshalBinaryValue)
}
if typ.Implements(textUnmarshalerType) {
if implementsTextUnmarshaler {
return nilAwareDecoder(typ, unmarshalTextValue)
}

Expand All @@ -86,10 +94,15 @@ func _getDecoder(typ reflect.Type) decoderFunc {
if ptr.Implements(unmarshalerType) {
return addrDecoder(nilAwareDecoder(typ, unmarshalValue))
}
if ptr.Implements(binaryUnmarshalerType) {
implementsBinaryUnmarshaler := ptr.Implements(binaryUnmarshalerType)
implementsTextUnmarshaler := ptr.Implements(textUnmarshalerType)
if implementsBinaryUnmarshaler && implementsTextUnmarshaler {
return addrDecoder(nilAwareDecoder(typ, unmarshalBinaryOrTextValue))
}
if implementsBinaryUnmarshaler {
return addrDecoder(nilAwareDecoder(typ, unmarshalBinaryValue))
}
if ptr.Implements(textUnmarshalerType) {
if implementsTextUnmarshaler {
return addrDecoder(nilAwareDecoder(typ, unmarshalTextValue))
}
}
Expand Down Expand Up @@ -249,3 +262,21 @@ func unmarshalTextValue(d *Decoder, v reflect.Value) error {
unmarshaler := v.Interface().(encoding.TextUnmarshaler)
return unmarshaler.UnmarshalText(data)
}

func unmarshalBinaryOrTextValue(d *Decoder, v reflect.Value) error {
useText := false
if d.flags&preferTextUnmarshalerForString != 0 {
code, err := d.PeekCode()
if err != nil {
return err
}
if msgpcode.IsString(code) {
useText = true
}
}
if useText {
return unmarshalTextValue(d, v)
} else {
return unmarshalBinaryValue(d, v)
}
}
67 changes: 67 additions & 0 deletions types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package msgpack_test

import (
"bytes"
"encoding"
"encoding/binary"
"encoding/hex"
"fmt"
"math"
Expand Down Expand Up @@ -427,6 +429,8 @@ type typeTest struct {
wantnil bool
wantzero bool
wanted interface{}

preferTextUnmarshalerForString bool
}

func (t typeTest) String() string {
Expand All @@ -442,6 +446,36 @@ func (t *typeTest) requireErr(err error, s string) {
}
}

type binaryTextType uint32

// UnmarshalText implements encoding.TextUnmarshaler
func (v *binaryTextType) UnmarshalText(text []byte) error {
var b [4]byte
n, err := hex.Decode(b[:], text)
if err != nil {
return err
}
if n != 4 {
return fmt.Errorf("invalid length %d", n)
}
*v = binaryTextType(binary.BigEndian.Uint32(b[:]))
return nil
}

// UnmarshalBinary implements encoding.BinaryUnmarshaler
func (v *binaryTextType) UnmarshalBinary(data []byte) error {
if n := len(data); n != 4 {
return fmt.Errorf("invalid length %d", n)
}
*v = binaryTextType(binary.BigEndian.Uint32(data))
return nil
}

var (
_ encoding.TextUnmarshaler = new(binaryTextType)
_ encoding.BinaryUnmarshaler = new(binaryTextType)
)

var (
intSlice = make([]int, 0, 3)
repoURL, _ = url.Parse("https://github.com/vmihailenco/msgpack")
Expand Down Expand Up @@ -622,6 +656,36 @@ var (
},

{in: big.NewInt(123), out: new(big.Int)},

{
in: "deadbeef",
out: new(binaryTextType),
wanted: binaryTextType(0xdeadbeef),
decErr: "invalid length 8",

preferTextUnmarshalerForString: false,
},
{
in: "deadbeef",
out: new(binaryTextType),
wanted: binaryTextType(0xdeadbeef),

preferTextUnmarshalerForString: true,
},
{
in: []byte{0xde, 0xad, 0xbe, 0xef},
out: new(binaryTextType),
wanted: binaryTextType(0xdeadbeef),

preferTextUnmarshalerForString: false,
},
{
in: []byte{0xde, 0xad, 0xbe, 0xef},
out: new(binaryTextType),
wanted: binaryTextType(0xdeadbeef),

preferTextUnmarshalerForString: true,
},
}
)

Expand Down Expand Up @@ -655,6 +719,9 @@ func TestTypes(t *testing.T) {
}

dec := msgpack.NewDecoder(&buf)
if test.preferTextUnmarshalerForString {
dec.PreferTextUnmarshalerForString(true)
}
err = dec.Decode(test.out)
if test.decErr != "" {
test.requireErr(err, test.decErr)
Expand Down

0 comments on commit d8fe095

Please sign in to comment.