Skip to content

AstGen: avoid intermediate loads during field/array access #19347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 85 additions & 41 deletions lib/std/zig/AstGen.zig
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ const ResultInfo = struct {
/// The expression must generate a pointer rather than a value. For example, the left hand side
/// of an assignment uses this kind of result location.
ref,
/// Same as `ty` but will not mark variables as being used as an lvalue.
pseudo_ref,
/// The expression must generate a pointer rather than a value, and the pointer will be coerced
/// by other code to this type, which is guaranteed by earlier instructions to be a pointer type.
ref_coerced_ty: Zir.Inst.Ref,
Expand Down Expand Up @@ -320,7 +322,7 @@ const ResultInfo = struct {
/// the given node.
fn resultType(rl: Loc, gz: *GenZir, node: Ast.Node.Index) !?Zir.Inst.Ref {
return switch (rl) {
.discard, .none, .ref, .inferred_ptr, .destructure => null,
.discard, .none, .ref, .pseudo_ref, .inferred_ptr, .destructure => null,
.ty, .coerced_ty => |ty_ref| ty_ref,
.ref_coerced_ty => |ptr_ty| try gz.addUnNode(.elem_type, ptr_ty, node),
.ptr => |ptr| {
Expand Down Expand Up @@ -970,7 +972,7 @@ fn expr(gz: *GenZir, scope: *Scope, ri: ResultInfo, node: Ast.Node.Index) InnerE
const lhs = try expr(gz, scope, .{ .rl = .none }, node_datas[node].lhs);
_ = try gz.addUnNode(.validate_deref, lhs, node);
switch (ri.rl) {
.ref, .ref_coerced_ty => return lhs,
.ref, .pseudo_ref, .ref_coerced_ty => return lhs,
else => {
const result = try gz.addUnNode(.load, lhs, node);
return rvalue(gz, ri, result, node);
Expand All @@ -991,7 +993,7 @@ fn expr(gz: *GenZir, scope: *Scope, ri: ResultInfo, node: Ast.Node.Index) InnerE
return rvalue(gz, ri, result, node);
},
.unwrap_optional => switch (ri.rl) {
.ref, .ref_coerced_ty => {
.ref, .pseudo_ref, .ref_coerced_ty => {
const lhs = try expr(gz, scope, .{ .rl = .ref }, node_datas[node].lhs);

const cursor = maybeAdvanceSourceCursorToMainToken(gz, node);
Expand Down Expand Up @@ -1053,7 +1055,7 @@ fn expr(gz: *GenZir, scope: *Scope, ri: ResultInfo, node: Ast.Node.Index) InnerE
return switchExprErrUnion(gz, scope, ri.br(), node, .@"catch");
}
switch (ri.rl) {
.ref, .ref_coerced_ty => return orelseCatchExpr(
.ref, .pseudo_ref, .ref_coerced_ty => return orelseCatchExpr(
gz,
scope,
ri,
Expand All @@ -1080,7 +1082,7 @@ fn expr(gz: *GenZir, scope: *Scope, ri: ResultInfo, node: Ast.Node.Index) InnerE
}
},
.@"orelse" => switch (ri.rl) {
.ref, .ref_coerced_ty => return orelseCatchExpr(
.ref, .pseudo_ref, .ref_coerced_ty => return orelseCatchExpr(
gz,
scope,
ri,
Expand Down Expand Up @@ -1523,7 +1525,7 @@ fn arrayInitExpr(
}
return Zir.Inst.Ref.void_value;
},
.ref => {
.ref, .pseudo_ref => {
const result = try arrayInitExprAnon(gz, scope, node, array_init.ast.elements);
return gz.addUnTok(.ref, result, tree.firstToken(node));
},
Expand Down Expand Up @@ -1697,7 +1699,7 @@ fn structInitExpr(
const val = try gz.addUnNode(.struct_init_empty_result, ty_inst, node);
return rvalue(gz, ri, val, node);
},
.none, .ref, .inferred_ptr => {
.none, .ref, .pseudo_ref, .inferred_ptr => {
return rvalue(gz, ri, .empty_struct, node);
},
.destructure => |destructure| {
Expand Down Expand Up @@ -1830,7 +1832,7 @@ fn structInitExpr(
}
return .void_value;
},
.ref => {
.ref, .pseudo_ref => {
const result = try structInitExprAnon(gz, scope, node, struct_init);
return gz.addUnTok(.ref, result, tree.firstToken(node));
},
Expand Down Expand Up @@ -5912,21 +5914,25 @@ fn tryExpr(
const try_lc = LineColumn{ astgen.source_line - parent_gz.decl_line, astgen.source_column };

const operand_ri: ResultInfo = switch (ri.rl) {
.pseudo_ref => .{ .rl = .pseudo_ref, .ctx = .error_handling_expr },
.ref, .ref_coerced_ty => .{ .rl = .ref, .ctx = .error_handling_expr },
else => .{ .rl = .none, .ctx = .error_handling_expr },
};
// This could be a pointer or value depending on the `ri` parameter.
const operand = try reachableExpr(parent_gz, scope, operand_ri, operand_node, node);
const block_tag: Zir.Inst.Tag = if (operand_ri.rl == .ref) .try_ptr else .@"try";
const block_tag: Zir.Inst.Tag = switch (ri.rl) {
.ref, .pseudo_ref, .ref_coerced_ty => .try_ptr,
else => .@"try",
};
const try_inst = try parent_gz.makeBlockInst(block_tag, node);
try parent_gz.instructions.append(astgen.gpa, try_inst);

var else_scope = parent_gz.makeSubBlock(scope);
defer else_scope.unstack();

const err_tag = switch (ri.rl) {
.ref, .ref_coerced_ty => Zir.Inst.Tag.err_union_code_ptr,
else => Zir.Inst.Tag.err_union_code,
const err_tag: Zir.Inst.Tag = switch (ri.rl) {
.ref, .pseudo_ref, .ref_coerced_ty => .err_union_code_ptr,
else => .err_union_code,
};
const err_code = try else_scope.addUnNode(err_tag, operand, node);
try genDefers(&else_scope, &fn_block.base, scope, .{ .both = err_code });
Expand All @@ -5936,7 +5942,7 @@ fn tryExpr(
try else_scope.setTryBody(try_inst, operand);
const result = try_inst.toRef();
switch (ri.rl) {
.ref, .ref_coerced_ty => return result,
.ref, .pseudo_ref, .ref_coerced_ty => return result,
else => return rvalue(parent_gz, ri, result, node),
}
}
Expand Down Expand Up @@ -5977,6 +5983,7 @@ fn orelseCatchExpr(
defer block_scope.unstack();

const operand_ri: ResultInfo = switch (block_scope.break_result_info.rl) {
.pseudo_ref => .{ .rl = .pseudo_ref, .ctx = if (do_err_trace) .error_handling_expr else .none },
.ref, .ref_coerced_ty => .{ .rl = .ref, .ctx = if (do_err_trace) .error_handling_expr else .none },
else => .{ .rl = .none, .ctx = if (do_err_trace) .error_handling_expr else .none },
};
Expand All @@ -5999,7 +6006,7 @@ fn orelseCatchExpr(
// This could be a pointer or value depending on `unwrap_op`.
const unwrapped_payload = try then_scope.addUnNode(unwrap_op, operand, node);
const then_result = switch (ri.rl) {
.ref, .ref_coerced_ty => unwrapped_payload,
.ref, .pseudo_ref, .ref_coerced_ty => unwrapped_payload,
else => try rvalue(&then_scope, block_scope.break_result_info, unwrapped_payload, node),
};
_ = try then_scope.addBreakWithSrcNode(.@"break", block, then_result, node);
Expand Down Expand Up @@ -6071,9 +6078,14 @@ fn fieldAccess(
) InnerError!Zir.Inst.Ref {
switch (ri.rl) {
.ref, .ref_coerced_ty => return addFieldAccess(.field_ptr, gz, scope, .{ .rl = .ref }, node),
.pseudo_ref => return addFieldAccess(.field_ptr, gz, scope, .{ .rl = .pseudo_ref }, node),
else => {
const access = try addFieldAccess(.field_val, gz, scope, .{ .rl = .none }, node);
return rvalue(gz, ri, access, node);
if ((gz.is_comptime and gz.astgen.fn_block != null) or !nodeAccessesIdentifier(gz.astgen.tree, node)) {
return rvalue(gz, ri, try addFieldAccess(.field_val, gz, scope, .{ .rl = .none }, node), node);
}
const ptr = try addFieldAccess(.field_ptr, gz, scope, .{ .rl = .pseudo_ref }, node);
const result = try gz.addUnNode(.load, ptr, node);
return rvalue(gz, ri, result, node);
},
}
}
Expand Down Expand Up @@ -6113,28 +6125,32 @@ fn arrayAccess(
) InnerError!Zir.Inst.Ref {
const tree = gz.astgen.tree;
const node_datas = tree.nodes.items(.data);
switch (ri.rl) {
.ref, .ref_coerced_ty => {
const lhs = try expr(gz, scope, .{ .rl = .ref }, node_datas[node].lhs);
const lhs_result: ResultInfo, const need_load = switch (ri.rl) {
.ref, .ref_coerced_ty => .{ .{ .rl = .ref }, false },
.pseudo_ref => .{ .{ .rl = .pseudo_ref }, false },
else => if ((gz.is_comptime and gz.astgen.fn_block != null) or !nodeAccessesIdentifier(tree, node)) {
const lhs = try expr(gz, scope, .{ .rl = .none }, node_datas[node].lhs);

const cursor = maybeAdvanceSourceCursorToMainToken(gz, node);

const rhs = try expr(gz, scope, .{ .rl = .{ .coerced_ty = .usize_type } }, node_datas[node].rhs);
try emitDbgStmt(gz, cursor);

return gz.addPlNode(.elem_ptr_node, node, Zir.Inst.Bin{ .lhs = lhs, .rhs = rhs });
},
else => {
const lhs = try expr(gz, scope, .{ .rl = .none }, node_datas[node].lhs);
return rvalue(gz, ri, try gz.addPlNode(.elem_val_node, node, Zir.Inst.Bin{ .lhs = lhs, .rhs = rhs }), node);
} else .{ .{ .rl = .pseudo_ref }, true },
};
const lhs = try expr(gz, scope, lhs_result, node_datas[node].lhs);

const cursor = maybeAdvanceSourceCursorToMainToken(gz, node);
const cursor = maybeAdvanceSourceCursorToMainToken(gz, node);

const rhs = try expr(gz, scope, .{ .rl = .{ .coerced_ty = .usize_type } }, node_datas[node].rhs);
try emitDbgStmt(gz, cursor);
const rhs = try expr(gz, scope, .{ .rl = .{ .coerced_ty = .usize_type } }, node_datas[node].rhs);
try emitDbgStmt(gz, cursor);

return rvalue(gz, ri, try gz.addPlNode(.elem_val_node, node, Zir.Inst.Bin{ .lhs = lhs, .rhs = rhs }), node);
},
}
const ptr = try gz.addPlNode(.elem_ptr_node, node, Zir.Inst.Bin{ .lhs = lhs, .rhs = rhs });
if (!need_load) return ptr;

const loaded = try gz.addUnNode(.load, ptr, node);
return rvalue(gz, ri, loaded, node);
}

fn simpleBinOp(
Expand Down Expand Up @@ -7201,7 +7217,7 @@ fn switchExprErrUnion(
switch (node_ty) {
.@"catch" => {
const case_result = switch (ri.rl) {
.ref, .ref_coerced_ty => unwrapped_payload,
.ref, .pseudo_ref, .ref_coerced_ty => unwrapped_payload,
else => try rvalue(
&case_scope,
block_scope.break_result_info,
Expand Down Expand Up @@ -8305,6 +8321,7 @@ fn localVarRef(
) else local_ptr.ptr;

switch (ri.rl) {
.pseudo_ref => return ptr_inst,
.ref, .ref_coerced_ty => {
local_ptr.used_as_lvalue = true;
return ptr_inst;
Expand Down Expand Up @@ -8349,7 +8366,7 @@ fn localVarRef(

if (found_namespaces_out > 0 and found_needs_tunnel) {
switch (ri.rl) {
.ref, .ref_coerced_ty => return tunnelThroughClosure(
.ref, .pseudo_ref, .ref_coerced_ty => return tunnelThroughClosure(
gz,
ident,
found_namespaces_out,
Expand All @@ -8370,7 +8387,7 @@ fn localVarRef(
}

switch (ri.rl) {
.ref, .ref_coerced_ty => return gz.addStrTok(.decl_ref, name_str_index, ident_token),
.ref, .pseudo_ref, .ref_coerced_ty => return gz.addStrTok(.decl_ref, name_str_index, ident_token),
else => {
const result = try gz.addStrTok(.decl_val, name_str_index, ident_token);
return rvalueNoCoercePreRef(gz, ri, result, ident);
Expand Down Expand Up @@ -9103,18 +9120,20 @@ fn builtinCall(
const result = try gz.addExtendedMultiOpPayloadIndex(.compile_log, payload_index, params.len);
return rvalue(gz, ri, result, node);
},
.field => {
if (ri.rl == .ref or ri.rl == .ref_coerced_ty) {
.field => switch (ri.rl) {
.ref, .pseudo_ref, .ref_coerced_ty => {
return gz.addPlNode(.field_ptr_named, node, Zir.Inst.FieldNamed{
.lhs = try expr(gz, scope, .{ .rl = .ref }, params[0]),
.field_name = try comptimeExpr(gz, scope, .{ .rl = .{ .coerced_ty = .slice_const_u8_type } }, params[1]),
});
}
const result = try gz.addPlNode(.field_val_named, node, Zir.Inst.FieldNamed{
.lhs = try expr(gz, scope, .{ .rl = .none }, params[0]),
.field_name = try comptimeExpr(gz, scope, .{ .rl = .{ .coerced_ty = .slice_const_u8_type } }, params[1]),
});
return rvalue(gz, ri, result, node);
},
else => {
const result = try gz.addPlNode(.field_val_named, node, Zir.Inst.FieldNamed{
.lhs = try expr(gz, scope, .{ .rl = .none }, params[0]),
.field_name = try comptimeExpr(gz, scope, .{ .rl = .{ .coerced_ty = .slice_const_u8_type } }, params[1]),
});
return rvalue(gz, ri, result, node);
},
},

// zig fmt: off
Expand Down Expand Up @@ -10929,6 +10948,31 @@ fn nodeUsesAnonNameStrategy(tree: *const Ast, node: Ast.Node.Index) bool {
}
}

/// Returns `true` if field/array access chain ultimately refers to an identifier.
fn nodeAccessesIdentifier(tree: *const Ast, start_node: Ast.Node.Index) bool {
const node_tags = tree.nodes.items(.tag);
const node_datas = tree.nodes.items(.data);

var node = start_node;
while (true) {
switch (node_tags[node]) {
.identifier => return true,

// Forward the question to the LHS sub-expression.
.grouped_expression,
.@"try",
.@"nosuspend",
.unwrap_optional,
.deref,
.field_access,
.array_access,
=> node = node_datas[node].lhs,

else => return false,
}
}
}

/// Applies `rl` semantics to `result`. Expressions which do not do their own handling of
/// result locations must call this function on their result.
/// As an example, if `ri.rl` is `.ptr`, it will write the result to the pointer.
Expand Down Expand Up @@ -10981,7 +11025,7 @@ fn rvalueInner(
_ = try gz.addUnNode(.ensure_result_non_error, result, src_node);
return .void_value;
},
.ref, .ref_coerced_ty => {
.ref, .pseudo_ref, .ref_coerced_ty => {
const coerced_result = if (allow_coerce_pre_ref and ri.rl == .ref_coerced_ty) res: {
const ptr_ty = ri.rl.ref_coerced_ty;
break :res try gz.addPlNode(.coerce_ptr_elem_ty, src_node, Zir.Inst.Bin{
Expand Down
31 changes: 31 additions & 0 deletions src/Sema.zig
Original file line number Diff line number Diff line change
Expand Up @@ -27074,6 +27074,37 @@ fn fieldPtr(
if (ip.stringEqlSlice(field_name, "len")) {
const int_val = try mod.intValue(Type.usize, inner_ty.arrayLen(mod));
return anonDeclRef(sema, int_val.toIntern());
} else if (ip.stringEqlSlice(field_name, "ptr") and is_pointer_to) {
const ptr_info = object_ty.ptrInfo(mod);
const new_ptr_ty = try sema.ptrType(.{
.child = Type.fromInterned(ptr_info.child).childType(mod).toIntern(),
.sentinel = if (object_ty.sentinel(mod)) |s| s.toIntern() else .none,
.flags = .{
.size = .Many,
.alignment = ptr_info.flags.alignment,
.is_const = ptr_info.flags.is_const,
.is_volatile = ptr_info.flags.is_volatile,
.is_allowzero = ptr_info.flags.is_allowzero,
.address_space = ptr_info.flags.address_space,
.vector_index = ptr_info.flags.vector_index,
},
.packed_offset = ptr_info.packed_offset,
});
const ptr_ptr_info = object_ptr_ty.ptrInfo(mod);
const result_ty = try sema.ptrType(.{
.child = new_ptr_ty.toIntern(),
.sentinel = if (object_ptr_ty.sentinel(mod)) |s| s.toIntern() else .none,
.flags = .{
.alignment = ptr_ptr_info.flags.alignment,
.is_const = ptr_ptr_info.flags.is_const,
.is_volatile = ptr_ptr_info.flags.is_volatile,
.is_allowzero = ptr_ptr_info.flags.is_allowzero,
.address_space = ptr_ptr_info.flags.address_space,
.vector_index = ptr_ptr_info.flags.vector_index,
},
.packed_offset = ptr_ptr_info.packed_offset,
});
return sema.bitCast(block, result_ty, object_ptr, src, null);
} else {
return sema.fail(
block,
Expand Down
29 changes: 27 additions & 2 deletions src/codegen/llvm.zig
Original file line number Diff line number Diff line change
Expand Up @@ -3816,7 +3816,10 @@ pub const Object = struct {
.elem,
.field,
=> try o.lowerParentPtr(val),
.comptime_field => unreachable,
.comptime_field => |field_val| {
const ptr_ty = Type.fromInterned(ptr.ty);
return o.lowerComptimeField(ptr_ty, field_val);
},
},
.slice => |slice| return o.builder.structConst(try o.lowerType(ty), &.{
try o.lowerValue(slice.ptr),
Expand Down Expand Up @@ -4311,7 +4314,10 @@ pub const Object = struct {

return o.builder.gepConst(.inbounds, try o.lowerType(opt_ty), parent_ptr, null, &.{ .@"0", .@"0" });
},
.comptime_field => unreachable,
.comptime_field => |field_val| {
const ptr_ty = Type.fromInterned(ptr.ty);
return o.lowerComptimeField(ptr_ty, field_val);
},
.elem => |elem_ptr| {
const parent_ptr = try o.lowerParentPtr(Value.fromInterned(elem_ptr.base));
const elem_ty = Type.fromInterned(ip.typeOf(elem_ptr.base)).elemType2(mod);
Expand Down Expand Up @@ -4388,6 +4394,25 @@ pub const Object = struct {
};
}

fn lowerComptimeField(o: *Object, ptr_ty: Type, field_val: InternPool.Index) Error!Builder.Constant {
const mod = o.module;
const target = mod.getTarget();
const llvm_addr_space = toLlvmAddressSpace(ptr_ty.ptrAddressSpace(mod), target);
const alignment = ptr_ty.ptrAlignment(mod);
const llvm_global = (try o.resolveGlobalAnonDecl(field_val, llvm_addr_space, alignment)).ptrConst(&o.builder).global;

const llvm_val = try o.builder.convConst(
.unneeded,
llvm_global.toConst(),
try o.builder.ptrType(llvm_addr_space),
);

return o.builder.convConst(if (ptr_ty.isAbiInt(mod)) switch (ptr_ty.intInfo(mod).signedness) {
.signed => .signed,
.unsigned => .unsigned,
} else .unneeded, llvm_val, try o.lowerType(ptr_ty));
}

/// This logic is very similar to `lowerDeclRefValue` but for anonymous declarations.
/// Maybe the logic could be unified.
fn lowerAnonDeclRef(
Expand Down
Loading