Skip to content

Commit

Permalink
feat(xmr): add support for HF15, BP+
Browse files Browse the repository at this point in the history
  • Loading branch information
ph4r05 authored and matejcik committed May 11, 2022
1 parent 74ff406 commit 13f60c6
Show file tree
Hide file tree
Showing 15 changed files with 218 additions and 120 deletions.
19 changes: 2 additions & 17 deletions common/protob/messages-monero.proto
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ message MoneroTransactionSetInputRequest {
/**
* Response: Response to setting UTXO for signature. Contains sealed values needed for further protocol steps.
* @next MoneroTransactionSetInputAck
* @next MoneroTransactionInputsPermutationRequest
* @next MoneroTransactionInputViniRequest
*/
message MoneroTransactionSetInputAck {
optional bytes vini = 1; // xmrtypes.TxinToKey
Expand All @@ -183,21 +183,6 @@ message MoneroTransactionSetInputAck {
optional bytes spend_key = 6;
}

/**
* Request: Sub request of MoneroTransactionSign. Permutation on key images.
* @next MoneroTransactionInputsPermutationAck
*/
message MoneroTransactionInputsPermutationRequest {
repeated uint32 perm = 1;
}

/**
* Response: Response to setting permutation on key images
* @next MoneroTransactionInputViniRequest
*/
message MoneroTransactionInputsPermutationAck {
}

/**
* Request: Sub request of MoneroTransactionSign. Sends one UTXO to device together with sealed values.
* @next MoneroTransactionInputViniAck
Expand Down Expand Up @@ -327,7 +312,7 @@ message MoneroTransactionFinalAck {
optional bytes salt = 2;
optional bytes rand_mult = 3;
optional bytes tx_enc_keys = 4;
optional bytes opening_key = 5; // enc master key to decrypt MLSAGs after protocol finishes correctly
optional bytes opening_key = 5; // enc master key to decrypt CLSAGs after protocol finishes correctly
}

/**
Expand Down
2 changes: 0 additions & 2 deletions common/protob/messages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,6 @@ enum MessageType {
MessageType_MoneroTransactionInitAck = 502 [(wire_out) = true];
MessageType_MoneroTransactionSetInputRequest = 503 [(wire_out) = true];
MessageType_MoneroTransactionSetInputAck = 504 [(wire_out) = true];
MessageType_MoneroTransactionInputsPermutationRequest = 505 [(wire_out) = true];
MessageType_MoneroTransactionInputsPermutationAck = 506 [(wire_out) = true];
MessageType_MoneroTransactionInputViniRequest = 507 [(wire_out) = true];
MessageType_MoneroTransactionInputViniAck = 508 [(wire_out) = true];
MessageType_MoneroTransactionAllInputsSetRequest = 509 [(wire_out) = true];
Expand Down
2 changes: 2 additions & 0 deletions core/src/apps/monero/signing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class RctType:
"""
There are several types of monero Ring Confidential Transactions
like RCTTypeFull and RCTTypeSimple but currently we use only CLSAG
and RCTTypeBulletproofPlus
"""

CLSAG = 5
RCTTypeBulletproofPlus = 6
1 change: 1 addition & 0 deletions core/src/apps/monero/signing/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(self, ctx: Context) -> None:
self.rsig_grouping: list[int] | None = []
# is range proof computing offloaded or not
self.rsig_offload: bool | None = False
self.rsig_is_bp_plus: bool | None = False

# sum of all inputs' pseudo out masks
self.sumpouts_alphas: Scalar = crypto.Scalar(0)
Expand Down
6 changes: 5 additions & 1 deletion core/src/apps/monero/signing/step_01_init_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ async def init_transaction(
state.fee = tsx_data.fee
state.account_idx = tsx_data.account
state.last_step = state.STEP_INIT
state.tx_type = signing.RctType.CLSAG
if tsx_data.hard_fork:
state.hard_fork = tsx_data.hard_fork
if state.hard_fork < 13:
Expand Down Expand Up @@ -208,6 +207,11 @@ def _check_rsig_data(state: State, rsig_data: MoneroTransactionRsigData) -> None
elif rsig_data.rsig_type not in (1, 2, 3):
raise ValueError("Unknown rsig type")

state.tx_type = signing.RctType.CLSAG
if rsig_data.bp_version == 4:
state.rsig_is_bp_plus = True
state.tx_type = signing.RctType.RCTTypeBulletproofPlus

if state.output_count > 2:
state.rsig_offload = True

Expand Down
85 changes: 67 additions & 18 deletions core/src/apps/monero/signing/step_06_set_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@

if TYPE_CHECKING:
from apps.monero.xmr.serialize_messages.tx_ecdh import EcdhTuple
from apps.monero.xmr.serialize_messages.tx_rsig_bulletproof import Bulletproof
from apps.monero.xmr.serialize_messages.tx_rsig_bulletproof import (
Bulletproof,
BulletproofPlus,
)
from trezor.messages import (
MoneroTransactionDestinationEntry,
MoneroTransactionSetOutputAck,
Expand Down Expand Up @@ -303,7 +306,7 @@ def _rsig_bp(state: State) -> bytes:
from apps.monero.xmr import range_signatures

rsig = range_signatures.prove_range_bp_batch(
state.output_amounts, state.output_masks
state.output_amounts, state.output_masks, state.rsig_is_bp_plus
)
state.mem_trace("post-bp" if __debug__ else None, collect=True)

Expand All @@ -314,7 +317,7 @@ def _rsig_bp(state: State) -> bytes:
state.full_message_hasher.rsig_val(rsig, raw=False)
state.mem_trace("post-bp-hash" if __debug__ else None, collect=True)

rsig = _dump_rsig_bp(rsig)
rsig = _dump_rsig_bp_plus(rsig) if state.rsig_is_bp_plus else _dump_rsig_bp(rsig)
state.mem_trace(
f"post-bp-ser, size: {len(rsig)}" if __debug__ else None, collect=True
)
Expand All @@ -327,9 +330,15 @@ def _rsig_bp(state: State) -> bytes:

def _rsig_process_bp(state: State, rsig_data: MoneroTransactionRsigData):
from apps.monero.xmr import range_signatures
from apps.monero.xmr.serialize_messages.tx_rsig_bulletproof import Bulletproof
from apps.monero.xmr.serialize_messages.tx_rsig_bulletproof import (
Bulletproof,
BulletproofPlus,
)

bp_obj = serialize.parse_msg(rsig_data.rsig, Bulletproof)
if state.rsig_is_bp_plus:
bp_obj = serialize.parse_msg(rsig_data.rsig, BulletproofPlus)
else:
bp_obj = serialize.parse_msg(rsig_data.rsig, Bulletproof)
rsig_data.rsig = None

# BP is hashed with raw=False as hash does not contain L, R
Expand Down Expand Up @@ -366,8 +375,45 @@ def _dump_rsig_bp(rsig: Bulletproof) -> bytes:
utils.memcpy(buff, 32 * 4, rsig.taux, 0, 32)
utils.memcpy(buff, 32 * 5, rsig.mu, 0, 32)

buff[32 * 6] = len(rsig.L)
offset = 32 * 6 + 1
offset = _dump_rsig_lr(buff, 32 * 6, rsig)

utils.memcpy(buff, offset, rsig.a, 0, 32)
offset += 32
utils.memcpy(buff, offset, rsig.b, 0, 32)
offset += 32
utils.memcpy(buff, offset, rsig.t, 0, 32)
return buff


def _dump_rsig_bp_plus(rsig: BulletproofPlus) -> bytes:
if len(rsig.L) > 127:
raise ValueError("Too large")

# Manual serialization as the generic purpose serialize.dump_msg_gc
# is more memory intensive which is not desired in the range proof section.

# BP: "V", "A", "A1", "B", "r1", "s1", "d1", "V", "L", "R"
# Commitment vector V is not serialized
# Vector size under 127 thus varint occupies 1 B
buff_size = 32 * (6 + 2 * (len(rsig.L))) + 2
buff = bytearray(buff_size)

utils.memcpy(buff, 0, rsig.A, 0, 32)
utils.memcpy(buff, 32, rsig.A1, 0, 32)
utils.memcpy(buff, 32 * 2, rsig.B, 0, 32)
utils.memcpy(buff, 32 * 3, rsig.r1, 0, 32)
utils.memcpy(buff, 32 * 4, rsig.s1, 0, 32)
utils.memcpy(buff, 32 * 5, rsig.d1, 0, 32)

_dump_rsig_lr(buff, 32 * 6, rsig)
return buff


def _dump_rsig_lr(
buff: bytearray, offset: int, rsig: Bulletproof | BulletproofPlus
) -> int:
buff[offset] = len(rsig.L)
offset += 1

for x in rsig.L:
utils.memcpy(buff, offset, x, 0, 32)
Expand All @@ -379,18 +425,12 @@ def _dump_rsig_bp(rsig: Bulletproof) -> bytes:
for x in rsig.R:
utils.memcpy(buff, offset, x, 0, 32)
offset += 32

utils.memcpy(buff, offset, rsig.a, 0, 32)
offset += 32
utils.memcpy(buff, offset, rsig.b, 0, 32)
offset += 32
utils.memcpy(buff, offset, rsig.t, 0, 32)
return buff
return offset


def _return_rsig_data(
rsig: bytes | None = None, mask: bytes | None = None
) -> MoneroTransactionRsigData:
) -> MoneroTransactionRsigData | None:
if rsig is None and mask is None:
return None

Expand Down Expand Up @@ -420,9 +460,18 @@ def _get_ecdh_info_and_out_pk(
so the recipient is able to reconstruct the commitment.
"""
out_pk_dest = crypto_helpers.encodepoint(tx_out_key)
out_pk_commitment = crypto_helpers.encodepoint(
crypto.gen_commitment_into(None, mask, amount)
)
if state.rsig_is_bp_plus:
# HF15+ stores commitment multiplied by 8**-1
inv8 = crypto.decodeint_into_noreduce(None, crypto_helpers.INV_EIGHT)
mask8 = crypto.sc_mul_into(None, mask, inv8)
amnt8 = crypto.Scalar(amount)
amnt8 = crypto.sc_mul_into(amnt8, amnt8, inv8)
out_pk_commitment = crypto.add_keys2_into(None, mask8, amnt8, crypto.xmr_H())
del (inv8, mask8, amnt8)
else:
out_pk_commitment = crypto.gen_commitment_into(None, mask, amount)

out_pk_commitment = crypto_helpers.encodepoint(out_pk_commitment)
crypto.sc_add_into(state.sumout, state.sumout, mask)
ecdh_info = _ecdh_encode(amount, crypto_helpers.encodeint(amount_key))

Expand Down
108 changes: 95 additions & 13 deletions core/src/apps/monero/xmr/bulletproof.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def _invert_batch(x):

for i in range(len(x) - 1, -1, -1):
_sc_mul(tmp, acc, x[i])
_sc_mul(x[i], acc, scratch[i])
x[i] = _sc_mul(x[i], acc, scratch[i])
memcpy(acc, 0, tmp, 0, 32)
return x

Expand Down Expand Up @@ -1161,6 +1161,55 @@ def __getitem__(self, item):
return self.cur_sc if self.raw else self.cur


class KeyChallengeCacheVct(KeyVBase):
"""
Challenge cache vector for BP+ verification
More on this in the verification code, near "challenge_cache" definition
"""

__slots__ = (
"nbits",
"ch_",
"chi",
"precomp",
"precomp_depth",
"cur",
)

def __init__(
self, nbits: int, ch_: KeyVBase, chi: KeyVBase, precomputed: KeyVBase | None
):
super().__init__(1 << nbits)
self.nbits = nbits
self.ch_ = ch_
self.chi = chi
self.precomp = precomputed
self.precomp_depth = 0
self.cur = _ensure_dst_key()
if not precomputed:
return

while (1 << self.precomp_depth) < len(precomputed):
self.precomp_depth += 1

def __getitem__(self, item):
i = self.idxize(item)
bits_done = 0

if self.precomp:
_copy_key(self.cur, self.precomp[i >> (self.nbits - self.precomp_depth)])
bits_done += self.precomp_depth
else:
_copy_key(self.cur, _ONE)

for j in range(self.nbits - 1, bits_done - 1, -1):
if i & (1 << (self.nbits - 1 - j)) > 0:
_sc_mul(self.cur, self.cur, self.ch_[j])
else:
_sc_mul(self.cur, self.cur, self.chi[j])
return self.cur


class KeyR0(KeyVBase):
"""
Vector r0. Allows only sequential access (no jumping). Resets on [0,1] access.
Expand Down Expand Up @@ -2736,6 +2785,9 @@ def vec_sc_fnc(i, d):

return BulletproofPlus(V=V, A=A, A1=A1, B=B, r1=r1, s1=s1, d1=d1, L=L, R=R)

def verify(self, proof: BulletproofPlus) -> bool:
return self.verify_batch([proof])

def verify_batch(self, proofs: list[BulletproofPlus]):
"""
BP+ batch verification
Expand Down Expand Up @@ -2763,7 +2815,8 @@ def verify_batch(self, proofs: list[BulletproofPlus]):
max_logm = 0

proof_data = []
to_invert = []
to_invert_offset = 0
to_invert = _ensure_dst_keyvect(None, 11 * len(proofs))
for proof in proofs:
max_length = max(max_length, len(proof.L))
nV += len(proof.V)
Expand Down Expand Up @@ -2806,11 +2859,17 @@ def verify_batch(self, proofs: list[BulletproofPlus]):
# batch scalar inversions
pd.inv_offset = inv_offset
for j in range(rounds): # max rounds is 10 = lg(16*64) = lg(1024)
to_invert.append(bytearray(pd.challenges[j]))
to_invert.append(bytearray(pd.y))
to_invert.read(to_invert_offset, pd.challenges[j])
to_invert_offset += 1

to_invert.read(to_invert_offset, pd.y)
to_invert_offset += 1
inv_offset += rounds + 1
self.gc(2)

to_invert.resize(inv_offset)
self.gc(2)

utils.ensure(max_length < 32, "At least one proof is too large")
maxMN = 1 << max_length
tmp2 = _ensure_dst_key()
Expand Down Expand Up @@ -2937,32 +2996,55 @@ def verify_batch(self, proofs: list[BulletproofPlus]):
yinv = inverses[pd.inv_offset + rounds]
self.gc(6)

# Compute challenge products
challenges_cache = _ensure_dst_keyvect(
None, 1 << rounds
) # [_ZERO] * (1 << rounds)
# Description of challenges_cache:
# Let define ch_[i] = pd.challenges[i] and
# chi[i] = pd.challenges[i]^{-1}
# Also define b_j[i] = i-th bit of integer j, 0 is MSB
# encoded in {rounds} bits
#
# challenges_cache[i] contains multiplication ch_ or chi depending on the b_i
# i.e., its binary representation. chi is for 0, ch_ for 1 in the b_i repr.
#
# challenges_cache[i] = \\mult_{j \in [0, rounds)} (b_i[j] * ch_[j]) +
# (1-b_i[j]) * chi[j]
# Originally, it is constructed iteratively, starting with 1 bit, 2 bits.
# We cannot afford having it all precomputed, thus we precompute it up to
# a threshold challenges_cache_depth_lim bits, the rest is evaluated on the fly
challenges_cache_depth_lim = const(8)
challenges_cache_depth = min(rounds, challenges_cache_depth_lim)
challenges_cache = _ensure_dst_keyvect(None, 1 << challenges_cache_depth)

challenges_cache[0] = inverses[pd.inv_offset]
challenges_cache[1] = pd.challenges[0]
for j in range(1, rounds):

for j in range(1, challenges_cache_depth):
slots = 1 << (j + 1)
for s in range(slots - 1, -1, -2):
challenges_cache.read(
s,
_sc_mul(
challenges_cache[s],
_tmp_bf_0,
challenges_cache[s // 2],
pd.challenges[j],
pd.challenges[j], # even s
),
)
challenges_cache.read(
s - 1,
_sc_mul(
challenges_cache[s - 1],
_tmp_bf_0,
challenges_cache[s // 2],
inverses[pd.inv_offset + j],
inverses[pd.inv_offset + j], # odd s
),
)

if rounds > challenges_cache_depth:
challenges_cache = KeyChallengeCacheVct(
rounds,
pd.challenges,
inverses.slice_view(pd.inv_offset, pd.inv_offset + rounds + 1),
challenges_cache,
)

# Gi and Hi
self.gc(7)
e_r1_w_y = _ensure_dst_key()
Expand Down
Loading

0 comments on commit 13f60c6

Please sign in to comment.