Skip to content

Commit dd71eb9

Browse files
duncanMRmergify[bot]
authored andcommitted
Revise ancestor generation algorithm to improve performance
1 parent ca2cf16 commit dd71eb9

File tree

8 files changed

+1499
-1528
lines changed

8 files changed

+1499
-1528
lines changed

_tsinfermodule.c

Lines changed: 1370 additions & 1415 deletions
Large diffs are not rendered by default.

lib/ancestor_builder.c

Lines changed: 67 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -419,8 +419,9 @@ ancestor_builder_compute_ancestral_states(const ancestor_builder_t *self, int di
419419
tsk_id_t last_site = focal_site;
420420
int64_t l;
421421
tsk_id_t u;
422-
size_t j, ones, zeros, tmp_size, sample_set_size, min_sample_set_size;
422+
size_t j, ones, zeros, tmp_size, sample_set_size, min_sample_set_size, derived_count;
423423
double focal_site_time = self->sites[focal_site].time;
424+
double site_time;
424425
const site_t *restrict sites = self->sites;
425426
const size_t num_sites = self->num_sites;
426427
allele_t consensus;
@@ -440,73 +441,78 @@ ancestor_builder_compute_ancestral_states(const ancestor_builder_t *self, int di
440441
/* printf("\tl = %d\n", (int) l); */
441442
ancestor[l] = 0;
442443
last_site = (tsk_id_t) l;
443-
if (sites[l].time > focal_site_time) {
444444

445-
/* printf("\t%d\t%d:", (int) l, (int) sample_set_size); */
446-
/* for (j = 0; j < sample_set_size; j++) { */
447-
/* printf("%d, ", sample_set[j]); */
448-
/* } */
449-
/* printf("\n"); */
450-
451-
ancestor_builder_get_site_genotypes_subset(
452-
self, (tsk_id_t) l, sample_set, sample_set_size, genotypes);
453-
ones = 0;
454-
zeros = 0;
455-
for (j = 0; j < sample_set_size; j++) {
456-
switch (genotypes[j]) {
457-
case 0:
458-
zeros++;
459-
break;
460-
case 1:
461-
ones++;
462-
break;
463-
}
445+
/* printf("\t%d\t%d:", (int) l, (int) sample_set_size); */
446+
/* for (j = 0; j < sample_set_size; j++) { */
447+
/* printf("%d, ", sample_set[j]); */
448+
/* } */
449+
/* printf("\n"); */
450+
451+
ancestor_builder_get_site_genotypes_subset(
452+
self, (tsk_id_t) l, sample_set, sample_set_size, genotypes);
453+
ones = 0;
454+
zeros = 0;
455+
for (j = 0; j < sample_set_size; j++) {
456+
switch (genotypes[j]) {
457+
case 0:
458+
zeros++;
459+
break;
460+
case 1:
461+
ones++;
462+
break;
464463
}
464+
}
465+
if (ones >= zeros) {
466+
consensus = 1;
467+
} else {
468+
consensus = 0;
469+
}
470+
/* printf("\t:ones=%d, consensus=%d\n", (int) ones, consensus); */
471+
/* fflush(stdout); */
472+
for (j = 0; j < sample_set_size; j++) {
473+
u = sample_set[j];
474+
if (disagree[u] && (genotypes[j] != consensus)
475+
&& (genotypes[j] != TSK_MISSING_DATA)) {
476+
/* This sample has disagreed with consensus twice in a row,
477+
* so remove it */
478+
/* printf("\t\tremoving %d\n", sample_set[j]); */
479+
sample_set[j] = -1;
480+
}
481+
}
482+
483+
site_time = sites[l].time;
484+
if (site_time > focal_site_time) {
465485
if (ones + zeros == 0) {
466486
ancestor[l] = TSK_MISSING_DATA;
467487
} else {
468-
if (ones >= zeros) {
469-
consensus = 1;
470-
} else {
471-
consensus = 0;
472-
}
473-
/* printf("\t:ones=%d, consensus=%d\n", (int) ones, consensus); */
474-
/* fflush(stdout); */
475-
for (j = 0; j < sample_set_size; j++) {
476-
u = sample_set[j];
477-
if (disagree[u] && (genotypes[j] != consensus)
478-
&& (genotypes[j] != TSK_MISSING_DATA)) {
479-
/* This sample has disagreed with consensus twice in a row,
480-
* so remove it */
481-
/* printf("\t\tremoving %d\n", sample_set[j]); */
482-
sample_set[j] = -1;
483-
}
484-
}
485488
ancestor[l] = consensus;
486-
/* For the remaining samples, set the disagree flags based
487-
* on whether they agree with the consensus for this site. */
488-
for (j = 0; j < sample_set_size; j++) {
489-
u = sample_set[j];
490-
if (u != -1) {
491-
disagree[u] = ((genotypes[j] != consensus)
492-
&& (genotypes[j] != TSK_MISSING_DATA));
493-
}
494-
}
495-
/* Repack the sample set */
496-
tmp_size = 0;
497-
for (j = 0; j < sample_set_size; j++) {
498-
if (sample_set[j] != -1) {
499-
sample_set[tmp_size] = sample_set[j];
500-
tmp_size++;
501-
}
502-
}
503-
sample_set_size = tmp_size;
504-
if (sample_set_size <= min_sample_set_size) {
505-
/* printf("BREAK\n"); */
506-
break;
489+
}
490+
}
491+
/* For the remaining samples, set the disagree flags based
492+
* on whether they agree with the consensus for this site. */
493+
derived_count = sites[l].derived_count;
494+
if ((site_time > focal_site_time) || (derived_count > ones)) {
495+
for (j = 0; j < sample_set_size; j++) {
496+
u = sample_set[j];
497+
if (u != -1) {
498+
disagree[u] = ((genotypes[j] != consensus)
499+
&& (genotypes[j] != TSK_MISSING_DATA));
507500
}
508501
}
509502
}
503+
/* Repack the sample set */
504+
tmp_size = 0;
505+
for (j = 0; j < sample_set_size; j++) {
506+
if (sample_set[j] != -1) {
507+
sample_set[tmp_size] = sample_set[j];
508+
tmp_size++;
509+
}
510+
}
511+
sample_set_size = tmp_size;
512+
if (sample_set_size <= min_sample_set_size) {
513+
/* printf("BREAK\n"); */
514+
break;
515+
}
510516
}
511517
*last_site_ret = last_site;
512518
return ret;
@@ -647,7 +653,7 @@ ancestor_builder_allocate_genotypes(ancestor_builder_t *self)
647653
}
648654

649655
int WARN_UNUSED
650-
ancestor_builder_add_site(ancestor_builder_t *self, double time, allele_t *genotypes)
656+
ancestor_builder_add_site(ancestor_builder_t *self, double time, allele_t *genotypes, tsk_size_t derived_count)
651657
{
652658
int ret = 0;
653659
site_t *site;
@@ -676,6 +682,7 @@ ancestor_builder_add_site(ancestor_builder_t *self, double time, allele_t *genot
676682
pattern_map = &time_map->pattern_map;
677683
site = &self->sites[site_id];
678684
site->time = time;
685+
site->derived_count = derived_count;
679686

680687
search.encoded_genotypes = encoded_genotypes;
681688
search.encoded_genotypes_size = self->encoded_genotypes_size;

lib/tests/tests.c

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ run_random_data(size_t num_samples, size_t num_sites, int seed,
375375
genotypes[k] = samples[k][j];
376376
time += genotypes[k];
377377
}
378-
ret = ancestor_builder_add_site(&ancestor_builder, time, genotypes);
378+
ret = ancestor_builder_add_site(&ancestor_builder, time, genotypes, 1);
379379
CU_ASSERT_EQUAL_FATAL(ret, 0);
380380
}
381381
/* ancestor_builder_print_state(&ancestor_builder, stdout); */
@@ -478,15 +478,15 @@ test_ancestor_builder_errors(void)
478478
ret = ancestor_builder_alloc(&ancestor_builder, 2, 0, -1, 0);
479479
CU_ASSERT_EQUAL_FATAL(ret, 0);
480480
CU_ASSERT_EQUAL_FATAL(ancestor_builder.num_sites, 0);
481-
ret = ancestor_builder_add_site(&ancestor_builder, 4, genotypes_ones);
481+
ret = ancestor_builder_add_site(&ancestor_builder, 4, genotypes_ones, 1);
482482
CU_ASSERT_EQUAL_FATAL(ret, TSI_ERR_TOO_MANY_SITES);
483483
ancestor_builder_free(&ancestor_builder);
484484

485485
ret = ancestor_builder_alloc(&ancestor_builder, 4, 2, -1, 0);
486486
CU_ASSERT_EQUAL_FATAL(ret, 0);
487-
ret = ancestor_builder_add_site(&ancestor_builder, 4, genotypes_zeros);
487+
ret = ancestor_builder_add_site(&ancestor_builder, 4, genotypes_zeros, 1);
488488
CU_ASSERT_EQUAL_FATAL(ret, 0);
489-
ret = ancestor_builder_add_site(&ancestor_builder, 4, genotypes_ones);
489+
ret = ancestor_builder_add_site(&ancestor_builder, 4, genotypes_ones, 1);
490490
CU_ASSERT_EQUAL_FATAL(ret, 0);
491491
CU_ASSERT_EQUAL_FATAL(ancestor_builder.num_sites, 2);
492492
ret = ancestor_builder_finalise(&ancestor_builder);
@@ -509,7 +509,7 @@ test_ancestor_builder_one_site(void)
509509

510510
ret = ancestor_builder_alloc(&ancestor_builder, 4, 1, -1, 0);
511511
CU_ASSERT_EQUAL_FATAL(ret, 0);
512-
ret = ancestor_builder_add_site(&ancestor_builder, 4, genotypes);
512+
ret = ancestor_builder_add_site(&ancestor_builder, 4, genotypes, 1);
513513
CU_ASSERT_EQUAL_FATAL(ret, 0);
514514
ret = ancestor_builder_finalise(&ancestor_builder);
515515
CU_ASSERT_EQUAL_FATAL(ret, 0);

lib/tsinfer.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ typedef struct _node_segment_list_node_t {
6969
typedef struct {
7070
double time;
7171
uint8_t *encoded_genotypes;
72+
tsk_size_t derived_count;
7273
} site_t;
7374

7475
typedef struct {
@@ -214,7 +215,7 @@ int ancestor_builder_alloc(ancestor_builder_t *self, size_t num_samples,
214215
int ancestor_builder_free(ancestor_builder_t *self);
215216
int ancestor_builder_print_state(ancestor_builder_t *self, FILE *out);
216217
int ancestor_builder_add_site(
217-
ancestor_builder_t *self, double time, allele_t *genotypes);
218+
ancestor_builder_t *self, double time, allele_t *genotypes, tsk_size_t derived_count);
218219
int ancestor_builder_finalise(ancestor_builder_t *self);
219220
int ancestor_builder_make_ancestor(const ancestor_builder_t *self,
220221
size_t num_focal_sites, const tsk_id_t *focal_sites, tsk_id_t *start, tsk_id_t *end,

tests/test_inference.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1702,7 +1702,7 @@ def test_array_args(self, tmp_path, tmpdir):
17021702
ts.tables.assert_equals(ts_batch.tables, ignore_provenance=True)
17031703

17041704

1705-
class TestAncestorGeneratorsEquivalant:
1705+
class TestAncestorGeneratorsEquivalent:
17061706
"""
17071707
Tests for the ancestor generation process.
17081708
"""
@@ -1920,7 +1920,7 @@ def test_bad_focal_sites(self):
19201920
g = np.zeros(2, dtype=np.int8)
19211921
h = np.zeros(1, dtype=np.int8)
19221922
generator = tsinfer.AncestorsGenerator(sample_data, None, {}, engine=engine)
1923-
generator.ancestor_builder.add_site(1, g)
1923+
generator.ancestor_builder.add_site(1, g, derived_count=0)
19241924
with pytest.raises(error):
19251925
generator.ancestor_builder.make_ancestor([0], h)
19261926

@@ -2725,7 +2725,10 @@ def test_ancestor_builder_print_state(self):
27252725
sample_data = self.sample_example(n_samples, n_sites)
27262726
ancestor_builder = tsinfer.algorithm.AncestorBuilder(n_samples, n_sites)
27272727
for variant in sample_data.variants():
2728-
ancestor_builder.add_site(variant.site.time, variant.genotypes)
2728+
derived_count = np.sum(variant.genotypes)
2729+
ancestor_builder.add_site(
2730+
variant.site.time, variant.genotypes, derived_count
2731+
)
27292732
with mock.patch("sys.stdout", new=io.StringIO()) as mock_output:
27302733
ancestor_builder.print_state()
27312734
# Simply check some text is output
@@ -3533,16 +3536,16 @@ def test_simple_case(self):
35333536

35343537
def test_simulation_with_error(self):
35353538
ts = msprime.simulate(
3536-
50, mutation_rate=10, random_seed=4, recombination_rate=15
3539+
50, mutation_rate=10, random_seed=5, recombination_rate=15
35373540
)
3538-
ts = eval_util.insert_errors(ts, 0.2, seed=32)
3541+
ts = eval_util.insert_errors(ts, 0.2, seed=33)
35393542
sample_data = tsinfer.SampleData.from_tree_sequence(ts, use_sites_time=False)
35403543
self.verify(sample_data)
35413544

35423545
def test_small_random_data(self):
35433546
n = 25
35443547
m = 20
3545-
G, positions = get_random_data_example(n, m)
3548+
G, positions = get_random_data_example(n, m, seed=101)
35463549
with tsinfer.SampleData(sequence_length=m) as sample_data:
35473550
for genotypes, position in zip(G, positions):
35483551
sample_data.add_site(position, genotypes)
@@ -3551,7 +3554,7 @@ def test_small_random_data(self):
35513554
def test_large_random_data(self):
35523555
n = 100
35533556
m = 30
3554-
G, positions = get_random_data_example(n, m)
3557+
G, positions = get_random_data_example(n, m, seed=100)
35553558
with tsinfer.SampleData(sequence_length=m) as sample_data:
35563559
for genotypes, position in zip(G, positions):
35573560
sample_data.add_site(position, genotypes)

tests/test_low_level.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,19 +134,22 @@ def test_add_site(self):
134134
ab = _tsinfer.AncestorBuilder(num_samples=2, max_sites=10)
135135
for bad_type in ["sdf", {}, None]:
136136
with pytest.raises(TypeError):
137-
ab.add_site(time=bad_type, genotypes=[0, 0])
137+
ab.add_site(time=bad_type, genotypes=[0, 0], derived_count=0)
138138
for bad_genotypes in ["asdf", [[], []], [0, 1, 2]]:
139139
with pytest.raises(ValueError):
140-
ab.add_site(time=0, genotypes=bad_genotypes)
140+
ab.add_site(time=0, genotypes=bad_genotypes, derived_count=0)
141+
for bad_derived_count in ["asdf", 1.2, [0, 1]]:
142+
with pytest.raises(TypeError):
143+
ab.add_site(time=0, genotypes=[0, 0], derived_count=bad_derived_count)
141144

142145
def test_add_too_many_sites(self):
143146
for max_sites in range(10):
144147
ab = _tsinfer.AncestorBuilder(num_samples=2, max_sites=max_sites)
145148
for _ in range(max_sites):
146-
ab.add_site(time=1, genotypes=[0, 1])
149+
ab.add_site(time=1, genotypes=[0, 1], derived_count=0)
147150
for _ in range(2 * max_sites):
148151
with pytest.raises(_tsinfer.LibraryError) as record:
149-
ab.add_site(time=1, genotypes=[0, 1])
152+
ab.add_site(time=1, genotypes=[0, 1], derived_count=0)
150153
msg = "Cannot add more sites than the specified maximum."
151154
assert str(record.value) == msg
152155

0 commit comments

Comments
 (0)