diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index c7706bf6a..75bbea399 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -1,53 +1,147 @@ -//! This module provides functions for inspecting and modifying the nature of -//! non local edges in a Hugr. -// -//TODO Add `remove_nonlocal_edges` and `add_nonlocal_edges` functions +//! This module provides functions for finding non-local edges +//! in a Hugr and converting them to local edges. +#![warn(missing_docs)] use itertools::Itertools as _; -use thiserror::Error; -use hugr_core::{HugrView, IncomingPort}; +use hugr_core::{ + HugrView, IncomingPort, Wire, + hugr::hugrmut::HugrMut, + types::{EdgeKind, Type}, +}; -/// Returns an iterator over all non local edges in a Hugr. +use crate::ComposablePass; + +mod localize; +use localize::ExtraSourceReqs; + +/// [ComposablePass] wrapper for [remove_nonlocal_edges] +#[derive(Clone, Debug, Hash)] +pub struct LocalizeEdges; + +/// Error from [LocalizeEdges] or [remove_nonlocal_edges] +#[derive(derive_more::Error, derive_more::Display, derive_more::From, Debug, PartialEq)] +#[non_exhaustive] +pub enum LocalizeEdgesError {} + +impl ComposablePass for LocalizeEdges { + type Error = LocalizeEdgesError; + + type Result = (); + + fn run(&self, hugr: &mut H) -> Result { + remove_nonlocal_edges(hugr) + } +} + +/// Returns an iterator over all non local edges in a Hugr beneath the entrypoint. /// -/// All `(node, in_port)` pairs are returned where `in_port` is a value port -/// connected to a node with a parent other than the parent of `node`. +/// All `(node, in_port)` pairs are returned where `in_port` is a value port connected to a +/// node whose parent is both beneath the entrypoint and different from the parent of `node`. pub fn nonlocal_edges(hugr: &H) -> impl Iterator + '_ { hugr.entry_descendants().flat_map(move |node| { hugr.in_value_types(node).filter_map(move |(in_p, _)| { - let parent = hugr.get_parent(node); - hugr.linked_outputs(node, in_p) - .any(|(neighbour_node, _)| parent != hugr.get_parent(neighbour_node)) - .then_some((node, in_p)) + let (src, _) = hugr.single_linked_output(node, in_p)?; + (hugr.get_parent(node) != hugr.get_parent(src) + && ancestors(src, hugr).any(|a| a == hugr.entrypoint())) + .then_some((node, in_p)) }) }) } -#[derive(Error, Debug, Clone, PartialEq, Eq)] +/// Legacy alias of [FindNonLocalEdgesError] +#[deprecated(note = "Use FindNonLocalEdgesError")] +pub type NonLocalEdgesError = FindNonLocalEdgesError; + +/// An error from [ensure_no_nonlocal_edges] +#[derive(Clone, derive_more::Error, derive_more::Display, Debug, PartialEq, Eq)] #[non_exhaustive] -pub enum NonLocalEdgesError { - #[error("Found {} nonlocal edges", .0.len())] +pub enum FindNonLocalEdgesError { + /// Some nonlocal edges were found + #[display("Found {} nonlocal edges", _0.len())] + #[error(ignore)] // Vec not convertible Edges(Vec<(N, IncomingPort)>), } /// Verifies that there are no non local value edges in the Hugr. -pub fn ensure_no_nonlocal_edges(hugr: &H) -> Result<(), NonLocalEdgesError> { +pub fn ensure_no_nonlocal_edges( + hugr: &H, +) -> Result<(), FindNonLocalEdgesError> { let non_local_edges: Vec<_> = nonlocal_edges(hugr).collect_vec(); if non_local_edges.is_empty() { Ok(()) } else { - Err(NonLocalEdgesError::Edges(non_local_edges))? + Err(FindNonLocalEdgesError::Edges(non_local_edges))? } } +fn just_types<'a, X: 'a>(v: impl IntoIterator) -> impl Iterator { + v.into_iter().map(|(_, t)| t.clone()) +} + +/// Converts all non-local edges in a Hugr into local ones, by inserting extra inputs +/// to container nodes and extra outports to Input nodes (and conversely to outputs of +/// [DataflowBlock]s). +/// +/// [DataflowBlock]: hugr_core::ops::DataflowBlock +pub fn remove_nonlocal_edges(hugr: &mut H) -> Result<(), LocalizeEdgesError> { + // Group all the non-local edges in the graph by target node, + // storing for each the source and type (well-defined as these are Value edges). + let nonlocal_edges: Vec<_> = nonlocal_edges(hugr) + .map(|(node, inport)| { + // unwrap because nonlocal_edges(hugr) already skips in-ports with !=1 linked outputs. + let (src_n, outp) = hugr.single_linked_output(node, inport).unwrap(); + debug_assert!(hugr.get_parent(src_n).unwrap() != hugr.get_parent(node).unwrap()); + let Some(EdgeKind::Value(ty)) = hugr.get_optype(src_n).port_kind(outp) else { + panic!("impossible") + }; + (node, (Wire::new(src_n, outp), ty)) + }) + .collect(); + + if nonlocal_edges.is_empty() { + return Ok(()); + } + + // We now compute the sources needed by each parent node. + let needs_sources_map = { + let mut bnsm = ExtraSourceReqs::default(); + for (target_node, (source, ty)) in nonlocal_edges.iter() { + let parent = hugr.get_parent(*target_node).unwrap(); + debug_assert!(hugr.get_parent(parent).is_some()); + bnsm.add_edge(&*hugr, parent, *source, ty.clone()); + } + bnsm + }; + + debug_assert!(nonlocal_edges.iter().all(|(n, (source, _))| { + let source_parent = hugr.get_parent(source.node()).unwrap(); + let source_gp = hugr.get_parent(source_parent); + ancestors(*n, hugr) + .skip(1) + .take_while(|&a| a != source_parent && source_gp.is_none_or(|gp| a != gp)) + .all(|parent| needs_sources_map.parent_needs(parent, *source)) + })); + + needs_sources_map.thread_hugr(hugr); + + Ok(()) +} + +fn ancestors(n: H::Node, h: &H) -> impl Iterator { + std::iter::successors(Some(n), |n| h.get_parent(*n)) +} + #[cfg(test)] mod test { use hugr_core::{ - builder::{DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer}, - extension::prelude::{Noop, bool_t}, - ops::handle::NodeHandle, + builder::{DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer}, + extension::prelude::{Noop, bool_t, either_type}, + ops::handle::{BasicBlockID, NodeHandle}, + ops::{Tag, TailLoop, Value}, type_row, types::Signature, }; + use rstest::rstest; use super::*; @@ -90,7 +184,268 @@ mod test { }; assert_eq!( ensure_no_nonlocal_edges(&hugr).unwrap_err(), - NonLocalEdgesError::Edges(vec![edge]) + FindNonLocalEdgesError::Edges(vec![edge]) + ); + } + + #[rstest] + fn localize_dfg(#[values(true, false)] same_src: bool) { + let mut hugr = { + let mut outer = DFGBuilder::new(Signature::new_endo(vec![bool_t(); 2])).unwrap(); + let [w0, mut w1] = outer.input_wires_arr(); + if !same_src { + [w1] = outer + .add_dataflow_op(Noop::new(bool_t()), [w1]) + .unwrap() + .outputs_arr(); + } + let inner_outs = { + let inner = outer + .dfg_builder(Signature::new(vec![], vec![bool_t(); 2]), []) + .unwrap(); + // Note two `ext` edges to the same (Input) node here + inner.finish_with_outputs([w0, w1]).unwrap().outputs() + }; + outer.finish_hugr_with_outputs(inner_outs).unwrap() + }; + assert!(ensure_no_nonlocal_edges(&hugr).is_err()); + remove_nonlocal_edges(&mut hugr).unwrap(); + hugr.validate().unwrap(); + assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); + } + + #[test] + fn localize_tailloop() { + let (t1, t2, t3) = (Type::UNIT, bool_t(), Type::new_unit_sum(3)); + let mut hugr = { + let mut outer = DFGBuilder::new(Signature::new_endo(vec![ + t1.clone(), + t2.clone(), + t3.clone(), + ])) + .unwrap(); + let [s1, s2, s3] = outer.input_wires_arr(); + let [s2, s3] = { + let mut inner = outer + .tail_loop_builder( + [(t1.clone(), s1)], + [(t3.clone(), s3)], + vec![t2.clone()].into(), + ) + .unwrap(); + let [_s1, s3] = inner.input_wires_arr(); + let control = inner + .add_dataflow_op( + Tag::new( + TailLoop::BREAK_TAG, + vec![vec![t1.clone()].into(), vec![t2.clone()].into()], + ), + [s2], + ) + .unwrap() + .out_wire(0); + inner + .finish_with_outputs(control, [s3]) + .unwrap() + .outputs_arr() + }; + outer.finish_hugr_with_outputs([s1, s2, s3]).unwrap() + }; + assert!(ensure_no_nonlocal_edges(&hugr).is_err()); + remove_nonlocal_edges(&mut hugr).unwrap(); + hugr.validate().unwrap_or_else(|e| panic!("{e}")); + assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); + } + + #[test] + fn localize_conditional() { + let (t1, t2, t3) = (Type::UNIT, bool_t(), Type::new_unit_sum(3)); + let out_variants = vec![t1.clone().into(), t2.clone().into()]; + let out_type = Type::new_sum(out_variants.clone()); + let mut hugr = { + let mut outer = DFGBuilder::new(Signature::new( + vec![t1.clone(), t2.clone(), t3.clone()], + out_type.clone(), + )) + .unwrap(); + let [s1, s2, s3] = outer.input_wires_arr(); + let [out] = { + let mut cond = outer + .conditional_builder((vec![type_row![]; 3], s3), [], out_type.into()) + .unwrap(); + + { + let mut case = cond.case_builder(0).unwrap(); + let [r] = case + .add_dataflow_op(Tag::new(0, out_variants.clone()), [s1]) + .unwrap() + .outputs_arr(); + case.finish_with_outputs([r]).unwrap(); + } + { + let mut case = cond.case_builder(1).unwrap(); + let [r] = case + .add_dataflow_op(Tag::new(1, out_variants.clone()), [s2]) + .unwrap() + .outputs_arr(); + case.finish_with_outputs([r]).unwrap(); + } + { + let mut case = cond.case_builder(2).unwrap(); + let u = case.add_load_value(Value::unit()); + let [r] = case + .add_dataflow_op(Tag::new(0, out_variants.clone()), [u]) + .unwrap() + .outputs_arr(); + case.finish_with_outputs([r]).unwrap(); + } + cond.finish_sub_container().unwrap().outputs_arr() + }; + outer.finish_hugr_with_outputs([out]).unwrap() + }; + assert!(ensure_no_nonlocal_edges(&hugr).is_err()); + remove_nonlocal_edges(&mut hugr).unwrap(); + hugr.validate().unwrap_or_else(|e| panic!("{e}")); + assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); + } + + #[test] + fn localize_cfg() { + // Cfg consists of 4 dataflow blocks and an exit block + // + // The 4 dataflow blocks form a diamond, and the bottom block branches + // either to the entry block or the exit block. + // + // The left block contains non-local uses of a value from outside the CFG (ext edge) + // and a value from the entry block (dom edge) - the `ext` must be threaded through + // all blocks because of the loop, the `dom` stays within (the same iter of) the loop. + // + // All non-trivial(i.e. more than one choice of successor) branching is + // done on an option type to exercise both empty and occupied control + // sums. + // + // All branches have an other-output. + let branch_sum_type = either_type(Type::UNIT, Type::UNIT); + let branch_type = Type::from(branch_sum_type.clone()); + let branch_variants = branch_sum_type + .variants() + .cloned() + .map(|x| x.try_into().unwrap()) + .collect_vec(); + let ext_edge_type = bool_t(); + let dom_edge_type = Type::new_unit_sum(3); + let other_output_type = branch_type.clone(); + let mut outer = DFGBuilder::new(Signature::new( + vec![branch_type.clone(), ext_edge_type.clone(), Type::UNIT], + vec![Type::UNIT, other_output_type.clone()], + )) + .unwrap(); + let [b, src_ext, unit] = outer.input_wires_arr(); + let mut cfg = outer + .cfg_builder( + [(Type::UNIT, unit), (branch_type.clone(), b)], + vec![Type::UNIT, other_output_type.clone()].into(), + ) + .unwrap(); + + let (entry, src_dom) = { + let mut entry = cfg + .entry_builder(branch_variants.clone(), other_output_type.clone().into()) + .unwrap(); + let [_, b] = entry.input_wires_arr(); + + let cst = entry.add_load_value(Value::unit_sum(1, 3).unwrap()); + + (entry.finish_with_outputs(b, [b]).unwrap(), cst) + }; + let exit = cfg.exit_block(); + + let (bb_left, tgt_ext, tgt_dom) = { + let mut bb = cfg + .block_builder( + vec![Type::UNIT, other_output_type.clone()].into(), + [type_row![]], + other_output_type.clone().into(), + ) + .unwrap(); + let [unit, oo] = bb.input_wires_arr(); + let tgt_ext = bb + .add_dataflow_op(Noop::new(ext_edge_type.clone()), [src_ext]) + .unwrap(); + + let tgt_dom = bb + .add_dataflow_op(Noop::new(dom_edge_type.clone()), [src_dom]) + .unwrap(); + ( + bb.finish_with_outputs(unit, [oo]).unwrap(), + tgt_ext, + tgt_dom, + ) + }; + + let bb_right = { + let mut bb = cfg + .block_builder( + vec![Type::UNIT, other_output_type.clone()].into(), + [type_row![]], + other_output_type.clone().into(), + ) + .unwrap(); + let [_b, oo] = bb.input_wires_arr(); + let unit = bb.add_load_value(Value::unit()); + bb.finish_with_outputs(unit, [oo]).unwrap() + }; + + let bb_bottom = { + let bb = cfg + .block_builder( + branch_type.clone().into(), + branch_variants, + other_output_type.clone().into(), + ) + .unwrap(); + let [oo] = bb.input_wires_arr(); + bb.finish_with_outputs(oo, [oo]).unwrap() + }; + cfg.branch(&entry, 0, &bb_left).unwrap(); + cfg.branch(&entry, 1, &bb_right).unwrap(); + cfg.branch(&bb_left, 0, &bb_bottom).unwrap(); + cfg.branch(&bb_right, 0, &bb_bottom).unwrap(); + cfg.branch(&bb_bottom, 0, &entry).unwrap(); + cfg.branch(&bb_bottom, 1, &exit).unwrap(); + let [unit, out] = cfg.finish_sub_container().unwrap().outputs_arr(); + + let mut hugr = outer.finish_hugr_with_outputs([unit, out]).unwrap(); + let Err(FindNonLocalEdgesError::Edges(es)) = ensure_no_nonlocal_edges(&hugr) else { + panic!() + }; + assert_eq!( + es, + vec![ + (tgt_ext.node(), IncomingPort::from(0)), + (tgt_dom.node(), IncomingPort::from(0)) + ] + ); + remove_nonlocal_edges(&mut hugr).unwrap(); + hugr.validate().unwrap(); + assert!(ensure_no_nonlocal_edges(&hugr).is_ok()); + let dfb = |bb: BasicBlockID| hugr.get_optype(bb.node()).as_dataflow_block().unwrap(); + // Entry node gets ext_edge_type added, only + assert_eq!( + dfb(entry).inputs[..], + [ext_edge_type.clone(), Type::UNIT, branch_type.clone()] + ); + // Left node gets both ext_edge_type and dom_edge_type + assert_eq!( + dfb(bb_left).inputs[..], + [ + ext_edge_type.clone(), + dom_edge_type, + Type::UNIT, + other_output_type + ] ); + // Bottom node gets ext_edge_type added, only + assert_eq!(dfb(bb_bottom).inputs[..], [ext_edge_type, branch_type]); } } diff --git a/hugr-passes/src/non_local/localize.rs b/hugr-passes/src/non_local/localize.rs new file mode 100644 index 000000000..15f5dedc0 --- /dev/null +++ b/hugr-passes/src/non_local/localize.rs @@ -0,0 +1,395 @@ +//! Implementation of [super::LocalizeEdgesPass] + +use std::collections::{BTreeMap, HashMap}; + +use hugr_core::{ + Direction, HugrView, IncomingPort, Wire, + builder::{ConditionalBuilder, Dataflow, DataflowSubContainer, HugrBuilder}, + core::HugrNode, + hugr::hugrmut::HugrMut, + ops::{DataflowOpTrait, OpType, Tag, TailLoop}, + types::{EdgeKind, Type, TypeRow}, +}; +use itertools::Itertools as _; + +use super::just_types; + +#[derive(Debug, Clone)] +// For each parent/container node, a map from the source Wires that need to be added +// as extra inputs to that container, to the Type of each. +pub(super) struct ExtraSourceReqs(BTreeMap, Type>>); + +impl Default for ExtraSourceReqs { + fn default() -> Self { + Self(BTreeMap::default()) + } +} + +impl ExtraSourceReqs { + fn insert(&mut self, node: N, source: Wire, ty: Type) -> bool { + self.0.entry(node).or_default().insert(source, ty).is_none() + } + + fn get(&self, node: N) -> impl Iterator, &Type)> + '_ { + self.0.get(&node).into_iter().flat_map(BTreeMap::iter) + } + + pub fn parent_needs(&self, parent: N, source: Wire) -> bool { + self.get(parent).any(|(w, _)| *w == source) + } + + /// Identify all required extra inputs (deals with both Dom and Ext edges). + /// Every intermediate node in the hierarchy + /// between the source's parent and the target needs that source. + pub fn add_edge( + &mut self, + hugr: &impl HugrView, + mut parent: N, + source: Wire, + ty: Type, + ) { + let source_parent = hugr.get_parent(source.node()).unwrap(); + while source_parent != parent { + debug_assert!(hugr.get_parent(parent).is_some()); + if !self.insert(parent, source, ty.clone()) { + break; + } + if hugr.get_optype(parent).is_conditional() { + // One of these we must have just done on the previous iteration + for case in hugr.children(parent) { + // Full recursion unnecessary as we've just added parent: + self.insert(case, source, ty.clone()); + } + } + // this will eventually panic if source_parent is not an ancestor of target + let parent_parent = hugr.get_parent(parent).unwrap(); + + if hugr.get_optype(parent).is_dataflow_block() { + assert!(hugr.get_optype(parent_parent).is_cfg()); + // For both Dom edges and Ext edges from outside the CFG, also add to all + // reaching BBs (for a Dom edge, up to but not including the source BB: + // all paths eventually come from the source since it dominates the target). + for pred in hugr.input_neighbours(parent).collect::>() { + self.add_edge(hugr, pred, source, ty.clone()); + } + if Some(parent) == hugr.children(parent_parent).next() { + // We've just added to entry node - so carry on and add to CFG as well + } else { + // Recursive calls on predecessors will have traced back to entry block + // (or source_parent itself if a dominating Basic Block) + break; + } + } + parent = parent_parent; + } + } + + /// Threads the extra connections required throughout the Hugr + pub(super) fn thread_hugr(&self, hugr: &mut impl HugrMut) { + self.thread_node(hugr, hugr.entrypoint(), &HashMap::new()) + } + + // keys of `locals` are the *original* sources of the non-local edges, in self.0. + fn thread_node( + &self, + hugr: &mut impl HugrMut, + node: N, + locals: &HashMap, Wire>, + ) { + if self.get(node).next().is_none() { + // No edges incoming into this subtree, but there could still be nonlocal edges internal to it + for ch in hugr.children(node).collect::>() { + self.thread_node(hugr, ch, &HashMap::new()) + } + return; + } + + let sources: Vec<(Wire, Type)> = self.get(node).map(|(w, t)| (*w, t.clone())).collect(); + let src_wires: Vec> = sources.iter().map(|(w, _)| *w).collect(); + + // `match` must deal with everything inside the node, and update the signature (per OpType) + let start_new_port_index = match hugr.optype_mut(node) { + OpType::DFG(dfg) => { + let ins = dfg.signature.input.to_mut(); + let start_new_port_index = ins.len(); + ins.extend(just_types(&sources)); + + self.thread_dataflow_parent(hugr, node, start_new_port_index, sources); + start_new_port_index + } + OpType::Conditional(cond) => { + let start_new_port_index = cond.signature().input.len(); + cond.other_inputs.to_mut().extend(just_types(&sources)); + + self.thread_conditional(hugr, node, sources); + start_new_port_index + } + OpType::TailLoop(tail_op) => { + vec_prepend(tail_op.just_inputs.to_mut(), just_types(&sources)); + self.thread_tailloop(hugr, node, sources); + 0 + } + OpType::CFG(cfg) => { + vec_prepend(cfg.signature.input.to_mut(), just_types(&sources)); + assert_eq!( + self.get(node).collect::>(), + self.get(hugr.children(node).next().unwrap()) + .collect::>() + ); // Entry node + for bb in hugr.children(node).collect::>() { + if hugr.get_optype(bb).is_dataflow_block() { + self.thread_bb(hugr, bb); + } + } + 0 + } + _ => panic!( + "All containers handled except Module/FuncDefn or root Case/DFB, which should not have incoming nonlocal edges" + ), + }; + + let new_dfg_ports = hugr.insert_ports( + node, + Direction::Incoming, + start_new_port_index, + src_wires.len(), + ); + let local_srcs = src_wires.into_iter().map(|w| *locals.get(&w).unwrap_or(&w)); + for (w, tgt_port) in local_srcs.zip_eq(new_dfg_ports) { + assert_eq!(hugr.get_parent(w.node()), hugr.get_parent(node)); + hugr.connect(w.node(), w.source(), node, tgt_port) + } + } + + // Add to Input node; assume container type already updated. + fn thread_dataflow_parent( + &self, + hugr: &mut impl HugrMut, + node: N, + start_new_port_index: usize, + srcs: Vec<(Wire, Type)>, + ) -> HashMap, Wire> { + let nlocals = if srcs.is_empty() { + HashMap::new() + } else { + let (srcs, tys): (Vec<_>, Vec) = srcs.into_iter().unzip(); + let [inp, _] = hugr.get_io(node).unwrap(); + let OpType::Input(in_op) = hugr.optype_mut(inp) else { + panic!("Expected Input node") + }; + vec_insert(in_op.types.to_mut(), tys, start_new_port_index); + let new_outports = + hugr.insert_ports(inp, Direction::Outgoing, start_new_port_index, srcs.len()); + + srcs.into_iter() + .zip_eq(new_outports) + .map(|(w, p)| (w, Wire::new(inp, p))) + .collect() + }; + for ch in hugr.children(node).collect::>() { + for (inp, _) in hugr.in_value_types(ch).collect::>() { + if let Some((src_n, src_p)) = hugr.single_linked_output(ch, inp) { + if hugr.get_parent(src_n) != Some(node) { + hugr.disconnect(ch, inp); + let new_p = nlocals.get(&Wire::new(src_n, src_p)).unwrap(); + hugr.connect(new_p.node(), new_p.source(), ch, inp); + } + } + } + self.thread_node(hugr, ch, &nlocals); + } + nlocals + } + + // Add to children (assuming conditional already updated). + fn thread_conditional( + &self, + hugr: &mut impl HugrMut, + node: N, + srcs: Vec<(Wire, Type)>, + ) { + for case in hugr.children(node).collect::>() { + let OpType::Case(case_op) = hugr.optype_mut(case) else { + continue; + }; + let ins = case_op.signature.input.to_mut(); + let start_case_port_index = ins.len(); + ins.extend(just_types(&srcs)); + self.thread_dataflow_parent(hugr, case, start_case_port_index, srcs.clone()); + } + } + + // Add to body of loop (assume TailLoop node itself already updated). + fn thread_tailloop( + &self, + hugr: &mut impl HugrMut, + node: N, + srcs: Vec<(Wire, Type)>, + ) { + let [_, o] = hugr.get_io(node).unwrap(); + let new_sum_row_prefixes = { + let mut v = vec![vec![]; 2]; + v[TailLoop::CONTINUE_TAG] = srcs.clone(); + v + }; + add_control_prefixes(hugr, o, new_sum_row_prefixes); + self.thread_dataflow_parent(hugr, node, 0, srcs); + } + + // Add to DataflowBlock *and* inner dataflow sibling subgraph + fn thread_bb(&self, hugr: &mut impl HugrMut, node: N) { + let OpType::DataflowBlock(this_dfb) = hugr.optype_mut(node) else { + panic!("Expected dataflow block") + }; + let my_inputs: Vec<_> = self.get(node).map(|(w, t)| (*w, t.clone())).collect(); + vec_prepend(this_dfb.inputs.to_mut(), just_types(&my_inputs)); + let locals = self.thread_dataflow_parent(hugr, node, 0, my_inputs); + let variant_source_prefixes: Vec, Type)>> = hugr + .output_neighbours(node) + .map(|succ| { + // The wires required for each successor block, should be available in the predecessor + self.get(succ) + .map(|(w, ty)| { + ( + if hugr.get_parent(w.node()) == Some(node) { + *w + } else { + *locals.get(w).unwrap() + }, + ty.clone(), + ) + }) + .collect() + }) + .collect(); + let OpType::DataflowBlock(this_dfb) = hugr.optype_mut(node) else { + panic!("It worked earlier!") + }; + for (source_prefix, sum_row) in variant_source_prefixes + .iter() + .zip_eq(this_dfb.sum_rows.iter_mut()) + { + vec_prepend(sum_row.to_mut(), just_types(source_prefix)); + } + let [_, output_node] = hugr.get_io(node).unwrap(); + add_control_prefixes(hugr, output_node, variant_source_prefixes); + } +} + +/// `variant_source_prefixes` are extra wires/types to prepend onto each variant +/// (must have one element per variant of control Sum) +fn add_control_prefixes( + hugr: &mut H, + output_node: H::Node, + variant_source_prefixes: Vec, Type)>>, +) { + debug_assert!(hugr.get_optype(output_node).is_output()); // Just to fail fast + let parent = hugr.get_parent(output_node).unwrap(); + let mut needed_sources = BTreeMap::new(); + let (cond, new_control_type) = { + let Some(EdgeKind::Value(control_type)) = hugr + .get_optype(output_node) + .port_kind(IncomingPort::from(0)) + else { + panic!("impossible") + }; + let Some(sum_type) = control_type.as_sum() else { + panic!("impossible") + }; + + let mut type_for_source = |source: &(Wire, Type)| { + let (w, t) = source; + let replaced = needed_sources.insert(*w, (*w, t.clone())); + debug_assert!(!replaced.is_some_and(|x| x != (*w, t.clone()))); + t.clone() + }; + let old_sum_rows: Vec = sum_type + .variants() + .map(|x| x.clone().try_into().unwrap()) + .collect_vec(); + let new_sum_rows: Vec = + itertools::zip_eq(variant_source_prefixes.iter(), old_sum_rows.iter()) + .map(|(new_sources, old_tys)| { + new_sources + .iter() + .map(&mut type_for_source) + .chain(old_tys.iter().cloned()) + .collect_vec() + .into() + }) + .collect_vec(); + + let new_control_type = Type::new_sum(new_sum_rows.clone()); + let mut cond = ConditionalBuilder::new( + old_sum_rows.clone(), + just_types(needed_sources.values()).collect_vec(), + new_control_type.clone(), + ) + .unwrap(); + for (i, new_sources) in variant_source_prefixes.into_iter().enumerate() { + let mut case = cond.case_builder(i).unwrap(); + let case_inputs = case.input_wires().collect_vec(); + let mut args = new_sources + .into_iter() + .map(|(s, _ty)| { + case_inputs[old_sum_rows[i].len() + + needed_sources + .iter() + .find_position(|(w, _)| **w == s) + .unwrap() + .0] + }) + .collect_vec(); + args.extend(&case_inputs[..old_sum_rows[i].len()]); + let case_outputs = case + .add_dataflow_op(Tag::new(i, new_sum_rows.clone()), args) + .unwrap() + .outputs(); + case.finish_with_outputs(case_outputs).unwrap(); + } + (cond.finish_hugr().unwrap(), new_control_type) + }; + let cond_node = hugr.insert_hugr(parent, cond).inserted_entrypoint; + let (old_output_source_node, old_output_source_port) = + hugr.single_linked_output(output_node, 0).unwrap(); + debug_assert_eq!(hugr.get_parent(old_output_source_node).unwrap(), parent); + hugr.connect(old_output_source_node, old_output_source_port, cond_node, 0); + for (i, &(w, _)) in needed_sources.values().enumerate() { + hugr.connect(w.node(), w.source(), cond_node, i + 1); + } + hugr.disconnect(output_node, IncomingPort::from(0)); + hugr.connect(cond_node, 0, output_node, 0); + let OpType::Output(output) = hugr.optype_mut(output_node) else { + panic!("impossible") + }; + output.types.to_mut()[0] = new_control_type; +} + +fn vec_prepend(v: &mut Vec, ts: impl IntoIterator) { + vec_insert(v, ts, 0) +} + +fn vec_insert(v: &mut Vec, ts: impl IntoIterator, index: usize) { + let mut old_v_iter = std::mem::take(v).into_iter(); + v.extend(old_v_iter.by_ref().take(index).chain(ts)); + v.extend(old_v_iter); +} + +#[cfg(test)] +mod test { + use super::vec_insert; + + #[test] + fn vec_insert0() { + let mut v = vec![5, 7, 9]; + vec_insert(&mut v, [1, 2], 0); + assert_eq!(v, [1, 2, 5, 7, 9]); + } + + #[test] + fn vec_insert1() { + let mut v = vec![5, 7, 9]; + vec_insert(&mut v, [1, 2], 1); + assert_eq!(v, [5, 1, 2, 7, 9]); + } +}