diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs index c5446b24..62c0df0b 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs @@ -12,10 +12,10 @@ use smol_str::SmolStr; use crate::{ db_index::{ AnalyzeError, LuaAliasCallType, LuaFunctionType, LuaGenericType, LuaIndexAccessKey, - LuaIntersectionType, LuaObjectType, LuaStringTplType, LuaTupleType, LuaType, LuaUnionType, + LuaObjectType, LuaStringTplType, LuaTupleType, LuaType, }, - DiagnosticCode, GenericTpl, InFiled, LuaAliasCallKind, LuaMultiLineUnion, LuaTypeDeclId, - TypeOps, VariadicType, + make_intersection, make_union, DiagnosticCode, GenericTpl, InFiled, LuaAliasCallKind, + LuaMultiLineUnion, LuaTypeDeclId, TypeOps, VariadicType, }; use super::{preprocess_description, DocAnalyzer}; @@ -289,53 +289,12 @@ fn infer_binary_type(analyzer: &mut DocAnalyzer, binary_type: &LuaDocBinaryType) if let Some(op) = binary_type.get_op_token() { match op.get_op() { - LuaTypeBinaryOperator::Union => match (left_type, right_type) { - (LuaType::Union(left_type_union), LuaType::Union(right_type_union)) => { - let mut left_types = left_type_union.into_types(); - let right_types = right_type_union.into_types(); - left_types.extend(right_types); - return LuaType::Union(LuaUnionType::new(left_types).into()); - } - (LuaType::Union(left_type_union), right) => { - let mut left_types = (*left_type_union).into_types(); - left_types.push(right); - return LuaType::Union(LuaUnionType::new(left_types).into()); - } - (left, LuaType::Union(right_type_union)) => { - let mut right_types = (*right_type_union).into_types(); - right_types.push(left); - return LuaType::Union(LuaUnionType::new(right_types).into()); - } - (left, right) => { - return LuaType::Union(LuaUnionType::new(vec![left, right]).into()); - } - }, - LuaTypeBinaryOperator::Intersection => match (left_type, right_type) { - ( - LuaType::Intersection(left_type_union), - LuaType::Intersection(right_type_union), - ) => { - let mut left_types = left_type_union.into_types(); - let right_types = right_type_union.into_types(); - left_types.extend(right_types); - return LuaType::Intersection(LuaIntersectionType::new(left_types).into()); - } - (LuaType::Intersection(left_type_union), right) => { - let mut left_types = left_type_union.into_types(); - left_types.push(right); - return LuaType::Intersection(LuaIntersectionType::new(left_types).into()); - } - (left, LuaType::Intersection(right_type_union)) => { - let mut right_types = right_type_union.into_types(); - right_types.push(left); - return LuaType::Intersection(LuaIntersectionType::new(right_types).into()); - } - (left, right) => { - return LuaType::Intersection( - LuaIntersectionType::new(vec![left, right]).into(), - ); - } - }, + LuaTypeBinaryOperator::Union => { + return make_union(left_type, right_type); + } + LuaTypeBinaryOperator::Intersection => { + return make_intersection(left_type, right_type); + } LuaTypeBinaryOperator::Extends => { return LuaType::Call( LuaAliasCallType::new( diff --git a/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs b/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs index 386f21bd..2e5ade71 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/member_infer_test.rs @@ -102,4 +102,64 @@ mod test { assert_eq!(e_ty, LuaType::Integer); assert_eq!(f_ty, LuaType::Integer); } + + #[test] + fn test_issue_454_unions() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + --- @class A + --- @field a integer + --- @field c integer + --- @field d integer + + --- @class B + --- @field b integer + --- @field c integer + --- @field d string + + ab --- @type A | B + "#, + ); + + assert_eq!(ws.expr_ty("ab.a"), LuaType::Unknown); + assert_eq!(ws.expr_ty("ab.b"), LuaType::Unknown); + assert_eq!(ws.expr_ty("ab.c"), ws.ty("integer")); + assert_eq!(ws.expr_ty("ab.d"), ws.ty("integer | string")); + assert_eq!(ws.expr_ty("ab['a']"), LuaType::Unknown); + assert_eq!(ws.expr_ty("ab['b']"), LuaType::Unknown); + assert_eq!(ws.expr_ty("ab['c']"), ws.ty("integer")); + assert_eq!(ws.expr_ty("ab['d']"), ws.ty("integer | string")); + } + + #[test] + fn test_issue_454_intersections() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + --- @class A + --- @field a integer + --- @field c integer + --- @field d integer + + --- @class B + --- @field b integer + --- @field c integer + --- @field d string + + ab --- @type A & B + "#, + ); + + assert_eq!(ws.expr_ty("ab.a"), ws.ty("integer")); + assert_eq!(ws.expr_ty("ab.b"), ws.ty("integer")); + assert_eq!(ws.expr_ty("ab.c"), ws.ty("integer")); + assert_eq!(ws.expr_ty("ab.d"), ws.ty("integer & string")); + assert_eq!(ws.expr_ty("ab['a']"), ws.ty("integer")); + assert_eq!(ws.expr_ty("ab['b']"), ws.ty("integer")); + assert_eq!(ws.expr_ty("ab['c']"), ws.ty("integer")); + assert_eq!(ws.expr_ty("ab['d']"), ws.ty("integer & string")); + } } diff --git a/crates/emmylua_code_analysis/src/db_index/type/types.rs b/crates/emmylua_code_analysis/src/db_index/type/types.rs index 607580b7..d6a8b58e 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/types.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/types.rs @@ -725,6 +725,29 @@ impl From for LuaType { } } +pub fn make_union(left_type: LuaType, right_type: LuaType) -> LuaType { + match (left_type, right_type) { + (left, right) if left == right => left, + (LuaType::Union(left_type_union), LuaType::Union(right_type_union)) => { + let mut left_types = left_type_union.into_types(); + let right_types = right_type_union.into_types(); + left_types.extend(right_types); + LuaType::Union(LuaUnionType::new(left_types).into()) + } + (LuaType::Union(left_type_union), right) => { + let mut left_types = (*left_type_union).into_types(); + left_types.push(right); + LuaType::Union(LuaUnionType::new(left_types).into()) + } + (left, LuaType::Union(right_type_union)) => { + let mut right_types = (*right_type_union).into_types(); + right_types.push(left); + LuaType::Union(LuaUnionType::new(right_types).into()) + } + (left, right) => LuaType::Union(LuaUnionType::new(vec![left, right]).into()), + } +} + #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct LuaIntersectionType { types: Vec, @@ -754,6 +777,29 @@ impl From for LuaType { } } +pub fn make_intersection(left_type: LuaType, right_type: LuaType) -> LuaType { + match (left_type, right_type) { + (left, right) if left == right => left, + (LuaType::Intersection(left_type_union), LuaType::Intersection(right_type_union)) => { + let mut left_types = left_type_union.into_types(); + let right_types = right_type_union.into_types(); + left_types.extend(right_types); + LuaType::Intersection(LuaIntersectionType::new(left_types).into()) + } + (LuaType::Intersection(left_type_union), right) => { + let mut left_types = left_type_union.into_types(); + left_types.push(right); + LuaType::Intersection(LuaIntersectionType::new(left_types).into()) + } + (left, LuaType::Intersection(right_type_union)) => { + let mut right_types = right_type_union.into_types(); + right_types.push(left); + LuaType::Intersection(LuaIntersectionType::new(right_types).into()) + } + (left, right) => LuaType::Intersection(LuaIntersectionType::new(vec![left, right]).into()), + } +} + #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] pub enum LuaAliasCallKind { KeyOf, diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/undefined_field_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/undefined_field_test.rs index b2ea1e3a..d8106263 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/undefined_field_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/undefined_field_test.rs @@ -5,7 +5,7 @@ mod test { #[test] fn test_1() { let mut ws = VirtualWorkspace::new(); - assert!(ws.check_code_for( + assert!(!ws.check_code_for( DiagnosticCode::UndefinedField, r#" ---@alias std.NotNull T - ? @@ -15,7 +15,7 @@ mod test { ---@return fun(tbl: any):int, std.NotNull function ipairs(t) end - ---@type {[integer]: string|table} + ---@type {[integer]: integer|table} local a = {} for i, extendsName in ipairs(a) do diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_index.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_index.rs index 75578c8c..12e3292e 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_index.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_index.rs @@ -471,31 +471,35 @@ fn infer_union_member( union_type: &LuaUnionType, index_expr: LuaIndexMemberExpr, ) -> InferResult { - let mut member_types = Vec::new(); - for sub_type in union_type.get_types() { + let mut member_type = LuaType::Unknown; + for member in union_type.get_types() { + if member == &LuaType::Nil { + // Allow inferring keys of nullable types. + continue; + } + let result = infer_member_by_member_key( db, cache, - sub_type, + member, index_expr.clone(), &mut InferGuard::new(), ); match result { Ok(typ) => { - if !typ.is_nil() { - member_types.push(typ); - } + member_type = TypeOps::Union.apply(db, &member_type, &typ); + } + Err(err) => { + return Err(err); } - _ => {} } } - member_types.dedup(); - match member_types.len() { - 0 => Ok(LuaType::Nil), - 1 => Ok(member_types[0].clone()), - _ => Ok(LuaType::Union(LuaUnionType::new(member_types).into())), + if member_type.is_unknown() { + return Err(InferFailReason::FieldDotFound); } + + Ok(member_type) } fn infer_intersection_member( @@ -504,21 +508,31 @@ fn infer_intersection_member( intersection_type: &LuaIntersectionType, index_expr: LuaIndexMemberExpr, ) -> InferResult { - for member in intersection_type.get_types() { - match infer_member_by_member_key( + let mut member_types = Vec::new(); + for sub_type in intersection_type.get_types() { + let result = infer_member_by_member_key( db, cache, - member, + sub_type, index_expr.clone(), &mut InferGuard::new(), - ) { - Ok(ty) => return Ok(ty), - Err(InferFailReason::FieldDotFound) => continue, - Err(reason) => return Err(reason), + ); + match result { + Ok(typ) => { + member_types.push(typ); + } + _ => {} } } - Err(InferFailReason::FieldDotFound) + member_types.dedup(); + match member_types.len() { + 0 => Err(InferFailReason::FieldDotFound), + 1 => Ok(member_types[0].clone()), + _ => Ok(LuaType::Intersection( + LuaIntersectionType::new(member_types).into(), + )), + } } fn infer_generic_members_from_super_generics( @@ -835,6 +849,10 @@ fn infer_member_by_index_union( ) -> InferResult { let mut member_type = LuaType::Unknown; for member in union.get_types() { + if member == &LuaType::Nil { + // Allow inferring keys of nullable types. + continue; + } let result = infer_member_by_operator( db, cache, @@ -846,7 +864,6 @@ fn infer_member_by_index_union( Ok(typ) => { member_type = TypeOps::Union.apply(db, &member_type, &typ); } - Err(InferFailReason::FieldDotFound) => {} Err(err) => { return Err(err); } @@ -866,6 +883,7 @@ fn infer_member_by_index_intersection( intersection: &LuaIntersectionType, index_expr: LuaIndexMemberExpr, ) -> InferResult { + let mut member_types = Vec::new(); for member in intersection.get_types() { match infer_member_by_operator( db, @@ -874,13 +892,21 @@ fn infer_member_by_index_intersection( index_expr.clone(), &mut InferGuard::new(), ) { - Ok(ty) => return Ok(ty), - Err(InferFailReason::FieldDotFound) => continue, - Err(reason) => return Err(reason), + Ok(typ) => { + member_types.push(typ); + } + _ => {} } } - Err(InferFailReason::FieldDotFound) + member_types.dedup(); + match member_types.len() { + 0 => Err(InferFailReason::FieldDotFound), + 1 => Ok(member_types[0].clone()), + _ => Ok(LuaType::Intersection( + LuaIntersectionType::new(member_types).into(), + )), + } } fn infer_member_by_index_generic( diff --git a/crates/emmylua_code_analysis/src/semantic/member/infer_members.rs b/crates/emmylua_code_analysis/src/semantic/member/infer_members.rs index 32ab056d..b7ac7e88 100644 --- a/crates/emmylua_code_analysis/src/semantic/member/infer_members.rs +++ b/crates/emmylua_code_analysis/src/semantic/member/infer_members.rs @@ -1,8 +1,6 @@ -use std::collections::HashSet; - -use smol_str::SmolStr; - +use super::{get_buildin_type_map_type_id, InferMembersResult, LuaMemberInfo}; use crate::{ + make_intersection, semantic::{ generic::{instantiate_type_generic, TypeSubstitutor}, InferGuard, @@ -11,8 +9,9 @@ use crate::{ LuaMemberOwner, LuaObjectType, LuaSemanticDeclId, LuaTupleType, LuaType, LuaTypeDeclId, LuaUnionType, }; - -use super::{get_buildin_type_map_type_id, InferMembersResult, LuaMemberInfo}; +use smol_str::SmolStr; +use std::collections::hash_map::Entry; +use std::collections::HashMap; pub fn infer_members(db: &DbIndex, prefix_type: &LuaType) -> InferMembersResult { infer_members_guard(db, prefix_type, &mut InferGuard::new()) @@ -174,15 +173,62 @@ fn infer_union_members( union_type: &LuaUnionType, infer_guard: &mut InferGuard, ) -> InferMembersResult { - let mut members = Vec::new(); - for typ in union_type.get_types().iter() { - let sub_members = infer_members_guard(db, typ, infer_guard); - if let Some(sub_members) = sub_members { - members.extend(sub_members); + let union_types = union_type.get_types(); + + if union_types.is_empty() { + return None; + } else if union_types.len() == 1 { + return infer_members_guard(db, &union_types[0], infer_guard); + } else { + let mut seen_keys = Vec::new(); + let mut members_per_key: HashMap> = HashMap::new(); + + for members in union_types + .iter() + .filter(|typ| typ != &&LuaType::Nil) + .filter_map(|typ| infer_members_guard(db, typ, infer_guard)) + { + for member in members { + let key = member.key.clone(); + let typ = member.typ.clone(); + + match members_per_key.entry(key) { + Entry::Occupied(mut entry) => { + entry.get_mut().push(typ); + } + Entry::Vacant(entry) => { + seen_keys.push(entry.key().clone()); + entry.insert(vec![typ]); + } + } + } } - } - Some(members) + return Some( + seen_keys + .into_iter() + .filter_map(|key| { + let types = members_per_key.get_mut(&key).unwrap(); + + if types.len() == union_types.len() { + // Member is present in all union types. + types.dedup(); + Some(LuaMemberInfo { + property_owner_id: None, + key, + typ: LuaType::Union(LuaUnionType::new(std::mem::take(types)).into()), + feature: None, + overload_index: None, + }) + } else { + // Member is absent in one of union types. Accessing it may result + // in error. + None + } + }) + .collect(), + ); + } } fn infer_intersection_members( @@ -190,40 +236,46 @@ fn infer_intersection_members( intersection_type: &LuaIntersectionType, infer_guard: &mut InferGuard, ) -> InferMembersResult { - let mut members = Vec::new(); - for typ in intersection_type.get_types().iter() { - let sub_members = infer_members_guard(db, typ, infer_guard); - if let Some(sub_members) = sub_members { - members.push(sub_members); - } - } + let intersection_types = intersection_type.get_types(); - if members.is_empty() { + if intersection_types.is_empty() { return None; - } else if members.len() == 1 { - return Some(members.remove(0)); + } else if intersection_types.len() == 1 { + return infer_members_guard(db, &intersection_types[0], infer_guard); } else { - let mut result = Vec::new(); - let mut member_set = HashSet::new(); + let mut member_set: HashMap = HashMap::new(); - for member in members.iter().flatten() { + for member in intersection_types + .iter() + .filter_map(|typ| infer_members_guard(db, typ, infer_guard)) + .flatten() + { let key = member.key.clone(); let typ = member.typ.clone(); - if member_set.contains(&key) { - continue; - } - member_set.insert(key.clone()); - result.push(LuaMemberInfo { - property_owner_id: None, - key, - typ, - feature: None, - overload_index: None, - }); + match member_set.entry(key) { + Entry::Occupied(mut entry) => { + let left_type = entry.get().clone(); + entry.insert(make_intersection(left_type, typ)); + } + Entry::Vacant(entry) => { + entry.insert(typ); + } + } } - return Some(result); + return Some( + member_set + .into_iter() + .map(|(key, typ)| LuaMemberInfo { + property_owner_id: None, + key, + typ, + feature: None, + overload_index: None, + }) + .collect(), + ); } }