diff --git a/server/server.go b/server/server.go index 3662582d..b3e05af8 100644 --- a/server/server.go +++ b/server/server.go @@ -100,6 +100,9 @@ type Server struct { // Calls are inserted into this queue, to be handled // by a goroutine running handleCalls() callQueue *mpsc.Queue[*Call] + + // Handler for custom behavior of unknown methods + HandleUnknownMethod func(m capnp.Method) *Method } // New returns a client hook that makes calls to a set of methods. @@ -126,6 +129,9 @@ func New(methods []Method, brand any, shutdown Shutdowner) *Server { // Send starts a method call. func (srv *Server) Send(ctx context.Context, s capnp.Send) (*capnp.Answer, capnp.ReleaseFunc) { mm := srv.methods.find(s.Method) + if mm == nil && srv.HandleUnknownMethod != nil { + mm = srv.HandleUnknownMethod(s.Method) + } if mm == nil { return capnp.ErrorAnswer(s.Method, capnp.Unimplemented("unimplemented")), func() {} } @@ -150,6 +156,9 @@ func (srv *Server) Send(ctx context.Context, s capnp.Send) (*capnp.Answer, capnp // Recv starts a method call. func (srv *Server) Recv(ctx context.Context, r capnp.Recv) capnp.PipelineCaller { mm := srv.methods.find(r.Method) + if mm == nil && srv.HandleUnknownMethod != nil { + mm = srv.HandleUnknownMethod(r.Method) + } if mm == nil { r.Reject(capnp.Unimplemented("unimplemented")) return nil diff --git a/server/server_test.go b/server/server_test.go index b9578624..d3698aaf 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -5,8 +5,11 @@ import ( "errors" "strings" "sync" + "sync/atomic" "testing" + "github.com/stretchr/testify/require" + "capnproto.org/go/capnp/v3" air "capnproto.org/go/capnp/v3/internal/aircraftlib" "capnproto.org/go/capnp/v3/server" @@ -91,6 +94,47 @@ func TestServerCall(t *testing.T) { } } }) + t.Run("Unimplemented hook", func(t *testing.T) { + t.Parallel() + var echoText = "are you there?" + var proxyReceived atomic.Value + + // start a proxy server with hook + srv := server.New(nil, nil, nil) + srv.HandleUnknownMethod = func(method capnp.Method) *server.Method { + sm := server.Method{ + Method: method, + Impl: nil, + } + sm.Impl = func(ctx context.Context, call *server.Call) error { + echoArgs := air.Echo_echo_Params(call.Args()) + inText, err := echoArgs.In() + require.NoError(t, err) + proxyReceived.Store(inText) + // pretend we received an answer + echo := air.Echo_echo{Call: call} + resp, _ := echo.AllocResults() + err = resp.SetOut(inText) + return err + } + return &sm + } + blankBoot := capnp.NewClient(srv) + echoClient := air.Echo(blankBoot) + defer echoClient.Release() + + ans, finish := echoClient.Echo(context.Background(), func(p air.Echo_echo_Params) error { + err := p.SetIn(echoText) + return err + }) + defer finish() + resp, err := ans.Struct() + answerOut, _ := resp.Out() + rxValue := proxyReceived.Load() + require.Equal(t, echoText, rxValue) + assert.Equal(t, echoText, answerOut) + assert.NoError(t, err, "echo.Echo() error != ; want success") + }) } type callSeq uint32