diff --git a/client/batch.go b/client/batch.go index cfce47626482..ece9b91d8a00 100644 --- a/client/batch.go +++ b/client/batch.go @@ -19,6 +19,7 @@ package client import ( "fmt" + "reflect" "github.com/cockroachdb/cockroach/proto" gogoproto "github.com/gogo/protobuf/proto" @@ -228,12 +229,12 @@ func (b *Batch) Put(key, value interface{}) { b.initResult(0, 1, err) return } - v, err := marshalValue(value) + v, err := marshalValue(reflect.ValueOf(value)) if err != nil { b.initResult(0, 1, err) return } - b.calls = append(b.calls, proto.PutCall(proto.Key(k), proto.Value{Bytes: v})) + b.calls = append(b.calls, proto.PutCall(proto.Key(k), v)) b.initResult(1, 1, nil) } @@ -252,17 +253,17 @@ func (b *Batch) CPut(key, value, expValue interface{}) { b.initResult(0, 1, err) return } - v, err := marshalValue(value) + v, err := marshalValue(reflect.ValueOf(value)) if err != nil { b.initResult(0, 1, err) return } - ev, err := marshalValue(expValue) + ev, err := marshalValue(reflect.ValueOf(expValue)) if err != nil { b.initResult(0, 1, err) return } - b.calls = append(b.calls, proto.ConditionalPutCall(proto.Key(k), v, ev)) + b.calls = append(b.calls, proto.ConditionalPutCall(proto.Key(k), v.Bytes, ev.Bytes)) b.initResult(1, 1, nil) } diff --git a/client/db.go b/client/db.go index e66128bbf78e..2ba9637289da 100644 --- a/client/db.go +++ b/client/db.go @@ -19,7 +19,6 @@ package client import ( "bytes" - "encoding" "fmt" "math/rand" "net/url" @@ -474,44 +473,6 @@ func (db *DB) send(calls ...proto.Call) (err error) { return } -func marshalKey(k interface{}) ([]byte, error) { - // Note that the ordering here is important. In particular, proto.Key is also - // a fmt.Stringer. - switch t := k.(type) { - case string: - return []byte(t), nil - case proto.Key: - return []byte(t), nil - case []byte: - return t, nil - case encoding.BinaryMarshaler: - return t.MarshalBinary() - case fmt.Stringer: - return []byte(t.String()), nil - } - return nil, fmt.Errorf("unable to marshal key: %T", k) -} - -func marshalValue(v interface{}) ([]byte, error) { - switch t := v.(type) { - case nil: - return nil, nil - case string: - return []byte(t), nil - case proto.Key: - return []byte(t), nil - case []byte: - return t, nil - case gogoproto.Message: - return gogoproto.Marshal(t) - case encoding.BinaryMarshaler: - return t.MarshalBinary() - case fmt.Stringer: - return []byte(t.String()), nil - } - return nil, fmt.Errorf("unable to marshal value: %T", v) -} - // Runner only exports the Run method on a batch of operations. type Runner interface { Run(b *Batch) error diff --git a/client/table.go b/client/table.go index 6d62fbc6b5c2..e8d129cac694 100644 --- a/client/table.go +++ b/client/table.go @@ -19,9 +19,7 @@ package client import ( "bytes" - "encoding" "fmt" - "math" "reflect" "strings" @@ -29,7 +27,6 @@ import ( "github.com/cockroachdb/cockroach/proto" roachencoding "github.com/cockroachdb/cockroach/util/encoding" "github.com/cockroachdb/cockroach/util/log" - gogoproto "github.com/gogo/protobuf/proto" ) // TODO(pmattis): @@ -553,7 +550,7 @@ func (b *Batch) GetStruct(obj interface{}, columns ...string) { c := proto.GetCall(proto.Key(key)) c.Post = func() error { reply := c.Reply.(*proto.GetResponse) - return unmarshalTableValue(reply.Value, v.FieldByIndex(col.field.Index)) + return unmarshalValue(reply.Value, v.FieldByIndex(col.field.Index)) } calls = append(calls, c) } @@ -600,7 +597,7 @@ func (b *Batch) PutStruct(obj interface{}, columns ...string) { log.Infof("Put %q -> %v", key, value.Interface()) } - v, err := marshalTableValue(value) + v, err := marshalValue(value) if err != nil { b.initResult(0, 0, err) return @@ -649,7 +646,7 @@ func (b *Batch) IncStruct(obj interface{}, value int64, column string) { // integer value directly instead of encoding it into a []byte. pv := &proto.Value{} pv.SetInteger(reply.NewValue) - return unmarshalTableValue(pv, v.FieldByIndex(col.field.Index)) + return unmarshalValue(pv, v.FieldByIndex(col.field.Index)) } b.calls = append(b.calls, c) @@ -772,7 +769,7 @@ func (b *Batch) ScanStruct(dest, start, end interface{}, maxRows int64, columns if !ok { return fmt.Errorf("%s: unable to find column %d", m.name, colID) } - if err := unmarshalTableValue(&row.Value, result.FieldByIndex(col.field.Index)); err != nil { + if err := unmarshalValue(&row.Value, result.FieldByIndex(col.field.Index)); err != nil { return err } } @@ -834,136 +831,3 @@ func (b *Batch) DelStruct(obj interface{}, columns ...string) { b.calls = append(b.calls, calls...) b.initResult(len(calls), len(calls), nil) } - -// marshalTableValue returns a proto.Value initialized from the source -// reflect.Value, returning an error if the types are not compatible. -func marshalTableValue(v reflect.Value) (proto.Value, error) { - var r proto.Value - switch t := v.Interface().(type) { - case nil: - return r, nil - - case string: - r.Bytes = []byte(t) - return r, nil - - case []byte: - r.Bytes = t - return r, nil - - case gogoproto.Message: - var err error - r.Bytes, err = gogoproto.Marshal(t) - return r, err - - case encoding.BinaryMarshaler: - var err error - r.Bytes, err = t.MarshalBinary() - return r, err - } - - switch v.Kind() { - case reflect.Bool: - i := int64(0) - if v.Bool() { - i = 1 - } - r.SetInteger(i) - return r, nil - - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - r.SetInteger(v.Int()) - return r, nil - - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - r.SetInteger(int64(v.Uint())) - return r, nil - - case reflect.Float32, reflect.Float64: - r.SetInteger(int64(math.Float64bits(v.Float()))) - return r, nil - - case reflect.String: - r.Bytes = []byte(v.String()) - return r, nil - } - - return r, fmt.Errorf("unable to marshal value: %s", v) -} - -// unmarshalTableValue sets the destination reflect.Value contents from the -// source proto.Value, returning an error if the types are not compatible. -func unmarshalTableValue(src *proto.Value, dest reflect.Value) error { - if src == nil { - dest.Set(reflect.Zero(dest.Type())) - return nil - } - - switch d := dest.Addr().Interface().(type) { - case *string: - if src.Bytes != nil { - *d = string(src.Bytes) - } else { - *d = "" - } - return nil - - case *[]byte: - if src.Bytes != nil { - *d = src.Bytes - } else { - *d = nil - } - return nil - - case *gogoproto.Message: - panic("TODO(pmattis): unimplemented") - - case *encoding.BinaryMarshaler: - panic("TODO(pmattis): unimplemented") - } - - switch dest.Kind() { - case reflect.Bool: - i, err := src.GetInteger() - if err != nil { - return err - } - dest.SetBool(i != 0) - return nil - - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - i, err := src.GetInteger() - if err != nil { - return err - } - dest.SetInt(i) - return nil - - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - i, err := src.GetInteger() - if err != nil { - return err - } - dest.SetUint(uint64(i)) - return nil - - case reflect.Float32, reflect.Float64: - i, err := src.GetInteger() - if err != nil { - return err - } - dest.SetFloat(math.Float64frombits(uint64(i))) - return nil - - case reflect.String: - if src == nil || src.Bytes == nil { - dest.SetString("") - return nil - } - dest.SetString(string(src.Bytes)) - return nil - } - - return fmt.Errorf("unable to unmarshal value: %s", dest.Type()) -} diff --git a/client/util.go b/client/util.go new file mode 100644 index 000000000000..604ad602f888 --- /dev/null +++ b/client/util.go @@ -0,0 +1,189 @@ +// Copyright 2015 The Cockroach Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. See the AUTHORS file +// for names of contributors. +// +// Author: Peter Mattis (peter@cockroachlabs.com) + +package client + +import ( + "encoding" + "fmt" + "math" + "reflect" + + "github.com/cockroachdb/cockroach/proto" + gogoproto "github.com/gogo/protobuf/proto" +) + +// TODO(pmattis): The methods in this file needs tests. + +func marshalKey(k interface{}) ([]byte, error) { + // Note that the ordering here is important. In particular, proto.Key is also + // a fmt.Stringer. + switch t := k.(type) { + case string: + return []byte(t), nil + case proto.Key: + return []byte(t), nil + case []byte: + return t, nil + case encoding.BinaryMarshaler: + return t.MarshalBinary() + case fmt.Stringer: + return []byte(t.String()), nil + } + return nil, fmt.Errorf("unable to marshal key: %T", k) +} + +// marshalValue returns a proto.Value initialized from the source +// reflect.Value, returning an error if the types are not compatible. +func marshalValue(v reflect.Value) (proto.Value, error) { + var r proto.Value + if !v.IsValid() { + return r, nil + } + + switch t := v.Interface().(type) { + case nil: + return r, nil + + case string: + r.Bytes = []byte(t) + return r, nil + + case []byte: + r.Bytes = t + return r, nil + + case proto.Key: + r.Bytes = []byte(t) + return r, nil + + case gogoproto.Message: + var err error + r.Bytes, err = gogoproto.Marshal(t) + return r, err + + case encoding.BinaryMarshaler: + var err error + r.Bytes, err = t.MarshalBinary() + return r, err + } + + switch v.Kind() { + case reflect.Bool: + i := int64(0) + if v.Bool() { + i = 1 + } + r.SetInteger(i) + return r, nil + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + r.SetInteger(v.Int()) + return r, nil + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + r.SetInteger(int64(v.Uint())) + return r, nil + + case reflect.Float32, reflect.Float64: + r.SetInteger(int64(math.Float64bits(v.Float()))) + return r, nil + + case reflect.String: + r.Bytes = []byte(v.String()) + return r, nil + } + + return r, fmt.Errorf("unable to marshal value: %s", v) +} + +// unmarshalValue sets the destination reflect.Value contents from the source +// proto.Value, returning an error if the types are not compatible. +func unmarshalValue(src *proto.Value, dest reflect.Value) error { + if src == nil { + dest.Set(reflect.Zero(dest.Type())) + return nil + } + + switch d := dest.Addr().Interface().(type) { + case *string: + if src.Bytes != nil { + *d = string(src.Bytes) + } else { + *d = "" + } + return nil + + case *[]byte: + if src.Bytes != nil { + *d = src.Bytes + } else { + *d = nil + } + return nil + + case *gogoproto.Message: + panic("TODO(pmattis): unimplemented") + + case *encoding.BinaryUnmarshaler: + panic("TODO(pmattis): unimplemented") + } + + switch dest.Kind() { + case reflect.Bool: + i, err := src.GetInteger() + if err != nil { + return err + } + dest.SetBool(i != 0) + return nil + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + i, err := src.GetInteger() + if err != nil { + return err + } + dest.SetInt(i) + return nil + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + i, err := src.GetInteger() + if err != nil { + return err + } + dest.SetUint(uint64(i)) + return nil + + case reflect.Float32, reflect.Float64: + i, err := src.GetInteger() + if err != nil { + return err + } + dest.SetFloat(math.Float64frombits(uint64(i))) + return nil + + case reflect.String: + if src == nil || src.Bytes == nil { + dest.SetString("") + return nil + } + dest.SetString(string(src.Bytes)) + return nil + } + + return fmt.Errorf("unable to unmarshal value: %s", dest.Type()) +}