Skip to content

Commit

Permalink
glightning: add shutdown subscription
Browse files Browse the repository at this point in the history
Adds a new shutdown subscription to the plugin system,
allowing plugins to be notified when lightningd is shutting down.
This enables plugins to perform any necessary cleanup
or finalization tasks before the daemon exits.
  • Loading branch information
YusukeShimizu committed Nov 18, 2024
1 parent 08a556b commit 0552053
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 25 deletions.
45 changes: 37 additions & 8 deletions glightning/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ const (
_Forward Subscription = "forward_event"
_SendPaySuccess Subscription = "sendpay_success"
_SendPayFailure Subscription = "sendpay_failure"
_Shutdown Subscription = "shutdown"
_PeerConnected Hook = "peer_connected"
_DbWrite Hook = "db_write"
_InvoicePayment Hook = "invoice_payment"
Expand All @@ -38,7 +39,7 @@ const (
var lightningMethodRegistry map[string]*jrpc2.Method

// The custommsg plugin hook is the receiving counterpart to the dev-sendcustommsg RPC method
//and allows plugins to handle messages that are not handled internally.
// and allows plugins to handle messages that are not handled internally.
type CustomMsgReceivedEvent struct {
PeerId string `json:"peer_id"`
Payload string `json:"payload"`
Expand Down Expand Up @@ -81,8 +82,9 @@ func (pc *CustomMsgReceivedEvent) Fail() *CustomMsgReceivedResponse {
}

// This hook is called whenever a peer has connected and successfully completed
// the cryptographic handshake. The parameters have the following structure if
// there is a channel with the peer:
//
// the cryptographic handshake. The parameters have the following structure if
// there is a channel with the peer:
type PeerConnectedEvent struct {
Peer PeerEvent `json:"peer"`
hook func(*PeerConnectedEvent) (*PeerConnectedResponse, error)
Expand Down Expand Up @@ -450,10 +452,11 @@ func (rc *RpcCommandEvent) ReturnError(errMsg string, errCode int) (*RpcCommandR
// its result determines how `lightningd` should treat that HTLC.
//
// Warning: `lightningd` will replay the HTLCs for which it doesn't have a final
// verdict during startup. This means that, if the plugin response wasn't
// processed before the HTLC was forwarded, failed, or resolved, then the plugin
// may see the same HTLC again during startup. It is therefore paramount that the
// plugin is idempotent if it talks to an external system.
//
// verdict during startup. This means that, if the plugin response wasn't
// processed before the HTLC was forwarded, failed, or resolved, then the plugin
// may see the same HTLC again during startup. It is therefore paramount that the
// plugin is idempotent if it talks to an external system.
type HtlcAcceptedEvent struct {
Onion Onion `json:"onion"`
Htlc HtlcOffer `json:"htlc"`
Expand Down Expand Up @@ -763,6 +766,25 @@ func (e *WarnEvent) Call() (jrpc2.Result, error) {
return nil, nil
}

type ShutdownEvent struct {
cb func()
}

func (e *ShutdownEvent) Name() string {
return string(_Shutdown)
}

func (e *ShutdownEvent) New() interface{} {
return &ShutdownEvent{
cb: e.cb,
}
}

func (e *ShutdownEvent) Call() (jrpc2.Result, error) {
e.cb()
return nil, nil
}

type OptionType string

const _String OptionType = "string"
Expand Down Expand Up @@ -1182,7 +1204,8 @@ func (p *Plugin) Log(message string, level LogLevel) {
}

// Map for registering hooks. Not the *most* elegant but
// it'll do for now.
//
// it'll do for now.
type Hooks struct {
PeerConnected func(*PeerConnectedEvent) (*PeerConnectedResponse, error)
DbWrite func(*DbWriteEvent) (*DbWriteResponse, error)
Expand Down Expand Up @@ -1559,6 +1582,12 @@ func (p *Plugin) SubscribeForwardings(cb func(c *Forwarding)) {
})
}

func (p *Plugin) SubscribeShutdown(cb func()) {
p.subscribe(&ShutdownEvent{
cb: cb,
})
}

func (p *Plugin) subscribe(subscription jrpc2.ServerMethod) {
p.server.Register(subscription)
p.subscriptions = append(p.subscriptions, subscription.Name())
Expand Down
26 changes: 26 additions & 0 deletions glightning/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,32 @@ func TestSubscription_Disconnected(t *testing.T) {
runTest(t, plugin, msg+"\n\n", "")
}

func TestSubscription_Shutdown(t *testing.T) {
var wg sync.WaitGroup
defer await(t, &wg)
shutdownCalled := make(chan struct{})

wg.Add(1)
initFn := getInitFunc(t, func(t *testing.T, options map[string]glightning.Option, config *glightning.Config) {
t.Error("Should not have called init when calling get manifest")
})
plugin := glightning.NewPlugin(initFn)
plugin.SubscribeShutdown(func() {
defer wg.Done()
shutdownCalled <- struct{}{}
})

msg := `{"jsonrpc":"2.0","method":"shutdown"}`

runTest(t, plugin, msg+"\n\n", "")

select {
case <-shutdownCalled:
case <-time.After(1 * time.Second):
t.Fatal("SubscribeShutdown was not called")
}
}

func await(t *testing.T, wg *sync.WaitGroup) {
awaitWithTimeout(t, wg, 1*time.Second)
}
Expand Down
60 changes: 43 additions & 17 deletions jrpc2/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,16 @@ type Server struct {
registry sync.Map // map[string]ServerMethod
outQueue chan interface{}
shutdown bool
// shutdownChan is closed when the Shutdown is called.
// This is used to signal the listener to stop listening.
shutdownChan chan interface{}
}

func NewServer() *Server {
server := &Server{}
server.outQueue = make(chan interface{})
server.shutdown = false
server.shutdownChan = make(chan interface{})
return server
}

Expand Down Expand Up @@ -71,6 +75,7 @@ func (s *Server) StartUp(in, out *os.File) error {

func (s *Server) Shutdown() {
s.shutdown = true
close(s.shutdownChan)
close(s.outQueue)
}

Expand Down Expand Up @@ -105,25 +110,46 @@ func (s *Server) listen(in io.Reader) error {
buf := make([]byte, 1024)
scanner.Buffer(buf, MaxIntakeBuffer)
scanner.Split(scanDoubleNewline)
for scanner.Scan() && !s.shutdown {
msg := scanner.Bytes()
if debugIO(true) {
log.Println(string(msg))

msgChan := make(chan []byte)
errChan := make(chan error)

// listen reads messages from the provided io.Reader and processes them.
// It uses a bufio.Scanner to read messages separated by double newline
// characters. Messages are processed concurrently using goroutines.
// The function listens for shutdown signals via the shutdownChan channel.
go func() {
for scanner.Scan() {
msg := scanner.Bytes()
msg_buf := make([]byte, len(msg))
copy(msg_buf, msg)
msgChan <- msg_buf
}
if err := scanner.Err(); err != nil {
errChan <- err
}
close(msgChan)
close(errChan)
}()

for {
select {
case msg, ok := <-msgChan:
if !ok {
return nil
}
if debugIO(true) {
log.Println(string(msg))
}
go processMsg(s, msg)
case err := <-errChan:
if err != nil {
return err
}
case <-s.shutdownChan:
return nil
}
// pass down a copy so things stay sane
msg_buf := make([]byte, len(msg))
copy(msg_buf, msg)
// todo: send this over a channel
// for processing, so the number
// of things we process at once
// is more easy to control
go processMsg(s, msg_buf)
}
if err := scanner.Err(); err != nil {
log.Fatal(err)
return err
}
return nil
}

func (s *Server) setupWriteQueue(outWriter io.Writer) {
Expand Down
52 changes: 52 additions & 0 deletions jrpc2/server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package jrpc2

import (
"os"
"testing"
"time"
)

func TestServer_StartUp(t *testing.T) {
t.Parallel()
in, err := os.CreateTemp("", "input")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(in.Name())
defer in.Close()

out, err := os.CreateTemp("", "output")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(out.Name())
defer out.Close()

server := NewServer()
err = server.StartUp(in, out)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
}

func TestServer_Shutdown(t *testing.T) {
t.Parallel()
server := NewServer()
resultChan := make(chan error)

go func() {
resultChan <- server.StartUp(os.Stdin, os.Stdout)
}()

time.Sleep(5 * time.Second) // Give some time for the server to start

server.Shutdown()
select {
case err := <-resultChan:
if err != nil {
t.Fatalf("Server startup failed: %v", err)
}
case <-time.After(5 * time.Second):
t.Fatalf("Server startup timed out")
}
}

0 comments on commit 0552053

Please sign in to comment.