Skip to content

Commit

Permalink
Fix client Websockets (broken) (#2749)
Browse files Browse the repository at this point in the history
This PR addresses several issues discovered in the client websocket stack. A simple local websocket echo server has been added for testing, which can be run using `make wsserver`.

**Memory leaks**

Running `HttpServer_Websockets` sample with valgrind shows a memory leak (details below, fig 1).
I can see from the logic of `WebsocketConnection::send` that there are multiple reasons the call could fail, but the `source` stream is only destroyed in one of them.

**Failed connection**

Testing with the local server failed with `websockets.exceptions.InvalidHeaderValue: invalid Sec-WebSocket-Key header`.
The key was 17 bytes instead of 16.

**utf-8 decoding errors**

Turns out message was getting corrupted because mask value passed to XorStream is on stack, which then gets overwritten before message has been sent out. Fixed by taking a copy of the value.

**CLOSE message not being sent**

Tested with `Websocket_Client` sample (running local echo server) to confirm correct behaviour, noticed a `Streams without known size are not supported` message when closing the connection. This blocked sending 'CLOSE' notifications which have no payload.

**Issues during CLOSE**

The TCP socket was being closed too soon, causing additional errors. Increasing timeout to 2 seconds fixes this. Also included a status code in the CLOSE message indicating normal closure; this is optional, but seems like a good thing to do.

RFC 6455 states: *The application MUST NOT send any more data frames after sending a Close frame. If an endpoint receives a Close frame and did not previously send a Close frame, the endpoint MUST send a Close frame in response.*

Therefore, the `close()` logic has been changed so that a CLOSE message is *always* sent, either in response to a previous incoming request (in which case the received status is echoed back) or as notification that we're closing (status 1000 - normal closure). Checked server operation with `HttpServer_Websockets` sample 


**Simplify packet generation**

It's not necessary to pre-calculate the packet length as it's never more than 14 bytes in length.

**WebsocketConnection not getting destroyed**

HttpConnection objects are not 'auto-released' so leaks memory every time a WebsocketConnection is destroyed (512+ bytes). Simplest fix for this is to add a `setAutoSelfDestruct` method to `TcpConnection` so this can be changed.

**Messages not being received**

Connection is activated OK, but `HttpClientConnection` then calls `init` in its `onConnected` handler which resets the new `receive` delegate. This causes incoming websocket frames to be passed to the http parser, instead of the WS parser, hence the `HTTP parser error: HPE_INVALID_CONSTANT` message.


====

Fig 1: Initial memory leak reported by valgrind

```
==1291918== 
==1291918== HEAP SUMMARY:
==1291918==     in use at exit: 4,133 bytes in 16 blocks
==1291918==   total heap usage: 573 allocs, 557 frees, 71,139 bytes allocated
==1291918== 
==1291918== 64 bytes in 2 blocks are definitely lost in loss record 10 of 13
==1291918==    at 0x4041D7D: operator new(unsigned int) (vg_replace_malloc.c:476)
==1291918==    by 0x8075BC5: WebsocketConnection::send(char const*, unsigned int, ws_frame_type_t) (WebsocketConnection.cpp:180)
==1291918==    by 0x804EB65: send (WebsocketConnection.h:107)
==1291918==    by 0x804EB65: sendString (WebsocketConnection.h:145)
==1291918==    by 0x804EB65: wsCommandReceived(WebsocketConnection&, String const&) (application.cpp:88)
==1291918==    by 0x807575D: operator() (std_function.h:591)
==1291918==    by 0x807575D: WebsocketConnection::staticOnDataPayload(void*, char const*, unsigned int) (WebsocketConnection.cpp:128)
==1291918==    by 0x8081E5C: ws_parser_execute (ws_parser.c:263)
==1291918==    by 0x80756C6: WebsocketConnection::processFrame(TcpClient&, char*, int) (WebsocketConnection.cpp:103)
==1291918==    by 0x8079305: operator() (std_function.h:591)
==1291918==    by 0x8079305: TcpClient::onReceive(pbuf*) (TcpClient.cpp:150)
==1291918==    by 0x8078A8E: TcpConnection::internalOnReceive(pbuf*, signed char) (TcpConnection.cpp:484)
==1291918==    by 0x8058560: tcp_input (in /stripe/sandboxes/sming-dev/samples/HttpServer_WebSockets/out/Host/debug/firmware/app)
==1291918==    by 0x80627F3: ip4_input (in /stripe/sandboxes/sming-dev/samples/HttpServer_WebSockets/out/Host/debug/firmware/app)
==1291918==    by 0x8063C89: ethernet_input (in /stripe/sandboxes/sming-dev/samples/HttpServer_WebSockets/out/Host/debug/firmware/app)
==1291918==    by 0x8064245: tapif_select (in /stripe/sandboxes/sming-dev/samples/HttpServer_WebSockets/out/Host/debug/firmware/app)
==1291918== 
==1291918== LEAK SUMMARY:
==1291918==    definitely lost: 64 bytes in 2 blocks
==1291918==    indirectly lost: 0 bytes in 0 blocks
==1291918==      possibly lost: 0 bytes in 0 blocks
==1291918==    still reachable: 4,069 bytes in 14 blocks
==1291918==         suppressed: 0 bytes in 0 blocks
==1291918== Reachable blocks (those to which a pointer was found) are not shown.
==1291918== To see them, rerun with: --leak-check=full --show-leak-kinds=all
==1291918== 
==1291918== For lists of detected and suppressed errors, rerun with: -s
==1291918== ERROR SUMMARY: 1 errors from 1 contexts (suppressed: 0 from 0)
```
  • Loading branch information
mikee47 authored Apr 3, 2024
1 parent 7f3eba5 commit c487f91
Show file tree
Hide file tree
Showing 12 changed files with 253 additions and 163 deletions.
11 changes: 11 additions & 0 deletions Sming/Components/Network/component.mk
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,14 @@ COMPONENT_INCDIRS += \
endif

endif

##@Testing

# Websocket Server
CACHE_VARS += WSSERVER_PORT
WSSERVER_PORT ?= 9999
.PHONY: wsserver
wsserver: ##Launch a simple python Websocket echo server for testing client applications
$(info Starting Websocket server for TESTING)
$(Q) $(PYTHON) $(CMP_Network_PATH)/tools/wsserver.py $(WSSERVER_PORT)

Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@
****/

#include "WebsocketConnection.h"
#include <BitManipulations.h>
#include <Crypto/Sha1.h>
#include <Data/WebHelpers/base64.h>
#include <Data/Stream/MemoryDataStream.h>
#include <Data/Stream/XorOutputStream.h>
#include <Data/Stream/SharedMemoryStream.h>
#include <memory>

DEFINE_FSTR(WSSTR_CONNECTION, "connection")
DEFINE_FSTR(WSSTR_UPGRADE, "upgrade")
Expand All @@ -32,12 +30,14 @@ WebsocketList WebsocketConnection::websocketList;
/** @brief ws_parser function table
* @note stored in flash memory; as it is word-aligned it can be accessed directly
*/
const ws_parser_callbacks_t WebsocketConnection::parserSettings PROGMEM = {.on_data_begin = staticOnDataBegin,
.on_data_payload = staticOnDataPayload,
.on_data_end = staticOnDataEnd,
.on_control_begin = staticOnControlBegin,
.on_control_payload = staticOnControlPayload,
.on_control_end = staticOnControlEnd};
const ws_parser_callbacks_t WebsocketConnection::parserSettings PROGMEM{
.on_data_begin = staticOnDataBegin,
.on_data_payload = staticOnDataPayload,
.on_data_end = staticOnDataEnd,
.on_control_begin = staticOnControlBegin,
.on_control_payload = staticOnControlPayload,
.on_control_end = staticOnControlEnd,
};

/** @brief Boilerplate code for ws_parser callbacks
* @note Obtain connection object and check it
Expand All @@ -55,6 +55,14 @@ WebsocketConnection::WebsocketConnection(HttpConnection* connection, bool isClie
ws_parser_init(&parser);
}

void WebsocketConnection::setConnection(HttpConnection* connection, bool isClientConnection)
{
assert(this->connection == nullptr);
this->connection = connection;
this->isClientConnection = isClientConnection;
this->state = connection ? eWSCS_Ready : eWSCS_Closed;
}

bool WebsocketConnection::bind(HttpRequest& request, HttpResponse& response)
{
String version = request.headers[HTTP_HEADER_SEC_WEBSOCKET_VERSION];
Expand Down Expand Up @@ -124,10 +132,21 @@ int WebsocketConnection::staticOnDataPayload(void* userData, const char* at, siz
{
GET_CONNECTION();

if(connection->frameType == WS_FRAME_TEXT && connection->wsMessage) {
connection->wsMessage(*connection, String(at, length));
} else if(connection->frameType == WS_FRAME_BINARY && connection->wsBinary) {
connection->wsBinary(*connection, reinterpret_cast<uint8_t*>(const_cast<char*>(at)), length);
switch(connection->frameType) {
case WS_FRAME_TEXT:
if(connection->wsMessage) {
connection->wsMessage(*connection, String(at, length));
}
break;
case WS_FRAME_BINARY:
if(connection->wsBinary) {
connection->wsBinary(*connection, reinterpret_cast<uint8_t*>(const_cast<char*>(at)), length);
}
break;
case WS_FRAME_CLOSE:
case WS_FRAME_PING:
case WS_FRAME_PONG:
break;
}

return WS_OK;
Expand All @@ -142,11 +161,7 @@ int WebsocketConnection::staticOnControlBegin(void* userData, ws_frame_type_t ty
{
GET_CONNECTION();

connection->controlFrame = WsFrameInfo(type, nullptr, 0);

if(type == WS_FRAME_CLOSE) {
connection->close();
}
connection->controlFrame = WsFrameInfo{type};

return WS_OK;
}
Expand All @@ -165,13 +180,26 @@ int WebsocketConnection::staticOnControlEnd(void* userData)
{
GET_CONNECTION();

if(connection->controlFrame.type == WS_FRAME_PING) {
switch(connection->controlFrame.type) {
case WS_FRAME_PING:
connection->send(connection->controlFrame.payload, connection->controlFrame.payloadLength, WS_FRAME_PONG);
}
break;
case WS_FRAME_PONG:
if(connection->wsPong) {
connection->wsPong(*connection);
}
break;

case WS_FRAME_CLOSE:
debug_hex(DBG, "WS: CLOSE", connection->controlFrame.payload, connection->controlFrame.payloadLength);
connection->close();
break;

if(connection->controlFrame.type == WS_FRAME_PONG && connection->wsPong) {
connection->wsPong(*connection);
case WS_FRAME_TEXT:
case WS_FRAME_BINARY:
break;
}

return WS_OK;
}

Expand All @@ -186,6 +214,7 @@ bool WebsocketConnection::send(const char* message, size_t length, ws_frame_type
size_t written = stream->write(message, length);
if(written != length) {
debug_e("Unable to store data in memory buffer");
delete stream;
return false;
}

Expand All @@ -194,96 +223,78 @@ bool WebsocketConnection::send(const char* message, size_t length, ws_frame_type

bool WebsocketConnection::send(IDataSourceStream* source, ws_frame_type_t type, bool useMask, bool isFin)
{
// Ensure source gets destroyed if we return prematurely
std::unique_ptr<IDataSourceStream> sourceRef(source);

if(source == nullptr) {
debug_w("WS: No source");
return false;
}

if(connection == nullptr) {
debug_w("WS: No connection");
return false;
}

if(!activated) {
debug_e("WS Connection is not activated yet!");
debug_e("WS: Not activated");
return false;
}

int available = source->available();
if(available < 1) {
debug_e("Streams without known size are not supported");
if(available < 0) {
debug_e("WS: Unknown stream size");
return false;
}

debug_d("Sending: %d bytes, Type: %d\n", available, type);

size_t packetLength = 2;
uint16_t lengthValue = available;
debug_d("WS: Sending %d bytes, type %d", available, type);

// calculate message length ....
if(available <= 125) {
lengthValue = available;
} else if(available < 65536) {
lengthValue = 126;
packetLength += 2;
} else {
lengthValue = 127;
packetLength += 8;
}

if(useMask) {
packetLength += 4; // we use mask with size 4 bytes
}

uint8_t packet[packetLength];
memset(packet, 0, packetLength);

int i = 0;
// byte 0
// Construct packet
uint8_t packet[16]{};
unsigned len = 0;
if(isFin) {
packet[i] |= bit(7); // set Fin
packet[len] |= _BV(7); // set Fin
}
packet[i++] |= (uint8_t)type; // set opcode
// byte 1
packet[len++] |= type; // set opcode
if(useMask) {
packet[i] |= bit(7); // set mask
packet[len] |= _BV(7); // set mask
}

// length
if(lengthValue < 126) {
packet[i++] |= lengthValue;
} else if(lengthValue == 126) {
packet[i++] |= 126;
packet[i++] = (available >> 8) & 0xFF;
packet[i++] = available & 0xFF;
} else if(lengthValue == 127) {
packet[i++] |= 127;
packet[i++] = 0;
packet[i++] = 0;
packet[i++] = 0;
packet[i++] = 0;
packet[i++] = (available >> 24) & 0xFF;
packet[i++] = (available >> 16) & 0xFF;
packet[i++] = (available >> 8) & 0xFF;
packet[i++] = (available)&0xFF;
if(available <= 125) {
packet[len++] |= available;
} else if(available <= 0xffff) {
packet[len++] |= 126;
packet[len++] = available >> 8;
packet[len++] = available;
} else {
packet[len++] |= 127;
len += 4; // All 0
packet[len++] = available >> 24;
packet[len++] = available >> 16;
packet[len++] = available >> 8;
packet[len++] = available;
}

if(useMask) {
uint8_t maskKey[4] = {0x00, 0x00, 0x00, 0x00};
for(uint8_t x = 0; x < sizeof(maskKey); x++) {
maskKey[x] = (char)os_random();
packet[i++] = maskKey[x];
}
uint8_t maskKey[4];
os_get_random(maskKey, sizeof(maskKey));
memcpy(&packet[len], maskKey, sizeof(maskKey));
len += sizeof(maskKey);

auto xorStream = new XorOutputStream(source, maskKey, sizeof(maskKey));
source = xorStream;
if(xorStream == nullptr) {
return false;
}
sourceRef.release();
sourceRef.reset(xorStream);
}

// send the header
if(!connection->send(reinterpret_cast<const char*>(packet), packetLength)) {
delete source;
if(!connection->send(reinterpret_cast<const char*>(packet), len)) {
return false;
}

return connection->send(source);
// Pass stream to connection
return connection->send(sourceRef.release());
}

void WebsocketConnection::broadcast(const char* message, size_t length, ws_frame_type_t type)
Expand All @@ -300,23 +311,29 @@ void WebsocketConnection::broadcast(const char* message, size_t length, ws_frame

void WebsocketConnection::close()
{
debug_d("Terminating Websocket connection.");
if(connection == nullptr) {
return;
}

debug_d("WS: Terminating connection %p, state %u", connection, state);
websocketList.removeElement(this);
if(state != eWSCS_Closed) {
state = eWSCS_Closed;
if(isClientConnection) {
send(nullptr, 0, WS_FRAME_CLOSE);
if(controlFrame.type == WS_FRAME_CLOSE) {
send(controlFrame.payload, controlFrame.payloadLength, WS_FRAME_CLOSE);
} else {
uint16_t status = htons(1000);
send(reinterpret_cast<char*>(&status), sizeof(status), WS_FRAME_CLOSE);
}
activated = false;
if(wsDisconnect) {
wsDisconnect(*this);
}
}

if(connection) {
connection->setTimeOut(1);
connection = nullptr;
}
connection->setTimeOut(2);
connection->setAutoSelfDestruct(true);
connection = nullptr;
}

void WebsocketConnection::reset()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,6 @@ struct WsFrameInfo {
ws_frame_type_t type = WS_FRAME_TEXT;
char* payload = nullptr;
size_t payloadLength = 0;

WsFrameInfo() = default;

WsFrameInfo(ws_frame_type_t type, char* payload, size_t payloadLength)
: type(type), payload(payload), payloadLength(payloadLength)
{
}
};

class WebsocketConnection
Expand All @@ -73,7 +66,7 @@ class WebsocketConnection
* @param connection the transport connection
* @param isClientConnection true when the passed connection is an http client connection
*/
WebsocketConnection(HttpConnection* connection, bool isClientConnection = true);
WebsocketConnection(HttpConnection* connection = nullptr, bool isClientConnection = true);

virtual ~WebsocketConnection()
{
Expand Down Expand Up @@ -109,14 +102,14 @@ class WebsocketConnection

/**
* @brief Sends websocket message from a stream
* @param stream
* @param source The stream to send - we get ownership of the stream
* @param type
* @param useMask MUST be true for client connections
* @param isFin true if this is the final frame
*
* @retval bool true on success
*/
bool send(IDataSourceStream* stream, ws_frame_type_t type = WS_FRAME_TEXT, bool useMask = false, bool isFin = true);
bool send(IDataSourceStream* source, ws_frame_type_t type = WS_FRAME_TEXT, bool useMask = false, bool isFin = true);

/**
* @brief Broadcasts a message to all active websocket connections
Expand Down Expand Up @@ -271,11 +264,7 @@ class WebsocketConnection
* @param connection the transport connection
* @param isClientConnection true when the passed connection is an http client connection
*/
void setConnection(HttpConnection* connection, bool isClientConnection = true)
{
this->connection = connection;
this->isClientConnection = isClientConnection;
}
void setConnection(HttpConnection* connection, bool isClientConnection = true);

/** @brief Gets the state of the websocket connection
* @retval WsConnectionState
Expand Down Expand Up @@ -311,7 +300,7 @@ class WebsocketConnection

void* userData = nullptr;

WsConnectionState state = eWSCS_Ready;
WsConnectionState state;

private:
ws_frame_type_t frameType = WS_FRAME_TEXT;
Expand All @@ -322,9 +311,8 @@ class WebsocketConnection

static WebsocketList websocketList;

bool isClientConnection = true;

HttpConnection* connection = nullptr;
bool isClientConnection;
bool activated = false;
};

Expand Down
Loading

0 comments on commit c487f91

Please sign in to comment.