diff --git a/src/analyze.rs b/src/analyze.rs index 9339b1d..45f4755 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -164,6 +164,12 @@ enum DefTy<'tcx> { Deferred(DeferredDefTy<'tcx>), } +#[derive(Debug, Clone)] +struct BasicBlockDef { + ty: BasicBlockType, + has_precondition: bool, +} + #[derive(Debug, Clone, Default)] pub struct EnumDefs { defs: HashMap, @@ -213,7 +219,7 @@ pub struct Analyzer<'tcx> { /// Resulting CHC system. system: Rc>, - basic_blocks: HashMap>, + basic_blocks: HashMap>, def_ids: did_cache::DefIdCache<'tcx>, enum_defs: Rc>, @@ -483,18 +489,84 @@ impl<'tcx> Analyzer<'tcx> { ); } - pub fn register_basic_block_ty( + pub fn register_basic_block_ty_with_precondition( &mut self, def_id: LocalDefId, bb: BasicBlock, rty: BasicBlockType, ) { - tracing::debug!(def_id = ?def_id, ?bb, rty = %rty.display(), "register_basic_block_ty"); - self.basic_blocks.entry(def_id).or_default().insert(bb, rty); + self.register_basic_block_def( + def_id, + bb, + BasicBlockDef { + ty: rty, + has_precondition: true, + }, + ); + } + + pub fn register_basic_block_ty_without_precondition( + &mut self, + def_id: LocalDefId, + bb: BasicBlock, + rty: BasicBlockType, + ) { + self.register_basic_block_def( + def_id, + bb, + BasicBlockDef { + ty: rty, + has_precondition: false, + }, + ); + } + + fn register_basic_block_def(&mut self, def_id: LocalDefId, bb: BasicBlock, def: BasicBlockDef) { + tracing::debug!( + def_id = ?def_id, + ?bb, + rty = %def.ty.display(), + has_precondition = def.has_precondition, + "register_basic_block_def", + ); + self.basic_blocks.entry(def_id).or_default().insert(bb, def); + } + + pub fn register_basic_block_precondition( + &mut self, + def_id: LocalDefId, + bb: BasicBlock, + precondition: rty::Refinement, + ) { + let bb_def = &mut self + .basic_blocks + .get_mut(&def_id) + .unwrap() + .get_mut(&bb) + .unwrap(); + assert!( + !bb_def.has_precondition, + "precondition is already registered for basic block" + ); + bb_def.has_precondition = true; + bb_def.ty.set_precondition(precondition); } pub fn basic_block_ty(&self, def_id: LocalDefId, bb: BasicBlock) -> &BasicBlockType { - &self.basic_blocks[&def_id][&bb] + &self.basic_blocks[&def_id][&bb].ty + } + + pub fn basic_block_ty_with_precondition( + &self, + def_id: LocalDefId, + bb: BasicBlock, + ) -> &BasicBlockType { + let def = &self.basic_blocks[&def_id][&bb]; + assert!( + def.has_precondition, + "basic block does not have precondition" + ); + &def.ty } pub fn register_well_known_defs(&mut self) { diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index eb2e7cb..b50d132 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -24,6 +24,113 @@ mod drop_point; mod visitor; pub use drop_point::DropPoints; +/// Whether a basic block needs a precondition of its own, rather than +/// inheriting its predecessor's outgoing env state. +/// +/// This holds for `START_BLOCK` (whose precondition comes from the function +/// signature, not a predecessor) and for every block reached by more than one +/// CFG edge — i.e. join points with multiple predecessors, or multiple edges +/// from a single predecessor (e.g. `SwitchInt` arms that share a target). +/// +/// A block with a unique incoming edge can inherit that edge's env state, so it +/// needs no precondition of its own. A block that does need one currently models +/// it with a fresh predicate variable; this is also the set of CFG cutpoints, so +/// it cuts every cycle (a loop header always has in-degree >= 2). +pub fn needs_own_precondition(body: &Body<'_>, bb: BasicBlock) -> bool { + if bb == mir::START_BLOCK { + return true; + } + let preds = &body.basic_blocks.predecessors()[bb]; + if preds.len() != 1 { + return true; + } + let pred = preds[0]; + let pred_term = body.basic_blocks[pred].terminator(); + pred_term.successors().filter(|s| *s == bb).count() > 1 +} + +/// Converts the current env state into a `Refinement` to be +/// used as the inherited precondition of a successor block. +/// +/// Each [`push`](PrecondCapture::push) records that a function parameter (or the +/// last param's value) equals an env-side [`PlaceType`]; +/// [`push_env_state`](PrecondCapture::push_env_state) folds in the env's refined +/// vars and accumulated assumptions; [`finish`](PrecondCapture::finish) +/// existentially closes over the env-side variables and emits the refinement. +/// +/// This is the focused successor of the former `UnbindAtoms`, specialized to the +/// only case that still needs it: capturing a predecessor's env into a goto +/// target's precondition. +#[derive(Default)] +struct PrecondCapture { + existentials: IndexVec, + body: chc::Body>, + target_equations: Vec<( + rty::RefinedTypeVar, + chc::Term>, + )>, +} + +impl PrecondCapture { + fn push(&mut self, target: rty::RefinedTypeVar, var_ty: PlaceType) { + self.body.push_conj( + var_ty + .formula + .map_var(|v| v.shift_existential(self.existentials.len()).into()), + ); + self.target_equations.push(( + target, + var_ty + .term + .map_var(|v| v.shift_existential(self.existentials.len()).into()), + )); + self.existentials.extend(var_ty.existentials); + } + + fn push_env_state(&mut self, env: &analyze::Env) { + for (var, rty) in env.vars() { + let base = self.existentials.len(); + self.existentials + .extend(rty.refinement.existentials.iter().cloned()); + let body = rty.refinement.body.clone().map_var(|v| match v { + rty::RefinedTypeVar::Value => rty::RefinedTypeVar::Free(var), + rty::RefinedTypeVar::Free(v) => rty::RefinedTypeVar::Free(v), + rty::RefinedTypeVar::Existential(ev) => rty::RefinedTypeVar::Existential(ev + base), + }); + self.body.push_conj(body); + } + for assumption in env.assumptions() { + let base = self.existentials.len(); + self.existentials + .extend(assumption.existentials.iter().cloned()); + let body = assumption.body.clone().map_var(|v| match v { + PlaceTypeVar::Var(v) => rty::RefinedTypeVar::Free(v), + PlaceTypeVar::Existential(ev) => rty::RefinedTypeVar::Existential(ev + base), + }); + self.body.push_conj(body); + } + } + + fn finish(mut self, env: &analyze::Env) -> rty::Refinement { + let mut substs = HashMap::new(); + for (v, sort) in env.dependencies() { + let ev = self.existentials.push(sort); + substs.insert(v, ev); + } + + let map = |v| match v { + rty::RefinedTypeVar::Value => rty::RefinedTypeVar::Value, + rty::RefinedTypeVar::Free(v) => rty::RefinedTypeVar::Existential(substs[&v]), + rty::RefinedTypeVar::Existential(ev) => rty::RefinedTypeVar::Existential(ev), + }; + let mut body = self.body.map_var(map); + for (t, term) in self.target_equations { + body.push_conj(chc::Term::var(t).equal_to(term.map_var(map))); + } + rty::Refinement::new(self.existentials, body) + } +} + pub struct Analyzer<'tcx, 'ctx> { ctx: &'ctx mut analyze::Analyzer<'tcx>, tcx: TyCtxt<'tcx>, @@ -61,8 +168,9 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { visitor::RustCallVisitor::new(self) } - fn basic_block_ty(&self, bb: BasicBlock) -> &BasicBlockType { - self.ctx.basic_block_ty(self.local_def_id, bb) + fn basic_block_ty_with_precondition(&self, bb: BasicBlock) -> &BasicBlockType { + self.ctx + .basic_block_ty_with_precondition(self.local_def_id, bb) } fn bind_local(&mut self, local: Local, rty: rty::RefinedType) { @@ -580,7 +688,11 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { bb: BasicBlock, outer_fn_param_vars: &HashMap, ) { - let bty = self.basic_block_ty(bb); + if !needs_own_precondition(&self.body, bb) { + self.install_inherited_bb_ty(bb, outer_fn_param_vars); + return; + } + let bty = self.basic_block_ty_with_precondition(bb); let expected_args: IndexVec<_, _> = bty .as_ref() .params @@ -614,6 +726,41 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self.ctx.extend_clauses(clauses); } + /// Materializes the `BasicBlockType` for a target that inherits its + /// precondition by building its (pvar-free) layout and overwriting the last + /// param's refinement with the current env state. + fn install_inherited_bb_ty( + &mut self, + bb: BasicBlock, + outer_fn_param_vars: &HashMap, + ) { + let bty = self.ctx.basic_block_ty(self.local_def_id, bb); + + let mut capture = PrecondCapture::default(); + for (param_idx, param_rty) in bty.as_ref().params.iter_enumerated() { + if param_rty.ty.to_sort().is_singleton() { + continue; + } + let pty = match bty.param_kind(param_idx) { + BasicBlockTypeParamKind::Local(local, _) => self.env.local_type(local), + BasicBlockTypeParamKind::OuterFnParam(outer_idx) => { + let outer_var = outer_fn_param_vars[&outer_idx]; + PlaceType::with_ty_and_term( + param_rty.ty.clone().assert_closed().vacuous(), + chc::Term::var(outer_var), + ) + } + BasicBlockTypeParamKind::Synthetic => continue, + }; + capture.push(rty::RefinedTypeVar::Free(param_idx), pty); + } + capture.push_env_state(&self.env); + let precondition = capture.finish(&self.env); + + self.ctx + .register_basic_block_precondition(self.local_def_id, bb, precondition); + } + fn with_assumptions(&mut self, assumptions: Vec>, callback: F) -> T where F: FnOnce(&mut Self) -> T, @@ -1078,7 +1225,9 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let mut outer_fn_param_vars = HashMap::new(); - let bb_ty = self.basic_block_ty(self.basic_block).clone(); + let bb_ty = self + .basic_block_ty_with_precondition(self.basic_block) + .clone(); let params = &bb_ty.as_ref().params; assert!(!params.is_empty()); for (param_idx, param_rty) in params.iter_enumerated() { diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index 1b05008..9f7f57e 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -851,18 +851,38 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } // function return type is basic block return type let ret_ty = self.body.local_decls[mir::RETURN_PLACE].ty; - let rty = self - .type_builder - .for_template(&mut self.ctx) - .build_basic_block(&self.body, live_locals, ret_ty); - self.ctx.register_basic_block_ty(self.local_def_id, bb, rty); + if analyze::basic_block::needs_own_precondition(&self.body, bb) { + let bty = self + .type_builder + .for_template(&mut self.ctx) + .build_basic_block(&self.body, live_locals, ret_ty); + self.ctx + .register_basic_block_ty_with_precondition(self.local_def_id, bb, bty); + } else { + // The block inherits its predecessor's outgoing env state as its + // precondition, materialized lazily during the predecessor's + // analysis. Record only unrefined type here. + let bty = self + .type_builder + .build_basic_block(&self.body, live_locals, ret_ty); + self.ctx + .register_basic_block_ty_without_precondition(self.local_def_id, bb, bty); + }; } } fn analyze_basic_blocks(&mut self, expected_fn_ty: &rty::RefinedType) { let expected_fn_ty = expected_fn_ty.ty.as_function().unwrap(); - for bb in self.body.basic_blocks.indices() { - let rty = self.ctx.basic_block_ty(self.local_def_id, bb).clone(); + // Reverse postorder guarantees each block that inherits its precondition + // is visited after the predecessor that lazily materialized its type. + for (bb, data) in mir::traversal::reverse_postorder(&self.body) { + if data.is_cleanup { + continue; + } + let rty = self + .ctx + .basic_block_ty_with_precondition(self.local_def_id, bb) + .clone(); let drop_points = self.drop_points[&bb].clone(); self.ctx .basic_block_analyzer(self.local_def_id, bb) @@ -972,7 +992,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { fn assert_entry(&mut self, expected: &rty::RefinedType) { let mut entry_ty = self .ctx - .basic_block_ty(self.local_def_id, mir::START_BLOCK) + .basic_block_ty_with_precondition(self.local_def_id, mir::START_BLOCK) .clone(); tracing::debug!(expected = %expected.display(), entry = %entry_ty.display(), "assert_entry before"); let mut expected = expected.ty.as_function().cloned().unwrap(); diff --git a/src/refine/basic_block.rs b/src/refine/basic_block.rs index ca6e6b5..ca217c8 100644 --- a/src/refine/basic_block.rs +++ b/src/refine/basic_block.rs @@ -126,6 +126,17 @@ impl BasicBlockType { self.ty.clone() } + pub fn set_precondition(&mut self, refinement: rty::Refinement) { + let last_param_idx = self.ty.params.last_index().unwrap(); + self.ty.params.raw.last_mut().unwrap().refinement = refinement.map_var(|v| { + if v == rty::RefinedTypeVar::Free(last_param_idx) { + rty::RefinedTypeVar::Value + } else { + v + } + }); + } + /// Inner function type of BasicBlockType contains extra parameters that carry original /// function parameter values. `truncate_outer_fn_params` removes these extra parameters /// to subtype output of [`BasicBlockType::to_function_ty`] against the function type. diff --git a/src/refine/env.rs b/src/refine/env.rs index 536e838..d9c3dfb 100644 --- a/src/refine/env.rs +++ b/src/refine/env.rs @@ -902,7 +902,7 @@ where ) } - fn vars(&self) -> impl Iterator)> + '_ { + pub fn vars(&self) -> impl Iterator)> + '_ { self.locals .iter() .map(|(local, rty)| (Var::Local(*local), rty)) @@ -914,6 +914,10 @@ where .filter(|(_var, rty)| rty.is_refined()) } + pub fn assumptions(&self) -> &[Assumption] { + &self.assumptions + } + pub fn contains_local(&self, local: Local) -> bool { self.locals.contains_key(&local) || self.flow_locals.contains_key(&local) } diff --git a/src/refine/template.rs b/src/refine/template.rs index 5204094..ed0762e 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -262,6 +262,30 @@ impl<'tcx> TypeBuilder<'tcx> { } } + pub fn build_basic_block( + &mut self, + body: &rustc_middle::mir::Body<'tcx>, + live_locals: I, + ret_ty: mir_ty::Ty<'tcx>, + ) -> BasicBlockType + where + I: IntoIterator)>, + { + struct FakeRegistry; + impl TemplateRegistry for FakeRegistry { + fn register_template(&mut self, _tmpl: rty::Template) -> rty::RefinedType { + panic!("unexpected template registration") + } + } + self.for_template(&mut FakeRegistry) + .build_basic_block_with_precondition( + body, + live_locals.into_iter().collect(), + ret_ty, + Some(rty::Refinement::top()), + ) + } + pub fn for_template<'a, R>( &self, registry: &'a mut R, @@ -424,17 +448,14 @@ where self.registry.register_template(tmpl) } - pub fn build_basic_block( + fn build_basic_block_with_precondition( &mut self, body: &rustc_middle::mir::Body<'tcx>, - live_locals: I, + mut live_locals: Vec<(Local, mir_ty::TypeAndMut<'tcx>)>, ret_ty: mir_ty::Ty<'tcx>, - ) -> BasicBlockType - where - I: IntoIterator)>, - { + precondition: Option>, + ) -> BasicBlockType { // this is necessary for local_def::Analyzer::elaborate_unused_args - let mut live_locals: Vec<_> = live_locals.into_iter().collect(); live_locals.sort_by_key(|(local, _)| *local); let mut locals = IndexVec::::new(); @@ -458,7 +479,7 @@ where param_tys: tys, ret_ty, param_rtys: Default::default(), - param_refinement: None, + param_refinement: precondition, // not generating pvar of BB post ret_rty: Some(rty::RefinedType::unrefined( self.inner.build(ret_ty).vacuous(), @@ -472,6 +493,23 @@ where outer_fn_param_count: body.arg_count, } } + + pub fn build_basic_block( + &mut self, + body: &rustc_middle::mir::Body<'tcx>, + live_locals: I, + ret_ty: mir_ty::Ty<'tcx>, + ) -> BasicBlockType + where + I: IntoIterator)>, + { + self.build_basic_block_with_precondition( + body, + live_locals.into_iter().collect(), + ret_ty, + None, + ) + } } /// A builder for function template types.