Skip to content

Commit c1c9bc0

Browse files
committed
Sema: do not assume switch item indices align with union field indices
Resolves: #17754
1 parent 5257643 commit c1c9bc0

File tree

2 files changed

+47
-16
lines changed

2 files changed

+47
-16
lines changed

src/Sema.zig

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10789,23 +10789,24 @@ const SwitchProngAnalysis = struct {
1078910789
const first_field_index: u32 = mod.unionTagFieldIndex(union_obj, first_item_val).?;
1079010790
const first_field_ty = union_obj.field_types.get(ip)[first_field_index].toType();
1079110791

10792-
const field_tys = try sema.arena.alloc(Type, case_vals.len);
10793-
for (case_vals, field_tys) |item, *field_ty| {
10792+
const field_indices = try sema.arena.alloc(u32, case_vals.len);
10793+
for (case_vals, field_indices) |item, *field_idx| {
1079410794
const item_val = sema.resolveConstDefinedValue(block, .unneeded, item, undefined) catch unreachable;
10795-
const field_idx = mod.unionTagFieldIndex(union_obj, item_val).?;
10796-
field_ty.* = union_obj.field_types.get(ip)[field_idx].toType();
10795+
field_idx.* = mod.unionTagFieldIndex(union_obj, item_val).?;
1079710796
}
1079810797

1079910798
// Fast path: if all the operands are the same type already, we don't need to hit
1080010799
// PTR! This will also allow us to emit simpler code.
10801-
const same_types = for (field_tys[1..]) |field_ty| {
10802-
if (!field_ty.eql(field_tys[0], sema.mod)) break false;
10800+
const same_types = for (field_indices[1..]) |field_idx| {
10801+
const field_ty = union_obj.field_types.get(ip)[field_idx].toType();
10802+
if (!field_ty.eql(first_field_ty, sema.mod)) break false;
1080310803
} else true;
1080410804

10805-
const capture_ty = if (same_types) field_tys[0] else capture_ty: {
10805+
const capture_ty = if (same_types) first_field_ty else capture_ty: {
1080610806
// We need values to run PTR on, so make a bunch of undef constants.
1080710807
const dummy_captures = try sema.arena.alloc(Air.Inst.Ref, case_vals.len);
10808-
for (dummy_captures, field_tys) |*dummy, field_ty| {
10808+
for (dummy_captures, field_indices) |*dummy, field_idx| {
10809+
const field_ty = union_obj.field_types.get(ip)[field_idx].toType();
1080910810
dummy.* = try mod.undefRef(field_ty);
1081010811
}
1081110812

@@ -10852,7 +10853,8 @@ const SwitchProngAnalysis = struct {
1085210853
// By-ref captures of hetereogeneous types are only allowed if each field
1085310854
// pointer type is in-memory coercible to the capture pointer type.
1085410855
if (!same_types) {
10855-
for (field_tys, 0..) |field_ty, i| {
10856+
for (field_indices, 0..) |field_idx, i| {
10857+
const field_ty = union_obj.field_types.get(ip)[field_idx].toType();
1085610858
const field_ptr_ty = try sema.ptrType(.{
1085710859
.child = field_ty.toIntern(),
1085810860
.flags = .{
@@ -10915,7 +10917,8 @@ const SwitchProngAnalysis = struct {
1091510917
// We may have to emit a switch block which coerces the operand to the capture type.
1091610918
// If we can, try to avoid that using in-memory coercions.
1091710919
const first_non_imc = in_mem: {
10918-
for (field_tys, 0..) |field_ty, i| {
10920+
for (field_indices, 0..) |field_idx, i| {
10921+
const field_ty = union_obj.field_types.get(ip)[field_idx].toType();
1091910922
if (.ok != try sema.coerceInMemoryAllowed(block, capture_ty, field_ty, false, sema.mod.getTarget(), .unneeded, .unneeded)) {
1092010923
break :in_mem i;
1092110924
}
@@ -10933,11 +10936,12 @@ const SwitchProngAnalysis = struct {
1093310936
// be several, and we can squash all of these cases into the same switch prong using
1093410937
// a simple bitcast. We'll make this the 'else' prong.
1093510938

10936-
var in_mem_coercible = try std.DynamicBitSet.initFull(sema.arena, field_tys.len);
10939+
var in_mem_coercible = try std.DynamicBitSet.initFull(sema.arena, field_indices.len);
1093710940
in_mem_coercible.unset(first_non_imc);
1093810941
{
1093910942
const next = first_non_imc + 1;
10940-
for (field_tys[next..], next..) |field_ty, i| {
10943+
for (field_indices[next..], next..) |field_idx, i| {
10944+
const field_ty = union_obj.field_types.get(ip)[field_idx].toType();
1094110945
if (.ok != try sema.coerceInMemoryAllowed(block, capture_ty, field_ty, false, sema.mod.getTarget(), .unneeded, .unneeded)) {
1094210946
in_mem_coercible.unset(i);
1094310947
}
@@ -10954,7 +10958,7 @@ const SwitchProngAnalysis = struct {
1095410958
},
1095510959
});
1095610960

10957-
const prong_count = field_tys.len - in_mem_coercible.count();
10961+
const prong_count = field_indices.len - in_mem_coercible.count();
1095810962

1095910963
const estimated_extra = prong_count * 6; // 2 for Case, 1 item, probably 3 insts
1096010964
var cases_extra = try std.ArrayList(u32).initCapacity(sema.gpa, estimated_extra);
@@ -10967,7 +10971,9 @@ const SwitchProngAnalysis = struct {
1096710971
var coerce_block = block.makeSubBlock();
1096810972
defer coerce_block.instructions.deinit(sema.gpa);
1096910973

10970-
const uncoerced = try coerce_block.addStructFieldVal(spa.operand, @intCast(idx), field_tys[idx]);
10974+
const field_idx = field_indices[idx];
10975+
const field_ty = union_obj.field_types.get(ip)[field_idx].toType();
10976+
const uncoerced = try coerce_block.addStructFieldVal(spa.operand, field_idx, field_ty);
1097110977
const coerced = sema.coerce(&coerce_block, capture_ty, uncoerced, .unneeded) catch |err| switch (err) {
1097210978
error.NeededSourceLocation => {
1097310979
const multi_idx = raw_capture_src.multi_capture;
@@ -10993,8 +10999,10 @@ const SwitchProngAnalysis = struct {
1099310999
var coerce_block = block.makeSubBlock();
1099411000
defer coerce_block.instructions.deinit(sema.gpa);
1099511001

10996-
const first_imc = in_mem_coercible.findFirstSet().?;
10997-
const uncoerced = try coerce_block.addStructFieldVal(spa.operand, @intCast(first_imc), field_tys[first_imc]);
11002+
const first_imc_item_idx = in_mem_coercible.findFirstSet().?;
11003+
const first_imc_field_idx = field_indices[first_imc_item_idx];
11004+
const first_imc_field_ty = union_obj.field_types.get(ip)[first_imc_field_idx].toType();
11005+
const uncoerced = try coerce_block.addStructFieldVal(spa.operand, first_imc_field_idx, first_imc_field_ty);
1099811006
const coerced = try coerce_block.addBitCast(capture_ty, uncoerced);
1099911007
_ = try coerce_block.addBr(capture_block_inst, coerced);
1100011008

test/behavior/switch.zig

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -800,3 +800,26 @@ test "nested break ignores switch conditions and breaks instead" {
800800
// Originally reported at https://github.com/ziglang/zig/issues/10196
801801
try expect(0x01 == try S.register_to_address("a0"));
802802
}
803+
804+
test "peer type resolution on switch captures ignores unused payload bits" {
805+
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
806+
807+
const Foo = union(enum) {
808+
a: u32,
809+
b: u64,
810+
};
811+
812+
var val: Foo = undefined;
813+
@memset(std.mem.asBytes(&val), 0xFF);
814+
815+
// This is runtime-known so the following store isn't comptime-known.
816+
var rt: u32 = 123;
817+
val = .{ .a = rt }; // will not necessarily zero remaning payload memory
818+
819+
// Fields intentionally backwards here
820+
const x = switch (val) {
821+
.b, .a => |x| x,
822+
};
823+
824+
try expect(x == 123);
825+
}

0 commit comments

Comments
 (0)