Skip to content

Commit 02a9b67

Browse files
committed
Use new algo for truncating priors
1 parent da58644 commit 02a9b67

File tree

1 file changed

+24
-28
lines changed

1 file changed

+24
-28
lines changed

tsdate/prior.py

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,14 +1027,10 @@ def shape_scale_from_mean_var(mean, var):
10271027

10281028
def _truncate_priors(ts, priors, progress=False):
10291029
"""
1030-
Truncate priors for the nodes listed in truncate_nodes (or all nonfixed nodes
1031-
if truncate_nodes is None) so they conform to the age of fixed nodes in the tree
1032-
sequence
1030+
Truncate priors for all nonfixed nodes
1031+
so they conform to the age of fixed nodes in the tree sequence
10331032
"""
10341033
tables = ts.tables
1035-
truncate_nodes = priors.nonfixed_node_ids()
1036-
# ensure truncate_nodes is ordered by node time
1037-
truncate_nodes = truncate_nodes[np.argsort(tables.nodes.time[truncate_nodes])]
10381034

10391035
fixed_nodes = priors.fixed_node_ids()
10401036
fixed_times = tables.nodes.time[fixed_nodes]
@@ -1050,29 +1046,29 @@ def _truncate_priors(ts, priors, progress=False):
10501046
constrained_min_times = np.zeros_like(tables.nodes.time)
10511047
# Set the min times of fixed nodes to those in the tree sequence
10521048
constrained_min_times[fixed_nodes] = fixed_times
1053-
constrained_max_times = np.full_like(constrained_min_times, np.inf)
1054-
1055-
parents = tables.edges.parent
1056-
nd_children = tables.edges.child[np.argsort(parents)]
1057-
parents = sorted(parents)
1058-
parents_unique = np.unique(parents, return_index=True)
1059-
parent_indices = parents_unique[1][np.isin(parents_unique[0], truncate_nodes)]
1060-
for index, nd in tqdm(
1061-
enumerate(truncate_nodes), desc="Constrain Ages", disable=not progress
1049+
1050+
# Traverse through the ARG, ensuring children come before parents.
1051+
# This can be done by iterating over groups of edges with the same parent
1052+
new_parent_edge_idx = np.concatenate(
1053+
(
1054+
[0],
1055+
np.where(np.diff(tables.edges.parent) != 0)[0] + 1,
1056+
[tables.edges.num_rows],
1057+
)
1058+
)
1059+
for edges_start, edges_end in zip(
1060+
new_parent_edge_idx[:-1], new_parent_edge_idx[1:]
10621061
):
1063-
if index + 1 != len(truncate_nodes):
1064-
children_index = np.arange(parent_indices[index], parent_indices[index + 1])
1065-
else:
1066-
children_index = np.arange(parent_indices[index], ts.num_edges)
1067-
children = nd_children[children_index]
1068-
time = np.max(constrained_min_times[children])
1069-
# The constrained time of the node should be the age of the oldest child
1070-
if constrained_min_times[nd] <= time:
1071-
constrained_min_times[nd] = time
1072-
nearest_time = np.argmin(np.abs(timepoints - time))
1073-
lookup_index = priors.row_lookup[int(nd)]
1074-
grid_data[lookup_index][:nearest_time] = zero_value
1075-
assert np.all(constrained_min_times < constrained_max_times)
1062+
parent = tables.edges.parent[edges_start]
1063+
child_ids = tables.edges.child[edges_start:edges_end] # May contain dups
1064+
oldest_child_time = np.max(constrained_min_times[child_ids])
1065+
if oldest_child_time > constrained_min_times[parent]:
1066+
constrained_min_times[parent] = oldest_child_time
1067+
if constrained_min_times[parent] > 0:
1068+
# What if the parent here is a fixed node?
1069+
nearest_time = np.argmin(np.abs(timepoints - constrained_min_times[parent]))
1070+
lookup_index = priors.row_lookup[parent]
1071+
grid_data[lookup_index][:nearest_time] = zero_value
10761072

10771073
rowmax = grid_data[:, 1:].max(axis=1)
10781074
if priors.probability_space == "linear":

0 commit comments

Comments
 (0)