Skip to content

Commit

Permalink
psbt: simplify test, add more coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
jgriffiths committed Jul 19, 2022
1 parent c4d0a62 commit 26aef05
Showing 1 changed file with 79 additions and 123 deletions.
202 changes: 79 additions & 123 deletions src/swig_python/contrib/psbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,46 @@
SAMPLE_V2 = 'cHNidP8BAgR7AAAAAQQBAQEFAQEB+wQCAAAAAAEOIAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gAQ8EAQAAAAABAwiH1hIAAAAAAAEEIQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA='
SAMPLE_PSET = 'cHNldP8BAgQCAAAAAQQBAQEFAQEBBgEDAfsEAgAAAAABDiCd/GYowmxYmf4b09wzhmW/1V162hD2Iglz3y04bewSdgEPBAEAAAABEAT///8AAAEDCPA9zR0AAAAAAQQWABR7OgC/3BTSd5XCt0kB0J2m7xM1eQf8BHBzZXQCIHd3d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3AA=='

def wally_fn(name):
return globals().get(name)

def accessors(typ, field):
names = [f'psbt_set_{typ}_{field}',
f'psbt_get_{typ}_{field}', f'psbt_get_{typ}_{field}_len',
f'psbt_has_{typ}_{field}', f'psbt_clear_{typ}_{field}']
return [wally_fn(n) for n in names]


class PSBTTests(unittest.TestCase):

def _throws(self, fn, psbt, *args):
self.assertRaises(ValueError, lambda: fn(psbt, *args))
def _throws(self, func, psbt, *args):
self.assertRaises(ValueError, lambda: func(psbt, *args))

def _try_invalid(self, fn, psbt, *args):
self._throws(fn, None, 0, *args) # Null PSBT
self._throws(fn, psbt, 1, *args) # Invalid index
def _try_invalid(self, func, psbt, *args):
self._throws(func, None, 0, *args) # Null PSBT
self._throws(func, psbt, 1, *args) # Invalid index

def _round_trip(self, psbt):
psbt_bytes = psbt_to_bytes(psbt, 0)
deserialized = psbt_from_bytes(psbt_bytes)
new_bytes = psbt_to_bytes(deserialized, 0)
self.assertEqual(psbt_bytes, new_bytes)

def _try_set(self, fn, psbt, valid_value, null_value=None, mandatory=False, allow_null=True, roundtrip=True):
def _try_set(self, func, psbt, valid_value, null_value=None, mandatory=False, allow_null=True, roundtrip=True):
if roundtrip:
self._round_trip(psbt)
fn(psbt, 0, valid_value) # Set
func(psbt, 0, valid_value) # Set
if roundtrip:
self._round_trip(psbt)
if allow_null:
fn(psbt, 0, null_value) # Un-set
func(psbt, 0, null_value) # Un-set
if mandatory:
fn(psbt, 0, valid_value) # Set
func(psbt, 0, valid_value) # Set
elif roundtrip:
self._round_trip(psbt)
else:
self._throws(fn, psbt, 0, null_value)
self._try_invalid(fn, psbt, valid_value)
self._throws(func, psbt, 0, null_value)
self._try_invalid(func, psbt, valid_value)

def _try_get_set_i(self, setfn, clearfn, getfn, psbt, valid_value, invalid_value=None, mandatory=False):
self._try_invalid(setfn, psbt, valid_value)
Expand Down Expand Up @@ -152,11 +161,11 @@ def test_psbt(self):
self.assertEqual(hex_from_bytes(psbt_bytes),
hex_from_bytes(psbt_to_bytes(psbt_tmp, 0)))

for fn, ret in [(psbt_get_version, 0 if p == psbt else 2),
(psbt_get_num_inputs, 1),
(psbt_get_num_outputs, 1)]:
self.assertEqual(fn(p), ret)
self._throws(fn, None) # Null PSBT
for func, ret in [(psbt_get_version, 0 if p == psbt else 2),
(psbt_get_num_inputs, 1),
(psbt_get_num_outputs, 1)]:
self.assertEqual(func(p), ret)
self._throws(func, None) # Null PSBT

sample = SAMPLE if p == psbt else SAMPLE_V2

Expand Down Expand Up @@ -230,9 +239,9 @@ def test_psbt(self):
if is_elements_build():
dummy_nonce = bytearray(b'\x00' * WALLY_TX_ASSET_CT_NONCE_LEN)
dummy_bf = bytearray(b'\x00' * BLINDING_FACTOR_LEN)
dummy_asset_commitment = bytearray(b'\x0a' * ASSET_COMMITMENT_LEN)
dummy_value_commitment = bytearray(b'\x08' * WALLY_TX_ASSET_CT_VALUE_UNBLIND_LEN)
dummy_nonce_commitment = bytearray(b'\x02' * ASSET_COMMITMENT_LEN)
dummy_blind_asset = bytearray(b'\x0a' * ASSET_COMMITMENT_LEN)
dummy_blind_value = bytearray(b'\x08' * WALLY_TX_ASSET_CT_VALUE_UNBLIND_LEN)
dummy_nonce = bytearray(b'\x02' * ASSET_COMMITMENT_LEN)
dummy_asset = bytearray(b'\x00' * ASSET_TAG_LEN)
dummy_nonce = bytearray(b'\x77' * ASSET_TAG_LEN)

Expand Down Expand Up @@ -307,15 +316,9 @@ def test_psbt(self):
self._try_invalid(psbt_get_input_utxo, p)
self._try_set(psbt_set_input_witness_utxo, p, dummy_txout)
self._try_invalid(psbt_get_input_witness_utxo, p)
self._try_get_set_b(psbt_set_input_redeem_script,
psbt_get_input_redeem_script,
psbt_get_input_redeem_script_len, p, dummy_bytes)
self._try_get_set_b(psbt_set_input_witness_script,
psbt_get_input_witness_script,
psbt_get_input_witness_script_len, p, dummy_bytes)
self._try_get_set_b(psbt_set_input_final_scriptsig,
psbt_get_input_final_scriptsig,
psbt_get_input_final_scriptsig_len, p, dummy_bytes)
for field in ['redeem_script', 'witness_script', 'final_scriptsig']:
setfn, getfn, lenfn, hasfn, clearfn = accessors('input', field)
self._try_get_set_b(setfn, getfn, lenfn, p, dummy_bytes)
self._try_set(psbt_set_input_final_witness, p, dummy_witness)
self._try_invalid(psbt_get_input_final_witness, p)
self._try_get_set_m(psbt_set_input_keypaths,
Expand Down Expand Up @@ -429,71 +432,36 @@ def test_psbt(self):
psbt_set_input_inflation_keys(pset2, 0, 0)

cases = [
# PSET: blinded issuance amount (issuance amount commitment)
(psbt_set_input_issuance_amount_commitment,
psbt_get_input_issuance_amount_commitment,
psbt_clear_input_issuance_amount_commitment, dummy_value_commitment, dummy_asset_commitment),
# PSET: blinded issuance amount rangeproof
(psbt_set_input_issuance_amount_rangeproof, psbt_get_input_issuance_amount_rangeproof,
psbt_clear_input_issuance_amount_rangeproof, dummy_bytes, None),
# PSET: issuance blinding nonce
(psbt_set_input_issuance_blinding_nonce,
psbt_get_input_issuance_blinding_nonce,
psbt_clear_input_issuance_blinding_nonce, dummy_nonce, dummy_nonce_commitment),
# PSET: issuance blinding entropy
(psbt_set_input_issuance_asset_entropy,
psbt_get_input_issuance_asset_entropy,
psbt_clear_input_issuance_asset_entropy, dummy_nonce, dummy_asset_commitment),
# PSET: blinded issuance amount value rangeproof
# (Confusing: this proves the blinded issuance amount matches
# the unblinded amount, for constructors/blinders use)
(psbt_set_input_issuance_amount_blinding_rangeproof,
psbt_get_input_issuance_amount_blinding_rangeproof,
psbt_clear_input_issuance_amount_blinding_rangeproof, dummy_bytes, None),
# PSET: peg-in claim script
(psbt_set_input_pegin_claim_script, psbt_get_input_pegin_claim_script,
psbt_clear_input_pegin_claim_script, dummy_bytes, None),
# PSET: peg-in genesis blockhash
(psbt_set_input_pegin_genesis_blockhash, psbt_get_input_pegin_genesis_blockhash,
psbt_clear_input_pegin_genesis_blockhash, dummy_txid, dummy_asset_commitment),
# PSET: peg-in txout proof
(psbt_set_input_pegin_txout_proof, psbt_get_input_pegin_txout_proof,
psbt_clear_input_pegin_txout_proof, dummy_bytes, None),
# PSET: blinded number of inflation keys (issuance keys commitment)
(psbt_set_input_inflation_keys_commitment, psbt_get_input_inflation_keys_commitment,
psbt_clear_input_inflation_keys_commitment, dummy_value_commitment, dummy_asset_commitment),
# PSET: blinded inflation keys rangeproof
(psbt_set_input_inflation_keys_rangeproof, psbt_get_input_inflation_keys_rangeproof,
psbt_clear_input_inflation_keys_rangeproof, dummy_bytes, None),
# PSET: blidned inflation keys value rangeproof
# (Confusing: this proves the number of blinded reissuance tokens
# matches the unblinded number, for constructors/blinders use)
(psbt_set_input_inflation_keys_blinding_rangeproof,
psbt_get_input_inflation_keys_blinding_rangeproof,
psbt_clear_input_inflation_keys_blinding_rangeproof, dummy_bytes, None),
# PSET: utxo rangeproof
(psbt_set_input_utxo_rangeproof, psbt_get_input_utxo_rangeproof,
psbt_clear_input_utxo_rangeproof, dummy_bytes, None),
('issuance_amount_commitment', dummy_blind_value, dummy_blind_asset),
('issuance_amount_rangeproof', dummy_bytes, None),
('issuance_blinding_nonce', dummy_nonce, dummy_nonce),
('issuance_asset_entropy', dummy_nonce, dummy_blind_asset),
('issuance_amount_blinding_rangeproof', dummy_bytes, None),
('pegin_claim_script', dummy_bytes, None),
('pegin_genesis_blockhash', dummy_txid, dummy_blind_asset),
('pegin_txout_proof', dummy_bytes, None),
('inflation_keys_commitment', dummy_blind_value, dummy_blind_asset),
('inflation_keys_rangeproof', dummy_bytes, None),
('inflation_keys_blinding_rangeproof', dummy_bytes, None),
('utxo_rangeproof', dummy_bytes, None),
]
for setfn, getfn, clearfn, valid_value, invalid_value in cases:
for field, valid_value, invalid_value in cases:
setfn, getfn, lenfn, hasfn, clearfn = accessors('input', field)

self._throws(setfn, psbt, 0, valid_value) # Non v2 PSBT
if invalid_value:
self._throws(setfn, psbt, 0, invalid_value) # Invalid value
self._throws(getfn, psbt, 0) # Non v2 PSBT
self._throws(getfn, psbt, 0) # Non v2 PSBT
self._throws(clearfn, psbt, 0) # Non v2 PSBT
self._try_get_set_b(setfn, getfn, clearfn, pset2, valid_value)
for func in getfn, lenfn, clearfn:
self._throws(func, psbt, 0) # Non v2 PSBT
self._try_get_set_b(setfn, getfn, lenfn, pset2, valid_value)

#
# Outputs
#
for p in [psbt, psbt2]:
self._try_get_set_b(psbt_set_output_redeem_script,
psbt_get_output_redeem_script,
psbt_get_output_redeem_script_len, p, dummy_bytes)
self._try_get_set_b(psbt_set_output_witness_script,
psbt_get_output_witness_script,
psbt_get_output_witness_script_len, p, dummy_bytes)
for field in ['redeem_script', 'witness_script']:
setfn, getfn, lenfn, hasfn, clearfn = accessors('output', field)
self._try_get_set_b(setfn, getfn, lenfn, p, dummy_bytes)
self._try_get_set_m(psbt_set_output_keypaths,
psbt_get_output_keypaths_size,
psbt_get_output_keypath_len,
Expand Down Expand Up @@ -542,57 +510,45 @@ def test_psbt(self):
self._try_get_set_i(setfn, None, getfn, pset2, 1234)

cases = [
# PSET: blinded issuance amount (issuance amount commitment)
(psbt_set_output_value_commitment, psbt_get_output_value_commitment,
psbt_clear_output_value_commitment, dummy_value_commitment, dummy_asset_commitment),
(psbt_set_output_asset, psbt_get_output_asset,
psbt_clear_output_asset, dummy_asset, dummy_asset_commitment),
(psbt_set_output_asset_commitment, psbt_get_output_asset_commitment,
psbt_clear_output_asset_commitment, dummy_asset_commitment, dummy_value_commitment),
(psbt_set_output_value_rangeproof, psbt_get_output_value_rangeproof,
psbt_clear_output_value_rangeproof, dummy_bytes, None),
(psbt_set_output_asset_surjectionproof,
psbt_get_output_asset_surjectionproof,
psbt_clear_output_asset_surjectionproof, dummy_bytes, None),
(psbt_set_output_blinding_public_key, psbt_get_output_blinding_public_key,
psbt_clear_output_blinding_public_key, dummy_pubkey, dummy_sig),
(psbt_set_output_ecdh_public_key, psbt_get_output_ecdh_public_key,
psbt_clear_output_ecdh_public_key, dummy_pubkey, dummy_sig),
(psbt_set_output_value_blinding_rangeproof,
psbt_get_output_value_blinding_rangeproof,
psbt_clear_output_value_blinding_rangeproof, dummy_bytes, None),
(psbt_set_output_asset_blinding_surjectionproof,
psbt_get_output_asset_blinding_surjectionproof,
psbt_clear_output_asset_blinding_surjectionproof, dummy_bytes, None),
('value_commitment', dummy_blind_value, dummy_blind_asset),
('asset', dummy_asset, dummy_blind_asset),
('asset_commitment', dummy_blind_asset, dummy_blind_value),
('value_rangeproof', dummy_bytes, None),
('asset_surjectionproof', dummy_bytes, None),
('blinding_public_key', dummy_pubkey, dummy_sig),
('ecdh_public_key', dummy_pubkey, dummy_sig),
('value_blinding_rangeproof', dummy_bytes, None),
('asset_blinding_surjectionproof', dummy_bytes, None),
]
for setfn, getfn, clearfn, valid_value, invalid_value in cases:
for field, valid_value, invalid_value in cases:
setfn, getfn, lenfn, hasfn, clearfn = accessors('output', field)

self._throws(setfn, psbt, 0, valid_value) # Non v2 PSBT
if invalid_value:
self._throws(setfn, psbt, 0, invalid_value) # Invalid value
self._throws(getfn, psbt, 0) # Non v2 PSBT
self._throws(getfn, psbt, 0) # Non v2 PSBT
self._throws(clearfn, psbt, 0) # Non v2 PSBT
is_commitment_fn = setfn in [psbt_set_output_value_commitment,
psbt_set_output_asset_commitment,
psbt_set_output_value_blinding_rangeproof,
psbt_set_output_asset_blinding_surjectionproof]
is_mandatory_fn = setfn in [psbt_set_output_asset]
self._try_get_set_b(setfn, getfn, clearfn, pset2, valid_value,
for func in getfn, lenfn, clearfn:
self._throws(func, psbt, 0) # Non v2 PSBT
is_commitment_fn = field in ['value_commitment',
'asset_commitment',
'value_blinding_rangeproof',
'asset_blinding_surjectionproof']
is_mandatory_fn = field in ['asset']
self._try_get_set_b(setfn, getfn, lenfn, pset2, valid_value,
mandatory=is_mandatory_fn, roundtrip=not is_commitment_fn)
if is_commitment_fn:
clearfn(pset2, 0)
else:
self._round_trip(pset2)

# Blinding status
fn = psbt_get_output_blinding_status
self._throws(fn, psbt, 0, 0) # Non v2 PSBT
self._throws(fn, psbt2, 3, 0) # Bad output index
self._throws(fn, psbt2, 0, 1) # Unknown flag
self.assertEqual(fn(psbt2, 0, 0), WALLY_PSET_BLINDED_NONE)
self.assertEqual(fn(pset2, 0, 0), WALLY_PSET_BLINDED_PARTIAL)
func = psbt_get_output_blinding_status
self._throws(func, psbt, 0, 0) # Non v2 PSBT
self._throws(func, psbt2, 3, 0) # Bad output index
self._throws(func, psbt2, 0, 1) # Unknown flag
self.assertEqual(func(psbt2, 0, 0), WALLY_PSET_BLINDED_NONE)
self.assertEqual(func(pset2, 0, 0), WALLY_PSET_BLINDED_PARTIAL)
psbt_clear_output_blinding_public_key(pset2, 0)
self.assertEqual(fn(pset2, 0, 0), WALLY_PSET_BLINDED_NONE)
self.assertEqual(func(pset2, 0, 0), WALLY_PSET_BLINDED_NONE)

if __name__ == '__main__':
unittest.main()

0 comments on commit 26aef05

Please sign in to comment.