Skip to content

std.heap.PageAllocator updates to fix race condition and utilize NtAllocateVirtualMemory / NtFreeVirtualMemory instead of VirtualAlloc / VirtualFree #23097

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 65 additions & 62 deletions lib/std/heap/PageAllocator.zig
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@ const maxInt = std.math.maxInt;
const assert = std.debug.assert;
const native_os = builtin.os.tag;
const windows = std.os.windows;
const ntdll = windows.ntdll;
const posix = std.posix;
const page_size_min = std.heap.page_size_min;

const SUCCESS = @import("../os/windows/ntstatus.zig").NTSTATUS.SUCCESS;
const MEM_RESERVE_PLACEHOLDER = windows.MEM_RESERVE_PLACEHOLDER;
const MEM_PRESERVE_PLACEHOLDER = windows.MEM_PRESERVE_PLACEHOLDER;

pub const vtable: Allocator.VTable = .{
.alloc = alloc,
.resize = resize,
Expand All @@ -22,51 +27,62 @@ pub fn map(n: usize, alignment: mem.Alignment) ?[*]u8 {
const alignment_bytes = alignment.toByteUnits();

if (native_os == .windows) {
// According to official documentation, VirtualAlloc aligns to page
// boundary, however, empirically it reserves pages on a 64K boundary.
// Since it is very likely the requested alignment will be honored,
// this logic first tries a call with exactly the size requested,
// before falling back to the loop below.
// https://devblogs.microsoft.com/oldnewthing/?p=42223
const addr = windows.VirtualAlloc(
null,
// VirtualAlloc will round the length to a multiple of page size.
// "If the lpAddress parameter is NULL, this value is rounded up to
// the next page boundary".
n,
windows.MEM_COMMIT | windows.MEM_RESERVE,
windows.PAGE_READWRITE,
) catch return null;

if (mem.isAligned(@intFromPtr(addr), alignment_bytes))
return @ptrCast(addr);

// Fallback: reserve a range of memory large enough to find a
// sufficiently aligned address, then free the entire range and
// immediately allocate the desired subset. Another thread may have won
// the race to map the target range, in which case a retry is needed.
windows.VirtualFree(addr, 0, windows.MEM_RELEASE);
var base_addr: ?*anyopaque = null;
var size: windows.SIZE_T = n;

var status = ntdll.NtAllocateVirtualMemory(windows.GetCurrentProcess(), @ptrCast(&base_addr), 0, &size, windows.MEM_COMMIT | windows.MEM_RESERVE, windows.PAGE_READWRITE);

if (status == SUCCESS and mem.isAligned(@intFromPtr(base_addr), alignment_bytes)) {
return @ptrCast(base_addr);
}

if (status == SUCCESS) {
var region_size: windows.SIZE_T = 0;
_ = ntdll.NtFreeVirtualMemory(windows.GetCurrentProcess(), @ptrCast(&base_addr), &region_size, windows.MEM_RELEASE);
}

const overalloc_len = n + alignment_bytes - page_size;
const aligned_len = mem.alignForward(usize, n, page_size);

while (true) {
const reserved_addr = windows.VirtualAlloc(
null,
overalloc_len,
windows.MEM_RESERVE,
windows.PAGE_NOACCESS,
) catch return null;
const aligned_addr = mem.alignForward(usize, @intFromPtr(reserved_addr), alignment_bytes);
windows.VirtualFree(reserved_addr, 0, windows.MEM_RELEASE);
const ptr = windows.VirtualAlloc(
@ptrFromInt(aligned_addr),
aligned_len,
windows.MEM_COMMIT | windows.MEM_RESERVE,
windows.PAGE_READWRITE,
) catch continue;
return @ptrCast(ptr);
base_addr = null;
size = overalloc_len;

status = ntdll.NtAllocateVirtualMemory(windows.GetCurrentProcess(), @ptrCast(&base_addr), 0, &size, windows.MEM_RESERVE | MEM_RESERVE_PLACEHOLDER, windows.PAGE_NOACCESS);

if (status != SUCCESS) return null;

const placeholder_addr = @intFromPtr(base_addr);
const aligned_addr = mem.alignForward(usize, placeholder_addr, alignment_bytes);
const prefix_size = aligned_addr - placeholder_addr;

if (prefix_size > 0) {
var prefix_base = base_addr;
var prefix_size_param: windows.SIZE_T = prefix_size;
_ = ntdll.NtFreeVirtualMemory(windows.GetCurrentProcess(), @ptrCast(&prefix_base), &prefix_size_param, windows.MEM_RELEASE | MEM_PRESERVE_PLACEHOLDER);
}

const suffix_start = aligned_addr + aligned_len;
const suffix_size = (placeholder_addr + overalloc_len) - suffix_start;
if (suffix_size > 0) {
var suffix_base = @as(?*anyopaque, @ptrFromInt(suffix_start));
var suffix_size_param: windows.SIZE_T = suffix_size;
_ = ntdll.NtFreeVirtualMemory(windows.GetCurrentProcess(), @ptrCast(&suffix_base), &suffix_size_param, windows.MEM_RELEASE | MEM_PRESERVE_PLACEHOLDER);
}

base_addr = @ptrFromInt(aligned_addr);
size = aligned_len;

status = ntdll.NtAllocateVirtualMemory(windows.GetCurrentProcess(), @ptrCast(&base_addr), 0, &size, windows.MEM_COMMIT | MEM_PRESERVE_PLACEHOLDER, windows.PAGE_READWRITE);

if (status == SUCCESS) {
return @ptrCast(base_addr);
}

base_addr = @as(?*anyopaque, @ptrFromInt(aligned_addr));
size = aligned_len;
_ = ntdll.NtFreeVirtualMemory(windows.GetCurrentProcess(), @ptrCast(&base_addr), &size, windows.MEM_RELEASE);

return null;
}

const aligned_len = mem.alignForward(usize, n, page_size);
Expand Down Expand Up @@ -104,26 +120,14 @@ fn alloc(context: *anyopaque, n: usize, alignment: mem.Alignment, ra: usize) ?[*
return map(n, alignment);
}

fn resize(
context: *anyopaque,
memory: []u8,
alignment: mem.Alignment,
new_len: usize,
return_address: usize,
) bool {
fn resize(context: *anyopaque, memory: []u8, alignment: mem.Alignment, new_len: usize, return_address: usize) bool {
_ = context;
_ = alignment;
_ = return_address;
return realloc(memory, new_len, false) != null;
}

fn remap(
context: *anyopaque,
memory: []u8,
alignment: mem.Alignment,
new_len: usize,
return_address: usize,
) ?[*]u8 {
fn remap(context: *anyopaque, memory: []u8, alignment: mem.Alignment, new_len: usize, return_address: usize) ?[*]u8 {
_ = context;
_ = alignment;
_ = return_address;
Expand All @@ -139,7 +143,9 @@ fn free(context: *anyopaque, memory: []u8, alignment: mem.Alignment, return_addr

pub fn unmap(memory: []align(page_size_min) u8) void {
if (native_os == .windows) {
windows.VirtualFree(memory.ptr, 0, windows.MEM_RELEASE);
var base_addr: ?*anyopaque = memory.ptr;
var region_size: windows.SIZE_T = 0;
_ = ntdll.NtFreeVirtualMemory(windows.GetCurrentProcess(), @ptrCast(&base_addr), &region_size, windows.MEM_RELEASE);
} else {
const page_aligned_len = mem.alignForward(usize, memory.len, std.heap.pageSize());
posix.munmap(memory.ptr[0..page_aligned_len]);
Expand All @@ -157,13 +163,10 @@ pub fn realloc(uncasted_memory: []u8, new_len: usize, may_move: bool) ?[*]u8 {
const old_addr_end = base_addr + memory.len;
const new_addr_end = mem.alignForward(usize, base_addr + new_len, page_size);
if (old_addr_end > new_addr_end) {
// For shrinking that is not releasing, we will only decommit
// the pages not needed anymore.
windows.VirtualFree(
@ptrFromInt(new_addr_end),
old_addr_end - new_addr_end,
windows.MEM_DECOMMIT,
);
var decommit_addr: ?*anyopaque = @ptrFromInt(new_addr_end);
var decommit_size: windows.SIZE_T = old_addr_end - new_addr_end;

_ = ntdll.NtAllocateVirtualMemory(windows.GetCurrentProcess(), @ptrCast(&decommit_addr), 0, &decommit_size, windows.MEM_RESET, windows.PAGE_NOACCESS);
}
return memory.ptr;
}
Expand Down
34 changes: 34 additions & 0 deletions lib/std/os/windows.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1758,6 +1758,38 @@ pub fn TerminateProcess(hProcess: HANDLE, uExitCode: UINT) TerminateProcessError
}
}

pub const NtAllocateVirtualMemoryError = error{
AccessDenied,
InvalidParameter,
NoMemory,
Unexpected,
};

pub fn NtAllocateVirtualMemory(hProcess: HANDLE, addr: ?*PVOID, zero_bits: ULONG_PTR, size: ?*SIZE_T, alloc_type: ULONG, protect: ULONG) NtAllocateVirtualMemoryError!void {
return switch (ntdll.NtAllocateVirtualMemory(hProcess, addr, zero_bits, size, alloc_type, protect)) {
.SUCCESS => return,
.ACCESS_DENIED => NtAllocateVirtualMemoryError.AccessDenied,
.INVALID_PARAMETER => NtAllocateVirtualMemoryError.InvalidParameter,
.NO_MEMORY => NtAllocateVirtualMemoryError.NoMemory,
else => |st| unexpectedStatus(st),
};
}

pub const NtFreeVirtualMemoryError = error{
AccessDenied,
InvalidParameter,
Unexpected,
};

pub fn NtFreeVirtualMemory(hProcess: HANDLE, addr: ?*PVOID, size: *SIZE_T, free_type: ULONG) NtFreeVirtualMemoryError!void {
return switch (ntdll.NtFreeVirtualMemory(hProcess, addr, size, free_type)) {
.SUCCESS => return,
.ACCESS_DENIED => NtFreeVirtualMemoryError.AccessDenied,
.INVALID_PARAMETER => NtFreeVirtualMemoryError.InvalidParameter,
else => NtFreeVirtualMemoryError.Unexpected,
};
}

pub const VirtualAllocError = error{Unexpected};

pub fn VirtualAlloc(addr: ?LPVOID, size: usize, alloc_type: DWORD, flProtect: DWORD) VirtualAllocError!LPVOID {
Expand Down Expand Up @@ -3539,6 +3571,8 @@ pub const MEM_LARGE_PAGES = 0x20000000;
pub const MEM_PHYSICAL = 0x400000;
pub const MEM_TOP_DOWN = 0x100000;
pub const MEM_WRITE_WATCH = 0x200000;
pub const MEM_RESERVE_PLACEHOLDER = 0x00040000;
pub const MEM_PRESERVE_PLACEHOLDER = 0x00000400;

// Protect values
pub const PAGE_EXECUTE = 0x10;
Expand Down
17 changes: 17 additions & 0 deletions lib/std/os/windows/ntdll.zig
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ const BOOL = windows.BOOL;
const DWORD = windows.DWORD;
const DWORD64 = windows.DWORD64;
const ULONG = windows.ULONG;
const ULONG_PTR = windows.ULONG_PTR;
const NTSTATUS = windows.NTSTATUS;
const WORD = windows.WORD;
const HANDLE = windows.HANDLE;
Expand Down Expand Up @@ -358,3 +359,19 @@ pub extern "ntdll" fn NtCreateNamedPipeFile(
OutboundQuota: ULONG,
DefaultTimeout: *LARGE_INTEGER,
) callconv(.winapi) NTSTATUS;

pub extern "ntdll" fn NtAllocateVirtualMemory(
ProcessHandle: HANDLE,
BaseAddress: ?*PVOID,
ZeroBits: ULONG_PTR,
RegionSize: ?*SIZE_T,
AllocationType: ULONG,
PageProtection: ULONG,
) callconv(.winapi) NTSTATUS;

pub extern "ntdll" fn NtFreeVirtualMemory(
ProcessHandle: HANDLE,
BaseAddress: ?*PVOID,
RegionSize: *SIZE_T,
FreeType: ULONG,
) callconv(.winapi) NTSTATUS;