Skip to content

Commit c96f9a0

Browse files
committed
Sema: implement @Splat for arrays
Resolves: #20433
1 parent 072e062 commit c96f9a0

File tree

7 files changed

+138
-32
lines changed

7 files changed

+138
-32
lines changed

lib/std/zig/AstGen.zig

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2716,7 +2716,7 @@ fn addEnsureResult(gz: *GenZir, maybe_unused_result: Zir.Inst.Ref, statement: As
27162716
.array_type_sentinel,
27172717
.elem_type,
27182718
.indexable_ptr_elem_type,
2719-
.vector_elem_type,
2719+
.vec_arr_elem_type,
27202720
.vector_type,
27212721
.indexable_ptr_len,
27222722
.anyframe_type,
@@ -9529,7 +9529,7 @@ fn builtinCall(
95299529

95309530
.splat => {
95319531
const result_type = try ri.rl.resultTypeForCast(gz, node, builtin_name);
9532-
const elem_type = try gz.addUnNode(.vector_elem_type, result_type, node);
9532+
const elem_type = try gz.addUnNode(.vec_arr_elem_type, result_type, node);
95339533
const scalar = try expr(gz, scope, .{ .rl = .{ .ty = elem_type } }, params[0]);
95349534
const result = try gz.addPlNode(.splat, node, Zir.Inst.Bin{
95359535
.lhs = result_type,

lib/std/zig/Zir.zig

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -247,9 +247,9 @@ pub const Inst = struct {
247247
/// element type. Emits a compile error if the type is not an indexable pointer.
248248
/// Uses the `un_node` field.
249249
indexable_ptr_elem_type,
250-
/// Given a vector type, returns its element type.
250+
/// Given a vector or array type, returns its element type.
251251
/// Uses the `un_node` field.
252-
vector_elem_type,
252+
vec_arr_elem_type,
253253
/// Given a pointer to an indexable object, returns the len property. This is
254254
/// used by for loops. This instruction also emits a for-loop specific compile
255255
/// error if the indexable object is not indexable.
@@ -1065,7 +1065,7 @@ pub const Inst = struct {
10651065
.vector_type,
10661066
.elem_type,
10671067
.indexable_ptr_elem_type,
1068-
.vector_elem_type,
1068+
.vec_arr_elem_type,
10691069
.indexable_ptr_len,
10701070
.anyframe_type,
10711071
.as_node,
@@ -1375,7 +1375,7 @@ pub const Inst = struct {
13751375
.vector_type,
13761376
.elem_type,
13771377
.indexable_ptr_elem_type,
1378-
.vector_elem_type,
1378+
.vec_arr_elem_type,
13791379
.indexable_ptr_len,
13801380
.anyframe_type,
13811381
.as_node,
@@ -1607,7 +1607,7 @@ pub const Inst = struct {
16071607
.vector_type = .pl_node,
16081608
.elem_type = .un_node,
16091609
.indexable_ptr_elem_type = .un_node,
1610-
.vector_elem_type = .un_node,
1610+
.vec_arr_elem_type = .un_node,
16111611
.indexable_ptr_len = .un_node,
16121612
.anyframe_type = .un_node,
16131613
.as_node = .pl_node,
@@ -3781,7 +3781,7 @@ fn findDeclsInner(
37813781
.vector_type,
37823782
.elem_type,
37833783
.indexable_ptr_elem_type,
3784-
.vector_elem_type,
3784+
.vec_arr_elem_type,
37853785
.indexable_ptr_len,
37863786
.anyframe_type,
37873787
.as_node,

src/Sema.zig

Lines changed: 55 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,7 +1087,7 @@ fn analyzeBodyInner(
10871087
.elem_val_imm => try sema.zirElemValImm(block, inst),
10881088
.elem_type => try sema.zirElemType(block, inst),
10891089
.indexable_ptr_elem_type => try sema.zirIndexablePtrElemType(block, inst),
1090-
.vector_elem_type => try sema.zirVectorElemType(block, inst),
1090+
.vec_arr_elem_type => try sema.zirVecArrElemType(block, inst),
10911091
.enum_literal => try sema.zirEnumLiteral(block, inst),
10921092
.decl_literal => try sema.zirDeclLiteral(block, inst, true),
10931093
.decl_literal_no_coerce => try sema.zirDeclLiteral(block, inst, false),
@@ -2046,7 +2046,7 @@ fn genericPoisonReason(sema: *Sema, block: *Block, ref: Zir.Inst.Ref) GenericPoi
20462046
const bin = sema.code.instructions.items(.data)[@intFromEnum(inst)].bin;
20472047
cur = bin.lhs;
20482048
},
2049-
.indexable_ptr_elem_type, .vector_elem_type => {
2049+
.indexable_ptr_elem_type, .vec_arr_elem_type => {
20502050
const un_node = sema.code.instructions.items(.data)[@intFromEnum(inst)].un_node;
20512051
cur = un_node.operand;
20522052
},
@@ -8603,7 +8603,7 @@ fn zirIndexablePtrElemType(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Com
86038603
return Air.internedToRef(elem_ty.toIntern());
86048604
}
86058605

8606-
fn zirVectorElemType(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
8606+
fn zirVecArrElemType(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
86078607
const pt = sema.pt;
86088608
const zcu = pt.zcu;
86098609
const un_node = sema.code.instructions.items(.data)[@intFromEnum(inst)].un_node;
@@ -8615,8 +8615,9 @@ fn zirVectorElemType(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileEr
86158615
error.GenericPoison => return .generic_poison_type,
86168616
else => |e| return e,
86178617
};
8618-
if (!vec_ty.isVector(zcu)) {
8619-
return sema.fail(block, block.nodeOffset(un_node.src_node), "expected vector type, found '{}'", .{vec_ty.fmt(pt)});
8618+
switch (vec_ty.zigTypeTag(zcu)) {
8619+
.array, .vector => {},
8620+
else => return sema.fail(block, block.nodeOffset(un_node.src_node), "expected array or vector type, found '{}'", .{vec_ty.fmt(pt)}),
86208621
}
86218622
return Air.internedToRef(vec_ty.childType(zcu).toIntern());
86228623
}
@@ -24804,26 +24805,66 @@ fn zirSplat(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.I
2480424805
const scalar_src = block.builtinCallArgSrc(inst_data.src_node, 0);
2480524806
const dest_ty = try sema.resolveDestType(block, src, extra.lhs, .remove_eu_opt, "@splat");
2480624807

24807-
if (!dest_ty.isVector(zcu)) return sema.fail(block, src, "expected vector type, found '{}'", .{dest_ty.fmt(pt)});
24808+
switch (dest_ty.zigTypeTag(zcu)) {
24809+
.array, .vector => {},
24810+
else => return sema.fail(block, src, "expected array or vector type, found '{}'", .{dest_ty.fmt(pt)}),
24811+
}
2480824812

24809-
if (!dest_ty.hasRuntimeBits(zcu)) {
24813+
const operand = try sema.resolveInst(extra.rhs);
24814+
const scalar_ty = dest_ty.childType(zcu);
24815+
const scalar = try sema.coerce(block, scalar_ty, operand, scalar_src);
24816+
24817+
const len = try sema.usizeCast(block, src, dest_ty.arrayLen(zcu));
24818+
24819+
// `len == 0` because `[0:s]T` always has a comptime-known splat.
24820+
if (!dest_ty.hasRuntimeBits(zcu) or len == 0) {
2481024821
const empty_aggregate = try pt.intern(.{ .aggregate = .{
2481124822
.ty = dest_ty.toIntern(),
24812-
.storage = .{ .elems = &[_]InternPool.Index{} },
24823+
.storage = .{ .elems = &.{} },
2481324824
} });
2481424825
return Air.internedToRef(empty_aggregate);
2481524826
}
2481624827

24817-
const operand = try sema.resolveInst(extra.rhs);
24818-
const scalar_ty = dest_ty.childType(zcu);
24819-
const scalar = try sema.coerce(block, scalar_ty, operand, scalar_src);
24828+
const maybe_sentinel = dest_ty.sentinel(zcu);
24829+
2482024830
if (try sema.resolveValue(scalar)) |scalar_val| {
24821-
if (scalar_val.isUndef(zcu)) return pt.undefRef(dest_ty);
24822-
return Air.internedToRef((try sema.splat(dest_ty, scalar_val)).toIntern());
24831+
if (scalar_val.isUndef(zcu) and maybe_sentinel == null) {
24832+
return pt.undefRef(dest_ty);
24833+
}
24834+
// TODO: I didn't want to put `.aggregate` on a separate line here; `zig fmt` bugs have forced my hand
24835+
return Air.internedToRef(try pt.intern(.{
24836+
.aggregate = .{
24837+
.ty = dest_ty.toIntern(),
24838+
.storage = s: {
24839+
full: {
24840+
if (dest_ty.zigTypeTag(zcu) == .vector) break :full;
24841+
const sentinel = maybe_sentinel orelse break :full;
24842+
if (sentinel.toIntern() == scalar_val.toIntern()) break :full;
24843+
// This is a array with non-zero length and a sentinel which does not match the element.
24844+
// We have to use the full `elems` representation.
24845+
const elems = try sema.arena.alloc(InternPool.Index, len + 1);
24846+
@memset(elems[0..len], scalar_val.toIntern());
24847+
elems[len] = sentinel.toIntern();
24848+
break :s .{ .elems = elems };
24849+
}
24850+
break :s .{ .repeated_elem = scalar_val.toIntern() };
24851+
},
24852+
},
24853+
}));
2482324854
}
2482424855

2482524856
try sema.requireRuntimeBlock(block, src, scalar_src);
24826-
return block.addTyOp(.splat, dest_ty, scalar);
24857+
24858+
switch (dest_ty.zigTypeTag(zcu)) {
24859+
.array => {
24860+
const elems = try sema.arena.alloc(Air.Inst.Ref, len + @intFromBool(maybe_sentinel != null));
24861+
@memset(elems[0..len], scalar);
24862+
if (maybe_sentinel) |s| elems[len] = Air.internedToRef(s.toIntern());
24863+
return block.addAggregateInit(dest_ty, elems);
24864+
},
24865+
.vector => return block.addTyOp(.splat, dest_ty, scalar),
24866+
else => unreachable,
24867+
}
2482724868
}
2482824869

2482924870
fn zirReduce(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {

src/print_zir.zig

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ const Writer = struct {
203203
.alloc_comptime_mut,
204204
.elem_type,
205205
.indexable_ptr_elem_type,
206-
.vector_elem_type,
206+
.vec_arr_elem_type,
207207
.indexable_ptr_len,
208208
.anyframe_type,
209209
.bit_not,

test/behavior/array.zig

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,3 +1021,70 @@ test "runtime index of array of zero-bit values" {
10211021
try std.testing.expect(result.index == 0);
10221022
try std.testing.expect(result.value == {});
10231023
}
1024+
1025+
test "@splat array" {
1026+
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
1027+
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
1028+
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
1029+
const S = struct {
1030+
fn doTheTest(comptime T: type, x: T) !void {
1031+
const arr: [10]T = @splat(x);
1032+
for (arr) |elem| {
1033+
try expectEqual(x, elem);
1034+
}
1035+
}
1036+
};
1037+
1038+
try S.doTheTest(u32, 123);
1039+
try comptime S.doTheTest(u32, 123);
1040+
1041+
const Foo = struct { x: u8 };
1042+
try S.doTheTest(Foo, .{ .x = 10 });
1043+
try comptime S.doTheTest(Foo, .{ .x = 10 });
1044+
}
1045+
1046+
test "@splat array with sentinel" {
1047+
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
1048+
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
1049+
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
1050+
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
1051+
const S = struct {
1052+
fn doTheTest(comptime T: type, x: T, comptime s: T) !void {
1053+
const arr: [10:s]T = @splat(x);
1054+
for (arr) |elem| {
1055+
try expectEqual(x, elem);
1056+
}
1057+
const ptr: [*]const T = &arr;
1058+
try expectEqual(s, ptr[10]); // sentinel correct
1059+
}
1060+
};
1061+
1062+
try S.doTheTest(u32, 100, 42);
1063+
try comptime S.doTheTest(u32, 100, 42);
1064+
1065+
try S.doTheTest(?*anyopaque, @ptrFromInt(0x1000), null);
1066+
try comptime S.doTheTest(?*anyopaque, @ptrFromInt(0x1000), null);
1067+
}
1068+
1069+
test "@splat zero-length array" {
1070+
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
1071+
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
1072+
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
1073+
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
1074+
const S = struct {
1075+
fn doTheTest(comptime T: type, comptime s: T) !void {
1076+
var runtime_undef: T = undefined;
1077+
runtime_undef = undefined;
1078+
// The array should be comptime-known despite the `@splat` operand being runtime-known.
1079+
const arr: [0:s]T = @splat(runtime_undef);
1080+
const ptr: [*]const T = &arr;
1081+
comptime assert(ptr[0] == s);
1082+
}
1083+
};
1084+
1085+
try S.doTheTest(u32, 42);
1086+
try comptime S.doTheTest(u32, 42);
1087+
1088+
try S.doTheTest(?*anyopaque, null);
1089+
try comptime S.doTheTest(?*anyopaque, null);
1090+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
export fn f() void {
2+
_ = @as(u32, @splat(5));
3+
}
4+
5+
// error
6+
//
7+
// :2:18: error: expected array or vector type, found 'u32'

test/cases/compile_errors/splat_result_type_non_vector.zig

Lines changed: 0 additions & 9 deletions
This file was deleted.

0 commit comments

Comments
 (0)