From 5cc416bdb1dcb61bed1f9725d07f8b1dab091645 Mon Sep 17 00:00:00 2001 From: Karl Seguin Date: Tue, 18 Feb 2025 08:23:36 +0800 Subject: [PATCH 1/3] Upgrade to support latest Zig (0.14.0-dev.3239+d7b93c787) --- .github/workflows/zig-test.yml | 2 +- build.zig | 2 +- build.zig.zon | 2 +- src/std/http/Client.zig | 2 +- src/tls.zig/cipher.zig | 33 +++- src/tls.zig/connection.zig | 308 ++++++++++++++++++++++------- src/tls.zig/handshake_client.zig | 228 +++++++++++++++++++-- src/tls.zig/handshake_common.zig | 8 +- src/tls.zig/handshake_server.zig | 216 ++++++++++++++------ src/tls.zig/record.zig | 82 +++++++- src/tls.zig/{main.zig => root.zig} | 47 +++-- src/tls.zig/rsa/der.zig | 4 +- src/tls.zig/testdata/tls13.zig | 5 + 13 files changed, 747 insertions(+), 192 deletions(-) rename src/tls.zig/{main.zig => root.zig} (53%) diff --git a/.github/workflows/zig-test.yml b/.github/workflows/zig-test.yml index 2bd42bf..ff03a95 100644 --- a/.github/workflows/zig-test.yml +++ b/.github/workflows/zig-test.yml @@ -1,7 +1,7 @@ name: zig-test env: - ZIG_VERSION: 0.13.0 + ZIG_VERSION: 0.14.0-dev.3239+d7b93c787 on: push: diff --git a/build.zig b/build.zig index 743d99a..00532cc 100644 --- a/build.zig +++ b/build.zig @@ -17,7 +17,7 @@ pub fn build(b: *std.Build) void { const tests = b.addTest(.{ .root_source_file = b.path("src/tests.zig"), - .test_runner = b.path("src/test_runner.zig"), + .test_runner = .{ .path = b.path("src/test_runner.zig"), .mode = .simple }, .target = target, .optimize = optimize, }); diff --git a/build.zig.zon b/build.zig.zon index daa586c..a9292c2 100644 --- a/build.zig.zon +++ b/build.zig.zon @@ -1,7 +1,7 @@ .{ .name = "zig-async-io", .version = "0.1.0", - .minimum_zig_version = "0.13.0", + .minimum_zig_version = "0.14.0", .paths = .{ "build.zig", "build.zig.zon", diff --git a/src/std/http/Client.zig b/src/std/http/Client.zig index 2bab4e7..849023c 100644 --- a/src/std/http/Client.zig +++ b/src/std/http/Client.zig @@ -18,7 +18,7 @@ const use_vectors = builtin.zig_backend != .stage2_x86_64; const Client = @This(); const proto = @import("protocol.zig"); -const tls23 = @import("../../tls.zig/main.zig"); +const tls23 = @import("../../tls.zig/root.zig"); const VecPut = @import("../../tls.zig/connection.zig").VecPut; const GenericStack = @import("../../stack.zig").Stack; pub const IO = @import("../../io.zig").IO; diff --git a/src/tls.zig/cipher.zig b/src/tls.zig/cipher.zig index dbf4a07..45eb8bf 100644 --- a/src/tls.zig/cipher.zig +++ b/src/tls.zig/cipher.zig @@ -182,6 +182,12 @@ pub const Cipher = union(CipherSuite) { }; } + pub fn recordLen(c: *Cipher, cleartext_len: usize) usize { + return switch (c.*) { + inline else => |*f| f.recordLen(cleartext_len), + }; + } + pub fn encryptSeq(c: Cipher) u64 { return switch (c) { inline else => |f| f.encrypt_seq, @@ -276,6 +282,10 @@ fn Aead12Type(comptime AeadType: type) type { return buf[0..record_len]; } + pub fn recordLen(_: Self, cleartext_len: usize) usize { + return record.header_len + explicit_iv_len + cleartext_len + auth_tag_len; + } + /// Decrypts payload into cleartext. Returns tls record content type and /// cleartext. /// Accepts tls record header and payload: @@ -362,6 +372,10 @@ fn Aead12ChaChaType(comptime AeadType: type) type { return buf[0..record_len]; } + pub fn recordLen(_: Self, cleartext_len: usize) usize { + return record.header_len + cleartext_len + auth_tag_len; + } + /// Decrypts payload into cleartext. Returns tls record content type and /// cleartext. /// Accepts tls record header and payload: @@ -478,6 +492,11 @@ fn Aead13Type(comptime AeadType: type, comptime Hash: type) type { return buf[0..record_len]; } + pub fn recordLen(_: Self, cleartext_len: usize) usize { + const payload_len = cleartext_len + 1 + auth_tag_len; + return record.header_len + payload_len; + } + /// Decrypts payload into cleartext. Returns tls record content type and /// cleartext. /// Accepts tls record header and payload: @@ -569,8 +588,7 @@ fn CbcType(comptime BlockCipher: type, comptime HashType: type) type { content_type: proto.ContentType, cleartext: []const u8, ) ![]const u8 { - const max_record_len = record.header_len + iv_len + cleartext.len + mac_len + max_padding; - if (buf.len < max_record_len) return error.BufferOverflow; + if (buf.len < self.recordLen(cleartext.len)) return error.BufferOverflow; const cleartext_idx = record.header_len + iv_len; // position of cleartext in buf @memcpy(buf[cleartext_idx..][0..cleartext.len], cleartext); @@ -607,6 +625,12 @@ fn CbcType(comptime BlockCipher: type, comptime HashType: type) type { return buf[0 .. record.header_len + iv_len + ciphertext.len]; } + pub fn recordLen(_: Self, cleartext_len: usize) usize { + const unpadded_len = cleartext_len + mac_len; + const padded_len = paddedLength(unpadded_len); + return record.header_len + iv_len + padded_len; + } + /// Decrypts payload into cleartext. Returns tls record content type and /// cleartext. pub fn decrypt( @@ -678,7 +702,7 @@ fn additionalData(seq: u64, content_type: proto.ContentType, payload_len: usize) // https://ciphersuite.info/page/faq/ // https://github.com/golang/go/blob/73186ba00251b3ed8baaab36e4f5278c7681155b/src/crypto/tls/cipher_suites.go#L226 pub const cipher_suites = struct { - const tls12_secure = if (crypto.core.aes.has_hardware_support) [_]CipherSuite{ + pub const tls12_secure = if (crypto.core.aes.has_hardware_support) [_]CipherSuite{ // recommended .ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, .ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, @@ -698,7 +722,7 @@ pub const cipher_suites = struct { .ECDHE_RSA_WITH_AES_128_GCM_SHA256, .ECDHE_RSA_WITH_AES_256_GCM_SHA384, }; - const tls12_week = [_]CipherSuite{ + pub const tls12_week = [_]CipherSuite{ // week .ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, .ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, @@ -988,6 +1012,7 @@ fn encryptDecrypt(client_cipher: *Cipher, server_cipher: *Cipher) !void { }; }, }; + try testing.expectEqual(client_cipher.recordLen(cleartext.len), encrypted.len); try testing.expectEqual(expected_encrypted_len, encrypted.len); // decrypt const content_type, const decrypted = try server_cipher.decrypt(&buf, Record.init(encrypted)); diff --git a/src/tls.zig/connection.zig b/src/tls.zig/connection.zig index 7a6afcb..55e27fb 100644 --- a/src/tls.zig/connection.zig +++ b/src/tls.zig/connection.zig @@ -1,4 +1,5 @@ const std = @import("std"); +const mem = std.mem; const assert = std.debug.assert; const proto = @import("protocol.zig"); @@ -227,55 +228,6 @@ pub fn Connection(comptime Stream: type) type { return vp.total; } - fn onWriteAll(ctx: *Ctx, res: anyerror!void) anyerror!void { - res catch |err| return ctx.pop(err); - - if (ctx._tls_write_bytes.len - ctx._tls_write_index > 0) { - const rec = ctx.conn().tls_client.prepareRecord(ctx.stream(), ctx) catch |err| return ctx.pop(err); - return ctx.stream().async_writeAll(rec, ctx, onWriteAll) catch |err| return ctx.pop(err); - } - - return ctx.pop({}); - } - - pub fn async_writeAll(c: *Self, stream: anytype, bytes: []const u8, ctx: *Ctx, comptime cbk: Cbk) !void { - assert(bytes.len <= cipher.max_cleartext_len); - - ctx._tls_write_bytes = bytes; - ctx._tls_write_index = 0; - const rec = try c.prepareRecord(stream, ctx); - - try ctx.push(cbk); - return stream.async_writeAll(rec, ctx, onWriteAll); - } - - fn prepareRecord(c: *Self, stream: anytype, ctx: *Ctx) ![]const u8 { - const len = @min(ctx._tls_write_bytes.len - ctx._tls_write_index, cipher.max_cleartext_len); - - // If key update is requested send key update message and update - // my encryption keys. - if (c.cipher.encryptSeq() >= c.max_encrypt_seq or @atomicLoad(bool, &c.key_update_requested, .monotonic)) { - @atomicStore(bool, &c.key_update_requested, false, .monotonic); - - // If the request_update field is set to "update_requested", - // then the receiver MUST send a KeyUpdate of its own with - // request_update set to "update_not_requested" prior to sending - // its next Application Data record. This mechanism allows - // either side to force an update to the entire connection, but - // causes an implementation which receives multiple KeyUpdates - // while it is silent to respond with a single update. - // - // rfc: https://datatracker.ietf.org/doc/html/rfc8446#autoid-57 - const key_update = &record.handshakeHeader(.key_update, 1) ++ [_]u8{0}; - const rec = try c.cipher.encrypt(&ctx._tls_write_buf, .handshake, key_update); - try stream.writeAll(rec); // TODO async - try c.cipher.keyUpdateEncrypt(); - } - - defer ctx._tls_write_index += len; - return c.cipher.encrypt(&ctx._tls_write_buf, .application_data, ctx._tls_write_bytes[ctx._tls_write_index..len]); - } - fn onReadv(ctx: *Ctx, res: anyerror!void) anyerror!void { res catch |err| return ctx.pop(err); @@ -312,23 +264,16 @@ pub fn Connection(comptime Stream: type) type { return c.async_next(stream, ctx, onReadv); } - fn onNext(ctx: *Ctx, res: anyerror!void) anyerror!void { - res catch |err| { - ctx.conn().tls_client.writeAlert(err) catch |e| std.log.err("onNext: write alert: {any}", .{e}); // TODO async - return ctx.pop(err); - }; - - if (ctx._tls_read_content_type != .application_data) { - return ctx.pop(error.TlsUnexpectedMessage); - } + pub fn async_next(c: *Self, stream: anytype, ctx: *Ctx, comptime cbk: Cbk) !void { + try ctx.push(cbk); - return ctx.pop({}); + return c.async_next_decrypt(stream, ctx, onNext); } - pub fn async_next(c: *Self, stream: anytype, ctx: *Ctx, comptime cbk: Cbk) !void { + pub fn async_next_decrypt(c: *Self, stream: anytype, ctx: *Ctx, comptime cbk: Cbk) !void { try ctx.push(cbk); - return c.async_next_decrypt(stream, ctx, onNext); + return c.async_next_record(stream, ctx, onNextDecrypt) catch |err| return ctx.pop(err); } pub fn onNextDecrypt(ctx: *Ctx, res: anyerror!void) anyerror!void { @@ -383,10 +328,39 @@ pub fn Connection(comptime Stream: type) type { return ctx.pop({}); } - pub fn async_next_decrypt(c: *Self, stream: anytype, ctx: *Ctx, comptime cbk: Cbk) !void { - try ctx.push(cbk); + fn onNext(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |err| { + ctx.conn().tls_client.writeAlert(err) catch |e| std.log.err("onNext: write alert: {any}", .{e}); // TODO async + return ctx.pop(err); + }; - return c.async_next_record(stream, ctx, onNextDecrypt) catch |err| return ctx.pop(err); + if (ctx._tls_read_content_type != .application_data) { + return ctx.pop(error.TlsUnexpectedMessage); + } + + return ctx.pop({}); + } + + fn onWriteAll(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |err| return ctx.pop(err); + + if (ctx._tls_write_bytes.len - ctx._tls_write_index > 0) { + const rec = ctx.conn().tls_client.prepareRecord(ctx.stream(), ctx) catch |err| return ctx.pop(err); + return ctx.stream().async_writeAll(rec, ctx, onWriteAll) catch |err| return ctx.pop(err); + } + + return ctx.pop({}); + } + + pub fn async_writeAll(c: *Self, stream: anytype, bytes: []const u8, ctx: *Ctx, comptime cbk: Cbk) !void { + assert(bytes.len <= cipher.max_cleartext_len); + + ctx._tls_write_bytes = bytes; + ctx._tls_write_index = 0; + const rec = try c.prepareRecord(stream, ctx); + + try ctx.push(cbk); + return stream.async_writeAll(rec, ctx, onWriteAll); } pub fn onNextRecord(ctx: *Ctx, res: anyerror!void) anyerror!void { @@ -422,18 +396,35 @@ pub fn Connection(comptime Stream: type) type { return c.async_reader_next(stream, ctx, onNextRecord); } - pub fn onReaderNext(ctx: *Ctx, res: anyerror!void) anyerror!void { - res catch |err| return ctx.pop(err); + fn prepareRecord(c: *Self, stream: anytype, ctx: *Ctx) ![]const u8 { + const len = @min(ctx._tls_write_bytes.len - ctx._tls_write_index, cipher.max_cleartext_len); - const c = ctx.conn().tls_client; + // If key update is requested send key update message and update + // my encryption keys. + if (c.cipher.encryptSeq() >= c.max_encrypt_seq or @atomicLoad(bool, &c.key_update_requested, .monotonic)) { + @atomicStore(bool, &c.key_update_requested, false, .monotonic); - const n = ctx.len(); - if (n == 0) { - ctx._tls_read_record = null; - return ctx.pop({}); + // If the request_update field is set to "update_requested", + // then the receiver MUST send a KeyUpdate of its own with + // request_update set to "update_not_requested" prior to sending + // its next Application Data record. This mechanism allows + // either side to force an update to the entire connection, but + // causes an implementation which receives multiple KeyUpdates + // while it is silent to respond with a single update. + // + // rfc: https://datatracker.ietf.org/doc/html/rfc8446#autoid-57 + const key_update = &record.handshakeHeader(.key_update, 1) ++ [_]u8{0}; + const rec = try c.cipher.encrypt(&ctx._tls_write_buf, .handshake, key_update); + try stream.writeAll(rec); // TODO async + try c.cipher.keyUpdateEncrypt(); } - c.rec_rdr.end += n; + defer ctx._tls_write_index += len; + return c.cipher.encrypt(&ctx._tls_write_buf, .application_data, ctx._tls_write_bytes[ctx._tls_write_index..len]); + } + + pub fn async_reader_next(c: *Self, _: anytype, ctx: *Ctx, comptime cbk: Cbk) !void { + try ctx.push(cbk); return c.readNext(ctx); } @@ -470,8 +461,18 @@ pub fn Connection(comptime Stream: type) type { .async_read(c.rec_rdr.buffer[c.rec_rdr.end..], ctx, onReaderNext) catch |err| return ctx.pop(err); } - pub fn async_reader_next(c: *Self, _: anytype, ctx: *Ctx, comptime cbk: Cbk) !void { - try ctx.push(cbk); + pub fn onReaderNext(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |err| return ctx.pop(err); + + const c = ctx.conn().tls_client; + + const n = ctx.len(); + if (n == 0) { + ctx._tls_read_record = null; + return ctx.pop({}); + } + c.rec_rdr.end += n; + return c.readNext(ctx); } }; @@ -663,3 +664,162 @@ test "client/server connection" { try testing.expectEqualSlices(u8, sent, received); } } + +pub fn Async(comptime Handler: type, comptime HandshakeType: type, comptime Options: type) type { + // ClientType has to have this api: + // + // onHandshake() - notification that tcp handshake is done. + // onRecvCleartext(cleartext) - cleartext data to pass to application. + // sendCiphertext(buf) - ciphertext to pass to server (tcp connection). + // + // Api provided to the client: + // + // onConnect - should be called after tcp client connection is established + // onRecv - data received from the server + // send - data to send to the server + // onSend - tcp is done coping buffer to the kernel + // + // Client should establish tcp connection and call startHandshake. That will + // fire client.sendCiphertext with tls hello. For each raw tcp data client + // will call onRecv. During handshake that data will be consumed here. When + // handshake succeeds we will have cipher, release handshake and call + // client.onHandshake. + // + // After that client should call send with cleartext data, that will be + // encrypted and pass to client.sendCiphertext. Any raw ciphertext received + // on tcp should be pass to onRecv to decrypt and pass to + // client.onRecvCleartext. + // + return struct { + const Self = @This(); + + allocator: mem.Allocator, + handler: *Handler, + handshake: ?*HandshakeType = null, + cipher: ?Cipher = null, + + pub fn init(allocator: mem.Allocator, handler: *Handler, opt: Options) !Self { + const handshake = try allocator.create(HandshakeType); + errdefer allocator.destroy(handshake); + try handshake.init(opt); + return .{ + .allocator = allocator, + .handler = handler, + .handshake = handshake, + }; + } + + pub fn deinit(self: *Self) void { + if (self.handshake) |handshake| + self.allocator.destroy(handshake); + } + + // ----------------- client api + + /// Client has established tcp connection, start tls handshake + pub fn connect(self: *Self) !void { + try self.handshakeSend(); + } + + /// `bytes` are received on plain tcp connection. Use it in handshake or + /// if handshake is done decrypt and send to client. + pub fn recv(self: *Self, bytes: []u8) !usize { + return if (self.handshake) |_| + try self.handshakeRecv(bytes) + else + try self.decrypt(bytes); + } + + /// Client sends data; encrypt it and return to client via sendCiphertext. + pub fn send(self: *Self, bytes: []const u8) !void { + if (self.handshake != null) return error.InvalidState; + const chp = &(self.cipher orelse return error.InvalidState); + + var index: usize = 0; + while (index < bytes.len) { + // Split into max cleartext buffers + const n = @min(bytes.len, cipher.max_cleartext_len); + const buf = bytes[index..][0..n]; + index += n; + + // allocate ciphertext record buffer and encrypt into that buffer + const rec_buf = try self.allocator.alloc(u8, chp.recordLen(buf.len)); + errdefer self.allocator.free(rec_buf); + const rec = try chp.encrypt(rec_buf, .application_data, buf); + assert(rec.len == rec_buf.len); + // send ciphertext record + try self.handler.sendZc(rec); + } + } + + /// Buffer allocated in send is copied to the kernel, safe to free it + /// now. + pub fn onSend(self: *Self, buf: []const u8) void { + if (self.handshake) |_| { + self.checkHandshakeDone(); + } else { + self.allocator.free(buf); + } + } + + // ----------------- client api + + /// NOTE: decrypt reuses provided ciphertext buf for cleartext data + fn decrypt(self: *Self, buf: []u8) !usize { + const chp = &(self.cipher orelse return error.InvalidState); + + var rdr = record.bufferReader(buf); + while (true) { + const content_type, const cleartext = try rdr.nextDecrypt(chp) orelse break; + switch (content_type) { + .application_data => {}, + .handshake => { + // TODO handle key_update and new_session_ticket separately + continue; + }, + .alert => { + if (cleartext.len < 2) return error.TlsUnexpectedMessage; + try proto.Alert.parse(cleartext[0..2].*).toError(); + return error.EndOfFile; // close notify received + }, + else => { + //log.err("unexpected content_type {}", .{content_type}); + return error.TlsUnexpectedMessage; + }, + } + + assert(content_type == .application_data); + self.handler.onRecv(@constCast(cleartext)); + } + return rdr.bytesRead(); + } + + fn handshakeRecv(self: *Self, buf: []u8) !usize { + var handshake = self.handshake orelse unreachable; + const n = handshake.recv(buf) catch |err| switch (err) { + error.EndOfStream => 0, + else => return err, + }; + self.checkHandshakeDone(); + if (n > 0) try self.handshakeSend(); + return n; + } + + fn checkHandshakeDone(self: *Self) void { + var handshake = self.handshake orelse unreachable; + if (!handshake.done()) return; + + self.cipher = handshake.inner.cipher; + self.allocator.destroy(handshake); + self.handshake = null; + + self.handler.onConnect(); + } + + fn handshakeSend(self: *Self) !void { + var handshake = self.handshake orelse return; + if (try handshake.send()) |buf| + try self.handler.sendZc(buf); + } + }; +} diff --git a/src/tls.zig/handshake_client.zig b/src/tls.zig/handshake_client.zig index e7b48cf..74b7b09 100644 --- a/src/tls.zig/handshake_client.zig +++ b/src/tls.zig/handshake_client.zig @@ -2,6 +2,7 @@ const std = @import("std"); const assert = std.debug.assert; const crypto = std.crypto; const mem = std.mem; +const io = std.io; const Certificate = crypto.Certificate; const cipher = @import("cipher.zig"); @@ -23,6 +24,8 @@ const DhKeyPair = common.DhKeyPair; const CertBundle = common.CertBundle; const CertKeyPair = common.CertKeyPair; +const log = std.log.scoped(.tls); + pub const Options = struct { host: []const u8, /// Set of root certificate authorities that clients use when verifying @@ -46,7 +49,7 @@ pub const Options = struct { named_groups: []const proto.NamedGroup = supported_named_groups, /// Client authentication certificates and private key. - auth: ?CertKeyPair = null, + auth: ?*CertKeyPair = null, /// If this structure is provided it will be filled with handshake attributes /// at the end of the handshake process. @@ -102,12 +105,13 @@ pub fn Handshake(comptime Stream: type) type { const HandshakeT = @This(); + // `buf` is used for creating client messages and for decrypting server + // ciphertext messages. pub fn init(buf: []u8, rec_rdr: *RecordReaderT) HandshakeT { return .{ .client_random = undefined, .dh_kp = undefined, .rsa_secret = undefined, - //.now_sec = std.time.timestamp(), .buffer = buf, .rec_rdr = rec_rdr, }; @@ -115,7 +119,7 @@ pub fn Handshake(comptime Stream: type) type { fn initKeys( h: *HandshakeT, - named_groups: []const proto.NamedGroup, + opt: Options, ) !void { const init_keys_buf_len = 32 + 46 + DhKeyPair.seed_len; var buf: [init_keys_buf_len]u8 = undefined; @@ -123,7 +127,13 @@ pub fn Handshake(comptime Stream: type) type { h.client_random = buf[0..32].*; h.rsa_secret = RsaSecret.init(buf[32..][0..46].*); - h.dh_kp = try DhKeyPair.init(buf[32 + 46 ..][0..DhKeyPair.seed_len].*, named_groups); + h.dh_kp = try DhKeyPair.init(buf[32 + 46 ..][0..DhKeyPair.seed_len].*, opt.named_groups); + + h.cert = .{ + .host = opt.host, + .root_ca = opt.root_ca.bundle, + .skip_verify = opt.insecure_skip_verify, + }; } /// Handshake exchanges messages with server to get agreement about @@ -175,12 +185,7 @@ pub fn Handshake(comptime Stream: type) type { /// pub fn handshake(h: *HandshakeT, w: Stream, opt: Options) !Cipher { defer h.updateDiagnostic(opt); - try h.initKeys(opt.named_groups); - h.cert = .{ - .host = opt.host, - .root_ca = opt.root_ca.bundle, - .skip_verify = opt.insecure_skip_verify, - }; + try h.initKeys(opt); try w.writeAll(try h.makeClientHello(opt)); // client flight 1 try h.readServerFlight1(); // server flight 1 @@ -202,6 +207,36 @@ pub fn Handshake(comptime Stream: type) type { return h.cipher; } + fn clientFlight1(h: *HandshakeT, opt: Options) ![]const u8 { + return try h.makeClientHello(opt); + } + + fn serverFlight1(h: *HandshakeT, opt: Options) !void { + try h.readServerFlight1(); + h.transcript.use(h.cipher_suite.hash()); + if (h.tls_version == .tls_1_3) { + try h.generateHandshakeCipher(opt.key_log_callback); + try h.readEncryptedServerFlight1(); + } + } + + fn clientFlight2(h: *HandshakeT, opt: Options) ![]const u8 { + if (h.tls_version == .tls_1_3) { + const app_cipher = try h.generateApplicationCipher(opt.key_log_callback); + const buf = try h.makeClientFlight2Tls13(opt.auth); + h.cipher = app_cipher; + return buf; + } + // tls 1.2 specific handshake part + try h.generateCipher(opt.key_log_callback); + return try h.makeClientFlight2Tls12(opt.auth); + } + + fn serverFlight2(h: *HandshakeT, _: Options) !void { + if (h.tls_version == .tls_1_3) return; + try h.readServerFlight2(); + } + /// Prepare key material and generate cipher for TLS 1.2 fn generateCipher(h: *HandshakeT, key_log_callback: ?key_log.Callback) !void { try h.verifyCertificateSignatureTls12(); @@ -418,7 +453,7 @@ pub fn Handshake(comptime Stream: type) type { 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, }; - return std.mem.eql(u8, server_random, &hello_retry_request_magic); + return mem.eql(u8, server_random, &hello_retry_request_magic); } fn parseServerKeyExchange(h: *HandshakeT, d: *record.Decoder) !void { @@ -540,7 +575,7 @@ pub fn Handshake(comptime Stream: type) type { /// finished messages for tls 1.2. /// If client certificate is requested also adds client certificate and /// certificate verify messages. - fn makeClientFlight2Tls12(h: *HandshakeT, auth: ?CertKeyPair) ![]const u8 { + fn makeClientFlight2Tls12(h: *HandshakeT, auth: ?*CertKeyPair) ![]const u8 { var w = record.Writer{ .buf = h.buffer }; var cert_builder: ?CertificateBuilder = null; @@ -594,7 +629,7 @@ pub fn Handshake(comptime Stream: type) type { /// and client certificate verify messages are also created. If the /// server has requested certificate but the client is not configured /// empty certificate message is sent, as is required by rfc. - fn makeClientFlight2Tls13(h: *HandshakeT, auth: ?CertKeyPair) ![]const u8 { + fn makeClientFlight2Tls13(h: *HandshakeT, auth: ?*CertKeyPair) ![]const u8 { var w = record.Writer{ .buf = h.buffer }; // Client change cipher spec message @@ -631,7 +666,7 @@ pub fn Handshake(comptime Stream: type) type { return w.getWritten(); } - fn certificateBuilder(h: *HandshakeT, auth: CertKeyPair) CertificateBuilder { + fn certificateBuilder(h: *HandshakeT, auth: *CertKeyPair) CertificateBuilder { return .{ .bundle = auth.bundle, .key = auth.key, @@ -729,10 +764,10 @@ const data12 = @import("testdata/tls12.zig"); const data13 = @import("testdata/tls13.zig"); const testu = @import("testu.zig"); -fn testReader(data: []const u8) record.Reader(std.io.FixedBufferStream([]const u8)) { - return record.reader(std.io.fixedBufferStream(data)); +fn testReader(data: []const u8) record.Reader(io.FixedBufferStream([]const u8)) { + return record.reader(io.fixedBufferStream(data)); } -const TestHandshake = Handshake(std.io.FixedBufferStream([]const u8)); +const TestHandshake = Handshake(io.FixedBufferStream([]const u8)); test "parse tls 1.2 server hello" { var h = brk: { @@ -953,3 +988,162 @@ test "handshake verify server finished message" { h.transcript.update(&data12.client_finished); try h.readServerFlight2(); } + +pub const Async = struct { + const Self = @This(); + pub const Inner = Handshake([]u8); + + // inner sync handshake + inner: Inner = undefined, + opt: Options = undefined, + buffer: [cipher.max_ciphertext_record_len]u8 = undefined, + state: State = .none, + + const State = enum { + none, + init, + client_flight_1, + server_flight_1, + client_flight_2, + server_flight_2, + + fn next(self: *State) void { + self.* = @enumFromInt(@intFromEnum(self.*) + 1); + } + }; + + pub fn init(self: *Self, opt: Options) !void { + self.* = .{ + .inner = Inner.init(&self.buffer, undefined), + .opt = opt, + }; + try self.inner.initKeys(opt); + self.state = .init; + } + + // Returns null if there is nothing to send at this state + pub fn send(self: *Self) !?[]const u8 { + switch (self.state) { + .init => { + const buf = try self.inner.clientFlight1(self.opt); + self.state.next(); + return buf; + }, + .server_flight_1 => { + const buf = try self.inner.clientFlight2(self.opt); + self.state.next(); + return buf; + }, + else => return null, + } + } + + // Returns number of bytes consumed from buf + pub fn recv(self: *Self, buf: []u8) !usize { + const prev: Transcript = self.inner.transcript; + errdefer self.inner.transcript = prev; + + var rdr = record.bufferReader(buf); + self.inner.rec_rdr = &rdr; + + switch (self.state) { + .client_flight_1 => { + try self.inner.serverFlight1(self.opt); + self.state.next(); + }, + .client_flight_2 => { + try self.inner.serverFlight2(self.opt); + self.state.next(); + }, + else => return error.TlsUnexpectedMessage, + } + + return rdr.bytesRead(); + } + + pub fn done(self: *Self) bool { + const is_done = self.state == .server_flight_2 or + (self.inner.tls_version == .tls_1_3 and self.state == .client_flight_2); + if (is_done) { + self.inner.updateDiagnostic(self.opt); + } + return is_done; + } +}; + +test "async handshake" { + var ah: Async = .{}; + try ah.init(.{ + .host = "example.ulfheim.net", + .insecure_skip_verify = true, + .root_ca = .{}, + .cipher_suites = &[_]CipherSuite{CipherSuite.AES_256_GCM_SHA384}, + .named_groups = &[_]proto.NamedGroup{.x25519}, + }); + var h = &ah.inner; + { // update secrets to well known from example + h.client_random = data13.client_random; + h.dh_kp.x25519_kp = .{ + .public_key = data13.client_public_key, + .secret_key = data13.client_private_key, + }; + } + + const expected_client_flight_1 = testu.hexToBytes( + \\ 16 03 03 00 9c + \\ 01 00 00 98 + \\ 03 03 + \\ 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f + \\ 00 + \\ 00 02 13 02 + \\ 01 00 + \\ 00 6d + \\ 00 2b 00 03 02 03 04 + \\ 00 0d 00 14 00 12 04 03 05 03 08 04 08 05 08 06 08 07 02 01 04 01 05 01 + \\ 00 0a 00 04 00 02 00 1d 00 33 00 26 00 24 00 1d 00 20 + \\ 35 80 72 d6 36 58 80 d1 ae ea 32 9a df 91 21 38 38 51 ed 21 a2 8e 3b 75 e9 65 d0 d2 cd 16 62 54 + \\ 00 00 00 18 00 16 00 00 13 65 78 61 6d 70 6c 65 2e 75 6c 66 68 65 69 6d 2e 6e 65 74 + ); + const client_flight_1 = try ah.send(); + try testing.expectEqualSlices(u8, &expected_client_flight_1, client_flight_1.?); + + { // update transcript to well known from example + h.transcript = .{}; + h.transcript.update(data13.client_hello[record.header_len..]); + } + + // parsing partial server flight message returns error.EndOfStream + for (1..data13.server_flight_1.len - 1) |i| { + const buf = data13.server_flight_1[0..i]; + try testing.expectError(error.EndOfStream, ah.recv(@constCast(buf))); + } + + const n = try ah.recv(@constCast(&(data13.server_flight_1 ++ "dummy footer".*))); + { // inspect + try testing.expectEqual(data13.server_flight_1.len, n); // footer is not touched + try testing.expectEqual(.tls_1_3, h.tls_version); + try testing.expectEqual(.x25519, h.named_group); + try testing.expectEqualSlices(u8, &data13.server_random, &h.server_random); + try testing.expectEqual(.AES_256_GCM_SHA384, h.cipher_suite); + try testing.expectEqualSlices(u8, &data13.server_pub_key, h.server_pub_key); + try testing.expect(!ah.done()); + } + + const client_flight_2 = try ah.send(); + try testing.expectEqualSlices(u8, &data13.client_flight_2, client_flight_2.?); + try testing.expect(ah.done()); + + try testing.expectEqual(0, try ah.recv(@constCast("dummy footer"))); + try testing.expect(ah.done()); +} + +test "sizes" { + try testing.expectEqual(36576, @sizeOf(Async)); + try testing.expectEqual(19792, @sizeOf(Handshake([]u8))); + try testing.expectEqual(14384, @sizeOf(DhKeyPair)); + try testing.expectEqual(128, @sizeOf(Options)); + try testing.expectEqual(2792, @sizeOf(CertKeyPair)); + try testing.expectEqual(1736, @sizeOf(CertificateParser)); + try testing.expectEqual(48, @sizeOf(CertBundle)); + try testing.expectEqual(208, @sizeOf(Cipher)); +} diff --git a/src/tls.zig/handshake_common.zig b/src/tls.zig/handshake_common.zig index 178a3ce..7d0bf37 100644 --- a/src/tls.zig/handshake_common.zig +++ b/src/tls.zig/handshake_common.zig @@ -374,10 +374,10 @@ pub const DhKeyPair = struct { var kp: DhKeyPair = .{}; for (named_groups) |ng| switch (ng) { - .x25519 => kp.x25519_kp = try X25519.KeyPair.create(seed[0..][0..X25519.seed_length].*), - .secp256r1 => kp.secp256r1_kp = try EcdsaP256Sha256.KeyPair.create(seed[32..][0..EcdsaP256Sha256.KeyPair.seed_length].*), - .secp384r1 => kp.secp384r1_kp = try EcdsaP384Sha384.KeyPair.create(seed[32 + 32 ..][0..EcdsaP384Sha384.KeyPair.seed_length].*), - .x25519_kyber768d00 => kp.kyber768_kp = try Kyber768.KeyPair.create(seed[32 + 32 + 48 ..][0..Kyber768.seed_length].*), + .x25519 => kp.x25519_kp = try X25519.KeyPair.generateDeterministic(seed[0..][0..X25519.seed_length].*), + .secp256r1 => kp.secp256r1_kp = try EcdsaP256Sha256.KeyPair.generateDeterministic(seed[32..][0..EcdsaP256Sha256.KeyPair.seed_length].*), + .secp384r1 => kp.secp384r1_kp = try EcdsaP384Sha384.KeyPair.generateDeterministic(seed[32 + 32 ..][0..EcdsaP384Sha384.KeyPair.seed_length].*), + .x25519_kyber768d00 => kp.kyber768_kp = try Kyber768.KeyPair.generateDeterministic(seed[32 + 32 + 48 ..][0..Kyber768.seed_length].*), else => return error.TlsIllegalParameter, }; return kp; diff --git a/src/tls.zig/handshake_server.zig b/src/tls.zig/handshake_server.zig index c26e8c6..e655fe7 100644 --- a/src/tls.zig/handshake_server.zig +++ b/src/tls.zig/handshake_server.zig @@ -21,10 +21,12 @@ const DhKeyPair = common.DhKeyPair; const CertBundle = common.CertBundle; const CertKeyPair = common.CertKeyPair; +const log = std.log.scoped(.tls); + pub const Options = struct { /// Server authentication. If null server will not send Certificate and /// CertificateVerify message. - auth: ?CertKeyPair, + auth: ?*CertKeyPair, /// If not null server will request client certificate. If auth_type is /// .request empty client certificate message will be accepted. @@ -94,11 +96,7 @@ pub fn Handshake(comptime Stream: type) type { } pub fn handshake(h: *HandshakeT, stream: Stream, opt: Options) !Cipher { - crypto.random.bytes(&h.server_random); - if (opt.auth) |a| { - // required signature scheme in client hello - h.signature_scheme = a.key.signature_scheme; - } + h.initKeys(opt); h.readClientHello() catch |err| { try h.writeAlert(stream, null, err); @@ -106,73 +104,92 @@ pub fn Handshake(comptime Stream: type) type { }; h.transcript.use(h.cipher_suite.hash()); - const server_flight = brk: { - var w = record.Writer{ .buf = h.buffer }; + const server_flight = h.serverFlight(opt) catch |err| { + try h.writeAlert(stream, null, err); + return err; + }; + try stream.writeAll(server_flight); - const shared_key = h.sharedKey() catch |err| { - try h.writeAlert(stream, null, err); - return err; - }; - { - const hello = try h.makeServerHello(w.getFree()); - h.transcript.update(hello[record.header_len..]); - w.pos += hello.len; - } - { - const handshake_secret = h.transcript.handshakeSecret(shared_key); - h.cipher = try Cipher.initTls13(h.cipher_suite, handshake_secret, .server); - } - try w.writeRecord(.change_cipher_spec, &[_]u8{1}); - { - const encrypted_extensions = &record.handshakeHeader(.encrypted_extensions, 2) ++ [_]u8{ 0, 0 }; - h.transcript.update(encrypted_extensions); - try h.writeEncrypted(&w, encrypted_extensions); - } - if (opt.client_auth) |_| { - const certificate_request = try makeCertificateRequest(w.getPayload()); - h.transcript.update(certificate_request); - try h.writeEncrypted(&w, certificate_request); - } - if (opt.auth) |a| { - const cm = CertificateBuilder{ - .bundle = a.bundle, - .key = a.key, - .transcript = &h.transcript, - .side = .server, - }; - { - const certificate = try cm.makeCertificate(w.getPayload()); - h.transcript.update(certificate); - try h.writeEncrypted(&w, certificate); - } - { - const certificate_verify = try cm.makeCertificateVerify(w.getPayload()); - h.transcript.update(certificate_verify); - try h.writeEncrypted(&w, certificate_verify); - } - } - { - const finished = try h.makeFinished(w.getPayload()); - h.transcript.update(finished); - try h.writeEncrypted(&w, finished); + h.clientFlight2(opt) catch |err| { + // Alert received from client + if (!mem.startsWith(u8, @errorName(err), "TlsAlert")) { + try h.writeAlert(stream, &h.cipher, err); } - break :brk w.getWritten(); + return err; }; - try stream.writeAll(server_flight); + return h.cipher; + } + + fn initKeys(h: *HandshakeT, opt: Options) void { + crypto.random.bytes(&h.server_random); + if (opt.auth) |a| { + // required signature scheme in client hello + h.signature_scheme = a.key.signature_scheme; + } + } - var app_cipher = brk: { + fn clientFlight1(h: *HandshakeT) !void { + try h.readClientHello(); + h.transcript.use(h.cipher_suite.hash()); + } + + fn clientFlight2(h: *HandshakeT, opt: Options) !void { + const app_cipher = brk: { const application_secret = h.transcript.applicationSecret(); break :brk try Cipher.initTls13(h.cipher_suite, application_secret, .server); }; + defer h.cipher = app_cipher; + try h.readClientFlight2(opt); + } - h.readClientFlight2(opt) catch |err| { - // Alert received from client - if (!mem.startsWith(u8, @errorName(err), "TlsAlert")) { - try h.writeAlert(stream, &app_cipher, err); + fn serverFlight(h: *HandshakeT, opt: Options) ![]const u8 { + var w = record.Writer{ .buf = h.buffer }; + + const shared_key = try h.sharedKey(); + { + const hello = try h.makeServerHello(w.getFree()); + h.transcript.update(hello[record.header_len..]); + w.pos += hello.len; + } + { + const handshake_secret = h.transcript.handshakeSecret(shared_key); + h.cipher = try Cipher.initTls13(h.cipher_suite, handshake_secret, .server); + } + try w.writeRecord(.change_cipher_spec, &[_]u8{1}); + { + const encrypted_extensions = &record.handshakeHeader(.encrypted_extensions, 2) ++ [_]u8{ 0, 0 }; + h.transcript.update(encrypted_extensions); + try h.writeEncrypted(&w, encrypted_extensions); + } + if (opt.client_auth) |_| { + const certificate_request = try makeCertificateRequest(w.getPayload()); + h.transcript.update(certificate_request); + try h.writeEncrypted(&w, certificate_request); + } + if (opt.auth) |a| { + const cm = CertificateBuilder{ + .bundle = a.bundle, + .key = a.key, + .transcript = &h.transcript, + .side = .server, + }; + { + const certificate = try cm.makeCertificate(w.getPayload()); + h.transcript.update(certificate); + try h.writeEncrypted(&w, certificate); } - return err; - }; - return app_cipher; + { + const certificate_verify = try cm.makeCertificateVerify(w.getPayload()); + h.transcript.update(certificate_verify); + try h.writeEncrypted(&w, certificate_verify); + } + } + { + const finished = try h.makeFinished(w.getPayload()); + h.transcript.update(finished); + try h.writeEncrypted(&w, finished); + } + return w.getWritten(); } inline fn sharedKey(h: *HandshakeT) ![]const u8 { @@ -518,3 +535,74 @@ test "make certificate request" { const actual = try TestHandshake.makeCertificateRequest(&buffer); try testing.expectEqualSlices(u8, &expected, actual); } + +pub const Async = struct { + const Self = @This(); + pub const Inner = Handshake([]u8); + + // inner sync handshake + inner: Inner = undefined, + opt: Options = undefined, + buffer: [cipher.max_ciphertext_record_len]u8 = undefined, + state: State = .none, + + const State = enum { + none, + init, + client_flight_1, + server_flight, + client_flight_2, + + fn next(self: *State) void { + self.* = @enumFromInt(@intFromEnum(self.*) + 1); + } + }; + + pub fn init(self: *Self, opt: Options) !void { + self.* = .{ + .inner = Inner.init(&self.buffer, undefined), + .opt = opt, + }; + self.inner.initKeys(opt); + self.state = .init; + } + + // Returns null if there is nothing to send at this state + pub fn send(self: *Self) !?[]const u8 { + switch (self.state) { + .client_flight_1 => { + const buf = try self.inner.serverFlight(self.opt); + self.state.next(); + return buf; + }, + else => return null, + } + } + + // Returns number of bytes consumed from buf + pub fn recv(self: *Self, buf: []u8) !usize { + const prev: Transcript = self.inner.transcript; + errdefer self.inner.transcript = prev; + + var rdr = record.bufferReader(buf); + self.inner.rec_rdr = &rdr; + + switch (self.state) { + .init => { + try self.inner.clientFlight1(); + self.state.next(); + }, + .server_flight => { + try self.inner.clientFlight2(self.opt); + self.state.next(); + }, + else => return error.TlsUnexpectedMessage, + } + + return rdr.bytesRead(); + } + + pub fn done(self: *Self) bool { + return self.state == .client_flight_2; + } +}; diff --git a/src/tls.zig/record.zig b/src/tls.zig/record.zig index 6c4df32..4867697 100644 --- a/src/tls.zig/record.zig +++ b/src/tls.zig/record.zig @@ -1,6 +1,7 @@ const std = @import("std"); const assert = std.debug.assert; const mem = std.mem; +const io = std.io; const proto = @import("protocol.zig"); const cipher = @import("cipher.zig"); @@ -9,15 +10,25 @@ const record = @import("record.zig"); pub const header_len = 5; +pub inline fn int2(int: u16) [2]u8 { + var arr: [2]u8 = undefined; + std.mem.writeInt(u16, &arr, int, .big); + return arr; +} + +pub inline fn int3(int: u24) [3]u8 { + var arr: [3]u8 = undefined; + std.mem.writeInt(u24, &arr, int, .big); + return arr; +} + pub fn header(content_type: proto.ContentType, payload_len: usize) [header_len]u8 { - const int2 = std.crypto.tls.int2; return [1]u8{@intFromEnum(content_type)} ++ int2(@intFromEnum(proto.Version.tls_1_2)) ++ int2(@intCast(payload_len)); } pub fn handshakeHeader(handshake_type: proto.Handshake, payload_len: usize) [4]u8 { - const int3 = std.crypto.tls.int3; return [1]u8{@intFromEnum(handshake_type)} ++ int3(@intCast(payload_len)); } @@ -25,11 +36,20 @@ pub fn reader(inner_reader: anytype) Reader(@TypeOf(inner_reader)) { return .{ .inner_reader = inner_reader }; } +pub fn bufferReader(buf: []u8) Reader([]u8) { + return .{ + .inner_reader = undefined, + .buffer = buf, + .end = buf.len, + }; +} + pub fn Reader(comptime InnerReader: type) type { + const is_slice = isSlice(InnerReader); return struct { - inner_reader: InnerReader, + inner_reader: if (is_slice) void else InnerReader, - buffer: [cipher.max_ciphertext_record_len]u8 = undefined, + buffer: if (is_slice) InnerReader else [cipher.max_ciphertext_record_len]u8 = undefined, start: usize = 0, end: usize = 0, @@ -71,6 +91,8 @@ pub fn Reader(comptime InnerReader: type) type { return Record.init(buffer[0..record_len]); } } + if (is_slice) return null; + { // Move dirty part to the start of the buffer. const n = r.end - r.start; if (n > 0 and r.start > 0) { @@ -110,6 +132,10 @@ pub fn Reader(comptime InnerReader: type) type { pub fn hasMore(r: *ReaderT) bool { return r.end > r.start; } + + pub fn bytesRead(r: *ReaderT) usize { + return r.start; + } }; } @@ -147,7 +173,7 @@ pub const Decoder = struct { pub fn decode(d: *Decoder, comptime T: type) !T { switch (@typeInfo(T)) { - .Int => |info| switch (info.bits) { + .int => |info| switch (info.bits) { 8 => { try skip(d, 1); return d.payload[d.idx - 1]; @@ -167,7 +193,7 @@ pub const Decoder = struct { }, else => @compileError("unsupported int type: " ++ @typeName(T)), }, - .Enum => |info| { + .@"enum" => |info| { const int = try d.decode(info.tag_type); if (info.is_exhaustive) @compileError("exhaustive enum cannot be used"); return @as(T, @enumFromInt(int)); @@ -221,7 +247,7 @@ const testu = @import("testu.zig"); const CipherSuite = @import("cipher.zig").CipherSuite; test Reader { - var fbs = std.io.fixedBufferStream(&data12.server_responses); + var fbs = io.fixedBufferStream(&data12.server_responses); var rdr = reader(fbs.reader()); const expected = [_]struct { @@ -241,10 +267,26 @@ test Reader { try testing.expectEqual(e.payload_len, rec.payload.len); try testing.expectEqual(.tls_1_2, rec.protocol_version); } + + { + var fr = bufferReader(@constCast(&data12.server_responses)); + var n: usize = 0; + for (expected) |e| { + const rec = (try fr.next()).?; + try testing.expectEqual(e.content_type, rec.content_type); + try testing.expectEqual(e.payload_len, rec.payload.len); + try testing.expectEqual(.tls_1_2, rec.protocol_version); + + n += rec.payload.len + record.header_len; + try testing.expectEqual(n, fr.bytesRead()); + } + try testing.expectEqual(data12.server_responses.len, fr.bytesRead()); + try testing.expect(try fr.next() == null); + } } test Decoder { - var fbs = std.io.fixedBufferStream(&data12.server_responses); + var fbs = io.fixedBufferStream(&data12.server_responses); var rdr = reader(fbs.reader()); var d = (try rdr.nextDecoder()); @@ -288,7 +330,7 @@ pub const Writer = struct { pub fn writeInt(self: *Writer, value: anytype) !void { const IntT = @TypeOf(value); - const bytes = @divExact(@typeInfo(IntT).Int.bits, 8); + const bytes = @divExact(@typeInfo(IntT).int.bits, 8); const free = self.buf[self.pos..]; if (free.len < bytes) return error.BufferOverflow; mem.writeInt(IntT, free[0..bytes], value, .big); @@ -403,3 +445,25 @@ test "Writer" { try w.writeInt(@as(u16, 0x1234)); try testing.expectEqualSlices(u8, &[_]u8{ 'a', 'b', 0x03, 0x00, 0x1d, 0x12, 0x34 }, w.getWritten()); } + +test isSlice { + try comptime testing.expect(isSlice([]const u8)); + try comptime testing.expect(isSlice([]u8)); + try comptime testing.expect(!isSlice(io.FixedBufferStream([]u8))); +} + +test "sizes" { + try testing.expectEqual(32, @sizeOf(Reader([]u8))); + try testing.expectEqual(32, @sizeOf(Reader([]const u8))); + try testing.expectEqual(16688, @sizeOf(Reader(io.FixedBufferStream([]u8)))); +} + +fn isSlice(comptime T: type) bool { + return switch (@typeInfo(T)) { + .pointer => |ptr_info| switch (ptr_info.size) { + .slice => true, + else => false, + }, + else => false, + }; +} diff --git a/src/tls.zig/main.zig b/src/tls.zig/root.zig similarity index 53% rename from src/tls.zig/main.zig rename to src/tls.zig/root.zig index b974377..6bbdfbc 100644 --- a/src/tls.zig/main.zig +++ b/src/tls.zig/root.zig @@ -1,26 +1,15 @@ const std = @import("std"); - -pub const CipherSuite = @import("cipher.zig").CipherSuite; -pub const cipher_suites = @import("cipher.zig").cipher_suites; -pub const PrivateKey = @import("PrivateKey.zig"); -pub const Connection = @import("connection.zig").Connection; -pub const ClientOptions = @import("handshake_client.zig").Options; -pub const ServerOptions = @import("handshake_server.zig").Options; -pub const key_log = @import("key_log.zig"); pub const proto = @import("protocol.zig"); -pub const NamedGroup = proto.NamedGroup; -pub const Version = proto.Version; const common = @import("handshake_common.zig"); -pub const CertBundle = common.CertBundle; -pub const CertKeyPair = common.CertKeyPair; pub const record = @import("record.zig"); const connection = @import("connection.zig").connection; const max_ciphertext_record_len = @import("cipher.zig").max_ciphertext_record_len; const HandshakeServer = @import("handshake_server.zig").Handshake; const HandshakeClient = @import("handshake_client.zig").Handshake; +pub const Connection = @import("connection.zig").Connection; -pub fn client(stream: anytype, opt: ClientOptions) !Connection(@TypeOf(stream)) { +pub fn client(stream: anytype, opt: config.Client) !Connection(@TypeOf(stream)) { const Stream = @TypeOf(stream); var conn = connection(stream); var write_buf: [max_ciphertext_record_len]u8 = undefined; @@ -29,7 +18,7 @@ pub fn client(stream: anytype, opt: ClientOptions) !Connection(@TypeOf(stream)) return conn; } -pub fn server(stream: anytype, opt: ServerOptions) !Connection(@TypeOf(stream)) { +pub fn server(stream: anytype, opt: config.Server) !Connection(@TypeOf(stream)) { const Stream = @TypeOf(stream); var conn = connection(stream); var write_buf: [max_ciphertext_record_len]u8 = undefined; @@ -38,6 +27,34 @@ pub fn server(stream: anytype, opt: ServerOptions) !Connection(@TypeOf(stream)) return conn; } +pub const config = struct { + pub const CipherSuite = @import("cipher.zig").CipherSuite; + pub const PrivateKey = @import("PrivateKey.zig"); + pub const NamedGroup = proto.NamedGroup; + pub const Version = proto.Version; + pub const CertBundle = common.CertBundle; + pub const CertKeyPair = common.CertKeyPair; + + pub const cipher_suites = @import("cipher.zig").cipher_suites; + pub const key_log = @import("key_log.zig"); + + pub const Client = @import("handshake_client.zig").Options; + pub const Server = @import("handshake_server.zig").Options; +}; + +pub const asyn = struct { + const Async = @import("connection.zig").Async; + const _hc = @import("handshake_client.zig"); + const _hs = @import("handshake_server.zig"); + + pub fn Client(T: type) type { + return Async(T, _hc.Async, _hc.Options); + } + pub fn Server(T: type) type { + return Async(T, _hs.Async, _hs.Options); + } +}; + test { _ = @import("handshake_common.zig"); _ = @import("handshake_server.zig"); @@ -49,3 +66,5 @@ test { _ = @import("transcript.zig"); _ = @import("PrivateKey.zig"); } + +pub const CertBundle = @compileError("deprecated: use config.CertBundle, see:https://github.com/ianic/tls.zig/commit/c028a2845d546298fdac3a1d3e3849090c8fc1ff"); diff --git a/src/tls.zig/rsa/der.zig b/src/tls.zig/rsa/der.zig index 743a65a..1489bc0 100644 --- a/src/tls.zig/rsa/der.zig +++ b/src/tls.zig/rsa/der.zig @@ -116,10 +116,10 @@ pub const Parser = struct { const bytes = self.view(ele); const info = @typeInfo(T); - if (info != .Int) @compileError(@typeName(T) ++ " is not an int type"); + if (info != .int) @compileError(@typeName(T) ++ " is not an int type"); const Shift = std.math.Log2Int(u8); - var result: std.meta.Int(.unsigned, info.Int.bits) = 0; + var result: std.meta.Int(.unsigned, info.int.bits) = 0; for (bytes, 0..) |b, index| { const shifted = @shlWithOverflow(b, @as(Shift, @intCast(index * 8))); if (shifted[1] == 1) return error.Overflow; diff --git a/src/tls.zig/testdata/tls13.zig b/src/tls.zig/testdata/tls13.zig index f98f9ff..6045fd8 100644 --- a/src/tls.zig/testdata/tls13.zig +++ b/src/tls.zig/testdata/tls13.zig @@ -62,3 +62,8 @@ pub const server_flight = server_certificate_wrapped ++ server_certificate_verify_wrapped ++ server_finished_wrapped; + +pub const server_flight_1 = server_hello ++ server_flight; + +pub const client_flight_2 = + hexToBytes("140303000101") ++ client_finished_wrapped; From 9f2e42b97e9fc63d0575175ce2ac0675c7e08c4e Mon Sep 17 00:00:00 2001 From: Karl Seguin Date: Tue, 4 Mar 2025 15:28:58 +0800 Subject: [PATCH 2/3] Don't clear the context stack in pop Currently, when pop returns, the stack is deinitialized if empty. I assume this is an optimize to free memory ASAP, without having to wait for deinit. But currently, the only way to free a Ctx is from a callback being executed within pop. You have to call `ctx.deinit` within `ctx.pop`, and you possibly have other things to clean up too, like deallocating a heap-allocated Ctx. This makes _any_ operations after @call returns dangerous. Consider this code: https://github.com/lightpanda-io/browser/blob/6ae4ed9fc33eb82d7206d52692d8616737d11bf5/src/xhr/xhr.zig#L617 This calls ctx.deinit() _within_ a call to ctx.pop, it then sets the union referencing ctx to null. After that point, the ctx variable/memory should not be used. --- src/std/http/Client.zig | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/std/http/Client.zig b/src/std/http/Client.zig index 849023c..484babe 100644 --- a/src/std/http/Client.zig +++ b/src/std/http/Client.zig @@ -2469,14 +2469,6 @@ pub const Ctx = struct { if (self.stack) |stack| { const allocator = self.alloc(); const func = stack.pop(allocator, null); - - defer { - if (self.stack != null and self.stack.?.next == null) { - allocator.destroy(self.stack.?); - self.stack = null; - } - } - return @call(.auto, func, .{ self, res }); } unreachable; From cc38625c827575c628b2a6ebddfc3531204b5c0d Mon Sep 17 00:00:00 2001 From: Karl Seguin Date: Thu, 6 Mar 2025 18:56:38 +0800 Subject: [PATCH 3/3] Zig 0.14 compatibility --- .github/workflows/zig-fmt.yml | 2 +- .github/workflows/zig-test.yml | 2 +- build.zig.zon | 3 ++- src/std/http/Client.zig | 24 ++++++++++++------------ src/tls.zig/connection.zig | 24 ++++++++++++------------ 5 files changed, 28 insertions(+), 27 deletions(-) diff --git a/.github/workflows/zig-fmt.yml b/.github/workflows/zig-fmt.yml index 83079f7..067f3ec 100644 --- a/.github/workflows/zig-fmt.yml +++ b/.github/workflows/zig-fmt.yml @@ -1,7 +1,7 @@ name: zig-fmt env: - ZIG_VERSION: 0.13.0 + ZIG_VERSION: 0.14.0 on: pull_request: diff --git a/.github/workflows/zig-test.yml b/.github/workflows/zig-test.yml index ff03a95..827cd1a 100644 --- a/.github/workflows/zig-test.yml +++ b/.github/workflows/zig-test.yml @@ -1,7 +1,7 @@ name: zig-test env: - ZIG_VERSION: 0.14.0-dev.3239+d7b93c787 + ZIG_VERSION: 0.14.0 on: push: diff --git a/build.zig.zon b/build.zig.zon index a9292c2..2809616 100644 --- a/build.zig.zon +++ b/build.zig.zon @@ -1,7 +1,8 @@ .{ - .name = "zig-async-io", + .name = .zig_async_io, .version = "0.1.0", .minimum_zig_version = "0.14.0", + .fingerprint = 0xec4ef418b22755ea, .paths = .{ "build.zig", "build.zig.zon", diff --git a/src/std/http/Client.zig b/src/std/http/Client.zig index 484babe..aed3ae2 100644 --- a/src/std/http/Client.zig +++ b/src/std/http/Client.zig @@ -1122,13 +1122,13 @@ pub const Request = struct { pub const WaitError = RequestError || SendError || TransferReadError || proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || error{ // TODO: file zig fmt issue for this bad indentation - TooManyHttpRedirects, - RedirectRequiresResend, - HttpRedirectLocationMissing, - HttpRedirectLocationInvalid, - CompressionInitializationFailed, - CompressionUnsupported, - }; + TooManyHttpRedirects, + RedirectRequiresResend, + HttpRedirectLocationMissing, + HttpRedirectLocationInvalid, + CompressionInitializationFailed, + CompressionUnsupported, + }; pub fn async_wait(_: *Request, ctx: *Ctx, comptime cbk: Cbk) !void { try ctx.push(cbk); @@ -1997,12 +1997,12 @@ pub fn async_connect( pub const RequestError = ConnectTcpError || ConnectErrorPartial || Request.SendError || std.fmt.ParseIntError || Connection.WriteError || error{ // TODO: file a zig fmt issue for this bad indentation - UnsupportedUriScheme, - UriMissingHost, + UnsupportedUriScheme, + UriMissingHost, - CertificateBundleLoadFailure, - UnsupportedTransferEncoding, -}; + CertificateBundleLoadFailure, + UnsupportedTransferEncoding, + }; pub const RequestOptions = struct { version: http.Version = .@"HTTP/1.1", diff --git a/src/tls.zig/connection.zig b/src/tls.zig/connection.zig index 55e27fb..c10ef5b 100644 --- a/src/tls.zig/connection.zig +++ b/src/tls.zig/connection.zig @@ -135,20 +135,20 @@ pub fn Connection(comptime Stream: type) type { pub const ReadError = Stream.ReadError || proto.Alert.Error || error{ - TlsBadVersion, - TlsUnexpectedMessage, - TlsRecordOverflow, - TlsDecryptError, - TlsDecodeError, - TlsBadRecordMac, - TlsIllegalParameter, - BufferOverflow, - }; + TlsBadVersion, + TlsUnexpectedMessage, + TlsRecordOverflow, + TlsDecryptError, + TlsDecodeError, + TlsBadRecordMac, + TlsIllegalParameter, + BufferOverflow, + }; pub const WriteError = Stream.WriteError || error{ - BufferOverflow, - TlsUnexpectedMessage, - }; + BufferOverflow, + TlsUnexpectedMessage, + }; pub const Reader = std.io.Reader(*Self, ReadError, read); pub const Writer = std.io.Writer(*Self, WriteError, write);