From 4f970ad974572618dad00d9ff8f4f90acbc2f316 Mon Sep 17 00:00:00 2001 From: M Hightower <27247790+mhightower83@users.noreply.github.com> Date: Mon, 7 Oct 2019 10:45:55 -0700 Subject: [PATCH] Corrected missed edit for _recv() in ASYNC_TCP_SSL_ENABLED in AsyncServer::_poll(). And other missed edit for errorTracker around ASYNC_TCP_SSL_ENABLED. This should resolve @kasedy comment https://github.com/me-no-dev/ESPAsyncTCP/pull/115#issuecomment-538816623 and @mcspr. Tested ASYNC_TCP_SSL_ENABLED using marvinroger/async-mqtt-client/ .. examples/FullyFeaturedSSL. Ran test against test.mosquitto.org's server. Thanks to @mcspr for suggesting. Updated tcp_ssl_read() to check for fd_data being freed by callback functions. I observed this with asyncmqttclient example. When finger print did not match during fd_data->on_handshake callback, the mqtt library did a close(true) which rippled down to an tcp_ssl_free(). Improvements in debug printing to handle debug print from tcp.axtls.c. --- src/DebugPrintMacros.h | 50 ++++++++++++++++++++++++++++-------------- src/ESPAsyncTCP.cpp | 41 ++++++++++++++++++++++++++-------- src/async_config.h | 6 ++++- src/tcp_axtls.c | 45 ++++++++++++++++++++++++++++--------- 4 files changed, 105 insertions(+), 37 deletions(-) diff --git a/src/DebugPrintMacros.h b/src/DebugPrintMacros.h index 29accaf..30ed706 100644 --- a/src/DebugPrintMacros.h +++ b/src/DebugPrintMacros.h @@ -27,18 +27,34 @@ inline struct _DEBUG_TIME_STAMP debugTimeStamp(void) { } #endif +#if defined(DEBUG_ESP_PORT) && !defined(DEBUG_ESP_PORT_PRINTF) + +#ifdef __cplusplus +#define DEBUG_ESP_PORT_PRINTF(format, ...) DEBUG_ESP_PORT.printf((format), ##__VA_ARGS__) +#define DEBUG_ESP_PORT_PRINTF_F(format, ...) DEBUG_ESP_PORT.printf_P(PSTR(format), ##__VA_ARGS__) +#define DEBUG_ESP_PORT_FLUSH DEBUG_ESP_PORT.flush +#else +// Handle debug printing from .c without CPP Stream, Print, ... classes +// Cannot handle flash strings in this setting +#define DEBUG_ESP_PORT_PRINTF ets_uart_printf +#define DEBUG_ESP_PORT_PRINTF_F ets_uart_printf +#define DEBUG_ESP_PORT_FLUSH (void)0 +#endif + +#endif + #if defined(DEBUG_ESP_PORT) && !defined(DEBUG_GENERIC) #define DEBUG_GENERIC( module, format, ... ) \ do { \ struct _DEBUG_TIME_STAMP st = debugTimeStamp(); \ - DEBUG_ESP_PORT.printf( DEBUG_TIME_STAMP_FMT module " " format, st.whole, st.dec, ##__VA_ARGS__ ); \ + DEBUG_ESP_PORT_PRINTF( (DEBUG_TIME_STAMP_FMT module " " format), st.whole, st.dec, ##__VA_ARGS__ ); \ } while(false) #endif -#if defined(DEBUG_ESP_PORT) && !defined(DEBUG_GENERIC_P) - #define DEBUG_GENERIC_P( module, format, ... ) \ +#if defined(DEBUG_ESP_PORT) && !defined(DEBUG_GENERIC_F) + #define DEBUG_GENERIC_F( module, format, ... ) \ do { \ struct _DEBUG_TIME_STAMP st = debugTimeStamp(); \ - DEBUG_ESP_PORT.printf_P(PSTR( DEBUG_TIME_STAMP_FMT module " " format ), st.whole, st.dec, ##__VA_ARGS__ ); \ + DEBUG_ESP_PORT_PRINTF_F( (DEBUG_TIME_STAMP_FMT module " " format), st.whole, st.dec, ##__VA_ARGS__ ); \ } while(false) #endif @@ -47,16 +63,16 @@ inline struct _DEBUG_TIME_STAMP debugTimeStamp(void) { do { \ if ( !(a) ) { \ DEBUG_GENERIC( module, "%s:%s:%u: ASSERT("#a") failed!\n", __FILE__, __func__, __LINE__); \ - DEBUG_ESP_PORT.flush(); \ + DEBUG_ESP_PORT_FLUSH(); \ } \ } while(false) #endif -#if defined(DEBUG_GENERIC_P) && !defined(ASSERT_GENERIC_P) -#define ASSERT_GENERIC_P( a, module ) \ +#if defined(DEBUG_GENERIC_F) && !defined(ASSERT_GENERIC_F) +#define ASSERT_GENERIC_F( a, module ) \ do { \ if ( !(a) ) { \ - DEBUG_GENERIC_P( module, "%s:%s:%u: ASSERT("#a") failed!\n", __FILE__, __func__, __LINE__); \ - DEBUG_ESP_PORT.flush(); \ + DEBUG_GENERIC_F( module, "%s:%s:%u: ASSERT("#a") failed!\n", __FILE__, __func__, __LINE__); \ + DEBUG_ESP_PORT_FLUSH(); \ } \ } while(false) #endif @@ -65,32 +81,32 @@ inline struct _DEBUG_TIME_STAMP debugTimeStamp(void) { #define DEBUG_GENERIC(...) do { (void)0;} while(false) #endif -#ifndef DEBUG_GENERIC_P -#define DEBUG_GENERIC_P(...) do { (void)0;} while(false) +#ifndef DEBUG_GENERIC_F +#define DEBUG_GENERIC_F(...) do { (void)0;} while(false) #endif #ifndef ASSERT_GENERIC #define ASSERT_GENERIC(...) do { (void)0;} while(false) #endif -#ifndef ASSERT_GENERIC_P -#define ASSERT_GENERIC_P(...) do { (void)0;} while(false) +#ifndef ASSERT_GENERIC_F +#define ASSERT_GENERIC_F(...) do { (void)0;} while(false) #endif #ifndef DEBUG_ESP_PRINTF -#define DEBUG_ESP_PRINTF( format, ...) DEBUG_GENERIC_P("[%s]", format, &_FILENAME_[1], ##__VA_ARGS__) +#define DEBUG_ESP_PRINTF( format, ...) DEBUG_GENERIC_F("[%s]", format, &_FILENAME_[1], ##__VA_ARGS__) #endif #if defined(DEBUG_ESP_ASYNC_TCP) && !defined(ASYNC_TCP_DEBUG) -#define ASYNC_TCP_DEBUG( format, ...) DEBUG_GENERIC_P("[ASYNC_TCP]", format, ##__VA_ARGS__) +#define ASYNC_TCP_DEBUG( format, ...) DEBUG_GENERIC_F("[ASYNC_TCP]", format, ##__VA_ARGS__) #endif #ifndef ASYNC_TCP_ASSERT -#define ASYNC_TCP_ASSERT( a ) ASSERT_GENERIC_P( (a), "[ASYNC_TCP]") +#define ASYNC_TCP_ASSERT( a ) ASSERT_GENERIC_F( (a), "[ASYNC_TCP]") #endif #if defined(DEBUG_ESP_TCP_SSL) && !defined(TCP_SSL_DEBUG) -#define TCP_SSL_DEBUG( format, ...) DEBUG_GENERIC_P("[TCP_SSL]", format, ##__VA_ARGS__) +#define TCP_SSL_DEBUG( format, ...) DEBUG_GENERIC_F("[TCP_SSL]", format, ##__VA_ARGS__) #endif #endif //_DEBUG_PRINT_MACROS_H diff --git a/src/ESPAsyncTCP.cpp b/src/ESPAsyncTCP.cpp index 7a9fdc7..0f4dad0 100644 --- a/src/ESPAsyncTCP.cpp +++ b/src/ESPAsyncTCP.cpp @@ -70,6 +70,7 @@ yield(), etc. */ + #include "Arduino.h" #include "ESPAsyncTCP.h" @@ -355,10 +356,12 @@ void AsyncClient::abort(){ void AsyncClient::close(bool now){ if(_pcb) tcp_recved(_pcb, _rx_ack_len); - if(now) + if(now) { + ASYNC_TCP_DEBUG("close[%u]: AsyncClient 0x%" PRIXPTR "\n", getConnectionId(), uintptr_t(this)); _close(); - else + } else { _close_pcb = true; + } } void AsyncClient::stop() { @@ -503,6 +506,7 @@ void AsyncClient::_close(){ err_t err = tcp_close(_pcb); if(ERR_OK == err) { setCloseError(err); + ASYNC_TCP_DEBUG("_close[%u]: AsyncClient 0x%" PRIXPTR "\n", getConnectionId(), uintptr_t(this)); } else { ASYNC_TCP_DEBUG("_close[%u]: abort() called for AsyncClient 0x%" PRIXPTR "\n", getConnectionId(), uintptr_t(this)); abort(); @@ -664,6 +668,7 @@ void AsyncClient::_poll(std::shared_ptr& errorTracker, tcp_pcb* // Close requested if(_close_pcb){ + ASYNC_TCP_DEBUG("_poll[%u]: Process _close_pcb.\n", errorTracker->getConnectionId() ); _close_pcb = false; _close(); return; @@ -679,12 +684,14 @@ void AsyncClient::_poll(std::shared_ptr& errorTracker, tcp_pcb* } // RX Timeout if(_rx_since_timeout && (now - _rx_last_packet) >= (_rx_since_timeout * 1000)){ + ASYNC_TCP_DEBUG("_poll[%u]: RX Timeout.\n", errorTracker->getConnectionId() ); _close(); return; } #if ASYNC_TCP_SSL_ENABLED // SSL Handshake Timeout if(_pcb_secure && !_handshake_done && (now - _rx_last_packet) >= 2000){ + ASYNC_TCP_DEBUG("_poll[%u]: SSL Handshake Timeout.\n", errorTracker->getConnectionId() ); _close(); return; } @@ -762,12 +769,15 @@ err_t AsyncClient::_s_connected(void* arg, void* tpcb, err_t err){ #if ASYNC_TCP_SSL_ENABLED void AsyncClient::_s_data(void *arg, struct tcp_pcb *tcp, uint8_t * data, size_t len){ + (void)tcp; AsyncClient *c = reinterpret_cast(arg); if(c->_recv_cb) c->_recv_cb(c->_recv_cb_arg, c, data, len); } void AsyncClient::_s_handshake(void *arg, struct tcp_pcb *tcp, SSL *ssl){ + (void)tcp; + (void)ssl; AsyncClient *c = reinterpret_cast(arg); c->_handshake_done = true; if(c->_connect_cb) @@ -775,6 +785,12 @@ void AsyncClient::_s_handshake(void *arg, struct tcp_pcb *tcp, SSL *ssl){ } void AsyncClient::_s_ssl_error(void *arg, struct tcp_pcb *tcp, int8_t err){ + (void)tcp; +#ifdef DEBUG_ESP_ASYNC_TCP + AsyncClient *c = reinterpret_cast(arg); + auto errorTracker = c->getACErrorTracker(); + ASYNC_TCP_DEBUG("_ssl_error[%u] err = %d\n", errorTracker->getConnectionId(), err); +#endif reinterpret_cast(arg)->_ssl_error(err); } #endif @@ -1230,7 +1246,7 @@ err_t AsyncServer::_accept(tcp_pcb* pcb, err_t err){ } return ERR_OK; } - ASYNC_TCP_DEBUG("### put to wait: %d\n", _clients_waiting); + //1 ASYNC_TCP_DEBUG("### put to wait: %d\n", _clients_waiting); new_item->pcb = pcb; new_item->pb = NULL; new_item->next = NULL; @@ -1252,6 +1268,7 @@ err_t AsyncServer::_accept(tcp_pcb* pcb, err_t err){ if(c){ ASYNC_TCP_DEBUG("_accept[%u]: SSL connected\n", c->getConnectionId()); c->onConnect([this](void * arg, AsyncClient *c){ + (void)arg; _connect_cb(_connect_cb_arg, c); }, this); } else { @@ -1303,6 +1320,7 @@ err_t AsyncServer::_s_accept(void *arg, tcp_pcb* pcb, err_t err){ #if ASYNC_TCP_SSL_ENABLED err_t AsyncServer::_poll(tcp_pcb* pcb){ + err_t err = ERR_OK; if(!tcp_ssl_has_client() && _pending){ struct pending_pcb * p = _pending; if(p->pcb == pcb){ @@ -1314,29 +1332,34 @@ err_t AsyncServer::_poll(tcp_pcb* pcb){ p->next = b->next; p = b; } - ASYNC_TCP_DEBUG("### remove from wait: %d\n", _clients_waiting); + //1 ASYNC_TCP_DEBUG("### remove from wait: %d\n", _clients_waiting); AsyncClient *c = new (std::nothrow) AsyncClient(pcb, _ssl_ctx); if(c){ c->onConnect([this](void * arg, AsyncClient *c){ + (void)arg; _connect_cb(_connect_cb_arg, c); }, this); - if(p->pb) - c->_recv(pcb, p->pb, 0); + if(p->pb) { + auto errorTracker = c->getACErrorTracker(); + c->_recv(errorTracker, pcb, p->pb, 0); + err = errorTracker->getCallbackCloseError(); + } } // Should there be error handling for when "new AsynClient" fails?? free(p); } - return ERR_OK; + return err; } err_t AsyncServer::_recv(struct tcp_pcb *pcb, struct pbuf *pb, err_t err){ + (void)err; if(!_pending) return ERR_OK; struct pending_pcb * p; if(!pb){ - ASYNC_TCP_DEBUG("### close from wait: %d\n", _clients_waiting); + //1 ASYNC_TCP_DEBUG("### close from wait: %d\n", _clients_waiting); p = _pending; if(p->pcb == pcb){ _pending = _pending->next; @@ -1357,7 +1380,7 @@ err_t AsyncServer::_recv(struct tcp_pcb *pcb, struct pbuf *pb, err_t err){ return ERR_ABRT; } } else { - ASYNC_TCP_DEBUG("### wait _recv: %u %d\n", pb->tot_len, _clients_waiting); + //1 ASYNC_TCP_DEBUG("### wait _recv: %u %d\n", pb->tot_len, _clients_waiting); p = _pending; while(p && p->pcb != pcb) p = p->next; diff --git a/src/async_config.h b/src/async_config.h index ca6912f..0ce336a 100644 --- a/src/async_config.h +++ b/src/async_config.h @@ -20,9 +20,11 @@ // Starting with Arduino Core 2.4.0 and up the define of DEBUG_ESP_PORT // can be handled through the Arduino IDE Board options instead of here. // #define DEBUG_ESP_PORT Serial - // #define DEBUG_ESP_ASYNC_TCP 1 // #define DEBUG_ESP_TCP_SSL 1 + +#ifndef DEBUG_SKIP__DEBUG_PRINT_MACROS + #include #ifndef ASYNC_TCP_ASSERT @@ -35,4 +37,6 @@ #define TCP_SSL_DEBUG(...) do { (void)0;} while(false) #endif +#endif + #endif /* LIBRARIES_ESPASYNCTCP_SRC_ASYNC_CONFIG_H_ */ diff --git a/src/tcp_axtls.c b/src/tcp_axtls.c index cdbdf41..f026b5f 100644 --- a/src/tcp_axtls.c +++ b/src/tcp_axtls.c @@ -22,7 +22,13 @@ * Compatibility for AxTLS with LWIP raw tcp mode (http://lwip.wikia.com/wiki/Raw/TCP) * Original Code and Inspiration: Slavey Karadzhov */ + +// To handle all the definitions needed for debug printing, we need to delay +// macro definitions till later. +#define DEBUG_SKIP__DEBUG_PRINT_MACROS 1 #include +#undef DEBUG_SKIP__DEBUG_PRINT_MACROS + #if ASYNC_TCP_SSL_ENABLED #include "lwip/opt.h" @@ -34,6 +40,13 @@ #include #include +// ets_uart_printf is defined in esp8266_undocumented.h, in newer Arduino ESP8266 Core. +extern int ets_uart_printf(const char *format, ...) __attribute__ ((format (printf, 1, 2))); +#include +#ifndef TCP_SSL_DEBUG +#define TCP_SSL_DEBUG(...) do { (void)0;} while(false) +#endif + uint8_t * default_private_key = NULL; uint16_t default_private_key_len = 0; @@ -377,7 +390,8 @@ int tcp_ssl_read(struct tcp_pcb *tcp, struct pbuf *p) { do { read_bytes = ssl_read(fd_data->ssl, &read_buf); - //TCP_SSL_DEBUG("tcp_ssl_ssl_read: %d\n", read_bytes); + TCP_SSL_DEBUG("tcp_ssl_ssl_read: %d\n", read_bytes); + if(read_bytes < SSL_OK) { if(read_bytes != SSL_CLOSE_NOTIFY) { TCP_SSL_DEBUG("tcp_ssl_read: read error: %d\n", read_bytes); @@ -387,20 +401,31 @@ int tcp_ssl_read(struct tcp_pcb *tcp, struct pbuf *p) { } else if(read_bytes > 0){ if(fd_data->on_data){ fd_data->on_data(fd_data->arg, tcp, read_buf, read_bytes); + // fd_data may have been freed in callback + fd_data = tcp_ssl_get(tcp); + if(NULL == fd_data) + return SSL_CLOSE_NOTIFY; } total_bytes+= read_bytes; } else { if(fd_data->handshake != SSL_OK) { - fd_data->handshake = ssl_handshake_status(fd_data->ssl); - if(fd_data->handshake == SSL_OK){ - //TCP_SSL_DEBUG("tcp_ssl_read: handshake OK\n"); + // fd_data may be freed in callbacks. + int handshake = fd_data->handshake = ssl_handshake_status(fd_data->ssl); + if(handshake == SSL_OK){ + TCP_SSL_DEBUG("tcp_ssl_read: handshake OK\n"); if(fd_data->on_handshake) fd_data->on_handshake(fd_data->arg, fd_data->tcp, fd_data->ssl); - } else if(fd_data->handshake != SSL_NOT_OK){ - TCP_SSL_DEBUG("tcp_ssl_read: handshake error: %d\n", fd_data->handshake); + fd_data = tcp_ssl_get(tcp); + if(NULL == fd_data) + return SSL_CLOSE_NOTIFY; + } else if(handshake != SSL_NOT_OK){ + TCP_SSL_DEBUG("tcp_ssl_read: handshake error: %d\n", handshake); if(fd_data->on_error) - fd_data->on_error(fd_data->arg, fd_data->tcp, fd_data->handshake); - return fd_data->handshake; + fd_data->on_error(fd_data->arg, fd_data->tcp, handshake); + return handshake; + // With current code APP gets called twice at onError handler. + // Once here and again after return when handshake != SSL_CLOSE_NOTIFY. + // As always APP must never free resources at onError only at onDisconnect. } } } @@ -525,13 +550,13 @@ int ax_port_write(int fd, uint8_t *data, uint16_t len) { TCP_SSL_DEBUG("ax_port_write: No memory %d (%d)\n", tcp_len, len); return err; } - TCP_SSL_DEBUG("ax_port_write: tcp_write error: %d\n", err); + TCP_SSL_DEBUG("ax_port_write: tcp_write error: %ld\n", err); return err; } else if (err == ERR_OK) { //TCP_SSL_DEBUG("ax_port_write: tcp_output: %d / %d\n", tcp_len, len); err = tcp_output(fd_data->tcp); if(err != ERR_OK) { - TCP_SSL_DEBUG("ax_port_write: tcp_output err: %d\n", err); + TCP_SSL_DEBUG("ax_port_write: tcp_output err: %ld\n", err); return err; } }