Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add decoding option to allow invalid UTF-8 #342

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,26 @@ func (ec ExtraDecErrorCond) valid() bool {
return ec < maxExtraDecError
}

// UTF8Mode option specifies if decoder should
// decode CBOR Text containing invalid UTF-8 string.
type UTF8Mode int

const (
// UTF8RejectInvalid rejects CBOR Text containing
// invalid UTF-8 string.
UTF8RejectInvalid UTF8Mode = iota

// UTF8DecodeInvalid allows decoding CBOR Text containing
// invalid UTF-8 string.
UTF8DecodeInvalid

maxUTF8Mode
)

func (um UTF8Mode) valid() bool {
return um < maxUTF8Mode
}

// DecOptions specifies decoding options.
type DecOptions struct {
// DupMapKey specifies whether to enforce duplicate map key.
Expand Down Expand Up @@ -295,6 +315,10 @@ type DecOptions struct {
// when unmarshalling CBOR into an empty interface value.
// By default, unmarshal uses map[interface{}]interface{}.
DefaultMapType reflect.Type

// UTF8 specifies if decoder should decode CBOR Text containing invalid UTF-8.
// By default, unmarshal rejects CBOR text containing invalid UTF-8.
UTF8 UTF8Mode
}

// DecMode returns DecMode with immutable options and no tags (safe for concurrency).
Expand Down Expand Up @@ -397,6 +421,9 @@ func (opts DecOptions) decMode() (*decMode, error) {
if opts.DefaultMapType != nil && opts.DefaultMapType.Kind() != reflect.Map {
return nil, fmt.Errorf("cbor: invalid DefaultMapType %s", opts.DefaultMapType)
}
if !opts.UTF8.valid() {
return nil, errors.New("cbor: invalid UTF8 " + strconv.Itoa(int(opts.UTF8)))
}
dm := decMode{
dupMapKey: opts.DupMapKey,
timeTag: opts.TimeTag,
Expand All @@ -408,6 +435,7 @@ func (opts DecOptions) decMode() (*decMode, error) {
intDec: opts.IntDec,
extraReturnErrors: opts.ExtraReturnErrors,
defaultMapType: opts.DefaultMapType,
utf8: opts.UTF8,
}
return &dm, nil
}
Expand Down Expand Up @@ -440,6 +468,7 @@ type decMode struct {
intDec IntDecMode
extraReturnErrors ExtraDecErrorCond
defaultMapType reflect.Type
utf8 UTF8Mode
}

var defaultDecMode, _ = DecOptions{}.decMode()
Expand All @@ -456,6 +485,7 @@ func (dm *decMode) DecOptions() DecOptions {
TagsMd: dm.tagsMd,
IntDec: dm.intDec,
ExtraReturnErrors: dm.extraReturnErrors,
UTF8: dm.utf8,
}
}

Expand Down Expand Up @@ -1064,7 +1094,7 @@ func (d *decoder) parseTextString() ([]byte, error) {
if ai != 31 {
b := d.data[d.off : d.off+int(val)]
d.off += int(val)
if !utf8.Valid(b) {
if d.dm.utf8 == UTF8RejectInvalid && !utf8.Valid(b) {
return nil, &SemanticError{"cbor: invalid UTF-8 string"}
}
return b, nil
Expand All @@ -1075,7 +1105,7 @@ func (d *decoder) parseTextString() ([]byte, error) {
_, _, val = d.getHead()
x := d.data[d.off : d.off+int(val)]
d.off += int(val)
if !utf8.Valid(x) {
if d.dm.utf8 == UTF8RejectInvalid && !utf8.Valid(x) {
for !d.foundBreak() {
d.skip() // Skip remaining chunk on error
}
Expand Down
183 changes: 166 additions & 17 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1383,33 +1383,157 @@ func TestInvalidCBORUnmarshal(t *testing.T) {
}
}

func TestInvalidUTF8TextString(t *testing.T) {
invalidUTF8TextStringTests := []struct {
func TestValidUTF8String(t *testing.T) {
dmRejectInvalidUTF8, err := DecOptions{UTF8: UTF8RejectInvalid}.DecMode()
if err != nil {
t.Errorf("DecMode() returned an error %+v", err)
}
dmDecodeInvalidUTF8, err := DecOptions{UTF8: UTF8DecodeInvalid}.DecMode()
if err != nil {
t.Errorf("DecMode() returned an error %+v", err)
}

testCases := []struct {
name string
cborData []byte
dm DecMode
wantObj interface{}
}{
{
name: "with UTF8RejectInvalid",
cborData: hexDecode("6973747265616d696e67"),
dm: dmRejectInvalidUTF8,
wantObj: "streaming",
},
{
name: "with UTF8DecodeInvalid",
cborData: hexDecode("6973747265616d696e67"),
dm: dmDecodeInvalidUTF8,
wantObj: "streaming",
},
{
name: "indef length with UTF8RejectInvalid",
cborData: hexDecode("7f657374726561646d696e67ff"),
dm: dmRejectInvalidUTF8,
wantObj: "streaming",
},
{
name: "indef length with UTF8DecodeInvalid",
cborData: hexDecode("7f657374726561646d696e67ff"),
dm: dmDecodeInvalidUTF8,
wantObj: "streaming",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Decode to empty interface
var i interface{}
err = tc.dm.Unmarshal(tc.cborData, &i)
if err != nil {
t.Errorf("Unmarshal(0x%x) returned error %q", tc.cborData, err)
}
if !reflect.DeepEqual(i, tc.wantObj) {
t.Errorf("Unmarshal(0x%x) returned %v (%T), want %v (%T)", tc.cborData, i, i, tc.wantObj, tc.wantObj)
}

// Decode to string
var v string
err = tc.dm.Unmarshal(tc.cborData, &v)
if err != nil {
t.Errorf("Unmarshal(0x%x) returned error %q", tc.cborData, err)
}
if !reflect.DeepEqual(v, tc.wantObj) {
t.Errorf("Unmarshal(0x%x) returned %v (%T), want %v (%T)", tc.cborData, v, v, tc.wantObj, tc.wantObj)
}
})
}
}

func TestInvalidUTF8String(t *testing.T) {
dmRejectInvalidUTF8, err := DecOptions{UTF8: UTF8RejectInvalid}.DecMode()
if err != nil {
t.Errorf("DecMode() returned an error %+v", err)
}
dmDecodeInvalidUTF8, err := DecOptions{UTF8: UTF8DecodeInvalid}.DecMode()
if err != nil {
t.Errorf("DecMode() returned an error %+v", err)
}

testCases := []struct {
name string
cborData []byte
dm DecMode
wantObj interface{}
wantErrorMsg string
}{
{"definite length text string", hexDecode("61fe"), invalidUTF8ErrorMsg},
{"indefinite length text string", hexDecode("7f62e6b061b4ff"), invalidUTF8ErrorMsg},
{
name: "with UTF8RejectInvalid",
cborData: hexDecode("61fe"),
dm: dmRejectInvalidUTF8,
wantErrorMsg: invalidUTF8ErrorMsg,
},
{
name: "with UTF8DecodeInvalid",
cborData: hexDecode("61fe"),
dm: dmDecodeInvalidUTF8,
wantObj: string([]byte{0xfe}),
},
{
name: "indef length with UTF8RejectInvalid",
cborData: hexDecode("7f62e6b061b4ff"),
dm: dmRejectInvalidUTF8,
wantErrorMsg: invalidUTF8ErrorMsg,
},
{
name: "indef length with UTF8DecodeInvalid",
cborData: hexDecode("7f62e6b061b4ff"),
dm: dmDecodeInvalidUTF8,
wantObj: string([]byte{0xe6, 0xb0, 0xb4}),
},
}
for _, tc := range invalidUTF8TextStringTests {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var i interface{}
if err := Unmarshal(tc.cborData, &i); err == nil {
t.Errorf("Unmarshal(0x%x) didn't return an error", tc.cborData)
} else if err.Error() != tc.wantErrorMsg {
t.Errorf("Unmarshal(0x%x) error %q, want %q", tc.cborData, err.Error(), tc.wantErrorMsg)
// Decode to empty interface
var v interface{}
err = tc.dm.Unmarshal(tc.cborData, &v)
if tc.wantErrorMsg != "" {
if err == nil {
t.Errorf("Unmarshal(0x%x) didn't return error", tc.cborData)
} else if !strings.Contains(err.Error(), tc.wantErrorMsg) {
t.Errorf("Unmarshal(0x%x) error %q, want %q", tc.cborData, err.Error(), tc.wantErrorMsg)
}
} else {
if err != nil {
t.Errorf("Unmarshal(0x%x) returned error %q", tc.cborData, err)
}
if !reflect.DeepEqual(v, tc.wantObj) {
t.Errorf("Unmarshal(0x%x) returned %v (%T), want %v (%T)", tc.cborData, v, v, tc.wantObj, tc.wantObj)
}
}

// Decode to string
var s string
if err := Unmarshal(tc.cborData, &s); err == nil {
t.Errorf("Unmarshal(0x%x) didn't return an error", tc.cborData)
} else if err.Error() != tc.wantErrorMsg {
t.Errorf("Unmarshal(0x%x) error %q, want %q", tc.cborData, err.Error(), tc.wantErrorMsg)
err = tc.dm.Unmarshal(tc.cborData, &s)
if tc.wantErrorMsg != "" {
if err == nil {
t.Errorf("Unmarshal(0x%x) didn't return error", tc.cborData)
} else if !strings.Contains(err.Error(), tc.wantErrorMsg) {
t.Errorf("Unmarshal(0x%x) error %q, want %q", tc.cborData, err.Error(), tc.wantErrorMsg)
}
} else {
if err != nil {
t.Errorf("Unmarshal(0x%x) returned error %q", tc.cborData, err)
}
if !reflect.DeepEqual(s, tc.wantObj) {
t.Errorf("Unmarshal(0x%x) returned %v (%T), want %v (%T)", tc.cborData, s, s, tc.wantObj, tc.wantObj)
}
}
})
}

// Test decoding of mixed invalid text string and valid text string
// with UTF8RejectInvalid option (default)
cborData := hexDecode("7f62e6b061b4ff7f657374726561646d696e67ff")
dec := NewDecoder(bytes.NewReader(cborData))
var s string
Expand All @@ -1423,6 +1547,20 @@ func TestInvalidUTF8TextString(t *testing.T) {
} else if s != "streaming" {
t.Errorf("Decode() returned %q, want %q", s, "streaming")
}

// Test decoding of mixed invalid text string and valid text string
// with UTF8DecodeInvalid option
dec = dmDecodeInvalidUTF8.NewDecoder(bytes.NewReader(cborData))
if err := dec.Decode(&s); err != nil {
t.Errorf("Decode() returned error %q", err)
} else if s != string([]byte{0xe6, 0xb0, 0xb4}) {
t.Errorf("Decode() returned %q, want %q", s, string([]byte{0xe6, 0xb0, 0xb4}))
}
if err := dec.Decode(&s); err != nil {
t.Errorf("Decode() returned error %v", err)
} else if s != "streaming" {
t.Errorf("Decode() returned %q, want %q", s, "streaming")
}
}

func TestUnmarshalStruct(t *testing.T) {
Expand Down Expand Up @@ -3063,15 +3201,16 @@ func TestUnmarshalToNotNilInterface(t *testing.T) {

func TestDecOptions(t *testing.T) {
opts1 := DecOptions{
TimeTag: DecTagRequired,
DupMapKey: DupMapKeyEnforcedAPF,
IndefLength: IndefLengthForbidden,
TimeTag: DecTagRequired,
MaxNestedLevels: 100,
MaxMapPairs: 101,
MaxArrayElements: 102,
MaxMapPairs: 101,
IndefLength: IndefLengthForbidden,
TagsMd: TagsForbidden,
IntDec: IntDecConvertSigned,
ExtraReturnErrors: ExtraDecErrorUnknownField,
UTF8: UTF8DecodeInvalid,
}
dm, err := opts1.DecMode()
if err != nil {
Expand Down Expand Up @@ -4566,6 +4705,16 @@ func TestExtraErrorCondUnknowField(t *testing.T) {
}
}

func TestInvalidUTF8Mode(t *testing.T) {
wantErrorMsg := "cbor: invalid UTF8 2"
_, err := DecOptions{UTF8: 2}.DecMode()
if err == nil {
t.Errorf("DecMode() didn't return an error")
} else if err.Error() != wantErrorMsg {
t.Errorf("DecMode() returned error %q, want %q", err.Error(), wantErrorMsg)
}
}

func TestStreamExtraErrorCondUnknowField(t *testing.T) {
type s struct {
A string
Expand Down