diff --git a/lnprototest/stash/stash.py b/lnprototest/stash/stash.py index 54703bb..801158a 100644 --- a/lnprototest/stash/stash.py +++ b/lnprototest/stash/stash.py @@ -1,10 +1,12 @@ -from lnprototest import Runner, Event, Side, SpecFileError, Funding -from typing import Callable, Optional, Any -from pyln.proto.message import Message import functools import time import coincurve +from typing import Callable, Optional, Any +from pyln.proto.message import Message + +from lnprototest import Runner, Event, Side, SpecFileError, Funding + def commitsig_to_send() -> Callable[[Runner, Event, str], str]: """Get the appropriate signature for the local side to send to the remote""" @@ -275,12 +277,15 @@ def _funding_close_tx(runner: Runner, event: Event, field: str) -> str: return _funding_close_tx -def stash_field_from_event(stash_key: str) -> Callable[[Runner, Event, str], str]: +def stash_field_from_event( + stash_key: str, field_name: Optional[str] = None, dummy_val: Optional[Any] = None +) -> Callable[[Runner, Event, str], str]: """Generic stash function to get the information back from a previous event""" def _stash_field_from_event(runner: Runner, event: Event, field: str) -> str: if runner._is_dummy(): - return "0" + return dummy_val + field = field if field_name is None else field_name return runner.get_stash(event, stash_key).fields[field] return _stash_field_from_event diff --git a/lnprototest/utils/ln_spec_utils.py b/lnprototest/utils/ln_spec_utils.py index ad6f875..967485a 100644 --- a/lnprototest/utils/ln_spec_utils.py +++ b/lnprototest/utils/ln_spec_utils.py @@ -8,7 +8,7 @@ author: Vincenzo PAlazzo https://github.com/vincenzopalazzo """ -from typing import List +from typing import List, Optional class LightningUtils: @@ -49,11 +49,12 @@ def connect_to_node_helper( runner: "Runner", tx_spendable: str, conn_privkey: str = "02", - global_features="", - features: str = "", + global_features: Optional[str] = None, + features: Optional[str] = None, ) -> List["Event"]: """Helper function to make a connection with the node""" from lnprototest.utils.bitcoin_utils import tx_spendable + from lnprototest.stash import stash_field_from_event from lnprototest import ( Connect, Block, @@ -65,7 +66,15 @@ def connect_to_node_helper( Block(blockheight=102, txs=[tx_spendable]), Connect(connprivkey=conn_privkey), ExpectMsg("init"), - Msg("init", globalfeatures=global_features, features=features), + Msg( + "init", + globalfeatures=stash_field_from_event("init", dummy_val="") + if global_features is None + else global_features, + features=stash_field_from_event("init", dummy_val="") + if features is None + else features, + ), ] diff --git a/tests/test_bolt2-01-open_channel.py b/tests/test_bolt2-01-open_channel.py index 89fae88..bac95ed 100644 --- a/tests/test_bolt2-01-open_channel.py +++ b/tests/test_bolt2-01-open_channel.py @@ -112,7 +112,7 @@ def test_open_channel_from_accepter_side(runner: Runner) -> None: delayed_payment_basepoint=remote_delayed_payment_basepoint(), htlc_basepoint=remote_htlc_basepoint(), first_per_commitment_point=remote_per_commitment_point(0), - minimum_depth=stash_field_from_event("accept_channel"), + minimum_depth=stash_field_from_event("accept_channel", dummy_val=3), channel_reserve_satoshis=9998, ), # Ignore unknown odd messages