From d005b200a0664a2791ea76d888c88dded9a688f1 Mon Sep 17 00:00:00 2001 From: carla Date: Tue, 5 Jan 2021 17:00:29 +0200 Subject: [PATCH] client: add option to block until lnd is unlocked This commit adds an optional wait for lnd to unlock to LndServices. The wallet unlocker is not used, because polling the unlock endpoint exacerbates a known deadlock: https://github.com/lightningnetwork/lnd/issues/3631. Instead, a call to the main grpc server is used. The wallet is considered locked when we receive a grpc unimplemented error, because the main server does not become active until the wallet is unlocked. Once the wallet is unlocked, there is a race condition where a query to the main server can return an unavailable code while the server is busy registering. This error is the same as when lnd is just not online at all, so we allow it to occur once (assuming our backoff period will be sufficient) to account for this race while still failing if lnd is consistently offline. --- go.mod | 1 + lnd_services.go | 139 +++++++++++++++++++++++++++++++++++++------ lnd_services_test.go | 113 +++++++++++++++++++++++++++++++++++ 3 files changed, 234 insertions(+), 19 deletions(-) diff --git a/go.mod b/go.mod index 16e308b3..b0814f96 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/btcsuite/btcutil v1.0.2 github.com/btcsuite/btcwallet/wtxmgr v1.2.0 github.com/lightningnetwork/lnd v0.11.0-beta + github.com/stretchr/testify v1.5.1 google.golang.org/grpc v1.24.0 gopkg.in/macaroon.v2 v2.1.0 ) diff --git a/lnd_services.go b/lnd_services.go index 44ecc31c..30a7c58a 100644 --- a/lnd_services.go +++ b/lnd_services.go @@ -25,6 +25,10 @@ var ( // call to find out if lnd is fully synced to its chain backend. chainSyncPollInterval = 5 * time.Second + // defaultUnlockedInterval is the default amount of time we wait between + // checks that the wallet is unlocked. + defaultUnlockedInterval = 2 * time.Second + // minimalCompatibleVersion is the minimum version and build tags // required in lnd to get all functionality implemented in lndclient. // Users can provide their own, specific version if needed. If only a @@ -97,13 +101,23 @@ type LndServicesConfig struct { // block download is still in progress. BlockUntilChainSynced bool + // BlockUntilUnlocked denotes that the NewLndServices function should + // block until lnd is unlocked. + BlockUntilUnlocked bool + + // UnlockInterval sets the interval at which we will query lnd to + // determine whether lnd is unlocked when BlockUntilUnlocked is true. + // This value is optional, and will be replaced with a default if it is + // zero. + UnlockInterval time.Duration + // CallerCtx is an optional context that can be passed if the caller // would like to be able to cancel the long waits involved in starting // up the client, such as waiting for chain sync to complete when - // BlockUntilChainSynced is set to true. If a context is passed in and - // its Done() channel sends a message, the wait for chain sync is - // aborted. This allows a client to still be shut down properly if lnd - // takes a long time to sync. + // BlockUntilChainSynced is set to true, or waiting for lnd to be + // unlocked when BlockUntilUnlocked is set to true. If a context is + // passed in and its Done() channel sends a message, these waits will + // be aborted. This allows a client to still be shut down properly. CallerCtx context.Context } @@ -226,8 +240,19 @@ func NewLndServices(cfg *LndServicesConfig) (*GrpcLndServices, error) { } } + // Get lnd's info, blocking until lnd is unlocked if required. + client := newLightningClient(conn, chainParams, readonlyMac) + info, err := getLndInfo( + cfg.CallerCtx, client, cfg.BlockUntilUnlocked, + cfg.UnlockInterval, + ) + if err != nil { + cleanupConn() + return nil, err + } + nodeAlias, nodeKey, version, err := checkLndCompatibility( - conn, chainParams, readonlyMac, cfg.Network, cfg.CheckVersion, + conn, readonlyMac, info, cfg.Network, cfg.CheckVersion, ) if err != nil { cleanupConn() @@ -374,11 +399,94 @@ func (s *GrpcLndServices) waitForChainSync(ctx context.Context) error { return <-update } +// getLndInfo queries lnd for information about the node it is connected to. +// If the waitForUnlocked boolean is set, it will examine any errors returned +// and make back off if the failure is due to lnd currently being locked. +// Otherwise, it will fail fast on any errors returned. +func getLndInfo(ctx context.Context, client LightningClient, + waitForUnlocked bool, waitInterval time.Duration) (*Info, error) { + + // allowUnavailable is a bool that tracks whether we allow lnd to + // respond with an unavailable error. This bool is used to track a race + // condition after unlock where lnd briefly responds with unavailable. + // This should not be confused with the case where lnd is just down, + // so we only allow unavailable errors once directly after an unlock. + var allowUnavailable bool + + if waitInterval == 0 { + waitInterval = defaultUnlockedInterval + } + + if ctx == nil { + ctx = context.Background() + } + + if waitForUnlocked { + log.Info("Waiting for lnd to unlock") + } + + for { + info, err := client.GetInfo(ctx) + if err == nil { + return info, nil + } + + // If we do not want to wait for lnd to be unlocked, we just + // fail immediately on any error. + if !waitForUnlocked { + return nil, err + } + + // If we do not get a rpc error code, something else is wrong + // with the call, so we fail. + rpcErrorCode, ok := status.FromError(err) + if !ok { + return nil, err + } + + // If the main rpc server is unimplemented, we need to fall + // through to our back off because lnd is still locked. + switch rpcErrorCode.Code() { + case codes.Unimplemented: + allowUnavailable = true + + // If the server is unavailable, we check whether this error + // directly follows an error caused by the wallet being locked. + // If this is the case, we allow a single unavailable error to + // account for a race where the main rpc server throws this + // error after unlock. Otherwise, it is likely that lnd is down, + // so we fail. + case codes.Unavailable: + if !allowUnavailable { + return nil, err + } + + log.Info("Lnd unavailable, allowing single backoff") + + // Do not allow another unavailable error after this. + allowUnavailable = false + + default: + return nil, err + } + + // At this point, we know lnd is locked, so we wait for our + // interval, exiting if context is cancelled. + select { + case <-ctx.Done(): + return nil, ctx.Err() + + case <-time.After(waitInterval): + + } + } +} + // checkLndCompatibility makes sure the connected lnd instance is running on the // correct network, has the version RPC implemented, is the correct minimal // version and supports all required build tags/subservers. -func checkLndCompatibility(conn *grpc.ClientConn, chainParams *chaincfg.Params, - readonlyMac serializedMacaroon, network Network, +func checkLndCompatibility(conn *grpc.ClientConn, + readonlyMac serializedMacaroon, info *Info, network Network, minVersion *verrpc.Version) (string, [33]byte, *verrpc.Version, error) { // onErr is a closure that simplifies returning multiple values in the @@ -396,24 +504,17 @@ func checkLndCompatibility(conn *grpc.ClientConn, chainParams *chaincfg.Params, return "", [33]byte{}, nil, newErr } - // We use our own clients with a readonly macaroon here, because we know - // that's all we need for the checks. - lightningClient := newLightningClient(conn, chainParams, readonlyMac) - versionerClient := newVersionerClient(conn, readonlyMac) - - // With our readonly macaroon obtained, we'll ensure that the network - // for lnd matches our expected network. - info, err := lightningClient.GetInfo(context.Background()) - if err != nil { - err := fmt.Errorf("unable to get info for lnd node: %v", err) - return onErr(err) - } + // Ensure that the network for lnd matches our expected network. if string(network) != info.Network { err := fmt.Errorf("network mismatch with connected lnd node, "+ "wanted '%s', got '%s'", network, info.Network) return onErr(err) } + // We use our own clients with a readonly macaroon here, because we know + // that's all we need for the checks. + versionerClient := newVersionerClient(conn, readonlyMac) + // Now let's also check the version of the connected lnd node. version, err := checkVersionCompatibility(versionerClient, minVersion) if err != nil { diff --git a/lnd_services_test.go b/lnd_services_test.go index 30bc8be2..f898a3f2 100644 --- a/lnd_services_test.go +++ b/lnd_services_test.go @@ -2,9 +2,11 @@ package lndclient import ( "context" + "errors" "testing" "github.com/lightningnetwork/lnd/lnrpc/verrpc" + "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -156,3 +158,114 @@ func TestLndVersionCheckComparison(t *testing.T) { }) } } + +// lockLNDMock is a mock lightning client which mocks calls to getinfo to +// determine the unlocked state of lnd. +type lockLNDMock struct { + LightningClient + + callCount int + errors []error +} + +// GetInfo mocks a call to getinfo, using our call count to get the error for +// this call as the index in our pre-set error slice. +func (l *lockLNDMock) GetInfo(ctx context.Context) (*Info, error) { + // Our actual call would use ctx, so add a panic to reflect that. + if ctx == nil { + panic("nil context for getinfo") + } + + err := l.errors[l.callCount] + + l.callCount++ + + return nil, err +} + +func newLockLndMock(errors []error) *lockLNDMock { + return &lockLNDMock{ + errors: errors, + } +} + +// TestGetLndInfo tests our logic for querying lnd for information in the case +// where we wait for the wallet to unlock, and when we fail fast. +func TestGetLndInfo(t *testing.T) { + // Override our default so that we don't have long waits in tests. + defaultUnlockedInterval = 1 + + var ( + ctx = context.Background() + nonNilErr = errors.New("failed") + unlockErr = status.Error(codes.Unimplemented, "unimpl") + unavailErr = status.Error(codes.Unavailable, "unavail") + ) + + tests := []struct { + name string + context context.Context + waitUnlocked bool + errors []error + expected error + }{ + { + name: "no error", + context: ctx, + errors: []error{nil}, + expected: nil, + }, + { + name: "nil context", + errors: []error{nil}, + expected: nil, + }, + { + name: "do not wait for unlock", + errors: []error{unlockErr}, + expected: unlockErr, + }, + { + name: "wait for unlock", + waitUnlocked: true, + errors: []error{unlockErr, nil}, + expected: nil, + }, + { + name: "wait for unlock with race", + waitUnlocked: true, + errors: []error{unlockErr, unavailErr, nil}, + expected: nil, + }, { + name: "lnd not available without unlock", + waitUnlocked: true, + errors: []error{unavailErr}, + expected: unavailErr, + }, + { + name: "lnd not available long term", + waitUnlocked: true, + errors: []error{unavailErr, unavailErr}, + expected: unavailErr, + }, + { + name: "other error", + waitUnlocked: true, + errors: []error{nonNilErr}, + expected: nonNilErr, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + mock := newLockLndMock(test.errors) + + _, err := getLndInfo( + test.context, mock, test.waitUnlocked, 0, + ) + require.Equal(t, test.expected, err) + }) + } +}