Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove pyln-client json monkey patch for _msat fields to Millisatoshis #6865

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 13 additions & 73 deletions contrib/pyln-client/pyln/client/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,15 @@
from typing import Optional, Union


def _patched_default(self, obj):
return getattr(obj.__class__, "to_json", _patched_default.default)(obj)

def to_json_default(self, obj):
"""
Try to use .to_json() if available, otherwise use the normal JSON default method.
"""
return getattr(obj.__class__, "to_json", old_json_default)(obj)

def monkey_patch_json(patch=True):
is_patched = JSONEncoder.default == _patched_default

if patch and not is_patched:
_patched_default.default = JSONEncoder.default # Save unmodified
JSONEncoder.default = _patched_default # Replace it.
elif not patch and is_patched:
JSONEncoder.default = _patched_default.default
old_json_default = JSONEncoder.default
JSONEncoder.default = to_json_default


class RpcError(ValueError):
Expand All @@ -41,8 +38,7 @@ class Millisatoshi:
"""
A subtype to represent thousandths of a satoshi.

Many JSON API fields are expressed in millisatoshis: these automatically
get turned into Millisatoshi types. Converts to and from int.
If you put this in an object, converting to JSON automatically makes it an "...msat" string, so you can safely hand it even to our APIs which treat raw numbers as satoshis. Converts to and from int.
"""
def __init__(self, v: Union[int, str, Decimal]):
"""
Expand Down Expand Up @@ -286,10 +282,8 @@ def __del__(self) -> None:


class UnixDomainSocketRpc(object):
def __init__(self, socket_path, executor=None, logger=logging, encoder_cls=json.JSONEncoder, decoder=json.JSONDecoder(), caller_name=None):
def __init__(self, socket_path, executor=None, logger=logging, caller_name=None):
self.socket_path = socket_path
self.encoder_cls = encoder_cls
self.decoder = decoder
self.executor = executor
self.logger = logger
self._notify = None
Expand All @@ -303,7 +297,7 @@ def __init__(self, socket_path, executor=None, logger=logging, encoder_cls=json.
self.next_id = 1

def _writeobj(self, sock, obj):
s = json.dumps(obj, ensure_ascii=False, cls=self.encoder_cls)
s = json.dumps(obj, ensure_ascii=False)
sock.sendall(bytearray(s, 'UTF-8'))

def _readobj(self, sock, buff=b''):
Expand All @@ -318,7 +312,7 @@ def _readobj(self, sock, buff=b''):
return {'error': 'Connection to RPC server lost.'}, buff
else:
buff = parts[1]
obj, _ = self.decoder.raw_decode(parts[0].decode("UTF-8"))
obj, _ = json.JSONDecoder().raw_decode(parts[0].decode("UTF-8"))
return obj, buff

def __getattr__(self, name):
Expand Down Expand Up @@ -480,67 +474,13 @@ class LightningRpc(UnixDomainSocketRpc):
between calls, but it does not (yet) support concurrent calls.
"""

class LightningJSONEncoder(json.JSONEncoder):
def default(self, o):
try:
return o.to_json()
except NameError:
pass
return json.JSONEncoder.default(self, o)

class LightningJSONDecoder(json.JSONDecoder):
def __init__(self, *, object_hook=None, parse_float=None,
parse_int=None, parse_constant=None,
strict=True, object_pairs_hook=None,
patch_json=True):
self.object_hook_next = object_hook
super().__init__(object_hook=self.millisatoshi_hook, parse_float=parse_float, parse_int=parse_int, parse_constant=parse_constant, strict=strict, object_pairs_hook=object_pairs_hook)

@staticmethod
def replace_amounts(obj):
"""
Recursively replace _msat fields with appropriate values with Millisatoshi.
"""
if isinstance(obj, dict):
for k, v in obj.items():
# Objects ending in msat are not treated specially!
if k.endswith('msat') and not isinstance(v, dict):
if isinstance(v, list):
obj[k] = [Millisatoshi(e) for e in v]
# FIXME: Deprecated "listconfigs" gives two 'null' fields:
# "lease-fee-base-msat": null,
# "channel-fee-max-base-msat": null,
# FIXME: Removed for v23.08, delete this code in 24.08?
elif v is None:
obj[k] = None
else:
obj[k] = Millisatoshi(v)
else:
obj[k] = LightningRpc.LightningJSONDecoder.replace_amounts(v)
elif isinstance(obj, list):
obj = [LightningRpc.LightningJSONDecoder.replace_amounts(e) for e in obj]

return obj

def millisatoshi_hook(self, obj):
obj = LightningRpc.LightningJSONDecoder.replace_amounts(obj)
if self.object_hook_next:
obj = self.object_hook_next(obj)
return obj

def __init__(self, socket_path, executor=None, logger=logging,
patch_json=True):
def __init__(self, socket_path, executor=None, logger=logging):
super().__init__(
socket_path,
executor,
logger,
self.LightningJSONEncoder,
self.LightningJSONDecoder()
logger
)

if patch_json:
monkey_patch_json(patch=True)

def addgossip(self, message):
"""
Inject this (hex-encoded) gossip message.
Expand Down
1 change: 0 additions & 1 deletion contrib/pyln-client/pyln/client/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,6 @@ def _write_locked(self, obj: JSONType) -> None:
# then utf8 ourselves.
s = bytes(json.dumps(
obj,
cls=LightningRpc.LightningJSONEncoder,
ensure_ascii=False
) + "\n\n", encoding='utf-8')
with self.write_lock:
Expand Down
4 changes: 2 additions & 2 deletions contrib/pyln-testing/pyln/testing/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,8 @@ def is_msat_request(checker, instance):
return False

def is_msat_response(checker, instance):
"""An integer, but we convert to Millisatoshi in JSON parsing"""
return type(instance) is Millisatoshi
"""A positive integer"""
return type(instance) is int and instance >= 0

def is_txid(checker, instance):
"""Bitcoin transaction ID"""
Expand Down
3 changes: 1 addition & 2 deletions contrib/pyln-testing/pyln/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,12 +684,11 @@ class PrettyPrintingLightningRpc(LightningRpc):
Also validates (optional) schemas for us.
"""
def __init__(self, socket_path, executor=None, logger=logging,
patch_json=True, jsonschemas={}):
jsonschemas={}):
super().__init__(
socket_path,
executor,
logger,
patch_json,
)
self.jsonschemas = jsonschemas
self.check_request_schemas = True
Expand Down
4 changes: 2 additions & 2 deletions tests/test_closing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3716,7 +3716,7 @@ def test_closing_anchorspend_htlc_tx_rbf(node_factory, bitcoind):
assert 'anchors_zero_fee_htlc_tx/even' in only_one(l1.rpc.listpeerchannels()['channels'])['channel_type']['names']

# We reduce l1's UTXOs so it's forced to use more than one UTXO to push.
fundsats = int(only_one(l1.rpc.listfunds()['outputs'])['amount_msat'].to_satoshi())
fundsats = int(Millisatoshi(only_one(l1.rpc.listfunds()['outputs'])['amount_msat']).to_satoshi())
psbt = l1.rpc.fundpsbt("all", "1000perkw", 1000)['psbt']
# Pay 5k sats in fees, send most to l2
psbt = l1.rpc.addpsbtoutput(fundsats - 20000 - 5000, psbt, destination=l2.rpc.newaddr()['bech32'])['psbt']
Expand Down Expand Up @@ -3909,7 +3909,7 @@ def test_peer_anchor_push(node_factory, bitcoind, executor, chainparams):
wait_for_announce=True)

# We splinter l2's funds so it's forced to use more than one UTXO to push.
fundsats = int(only_one(l2.rpc.listfunds()['outputs'])['amount_msat'].to_satoshi())
fundsats = int(Millisatoshi(only_one(l2.rpc.listfunds()['outputs'])['amount_msat']).to_satoshi())
OUTPUT_SAT = 10000
NUM_OUTPUTS = 10
psbt = l2.rpc.fundpsbt("all", "1000perkw", 1000)['psbt']
Expand Down
2 changes: 1 addition & 1 deletion tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ def test_multiplexed_rpc(node_factory):
# (delaying completion should mean we don't see the other commands intermingled).
for i in commands:
obj, buff = l1.rpc._readobj(sock, buff)
assert obj['id'] == l1.rpc.decoder.decode(i.decode("UTF-8"))['id']
assert obj['id'] == json.loads(i.decode("UTF-8"))['id']
sock.close()


Expand Down
16 changes: 7 additions & 9 deletions tests/test_pay.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,11 @@ def test_pay_amounts(node_factory):

invoice = only_one(l2.rpc.listinvoices('test_pay_amounts')['invoices'])

assert isinstance(invoice['amount_msat'], Millisatoshi)
assert invoice['amount_msat'] == Millisatoshi(123000)

l1.dev_pay(inv, dev_use_shadow=False)

invoice = only_one(l2.rpc.listinvoices('test_pay_amounts')['invoices'])
assert isinstance(invoice['amount_received_msat'], Millisatoshi)
assert invoice['amount_received_msat'] >= Millisatoshi(123000)


Expand Down Expand Up @@ -2844,9 +2842,9 @@ def test_shadow_routing(node_factory):
inv = l3.rpc.invoice(amount, "{}".format(i), "test")["bolt11"]
total_amount += l1.rpc.pay(inv)["amount_sent_msat"]

assert total_amount.millisatoshis > n_payments * amount
assert total_amount > n_payments * amount
# Test that the added amount isn't absurd
assert total_amount.millisatoshis < (n_payments * amount) * (1 + 0.01)
assert total_amount < int((n_payments * amount) * (1 + 0.01))

# FIXME: Test cltv delta too ?

Expand Down Expand Up @@ -3732,11 +3730,11 @@ def test_pay_peer(node_factory, bitcoind):
def spendable(n1, n2):
chan = n1.rpc.listpeerchannels(n2.info['id'])['channels'][0]
avail = chan['spendable_msat']
return avail
return Millisatoshi(avail)

amt = Millisatoshi(10**8)
# How many payments do we expect to go through directly?
direct = spendable(l1, l2).millisatoshis // amt.millisatoshis
direct = spendable(l1, l2) // amt

# Remember the l1 -> l3 capacity, it should not change until we run out of
# direct capacity.
Expand Down Expand Up @@ -3797,8 +3795,8 @@ def test_mpp_adaptive(node_factory, bitcoind):
# Make sure neither channel can fit the payment by itself.
c12 = l1.rpc.listpeerchannels(l2.info['id'])['channels'][0]
c34 = l3.rpc.listpeerchannels(l4.info['id'])['channels'][0]
assert(c12['spendable_msat'].millisatoshis < amt)
assert(c34['spendable_msat'].millisatoshis < amt)
assert(c12['spendable_msat'] < amt)
assert(c34['spendable_msat'] < amt)

# Make sure all HTLCs entirely resolved before we mine more blocks!
def all_htlcs(n):
Expand Down Expand Up @@ -3872,7 +3870,7 @@ def test_pay_fail_unconfirmed_channel(node_factory, bitcoind):
l2.rpc.pay(invl1)

# Wait for us to recognize that the channel is available
wait_for(lambda: l1.rpc.listpeerchannels()['channels'][0]['spendable_msat'].millisatoshis > amount_sat * 1000)
wait_for(lambda: l1.rpc.listpeerchannels()['channels'][0]['spendable_msat'] > amount_sat * 1000)

# Now l1 can pay to l2.
l1.rpc.pay(invl2)
Expand Down
8 changes: 3 additions & 5 deletions tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,13 @@ def test_millisatoshi_passthrough(node_factory):
plugin_path = os.path.join(os.getcwd(), 'tests/plugins/millisatoshis.py')
n = node_factory.get_node(options={'plugin': plugin_path, 'log-level': 'io'})

# By keyword
# By keyword (plugin literally returns Millisatoshi, which becomes a string)
ret = n.rpc.call('echo', {'msat': Millisatoshi(17), 'not_an_msat': '22msat'})['echo_msat']
assert type(ret) == Millisatoshi
assert ret == Millisatoshi(17)
assert Millisatoshi(ret) == Millisatoshi(17)

# By position
ret = n.rpc.call('echo', [Millisatoshi(18), '22msat'])['echo_msat']
assert type(ret) == Millisatoshi
assert ret == Millisatoshi(18)
assert Millisatoshi(ret) == Millisatoshi(18)


def test_rpc_passthrough(node_factory):
Expand Down
4 changes: 0 additions & 4 deletions tests/test_renepay.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,11 @@ def test_amounts(node_factory):

invoice = only_one(l2.rpc.listinvoices('test_pay_amounts')['invoices'])

assert isinstance(invoice['amount_msat'], Millisatoshi)
assert invoice['amount_msat'] == Millisatoshi(123456)

l1.rpc.call('renepay', {'invstring': inv, 'dev_use_shadow': False})

invoice = only_one(l2.rpc.listinvoices('test_pay_amounts')['invoices'])
assert isinstance(invoice['amount_received_msat'], Millisatoshi)
assert invoice['amount_received_msat'] >= Millisatoshi(123456)


Expand Down Expand Up @@ -268,7 +266,6 @@ def test_limits(node_factory):
l1.rpc.call(
'renepay', {'invstring': inv2['bolt11']})
invoice = only_one(l6.rpc.listinvoices('inv2')['invoices'])
assert isinstance(invoice['amount_received_msat'], Millisatoshi)
assert invoice['amount_received_msat'] >= Millisatoshi('800000sat')


Expand Down Expand Up @@ -358,5 +355,4 @@ def test_hardmpp(node_factory):
json.loads("".join([l for l in lines if not l.startswith('#')]))
l1.wait_for_htlcs()
invoice = only_one(l6.rpc.listinvoices('inv2')['invoices'])
assert isinstance(invoice['amount_received_msat'], Millisatoshi)
assert invoice['amount_received_msat'] >= Millisatoshi('1800000sat')
Loading