diff --git a/dialects/hir/src/transforms/spill.rs b/dialects/hir/src/transforms/spill.rs index a981d5f57..973e63fa8 100644 --- a/dialects/hir/src/transforms/spill.rs +++ b/dialects/hir/src/transforms/spill.rs @@ -3,7 +3,7 @@ use alloc::rc::Rc; use midenc_hir::{ adt::SmallDenseMap, dialects::builtin::{Function, FunctionRef, LocalVariable}, - pass::{Pass, PassExecutionState}, + pass::{Pass, PassExecutionState, PostPassStatus}, BlockRef, BuilderExt, EntityMut, Op, OpBuilder, OperationName, OperationRef, Report, Rewriter, SourceSpan, Spanned, Symbol, ValueRef, }; @@ -38,6 +38,7 @@ impl Pass for TransformSpills { if function.is_declaration() { log::debug!(target: "insert-spills", "function has no body, no spills needed!"); state.preserved_analyses_mut().preserve_all(); + state.set_post_pass_status(PostPassStatus::Unchanged); return Ok(()); } let mut analysis = @@ -46,6 +47,7 @@ impl Pass for TransformSpills { if !analysis.has_spills() { log::debug!(target: "insert-spills", "no spills needed!"); state.preserved_analyses_mut().preserve_all(); + state.set_post_pass_status(PostPassStatus::Unchanged); return Ok(()); } @@ -62,7 +64,17 @@ impl Pass for TransformSpills { let op = function.as_operation_ref(); drop(function); - transforms::transform_spills(op, analysis, &mut interface, state.analysis_manager().clone()) + + let transform_result = transforms::transform_spills( + op, + analysis, + &mut interface, + state.analysis_manager().clone(), + )?; + + state.set_post_pass_status(transform_result); + + return Ok(()); } } diff --git a/dialects/scf/src/transforms/cfg_to_scf.rs b/dialects/scf/src/transforms/cfg_to_scf.rs index 0ca813aea..f54aa919b 100644 --- a/dialects/scf/src/transforms/cfg_to_scf.rs +++ b/dialects/scf/src/transforms/cfg_to_scf.rs @@ -7,7 +7,7 @@ use midenc_hir::{ diagnostics::Severity, dialects::builtin, dominance::DominanceInfo, - pass::{Pass, PassExecutionState}, + pass::{Pass, PassExecutionState, PostPassStatus}, Builder, EntityMut, Forward, Op, Operation, OperationName, OperationRef, RawWalk, Report, SmallVec, Spanned, Type, ValueRange, ValueRef, WalkResult, }; @@ -130,6 +130,7 @@ impl Pass for LiftControlFlowToSCF { }); if result.was_interrupted() { + state.set_post_pass_status(PostPassStatus::Unchanged); return result.into_result(); } @@ -141,6 +142,8 @@ impl Pass for LiftControlFlowToSCF { state.preserved_analyses_mut().preserve_all(); } + state.set_post_pass_status(changed.into()); + Ok(()) } } diff --git a/hir-transform/src/canonicalization.rs b/hir-transform/src/canonicalization.rs index 5e600a774..8adac3ed4 100644 --- a/hir-transform/src/canonicalization.rs +++ b/hir-transform/src/canonicalization.rs @@ -1,7 +1,7 @@ use alloc::{boxed::Box, format, rc::Rc}; use midenc_hir::{ - pass::{OperationPass, Pass, PassExecutionState}, + pass::{OperationPass, Pass, PassExecutionState, PostPassStatus}, patterns::{self, FrozenRewritePatternSet, GreedyRewriteConfig, RewritePatternSet}, Context, EntityMut, Operation, OperationName, Report, Spanned, }; @@ -96,6 +96,7 @@ impl Pass for Canonicalizer { ) -> Result<(), Report> { let Some(rewrites) = self.rewrites.as_ref() else { log::debug!("skipping canonicalization as there are no rewrite patterns to apply"); + state.set_post_pass_status(PostPassStatus::Unchanged); return Ok(()); }; let op = { @@ -129,15 +130,21 @@ impl Pass for Canonicalizer { } let op = op.borrow(); - match converged { + let changed = match converged { Ok(changed) => { - log::debug!("canonicalization converged for '{}', changed={changed}", op.name()) + log::debug!("canonicalization converged for '{}', changed={changed}", op.name()); + changed } - Err(changed) => log::warn!( - "canonicalization failed to converge for '{}', changed={changed}", - op.name() - ), - } + Err(changed) => { + log::warn!( + "canonicalization failed to converge for '{}', changed={changed}", + op.name() + ); + changed + } + }; + let ir_changed = changed.into(); + state.set_post_pass_status(ir_changed); Ok(()) } diff --git a/hir-transform/src/sccp.rs b/hir-transform/src/sccp.rs index f12c85685..a80a6e194 100644 --- a/hir-transform/src/sccp.rs +++ b/hir-transform/src/sccp.rs @@ -56,7 +56,7 @@ impl SparseConditionalConstantPropagation { fn rewrite( &mut self, op: &mut Operation, - _state: &mut PassExecutionState, + state: &mut PassExecutionState, solver: &DataFlowSolver, ) -> Result<(), Report> { let mut worklist = SmallVec::<[BlockRef; 8]>::default(); @@ -76,6 +76,7 @@ impl SparseConditionalConstantPropagation { add_to_worklist(op.regions(), &mut worklist); + let mut replaced_any = false; while let Some(mut block) = worklist.pop() { let mut block = block.borrow_mut(); let body = block.body_mut(); @@ -91,8 +92,10 @@ impl SparseConditionalConstantPropagation { let mut replaced_all = num_results != 0; for index in 0..num_results { let result = { op.borrow().get_result(index).borrow().as_value_ref() }; - replaced_all &= - replace_with_constant(solver, &mut builder, &mut folder, result); + let replaced = replace_with_constant(solver, &mut builder, &mut folder, result); + + replaced_any |= replaced; + replaced_all &= replaced; } // If all of the results of the operation were replaced, try to erase the operation @@ -112,7 +115,7 @@ impl SparseConditionalConstantPropagation { builder.set_insertion_point_to_start(block.as_block_ref()); for arg in block.arguments() { - replace_with_constant( + replaced_any |= replace_with_constant( solver, &mut builder, &mut folder, @@ -121,6 +124,8 @@ impl SparseConditionalConstantPropagation { } } + state.set_post_pass_status(replaced_any.into()); + Ok(()) } } diff --git a/hir-transform/src/sink.rs b/hir-transform/src/sink.rs index 3504c1398..c5c97ef1f 100644 --- a/hir-transform/src/sink.rs +++ b/hir-transform/src/sink.rs @@ -4,7 +4,7 @@ use midenc_hir::{ adt::SmallDenseMap, dominance::DominanceInfo, matchers::{self, Matcher}, - pass::{Pass, PassExecutionState}, + pass::{Pass, PassExecutionState, PostPassStatus}, traits::{ConstantLike, Terminator}, Backward, Builder, EntityMut, Forward, FxHashSet, OpBuilder, Operation, OperationName, OperationRef, ProgramPoint, RawWalk, Region, RegionBranchOpInterface, @@ -105,6 +105,7 @@ impl Pass for ControlFlowSink { let dominfo = state.analysis_manager().get_analysis::()?; + let mut sunk = PostPassStatus::Unchanged; operation.raw_prewalk_all::(|op: OperationRef| { let regions_to_sink = { let op = op.borrow(); @@ -118,7 +119,7 @@ impl Pass for ControlFlowSink { }; // Sink side-effect free operations. - control_flow_sink( + sunk = control_flow_sink( ®ions_to_sink, &dominfo, |op: &Operation, _region: &Region| op.is_memory_effect_free(), @@ -132,6 +133,8 @@ impl Pass for ControlFlowSink { ); }); + state.set_post_pass_status(sunk); + Ok(()) } } @@ -171,7 +174,7 @@ impl Pass for SinkOperandDefs { fn run_on_operation( &mut self, op: EntityMut<'_, Self::Target>, - _state: &mut PassExecutionState, + state: &mut PassExecutionState, ) -> Result<(), Report> { let operation = op.as_operation_ref(); drop(op); @@ -184,6 +187,7 @@ impl Pass for SinkOperandDefs { // then process the worklist, moving everything into position. let mut worklist = alloc::collections::VecDeque::default(); + let mut changed = PostPassStatus::Unchanged; // Visit ops in "true" post-order (i.e. block bodies are visited bottom-up). operation.raw_postwalk_all::(|operation: OperationRef| { // Determine if any of this operation's operands represent one of the following: @@ -308,6 +312,7 @@ impl Pass for SinkOperandDefs { log::trace!(target: "sink-operand-defs", " rewriting operand {operand_value} as {replacement}"); operand.borrow_mut().set(replacement); + changed = PostPassStatus::Changed; // If no other uses of this value remain, then remove the original // operation, as it is now dead. if !operand_value.borrow().is_used() { @@ -354,6 +359,7 @@ impl Pass for SinkOperandDefs { log::trace!(target: "sink-operand-defs", " rewriting operand {operand_value} as {replacement}"); sink_state.replacements.insert(operand_value, replacement); operand.borrow_mut().set(replacement); + changed = PostPassStatus::Changed; } else { log::trace!(target: "sink-operand-defs", " defining op is a constant with no other uses, moving into place"); // The original op can be moved @@ -397,6 +403,7 @@ impl Pass for SinkOperandDefs { } } + state.set_post_pass_status(changed); Ok(()) } } @@ -548,12 +555,14 @@ pub fn control_flow_sink( dominfo: &DominanceInfo, should_move_into_region: P, move_into_region: F, -) where +) -> PostPassStatus +where P: Fn(&Operation, &Region) -> bool, F: Fn(OperationRef, RegionRef), { let sinker = Sinker::new(dominfo, should_move_into_region, move_into_region); - sinker.sink_regions(regions); + let sunk_regions = sinker.sink_regions(regions); + (sunk_regions > 0).into() } /// Populates `regions` with regions of the provided region branch op that are executed at most once diff --git a/hir-transform/src/spill.rs b/hir-transform/src/spill.rs index d758ede48..65f4f814e 100644 --- a/hir-transform/src/spill.rs +++ b/hir-transform/src/spill.rs @@ -4,7 +4,7 @@ use midenc_hir::{ adt::{SmallDenseMap, SmallSet}, cfg::Graph, dominance::{DomTreeNode, DominanceFrontier, DominanceInfo}, - pass::AnalysisManager, + pass::{AnalysisManager, PostPassStatus}, traits::SingleRegion, BlockRef, Builder, Context, FxHashMap, OpBuilder, OpOperand, Operation, OperationRef, ProgramPoint, Region, RegionBranchOpInterface, RegionBranchPoint, RegionRef, Report, Rewriter, @@ -113,7 +113,7 @@ pub fn transform_spills( analysis: &mut SpillAnalysis, interface: &mut dyn TransformSpillsInterface, analysis_manager: AnalysisManager, -) -> Result<(), Report> { +) -> Result { assert!( op.borrow().implements::(), "the spills transformation is not supported when the root op is multi-region" @@ -252,7 +252,7 @@ pub fn transform_spills( )?; } - Ok(()) + Ok(PostPassStatus::Changed) } fn rewrite_single_block_spills( diff --git a/hir/src/ir/print.rs b/hir/src/ir/print.rs index 86381db07..ca3bfb3ea 100644 --- a/hir/src/ir/print.rs +++ b/hir/src/ir/print.rs @@ -9,6 +9,7 @@ use crate::{ AttributeValue, EntityWithId, SuccessorOperands, Value, }; +#[derive(Debug)] pub struct OpPrintingFlags { pub print_entry_block_headers: bool, pub print_intrinsic_attributes: bool, diff --git a/hir/src/pass.rs b/hir/src/pass.rs index d87999128..8e2182a8f 100644 --- a/hir/src/pass.rs +++ b/hir/src/pass.rs @@ -7,24 +7,57 @@ pub mod registry; mod specialization; pub mod statistics; +use alloc::string::String; + pub use self::{ analysis::{Analysis, AnalysisManager, OperationAnalysis, PreservedAnalyses}, instrumentation::{PassInstrumentation, PassInstrumentor, PipelineParentInfo}, - manager::{Nesting, OpPassManager, PassDisplayMode, PassManager}, - pass::{OperationPass, Pass, PassExecutionState}, + manager::{IRPrintingConfig, Nesting, OpPassManager, PassDisplayMode, PassManager}, + pass::{OperationPass, Pass, PassExecutionState, PostPassStatus}, registry::{PassInfo, PassPipelineInfo}, specialization::PassTarget, statistics::{PassStatistic, Statistic, StatisticValue}, }; +use crate::{EntityRef, Operation, OperationName, OperationRef, SmallVec}; -/// A `Pass` which prints IR it is run on, based on provided configuration. +/// Handles IR printing, based on the [`IRPrintingConfig`] passed in +/// [Print::new]. Currently, this struct is managed by the [`PassManager`]'s [`PassInstrumentor`], +/// which calls the Print struct via its [`PassInstrumentation`] trait implementation. +/// +/// The configuration passed by [`IRPrintingConfig`] controls *when* the IR gets displayed, rather +/// than *how*. The display format itself depends on the `Display` implementation done by each +/// [`Operation`]. +/// +/// [`Print::selected_passes`] controls which passes are selected to be printable. This means that +/// those selected passes will run all the configured filters; which will determine whether +/// the pass displays the IR or not. The available options are [`SelectedPasses::All`] to enable all +/// the passes and [`SelectedPasses::Just`] to enable a select set of passes. +/// +/// The filters that run on the selected passes are: +/// - [`Print::only_when_modified`] will only print the IR if said pass modified the IR. +/// +/// - [`Print::op_filter`] will only display a specific subset of operations. #[derive(Default)] pub struct Print { - filter: OpFilter, + selected_passes: Option, + + only_when_modified: bool, + op_filter: Option, + target: Option, } -#[derive(Default)] +/// Which passes are enabled for IR printing. +#[derive(Debug)] +enum SelectedPasses { + /// Enable all passes for IR Printing. + All, + /// Just select a subset of passes for IR printing. + Just(SmallVec<[String; 1]>), +} + +#[allow(dead_code)] +#[derive(Default, Debug)] enum OpFilter { /// Print all operations #[default] @@ -40,30 +73,66 @@ enum OpFilter { } impl Print { - /// Create a printer that prints any operation - pub fn any() -> Self { - Self { - filter: OpFilter::All, - target: None, - } + pub fn new(config: &IRPrintingConfig) -> Option { + let print = if config.print_ir_after_all + || !config.print_ir_after_pass.is_empty() + || config.print_ir_after_modified + { + Some(Self::default()) + } else { + None + }; + print.map(|p| p.with_pass_filter(config)).map(|p| p.with_symbol_filter(config)) } - /// Create a printer that only prints operations of type `T` - pub fn only() -> Self { + pub fn with_type_filter(mut self) -> Self { let dialect = ::dialect_name(); let op = ::name(); - Self { - filter: OpFilter::Type { dialect, op }, - target: None, - } + self.op_filter = Some(OpFilter::Type { dialect, op }); + self } + #[allow(dead_code)] /// Create a printer that only prints `Symbol` operations containing `name` - pub fn symbol_matching(name: &'static str) -> Self { - Self { - filter: OpFilter::Symbol(Some(name)), - target: None, - } + fn with_symbol_matching(mut self, name: &'static str) -> Self { + self.op_filter = Some(OpFilter::Symbol(Some(name))); + self + } + + #[allow(unused_mut)] + fn with_symbol_filter(mut self, _config: &IRPrintingConfig) -> Self { + // NOTE: At the moment, symbol filtering is not processed by the CLI. However, were it to be + // added, it could be done inside this function + self.with_all_symbols() + } + + fn with_all_symbols(mut self) -> Self { + self.op_filter = Some(OpFilter::All); + self + } + + fn with_pass_filter(mut self, config: &IRPrintingConfig) -> Self { + let is_ir_filter_set = if config.print_ir_after_all { + self.selected_passes = Some(SelectedPasses::All); + true + } else if !config.print_ir_after_pass.is_empty() { + self.selected_passes = Some(SelectedPasses::Just(config.print_ir_after_pass.clone())); + true + } else { + false + }; + + if config.print_ir_after_modified { + self.only_when_modified = true; + // NOTE: If the user specified the "print after modification" flag, but didn't specify + // any IR pass filter flag; then we assume that the desired behavior is to set the "all + // pass" filter. + if !is_ir_filter_set { + self.selected_passes = Some(SelectedPasses::All); + } + }; + + self } /// Specify the `log` target to write the IR output to. @@ -75,55 +144,99 @@ impl Print { self.target = Some(target); self } -} - -impl Pass for Print { - type Target = crate::Operation; - - fn name(&self) -> &'static str { - "print" - } - - fn can_schedule_on(&self, _name: &crate::OperationName) -> bool { - true - } - fn run_on_operation( - &mut self, - op: crate::EntityMut<'_, Self::Target>, - _state: &mut PassExecutionState, - ) -> Result<(), crate::Report> { - let op = op.into_entity_ref(); - match self.filter { - OpFilter::All => { + fn print_ir(&self, op: EntityRef<'_, Operation>) { + match self.op_filter { + Some(OpFilter::All) => { let target = self.target.as_deref().unwrap_or("printer"); log::trace!(target: target, "{op}"); } - OpFilter::Type { + Some(OpFilter::Type { dialect, op: op_name, - } => { + }) => { let name = op.name(); if name.dialect() == dialect && name.name() == op_name { let target = self.target.as_deref().unwrap_or("printer"); log::trace!(target: target, "{op}"); } } - OpFilter::Symbol(None) => { + Some(OpFilter::Symbol(None)) => { if let Some(sym) = op.as_symbol() { let name = sym.name().as_str(); let target = self.target.as_deref().unwrap_or(name); log::trace!(target: target, "{}", sym.as_symbol_operation()); } } - OpFilter::Symbol(Some(filter)) => { + Some(OpFilter::Symbol(Some(filter))) => { if let Some(sym) = op.as_symbol().filter(|sym| sym.name().as_str().contains(filter)) { let target = self.target.as_deref().unwrap_or(filter); log::trace!(target: target, "{}", sym.as_symbol_operation()); } } + None => (), + } + } + + fn pass_filter(&self, pass: &dyn OperationPass) -> bool { + match &self.selected_passes { + Some(SelectedPasses::All) => true, + Some(SelectedPasses::Just(passes)) => passes.iter().any(|p| pass.name() == *p), + None => false, + } + } + + fn should_print(&self, pass: &dyn OperationPass, ir_changed: &PostPassStatus) -> bool { + let pass_filter = self.pass_filter(pass); + + // Always print, unless "only_when_modified" has been set and there have not been changes. + let modification_filter = + !matches!((self.only_when_modified, ir_changed), (true, PostPassStatus::Unchanged)); + + pass_filter && modification_filter + } +} + +impl PassInstrumentation for Print { + fn run_before_pipeline( + &mut self, + _name: Option<&OperationName>, + _parent_info: &PipelineParentInfo, + op: OperationRef, + ) { + if !self.only_when_modified { + return; + } + + log::trace!("IR before the pass pipeline"); + let op = op.borrow(); + self.print_ir(op); + } + + fn run_before_pass(&mut self, pass: &dyn OperationPass, op: &OperationRef) { + if self.only_when_modified { + return; + } + if self.pass_filter(pass) { + log::trace!("Before the {} pass", pass.name()); + let op = op.borrow(); + self.print_ir(op); + } + } + + fn run_after_pass( + &mut self, + pass: &dyn OperationPass, + op: &OperationRef, + post_execution_state: &PassExecutionState, + ) { + let changed = post_execution_state.post_pass_status(); + + if self.should_print(pass, changed) { + log::trace!("After the {} pass", pass.name()); + let op = op.borrow(); + self.print_ir(op); } - Ok(()) } } diff --git a/hir/src/pass/instrumentation.rs b/hir/src/pass/instrumentation.rs index c9d3f15ff..ccddaac05 100644 --- a/hir/src/pass/instrumentation.rs +++ b/hir/src/pass/instrumentation.rs @@ -5,7 +5,7 @@ use compact_str::CompactString; use smallvec::SmallVec; use super::OperationPass; -use crate::{OperationName, OperationRef}; +use crate::{pass::PassExecutionState, OperationName, OperationRef}; #[allow(unused_variables)] pub trait PassInstrumentation { @@ -13,6 +13,7 @@ pub trait PassInstrumentation { &mut self, name: Option<&OperationName>, parent_info: &PipelineParentInfo, + op: OperationRef, ) { } fn run_after_pipeline( @@ -22,7 +23,13 @@ pub trait PassInstrumentation { ) { } fn run_before_pass(&mut self, pass: &dyn OperationPass, op: &OperationRef) {} - fn run_after_pass(&mut self, pass: &dyn OperationPass, op: &OperationRef) {} + fn run_after_pass( + &mut self, + pass: &dyn OperationPass, + op: &OperationRef, + post_execution_state: &PassExecutionState, + ) { + } fn run_after_pass_failed(&mut self, pass: &dyn OperationPass, op: &OperationRef) {} fn run_before_analysis(&mut self, name: &str, id: &TypeId, op: &OperationRef) {} fn run_after_analysis(&mut self, name: &str, id: &TypeId, op: &OperationRef) {} @@ -38,8 +45,9 @@ impl PassInstrumentation for Box

{ &mut self, name: Option<&OperationName>, parent_info: &PipelineParentInfo, + op: OperationRef, ) { - (**self).run_before_pipeline(name, parent_info); + (**self).run_before_pipeline(name, parent_info, op); } fn run_after_pipeline( @@ -54,8 +62,13 @@ impl PassInstrumentation for Box

{ (**self).run_before_pass(pass, op); } - fn run_after_pass(&mut self, pass: &dyn OperationPass, op: &OperationRef) { - (**self).run_after_pass(pass, op); + fn run_after_pass( + &mut self, + pass: &dyn OperationPass, + op: &OperationRef, + post_execution_state: &PassExecutionState, + ) { + (**self).run_after_pass(pass, op, post_execution_state); } fn run_after_pass_failed(&mut self, pass: &dyn OperationPass, op: &OperationRef) { @@ -81,8 +94,9 @@ impl PassInstrumentor { &self, name: Option<&OperationName>, parent_info: &PipelineParentInfo, + op: OperationRef, ) { - self.instrument(|pi| pi.run_before_pipeline(name, parent_info)); + self.instrument(|pi| pi.run_before_pipeline(name, parent_info, op)); } pub fn run_after_pipeline( @@ -97,8 +111,13 @@ impl PassInstrumentor { self.instrument(|pi| pi.run_before_pass(pass, op)); } - pub fn run_after_pass(&self, pass: &dyn OperationPass, op: &OperationRef) { - self.instrument(|pi| pi.run_after_pass(pass, op)); + pub fn run_after_pass( + &self, + pass: &dyn OperationPass, + op: &OperationRef, + post_execution_state: &PassExecutionState, + ) { + self.instrument(|pi| pi.run_after_pass(pass, op, post_execution_state)); } pub fn run_after_pass_failed(&self, pass: &dyn OperationPass, op: &OperationRef) { diff --git a/hir/src/pass/manager.rs b/hir/src/pass/manager.rs index a9e4cc23d..c0e2b9ad5 100644 --- a/hir/src/pass/manager.rs +++ b/hir/src/pass/manager.rs @@ -1,7 +1,13 @@ -use alloc::{boxed::Box, collections::BTreeMap, format, rc::Rc, string::ToString}; +use alloc::{ + boxed::Box, + collections::BTreeMap, + format, + rc::Rc, + string::{String, ToString}, +}; use compact_str::{CompactString, ToCompactString}; -use midenc_session::diagnostics::Severity; +use midenc_session::{diagnostics::Severity, Options}; use smallvec::{smallvec, SmallVec}; use super::{ @@ -9,8 +15,10 @@ use super::{ PassInstrumentor, PipelineParentInfo, Statistic, }; use crate::{ - traits::IsolatedFromAbove, Context, EntityMut, OpPrintingFlags, OpRegistration, Operation, - OperationName, OperationRef, Report, + pass::{PostPassStatus, Print}, + traits::IsolatedFromAbove, + Context, EntityMut, OpPrintingFlags, OpRegistration, Operation, OperationName, OperationRef, + Report, }; #[derive(Debug, Default, Copy, Clone, PartialEq, Eq)] @@ -29,11 +37,38 @@ pub enum PassDisplayMode { // TODO(pauls) #[allow(unused)] +#[derive(Default, Debug)] pub struct IRPrintingConfig { - print_module_scope: bool, - print_after_only_on_change: bool, - print_after_only_on_failure: bool, - flags: OpPrintingFlags, + pub print_module_scope: bool, + pub print_after_only_on_failure: bool, + // NOTE: Taken from the Options struct + pub print_ir_after_all: bool, + pub print_ir_after_pass: SmallVec<[String; 1]>, + pub print_ir_after_modified: bool, + pub flags: OpPrintingFlags, +} + +impl TryFrom<&Options> for IRPrintingConfig { + type Error = Report; + + fn try_from(options: &Options) -> Result { + let pass_filters = options.print_ir_after_pass.clone(); + + if options.print_ir_after_all && !pass_filters.is_empty() { + return Err(Report::msg( + "Flags `print_ir_after_all` and `print_ir_after_pass` are mutually exclusive. \ + Please select only one." + .to_string(), + )); + }; + + Ok(IRPrintingConfig { + print_ir_after_all: options.print_ir_after_all, + print_ir_after_pass: pass_filters.into(), + print_ir_after_modified: options.print_ir_after_modified, + ..Default::default() + }) + } } /// The main pass manager and pipeline builder @@ -169,8 +204,14 @@ impl PassManager { self } - pub fn enable_ir_printing(&mut self, _config: IRPrintingConfig) { - todo!() + pub fn enable_ir_printing(mut self, config: IRPrintingConfig) -> Self { + let print = Print::new(&config); + + if let Some(print) = print { + let print = Box::new(print); + self.add_instrumentation(print); + } + self } pub fn enable_timing(&mut self, yes: bool) -> &mut Self { @@ -820,7 +861,7 @@ impl OpToOpPassAdaptor { let mut op_name = None; if let Some(instrumentor) = instrumentor.as_deref() { op_name = pm.name().cloned(); - instrumentor.run_before_pipeline(op_name.as_ref(), parent_info.as_ref().unwrap()); + instrumentor.run_before_pipeline(op_name.as_ref(), parent_info.as_ref().unwrap(), op); } for pass in pm.passes_mut() { @@ -943,8 +984,11 @@ impl OpToOpPassAdaptor { // // * If the pass said that it preserved all analyses then it can't have permuted the IR let run_verifier_now = !execution_state.preserved_analyses().is_all(); + if run_verifier_now { - result = Self::verify(&op, run_verifier_recursively); + if let Err(verification_result) = Self::verify(&op, run_verifier_recursively) { + result = result.map_err(|_| verification_result); + } } } @@ -952,12 +996,12 @@ impl OpToOpPassAdaptor { if result.is_err() { instrumentor.run_after_pass_failed(pass, &op); } else { - instrumentor.run_after_pass(pass, &op); + instrumentor.run_after_pass(pass, &op, &execution_state); } } // Return the pass result - result + result.map(|_| ()) } fn verify(op: &OperationRef, verify_recursively: bool) -> Result<(), Report> { @@ -1008,7 +1052,7 @@ impl OpToOpPassAdaptor { } } } - + state.set_post_pass_status(PostPassStatus::Unchanged); Ok(()) } } diff --git a/hir/src/pass/pass.rs b/hir/src/pass/pass.rs index 231ea5c09..224ee31c6 100644 --- a/hir/src/pass/pass.rs +++ b/hir/src/pass/pass.rs @@ -15,6 +15,7 @@ pub trait OperationPass { fn as_any_mut(&mut self) -> &mut dyn Any; fn into_any(self: Box) -> Box; fn name(&self) -> &'static str; + fn argument(&self) -> &'static str { // NOTE: Could we compute an argument string from the type name? "" @@ -136,6 +137,22 @@ where } } +#[derive(Debug, PartialEq, Clone, Copy)] +pub enum PostPassStatus { + Unchanged, + Changed, +} + +impl From for PostPassStatus { + fn from(ir_was_changed: bool) -> Self { + if ir_was_changed { + PostPassStatus::Changed + } else { + PostPassStatus::Unchanged + } + } +} + /// A compiler pass which operates on an [Operation] of some kind. #[allow(unused_variables)] pub trait Pass: Sized + Any { @@ -361,6 +378,7 @@ pub struct PassExecutionState { // rooted at the provided operation. #[allow(unused)] pipeline_executor: Option>, + post_pass_status: PostPassStatus, } impl PassExecutionState { pub fn new( @@ -375,6 +393,7 @@ impl PassExecutionState { analysis_manager, preserved_analyses: Default::default(), pipeline_executor, + post_pass_status: PostPassStatus::Unchanged, } } @@ -403,6 +422,16 @@ impl PassExecutionState { &mut self.preserved_analyses } + #[inline(always)] + pub fn post_pass_status(&self) -> &PostPassStatus { + &self.post_pass_status + } + + #[inline(always)] + pub fn set_post_pass_status(&mut self, post_pass_status: PostPassStatus) { + self.post_pass_status = post_pass_status; + } + pub fn run_pipeline( &mut self, pipeline: &mut OpPassManager, diff --git a/midenc-compile/src/compiler.rs b/midenc-compile/src/compiler.rs index 1a6f27ecd..e9e377d23 100644 --- a/midenc-compile/src/compiler.rs +++ b/midenc-compile/src/compiler.rs @@ -339,6 +339,13 @@ pub struct UnstableOptions { ) )] pub print_ir_after_pass: Vec, + /// Only print the IR if the pass modified the IR structure. If this flag is set, and no IR + /// filter flag is; then the default behavior is to print the IR after every pass. + #[cfg_attr( + feature = "std", + arg(long, default_value_t = false, help_heading = "Passes") + )] + pub print_ir_after_modified: bool, } impl CodegenOptions { @@ -505,6 +512,7 @@ impl Compiler { options.print_cfg_after_pass = unstable.print_cfg_after_pass; options.print_ir_after_all = unstable.print_ir_after_all; options.print_ir_after_pass = unstable.print_ir_after_pass; + options.print_ir_after_modified = unstable.print_ir_after_modified; // Establish --target-dir let target_dir = if self.target_dir.is_absolute() { diff --git a/midenc-compile/src/stages/rewrite.rs b/midenc-compile/src/stages/rewrite.rs index 5fa1448c9..a7d8ffa06 100644 --- a/midenc-compile/src/stages/rewrite.rs +++ b/midenc-compile/src/stages/rewrite.rs @@ -3,7 +3,7 @@ use alloc::boxed::Box; use midenc_dialect_hir::transforms::TransformSpills; use midenc_dialect_scf::transforms::LiftControlFlowToSCF; use midenc_hir::{ - pass::{Nesting, PassManager}, + pass::{IRPrintingConfig, Nesting, PassManager}, patterns::{GreedyRewriteConfig, RegionSimplificationLevel}, Op, }; @@ -22,6 +22,7 @@ impl Stage for ApplyRewritesStage { } fn run(&mut self, input: Self::Input, context: Rc) -> CompilerResult { + let ir_print_config: IRPrintingConfig = (&context.as_ref().session().options).try_into()?; log::debug!(target: "driver", "applying rewrite passes"); // TODO(pauls): Set up pass registration for new pass infra /* @@ -48,7 +49,8 @@ impl Stage for ApplyRewritesStage { */ // Construct a pass manager with the default pass pipeline - let mut pm = PassManager::on::(context.clone(), Nesting::Implicit); + let mut pm = PassManager::on::(context.clone(), Nesting::Implicit) + .enable_ir_printing(ir_print_config); let mut rewrite_config = GreedyRewriteConfig::default(); rewrite_config.with_region_simplification_level(RegionSimplificationLevel::Normal); diff --git a/midenc-session/src/options/mod.rs b/midenc-session/src/options/mod.rs index af547aa60..359e21ee7 100644 --- a/midenc-session/src/options/mod.rs +++ b/midenc-session/src/options/mod.rs @@ -52,6 +52,8 @@ pub struct Options { pub print_ir_after_all: bool, /// Print IR to stdout each time the named passes are applied pub print_ir_after_pass: Vec, + /// Only print the IR if the pass modified the IR structure. + pub print_ir_after_modified: bool, /// Save intermediate artifacts in memory during compilation pub save_temps: bool, /// We store any leftover argument matches in the session options for use @@ -126,6 +128,7 @@ impl Options { print_cfg_after_pass: vec![], print_ir_after_all: false, print_ir_after_pass: vec![], + print_ir_after_modified: false, flags: CompileFlags::default(), } }