Skip to content

Commit

Permalink
WebSockets client: handle TCP fragmentation and simplify code
Browse files Browse the repository at this point in the history
  • Loading branch information
argosphil committed Feb 24, 2024
1 parent 49b2e2d commit 840a2b3
Showing 1 changed file with 58 additions and 108 deletions.
166 changes: 58 additions & 108 deletions src/http/websocket_http_client.zig
Original file line number Diff line number Diff line change
Expand Up @@ -781,8 +781,8 @@ fn parseWebSocketHeader(
// | Payload Data continued ... |
// +---------------------------------------------------------------+
const header = WebsocketHeader.fromSlice(bytes);
const payload = @as(usize, header.len);
payload_length.* = payload;
const payload_len = @as(usize, header.len);
payload_length.* = payload_len;
receiving_type.* = header.opcode;
is_fragmented.* = switch (header.opcode) {
.Continue => true,
Expand All @@ -799,11 +799,11 @@ fn parseWebSocketHeader(
}

return switch (header.opcode) {
.Text, .Continue, .Binary => if (payload <= 125)
.Text, .Continue, .Binary => if (payload_len <= 125)
return .need_body
else if (payload == 126)
else if (payload_len == 126)
return .extended_payload_length_16
else if (payload == 127)
else if (payload_len == 127)
return .extended_payload_length_64
else
return .fail,
Expand Down Expand Up @@ -922,11 +922,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
// we need to start with final so we validate the first frame
receiving_is_final: bool = true,

recursion: i32 = 0,

ping_frame_bytes: [128 + 6]u8 = [_]u8{0} ** (128 + 6),
ping_len: u8 = 0,
ping_received: bool = false,
close_received: bool = false,

receive_frame: usize = 0,
Expand All @@ -939,9 +936,9 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
globalThis: *JSC.JSGlobalObject,
poll_ref: Async.KeepAlive = Async.KeepAlive.init(),

initial_fragment: []u8 = &.{},
buffered_data: []u8 = &.{},
required_data_len: usize = 0,

initial_data_handler: ?*InitialDataHandler = null,
event_loop: *JSC.EventLoop = undefined,

pub const name = if (ssl) "WebSocketClientTLS" else "WebSocketClient";
Expand Down Expand Up @@ -984,9 +981,12 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {

pub fn clearData(this: *WebSocket) void {
this.poll_ref.unref(this.globalThis.bunVM());
if (this.buffered_data.len > 0) {
bun.default_allocator.free(this.buffered_data);
this.buffered_data = this.buffered_data[0..0];
}
this.clearReceiveBuffers(true);
this.clearSendBuffers(true);
this.ping_received = false;
this.ping_len = 0;
this.receive_pending_chunk_len = 0;
}
Expand Down Expand Up @@ -1154,64 +1154,47 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
}

pub fn handleData(this: *WebSocket, socket: Socket, data_: []const u8) void {
if (data_.len > 1 and this.recursion == 0) {
for (0..data_.len) |i| {
handleData(this, socket, data_[i .. i + 1]);
}
return;
}
// after receiving close we should ignore the data
if (this.close_received) return;

if (this.initial_fragment.len > 0) {
const fragment = this.initial_fragment;
this.initial_fragment = fragment[0..0];
const new_data = bun.default_allocator.alloc(u8, fragment.len + data_.len) catch {
this.sendClose();
var data = data_;

// This code deliberately uses recursion and dynamic memory allocation in the uncommon case of TCP
// fragmentation of WebSockets frame headers. The common case (no fragmentation or fragmentation in
// the payload) falls right through and doesn't bloat the struct further.
while (this.buffered_data.len > 0) {
const data_len = @min(data.len, this.required_data_len - this.buffered_data.len);
const fragment = this.buffered_data;
this.buffered_data = fragment[0..0];
// allocated separately so we can use @memcpy, which requires non-overlapping memory regions.
const new_data = bun.default_allocator.alloc(u8, fragment.len + data_len) catch {
this.fail(ErrorCode.failed_to_allocate_memory);
return;
};
@memcpy(new_data[0..fragment.len], fragment);
bun.default_allocator.free(fragment);
@memcpy(new_data[fragment.len..], data_);
this.recursion += 1;
@memcpy(new_data[fragment.len..][0..data_len], data[0..data_len]);
// the recursive call may change this.buffered_data, this.close_received, or this.required_data_len
handleData(this, socket, new_data);
this.recursion -= 1;
bun.default_allocator.free(new_data);
return;
}
// Due to scheduling, it is possible for the websocket onData
// handler to run with additional data before the microtask queue is
// drained.
if (this.initial_data_handler) |initial_handler| {
// This calls `handleData`
// We deliberately do not set this.initial_data_handler to null here, that's done in handleWithoutDeinit.
// We do not free the memory here since the lifetime is managed by the microtask queue (it should free when called from there)
initial_handler.handleWithoutDeinit();

// handleWithoutDeinit is supposed to clear the handler from WebSocket*
// to prevent an infinite loop
std.debug.assert(this.initial_data_handler == null);

// If we disconnected for any reason in the re-entrant case, we should just ignore the data
if (this.outgoing_websocket == null or this.tcp.isShutdown() or this.tcp.isClosed())
data = data[data_len..];
if (data.len == 0 or this.close_received) {
return;
}
}

var terminated = false;
var data = data_;
defer {
if (!terminated and data.len > 0) {
this.initial_fragment = bun.default_allocator.alloc(u8, data.len) catch blk: {
this.sendClose();
data = data[0..0];
this.buffered_data = bun.default_allocator.dupe(u8, data) catch blk: {
// not much we can do in this case.
this.fail(ErrorCode.failed_to_allocate_memory);
break :blk &.{};
};
@memcpy(this.initial_fragment, data);
}
}

var receive_state = this.receive_state;
var is_fragmented = false;
var receiving_type = this.receiving_type;
var receive_body_remain = this.receive_body_remain;
var is_final = this.receiving_is_final;
Expand All @@ -1227,13 +1210,11 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
}
}

// In the WebSocket specification, control frames may not be fragmented.
// However, the frame parser should handle fragmented control frames nonetheless.
// Whether or not the frame parser is given a set of fragmented bytes to parse is subject
// to the strategy in which the client buffers and coalesces received bytes.

while (data.len > 0) {
log("onData ({s} {} {})", .{ @tagName(receive_state), data[0], data.len });
var progress = true;
while (data.len > 0 or progress) {
log("onData ({s})", .{@tagName(receive_state)});
const old_receive_state = receive_state;
defer progress = receive_state != old_receive_state;

switch (receive_state) {
// 0 1 2 3
Expand All @@ -1256,12 +1237,14 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
// +---------------------------------------------------------------+
.need_header => {
if (data.len < 2) {
this.required_data_len = 2;
break;
}

receive_body_remain = 0;
var need_compression = false;
is_final = false;
var is_fragmented = false;

receive_state = parseWebSocketHeader(
data[0..2].*,
Expand Down Expand Up @@ -1313,27 +1296,6 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
terminated = true;
break;
}

// Handle when the payload length is 0, but it is a message
//
// This should become
//
// - ArrayBuffer(0)
// - ""
// - Buffer(0) (etc)
//
if (receive_body_remain == 0 and receive_state == .need_body and is_final) {
_ = this.consume(
"",
receive_body_remain,
last_receive_data_type,
is_final,
);

// Return to the header state to read the next frame
receive_state = .need_header;
is_fragmented = false;
}
},
.need_mask => {
this.terminate(.unexpected_mask_from_server);
Expand All @@ -1349,6 +1311,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {

// we need to wait for more data
if (data.len < byte_size) {
this.required_data_len = byte_size;
break;
}

Expand All @@ -1371,21 +1334,17 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
}
},
.ping => {
if (!this.ping_received) {
if (receive_body_remain > 125) {
this.terminate(ErrorCode.invalid_control_frame);
terminated = true;
break;
}
this.ping_len = @truncate(receive_body_remain);
this.ping_received = true;
if (receive_body_remain > 125) {
this.terminate(ErrorCode.invalid_control_frame);
terminated = true;
break;
}
const ping_len = this.ping_len;

if (data.len < receive_body_remain) {
this.required_data_len = receive_body_remain;
break;
}
const ping_data = this.ping_frame_bytes[6..][0..ping_len];
this.ping_len = @truncate(receive_body_remain);
const ping_data = this.ping_frame_bytes[6..][0..this.ping_len];
@memcpy(ping_data, data[0..ping_data.len]);
data = data[ping_data.len..];

Expand All @@ -1394,13 +1353,13 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
receive_state = .need_header;
receive_body_remain = 0;
receiving_type = last_receive_data_type;
this.ping_received = false;

// we need to send all pongs to pass autobahn tests
_ = this.sendPong(socket);
},
.pong => {
if (data.len < receive_body_remain) {
this.required_data_len = receive_body_remain;
break;
}
const pong_len = receive_body_remain;
Expand All @@ -1411,19 +1370,15 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
receive_state = .need_header;
receive_body_remain = 0;
receiving_type = last_receive_data_type;

if (data.len == 0) break;
},
.need_body => {
const to_consume = @min(receive_body_remain, data.len);

const consumed = this.consume(data[0..to_consume], receive_body_remain, last_receive_data_type, is_final);

receive_body_remain -= consumed;
data = data[to_consume..];
if (receive_body_remain == 0) {
receive_state = .need_header;
is_fragmented = false;
}
},

Expand All @@ -1434,19 +1389,18 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
break;
}
if (data.len < receive_body_remain) {
this.required_data_len = receive_body_remain;
break;
}

this.close_received = true;
terminated = true;

// invalid close frame with 1 byte
if (receive_body_remain == 1) {
// invalid close frame with 1 byte
this.terminate(ErrorCode.invalid_control_frame);
terminated = true;
break;
}
// 2 byte close code and optional reason
if (data.len >= 2) {
} else if (receive_body_remain >= 2) {
// 2 byte close code and optional reason
var code = std.mem.readInt(u16, data[0..2], .big);
log("Received close with code {d}", .{code});
if (code == 1001) {
Expand All @@ -1461,12 +1415,10 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
@memcpy(close_reason_buf[0..reason_len], data[2..receive_body_remain]);
this.sendCloseWithBody(socket, code, &close_reason_buf, reason_len);
data = data[receive_body_remain..];
terminated = true;
break;
} else {
// empty close
this.sendClose();
}

this.sendClose();
terminated = true;
break;
},
.fail => {
Expand Down Expand Up @@ -1768,24 +1720,21 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
const InitialDataHandler = struct {
adopted: ?*WebSocket,
ws: *CppWebSocket,
slice: []u8,

pub const Handle = JSC.AnyTask.New(@This(), handle);

pub fn handleWithoutDeinit(this: *@This()) void {
var this_socket = this.adopted orelse return;
this.adopted = null;
this_socket.initial_data_handler = null;
var ws = this.ws;
defer ws.unref();

if (this_socket.outgoing_websocket != null)
this_socket.handleData(this_socket.tcp, this.slice);
this_socket.handleData(this_socket.tcp, "");
}

pub fn handle(this: *@This()) void {
defer {
bun.default_allocator.free(this.slice);
bun.default_allocator.destroy(this);
}
this.handleWithoutDeinit();
Expand Down Expand Up @@ -1828,9 +1777,10 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
const buffered_slice: []u8 = buffered_data[0..buffered_data_len];
if (buffered_slice.len > 0) {
const initial_data = bun.default_allocator.create(InitialDataHandler) catch unreachable;
ws.buffered_data = buffered_slice;
ws.required_data_len = buffered_slice.len;
initial_data.* = .{
.adopted = ws,
.slice = buffered_slice,
.ws = outgoing,
};

Expand Down

0 comments on commit 840a2b3

Please sign in to comment.