diff --git a/examples/tree_traversals.rs b/examples/tree_traversals.rs index 5cd89faa3..1d4c58743 100644 --- a/examples/tree_traversals.rs +++ b/examples/tree_traversals.rs @@ -21,7 +21,19 @@ fn traverse_upwards_with_iterator(tree: &tskit::Tree) { } fn preorder_traversal(tree: &tskit::Tree) { - for _ in tree.traverse_nodes(tskit::NodeTraversalOrder::Preorder) {} + // Iterate over nodes. + // For preorder traversal, this avoids allocation. + // (But we collect the data for this example, which does allocate.) + let nodes_from_iter = tree + .traverse_nodes(tskit::NodeTraversalOrder::Preorder) + .collect::>(); + // Get a COPY of all nodes as a boxed slice + let nodes_as_slice = tree.nodes(tskit::NodeTraversalOrder::Preorder).unwrap(); + assert_eq!(nodes_as_slice.len(), nodes_from_iter.len()); + nodes_from_iter + .iter() + .zip(nodes_as_slice.iter()) + .for_each(|(i, j)| assert_eq!(i, j)); } #[derive(clap::Parser)] diff --git a/src/sys/tree.rs b/src/sys/tree.rs index e79628214..84b51cf94 100644 --- a/src/sys/tree.rs +++ b/src/sys/tree.rs @@ -156,6 +156,46 @@ impl<'treeseq> LLTree<'treeseq> { } } + pub fn nodes(&self, order: NodeTraversalOrder) -> Result, TskitError> { + let mut nodes: Vec = vec![ + NodeId::NULL; + unsafe { super::bindings::tsk_tree_get_size_bound(self.as_ll_ref()) } + as usize + ]; + + let mut num_nodes: super::bindings::tsk_size_t = 0; + let ptr = std::ptr::addr_of_mut!(num_nodes); + unsafe { + super::bindings::tsk_tree_preorder( + self.as_ll_ref(), + nodes.as_mut_ptr() as *mut super::bindings::tsk_id_t, + ptr, + ); + } + + let code = match order { + NodeTraversalOrder::Preorder => unsafe { + super::bindings::tsk_tree_preorder( + self.as_ll_ref(), + nodes.as_mut_ptr() as *mut super::bindings::tsk_id_t, + ptr, + ) + }, + NodeTraversalOrder::Postorder => unsafe { + super::bindings::tsk_tree_preorder( + self.as_ll_ref(), + nodes.as_mut_ptr() as *mut super::bindings::tsk_id_t, + ptr, + ) + }, + }; + if code == 0 { + nodes.resize(num_nodes as usize, NodeId::NULL); + } + + handle_tsk_return_value!(code, nodes.into_boxed_slice()) + } + pub fn children(&self, u: NodeId) -> impl Iterator + '_ { NodeIteratorAdapter(ChildIterator::new(self, u)) } diff --git a/src/trees/tree.rs b/src/trees/tree.rs index 5891acd54..ee06b677a 100644 --- a/src/trees/tree.rs +++ b/src/trees/tree.rs @@ -313,6 +313,12 @@ impl<'treeseq> Tree<'treeseq> { self.inner.traverse_nodes(order) } + pub fn nodes( + &self, + order: crate::NodeTraversalOrder, + ) -> Result, TskitError> { + } + /// Return an [`Iterator`] over the children of node `u`. /// # Returns ///