diff --git a/encoding.go b/encoding.go index 831761f3..6e4159fa 100644 --- a/encoding.go +++ b/encoding.go @@ -91,16 +91,35 @@ func MakeTypedEncoder(f interface{}) func(*Request) func(io.Writer) Encoder { panic("MakeTypedEncoder must receive a function matching func(*Request, io.Writer, ...)") } - valType := t.In(2) - valTypePtr := reflect.PtrTo(valType) + var ( + valType, valTypeAlt reflect.Type + ) + + valType = t.In(2) + valTypeIsPtr := valType.Kind() == reflect.Ptr + if valTypeIsPtr { + valTypeAlt = valType.Elem() + } else { + valTypeAlt = reflect.PtrTo(valType) + } return MakeEncoder(func(req *Request, w io.Writer, i interface{}) error { iType := reflect.TypeOf(i) iValue := reflect.ValueOf(i) switch iType { case valType: - case valTypePtr: - iValue = iValue.Elem() + case valTypeAlt: + if valTypeIsPtr { + if iValue.CanAddr() { + iValue = iValue.Addr() + } else { + oldValue := iValue + iValue = reflect.New(iType) + iValue.Elem().Set(oldValue) + } + } else { + iValue = iValue.Elem() + } default: return fmt.Errorf("unexpected type %T, expected %v", i, valType) } diff --git a/encoding_test.go b/encoding_test.go index 67a2aaf4..4566e6a5 100644 --- a/encoding_test.go +++ b/encoding_test.go @@ -61,6 +61,31 @@ func TestMakeTypedEncoderByValue(t *testing.T) { } } +func TestMakeTypedEncoderByPointer(t *testing.T) { + expErr := fmt.Errorf("command fooTestObj failed") + f := MakeTypedEncoder(func(req *Request, w io.Writer, v *fooTestObj) error { + if v.Good { + return nil + } + return expErr + }) + + req := &Request{} + + encoderFunc := f(req) + + buf := new(bytes.Buffer) + encoder := encoderFunc(buf) + + if err := encoder.Encode(fooTestObj{true}); err != nil { + t.Fatal(err) + } + + if err := encoder.Encode(fooTestObj{false}); err != expErr { + t.Fatal("expected: ", expErr) + } +} + func TestMakeTypedEncoderArrays(t *testing.T) { f := MakeTypedEncoder(func(req *Request, w io.Writer, v []fooTestObj) error { if len(v) != 2 {