diff --git a/communication/src/dsakeygen.cpp b/communication/src/dsakeygen.cpp index ff2a5111bc..07dfe96265 100644 --- a/communication/src/dsakeygen.cpp +++ b/communication/src/dsakeygen.cpp @@ -32,59 +32,97 @@ using namespace std; class RSADERCommon { - - uint8_t* buffer; - int length; +private: + uint8_t* buffer_; + int length_; + int written_; protected: - void write_mpi(const mpi* data) { - int size = mpi_size(data); - if (length>=size) { - mpi_write_binary(data, buffer, size); - length -= size; - buffer += size; + int written() const { + return written_; + } + + void write_length(int len) { + uint8_t len_field = 0x00; + + if (len < 128) { + // Length can be encoded as a single byte + len_field = (uint8_t)len; + write(&len_field, sizeof(len_field)); + } else { + // Length has to be encoded as: + // (0x80 | number of length bytes) (length byte 0) ... (length byte N) + len_field = 0x80; + mpi intlen; + mpi_init(&intlen); + mpi_lset(&intlen, len); + int bytelen = mpi_size(&intlen); + len_field |= (uint8_t)bytelen; + + write(&len_field, sizeof(len_field)); + write_mpi(&intlen); + + mpi_free(&intlen); } } - void write_integer_1024(const mpi* data) { - uint8_t header[] = { 0x02, 0x81, 0x81, 0 }; // padded with an extra 0 byte to ensure number is positive - write(header, sizeof(header)); - write_mpi(data); + void write_mpi(const mpi* data, int fixedLength = -1) { + int len = fixedLength == -1 ? mpi_size(data) : fixedLength; + if (length_ >= len) { + mpi_write_binary(data, buffer_, len); + write(nullptr, len); + } } - void write_integer_512(const mpi* data) { - uint8_t header[] = { 0x02, 0x41, 0 }; // padded with an extra 0 byte to ensure number is positive - write(header, sizeof(header)); - write_mpi(data); + void write_integer_immediate(int value, int fixedLength = -1) { + mpi integer; + mpi_init(&integer); + mpi_lset(&integer, value); + write_integer(&integer, fixedLength); + mpi_free(&integer); } - void write_public_exponent() { - uint8_t data[] = { 2, 3, 1, 0, 1 }; - write(data, 5); + void write_integer(const mpi* data, int fixedLength = -1) { + int len = fixedLength == -1 ? mpi_size(data) : fixedLength; + uint8_t tmp = 0x02; // INTEGER + + // Write type + write(&tmp, sizeof(tmp)); + + write_length(len); + + write_mpi(data, fixedLength); } void write(const uint8_t* data, size_t length) { - while (length && this->length) { - *this->buffer++ = *data++; - length--; - this->length--; + while (length && length_ > 0) { + if (buffer_) { + if (data) + *buffer_++ = *data++; + else + buffer_++; + } + length--; + length_--; + written_++; } } void set_buffer(void* buffer, size_t size) { - this->buffer = (uint8_t*)buffer; - this->length = size; + buffer_ = (uint8_t*)buffer; + length_ = size; + written_ = 0; } }; class RSAPrivateKeyWriter : RSADERCommon { - void write_sequence_header() { - // sequence tag and version - uint8_t header[] = { 0x30, 0x82, 0x2, 0x5F, 2, 1, 0 }; - write(header, sizeof(header)); + void write_sequence_header(int len) { + uint8_t tag = 0x30; // SEQUENCE + write(&tag, sizeof(tag)); + write_length(len); } public: @@ -100,16 +138,26 @@ class RSAPrivateKeyWriter : RSADERCommon { * @param QP */ void write_private_key(uint8_t* buf, size_t length, rsa_context& ctx) { + // Dummy run to calculate sequence length + set_buffer(nullptr, length); + write_private_key_parts(ctx); + int seq_len = written(); + set_buffer(buf, length); - write_sequence_header(); - write_integer_1024(&ctx.N); - write_public_exponent(); // 5 - write_integer_1024(&ctx.D); // 132 * 2 - write_integer_512(&ctx.P); // 67 * 5 - write_integer_512(&ctx.Q); - write_integer_512(&ctx.DP); - write_integer_512(&ctx.DQ); - write_integer_512(&ctx.QP); // total is 604 bytes, 0x25C + write_sequence_header(seq_len); + write_private_key_parts(ctx); + } + + void write_private_key_parts(rsa_context& ctx) { + write_integer_immediate(0, 1); + write_integer(&ctx.N, 129); + write_integer(&ctx.E, 3); + write_integer(&ctx.D, 129); + write_integer(&ctx.P, 65); + write_integer(&ctx.Q, 65); + write_integer(&ctx.DP, 65); + write_integer(&ctx.DQ, 65); + write_integer(&ctx.QP, 65); } }; diff --git a/user/tests/app/rsakeygen/rsakeygen.cpp b/user/tests/app/rsakeygen/rsakeygen.cpp new file mode 100644 index 0000000000..46f31764ec --- /dev/null +++ b/user/tests/app/rsakeygen/rsakeygen.cpp @@ -0,0 +1,47 @@ +#include "application.h" +#include "ota_flash_hal.h" +#include "dsakeygen.h" +#include "dct.h" + +static int s_loops = 1000; + +static uint8_t privkey[EXTERNAL_FLASH_CORE_PRIVATE_KEY_LENGTH]; +static uint8_t pubkey[DCT_DEVICE_PUBLIC_KEY_SIZE]; + +static const char c_hexmap[] = { + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', + 'A', 'B', 'C', 'D', 'E', 'F' +}; + +void printHex(const uint8_t* buf, size_t size) { + for (int i = 0; i < size; i++) { + Serial.write(c_hexmap[(buf[i] & 0xF0) >> 4]); + Serial.write(c_hexmap[buf[i] & 0x0F]); + } +} + +int key_gen_random(void* p) { + return (int)HAL_RNG_GetRandomNumber(); +} + +/* executes once at startup */ +void setup() { + Serial.begin(57600); + + int error = 1; + while (s_loops > 0) { + memset(pubkey, 0, sizeof(pubkey)); + memset(privkey, 0, sizeof(privkey)); + error = gen_rsa_key(privkey, EXTERNAL_FLASH_CORE_PRIVATE_KEY_LENGTH, key_gen_random, NULL); + if (!error) { + Serial.print("keys:"); + extract_public_rsa_key(pubkey, privkey); + printHex(privkey, sizeof(privkey)); + Serial.print(":"); + printHex(pubkey, sizeof(pubkey)); + Serial.print("\n"); + } + s_loops--; + } + Serial.print("done"); +} diff --git a/user/tests/app/rsakeygen/verify_keys.py b/user/tests/app/rsakeygen/verify_keys.py new file mode 100755 index 0000000000..c097c0422a --- /dev/null +++ b/user/tests/app/rsakeygen/verify_keys.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python +import sys +import os +import subprocess + +def usage(): + print '%s [input_file or tty]' % (sys.argv[0]) + +if len(sys.argv) != 2: + usage() + sys.exit(1) + +f = open(sys.argv[1], 'r') +i = -1 + +def hex2bin(hex): + return hex.decode('hex') + +def run(cmd, stdin): + p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE) + out, err = p.communicate(stdin) + return (out.strip(), err.strip()) + +for l in f: + l = l.strip() + if l.startswith('done'): + print 'Done' + sys.exit(0) + + if l.startswith('keys:'): + l = l[5:] + else: + continue + priv, pub = l.split(":") + i += 1 + privb = hex2bin(priv) + pubb = hex2bin(pub) + + privcheck = run(['openssl', 'rsa', '-inform', 'DER', '-noout', '-check'], privb) + privmod = run(['openssl', 'rsa', '-inform', 'DER', '-noout', '-modulus'], privb) + pubmod = run(['openssl', 'rsa', '-pubin', '-inform', 'DER', '-noout', '-modulus'], pubb) + + if not privcheck[1] and (privmod[0] == pubmod[0]) and not privmod[1] and not pubmod[1]: + print '%d:OK' % (i, ) + else: + print '%d:FAIL' % (i, ) + sys.stderr.write(privcheck[0]) + sys.stderr.write(privcheck[1]) + sys.stderr.write(privmod[0]) + sys.stderr.write(privmod[1]) + sys.stderr.write(pubmod[0]) + sys.stderr.write(pubmod[1]) + +print 'Done' +sys.exit(0)