diff --git a/reflect.go b/reflect.go index 08d40b4..1614079 100644 --- a/reflect.go +++ b/reflect.go @@ -54,8 +54,9 @@ func Reflect(rcvr any, opts ...options.ServiceDefinitionOption) ServiceDefinitio if !method.IsExported() { continue } - // Method needs three ins: receiver, Context, I - if mtype.NumIn() != 3 { + // Method needs 2-3 ins: receiver, Context, optionally I + numIn := mtype.NumIn() + if numIn < 2 || numIn > 3 { continue } @@ -87,42 +88,61 @@ func Reflect(rcvr any, opts ...options.ServiceDefinitionOption) ServiceDefinitio continue } - // Method needs two outs: O, and error - if mtype.NumOut() != 2 { + // Method needs 0-2 outs: (), (O), (error), (O, error) are all valid + var output reflect.Type + var hasError bool + switch mtype.NumOut() { + case 0: + // () + output = nil + hasError = false + case 1: + if returnType := mtype.Out(0); returnType == typeOfError { + // (error) + output = nil + hasError = true + } else { + output = returnType + hasError = false + } + case 2: + if returnType := mtype.Out(1); returnType != typeOfError { + continue + } + output = mtype.Out(0) + hasError = true + default: continue } - // The second return type of the method must be error. - if returnType := mtype.Out(1); returnType != typeOfError { - continue + var input reflect.Type + if numIn > 2 { + input = mtype.In(2) } - input := mtype.In(2) - output := mtype.Out(0) - switch def := definition.(type) { case *service: - def.Handler(mname, &serviceReflectHandler{ - reflectHandler{ - fn: method.Func, - receiver: val, - input: input, - output: output, - options: options.HandlerOptions{}, - handlerType: nil, - }, - }) + def.Handler(mname, &reflectHandler{ + fn: method.Func, + receiver: val, + input: input, + output: output, + hasError: hasError, + options: options.HandlerOptions{}, + handlerType: nil, + }, + ) case *object: - def.Handler(mname, &objectReflectHandler{ - reflectHandler{ - fn: method.Func, - receiver: val, - input: input, - output: input, - options: options.HandlerOptions{}, - handlerType: &handlerType, - }, - }) + def.Handler(mname, &reflectHandler{ + fn: method.Func, + receiver: val, + input: input, + output: input, + hasError: hasError, + options: options.HandlerOptions{}, + handlerType: &handlerType, + }, + ) } } @@ -138,6 +158,7 @@ type reflectHandler struct { receiver reflect.Value input reflect.Type output reflect.Type + hasError bool options options.HandlerOptions handlerType *internal.ServiceHandlerType } @@ -158,30 +179,44 @@ func (h *reflectHandler) HandlerType() *internal.ServiceHandlerType { return h.handlerType } -type objectReflectHandler struct { - reflectHandler -} - -var _ state.Handler = (*objectReflectHandler)(nil) +func (h *reflectHandler) Call(ctx *state.Context, bytes []byte) ([]byte, error) { + var args []reflect.Value + if h.input != nil { + input := reflect.New(h.input) -func (h *objectReflectHandler) Call(ctx *state.Context, bytes []byte) ([]byte, error) { - input := reflect.New(h.input) + if err := encoding.Unmarshal(h.options.Codec, bytes, input.Interface()); err != nil { + return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest) + } - if err := encoding.Unmarshal(h.options.Codec, bytes, input.Interface()); err != nil { - return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest) + args = []reflect.Value{h.receiver, reflect.ValueOf(ctxWrapper{ctx}), input.Elem()} + } else { + args = []reflect.Value{h.receiver, reflect.ValueOf(ctxWrapper{ctx})} } - // we are sure about the fn signature so it's safe to do this - output := h.fn.Call([]reflect.Value{ - h.receiver, - reflect.ValueOf(ctxWrapper{ctx}), - input.Elem(), - }) - - outI := output[0].Interface() - errI := output[1].Interface() - if errI != nil { - return nil, errI.(error) + output := h.fn.Call(args) + var outI any + + switch [2]bool{h.output != nil, h.hasError} { + case [2]bool{false, false}: + // () + return nil, nil + case [2]bool{false, true}: + // (error) + errI := output[0].Interface() + if errI != nil { + return nil, errI.(error) + } + return nil, nil + case [2]bool{true, false}: + // (O) + outI = output[0].Interface() + case [2]bool{true, true}: + // (O, error) + errI := output[1].Interface() + if errI != nil { + return nil, errI.(error) + } + outI = output[0].Interface() } bytes, err := encoding.Marshal(h.options.Codec, outI) @@ -193,37 +228,4 @@ func (h *objectReflectHandler) Call(ctx *state.Context, bytes []byte) ([]byte, e return bytes, nil } -type serviceReflectHandler struct { - reflectHandler -} - -var _ state.Handler = (*serviceReflectHandler)(nil) - -func (h *serviceReflectHandler) Call(ctx *state.Context, bytes []byte) ([]byte, error) { - input := reflect.New(h.input) - - if err := encoding.Unmarshal(h.options.Codec, bytes, input.Interface()); err != nil { - return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest) - } - - // we are sure about the fn signature so it's safe to do this - output := h.fn.Call([]reflect.Value{ - h.receiver, - reflect.ValueOf(ctxWrapper{ctx}), - input.Elem(), - }) - - outI := output[0].Interface() - errI := output[1].Interface() - if errI != nil { - return nil, errI.(error) - } - - bytes, err := encoding.Marshal(h.options.Codec, outI) - if err != nil { - // we don't use a terminal error here as this is hot-fixable by changing the return type - return nil, fmt.Errorf("failed to serialize output: %w", err) - } - - return bytes, nil -} +var _ state.Handler = (*reflectHandler)(nil) diff --git a/reflect_test.go b/reflect_test.go index 3121480..2a9e577 100644 --- a/reflect_test.go +++ b/reflect_test.go @@ -25,8 +25,15 @@ var shared = internal.ServiceHandlerType_SHARED var tests []reflectTestParams = []reflectTestParams{ {rcvr: validObject{}, serviceName: "validObject", expectedMethods: expectedMethods{ - "Greet": &exclusive, - "GreetShared": &shared, + "Greet": &exclusive, + "GreetShared": &shared, + "NoInput": &exclusive, + "NoError": &exclusive, + "NoOutput": &exclusive, + "NoOutputNoError": &exclusive, + "NoInputNoError": &exclusive, + "NoInputNoOutput": &exclusive, + "NoInputNoOutputNoError": &exclusive, }}, {rcvr: validService{}, serviceName: "validService", expectedMethods: expectedMethods{ "Greet": nil, @@ -53,17 +60,11 @@ func TestReflect(t *testing.T) { } }() def := restate.Reflect(test.rcvr, test.opts...) - foundMethods := make([]string, 0, len(def.Handlers())) - for k := range def.Handlers() { - foundMethods = append(foundMethods, k) - } - for k, expectedTyp := range test.expectedMethods { - handler, ok := def.Handlers()[k] - if !ok { - t.Fatalf("missing handler %s", k) - } - require.Equal(t, expectedTyp, handler.HandlerType(), "mismatched handler type") + foundMethods := make(map[string]*internal.ServiceHandlerType, len(def.Handlers())) + for k, foundHandler := range def.Handlers() { + foundMethods[k] = foundHandler.HandlerType() } + require.Equal(t, test.expectedMethods, foundMethods) require.Equal(t, test.serviceName, def.Name()) }) } @@ -79,22 +80,40 @@ func (validObject) GreetShared(ctx restate.ObjectSharedContext, _ string) (strin return "", nil } -func (validObject) SkipInvalidArgCount(ctx restate.ObjectContext) (string, error) { +func (validObject) NoInput(ctx restate.ObjectContext) (string, error) { return "", nil } -func (validObject) SkipInvalidCtx(ctx context.Context, _ string) (string, error) { - return "", nil +func (validObject) NoError(ctx restate.ObjectContext, _ string) string { + return "" } -func (validObject) SkipInvalidError(ctx restate.ObjectContext, _ string) (string, string) { - return "", "" +func (validObject) NoOutput(ctx restate.ObjectContext, _ string) error { + return nil } -func (validObject) SkipInvalidRetCount(ctx restate.ObjectContext, _ string) string { +func (validObject) NoOutputNoError(ctx restate.ObjectContext, _ string) { +} + +func (validObject) NoInputNoError(ctx restate.ObjectContext) string { return "" } +func (validObject) NoInputNoOutput(ctx restate.ObjectContext) error { + return nil +} + +func (validObject) NoInputNoOutputNoError(ctx restate.ObjectContext) { +} + +func (validObject) SkipInvalidCtx(ctx context.Context, _ string) (string, error) { + return "", nil +} + +func (validObject) SkipInvalidError(ctx restate.ObjectContext, _ string) (error, string) { + return nil, "" +} + func (validObject) skipUnexported(_ string) (string, error) { return "", nil }