Skip to content

stage1: Implement non-exhaustive tagged unions #8163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion src/stage1/analyze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3189,9 +3189,35 @@ static Error resolve_union_zero_bits(CodeGen *g, ZigType *union_type) {
union_type->data.unionation.resolve_loop_flag_zero_bits = true;

uint32_t field_count;
bool container_non_exhaustive;
if (decl_node->type == NodeTypeContainerDecl) {
AstNode *last_field_node = decl_node->data.container_decl.fields.at(decl_node->data.container_decl.fields.length - 1);
if (buf_eql_str(last_field_node->data.struct_field.name, "_")) {
if (last_field_node->data.struct_field.value != nullptr) {
add_node_error(g, last_field_node, buf_sprintf("value assigned to '_' field of non-exhaustive union"));
union_type->data.unionation.resolve_status = ResolveStatusInvalid;
return ErrorSemanticAnalyzeFail;
}
if (decl_node->data.container_decl.init_arg_expr == nullptr) {
add_node_error(g, decl_node, buf_sprintf("non-exhaustive enum must specify size"));
union_type->data.unionation.resolve_status = ResolveStatusInvalid;
return ErrorSemanticAnalyzeFail;
}
bool is_auto_enum = decl_node->data.container_decl.auto_enum;
bool is_explicit_enum = decl_node->data.container_decl.init_arg_expr != nullptr;
if (!is_auto_enum && !is_explicit_enum) {
add_node_error(g, decl_node, buf_sprintf("untagged union cannot be non-exhaustive"));
union_type->data.unionation.resolve_status = ResolveStatusInvalid;
return ErrorSemanticAnalyzeFail;
}
container_non_exhaustive = true;
} else {
container_non_exhaustive = false;
}

assert(union_type->data.unionation.fields == nullptr);
field_count = (uint32_t)decl_node->data.container_decl.fields.length;
field_count = (uint32_t)decl_node->data.container_decl.fields.length
- container_non_exhaustive;
union_type->data.unionation.src_field_count = field_count;
union_type->data.unionation.fields = heap::c_allocator.allocate<TypeUnionField>(field_count);
union_type->data.unionation.fields_by_name.init(field_count);
Expand Down Expand Up @@ -3287,6 +3313,7 @@ static Error resolve_union_zero_bits(CodeGen *g, ZigType *union_type) {
tag_type->data.enumeration.fields_by_name.init(field_count);
tag_type->data.enumeration.decls_scope = create_decls_scope(
g, nullptr, nullptr, tag_type, get_scope_import(scope), &tag_type->name);
tag_type->data.enumeration.non_exhaustive = container_non_exhaustive;
} else if (enum_type_node != nullptr) {
tag_type = analyze_type_expr(g, scope, enum_type_node);
} else {
Expand All @@ -3311,6 +3338,20 @@ static Error resolve_union_zero_bits(CodeGen *g, ZigType *union_type) {
assert(g->errors.length != 0);
return err;
}
if (decl_node->type == NodeTypeContainerDecl) {
if (container_non_exhaustive && !tag_type->data.enumeration.non_exhaustive) {
add_node_error(g, enum_type_node != nullptr ? enum_type_node : decl_node,
buf_sprintf("enum tag of non-exhaustive union must be non-exhaustive"));
union_type->data.unionation.resolve_status = ResolveStatusInvalid;
return ErrorSemanticAnalyzeFail;
}
if (!container_non_exhaustive && tag_type->data.enumeration.non_exhaustive) {
add_node_error(g, decl_node,
buf_sprintf("union with non-exhaustive enum tag must be non-exhaustive"));
union_type->data.unionation.resolve_status = ResolveStatusInvalid;
return ErrorSemanticAnalyzeFail;
}
}
covered_enum_fields = heap::c_allocator.allocate<bool>(tag_type->data.enumeration.src_field_count);
}
union_type->data.unionation.tag_type = tag_type;
Expand Down
9 changes: 1 addition & 8 deletions src/stage1/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29709,10 +29709,6 @@ static IrInstGen *ir_analyze_instruction_check_switch_prongs(IrAnalyze *ira,
if (type_is_invalid(switch_type))
return ira->codegen->invalid_inst_gen;

ZigValue *original_value = ((IrInstSrcSwitchTarget *)(instruction->target_value))->target_value_ptr->child->value;
bool target_is_originally_union = original_value->type->id == ZigTypeIdPointer &&
original_value->type->data.pointer.child_type->id == ZigTypeIdUnion;

if (switch_type->id == ZigTypeIdEnum) {
HashMap<BigInt, AstNode *, bigint_hash, bigint_eql> field_prev_uses = {};
field_prev_uses.init(switch_type->data.enumeration.src_field_count);
Expand Down Expand Up @@ -29771,9 +29767,6 @@ static IrInstGen *ir_analyze_instruction_check_switch_prongs(IrAnalyze *ira,
if (!switch_type->data.enumeration.non_exhaustive) {
ir_add_error(ira, &instruction->base.base,
buf_sprintf("switch on exhaustive enum has `_` prong"));
} else if (target_is_originally_union) {
ir_add_error(ira, &instruction->base.base,
buf_sprintf("`_` prong not allowed when switching on tagged union"));
}
for (uint32_t i = 0; i < switch_type->data.enumeration.src_field_count; i += 1) {
TypeEnumField *enum_field = &switch_type->data.enumeration.fields[i];
Expand All @@ -29788,7 +29781,7 @@ static IrInstGen *ir_analyze_instruction_check_switch_prongs(IrAnalyze *ira,
}
}
} else if (instruction->else_prong == nullptr) {
if (switch_type->data.enumeration.non_exhaustive && !target_is_originally_union) {
if (switch_type->data.enumeration.non_exhaustive) {
ir_add_error(ira, &instruction->base.base,
buf_sprintf("switch on non-exhaustive enum must include `else` or `_` prong"));
}
Expand Down
60 changes: 60 additions & 0 deletions test/compile_errors.zig
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,66 @@ const tests = @import("tests.zig");
const std = @import("std");

pub fn addCases(cases: *tests.CompileErrorContext) void {
cases.add("non-exhaustive switch on non-exhaustive unions",
\\const A = union((enum(u32) { a, b, _ })) { a: u32, b: u32, _ };
\\const B = union(enum(u32)) { a: u32, b: u32, _ };
\\fn check(x: anytype) void {
\\ switch (x) {
\\ .a => |n| {},
\\ .b => |n| {},
\\ }
\\ switch (x) {
\\ .a => |n| {},
\\ _ => {},
\\ }
\\ switch (x) {
\\ .a => |n| {},
\\ }
\\}
\\export fn foo() void {
\\ var x = A{ .a = 0 };
\\ var y = B{ .a = 0 };
\\ check(x);
\\ check(y);
\\}
, &[_][]const u8{
"tmp.zig:4:5: error: switch on non-exhaustive enum must include `else` or `_` prong",
"tmp.zig:8:5: error: enumeration value 'enum:1:18.b' not handled in switch",
"tmp.zig:12:5: error: switch on non-exhaustive enum must include `else` or `_` prong",
"tmp.zig:12:5: error: enumeration value 'enum:1:18.b' not handled in switch",
"tmp.zig:4:5: error: switch on non-exhaustive enum must include `else` or `_` prong",
"tmp.zig:8:5: error: enumeration value '@typeInfo(B).Union.tag_type.?.b' not handled in switch",
"tmp.zig:12:5: error: switch on non-exhaustive enum must include `else` or `_` prong",
"tmp.zig:12:5: error: enumeration value '@typeInfo(B).Union.tag_type.?.b' not handled in switch",
});

cases.add("exhaustive enum tag on non-exhaustive union",
\\const C = union((enum(u32) { a, b })) { a: u32, b: u32, _ };
\\export fn foo() void {
\\ _ = C{ .a = 0 };
\\}
, &[_][]const u8{
"tmp.zig:1:17: error: enum tag of non-exhaustive union must be non-exhaustive",
});

cases.add("non-exhautive enum tag on exhaustive union",
\\const D = union((enum(u32) { a, b, _ })) { a: u32, b: u32 };
\\export fn foo() void {
\\ _ = D{ .a = 0 };
\\}
, &[_][]const u8{
"tmp.zig:1:11: error: union with non-exhaustive enum tag must be non-exhaustive",
});

cases.add("non-exhaustive union without explicit enum size",
\\const E = union(enum) { a: u32, b: u32, _ };
\\export fn foo() void {
\\ _ = E{ .a = 0 };
\\}
, &[_][]const u8{
"tmp.zig:1:11: error: non-exhaustive enum must specify size",
});

cases.add("lazy pointer with undefined element type",
\\export fn foo() void {
\\ comptime var T: type = undefined;
Expand Down
25 changes: 25 additions & 0 deletions test/stage1/behavior/union.zig
Original file line number Diff line number Diff line change
Expand Up @@ -726,12 +726,14 @@ test "switching on non exhaustive union" {
const U = union(E) {
a: i32,
b: u32,
_,
};
fn doTheTest() void {
var a = U{ .a = 2 };
switch (a) {
.a => |val| expect(val == 2),
.b => unreachable,
_ => unreachable,
}
}
};
Expand Down Expand Up @@ -804,3 +806,26 @@ test "union enum type gets a separate scope" {

S.doTheTest();
}

fn testNonExhaustive(x: anytype) void {
switch (x) {
.a => |n| {},
.b => |n| {},
_ => {},
}
switch (x) {
.a => |n| {},
else => {},
}
switch (x) {
.a => |n| {},
else => {},
}
}

test "non-exhaustive unions" {
const A = union((enum(u32) { a, b, _ })) { a: u32, b: u32, _ };
const B = union(enum(u32)) { a: u32, b: u32, _ };
testNonExhaustive(A{ .a = 0 });
testNonExhaustive(B{ .a = 0 });
}