From 04a2eec304c687e8098f4d1d0e10a57b988924f2 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Wed, 18 Jun 2025 17:46:05 +0000 Subject: [PATCH] Cache instantiation of canonical binder --- .../src/infer/canonical/instantiate.rs | 189 ++++++++++++++++-- compiler/rustc_middle/src/ty/context.rs | 8 + 2 files changed, 177 insertions(+), 20 deletions(-) diff --git a/compiler/rustc_infer/src/infer/canonical/instantiate.rs b/compiler/rustc_infer/src/infer/canonical/instantiate.rs index 67f13192b5220..2385c68ef6bb7 100644 --- a/compiler/rustc_infer/src/infer/canonical/instantiate.rs +++ b/compiler/rustc_infer/src/infer/canonical/instantiate.rs @@ -7,8 +7,11 @@ //! [c]: https://rust-lang.github.io/chalk/book/canonical_queries/canonicalization.html use rustc_macros::extension; -use rustc_middle::bug; -use rustc_middle::ty::{self, FnMutDelegate, GenericArgKind, TyCtxt, TypeFoldable}; +use rustc_middle::ty::{ + self, DelayedMap, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeSuperVisitable, + TypeVisitableExt, TypeVisitor, +}; +use rustc_type_ir::TypeVisitable; use crate::infer::canonical::{Canonical, CanonicalVarValues}; @@ -58,23 +61,169 @@ where T: TypeFoldable>, { if var_values.var_values.is_empty() { - value - } else { - let delegate = FnMutDelegate { - regions: &mut |br: ty::BoundRegion| match var_values[br.var].kind() { - GenericArgKind::Lifetime(l) => l, - r => bug!("{:?} is a region but value is {:?}", br, r), - }, - types: &mut |bound_ty: ty::BoundTy| match var_values[bound_ty.var].kind() { - GenericArgKind::Type(ty) => ty, - r => bug!("{:?} is a type but value is {:?}", bound_ty, r), - }, - consts: &mut |bound_ct: ty::BoundVar| match var_values[bound_ct].kind() { - GenericArgKind::Const(ct) => ct, - c => bug!("{:?} is a const but value is {:?}", bound_ct, c), - }, - }; - - tcx.replace_escaping_bound_vars_uncached(value, delegate) + return value; } + + value.fold_with(&mut CanonicalInstantiator { + tcx, + current_index: ty::INNERMOST, + var_values: var_values.var_values, + cache: Default::default(), + }) +} + +/// Replaces the bound vars in a canonical binder with var values. +struct CanonicalInstantiator<'tcx> { + tcx: TyCtxt<'tcx>, + + // The values that the bound vars are are being instantiated with. + var_values: ty::GenericArgsRef<'tcx>, + + /// As with `BoundVarReplacer`, represents the index of a binder *just outside* + /// the ones we have visited. + current_index: ty::DebruijnIndex, + + // Instantiation is a pure function of `DebruijnIndex` and `Ty`. + cache: DelayedMap<(ty::DebruijnIndex, Ty<'tcx>), Ty<'tcx>>, +} + +impl<'tcx> TypeFolder> for CanonicalInstantiator<'tcx> { + fn cx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn fold_binder>>( + &mut self, + t: ty::Binder<'tcx, T>, + ) -> ty::Binder<'tcx, T> { + self.current_index.shift_in(1); + let t = t.super_fold_with(self); + self.current_index.shift_out(1); + t + } + + fn fold_ty(&mut self, t: Ty<'tcx>) -> Ty<'tcx> { + match *t.kind() { + ty::Bound(debruijn, bound_ty) if debruijn == self.current_index => { + self.var_values[bound_ty.var.as_usize()].expect_ty() + } + _ => { + if !t.has_vars_bound_at_or_above(self.current_index) { + t + } else if let Some(&t) = self.cache.get(&(self.current_index, t)) { + t + } else { + let res = t.super_fold_with(self); + assert!(self.cache.insert((self.current_index, t), res)); + res + } + } + } + } + + fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> { + match r.kind() { + ty::ReBound(debruijn, br) if debruijn == self.current_index => { + self.var_values[br.var.as_usize()].expect_region() + } + _ => r, + } + } + + fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> { + match ct.kind() { + ty::ConstKind::Bound(debruijn, bound_const) if debruijn == self.current_index => { + self.var_values[bound_const.as_usize()].expect_const() + } + _ => ct.super_fold_with(self), + } + } + + fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> { + if p.has_vars_bound_at_or_above(self.current_index) { p.super_fold_with(self) } else { p } + } + + fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> { + if !c.has_vars_bound_at_or_above(self.current_index) { + return c; + } + + // Since instantiation is a function of `DebruijnIndex`, we don't want + // to have to cache more copies of clauses when we're inside of binders. + // Since we currently expect to only have clauses in the outermost + // debruijn index, we just fold if we're inside of a binder. + if self.current_index > ty::INNERMOST { + return c.super_fold_with(self); + } + + // Our cache key is `(clauses, var_values)`, but we also don't care about + // var values that aren't named in the clauses, since they can change without + // affecting the output. Since `ParamEnv`s are cached first, we compute the + // last var value that is mentioned in the clauses, and cut off the list so + // that we have more hits in the cache. + + // We also cache the computation of "highest var named by clauses" since that + // is both expensive (depending on the size of the clauses) and a pure function. + let index = *self + .tcx + .highest_var_in_clauses_cache + .lock() + .entry(c) + .or_insert_with(|| highest_var_in_clauses(c)); + let c_args = &self.var_values[..=index]; + + if let Some(c) = self.tcx.clauses_cache.lock().get(&(c, c_args)) { + c + } else { + let folded = c.super_fold_with(self); + self.tcx.clauses_cache.lock().insert((c, c_args), folded); + folded + } + } +} + +fn highest_var_in_clauses<'tcx>(c: ty::Clauses<'tcx>) -> usize { + struct HighestVarInClauses { + max_var: usize, + current_index: ty::DebruijnIndex, + } + impl<'tcx> TypeVisitor> for HighestVarInClauses { + fn visit_binder>>( + &mut self, + t: &ty::Binder<'tcx, T>, + ) -> Self::Result { + self.current_index.shift_in(1); + let t = t.super_visit_with(self); + self.current_index.shift_out(1); + t + } + fn visit_ty(&mut self, t: Ty<'tcx>) { + if let ty::Bound(debruijn, bound_ty) = *t.kind() + && debruijn == self.current_index + { + self.max_var = self.max_var.max(bound_ty.var.as_usize()); + } else if t.has_vars_bound_at_or_above(self.current_index) { + t.super_visit_with(self); + } + } + fn visit_region(&mut self, r: ty::Region<'tcx>) { + if let ty::ReBound(debruijn, bound_region) = r.kind() + && debruijn == self.current_index + { + self.max_var = self.max_var.max(bound_region.var.as_usize()); + } + } + fn visit_const(&mut self, ct: ty::Const<'tcx>) { + if let ty::ConstKind::Bound(debruijn, bound_const) = ct.kind() + && debruijn == self.current_index + { + self.max_var = self.max_var.max(bound_const.as_usize()); + } else if ct.has_vars_bound_at_or_above(self.current_index) { + ct.super_visit_with(self); + } + } + } + let mut visitor = HighestVarInClauses { max_var: 0, current_index: ty::INNERMOST }; + c.visit_with(&mut visitor); + visitor.max_var } diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs index f1395c242f271..12eb626a1f68d 100644 --- a/compiler/rustc_middle/src/ty/context.rs +++ b/compiler/rustc_middle/src/ty/context.rs @@ -1479,6 +1479,12 @@ pub struct GlobalCtxt<'tcx> { pub canonical_param_env_cache: CanonicalParamEnvCache<'tcx>, + /// Caches the index of the highest bound var in clauses in a canonical binder. + pub highest_var_in_clauses_cache: Lock, usize>>, + /// Caches the instantiation of a canonical binder given a set of args. + pub clauses_cache: + Lock, &'tcx [ty::GenericArg<'tcx>]), ty::Clauses<'tcx>>>, + /// Data layout specification for the current target. pub data_layout: TargetDataLayout, @@ -1727,6 +1733,8 @@ impl<'tcx> TyCtxt<'tcx> { new_solver_evaluation_cache: Default::default(), new_solver_canonical_param_env_cache: Default::default(), canonical_param_env_cache: Default::default(), + highest_var_in_clauses_cache: Default::default(), + clauses_cache: Default::default(), data_layout, alloc_map: interpret::AllocMap::new(), current_gcx,