Skip to content

Commit

Permalink
refactor: Change restorewallet RPC to call wallet.h:RestoreWallet
Browse files Browse the repository at this point in the history
  • Loading branch information
w0xlt committed Dec 8, 2021
1 parent c7b3de7 commit d021bf9
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 33 deletions.
24 changes: 8 additions & 16 deletions src/wallet/rpc/backup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1879,27 +1879,19 @@ RPCHelpMan restorewallet()

auto backup_file = fs::u8path(request.params[1].get_str());

if (!fs::exists(backup_file)) {
throw JSONRPCError(RPC_INVALID_PARAMETER, "Backup file does not exist");
}

std::string wallet_name = request.params[0].get_str();

const fs::path wallet_path = fsbridge::AbsPathJoin(GetWalletDir(), fs::u8path(wallet_name));

if (fs::exists(wallet_path)) {
throw JSONRPCError(RPC_INVALID_PARAMETER, "Wallet name already exists.");
}

if (!TryCreateDirectories(wallet_path)) {
throw JSONRPCError(RPC_WALLET_ERROR, strprintf("Failed to create database path '%s'. Database already exists.", wallet_path.u8string()));
}
std::optional<bool> load_on_start = request.params[2].isNull() ? std::nullopt : std::optional<bool>(request.params[2].get_bool());

auto wallet_file = wallet_path / "wallet.dat";
DatabaseOptions options;
DatabaseStatus status;
options.require_existing = true;
bilingual_str error;
std::vector<bilingual_str> warnings;

fs::copy_file(backup_file, wallet_file, fs::copy_option::fail_if_exists);
const std::shared_ptr<CWallet> wallet = RestoreWallet(context, fs::PathToString(backup_file), wallet_name, load_on_start, options, status, error, warnings);

auto [wallet, warnings] = LoadWalletHelper(context, request.params[2], wallet_name);
HandleWalletError(wallet, status, error);

UniValue obj(UniValue::VOBJ);
obj.pushKV("name", wallet->GetName());
Expand Down
12 changes: 1 addition & 11 deletions src/wallet/rpc/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,8 @@ std::string LabelFromValue(const UniValue& value)
return label;
}

std::tuple<std::shared_ptr<CWallet>, std::vector<bilingual_str>> LoadWalletHelper(WalletContext& context, UniValue load_on_start_param, const std::string wallet_name)
void HandleWalletError(const std::shared_ptr<CWallet> wallet, DatabaseStatus& status, bilingual_str& error)
{
DatabaseOptions options;
DatabaseStatus status;
options.require_existing = true;
bilingual_str error;
std::vector<bilingual_str> warnings;
std::optional<bool> load_on_start = load_on_start_param.isNull() ? std::nullopt : std::optional<bool>(load_on_start_param.get_bool());
std::shared_ptr<CWallet> const wallet = LoadWallet(context, wallet_name, load_on_start, options, status, error, warnings);

if (!wallet) {
// Map bad format to not found, since bad format is returned when the
// wallet directory exists, but doesn't contain a data file.
Expand All @@ -149,6 +141,4 @@ std::tuple<std::shared_ptr<CWallet>, std::vector<bilingual_str>> LoadWalletHelpe
}
throw JSONRPCError(code, error.original);
}

return { wallet, warnings };
}
3 changes: 2 additions & 1 deletion src/wallet/rpc/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

struct bilingual_str;
class CWallet;
enum class DatabaseStatus;
class JSONRPCRequest;
class LegacyScriptPubKeyMan;
class UniValue;
Expand All @@ -37,6 +38,6 @@ bool GetAvoidReuseFlag(const CWallet& wallet, const UniValue& param);
bool ParseIncludeWatchonly(const UniValue& include_watchonly, const CWallet& wallet);
std::string LabelFromValue(const UniValue& value);

std::tuple<std::shared_ptr<CWallet>, std::vector<bilingual_str>> LoadWalletHelper(WalletContext& context, UniValue load_on_start_param, const std::string wallet_name);
void HandleWalletError(const std::shared_ptr<CWallet> wallet, DatabaseStatus& status, bilingual_str& error);

#endif // BITCOIN_WALLET_RPC_UTIL_H
10 changes: 9 additions & 1 deletion src/wallet/rpc/wallet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,15 @@ static RPCHelpMan loadwallet()
WalletContext& context = EnsureWalletContext(request.context);
const std::string name(request.params[0].get_str());

auto [wallet, warnings] = LoadWalletHelper(context, request.params[1], name);
DatabaseOptions options;
DatabaseStatus status;
options.require_existing = true;
bilingual_str error;
std::vector<bilingual_str> warnings;
std::optional<bool> load_on_start = request.params[1].isNull() ? std::nullopt : std::optional<bool>(request.params[1].get_bool());
std::shared_ptr<CWallet> const wallet = LoadWallet(context, name, load_on_start, options, status, error, warnings);

HandleWalletError(wallet, status, error);

UniValue obj(UniValue::VOBJ);
obj.pushKV("name", wallet->GetName());
Expand Down
15 changes: 14 additions & 1 deletion src/wallet/wallet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,12 @@ std::shared_ptr<CWallet> CreateWallet(WalletContext& context, const std::string&

std::shared_ptr<CWallet> RestoreWallet(WalletContext& context, const std::string& backup_file, const std::string& wallet_name, std::optional<bool> load_on_start, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector<bilingual_str>& warnings)
{
if (!fs::exists(fs::u8path(backup_file))) {
error = Untranslated("Backup file does not exist");
status = DatabaseStatus::FAILED_BAD_PATH;
return nullptr;
}

const fs::path wallet_path = fsbridge::AbsPathJoin(GetWalletDir(), fs::u8path(wallet_name));

if (fs::exists(wallet_path) || !TryCreateDirectories(wallet_path)) {
Expand All @@ -370,7 +376,14 @@ std::shared_ptr<CWallet> RestoreWallet(WalletContext& context, const std::string
auto wallet_file = wallet_path / "wallet.dat";
fs::copy_file(backup_file, wallet_file, fs::copy_option::fail_if_exists);

return LoadWallet(context, wallet_name, load_on_start, options, status, error, warnings);
auto wallet = LoadWallet(context, wallet_name, load_on_start, options, status, error, warnings);

if (!wallet) {
fs::remove(wallet_file);
fs::remove(wallet_path);
}

return wallet;
}

/** @defgroup mapWallet
Expand Down
8 changes: 5 additions & 3 deletions test/functional/wallet_backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,15 @@ def restore_nonexistent_wallet(self):
node = self.nodes[3]
nonexistent_wallet_file = os.path.join(self.nodes[0].datadir, 'nonexistent_wallet.bak')
wallet_name = "res0"
assert_raises_rpc_error(-8, "Backup file does not exist", node.restorewallet, wallet_name, nonexistent_wallet_file)
assert_raises_rpc_error(-4, "Backup file does not exist", node.restorewallet, wallet_name, nonexistent_wallet_file)

def restore_wallet_existent_name(self):
node = self.nodes[3]
wallet_file = os.path.join(self.nodes[0].datadir, 'wallet.bak')
backup_file = os.path.join(self.nodes[0].datadir, 'wallet.bak')
wallet_name = "res0"
assert_raises_rpc_error(-8, "Wallet name already exists.", node.restorewallet, wallet_name, wallet_file)
wallet_file = os.path.join(node.datadir, self.chain, 'wallets', wallet_name)
error_message = "Failed to create database path '{}'. Database already exists.".format(wallet_file)
assert_raises_rpc_error(-4, error_message, node.restorewallet, wallet_name, backup_file)

def init_three(self):
self.init_wallet(node=0)
Expand Down

0 comments on commit d021bf9

Please sign in to comment.