diff --git a/lnd.go b/lnd.go index cccf30e61c..075b6fb94d 100644 --- a/lnd.go +++ b/lnd.go @@ -395,6 +395,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, shutdownChan <-chan struct{}) error { params, err := waitForWalletPassword( cfg, cfg.RESTListeners, serverOpts, restDialOpts, restProxyDest, tlsCfg, walletUnlockerListeners, + shutdownChan, ) if err != nil { err := fmt.Errorf("unable to set up wallet password "+ @@ -966,7 +967,8 @@ type WalletUnlockParams struct { func waitForWalletPassword(cfg *Config, restEndpoints []net.Addr, serverOpts []grpc.ServerOption, restDialOpts []grpc.DialOption, restProxyDest string, tlsConf *tls.Config, - getListeners rpcListeners) (*WalletUnlockParams, error) { + getListeners rpcListeners, + shutdownChan <-chan struct{}) (*WalletUnlockParams, error) { // Start a gRPC server listening for HTTP/2 connections, solely used // for getting the encryption password from the client. @@ -996,7 +998,7 @@ func waitForWalletPassword(cfg *Config, restEndpoints []net.Addr, } pwService := walletunlocker.New( chainConfig.ChainDir, activeNetParams.Params, !cfg.SyncFreelist, - macaroonFiles, + macaroonFiles, shutdownChan, ) lnrpc.RegisterWalletUnlockerServer(grpcServer, pwService) @@ -1113,6 +1115,10 @@ func waitForWalletPassword(cfg *Config, restEndpoints []net.Addr, return nil, err } + // Now that the wallet has been initialized, we'll close the + // done channel so the call can unblock. + close(initMsg.Done) + return &WalletUnlockParams{ Password: password, Birthday: birthday, @@ -1124,6 +1130,13 @@ func waitForWalletPassword(cfg *Config, restEndpoints []net.Addr, // The wallet has already been created in the past, and is simply being // unlocked. So we'll just return these passphrases. case unlockMsg := <-pwService.UnlockMsgs: + + // Now that we have the parameters, we'll close the done + // channel to allow other operations for the unlocker service. + // + // TODO(roasbeef): push down further? + close(unlockMsg.Done) + return &WalletUnlockParams{ Password: unlockMsg.Passphrase, RecoveryWindow: unlockMsg.RecoveryWindow, diff --git a/walletunlocker/service.go b/walletunlocker/service.go index 90e84bae82..117c8c7a9d 100644 --- a/walletunlocker/service.go +++ b/walletunlocker/service.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "os" + "sync" "time" "github.com/btcsuite/btcd/chaincfg" @@ -54,6 +55,10 @@ type WalletInitMsg struct { // ChanBackups a set of static channel backups that should be received // after the wallet has been initialized. ChanBackups ChannelsToRecover + + // Done is a channel that should be closed once the wallet has been + // fully initialized. + Done chan struct{} } // WalletUnlockMsg is a message sent by the UnlockerService when a user wishes @@ -81,6 +86,10 @@ type WalletUnlockMsg struct { // ChanBackups a set of static channel backups that should be received // after the wallet has been unlocked. ChanBackups ChannelsToRecover + + // Done is a channel that should be closed once the wallet has been + // fully unlocked. + Done chan struct{} } // UnlockerService implements the WalletUnlocker service used to provide lnd @@ -100,11 +109,18 @@ type UnlockerService struct { noFreelistSync bool netParams *chaincfg.Params macaroonFiles []string + + quitChan <-chan struct{} + + // This mutex is only used to guard concurrent access to the external + // RPC calls. This ensures that we don't allow multiple callers to + // init/unlock the wallet. + sync.Mutex } // New creates and returns a new UnlockerService. func New(chainDir string, params *chaincfg.Params, noFreelistSync bool, - macaroonFiles []string) *UnlockerService { + macaroonFiles []string, quitChan <-chan struct{}) *UnlockerService { return &UnlockerService{ InitMsgs: make(chan *WalletInitMsg, 1), @@ -112,6 +128,7 @@ func New(chainDir string, params *chaincfg.Params, noFreelistSync bool, chainDir: chainDir, netParams: params, macaroonFiles: macaroonFiles, + quitChan: quitChan, } } @@ -242,6 +259,9 @@ func extractChanBackups(chanBackups *lnrpc.ChanBackupSnapshot) *ChannelsToRecove func (u *UnlockerService) InitWallet(ctx context.Context, in *lnrpc.InitWalletRequest) (*lnrpc.InitWalletResponse, error) { + u.Lock() + defer u.Unlock() + // Make sure the password meets our constraints. password := in.WalletPassword if err := ValidatePassword(password); err != nil { @@ -293,6 +313,7 @@ func (u *UnlockerService) InitWallet(ctx context.Context, Passphrase: password, WalletSeed: cipherSeed, RecoveryWindow: uint32(recoveryWindow), + Done: make(chan struct{}), } // Before we return the unlock payload, we'll check if we can extract @@ -302,7 +323,19 @@ func (u *UnlockerService) InitWallet(ctx context.Context, initMsg.ChanBackups = *chansToRestore } - u.InitMsgs <- initMsg + select { + case u.InitMsgs <- initMsg: + case <-u.quitChan: + return nil, fmt.Errorf("server shutting down") + } + + // As we want to avoid a possible deadlock scenario, we'll wait the + // daemon to respond that the wallet has been initialized. + select { + case <-initMsg.Done: + case <-u.quitChan: + return nil, fmt.Errorf("server shutting down") + } return &lnrpc.InitWalletResponse{}, nil } @@ -313,6 +346,9 @@ func (u *UnlockerService) InitWallet(ctx context.Context, func (u *UnlockerService) UnlockWallet(ctx context.Context, in *lnrpc.UnlockWalletRequest) (*lnrpc.UnlockWalletResponse, error) { + u.Lock() + defer u.Unlock() + password := in.WalletPassword recoveryWindow := uint32(in.RecoveryWindow) @@ -346,6 +382,7 @@ func (u *UnlockerService) UnlockWallet(ctx context.Context, Passphrase: password, RecoveryWindow: recoveryWindow, Wallet: unlockedWallet, + Done: make(chan struct{}), } // Before we return the unlock payload, we'll check if we can extract @@ -358,7 +395,19 @@ func (u *UnlockerService) UnlockWallet(ctx context.Context, // At this point we was able to open the existing wallet with the // provided password. We send the password over the UnlockMsgs // channel, such that it can be used by lnd to open the wallet. - u.UnlockMsgs <- walletUnlockMsg + select { + case u.UnlockMsgs <- walletUnlockMsg: + case <-u.quitChan: + return nil, fmt.Errorf("server shutting down") + } + + // As we want to avoid a possible deadlock scenario, we'll wait the + // daemon to respond that the wallet has been unlocked. + select { + case <-walletUnlockMsg.Done: + case <-u.quitChan: + return nil, fmt.Errorf("server shutting down") + } return &lnrpc.UnlockWalletResponse{}, nil } @@ -369,6 +418,9 @@ func (u *UnlockerService) UnlockWallet(ctx context.Context, func (u *UnlockerService) ChangePassword(ctx context.Context, in *lnrpc.ChangePasswordRequest) (*lnrpc.ChangePasswordResponse, error) { + u.Lock() + defer u.Unlock() + netDir := btcwallet.NetworkDir(u.chainDir, u.netParams) loader := wallet.NewLoader(u.netParams, netDir, u.noFreelistSync, 0) @@ -431,7 +483,23 @@ func (u *UnlockerService) ChangePassword(ctx context.Context, // Finally, send the new password across the UnlockPasswords channel to // automatically unlock the wallet. - u.UnlockMsgs <- &WalletUnlockMsg{Passphrase: in.NewPassword} + unlockMsg := &WalletUnlockMsg{ + Passphrase: in.NewPassword, + Done: make(chan struct{}), + } + select { + case u.UnlockMsgs <- unlockMsg: + case <-u.quitChan: + return nil, fmt.Errorf("server shutting down") + } + + // As we want to avoid a possible deadlock scenario, we'll wait the + // daemon to respond that the wallet has been unlocked. + select { + case <-unlockMsg.Done: + case <-u.quitChan: + return nil, fmt.Errorf("server shutting down") + } return &lnrpc.ChangePasswordResponse{}, nil } diff --git a/walletunlocker/service_test.go b/walletunlocker/service_test.go index fc331f42c5..822d05ade3 100644 --- a/walletunlocker/service_test.go +++ b/walletunlocker/service_test.go @@ -77,7 +77,9 @@ func TestGenSeed(t *testing.T) { } defer os.RemoveAll(testDir) - service := walletunlocker.New(testDir, testNetParams, true, nil) + service := walletunlocker.New( + testDir, testNetParams, true, nil, make(chan struct{}), + ) // Now that the service has been created, we'll ask it to generate a // new seed for us given a test passphrase. @@ -118,7 +120,9 @@ func TestGenSeedGenerateEntropy(t *testing.T) { defer func() { os.RemoveAll(testDir) }() - service := walletunlocker.New(testDir, testNetParams, true, nil) + service := walletunlocker.New( + testDir, testNetParams, true, nil, make(chan struct{}), + ) // Now that the service has been created, we'll ask it to generate a // new seed for us given a test passphrase. Note that we don't actually @@ -158,7 +162,9 @@ func TestGenSeedInvalidEntropy(t *testing.T) { defer func() { os.RemoveAll(testDir) }() - service := walletunlocker.New(testDir, testNetParams, true, nil) + service := walletunlocker.New( + testDir, testNetParams, true, nil, make(chan struct{}), + ) // Now that the service has been created, we'll ask it to generate a // new seed for us given a test passphrase. However, we'll be using an @@ -196,7 +202,9 @@ func TestInitWallet(t *testing.T) { }() // Create new UnlockerService. - service := walletunlocker.New(testDir, testNetParams, true, nil) + service := walletunlocker.New( + testDir, testNetParams, true, nil, make(chan struct{}), + ) // Once we have the unlocker service created, we'll now instantiate a // new cipher seed instance. @@ -226,10 +234,18 @@ func TestInitWallet(t *testing.T) { AezeedPassphrase: pass, RecoveryWindow: int32(testRecoveryWindow), } - _, err = service.InitWallet(ctx, req) - if err != nil { - t.Fatalf("InitWallet call failed: %v", err) - } + + // As the InitWallet call will block until the operation has been + // completed, we'll execute it in a goroutine so we can check our + // assertions below. + go func() { + // TODO(rosabeef): other option is a goroutine in the calls + // themselves + _, err = service.InitWallet(ctx, req) + if err != nil { + t.Fatalf("InitWallet call failed: %v", err) + } + }() // The same user passphrase, and also the plaintext cipher seed // should be sent over and match exactly. @@ -260,6 +276,10 @@ func TestInitWallet(t *testing.T) { msg.RecoveryWindow) } + // We'll now close the done channel to unlock the service to be + // able to accept another request. + close(msg.Done) + case <-time.After(3 * time.Second): t.Fatalf("password not received") } @@ -297,7 +317,9 @@ func TestCreateWalletInvalidEntropy(t *testing.T) { }() // Create new UnlockerService. - service := walletunlocker.New(testDir, testNetParams, true, nil) + service := walletunlocker.New( + testDir, testNetParams, true, nil, make(chan struct{}), + ) // We'll attempt to init the wallet with an invalid cipher seed and // passphrase. @@ -330,7 +352,9 @@ func TestUnlockWallet(t *testing.T) { }() // Create new UnlockerService. - service := walletunlocker.New(testDir, testNetParams, true, nil) + service := walletunlocker.New( + testDir, testNetParams, true, nil, make(chan struct{}), + ) ctx := context.Background() req := &lnrpc.UnlockWalletRequest{ @@ -356,11 +380,15 @@ func TestUnlockWallet(t *testing.T) { t.Fatalf("expected call to UnlockWallet to fail") } - // With the correct password, we should be able to unlock the wallet. - _, err = service.UnlockWallet(ctx, req) - if err != nil { - t.Fatalf("unable to unlock wallet: %v", err) - } + // We'll unlock the wallet in a new goroutine as UnlockWallet now + // blocks until the wallet has been fully unlocked. + go func() { + // With the correct password, we should be able to unlock the wallet. + _, err = service.UnlockWallet(ctx, req) + if err != nil { + t.Fatalf("unable to unlock wallet: %v", err) + } + }() // Password and recovery window should be sent over the channel. select { @@ -374,6 +402,11 @@ func TestUnlockWallet(t *testing.T) { "got %d", testRecoveryWindow, unlockMsg.RecoveryWindow) } + + // We'll now close the done channel to unlock the service to be + // able to accept another request. + close(unlockMsg.Done) + case <-time.After(3 * time.Second): t.Fatalf("password not received") } @@ -404,7 +437,9 @@ func TestChangeWalletPassword(t *testing.T) { } // Create a new UnlockerService with our temp files. - service := walletunlocker.New(testDir, testNetParams, true, tempFiles) + service := walletunlocker.New( + testDir, testNetParams, true, tempFiles, make(chan struct{}), + ) ctx := context.Background() newPassword := []byte("hunter2???") @@ -450,20 +485,17 @@ func TestChangeWalletPassword(t *testing.T) { t.Fatal("expected call to ChangePassword to fail") } - // When providing the correct wallet's current password and a new - // password that meets the length requirement, the password change - // should succeed. - _, err = service.ChangePassword(ctx, req) - if err != nil { - t.Fatalf("unable to change wallet's password: %v", err) - } - - // The files should no longer exist. - for _, tempFile := range tempFiles { - if _, err := os.Open(tempFile); err == nil { - t.Fatal("file exists but it shouldn't") + // We change the password in a new goroutine, as we expect this case to + // succeed. + go func() { + // When providing the correct wallet's current password and a + // new password that meets the length requirement, the password + // change should succeed. + _, err = service.ChangePassword(ctx, req) + if err != nil { + t.Fatalf("unable to change wallet's password: %v", err) } - } + }() // The new password should be sent over the channel. select { @@ -472,7 +504,20 @@ func TestChangeWalletPassword(t *testing.T) { t.Fatalf("expected to receive password %x, got %x", testPassword, unlockMsg.Passphrase) } + + // We'll now close the done channel to unlock the service to be + // able to accept another request. + close(unlockMsg.Done) + case <-time.After(3 * time.Second): t.Fatalf("password not received") } + + // The files should no longer exist. + for _, tempFile := range tempFiles { + if _, err := os.Open(tempFile); err == nil { + t.Fatal("file exists but it shouldn't") + } + } + }