Skip to content

Commit

Permalink
msgpack: fix number types conversions after decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
oleg-jukovec committed Jun 2, 2022
1 parent e2c68d7 commit 36b43bb
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 11 deletions.
30 changes: 29 additions & 1 deletion queue/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,12 +283,40 @@ func (q *queue) produce(cmd string, params ...interface{}) (string, error) {
return qd.task.status, nil
}

func convertUint64(v interface{}) (result uint64, err error) {
switch v := v.(type) {
case uint:
result = uint64(v)
case uint8:
result = uint64(v)
case uint16:
result = uint64(v)
case uint32:
result = uint64(v)
case uint64:
result = uint64(v)
case int:
result = uint64(v)
case int8:
result = uint64(v)
case int16:
result = uint64(v)
case int32:
result = uint64(v)
case int64:
result = uint64(v)
default:
err = fmt.Errorf("Non-number value %T", v)
}
return
}

// Reverse the effect of a bury request on one or more tasks.
func (q *queue) Kick(count uint64) (uint64, error) {
resp, err := q.conn.Call(q.cmds.kick, []interface{}{count})
var id uint64
if err == nil {
id = resp.Data[0].([]interface{})[0].(uint64)
id, err = convertUint64(resp.Data[0].([]interface{})[0])
}
return id, err
}
Expand Down
58 changes: 48 additions & 10 deletions tarantool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,34 @@ func (m *Member) decodeMsgpackImpl(d *Decoder) error {
return nil
}

func convertUint64(v interface{}) (result uint64, err error) {
switch v := v.(type) {
case uint:
result = uint64(v)
case uint8:
result = uint64(v)
case uint16:
result = uint64(v)
case uint32:
result = uint64(v)
case uint64:
result = uint64(v)
case int:
result = uint64(v)
case int8:
result = uint64(v)
case int16:
result = uint64(v)
case int32:
result = uint64(v)
case int64:
result = uint64(v)
default:
err = fmt.Errorf("Non-number value %T", v)
}
return
}

var server = "127.0.0.1:3013"
var spaceNo = uint32(517)
var spaceName = "test"
Expand Down Expand Up @@ -524,7 +552,7 @@ func TestClient(t *testing.T) {
if len(tpl) != 3 {
t.Errorf("Unexpected body of Insert (tuple len)")
}
if id, ok := tpl[0].(uint64); !ok || id != 1 {
if id, err := convertUint64(tpl[0]); err != nil || id != 1 {
t.Errorf("Unexpected body of Insert (0)")
}
if h, ok := tpl[1].(string); !ok || h != "hello" {
Expand Down Expand Up @@ -557,7 +585,7 @@ func TestClient(t *testing.T) {
if len(tpl) != 3 {
t.Errorf("Unexpected body of Delete (tuple len)")
}
if id, ok := tpl[0].(uint64); !ok || id != 1 {
if id, err := convertUint64(tpl[0]); err != nil || id != 1 {
t.Errorf("Unexpected body of Delete (0)")
}
if h, ok := tpl[1].(string); !ok || h != "hello" {
Expand Down Expand Up @@ -599,7 +627,7 @@ func TestClient(t *testing.T) {
if len(tpl) != 3 {
t.Errorf("Unexpected body of Replace (tuple len)")
}
if id, ok := tpl[0].(uint64); !ok || id != 2 {
if id, err := convertUint64(tpl[0]); err != nil || id != 2 {
t.Errorf("Unexpected body of Replace (0)")
}
if h, ok := tpl[1].(string); !ok || h != "hi" {
Expand All @@ -624,7 +652,7 @@ func TestClient(t *testing.T) {
if len(tpl) != 2 {
t.Errorf("Unexpected body of Update (tuple len)")
}
if id, ok := tpl[0].(uint64); !ok || id != 2 {
if id, err := convertUint64(tpl[0]); err != nil || id != 2 {
t.Errorf("Unexpected body of Update (0)")
}
if h, ok := tpl[1].(string); !ok || h != "bye" {
Expand Down Expand Up @@ -673,7 +701,7 @@ func TestClient(t *testing.T) {
if tpl, ok := resp.Data[0].([]interface{}); !ok {
t.Errorf("Unexpected body of Select")
} else {
if id, ok := tpl[0].(uint64); !ok || id != 10 {
if id, err := convertUint64(tpl[0]); err != nil || id != 10 {
t.Errorf("Unexpected body of Select (0)")
}
if h, ok := tpl[1].(string); !ok || h != "val 10" {
Expand Down Expand Up @@ -768,15 +796,15 @@ func TestClient(t *testing.T) {
if err != nil {
t.Errorf("Failed to use Call")
}
if resp.Data[0].([]interface{})[0].(uint64) != 2 {
if val, err := convertUint64(resp.Data[0].([]interface{})[0]); err != nil || val != 2 {
t.Errorf("result is not {{1}} : %v", resp.Data)
}

resp, err = conn.Call17("simple_incr", []interface{}{1})
if err != nil {
t.Errorf("Failed to use Call17")
}
if resp.Data[0].(uint64) != 2 {
if val, err := convertUint64(resp.Data[0]); err != nil || val != 2 {
t.Errorf("result is not {{1}} : %v", resp.Data)
}

Expand All @@ -791,8 +819,7 @@ func TestClient(t *testing.T) {
if len(resp.Data) < 1 {
t.Errorf("Response.Data is empty after Eval")
}
val := resp.Data[0].(uint64)
if val != 11 {
if val, err := convertUint64(resp.Data[0]); err != nil || val != 11 {
t.Errorf("5 + 6 == 11, but got %v", val)
}
}
Expand Down Expand Up @@ -984,7 +1011,18 @@ func TestSQL(t *testing.T) {
assert.NoError(t, err, "Failed to Execute, Query number: %d", i)
assert.NotNil(t, resp, "Response is nil after Execute\nQuery number: %d", i)
for j := range resp.Data {
assert.Equal(t, resp.Data[j], test.Resp.Data[j], "Response data is wrong")
fixed := make([]interface{}, 0)

assert.EqualValues(t, reflect.TypeOf(resp.Data[j]).Kind(), reflect.Slice)
s := reflect.ValueOf(resp.Data[j])
for k := 0; k < s.Len(); k++ {
val := s.Index(k)
if num, err := convertUint64(val.Interface()); err == nil {
val = reflect.ValueOf(num)
}
fixed = append(fixed, val.Interface())
}
assert.Equal(t, fixed, test.Resp.Data[j], "Response data is wrong")
}
assert.Equal(t, resp.SQLInfo.AffectedCount, test.Resp.SQLInfo.AffectedCount, "Affected count is wrong")

Expand Down

0 comments on commit 36b43bb

Please sign in to comment.