diff --git a/Cargo.toml b/Cargo.toml index 808538b..0ab6576 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "orx-tree" -version = "1.4.0" +version = "1.5.0" edition = "2024" authors = ["orxfun "] description = "A beautiful tree 🌳 with convenient and efficient growth, mutation and traversal features." diff --git a/README.md b/README.md index a298561..b6662cd 100644 --- a/README.md +++ b/README.md @@ -375,6 +375,10 @@ let remaining_bfs: Vec<_> = tree.root().walk::().copied().collect(); assert_eq!(remaining_bfs, [1, 3, 6, 9]); ``` +### More Examples + +* [mutable_recursive_traversal](https://github.com/orxfun/orx-tree/blob/main/examples/mutable_recursive_traversal.rs) demonstrates different approaches to achieve a recursive mutation of all nodes in the tree. + ## Contributing Contributions are welcome! If you notice an error, have a question or think something could be added or improved, please open an [issue](https://github.com/orxfun/orx-tree/issues/new) or create a PR. diff --git a/examples/mutable_recursive_traversal.rs b/examples/mutable_recursive_traversal.rs new file mode 100644 index 0000000..b87c225 --- /dev/null +++ b/examples/mutable_recursive_traversal.rs @@ -0,0 +1,238 @@ +// # EXAMPLE DEFINITION +// +// cargo run --example mutable_recursive_traversal +// +// This example demonstrates a use case where value of a node is defined +// as a function of the values of its children. Since the value of a child +// of the node also depends on values of its own children, it follows that +// the value of a node is a function of values of all of its descendants. +// +// The task is to compute and set all values of a tree given the values of +// the leaves. +// +// This is a interesting and common case in terms of requiring mutable +// recursive traversal over the tree that can be handled with different +// approaches. Some of these are demonstrated in this example. + +use orx_tree::*; +use std::fmt::Display; + +#[derive(Debug, Clone, Copy)] +enum Instruction { + Input(usize), + Add, + AddI { val: f32 }, +} + +impl Display for Instruction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Input(x) => write!(f, "Input({})", x), + Self::Add => write!(f, "Add"), + Self::AddI { val } => write!(f, "AddI({})", val), + } + } +} + +#[derive(Debug)] +struct InstructionNode { + instruction: Instruction, + value: f32, +} + +impl InstructionNode { + fn new(instruction: Instruction, value: f32) -> Self { + Self { instruction, value } + } +} + +impl Display for InstructionNode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.instruction { + Instruction::Input(x) => write!(f, "Input({}) => {}", x, self.value), + Instruction::Add => write!(f, "Add => {}", self.value), + Instruction::AddI { val } => write!(f, "AddI({}) => {}", val, self.value), + } + } +} + +#[derive(Debug)] +struct Instructions { + tree: DynTree, +} + +impl Instructions { + fn example() -> Self { + let mut tree = DynTree::new(InstructionNode::new(Instruction::AddI { val: 100.0 }, 0.0)); + + let mut n0 = tree.root_mut(); + let [n1, n2] = n0.push_children([ + InstructionNode::new(Instruction::Input(1), 0.0), + InstructionNode::new(Instruction::AddI { val: 2.0 }, 0.0), + ]); + let _n3 = tree + .node_mut(&n1) + .push_child(InstructionNode::new(Instruction::Input(0), 0.0)); + let [_n4, _n5] = tree.node_mut(&n2).push_children([ + InstructionNode::new(Instruction::Add, 0.0), + InstructionNode::new(Instruction::AddI { val: 5.0 }, 0.0), + ]); + + Self { tree } + } +} + +/// Demonstrates manual mutable and recursive traversal over the tree. +/// +/// Notice that we can freely walk the tree while always having a single +/// mutable reference to one node. This satisfies the borrow checker rules +/// and further allows for calling the function recursively. +/// +/// Note also that, although it is not necessary in this scenario, we are +/// free to change the shape of the tree during our walk by adding nodes, +/// moving around or pruning subtrees, etc. In other words, it enables the +/// greatest freedom while it requires us to make sure that we do not have +/// errors, such as out-of-bounds errors with the `into_child_mut` call. +/// +/// * Pros +/// * Complete freedom to mutate the nodes and the tree structure during +/// the walk. +/// * No intermediate allocation is required; borrow checker rules are +/// satisfied without the need to collect indices. +/// * Cons +/// * Implementor is required to define the walk. This example demonstrates +/// a depth-first walk due to the recursive calls, which is straightforward +/// to implement. +/// * Due to lack of tail-call optimization in rust, this function is likely +/// to encounter stack overflow for very deep trees. +fn recursive_traversal_over_nodes<'a>( + inputs: &[f32], + mut node: NodeMut<'a, Dyn>, +) -> (NodeMut<'a, Dyn>, f32) { + let num_children = node.num_children(); + + let mut children_sum = 0.0; + for i in 0..num_children { + let child = node.into_child_mut(i).unwrap(); + let (child, child_value) = recursive_traversal_over_nodes(inputs, child); + children_sum += child_value; + node = child.into_parent_mut().unwrap(); + } + + let new_value = match node.data().instruction { + Instruction::Input(i) => inputs[i], + Instruction::Add => children_sum, + Instruction::AddI { val } => val + children_sum, + }; + + (*node.data_mut()).value = new_value; + + (node, new_value) +} + +/// Demonstrates recursive mutable traversal by internally collecting and storing +/// the child node indices. +/// +/// This simplifies the borrow relations and allows for the recursive calls only +/// having a single mutable reference to the tree; however, each recursive call +/// requires an internal allocation. +/// +/// * Pros +/// * Complete freedom to mutate the nodes and the tree structure during +/// the walk. +/// * Cons +/// * Requires to collect indices and results into an internal vector for each +/// recursive call, requiring additional allocation. +/// * Implementor is required to define the walk. This example demonstrates +/// a depth-first walk due to the recursive calls, which is straightforward +/// to implement. +/// * Due to lack of tail-call optimization in rust, this function is likely +/// to encounter stack overflow for very deep trees. +fn recursive_traversal_over_indices( + tree: &mut DynTree, + inputs: &[f32], + node_idx: NodeIdx>, +) -> f32 { + let node = tree.node(&node_idx); + + let children_ids: Vec<_> = node.children().map(|child| child.idx()).collect(); + let children: Vec<_> = children_ids + .into_iter() + .map(|node| recursive_traversal_over_indices(tree, inputs, node)) + .collect(); + + let mut node = tree.node_mut(&node_idx); + + let new_value = match node.data().instruction { + Instruction::Input(i) => inputs[i], + Instruction::Add => children.into_iter().sum(), + Instruction::AddI { val } => children.into_iter().sum::() + val, + }; + (*node.data_mut()).value = new_value; + + new_value +} + +/// Demonstrates the use of [`recursive_set`] method: +/// +/// *Recursively sets the data of all nodes belonging to the subtree rooted +/// at this node using the compute_data function.* +/// +/// This function fits perfectly to this and similar scenarios where we want +/// to compute values of all nodes of a tree such that the value of a node +/// depends on the values of all of its descendants, and hence the name +/// *recursive*. +/// +/// * Pros +/// * More expressive in the sense that the implementor only defines how the +/// value of a node should be computed given its prior value and values of +/// its children. Iteration is abstracted away. +/// * Despite the name, the implementation actually does not require recursive +/// function calls; and hence, can work with trees of arbitrary depth without +/// the risk of stack overflow. Instead, it internally uses the [`PostOrder`] +/// traverser. +/// * Cons +/// * It only allows to set the data of the nodes; however, does not allow for +/// structural mutations. +/// +/// [`recursive_set`]: orx_tree::NodeMut::recursive_set +/// [`PostOrder`]: orx_tree::PostOrder +fn recursive_set(inputs: &[f32], mut node: NodeMut>) { + node.recursive_set(|node_data, children_data| { + let instruction = node_data.instruction; + let children_sum: f32 = children_data.iter().map(|x| x.value).sum(); + let value = match node_data.instruction { + Instruction::Input(i) => inputs[i], + Instruction::Add => children_sum, + Instruction::AddI { val } => val + children_sum, + }; + + InstructionNode { instruction, value } + }); +} + +fn main() { + fn test_implementation(method: &str, f: impl FnOnce(&[f32], &mut Instructions)) { + let inputs = [10.0, 20.0]; + let mut instructions = Instructions::example(); + println!("\n\n### {}", method); + f(&inputs, &mut instructions); + println!("\n{}\n", &instructions.tree); + } + + test_implementation( + "recursive_traversal_over_indices", + |inputs, instructions| { + let root_idx = instructions.tree.root().idx(); + recursive_traversal_over_indices(&mut instructions.tree, inputs, root_idx); + }, + ); + + test_implementation("recursive_traversal_over_nodes", |inputs, instructions| { + recursive_traversal_over_nodes(&inputs, instructions.tree.root_mut()); + }); + + test_implementation("recursive_set", |inputs, instructions| { + recursive_set(inputs, instructions.tree.root_mut()); + }); +} diff --git a/src/node_mut.rs b/src/node_mut.rs index ac083d2..9c463f2 100644 --- a/src/node_mut.rs +++ b/src/node_mut.rs @@ -1,5 +1,5 @@ use crate::{ - NodeIdx, NodeRef, SubTree, Traverser, Tree, TreeVariant, + NodeIdx, NodeRef, PostOrder, SubTree, Traverser, Tree, TreeVariant, aliases::{Col, N}, iter::ChildrenMutIter, memory::{Auto, MemoryPolicy}, @@ -10,12 +10,15 @@ use crate::{ traversal::{ OverData, OverMut, enumerations::Val, + over::OverPtr, over_mut::{OverItemInto, OverItemMut}, post_order::iter_ptr::PostOrderIterPtr, + traverser_core::TraverserCore, }, tree_node_idx::INVALID_IDX_ERROR, tree_variant::RefsChildren, }; +use alloc::vec::Vec; use core::{fmt::Debug, marker::PhantomData}; use orx_selfref_col::{NodePtr, Refs}; @@ -2455,6 +2458,173 @@ where traverser.into_iter(self) } + // recursive + + /// Recursively sets the data of all nodes belonging to the subtree rooted at this node using the `compute_data` + /// function. + /// + /// This method provides an expressive way to update the values of a tree where value of a node is a function of + /// its prior value and values of its children. Since the values of its children subsequently depend on their own + /// children, it immediately follows that the value of the node depends on values of all of its descendants that + /// must be computed to be able to compute the node's value. + /// + /// The `compute_data` function takes two arguments: + /// + /// * current value (data) of this node, and + /// * slice of values of children of this node that are computed recursively using `compute_data` (*); + /// + /// and then, computes the new value of this node. + /// + /// The method is named *recursive* (*) due to the fact that, + /// + /// * before computing the value of this node; + /// * values of all of its children are also computed and set using the `compute_data` function. + /// + /// *Note that this method does **not** actually make recursive method calls. Instead, it internally uses the [`PostOrder`] + /// traverser which ensures that all required values are computed before they are used for another computation. This + /// is a guard against potential stack overflow issues, and hence, can be used for trees of arbitrary depth.* + /// + /// [`PostOrder`]: crate::PostOrder + /// + /// # Examples + /// + /// In the following example, we set the value of every node to the sum of values of all its descendants. + /// + /// While building the tree, we set only the values of the leaves. + /// We initially set values of all other nodes to zero as a placeholder. + /// Then, we call `recursive_set` to compute them. + /// + /// ``` + /// use orx_tree::*; + /// + /// let mut tree = DynTree::<_>::new(0); + /// let [id1, id2] = tree.root_mut().push_children([0, 0]); + /// tree.node_mut(&id1).push_children([1, 3]); + /// tree.node_mut(&id2).push_children([7, 2, 4]); + /// // 0 + /// // ╱ ╲ + /// // ╱ ╲ + /// // 0 0 + /// // ╱ ╲ ╱|╲ + /// // 1 3 7 2 4 + /// + /// tree.root_mut() + /// .recursive_set( + /// |current_value, children_values| match children_values.is_empty() { + /// true => *current_value, // is a leaf + /// false => children_values.iter().copied().sum(), + /// }, + /// ); + /// // 17 + /// // ╱ ╲ + /// // ╱ ╲ + /// // 4 13 + /// // ╱ ╲ ╱|╲ + /// // 1 3 7 2 4 + /// + /// let bfs: Vec<_> = tree.root().walk::().copied().collect(); + /// assert_eq!(bfs, [17, 4, 13, 1, 3, 7, 2, 4]); + /// ``` + /// + /// The following is a similar example where leaf nodes represent deterministic outcomes of + /// a process. + /// The root represents the current state. + /// The remaining nodes represent intermediate states that we can reach from its parent with + /// the given `probability`. + /// Our task is to compute `expected_value` of each state. + /// + /// Since we know the value of the leaves with certainty, we set them while constructing the + /// tree. Then, we call `recursive_set` to compute the expected value of every other node. + /// + /// ``` + /// use orx_tree::*; + /// + /// #[derive(Clone)] + /// struct State { + /// /// Probability of reaching this state from its parent. + /// probability: f64, + /// /// Expected value of the state; i.e., average of values of all leaves weighted by + /// /// the probability of being reached from this state. + /// expected_value: f64, + /// } + /// + /// fn state(probability: f64, expected_value: f64) -> State { + /// State { + /// probability, + /// expected_value, + /// } + /// } + /// + /// // (1.0, ???) + /// // ╱ ╲ + /// // ╱ ╲ + /// // ╱ ╲ + /// // ╱ ╲ + /// // (.3, ???) (.7, ???) + /// // ╱ ╲ | ╲ + /// // ╱ ╲ | ╲ + /// // (.2, 9) (.8, 2) (.9, 5) (.1, 4) + /// + /// let mut tree = DynTree::<_>::new(state(1.0, 0.0)); + /// + /// let [id1, id2] = tree + /// .root_mut() + /// .push_children([state(0.3, 0.0), state(0.7, 0.0)]); + /// tree.node_mut(&id1) + /// .push_children([state(0.2, 9.0), state(0.8, 2.0)]); + /// tree.node_mut(&id2) + /// .push_children([state(0.9, 5.0), state(0.1, 4.0)]); + /// + /// tree.root_mut() + /// .recursive_set( + /// |current_value, children_values| match children_values.is_empty() { + /// true => current_value.clone(), // is a leaf, we know expected value + /// false => { + /// let expected_value = children_values + /// .iter() + /// .fold(0.0, |a, x| a + x.probability * x.expected_value); + /// state(current_value.probability, expected_value) + /// } + /// }, + /// ); + /// // (1.0, 4.45) + /// // ╱ ╲ + /// // ╱ ╲ + /// // ╱ ╲ + /// // ╱ ╲ + /// // (.3, 3.4) (.7, 4.9) + /// // ╱ ╲ | ╲ + /// // ╱ ╲ | ╲ + /// // (.2, 9) (.8, 2) (.9, 5) (.1, 4) + /// + /// let equals = |a: f64, b: f64| (a - b).abs() < 1e-5; + /// + /// assert!(equals(tree.root().data().expected_value, 4.45)); + /// assert!(equals(tree.node(&id1).data().expected_value, 3.40)); + /// assert!(equals(tree.node(&id2).data().expected_value, 4.90)); + /// ``` + #[allow(clippy::missing_panics_doc)] + pub fn recursive_set(&mut self, compute_data: impl Fn(&V::Item, &[&V::Item]) -> V::Item) { + let iter = PostOrder::::iter_ptr_with_owned_storage(self.node_ptr.clone()); + let mut children_data = Vec::<&V::Item>::new(); + + for ptr in iter { + let node = unsafe { &mut *ptr.ptr_mut() }; + let node_data = node.data().expect("is not closed"); + + for child_ptr in node.next().children_ptr() { + let data = unsafe { &*child_ptr.ptr() }.data().expect("is not closed"); + children_data.push(data); + } + + let new_data = compute_data(node_data, &children_data); + + *node.data_mut().expect("is not closed") = new_data; + + children_data.clear(); + } + } + // subtree /// Creates a subtree view including this node as the root and all of its descendants with their orientation relative diff --git a/src/traversal/post_order/traverser.rs b/src/traversal/post_order/traverser.rs index 46648c7..09e8042 100644 --- a/src/traversal/post_order/traverser.rs +++ b/src/traversal/post_order/traverser.rs @@ -13,6 +13,7 @@ use core::marker::PhantomData; /// # Construction /// /// A post order traverser can be created, +/// /// * either by using Default trait and providing its two generic type parameters /// * `PostOrder::<_, OverData>::default()` or `PostOrder::<_, OverDepthSiblingIdxData>::default()`, or /// * `PostOrder::, OverData>::default()` or `PostOrder::, OverDepthSiblingIdxData>::default()`