Skip to content

Commit

Permalink
Implement custom secp256k1 error callbacks
Browse files Browse the repository at this point in the history
These callbacks are only triggered either by arguments do not match explicit rquirements of the sepc256k1 library, or by hardware failures, memory corruption
or bug in secp256k1, and not by misuse of the library.
In theory we do not need to implement them, except to find bugs in our own code, but the default callbacks print a message to stderr and call abort() which
is not nice especially on mobile apps.

=> Here we introduce 2 specific exceptions, Secp256k1ErrorCallbackException and Secp256k1IllegalCallbackException, which are thrown when the error callback or illegal callback are called.
  • Loading branch information
sstone committed Dec 11, 2023
1 parent 929e2cd commit 8e17e70
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 12 deletions.
115 changes: 106 additions & 9 deletions jni/c/src/fr_acinq_secp256k1_Secp256k1CFunctions.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ void JNI_ThrowByName(JNIEnv *penv, const char *name, const char *msg)
(*penv)->DeleteLocalRef(penv, cls);
}
}
/**
* secp256k1 uses callbacks for errors that are either hw pbs or bugs in the calling library, for example
* passing parameters with values that are explicitly defined as illegal in the API, and should never be called for normal operations
* But if they are, default behaviour is to print an error to stderr and abort which is not what we want especially in mobile apps
* => we set up string pointers in every method, and custom callback that will set them to the message passed in by sec256k1's callbacks, which
* we turn into specific Sec256k1 exceptions
*/
#define SETUP_ERROR_CALLBACKS \
char *error_callback_message = NULL; \
char *illegal_callback_message = NULL; \
secp256k1_context_set_error_callback(ctx, my_error_callback_fn, &error_callback_message); \
secp256k1_context_set_illegal_callback(ctx, my_illegal_callback_fn, &illegal_callback_message);

#define CHECKRESULT(errorcheck, message) \
{ \
Expand All @@ -33,15 +45,61 @@ void JNI_ThrowByName(JNIEnv *penv, const char *name, const char *msg)
} \
}

#define CHECKRESULT1(errorcheck, message, dosomething) \
{ \
if (errorcheck) \
{ \
dosomething; \
JNI_ThrowByName(penv, "fr/acinq/secp256k1/Secp256k1Exception", message); \
return 0; \
} \
#define CHECKRESULT(errorcheck, message) \
{ \
if (error_callback_message) \
{ \
JNI_ThrowByName(penv, "fr/acinq/secp256k1/Secp256k1ErrorCallbackException", error_callback_message); \
return 0; \
} \
if (illegal_callback_message) \
{ \
JNI_ThrowByName(penv, "fr/acinq/secp256k1/Secp256k1IllegalCallbackException", illegal_callback_message); \
return 0; \
} \
if (errorcheck) \
{ \
JNI_ThrowByName(penv, "fr/acinq/secp256k1/Secp256k1Exception", message); \
return 0; \
} \
}

#define CHECKRESULT1(errorcheck, message, dosomething) \
{ \
if (error_callback_message) \
{ \
dosomething; \
JNI_ThrowByName(penv, "fr/acinq/secp256k1/Secp256k1ErrorCallbackException", error_callback_message); \
return 0; \
} \
if (illegal_callback_message) \
{ \
dosomething; \
JNI_ThrowByName(penv, "fr/acinq/secp256k1/Secp256k1IllegalCallbackException", illegal_callback_message); \
return 0; \
} \
if (errorcheck) \
{ \
JNI_ThrowByName(penv, "fr/acinq/secp256k1/Secp256k1Exception", message); \
return 0; \
} \
}

void my_illegal_callback_fn(const char *str, void *data)
{
if (data != NULL)
{
*(char **)data = str;
}
}

void my_error_callback_fn(const char *str, void *data)
{
if (data != NULL)
{
*(char **)data = str;
}
}

/*
* Class: fr_acinq_bitcoin_Secp256k1Bindings
Expand Down Expand Up @@ -84,6 +142,8 @@ JNIEXPORT jint JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256k1_1ec
if ((*penv)->GetArrayLength(penv, jseckey) != 32)
return 0;

SETUP_ERROR_CALLBACKS

seckey = (*penv)->GetByteArrayElements(penv, jseckey, 0);
result = secp256k1_ec_seckey_verify(ctx, (unsigned char *)seckey);
(*penv)->ReleaseByteArrayElements(penv, jseckey, seckey, 0);
Expand All @@ -108,6 +168,8 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256
if (jpubkey == NULL)
return 0;

SETUP_ERROR_CALLBACKS

size = (*penv)->GetArrayLength(penv, jpubkey);
CHECKRESULT((size != 33) && (size != 65), "invalid public key size");

Expand Down Expand Up @@ -144,6 +206,8 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256
if (jctx == 0)
return NULL;

SETUP_ERROR_CALLBACKS

CHECKRESULT((*penv)->GetArrayLength(penv, jseckey) != 32, "secret key must be 32 bytes");
seckey = (*penv)->GetByteArrayElements(penv, jseckey, 0);
result = secp256k1_ec_pubkey_create(ctx, &pub, (unsigned char *)seckey);
Expand Down Expand Up @@ -178,6 +242,8 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256
if (jseckey == NULL)
return NULL;

SETUP_ERROR_CALLBACKS

CHECKRESULT((*penv)->GetArrayLength(penv, jseckey) != 32, "secret key must be 32 bytes");
CHECKRESULT((*penv)->GetArrayLength(penv, jmsg) != 32, "message key must be 32 bytes");
seckey = (*penv)->GetByteArrayElements(penv, jseckey, 0);
Expand Down Expand Up @@ -228,6 +294,8 @@ JNIEXPORT jint JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256k1_1ec
if (jpubkey == NULL)
return 0;

SETUP_ERROR_CALLBACKS

sigSize = (*penv)->GetArrayLength(penv, jsig);
int sigFormat = GetSignatureFormat(sigSize);
CHECKRESULT(sigFormat == SIG_FORMAT_UNKNOWN, "invalid signature size");
Expand Down Expand Up @@ -285,6 +353,8 @@ JNIEXPORT jint JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256k1_1ec
if (jsigout == NULL)
return 0;

SETUP_ERROR_CALLBACKS

size = (*penv)->GetArrayLength(penv, jsigin);
sigFormat = GetSignatureFormat(size);
CHECKRESULT(sigFormat == SIG_FORMAT_UNKNOWN, "invalid signature size");
Expand Down Expand Up @@ -328,6 +398,9 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256
return 0;
if (jseckey == NULL)
return 0;

SETUP_ERROR_CALLBACKS

CHECKRESULT((*penv)->GetArrayLength(penv, jseckey) != 32, "secret key must be 32 bytes");
seckey = (*penv)->GetByteArrayElements(penv, jseckey, 0);
result = secp256k1_ec_seckey_negate(ctx, (unsigned char *)seckey);
Expand All @@ -354,6 +427,8 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256
if (jpubkey == NULL)
return 0;

SETUP_ERROR_CALLBACKS

size = (*penv)->GetArrayLength(penv, jpubkey);
CHECKRESULT((size != 33) && (size != 65), "invalid public key size");
pub = (*penv)->GetByteArrayElements(penv, jpubkey, 0);
Expand Down Expand Up @@ -391,6 +466,8 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256
if (jtweak == NULL)
return NULL;

SETUP_ERROR_CALLBACKS

CHECKRESULT((*penv)->GetArrayLength(penv, jseckey) != 32, "secret key must be 32 bytes");
CHECKRESULT((*penv)->GetArrayLength(penv, jtweak) != 32, "tweak must be 32 bytes");
seckey = (*penv)->GetByteArrayElements(penv, jseckey, 0);
Expand Down Expand Up @@ -422,6 +499,8 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256
if (jtweak == NULL)
return NULL;

SETUP_ERROR_CALLBACKS

size = (*penv)->GetArrayLength(penv, jpubkey);
CHECKRESULT((size != 33) && (size != 65), "invalid public key size");
CHECKRESULT((*penv)->GetArrayLength(penv, jtweak) != 32, "tweak must be 32 bytes");
Expand Down Expand Up @@ -463,6 +542,8 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256
if (jtweak == NULL)
return NULL;

SETUP_ERROR_CALLBACKS

CHECKRESULT((*penv)->GetArrayLength(penv, jseckey) != 32, "secret key must be 32 bytes");
CHECKRESULT((*penv)->GetArrayLength(penv, jtweak) != 32, "tweak must be 32 bytes");
seckey = (*penv)->GetByteArrayElements(penv, jseckey, 0);
Expand Down Expand Up @@ -494,6 +575,8 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256
if (jtweak == NULL)
return NULL;

SETUP_ERROR_CALLBACKS

size = (*penv)->GetArrayLength(penv, jpubkey);
CHECKRESULT((size != 33) && (size != 65), "invalid public key size");
CHECKRESULT((*penv)->GetArrayLength(penv, jtweak) != 32, "tweak must be 32 bytes");
Expand Down Expand Up @@ -548,6 +631,8 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256
if (jpubkeys == NULL)
return NULL;

SETUP_ERROR_CALLBACKS

count = (*penv)->GetArrayLength(penv, jpubkeys);
pubkeys = calloc(count, sizeof(secp256k1_pubkey *));

Expand Down Expand Up @@ -596,6 +681,8 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256
if (jpubkey == NULL)
return NULL;

SETUP_ERROR_CALLBACKS

CHECKRESULT((*penv)->GetArrayLength(penv, jseckey) != 32, "invalid private key size");

size = (*penv)->GetArrayLength(penv, jpubkey);
Expand Down Expand Up @@ -637,7 +724,10 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256
return NULL;
if (jmsg == NULL)
return NULL;
CHECKRESULT(recid < 0 || recid > 3, "recid must be 0, 1, 2 or 3")

SETUP_ERROR_CALLBACKS

// CHECKRESULT(recid < 0 || recid > 3, "recid must be 0, 1, 2 or 3")
sigSize = (*penv)->GetArrayLength(penv, jsig);
int sigFormat = GetSignatureFormat(sigSize);
CHECKRESULT(sigFormat == SIG_FORMAT_UNKNOWN, "invalid signature size");
Expand Down Expand Up @@ -693,6 +783,9 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256
return 0;
if (jsig == NULL)
return 0;

SETUP_ERROR_CALLBACKS

CHECKRESULT((*penv)->GetArrayLength(penv, jsig) != 64, "invalid signature size");

size = (*penv)->GetArrayLength(penv, jsig);
Expand Down Expand Up @@ -732,6 +825,8 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256
if (jseckey == NULL)
return NULL;

SETUP_ERROR_CALLBACKS

CHECKRESULT((*penv)->GetArrayLength(penv, jseckey) != 32, "secret key must be 32 bytes");
CHECKRESULT((*penv)->GetArrayLength(penv, jmsg) != 32, "message must be 32 bytes");
if (jauxrand32 != 0)
Expand Down Expand Up @@ -785,6 +880,8 @@ JNIEXPORT jint JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256k1_1sc
if (jpubkey == NULL)
return 0;

SETUP_ERROR_CALLBACKS

CHECKRESULT((*penv)->GetArrayLength(penv, jsig) != 64, "signature must be 64 bytes");
CHECKRESULT((*penv)->GetArrayLength(penv, jpubkey) != 32, "public key must be 32 bytes");
CHECKRESULT((*penv)->GetArrayLength(penv, jmsg) != 32, "message must be 32 bytes");
Expand Down
12 changes: 11 additions & 1 deletion src/commonMain/kotlin/fr/acinq/secp256k1/Secp256k1.kt
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,17 @@ public interface Secp256k1 {

internal expect fun getSecpk256k1(): Secp256k1

public class Secp256k1Exception : RuntimeException {
public open class Secp256k1Exception : RuntimeException {
public constructor() : super()
public constructor(message: String?) : super(message)
}

public class Secp256k1ErrorCallbackException : Secp256k1Exception {
public constructor() : super()
public constructor(message: String?) : super(message)
}

public class Secp256k1IllegalCallbackException : Secp256k1Exception {
public constructor() : super()
public constructor(message: String?) : super(message)
}
53 changes: 51 additions & 2 deletions src/nativeMain/kotlin/fr/acinq/secp256k1/Secp256k1Native.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,64 @@ import kotlinx.cinterop.*
import platform.posix.size_tVar
import secp256k1.*

private typealias MyHandler = (String) -> Unit

private object CallbackHandler {
var illegalCallBackMessage: String? = null
val illegalHandler: MyHandler = { x: String -> illegalCallBackMessage = x }
val illegalCallbackRef = StableRef.create(illegalHandler)
var errorCallBackMessage: String? = null
val errorHandler: MyHandler = { x: String -> errorCallBackMessage = x }
val errorCallbackRef = StableRef.create(errorHandler)

fun checkForErrors() {
if (errorCallBackMessage != null) {
val message = errorCallBackMessage
errorCallBackMessage = null
throw Secp256k1ErrorCallbackException(message)
}
if (illegalCallBackMessage != null) {
val message = illegalCallBackMessage
illegalCallBackMessage = null
throw Secp256k1IllegalCallbackException(message)
}
}
}

@OptIn(ExperimentalUnsignedTypes::class)
public object Secp256k1Native : Secp256k1 {

private val ctx: CPointer<secp256k1_context> by lazy {
secp256k1_context_create((SECP256K1_FLAGS_TYPE_CONTEXT or SECP256K1_FLAGS_BIT_CONTEXT_SIGN or SECP256K1_FLAGS_BIT_CONTEXT_VERIFY).toUInt())

val ctx = secp256k1_context_create((SECP256K1_FLAGS_TYPE_CONTEXT or SECP256K1_FLAGS_BIT_CONTEXT_SIGN or SECP256K1_FLAGS_BIT_CONTEXT_VERIFY).toUInt())
?: error("Could not create secp256k1 context")

secp256k1_context_set_error_callback(
ctx, staticCFunction { buffer: CPointer<ByteVar>?, data: COpaquePointer? ->
if (data != null) {
val callback = data.asStableRef<MyHandler>().get()
callback(buffer?.toKString() ?: "error callback triggered")
}
},
CallbackHandler.errorCallbackRef.asCPointer()
)
secp256k1_context_set_illegal_callback(
ctx, staticCFunction { buffer: CPointer<ByteVar>?, data: COpaquePointer? ->
if (data != null) {
val callback = data.asStableRef<MyHandler>().get()
callback(buffer?.toKString() ?: "illegal callback triggered")
}
},
CallbackHandler.illegalCallbackRef.asCPointer()
)

ctx
}

private fun Int.requireSuccess(message: String): Int = if (this != 1) throw Secp256k1Exception(message) else this
private fun Int.requireSuccess(message: String): Int {
CallbackHandler.checkForErrors()
return if (this != 1) throw Secp256k1Exception(message) else this
}

private fun MemScope.allocSignature(input: ByteArray): secp256k1_ecdsa_signature {
val sig = alloc<secp256k1_ecdsa_signature>()
Expand Down

0 comments on commit 8e17e70

Please sign in to comment.