diff --git a/gcc/rust/typecheck/rust-autoderef.cc b/gcc/rust/typecheck/rust-autoderef.cc index 7d91156c0593..ae41663a8156 100644 --- a/gcc/rust/typecheck/rust-autoderef.cc +++ b/gcc/rust/typecheck/rust-autoderef.cc @@ -170,18 +170,20 @@ resolve_operator_overload_fn ( } } - bool have_implementation_for_lang_item = resolved_candidates.size () > 0; + auto selected_candidates + = MethodResolver::Select (resolved_candidates, lhs, {}); + bool have_implementation_for_lang_item = selected_candidates.size () > 0; if (!have_implementation_for_lang_item) return false; - if (resolved_candidates.size () > 1) + if (selected_candidates.size () > 1) { // no need to error out as we are just trying to see if there is a fit return false; } // Get the adjusted self - MethodCandidate candidate = *resolved_candidates.begin (); + MethodCandidate candidate = *selected_candidates.begin (); Adjuster adj (lhs); TyTy::BaseType *adjusted_self = adj.adjust_type (candidate.adjustments); diff --git a/gcc/rust/typecheck/rust-hir-dot-operator.cc b/gcc/rust/typecheck/rust-hir-dot-operator.cc index 323902947cfb..775a4a4db6f2 100644 --- a/gcc/rust/typecheck/rust-hir-dot-operator.cc +++ b/gcc/rust/typecheck/rust-hir-dot-operator.cc @@ -41,6 +41,45 @@ MethodResolver::Probe (TyTy::BaseType *receiver, return resolver.result; } +std::set +MethodResolver::Select (std::set &candidates, + TyTy::BaseType *receiver, + std::vector arguments) +{ + std::set selected; + for (auto &candidate : candidates) + { + TyTy::BaseType *candidate_type = candidate.candidate.ty; + rust_assert (candidate_type->get_kind () == TyTy::TypeKind::FNDEF); + TyTy::FnType &fn = *static_cast (candidate_type); + + // match the number of arguments + if (fn.num_params () != (arguments.size () + 1)) + continue; + + // match the arguments + bool failed = false; + for (size_t i = 0; i < arguments.size (); i++) + { + TyTy::BaseType *arg = arguments.at (i); + TyTy::BaseType *param = fn.get_params ().at (i + 1).second; + TyTy::BaseType *coerced + = try_coercion (0, TyTy::TyWithLocation (param), + TyTy::TyWithLocation (arg), Location ()); + if (coerced->get_kind () == TyTy::TypeKind::ERROR) + { + failed = true; + break; + } + } + + if (!failed) + selected.insert (candidate); + } + + return selected; +} + void MethodResolver::try_hook (const TyTy::BaseType &r) { diff --git a/gcc/rust/typecheck/rust-hir-dot-operator.h b/gcc/rust/typecheck/rust-hir-dot-operator.h index db04ad8a56fb..5451c13ccde6 100644 --- a/gcc/rust/typecheck/rust-hir-dot-operator.h +++ b/gcc/rust/typecheck/rust-hir-dot-operator.h @@ -57,6 +57,10 @@ class MethodResolver : private TypeCheckBase, protected AutoderefCycle Probe (TyTy::BaseType *receiver, const HIR::PathIdentSegment &segment_name, bool autoderef_flag = false); + static std::set + Select (std::set &candidates, TyTy::BaseType *receiver, + std::vector arguments); + static std::vector get_predicate_items ( const HIR::PathIdentSegment &segment_name, const TyTy::BaseType &receiver, const std::vector &specified_bounds); diff --git a/gcc/rust/typecheck/rust-hir-type-check-expr.cc b/gcc/rust/typecheck/rust-hir-type-check-expr.cc index 1db4cff420b7..641f314a0ce2 100644 --- a/gcc/rust/typecheck/rust-hir-type-check-expr.cc +++ b/gcc/rust/typecheck/rust-hir-type-check-expr.cc @@ -1618,11 +1618,17 @@ TypeCheckExpr::resolve_operator_overload ( } } - bool have_implementation_for_lang_item = resolved_candidates.size () > 0; + std::vector select_args = {}; + if (rhs != nullptr) + select_args = {rhs}; + auto selected_candidates + = MethodResolver::Select (resolved_candidates, lhs, select_args); + + bool have_implementation_for_lang_item = selected_candidates.size () > 0; if (!have_implementation_for_lang_item) return false; - if (resolved_candidates.size () > 1) + if (selected_candidates.size () > 1) { // mutliple candidates RichLocation r (expr.get_locus ()); @@ -1636,7 +1642,7 @@ TypeCheckExpr::resolve_operator_overload ( } // Get the adjusted self - MethodCandidate candidate = *resolved_candidates.begin (); + MethodCandidate candidate = *selected_candidates.begin (); Adjuster adj (lhs); TyTy::BaseType *adjusted_self = adj.adjust_type (candidate.adjustments); diff --git a/gcc/rust/typecheck/rust-type-util.cc b/gcc/rust/typecheck/rust-type-util.cc index 1c93c60f29c9..e9e2f7e1183c 100644 --- a/gcc/rust/typecheck/rust-type-util.cc +++ b/gcc/rust/typecheck/rust-type-util.cc @@ -231,6 +231,24 @@ coercion_site (HirId id, TyTy::TyWithLocation lhs, TyTy::TyWithLocation rhs, return coerced; } +TyTy::BaseType * +try_coercion (HirId id, TyTy::TyWithLocation lhs, TyTy::TyWithLocation rhs, + Location locus) +{ + TyTy::BaseType *expected = lhs.get_ty (); + TyTy::BaseType *expr = rhs.get_ty (); + + rust_debug ("try_coercion_site id={%u} expected={%s} expr={%s}", id, + expected->debug_str ().c_str (), expr->debug_str ().c_str ()); + + auto result = TypeCoercionRules::TryCoerce (expr, expected, locus, + true /*allow-autodref*/); + if (result.is_error ()) + return new TyTy::ErrorType (id); + + return result.tyty; +} + TyTy::BaseType * cast_site (HirId id, TyTy::TyWithLocation from, TyTy::TyWithLocation to, Location cast_locus) diff --git a/gcc/rust/typecheck/rust-type-util.h b/gcc/rust/typecheck/rust-type-util.h index d6f1b8cf2d10..9d3a59b2674a 100644 --- a/gcc/rust/typecheck/rust-type-util.h +++ b/gcc/rust/typecheck/rust-type-util.h @@ -45,6 +45,10 @@ TyTy::BaseType * coercion_site (HirId id, TyTy::TyWithLocation lhs, TyTy::TyWithLocation rhs, Location coercion_locus); +TyTy::BaseType * +try_coercion (HirId id, TyTy::TyWithLocation lhs, TyTy::TyWithLocation rhs, + Location coercion_locus); + TyTy::BaseType * cast_site (HirId id, TyTy::TyWithLocation from, TyTy::TyWithLocation to, Location cast_locus); diff --git a/gcc/testsuite/rust/compile/issue-2304.rs b/gcc/testsuite/rust/compile/issue-2304.rs new file mode 100644 index 000000000000..243cf100539a --- /dev/null +++ b/gcc/testsuite/rust/compile/issue-2304.rs @@ -0,0 +1,23 @@ +#[lang = "add"] +pub trait Add { + type Output; + + fn add(self, rhs: RHS) -> Self::Output; +} +macro_rules! add_impl { + ($($t:ty)*) => ($( + impl Add for $t { + type Output = $t; + + fn add(self, other: $t) -> $t { self + other } + } + )*) +} + +add_impl! { usize u8 u16 u32 u64 /*isize i8 i16 i32 i64*/ f32 f64 } + +pub fn test() { + let x: usize = 123; + let mut i = 0; + let _bug = i + x; +}