Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor TssMessagePuller & GeneratingKeyViewModel #1360

Merged
merged 2 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import com.vultisig.wallet.data.api.SessionApi
import com.vultisig.wallet.data.api.models.FeatureFlagJson
import com.vultisig.wallet.data.mediator.MediatorService
import com.vultisig.wallet.data.models.TssAction
import com.vultisig.wallet.data.models.TssKeyType
import com.vultisig.wallet.data.models.Vault
import com.vultisig.wallet.data.repositories.LastOpenedVaultRepository
import com.vultisig.wallet.data.repositories.VaultDataStoreRepository
Expand All @@ -30,7 +29,6 @@ import dagger.hilt.android.qualifiers.ApplicationContext
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import timber.log.Timber
import tss.ServiceImpl
Expand All @@ -44,7 +42,10 @@ internal sealed interface KeygenState {
data object ReshareECDSA : KeygenState
data object ReshareEdDSA : KeygenState
data object Success : KeygenState
data class Error(val errorMessage: UiText, val isThresholdError: Boolean) : KeygenState
data class Error(
val title: UiText?,
val message: UiText,
) : KeygenState
}

internal class GeneratingKeyViewModel(
Expand All @@ -58,6 +59,7 @@ internal class GeneratingKeyViewModel(
private val oldResharePrefix: String,
private val password: String? = null,
private val hint: String? = null,

@ApplicationContext private val context: Context,
private val navigator: Navigator<Destination>,
private val saveVault: SaveVaultUseCase,
Expand All @@ -70,198 +72,144 @@ internal class GeneratingKeyViewModel(
private val vaultPasswordRepository: VaultPasswordRepository,
private val vaultMetadataRepo: VaultMetadataRepo,
) : ViewModel() {
private var tssInstance: ServiceImpl? = null
private var tssMessenger: TssMessenger? = null

val state = MutableStateFlow<KeygenState>(KeygenState.CreatingInstance)

private val localStateAccessor: tss.LocalStateAccessor = LocalStateAccessor(vault)
val currentState: MutableStateFlow<KeygenState> = MutableStateFlow(KeygenState.CreatingInstance)
private var _messagePuller: TssMessagePuller? = null
private var featureFlag: FeatureFlagJson? = null
init {
viewModelScope.launch {
collectCurrentState()
}
}

private suspend fun collectCurrentState() {
currentState.collect { state ->
when (state) {
is KeygenState.Error -> {
stopService()
}

KeygenState.Success -> {
saveVault()
}

else -> Unit
}
}
}

suspend fun generateKey() {
currentState.value = KeygenState.CreatingInstance
state.value = KeygenState.CreatingInstance

try {
withContext(Dispatchers.IO) {
featureFlag = featureFlagApi.getFeatureFlag()
createInstance()
}
featureFlag = featureFlagApi.getFeatureFlag()

this.tssInstance?.let {
keygenWithRetry(it, 1)
}
this.vault.signers = keygenCommittee
currentState.value = KeygenState.Success
this._messagePuller?.stop()
val tss = createTss()

keygenWithRetry(tss, 1)

vault.signers = keygenCommittee
state.value = KeygenState.Success

saveVault()
} catch (e: Exception) {
Timber.tag("GeneratingKeyViewModel").d("generateKey error: %s", e.stackTraceToString())
val errorMessage = UiText.DynamicString(e.message ?: "Unknown error")
val isThresholdError = checkIsThresholdError(e)
currentState.value = KeygenState.Error(
if (isThresholdError)
UiText.StringResource(R.string.threshold_error) else errorMessage,
isThresholdError
)
}
Timber.d("generateKey error: %s", e.stackTraceToString())

state.value = resolveKeygenErrorFromException(e)

stopService()
}
}

private suspend fun keygenWithRetry(service: ServiceImpl, attempt: Int = 1) {
try {
_messagePuller = TssMessagePuller(
withContext(Dispatchers.IO) {
val messagePuller = TssMessagePuller(
service,
this.encryptionKeyHex,
encryptionKeyHex,
serverAddress,
vault.localPartyID,
sessionId,
sessionApi,
encryption,
featureFlag?.isEncryptGcmEnabled == true
)
_messagePuller?.pullMessages(null)
when (this.action) {
TssAction.KEYGEN -> {
// generate ECDSA
currentState.value = KeygenState.KeygenECDSA
val keygenRequest = tss.KeygenRequest()
keygenRequest.localPartyID = vault.localPartyID
keygenRequest.allParties = keygenCommittee.joinToString(",")
keygenRequest.chainCodeHex = vault.hexChainCode
val ecdsaResp = tssKeygen(service, keygenRequest, TssKeyType.ECDSA)
vault.pubKeyECDSA = ecdsaResp.pubKey
delay(1.seconds) // backoff for 1 second
currentState.value = KeygenState.KeygenEdDSA
val eddsaResp = tssKeygen(service, keygenRequest, TssKeyType.EDDSA)
vault.pubKeyEDDSA = eddsaResp.pubKey
}

TssAction.ReShare -> {
currentState.value = KeygenState.ReshareECDSA
val reshareRequest = tss.ReshareRequest()
reshareRequest.localPartyID = vault.localPartyID
reshareRequest.pubKey = vault.pubKeyECDSA
reshareRequest.oldParties = oldCommittee.joinToString(",")
reshareRequest.newParties = keygenCommittee.joinToString(",")
reshareRequest.resharePrefix =
vault.resharePrefix.ifEmpty { oldResharePrefix }
reshareRequest.chainCodeHex = vault.hexChainCode
val ecdsaResp = tssReshare(service, reshareRequest, TssKeyType.ECDSA)
currentState.value = KeygenState.ReshareEdDSA
delay(1.seconds) // backoff for 1 second
reshareRequest.pubKey = vault.pubKeyEDDSA
reshareRequest.newResharePrefix = ecdsaResp.resharePrefix
val eddsaResp = tssReshare(service, reshareRequest, TssKeyType.EDDSA)
vault.pubKeyEDDSA = eddsaResp.pubKey
vault.pubKeyECDSA = ecdsaResp.pubKey
vault.resharePrefix = ecdsaResp.resharePrefix
try {
messagePuller.pullMessages(null)

when (action) {
TssAction.KEYGEN -> {
// generate ECDSA
state.value = KeygenState.KeygenECDSA
val keygenRequest = tss.KeygenRequest()
keygenRequest.localPartyID = vault.localPartyID
keygenRequest.allParties = keygenCommittee.joinToString(",")
keygenRequest.chainCodeHex = vault.hexChainCode
val ecdsaResp = service.keygenECDSA(keygenRequest)
vault.pubKeyECDSA = ecdsaResp.pubKey
delay(1.seconds) // backoff for 1 second
state.value = KeygenState.KeygenEdDSA
val eddsaResp = service.keygenEdDSA(keygenRequest)
vault.pubKeyEDDSA = eddsaResp.pubKey
}

TssAction.ReShare -> {
state.value = KeygenState.ReshareECDSA
val reshareRequest = tss.ReshareRequest()
reshareRequest.localPartyID = vault.localPartyID
reshareRequest.pubKey = vault.pubKeyECDSA
reshareRequest.oldParties = oldCommittee.joinToString(",")
reshareRequest.newParties = keygenCommittee.joinToString(",")
reshareRequest.resharePrefix =
vault.resharePrefix.ifEmpty { oldResharePrefix }
reshareRequest.chainCodeHex = vault.hexChainCode
val ecdsaResp = service.reshareECDSA(reshareRequest)
state.value = KeygenState.ReshareEdDSA
delay(1.seconds) // backoff for 1 second
reshareRequest.pubKey = vault.pubKeyEDDSA
reshareRequest.newResharePrefix = ecdsaResp.resharePrefix
val eddsaResp = service.resharingEdDSA(reshareRequest)
vault.pubKeyEDDSA = eddsaResp.pubKey
vault.pubKeyECDSA = ecdsaResp.pubKey
vault.resharePrefix = ecdsaResp.resharePrefix
}
}
}
// here is the keygen process is done
withContext(Dispatchers.IO) {
sessionApi.markLocalPartyComplete(serverAddress, sessionId, listOf(vault.localPartyID))

// here is the keygen process is done
sessionApi.markLocalPartyComplete(
serverAddress,
sessionId,
listOf(vault.localPartyID)
)
Timber.d("Local party ${vault.localPartyID} marked as complete")

var counter = 0
var isSuccess = false
while (counter < 60){
val serverCompletedParties = sessionApi.getCompletedParties(serverAddress, sessionId)
while (counter < 60) {
val serverCompletedParties =
sessionApi.getCompletedParties(serverAddress, sessionId)
if (serverCompletedParties.containsAll(keygenCommittee)) {
isSuccess = true
break // this means all parties have completed the key generation process
}
delay(1000)
delay(1.seconds)
counter++
}
if (isSuccess.not()) {

if (!isSuccess) {
throw Exception("Timeout waiting for all parties to complete the key generation process")
}

Timber.d("All parties have completed the key generation process")

}
} catch (e: Exception) {
this._messagePuller?.stop()
Timber.tag("GeneratingKeyViewModel")
.e("attempt $attempt keygenWithRetry: ${e.stackTraceToString()}")
if (attempt < 3) {
keygenWithRetry(service, attempt + 1)
} else {
throw e
messagePuller.stop()
} catch (e: Exception) {
messagePuller.stop()

Timber.e(e, "attempt $attempt keygenWithRetry failed")

if (attempt < MAX_KEYGEN_ATTEMPTS) {
keygenWithRetry(service, attempt + 1)
} else {
throw e
}
}
}
}

private fun createInstance() {
this.tssMessenger = TssMessenger(
private suspend fun createTss(): ServiceImpl = withContext(Dispatchers.IO) {
val messenger = TssMessenger(
serverAddress,
sessionId,
encryptionKeyHex,
sessionApi = sessionApi,
coroutineScope = viewModelScope,
encryption = encryption,
isEncryptionGCM = this.featureFlag?.isEncryptGcmEnabled == true,
isEncryptionGCM = featureFlag?.isEncryptGcmEnabled == true,
)
this.tssMessenger?.let { messenger ->
// this will take a while
this.tssInstance = Tss.newService(messenger, this.localStateAccessor, true)
}

}

private suspend fun tssKeygen(
service: ServiceImpl,
keygenRequest: tss.KeygenRequest,
tssKeyType: TssKeyType,
): tss.KeygenResponse {
return withContext(Dispatchers.IO) {
when (tssKeyType) {
TssKeyType.ECDSA -> {
return@withContext service.keygenECDSA(keygenRequest)
}

TssKeyType.EDDSA -> {
return@withContext service.keygenEdDSA(keygenRequest)
}
}
}
}

private suspend fun tssReshare(
service: ServiceImpl,
reshareRequest: tss.ReshareRequest,
tssKeyType: TssKeyType,
): tss.ReshareResponse {
return withContext(Dispatchers.IO) {
when (tssKeyType) {
TssKeyType.ECDSA -> {
return@withContext service.reshareECDSA(reshareRequest)
}

TssKeyType.EDDSA -> {
return@withContext service.resharingEdDSA(reshareRequest)
}
}
}
// this will take a while
return@withContext Tss.newService(messenger, localStateAccessor, true)
}

private suspend fun saveVault() {
Expand Down Expand Up @@ -312,9 +260,28 @@ internal class GeneratingKeyViewModel(

}

private fun checkIsThresholdError(errorMessage: Exception) =
errorMessage.message?.let { message ->
private fun resolveKeygenErrorFromException(e: Exception): KeygenState.Error {
val isThresholdError = checkIsThresholdError(e)

return KeygenState.Error(
title = when {
isThresholdError -> null
isReshareMode -> UiText.StringResource(R.string.generating_key_screen_reshare_failed)
else -> UiText.StringResource(R.string.generating_key_screen_keygen_failed)
},
message = if (isThresholdError) {
UiText.StringResource(R.string.threshold_error)
} else {
UiText.DynamicString(e.message ?: "Unknown error")
}
)
}

private fun checkIsThresholdError(exception: Exception) =
exception.message?.let { message ->
message.contains("threshold") ||
message.contains("failed to update from bytes to new local party")
} ?: false
}
}

private const val MAX_KEYGEN_ATTEMPTS = 3
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
package com.vultisig.wallet.ui.models.keygen

import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.collectAsState
import androidx.compose.ui.res.stringResource
import androidx.hilt.navigation.compose.hiltViewModel
import androidx.navigation.NavHostController
import com.vultisig.wallet.R
import com.vultisig.wallet.ui.components.KeepScreenOn
import com.vultisig.wallet.ui.screens.keygen.GeneratingKey
import com.vultisig.wallet.ui.screens.keygen.KeygenPeerDiscovery
Expand Down Expand Up @@ -34,9 +31,5 @@ fun KeygenFlowView(
KeygenFlowState.ERROR -> {
KeyGenErrorScreen(navController)
}

KeygenFlowState.SUCCESS -> {
Text(text = stringResource(R.string.keygen_flow_views_success))
}
}
}
Loading
Loading