Skip to content

Commit 6e6cf7b

Browse files
awohnshyanwong
authored andcommitted
Allow ancient samples
Rework build-prior and inside / outside logic to allow historical samples And speed up time constraint algorithms while also allowing nodes to be out of time order
1 parent 38493ff commit 6e6cf7b

File tree

6 files changed

+260
-98
lines changed

6 files changed

+260
-98
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
tskit>=0.4.0
1+
tskit>=0.5.2
22
tsinfer>=0.3.0
33
flake8
44
numpy

tests/test_functions.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,7 +1048,7 @@ def test_dangling_fails(self):
10481048
print(ts.draw_text())
10491049
print("Samples:", ts.samples())
10501050
Ne = 0.5
1051-
with pytest.raises(ValueError, match="simplified"):
1051+
with pytest.raises(ValueError, match="simplify"):
10521052
tsdate.build_prior_grid(ts, Ne, timepoints=np.array([0, 1.2, 2]))
10531053
# mut_rate = 1
10541054
# eps = 1e-6
@@ -1421,7 +1421,7 @@ def test_date_input(self):
14211421

14221422
def test_sample_as_parent_fails(self):
14231423
ts = utility_functions.single_tree_ts_n3_sample_as_parent()
1424-
with pytest.raises(NotImplementedError):
1424+
with pytest.raises(ValueError, match="samples at non-zero times"):
14251425
tsdate.date(ts, mutation_rate=None, Ne=1)
14261426

14271427
def test_recombination_not_implemented(self):
@@ -1532,18 +1532,7 @@ def test_constrain_ages_topo(self):
15321532
ts = utility_functions.two_tree_ts()
15331533
post_mn = np.array([0.0, 0.0, 0.0, 2.0, 1.0, 3.0])
15341534
eps = 1e-6
1535-
nodes_to_date = np.array([3, 4, 5])
1536-
constrained_ages = constrain_ages_topo(ts, post_mn, eps, nodes_to_date)
1537-
assert np.array_equal(
1538-
np.array([0.0, 0.0, 0.0, 2.0, 2.000001, 3.0]), constrained_ages
1539-
)
1540-
1541-
def test_constrain_ages_topo_no_nodes_to_date(self):
1542-
ts = utility_functions.two_tree_ts()
1543-
post_mn = np.array([0.0, 0.0, 0.0, 2.0, 1.0, 3.0])
1544-
eps = 1e-6
1545-
nodes_to_date = None
1546-
constrained_ages = constrain_ages_topo(ts, post_mn, eps, nodes_to_date)
1535+
constrained_ages = constrain_ages_topo(ts, post_mn, eps)
15471536
assert np.array_equal(
15481537
np.array([0.0, 0.0, 0.0, 2.0, 2.000001, 3.0]), constrained_ages
15491538
)

tests/test_inference.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def test_bad_Ne(self):
6161

6262
def test_dangling_failure(self):
6363
ts = utility_functions.single_tree_ts_n2_dangling()
64-
with pytest.raises(ValueError, match="simplified"):
64+
with pytest.raises(ValueError, match="simplify"):
6565
tsdate.date(ts, mutation_rate=None, Ne=1)
6666

6767
def test_unary_failure(self):
@@ -271,16 +271,29 @@ def test_fails_multi_root(self):
271271
with pytest.raises(ValueError):
272272
tsdate.date(multiroot_ts, Ne=1, mutation_rate=2, priors=good_priors)
273273

274-
def test_non_contemporaneous(self):
274+
def test_non_contemporaneous_warn(self):
275275
samples = [
276276
msprime.Sample(population=0, time=0),
277277
msprime.Sample(population=0, time=0),
278278
msprime.Sample(population=0, time=0),
279279
msprime.Sample(population=0, time=1.0),
280280
]
281281
ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2, random_seed=12)
282-
with pytest.raises(NotImplementedError):
282+
with pytest.raises(ValueError, match="samples at non-zero times"):
283283
tsdate.date(ts, Ne=1, mutation_rate=2)
284+
with pytest.raises(ValueError, match="samples at non-zero times"):
285+
tsdate.build_prior_grid(ts, Ne=1)
286+
287+
def test_non_contemporaneous(self):
288+
samples = [
289+
msprime.Sample(population=0, time=0),
290+
msprime.Sample(population=0, time=0),
291+
msprime.Sample(population=0, time=0),
292+
msprime.Sample(population=0, time=1.0),
293+
]
294+
ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2, random_seed=12)
295+
priors = tsdate.build_prior_grid(ts, Ne=1, allow_historical_samples=True)
296+
tsdate.date(ts, priors=priors, mutation_rate=2)
284297

285298
def test_no_mutation_times(self):
286299
ts = msprime.simulate(20, Ne=1, mutation_rate=1, random_seed=12)

tsdate/base.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,12 @@ def __init__(
9595
] = (-np.arange(num_nodes - self.num_nonfixed) - 1)
9696
self.probability_space = LIN
9797

98+
def fixed_node_ids(self):
99+
return np.where(self.row_lookup < 0)[0]
100+
101+
def nonfixed_node_ids(self):
102+
return np.where(self.row_lookup >= 0)[0]
103+
98104
def force_probability_space(self, probability_space):
99105
"""
100106
probability_space can be "logarithmic" or "linear": this function will force
@@ -140,6 +146,9 @@ def normalize(self):
140146
else:
141147
raise RuntimeError("Probability space is not", LIN, "or", LOG)
142148

149+
def is_fixed(self, node_id):
150+
return self.row_lookup[node_id] < 0
151+
143152
def __getitem__(self, node_id):
144153
index = self.row_lookup[node_id]
145154
if index < 0:
@@ -207,8 +216,7 @@ def fill_fixed(orig, fixed_data):
207216
new_obj.fixed_data = fill_fixed(
208217
self, grid_data if fixed_data is None else fixed_data
209218
)
210-
if probability_space is None:
211-
new_obj.probability_space = self.probability_space
212-
else:
213-
new_obj.probability_space = probability_space
219+
new_obj.probability_space = self.probability_space
220+
if probability_space is not None:
221+
new_obj.force_probability_space(probability_space)
214222
return new_obj

tsdate/core.py

Lines changed: 77 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def _lik(muts, span, dt, mutation_rate, normalize=True):
151151
"""
152152
ll = scipy.stats.poisson.pmf(muts, dt * mutation_rate * span)
153153
if normalize:
154-
return ll / np.max(ll)
154+
return ll / np.nanmax(ll)
155155
else:
156156
return ll
157157

@@ -258,15 +258,28 @@ def get_mut_lik_fixed_node(self, edge):
258258

259259
mutations_on_edge = self.mut_edges[edge.id]
260260
child_time = self.ts.node(edge.child).time
261-
#assert child_time == 0
262-
# Temporary hack - we should really take a more precise likelihood
263-
return self._lik(
264-
mutations_on_edge,
265-
edge.span,
266-
self.timediff,
267-
self.mut_rate,
268-
normalize=self.normalize,
269-
)
261+
if child_time == 0:
262+
return self._lik(
263+
mutations_on_edge,
264+
edge.span,
265+
self.timediff,
266+
self.mut_rate,
267+
normalize=self.normalize,
268+
)
269+
else:
270+
timediff = self.timepoints - child_time + 1e-8
271+
# Temporary hack - we should really take a more precise likelihood
272+
likelihood = self._lik(
273+
mutations_on_edge,
274+
edge.span,
275+
timediff,
276+
self.mut_rate,
277+
normalize=self.normalize,
278+
)
279+
# Prevent child from being older than parent
280+
likelihood[timediff < 0] = 0
281+
282+
return likelihood
270283

271284
def get_mut_lik_lower_tri(self, edge):
272285
"""
@@ -389,7 +402,7 @@ def get_fixed(self, arr, edge):
389402
return arr * liks
390403

391404
def scale_geometric(self, fraction, value):
392-
return value**fraction
405+
return value ** fraction
393406

394407

395408
class LogLikelihoods(Likelihoods):
@@ -429,7 +442,7 @@ def _lik(muts, span, dt, mutation_rate, normalize=True):
429442
"""
430443
ll = scipy.stats.poisson.logpmf(muts, dt * mutation_rate * span)
431444
if normalize:
432-
return ll - np.max(ll)
445+
return ll - np.nanmax(ll)
433446
else:
434447
return ll
435448

@@ -634,11 +647,22 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
634647
inside = self.priors.clone_with_new_data( # store inside matrix values
635648
grid_data=np.nan, fixed_data=self.lik.identity_constant
636649
)
650+
# It is possible that a simple node is non-fixed, in which case we want to
651+
# provide an inside array that reflects the prior distribution
652+
nonfixed_samples = np.intersect1d(inside.nonfixed_node_ids(), self.ts.samples())
653+
for u in nonfixed_samples:
654+
# this is in the same probability space as the prior, so we should be
655+
# OK just to copy the prior values straight in. It's unclear to me (Yan)
656+
# how/if they should be normalised, however
657+
inside[u][:] = self.priors[u]
658+
637659
if cache_inside:
638660
g_i = np.full(
639661
(self.ts.num_edges, self.lik.grid_size), self.lik.identity_constant
640662
)
641663
norm = np.full(self.ts.num_nodes, np.nan)
664+
to_visit = np.zeros(self.ts.num_nodes, dtype=bool)
665+
to_visit[inside.nonfixed_node_ids()] = True
642666
# Iterate through the nodes via groupby on parent node
643667
for parent, edges in tqdm(
644668
self.edges_by_parent_asc(),
@@ -673,14 +697,23 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
673697
"dangling nodes: please simplify it"
674698
)
675699
daughter_val = self.lik.scale_geometric(
676-
spanfrac, self.lik.make_lower_tri(inside[edge.child])
700+
spanfrac, self.lik.make_lower_tri(inside_values)
677701
)
678702
edge_lik = self.lik.get_inside(daughter_val, edge)
679703
val = self.lik.combine(val, edge_lik)
704+
if np.all(val == 0):
705+
raise ValueError
680706
if cache_inside:
681707
g_i[edge.id] = edge_lik
682-
norm[parent] = np.max(val) if normalize else 1
708+
norm[parent] = np.max(val) if normalize else self.lik.identity_constant
683709
inside[parent] = self.lik.reduce(val, norm[parent])
710+
to_visit[parent] = False
711+
712+
# There may be nodes that are not parents but are also not fixed (e.g.
713+
# undated sample nodes). These need an identity normalization constant
714+
for unfixed_unvisited in np.where(to_visit)[0]:
715+
norm[unfixed_unvisited] = self.lik.identity_constant
716+
684717
if cache_inside:
685718
self.g_i = self.lik.reduce(g_i, norm[self.ts.tables.edges.child, None])
686719
# Keep the results in this object
@@ -732,10 +765,10 @@ def outside_pass(
732765
if ignore_oldest_root:
733766
if edge.parent == self.ts.num_nodes - 1:
734767
continue
735-
#if edge.parent in self.fixednodes:
736-
# raise RuntimeError(
737-
# "Fixed nodes cannot currently be parents in the TS"
738-
# )
768+
if edge.parent in self.fixednodes:
769+
raise RuntimeError(
770+
"Fixed nodes cannot currently be parents in the TS"
771+
)
739772
# Geometric scaling works exactly for all nodes fixed in graph
740773
# but is an approximation when times are unknown.
741774
spanfrac = edge.span / self.spans[child]
@@ -897,34 +930,32 @@ def posterior_mean_var(ts, posterior, *, fixed_node_set=None):
897930
return ts, mn_post, vr_post
898931

899932

900-
def constrain_ages_topo(ts, post_mn, eps, nodes_to_date=None, progress=False):
933+
def constrain_ages_topo(ts, node_times, eps, progress=False):
901934
"""
902-
If predicted node times violate topology, restrict node ages so that they
903-
must be older than all their children.
935+
If node_times violate topology, return increased node_times so that each node is
936+
guaranteed to be older than any of its their children.
904937
"""
905-
new_mn_post = np.copy(post_mn)
906-
if nodes_to_date is None:
907-
nodes_to_date = np.arange(ts.num_nodes, dtype=np.uint64)
908-
nodes_to_date = nodes_to_date[~np.isin(nodes_to_date, ts.samples())]
909-
910-
tables = ts.tables
911-
parents = tables.edges.parent
912-
nd_children = tables.edges.child[np.argsort(parents)]
913-
parents = sorted(parents)
914-
parents_unique = np.unique(parents, return_index=True)
915-
parent_indices = parents_unique[1][np.isin(parents_unique[0], nodes_to_date)]
916-
for index, nd in tqdm(
917-
enumerate(sorted(nodes_to_date)), desc="Constrain Ages", disable=not progress
938+
edges_parent = ts.edges_parent
939+
edges_child = ts.edges_child
940+
941+
new_node_times = np.copy(node_times)
942+
# Traverse through the ARG, ensuring children come before parents.
943+
# This can be done by iterating over groups of edges with the same parent
944+
new_parent_edge_idx = np.where(np.diff(edges_parent) != 0)[0] + 1
945+
for edges_start, edges_end in tqdm(
946+
zip(
947+
itertools.chain([0], new_parent_edge_idx),
948+
itertools.chain(new_parent_edge_idx, [len(edges_parent)]),
949+
),
950+
desc="Constrain Ages",
951+
disable=not progress,
918952
):
919-
if index + 1 != len(nodes_to_date):
920-
children_index = np.arange(parent_indices[index], parent_indices[index + 1])
921-
else:
922-
children_index = np.arange(parent_indices[index], ts.num_edges)
923-
children = nd_children[children_index]
924-
time = np.max(new_mn_post[children])
925-
if new_mn_post[nd] <= time:
926-
new_mn_post[nd] = time + eps
927-
return new_mn_post
953+
parent = edges_parent[edges_start]
954+
child_ids = edges_child[edges_start:edges_end] # May contain dups
955+
oldest_child_time = np.max(new_node_times[child_ids])
956+
if oldest_child_time >= new_node_times[parent]:
957+
new_node_times[parent] = oldest_child_time + eps
958+
return new_node_times
928959

929960

930961
def date(
@@ -1015,7 +1046,7 @@ def date(
10151046
progress=progress,
10161047
**kwargs
10171048
)
1018-
constrained = constrain_ages_topo(tree_sequence, dates, eps, nds, progress)
1049+
constrained = constrain_ages_topo(tree_sequence, dates, eps, progress)
10191050
tables = tree_sequence.dump_tables()
10201051
tables.time_units = time_units
10211052
tables.nodes.time = constrained
@@ -1064,12 +1095,6 @@ def get_dates(
10641095
10651096
:return: tuple(mn_post, posterior, timepoints, eps, nodes_to_date)
10661097
"""
1067-
# Stuff yet to be implemented. These can be deleted once fixed
1068-
#for sample in tree_sequence.samples():
1069-
# if tree_sequence.node(sample).time != 0:
1070-
# raise NotImplementedError("Samples must all be at time 0")
1071-
fixed_nodes = set(tree_sequence.samples())
1072-
10731098
# Default to not creating approximate priors unless ts has > 1000 samples
10741099
approx_priors = False
10751100
if tree_sequence.num_samples > 1000:
@@ -1097,6 +1122,8 @@ def get_dates(
10971122
)
10981123
priors = priors
10991124

1125+
fixed_nodes = set(priors.fixed_node_ids())
1126+
11001127
if probability_space != base.LOG:
11011128
liklhd = Likelihoods(
11021129
tree_sequence,

0 commit comments

Comments
 (0)