Skip to content

Commit

Permalink
Corrected missed edit for _recv() in ASYNC_TCP_SSL_ENABLED in AsyncSe…
Browse files Browse the repository at this point in the history
…rver::_poll().

And other missed edit for errorTracker around ASYNC_TCP_SSL_ENABLED.
This should resolve @kasedy comment me-no-dev#115 (comment)
and @mcspr.

Tested ASYNC_TCP_SSL_ENABLED using marvinroger/async-mqtt-client/
.. examples/FullyFeaturedSSL. Ran test against test.mosquitto.org's
server. Thanks to @mcspr for suggesting.

Updated tcp_ssl_read() to check for fd_data being freed by callback
functions. I observed this with asyncmqttclient example. When finger
print did not match during fd_data->on_handshake callback, the mqtt
library did a close(true) which rippled down to an tcp_ssl_free().

Improvements in debug printing to handle debug print from tcp.axtls.c.
  • Loading branch information
mhightower83 authored and kleini committed Dec 7, 2019
1 parent 1547686 commit 4f970ad
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 37 deletions.
50 changes: 33 additions & 17 deletions src/DebugPrintMacros.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,34 @@ inline struct _DEBUG_TIME_STAMP debugTimeStamp(void) {
}
#endif

#if defined(DEBUG_ESP_PORT) && !defined(DEBUG_ESP_PORT_PRINTF)

#ifdef __cplusplus
#define DEBUG_ESP_PORT_PRINTF(format, ...) DEBUG_ESP_PORT.printf((format), ##__VA_ARGS__)
#define DEBUG_ESP_PORT_PRINTF_F(format, ...) DEBUG_ESP_PORT.printf_P(PSTR(format), ##__VA_ARGS__)
#define DEBUG_ESP_PORT_FLUSH DEBUG_ESP_PORT.flush
#else
// Handle debug printing from .c without CPP Stream, Print, ... classes
// Cannot handle flash strings in this setting
#define DEBUG_ESP_PORT_PRINTF ets_uart_printf
#define DEBUG_ESP_PORT_PRINTF_F ets_uart_printf
#define DEBUG_ESP_PORT_FLUSH (void)0
#endif

#endif

#if defined(DEBUG_ESP_PORT) && !defined(DEBUG_GENERIC)
#define DEBUG_GENERIC( module, format, ... ) \
do { \
struct _DEBUG_TIME_STAMP st = debugTimeStamp(); \
DEBUG_ESP_PORT.printf( DEBUG_TIME_STAMP_FMT module " " format, st.whole, st.dec, ##__VA_ARGS__ ); \
DEBUG_ESP_PORT_PRINTF( (DEBUG_TIME_STAMP_FMT module " " format), st.whole, st.dec, ##__VA_ARGS__ ); \
} while(false)
#endif
#if defined(DEBUG_ESP_PORT) && !defined(DEBUG_GENERIC_P)
#define DEBUG_GENERIC_P( module, format, ... ) \
#if defined(DEBUG_ESP_PORT) && !defined(DEBUG_GENERIC_F)
#define DEBUG_GENERIC_F( module, format, ... ) \
do { \
struct _DEBUG_TIME_STAMP st = debugTimeStamp(); \
DEBUG_ESP_PORT.printf_P(PSTR( DEBUG_TIME_STAMP_FMT module " " format ), st.whole, st.dec, ##__VA_ARGS__ ); \
DEBUG_ESP_PORT_PRINTF_F( (DEBUG_TIME_STAMP_FMT module " " format), st.whole, st.dec, ##__VA_ARGS__ ); \
} while(false)
#endif

Expand All @@ -47,16 +63,16 @@ inline struct _DEBUG_TIME_STAMP debugTimeStamp(void) {
do { \
if ( !(a) ) { \
DEBUG_GENERIC( module, "%s:%s:%u: ASSERT("#a") failed!\n", __FILE__, __func__, __LINE__); \
DEBUG_ESP_PORT.flush(); \
DEBUG_ESP_PORT_FLUSH(); \
} \
} while(false)
#endif
#if defined(DEBUG_GENERIC_P) && !defined(ASSERT_GENERIC_P)
#define ASSERT_GENERIC_P( a, module ) \
#if defined(DEBUG_GENERIC_F) && !defined(ASSERT_GENERIC_F)
#define ASSERT_GENERIC_F( a, module ) \
do { \
if ( !(a) ) { \
DEBUG_GENERIC_P( module, "%s:%s:%u: ASSERT("#a") failed!\n", __FILE__, __func__, __LINE__); \
DEBUG_ESP_PORT.flush(); \
DEBUG_GENERIC_F( module, "%s:%s:%u: ASSERT("#a") failed!\n", __FILE__, __func__, __LINE__); \
DEBUG_ESP_PORT_FLUSH(); \
} \
} while(false)
#endif
Expand All @@ -65,32 +81,32 @@ inline struct _DEBUG_TIME_STAMP debugTimeStamp(void) {
#define DEBUG_GENERIC(...) do { (void)0;} while(false)
#endif

#ifndef DEBUG_GENERIC_P
#define DEBUG_GENERIC_P(...) do { (void)0;} while(false)
#ifndef DEBUG_GENERIC_F
#define DEBUG_GENERIC_F(...) do { (void)0;} while(false)
#endif

#ifndef ASSERT_GENERIC
#define ASSERT_GENERIC(...) do { (void)0;} while(false)
#endif

#ifndef ASSERT_GENERIC_P
#define ASSERT_GENERIC_P(...) do { (void)0;} while(false)
#ifndef ASSERT_GENERIC_F
#define ASSERT_GENERIC_F(...) do { (void)0;} while(false)
#endif

#ifndef DEBUG_ESP_PRINTF
#define DEBUG_ESP_PRINTF( format, ...) DEBUG_GENERIC_P("[%s]", format, &_FILENAME_[1], ##__VA_ARGS__)
#define DEBUG_ESP_PRINTF( format, ...) DEBUG_GENERIC_F("[%s]", format, &_FILENAME_[1], ##__VA_ARGS__)
#endif

#if defined(DEBUG_ESP_ASYNC_TCP) && !defined(ASYNC_TCP_DEBUG)
#define ASYNC_TCP_DEBUG( format, ...) DEBUG_GENERIC_P("[ASYNC_TCP]", format, ##__VA_ARGS__)
#define ASYNC_TCP_DEBUG( format, ...) DEBUG_GENERIC_F("[ASYNC_TCP]", format, ##__VA_ARGS__)
#endif

#ifndef ASYNC_TCP_ASSERT
#define ASYNC_TCP_ASSERT( a ) ASSERT_GENERIC_P( (a), "[ASYNC_TCP]")
#define ASYNC_TCP_ASSERT( a ) ASSERT_GENERIC_F( (a), "[ASYNC_TCP]")
#endif

#if defined(DEBUG_ESP_TCP_SSL) && !defined(TCP_SSL_DEBUG)
#define TCP_SSL_DEBUG( format, ...) DEBUG_GENERIC_P("[TCP_SSL]", format, ##__VA_ARGS__)
#define TCP_SSL_DEBUG( format, ...) DEBUG_GENERIC_F("[TCP_SSL]", format, ##__VA_ARGS__)
#endif

#endif //_DEBUG_PRINT_MACROS_H
41 changes: 32 additions & 9 deletions src/ESPAsyncTCP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ yield(), etc.
*/

#include "Arduino.h"

#include "ESPAsyncTCP.h"
Expand Down Expand Up @@ -355,10 +356,12 @@ void AsyncClient::abort(){
void AsyncClient::close(bool now){
if(_pcb)
tcp_recved(_pcb, _rx_ack_len);
if(now)
if(now) {
ASYNC_TCP_DEBUG("close[%u]: AsyncClient 0x%" PRIXPTR "\n", getConnectionId(), uintptr_t(this));
_close();
else
} else {
_close_pcb = true;
}
}

void AsyncClient::stop() {
Expand Down Expand Up @@ -503,6 +506,7 @@ void AsyncClient::_close(){
err_t err = tcp_close(_pcb);
if(ERR_OK == err) {
setCloseError(err);
ASYNC_TCP_DEBUG("_close[%u]: AsyncClient 0x%" PRIXPTR "\n", getConnectionId(), uintptr_t(this));
} else {
ASYNC_TCP_DEBUG("_close[%u]: abort() called for AsyncClient 0x%" PRIXPTR "\n", getConnectionId(), uintptr_t(this));
abort();
Expand Down Expand Up @@ -664,6 +668,7 @@ void AsyncClient::_poll(std::shared_ptr<ACErrorTracker>& errorTracker, tcp_pcb*

// Close requested
if(_close_pcb){
ASYNC_TCP_DEBUG("_poll[%u]: Process _close_pcb.\n", errorTracker->getConnectionId() );
_close_pcb = false;
_close();
return;
Expand All @@ -679,12 +684,14 @@ void AsyncClient::_poll(std::shared_ptr<ACErrorTracker>& errorTracker, tcp_pcb*
}
// RX Timeout
if(_rx_since_timeout && (now - _rx_last_packet) >= (_rx_since_timeout * 1000)){
ASYNC_TCP_DEBUG("_poll[%u]: RX Timeout.\n", errorTracker->getConnectionId() );
_close();
return;
}
#if ASYNC_TCP_SSL_ENABLED
// SSL Handshake Timeout
if(_pcb_secure && !_handshake_done && (now - _rx_last_packet) >= 2000){
ASYNC_TCP_DEBUG("_poll[%u]: SSL Handshake Timeout.\n", errorTracker->getConnectionId() );
_close();
return;
}
Expand Down Expand Up @@ -762,19 +769,28 @@ err_t AsyncClient::_s_connected(void* arg, void* tpcb, err_t err){

#if ASYNC_TCP_SSL_ENABLED
void AsyncClient::_s_data(void *arg, struct tcp_pcb *tcp, uint8_t * data, size_t len){
(void)tcp;
AsyncClient *c = reinterpret_cast<AsyncClient*>(arg);
if(c->_recv_cb)
c->_recv_cb(c->_recv_cb_arg, c, data, len);
}

void AsyncClient::_s_handshake(void *arg, struct tcp_pcb *tcp, SSL *ssl){
(void)tcp;
(void)ssl;
AsyncClient *c = reinterpret_cast<AsyncClient*>(arg);
c->_handshake_done = true;
if(c->_connect_cb)
c->_connect_cb(c->_connect_cb_arg, c);
}

void AsyncClient::_s_ssl_error(void *arg, struct tcp_pcb *tcp, int8_t err){
(void)tcp;
#ifdef DEBUG_ESP_ASYNC_TCP
AsyncClient *c = reinterpret_cast<AsyncClient*>(arg);
auto errorTracker = c->getACErrorTracker();
ASYNC_TCP_DEBUG("_ssl_error[%u] err = %d\n", errorTracker->getConnectionId(), err);
#endif
reinterpret_cast<AsyncClient*>(arg)->_ssl_error(err);
}
#endif
Expand Down Expand Up @@ -1230,7 +1246,7 @@ err_t AsyncServer::_accept(tcp_pcb* pcb, err_t err){
}
return ERR_OK;
}
ASYNC_TCP_DEBUG("### put to wait: %d\n", _clients_waiting);
//1 ASYNC_TCP_DEBUG("### put to wait: %d\n", _clients_waiting);
new_item->pcb = pcb;
new_item->pb = NULL;
new_item->next = NULL;
Expand All @@ -1252,6 +1268,7 @@ err_t AsyncServer::_accept(tcp_pcb* pcb, err_t err){
if(c){
ASYNC_TCP_DEBUG("_accept[%u]: SSL connected\n", c->getConnectionId());
c->onConnect([this](void * arg, AsyncClient *c){
(void)arg;
_connect_cb(_connect_cb_arg, c);
}, this);
} else {
Expand Down Expand Up @@ -1303,6 +1320,7 @@ err_t AsyncServer::_s_accept(void *arg, tcp_pcb* pcb, err_t err){

#if ASYNC_TCP_SSL_ENABLED
err_t AsyncServer::_poll(tcp_pcb* pcb){
err_t err = ERR_OK;
if(!tcp_ssl_has_client() && _pending){
struct pending_pcb * p = _pending;
if(p->pcb == pcb){
Expand All @@ -1314,29 +1332,34 @@ err_t AsyncServer::_poll(tcp_pcb* pcb){
p->next = b->next;
p = b;
}
ASYNC_TCP_DEBUG("### remove from wait: %d\n", _clients_waiting);
//1 ASYNC_TCP_DEBUG("### remove from wait: %d\n", _clients_waiting);
AsyncClient *c = new (std::nothrow) AsyncClient(pcb, _ssl_ctx);
if(c){
c->onConnect([this](void * arg, AsyncClient *c){
(void)arg;
_connect_cb(_connect_cb_arg, c);
}, this);
if(p->pb)
c->_recv(pcb, p->pb, 0);
if(p->pb) {
auto errorTracker = c->getACErrorTracker();
c->_recv(errorTracker, pcb, p->pb, 0);
err = errorTracker->getCallbackCloseError();
}
}
// Should there be error handling for when "new AsynClient" fails??
free(p);
}
return ERR_OK;
return err;
}

err_t AsyncServer::_recv(struct tcp_pcb *pcb, struct pbuf *pb, err_t err){
(void)err;
if(!_pending)
return ERR_OK;

struct pending_pcb * p;

if(!pb){
ASYNC_TCP_DEBUG("### close from wait: %d\n", _clients_waiting);
//1 ASYNC_TCP_DEBUG("### close from wait: %d\n", _clients_waiting);
p = _pending;
if(p->pcb == pcb){
_pending = _pending->next;
Expand All @@ -1357,7 +1380,7 @@ err_t AsyncServer::_recv(struct tcp_pcb *pcb, struct pbuf *pb, err_t err){
return ERR_ABRT;
}
} else {
ASYNC_TCP_DEBUG("### wait _recv: %u %d\n", pb->tot_len, _clients_waiting);
//1 ASYNC_TCP_DEBUG("### wait _recv: %u %d\n", pb->tot_len, _clients_waiting);
p = _pending;
while(p && p->pcb != pcb)
p = p->next;
Expand Down
6 changes: 5 additions & 1 deletion src/async_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
// Starting with Arduino Core 2.4.0 and up the define of DEBUG_ESP_PORT
// can be handled through the Arduino IDE Board options instead of here.
// #define DEBUG_ESP_PORT Serial

// #define DEBUG_ESP_ASYNC_TCP 1
// #define DEBUG_ESP_TCP_SSL 1

#ifndef DEBUG_SKIP__DEBUG_PRINT_MACROS

#include <DebugPrintMacros.h>

#ifndef ASYNC_TCP_ASSERT
Expand All @@ -35,4 +37,6 @@
#define TCP_SSL_DEBUG(...) do { (void)0;} while(false)
#endif

#endif

#endif /* LIBRARIES_ESPASYNCTCP_SRC_ASYNC_CONFIG_H_ */
45 changes: 35 additions & 10 deletions src/tcp_axtls.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
* Compatibility for AxTLS with LWIP raw tcp mode (http://lwip.wikia.com/wiki/Raw/TCP)
* Original Code and Inspiration: Slavey Karadzhov
*/

// To handle all the definitions needed for debug printing, we need to delay
// macro definitions till later.
#define DEBUG_SKIP__DEBUG_PRINT_MACROS 1
#include <async_config.h>
#undef DEBUG_SKIP__DEBUG_PRINT_MACROS

#if ASYNC_TCP_SSL_ENABLED

#include "lwip/opt.h"
Expand All @@ -34,6 +40,13 @@
#include <stdbool.h>
#include <tcp_axtls.h>

// ets_uart_printf is defined in esp8266_undocumented.h, in newer Arduino ESP8266 Core.
extern int ets_uart_printf(const char *format, ...) __attribute__ ((format (printf, 1, 2)));
#include <DebugPrintMacros.h>
#ifndef TCP_SSL_DEBUG
#define TCP_SSL_DEBUG(...) do { (void)0;} while(false)
#endif

uint8_t * default_private_key = NULL;
uint16_t default_private_key_len = 0;

Expand Down Expand Up @@ -377,7 +390,8 @@ int tcp_ssl_read(struct tcp_pcb *tcp, struct pbuf *p) {

do {
read_bytes = ssl_read(fd_data->ssl, &read_buf);
//TCP_SSL_DEBUG("tcp_ssl_ssl_read: %d\n", read_bytes);
TCP_SSL_DEBUG("tcp_ssl_ssl_read: %d\n", read_bytes);

if(read_bytes < SSL_OK) {
if(read_bytes != SSL_CLOSE_NOTIFY) {
TCP_SSL_DEBUG("tcp_ssl_read: read error: %d\n", read_bytes);
Expand All @@ -387,20 +401,31 @@ int tcp_ssl_read(struct tcp_pcb *tcp, struct pbuf *p) {
} else if(read_bytes > 0){
if(fd_data->on_data){
fd_data->on_data(fd_data->arg, tcp, read_buf, read_bytes);
// fd_data may have been freed in callback
fd_data = tcp_ssl_get(tcp);
if(NULL == fd_data)
return SSL_CLOSE_NOTIFY;
}
total_bytes+= read_bytes;
} else {
if(fd_data->handshake != SSL_OK) {
fd_data->handshake = ssl_handshake_status(fd_data->ssl);
if(fd_data->handshake == SSL_OK){
//TCP_SSL_DEBUG("tcp_ssl_read: handshake OK\n");
// fd_data may be freed in callbacks.
int handshake = fd_data->handshake = ssl_handshake_status(fd_data->ssl);
if(handshake == SSL_OK){
TCP_SSL_DEBUG("tcp_ssl_read: handshake OK\n");
if(fd_data->on_handshake)
fd_data->on_handshake(fd_data->arg, fd_data->tcp, fd_data->ssl);
} else if(fd_data->handshake != SSL_NOT_OK){
TCP_SSL_DEBUG("tcp_ssl_read: handshake error: %d\n", fd_data->handshake);
fd_data = tcp_ssl_get(tcp);
if(NULL == fd_data)
return SSL_CLOSE_NOTIFY;
} else if(handshake != SSL_NOT_OK){
TCP_SSL_DEBUG("tcp_ssl_read: handshake error: %d\n", handshake);
if(fd_data->on_error)
fd_data->on_error(fd_data->arg, fd_data->tcp, fd_data->handshake);
return fd_data->handshake;
fd_data->on_error(fd_data->arg, fd_data->tcp, handshake);
return handshake;
// With current code APP gets called twice at onError handler.
// Once here and again after return when handshake != SSL_CLOSE_NOTIFY.
// As always APP must never free resources at onError only at onDisconnect.
}
}
}
Expand Down Expand Up @@ -525,13 +550,13 @@ int ax_port_write(int fd, uint8_t *data, uint16_t len) {
TCP_SSL_DEBUG("ax_port_write: No memory %d (%d)\n", tcp_len, len);
return err;
}
TCP_SSL_DEBUG("ax_port_write: tcp_write error: %d\n", err);
TCP_SSL_DEBUG("ax_port_write: tcp_write error: %ld\n", err);
return err;
} else if (err == ERR_OK) {
//TCP_SSL_DEBUG("ax_port_write: tcp_output: %d / %d\n", tcp_len, len);
err = tcp_output(fd_data->tcp);
if(err != ERR_OK) {
TCP_SSL_DEBUG("ax_port_write: tcp_output err: %d\n", err);
TCP_SSL_DEBUG("ax_port_write: tcp_output err: %ld\n", err);
return err;
}
}
Expand Down

0 comments on commit 4f970ad

Please sign in to comment.