diff --git a/glightning/plugin.go b/glightning/plugin.go index 47c9dd9..da111b9 100644 --- a/glightning/plugin.go +++ b/glightning/plugin.go @@ -26,6 +26,7 @@ const ( _Forward Subscription = "forward_event" _SendPaySuccess Subscription = "sendpay_success" _SendPayFailure Subscription = "sendpay_failure" + _Shutdown Subscription = "shutdown" _PeerConnected Hook = "peer_connected" _DbWrite Hook = "db_write" _InvoicePayment Hook = "invoice_payment" @@ -38,7 +39,7 @@ const ( var lightningMethodRegistry map[string]*jrpc2.Method // The custommsg plugin hook is the receiving counterpart to the dev-sendcustommsg RPC method -//and allows plugins to handle messages that are not handled internally. +// and allows plugins to handle messages that are not handled internally. type CustomMsgReceivedEvent struct { PeerId string `json:"peer_id"` Payload string `json:"payload"` @@ -81,8 +82,9 @@ func (pc *CustomMsgReceivedEvent) Fail() *CustomMsgReceivedResponse { } // This hook is called whenever a peer has connected and successfully completed -// the cryptographic handshake. The parameters have the following structure if -// there is a channel with the peer: +// +// the cryptographic handshake. The parameters have the following structure if +// there is a channel with the peer: type PeerConnectedEvent struct { Peer PeerEvent `json:"peer"` hook func(*PeerConnectedEvent) (*PeerConnectedResponse, error) @@ -450,10 +452,11 @@ func (rc *RpcCommandEvent) ReturnError(errMsg string, errCode int) (*RpcCommandR // its result determines how `lightningd` should treat that HTLC. // // Warning: `lightningd` will replay the HTLCs for which it doesn't have a final -// verdict during startup. This means that, if the plugin response wasn't -// processed before the HTLC was forwarded, failed, or resolved, then the plugin -// may see the same HTLC again during startup. It is therefore paramount that the -// plugin is idempotent if it talks to an external system. +// +// verdict during startup. This means that, if the plugin response wasn't +// processed before the HTLC was forwarded, failed, or resolved, then the plugin +// may see the same HTLC again during startup. It is therefore paramount that the +// plugin is idempotent if it talks to an external system. type HtlcAcceptedEvent struct { Onion Onion `json:"onion"` Htlc HtlcOffer `json:"htlc"` @@ -763,6 +766,25 @@ func (e *WarnEvent) Call() (jrpc2.Result, error) { return nil, nil } +type ShutdownEvent struct { + cb func() +} + +func (e *ShutdownEvent) Name() string { + return string(_Shutdown) +} + +func (e *ShutdownEvent) New() interface{} { + return &ShutdownEvent{ + cb: e.cb, + } +} + +func (e *ShutdownEvent) Call() (jrpc2.Result, error) { + e.cb() + return nil, nil +} + type OptionType string const _String OptionType = "string" @@ -1182,7 +1204,8 @@ func (p *Plugin) Log(message string, level LogLevel) { } // Map for registering hooks. Not the *most* elegant but -// it'll do for now. +// +// it'll do for now. type Hooks struct { PeerConnected func(*PeerConnectedEvent) (*PeerConnectedResponse, error) DbWrite func(*DbWriteEvent) (*DbWriteResponse, error) @@ -1559,6 +1582,12 @@ func (p *Plugin) SubscribeForwardings(cb func(c *Forwarding)) { }) } +func (p *Plugin) SubscribeShutdown(cb func()) { + p.subscribe(&ShutdownEvent{ + cb: cb, + }) +} + func (p *Plugin) subscribe(subscription jrpc2.ServerMethod) { p.server.Register(subscription) p.subscriptions = append(p.subscriptions, subscription.Name()) diff --git a/glightning/plugin_test.go b/glightning/plugin_test.go index d9d7b88..5e62ad5 100644 --- a/glightning/plugin_test.go +++ b/glightning/plugin_test.go @@ -756,6 +756,32 @@ func TestSubscription_Disconnected(t *testing.T) { runTest(t, plugin, msg+"\n\n", "") } +func TestSubscription_Shutdown(t *testing.T) { + var wg sync.WaitGroup + defer await(t, &wg) + shutdownCalled := make(chan struct{}) + + wg.Add(1) + initFn := getInitFunc(t, func(t *testing.T, options map[string]glightning.Option, config *glightning.Config) { + t.Error("Should not have called init when calling get manifest") + }) + plugin := glightning.NewPlugin(initFn) + plugin.SubscribeShutdown(func() { + defer wg.Done() + shutdownCalled <- struct{}{} + }) + + msg := `{"jsonrpc":"2.0","method":"shutdown"}` + + runTest(t, plugin, msg+"\n\n", "") + + select { + case <-shutdownCalled: + case <-time.After(1 * time.Second): + t.Fatal("SubscribeShutdown was not called") + } +} + func await(t *testing.T, wg *sync.WaitGroup) { awaitWithTimeout(t, wg, 1*time.Second) } diff --git a/jrpc2/server.go b/jrpc2/server.go index ce411fb..429d2e9 100644 --- a/jrpc2/server.go +++ b/jrpc2/server.go @@ -31,12 +31,16 @@ type Server struct { registry sync.Map // map[string]ServerMethod outQueue chan interface{} shutdown bool + // shutdownChan is closed when the Shutdown is called. + // This is used to signal the listener to stop listening. + shutdownChan chan interface{} } func NewServer() *Server { server := &Server{} server.outQueue = make(chan interface{}) server.shutdown = false + server.shutdownChan = make(chan interface{}) return server } @@ -71,6 +75,7 @@ func (s *Server) StartUp(in, out *os.File) error { func (s *Server) Shutdown() { s.shutdown = true + close(s.shutdownChan) close(s.outQueue) } @@ -105,25 +110,46 @@ func (s *Server) listen(in io.Reader) error { buf := make([]byte, 1024) scanner.Buffer(buf, MaxIntakeBuffer) scanner.Split(scanDoubleNewline) - for scanner.Scan() && !s.shutdown { - msg := scanner.Bytes() - if debugIO(true) { - log.Println(string(msg)) + + msgChan := make(chan []byte) + errChan := make(chan error) + + // listen reads messages from the provided io.Reader and processes them. + // It uses a bufio.Scanner to read messages separated by double newline + // characters. Messages are processed concurrently using goroutines. + // The function listens for shutdown signals via the shutdownChan channel. + go func() { + for scanner.Scan() { + msg := scanner.Bytes() + msg_buf := make([]byte, len(msg)) + copy(msg_buf, msg) + msgChan <- msg_buf + } + if err := scanner.Err(); err != nil { + errChan <- err + } + close(msgChan) + close(errChan) + }() + + for { + select { + case msg, ok := <-msgChan: + if !ok { + return nil + } + if debugIO(true) { + log.Println(string(msg)) + } + go processMsg(s, msg) + case err := <-errChan: + if err != nil { + return err + } + case <-s.shutdownChan: + return nil } - // pass down a copy so things stay sane - msg_buf := make([]byte, len(msg)) - copy(msg_buf, msg) - // todo: send this over a channel - // for processing, so the number - // of things we process at once - // is more easy to control - go processMsg(s, msg_buf) - } - if err := scanner.Err(); err != nil { - log.Fatal(err) - return err } - return nil } func (s *Server) setupWriteQueue(outWriter io.Writer) { diff --git a/jrpc2/server_test.go b/jrpc2/server_test.go new file mode 100644 index 0000000..1176af8 --- /dev/null +++ b/jrpc2/server_test.go @@ -0,0 +1,52 @@ +package jrpc2 + +import ( + "os" + "testing" + "time" +) + +func TestServer_StartUp(t *testing.T) { + t.Parallel() + in, err := os.CreateTemp("", "input") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + defer os.Remove(in.Name()) + defer in.Close() + + out, err := os.CreateTemp("", "output") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + defer os.Remove(out.Name()) + defer out.Close() + + server := NewServer() + err = server.StartUp(in, out) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } +} + +func TestServer_Shutdown(t *testing.T) { + t.Parallel() + server := NewServer() + resultChan := make(chan error) + + go func() { + resultChan <- server.StartUp(os.Stdin, os.Stdout) + }() + + time.Sleep(5 * time.Second) // Give some time for the server to start + + server.Shutdown() + select { + case err := <-resultChan: + if err != nil { + t.Fatalf("Server startup failed: %v", err) + } + case <-time.After(5 * time.Second): + t.Fatalf("Server startup timed out") + } +}