Skip to content

Commit 44827ed

Browse files
duncanMRmergify[bot]
authored andcommitted
Simplify handling of derived counts
1 parent dd71eb9 commit 44827ed

File tree

8 files changed

+1453
-1398
lines changed

8 files changed

+1453
-1398
lines changed

_tsinfermodule.c

Lines changed: 1415 additions & 1370 deletions
Large diffs are not rendered by default.

lib/ancestor_builder.c

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -474,12 +474,12 @@ ancestor_builder_compute_ancestral_states(const ancestor_builder_t *self, int di
474474
if (disagree[u] && (genotypes[j] != consensus)
475475
&& (genotypes[j] != TSK_MISSING_DATA)) {
476476
/* This sample has disagreed with consensus twice in a row,
477-
* so remove it */
477+
* so remove it */
478478
/* printf("\t\tremoving %d\n", sample_set[j]); */
479479
sample_set[j] = -1;
480480
}
481481
}
482-
482+
483483
site_time = sites[l].time;
484484
if (site_time > focal_site_time) {
485485
if (ones + zeros == 0) {
@@ -489,14 +489,14 @@ ancestor_builder_compute_ancestral_states(const ancestor_builder_t *self, int di
489489
}
490490
}
491491
/* For the remaining samples, set the disagree flags based
492-
* on whether they agree with the consensus for this site. */
492+
* on whether they agree with the consensus for this site. */
493493
derived_count = sites[l].derived_count;
494494
if ((site_time > focal_site_time) || (derived_count > ones)) {
495495
for (j = 0; j < sample_set_size; j++) {
496496
u = sample_set[j];
497497
if (u != -1) {
498498
disagree[u] = ((genotypes[j] != consensus)
499-
&& (genotypes[j] != TSK_MISSING_DATA));
499+
&& (genotypes[j] != TSK_MISSING_DATA));
500500
}
501501
}
502502
}
@@ -653,7 +653,7 @@ ancestor_builder_allocate_genotypes(ancestor_builder_t *self)
653653
}
654654

655655
int WARN_UNUSED
656-
ancestor_builder_add_site(ancestor_builder_t *self, double time, allele_t *genotypes, tsk_size_t derived_count)
656+
ancestor_builder_add_site(ancestor_builder_t *self, double time, allele_t *genotypes)
657657
{
658658
int ret = 0;
659659
site_t *site;
@@ -664,8 +664,16 @@ ancestor_builder_add_site(ancestor_builder_t *self, double time, allele_t *genot
664664
uint8_t *stored_genotypes = NULL;
665665
avl_tree_t *pattern_map;
666666
tsk_id_t site_id = (tsk_id_t) self->num_sites;
667+
size_t derived_count, j;
667668
time_map_t *time_map = ancestor_builder_get_time_map(self, time);
668669

670+
derived_count = 0;
671+
for (j = 0; j < (size_t) self->num_samples; j++) {
672+
if (genotypes[j] == 1) {
673+
derived_count++;
674+
}
675+
}
676+
669677
if (time_map == NULL) {
670678
ret = TSI_ERR_NO_MEMORY;
671679
goto out;

lib/tests/tests.c

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ run_random_data(size_t num_samples, size_t num_sites, int seed,
375375
genotypes[k] = samples[k][j];
376376
time += genotypes[k];
377377
}
378-
ret = ancestor_builder_add_site(&ancestor_builder, time, genotypes, 1);
378+
ret = ancestor_builder_add_site(&ancestor_builder, time, genotypes);
379379
CU_ASSERT_EQUAL_FATAL(ret, 0);
380380
}
381381
/* ancestor_builder_print_state(&ancestor_builder, stdout); */
@@ -478,15 +478,15 @@ test_ancestor_builder_errors(void)
478478
ret = ancestor_builder_alloc(&ancestor_builder, 2, 0, -1, 0);
479479
CU_ASSERT_EQUAL_FATAL(ret, 0);
480480
CU_ASSERT_EQUAL_FATAL(ancestor_builder.num_sites, 0);
481-
ret = ancestor_builder_add_site(&ancestor_builder, 4, genotypes_ones, 1);
481+
ret = ancestor_builder_add_site(&ancestor_builder, 4, genotypes_ones);
482482
CU_ASSERT_EQUAL_FATAL(ret, TSI_ERR_TOO_MANY_SITES);
483483
ancestor_builder_free(&ancestor_builder);
484484

485485
ret = ancestor_builder_alloc(&ancestor_builder, 4, 2, -1, 0);
486486
CU_ASSERT_EQUAL_FATAL(ret, 0);
487-
ret = ancestor_builder_add_site(&ancestor_builder, 4, genotypes_zeros, 1);
487+
ret = ancestor_builder_add_site(&ancestor_builder, 4, genotypes_zeros);
488488
CU_ASSERT_EQUAL_FATAL(ret, 0);
489-
ret = ancestor_builder_add_site(&ancestor_builder, 4, genotypes_ones, 1);
489+
ret = ancestor_builder_add_site(&ancestor_builder, 4, genotypes_ones);
490490
CU_ASSERT_EQUAL_FATAL(ret, 0);
491491
CU_ASSERT_EQUAL_FATAL(ancestor_builder.num_sites, 2);
492492
ret = ancestor_builder_finalise(&ancestor_builder);
@@ -509,7 +509,7 @@ test_ancestor_builder_one_site(void)
509509

510510
ret = ancestor_builder_alloc(&ancestor_builder, 4, 1, -1, 0);
511511
CU_ASSERT_EQUAL_FATAL(ret, 0);
512-
ret = ancestor_builder_add_site(&ancestor_builder, 4, genotypes, 1);
512+
ret = ancestor_builder_add_site(&ancestor_builder, 4, genotypes);
513513
CU_ASSERT_EQUAL_FATAL(ret, 0);
514514
ret = ancestor_builder_finalise(&ancestor_builder);
515515
CU_ASSERT_EQUAL_FATAL(ret, 0);

lib/tsinfer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ int ancestor_builder_alloc(ancestor_builder_t *self, size_t num_samples,
215215
int ancestor_builder_free(ancestor_builder_t *self);
216216
int ancestor_builder_print_state(ancestor_builder_t *self, FILE *out);
217217
int ancestor_builder_add_site(
218-
ancestor_builder_t *self, double time, allele_t *genotypes, tsk_size_t derived_count);
218+
ancestor_builder_t *self, double time, allele_t *genotypes);
219219
int ancestor_builder_finalise(ancestor_builder_t *self);
220220
int ancestor_builder_make_ancestor(const ancestor_builder_t *self,
221221
size_t num_focal_sites, const tsk_id_t *focal_sites, tsk_id_t *start, tsk_id_t *end,

tests/test_inference.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1702,7 +1702,7 @@ def test_array_args(self, tmp_path, tmpdir):
17021702
ts.tables.assert_equals(ts_batch.tables, ignore_provenance=True)
17031703

17041704

1705-
class TestAncestorGeneratorsEquivalent:
1705+
class TestAncestorGeneratorsEquivalant:
17061706
"""
17071707
Tests for the ancestor generation process.
17081708
"""
@@ -1920,7 +1920,7 @@ def test_bad_focal_sites(self):
19201920
g = np.zeros(2, dtype=np.int8)
19211921
h = np.zeros(1, dtype=np.int8)
19221922
generator = tsinfer.AncestorsGenerator(sample_data, None, {}, engine=engine)
1923-
generator.ancestor_builder.add_site(1, g, derived_count=0)
1923+
generator.ancestor_builder.add_site(1, g)
19241924
with pytest.raises(error):
19251925
generator.ancestor_builder.make_ancestor([0], h)
19261926

@@ -2725,10 +2725,7 @@ def test_ancestor_builder_print_state(self):
27252725
sample_data = self.sample_example(n_samples, n_sites)
27262726
ancestor_builder = tsinfer.algorithm.AncestorBuilder(n_samples, n_sites)
27272727
for variant in sample_data.variants():
2728-
derived_count = np.sum(variant.genotypes)
2729-
ancestor_builder.add_site(
2730-
variant.site.time, variant.genotypes, derived_count
2731-
)
2728+
ancestor_builder.add_site(variant.site.time, variant.genotypes)
27322729
with mock.patch("sys.stdout", new=io.StringIO()) as mock_output:
27332730
ancestor_builder.print_state()
27342731
# Simply check some text is output

tests/test_low_level.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,22 +134,19 @@ def test_add_site(self):
134134
ab = _tsinfer.AncestorBuilder(num_samples=2, max_sites=10)
135135
for bad_type in ["sdf", {}, None]:
136136
with pytest.raises(TypeError):
137-
ab.add_site(time=bad_type, genotypes=[0, 0], derived_count=0)
137+
ab.add_site(time=bad_type, genotypes=[0, 0])
138138
for bad_genotypes in ["asdf", [[], []], [0, 1, 2]]:
139139
with pytest.raises(ValueError):
140-
ab.add_site(time=0, genotypes=bad_genotypes, derived_count=0)
141-
for bad_derived_count in ["asdf", 1.2, [0, 1]]:
142-
with pytest.raises(TypeError):
143-
ab.add_site(time=0, genotypes=[0, 0], derived_count=bad_derived_count)
140+
ab.add_site(time=0, genotypes=bad_genotypes)
144141

145142
def test_add_too_many_sites(self):
146143
for max_sites in range(10):
147144
ab = _tsinfer.AncestorBuilder(num_samples=2, max_sites=max_sites)
148145
for _ in range(max_sites):
149-
ab.add_site(time=1, genotypes=[0, 1], derived_count=0)
146+
ab.add_site(time=1, genotypes=[0, 1])
150147
for _ in range(2 * max_sites):
151148
with pytest.raises(_tsinfer.LibraryError) as record:
152-
ab.add_site(time=1, genotypes=[0, 1], derived_count=0)
149+
ab.add_site(time=1, genotypes=[0, 1])
153150
msg = "Cannot add more sites than the specified maximum."
154151
assert str(record.value) == msg
155152

tsinfer/algorithm.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,12 @@ def store_site_genotypes(self, site_id, genotypes):
137137
stop = start + self.encoded_genotypes_size
138138
self.genotype_store[start:stop] = genotypes
139139

140-
def add_site(self, time, genotypes, derived_count):
140+
def add_site(self, time, genotypes):
141141
"""
142142
Adds a new site at the specified ID to the builder.
143143
"""
144144
site_id = len(self.sites)
145+
derived_count = np.sum(genotypes == 1)
145146
self.store_site_genotypes(site_id, genotypes)
146147
self.sites.append(Site(site_id, time, derived_count))
147148
sites_at_fixed_timepoint = self.time_map[time]
@@ -203,6 +204,15 @@ def ancestor_descriptors(self):
203204
return ret
204205

205206
def compute_ancestral_states(self, a, focal_site, sites):
207+
"""
208+
For a given focal site, and set of sites to fill in (usually all the ones
209+
leftwards or rightwards), augment the haplotype array a with the inferred sites
210+
Together with `make_ancestor`, which calls this function, these describe the main
211+
algorithm as implemented in Fig S2 of the preprint, with the buffer.
212+
213+
At the moment we assume that the derived state is 1. We should alter this so
214+
that we allow the derived state to be a different non-zero integer.
215+
"""
206216
focal_time = self.sites[focal_site].time
207217
g = self.get_site_genotypes(focal_site)
208218
sample_set = np.where(g == 1)[0]

tsinfer/inference.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1881,9 +1881,7 @@ def add_sites(self, exclude_positions=None):
18811881
if np.isnan(time):
18821882
use_site = False # Site with meaningless time value: skip inference
18831883
if use_site:
1884-
self.ancestor_builder.add_site(
1885-
time, variant.genotypes, int(counts.derived)
1886-
)
1884+
self.ancestor_builder.add_site(time, variant.genotypes)
18871885
inference_site_id.append(site.id)
18881886
self.num_sites += 1
18891887
progress.update()

0 commit comments

Comments
 (0)