Skip to content

Commit

Permalink
handle close during handshake
Browse files Browse the repository at this point in the history
  • Loading branch information
naoki9911 committed Dec 31, 2023
1 parent 6b20657 commit 0f1c747
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 34 deletions.
35 changes: 20 additions & 15 deletions src/client.zig
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ pub fn TLSClientImpl(comptime ReaderType: type, comptime WriterType: type, compt
// payload protection
cipher_suites: ArrayList(msg.CipherSuite),
ks: key.KeyScheduler,
hs_protector: RecordPayloadProtector,
ap_protector: RecordPayloadProtector,
hs_protector: ?RecordPayloadProtector,
ap_protector: ?RecordPayloadProtector,

// certificate
allow_self_signed: bool = false,
Expand Down Expand Up @@ -183,8 +183,8 @@ pub fn TLSClientImpl(comptime ReaderType: type, comptime WriterType: type, compt
.key_shares = ArrayList(NamedGroup).init(allocator),
.cipher_suites = ArrayList(msg.CipherSuite).init(allocator),
.ks = undefined,
.hs_protector = undefined,
.ap_protector = undefined,
.hs_protector = null,
.ap_protector = null,
.rootCA = crypto.root.RootCA.init(allocator),
.signature_schems = ArrayList(signature_scheme.SignatureScheme).init(allocator),
.cert_pubkeys = ArrayList(crypto.key.PublicKey).init(allocator),
Expand Down Expand Up @@ -474,7 +474,7 @@ pub fn TLSClientImpl(comptime ReaderType: type, comptime WriterType: type, compt
const recv_record = try TLSCipherText.decode(self.reader, t, self.allocator);
defer recv_record.deinit();

var plain_record = try self.hs_protector.decrypt(recv_record, self.allocator);
var plain_record = try self.hs_protector.?.decrypt(recv_record, self.allocator);
defer plain_record.deinit();

if (plain_record.content_type == .alert) {
Expand All @@ -499,7 +499,7 @@ pub fn TLSClientImpl(comptime ReaderType: type, comptime WriterType: type, compt
}

self.write_engine = .{
.protector = &self.ap_protector,
.protector = &self.ap_protector.?,
.ks = &self.ks,
.write_buffer = &self.write_buffer,
.allocator = self.allocator,
Expand Down Expand Up @@ -537,9 +537,14 @@ pub fn TLSClientImpl(comptime ReaderType: type, comptime WriterType: type, compt
}
}

// TODO: Send close notify as plain text
if (self.ap_protector == null) {
return;
}

// close connection
const close_notify = Content{ .alert = Alert{ .level = .warning, .description = .close_notify } };
_ = try self.ap_protector.encryptFromMessageAndWrite(
_ = try self.ap_protector.?.encryptFromMessageAndWrite(
close_notify,
self.allocator,
self.write_buffer.writer(),
Expand All @@ -559,7 +564,7 @@ pub fn TLSClientImpl(comptime ReaderType: type, comptime WriterType: type, compt
const recv_record = try TLSCipherText.decode(self.reader, t, self.allocator);
defer recv_record.deinit();

const plain_record = try self.ap_protector.decrypt(recv_record, self.allocator);
const plain_record = try self.ap_protector.?.decrypt(recv_record, self.allocator);
defer plain_record.deinit();

if (plain_record.content_type != .alert) {
Expand Down Expand Up @@ -756,7 +761,7 @@ pub fn TLSClientImpl(comptime ReaderType: type, comptime WriterType: type, compt
if (self.state == .SEND_FINISHED) {
// generate keys
try self.ks.generateApplicationSecrets(self.msgs_stream.getWritten());
self.ap_protector = RecordPayloadProtector.init(self.hs_protector.aead, self.ks.secret.c_ap_keys, self.ks.secret.s_ap_keys);
self.ap_protector = RecordPayloadProtector.init(self.hs_protector.?.aead, self.ks.secret.c_ap_keys, self.ks.secret.s_ap_keys);

if (self.early_data_ok) {
const eoed = Content{ .handshake = Handshake{ .end_of_early_data = [0]u8{} } };
Expand All @@ -773,7 +778,7 @@ pub fn TLSClientImpl(comptime ReaderType: type, comptime WriterType: type, compt
defer hs_c_finished.deinit();
_ = try hs_c_finished.encode(self.msgs_stream.writer());

_ = try self.hs_protector.encryptFromMessageAndWrite(hs_c_finished, self.allocator, writer);
_ = try self.hs_protector.?.encryptFromMessageAndWrite(hs_c_finished, self.allocator, writer);

self.state = .CONNECTED;
}
Expand Down Expand Up @@ -959,8 +964,8 @@ pub fn TLSClientImpl(comptime ReaderType: type, comptime WriterType: type, compt
pub fn handleKeyUpdate(self: *Self, ku: KeyUpdate) !void {
// update decoding key(server key)
try self.ks.updateServerSecrets();
self.ap_protector.dec_keys = self.ks.secret.s_ap_keys;
self.ap_protector.dec_cnt = 0;
self.ap_protector.?.dec_keys = self.ks.secret.s_ap_keys;
self.ap_protector.?.dec_cnt = 0;

switch (ku.request_update) {
.update_not_requested => {
Expand All @@ -971,13 +976,13 @@ pub fn TLSClientImpl(comptime ReaderType: type, comptime WriterType: type, compt
const update = Content{ .handshake = .{ .key_update = .{ .request_update = .update_not_requested } } };
defer update.deinit();

_ = try self.ap_protector.encryptFromMessageAndWrite(update, self.allocator, self.write_buffer.writer());
_ = try self.ap_protector.?.encryptFromMessageAndWrite(update, self.allocator, self.write_buffer.writer());
try self.write_buffer.flush();

// update encoding key(clieny key)
try self.ks.updateClientSecrets();
self.ap_protector.enc_keys = self.ks.secret.c_ap_keys;
self.ap_protector.enc_cnt = 0;
self.ap_protector.?.enc_keys = self.ks.secret.c_ap_keys;
self.ap_protector.?.enc_cnt = 0;
},
}
}
Expand Down
5 changes: 3 additions & 2 deletions src/common.zig
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ pub fn ReadEngine(comptime Entity: type, comptime et: EntityType) type {
return try msg_stream.getPos();
}

const updated = try checkAndUpdateKey(&self.entity.ap_protector, &self.entity.ks, &self.entity.write_buffer, self.entity.allocator, et);
const updated = try checkAndUpdateKey(&self.entity.ap_protector.?, &self.entity.ks, &self.entity.write_buffer, self.entity.allocator, et);

if (updated) {
log.debug("KeyUpdate updated_request has been sent", .{});
}
Expand All @@ -155,7 +156,7 @@ pub fn ReadEngine(comptime Entity: type, comptime et: EntityType) type {
const recv_record = try TLSCipherText.decode(self.entity.reader, t, self.entity.allocator);
defer recv_record.deinit();

const plain_record = try self.entity.ap_protector.decrypt(recv_record, self.entity.allocator);
const plain_record = try self.entity.ap_protector.?.decrypt(recv_record, self.entity.allocator);
defer plain_record.deinit();

if (plain_record.content_type != .application_data) {
Expand Down
1 change: 1 addition & 0 deletions src/protector.zig
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ pub const RecordPayloadProtector = struct {

i = 0;
while (i < @sizeOf(u64)) : (i += 1) {
std.debug.print("nonce={} i={}", .{ self.aead.nonce_length, i });
nonce.slice()[nonce.len - i - 1] = @as(u8, @intCast((count >> (@as(u6, @intCast(i * 8)))) & 0xFF));
}

Expand Down
39 changes: 22 additions & 17 deletions src/server.zig
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ pub fn TLSStreamImpl(comptime ReaderType: type, comptime WriterType: type, compt
// payload protection
cipher_suite: msg.CipherSuite,
ks: key.KeyScheduler,
hs_protector: RecordPayloadProtector,
ap_protector: RecordPayloadProtector,
hs_protector: ?RecordPayloadProtector,
ap_protector: ?RecordPayloadProtector,

// misc
allocator: std.mem.Allocator,
Expand Down Expand Up @@ -304,8 +304,8 @@ pub fn TLSStreamImpl(comptime ReaderType: type, comptime WriterType: type, compt
.secp256r1_key = try P256.KeyPair.fromSecretKey(try P256.SecretKey.fromBytes(secp256r1_priv_key)),
.cipher_suite = undefined,
.ks = undefined,
.hs_protector = undefined,
.ap_protector = undefined,
.hs_protector = null,
.ap_protector = null,
.allocator = allocator,
};

Expand Down Expand Up @@ -550,7 +550,7 @@ pub fn TLSStreamImpl(comptime ReaderType: type, comptime WriterType: type, compt
defer c_record.deinit();

// client may send early_data. Currently, server does not accept early_data.
var p_record = self.hs_protector.decrypt(c_record, self.allocator) catch |err| {
var p_record = self.hs_protector.?.decrypt(c_record, self.allocator) catch |err| {
switch (err) {
error.AuthenticationFailed => continue,
else => return err,
Expand Down Expand Up @@ -589,7 +589,7 @@ pub fn TLSStreamImpl(comptime ReaderType: type, comptime WriterType: type, compt
}

self.write_engine = .{
.protector = &self.ap_protector,
.protector = &self.ap_protector.?,
.ks = &self.ks,
.write_buffer = &self.write_buffer,
.allocator = self.allocator,
Expand Down Expand Up @@ -620,8 +620,13 @@ pub fn TLSStreamImpl(comptime ReaderType: type, comptime WriterType: type, compt
}
}

// TODO: Send close notify as plain text
if (self.ap_protector == null) {
return;
}

const close_notify = Content{ .alert = Alert{ .level = .warning, .description = .close_notify } };
_ = self.ap_protector.encryptFromMessageAndWrite(
_ = self.ap_protector.?.encryptFromMessageAndWrite(
close_notify,
self.allocator,
self.write_buffer.writer(),
Expand Down Expand Up @@ -882,7 +887,7 @@ pub fn TLSStreamImpl(comptime ReaderType: type, comptime WriterType: type, compt
defer cont_ee.deinit();
_ = try cont_ee.encode(self.msgs_stream.writer());

_ = try self.hs_protector.encryptFromMessageAndWrite(cont_ee, self.allocator, self.write_buffer.writer());
_ = try self.hs_protector.?.encryptFromMessageAndWrite(cont_ee, self.allocator, self.write_buffer.writer());

log.debug("EncryptedExtensions has been written to send buffer", .{});
}
Expand All @@ -903,7 +908,7 @@ pub fn TLSStreamImpl(comptime ReaderType: type, comptime WriterType: type, compt
}
var cont_c = Content{ .handshake = .{ .certificate = c } };
_ = try cont_c.encode(self.msgs_stream.writer());
_ = try self.hs_protector.encryptFromMessageAndWrite(cont_c, self.allocator, self.write_buffer.writer());
_ = try self.hs_protector.?.encryptFromMessageAndWrite(cont_c, self.allocator, self.write_buffer.writer());

log.debug("Certificate has been written to send buffer", .{});
}
Expand Down Expand Up @@ -957,7 +962,7 @@ pub fn TLSStreamImpl(comptime ReaderType: type, comptime WriterType: type, compt
const cont_cv = Content{ .handshake = .{ .certificate_verify = cv } };
defer cont_cv.deinit();
_ = try cont_cv.encode(self.msgs_stream.writer());
_ = try self.hs_protector.encryptFromMessageAndWrite(cont_cv, self.allocator, self.write_buffer.writer());
_ = try self.hs_protector.?.encryptFromMessageAndWrite(cont_cv, self.allocator, self.write_buffer.writer());

log.debug("CertificateVerify has been written to send buffer", .{});
}
Expand All @@ -968,7 +973,7 @@ pub fn TLSStreamImpl(comptime ReaderType: type, comptime WriterType: type, compt
defer cont_fin.deinit();
_ = try cont_fin.encode(self.msgs_stream.writer());

_ = try self.hs_protector.encryptFromMessageAndWrite(cont_fin, self.allocator, self.write_buffer.writer());
_ = try self.hs_protector.?.encryptFromMessageAndWrite(cont_fin, self.allocator, self.write_buffer.writer());

log.debug("Finished has been written to send buffer", .{});
}
Expand Down Expand Up @@ -1009,15 +1014,15 @@ pub fn TLSStreamImpl(comptime ReaderType: type, comptime WriterType: type, compt
const c = Content{ .handshake = Handshake{ .new_session_ticket = nst } };
defer c.deinit();

_ = try self.ap_protector.encryptFromMessageAndWrite(c, self.allocator, self.write_buffer.writer());
_ = try self.ap_protector.?.encryptFromMessageAndWrite(c, self.allocator, self.write_buffer.writer());
log.debug("NewSessionTicket has been written to send buffer", .{});
}

pub fn handleKeyUpdate(self: *Self, ku: KeyUpdate) !void {
// update decoding key(client key)
try self.ks.updateClientSecrets();
self.ap_protector.dec_keys = self.ks.secret.c_ap_keys;
self.ap_protector.dec_cnt = 0;
self.ap_protector.?.dec_keys = self.ks.secret.c_ap_keys;
self.ap_protector.?.dec_cnt = 0;

switch (ku.request_update) {
.update_not_requested => {
Expand All @@ -1028,13 +1033,13 @@ pub fn TLSStreamImpl(comptime ReaderType: type, comptime WriterType: type, compt
const update = Content{ .handshake = .{ .key_update = .{ .request_update = .update_not_requested } } };
defer update.deinit();

_ = try self.ap_protector.encryptFromMessageAndWrite(update, self.allocator, self.write_buffer.writer());
_ = try self.ap_protector.?.encryptFromMessageAndWrite(update, self.allocator, self.write_buffer.writer());
try self.write_buffer.flush();

// update encoding key(server key)
try self.ks.updateServerSecrets();
self.ap_protector.enc_keys = self.ks.secret.s_ap_keys;
self.ap_protector.enc_cnt = 0;
self.ap_protector.?.enc_keys = self.ks.secret.s_ap_keys;
self.ap_protector.?.enc_cnt = 0;
},
}
}
Expand Down

0 comments on commit 0f1c747

Please sign in to comment.