Skip to content

Commit

Permalink
benchmarks: optimize sm3
Browse files Browse the repository at this point in the history
This commit optimizes the RV32 SM3 implementation to yield a speedup of
about 2.6x the original implementation.

The RV64 version is faster now too, but the toolchain seems to be
broken as grev and rol don't seem to compile. So it's still a lot slower.
  • Loading branch information
HCPauKaifler committed Feb 12, 2021
1 parent 53cbf4f commit 74c2612
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 143 deletions.
123 changes: 78 additions & 45 deletions benchmarks/sm3/reference/sm3.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#define SM3_BLOCK_SIZE (16 * sizeof(uint32_t))

// Reverses the byte order of `V`
#define REVERSE_BITS_32(V) \
#define REVERSE_BYTES_32(V) \
(((V & 0x000000FF) << 24) | (((V)&0x0000FF00) << 8) | \
(((V)&0x00FF0000) >> 8) | (((V)&0xFF000000) >> 24))

Expand All @@ -19,54 +19,83 @@
#define SM3_P0(X) ((X) ^ SM3_ROTATE_32((X), 9) ^ SM3_ROTATE_32((X), 17))
#define SM3_P1(X) ((X) ^ SM3_ROTATE_32((X), 15) ^ SM3_ROTATE_32((X), 23))

// Expands the state `s` to `w`
static void sm3_expand(uint32_t w[68], uint32_t s[24]) {
for (int i = 0; i < 16; ++i) {
w[i] = REVERSE_BITS_32(s[i + 8]);
// Expands state values and returns the result
#define SM3_EXPAND_STEP(W0, W3, W7, W10, W13) \
(SM3_P1((W0) ^ (W7) ^ SM3_ROTATE_32((W13), 15)) ^ SM3_ROTATE_32((W3), 7) ^ \
(W10))

// Performs a compression step with permutation constant T, iteration I
// and expanded words W1 and W2
#define SM3_COMPRESS_STEP(I, W1, W2) \
{ \
uint32_t t = (I) < 16 ? 0x79CC4519 : 0x7A879D8A; \
uint32_t rot = SM3_ROTATE_32(x[0], 12); \
uint32_t ss1 = SM3_ROTATE_32(rot + x[4] + SM3_ROTATE_32(t, (I)), 7); \
\
uint32_t tt1, tt2; \
/* optimized out by the compiler */ \
if ((I) < 16) { \
tt1 = (x[0] ^ x[1] ^ x[2]) + x[3] + (ss1 ^ rot) + ((W1) ^ (W2)); \
tt2 = (x[4] ^ x[5] ^ x[6]) + x[7] + ss1 + (W1); \
} else { \
tt1 = ((x[0] & x[1]) | (x[0] & x[2]) | (x[1] & x[2])) + x[3] + \
(ss1 ^ rot) + ((W1) ^ (W2)); \
tt2 = ((x[4] & x[5]) | (~x[4] & x[6])) + x[7] + ss1 + (W1); \
} \
\
x[3] = x[2]; \
x[2] = SM3_ROTATE_32(x[1], 9); \
x[1] = x[0]; \
x[0] = tt1; \
x[7] = x[6]; \
x[6] = SM3_ROTATE_32(x[5], 19); \
x[5] = x[4]; \
x[4] = SM3_P0(tt2); \
}

for (int i = 16; i < 68; ++i) {
w[i] = SM3_P1(w[i - 16] ^ w[i - 9] ^ SM3_ROTATE_32(w[i - 3], 15)) ^
SM3_ROTATE_32(w[i - 13], 7) ^ w[i - 6];
}
}

// Compresses `s` in place
static void sm3_compress(uint32_t s[24]) {
uint32_t w[68];
sm3_expand(w, s);

// The IV and iteration state
uint32_t x[8];
memcpy(x, s, 8 * sizeof(uint32_t));

// The state update transformation below uses and modifies `x`
// depending on the expansion `w` and the current iteration `i`
for (int i = 0; i < 64; ++i) {
// The round constant `t` provides additional randomness
uint32_t t = (i < 16) ? 0x79CC4519 : 0x7A879D8A;
uint32_t rot = SM3_ROTATE_32(x[0], 12);
uint32_t ss1 = SM3_ROTATE_32(rot + x[4] + SM3_ROTATE_32(t, i % 32), 7);
uint32_t ss2 = ss1 ^ rot;
uint32_t w_i = w[i] ^ w[i + 4];

uint32_t tt1, tt2;
if (i < 16) {
tt1 = (x[0] ^ x[1] ^ x[2]) + x[3] + ss2 + w_i;
tt2 = (x[4] ^ x[5] ^ x[6]) + x[7] + ss1 + w[i];
} else {
tt1 = ((x[0] & x[1]) | (x[0] & x[2]) | (x[1] & x[2])) + x[3] + ss2 + w_i;
tt2 = ((x[4] & x[5]) | (~x[4] & x[6])) + x[7] + ss1 + w[i];
for (int i = 0; i < 8; ++i) {
x[i] = s[i];
}

// `w` contains 16 of the expanded words.
uint32_t w[16];
for (int i = 0; i < 16; ++i) {
w[i] = REVERSE_BYTES_32(s[i + 8]);
}

// Compress first 12 words.
for (int i = 0; i < 12; ++i) {
SM3_COMPRESS_STEP(i, w[i], w[i + 4]);
}
// Compress and expand the remaining 4 words.
for (int i = 0; i < 4; ++i) {
w[i] =
SM3_EXPAND_STEP(w[i], w[3 + i], w[7 + i], w[10 + i], w[(13 + i) % 16]);
SM3_COMPRESS_STEP(i + 12, w[i + 12], w[i]);
}

// Rounds 16 to 64
for (int j = 16; j < 64; j += 16) {
// Expand and then compress the first 12 words as the remaining 4 need to be
// handled differently in this implementation.
for (int i = 0; i < 12; ++i) {
w[4 + i] = SM3_EXPAND_STEP(w[4 + i], w[(7 + i) % 16], w[(11 + i) % 16],
w[(14 + i) % 16], w[(1 + i) % 16]);
}
for (int i = 0; i < 12; ++i) {
SM3_COMPRESS_STEP(i + j, w[i], w[i + 4]);
}

x[3] = x[2];
x[2] = SM3_ROTATE_32(x[1], 9);
x[1] = x[0];
x[0] = tt1;
x[7] = x[6];
x[6] = SM3_ROTATE_32(x[5], 19);
x[5] = x[4];
x[4] = SM3_P0(tt2);
// Now expand and compress the remaining 4 words.
for (int i = 0; i < 4; ++i) {
w[i] = SM3_EXPAND_STEP(w[i], w[3 + i], w[7 + i], w[10 + i],
w[(13 + i) % 16]);
SM3_COMPRESS_STEP(i + j + 12, w[i + 12], w[i]);
}
}

// Xor `s` with `x`
Expand All @@ -87,14 +116,18 @@ void sm3_hash(uint8_t hash[32], const uint8_t *message, size_t len) {

// Hash complete blocks first
while (remaining >= SM3_BLOCK_SIZE) {
memcpy(&s[8], m, SM3_BLOCK_SIZE);
for (int i = 0; i < SM3_BLOCK_SIZE; ++i) {
b[i] = m[i];
}
sm3_compress(s);
remaining -= SM3_BLOCK_SIZE;
m += SM3_BLOCK_SIZE;
}

// Hash the last block with padding
memcpy(b, m, remaining);
for (int i = 0; i < remaining; ++i) {
b[i] = m[i];
}
// Append bit 1 after the message
b[remaining] = 0b10000000;
++remaining;
Expand All @@ -107,8 +140,8 @@ void sm3_hash(uint8_t hash[32], const uint8_t *message, size_t len) {
memset(&b[remaining], 0x00, SM3_BLOCK_SIZE - 8 - remaining);
// Append the length of the message in bits
uint64_t bitlen = 8 * (uint64_t)len;
s[22] = REVERSE_BITS_32((uint32_t)(bitlen >> 32));
s[23] = REVERSE_BITS_32((uint32_t)bitlen);
s[22] = REVERSE_BYTES_32((uint32_t)(bitlen >> 32));
s[23] = REVERSE_BYTES_32((uint32_t)bitlen);
sm3_compress(s);

// stores `s` in `hash` in big-endian
Expand Down
130 changes: 81 additions & 49 deletions benchmarks/sm3/zscrypto_rv32/sm3.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,72 +2,100 @@
#include <stdio.h>
#include <string.h>

#include "riscvcrypto/sm3/api_sm3.h"
#include "riscvcrypto/share/riscv-crypto-intrinsics.h"
#include "riscvcrypto/sm3/api_sm3.h"
#include "rvintrin.h"

// The block size in bytes
#define SM3_BLOCK_SIZE (16 * sizeof(uint32_t))

// Reverses the byte order of `V`
#define REVERSE_BITS_32(V) \
(((V & 0x000000FF) << 24) | (((V)&0x0000FF00) << 8) | \
(((V)&0x00FF0000) >> 8) | (((V)&0xFF000000) >> 24))
#define REVERSE_BYTES_32(V) (_rv32_grev((V), 0x18))

// Rotates `V` by `N` bits to the left
#define SM3_ROTATE_32(V, N) (((V) << (N)) | ((V) >> (32 - (N))))
#define SM3_ROTATE_32(V, N) (_rv32_rol((V), (N)))

// The two permutation functions
#define SM3_P0(X) _sm3p0((X))
#define SM3_P1(X) _sm3p1((X))

// Expands the state `s` to `w`
static void sm3_expand(uint32_t w[68], uint32_t s[24]) {
for (int i = 0; i < 16; ++i) {
w[i] = REVERSE_BITS_32(s[i + 8]);
}

for (int i = 16; i < 68; ++i) {
w[i] = SM3_P1(w[i - 16] ^ w[i - 9] ^ SM3_ROTATE_32(w[i - 3], 15)) ^
SM3_ROTATE_32(w[i - 13], 7) ^ w[i - 6];
// Expands state values and returns the result
#define SM3_EXPAND_STEP(W0, W3, W7, W10, W13) \
(SM3_P1((W0) ^ (W7) ^ SM3_ROTATE_32((W13), 15)) ^ SM3_ROTATE_32((W3), 7) ^ \
(W10))

// Performs a compression step with permutation constant T, iteration I
// and expanded words W1 and W2
#define SM3_COMPRESS_STEP(I, W1, W2) \
{ \
uint32_t t = (I) < 16 ? 0x79CC4519 : 0x7A879D8A; \
uint32_t rot = SM3_ROTATE_32(x[0], 12); \
uint32_t ss1 = SM3_ROTATE_32(rot + x[4] + SM3_ROTATE_32(t, (I)), 7); \
\
uint32_t tt1, tt2; \
/* optimized out by the compiler */ \
if ((I) < 16) { \
tt1 = (x[0] ^ x[1] ^ x[2]) + x[3] + (ss1 ^ rot) + ((W1) ^ (W2)); \
tt2 = (x[4] ^ x[5] ^ x[6]) + x[7] + ss1 + (W1); \
} else { \
tt1 = ((x[0] & x[1]) | (x[0] & x[2]) | (x[1] & x[2])) + x[3] + \
(ss1 ^ rot) + ((W1) ^ (W2)); \
tt2 = ((x[4] & x[5]) | (~x[4] & x[6])) + x[7] + ss1 + (W1); \
} \
\
x[3] = x[2]; \
x[2] = SM3_ROTATE_32(x[1], 9); \
x[1] = x[0]; \
x[0] = tt1; \
x[7] = x[6]; \
x[6] = SM3_ROTATE_32(x[5], 19); \
x[5] = x[4]; \
x[4] = SM3_P0(tt2); \
}
}

// Compresses `s` in place
static void sm3_compress(uint32_t s[24]) {
uint32_t w[68];
sm3_expand(w, s);

// The IV and iteration state
uint32_t x[8];
memcpy(x, s, 8 * sizeof(uint32_t));

// The state update transformation below uses and modifies `x`
// depending on the expansion `w` and the current iteration `i`
for (int i = 0; i < 64; ++i) {
// The round constant `t` provides additional randomness
uint32_t t = (i < 16) ? 0x79CC4519 : 0x7A879D8A;
uint32_t rot = SM3_ROTATE_32(x[0], 12);
uint32_t ss1 = SM3_ROTATE_32(rot + x[4] + SM3_ROTATE_32(t, i % 32), 7);
uint32_t ss2 = ss1 ^ rot;
uint32_t w_i = w[i] ^ w[i + 4];

uint32_t tt1, tt2;
if (i < 16) {
tt1 = (x[0] ^ x[1] ^ x[2]) + x[3] + ss2 + w_i;
tt2 = (x[4] ^ x[5] ^ x[6]) + x[7] + ss1 + w[i];
} else {
tt1 = ((x[0] & x[1]) | (x[0] & x[2]) | (x[1] & x[2])) + x[3] + ss2 + w_i;
tt2 = ((x[4] & x[5]) | (~x[4] & x[6])) + x[7] + ss1 + w[i];
for (int i = 0; i < 8; ++i) {
x[i] = s[i];
}

// `w` contains 16 of the expanded words.
uint32_t w[16];
for (int i = 0; i < 16; ++i) {
w[i] = REVERSE_BYTES_32(s[i + 8]);
}

// Compress first 12 words.
for (int i = 0; i < 12; ++i) {
SM3_COMPRESS_STEP(i, w[i], w[i + 4]);
}
// Compress and expand the remaining 4 words.
for (int i = 0; i < 4; ++i) {
w[i] =
SM3_EXPAND_STEP(w[i], w[3 + i], w[7 + i], w[10 + i], w[(13 + i) % 16]);
SM3_COMPRESS_STEP(i + 12, w[i + 12], w[i]);
}

// Rounds 16 to 64
for (int j = 16; j < 64; j += 16) {
// Expand and then compress the first 12 words as the remaining 4 need to be
// handled differently in this implementation.
for (int i = 0; i < 12; ++i) {
w[4 + i] = SM3_EXPAND_STEP(w[4 + i], w[(7 + i) % 16], w[(11 + i) % 16],
w[(14 + i) % 16], w[(1 + i) % 16]);
}
for (int i = 0; i < 12; ++i) {
SM3_COMPRESS_STEP(i + j, w[i], w[i + 4]);
}

x[3] = x[2];
x[2] = SM3_ROTATE_32(x[1], 9);
x[1] = x[0];
x[0] = tt1;
x[7] = x[6];
x[6] = SM3_ROTATE_32(x[5], 19);
x[5] = x[4];
x[4] = SM3_P0(tt2);
// Now expand and compress the remaining 4 words.
for (int i = 0; i < 4; ++i) {
w[i] = SM3_EXPAND_STEP(w[i], w[3 + i], w[7 + i], w[10 + i],
w[(13 + i) % 16]);
SM3_COMPRESS_STEP(i + j + 12, w[i + 12], w[i]);
}
}

// Xor `s` with `x`
Expand All @@ -88,14 +116,18 @@ void sm3_hash(uint8_t hash[32], const uint8_t *message, size_t len) {

// Hash complete blocks first
while (remaining >= SM3_BLOCK_SIZE) {
memcpy(&s[8], m, SM3_BLOCK_SIZE);
for (int i = 0; i < SM3_BLOCK_SIZE; ++i) {
b[i] = m[i];
}
sm3_compress(s);
remaining -= SM3_BLOCK_SIZE;
m += SM3_BLOCK_SIZE;
}

// Hash the last block with padding
memcpy(b, m, remaining);
for (int i = 0; i < remaining; ++i) {
b[i] = m[i];
}
// Append bit 1 after the message
b[remaining] = 0b10000000;
++remaining;
Expand All @@ -108,8 +140,8 @@ void sm3_hash(uint8_t hash[32], const uint8_t *message, size_t len) {
memset(&b[remaining], 0x00, SM3_BLOCK_SIZE - 8 - remaining);
// Append the length of the message in bits
uint64_t bitlen = 8 * (uint64_t)len;
s[22] = REVERSE_BITS_32((uint32_t)(bitlen >> 32));
s[23] = REVERSE_BITS_32((uint32_t)bitlen);
s[22] = REVERSE_BYTES_32((uint32_t)(bitlen >> 32));
s[23] = REVERSE_BYTES_32((uint32_t)bitlen);
sm3_compress(s);

// stores `s` in `hash` in big-endian
Expand Down
Loading

0 comments on commit 74c2612

Please sign in to comment.