diff --git a/src/client/client.zig b/src/client/client.zig index fb9dbf6..0897be7 100644 --- a/src/client/client.zig +++ b/src/client/client.zig @@ -138,79 +138,90 @@ pub const Client = struct { pub fn readLoop(self: *Client, handler: anytype) !void { const H = @TypeOf(handler); + const Handler = switch (@typeInfo(H)) { + .Struct => H, + .Pointer => |ptr| ptr.child, + else => @compileError("readLoop handler must be a struct, got: " ++ @tagName(@typeInfo(H))), + }; var reader = &self._reader; - const stream = &self.stream; defer if (comptime std.meta.hasFn(H, "close")) { handler.close(); }; + // block until we have data + try self.readTimeout(0); + while (true) { - // for a single fill, we might have multiple messages to process - while (true) { - const has_more, const message = reader.read() catch |err| { - self.closeWithCode(1002); - return err; - } orelse break; // orelse, we don't have enough data, so break out of the inner loop, and go get more data from the socket in the outer loop - - const message_type = message.type; - defer reader.done(message_type); - - switch (message_type) { - .text, .binary => { - switch (comptime @typeInfo(@TypeOf(H.handleMessage)).Fn.params.len) { - 2 => try handler.handleMessage(message.data), - 3 => try handler.handleMessage(message.data, if (message_type == .text) .text else .binary), - else => @compileError(@typeName(H) ++ ".handleMessage must accept 2 or 3 parameters"), - } - }, - .ping => if (comptime std.meta.hasFn(H, "handlePing")) { - try handler.handlePing(message.data); + const message = (try self.read()) orelse unreachable; + const message_type = message.type; + defer reader.done(message_type); + + switch (message_type) { + .text, .binary => { + switch (comptime @typeInfo(@TypeOf(Handler.handleMessage)).Fn.params.len) { + 2 => try handler.handleMessage(message.data), + 3 => try handler.handleMessage(message.data, if (message_type == .text) .text else .binary), + else => @compileError(@typeName(Handler) ++ ".handleMessage must accept 2 or 3 parameters"), + } + }, + .ping => if (comptime std.meta.hasFn(Handler, "handlePing")) { + try handler.handlePing(message.data); + } else { + // @constCast is safe because we know message.data points to + // reader.buffer.buf, which we own and which can be mutated + try self.writeFrame(.pong, @constCast(message.data)); + }, + .close => { + if (comptime std.meta.hasFn(Handler, "handleClose")) { + try handler.handleClose(message.data); } else { - // @constCast is safe because we know message.data points to - // reader.buffer.buf, which we own and which can be mutated - try self.writeFrame(.pong, @constCast(message.data)); - }, - .close => { - if (comptime std.meta.hasFn(H, "handleClose")) { - try handler.handleClose(message.data); - } else { - self.close(); - } - return; - }, - .pong => if (comptime std.meta.hasFn(H, "handlePong")) { - try handler.handlePong(); - }, - } - - if (has_more == false) { - // break out of the inner loop, and back into the outer loop - // to fill the buffer with more data from the socket - break; - } - } - - // Wondering why we don't call fill at the start of the outer loop? - // When we read the handshake response, we might already have a message - // (or part of a message) to process. So we always want to run reader.next() - // first, to process any initial message we might have. - // If we call fill first, we might block forever, despite there being - // an initial message waiting. - reader.fill(stream) catch |err| switch (err) { - error.Closed, error.ConnectionResetByPeer, error.BrokenPipe, error.NotOpenForReading => { - _ = @cmpxchgStrong(bool, &self._closed, false, true, .monotonic, .monotonic); + self.close(); + } return; }, - else => { - self.closeWithCode(1002); - return err; + .pong => if (comptime std.meta.hasFn(Handler, "handlePong")) { + try handler.handlePong(); }, + } + } + } + + pub fn read(self: *Client) !?proto.Message { + var reader = &self._reader; + const stream = &self.stream; + + while (true) { + // try to read a message from our buffer first, before trying to + // get more data from the socket. + const has_more, const message = reader.read() catch |err| { + self.closeWithCode(1002); + return err; + } orelse { + reader.fill(stream) catch |err| switch (err) { + error.WouldBlock => return null, + error.Closed, error.ConnectionResetByPeer, error.BrokenPipe, error.NotOpenForReading => { + _ = @cmpxchgStrong(bool, &self._closed, false, true, .monotonic, .monotonic); + return error.Closed; + }, + else => { + self.closeWithCode(1002); + return err; + }, + }; + continue; }; + + _ = has_more; + return message; } } + pub fn done(self:* Client, message: proto.Message) void { + self.reader.done(message.type); + } + pub fn readLoopInNewThread(self: *Client, h: anytype) !std.Thread { return std.Thread.spawn(.{}, readLoopOwnedThread, .{ self, h }); } @@ -219,6 +230,14 @@ pub const Client = struct { self.readLoop(h) catch {}; } + pub fn writeTimeout(self: *const Client, ms: u32) !void { + return self.stream.writeTimeout(ms); + } + + pub fn readTimeout(self: *const Client, ms: u32) !void { + return self.stream.readTimeout(ms); + } + pub fn write(self: *Client, data: []u8) !void { return self.writeFrame(.text, data); } @@ -311,31 +330,27 @@ pub const Stream = struct { const zero_timeout = std.mem.toBytes(posix.timeval{ .sec = 0, .usec = 0 }); pub fn writeTimeout(self: *const Stream, ms: u32) !void { - if (ms == 0) { - return self.setsockopt(posix.SO.SNDTIMEO, &zero_timeout); - } + return self.setTimeout(posix.SO.SNDTIMEO, ms); + } - const timeout = std.mem.toBytes(posix.timeval{ - .sec = @intCast(@divTrunc(ms, 1000)), - .usec = @intCast(@mod(ms, 1000) * 1000), - }); - return self.setsockopt(posix.SO.SNDTIMEO, &timeout); + pub fn readTimeout(self: *const Stream, ms: u32) !void { + return self.setTimeout(posix.SO.RCVTIMEO, ms); } - pub fn receiveTimeout(self: *const Stream, ms: u32) !void { + fn setTimeout(self: *const Stream, opt_name: u32, ms: u32) !void { if (ms == 0) { - return self.setsockopt(posix.SO.RCVTIMEO, &zero_timeout); + return self.setsockopt(opt_name, &zero_timeout); } const timeout = std.mem.toBytes(posix.timeval{ .sec = @intCast(@divTrunc(ms, 1000)), .usec = @intCast(@mod(ms, 1000) * 1000), }); - return self.setsockopt(posix.SO.RCVTIMEO, &timeout); + return self.setsockopt(opt_name, &timeout); } - pub fn setsockopt(self: *const Stream, optname: u32, value: []const u8) !void { - return posix.setsockopt(self.stream.handle, posix.SOL.SOCKET, optname, value); + pub fn setsockopt(self: *const Stream, opt_name: u32, value: []const u8) !void { + return posix.setsockopt(self.stream.handle, posix.SOL.SOCKET, opt_name, value); } }; @@ -402,7 +417,7 @@ fn readHandshakeReply(buf: []u8, key: []const u8, opts: *const Client.HandshakeO const timeout_ms = opts.timeout_ms; const deadline = std.time.milliTimestamp() + timeout_ms; - try stream.receiveTimeout(timeout_ms); + try stream.readTimeout(timeout_ms); var pos: usize = 0; var line_start: usize = 0; @@ -422,7 +437,7 @@ fn readHandshakeReply(buf: []u8, key: []const u8, opts: *const Client.HandshakeO } const over_read = pos - (line_start + 2); std.mem.copyForwards(u8, buf[0..over_read], buf[line_start + 2 .. pos]); - try stream.receiveTimeout(0); + try stream.readTimeout(0); return over_read; } diff --git a/support/autobahn/client/main.zig b/support/autobahn/client/main.zig index 95d385a..68d8829 100644 --- a/support/autobahn/client/main.zig +++ b/support/autobahn/client/main.zig @@ -1,6 +1,8 @@ const std = @import("std"); const websocket = @import("websocket"); +const Allocator = std.mem.Allocator; + pub fn main() !void { var gpa = std.heap.GeneralPurposeAllocator(.{}){}; defer _ = gpa.detectLeaks(); @@ -66,25 +68,9 @@ pub fn main() !void { // if (!std.mem.eql(u8, case, "1.1.1")) continue; std.debug.print("running case: {s}\n", .{case}); - const path = try std.fmt.allocPrint(allocator, "/runCase?casetuple={s}&agent=websocket.zig", .{case}); - defer allocator.free(path); - - var client = try websocket.Client.init(allocator, .{ - .port = 9001, - .host = "localhost", - .buffer_provider = &buffer_provider, - }); - defer client.deinit(); - const handler = Handler{.client = &client}; - - client.handshake(path, .{ - .timeout_ms = 500, - .headers = "host: localhost:9001\r\n", - }) catch continue; - - // optional, if you want to set a SO_SNDTIMEO on the socket - try client.stream.writeTimeout(5000); - client.readLoop(handler) catch |err| switch (err) { + var handler = try Handler.init(allocator, &buffer_provider, case); + defer handler.deinit(); + handler.readLoop() catch |err| switch (err) { // Each of these error cases reqiuer that we close the connection, as per // the spec. You probably just want to re-connect. But, for the autobahn tests // we don't want to shutdown, since autobahn is testing these invalid cases. @@ -116,9 +102,36 @@ fn updateReport(allocator: std.mem.Allocator) !void { } const Handler = struct { - client: *websocket.Client, + client: websocket.Client, + + fn init(allocator: Allocator, buffer_provider: *websocket.buffer.Provider, case: []const u8) !Handler { + const path = try std.fmt.allocPrint(allocator, "/runCase?casetuple={s}&agent=websocket.zig", .{case}); + defer allocator.free(path); + + var client = try websocket.Client.init(allocator, .{ + .port = 9001, + .host = "localhost", + .buffer_provider = buffer_provider, + }); + errdefer client.deinit(); + try client.handshake(path, .{ + .timeout_ms = 500, + .headers = "host: localhost:9001\r\n", + }); + return .{ + .client = client, + }; + } + + fn deinit(self: *Handler) void { + self.client.deinit(); + } + + fn readLoop(self: *Handler) !void { + return self.client.readLoop(self); + } - pub fn handleMessage(self: Handler, data: []const u8, tpe: websocket.Message.TextType) !void { + pub fn handleMessage(self: *Handler, data: []const u8, tpe: websocket.Message.TextType) !void { switch (tpe) { .binary => try self.client.writeBin(@constCast(data)), .text => {