Skip to content

Commit

Permalink
Allow more method signatures in .Reflect()
Browse files Browse the repository at this point in the history
  • Loading branch information
jackkleeman committed Aug 21, 2024
1 parent f505c4c commit a8ab706
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 105 deletions.
184 changes: 97 additions & 87 deletions reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,17 @@ var (
// Reflect converts a struct with methods into a service definition where each correctly-typed
// and exported method of the struct will become a handler in the definition. The service name
// defaults to the name of the struct, but this can be overidden by providing a `ServiceName() string` method.
// The handler name is the name of the method. Handler methods should be of the type `ServiceHandlerFn[I,O]`,
// `ObjectHandlerFn[I, O]` or `ObjectSharedHandlerFn[I, O]`. This function will panic if a mixture of
// object and service method signatures or opts are provided.
// The handler name is the name of the method. Handler methods should be one of the following signatures:
// - (ctx, I) (O, error)
// - (ctx, I) (O)
// - (ctx, I) (error)
// - (ctx, I)
// - (ctx)
// - (ctx) (error)
// - (ctx) (O)
// - (ctx) (O, error)
// Where ctx is [ObjectContext], [ObjectSharedContext] or [Context]. Other signatures are ignored.
// This function will panic if a mixture of object and service method signatures or opts are provided.
//
// Input types will be deserialised with the provided codec (defaults to JSON) except when they are restate.Void,
// in which case no input bytes or content type may be sent.
Expand All @@ -54,8 +62,9 @@ func Reflect(rcvr any, opts ...options.ServiceDefinitionOption) ServiceDefinitio
if !method.IsExported() {
continue
}
// Method needs three ins: receiver, Context, I
if mtype.NumIn() != 3 {
// Method needs 2-3 ins: receiver, Context, optionally I
numIn := mtype.NumIn()
if numIn < 2 || numIn > 3 {
continue
}

Expand Down Expand Up @@ -87,42 +96,61 @@ func Reflect(rcvr any, opts ...options.ServiceDefinitionOption) ServiceDefinitio
continue
}

// Method needs two outs: O, and error
if mtype.NumOut() != 2 {
// Method needs 0-2 outs: (), (O), (error), (O, error) are all valid
var output reflect.Type
var hasError bool
switch mtype.NumOut() {
case 0:
// ()
output = nil
hasError = false
case 1:
if returnType := mtype.Out(0); returnType == typeOfError {
// (error)
output = nil
hasError = true
} else {
output = returnType
hasError = false
}
case 2:
if returnType := mtype.Out(1); returnType != typeOfError {
continue
}
output = mtype.Out(0)
hasError = true
default:
continue
}

// The second return type of the method must be error.
if returnType := mtype.Out(1); returnType != typeOfError {
continue
var input reflect.Type
if numIn > 2 {
input = mtype.In(2)
}

input := mtype.In(2)
output := mtype.Out(0)

switch def := definition.(type) {
case *service:
def.Handler(mname, &serviceReflectHandler{
reflectHandler{
fn: method.Func,
receiver: val,
input: input,
output: output,
options: options.HandlerOptions{},
handlerType: nil,
},
})
def.Handler(mname, &reflectHandler{
fn: method.Func,
receiver: val,
input: input,
output: output,
hasError: hasError,
options: options.HandlerOptions{},
handlerType: nil,
},
)
case *object:
def.Handler(mname, &objectReflectHandler{
reflectHandler{
fn: method.Func,
receiver: val,
input: input,
output: input,
options: options.HandlerOptions{},
handlerType: &handlerType,
},
})
def.Handler(mname, &reflectHandler{
fn: method.Func,
receiver: val,
input: input,
output: input,
hasError: hasError,
options: options.HandlerOptions{},
handlerType: &handlerType,
},
)
}
}

Expand All @@ -138,6 +166,7 @@ type reflectHandler struct {
receiver reflect.Value
input reflect.Type
output reflect.Type
hasError bool
options options.HandlerOptions
handlerType *internal.ServiceHandlerType
}
Expand All @@ -158,30 +187,44 @@ func (h *reflectHandler) HandlerType() *internal.ServiceHandlerType {
return h.handlerType
}

type objectReflectHandler struct {
reflectHandler
}

var _ state.Handler = (*objectReflectHandler)(nil)
func (h *reflectHandler) Call(ctx *state.Context, bytes []byte) ([]byte, error) {
var args []reflect.Value
if h.input != nil {
input := reflect.New(h.input)

func (h *objectReflectHandler) Call(ctx *state.Context, bytes []byte) ([]byte, error) {
input := reflect.New(h.input)
if err := encoding.Unmarshal(h.options.Codec, bytes, input.Interface()); err != nil {
return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest)
}

if err := encoding.Unmarshal(h.options.Codec, bytes, input.Interface()); err != nil {
return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest)
args = []reflect.Value{h.receiver, reflect.ValueOf(ctxWrapper{ctx}), input.Elem()}
} else {
args = []reflect.Value{h.receiver, reflect.ValueOf(ctxWrapper{ctx})}
}

// we are sure about the fn signature so it's safe to do this
output := h.fn.Call([]reflect.Value{
h.receiver,
reflect.ValueOf(ctxWrapper{ctx}),
input.Elem(),
})

outI := output[0].Interface()
errI := output[1].Interface()
if errI != nil {
return nil, errI.(error)
output := h.fn.Call(args)
var outI any

switch [2]bool{h.output != nil, h.hasError} {
case [2]bool{false, false}:
// ()
return nil, nil
case [2]bool{false, true}:
// (error)
errI := output[0].Interface()
if errI != nil {
return nil, errI.(error)
}
return nil, nil
case [2]bool{true, false}:
// (O)
outI = output[0].Interface()
case [2]bool{true, true}:
// (O, error)
errI := output[1].Interface()
if errI != nil {
return nil, errI.(error)
}
outI = output[0].Interface()
}

bytes, err := encoding.Marshal(h.options.Codec, outI)
Expand All @@ -193,37 +236,4 @@ func (h *objectReflectHandler) Call(ctx *state.Context, bytes []byte) ([]byte, e
return bytes, nil
}

type serviceReflectHandler struct {
reflectHandler
}

var _ state.Handler = (*serviceReflectHandler)(nil)

func (h *serviceReflectHandler) Call(ctx *state.Context, bytes []byte) ([]byte, error) {
input := reflect.New(h.input)

if err := encoding.Unmarshal(h.options.Codec, bytes, input.Interface()); err != nil {
return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest)
}

// we are sure about the fn signature so it's safe to do this
output := h.fn.Call([]reflect.Value{
h.receiver,
reflect.ValueOf(ctxWrapper{ctx}),
input.Elem(),
})

outI := output[0].Interface()
errI := output[1].Interface()
if errI != nil {
return nil, errI.(error)
}

bytes, err := encoding.Marshal(h.options.Codec, outI)
if err != nil {
// we don't use a terminal error here as this is hot-fixable by changing the return type
return nil, fmt.Errorf("failed to serialize output: %w", err)
}

return bytes, nil
}
var _ state.Handler = (*reflectHandler)(nil)
55 changes: 37 additions & 18 deletions reflect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,15 @@ var shared = internal.ServiceHandlerType_SHARED

var tests []reflectTestParams = []reflectTestParams{
{rcvr: validObject{}, serviceName: "validObject", expectedMethods: expectedMethods{
"Greet": &exclusive,
"GreetShared": &shared,
"Greet": &exclusive,
"GreetShared": &shared,
"NoInput": &exclusive,
"NoError": &exclusive,
"NoOutput": &exclusive,
"NoOutputNoError": &exclusive,
"NoInputNoError": &exclusive,
"NoInputNoOutput": &exclusive,
"NoInputNoOutputNoError": &exclusive,
}},
{rcvr: validService{}, serviceName: "validService", expectedMethods: expectedMethods{
"Greet": nil,
Expand All @@ -53,17 +60,11 @@ func TestReflect(t *testing.T) {
}
}()
def := restate.Reflect(test.rcvr, test.opts...)
foundMethods := make([]string, 0, len(def.Handlers()))
for k := range def.Handlers() {
foundMethods = append(foundMethods, k)
}
for k, expectedTyp := range test.expectedMethods {
handler, ok := def.Handlers()[k]
if !ok {
t.Fatalf("missing handler %s", k)
}
require.Equal(t, expectedTyp, handler.HandlerType(), "mismatched handler type")
foundMethods := make(map[string]*internal.ServiceHandlerType, len(def.Handlers()))
for k, foundHandler := range def.Handlers() {
foundMethods[k] = foundHandler.HandlerType()
}
require.Equal(t, test.expectedMethods, foundMethods)
require.Equal(t, test.serviceName, def.Name())
})
}
Expand All @@ -79,22 +80,40 @@ func (validObject) GreetShared(ctx restate.ObjectSharedContext, _ string) (strin
return "", nil
}

func (validObject) SkipInvalidArgCount(ctx restate.ObjectContext) (string, error) {
func (validObject) NoInput(ctx restate.ObjectContext) (string, error) {
return "", nil
}

func (validObject) SkipInvalidCtx(ctx context.Context, _ string) (string, error) {
return "", nil
func (validObject) NoError(ctx restate.ObjectContext, _ string) string {
return ""
}

func (validObject) SkipInvalidError(ctx restate.ObjectContext, _ string) (string, string) {
return "", ""
func (validObject) NoOutput(ctx restate.ObjectContext, _ string) error {
return nil
}

func (validObject) SkipInvalidRetCount(ctx restate.ObjectContext, _ string) string {
func (validObject) NoOutputNoError(ctx restate.ObjectContext, _ string) {
}

func (validObject) NoInputNoError(ctx restate.ObjectContext) string {
return ""
}

func (validObject) NoInputNoOutput(ctx restate.ObjectContext) error {
return nil
}

func (validObject) NoInputNoOutputNoError(ctx restate.ObjectContext) {
}

func (validObject) SkipInvalidCtx(ctx context.Context, _ string) (string, error) {
return "", nil
}

func (validObject) SkipInvalidError(ctx restate.ObjectContext, _ string) (error, string) {
return nil, ""
}

func (validObject) skipUnexported(_ string) (string, error) {
return "", nil
}
Expand Down

0 comments on commit a8ab706

Please sign in to comment.