Skip to content

Commit c8568d5

Browse files
Merge pull request #648 from jeromekelleher/fix-from-ts
Add support for multiple mutations
2 parents 6ed592f + 519905f commit c8568d5

File tree

2 files changed

+31
-29
lines changed

2 files changed

+31
-29
lines changed

tests/test_evaluation.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#
2-
# Copyright (C) 2018-2020 University of Oxford
2+
# Copyright (C) 2018-2022 University of Oxford
33
#
44
# This file is part of tsinfer.
55
#
@@ -527,9 +527,8 @@ class TestMakeAncestorsTs:
527527
Tests for the process of generating an ancestors tree sequence.
528528
"""
529529

530-
def verify_from_source(self, remove_leaves):
531-
ts = msprime.simulate(15, recombination_rate=1, mutation_rate=2, random_seed=3)
532-
samples = tsinfer.SampleData.from_tree_sequence(ts)
530+
def verify_from_source(self, ts, remove_leaves):
531+
samples = tsinfer.SampleData.from_tree_sequence(ts, use_sites_time=False)
533532
ancestors_ts = tsinfer.make_ancestors_ts(
534533
samples, ts, remove_leaves=remove_leaves
535534
)
@@ -538,11 +537,10 @@ def verify_from_source(self, remove_leaves):
538537
final_ts = tsinfer.match_samples(samples, ancestors_ts, engine=engine)
539538
tsinfer.verify(samples, final_ts)
540539

541-
def test_infer_from_source_no_leaves(self):
542-
self.verify_from_source(True)
543-
544-
def test_infer_from_source(self):
545-
self.verify_from_source(True)
540+
@pytest.mark.parametrize("remove_leaves", [True, False])
541+
def test_infer_from_source(self, remove_leaves):
542+
ts = msprime.simulate(15, recombination_rate=1, mutation_rate=2, random_seed=3)
543+
self.verify_from_source(ts, remove_leaves=remove_leaves)
546544

547545
def verify_from_inferred(self, remove_leaves):
548546
ts = msprime.simulate(15, recombination_rate=1, mutation_rate=2, random_seed=3)
@@ -556,11 +554,16 @@ def verify_from_inferred(self, remove_leaves):
556554
final_ts = tsinfer.match_samples(samples, ancestors_ts, engine=engine)
557555
tsinfer.verify(samples, final_ts)
558556

559-
def test_infer_from_inferred_no_leaves(self):
560-
self.verify_from_inferred(True)
557+
@pytest.mark.parametrize("remove_leaves", [True, False])
558+
def test_infer_from_inferred(self, remove_leaves):
559+
self.verify_from_inferred(remove_leaves)
561560

562-
def test_infer_from_inferred(self):
563-
self.verify_from_inferred(False)
561+
@pytest.mark.parametrize("remove_leaves", [True, False])
562+
def test_infer_from_source_multiple_mutations(self, remove_leaves):
563+
ts = msprime.sim_ancestry(5, sequence_length=100, random_seed=3)
564+
mts = msprime.sim_mutations(ts, rate=0.1, random_seed=3)
565+
assert mts.num_mutations > mts.num_sites
566+
self.verify_from_source(mts, remove_leaves=remove_leaves)
564567

565568

566569
class TestCheckAncestorsTs:

tsinfer/eval_util.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (C) 2018 University of Oxford
1+
# Copyright (C) 2018-2022 University of Oxford
22
#
33
# This file is part of tsinfer.
44
#
@@ -450,12 +450,17 @@ def make_ancestors_ts(samples, ts, remove_leaves=False):
450450
msprime.simulate. We remove populations, as normally ancestors tree sequences
451451
do not have populations defined.
452452
"""
453-
# Get the non-singleton sites
454-
position = []
455-
for var in ts.variants():
456-
if np.sum(var.genotypes) > 1:
457-
position.append(var.site.position)
458-
reduced = subset_sites(ts, position)
453+
# Get the non-singleton sites and those with > 1 mutation
454+
remove_sites = []
455+
for tree in ts.trees():
456+
for site in tree.sites():
457+
if len(site.mutations) != 1:
458+
remove_sites.append(site.id)
459+
else:
460+
if tree.num_samples(site.mutations[0].node) < 2:
461+
remove_sites.append(site.id)
462+
463+
reduced = ts.delete_sites(remove_sites)
459464
minimised = inference.minimise(reduced)
460465

461466
tables = minimised.dump_tables()
@@ -478,15 +483,9 @@ def make_ancestors_ts(samples, ts, remove_leaves=False):
478483
parent=tables.edges.parent + 1,
479484
child=tables.edges.child + 1,
480485
)
481-
tables.mutations.set_columns(
482-
node=tables.mutations.node + 1,
483-
site=tables.mutations.site,
484-
parent=tables.mutations.parent,
485-
derived_state=tables.mutations.derived_state,
486-
derived_state_offset=tables.mutations.derived_state_offset,
487-
metadata=tables.mutations.metadata,
488-
metadata_offset=tables.mutations.metadata_offset,
489-
)
486+
tables.mutations.node += 1
487+
# We could also set the time to UNKNOWN_TIME, this is a bit easier.
488+
tables.mutations.time += 1
490489

491490
trees = minimised.trees()
492491
tree = next(trees)

0 commit comments

Comments
 (0)