Skip to content

Commit

Permalink
make websocket.Client more composable
Browse files Browse the repository at this point in the history
  • Loading branch information
karlseguin committed Aug 8, 2024
1 parent fc97183 commit 31f5ca0
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 93 deletions.
159 changes: 87 additions & 72 deletions src/client/client.zig
Original file line number Diff line number Diff line change
Expand Up @@ -138,79 +138,90 @@ pub const Client = struct {

pub fn readLoop(self: *Client, handler: anytype) !void {
const H = @TypeOf(handler);
const Handler = switch (@typeInfo(H)) {
.Struct => H,
.Pointer => |ptr| ptr.child,
else => @compileError("readLoop handler must be a struct, got: " ++ @tagName(@typeInfo(H))),
};

var reader = &self._reader;
const stream = &self.stream;

defer if (comptime std.meta.hasFn(H, "close")) {
handler.close();
};

// block until we have data
try self.readTimeout(0);

while (true) {
// for a single fill, we might have multiple messages to process
while (true) {
const has_more, const message = reader.read() catch |err| {
self.closeWithCode(1002);
return err;
} orelse break; // orelse, we don't have enough data, so break out of the inner loop, and go get more data from the socket in the outer loop

const message_type = message.type;
defer reader.done(message_type);

switch (message_type) {
.text, .binary => {
switch (comptime @typeInfo(@TypeOf(H.handleMessage)).Fn.params.len) {
2 => try handler.handleMessage(message.data),
3 => try handler.handleMessage(message.data, if (message_type == .text) .text else .binary),
else => @compileError(@typeName(H) ++ ".handleMessage must accept 2 or 3 parameters"),
}
},
.ping => if (comptime std.meta.hasFn(H, "handlePing")) {
try handler.handlePing(message.data);
const message = (try self.read()) orelse unreachable;
const message_type = message.type;
defer reader.done(message_type);

switch (message_type) {
.text, .binary => {
switch (comptime @typeInfo(@TypeOf(Handler.handleMessage)).Fn.params.len) {
2 => try handler.handleMessage(message.data),
3 => try handler.handleMessage(message.data, if (message_type == .text) .text else .binary),
else => @compileError(@typeName(Handler) ++ ".handleMessage must accept 2 or 3 parameters"),
}
},
.ping => if (comptime std.meta.hasFn(Handler, "handlePing")) {
try handler.handlePing(message.data);
} else {
// @constCast is safe because we know message.data points to
// reader.buffer.buf, which we own and which can be mutated
try self.writeFrame(.pong, @constCast(message.data));
},
.close => {
if (comptime std.meta.hasFn(Handler, "handleClose")) {
try handler.handleClose(message.data);
} else {
// @constCast is safe because we know message.data points to
// reader.buffer.buf, which we own and which can be mutated
try self.writeFrame(.pong, @constCast(message.data));
},
.close => {
if (comptime std.meta.hasFn(H, "handleClose")) {
try handler.handleClose(message.data);
} else {
self.close();
}
return;
},
.pong => if (comptime std.meta.hasFn(H, "handlePong")) {
try handler.handlePong();
},
}

if (has_more == false) {
// break out of the inner loop, and back into the outer loop
// to fill the buffer with more data from the socket
break;
}
}

// Wondering why we don't call fill at the start of the outer loop?
// When we read the handshake response, we might already have a message
// (or part of a message) to process. So we always want to run reader.next()
// first, to process any initial message we might have.
// If we call fill first, we might block forever, despite there being
// an initial message waiting.
reader.fill(stream) catch |err| switch (err) {
error.Closed, error.ConnectionResetByPeer, error.BrokenPipe, error.NotOpenForReading => {
_ = @cmpxchgStrong(bool, &self._closed, false, true, .monotonic, .monotonic);
self.close();
}
return;
},
else => {
self.closeWithCode(1002);
return err;
.pong => if (comptime std.meta.hasFn(Handler, "handlePong")) {
try handler.handlePong();
},
}
}
}

pub fn read(self: *Client) !?proto.Message {
var reader = &self._reader;
const stream = &self.stream;

while (true) {
// try to read a message from our buffer first, before trying to
// get more data from the socket.
const has_more, const message = reader.read() catch |err| {
self.closeWithCode(1002);
return err;
} orelse {
reader.fill(stream) catch |err| switch (err) {
error.WouldBlock => return null,
error.Closed, error.ConnectionResetByPeer, error.BrokenPipe, error.NotOpenForReading => {
_ = @cmpxchgStrong(bool, &self._closed, false, true, .monotonic, .monotonic);
return error.Closed;
},
else => {
self.closeWithCode(1002);
return err;
},
};
continue;
};

_ = has_more;
return message;
}
}

pub fn done(self:* Client, message: proto.Message) void {
self.reader.done(message.type);
}

pub fn readLoopInNewThread(self: *Client, h: anytype) !std.Thread {
return std.Thread.spawn(.{}, readLoopOwnedThread, .{ self, h });
}
Expand All @@ -219,6 +230,14 @@ pub const Client = struct {
self.readLoop(h) catch {};
}

pub fn writeTimeout(self: *const Client, ms: u32) !void {
return self.stream.writeTimeout(ms);
}

pub fn readTimeout(self: *const Client, ms: u32) !void {
return self.stream.readTimeout(ms);
}

pub fn write(self: *Client, data: []u8) !void {
return self.writeFrame(.text, data);
}
Expand Down Expand Up @@ -311,31 +330,27 @@ pub const Stream = struct {

const zero_timeout = std.mem.toBytes(posix.timeval{ .sec = 0, .usec = 0 });
pub fn writeTimeout(self: *const Stream, ms: u32) !void {
if (ms == 0) {
return self.setsockopt(posix.SO.SNDTIMEO, &zero_timeout);
}
return self.setTimeout(posix.SO.SNDTIMEO, ms);
}

const timeout = std.mem.toBytes(posix.timeval{
.sec = @intCast(@divTrunc(ms, 1000)),
.usec = @intCast(@mod(ms, 1000) * 1000),
});
return self.setsockopt(posix.SO.SNDTIMEO, &timeout);
pub fn readTimeout(self: *const Stream, ms: u32) !void {
return self.setTimeout(posix.SO.RCVTIMEO, ms);
}

pub fn receiveTimeout(self: *const Stream, ms: u32) !void {
fn setTimeout(self: *const Stream, opt_name: u32, ms: u32) !void {
if (ms == 0) {
return self.setsockopt(posix.SO.RCVTIMEO, &zero_timeout);
return self.setsockopt(opt_name, &zero_timeout);
}

const timeout = std.mem.toBytes(posix.timeval{
.sec = @intCast(@divTrunc(ms, 1000)),
.usec = @intCast(@mod(ms, 1000) * 1000),
});
return self.setsockopt(posix.SO.RCVTIMEO, &timeout);
return self.setsockopt(opt_name, &timeout);
}

pub fn setsockopt(self: *const Stream, optname: u32, value: []const u8) !void {
return posix.setsockopt(self.stream.handle, posix.SOL.SOCKET, optname, value);
pub fn setsockopt(self: *const Stream, opt_name: u32, value: []const u8) !void {
return posix.setsockopt(self.stream.handle, posix.SOL.SOCKET, opt_name, value);
}
};

Expand Down Expand Up @@ -402,7 +417,7 @@ fn readHandshakeReply(buf: []u8, key: []const u8, opts: *const Client.HandshakeO

const timeout_ms = opts.timeout_ms;
const deadline = std.time.milliTimestamp() + timeout_ms;
try stream.receiveTimeout(timeout_ms);
try stream.readTimeout(timeout_ms);

var pos: usize = 0;
var line_start: usize = 0;
Expand All @@ -422,7 +437,7 @@ fn readHandshakeReply(buf: []u8, key: []const u8, opts: *const Client.HandshakeO
}
const over_read = pos - (line_start + 2);
std.mem.copyForwards(u8, buf[0..over_read], buf[line_start + 2 .. pos]);
try stream.receiveTimeout(0);
try stream.readTimeout(0);
return over_read;
}

Expand Down
55 changes: 34 additions & 21 deletions support/autobahn/client/main.zig
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
const std = @import("std");
const websocket = @import("websocket");

const Allocator = std.mem.Allocator;

pub fn main() !void {
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
defer _ = gpa.detectLeaks();
Expand Down Expand Up @@ -66,25 +68,9 @@ pub fn main() !void {
// if (!std.mem.eql(u8, case, "1.1.1")) continue;
std.debug.print("running case: {s}\n", .{case});

const path = try std.fmt.allocPrint(allocator, "/runCase?casetuple={s}&agent=websocket.zig", .{case});
defer allocator.free(path);

var client = try websocket.Client.init(allocator, .{
.port = 9001,
.host = "localhost",
.buffer_provider = &buffer_provider,
});
defer client.deinit();
const handler = Handler{.client = &client};

client.handshake(path, .{
.timeout_ms = 500,
.headers = "host: localhost:9001\r\n",
}) catch continue;

// optional, if you want to set a SO_SNDTIMEO on the socket
try client.stream.writeTimeout(5000);
client.readLoop(handler) catch |err| switch (err) {
var handler = try Handler.init(allocator, &buffer_provider, case);
defer handler.deinit();
handler.readLoop() catch |err| switch (err) {
// Each of these error cases reqiuer that we close the connection, as per
// the spec. You probably just want to re-connect. But, for the autobahn tests
// we don't want to shutdown, since autobahn is testing these invalid cases.
Expand Down Expand Up @@ -116,9 +102,36 @@ fn updateReport(allocator: std.mem.Allocator) !void {
}

const Handler = struct {
client: *websocket.Client,
client: websocket.Client,

fn init(allocator: Allocator, buffer_provider: *websocket.buffer.Provider, case: []const u8) !Handler {
const path = try std.fmt.allocPrint(allocator, "/runCase?casetuple={s}&agent=websocket.zig", .{case});
defer allocator.free(path);

var client = try websocket.Client.init(allocator, .{
.port = 9001,
.host = "localhost",
.buffer_provider = buffer_provider,
});
errdefer client.deinit();
try client.handshake(path, .{
.timeout_ms = 500,
.headers = "host: localhost:9001\r\n",
});
return .{
.client = client,
};
}

fn deinit(self: *Handler) void {
self.client.deinit();
}

fn readLoop(self: *Handler) !void {
return self.client.readLoop(self);
}

pub fn handleMessage(self: Handler, data: []const u8, tpe: websocket.Message.TextType) !void {
pub fn handleMessage(self: *Handler, data: []const u8, tpe: websocket.Message.TextType) !void {
switch (tpe) {
.binary => try self.client.writeBin(@constCast(data)),
.text => {
Expand Down

0 comments on commit 31f5ca0

Please sign in to comment.