Skip to content

Improve type inference for member access in unions and intersections #461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ use smol_str::SmolStr;
use crate::{
db_index::{
AnalyzeError, LuaAliasCallType, LuaFunctionType, LuaGenericType, LuaIndexAccessKey,
LuaIntersectionType, LuaObjectType, LuaStringTplType, LuaTupleType, LuaType, LuaUnionType,
LuaObjectType, LuaStringTplType, LuaTupleType, LuaType,
},
DiagnosticCode, GenericTpl, InFiled, LuaAliasCallKind, LuaMultiLineUnion, LuaTypeDeclId,
TypeOps, VariadicType,
make_intersection, make_union, DiagnosticCode, GenericTpl, InFiled, LuaAliasCallKind,
LuaMultiLineUnion, LuaTypeDeclId, TypeOps, VariadicType,
};

use super::{preprocess_description, DocAnalyzer};
Expand Down Expand Up @@ -289,53 +289,12 @@ fn infer_binary_type(analyzer: &mut DocAnalyzer, binary_type: &LuaDocBinaryType)

if let Some(op) = binary_type.get_op_token() {
match op.get_op() {
LuaTypeBinaryOperator::Union => match (left_type, right_type) {
(LuaType::Union(left_type_union), LuaType::Union(right_type_union)) => {
let mut left_types = left_type_union.into_types();
let right_types = right_type_union.into_types();
left_types.extend(right_types);
return LuaType::Union(LuaUnionType::new(left_types).into());
}
(LuaType::Union(left_type_union), right) => {
let mut left_types = (*left_type_union).into_types();
left_types.push(right);
return LuaType::Union(LuaUnionType::new(left_types).into());
}
(left, LuaType::Union(right_type_union)) => {
let mut right_types = (*right_type_union).into_types();
right_types.push(left);
return LuaType::Union(LuaUnionType::new(right_types).into());
}
(left, right) => {
return LuaType::Union(LuaUnionType::new(vec![left, right]).into());
}
},
LuaTypeBinaryOperator::Intersection => match (left_type, right_type) {
(
LuaType::Intersection(left_type_union),
LuaType::Intersection(right_type_union),
) => {
let mut left_types = left_type_union.into_types();
let right_types = right_type_union.into_types();
left_types.extend(right_types);
return LuaType::Intersection(LuaIntersectionType::new(left_types).into());
}
(LuaType::Intersection(left_type_union), right) => {
let mut left_types = left_type_union.into_types();
left_types.push(right);
return LuaType::Intersection(LuaIntersectionType::new(left_types).into());
}
(left, LuaType::Intersection(right_type_union)) => {
let mut right_types = right_type_union.into_types();
right_types.push(left);
return LuaType::Intersection(LuaIntersectionType::new(right_types).into());
}
(left, right) => {
return LuaType::Intersection(
LuaIntersectionType::new(vec![left, right]).into(),
);
}
},
LuaTypeBinaryOperator::Union => {
return make_union(left_type, right_type);
}
LuaTypeBinaryOperator::Intersection => {
return make_intersection(left_type, right_type);
}
LuaTypeBinaryOperator::Extends => {
return LuaType::Call(
LuaAliasCallType::new(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,64 @@ mod test {
assert_eq!(e_ty, LuaType::Integer);
assert_eq!(f_ty, LuaType::Integer);
}

#[test]
fn test_issue_454_unions() {
let mut ws = VirtualWorkspace::new();

ws.def(
r#"
--- @class A
--- @field a integer
--- @field c integer
--- @field d integer

--- @class B
--- @field b integer
--- @field c integer
--- @field d string

ab --- @type A | B
"#,
);

assert_eq!(ws.expr_ty("ab.a"), LuaType::Unknown);
assert_eq!(ws.expr_ty("ab.b"), LuaType::Unknown);
assert_eq!(ws.expr_ty("ab.c"), ws.ty("integer"));
assert_eq!(ws.expr_ty("ab.d"), ws.ty("integer | string"));
assert_eq!(ws.expr_ty("ab['a']"), LuaType::Unknown);
assert_eq!(ws.expr_ty("ab['b']"), LuaType::Unknown);
assert_eq!(ws.expr_ty("ab['c']"), ws.ty("integer"));
assert_eq!(ws.expr_ty("ab['d']"), ws.ty("integer | string"));
}

#[test]
fn test_issue_454_intersections() {
let mut ws = VirtualWorkspace::new();

ws.def(
r#"
--- @class A
--- @field a integer
--- @field c integer
--- @field d integer

--- @class B
--- @field b integer
--- @field c integer
--- @field d string

ab --- @type A & B
"#,
);

assert_eq!(ws.expr_ty("ab.a"), ws.ty("integer"));
assert_eq!(ws.expr_ty("ab.b"), ws.ty("integer"));
assert_eq!(ws.expr_ty("ab.c"), ws.ty("integer"));
assert_eq!(ws.expr_ty("ab.d"), ws.ty("integer & string"));
assert_eq!(ws.expr_ty("ab['a']"), ws.ty("integer"));
assert_eq!(ws.expr_ty("ab['b']"), ws.ty("integer"));
assert_eq!(ws.expr_ty("ab['c']"), ws.ty("integer"));
assert_eq!(ws.expr_ty("ab['d']"), ws.ty("integer & string"));
}
}
46 changes: 46 additions & 0 deletions crates/emmylua_code_analysis/src/db_index/type/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,29 @@ impl From<LuaUnionType> for LuaType {
}
}

pub fn make_union(left_type: LuaType, right_type: LuaType) -> LuaType {
match (left_type, right_type) {
(left, right) if left == right => left,
(LuaType::Union(left_type_union), LuaType::Union(right_type_union)) => {
let mut left_types = left_type_union.into_types();
let right_types = right_type_union.into_types();
left_types.extend(right_types);
LuaType::Union(LuaUnionType::new(left_types).into())
}
(LuaType::Union(left_type_union), right) => {
let mut left_types = (*left_type_union).into_types();
left_types.push(right);
LuaType::Union(LuaUnionType::new(left_types).into())
}
(left, LuaType::Union(right_type_union)) => {
let mut right_types = (*right_type_union).into_types();
right_types.push(left);
LuaType::Union(LuaUnionType::new(right_types).into())
}
(left, right) => LuaType::Union(LuaUnionType::new(vec![left, right]).into()),
}
}

#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct LuaIntersectionType {
types: Vec<LuaType>,
Expand Down Expand Up @@ -754,6 +777,29 @@ impl From<LuaIntersectionType> for LuaType {
}
}

pub fn make_intersection(left_type: LuaType, right_type: LuaType) -> LuaType {
match (left_type, right_type) {
(left, right) if left == right => left,
(LuaType::Intersection(left_type_union), LuaType::Intersection(right_type_union)) => {
let mut left_types = left_type_union.into_types();
let right_types = right_type_union.into_types();
left_types.extend(right_types);
LuaType::Intersection(LuaIntersectionType::new(left_types).into())
}
(LuaType::Intersection(left_type_union), right) => {
let mut left_types = left_type_union.into_types();
left_types.push(right);
LuaType::Intersection(LuaIntersectionType::new(left_types).into())
}
(left, LuaType::Intersection(right_type_union)) => {
let mut right_types = right_type_union.into_types();
right_types.push(left);
LuaType::Intersection(LuaIntersectionType::new(right_types).into())
}
(left, right) => LuaType::Intersection(LuaIntersectionType::new(vec![left, right]).into()),
}
}

#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub enum LuaAliasCallKind {
KeyOf,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ mod test {
#[test]
fn test_1() {
let mut ws = VirtualWorkspace::new();
assert!(ws.check_code_for(
assert!(!ws.check_code_for(
DiagnosticCode::UndefinedField,
r#"
---@alias std.NotNull<T> T - ?
Expand All @@ -15,7 +15,7 @@ mod test {
---@return fun(tbl: any):int, std.NotNull<V>
function ipairs(t) end

---@type {[integer]: string|table}
---@type {[integer]: integer|table}
local a = {}

for i, extendsName in ipairs(a) do
Expand Down
76 changes: 51 additions & 25 deletions crates/emmylua_code_analysis/src/semantic/infer/infer_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -471,31 +471,35 @@ fn infer_union_member(
union_type: &LuaUnionType,
index_expr: LuaIndexMemberExpr,
) -> InferResult {
let mut member_types = Vec::new();
for sub_type in union_type.get_types() {
let mut member_type = LuaType::Unknown;
for member in union_type.get_types() {
if member == &LuaType::Nil {
// Allow inferring keys of nullable types.
continue;
}

let result = infer_member_by_member_key(
db,
cache,
sub_type,
member,
index_expr.clone(),
&mut InferGuard::new(),
);
match result {
Ok(typ) => {
if !typ.is_nil() {
member_types.push(typ);
}
member_type = TypeOps::Union.apply(db, &member_type, &typ);
}
Err(err) => {
return Err(err);
}
_ => {}
}
}

member_types.dedup();
match member_types.len() {
0 => Ok(LuaType::Nil),
1 => Ok(member_types[0].clone()),
_ => Ok(LuaType::Union(LuaUnionType::new(member_types).into())),
if member_type.is_unknown() {
return Err(InferFailReason::FieldDotFound);
}

Ok(member_type)
}

fn infer_intersection_member(
Expand All @@ -504,21 +508,31 @@ fn infer_intersection_member(
intersection_type: &LuaIntersectionType,
index_expr: LuaIndexMemberExpr,
) -> InferResult {
for member in intersection_type.get_types() {
match infer_member_by_member_key(
let mut member_types = Vec::new();
for sub_type in intersection_type.get_types() {
let result = infer_member_by_member_key(
db,
cache,
member,
sub_type,
index_expr.clone(),
&mut InferGuard::new(),
) {
Ok(ty) => return Ok(ty),
Err(InferFailReason::FieldDotFound) => continue,
Err(reason) => return Err(reason),
);
match result {
Ok(typ) => {
member_types.push(typ);
}
_ => {}
}
}

Err(InferFailReason::FieldDotFound)
member_types.dedup();
match member_types.len() {
0 => Err(InferFailReason::FieldDotFound),
1 => Ok(member_types[0].clone()),
_ => Ok(LuaType::Intersection(
LuaIntersectionType::new(member_types).into(),
)),
}
}

fn infer_generic_members_from_super_generics(
Expand Down Expand Up @@ -835,6 +849,10 @@ fn infer_member_by_index_union(
) -> InferResult {
let mut member_type = LuaType::Unknown;
for member in union.get_types() {
if member == &LuaType::Nil {
// Allow inferring keys of nullable types.
continue;
}
let result = infer_member_by_operator(
db,
cache,
Expand All @@ -846,7 +864,6 @@ fn infer_member_by_index_union(
Ok(typ) => {
member_type = TypeOps::Union.apply(db, &member_type, &typ);
}
Err(InferFailReason::FieldDotFound) => {}
Err(err) => {
return Err(err);
}
Expand All @@ -866,6 +883,7 @@ fn infer_member_by_index_intersection(
intersection: &LuaIntersectionType,
index_expr: LuaIndexMemberExpr,
) -> InferResult {
let mut member_types = Vec::new();
for member in intersection.get_types() {
match infer_member_by_operator(
db,
Expand All @@ -874,13 +892,21 @@ fn infer_member_by_index_intersection(
index_expr.clone(),
&mut InferGuard::new(),
) {
Ok(ty) => return Ok(ty),
Err(InferFailReason::FieldDotFound) => continue,
Err(reason) => return Err(reason),
Ok(typ) => {
member_types.push(typ);
}
_ => {}
}
}

Err(InferFailReason::FieldDotFound)
member_types.dedup();
match member_types.len() {
0 => Err(InferFailReason::FieldDotFound),
1 => Ok(member_types[0].clone()),
_ => Ok(LuaType::Intersection(
LuaIntersectionType::new(member_types).into(),
)),
}
}

fn infer_member_by_index_generic(
Expand Down
Loading