Skip to content

Commit 87d09ed

Browse files
authored
Merge pull request #17352 from kcbanner/extern_union_comptime_memory
sema: Support reinterpreting extern/packed unions at comptime via field access
2 parents c933a7c + 1b8a50e commit 87d09ed

File tree

7 files changed

+376
-109
lines changed

7 files changed

+376
-109
lines changed

src/Module.zig

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6607,6 +6607,7 @@ pub fn unionFieldNormalAlignment(mod: *Module, u: InternPool.UnionType, field_in
66076607
return field_ty.abiAlignment(mod);
66086608
}
66096609

6610+
/// Returns the index of the active field, given the current tag value
66106611
pub fn unionTagFieldIndex(mod: *Module, u: InternPool.UnionType, enum_tag: Value) ?u32 {
66116612
const ip = &mod.intern_pool;
66126613
if (enum_tag.toIntern() == .none) return null;

src/Sema.zig

Lines changed: 110 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27258,7 +27258,7 @@ fn unionFieldVal(
2725827258
return sema.failWithOwnedErrorMsg(block, msg);
2725927259
}
2726027260
},
27261-
.Packed, .Extern => {
27261+
.Packed, .Extern => |layout| {
2726227262
if (tag_matches) {
2726327263
return Air.internedToRef(un.val);
2726427264
} else {
@@ -27267,7 +27267,7 @@ fn unionFieldVal(
2726727267
else
2726827268
union_ty.unionFieldType(un.tag.toValue(), mod).?;
2726927269

27270-
if (try sema.bitCastVal(block, src, un.val.toValue(), old_ty, field_ty, 0)) |new_val| {
27270+
if (try sema.bitCastUnionFieldVal(block, src, un.val.toValue(), old_ty, field_ty, layout)) |new_val| {
2727127271
return Air.internedToRef(new_val.toIntern());
2727227272
}
2727327273
}
@@ -29788,13 +29788,19 @@ fn storePtrVal(
2978829788
error.IllDefinedMemoryLayout => unreachable, // Sema was supposed to emit a compile error already
2978929789
error.Unimplemented => return sema.fail(block, src, "TODO: implement writeToMemory for type '{}'", .{mut_kit.ty.fmt(mod)}),
2979029790
};
29791-
operand_val.writeToMemory(operand_ty, mod, buffer[reinterpret.byte_offset..]) catch |err| switch (err) {
29792-
error.OutOfMemory => return error.OutOfMemory,
29793-
error.ReinterpretDeclRef => unreachable,
29794-
error.IllDefinedMemoryLayout => unreachable, // Sema was supposed to emit a compile error already
29795-
error.Unimplemented => return sema.fail(block, src, "TODO: implement writeToMemory for type '{}'", .{operand_ty.fmt(mod)}),
29796-
};
29797-
29791+
if (reinterpret.write_packed) {
29792+
operand_val.writeToPackedMemory(operand_ty, mod, buffer[reinterpret.byte_offset..], 0) catch |err| switch (err) {
29793+
error.OutOfMemory => return error.OutOfMemory,
29794+
error.ReinterpretDeclRef => unreachable,
29795+
};
29796+
} else {
29797+
operand_val.writeToMemory(operand_ty, mod, buffer[reinterpret.byte_offset..]) catch |err| switch (err) {
29798+
error.OutOfMemory => return error.OutOfMemory,
29799+
error.ReinterpretDeclRef => unreachable,
29800+
error.IllDefinedMemoryLayout => unreachable, // Sema was supposed to emit a compile error already
29801+
error.Unimplemented => return sema.fail(block, src, "TODO: implement writeToMemory for type '{}'", .{operand_ty.fmt(mod)}),
29802+
};
29803+
}
2979829804
const val = Value.readFromMemory(mut_kit.ty, mod, buffer, sema.arena) catch |err| switch (err) {
2979929805
error.OutOfMemory => return error.OutOfMemory,
2980029806
error.IllDefinedMemoryLayout => unreachable,
@@ -29826,6 +29832,8 @@ const ComptimePtrMutationKit = struct {
2982629832
reinterpret: struct {
2982729833
val_ptr: *Value,
2982829834
byte_offset: usize,
29835+
/// If set, write the operand to packed memory
29836+
write_packed: bool = false,
2982929837
},
2983029838
/// If the root decl could not be used as parent, this means `ty` is the type that
2983129839
/// caused that by not having a well-defined layout.
@@ -30189,21 +30197,43 @@ fn beginComptimePtrMutation(
3018930197
);
3019030198
},
3019130199
.@"union" => {
30192-
// We need to set the active field of the union.
30193-
const union_tag_ty = base_child_ty.unionTagTypeHypothetical(mod);
30194-
3019530200
const payload = &val_ptr.castTag(.@"union").?.data;
30196-
payload.tag = try mod.enumValueFieldIndex(union_tag_ty, field_index);
30201+
const layout = base_child_ty.containerLayout(mod);
3019730202

30198-
return beginComptimePtrMutationInner(
30199-
sema,
30200-
block,
30201-
src,
30202-
parent.ty.structFieldType(field_index, mod),
30203-
&payload.val,
30204-
ptr_elem_ty,
30205-
parent.mut_decl,
30206-
);
30203+
const tag_type = base_child_ty.unionTagTypeHypothetical(mod);
30204+
const hypothetical_tag = try mod.enumValueFieldIndex(tag_type, field_index);
30205+
if (layout == .Auto or (payload.tag != null and hypothetical_tag.eql(payload.tag.?, tag_type, mod))) {
30206+
// We need to set the active field of the union.
30207+
payload.tag = hypothetical_tag;
30208+
30209+
const field_ty = parent.ty.structFieldType(field_index, mod);
30210+
return beginComptimePtrMutationInner(
30211+
sema,
30212+
block,
30213+
src,
30214+
field_ty,
30215+
&payload.val,
30216+
ptr_elem_ty,
30217+
parent.mut_decl,
30218+
);
30219+
} else {
30220+
// Writing to a different field (a different or unknown tag is active) requires reinterpreting
30221+
// memory of the entire union, which requires knowing its abiSize.
30222+
try sema.resolveTypeLayout(parent.ty);
30223+
30224+
// This union value no longer has a well-defined tag type.
30225+
// The reinterpretation will read it back out as .none.
30226+
payload.val = try payload.val.unintern(sema.arena, mod);
30227+
return ComptimePtrMutationKit{
30228+
.mut_decl = parent.mut_decl,
30229+
.pointee = .{ .reinterpret = .{
30230+
.val_ptr = val_ptr,
30231+
.byte_offset = 0,
30232+
.write_packed = layout == .Packed,
30233+
} },
30234+
.ty = parent.ty,
30235+
};
30236+
}
3020730237
},
3020830238
.slice => switch (field_index) {
3020930239
Value.slice_ptr_index => return beginComptimePtrMutationInner(
@@ -30704,6 +30734,7 @@ fn bitCastVal(
3070430734
// For types with well-defined memory layouts, we serialize them a byte buffer,
3070530735
// then deserialize to the new type.
3070630736
const abi_size = try sema.usizeCast(block, src, old_ty.abiSize(mod));
30737+
3070730738
const buffer = try sema.gpa.alloc(u8, abi_size);
3070830739
defer sema.gpa.free(buffer);
3070930740
val.writeToMemory(old_ty, mod, buffer) catch |err| switch (err) {
@@ -30720,6 +30751,63 @@ fn bitCastVal(
3072030751
};
3072130752
}
3072230753

30754+
fn bitCastUnionFieldVal(
30755+
sema: *Sema,
30756+
block: *Block,
30757+
src: LazySrcLoc,
30758+
val: Value,
30759+
old_ty: Type,
30760+
field_ty: Type,
30761+
layout: std.builtin.Type.ContainerLayout,
30762+
) !?Value {
30763+
const mod = sema.mod;
30764+
if (old_ty.eql(field_ty, mod)) return val;
30765+
30766+
const old_size = try sema.usizeCast(block, src, old_ty.abiSize(mod));
30767+
const field_size = try sema.usizeCast(block, src, field_ty.abiSize(mod));
30768+
const endian = mod.getTarget().cpu.arch.endian();
30769+
30770+
const buffer = try sema.gpa.alloc(u8, @max(old_size, field_size));
30771+
defer sema.gpa.free(buffer);
30772+
30773+
// Reading a larger value means we need to reinterpret from undefined bytes.
30774+
const offset = switch (layout) {
30775+
.Extern => offset: {
30776+
if (field_size > old_size) @memset(buffer[old_size..], 0xaa);
30777+
val.writeToMemory(old_ty, mod, buffer) catch |err| switch (err) {
30778+
error.OutOfMemory => return error.OutOfMemory,
30779+
error.ReinterpretDeclRef => return null,
30780+
error.IllDefinedMemoryLayout => unreachable, // Sema was supposed to emit a compile error already
30781+
error.Unimplemented => return sema.fail(block, src, "TODO: implement writeToMemory for type '{}'", .{old_ty.fmt(mod)}),
30782+
};
30783+
break :offset 0;
30784+
},
30785+
.Packed => offset: {
30786+
if (field_size > old_size) {
30787+
const min_size = @max(old_size, 1);
30788+
switch (endian) {
30789+
.Little => @memset(buffer[min_size - 1 ..], 0xaa),
30790+
.Big => @memset(buffer[0 .. buffer.len - min_size + 1], 0xaa),
30791+
}
30792+
}
30793+
30794+
val.writeToPackedMemory(old_ty, mod, buffer, 0) catch |err| switch (err) {
30795+
error.OutOfMemory => return error.OutOfMemory,
30796+
error.ReinterpretDeclRef => return null,
30797+
};
30798+
30799+
break :offset if (endian == .Big) buffer.len - field_size else 0;
30800+
},
30801+
.Auto => unreachable,
30802+
};
30803+
30804+
return Value.readFromMemory(field_ty, mod, buffer[offset..], sema.arena) catch |err| switch (err) {
30805+
error.OutOfMemory => return error.OutOfMemory,
30806+
error.IllDefinedMemoryLayout => unreachable,
30807+
error.Unimplemented => return sema.fail(block, src, "TODO: implement readFromMemory for type '{}'", .{field_ty.fmt(mod)}),
30808+
};
30809+
}
30810+
3072330811
fn coerceArrayPtrToSlice(
3072430812
sema: *Sema,
3072530813
block: *Block,

src/TypedValue.zig

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,22 +84,27 @@ pub fn print(
8484
if (level == 0) {
8585
return writer.writeAll(".{ ... }");
8686
}
87-
const union_val = val.castTag(.@"union").?.data;
87+
const payload = val.castTag(.@"union").?.data;
8888
try writer.writeAll(".{ ");
8989

90-
if (union_val.tag.toIntern() != .none) {
90+
if (payload.tag) |tag| {
9191
try print(.{
9292
.ty = ip.indexToKey(ty.toIntern()).union_type.enum_tag_ty.toType(),
93-
.val = union_val.tag,
93+
.val = tag,
9494
}, writer, level - 1, mod);
9595
try writer.writeAll(" = ");
96-
const field_ty = ty.unionFieldType(union_val.tag, mod).?;
96+
const field_ty = ty.unionFieldType(tag, mod).?;
9797
try print(.{
9898
.ty = field_ty,
99-
.val = union_val.val,
99+
.val = payload.val,
100100
}, writer, level - 1, mod);
101101
} else {
102-
return writer.writeAll("(unknown tag)");
102+
try writer.writeAll("(unknown tag) = ");
103+
const backing_ty = try ty.unionBackingType(mod);
104+
try print(.{
105+
.ty = backing_ty,
106+
.val = payload.val,
107+
}, writer, level - 1, mod);
103108
}
104109

105110
return writer.writeAll(" }");
@@ -421,7 +426,12 @@ pub fn print(
421426
.val = un.val.toValue(),
422427
}, writer, level - 1, mod);
423428
} else {
424-
try writer.writeAll("(unknown tag)");
429+
try writer.writeAll("(unknown tag) = ");
430+
const backing_ty = try ty.unionBackingType(mod);
431+
try print(.{
432+
.ty = backing_ty,
433+
.val = un.val.toValue(),
434+
}, writer, level - 1, mod);
425435
}
426436
} else try writer.writeAll("...");
427437
return writer.writeAll(" }");

src/type.zig

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1954,6 +1954,16 @@ pub const Type = struct {
19541954
return true;
19551955
}
19561956

1957+
/// Returns the type used for backing storage of this union during comptime operations.
1958+
/// Asserts the type is either an extern or packed union.
1959+
pub fn unionBackingType(ty: Type, mod: *Module) !Type {
1960+
return switch (ty.containerLayout(mod)) {
1961+
.Extern => try mod.arrayType(.{ .len = ty.abiSize(mod), .child = .u8_type }),
1962+
.Packed => try mod.intType(.unsigned, @intCast(ty.bitSize(mod))),
1963+
.Auto => unreachable,
1964+
};
1965+
}
1966+
19571967
pub fn unionGetLayout(ty: Type, mod: *Module) Module.UnionLayout {
19581968
const ip = &mod.intern_pool;
19591969
const union_type = ip.indexToKey(ty.toIntern()).union_type;

src/value.zig

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -327,11 +327,19 @@ pub const Value = struct {
327327
},
328328
.@"union" => {
329329
const pl = val.castTag(.@"union").?.data;
330-
return mod.intern(.{ .un = .{
331-
.ty = ty.toIntern(),
332-
.tag = try pl.tag.intern(ty.unionTagTypeHypothetical(mod), mod),
333-
.val = try pl.val.intern(ty.unionFieldType(pl.tag, mod).?, mod),
334-
} });
330+
if (pl.tag) |pl_tag| {
331+
return mod.intern(.{ .un = .{
332+
.ty = ty.toIntern(),
333+
.tag = try pl_tag.intern(ty.unionTagTypeHypothetical(mod), mod),
334+
.val = try pl.val.intern(ty.unionFieldType(pl_tag, mod).?, mod),
335+
} });
336+
} else {
337+
return mod.intern(.{ .un = .{
338+
.ty = ty.toIntern(),
339+
.tag = .none,
340+
.val = try pl.val.intern(try ty.unionBackingType(mod), mod),
341+
} });
342+
}
335343
},
336344
}
337345
}
@@ -399,10 +407,7 @@ pub const Value = struct {
399407

400408
.un => |un| Tag.@"union".create(arena, .{
401409
// toValue asserts that the value cannot be .none which is valid on unions.
402-
.tag = .{
403-
.ip_index = un.tag,
404-
.legacy = undefined,
405-
},
410+
.tag = if (un.tag == .none) null else un.tag.toValue(),
406411
.val = un.val.toValue(),
407412
}),
408413

@@ -709,21 +714,22 @@ pub const Value = struct {
709714
.Union => switch (ty.containerLayout(mod)) {
710715
.Auto => return error.IllDefinedMemoryLayout, // Sema is supposed to have emitted a compile error already
711716
.Extern => {
712-
const union_obj = mod.typeToUnion(ty).?;
713717
if (val.unionTag(mod)) |union_tag| {
718+
const union_obj = mod.typeToUnion(ty).?;
714719
const field_index = mod.unionTagFieldIndex(union_obj, union_tag).?;
715720
const field_type = union_obj.field_types.get(&mod.intern_pool)[field_index].toType();
716721
const field_val = try val.fieldValue(mod, field_index);
717722
const byte_count = @as(usize, @intCast(field_type.abiSize(mod)));
718723
return writeToMemory(field_val, field_type, mod, buffer[0..byte_count]);
719724
} else {
720-
const union_size = ty.abiSize(mod);
721-
const array_type = try mod.arrayType(.{ .len = union_size, .child = .u8_type });
722-
return writeToMemory(val.unionValue(mod), array_type, mod, buffer[0..@as(usize, @intCast(union_size))]);
725+
const backing_ty = try ty.unionBackingType(mod);
726+
const byte_count: usize = @intCast(backing_ty.abiSize(mod));
727+
return writeToMemory(val.unionValue(mod), backing_ty, mod, buffer[0..byte_count]);
723728
}
724729
},
725730
.Packed => {
726-
const byte_count = (@as(usize, @intCast(ty.bitSize(mod))) + 7) / 8;
731+
const backing_ty = try ty.unionBackingType(mod);
732+
const byte_count: usize = @intCast(backing_ty.abiSize(mod));
727733
return writeToPackedMemory(val, ty, mod, buffer[0..byte_count], 0);
728734
},
729735
},
@@ -842,9 +848,8 @@ pub const Value = struct {
842848
const field_val = try val.fieldValue(mod, field_index);
843849
return field_val.writeToPackedMemory(field_type, mod, buffer, bit_offset);
844850
} else {
845-
const union_bits: u16 = @intCast(ty.bitSize(mod));
846-
const int_ty = try mod.intType(.unsigned, union_bits);
847-
return val.unionValue(mod).writeToPackedMemory(int_ty, mod, buffer, bit_offset);
851+
const backing_ty = try ty.unionBackingType(mod);
852+
return val.unionValue(mod).writeToPackedMemory(backing_ty, mod, buffer, bit_offset);
848853
}
849854
},
850855
}
@@ -1146,10 +1151,8 @@ pub const Value = struct {
11461151
.Union => switch (ty.containerLayout(mod)) {
11471152
.Auto, .Extern => unreachable, // Handled by non-packed readFromMemory
11481153
.Packed => {
1149-
const union_bits: u16 = @intCast(ty.bitSize(mod));
1150-
assert(union_bits != 0);
1151-
const int_ty = try mod.intType(.unsigned, union_bits);
1152-
const val = (try readFromPackedMemory(int_ty, mod, buffer, bit_offset, arena)).toIntern();
1154+
const backing_ty = try ty.unionBackingType(mod);
1155+
const val = (try readFromPackedMemory(backing_ty, mod, buffer, bit_offset, arena)).toIntern();
11531156
return (try mod.intern(.{ .un = .{
11541157
.ty = ty.toIntern(),
11551158
.tag = .none,
@@ -4017,7 +4020,7 @@ pub const Value = struct {
40174020
data: Data,
40184021

40194022
pub const Data = struct {
4020-
tag: Value,
4023+
tag: ?Value,
40214024
val: Value,
40224025
};
40234026
};

0 commit comments

Comments
 (0)