Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable non-blocking socket mode in coreMQTT transport (CA-329) #214

Merged
merged 2 commits into from
Jul 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 170 additions & 66 deletions libraries/coreMQTT/port/network_transport/network_transport.c
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
#include "esp_err.h"
#include "freertos/FreeRTOS.h"
#include "freertos/projdefs.h"
#include "freertos/semphr.h"
#include <string.h>
#include "esp_log.h"
#include "esp_tls.h"
#include "sys/socket.h"
#include "network_transport.h"
#include "sdkconfig.h"

#define TAG "network_transport"

TlsTransportStatus_t xTlsConnect( NetworkContext_t* pxNetworkContext )
{
TlsTransportStatus_t xRet = TLS_TRANSPORT_SUCCESS;
TlsTransportStatus_t xResult = TLS_TRANSPORT_CONNECT_FAILURE;

esp_tls_cfg_t xEspTlsConfig = {
.cacert_buf = (const unsigned char*) ( pxNetworkContext->pcServerRootCA ),
Expand All @@ -21,99 +26,198 @@ TlsTransportStatus_t xTlsConnect( NetworkContext_t* pxNetworkContext )
.ds_data = pxNetworkContext->ds_data,
.clientkey_buf = ( const unsigned char* )( pxNetworkContext->pcClientKey ),
.clientkey_bytes = pxNetworkContext->pcClientKeySize,
.timeout_ms = 1000,
.timeout_ms = 2000,
.non_block = false,
};

esp_tls_t* pxTls = esp_tls_init();

xSemaphoreTake(pxNetworkContext->xTlsContextSemaphore, portMAX_DELAY);
pxNetworkContext->pxTls = pxTls;

if (esp_tls_conn_new_sync( pxNetworkContext->pcHostname,
strlen( pxNetworkContext->pcHostname ),
pxNetworkContext->xPort,
&xEspTlsConfig, pxTls) <= 0)
if( xSemaphoreTake(pxNetworkContext->xTlsContextSemaphore, portMAX_DELAY) == pdTRUE )
{
if (pxNetworkContext->pxTls)
int lConnectResult = -1;
esp_tls_t * pxTls = esp_tls_init();

if( pxTls != NULL )
{
esp_tls_conn_destroy(pxNetworkContext->pxTls);
pxNetworkContext->pxTls = NULL;
pxNetworkContext->pxTls = pxTls;

lConnectResult = esp_tls_conn_new_sync( pxNetworkContext->pcHostname,
strlen( pxNetworkContext->pcHostname ),
pxNetworkContext->xPort,
&xEspTlsConfig, pxTls );

if( lConnectResult == 1 )
{

int lSockFd = -1;
if( esp_tls_get_conn_sockfd(pxNetworkContext->pxTls, &lSockFd) == ESP_OK)
{
int flags = fcntl(lSockFd, F_GETFL);

if( fcntl(lSockFd, F_SETFL, flags | O_NONBLOCK ) != -1)
{
xResult = TLS_TRANSPORT_SUCCESS;
}
}
}

if( xResult != TLS_TRANSPORT_SUCCESS )
{
esp_tls_conn_destroy( pxNetworkContext->pxTls );
pxNetworkContext->pxTls = NULL;
}
}
xRet = TLS_TRANSPORT_CONNECT_FAILURE;
( void ) xSemaphoreGive( pxNetworkContext->xTlsContextSemaphore );
}

xSemaphoreGive(pxNetworkContext->xTlsContextSemaphore);

return xRet;
return xResult;
}

TlsTransportStatus_t xTlsDisconnect( NetworkContext_t* pxNetworkContext )
{
BaseType_t xRet = TLS_TRANSPORT_SUCCESS;
BaseType_t xResult;

xSemaphoreTake(pxNetworkContext->xTlsContextSemaphore, portMAX_DELAY);
if (pxNetworkContext->pxTls != NULL &&
esp_tls_conn_destroy(pxNetworkContext->pxTls) < 0)
if( xSemaphoreTake(pxNetworkContext->xTlsContextSemaphore, portMAX_DELAY ) == pdTRUE )
{
xRet = TLS_TRANSPORT_DISCONNECT_FAILURE;
if( pxNetworkContext->pxTls == NULL )
{
xResult = TLS_TRANSPORT_SUCCESS;
}
else if(esp_tls_conn_destroy(pxNetworkContext->pxTls ) == 0)
{
xResult = TLS_TRANSPORT_SUCCESS;
}
else
{
xResult = TLS_TRANSPORT_DISCONNECT_FAILURE;
}

( void ) xSemaphoreGive( pxNetworkContext->xTlsContextSemaphore );
}
else
{
xResult = TLS_TRANSPORT_DISCONNECT_FAILURE;
}
pxNetworkContext->pxTls = NULL;
xSemaphoreGive(pxNetworkContext->xTlsContextSemaphore);

return xRet;
return xResult;
}

int32_t espTlsTransportSend(NetworkContext_t* pxNetworkContext,
const void* pvData, size_t uxDataLen)
int32_t espTlsTransportSend( NetworkContext_t* pxNetworkContext,
const void* pvData, size_t uxDataLen)
{
if (pvData == NULL || uxDataLen == 0)
int32_t lBytesSent = -1;

if( ( pvData != NULL ) &&
( uxDataLen > 0 ) &&
( pxNetworkContext != NULL ) &&
( pxNetworkContext->pxTls != NULL ) )
{
return -1;
}
TimeOut_t xTimeout;
TickType_t xTicksToWait = pdMS_TO_TICKS(10);

int32_t lBytesSent = 0;
vTaskSetTimeOutState( &xTimeout );

if(pxNetworkContext != NULL && pxNetworkContext->pxTls != NULL)
{
xSemaphoreTake(pxNetworkContext->xTlsContextSemaphore, portMAX_DELAY);
lBytesSent = esp_tls_conn_write(pxNetworkContext->pxTls, pvData, uxDataLen);
xSemaphoreGive(pxNetworkContext->xTlsContextSemaphore);
}
else
{
lBytesSent = -1;
if( xSemaphoreTake( pxNetworkContext->xTlsContextSemaphore, xTicksToWait ) == pdTRUE )
{
int lSockFd = -1;
esp_err_t xError = esp_tls_get_conn_sockfd( pxNetworkContext->pxTls, &lSockFd );
if( xError == ESP_OK)
{
unsigned char * pucData = ( unsigned char * ) pvData;
struct timeval timeout = { .tv_usec = 10000, .tv_sec = 0 };
lBytesSent = 0;
do
{
fd_set write_fds;
fd_set error_fds;
int lSelectResult = -1;

FD_ZERO( &write_fds );
FD_SET( lSockFd, &write_fds );

FD_ZERO( &error_fds );
FD_SET( lSockFd, &error_fds );

lSelectResult = select( lSockFd + 1, NULL, &write_fds, &error_fds, &timeout );

if( lSelectResult < 0 )
{
lBytesSent = -1;
ESP_LOGE(TAG, "Error during call to select.");
break;
}
else if( ( lSelectResult > 0 ) && ( FD_ISSET( lSockFd, &write_fds ) != 0 ) )
{
ssize_t lResult = esp_tls_conn_write( pxNetworkContext->pxTls,
&(pucData[lBytesSent]),
uxDataLen - lBytesSent );

if( lResult > 0 )
{
lBytesSent += ( int32_t ) lResult;
}
else if( ( lResult != MBEDTLS_ERR_SSL_WANT_WRITE ) &&
( lResult != MBEDTLS_ERR_SSL_WANT_READ ) )
{
lBytesSent = lResult;
}
else
{
/* Empty when lResult == 0 */
}

if( ( lBytesSent < 0 ) ||
( lBytesSent == uxDataLen ) )
{
break;
}
}

else
{
/* Empty when lSelectResult == 0 */
}
}
while( xTaskCheckForTimeOut( &xTimeout, &xTicksToWait ) == pdFALSE );
}
xSemaphoreGive(pxNetworkContext->xTlsContextSemaphore);
}
}

return lBytesSent;
}

int32_t espTlsTransportRecv(NetworkContext_t* pxNetworkContext,
void* pvData, size_t uxDataLen)
int32_t espTlsTransportRecv( NetworkContext_t* pxNetworkContext,
void* pvData, size_t uxDataLen)
{
if (pvData == NULL || uxDataLen == 0)
{
return -1;
}
int32_t lBytesRead = 0;
if(pxNetworkContext != NULL && pxNetworkContext->pxTls != NULL)
{
xSemaphoreTake(pxNetworkContext->xTlsContextSemaphore, portMAX_DELAY);
lBytesRead = esp_tls_conn_read(pxNetworkContext->pxTls, pvData, uxDataLen);
xSemaphoreGive(pxNetworkContext->xTlsContextSemaphore);
}
else
int32_t lBytesRead = -1;

if( ( pvData != NULL ) &&
( uxDataLen > 0 ) &&
( pxNetworkContext != NULL ) &&
( pxNetworkContext->pxTls != NULL ) )
{
return -1; /* pxNetworkContext or pxTls uninitialised */
}
if (lBytesRead == ESP_TLS_ERR_SSL_WANT_WRITE || lBytesRead == ESP_TLS_ERR_SSL_WANT_READ) {
return 0;
}
if (lBytesRead < 0) {
return lBytesRead;
}
if (lBytesRead == 0) {
/* Connection closed */
return -1;
if( xSemaphoreTake( pxNetworkContext->xTlsContextSemaphore, portMAX_DELAY ) == pdTRUE )
{
ssize_t lResult = esp_tls_conn_read( pxNetworkContext->pxTls,
pvData,
( size_t ) uxDataLen );

if( lResult > 0 )
{
lBytesRead = ( int32_t ) lResult;
}
else if( ( lResult != MBEDTLS_ERR_SSL_WANT_WRITE ) &&
( lResult != MBEDTLS_ERR_SSL_WANT_READ ) )
{
lBytesRead = ( int32_t ) lResult;
}
else
{
lBytesRead = 0;
}

( void ) xSemaphoreGive( pxNetworkContext->xTlsContextSemaphore);
}
}

return lBytesRead;
}