diff --git a/example-client.c b/example-client.c index 3a86e2f..a70f1aa 100644 --- a/example-client.c +++ b/example-client.c @@ -1,7 +1,7 @@ /* - * openssl: gcc example-client.c openssl.c -lssl -lcrypto - * wolfssl: gcc example-client.c openssl.c -lwolfssl -DHAVE_WOLFSSL - * mbedtls: gcc example-client.c mbedtls.c -lmbedtls -lmbedcrypto -lmbedx509 + * openssl: gcc example-client.c openssl.c -lssl -lcrypto -o client + * wolfssl: gcc example-client.c openssl.c -lwolfssl -DHAVE_WOLFSSL -o client + * mbedtls: gcc example-client.c mbedtls.c -lmbedtls -lmbedcrypto -lmbedx509 -o client */ #include @@ -12,16 +12,12 @@ #include #include #include +#include #include #include #include -#include "ssl.h" - -static void on_verify_error(int error, const char *str, void *arg) -{ - fprintf(stderr, "WARNING: SSL certificate error(%d): %s\n", error, str); -} +#include "example.h" static void chat(void *ssl, int sock) { @@ -47,32 +43,29 @@ static void chat(void *ssl, int sock) if (FD_ISSET(STDIN_FILENO, &rfds)) { int n = read(STDIN_FILENO, buf, sizeof(buf)); - do { - ret = ssl_write(ssl, buf, n); - if (ret < 0) { - fprintf(stderr, "ssl_write: %s\n", ssl_last_error_string(err_buf, sizeof(err_buf))); - return; - } - } while (ret == 0); + ret = ssl_write_nonblock(ssl, sock, buf, n); + if (ret < 0) + return; printf("Send: %.*s\n", ret, buf); } else if (FD_ISSET(sock, &rfds)) { - ret = ssl_read(ssl, buf, sizeof(buf)); + bool closed; + ret = ssl_read_nonblock(ssl, sock, buf, sizeof(buf), &closed); if (ret < 0) { - fprintf(stderr, "ssl_read: %s\n", ssl_last_error_string(err_buf, sizeof(err_buf))); ssl_session_free(ssl); close(sock); return; } - if (ret == 0) { + if (closed) { fprintf(stderr, "Connection closed by peer\n"); ssl_session_free(ssl); close(sock); return; } - printf("Recv: %.*s\n", ret, buf); + if (ret > 0) + printf("Recv: %.*s\n", ret, buf); } } } @@ -100,14 +93,19 @@ static void *connect_ssl(int sock, const char *host) ssl_set_server_name(ssl, host); - do { + while (true) { ret = ssl_connect(ssl, on_verify_error, NULL); + if (ret == SSL_OK) + break; if (ret == SSL_ERROR) { fprintf(stderr, "ssl_connect: %s\n", ssl_last_error_string(err_buf, sizeof(err_buf))); return NULL; } - } while (ret == SSL_PENDING); + + if (ssl_select(sock, ret)) + return NULL; + } printf("SSL negotiation OK\n"); @@ -140,6 +138,11 @@ static bool wait_connect(int sock) return false; } + if (err) { + fprintf(stderr, "connect: %s\n", strerror(err)); + return false; + } + return true; } diff --git a/example-server.c b/example-server.c index bb3ef84..c2d3ce4 100644 --- a/example-server.c +++ b/example-server.c @@ -1,7 +1,7 @@ /* - * openssl: gcc example-server.c openssl.c -lssl -lcrypto - * wolfssl: gcc example-server.c openssl.c -lwolfssl -DHAVE_WOLFSSL - * mbedtls: gcc example-server.c mbedtls.c -lmbedtls -lmbedcrypto -lmbedx509 + * openssl: gcc example-server.c openssl.c -lssl -lcrypto -o server + * wolfssl: gcc example-server.c openssl.c -lwolfssl -DHAVE_WOLFSSL -o server + * mbedtls: gcc example-server.c mbedtls.c -lmbedtls -lmbedcrypto -lmbedx509 -o server */ #define _GNU_SOURCE @@ -17,15 +17,10 @@ #include #include -#include "ssl.h" +#include "example.h" static struct ssl_context *ctx; -static void on_verify_error(int error, const char *str, void *arg) -{ - fprintf(stderr, "WARNING: SSL certificate error(%d): %s\n", error, str); -} - static void chat(void *ssl, int sock) { char err_buf[128]; @@ -50,32 +45,29 @@ static void chat(void *ssl, int sock) if (FD_ISSET(STDIN_FILENO, &rfds)) { int n = read(STDIN_FILENO, buf, sizeof(buf)); - do { - ret = ssl_write(ssl, buf, n); - if (ret < 0) { - fprintf(stderr, "ssl_write: %s\n", ssl_last_error_string(err_buf, sizeof(err_buf))); - return; - } - } while (ret == 0); + ret = ssl_write_nonblock(ssl, sock, buf, n); + if (ret < 0) + return; printf("Send: %.*s\n", ret, buf); } else if (FD_ISSET(sock, &rfds)) { - ret = ssl_read(ssl, buf, sizeof(buf)); + bool closed; + ret = ssl_read_nonblock(ssl, sock, buf, sizeof(buf), &closed); if (ret < 0) { - fprintf(stderr, "ssl_read error: %s\n", ssl_last_error_string(err_buf, sizeof(err_buf))); ssl_session_free(ssl); close(sock); return; } - if (ret == 0) { + if (closed) { fprintf(stderr, "Connection closed by peer\n"); ssl_session_free(ssl); close(sock); return; } - printf("Recv: %.*s\n", ret, buf); + if (ret > 0) + printf("Recv: %.*s\n", ret, buf); } } } @@ -94,14 +86,19 @@ static void *ssl_negotiation(int sock) printf("Wait SSL negotiation...\n"); - do { + while (true) { ret = ssl_accept(ssl, on_verify_error, NULL); + if (ret == SSL_OK) + break; if (ret == SSL_ERROR) { fprintf(stderr, "ssl_connect: %s\n", ssl_last_error_string(err_buf, sizeof(err_buf))); return NULL; } - } while (ret == SSL_PENDING); + + if (ssl_select(sock, ret)) + return NULL; + } printf("SSL negotiation OK\n"); @@ -137,8 +134,15 @@ int main(int argc, char **argv) ctx = ssl_context_new(true); - ssl_load_crt_file(ctx, "example.crt"); - ssl_load_key_file(ctx, "example.key"); + if (ssl_load_crt_file(ctx, "example.crt")) { + fprintf(stderr, "ssl_load_crt_file fail\n"); + return -1; + } + + if (ssl_load_key_file(ctx, "example.key")) { + fprintf(stderr, "ssl_load_key_file fail\n"); + return -1; + } printf("Wait connect...\n"); diff --git a/example.crt b/example.crt new file mode 100644 index 0000000..7b68b55 --- /dev/null +++ b/example.crt @@ -0,0 +1,21 @@ +-----BEGIN CERTIFICATE----- +MIIDbzCCAlegAwIBAgIUdHJODTH4Ym0AAUyM54BqJBQZNIAwDQYJKoZIhvcNAQEL +BQAwRzELMAkGA1UEBhMCQ1oxEjAQBgNVBAoMCUFjbWUgSW5jLjENMAsGA1UECwwE +QUNNRTEVMBMGA1UEAwwMQUNNRS1ERVYtMTIzMB4XDTIyMDQyOTE1NDUzN1oXDTIy +MDUyOTE1NDUzN1owRzELMAkGA1UEBhMCQ1oxEjAQBgNVBAoMCUFjbWUgSW5jLjEN +MAsGA1UECwwEQUNNRTEVMBMGA1UEAwwMQUNNRS1ERVYtMTIzMIIBIjANBgkqhkiG +9w0BAQEFAAOCAQ8AMIIBCgKCAQEA6udad6dBQwX1GW2z+7yPX7Aq8Rkoi7BFIdmQ +px3w/5t35FQtE+O/+4GsIPXIXEq0Q0vezZRCkRe8Jdona5blqZv8s2tP2rPoI3Sk +YE/0kAvAMA18Xh4VyGw5xlFIrOa9T1yAhAD3xGD33kS/mhJEDnLn9GCQ7nEwoxoJ +XEZf8/7B0+GxLfoR2VpO+gZNWtsCHYidIY6QU3YvgcvGb2RAKULCQAoY86CKleBd +WO3nVKNOCv1mb1WrqKiUC6kmFACeau6gLywyJXGej8M3+YuozxDirwAHiuFKZItX +CGoaMwZrGBouSCq4T4UJiIjZPaSCRY0B0bB1WDuXSK+JqHFDJQIDAQABo1MwUTAd +BgNVHQ4EFgQU/KbZRAdXeskytl6tj0Gw0kjXVZcwHwYDVR0jBBgwFoAU/KbZRAdX +eskytl6tj0Gw0kjXVZcwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOC +AQEAFknK2HoE86Yxf29mux2MltkBvLsUIgliCVanPUiJOFOjoLPzLnpmyJcTnpdG +W+d9GqLzxfbmtVPI0YzB8uYV2r/Xh3LN5H+UB9v/5IjQcFftTIj3oECGEOQOEu75 +M1o4d/moazbZCauZi63N3GlgdQmIsuAoQy6wy3j8AkxQGt2ptASffx0rDXjS7Z+L +XM3RN0/d/h3iAwboeT5XhUVkEa2zSOw0NmXTCKpNJvPr0ulLLCpSOErD6Jq+sMw5 +En+VZhh4QiK7eqqb/tuQclkcGxcSAIRpVg98+PLnwi0VYR2YPnwxV1SAqMDWKHYH +U3X0GiiOs125w7x4zFV9d+73/Q== +-----END CERTIFICATE----- diff --git a/example.h b/example.h new file mode 100644 index 0000000..fc198d5 --- /dev/null +++ b/example.h @@ -0,0 +1,86 @@ +#ifndef __EXAMPLE_H +#define __EXAMPLE_H + +#include +#include +#include + +#include "ssl.h" + +static int ssl_select(int sock, int ret) +{ + fd_set fds = {}; + + FD_SET(sock, &fds); + + if (ret == SSL_WANT_READ) + ret = select(sock + 1, &fds, NULL, NULL, NULL); + else if (ret == SSL_WANT_WRITE) + ret = select(sock + 1, NULL, &fds, NULL, NULL); + + if (ret < 0) { + perror("select"); + return -1; + } + + return 0; +} + +static void on_verify_error(int error, const char *str, void *arg) +{ + fprintf(stderr, "WARNING: SSL certificate error(%d): %s\n", error, str); +} + +static int ssl_write_nonblock(void *ssl, int sock, void *data, int len) +{ + char err_buf[128]; + fd_set fds; + int ret; + + while (true) { + ret = ssl_write(ssl, data, len); + if (ret == SSL_ERROR) { + fprintf(stderr, "ssl_write: %s\n", ssl_last_error_string(err_buf, sizeof(err_buf))); + return -1; + } + + if (ret > 0) + return ret; + + if (ssl_select(sock, ret)) + return -1; + } +} + +static int ssl_read_nonblock(void *ssl, int sock, void *data, int len, bool *closed) +{ + char err_buf[128]; + fd_set fds; + int ret; + + *closed = false; + + while (true) { + ret = ssl_read(ssl, data, len); + if (ret == SSL_ERROR) { + fprintf(stderr, "ssl_read: %s\n", ssl_last_error_string(err_buf, sizeof(err_buf))); + return -1; + } + + if (ret > 0) + return ret; + + if (ret == 0) { + *closed = true; + return 0; + } + + if (ret == SSL_WANT_READ) + return 0; + + if (ssl_select(sock, ret)) + return -1; + } +} + +#endif diff --git a/example.key b/example.key new file mode 100644 index 0000000..4bd00ba --- /dev/null +++ b/example.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDq51p3p0FDBfUZ +bbP7vI9fsCrxGSiLsEUh2ZCnHfD/m3fkVC0T47/7gawg9chcSrRDS97NlEKRF7wl +2idrluWpm/yza0/as+gjdKRgT/SQC8AwDXxeHhXIbDnGUUis5r1PXICEAPfEYPfe +RL+aEkQOcuf0YJDucTCjGglcRl/z/sHT4bEt+hHZWk76Bk1a2wIdiJ0hjpBTdi+B +y8ZvZEApQsJAChjzoIqV4F1Y7edUo04K/WZvVauoqJQLqSYUAJ5q7qAvLDIlcZ6P +wzf5i6jPEOKvAAeK4Upki1cIahozBmsYGi5IKrhPhQmIiNk9pIJFjQHRsHVYO5dI +r4mocUMlAgMBAAECggEAAb61N9VTZp2OjAv6EyEmntXZWSV63SAdQK/z40yVeTic +7o0c2/HMVSB0owwLBdB86qZkzHnQ+BtJMljJWS3A8qkYz/ZjR03LKCmaJ6hVueo0 +bnGd3mRyfKTSgGF4h5Gb5LvcV96v+H1QlLibBG3P+UbWPUT9s/US9kCKbZfiPOxA +Zv/eEodSjzPkbjkBp54qfH04bndJPPcapSyCbIUwUpfjjz5vgubq4+wtNfH73+Pj +Iaj1N7erOTJ27Dip19rwK+vfIxUMoFiX+IoZy6P6j4ZrmII2d+kkMI0P/4r8SutB +HP/zyXXsVEILhHy+8Mj193Wb1MUtnobj3kXfSbiXowKBgQD717Eqc+z73SotiTvv +09uod5fDEzOsF/i7Isb3uDpEKU0q0yC8A1+KGp6ocFvng0jf8FvM43vRKA3jp1oW +hHPpJBzZOcSFKv6qhT96BnINvPt9G2U7IOcfWF9AQsAcAdeHqRzh1+XdPOOzcjHr +0Q5UQ1c1F2uW14YGAhHU8AWM1wKBgQDuyBORH4Hdh+ncJMBDbro78aRLioQQLwpS +1W2poKspmaiP5/ZgAPFx+Slylm+QOqXgtsLK4jy+cKRxGtH2fEvAex5YSsHWIxYm +XWKzTkcpOQGdLbDlahhDktTym8ezv0//9WWHSqXlweQmN9rz+ySjwLPZrK4CU+tv +L3CgVq0UYwKBgDOWxpsMrkIV1xsG3rlNK9UB0pvKZi5dpr0m7Z03JvBpiX45S55Y +Do0q0M9uXNU7BoWhJhz9iJKa4uV8la7BKUFb/XDeLYyd9xcVPqCPi3OW/+lr4DvR +jKbWIoT2Z4YVNoJ6uQjmghbk7zwGK4XECGxocwfUKVz3/2Nhrydwl6J3AoGAPN4f +rsS7U/9La+SqZgYZzyH/4YnDtGRpW0gwlibwusACqfxVX4+d/JGpMR6L/dYVZrzv +1svo9Bq+sF5H228/2CcKSzNzSeTTxp/TgyWXGjj/4lM9Xp225bLOObHgLD++Yt7p +LJ2owHK2d8+RLtR8OInszrYn/UvrHgKX0SeHI0UCgYEA7YM/WpeB24Wlt2viujEi +EcBx5BLu+uoere7Q9veCXXvbUYm46nkAU0tf2BolGtmDTMdPPf8sVLfwS4xUQiXq +nbe5SHBUHzB/js8FQtnud4SkDB1I/93XoxHsX20quwISQEi146ezTxBdTzC3wDvM +OkkCScqCD+xXUwi1dsbXr3g= +-----END PRIVATE KEY----- diff --git a/mbedtls.c b/mbedtls.c index 4718b29..74238f2 100644 --- a/mbedtls.c +++ b/mbedtls.c @@ -379,16 +379,13 @@ void ssl_set_server_name(void *ssl, const char *name) mbedtls_ssl_set_hostname(ssl, name); } -static bool ssl_do_wait(int ret) -{ - switch(ret) { - case MBEDTLS_ERR_SSL_WANT_READ: - case MBEDTLS_ERR_SSL_WANT_WRITE: - return true; - default: - return false; - } -} +#define ssl_need_retry(ret) \ + do { \ + if (ret == MBEDTLS_ERR_SSL_WANT_READ) \ + return SSL_WANT_READ; \ + else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) \ + return SSL_WANT_WRITE; \ + } while (0) static void ssl_verify_cert(void *ssl, void (*on_verify_error)(int error, const char *str, void *arg), void *arg) { @@ -424,8 +421,7 @@ static int ssl_handshake(void *ssl, bool server, return SSL_OK; } - if (ssl_do_wait(r)) - return SSL_PENDING; + ssl_need_retry(r); ssl_err_code = r; @@ -453,9 +449,7 @@ int ssl_write(void *ssl, const void *buf, int len) ret = mbedtls_ssl_write(ssl, (const unsigned char *)buf + done, len - done); if (ret < 0) { - if (ssl_do_wait(ret)) - return done; - + ssl_need_retry(ret); ssl_err_code = ret; return -1; } @@ -473,8 +467,7 @@ int ssl_read(void *ssl, void *buf, int len) ssl_err_code = 0; if (ret < 0) { - if (ssl_do_wait(ret)) - return SSL_PENDING; + ssl_need_retry(ret); if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) return 0; diff --git a/openssl.c b/openssl.c index 2f87e7e..8b99c9e 100644 --- a/openssl.c +++ b/openssl.c @@ -371,6 +371,14 @@ static bool handle_wolfssl_asn_error(void *ssl, int r, } #endif +#define ssl_need_retry(ret) \ + do { \ + if (ret == SSL_ERROR_WANT_READ) \ + return SSL_WANT_READ; \ + else if (ret == SSL_ERROR_WANT_WRITE) \ + return SSL_WANT_WRITE; \ + } while (0) + static int ssl_handshake(void *ssl, bool server, void (*on_verify_error)(int error, const char *str, void *arg), void *arg) { @@ -391,8 +399,8 @@ static int ssl_handshake(void *ssl, bool server, } r = SSL_get_error(ssl, r); - if (r == SSL_ERROR_WANT_READ || r == SSL_ERROR_WANT_WRITE) - return SSL_PENDING; + + ssl_need_retry(r); #ifdef WOLFSSL_SSL_H if (handle_wolfssl_asn_error(ssl, r, on_verify_error, arg)) @@ -426,9 +434,7 @@ int ssl_write(void *ssl, const void *buf, int len) if (ret < 0) { ret = SSL_get_error(ssl, ret); - if (ret == SSL_ERROR_WANT_WRITE || ret == SSL_ERROR_WANT_READ) - return SSL_PENDING; - + ssl_need_retry(ret); ssl_err_code = ret; return SSL_ERROR; } @@ -447,9 +453,7 @@ int ssl_read(void *ssl, void *buf, int len) ret = SSL_read(ssl, buf, len); if (ret < 0) { ret = SSL_get_error(ssl, ret); - if (ret == SSL_ERROR_WANT_WRITE || ret == SSL_ERROR_WANT_READ) - return SSL_PENDING; - + ssl_need_retry(ret); ssl_err_code = ret; return SSL_ERROR; } diff --git a/ssl.h b/ssl.h index de4862f..7be74cd 100644 --- a/ssl.h +++ b/ssl.h @@ -30,7 +30,8 @@ enum { SSL_OK = 0, SSL_ERROR = -1, - SSL_PENDING = -2 + SSL_WANT_READ = -2, + SSL_WANT_WRITE = -3 }; struct ssl_context;