Skip to content

Commit

Permalink
Simplify taproot_bip32_derivations
Browse files Browse the repository at this point in the history
  • Loading branch information
kdmukai committed Aug 29, 2022
1 parent cf91d2f commit 6346888
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 26 deletions.
40 changes: 17 additions & 23 deletions src/embit/psbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __init__(self, unknown: dict = {}, vin=None, compress=CompressMode.KEEP_ALL)
self.witness_script = None
self.bip32_derivations = OrderedDict()

# tuples of (num_leaf_hashes, [leaf_hashes], DerivationPath)
# tuples of ([leaf_hashes], DerivationPath)
self.taproot_bip32_derivations = OrderedDict()
self.taproot_internal_key = None

Expand Down Expand Up @@ -328,17 +328,14 @@ def read_value(self, stream, k):
# PSBT_IN_TAP_BIP32_DERIVATION
elif k[0] == 0x16:
pub = ec.PublicKey.from_xonly(k[1:])
if pub in self.taproot_bip32_derivations:
raise PSBTError("Duplicated derivation path")
else:
if pub not in self.taproot_bip32_derivations:
# Field begins with the number of leaf hashes; for now only support the
# internal key where there are no leaf hashes.
# TODO: Support keysigns from leaves within the taptree.
if v[0] != 0:
raise PSBTError("Signing for public keys in leaves not yet implemented")
num_leaf_hashes = 0
leaf_hashes = None
self.taproot_bip32_derivations[pub] = (num_leaf_hashes, leaf_hashes, DerivationPath.parse(v[1:]))
if v[0] > 0:
return
leaf_hashes = [] # TODO: Actually parse leaf hashes, if present
self.taproot_bip32_derivations[pub] = (leaf_hashes, DerivationPath.parse(v[1:]))

# PSBT_IN_TAP_INTERNAL_KEY
elif k[0] == 0x17:
Expand Down Expand Up @@ -393,8 +390,8 @@ def write_to(self, stream, skip_separator=False, version=None, **kwargs) -> int:
# PSBT_IN_TAP_BIP32_DERIVATION
for pub in self.taproot_bip32_derivations:
r += ser_string(stream, b"\x16" + pub.xonly())
num_leaf_hashes, leaf_hashes, derivation = self.taproot_bip32_derivations[pub]
r += ser_string(stream, num_leaf_hashes.to_bytes(1, 'little') + derivation.serialize())
leaf_hashes, derivation = self.taproot_bip32_derivations[pub]
r += ser_string(stream, len(leaf_hashes).to_bytes(1, 'little') + derivation.serialize())

# PSBT_IN_TAP_INTERNAL_KEY
if self.taproot_internal_key is not None:
Expand Down Expand Up @@ -495,17 +492,14 @@ def read_value(self, stream, k):
# PSBT_OUT_TAP_BIP32_DERIVATION
elif k[0] == 0x07:
pub = ec.PublicKey.from_xonly(k[1:])
if pub in self.taproot_bip32_derivations:
raise PSBTError("Duplicated derivation path")
else:
if pub not in self.taproot_bip32_derivations:
# Field begins with the number of leaf hashes; for now only support the
# internal key where there are no leaf hashes.
# TODO: Support outputs to leaves within a taptree.
if v[0] != 0:
raise PSBTError("Outputs for public keys in leaves not yet implemented")
num_leaf_hashes = v[0]
leaf_hashes = None
self.taproot_bip32_derivations[pub] = (num_leaf_hashes, leaf_hashes, DerivationPath.parse(v[1:]))
# TODO: Support keysigns from leaves within the taptree.
if v[0] > 0:
return
leaf_hashes = [] # TODO: Actually parse leaf hashes, if present
self.taproot_bip32_derivations[pub] = (leaf_hashes, DerivationPath.parse(v[1:]))

else:
if k in self.unknown:
Expand Down Expand Up @@ -541,8 +535,8 @@ def write_to(self, stream, skip_separator=False, version=None, **kwargs) -> int:
# PSBT_OUT_TAP_BIP32_DERIVATION
for pub in self.taproot_bip32_derivations:
r += ser_string(stream, b"\x07" + pub.xonly())
num_leaf_hashes, leaf_hashes, derivation = self.taproot_bip32_derivations[pub]
r += ser_string(stream, num_leaf_hashes.to_bytes(1, 'little') + derivation.serialize())
leaf_hashes, derivation = self.taproot_bip32_derivations[pub]
r += ser_string(stream, len(leaf_hashes).to_bytes(1, 'little') + derivation.serialize())

# unknown
for key in self.unknown:
Expand Down Expand Up @@ -845,7 +839,7 @@ def sign_with(self, root, sighash=SIGHASH.DEFAULT) -> int:
else:
bip32_derivations = []
for pub in inp.taproot_bip32_derivations:
num_leaf_hashes, leaf_hashes, derivation = inp.taproot_bip32_derivations[pub]
leaf_hashes, derivation = inp.taproot_bip32_derivations[pub]
if derivation.fingerprint == fingerprint:
bip32_derivations.append((pub, derivation))

Expand Down
1 change: 0 additions & 1 deletion tests/tests/test_psbtview.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def test_sign(self):
"""Test if we can sign psbtview and get the same as from signing psbt"""
for compress in [True, False]:
for b64 in PSBTS:
print(b64)
psbt = PSBT.from_string(b64, compress=compress)
stream = BytesIO(a2b_base64(b64))
psbtv = PSBTView.view(stream, compress=compress)
Expand Down
4 changes: 2 additions & 2 deletions tests/tests/test_taproot.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,10 @@ def test_taproot_internal_keyspend(self):
inp = psbt_act.inputs[0]
self.assertTrue(inp.is_taproot)

# Should have extracted: X-only pubkey, (num_leaf_hashes, leaf_hashes, DerivationPath)
# Should have extracted: X-only pubkey, ([leaf_hashes], DerivationPath)
# from `PSBT_IN_TAP_BIP32_DERIVATION`
self.assertTrue(len(inp.taproot_bip32_derivations) > 0)
for pub in inp.taproot_bip32_derivations:
num_leaf_hashes, leaf_hashes, der = inp.taproot_bip32_derivations[pub]
leaf_hashes, der = inp.taproot_bip32_derivations[pub]
self.assertTrue(der.fingerprint is not None)
self.assertTrue(der.derivation is not None)

0 comments on commit 6346888

Please sign in to comment.