From d4d87b4c208724c7a6979f12bfde812acd6665f1 Mon Sep 17 00:00:00 2001 From: Philip Russell Date: Sat, 24 Feb 2024 15:11:55 +0000 Subject: [PATCH] WebSockets client: handle TCP fragmentation and simplify code Fixes #9052. --- src/http/websocket_http_client.zig | 166 ++++++++++------------------- 1 file changed, 58 insertions(+), 108 deletions(-) diff --git a/src/http/websocket_http_client.zig b/src/http/websocket_http_client.zig index 73f52f38e52986..27bc60e8181294 100644 --- a/src/http/websocket_http_client.zig +++ b/src/http/websocket_http_client.zig @@ -781,8 +781,8 @@ fn parseWebSocketHeader( // | Payload Data continued ... | // +---------------------------------------------------------------+ const header = WebsocketHeader.fromSlice(bytes); - const payload = @as(usize, header.len); - payload_length.* = payload; + const payload_len = @as(usize, header.len); + payload_length.* = payload_len; receiving_type.* = header.opcode; is_fragmented.* = switch (header.opcode) { .Continue => true, @@ -799,11 +799,11 @@ fn parseWebSocketHeader( } return switch (header.opcode) { - .Text, .Continue, .Binary => if (payload <= 125) + .Text, .Continue, .Binary => if (payload_len <= 125) return .need_body - else if (payload == 126) + else if (payload_len == 126) return .extended_payload_length_16 - else if (payload == 127) + else if (payload_len == 127) return .extended_payload_length_64 else return .fail, @@ -922,11 +922,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { // we need to start with final so we validate the first frame receiving_is_final: bool = true, - recursion: i32 = 0, - ping_frame_bytes: [128 + 6]u8 = [_]u8{0} ** (128 + 6), ping_len: u8 = 0, - ping_received: bool = false, close_received: bool = false, receive_frame: usize = 0, @@ -939,9 +936,9 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { globalThis: *JSC.JSGlobalObject, poll_ref: Async.KeepAlive = Async.KeepAlive.init(), - initial_fragment: []u8 = &.{}, + buffered_data: []u8 = &.{}, + required_data_len: usize = 0, - initial_data_handler: ?*InitialDataHandler = null, event_loop: *JSC.EventLoop = undefined, pub const name = if (ssl) "WebSocketClientTLS" else "WebSocketClient"; @@ -984,9 +981,12 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { pub fn clearData(this: *WebSocket) void { this.poll_ref.unref(this.globalThis.bunVM()); + if (this.buffered_data.len > 0) { + bun.default_allocator.free(this.buffered_data); + this.buffered_data = this.buffered_data[0..0]; + } this.clearReceiveBuffers(true); this.clearSendBuffers(true); - this.ping_received = false; this.ping_len = 0; this.receive_pending_chunk_len = 0; } @@ -1154,64 +1154,47 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { } pub fn handleData(this: *WebSocket, socket: Socket, data_: []const u8) void { - if (data_.len > 1 and this.recursion == 0) { - for (0..data_.len) |i| { - handleData(this, socket, data_[i .. i + 1]); - } - return; - } // after receiving close we should ignore the data if (this.close_received) return; - if (this.initial_fragment.len > 0) { - const fragment = this.initial_fragment; - this.initial_fragment = fragment[0..0]; - const new_data = bun.default_allocator.alloc(u8, fragment.len + data_.len) catch { - this.sendClose(); + var data = data_; + + // This code deliberately uses recursion and dynamic memory allocation in the uncommon case of TCP + // fragmentation of WebSockets frame headers. The common case (no fragmentation or fragmentation in + // the payload) falls right through and doesn't bloat the struct further. + while (this.buffered_data.len > 0) { + const data_len = @min(data.len, this.required_data_len - this.buffered_data.len); + const fragment = this.buffered_data; + this.buffered_data = fragment[0..0]; + // allocated separately so we can use @memcpy, which requires non-overlapping memory regions. + const new_data = bun.default_allocator.alloc(u8, fragment.len + data_len) catch { + this.fail(ErrorCode.failed_to_allocate_memory); return; }; @memcpy(new_data[0..fragment.len], fragment); bun.default_allocator.free(fragment); - @memcpy(new_data[fragment.len..], data_); - this.recursion += 1; + @memcpy(new_data[fragment.len..][0..data_len], data[0..data_len]); + // the recursive call may change this.buffered_data, this.close_received, or this.required_data_len handleData(this, socket, new_data); - this.recursion -= 1; bun.default_allocator.free(new_data); - return; - } - // Due to scheduling, it is possible for the websocket onData - // handler to run with additional data before the microtask queue is - // drained. - if (this.initial_data_handler) |initial_handler| { - // This calls `handleData` - // We deliberately do not set this.initial_data_handler to null here, that's done in handleWithoutDeinit. - // We do not free the memory here since the lifetime is managed by the microtask queue (it should free when called from there) - initial_handler.handleWithoutDeinit(); - - // handleWithoutDeinit is supposed to clear the handler from WebSocket* - // to prevent an infinite loop - std.debug.assert(this.initial_data_handler == null); - - // If we disconnected for any reason in the re-entrant case, we should just ignore the data - if (this.outgoing_websocket == null or this.tcp.isShutdown() or this.tcp.isClosed()) + data = data[data_len..]; + if (data.len == 0 or this.close_received) { return; + } } var terminated = false; - var data = data_; defer { if (!terminated and data.len > 0) { - this.initial_fragment = bun.default_allocator.alloc(u8, data.len) catch blk: { - this.sendClose(); - data = data[0..0]; + this.buffered_data = bun.default_allocator.dupe(u8, data) catch blk: { + // not much we can do in this case. + this.fail(ErrorCode.failed_to_allocate_memory); break :blk &.{}; }; - @memcpy(this.initial_fragment, data); } } var receive_state = this.receive_state; - var is_fragmented = false; var receiving_type = this.receiving_type; var receive_body_remain = this.receive_body_remain; var is_final = this.receiving_is_final; @@ -1227,13 +1210,11 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { } } - // In the WebSocket specification, control frames may not be fragmented. - // However, the frame parser should handle fragmented control frames nonetheless. - // Whether or not the frame parser is given a set of fragmented bytes to parse is subject - // to the strategy in which the client buffers and coalesces received bytes. - - while (data.len > 0) { - log("onData ({s} {} {})", .{ @tagName(receive_state), data[0], data.len }); + var progress = true; + while (data.len > 0 or progress) { + log("onData ({s})", .{@tagName(receive_state)}); + const old_receive_state = receive_state; + defer progress = receive_state != old_receive_state; switch (receive_state) { // 0 1 2 3 @@ -1256,12 +1237,14 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { // +---------------------------------------------------------------+ .need_header => { if (data.len < 2) { + this.required_data_len = 2; break; } receive_body_remain = 0; var need_compression = false; is_final = false; + var is_fragmented = false; receive_state = parseWebSocketHeader( data[0..2].*, @@ -1313,27 +1296,6 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { terminated = true; break; } - - // Handle when the payload length is 0, but it is a message - // - // This should become - // - // - ArrayBuffer(0) - // - "" - // - Buffer(0) (etc) - // - if (receive_body_remain == 0 and receive_state == .need_body and is_final) { - _ = this.consume( - "", - receive_body_remain, - last_receive_data_type, - is_final, - ); - - // Return to the header state to read the next frame - receive_state = .need_header; - is_fragmented = false; - } }, .need_mask => { this.terminate(.unexpected_mask_from_server); @@ -1349,6 +1311,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { // we need to wait for more data if (data.len < byte_size) { + this.required_data_len = byte_size; break; } @@ -1371,21 +1334,17 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { } }, .ping => { - if (!this.ping_received) { - if (receive_body_remain > 125) { - this.terminate(ErrorCode.invalid_control_frame); - terminated = true; - break; - } - this.ping_len = @truncate(receive_body_remain); - this.ping_received = true; + if (receive_body_remain > 125) { + this.terminate(ErrorCode.invalid_control_frame); + terminated = true; + break; } - const ping_len = this.ping_len; - if (data.len < receive_body_remain) { + this.required_data_len = receive_body_remain; break; } - const ping_data = this.ping_frame_bytes[6..][0..ping_len]; + this.ping_len = @truncate(receive_body_remain); + const ping_data = this.ping_frame_bytes[6..][0..this.ping_len]; @memcpy(ping_data, data[0..ping_data.len]); data = data[ping_data.len..]; @@ -1394,13 +1353,13 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { receive_state = .need_header; receive_body_remain = 0; receiving_type = last_receive_data_type; - this.ping_received = false; // we need to send all pongs to pass autobahn tests _ = this.sendPong(socket); }, .pong => { if (data.len < receive_body_remain) { + this.required_data_len = receive_body_remain; break; } const pong_len = receive_body_remain; @@ -1411,19 +1370,15 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { receive_state = .need_header; receive_body_remain = 0; receiving_type = last_receive_data_type; - - if (data.len == 0) break; }, .need_body => { const to_consume = @min(receive_body_remain, data.len); - const consumed = this.consume(data[0..to_consume], receive_body_remain, last_receive_data_type, is_final); receive_body_remain -= consumed; data = data[to_consume..]; if (receive_body_remain == 0) { receive_state = .need_header; - is_fragmented = false; } }, @@ -1434,19 +1389,18 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { break; } if (data.len < receive_body_remain) { + this.required_data_len = receive_body_remain; break; } this.close_received = true; + terminated = true; - // invalid close frame with 1 byte if (receive_body_remain == 1) { + // invalid close frame with 1 byte this.terminate(ErrorCode.invalid_control_frame); - terminated = true; - break; - } - // 2 byte close code and optional reason - if (data.len >= 2) { + } else if (receive_body_remain >= 2) { + // 2 byte close code and optional reason var code = std.mem.readInt(u16, data[0..2], .big); log("Received close with code {d}", .{code}); if (code == 1001) { @@ -1461,12 +1415,10 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { @memcpy(close_reason_buf[0..reason_len], data[2..receive_body_remain]); this.sendCloseWithBody(socket, code, &close_reason_buf, reason_len); data = data[receive_body_remain..]; - terminated = true; - break; + } else { + // empty close + this.sendClose(); } - - this.sendClose(); - terminated = true; break; }, .fail => { @@ -1768,24 +1720,21 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { const InitialDataHandler = struct { adopted: ?*WebSocket, ws: *CppWebSocket, - slice: []u8, pub const Handle = JSC.AnyTask.New(@This(), handle); pub fn handleWithoutDeinit(this: *@This()) void { var this_socket = this.adopted orelse return; this.adopted = null; - this_socket.initial_data_handler = null; var ws = this.ws; defer ws.unref(); if (this_socket.outgoing_websocket != null) - this_socket.handleData(this_socket.tcp, this.slice); + this_socket.handleData(this_socket.tcp, ""); } pub fn handle(this: *@This()) void { defer { - bun.default_allocator.free(this.slice); bun.default_allocator.destroy(this); } this.handleWithoutDeinit(); @@ -1828,9 +1777,10 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { const buffered_slice: []u8 = buffered_data[0..buffered_data_len]; if (buffered_slice.len > 0) { const initial_data = bun.default_allocator.create(InitialDataHandler) catch unreachable; + ws.buffered_data = buffered_slice; + ws.required_data_len = buffered_slice.len; initial_data.* = .{ .adopted = ws, - .slice = buffered_slice, .ws = outgoing, };