From 7598f51171d8bf7c90cc2151b538ed792d663fbb Mon Sep 17 00:00:00 2001 From: yvebe <14962060+yvebe@users.noreply.github.com> Date: Thu, 9 Jan 2025 14:52:11 +0400 Subject: [PATCH] Add dkls/schnorr keysign --- .../vultisig/wallet/data/keygen/DklsHelper.kt | 28 ++ .../wallet/data/keygen/DklsKeysign.kt | 375 ++++++++++++++++++ .../wallet/data/keygen/SchnorrKeysign.kt | 375 ++++++++++++++++++ .../ui/models/keysign/JoinKeysignViewModel.kt | 3 +- .../ui/models/keysign/KeysignFlowViewModel.kt | 1 + .../ui/models/keysign/KeysignViewModel.kt | 74 +++- commondata | 2 +- .../vultisig/wallet/data/api/SessionApi.kt | 8 +- 8 files changed, 862 insertions(+), 4 deletions(-) create mode 100644 app/src/main/java/com/vultisig/wallet/data/keygen/DklsKeysign.kt create mode 100644 app/src/main/java/com/vultisig/wallet/data/keygen/SchnorrKeysign.kt diff --git a/app/src/main/java/com/vultisig/wallet/data/keygen/DklsHelper.kt b/app/src/main/java/com/vultisig/wallet/data/keygen/DklsHelper.kt index ae2723ac9..ccaf5846f 100644 --- a/app/src/main/java/com/vultisig/wallet/data/keygen/DklsHelper.kt +++ b/app/src/main/java/com/vultisig/wallet/data/keygen/DklsHelper.kt @@ -22,4 +22,32 @@ object DklsHelper { } return byteArray.toByteArray() } + + // Function to encode an integer as ASN.1 DER + fun encodeASN1Integer(value: ByteArray): ByteArray { + val encoded = mutableListOf() + encoded.add(0x02) // ASN.1 INTEGER tag + if (value.first() >= 0x80.toByte()) { + encoded.add((value.size + 1).toByte()) + encoded.add(0x00) + } else { + encoded.add(value.size.toByte()) + } + encoded.addAll(value.toList()) + return encoded.toByteArray() + } + + // Function to create a DER-encoded ECDSA signature + fun createDERSignature(r: ByteArray, s: ByteArray): ByteArray { + val encodedR = encodeASN1Integer(r) + val encodedS = encodeASN1Integer(s) + + val derSignature = mutableListOf() + derSignature.add(0x30) // ASN.1 SEQUENCE tag + derSignature.add((encodedR.size + encodedS.size).toByte()) + derSignature.addAll(encodedR.toList()) + derSignature.addAll(encodedS.toList()) + + return derSignature.toByteArray() + } } \ No newline at end of file diff --git a/app/src/main/java/com/vultisig/wallet/data/keygen/DklsKeysign.kt b/app/src/main/java/com/vultisig/wallet/data/keygen/DklsKeysign.kt new file mode 100644 index 000000000..e668625b6 --- /dev/null +++ b/app/src/main/java/com/vultisig/wallet/data/keygen/DklsKeysign.kt @@ -0,0 +1,375 @@ +@file:OptIn(ExperimentalEncodingApi::class, ExperimentalStdlibApi::class) + +package com.vultisig.wallet.data.keygen + +import com.silencelaboratories.godkls.BufferUtilJNI +import com.silencelaboratories.godkls.Handle +import com.silencelaboratories.godkls.go_slice +import com.silencelaboratories.godkls.godkls.dkls_decode_message +import com.silencelaboratories.godkls.godkls.dkls_keyshare_from_bytes +import com.silencelaboratories.godkls.godkls.dkls_keyshare_key_id +import com.silencelaboratories.godkls.godkls.dkls_sign_session_finish +import com.silencelaboratories.godkls.godkls.dkls_sign_session_from_setup +import com.silencelaboratories.godkls.godkls.dkls_sign_session_input_message +import com.silencelaboratories.godkls.godkls.dkls_sign_session_message_receiver +import com.silencelaboratories.godkls.godkls.dkls_sign_session_output_message +import com.silencelaboratories.godkls.godkls.dkls_sign_setupmsg_new +import com.silencelaboratories.godkls.godkls.tss_buffer_free +import com.silencelaboratories.godkls.lib_error +import com.silencelaboratories.godkls.lib_error.LIB_OK +import com.silencelaboratories.godkls.tss_buffer +import com.vultisig.wallet.data.api.KeysignVerify +import com.vultisig.wallet.data.api.SessionApi +import com.vultisig.wallet.data.common.md5 +import com.vultisig.wallet.data.mediator.Message +import com.vultisig.wallet.data.models.Vault +import com.vultisig.wallet.data.tss.TssMessenger +import com.vultisig.wallet.data.usecases.Encryption +import com.vultisig.wallet.data.utils.Numeric +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch +import timber.log.Timber +import tss.KeysignResponse +import java.util.concurrent.locks.ReentrantLock +import kotlin.concurrent.withLock +import kotlin.io.encoding.Base64 +import kotlin.io.encoding.ExperimentalEncodingApi + +class DKLSKeysign( + val keysignCommittee: List, + val mediatorURL: String, + val sessionID: String, + val messageToSign: List, + val vault: Vault, + val encryptionKeyHex: String, + val chainPath: String, + val isInitiateDevice: Boolean, + + private val sessionApi: SessionApi, + private val encryption: Encryption, +) { + val localPartyID: String = vault.localPartyID + val publicKeyECDSA: String = vault.pubKeyECDSA + var messenger: TssMessenger = + TssMessenger( + serverAddress = mediatorURL, + sessionID = sessionID, + encryptionHex = encryptionKeyHex, + sessionApi = sessionApi, + coroutineScope = CoroutineScope(Dispatchers.IO), + encryption = encryption, + isEncryptionGCM = true + ) + var keysignDoneIndicator = false + val keySignLock = ReentrantLock() + val cache = mutableMapOf() + val signatures = mutableMapOf() + var keyshare: ByteArray = byteArrayOf() + + fun isKeysignDone(): Boolean = keySignLock.withLock { keysignDoneIndicator } + + fun setKeysignDone(status: Boolean) = keySignLock.withLock { keysignDoneIndicator = status } + + fun getKeyshareString(): String? { + for (ks in vault.keyshares) { + if (ks.pubKey == publicKeyECDSA) { + return ks.keyShare + } + } + return null + } + + @Throws(Exception::class) + fun getKeyshareBytes(): ByteArray { + val localKeyshare = getKeyshareString() ?: error("fail to get local keyshare") + val keyshareData = Base64.decode(localKeyshare) + return keyshareData + } + + @Throws(Exception::class) + fun getDKLSKeyshareID(): ByteArray { + val buf = tss_buffer() + try { + val keyShareBytes = getKeyshareBytes() + val keyshareSlice = keyShareBytes.toGoSlice() + val h = Handle() + val result = dkls_keyshare_from_bytes(keyshareSlice, h) + if (result != LIB_OK) { + error("fail to create keyshare handle from bytes, $result") + } + val keyIDResult = dkls_keyshare_key_id(h, buf) + if (keyIDResult != LIB_OK) { + error("fail to get key id from keyshare: $keyIDResult") + } + return BufferUtilJNI.get_bytes_from_tss_buffer(buf) + } finally { + tss_buffer_free(buf) + } + } + + @Throws(Exception::class) + fun getDKLSKeysignSetupMessage(message: String): ByteArray { + val buf = tss_buffer() + try { + val keyIdArr = getDKLSKeyshareID() + val keyIdSlice = keyIdArr.toGoSlice() + val byteArray = DklsHelper.arrayToBytes(keysignCommittee) + val ids = byteArray.toGoSlice() + val chainPathArr = chainPath.replace("'", "").toByteArray(Charsets.UTF_8) + val chainPathSlice = chainPathArr.toGoSlice() + val decodedMsgData = message.hexToByteArray() + val msgArr = decodedMsgData + val msgSlice = msgArr.toGoSlice() + val err = dkls_sign_setupmsg_new(keyIdSlice, chainPathSlice, msgSlice, ids, buf) + if (err != LIB_OK) { + error("fail to setup keysign message, dkls error: $err") + } + return BufferUtilJNI.get_bytes_from_tss_buffer(buf) + } finally { + tss_buffer_free(buf) + } + } + + @Throws(Exception::class) + fun DKLSDecodeMessage(setupMsg: ByteArray): String { + val buf = tss_buffer() + try { + val setupMsgSlice = setupMsg.toGoSlice() + val result = dkls_decode_message(setupMsgSlice, buf) + if (result != LIB_OK) { + error("fail to extract message from setup message: $result") + } + return BufferUtilJNI.get_bytes_from_tss_buffer(buf).toHexString() + } finally { + tss_buffer_free(buf) + } + } + + fun getOutboundMessageReceiver(handle: Handle, message: go_slice, idx: Long): ByteArray { + val bufReceiver = tss_buffer() + try { + val receiverResult = dkls_sign_session_message_receiver(handle, message, idx, bufReceiver) + if (receiverResult != LIB_OK) { + println("fail to get receiver message, error: $receiverResult") + return byteArrayOf() + } + return BufferUtilJNI.get_bytes_from_tss_buffer(bufReceiver) + } finally { + tss_buffer_free(bufReceiver) + } + } + + fun getDKLSOutboundMessage(handle: Handle): Pair { + val buf = tss_buffer() + try { + val result = dkls_sign_session_output_message(handle, buf) + if (result != LIB_OK) { + println("fail to get outbound message: $result") + return Pair(result, byteArrayOf()) + } + return Pair(result, BufferUtilJNI.get_bytes_from_tss_buffer(buf)) + } finally { + tss_buffer_free(buf) + } + } + + suspend fun processDKLSOutboundMessage(handle: Handle) { + while (true) { + val (result, outboundMessage) = getDKLSOutboundMessage(handle) + if (result != LIB_OK) { + println("fail to get outbound message, $result") + } + if (outboundMessage.isEmpty()) { + if (isKeysignDone()) { + println("DKLS ECDSA keysign finished") + return + } + delay(100) + continue + } + val message = outboundMessage.toGoSlice() + val encodedOutboundMessage = Base64.encode(outboundMessage) + for (i in keysignCommittee.indices) { + val receiverArray = getOutboundMessageReceiver(handle, message, i.toLong()) + if (receiverArray.isEmpty()) { + break + } + val receiverString = String(receiverArray, Charsets.UTF_8) + println("sending message from $localPartyID to: $receiverString, content length: ${encodedOutboundMessage.length}") + messenger?.send(localPartyID, receiverString, encodedOutboundMessage) + } + } + } + + suspend fun pullInboundMessages(handle: Handle, messageID: String): Boolean { + Timber.d("start pulling inbound messages") + + val start = System.nanoTime() + while (true) { + try { + val msgs = sessionApi + .getTssMessages(mediatorURL, sessionID, localPartyID, messageID) + + if (msgs.isNotEmpty()) { + if (processInboundMessage(handle, msgs, messageID)) { + return true + } + } else { + delay(100) + } + } catch (e: Exception) { + Timber.e("Failed to get messages", e) + } + + val elapsedTime = (System.nanoTime() - start) / 1_000_000_000.0 + if (elapsedTime > 60) { + error("timeout: failed to create vault within 60 seconds") + } + } + + return false + } + + suspend fun processInboundMessage(handle: Handle, msgs: List, messageID: String): Boolean { + val sortedMsgs = msgs.sortedBy { it.sequenceNo } + for (msg in sortedMsgs) { + val key = "$sessionID-$localPartyID-$messageID-${msg.hash}" + if (cache[key] != null) { + println("message with key: $key has been applied before") + continue + } + println("Got message from: ${msg.from}, to: ${msg.to}, key: $key") + val decryptedBody = encryption.decrypt( + Base64.Default.decode(msg.body), + Numeric.hexStringToByteArray(encryptionKeyHex) + ) ?: error("fail to decrypt message body") + val decodedMsg = Base64.decode(decryptedBody) + val decryptedBodySlice = decodedMsg.toGoSlice() + val isFinished = intArrayOf(0) + val result = dkls_sign_session_input_message(handle, decryptedBodySlice, isFinished) + if (result != LIB_OK) { + error("fail to apply message to dkls, $result") + } + cache[key] = Any() + deleteMessageFromServer(msg.hash, messageID) + if (isFinished[0] != 0) { + return true + } + } + return false + } + + private suspend fun deleteMessageFromServer(hash: String, messageID: String) { + sessionApi.deleteTssMessage(mediatorURL, sessionID, localPartyID, hash, messageID) + } + + suspend fun DKLSKeysignOneMessageWithRetry(attempt: Int, messageToSign: String) { + setKeysignDone(false) + val msgHash = messageToSign.md5() + val localMessenger = TssMessenger(mediatorURL, sessionID, encryptionKeyHex, + sessionApi, CoroutineScope(Dispatchers.IO), encryption, true) + localMessenger.setMessageID(msgHash) + messenger = localMessenger + try { + val keysignSetupMsg: ByteArray + + if (isInitiateDevice) { + keysignSetupMsg = getDKLSKeysignSetupMessage(messageToSign) + + sessionApi.uploadSetupMessage( + serverUrl = mediatorURL, + sessionId = sessionID, + message = Base64.encode( + encryption.encrypt( + Base64.encodeToByteArray(keysignSetupMsg), + Numeric.hexStringToByteArray(encryptionKeyHex) + ) + ) + ) + } else { + keysignSetupMsg = sessionApi.getSetupMessage(mediatorURL, sessionID) + .let { + encryption.decrypt( + Base64.Default.decode(it), + Numeric.hexStringToByteArray(encryptionKeyHex) + )!! + }.let { + Base64.decode(it) + } + } + + val signingMsg = DKLSDecodeMessage(keysignSetupMsg) + if (signingMsg != messageToSign) { + error("message doesn't match ($messageToSign) vs ($signingMsg)") + } + val finalSetupMsgArr = keysignSetupMsg + val decodedSetupMsg = finalSetupMsgArr.toGoSlice() + val handler = Handle() + val localPartyIDArr = localPartyID.toByteArray() + val localPartySlice = localPartyIDArr.toGoSlice() + val keyShareBytes = getKeyshareBytes() + val keyshareSlice = keyShareBytes.toGoSlice() + val keyshareHandle = Handle() + val result = dkls_keyshare_from_bytes(keyshareSlice, keyshareHandle) + if (result != LIB_OK) { + error("fail to create keyshare handle from bytes, $result") + } + val sessionResult = dkls_sign_session_from_setup(decodedSetupMsg, localPartySlice, keyshareHandle, handler) + if (sessionResult != LIB_OK) { + error("fail to create sign session from setup message, error: $sessionResult") + } + CoroutineScope(Dispatchers.IO).launch { + processDKLSOutboundMessage(handler) + } + val isFinished = pullInboundMessages(handler, msgHash) + if (isFinished) { + setKeysignDone(true) + val sig = dklsSignSessionFinish(handler) + val resp = KeysignResponse() + resp.msg = messageToSign + val r = sig.copyOfRange(0, 32) + val s = sig.copyOfRange(32, 64) + resp.r = r.toHexString() + resp.s = s.toHexString() + resp.recoveryID = String.format("%02x", sig[64]) + resp.derSignature = DklsHelper.createDERSignature(r, s).toHexString() + val keySignVerify = KeysignVerify(mediatorURL, sessionID, sessionApi) + keySignVerify.markLocalPartyKeysignComplete(msgHash, resp) + signatures[messageToSign] = resp + } + } catch (e: Exception) { + println("Failed to sign message ($messageToSign), error: ${e.localizedMessage}") + if (attempt < 3) { + DKLSKeysignOneMessageWithRetry(attempt + 1, messageToSign) + } + } + } + + @Throws(Exception::class) + fun dklsSignSessionFinish(handle: Handle): ByteArray { + val buf = tss_buffer() + try { + val result = dkls_sign_session_finish(handle, buf) + if (result != LIB_OK) { + error("fail to get keysign signature $result") + } + return BufferUtilJNI.get_bytes_from_tss_buffer(buf) + } finally { + tss_buffer_free(buf) + } + } + + suspend fun DKLSKeysignWithRetry(attempt: Int) { + for (msg in messageToSign) { + DKLSKeysignOneMessageWithRetry(0, msg) + } + } + + private fun ByteArray.toGoSlice(): go_slice { + val slice = go_slice() + BufferUtilJNI.set_bytes_on_go_slice(slice, this) + return slice + } +} \ No newline at end of file diff --git a/app/src/main/java/com/vultisig/wallet/data/keygen/SchnorrKeysign.kt b/app/src/main/java/com/vultisig/wallet/data/keygen/SchnorrKeysign.kt new file mode 100644 index 000000000..ddaac3e0d --- /dev/null +++ b/app/src/main/java/com/vultisig/wallet/data/keygen/SchnorrKeysign.kt @@ -0,0 +1,375 @@ +@file:OptIn(ExperimentalEncodingApi::class, ExperimentalStdlibApi::class) + +package com.vultisig.wallet.data.keygen + +import com.silencelaboratories.goschnorr.BufferUtilJNI +import com.silencelaboratories.goschnorr.Handle +import com.silencelaboratories.goschnorr.go_slice +import com.silencelaboratories.goschnorr.goschnorr.schnorr_decode_message +import com.silencelaboratories.goschnorr.goschnorr.schnorr_keyshare_from_bytes +import com.silencelaboratories.goschnorr.goschnorr.schnorr_keyshare_key_id +import com.silencelaboratories.goschnorr.goschnorr.schnorr_sign_session_finish +import com.silencelaboratories.goschnorr.goschnorr.schnorr_sign_session_from_setup +import com.silencelaboratories.goschnorr.goschnorr.schnorr_sign_session_input_message +import com.silencelaboratories.goschnorr.goschnorr.schnorr_sign_session_message_receiver +import com.silencelaboratories.goschnorr.goschnorr.schnorr_sign_session_output_message +import com.silencelaboratories.goschnorr.goschnorr.schnorr_sign_setupmsg_new +import com.silencelaboratories.goschnorr.goschnorr.tss_buffer_free +import com.silencelaboratories.goschnorr.lib_error +import com.silencelaboratories.goschnorr.lib_error.LIB_OK +import com.silencelaboratories.goschnorr.tss_buffer +import com.vultisig.wallet.data.api.KeysignVerify +import com.vultisig.wallet.data.api.SessionApi +import com.vultisig.wallet.data.common.md5 +import com.vultisig.wallet.data.mediator.Message +import com.vultisig.wallet.data.models.Vault +import com.vultisig.wallet.data.tss.TssMessenger +import com.vultisig.wallet.data.usecases.Encryption +import com.vultisig.wallet.data.utils.Numeric +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch +import timber.log.Timber +import tss.KeysignResponse +import java.util.concurrent.locks.ReentrantLock +import kotlin.concurrent.withLock +import kotlin.io.encoding.Base64 +import kotlin.io.encoding.ExperimentalEncodingApi + +class SchnorrKeysign( + val keysignCommittee: List, + val mediatorURL: String, + val sessionID: String, + val messageToSign: List, + val vault: Vault, + val encryptionKeyHex: String, + val isInitiateDevice: Boolean, + + + private val sessionApi: SessionApi, + private val encryption: Encryption, +) { + val localPartyID: String = vault.localPartyID + val publicKeyEdDSA: String = vault.pubKeyEDDSA + var messenger: TssMessenger? = null + var keysignDoneIndicator = false + val keySignLock = ReentrantLock() + val cache = mutableMapOf() + val signatures = mutableMapOf() + var keyshare: ByteArray = byteArrayOf() + + fun isKeysignDone(): Boolean = keySignLock.withLock { keysignDoneIndicator } + + fun setKeysignDone(status: Boolean) = keySignLock.withLock { keysignDoneIndicator = status } + + fun getKeyshareString(): String? { + for (ks in vault.keyshares) { + if (ks.pubKey == publicKeyEdDSA) { + return ks.keyShare + } + } + return null + } + + @Throws(Exception::class) + fun getKeyshareBytes(): ByteArray { + val localKeyshare = + getKeyshareString() ?: throw RuntimeException("fail to get local keyshare") + return Base64.decode(localKeyshare) + } + + @Throws(Exception::class) + fun getKeyshareID(): ByteArray { + val buf = tss_buffer() + try { + val keyShareBytes = getKeyshareBytes() + val keyshareSlice = keyShareBytes.toGoSlice() + val h = Handle() + val result = schnorr_keyshare_from_bytes(keyshareSlice, h) + if (result != LIB_OK) { + throw RuntimeException("fail to create keyshare handle from bytes, $result") + } + val keyIDResult = schnorr_keyshare_key_id(h, buf) + if (keyIDResult != LIB_OK) { + throw RuntimeException("fail to get key id from keyshare: $keyIDResult") + } + return BufferUtilJNI.get_bytes_from_tss_buffer(buf) + } finally { + tss_buffer_free(buf) + } + } + + @Throws(Exception::class) + fun getKeysignSetupMessage(message: String): ByteArray { + val buf = tss_buffer() + try { + val keyIdArr = getKeyshareID() + val keyIdSlice = keyIdArr.toGoSlice() + val byteArray = DklsHelper.arrayToBytes(keysignCommittee) + val ids = byteArray.toGoSlice() + val decodedMsgData = message.hexToByteArray() + val msgSlice = decodedMsgData.toGoSlice() + val err = schnorr_sign_setupmsg_new(keyIdSlice, null, msgSlice, ids, buf) + if (err != LIB_OK) { + throw RuntimeException("fail to setup keysign message, error: $err") + } + return BufferUtilJNI.get_bytes_from_tss_buffer(buf) + } finally { + tss_buffer_free(buf) + } + } + + @Throws(Exception::class) + fun DKLSDecodeMessage(setupMsg: ByteArray): String { + val buf = tss_buffer() + try { + val setupMsgSlice = setupMsg.toGoSlice() + val result = schnorr_decode_message(setupMsgSlice, buf) + if (result != LIB_OK) { + throw RuntimeException("fail to extract message from setup message: $result") + } + return BufferUtilJNI.get_bytes_from_tss_buffer(buf).toHexString() + } finally { + tss_buffer_free(buf) + } + } + + fun getOutboundMessageReceiver(handle: Handle, message: go_slice, idx: Long): ByteArray { + val bufReceiver = tss_buffer() + try { + val receiverResult = + schnorr_sign_session_message_receiver(handle, message, idx, bufReceiver) + if (receiverResult != LIB_OK) { + println("fail to get receiver message, error: $receiverResult") + return byteArrayOf() + } + return BufferUtilJNI.get_bytes_from_tss_buffer(bufReceiver) + } finally { + tss_buffer_free(bufReceiver) + } + } + + fun getSchnorrOutboundMessage(handle: Handle): Pair { + val buf = tss_buffer() + try { + val result = schnorr_sign_session_output_message(handle, buf) + if (result != LIB_OK) { + println("fail to get outbound message: $result") + return Pair(result, byteArrayOf()) + } + return Pair(result, BufferUtilJNI.get_bytes_from_tss_buffer(buf)) + } finally { + tss_buffer_free(buf) + } + } + + suspend fun processSchnorrOutboundMessage(handle: Handle) { + while (true) { + val (result, outboundMessage) = getSchnorrOutboundMessage(handle) + if (result != LIB_OK) { + println("fail to get outbound message") + } + if (outboundMessage.isEmpty()) { + if (isKeysignDone()) { + println("EdDSA keysign finished") + return + } + delay(100) + continue + } + val message = outboundMessage.toGoSlice() + val encodedOutboundMessage = Base64.encode(outboundMessage) + for (i in keysignCommittee.indices) { + val receiverArray = getOutboundMessageReceiver(handle, message, i.toLong()) + if (receiverArray.isEmpty()) { + break + } + val receiverString = String(receiverArray, Charsets.UTF_8) + println("sending message from $localPartyID to: $receiverString, content length: ${encodedOutboundMessage.length}") + messenger?.send(localPartyID, receiverString, encodedOutboundMessage) + } + } + } + + suspend fun pullInboundMessages(handle: Handle, messageID: String): Boolean { + Timber.d("start pulling inbound messages") + + val start = System.nanoTime() + while (true) { + try { + val msgs = sessionApi + .getTssMessages(mediatorURL, sessionID, localPartyID, messageID) + + if (msgs.isNotEmpty()) { + if (processInboundMessage(handle, msgs, messageID)) { + return true + } + } else { + delay(100) + } + } catch (e: Exception) { + Timber.e("Failed to get messages", e) + } + + val elapsedTime = (System.nanoTime() - start) / 1_000_000_000.0 + if (elapsedTime > 60) { + error("timeout: failed to create vault within 60 seconds") + } + } + + return false + } + + suspend fun processInboundMessage( + handle: Handle, + msgs: List, + messageID: String + ): Boolean { + val sortedMsgs = msgs.sortedBy { it.sequenceNo } + for (msg in sortedMsgs) { + val key = "$sessionID-$localPartyID-$messageID-${msg.hash}" + if (cache.containsKey(key)) { + println("message with key: $key has been applied before") + continue + } + println("Got message from: ${msg.from}, to: ${msg.to}, key: $key") + val decryptedBody = encryption.decrypt( + Base64.Default.decode(msg.body), + Numeric.hexStringToByteArray(encryptionKeyHex) + ) ?: error("fail to decrypt message body") + val decodedMsg = Base64.decode(decryptedBody) + val decryptedBodySlice = decodedMsg.toGoSlice() + val isFinished = intArrayOf(0) + val result = schnorr_sign_session_input_message(handle, decryptedBodySlice, isFinished) + if (result != LIB_OK) { + throw RuntimeException("fail to apply message to dkls, $result") + } + cache[key] = Any() + deleteMessageFromServer(msg.hash, messageID) + if (isFinished[0] != 0) { + return true + } + } + return false + } + + private suspend fun deleteMessageFromServer(hash: String, messageID: String) { + sessionApi.deleteTssMessage(mediatorURL, sessionID, localPartyID, hash, messageID) + } + + suspend fun keysignOneMessageWithRetry(attempt: Int, messageToSign: String) { + setKeysignDone(false) + val msgHash = messageToSign.md5() + val localMessenger = TssMessenger( + mediatorURL, sessionID, encryptionKeyHex, sessionApi, + CoroutineScope(Dispatchers.IO), encryption, true + ) + localMessenger.setMessageID(msgHash) + messenger = localMessenger + try { + val keysignSetupMsg: ByteArray + + if (isInitiateDevice) { + keysignSetupMsg = getKeysignSetupMessage(messageToSign) + + sessionApi.uploadSetupMessage( + serverUrl = mediatorURL, + sessionId = sessionID, + message = Base64.encode( + encryption.encrypt( + Base64.encodeToByteArray(keysignSetupMsg), + Numeric.hexStringToByteArray(encryptionKeyHex) + ) + ) + ) + } else { + keysignSetupMsg = sessionApi.getSetupMessage(mediatorURL, sessionID) + .let { + encryption.decrypt( + Base64.Default.decode(it), + Numeric.hexStringToByteArray(encryptionKeyHex) + )!! + }.let { + Base64.decode(it) + } + } + + val signingMsg = DKLSDecodeMessage(keysignSetupMsg) + if (signingMsg != messageToSign) { + throw RuntimeException("message doesn't match ($messageToSign) vs ($signingMsg)") + } + val finalSetupMsgArr = keysignSetupMsg + val decodedSetupMsg = finalSetupMsgArr.toGoSlice() + val handler = Handle() + val localPartyIDArr = localPartyID.toByteArray() + val localPartySlice = localPartyIDArr.toGoSlice() + val keyShareBytes = getKeyshareBytes() + val keyshareSlice = keyShareBytes.toGoSlice() + val keyshareHandle = Handle() + val result = schnorr_keyshare_from_bytes(keyshareSlice, keyshareHandle) + if (result != LIB_OK) { + throw RuntimeException("fail to create keyshare handle from bytes, $result") + } + val sessionResult = schnorr_sign_session_from_setup( + decodedSetupMsg, + localPartySlice, + keyshareHandle, + handler + ) + if (sessionResult != LIB_OK) { + throw RuntimeException("fail to create sign session from setup message, error: $sessionResult") + } + val h = handler + val task = CoroutineScope(Dispatchers.IO).launch { + processSchnorrOutboundMessage(h) + } + val isFinished = pullInboundMessages(h, msgHash) + if (isFinished) { + setKeysignDone(true) + val sig = signSessionFinish(h) + val resp = KeysignResponse() + resp.msg = messageToSign + val r = sig.copyOfRange(0, 32).reversedArray() + val s = sig.copyOfRange(32, 64).reversedArray() + resp.r = r.toHexString() + resp.s = s.toHexString() + resp.derSignature = DklsHelper.createDERSignature(r, s).toHexString() + val keySignVerify = KeysignVerify(mediatorURL, sessionID, sessionApi) + keySignVerify.markLocalPartyKeysignComplete(msgHash, resp) + signatures[messageToSign] = resp + } + } catch (e: Exception) { + println("Failed to sign message ($messageToSign), error: ${e.localizedMessage}") + if (attempt < 3) { + keysignOneMessageWithRetry(attempt + 1, messageToSign) + } + } + } + + @Throws(Exception::class) + fun signSessionFinish(handle: Handle): ByteArray { + val buf = tss_buffer() + try { + val result = schnorr_sign_session_finish(handle, buf) + if (result != LIB_OK) { + throw RuntimeException("fail to get keysign signature $result") + } + return BufferUtilJNI.get_bytes_from_tss_buffer(buf) + } finally { + tss_buffer_free(buf) + } + } + + suspend fun keysignWithRetry(attempt: Int) { + for (msg in messageToSign) { + keysignOneMessageWithRetry(0, msg) + } + } + + private fun ByteArray.toGoSlice(): go_slice { + val slice = go_slice() + BufferUtilJNI.set_bytes_on_go_slice(slice, this) + return slice + } +} \ No newline at end of file diff --git a/app/src/main/java/com/vultisig/wallet/ui/models/keysign/JoinKeysignViewModel.kt b/app/src/main/java/com/vultisig/wallet/ui/models/keysign/JoinKeysignViewModel.kt index deab49bb4..9a64da7b2 100644 --- a/app/src/main/java/com/vultisig/wallet/ui/models/keysign/JoinKeysignViewModel.kt +++ b/app/src/main/java/com/vultisig/wallet/ui/models/keysign/JoinKeysignViewModel.kt @@ -37,8 +37,8 @@ import com.vultisig.wallet.data.models.proto.v1.KeysignPayloadProto import com.vultisig.wallet.data.models.settings.AppCurrency import com.vultisig.wallet.data.repositories.AppCurrencyRepository import com.vultisig.wallet.data.repositories.ChainAccountAddressRepository -import com.vultisig.wallet.data.repositories.FourByteRepository import com.vultisig.wallet.data.repositories.ExplorerLinkRepository +import com.vultisig.wallet.data.repositories.FourByteRepository import com.vultisig.wallet.data.repositories.GasFeeRepository import com.vultisig.wallet.data.repositories.SwapQuoteRepository import com.vultisig.wallet.data.repositories.TokenRepository @@ -215,6 +215,7 @@ internal class JoinKeysignViewModel @Inject constructor( featureFlagApi = featureFlagApi, pullTssMessages = pullTssMessages, customMessagePayload = customMessagePayload, + isInitiatingDevice = false, ) val verifyUiModel = diff --git a/app/src/main/java/com/vultisig/wallet/ui/models/keysign/KeysignFlowViewModel.kt b/app/src/main/java/com/vultisig/wallet/ui/models/keysign/KeysignFlowViewModel.kt index a95624640..518d8e717 100644 --- a/app/src/main/java/com/vultisig/wallet/ui/models/keysign/KeysignFlowViewModel.kt +++ b/app/src/main/java/com/vultisig/wallet/ui/models/keysign/KeysignFlowViewModel.kt @@ -188,6 +188,7 @@ internal class KeysignFlowViewModel @Inject constructor( featureFlagApi = featureFlagApi, transactionTypeUiModel = transactionTypeUiModel, pullTssMessages = pullTssMessages, + isInitiatingDevice = true, ) init { diff --git a/app/src/main/java/com/vultisig/wallet/ui/models/keysign/KeysignViewModel.kt b/app/src/main/java/com/vultisig/wallet/ui/models/keysign/KeysignViewModel.kt index d0a26b330..eb291b325 100644 --- a/app/src/main/java/com/vultisig/wallet/ui/models/keysign/KeysignViewModel.kt +++ b/app/src/main/java/com/vultisig/wallet/ui/models/keysign/KeysignViewModel.kt @@ -12,6 +12,9 @@ import com.vultisig.wallet.data.chains.helpers.SigningHelper import com.vultisig.wallet.data.chains.helpers.THORChainSwaps import com.vultisig.wallet.data.common.md5 import com.vultisig.wallet.data.common.toHexBytes +import com.vultisig.wallet.data.keygen.DKLSKeysign +import com.vultisig.wallet.data.keygen.SchnorrKeysign +import com.vultisig.wallet.data.models.SigningLibType import com.vultisig.wallet.data.models.TssKeyType import com.vultisig.wallet.data.models.Vault import com.vultisig.wallet.data.models.payload.BlockChainSpecific @@ -83,6 +86,7 @@ internal class KeysignViewModel( private val featureFlagApi: FeatureFlagApi, val transactionTypeUiModel: TransactionTypeUiModel?, private val pullTssMessages: PullTssMessagesUseCase, + private val isInitiatingDevice: Boolean, ) : ViewModel() { val currentState: MutableStateFlow = MutableStateFlow(KeysignState.CreatingInstance) @@ -104,9 +108,77 @@ internal class KeysignViewModel( fun startKeysign() { viewModelScope.launch { withContext(Dispatchers.IO) { - signAndBroadcast() + when (vault.libType) { + SigningLibType.GG20 -> + signAndBroadcast() + + SigningLibType.DKLS -> + startKeysignDkls() + } + } + } + } + + private suspend fun startKeysignDkls() { + val keysignPayload = keysignPayload ?: error("Keysign payload is null") + + when (keyType){ + TssKeyType.ECDSA -> { + currentState.value = KeysignState.KeysignECDSA + + val dkls = DKLSKeysign( + vault = vault, + keysignCommittee = keysignCommittee, + mediatorURL = serverUrl, + sessionID = sessionId, + encryptionKeyHex = encryptionKeyHex, + messageToSign = messagesToSign, + chainPath = keysignPayload.coin.coinType.derivationPath(), + isInitiateDevice = isInitiatingDevice, + sessionApi = sessionApi, + encryption = encryption, + ) + + dkls.DKLSKeysignWithRetry(0) + + this.signatures += dkls.signatures + + if (signatures.isEmpty()) { + error("Failed to sign transaction, signatures empty") + } + } + TssKeyType.EDDSA -> { + currentState.value = KeysignState.KeysignEdDSA + + val schnorr = SchnorrKeysign( + vault = vault, + keysignCommittee = keysignCommittee, + mediatorURL = serverUrl, + sessionID = sessionId, + encryptionKeyHex = encryptionKeyHex, + messageToSign = messagesToSign, + isInitiateDevice = isInitiatingDevice, + sessionApi = sessionApi, + encryption = encryption, + ) + + schnorr.keysignWithRetry(0) + + this.signatures += schnorr.signatures + + if (signatures.isEmpty()) { + error("Failed to sign transaction, signatures empty") + } } } + + Timber.d("All messages signed, broadcasting transaction") + + broadcastTransaction() + checkThorChainTxResult() + + currentState.value = KeysignState.KeysignFinished + isNavigateToHome = true } @Suppress("ReplaceNotNullAssertionWithElvisReturn") diff --git a/commondata b/commondata index 3a35c1eac..1d07a8dac 160000 --- a/commondata +++ b/commondata @@ -1 +1 @@ -Subproject commit 3a35c1eac46ae59f71c382530b72736b00fe0cd3 +Subproject commit 1d07a8dacecef87bd157a5f1f5b6512a77377b89 diff --git a/data/src/main/kotlin/com/vultisig/wallet/data/api/SessionApi.kt b/data/src/main/kotlin/com/vultisig/wallet/data/api/SessionApi.kt index b1d74a840..0095c09b3 100644 --- a/data/src/main/kotlin/com/vultisig/wallet/data/api/SessionApi.kt +++ b/data/src/main/kotlin/com/vultisig/wallet/data/api/SessionApi.kt @@ -26,6 +26,7 @@ interface SessionApi { serverUrl: String, sessionId: String, localPartyId: String, + messageId: String? = null, ): List suspend fun deleteTssMessage( serverUrl: String, @@ -114,8 +115,13 @@ internal class SessionApiImpl @Inject constructor( serverUrl: String, sessionId: String, localPartyId: String, + messageId: String?, ): List { - return httpClient.get("$serverUrl/message/$sessionId/$localPartyId") + return httpClient.get("$serverUrl/message/$sessionId/$localPartyId") { + messageId?.let { + header(MESSAGE_ID_HEADER_TITLE, it) + } + } .throwIfUnsuccessful() .body>() }