Skip to content

Commit

Permalink
Optimize implementations (#72)
Browse files Browse the repository at this point in the history
  • Loading branch information
05nelsonm authored Dec 21, 2024
1 parent dd9078b commit ebd0eb1
Show file tree
Hide file tree
Showing 13 changed files with 288 additions and 218 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import kotlinx.benchmark.*
import org.kotlincrypto.core.digest.Digest
import org.kotlincrypto.hash.sha2.SHA256
import org.kotlincrypto.hash.sha2.SHA512
import org.kotlincrypto.hash.sha2.SHA512_224

@State(Scope.Benchmark)
@BenchmarkMode(Mode.AverageTime)
Expand All @@ -37,3 +38,12 @@ open class SHA256Benchmark: DigestBenchmarkBase() {
open class SHA512Benchmark: DigestBenchmarkBase() {
override val d: Digest = SHA512()
}

@State(Scope.Benchmark)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(BenchmarkTimeUnit.NANOSECONDS)
@Warmup(iterations = ITERATIONS, time = TIME_WARMUP)
@Measurement(iterations = ITERATIONS, time = TIME_MEASURE)
open class SHA512_224Benchmark: DigestBenchmarkBase() {
override val d: Digest = SHA512_224()
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public class MD5: Digest {
var c = state[2]
var d = state[3]

for (i in 0 until blockSize()) {
for (i in 0..<blockSize()) {
when {
i < 16 -> {
var j = (i * 4) + offset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ public class SHA1: Digest {
val x = x

var j = offset
for (i in 0 until 16) {
for (i in 0..<16) {
x[i] =
((input[j++].toInt() and 0xff) shl 24) or
((input[j++].toInt() and 0xff) shl 16) or
((input[j++].toInt() and 0xff) shl 8) or
((input[j++].toInt() and 0xff) )
}

for (i in 16 until 80) {
for (i in 16..<80) {
x[i] = (x[i - 3] xor x[i - 8] xor x[i - 14] xor x[i - 16]) rotateLeft 1
}

Expand All @@ -62,7 +62,7 @@ public class SHA1: Digest {
var d = state[3]
var e = state[4]

for (i in 0 until 80) {
for (i in 0..<80) {
val a2 = when {
i < 20 -> {
val f = d xor (b and (c xor d))
Expand Down
13 changes: 13 additions & 0 deletions library/sha2/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,18 @@ kmpConfiguration {
}
}
}

kotlin {
with(sourceSets) {
val nonJsSources = listOf("jvm", "native", "wasmJs", "wasmWasi").mapNotNull {
findByName(it + "Main")
}
if (nonJsSources.isEmpty()) return@kotlin
val nonJsMain = maybeCreate("nonJsMain").apply {
dependsOn(getByName("commonMain"))
}
nonJsSources.forEach { it.dependsOn(nonJsMain) }
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,15 @@ public sealed class Bit32Digest: Digest {
val x = x

var j = offset
for (i in 0 until 16) {
for (i in 0..<16) {
x[i] =
((input[j++].toInt() and 0xff) shl 24) or
((input[j++].toInt() and 0xff) shl 16) or
((input[j++].toInt() and 0xff) shl 8) or
((input[j++].toInt() and 0xff) )
}

for (i in 16 until 64) {
for (i in 16..<64) {
val x15 = x[i - 15]
val s0 =
((x15 ushr 7) or (x15 shl 25)) xor
Expand All @@ -119,7 +119,7 @@ public sealed class Bit32Digest: Digest {
var g = state[6]
var h = state[7]

for (i in 0 until 64) {
for (i in 0..<64) {
val s0 =
((a ushr 2) or (a shl 30)) xor
((a ushr 13) or (a shl 19)) xor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.kotlincrypto.hash.sha2
import org.kotlincrypto.core.InternalKotlinCryptoApi
import org.kotlincrypto.core.digest.Digest
import org.kotlincrypto.core.digest.internal.DigestState
import org.kotlincrypto.hash.sha2.internal.rotateRight
import kotlin.jvm.JvmField

/**
Expand Down Expand Up @@ -94,7 +95,7 @@ public sealed class Bit64Digest: Digest {
val x = x

var j = offset
for (i in 0 until 16) {
for (i in 0..<16) {
x[i] =
((input[j++].toLong() and 0xff) shl 56) or
((input[j++].toLong() and 0xff) shl 48) or
Expand All @@ -106,11 +107,11 @@ public sealed class Bit64Digest: Digest {
((input[j++].toLong() and 0xff) )
}

for (i in 16 until 80) {
for (i in 16..<80) {
val x15 = x[i - 15]
val s0 = (x15 rotateRight 1) xor (x15 rotateRight 8) xor (x15 ushr 7)
val s0 = (x15.rotateRight(1)) xor (x15.rotateRight(8)) xor (x15 ushr 7)
val x2 = x[i - 2]
val s1 = (x2 rotateRight 19) xor (x2 rotateRight 61) xor (x2 ushr 6)
val s1 = (x2.rotateRight(19)) xor (x2.rotateRight(61)) xor (x2 ushr 6)
val x16 = x[i - 16]
val x7 = x[i - 7]
x[i] = x16 + s0 + x7 + s1
Expand All @@ -127,9 +128,9 @@ public sealed class Bit64Digest: Digest {
var g = state[6]
var h = state[7]

for (i in 0 until 80) {
val s0 = (a rotateRight 28) xor (a rotateRight 34) xor (a rotateRight 39)
val s1 = (e rotateRight 14) xor (e rotateRight 18) xor (e rotateRight 41)
for (i in 0..<80) {
val s0 = (a.rotateRight(28)) xor (a.rotateRight(34)) xor (a.rotateRight(39))
val s1 = (e.rotateRight(14)) xor (e.rotateRight(18)) xor (e.rotateRight(41))

val ch = (e and f) xor (e.inv() and g)
val maj = (a and b) xor (a and c) xor (b and c)
Expand Down Expand Up @@ -215,9 +216,6 @@ public sealed class Bit64Digest: Digest {
state[7] = h7
}

@Suppress("NOTHING_TO_INLINE", "KotlinRedundantDiagnosticSuppress")
private inline infix fun Long.rotateRight(n: Int): Long = (this ushr n) or (this shl (64 - n))

private companion object {
private val K = longArrayOf(
4794697086780616226L, 8158064640168781261L, -5349999486874862801L, -1606136188198331460L,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,78 +111,79 @@ public class SHA512t: Bit64Digest {
return ByteArray(0)
}

val out = ByteArray(digestLength())
var i = 0
val len = digestLength()
val out = ByteArray(len)

fun Long.putOut() {
if (i == out.size) return
var i = 0
fun Long.putOut(): Unit? {
out[i++] = toByte()
return if (i == len) null else Unit
}

(a shr 56).putOut()
(a shr 48).putOut()
(a shr 40).putOut()
(a shr 32).putOut()
(a shr 24).putOut()
(a shr 16).putOut()
(a shr 8).putOut()
(a ).putOut()
(b shr 56).putOut()
(b shr 48).putOut()
(b shr 40).putOut()
(b shr 32).putOut()
(b shr 24).putOut()
(b shr 16).putOut()
(b shr 8).putOut()
(b ).putOut()
(c shr 56).putOut()
(c shr 48).putOut()
(c shr 40).putOut()
(c shr 32).putOut()
(c shr 24).putOut()
(c shr 16).putOut()
(c shr 8).putOut()
(c ).putOut()
(d shr 56).putOut()
(d shr 48).putOut()
(d shr 40).putOut()
(d shr 32).putOut()
(d shr 24).putOut()
(d shr 16).putOut()
(d shr 8).putOut()
(d ).putOut()
(e shr 56).putOut()
(e shr 48).putOut()
(e shr 40).putOut()
(e shr 32).putOut()
(e shr 24).putOut()
(e shr 16).putOut()
(e shr 8).putOut()
(e ).putOut()
(f shr 56).putOut()
(f shr 48).putOut()
(f shr 40).putOut()
(f shr 32).putOut()
(f shr 24).putOut()
(f shr 16).putOut()
(f shr 8).putOut()
(f ).putOut()
(g shr 56).putOut()
(g shr 48).putOut()
(g shr 40).putOut()
(g shr 32).putOut()
(g shr 24).putOut()
(g shr 16).putOut()
(g shr 8).putOut()
(g ).putOut()
(h shr 56).putOut()
(h shr 48).putOut()
(h shr 40).putOut()
(h shr 32).putOut()
(h shr 24).putOut()
(h shr 16).putOut()
(h shr 8).putOut()
(h ).putOut()
(a shr 56).putOut() ?: return out
(a shr 48).putOut() ?: return out
(a shr 40).putOut() ?: return out
(a shr 32).putOut() ?: return out
(a shr 24).putOut() ?: return out
(a shr 16).putOut() ?: return out
(a shr 8).putOut() ?: return out
(a ).putOut() ?: return out
(b shr 56).putOut() ?: return out
(b shr 48).putOut() ?: return out
(b shr 40).putOut() ?: return out
(b shr 32).putOut() ?: return out
(b shr 24).putOut() ?: return out
(b shr 16).putOut() ?: return out
(b shr 8).putOut() ?: return out
(b ).putOut() ?: return out
(c shr 56).putOut() ?: return out
(c shr 48).putOut() ?: return out
(c shr 40).putOut() ?: return out
(c shr 32).putOut() ?: return out
(c shr 24).putOut() ?: return out
(c shr 16).putOut() ?: return out
(c shr 8).putOut() ?: return out
(c ).putOut() ?: return out
(d shr 56).putOut() ?: return out
(d shr 48).putOut() ?: return out
(d shr 40).putOut() ?: return out
(d shr 32).putOut() ?: return out
(d shr 24).putOut() ?: return out
(d shr 16).putOut() ?: return out
(d shr 8).putOut() ?: return out
(d ).putOut() ?: return out
(e shr 56).putOut() ?: return out
(e shr 48).putOut() ?: return out
(e shr 40).putOut() ?: return out
(e shr 32).putOut() ?: return out
(e shr 24).putOut() ?: return out
(e shr 16).putOut() ?: return out
(e shr 8).putOut() ?: return out
(e ).putOut() ?: return out
(f shr 56).putOut() ?: return out
(f shr 48).putOut() ?: return out
(f shr 40).putOut() ?: return out
(f shr 32).putOut() ?: return out
(f shr 24).putOut() ?: return out
(f shr 16).putOut() ?: return out
(f shr 8).putOut() ?: return out
(f ).putOut() ?: return out
(g shr 56).putOut() ?: return out
(g shr 48).putOut() ?: return out
(g shr 40).putOut() ?: return out
(g shr 32).putOut() ?: return out
(g shr 24).putOut() ?: return out
(g shr 16).putOut() ?: return out
(g shr 8).putOut() ?: return out
(g ).putOut() ?: return out
(h shr 56).putOut() ?: return out
(h shr 48).putOut() ?: return out
(h shr 40).putOut() ?: return out
(h shr 32).putOut() ?: return out
(h shr 24).putOut() ?: return out
(h shr 16).putOut() ?: return out
(h shr 8).putOut() ?: return out
(h ).putOut() ?: return out

return out
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
**/
package org.kotlincrypto.hash.benchmarks
@file:Suppress("KotlinRedundantDiagnosticSuppress")

import kotlinx.benchmark.*
import org.kotlincrypto.core.digest.Digest
import org.kotlincrypto.hash.md.MD5
package org.kotlincrypto.hash.sha2.internal

@State(Scope.Benchmark)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(BenchmarkTimeUnit.NANOSECONDS)
@Warmup(iterations = ITERATIONS, time = TIME_WARMUP)
@Measurement(iterations = ITERATIONS, time = TIME_MEASURE)
open class M55Benchmark: DigestBenchmarkBase() {
override val d: Digest = MD5()
}
@Suppress("NOTHING_TO_INLINE")
internal expect inline fun Long.rotateRight(n: Int): Long
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,11 @@
* See the License for the specific language governing permissions and
* limitations under the License.
**/
package org.kotlincrypto.hash.benchmarks
@file:Suppress("KotlinRedundantDiagnosticSuppress")

import kotlinx.benchmark.*
import org.kotlincrypto.core.digest.Digest
import org.kotlincrypto.hash.sha1.SHA1
package org.kotlincrypto.hash.sha2.internal

@State(Scope.Benchmark)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(BenchmarkTimeUnit.NANOSECONDS)
@Warmup(iterations = ITERATIONS, time = TIME_WARMUP)
@Measurement(iterations = ITERATIONS, time = TIME_MEASURE)
open class SHA1Benchmark: DigestBenchmarkBase() {
override val d: Digest = SHA1()
}
import kotlin.rotateRight as kRotateRight

@Suppress("NOTHING_TO_INLINE")
internal actual inline fun Long.rotateRight(n: Int): Long = kRotateRight(n)
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright (c) 2024 Matthew Nelson
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
**/
@file:Suppress("KotlinRedundantDiagnosticSuppress")

package org.kotlincrypto.hash.sha2.internal

@Suppress("NOTHING_TO_INLINE")
internal actual inline fun Long.rotateRight(n: Int): Long = (this ushr n) or (this shl (64 - n))
Loading

0 comments on commit ebd0eb1

Please sign in to comment.