Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic taproot BIP32 derivation & keyspend support #33

Merged
152 changes: 126 additions & 26 deletions src/embit/psbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .script import Script, Witness
from . import script
from .base import EmbitBase, EmbitError

from binascii import b2a_base64, a2b_base64, hexlify, unhexlify
from io import BytesIO

Expand Down Expand Up @@ -133,6 +134,11 @@ def __init__(self, unknown: dict = {}, vin=None, compress=CompressMode.KEEP_ALL)
self.redeem_script = None
self.witness_script = None
self.bip32_derivations = OrderedDict()

# tuples of ([leaf_hashes], DerivationPath)
self.taproot_bip32_derivations = OrderedDict()
self.taproot_internal_key = None

self.final_scriptsig = None
self.final_scriptwitness = None
self.parse_unknowns()
Expand All @@ -152,6 +158,8 @@ def clear_metadata(self, compress=CompressMode.CLEAR_ALL):
if self.witness_utxo is not None:
self.non_witness_utxo = None
self.bip32_derivations = OrderedDict()
self.taproot_bip32_derivations = OrderedDict()
self.taproot_internal_key = None

def update(self, other):
self.txid = other.txid or self.txid
Expand All @@ -166,6 +174,8 @@ def update(self, other):
self.redeem_script = other.redeem_script or self.redeem_script
self.witness_script = other.witness_script or self.witness_script
self.bip32_derivations.update(other.bip32_derivations)
self.taproot_bip32_derivations.update(other.taproot_bip32_derivations)
self.taproot_internal_key = other.taproot_internal_key
self.final_scriptsig = other.final_scriptsig or self.final_scriptsig
self.final_scriptwitness = other.final_scriptwitness or self.final_scriptwitness

Expand Down Expand Up @@ -229,7 +239,9 @@ def read_value(self, stream, k):
tx = self.TX_CLS.read_from(stream)
self.non_witness_utxo = tx
return

v = read_string(stream)

# witness utxo
if k[0] == 0x01:
if len(k) != 1:
Expand Down Expand Up @@ -274,13 +286,15 @@ def read_value(self, stream, k):
self.witness_script = Script(v)
else:
raise PSBTError("Duplicated witness script")
# bip32 derivation

# PSBT_IN_BIP32_DERIVATION
elif k[0] == 0x06:
pub = ec.PublicKey.parse(k[1:])
if pub in self.bip32_derivations:
raise PSBTError("Duplicated derivation path")
else:
self.bip32_derivations[pub] = DerivationPath.parse(v)

# final scriptsig
elif k[0] == 0x07:
# we don't need this key for signing
Expand All @@ -303,12 +317,30 @@ def read_value(self, stream, k):
self.final_scriptwitness = Witness.parse(v)
else:
raise PSBTError("Duplicated final scriptwitness")

elif k == b"\x0e":
self.txid = bytes(reversed(v))
elif k == b"\x0f":
self.vout = int.from_bytes(v, 'little')
elif k == b"\x10":
self.sequence = int.from_bytes(v, 'little')

# PSBT_IN_TAP_BIP32_DERIVATION
elif k[0] == 0x16:
pub = ec.PublicKey.from_xonly(k[1:])
if pub not in self.taproot_bip32_derivations:
# Field begins with the number of leaf hashes; for now only support the
# internal key where there are no leaf hashes.
# TODO: Support keysigns from leaves within the taptree.
if v[0] > 0:
return
leaf_hashes = [] # TODO: Actually parse leaf hashes, if present
self.taproot_bip32_derivations[pub] = (leaf_hashes, DerivationPath.parse(v[1:]))

# PSBT_IN_TAP_INTERNAL_KEY
elif k[0] == 0x17:
self.taproot_internal_key = ec.PublicKey.from_xonly(v)

else:
if k in self.unknown:
raise PSBTError("Duplicated key")
Expand Down Expand Up @@ -343,6 +375,7 @@ def write_to(self, stream, skip_separator=False, version=None, **kwargs) -> int:
if self.final_scriptwitness is not None:
r += stream.write(b"\x01\x08")
r += ser_string(stream, self.final_scriptwitness.serialize())

if version == 2:
if self.txid is not None:
r += ser_string(stream, b"\x0e")
Expand All @@ -353,6 +386,18 @@ def write_to(self, stream, skip_separator=False, version=None, **kwargs) -> int:
if self.sequence is not None:
r += ser_string(stream, b"\x10")
r += ser_string(stream, self.sequence.to_bytes(4, 'little'))

# PSBT_IN_TAP_BIP32_DERIVATION
for pub in self.taproot_bip32_derivations:
r += ser_string(stream, b"\x16" + pub.xonly())
leaf_hashes, derivation = self.taproot_bip32_derivations[pub]
r += ser_string(stream, len(leaf_hashes).to_bytes(1, 'little') + derivation.serialize())

# PSBT_IN_TAP_INTERNAL_KEY
if self.taproot_internal_key is not None:
r += ser_string(stream, b"\x17")
r += ser_string(stream, self.taproot_internal_key.xonly())

# unknown
for key in self.unknown:
r += ser_string(stream, key)
Expand All @@ -375,6 +420,8 @@ def __init__(self, unknown: dict = {}, vout=None, compress=CompressMode.KEEP_ALL
self.redeem_script = None
self.witness_script = None
self.bip32_derivations = OrderedDict()
self.taproot_bip32_derivations = OrderedDict()
self.taproot_internal_key = None
self.parse_unknowns()

def clear_metadata(self, compress=CompressMode.CLEAR_ALL):
Expand All @@ -385,6 +432,8 @@ def clear_metadata(self, compress=CompressMode.CLEAR_ALL):
self.redeem_script = None
self.witness_script = None
self.bip32_derivations = OrderedDict()
self.taproot_bip32_derivations = OrderedDict()
self.taproot_internal_key = None

def update(self, other):
self.value = other.value if other.value is not None else self.value
Expand All @@ -393,6 +442,8 @@ def update(self, other):
self.redeem_script = other.redeem_script or self.redeem_script
self.witness_script = other.witness_script or self.witness_script
self.bip32_derivations.update(other.bip32_derivations)
self.taproot_bip32_derivations.update(other.bip32_derivations)
self.taproot_internal_key = other.taproot_internal_key

@property
def vout(self):
Expand All @@ -402,7 +453,9 @@ def read_value(self, stream, k):
# separator
if len(k) == 0:
return

v = read_string(stream)

# redeem script
if k[0] == 0x00:
if len(k) != 1:
Expand All @@ -426,15 +479,34 @@ def read_value(self, stream, k):
raise PSBTError("Duplicated derivation path")
else:
self.bip32_derivations[pub] = DerivationPath.parse(v)

elif k == b"\x03":
self.value = int.from_bytes(v, 'little')
elif k == b"\x04":
self.script_pubkey = Script(v)

# PSBT_OUT_TAP_INTERNAL_KEY
elif k[0] == 0x05:
self.taproot_internal_key = ec.PublicKey.from_xonly(v)

# PSBT_OUT_TAP_BIP32_DERIVATION
elif k[0] == 0x07:
pub = ec.PublicKey.from_xonly(k[1:])
if pub not in self.taproot_bip32_derivations:
# Field begins with the number of leaf hashes; for now only support the
# internal key where there are no leaf hashes.
# TODO: Support keysigns from leaves within the taptree.
if v[0] > 0:
return
leaf_hashes = [] # TODO: Actually parse leaf hashes, if present
self.taproot_bip32_derivations[pub] = (leaf_hashes, DerivationPath.parse(v[1:]))

else:
if k in self.unknown:
raise PSBTError("Duplicated key")
self.unknown[k] = v


def write_to(self, stream, skip_separator=False, version=None, **kwargs) -> int:
r = 0
if self.redeem_script is not None:
Expand All @@ -446,6 +518,7 @@ def write_to(self, stream, skip_separator=False, version=None, **kwargs) -> int:
for pub in self.bip32_derivations:
r += ser_string(stream, b"\x02" + pub.serialize())
r += ser_string(stream, self.bip32_derivations[pub].serialize())

if version == 2:
if self.value is not None:
r += ser_string(stream, b"\x03")
Expand All @@ -454,6 +527,17 @@ def write_to(self, stream, skip_separator=False, version=None, **kwargs) -> int:
r += ser_string(stream, b"\x04")
r += self.script_pubkey.write_to(stream)

# PSBT_OUT_TAP_INTERNAL_KEY
if self.taproot_internal_key is not None:
r += ser_string(stream, b"\x05")
r += ser_string(stream, self.taproot_internal_key.xonly())

# PSBT_OUT_TAP_BIP32_DERIVATION
for pub in self.taproot_bip32_derivations:
r += ser_string(stream, b"\x07" + pub.xonly())
leaf_hashes, derivation = self.taproot_bip32_derivations[pub]
r += ser_string(stream, len(leaf_hashes).to_bytes(1, 'little') + derivation.serialize())

# unknown
for key in self.unknown:
r += ser_string(stream, key)
Expand Down Expand Up @@ -753,32 +837,48 @@ def sign_with(self, root, sighash=SIGHASH.DEFAULT) -> int:
counter += 1
# if we use HDKey
else:
# TODO: add taproot derivation paths and scripts
bip32_derivations = []
for pub in inp.taproot_bip32_derivations:
leaf_hashes, derivation = inp.taproot_bip32_derivations[pub]
if derivation.fingerprint == fingerprint:
bip32_derivations.append((pub, derivation))

# "Legacy" support for workaround when BIP-371 Taproot psbt fields aren't available.
# TODO: Remove this (and refactor above) when workaround has been phased out.
for pub in inp.bip32_derivations:
# check if it is root key
if inp.bip32_derivations[pub].fingerprint == fingerprint:
der = inp.bip32_derivations[pub].derivation
if hasattr(root, "origin"):
# for descriptor key remove origin part
if root.origin:
if root.origin.derivation != der[:len(root.origin.derivation)]:
continue
der = der[len(root.origin.derivation):]
hdkey = root.key.derive(der)
else:
hdkey = root.derive(der)
mypub = hdkey.key.get_public_key()
if mypub != pub:
raise PSBTError("Derivation path doesn't look right")
pk = hdkey.taproot_tweak(b"")
if pk.xonly() in sc.data:
sig = pk.schnorr_sign(h)
# sig plus sighash flag
wit = sig.serialize()
if inp_sighash != SIGHASH.DEFAULT:
wit += bytes([inp_sighash])
inp.final_scriptwitness = Witness([wit])
counter += 1
derivation = inp.bip32_derivations[pub]
if derivation.fingerprint == fingerprint:
bip32_derivations.append((pub, derivation))

for pub, derivation in bip32_derivations:
der = derivation.derivation
if hasattr(root, "origin"):
# for descriptor key remove origin part
if root.origin:
if root.origin.derivation != der[:len(root.origin.derivation)]:
continue
der = der[len(root.origin.derivation):]
hdkey = root.key.derive(der)
else:
hdkey = root.derive(der)

# Taproot BIP32 derivations use X-only pubkeys
xonly_pub = hdkey.key.xonly()
mypub = ec.PublicKey.from_xonly(xonly_pub)

if mypub != pub:
raise PSBTError("Derivation path doesn't look right")

# TODO: Support signing for keys within leaves
pk = hdkey.taproot_tweak(b"")
if pk.xonly() in sc.data:
sig = pk.schnorr_sign(h)
# sig plus sighash flag
wit = sig.serialize()
if inp_sighash != SIGHASH.DEFAULT:
wit += bytes([inp_sighash])
inp.final_scriptwitness = Witness([wit])
counter += 1
continue

# if we have individual private key
Expand Down
45 changes: 29 additions & 16 deletions src/embit/psbtview.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,23 +598,36 @@ def sign_input(self, i, root, sig_stream, sighash=SIGHASH.DEFAULT, extra_scope_d
counter = 1
# if we use HDKey
else:
# TODO: add taproot derivation paths and scripts
bip32_derivations = []
for pub in inp.taproot_bip32_derivations:
num_leaf_hashes, leaf_hashes, derivation = inp.taproot_bip32_derivations[pub]
if derivation.fingerprint == fingerprint:
bip32_derivations.append((pub, derivation))

# "Legacy" support for workaround when BIP-371 Taproot psbt fields aren't available
for pub in inp.bip32_derivations:
# check if it is root key
if inp.bip32_derivations[pub].fingerprint == fingerprint:
hdkey = root.derive(inp.bip32_derivations[pub].derivation)
mypub = hdkey.key.get_public_key()
if mypub != pub:
raise PSBTError("Derivation path doesn't look right")
pk = hdkey.taproot_tweak(b"")
if pk.xonly() in sc.data:
sig = pk.schnorr_sign(h)
# sig plus sighash flag
wit = sig.serialize()
if inp_sighash != SIGHASH.DEFAULT:
wit += bytes([inp_sighash])
inp.final_scriptwitness = Witness([wit])
counter = 1
derivation = inp.bip32_derivations[pub]
if derivation.fingerprint == fingerprint:
bip32_derivations.append((pub, derivation))

for pub, derivation in bip32_derivations:
hdkey = root.derive(derivation.derivation)

# Taproot BIP32 derivations use X-only pubkeys
xonly_pub = hdkey.key.xonly()
mypub = ec.PublicKey.from_xonly(xonly_pub)

if mypub != pub:
raise PSBTError("Derivation path doesn't look right")
pk = hdkey.taproot_tweak(b"")
if pk.xonly() in sc.data:
sig = pk.schnorr_sign(h)
# sig plus sighash flag
wit = sig.serialize()
if inp_sighash != SIGHASH.DEFAULT:
wit += bytes([inp_sighash])
inp.final_scriptwitness = Witness([wit])
counter = 1
if counter:
ser_string(sig_stream, b"\x08")
ser_string(sig_stream, inp.final_scriptwitness.serialize())
Expand Down
4 changes: 2 additions & 2 deletions tests/tests/test_psbtview.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"cHNidP8BAgQCAAAAAQMEaAAAAAEEAQIBBQECAfsEAgAAAAABAIkCAAAAAZtxoIrIK2VWr0F9+yJWW+u4lrJ7Jt8a7vcqyTehjnHhAQAAAAD+////AkBLTAAAAAAAIgAgEmbKGKSfdnRULLw/X2F3vsCX8kCU1Q5MHsyqnemip2IaWG0pAQAAACJRIO/5p6DvLGW4bS1rGXEN90/F4WYvXsAZ6lcDOu7r3vMsAAAAAAEBK0BLTAAAAAAAIgAgEmbKGKSfdnRULLw/X2F3vsCX8kCU1Q5MHsyqnemip2IBBUdRIQI0d9dRmoMchemObZucs3XuW+t9B5czDD26ejPpXp6deyEDaJVy4osP7anWmBwCN9nl3tv+wW3nFD9oCgLtXRWfdYdSriIGAjR311GagxyF6Y5tm5yzde5b630HlzMMPbp6M+lenp17HCYUvcQwAACAAQAAgAAAAIACAACAAAAAAAEAAAAiBgNolXLiiw/tqdaYHAI32eXe2/7BbecUP2gKAu1dFZ91hxxzxdoKMAAAgAEAAIAAAACAAgAAgAAAAAABAAAAAQ4grCdq+90cjg1s4G1WH2DAMxoSKEQQMdkZC+0i4NDCxf8BDwQAAAAAARAE/f///wABAIkCAAAAAZa1AcGuoTE/hvyx2z6J8nCQXx/w95NYLDJz/86ofJ8nAQAAAAD+////Ak1Chu0AAAAAIlEg2+nu/5cRFoR+h6D1Hr+2KDi4AXhSqzDAOcP+U73sm3yAlpgAAAAAACIAIFFGwzWLmusy+RnSQwoKbc3sAGDkWif+iApu0bBCGNTuAAAAAAEBK4CWmAAAAAAAIgAgUUbDNYua6zL5GdJDCgptzewAYORaJ/6ICm7RsEIY1O4BBUdRIQMLkO0uhrrX8qT+l2m7QX17qcqhEkgH2/s2Lfvutl5+ASED+7FQTEpqQ7FapyTEsb8sa40VIED+lpRG1BU0dIRy8glSriIGAwuQ7S6GutfypP6XabtBfXupyqESSAfb+zYt++62Xn4BHHPF2gowAACAAQAAgAAAAIACAACAAAAAAAAAAAAiBgP7sVBMSmpDsVqnJMSxvyxrjRUgQP6WlEbUFTR0hHLyCRwmFL3EMAAAgAEAAIAAAACAAgAAgAAAAAAAAAAAAQ4gYfVl48Yz2drXieVIad5Y7Wi0tSRJtFHrQlEnngrZ3QkBDwQBAAAAARAE/f///wABAUdRIQOgfTvgutY8gDXSHJe0EIkNPToZ0uQDr7P8/GgmqiY8diED41V+tIWahSqtEAq6IPVKeX7RITmu5WdtwNFRHa8RujRSriICA6B9O+C61jyANdIcl7QQiQ09OhnS5AOvs/z8aCaqJjx2HHPF2gowAACAAQAAgAAAAIACAACAAQAAAAAAAAAiAgPjVX60hZqFKq0QCrog9Up5ftEhOa7lZ23A0VEdrxG6NBwmFL3EMAAAgAEAAIAAAACAAgAAgAEAAAAAAAAAAQMIrMUtAAAAAAABBCIAIN+3uFj/eYJpVvPHmKjOj9adl8+SNgs7Tk9+G2HXP2gUAAEDCAAbtwAAAAAAAQQWABTQxKPvCemXtumeOX5Rj+PkGhGMoQA=",
]

class PSBTTest(TestCase):
class PSBTViewTest(TestCase):
def test_scopes(self):
"""Tests that PSBT and PSBTView result in the same scopes and other constants"""
for compress in [CompressMode.KEEP_ALL, CompressMode.CLEAR_ALL, CompressMode.PARTIAL]:
Expand Down Expand Up @@ -65,7 +65,7 @@ def test_sign(self):
"""Test if we can sign psbtview and get the same as from signing psbt"""
for compress in [CompressMode.KEEP_ALL, CompressMode.CLEAR_ALL, CompressMode.PARTIAL]:
for b64 in PSBTS:
psbt = PSBT.from_string(b64)
psbt = PSBT.from_string(b64, compress=compress)
stream = BytesIO(a2b_base64(b64))
psbtv = PSBTView.view(stream, compress=compress)

Expand Down
Loading