@@ -1027,14 +1027,10 @@ def shape_scale_from_mean_var(mean, var):
1027
1027
1028
1028
def _truncate_priors (ts , priors , progress = False ):
1029
1029
"""
1030
- Truncate priors for the nodes listed in truncate_nodes (or all nonfixed nodes
1031
- if truncate_nodes is None) so they conform to the age of fixed nodes in the tree
1032
- sequence
1030
+ Truncate priors for all nonfixed nodes
1031
+ so they conform to the age of fixed nodes in the tree sequence
1033
1032
"""
1034
1033
tables = ts .tables
1035
- truncate_nodes = priors .nonfixed_node_ids ()
1036
- # ensure truncate_nodes is ordered by node time
1037
- truncate_nodes = truncate_nodes [np .argsort (tables .nodes .time [truncate_nodes ])]
1038
1034
1039
1035
fixed_nodes = priors .fixed_node_ids ()
1040
1036
fixed_times = tables .nodes .time [fixed_nodes ]
@@ -1050,29 +1046,29 @@ def _truncate_priors(ts, priors, progress=False):
1050
1046
constrained_min_times = np .zeros_like (tables .nodes .time )
1051
1047
# Set the min times of fixed nodes to those in the tree sequence
1052
1048
constrained_min_times [fixed_nodes ] = fixed_times
1053
- constrained_max_times = np .full_like (constrained_min_times , np .inf )
1054
-
1055
- parents = tables .edges .parent
1056
- nd_children = tables .edges .child [np .argsort (parents )]
1057
- parents = sorted (parents )
1058
- parents_unique = np .unique (parents , return_index = True )
1059
- parent_indices = parents_unique [1 ][np .isin (parents_unique [0 ], truncate_nodes )]
1060
- for index , nd in tqdm (
1061
- enumerate (truncate_nodes ), desc = "Constrain Ages" , disable = not progress
1049
+
1050
+ # Traverse through the ARG, ensuring children come before parents.
1051
+ # This can be done by iterating over groups of edges with the same parent
1052
+ new_parent_edge_idx = np .concatenate (
1053
+ (
1054
+ [0 ],
1055
+ np .where (np .diff (tables .edges .parent ) != 0 )[0 ] + 1 ,
1056
+ [tables .edges .num_rows ],
1057
+ )
1058
+ )
1059
+ for edges_start , edges_end in zip (
1060
+ new_parent_edge_idx [:- 1 ], new_parent_edge_idx [1 :]
1062
1061
):
1063
- if index + 1 != len (truncate_nodes ):
1064
- children_index = np .arange (parent_indices [index ], parent_indices [index + 1 ])
1065
- else :
1066
- children_index = np .arange (parent_indices [index ], ts .num_edges )
1067
- children = nd_children [children_index ]
1068
- time = np .max (constrained_min_times [children ])
1069
- # The constrained time of the node should be the age of the oldest child
1070
- if constrained_min_times [nd ] <= time :
1071
- constrained_min_times [nd ] = time
1072
- nearest_time = np .argmin (np .abs (timepoints - time ))
1073
- lookup_index = priors .row_lookup [int (nd )]
1074
- grid_data [lookup_index ][:nearest_time ] = zero_value
1075
- assert np .all (constrained_min_times < constrained_max_times )
1062
+ parent = tables .edges .parent [edges_start ]
1063
+ child_ids = tables .edges .child [edges_start :edges_end ] # May contain dups
1064
+ oldest_child_time = np .max (constrained_min_times [child_ids ])
1065
+ if oldest_child_time > constrained_min_times [parent ]:
1066
+ constrained_min_times [parent ] = oldest_child_time
1067
+ if constrained_min_times [parent ] > 0 :
1068
+ # What if the parent here is a fixed node?
1069
+ nearest_time = np .argmin (np .abs (timepoints - constrained_min_times [parent ]))
1070
+ lookup_index = priors .row_lookup [parent ]
1071
+ grid_data [lookup_index ][:nearest_time ] = zero_value
1076
1072
1077
1073
rowmax = grid_data [:, 1 :].max (axis = 1 )
1078
1074
if priors .probability_space == "linear" :
0 commit comments