diff --git a/jupyter-lib/shared-compiler/src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol/HMAC.kt b/jupyter-lib/shared-compiler/src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol/HMAC.kt index 7b42f868..85e14a26 100644 --- a/jupyter-lib/shared-compiler/src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol/HMAC.kt +++ b/jupyter-lib/shared-compiler/src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol/HMAC.kt @@ -3,19 +3,17 @@ package org.jetbrains.kotlinx.jupyter.protocol import javax.crypto.Mac import javax.crypto.spec.SecretKeySpec -class HMAC(algorithm: String, key: String?) { - private val mac = if (key?.isNotBlank() == true) Mac.getInstance(algorithm) else null - - init { - mac?.init(SecretKeySpec(key!!.toByteArray(), algorithm)) - } +class HMAC(algorithm: String, key: String) { + private val mac: Mac = + Mac.getInstance(algorithm).apply { + init(SecretKeySpec(key.toByteArray(), algorithm)) + } @Synchronized - operator fun invoke(data: Iterable): String? = - mac?.let { mac -> - data.forEach { mac.update(it) } - mac.doFinal().toHexString() - } + operator fun invoke(data: Iterable): String { + data.forEach { mac.update(it) } + return mac.doFinal().toHexString() + } - operator fun invoke(vararg data: ByteArray): String? = invoke(data.asIterable()) + operator fun invoke(vararg data: ByteArray): String = invoke(data.asIterable()) } diff --git a/jupyter-lib/shared-compiler/src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol/SocketWrapper.kt b/jupyter-lib/shared-compiler/src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol/SocketWrapper.kt index b581892d..654f0601 100644 --- a/jupyter-lib/shared-compiler/src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol/SocketWrapper.kt +++ b/jupyter-lib/shared-compiler/src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol/SocketWrapper.kt @@ -89,7 +89,7 @@ class SocketWrapper( prop -> prop.get(msg)?.let { MessageFormat.encodeToString(it) }?.toByteArray() ?: emptyJsonObjectStringBytes } - sendMore(hmac(signableMsg) ?: "") + sendMore(hmac(signableMsg)) for (i in 0 until (signableMsg.size - 1)) { sendMore(signableMsg[i]) } @@ -121,7 +121,7 @@ class SocketWrapper( val content = recv() val calculatedSig = hmac(header, parentHeader, metadata, content) - if (calculatedSig != null && sig != calculatedSig) { + if (sig != calculatedSig) { throw SignatureException("Invalid signature: expected $calculatedSig, received $sig - $ids") } diff --git a/src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/protocol/KernelServerTestsBase.kt b/src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/protocol/KernelServerTestsBase.kt index e774173e..bc88e8c1 100644 --- a/src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/protocol/KernelServerTestsBase.kt +++ b/src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/protocol/KernelServerTestsBase.kt @@ -41,7 +41,7 @@ abstract class KernelServerTestsBase(protected val runServerInSeparateProcess: B protected val kernelConfig = createKotlinKernelConfig( ports = createRandomKernelPorts(), - signatureKey = "", + signatureKey = "abc", scriptClasspath = classpath, homeDir = File(""), )