Skip to content

Commit 04626c1

Browse files
committed
Merge remote-tracking branch 'origin/master' into io-stream-iface
2 parents 431d76c + 51c6bb9 commit 04626c1

File tree

9 files changed

+173
-36
lines changed

9 files changed

+173
-36
lines changed

lib/std/fmt.zig

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -492,10 +492,7 @@ pub fn formatType(
492492
},
493493
.Type => return output(context, @typeName(T)),
494494
.EnumLiteral => {
495-
const name = @tagName(value);
496-
var buffer: [name.len + 1]u8 = undefined;
497-
buffer[0] = '.';
498-
std.mem.copy(u8, buffer[1..], name);
495+
const buffer = [_]u8{'.'} ++ @tagName(value);
499496
return formatType(buffer, fmt, options, context, Errors, output, max_depth);
500497
},
501498
else => @compileError("Unable to format type '" ++ @typeName(T) ++ "'"),

lib/std/io/bit_in_stream.zig

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,18 @@ pub fn BitInStream(endian: builtin.Endian, comptime InStreamType: type) type {
6666
switch (endian) {
6767
.Big => {
6868
out_buffer = @as(Buf, self.bit_buffer >> shift);
69-
self.bit_buffer <<= n;
69+
if (n >= u7_bit_count)
70+
self.bit_buffer = 0
71+
else
72+
self.bit_buffer <<= n;
7073
},
7174
.Little => {
7275
const value = (self.bit_buffer << shift) >> shift;
7376
out_buffer = @as(Buf, value);
74-
self.bit_buffer >>= n;
77+
if (n >= u7_bit_count)
78+
self.bit_buffer = 0
79+
else
80+
self.bit_buffer >>= n;
7581
},
7682
}
7783
self.bit_count -= n;

lib/std/mem.zig

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,9 @@ pub fn writeInt(comptime T: type, buffer: *[@divExact(T.bit_count, 8)]u8, value:
935935
pub fn writeIntSliceLittle(comptime T: type, buffer: []u8, value: T) void {
936936
assert(buffer.len >= @divExact(T.bit_count, 8));
937937

938+
if (T.bit_count == 0)
939+
return set(u8, buffer, 0);
940+
938941
// TODO I want to call writeIntLittle here but comptime eval facilities aren't good enough
939942
const uint = std.meta.IntType(false, T.bit_count);
940943
var bits = @truncate(uint, value);
@@ -952,6 +955,9 @@ pub fn writeIntSliceLittle(comptime T: type, buffer: []u8, value: T) void {
952955
pub fn writeIntSliceBig(comptime T: type, buffer: []u8, value: T) void {
953956
assert(buffer.len >= @divExact(T.bit_count, 8));
954957

958+
if (T.bit_count == 0)
959+
return set(u8, buffer, 0);
960+
955961
// TODO I want to call writeIntBig here but comptime eval facilities aren't good enough
956962
const uint = std.meta.IntType(false, T.bit_count);
957963
var bits = @truncate(uint, value);
@@ -1821,7 +1827,7 @@ test "sliceAsBytes" {
18211827
}
18221828

18231829
test "sliceAsBytes with sentinel slice" {
1824-
const empty_string:[:0]const u8 = "";
1830+
const empty_string: [:0]const u8 = "";
18251831
const bytes = sliceAsBytes(empty_string);
18261832
testing.expect(bytes.len == 0);
18271833
}

src/all_types.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1834,6 +1834,7 @@ enum PanicMsgId {
18341834
PanicMsgIdBadNoAsyncCall,
18351835
PanicMsgIdResumeNotSuspendedFn,
18361836
PanicMsgIdBadSentinel,
1837+
PanicMsgIdShxTooBigRhs,
18371838

18381839
PanicMsgIdCount,
18391840
};

src/codegen.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,8 @@ static Buf *panic_msg_buf(PanicMsgId msg_id) {
974974
return buf_create_from_str("resumed a non-suspended function");
975975
case PanicMsgIdBadSentinel:
976976
return buf_create_from_str("sentinel mismatch");
977+
case PanicMsgIdShxTooBigRhs:
978+
return buf_create_from_str("shift amount is greater than the type size");
977979
}
978980
zig_unreachable();
979981
}
@@ -2841,6 +2843,26 @@ static LLVMValueRef gen_rem(CodeGen *g, bool want_runtime_safety, bool want_fast
28412843

28422844
}
28432845

2846+
static void gen_shift_rhs_check(CodeGen *g, ZigType *lhs_type, ZigType *rhs_type, LLVMValueRef value) {
2847+
// We only check if the rhs value of the shift expression is greater or
2848+
// equal to the number of bits of the lhs if it's not a power of two,
2849+
// otherwise the check is useful as the allowed values are limited by the
2850+
// operand type itself
2851+
if (!is_power_of_2(lhs_type->data.integral.bit_count)) {
2852+
LLVMValueRef bit_count_value = LLVMConstInt(get_llvm_type(g, rhs_type),
2853+
lhs_type->data.integral.bit_count, false);
2854+
LLVMValueRef less_than_bit = LLVMBuildICmp(g->builder, LLVMIntULT, value, bit_count_value, "");
2855+
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "CheckFail");
2856+
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "CheckOk");
2857+
LLVMBuildCondBr(g->builder, less_than_bit, ok_block, fail_block);
2858+
2859+
LLVMPositionBuilderAtEnd(g->builder, fail_block);
2860+
gen_safety_crash(g, PanicMsgIdShxTooBigRhs);
2861+
2862+
LLVMPositionBuilderAtEnd(g->builder, ok_block);
2863+
}
2864+
}
2865+
28442866
static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable,
28452867
IrInstGenBinOp *bin_op_instruction)
28462868
{
@@ -2949,6 +2971,11 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable,
29492971
{
29502972
assert(scalar_type->id == ZigTypeIdInt);
29512973
LLVMValueRef op2_casted = gen_widen_or_shorten(g, false, op2->value->type, scalar_type, op2_value);
2974+
2975+
if (want_runtime_safety) {
2976+
gen_shift_rhs_check(g, scalar_type, op2->value->type, op2_value);
2977+
}
2978+
29522979
bool is_sloppy = (op_id == IrBinOpBitShiftLeftLossy);
29532980
if (is_sloppy) {
29542981
return LLVMBuildShl(g->builder, op1_value, op2_casted, "");
@@ -2965,6 +2992,11 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable,
29652992
{
29662993
assert(scalar_type->id == ZigTypeIdInt);
29672994
LLVMValueRef op2_casted = gen_widen_or_shorten(g, false, op2->value->type, scalar_type, op2_value);
2995+
2996+
if (want_runtime_safety) {
2997+
gen_shift_rhs_check(g, scalar_type, op2->value->type, op2_value);
2998+
}
2999+
29683000
bool is_sloppy = (op_id == IrBinOpBitShiftRightLossy);
29693001
if (is_sloppy) {
29703002
if (scalar_type->data.integral.is_signed) {

src/ir.cpp

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16635,49 +16635,69 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in
1663516635
IrInstGen *casted_op2;
1663616636
IrBinOp op_id = bin_op_instruction->op_id;
1663716637
if (op1->value->type->id == ZigTypeIdComptimeInt) {
16638+
// comptime_int has no finite bit width
1663816639
casted_op2 = op2;
1663916640

1664016641
if (op_id == IrBinOpBitShiftLeftLossy) {
1664116642
op_id = IrBinOpBitShiftLeftExact;
1664216643
}
1664316644

16644-
if (casted_op2->value->data.x_bigint.is_negative) {
16645+
if (!instr_is_comptime(op2)) {
16646+
ir_add_error(ira, &bin_op_instruction->base.base,
16647+
buf_sprintf("LHS of shift must be an integer type, or RHS must be compile-time known"));
16648+
return ira->codegen->invalid_inst_gen;
16649+
}
16650+
16651+
ZigValue *op2_val = ir_resolve_const(ira, casted_op2, UndefBad);
16652+
if (op2_val == nullptr)
16653+
return ira->codegen->invalid_inst_gen;
16654+
16655+
if (op2_val->data.x_bigint.is_negative) {
1664516656
Buf *val_buf = buf_alloc();
16646-
bigint_append_buf(val_buf, &casted_op2->value->data.x_bigint, 10);
16647-
ir_add_error(ira, &casted_op2->base, buf_sprintf("shift by negative value %s", buf_ptr(val_buf)));
16657+
bigint_append_buf(val_buf, &op2_val->data.x_bigint, 10);
16658+
ir_add_error(ira, &casted_op2->base,
16659+
buf_sprintf("shift by negative value %s", buf_ptr(val_buf)));
1664816660
return ira->codegen->invalid_inst_gen;
1664916661
}
1665016662
} else {
16663+
const unsigned bit_count = op1->value->type->data.integral.bit_count;
1665116664
ZigType *shift_amt_type = get_smallest_unsigned_int_type(ira->codegen,
16652-
op1->value->type->data.integral.bit_count - 1);
16653-
if (bin_op_instruction->op_id == IrBinOpBitShiftLeftLossy &&
16654-
op2->value->type->id == ZigTypeIdComptimeInt) {
16665+
bit_count > 0 ? bit_count - 1 : 0);
1665516666

16656-
ZigValue *op2_val = ir_resolve_const(ira, op2, UndefBad);
16667+
casted_op2 = ir_implicit_cast(ira, op2, shift_amt_type);
16668+
if (type_is_invalid(casted_op2->value->type))
16669+
return ira->codegen->invalid_inst_gen;
16670+
16671+
// This check is only valid iff op1 has at least one bit
16672+
if (bit_count > 0 && instr_is_comptime(casted_op2)) {
16673+
ZigValue *op2_val = ir_resolve_const(ira, casted_op2, UndefBad);
1665716674
if (op2_val == nullptr)
1665816675
return ira->codegen->invalid_inst_gen;
16659-
if (!bigint_fits_in_bits(&op2_val->data.x_bigint,
16660-
shift_amt_type->data.integral.bit_count,
16661-
op2_val->data.x_bigint.is_negative)) {
16662-
Buf *val_buf = buf_alloc();
16663-
bigint_append_buf(val_buf, &op2_val->data.x_bigint, 10);
16676+
16677+
BigInt bit_count_value = {0};
16678+
bigint_init_unsigned(&bit_count_value, bit_count);
16679+
16680+
if (bigint_cmp(&op2_val->data.x_bigint, &bit_count_value) != CmpLT) {
1666416681
ErrorMsg* msg = ir_add_error(ira,
1666516682
&bin_op_instruction->base.base,
1666616683
buf_sprintf("RHS of shift is too large for LHS type"));
16667-
add_error_note(
16668-
ira->codegen,
16669-
msg,
16670-
op2->base.source_node,
16671-
buf_sprintf("value %s cannot fit into type %s",
16672-
buf_ptr(val_buf),
16673-
buf_ptr(&shift_amt_type->name)));
16684+
add_error_note(ira->codegen, msg, op1->base.source_node,
16685+
buf_sprintf("type %s has only %u bits",
16686+
buf_ptr(&op1->value->type->name), bit_count));
16687+
1667416688
return ira->codegen->invalid_inst_gen;
1667516689
}
1667616690
}
16691+
}
1667716692

16678-
casted_op2 = ir_implicit_cast(ira, op2, shift_amt_type);
16679-
if (type_is_invalid(casted_op2->value->type))
16693+
// Fast path for zero RHS
16694+
if (instr_is_comptime(casted_op2)) {
16695+
ZigValue *op2_val = ir_resolve_const(ira, casted_op2, UndefBad);
16696+
if (op2_val == nullptr)
1668016697
return ira->codegen->invalid_inst_gen;
16698+
16699+
if (bigint_cmp_zero(&op2_val->data.x_bigint) == CmpEQ)
16700+
return ir_analyze_cast(ira, &bin_op_instruction->base.base, op1->value->type, op1);
1668116701
}
1668216702

1668316703
if (instr_is_comptime(op1) && instr_is_comptime(casted_op2)) {
@@ -16690,12 +16710,6 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in
1669016710
return ira->codegen->invalid_inst_gen;
1669116711

1669216712
return ir_analyze_math_op(ira, &bin_op_instruction->base.base, op1->value->type, op1_val, op_id, op2_val);
16693-
} else if (op1->value->type->id == ZigTypeIdComptimeInt) {
16694-
ir_add_error(ira, &bin_op_instruction->base.base,
16695-
buf_sprintf("LHS of shift must be an integer type, or RHS must be compile-time known"));
16696-
return ira->codegen->invalid_inst_gen;
16697-
} else if (instr_is_comptime(casted_op2) && bigint_cmp_zero(&casted_op2->value->data.x_bigint) == CmpEQ) {
16698-
return ir_build_cast(ira, &bin_op_instruction->base.base, op1->value->type, op1, CastOpNoop);
1669916713
}
1670016714

1670116715
return ir_build_bin_op_gen(ira, &bin_op_instruction->base.base, op1->value->type,

test/compile_errors.zig

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,38 @@ const tests = @import("tests.zig");
22
const std = @import("std");
33

44
pub fn addCases(cases: *tests.CompileErrorContext) void {
5+
cases.addTest("shift on type with non-power-of-two size",
6+
\\export fn entry() void {
7+
\\ const S = struct {
8+
\\ fn a() void {
9+
\\ var x: u24 = 42;
10+
\\ _ = x >> 24;
11+
\\ }
12+
\\ fn b() void {
13+
\\ var x: u24 = 42;
14+
\\ _ = x << 24;
15+
\\ }
16+
\\ fn c() void {
17+
\\ var x: u24 = 42;
18+
\\ _ = @shlExact(x, 24);
19+
\\ }
20+
\\ fn d() void {
21+
\\ var x: u24 = 42;
22+
\\ _ = @shrExact(x, 24);
23+
\\ }
24+
\\ };
25+
\\ S.a();
26+
\\ S.b();
27+
\\ S.c();
28+
\\ S.d();
29+
\\}
30+
, &[_][]const u8{
31+
"tmp.zig:5:19: error: RHS of shift is too large for LHS type",
32+
"tmp.zig:9:19: error: RHS of shift is too large for LHS type",
33+
"tmp.zig:13:17: error: RHS of shift is too large for LHS type",
34+
"tmp.zig:17:17: error: RHS of shift is too large for LHS type",
35+
});
36+
537
cases.addTest("combination of noasync and async",
638
\\export fn entry() void {
739
\\ noasync {
@@ -4029,8 +4061,7 @@ pub fn addCases(cases: *tests.CompileErrorContext) void {
40294061
\\}
40304062
\\export fn entry() u16 { return f(); }
40314063
, &[_][]const u8{
4032-
"tmp.zig:3:14: error: RHS of shift is too large for LHS type",
4033-
"tmp.zig:3:17: note: value 8 cannot fit into type u3",
4064+
"tmp.zig:3:17: error: integer value 8 cannot be coerced to type 'u3'",
40344065
});
40354066

40364067
cases.add("missing function call param",

test/runtime_safety.zig

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,37 @@
11
const tests = @import("tests.zig");
22

33
pub fn addCases(cases: *tests.CompareOutputContext) void {
4+
cases.addRuntimeSafety("shift left by huge amount",
5+
\\const std = @import("std");
6+
\\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
7+
\\ std.debug.warn("{}\n", .{message});
8+
\\ if (std.mem.eql(u8, message, "shift amount is greater than the type size")) {
9+
\\ std.process.exit(126); // good
10+
\\ }
11+
\\ std.process.exit(0); // test failed
12+
\\}
13+
\\pub fn main() void {
14+
\\ var x: u24 = 42;
15+
\\ var y: u5 = 24;
16+
\\ var z = x >> y;
17+
\\}
18+
);
19+
20+
cases.addRuntimeSafety("shift right by huge amount",
21+
\\const std = @import("std");
22+
\\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
23+
\\ if (std.mem.eql(u8, message, "shift amount is greater than the type size")) {
24+
\\ std.process.exit(126); // good
25+
\\ }
26+
\\ std.process.exit(0); // test failed
27+
\\}
28+
\\pub fn main() void {
29+
\\ var x: u24 = 42;
30+
\\ var y: u5 = 24;
31+
\\ var z = x << y;
32+
\\}
33+
);
34+
435
cases.addRuntimeSafety("slice sentinel mismatch - optional pointers",
536
\\const std = @import("std");
637
\\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {

test/stage1/behavior/math.zig

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,25 @@ fn testShrExact(x: u8) void {
453453
expect(shifted == 0b00101101);
454454
}
455455

456+
test "shift left/right on u0 operand" {
457+
const S = struct {
458+
fn doTheTest() void {
459+
var x: u0 = 0;
460+
var y: u0 = 0;
461+
expectEqual(@as(u0, 0), x << 0);
462+
expectEqual(@as(u0, 0), x >> 0);
463+
expectEqual(@as(u0, 0), x << y);
464+
expectEqual(@as(u0, 0), x >> y);
465+
expectEqual(@as(u0, 0), @shlExact(x, 0));
466+
expectEqual(@as(u0, 0), @shrExact(x, 0));
467+
expectEqual(@as(u0, 0), @shlExact(x, y));
468+
expectEqual(@as(u0, 0), @shrExact(x, y));
469+
}
470+
};
471+
S.doTheTest();
472+
comptime S.doTheTest();
473+
}
474+
456475
test "comptime_int addition" {
457476
comptime {
458477
expect(35361831660712422535336160538497375248 + 101752735581729509668353361206450473702 == 137114567242441932203689521744947848950);

0 commit comments

Comments
 (0)