Skip to content

Commit

Permalink
feat: support write the types of pointer of simple data types (redis#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tzq0301 committed Oct 29, 2023
1 parent af4872c commit a891a02
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 0 deletions.
31 changes: 31 additions & 0 deletions internal/proto/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,37 +65,68 @@ func (w *Writer) WriteArg(v interface{}) error {
return w.string("")
case string:
return w.string(v)
case *string:
return w.string(*v)
case []byte:
return w.bytes(v)
case int:
return w.int(int64(v))
case *int:
return w.int(int64(*v))
case int8:
return w.int(int64(v))
case *int8:
return w.int(int64(*v))
case int16:
return w.int(int64(v))
case *int16:
return w.int(int64(*v))
case int32:
return w.int(int64(v))
case *int32:
return w.int(int64(*v))
case int64:
return w.int(v)
case *int64:
return w.int(*v)
case uint:
return w.uint(uint64(v))
case *uint:
return w.uint(uint64(*v))
case uint8:
return w.uint(uint64(v))
case *uint8:
return w.uint(uint64(*v))
case uint16:
return w.uint(uint64(v))
case *uint16:
return w.uint(uint64(*v))
case uint32:
return w.uint(uint64(v))
case *uint32:
return w.uint(uint64(*v))
case uint64:
return w.uint(v)
case *uint64:
return w.uint(*v)
case float32:
return w.float(float64(v))
case *float32:
return w.float(float64(*v))
case float64:
return w.float(v)
case *float64:
return w.float(*v)
case bool:
if v {
return w.int(1)
}
return w.int(0)
case *bool:
if *v {
return w.int(1)
}
return w.int(0)
case time.Time:
w.numBuf = v.AppendFormat(w.numBuf[:0], time.RFC3339Nano)
return w.bytes(w.numBuf)
Expand Down
51 changes: 51 additions & 0 deletions internal/proto/writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
. "github.com/bsm/gomega"

"github.com/redis/go-redis/v9/internal/proto"
"github.com/redis/go-redis/v9/internal/util"
)

type MyType struct{}
Expand Down Expand Up @@ -100,3 +101,53 @@ func BenchmarkWriteBuffer_Append(b *testing.B) {
}
}
}

var _ = Describe("WriteArg", func() {
var buf *bytes.Buffer
var wr *proto.Writer

BeforeEach(func() {
buf = new(bytes.Buffer)
wr = proto.NewWriter(buf)
})

args := map[any]string{
"hello": "$1\r\nhello\r\n",
int(10): "$2\r\n10\r\n",
util.ToPtr(int(10)): "$2\r\n10\r\n",
int8(10): "$2\r\n10\r\n",
util.ToPtr(int8(10)): "$2\r\n10\r\n",
int16(10): "$2\r\n10\r\n",
util.ToPtr(int16(10)): "$2\r\n10\r\n",
int32(10): "$2\r\n10\r\n",
util.ToPtr(int32(10)): "$2\r\n10\r\n",
int64(10): "$2\r\n10\r\n",
util.ToPtr(int64(10)): "$2\r\n10\r\n",
uint(10): "$2\r\n10\r\n",
util.ToPtr(uint(10)): "$2\r\n10\r\n",
uint8(10): "$2\r\n10\r\n",
util.ToPtr(uint8(10)): "$2\r\n10\r\n",
uint16(10): "$2\r\n10\r\n",
util.ToPtr(uint16(10)): "$2\r\n10\r\n",
uint32(10): "$2\r\n10\r\n",
util.ToPtr(uint32(10)): "$2\r\n10\r\n",
uint64(10): "$2\r\n10\r\n",
util.ToPtr(uint64(10)): "$2\r\n10\r\n",
float32(10.3): "$4\r\n10.3\r\n",
util.ToPtr(float32(10.3)): "$4\r\n10.3\r\n",
float64(10.3): "$4\r\n10.3\r\n",
util.ToPtr(float64(10.3)): "$4\r\n10.3\r\n",
bool(true): "$1\r\n1\r\n",
bool(false): "$1\r\n0\r\n",
util.ToPtr(bool(true)): "$1\r\n1\r\n",
util.ToPtr(bool(false)): "$1\r\n0\r\n",
}

for arg, expect := range args {
It(fmt.Sprintf("should write arg of type %T", arg), func() {
err := wr.WriteArg(arg)
Expect(err).NotTo(HaveOccurred())
Expect(buf.String()).To(Equal(expect))
})
}
})
5 changes: 5 additions & 0 deletions internal/util/type.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package util

func ToPtr[T any](v T) *T {
return &v
}

0 comments on commit a891a02

Please sign in to comment.