Skip to content

Commit b3d463c

Browse files
authored
Merge pull request #12337 from Vexu/stage2-safety
Stage2: implement remaining runtime safety checks
2 parents 3e2defd + 75275a1 commit b3d463c

33 files changed

+378
-102
lines changed

src/Air.zig

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,10 @@ pub const Inst = struct {
660660
/// Uses the `pl_op` field with payload `AtomicRmw`. Operand is `ptr`.
661661
atomic_rmw,
662662

663+
/// Returns true if enum tag value has a name.
664+
/// Uses the `un_op` field.
665+
is_named_enum_value,
666+
663667
/// Given an enum tag value, returns the tag name. The enum type may be non-exhaustive.
664668
/// Result type is always `[:0]const u8`.
665669
/// Uses the `un_op` field.
@@ -1057,6 +1061,7 @@ pub fn typeOfIndex(air: Air, inst: Air.Inst.Index) Type {
10571061
.is_non_err,
10581062
.is_err_ptr,
10591063
.is_non_err_ptr,
1064+
.is_named_enum_value,
10601065
=> return Type.bool,
10611066

10621067
.const_ty => return Type.type,

src/Liveness.zig

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ pub fn categorizeOperand(
291291
.is_non_err_ptr,
292292
.ptrtoint,
293293
.bool_to_int,
294+
.is_named_enum_value,
294295
.tag_name,
295296
.error_name,
296297
.sqrt,
@@ -858,6 +859,7 @@ fn analyzeInst(
858859
.bool_to_int,
859860
.ret,
860861
.ret_load,
862+
.is_named_enum_value,
861863
.tag_name,
862864
.error_name,
863865
.sqrt,

src/Sema.zig

Lines changed: 159 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1578,8 +1578,7 @@ pub fn setupErrorReturnTrace(sema: *Sema, block: *Block, last_arg_index: usize)
15781578

15791579
// st.index = 0;
15801580
const index_field_ptr = try sema.fieldPtr(&err_trace_block, src, st_ptr, "index", src, true);
1581-
const zero = try sema.addConstant(Type.usize, Value.zero);
1582-
try sema.storePtr2(&err_trace_block, src, index_field_ptr, src, zero, src, .store);
1581+
try sema.storePtr2(&err_trace_block, src, index_field_ptr, src, .zero_usize, src, .store);
15831582

15841583
// @errorReturnTrace() = &st;
15851584
_ = try err_trace_block.addUnOp(.set_err_return_trace, st_ptr);
@@ -6949,8 +6948,12 @@ fn zirIntToEnum(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!A
69496948
}
69506949

69516950
try sema.requireRuntimeBlock(block, src, operand_src);
6952-
// TODO insert safety check to make sure the value matches an enum value
6953-
return block.addTyOp(.intcast, dest_ty, operand);
6951+
const result = try block.addTyOp(.intcast, dest_ty, operand);
6952+
if (block.wantSafety() and !dest_ty.isNonexhaustiveEnum() and sema.mod.comp.bin_file.options.use_llvm) {
6953+
const ok = try block.addUnOp(.is_named_enum_value, result);
6954+
try sema.addSafetyCheck(block, ok, .invalid_enum_value);
6955+
}
6956+
return result;
69546957
}
69556958

69566959
/// Pointer in, pointer out.
@@ -9707,7 +9710,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
97079710
}
97089711

97099712
var final_else_body: []const Air.Inst.Index = &.{};
9710-
if (special.body.len != 0 or !is_first) {
9713+
if (special.body.len != 0 or !is_first or case_block.wantSafety()) {
97119714
var wip_captures = try WipCaptureScope.init(gpa, sema.perm_arena, child_block.wip_capture_scope);
97129715
defer wip_captures.deinit();
97139716

@@ -9730,9 +9733,11 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
97309733
} else {
97319734
// We still need a terminator in this block, but we have proven
97329735
// that it is unreachable.
9733-
// TODO this should be a special safety panic other than unreachable, something
9734-
// like "panic: switch operand had corrupt value not allowed by the type"
9735-
try case_block.addUnreachable(src, true);
9736+
if (case_block.wantSafety()) {
9737+
_ = try sema.safetyPanic(&case_block, src, .corrupt_switch);
9738+
} else {
9739+
_ = try case_block.addNoOp(.unreach);
9740+
}
97369741
}
97379742

97389743
try wip_captures.finalize();
@@ -10241,34 +10246,57 @@ fn zirShl(
1024110246
} else rhs;
1024210247

1024310248
try sema.requireRuntimeBlock(block, src, runtime_src);
10244-
if (block.wantSafety() and air_tag == .shl_exact) {
10245-
const op_ov_tuple_ty = try sema.overflowArithmeticTupleType(lhs_ty);
10246-
const op_ov = try block.addInst(.{
10247-
.tag = .shl_with_overflow,
10248-
.data = .{ .ty_pl = .{
10249-
.ty = try sema.addType(op_ov_tuple_ty),
10250-
.payload = try sema.addExtra(Air.Bin{
10251-
.lhs = lhs,
10252-
.rhs = rhs,
10253-
}),
10254-
} },
10255-
});
10256-
const ov_bit = try sema.tupleFieldValByIndex(block, src, op_ov, 1, op_ov_tuple_ty);
10257-
const any_ov_bit = if (lhs_ty.zigTypeTag() == .Vector)
10258-
try block.addInst(.{
10259-
.tag = if (block.float_mode == .Optimized) .reduce_optimized else .reduce,
10260-
.data = .{ .reduce = .{
10261-
.operand = ov_bit,
10262-
.operation = .Or,
10249+
if (block.wantSafety()) {
10250+
const bit_count = scalar_ty.intInfo(target).bits;
10251+
if (!std.math.isPowerOfTwo(bit_count)) {
10252+
const bit_count_val = try Value.Tag.int_u64.create(sema.arena, bit_count);
10253+
10254+
const ok = if (rhs_ty.zigTypeTag() == .Vector) ok: {
10255+
const bit_count_inst = try sema.addConstant(rhs_ty, try Value.Tag.repeated.create(sema.arena, bit_count_val));
10256+
const lt = try block.addCmpVector(rhs, bit_count_inst, .lt, try sema.addType(rhs_ty));
10257+
break :ok try block.addInst(.{
10258+
.tag = .reduce,
10259+
.data = .{ .reduce = .{
10260+
.operand = lt,
10261+
.operation = .And,
10262+
} },
10263+
});
10264+
} else ok: {
10265+
const bit_count_inst = try sema.addConstant(rhs_ty, bit_count_val);
10266+
break :ok try block.addBinOp(.cmp_lt, rhs, bit_count_inst);
10267+
};
10268+
try sema.addSafetyCheck(block, ok, .shift_rhs_too_big);
10269+
}
10270+
10271+
if (air_tag == .shl_exact) {
10272+
const op_ov_tuple_ty = try sema.overflowArithmeticTupleType(lhs_ty);
10273+
const op_ov = try block.addInst(.{
10274+
.tag = .shl_with_overflow,
10275+
.data = .{ .ty_pl = .{
10276+
.ty = try sema.addType(op_ov_tuple_ty),
10277+
.payload = try sema.addExtra(Air.Bin{
10278+
.lhs = lhs,
10279+
.rhs = rhs,
10280+
}),
1026310281
} },
10264-
})
10265-
else
10266-
ov_bit;
10267-
const zero_ov = try sema.addConstant(Type.@"u1", Value.zero);
10268-
const no_ov = try block.addBinOp(.cmp_eq, any_ov_bit, zero_ov);
10282+
});
10283+
const ov_bit = try sema.tupleFieldValByIndex(block, src, op_ov, 1, op_ov_tuple_ty);
10284+
const any_ov_bit = if (lhs_ty.zigTypeTag() == .Vector)
10285+
try block.addInst(.{
10286+
.tag = if (block.float_mode == .Optimized) .reduce_optimized else .reduce,
10287+
.data = .{ .reduce = .{
10288+
.operand = ov_bit,
10289+
.operation = .Or,
10290+
} },
10291+
})
10292+
else
10293+
ov_bit;
10294+
const zero_ov = try sema.addConstant(Type.@"u1", Value.zero);
10295+
const no_ov = try block.addBinOp(.cmp_eq, any_ov_bit, zero_ov);
1026910296

10270-
try sema.addSafetyCheck(block, no_ov, .shl_overflow);
10271-
return sema.tupleFieldValByIndex(block, src, op_ov, 0, op_ov_tuple_ty);
10297+
try sema.addSafetyCheck(block, no_ov, .shl_overflow);
10298+
return sema.tupleFieldValByIndex(block, src, op_ov, 0, op_ov_tuple_ty);
10299+
}
1027210300
}
1027310301
return block.addBinOp(air_tag, lhs, new_rhs);
1027410302
}
@@ -10347,20 +10375,43 @@ fn zirShr(
1034710375

1034810376
try sema.requireRuntimeBlock(block, src, runtime_src);
1034910377
const result = try block.addBinOp(air_tag, lhs, rhs);
10350-
if (block.wantSafety() and air_tag == .shr_exact) {
10351-
const back = try block.addBinOp(.shl, result, rhs);
10352-
10353-
const ok = if (rhs_ty.zigTypeTag() == .Vector) ok: {
10354-
const eql = try block.addCmpVector(lhs, back, .eq, try sema.addType(rhs_ty));
10355-
break :ok try block.addInst(.{
10356-
.tag = if (block.float_mode == .Optimized) .reduce_optimized else .reduce,
10357-
.data = .{ .reduce = .{
10358-
.operand = eql,
10359-
.operation = .And,
10360-
} },
10361-
});
10362-
} else try block.addBinOp(.cmp_eq, lhs, back);
10363-
try sema.addSafetyCheck(block, ok, .shr_overflow);
10378+
if (block.wantSafety()) {
10379+
const bit_count = scalar_ty.intInfo(target).bits;
10380+
if (!std.math.isPowerOfTwo(bit_count)) {
10381+
const bit_count_val = try Value.Tag.int_u64.create(sema.arena, bit_count);
10382+
10383+
const ok = if (rhs_ty.zigTypeTag() == .Vector) ok: {
10384+
const bit_count_inst = try sema.addConstant(rhs_ty, try Value.Tag.repeated.create(sema.arena, bit_count_val));
10385+
const lt = try block.addCmpVector(rhs, bit_count_inst, .lt, try sema.addType(rhs_ty));
10386+
break :ok try block.addInst(.{
10387+
.tag = .reduce,
10388+
.data = .{ .reduce = .{
10389+
.operand = lt,
10390+
.operation = .And,
10391+
} },
10392+
});
10393+
} else ok: {
10394+
const bit_count_inst = try sema.addConstant(rhs_ty, bit_count_val);
10395+
break :ok try block.addBinOp(.cmp_lt, rhs, bit_count_inst);
10396+
};
10397+
try sema.addSafetyCheck(block, ok, .shift_rhs_too_big);
10398+
}
10399+
10400+
if (air_tag == .shr_exact) {
10401+
const back = try block.addBinOp(.shl, result, rhs);
10402+
10403+
const ok = if (rhs_ty.zigTypeTag() == .Vector) ok: {
10404+
const eql = try block.addCmpVector(lhs, back, .eq, try sema.addType(rhs_ty));
10405+
break :ok try block.addInst(.{
10406+
.tag = if (block.float_mode == .Optimized) .reduce_optimized else .reduce,
10407+
.data = .{ .reduce = .{
10408+
.operand = eql,
10409+
.operation = .And,
10410+
} },
10411+
});
10412+
} else try block.addBinOp(.cmp_eq, lhs, back);
10413+
try sema.addSafetyCheck(block, ok, .shr_overflow);
10414+
}
1036410415
}
1036510416
return result;
1036610417
}
@@ -15961,6 +16012,11 @@ fn zirTagName(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air
1596116012
const field_name = enum_ty.enumFieldName(field_index);
1596216013
return sema.addStrLit(block, field_name);
1596316014
}
16015+
try sema.requireRuntimeBlock(block, src, operand_src);
16016+
if (block.wantSafety() and sema.mod.comp.bin_file.options.use_llvm) {
16017+
const ok = try block.addUnOp(.is_named_enum_value, casted_operand);
16018+
try sema.addSafetyCheck(block, ok, .invalid_enum_value);
16019+
}
1596416020
// In case the value is runtime-known, we have an AIR instruction for this instead
1596516021
// of trying to lower it in Sema because an optimization pass may result in the operand
1596616022
// being comptime-known, which would let us elide the `tag_name` AIR instruction.
@@ -16942,7 +16998,7 @@ fn zirIntToPtr(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Ai
1694216998
}
1694316999

1694417000
try sema.requireRuntimeBlock(block, src, operand_src);
16945-
if (block.wantSafety()) {
17001+
if (block.wantSafety() and try sema.typeHasRuntimeBits(block, sema.src, type_res.elemType2())) {
1694617002
if (!type_res.isAllowzeroPtr()) {
1694717003
const is_non_zero = try block.addBinOp(.cmp_neq, operand_coerced, .zero_usize);
1694817004
try sema.addSafetyCheck(block, is_non_zero, .cast_to_null);
@@ -17234,7 +17290,9 @@ fn zirAlignCast(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!A
1723417290
}
1723517291

1723617292
try sema.requireRuntimeBlock(block, inst_data.src(), ptr_src);
17237-
if (block.wantSafety() and dest_align > 1) {
17293+
if (block.wantSafety() and dest_align > 1 and
17294+
try sema.typeHasRuntimeBits(block, sema.src, dest_ty.elemType2()))
17295+
{
1723817296
const val_payload = try sema.arena.create(Value.Payload.U64);
1723917297
val_payload.* = .{
1724017298
.base = .{ .tag = .int_u64 },
@@ -17253,7 +17311,7 @@ fn zirAlignCast(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!A
1725317311
const is_aligned = try block.addBinOp(.cmp_eq, remainder, .zero_usize);
1725417312
const ok = if (ptr_ty.isSlice()) ok: {
1725517313
const len = try sema.analyzeSliceLen(block, ptr_src, ptr);
17256-
const len_zero = try block.addBinOp(.cmp_eq, len, try sema.addConstant(Type.usize, Value.zero));
17314+
const len_zero = try block.addBinOp(.cmp_eq, len, .zero_usize);
1725717315
break :ok try block.addBinOp(.bit_or, len_zero, is_aligned);
1725817316
} else is_aligned;
1725917317
try sema.addSafetyCheck(block, ok, .incorrect_alignment);
@@ -20114,6 +20172,9 @@ pub const PanicId = enum {
2011420172
/// TODO make this call `std.builtin.panicInactiveUnionField`.
2011520173
inactive_union_field,
2011620174
integer_part_out_of_bounds,
20175+
corrupt_switch,
20176+
shift_rhs_too_big,
20177+
invalid_enum_value,
2011720178
};
2011820179

2011920180
fn addSafetyCheck(
@@ -20408,6 +20469,9 @@ fn safetyPanic(
2040820469
.exact_division_remainder => "exact division produced remainder",
2040920470
.inactive_union_field => "access of inactive union field",
2041020471
.integer_part_out_of_bounds => "integer part of floating point value out of bounds",
20472+
.corrupt_switch => "switch on corrupt value",
20473+
.shift_rhs_too_big => "shift amount is greater than the type size",
20474+
.invalid_enum_value => "invalid enum value",
2041120475
};
2041220476

2041320477
const msg_inst = msg_inst: {
@@ -22096,7 +22160,6 @@ fn coerceExtra(
2209622160
.ok => {},
2209722161
else => break :src_c_ptr,
2209822162
}
22099-
// TODO add safety check for null pointer
2210022163
return sema.coerceCompatiblePtrs(block, dest_ty, inst, inst_src);
2210122164
}
2210222165

@@ -24569,6 +24632,24 @@ fn coerceCompatiblePtrs(
2456924632
return sema.addConstant(dest_ty, val);
2457024633
}
2457124634
try sema.requireRuntimeBlock(block, inst_src, null);
24635+
const inst_ty = sema.typeOf(inst);
24636+
const inst_allows_zero = (inst_ty.zigTypeTag() == .Pointer and inst_ty.ptrAllowsZero()) or true;
24637+
if (block.wantSafety() and inst_allows_zero and !dest_ty.ptrAllowsZero() and
24638+
try sema.typeHasRuntimeBits(block, sema.src, dest_ty.elemType2()))
24639+
{
24640+
const actual_ptr = if (inst_ty.isSlice())
24641+
try sema.analyzeSlicePtr(block, inst_src, inst, inst_ty)
24642+
else
24643+
inst;
24644+
const ptr_int = try block.addUnOp(.ptrtoint, actual_ptr);
24645+
const is_non_zero = try block.addBinOp(.cmp_neq, ptr_int, .zero_usize);
24646+
const ok = if (inst_ty.isSlice()) ok: {
24647+
const len = try sema.analyzeSliceLen(block, inst_src, inst);
24648+
const len_zero = try block.addBinOp(.cmp_eq, len, .zero_usize);
24649+
break :ok try block.addBinOp(.bit_or, len_zero, is_non_zero);
24650+
} else is_non_zero;
24651+
try sema.addSafetyCheck(block, ok, .cast_to_null);
24652+
}
2457224653
return sema.bitCast(block, dest_ty, inst, inst_src);
2457324654
}
2457424655

@@ -25708,6 +25789,27 @@ fn analyzeSlice(
2570825789
const new_ptr_val = opt_new_ptr_val orelse {
2570925790
const result = try block.addBitCast(return_ty, new_ptr);
2571025791
if (block.wantSafety()) {
25792+
// requirement: slicing C ptr is non-null
25793+
if (ptr_ptr_child_ty.isCPtr()) {
25794+
const is_non_null = try sema.analyzeIsNull(block, ptr_src, ptr, true);
25795+
try sema.addSafetyCheck(block, is_non_null, .unwrap_null);
25796+
}
25797+
25798+
if (slice_ty.isSlice()) {
25799+
const slice_len_inst = try block.addTyOp(.slice_len, Type.usize, ptr_or_slice);
25800+
const actual_len = if (slice_ty.sentinel() == null)
25801+
slice_len_inst
25802+
else
25803+
try sema.analyzeArithmetic(block, .add, slice_len_inst, .one, src, end_src, end_src);
25804+
25805+
const actual_end = if (slice_sentinel != null)
25806+
try sema.analyzeArithmetic(block, .add, end, .one, src, end_src, end_src)
25807+
else
25808+
end;
25809+
25810+
try sema.panicIndexOutOfBounds(block, src, actual_end, actual_len, .cmp_lte);
25811+
}
25812+
2571125813
// requirement: result[new_len] == slice_sentinel
2571225814
try sema.panicSentinelMismatch(block, src, slice_sentinel, elem_ty, result, new_len);
2571325815
}
@@ -25769,7 +25871,11 @@ fn analyzeSlice(
2576925871
break :blk try sema.analyzeArithmetic(block, .add, slice_len_inst, .one, src, end_src, end_src);
2577025872
} else null;
2577125873
if (opt_len_inst) |len_inst| {
25772-
try sema.panicIndexOutOfBounds(block, src, end, len_inst, .cmp_lte);
25874+
const actual_end = if (slice_sentinel != null)
25875+
try sema.analyzeArithmetic(block, .add, end, .one, src, end_src, end_src)
25876+
else
25877+
end;
25878+
try sema.panicIndexOutOfBounds(block, src, actual_end, len_inst, .cmp_lte);
2577325879
}
2577425880

2577525881
// requirement: start <= end

src/arch/aarch64/CodeGen.zig

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,8 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
777777
.float_to_int_optimized,
778778
=> return self.fail("TODO implement optimized float mode", .{}),
779779

780+
.is_named_enum_value => return self.fail("TODO implement is_named_enum_value", .{}),
781+
780782
.wasm_memory_size => unreachable,
781783
.wasm_memory_grow => unreachable,
782784
// zig fmt: on

src/arch/arm/CodeGen.zig

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,8 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
768768
.float_to_int_optimized,
769769
=> return self.fail("TODO implement optimized float mode", .{}),
770770

771+
.is_named_enum_value => return self.fail("TODO implement is_named_enum_value", .{}),
772+
771773
.wasm_memory_size => unreachable,
772774
.wasm_memory_grow => unreachable,
773775
// zig fmt: on

src/arch/riscv64/CodeGen.zig

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,8 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
693693
.float_to_int_optimized,
694694
=> return self.fail("TODO implement optimized float mode", .{}),
695695

696+
.is_named_enum_value => return self.fail("TODO implement is_named_enum_value", .{}),
697+
696698
.wasm_memory_size => unreachable,
697699
.wasm_memory_grow => unreachable,
698700
// zig fmt: on

src/arch/sparc64/CodeGen.zig

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,8 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
705705
.float_to_int_optimized,
706706
=> @panic("TODO implement optimized float mode"),
707707

708+
.is_named_enum_value => @panic("TODO implement is_named_enum_value"),
709+
708710
.wasm_memory_size => unreachable,
709711
.wasm_memory_grow => unreachable,
710712
// zig fmt: on

0 commit comments

Comments
 (0)