Skip to content

Commit b5ac079

Browse files
authored
Merge pull request #4191 from Vexu/non-exhaustive-enums
Implement non-exhaustive enums
2 parents d9be6e5 + 39f92a9 commit b5ac079

File tree

11 files changed

+330
-131
lines changed

11 files changed

+330
-131
lines changed

doc/langref.html.in

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2893,6 +2893,47 @@ test "switch using enum literals" {
28932893
}
28942894
{#code_end#}
28952895
{#header_close#}
2896+
2897+
{#header_open|Non-exhaustive enum#}
2898+
<p>
2899+
A Non-exhaustive enum can be created by adding a trailing '_' field.
2900+
It must specify a tag type and cannot consume every enumeration value.
2901+
</p>
2902+
<p>
2903+
{#link|@intToEnum#} on a non-exhaustive enum cannot fail.
2904+
</p>
2905+
<p>
2906+
A switch on a non-exhaustive enum can include a '_' prong as an alternative to an {#syntax#}else{#endsyntax#} prong
2907+
with the difference being that it makes it a compile error if all the known tag names are not handled by the switch.
2908+
</p>
2909+
{#code_begin|test#}
2910+
const std = @import("std");
2911+
const assert = std.debug.assert;
2912+
2913+
const Number = enum(u8) {
2914+
One,
2915+
Two,
2916+
Three,
2917+
_,
2918+
};
2919+
2920+
test "switch on non-exhaustive enum" {
2921+
const number = Number.One;
2922+
const result = switch (number) {
2923+
.One => true,
2924+
.Two,
2925+
.Three => false,
2926+
_ => false,
2927+
};
2928+
assert(result);
2929+
const is_one = switch (number) {
2930+
.One => true,
2931+
else => false,
2932+
};
2933+
assert(is_one);
2934+
}
2935+
{#code_end#}
2936+
{#header_close#}
28962937
{#header_close#}
28972938

28982939
{#header_open|union#}

lib/std/builtin.zig

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ pub const TypeInfo = union(enum) {
254254
tag_type: type,
255255
fields: []EnumField,
256256
decls: []Declaration,
257+
is_exhaustive: bool,
257258
};
258259

259260
/// This data structure is used by the Zig language code generation and

src-self-hosted/translate_c.zig

Lines changed: 42 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,7 @@ pub fn translate(
289289
tree.errors = ast.Tree.ErrorList.init(arena);
290290

291291
tree.root_node = try arena.create(ast.Node.Root);
292-
tree.root_node.* = ast.Node.Root{
293-
.base = ast.Node{ .id = ast.Node.Id.Root },
292+
tree.root_node.* = .{
294293
.decls = ast.Node.Root.DeclList.init(arena),
295294
// initialized with the eof token at the end
296295
.eof_token = undefined,
@@ -440,7 +439,6 @@ fn visitFnDecl(c: *Context, fn_decl: *const ZigClangFunctionDecl) Error!void {
440439
.PrivateExtern => return failDecl(c, fn_decl_loc, fn_name, "unsupported storage class: private extern", .{}),
441440
.Auto => unreachable, // Not legal on functions
442441
.Register => unreachable, // Not legal on functions
443-
else => unreachable,
444442
},
445443
};
446444

@@ -877,25 +875,23 @@ fn transEnumDecl(c: *Context, enum_decl: *const ZigClangEnumDecl) Error!?*ast.No
877875
// types, while that's not ISO-C compliant many compilers allow this and
878876
// default to the usual integer type used for all the enums.
879877

880-
// TODO only emit this tag type if the enum tag type is not the default.
881-
// I don't know what the default is, need to figure out how clang is deciding.
882-
// it appears to at least be different across gcc/msvc
883-
if (int_type.ptr != null and
884-
!isCBuiltinType(int_type, .UInt) and
885-
!isCBuiltinType(int_type, .Int))
886-
{
887-
_ = try appendToken(c, .LParen, "(");
888-
container_node.init_arg_expr = .{
889-
.Type = transQualType(rp, int_type, enum_loc) catch |err| switch (err) {
878+
// default to c_int since msvc and gcc default to different types
879+
_ = try appendToken(c, .LParen, "(");
880+
container_node.init_arg_expr = .{
881+
.Type = if (int_type.ptr != null and
882+
!isCBuiltinType(int_type, .UInt) and
883+
!isCBuiltinType(int_type, .Int))
884+
transQualType(rp, int_type, enum_loc) catch |err| switch (err) {
890885
error.UnsupportedType => {
891886
try failDecl(c, enum_loc, name, "unable to translate enum tag type", .{});
892887
return null;
893888
},
894889
else => |e| return e,
895-
},
896-
};
897-
_ = try appendToken(c, .RParen, ")");
898-
}
890+
}
891+
else
892+
try transCreateNodeIdentifier(c, "c_int"),
893+
};
894+
_ = try appendToken(c, .RParen, ")");
899895

900896
container_node.lbrace_token = try appendToken(c, .LBrace, "{");
901897

@@ -953,6 +949,19 @@ fn transEnumDecl(c: *Context, enum_decl: *const ZigClangEnumDecl) Error!?*ast.No
953949
tld_node.semicolon_token = try appendToken(c, .Semicolon, ";");
954950
try addTopLevelDecl(c, field_name, &tld_node.base);
955951
}
952+
// make non exhaustive
953+
const field_node = try c.a().create(ast.Node.ContainerField);
954+
field_node.* = .{
955+
.doc_comments = null,
956+
.comptime_token = null,
957+
.name_token = try appendIdentifier(c, "_"),
958+
.type_expr = null,
959+
.value_expr = null,
960+
.align_expr = null,
961+
};
962+
963+
try container_node.fields_and_decls.push(&field_node.base);
964+
_ = try appendToken(c, .Comma, ",");
956965
container_node.rbrace_token = try appendToken(c, .RBrace, "}");
957966

958967
break :blk &container_node.base;
@@ -1231,18 +1240,6 @@ fn transBinaryOperator(
12311240
op_id = .BitOr;
12321241
op_token = try appendToken(rp.c, .Pipe, "|");
12331242
},
1234-
.Assign,
1235-
.MulAssign,
1236-
.DivAssign,
1237-
.RemAssign,
1238-
.AddAssign,
1239-
.SubAssign,
1240-
.ShlAssign,
1241-
.ShrAssign,
1242-
.AndAssign,
1243-
.XorAssign,
1244-
.OrAssign,
1245-
=> unreachable,
12461243
else => unreachable,
12471244
}
12481245

@@ -1678,7 +1675,6 @@ fn transStringLiteral(
16781675
"TODO: support string literal kind {}",
16791676
.{kind},
16801677
),
1681-
else => unreachable,
16821678
}
16831679
}
16841680

@@ -2206,6 +2202,19 @@ fn transDoWhileLoop(
22062202
.id = .Loop,
22072203
};
22082204

2205+
// if (!cond) break;
2206+
const if_node = try transCreateNodeIf(rp.c);
2207+
var cond_scope = Scope{
2208+
.parent = scope,
2209+
.id = .Condition,
2210+
};
2211+
const prefix_op = try transCreateNodePrefixOp(rp.c, .BoolNot, .Bang, "!");
2212+
prefix_op.rhs = try transBoolExpr(rp, &cond_scope, @ptrCast(*const ZigClangExpr, ZigClangDoStmt_getCond(stmt)), .used, .r_value, true);
2213+
_ = try appendToken(rp.c, .RParen, ")");
2214+
if_node.condition = &prefix_op.base;
2215+
if_node.body = &(try transCreateNodeBreak(rp.c, null)).base;
2216+
_ = try appendToken(rp.c, .Semicolon, ";");
2217+
22092218
const body_node = if (ZigClangStmt_getStmtClass(ZigClangDoStmt_getBody(stmt)) == .CompoundStmtClass) blk: {
22102219
// there's already a block in C, so we'll append our condition to it.
22112220
// c: do {
@@ -2217,10 +2226,7 @@ fn transDoWhileLoop(
22172226
// zig: b;
22182227
// zig: if (!cond) break;
22192228
// zig: }
2220-
const body = (try transStmt(rp, &loop_scope, ZigClangDoStmt_getBody(stmt), .unused, .r_value)).cast(ast.Node.Block).?;
2221-
// if this is used as an expression in Zig it needs to be immediately followed by a semicolon
2222-
_ = try appendToken(rp.c, .Semicolon, ";");
2223-
break :blk body;
2229+
break :blk (try transStmt(rp, &loop_scope, ZigClangDoStmt_getBody(stmt), .unused, .r_value)).cast(ast.Node.Block).?;
22242230
} else blk: {
22252231
// the C statement is without a block, so we need to create a block to contain it.
22262232
// c: do
@@ -2236,19 +2242,6 @@ fn transDoWhileLoop(
22362242
break :blk block;
22372243
};
22382244

2239-
// if (!cond) break;
2240-
const if_node = try transCreateNodeIf(rp.c);
2241-
var cond_scope = Scope{
2242-
.parent = scope,
2243-
.id = .Condition,
2244-
};
2245-
const prefix_op = try transCreateNodePrefixOp(rp.c, .BoolNot, .Bang, "!");
2246-
prefix_op.rhs = try transBoolExpr(rp, &cond_scope, @ptrCast(*const ZigClangExpr, ZigClangDoStmt_getCond(stmt)), .used, .r_value, true);
2247-
_ = try appendToken(rp.c, .RParen, ")");
2248-
if_node.condition = &prefix_op.base;
2249-
if_node.body = &(try transCreateNodeBreak(rp.c, null)).base;
2250-
_ = try appendToken(rp.c, .Semicolon, ";");
2251-
22522245
try body_node.statements.push(&if_node.base);
22532246
if (new)
22542247
body_node.rbrace = try appendToken(rp.c, .RBrace, "}");
@@ -4783,8 +4776,7 @@ fn appendIdentifier(c: *Context, name: []const u8) !ast.TokenIndex {
47834776
fn transCreateNodeIdentifier(c: *Context, name: []const u8) !*ast.Node {
47844777
const token_index = try appendIdentifier(c, name);
47854778
const identifier = try c.a().create(ast.Node.Identifier);
4786-
identifier.* = ast.Node.Identifier{
4787-
.base = ast.Node{ .id = ast.Node.Id.Identifier },
4779+
identifier.* = .{
47884780
.token = token_index,
47894781
};
47904782
return &identifier.base;
@@ -4923,8 +4915,7 @@ fn transMacroFnDefine(c: *Context, it: *ctok.TokenList.Iterator, name: []const u
49234915

49244916
const token_index = try appendToken(c, .Keyword_var, "var");
49254917
const identifier = try c.a().create(ast.Node.Identifier);
4926-
identifier.* = ast.Node.Identifier{
4927-
.base = ast.Node{ .id = ast.Node.Id.Identifier },
4918+
identifier.* = .{
49284919
.token = token_index,
49294920
};
49304921

src/all_types.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1385,6 +1385,7 @@ struct ZigTypeEnum {
13851385
ContainerLayout layout;
13861386
ResolveStatus resolve_status;
13871387

1388+
bool non_exhaustive;
13881389
bool resolve_loop_flag;
13891390
};
13901391

@@ -3669,6 +3670,7 @@ struct IrInstructionCheckSwitchProngs {
36693670
IrInstructionCheckSwitchProngsRange *ranges;
36703671
size_t range_count;
36713672
bool have_else_prong;
3673+
bool have_underscore_prong;
36723674
};
36733675

36743676
struct IrInstructionCheckStatementIsVoid {

src/analyze.cpp

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2569,15 +2569,8 @@ static Error resolve_enum_zero_bits(CodeGen *g, ZigType *enum_type) {
25692569
return ErrorSemanticAnalyzeFail;
25702570
}
25712571

2572-
enum_type->data.enumeration.src_field_count = field_count;
2573-
enum_type->data.enumeration.fields = allocate<TypeEnumField>(field_count);
2574-
enum_type->data.enumeration.fields_by_name.init(field_count);
2575-
25762572
Scope *scope = &enum_type->data.enumeration.decls_scope->base;
25772573

2578-
HashMap<BigInt, AstNode *, bigint_hash, bigint_eql> occupied_tag_values = {};
2579-
occupied_tag_values.init(field_count);
2580-
25812574
ZigType *tag_int_type;
25822575
if (enum_type->data.enumeration.layout == ContainerLayoutExtern) {
25832576
tag_int_type = get_c_int_type(g, CIntTypeInt);
@@ -2619,6 +2612,7 @@ static Error resolve_enum_zero_bits(CodeGen *g, ZigType *enum_type) {
26192612
}
26202613
}
26212614

2615+
enum_type->data.enumeration.non_exhaustive = false;
26222616
enum_type->data.enumeration.tag_int_type = tag_int_type;
26232617
enum_type->size_in_bits = tag_int_type->size_in_bits;
26242618
enum_type->abi_size = tag_int_type->abi_size;
@@ -2627,6 +2621,31 @@ static Error resolve_enum_zero_bits(CodeGen *g, ZigType *enum_type) {
26272621
BigInt bi_one;
26282622
bigint_init_unsigned(&bi_one, 1);
26292623

2624+
AstNode *last_field_node = decl_node->data.container_decl.fields.at(field_count - 1);
2625+
if (buf_eql_str(last_field_node->data.struct_field.name, "_")) {
2626+
field_count -= 1;
2627+
if (field_count > 1 && log2_u64(field_count) == enum_type->size_in_bits) {
2628+
add_node_error(g, last_field_node, buf_sprintf("non-exhaustive enum specifies every value"));
2629+
enum_type->data.enumeration.resolve_status = ResolveStatusInvalid;
2630+
}
2631+
if (decl_node->data.container_decl.init_arg_expr == nullptr) {
2632+
add_node_error(g, last_field_node, buf_sprintf("non-exhaustive enum must specify size"));
2633+
enum_type->data.enumeration.resolve_status = ResolveStatusInvalid;
2634+
}
2635+
if (last_field_node->data.struct_field.value != nullptr) {
2636+
add_node_error(g, last_field_node, buf_sprintf("value assigned to '_' field of non-exhaustive enum"));
2637+
enum_type->data.enumeration.resolve_status = ResolveStatusInvalid;
2638+
}
2639+
enum_type->data.enumeration.non_exhaustive = true;
2640+
}
2641+
2642+
enum_type->data.enumeration.src_field_count = field_count;
2643+
enum_type->data.enumeration.fields = allocate<TypeEnumField>(field_count);
2644+
enum_type->data.enumeration.fields_by_name.init(field_count);
2645+
2646+
HashMap<BigInt, AstNode *, bigint_hash, bigint_eql> occupied_tag_values = {};
2647+
occupied_tag_values.init(field_count);
2648+
26302649
TypeEnumField *last_enum_field = nullptr;
26312650

26322651
for (uint32_t field_i = 0; field_i < field_count; field_i += 1) {
@@ -2648,6 +2667,11 @@ static Error resolve_enum_zero_bits(CodeGen *g, ZigType *enum_type) {
26482667
buf_sprintf("consider 'union(enum)' here"));
26492668
}
26502669

2670+
if (buf_eql_str(type_enum_field->name, "_")) {
2671+
add_node_error(g, field_node, buf_sprintf("'_' field of non-exhaustive enum must be last"));
2672+
enum_type->data.enumeration.resolve_status = ResolveStatusInvalid;
2673+
}
2674+
26512675
auto field_entry = enum_type->data.enumeration.fields_by_name.put_unique(type_enum_field->name, type_enum_field);
26522676
if (field_entry != nullptr) {
26532677
ErrorMsg *msg = add_node_error(g, field_node,

src/codegen.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3356,7 +3356,7 @@ static LLVMValueRef ir_render_int_to_enum(CodeGen *g, IrExecutable *executable,
33563356
LLVMValueRef tag_int_value = gen_widen_or_shorten(g, ir_want_runtime_safety(g, &instruction->base),
33573357
instruction->target->value->type, tag_int_type, target_val);
33583358

3359-
if (ir_want_runtime_safety(g, &instruction->base) && wanted_type->data.enumeration.layout != ContainerLayoutExtern) {
3359+
if (ir_want_runtime_safety(g, &instruction->base) && !wanted_type->data.enumeration.non_exhaustive) {
33603360
LLVMBasicBlockRef bad_value_block = LLVMAppendBasicBlock(g->cur_fn_val, "BadValue");
33613361
LLVMBasicBlockRef ok_value_block = LLVMAppendBasicBlock(g->cur_fn_val, "OkValue");
33623362
size_t field_count = wanted_type->data.enumeration.src_field_count;
@@ -5065,6 +5065,11 @@ static LLVMValueRef ir_render_enum_tag_name(CodeGen *g, IrExecutable *executable
50655065
{
50665066
ZigType *enum_type = instruction->target->value->type;
50675067
assert(enum_type->id == ZigTypeIdEnum);
5068+
if (enum_type->data.enumeration.non_exhaustive) {
5069+
add_node_error(g, instruction->base.source_node,
5070+
buf_sprintf("TODO @tagName on non-exhaustive enum https://github.com/ziglang/zig/issues/3991"));
5071+
codegen_report_errors_and_exit(g);
5072+
}
50685073

50695074
LLVMValueRef enum_name_function = get_enum_tag_name_function(g, enum_type);
50705075

0 commit comments

Comments
 (0)