Skip to content

Commit a6525c1

Browse files
authored
Merge pull request #22529 from xtexx/x86-64/shl-sat-int
x86_64: Implement integer saturating left shifting codegen
2 parents 235001a + 1da909a commit a6525c1

File tree

2 files changed

+184
-6
lines changed

2 files changed

+184
-6
lines changed

src/arch/x86_64/CodeGen.zig

Lines changed: 137 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85078,10 +85078,132 @@ fn airShlShrBinOp(self: *CodeGen, inst: Air.Inst.Index) !void {
8507885078
}
8507985079

8508085080
fn airShlSat(self: *CodeGen, inst: Air.Inst.Index) !void {
85081+
const zcu = self.pt.zcu;
8508185082
const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op;
85082-
_ = bin_op;
85083-
return self.fail("TODO implement shl_sat for {}", .{self.target.cpu.arch});
85084-
//return self.finishAir(inst, result, .{ bin_op.lhs, bin_op.rhs, .none });
85083+
const lhs_ty = self.typeOf(bin_op.lhs);
85084+
const rhs_ty = self.typeOf(bin_op.rhs);
85085+
85086+
const result: MCValue = result: {
85087+
switch (lhs_ty.zigTypeTag(zcu)) {
85088+
.int => {
85089+
const lhs_bits = lhs_ty.bitSize(zcu);
85090+
const rhs_bits = rhs_ty.bitSize(zcu);
85091+
if (!(lhs_bits <= 32 and rhs_bits <= 5) and !(lhs_bits > 32 and lhs_bits <= 64 and rhs_bits <= 6) and !(rhs_bits <= std.math.log2(lhs_bits))) {
85092+
return self.fail("TODO implement shl_sat for {} with lhs bits {}, rhs bits {}", .{ self.target.cpu.arch, lhs_bits, rhs_bits });
85093+
}
85094+
85095+
// clobberred by genShiftBinOp
85096+
try self.spillRegisters(&.{.rcx});
85097+
85098+
const lhs_mcv = try self.resolveInst(bin_op.lhs);
85099+
var lhs_temp1 = try self.tempInit(lhs_ty, lhs_mcv);
85100+
const rhs_mcv = try self.resolveInst(bin_op.rhs);
85101+
85102+
const lhs_lock = switch (lhs_mcv) {
85103+
.register => |reg| self.register_manager.lockRegAssumeUnused(reg),
85104+
else => null,
85105+
};
85106+
defer if (lhs_lock) |lock| self.register_manager.unlockReg(lock);
85107+
85108+
// shift left
85109+
const dst_mcv = try self.genShiftBinOp(.shl, null, lhs_mcv, rhs_mcv, lhs_ty, rhs_ty);
85110+
switch (dst_mcv) {
85111+
.register => |dst_reg| try self.truncateRegister(lhs_ty, dst_reg),
85112+
.register_pair => |dst_regs| try self.truncateRegister(lhs_ty, dst_regs[1]),
85113+
.load_frame => |frame_addr| {
85114+
const tmp_reg =
85115+
try self.register_manager.allocReg(null, abi.RegisterClass.gp);
85116+
const tmp_lock = self.register_manager.lockRegAssumeUnused(tmp_reg);
85117+
defer self.register_manager.unlockReg(tmp_lock);
85118+
85119+
const lhs_bits_u31: u31 = @intCast(lhs_bits);
85120+
const tmp_ty: Type = if (lhs_bits_u31 > 64) .usize else lhs_ty;
85121+
const off = frame_addr.off + (lhs_bits_u31 - 1) / 64 * 8;
85122+
try self.genSetReg(
85123+
tmp_reg,
85124+
tmp_ty,
85125+
.{ .load_frame = .{ .index = frame_addr.index, .off = off } },
85126+
.{},
85127+
);
85128+
try self.truncateRegister(lhs_ty, tmp_reg);
85129+
try self.genSetMem(
85130+
.{ .frame = frame_addr.index },
85131+
off,
85132+
tmp_ty,
85133+
.{ .register = tmp_reg },
85134+
.{},
85135+
);
85136+
},
85137+
else => {},
85138+
}
85139+
const dst_lock = switch (dst_mcv) {
85140+
.register => |reg| self.register_manager.lockRegAssumeUnused(reg),
85141+
else => null,
85142+
};
85143+
defer if (dst_lock) |lock| self.register_manager.unlockReg(lock);
85144+
85145+
// shift right
85146+
const tmp_mcv = try self.genShiftBinOp(.shr, null, dst_mcv, rhs_mcv, lhs_ty, rhs_ty);
85147+
var tmp_temp = try self.tempInit(lhs_ty, tmp_mcv);
85148+
85149+
// check if overflow happens
85150+
const cc_temp = lhs_temp1.cmpInts(.neq, &tmp_temp, self) catch |err| switch (err) {
85151+
error.SelectFailed => unreachable,
85152+
else => |e| return e,
85153+
};
85154+
try lhs_temp1.die(self);
85155+
try tmp_temp.die(self);
85156+
const overflow_reloc = try self.genCondBrMir(lhs_ty, cc_temp.tracking(self).short);
85157+
try cc_temp.die(self);
85158+
85159+
// if overflow,
85160+
// for unsigned integers, the saturating result is just its max
85161+
// for signed integers,
85162+
// if lhs is positive, the result is its max
85163+
// if lhs is negative, it is min
85164+
switch (lhs_ty.intInfo(zcu).signedness) {
85165+
.unsigned => {
85166+
const bound_mcv = try self.genTypedValue(try lhs_ty.maxIntScalar(self.pt, lhs_ty));
85167+
try self.genCopy(lhs_ty, dst_mcv, bound_mcv, .{});
85168+
},
85169+
.signed => {
85170+
// check the sign of lhs
85171+
// TODO: optimize this.
85172+
// we only need the highest bit so shifting the highest part of lhs_mcv
85173+
// is enough to check the signedness. other parts can be skipped here.
85174+
var lhs_temp2 = try self.tempInit(lhs_ty, lhs_mcv);
85175+
var zero_temp = try self.tempInit(lhs_ty, try self.genTypedValue(try self.pt.intValue(lhs_ty, 0)));
85176+
const sign_cc_temp = lhs_temp2.cmpInts(.lt, &zero_temp, self) catch |err| switch (err) {
85177+
error.SelectFailed => unreachable,
85178+
else => |e| return e,
85179+
};
85180+
try lhs_temp2.die(self);
85181+
try zero_temp.die(self);
85182+
const sign_reloc_condbr = try self.genCondBrMir(lhs_ty, sign_cc_temp.tracking(self).short);
85183+
try sign_cc_temp.die(self);
85184+
85185+
// if it is negative
85186+
const min_mcv = try self.genTypedValue(try lhs_ty.minIntScalar(self.pt, lhs_ty));
85187+
try self.genCopy(lhs_ty, dst_mcv, min_mcv, .{});
85188+
const sign_reloc_br = try self.asmJmpReloc(undefined);
85189+
self.performReloc(sign_reloc_condbr);
85190+
85191+
// if it is positive
85192+
const max_mcv = try self.genTypedValue(try lhs_ty.maxIntScalar(self.pt, lhs_ty));
85193+
try self.genCopy(lhs_ty, dst_mcv, max_mcv, .{});
85194+
self.performReloc(sign_reloc_br);
85195+
},
85196+
}
85197+
85198+
self.performReloc(overflow_reloc);
85199+
break :result dst_mcv;
85200+
},
85201+
else => {
85202+
return self.fail("TODO implement shl_sat for {} op type {}", .{ self.target.cpu.arch, lhs_ty.zigTypeTag(zcu) });
85203+
},
85204+
}
85205+
};
85206+
return self.finishAir(inst, result, .{ bin_op.lhs, bin_op.rhs, .none });
8508585207
}
8508685208

8508785209
fn airOptionalPayload(self: *CodeGen, inst: Air.Inst.Index) !void {
@@ -88466,7 +88588,7 @@ fn genShiftBinOpMir(
8846688588
) !void {
8846788589
const pt = self.pt;
8846888590
const zcu = pt.zcu;
88469-
const abi_size: u32 = @intCast(lhs_ty.abiSize(zcu));
88591+
const abi_size: u31 = @intCast(lhs_ty.abiSize(zcu));
8847088592
const shift_abi_size: u32 = @intCast(rhs_ty.abiSize(zcu));
8847188593
try self.spillEflagsIfOccupied();
8847288594

@@ -88650,7 +88772,17 @@ fn genShiftBinOpMir(
8865088772
.immediate => {},
8865188773
else => self.performReloc(skip),
8865288774
}
88653-
}
88775+
} else try self.asmRegisterMemory(.{ ._, .mov }, temp_regs[2].to64(), .{
88776+
.base = .{ .frame = lhs_mcv.load_frame.index },
88777+
.mod = .{ .rm = .{
88778+
.size = .qword,
88779+
.disp = switch (tag[0]) {
88780+
._l => lhs_mcv.load_frame.off,
88781+
._r => lhs_mcv.load_frame.off + abi_size - 8,
88782+
else => unreachable,
88783+
},
88784+
} },
88785+
});
8865488786
switch (rhs_mcv) {
8865588787
.immediate => |shift_imm| try self.asmRegisterImmediate(
8865688788
tag,

test/behavior/bit_shifting.zig

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
const std = @import("std");
22
const expect = std.testing.expect;
3+
const expectEqual = std.testing.expectEqual;
34
const builtin = @import("builtin");
45

56
fn ShardedTable(comptime Key: type, comptime mask_bit_count: comptime_int, comptime V: type) type {
@@ -111,7 +112,6 @@ test "comptime shift safety check" {
111112
}
112113

113114
test "Saturating Shift Left where lhs is of a computed type" {
114-
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
115115
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
116116
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
117117
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
@@ -159,3 +159,49 @@ comptime {
159159
_ = &image;
160160
_ = @shlExact(@as(u16, image[0]), 8);
161161
}
162+
163+
test "Saturating Shift Left" {
164+
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest;
165+
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest;
166+
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
167+
if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
168+
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
169+
if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest;
170+
171+
const S = struct {
172+
fn shlSat(x: anytype, y: std.math.Log2Int(@TypeOf(x))) @TypeOf(x) {
173+
// workaround https://github.com/ziglang/zig/issues/23033
174+
@setRuntimeSafety(false);
175+
return x <<| y;
176+
}
177+
178+
fn testType(comptime T: type) !void {
179+
comptime var rhs: std.math.Log2Int(T) = 0;
180+
inline while (true) : (rhs += 1) {
181+
comptime var lhs: T = std.math.minInt(T);
182+
inline while (true) : (lhs += 1) {
183+
try expectEqual(lhs <<| rhs, shlSat(lhs, rhs));
184+
if (lhs == std.math.maxInt(T)) break;
185+
}
186+
if (rhs == @bitSizeOf(T) - 1) break;
187+
}
188+
}
189+
};
190+
191+
try S.testType(u2);
192+
try S.testType(i2);
193+
try S.testType(u3);
194+
try S.testType(i3);
195+
try S.testType(u4);
196+
try S.testType(i4);
197+
198+
try expectEqual(0xfffffffffffffff0fffffffffffffff0, S.shlSat(@as(u128, 0x0fffffffffffffff0fffffffffffffff), 4));
199+
try expectEqual(0xffffffffffffffffffffffffffffffff, S.shlSat(@as(u128, 0x0fffffffffffffff0fffffffffffffff), 5));
200+
try expectEqual(-0x80000000000000000000000000000000, S.shlSat(@as(i128, -0x0fffffffffffffff0fffffffffffffff), 5));
201+
202+
// TODO
203+
// try expectEqual(51146728248377216718956089012931236753385031969422887335676427626502090568823039920051095192592252455482604439493126109519019633529459266458258243583, S.shlSat(@as(i495, 0x2fe6bc5448c55ce18252e2c9d44777505dfe63ff249a8027a6626c7d8dd9893fd5731e51474727be556f757facb586a4e04bbc0148c6c7ad692302f46fbd), 0x31));
204+
try expectEqual(-57896044618658097711785492504343953926634992332820282019728792003956564819968, S.shlSat(@as(i256, -0x53d4148cee74ea43477a65b3daa7b8fdadcbf4508e793f4af113b8d8da5a7eb6), 0x91));
205+
try expectEqual(170141183460469231731687303715884105727, S.shlSat(@as(i128, 0x2fe6bc5448c55ce18252e2c9d4477750), 0x31));
206+
try expectEqual(0, S.shlSat(@as(i128, 0), 127));
207+
}

0 commit comments

Comments
 (0)