From 301902751b0e0083cff551b115441959b45b1a41 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 20 May 2026 17:08:56 +0000 Subject: [PATCH] Allocate a precondition pvar only for blocks that need one A basic block needs a precondition of its own only if it is START_BLOCK or is reached by more than one CFG edge (multiple predecessors, or multiple edges from a single predecessor, e.g. SwitchInt arms that share a target). Every other block has a unique incoming edge, so its precondition is exactly the predecessor's outgoing env state and needs no fresh predicate variable. (This is also the set of CFG cutpoints, so it still cuts every cycle.) - needs_own_precondition(bb) decides this. refine_basic_blocks builds each block's layout eagerly: blocks that need their own precondition get a pvar (build_basic_block via for_template) and has_precondition = true; the rest are built with a top precondition (no pvar, guarded by a FakeRegistry) and has_precondition = false. - At type_goto(target): if the target needs its own precondition, emit the usual subtyping clauses against its pvar. Otherwise install_inherited_bb_ty fills the target's precondition from the current env state (PrecondCapture, the focused successor of the removed UnbindAtoms) via register_basic_block_precondition, which flips has_precondition = true. No subtyping clause is emitted on the edge. - analyze_basic_blocks visits blocks in reverse postorder so each inheriting block is analyzed after the predecessor that materialized its precondition; cleanup/unwind blocks are skipped. basic_block_ty_with_precondition asserts the precondition was materialized. This shrinks the CHC system on every function with straight-line block sequences. https://claude.ai/code/session_01WB28auaD8dSQrckqBwJWBt Co-Authored-By: coord_e --- src/analyze.rs | 82 +++++++++++++++++-- src/analyze/basic_block.rs | 157 ++++++++++++++++++++++++++++++++++++- src/analyze/local_def.rs | 36 +++++++-- src/refine/basic_block.rs | 11 +++ src/refine/env.rs | 6 +- src/refine/template.rs | 54 +++++++++++-- 6 files changed, 320 insertions(+), 26 deletions(-) diff --git a/src/analyze.rs b/src/analyze.rs index 9339b1d..45f4755 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -164,6 +164,12 @@ enum DefTy<'tcx> { Deferred(DeferredDefTy<'tcx>), } +#[derive(Debug, Clone)] +struct BasicBlockDef { + ty: BasicBlockType, + has_precondition: bool, +} + #[derive(Debug, Clone, Default)] pub struct EnumDefs { defs: HashMap, @@ -213,7 +219,7 @@ pub struct Analyzer<'tcx> { /// Resulting CHC system. system: Rc>, - basic_blocks: HashMap>, + basic_blocks: HashMap>, def_ids: did_cache::DefIdCache<'tcx>, enum_defs: Rc>, @@ -483,18 +489,84 @@ impl<'tcx> Analyzer<'tcx> { ); } - pub fn register_basic_block_ty( + pub fn register_basic_block_ty_with_precondition( &mut self, def_id: LocalDefId, bb: BasicBlock, rty: BasicBlockType, ) { - tracing::debug!(def_id = ?def_id, ?bb, rty = %rty.display(), "register_basic_block_ty"); - self.basic_blocks.entry(def_id).or_default().insert(bb, rty); + self.register_basic_block_def( + def_id, + bb, + BasicBlockDef { + ty: rty, + has_precondition: true, + }, + ); + } + + pub fn register_basic_block_ty_without_precondition( + &mut self, + def_id: LocalDefId, + bb: BasicBlock, + rty: BasicBlockType, + ) { + self.register_basic_block_def( + def_id, + bb, + BasicBlockDef { + ty: rty, + has_precondition: false, + }, + ); + } + + fn register_basic_block_def(&mut self, def_id: LocalDefId, bb: BasicBlock, def: BasicBlockDef) { + tracing::debug!( + def_id = ?def_id, + ?bb, + rty = %def.ty.display(), + has_precondition = def.has_precondition, + "register_basic_block_def", + ); + self.basic_blocks.entry(def_id).or_default().insert(bb, def); + } + + pub fn register_basic_block_precondition( + &mut self, + def_id: LocalDefId, + bb: BasicBlock, + precondition: rty::Refinement, + ) { + let bb_def = &mut self + .basic_blocks + .get_mut(&def_id) + .unwrap() + .get_mut(&bb) + .unwrap(); + assert!( + !bb_def.has_precondition, + "precondition is already registered for basic block" + ); + bb_def.has_precondition = true; + bb_def.ty.set_precondition(precondition); } pub fn basic_block_ty(&self, def_id: LocalDefId, bb: BasicBlock) -> &BasicBlockType { - &self.basic_blocks[&def_id][&bb] + &self.basic_blocks[&def_id][&bb].ty + } + + pub fn basic_block_ty_with_precondition( + &self, + def_id: LocalDefId, + bb: BasicBlock, + ) -> &BasicBlockType { + let def = &self.basic_blocks[&def_id][&bb]; + assert!( + def.has_precondition, + "basic block does not have precondition" + ); + &def.ty } pub fn register_well_known_defs(&mut self) { diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index eb2e7cb..b50d132 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -24,6 +24,113 @@ mod drop_point; mod visitor; pub use drop_point::DropPoints; +/// Whether a basic block needs a precondition of its own, rather than +/// inheriting its predecessor's outgoing env state. +/// +/// This holds for `START_BLOCK` (whose precondition comes from the function +/// signature, not a predecessor) and for every block reached by more than one +/// CFG edge — i.e. join points with multiple predecessors, or multiple edges +/// from a single predecessor (e.g. `SwitchInt` arms that share a target). +/// +/// A block with a unique incoming edge can inherit that edge's env state, so it +/// needs no precondition of its own. A block that does need one currently models +/// it with a fresh predicate variable; this is also the set of CFG cutpoints, so +/// it cuts every cycle (a loop header always has in-degree >= 2). +pub fn needs_own_precondition(body: &Body<'_>, bb: BasicBlock) -> bool { + if bb == mir::START_BLOCK { + return true; + } + let preds = &body.basic_blocks.predecessors()[bb]; + if preds.len() != 1 { + return true; + } + let pred = preds[0]; + let pred_term = body.basic_blocks[pred].terminator(); + pred_term.successors().filter(|s| *s == bb).count() > 1 +} + +/// Converts the current env state into a `Refinement` to be +/// used as the inherited precondition of a successor block. +/// +/// Each [`push`](PrecondCapture::push) records that a function parameter (or the +/// last param's value) equals an env-side [`PlaceType`]; +/// [`push_env_state`](PrecondCapture::push_env_state) folds in the env's refined +/// vars and accumulated assumptions; [`finish`](PrecondCapture::finish) +/// existentially closes over the env-side variables and emits the refinement. +/// +/// This is the focused successor of the former `UnbindAtoms`, specialized to the +/// only case that still needs it: capturing a predecessor's env into a goto +/// target's precondition. +#[derive(Default)] +struct PrecondCapture { + existentials: IndexVec, + body: chc::Body>, + target_equations: Vec<( + rty::RefinedTypeVar, + chc::Term>, + )>, +} + +impl PrecondCapture { + fn push(&mut self, target: rty::RefinedTypeVar, var_ty: PlaceType) { + self.body.push_conj( + var_ty + .formula + .map_var(|v| v.shift_existential(self.existentials.len()).into()), + ); + self.target_equations.push(( + target, + var_ty + .term + .map_var(|v| v.shift_existential(self.existentials.len()).into()), + )); + self.existentials.extend(var_ty.existentials); + } + + fn push_env_state(&mut self, env: &analyze::Env) { + for (var, rty) in env.vars() { + let base = self.existentials.len(); + self.existentials + .extend(rty.refinement.existentials.iter().cloned()); + let body = rty.refinement.body.clone().map_var(|v| match v { + rty::RefinedTypeVar::Value => rty::RefinedTypeVar::Free(var), + rty::RefinedTypeVar::Free(v) => rty::RefinedTypeVar::Free(v), + rty::RefinedTypeVar::Existential(ev) => rty::RefinedTypeVar::Existential(ev + base), + }); + self.body.push_conj(body); + } + for assumption in env.assumptions() { + let base = self.existentials.len(); + self.existentials + .extend(assumption.existentials.iter().cloned()); + let body = assumption.body.clone().map_var(|v| match v { + PlaceTypeVar::Var(v) => rty::RefinedTypeVar::Free(v), + PlaceTypeVar::Existential(ev) => rty::RefinedTypeVar::Existential(ev + base), + }); + self.body.push_conj(body); + } + } + + fn finish(mut self, env: &analyze::Env) -> rty::Refinement { + let mut substs = HashMap::new(); + for (v, sort) in env.dependencies() { + let ev = self.existentials.push(sort); + substs.insert(v, ev); + } + + let map = |v| match v { + rty::RefinedTypeVar::Value => rty::RefinedTypeVar::Value, + rty::RefinedTypeVar::Free(v) => rty::RefinedTypeVar::Existential(substs[&v]), + rty::RefinedTypeVar::Existential(ev) => rty::RefinedTypeVar::Existential(ev), + }; + let mut body = self.body.map_var(map); + for (t, term) in self.target_equations { + body.push_conj(chc::Term::var(t).equal_to(term.map_var(map))); + } + rty::Refinement::new(self.existentials, body) + } +} + pub struct Analyzer<'tcx, 'ctx> { ctx: &'ctx mut analyze::Analyzer<'tcx>, tcx: TyCtxt<'tcx>, @@ -61,8 +168,9 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { visitor::RustCallVisitor::new(self) } - fn basic_block_ty(&self, bb: BasicBlock) -> &BasicBlockType { - self.ctx.basic_block_ty(self.local_def_id, bb) + fn basic_block_ty_with_precondition(&self, bb: BasicBlock) -> &BasicBlockType { + self.ctx + .basic_block_ty_with_precondition(self.local_def_id, bb) } fn bind_local(&mut self, local: Local, rty: rty::RefinedType) { @@ -580,7 +688,11 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { bb: BasicBlock, outer_fn_param_vars: &HashMap, ) { - let bty = self.basic_block_ty(bb); + if !needs_own_precondition(&self.body, bb) { + self.install_inherited_bb_ty(bb, outer_fn_param_vars); + return; + } + let bty = self.basic_block_ty_with_precondition(bb); let expected_args: IndexVec<_, _> = bty .as_ref() .params @@ -614,6 +726,41 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self.ctx.extend_clauses(clauses); } + /// Materializes the `BasicBlockType` for a target that inherits its + /// precondition by building its (pvar-free) layout and overwriting the last + /// param's refinement with the current env state. + fn install_inherited_bb_ty( + &mut self, + bb: BasicBlock, + outer_fn_param_vars: &HashMap, + ) { + let bty = self.ctx.basic_block_ty(self.local_def_id, bb); + + let mut capture = PrecondCapture::default(); + for (param_idx, param_rty) in bty.as_ref().params.iter_enumerated() { + if param_rty.ty.to_sort().is_singleton() { + continue; + } + let pty = match bty.param_kind(param_idx) { + BasicBlockTypeParamKind::Local(local, _) => self.env.local_type(local), + BasicBlockTypeParamKind::OuterFnParam(outer_idx) => { + let outer_var = outer_fn_param_vars[&outer_idx]; + PlaceType::with_ty_and_term( + param_rty.ty.clone().assert_closed().vacuous(), + chc::Term::var(outer_var), + ) + } + BasicBlockTypeParamKind::Synthetic => continue, + }; + capture.push(rty::RefinedTypeVar::Free(param_idx), pty); + } + capture.push_env_state(&self.env); + let precondition = capture.finish(&self.env); + + self.ctx + .register_basic_block_precondition(self.local_def_id, bb, precondition); + } + fn with_assumptions(&mut self, assumptions: Vec>, callback: F) -> T where F: FnOnce(&mut Self) -> T, @@ -1078,7 +1225,9 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let mut outer_fn_param_vars = HashMap::new(); - let bb_ty = self.basic_block_ty(self.basic_block).clone(); + let bb_ty = self + .basic_block_ty_with_precondition(self.basic_block) + .clone(); let params = &bb_ty.as_ref().params; assert!(!params.is_empty()); for (param_idx, param_rty) in params.iter_enumerated() { diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index 1b05008..9f7f57e 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -851,18 +851,38 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } // function return type is basic block return type let ret_ty = self.body.local_decls[mir::RETURN_PLACE].ty; - let rty = self - .type_builder - .for_template(&mut self.ctx) - .build_basic_block(&self.body, live_locals, ret_ty); - self.ctx.register_basic_block_ty(self.local_def_id, bb, rty); + if analyze::basic_block::needs_own_precondition(&self.body, bb) { + let bty = self + .type_builder + .for_template(&mut self.ctx) + .build_basic_block(&self.body, live_locals, ret_ty); + self.ctx + .register_basic_block_ty_with_precondition(self.local_def_id, bb, bty); + } else { + // The block inherits its predecessor's outgoing env state as its + // precondition, materialized lazily during the predecessor's + // analysis. Record only unrefined type here. + let bty = self + .type_builder + .build_basic_block(&self.body, live_locals, ret_ty); + self.ctx + .register_basic_block_ty_without_precondition(self.local_def_id, bb, bty); + }; } } fn analyze_basic_blocks(&mut self, expected_fn_ty: &rty::RefinedType) { let expected_fn_ty = expected_fn_ty.ty.as_function().unwrap(); - for bb in self.body.basic_blocks.indices() { - let rty = self.ctx.basic_block_ty(self.local_def_id, bb).clone(); + // Reverse postorder guarantees each block that inherits its precondition + // is visited after the predecessor that lazily materialized its type. + for (bb, data) in mir::traversal::reverse_postorder(&self.body) { + if data.is_cleanup { + continue; + } + let rty = self + .ctx + .basic_block_ty_with_precondition(self.local_def_id, bb) + .clone(); let drop_points = self.drop_points[&bb].clone(); self.ctx .basic_block_analyzer(self.local_def_id, bb) @@ -972,7 +992,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { fn assert_entry(&mut self, expected: &rty::RefinedType) { let mut entry_ty = self .ctx - .basic_block_ty(self.local_def_id, mir::START_BLOCK) + .basic_block_ty_with_precondition(self.local_def_id, mir::START_BLOCK) .clone(); tracing::debug!(expected = %expected.display(), entry = %entry_ty.display(), "assert_entry before"); let mut expected = expected.ty.as_function().cloned().unwrap(); diff --git a/src/refine/basic_block.rs b/src/refine/basic_block.rs index ca6e6b5..ca217c8 100644 --- a/src/refine/basic_block.rs +++ b/src/refine/basic_block.rs @@ -126,6 +126,17 @@ impl BasicBlockType { self.ty.clone() } + pub fn set_precondition(&mut self, refinement: rty::Refinement) { + let last_param_idx = self.ty.params.last_index().unwrap(); + self.ty.params.raw.last_mut().unwrap().refinement = refinement.map_var(|v| { + if v == rty::RefinedTypeVar::Free(last_param_idx) { + rty::RefinedTypeVar::Value + } else { + v + } + }); + } + /// Inner function type of BasicBlockType contains extra parameters that carry original /// function parameter values. `truncate_outer_fn_params` removes these extra parameters /// to subtype output of [`BasicBlockType::to_function_ty`] against the function type. diff --git a/src/refine/env.rs b/src/refine/env.rs index 536e838..d9c3dfb 100644 --- a/src/refine/env.rs +++ b/src/refine/env.rs @@ -902,7 +902,7 @@ where ) } - fn vars(&self) -> impl Iterator)> + '_ { + pub fn vars(&self) -> impl Iterator)> + '_ { self.locals .iter() .map(|(local, rty)| (Var::Local(*local), rty)) @@ -914,6 +914,10 @@ where .filter(|(_var, rty)| rty.is_refined()) } + pub fn assumptions(&self) -> &[Assumption] { + &self.assumptions + } + pub fn contains_local(&self, local: Local) -> bool { self.locals.contains_key(&local) || self.flow_locals.contains_key(&local) } diff --git a/src/refine/template.rs b/src/refine/template.rs index 5204094..ed0762e 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -262,6 +262,30 @@ impl<'tcx> TypeBuilder<'tcx> { } } + pub fn build_basic_block( + &mut self, + body: &rustc_middle::mir::Body<'tcx>, + live_locals: I, + ret_ty: mir_ty::Ty<'tcx>, + ) -> BasicBlockType + where + I: IntoIterator)>, + { + struct FakeRegistry; + impl TemplateRegistry for FakeRegistry { + fn register_template(&mut self, _tmpl: rty::Template) -> rty::RefinedType { + panic!("unexpected template registration") + } + } + self.for_template(&mut FakeRegistry) + .build_basic_block_with_precondition( + body, + live_locals.into_iter().collect(), + ret_ty, + Some(rty::Refinement::top()), + ) + } + pub fn for_template<'a, R>( &self, registry: &'a mut R, @@ -424,17 +448,14 @@ where self.registry.register_template(tmpl) } - pub fn build_basic_block( + fn build_basic_block_with_precondition( &mut self, body: &rustc_middle::mir::Body<'tcx>, - live_locals: I, + mut live_locals: Vec<(Local, mir_ty::TypeAndMut<'tcx>)>, ret_ty: mir_ty::Ty<'tcx>, - ) -> BasicBlockType - where - I: IntoIterator)>, - { + precondition: Option>, + ) -> BasicBlockType { // this is necessary for local_def::Analyzer::elaborate_unused_args - let mut live_locals: Vec<_> = live_locals.into_iter().collect(); live_locals.sort_by_key(|(local, _)| *local); let mut locals = IndexVec::::new(); @@ -458,7 +479,7 @@ where param_tys: tys, ret_ty, param_rtys: Default::default(), - param_refinement: None, + param_refinement: precondition, // not generating pvar of BB post ret_rty: Some(rty::RefinedType::unrefined( self.inner.build(ret_ty).vacuous(), @@ -472,6 +493,23 @@ where outer_fn_param_count: body.arg_count, } } + + pub fn build_basic_block( + &mut self, + body: &rustc_middle::mir::Body<'tcx>, + live_locals: I, + ret_ty: mir_ty::Ty<'tcx>, + ) -> BasicBlockType + where + I: IntoIterator)>, + { + self.build_basic_block_with_precondition( + body, + live_locals.into_iter().collect(), + ret_ty, + None, + ) + } } /// A builder for function template types.