Skip to content

Commit 729a051

Browse files
mizuochikandrewrk
authored andcommitted
std.http: Fix segfault while redirecting
Make to avoid releasing request's connection twice. Change the `Request.connection` field optional. This field is null while the connection is released. Fixes #15965
1 parent e23d48e commit 729a051

File tree

2 files changed

+66
-25
lines changed

2 files changed

+66
-25
lines changed

lib/std/http/Client.zig

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,8 @@ pub const Response = struct {
451451
pub const Request = struct {
452452
uri: Uri,
453453
client: *Client,
454-
connection: *ConnectionPool.Node,
454+
/// is null when this connection is released
455+
connection: ?*ConnectionPool.Node,
455456

456457
method: http.Method,
457458
version: http.Version = .@"HTTP/1.1",
@@ -481,13 +482,14 @@ pub const Request = struct {
481482
req.response.parser.header_bytes.deinit(req.client.allocator);
482483
}
483484

484-
if (!req.response.parser.done) {
485-
// If the response wasn't fully read, then we need to close the connection.
486-
req.connection.data.closing = true;
485+
if (req.connection) |connection| {
486+
if (!req.response.parser.done) {
487+
// If the response wasn't fully read, then we need to close the connection.
488+
connection.data.closing = true;
489+
}
490+
req.client.connection_pool.release(req.client, connection);
487491
}
488492

489-
req.client.connection_pool.release(req.client, req.connection);
490-
491493
req.arena.deinit();
492494
req.* = undefined;
493495
}
@@ -504,7 +506,8 @@ pub const Request = struct {
504506
.zstd => |*zstd| zstd.deinit(),
505507
}
506508

507-
req.client.connection_pool.release(req.client, req.connection);
509+
req.client.connection_pool.release(req.client, req.connection.?);
510+
req.connection = null;
508511

509512
const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUrlScheme;
510513

@@ -534,7 +537,7 @@ pub const Request = struct {
534537

535538
/// Send the request to the server.
536539
pub fn start(req: *Request) StartError!void {
537-
var buffered = std.io.bufferedWriter(req.connection.data.writer());
540+
var buffered = std.io.bufferedWriter(req.connection.?.data.writer());
538541
const w = buffered.writer();
539542

540543
try w.writeAll(@tagName(req.method));
@@ -544,7 +547,7 @@ pub const Request = struct {
544547
try w.writeAll(req.uri.host.?);
545548
try w.writeByte(':');
546549
try w.print("{}", .{req.uri.port.?});
547-
} else if (req.connection.data.proxied) {
550+
} else if (req.connection.?.data.proxied) {
548551
// proxied connections require the full uri
549552
try w.print("{+/}", .{req.uri});
550553
} else {
@@ -625,7 +628,7 @@ pub const Request = struct {
625628

626629
var index: usize = 0;
627630
while (index == 0) {
628-
const amt = try req.response.parser.read(&req.connection.data, buf[index..], req.response.skip);
631+
const amt = try req.response.parser.read(&req.connection.?.data, buf[index..], req.response.skip);
629632
if (amt == 0 and req.response.parser.done) break;
630633
index += amt;
631634
}
@@ -643,23 +646,23 @@ pub const Request = struct {
643646
pub fn wait(req: *Request) WaitError!void {
644647
while (true) { // handle redirects
645648
while (true) { // read headers
646-
try req.connection.data.fill();
649+
try req.connection.?.data.fill();
647650

648-
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.peek());
649-
req.connection.data.drop(@intCast(u16, nchecked));
651+
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.data.peek());
652+
req.connection.?.data.drop(@intCast(u16, nchecked));
650653

651654
if (req.response.parser.state.isContent()) break;
652655
}
653656

654657
try req.response.parse(req.response.parser.header_bytes.items, false);
655658

656659
if (req.response.status == .switching_protocols) {
657-
req.connection.data.closing = false;
660+
req.connection.?.data.closing = false;
658661
req.response.parser.done = true;
659662
}
660663

661664
if (req.method == .CONNECT and req.response.status == .ok) {
662-
req.connection.data.closing = false;
665+
req.connection.?.data.closing = false;
663666
req.response.parser.done = true;
664667
}
665668

@@ -670,9 +673,9 @@ pub const Request = struct {
670673
const res_connection = req.response.headers.getFirstValue("connection");
671674
const res_keepalive = res_connection != null and !std.ascii.eqlIgnoreCase("close", res_connection.?);
672675
if (res_keepalive and (req_keepalive or req_connection == null)) {
673-
req.connection.data.closing = false;
676+
req.connection.?.data.closing = false;
674677
} else {
675-
req.connection.data.closing = true;
678+
req.connection.?.data.closing = true;
676679
}
677680

678681
if (req.response.transfer_encoding) |te| {
@@ -762,10 +765,10 @@ pub const Request = struct {
762765
const has_trail = !req.response.parser.state.isContent();
763766

764767
while (!req.response.parser.state.isContent()) { // read trailing headers
765-
try req.connection.data.fill();
768+
try req.connection.?.data.fill();
766769

767-
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.peek());
768-
req.connection.data.drop(@intCast(u16, nchecked));
770+
const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.data.peek());
771+
req.connection.?.data.drop(@intCast(u16, nchecked));
769772
}
770773

771774
if (has_trail) {
@@ -803,16 +806,16 @@ pub const Request = struct {
803806
pub fn write(req: *Request, bytes: []const u8) WriteError!usize {
804807
switch (req.transfer_encoding) {
805808
.chunked => {
806-
try req.connection.data.writer().print("{x}\r\n", .{bytes.len});
807-
try req.connection.data.writeAll(bytes);
808-
try req.connection.data.writeAll("\r\n");
809+
try req.connection.?.data.writer().print("{x}\r\n", .{bytes.len});
810+
try req.connection.?.data.writeAll(bytes);
811+
try req.connection.?.data.writeAll("\r\n");
809812

810813
return bytes.len;
811814
},
812815
.content_length => |*len| {
813816
if (len.* < bytes.len) return error.MessageTooLong;
814817

815-
const amt = try req.connection.data.write(bytes);
818+
const amt = try req.connection.?.data.write(bytes);
816819
len.* -= amt;
817820
return amt;
818821
},
@@ -832,7 +835,7 @@ pub const Request = struct {
832835
/// Finish the body of a request. This notifies the server that you have no more data to send.
833836
pub fn finish(req: *Request) FinishError!void {
834837
switch (req.transfer_encoding) {
835-
.chunked => try req.connection.data.writeAll("0\r\n\r\n"),
838+
.chunked => try req.connection.?.data.writeAll("0\r\n\r\n"),
836839
.content_length => |len| if (len != 0) return error.MessageNotCompleted,
837840
.none => {},
838841
}

test/standalone/http.zig

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,15 @@ fn handleRequest(res: *Server.Response) !void {
129129
try res.writeAll("Hello, ");
130130
try res.writeAll("Redirected!\n");
131131
try res.finish();
132+
} else if (mem.eql(u8, res.request.target, "/redirect/invalid")) {
133+
const invalid_port = try getUnusedTcpPort();
134+
const location = try std.fmt.allocPrint(salloc, "http://127.0.0.1:{d}", .{invalid_port});
135+
defer salloc.free(location);
136+
137+
res.status = .found;
138+
try res.headers.append("location", location);
139+
try res.do();
140+
try res.finish();
132141
} else {
133142
res.status = .not_found;
134143
try res.do();
@@ -180,6 +189,14 @@ fn killServer(addr: std.net.Address) void {
180189
conn.close();
181190
}
182191

192+
fn getUnusedTcpPort() !u16 {
193+
const addr = try std.net.Address.parseIp("127.0.0.1", 0);
194+
var s = std.net.StreamServer.init(.{});
195+
defer s.deinit();
196+
try s.listen(addr);
197+
return s.listen_address.in.getPort();
198+
}
199+
183200
pub fn main() !void {
184201
const log = std.log.scoped(.client);
185202

@@ -533,6 +550,27 @@ pub fn main() !void {
533550
// connection has been kept alive
534551
try testing.expect(client.connection_pool.free_len == 1);
535552

553+
{ // check client without segfault by connection error after redirection
554+
var h = http.Headers{ .allocator = calloc };
555+
defer h.deinit();
556+
557+
const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/redirect/invalid", .{port});
558+
defer calloc.free(location);
559+
const uri = try std.Uri.parse(location);
560+
561+
log.info("{s}", .{location});
562+
var req = try client.request(.GET, uri, h, .{});
563+
defer req.deinit();
564+
565+
try req.start();
566+
const result = req.wait();
567+
568+
try testing.expectError(error.ConnectionRefused, result); // expects not segfault but the regular error
569+
}
570+
571+
// connection has been kept alive
572+
try testing.expect(client.connection_pool.free_len == 1);
573+
536574
client.deinit();
537575

538576
killServer(server.socket.listen_address);

0 commit comments

Comments
 (0)