@@ -151,7 +151,7 @@ def _lik(muts, span, dt, mutation_rate, normalize=True):
151
151
"""
152
152
ll = scipy .stats .poisson .pmf (muts , dt * mutation_rate * span )
153
153
if normalize :
154
- return ll / np .max (ll )
154
+ return ll / np .nanmax (ll )
155
155
else :
156
156
return ll
157
157
@@ -258,15 +258,28 @@ def get_mut_lik_fixed_node(self, edge):
258
258
259
259
mutations_on_edge = self .mut_edges [edge .id ]
260
260
child_time = self .ts .node (edge .child ).time
261
- #assert child_time == 0
262
- # Temporary hack - we should really take a more precise likelihood
263
- return self ._lik (
264
- mutations_on_edge ,
265
- edge .span ,
266
- self .timediff ,
267
- self .mut_rate ,
268
- normalize = self .normalize ,
269
- )
261
+ if child_time == 0 :
262
+ return self ._lik (
263
+ mutations_on_edge ,
264
+ edge .span ,
265
+ self .timediff ,
266
+ self .mut_rate ,
267
+ normalize = self .normalize ,
268
+ )
269
+ else :
270
+ timediff = self .timepoints - child_time + 1e-8
271
+ # Temporary hack - we should really take a more precise likelihood
272
+ likelihood = self ._lik (
273
+ mutations_on_edge ,
274
+ edge .span ,
275
+ timediff ,
276
+ self .mut_rate ,
277
+ normalize = self .normalize ,
278
+ )
279
+ # Prevent child from being older than parent
280
+ likelihood [timediff < 0 ] = 0
281
+
282
+ return likelihood
270
283
271
284
def get_mut_lik_lower_tri (self , edge ):
272
285
"""
@@ -389,7 +402,7 @@ def get_fixed(self, arr, edge):
389
402
return arr * liks
390
403
391
404
def scale_geometric (self , fraction , value ):
392
- return value ** fraction
405
+ return value ** fraction
393
406
394
407
395
408
class LogLikelihoods (Likelihoods ):
@@ -429,7 +442,7 @@ def _lik(muts, span, dt, mutation_rate, normalize=True):
429
442
"""
430
443
ll = scipy .stats .poisson .logpmf (muts , dt * mutation_rate * span )
431
444
if normalize :
432
- return ll - np .max (ll )
445
+ return ll - np .nanmax (ll )
433
446
else :
434
447
return ll
435
448
@@ -634,11 +647,22 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
634
647
inside = self .priors .clone_with_new_data ( # store inside matrix values
635
648
grid_data = np .nan , fixed_data = self .lik .identity_constant
636
649
)
650
+ # It is possible that a simple node is non-fixed, in which case we want to
651
+ # provide an inside array that reflects the prior distribution
652
+ nonfixed_samples = np .intersect1d (inside .nonfixed_node_ids (), self .ts .samples ())
653
+ for u in nonfixed_samples :
654
+ # this is in the same probability space as the prior, so we should be
655
+ # OK just to copy the prior values straight in. It's unclear to me (Yan)
656
+ # how/if they should be normalised, however
657
+ inside [u ][:] = self .priors [u ]
658
+
637
659
if cache_inside :
638
660
g_i = np .full (
639
661
(self .ts .num_edges , self .lik .grid_size ), self .lik .identity_constant
640
662
)
641
663
norm = np .full (self .ts .num_nodes , np .nan )
664
+ to_visit = np .zeros (self .ts .num_nodes , dtype = bool )
665
+ to_visit [inside .nonfixed_node_ids ()] = True
642
666
# Iterate through the nodes via groupby on parent node
643
667
for parent , edges in tqdm (
644
668
self .edges_by_parent_asc (),
@@ -673,14 +697,23 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
673
697
"dangling nodes: please simplify it"
674
698
)
675
699
daughter_val = self .lik .scale_geometric (
676
- spanfrac , self .lik .make_lower_tri (inside [ edge . child ] )
700
+ spanfrac , self .lik .make_lower_tri (inside_values )
677
701
)
678
702
edge_lik = self .lik .get_inside (daughter_val , edge )
679
703
val = self .lik .combine (val , edge_lik )
704
+ if np .all (val == 0 ):
705
+ raise ValueError
680
706
if cache_inside :
681
707
g_i [edge .id ] = edge_lik
682
- norm [parent ] = np .max (val ) if normalize else 1
708
+ norm [parent ] = np .max (val ) if normalize else self . lik . identity_constant
683
709
inside [parent ] = self .lik .reduce (val , norm [parent ])
710
+ to_visit [parent ] = False
711
+
712
+ # There may be nodes that are not parents but are also not fixed (e.g.
713
+ # undated sample nodes). These need an identity normalization constant
714
+ for unfixed_unvisited in np .where (to_visit )[0 ]:
715
+ norm [unfixed_unvisited ] = self .lik .identity_constant
716
+
684
717
if cache_inside :
685
718
self .g_i = self .lik .reduce (g_i , norm [self .ts .tables .edges .child , None ])
686
719
# Keep the results in this object
@@ -732,10 +765,10 @@ def outside_pass(
732
765
if ignore_oldest_root :
733
766
if edge .parent == self .ts .num_nodes - 1 :
734
767
continue
735
- # if edge.parent in self.fixednodes:
736
- # raise RuntimeError(
737
- # "Fixed nodes cannot currently be parents in the TS"
738
- # )
768
+ if edge .parent in self .fixednodes :
769
+ raise RuntimeError (
770
+ "Fixed nodes cannot currently be parents in the TS"
771
+ )
739
772
# Geometric scaling works exactly for all nodes fixed in graph
740
773
# but is an approximation when times are unknown.
741
774
spanfrac = edge .span / self .spans [child ]
@@ -897,34 +930,32 @@ def posterior_mean_var(ts, posterior, *, fixed_node_set=None):
897
930
return ts , mn_post , vr_post
898
931
899
932
900
- def constrain_ages_topo (ts , post_mn , eps , nodes_to_date = None , progress = False ):
933
+ def constrain_ages_topo (ts , node_times , eps , progress = False ):
901
934
"""
902
- If predicted node times violate topology, restrict node ages so that they
903
- must be older than all their children.
935
+ If node_times violate topology, return increased node_times so that each node is
936
+ guaranteed to be older than any of its their children.
904
937
"""
905
- new_mn_post = np .copy (post_mn )
906
- if nodes_to_date is None :
907
- nodes_to_date = np .arange (ts .num_nodes , dtype = np .uint64 )
908
- nodes_to_date = nodes_to_date [~ np .isin (nodes_to_date , ts .samples ())]
909
-
910
- tables = ts .tables
911
- parents = tables .edges .parent
912
- nd_children = tables .edges .child [np .argsort (parents )]
913
- parents = sorted (parents )
914
- parents_unique = np .unique (parents , return_index = True )
915
- parent_indices = parents_unique [1 ][np .isin (parents_unique [0 ], nodes_to_date )]
916
- for index , nd in tqdm (
917
- enumerate (sorted (nodes_to_date )), desc = "Constrain Ages" , disable = not progress
938
+ edges_parent = ts .edges_parent
939
+ edges_child = ts .edges_child
940
+
941
+ new_node_times = np .copy (node_times )
942
+ # Traverse through the ARG, ensuring children come before parents.
943
+ # This can be done by iterating over groups of edges with the same parent
944
+ new_parent_edge_idx = np .where (np .diff (edges_parent ) != 0 )[0 ] + 1
945
+ for edges_start , edges_end in tqdm (
946
+ zip (
947
+ itertools .chain ([0 ], new_parent_edge_idx ),
948
+ itertools .chain (new_parent_edge_idx , [len (edges_parent )]),
949
+ ),
950
+ desc = "Constrain Ages" ,
951
+ disable = not progress ,
918
952
):
919
- if index + 1 != len (nodes_to_date ):
920
- children_index = np .arange (parent_indices [index ], parent_indices [index + 1 ])
921
- else :
922
- children_index = np .arange (parent_indices [index ], ts .num_edges )
923
- children = nd_children [children_index ]
924
- time = np .max (new_mn_post [children ])
925
- if new_mn_post [nd ] <= time :
926
- new_mn_post [nd ] = time + eps
927
- return new_mn_post
953
+ parent = edges_parent [edges_start ]
954
+ child_ids = edges_child [edges_start :edges_end ] # May contain dups
955
+ oldest_child_time = np .max (new_node_times [child_ids ])
956
+ if oldest_child_time >= new_node_times [parent ]:
957
+ new_node_times [parent ] = oldest_child_time + eps
958
+ return new_node_times
928
959
929
960
930
961
def date (
@@ -1015,7 +1046,7 @@ def date(
1015
1046
progress = progress ,
1016
1047
** kwargs
1017
1048
)
1018
- constrained = constrain_ages_topo (tree_sequence , dates , eps , nds , progress )
1049
+ constrained = constrain_ages_topo (tree_sequence , dates , eps , progress )
1019
1050
tables = tree_sequence .dump_tables ()
1020
1051
tables .time_units = time_units
1021
1052
tables .nodes .time = constrained
@@ -1064,12 +1095,6 @@ def get_dates(
1064
1095
1065
1096
:return: tuple(mn_post, posterior, timepoints, eps, nodes_to_date)
1066
1097
"""
1067
- # Stuff yet to be implemented. These can be deleted once fixed
1068
- #for sample in tree_sequence.samples():
1069
- # if tree_sequence.node(sample).time != 0:
1070
- # raise NotImplementedError("Samples must all be at time 0")
1071
- fixed_nodes = set (tree_sequence .samples ())
1072
-
1073
1098
# Default to not creating approximate priors unless ts has > 1000 samples
1074
1099
approx_priors = False
1075
1100
if tree_sequence .num_samples > 1000 :
@@ -1097,6 +1122,8 @@ def get_dates(
1097
1122
)
1098
1123
priors = priors
1099
1124
1125
+ fixed_nodes = set (priors .fixed_node_ids ())
1126
+
1100
1127
if probability_space != base .LOG :
1101
1128
liklhd = Likelihoods (
1102
1129
tree_sequence ,
0 commit comments