Skip to content

Commit 566ab42

Browse files
hyanwongmergify[bot]
authored andcommitted
Check indexes out of order
1 parent 0a3444a commit 566ab42

File tree

2 files changed

+18
-14
lines changed

2 files changed

+18
-14
lines changed

tests/test_inference.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1933,18 +1933,20 @@ def test_partial_samples(self):
19331933
ts2 = tsinfer.match_samples(sd, anc_ts, indexes=samples).simplify()
19341934
assert ts1.simplify(samples).equals(ts2, ignore_provenance=True)
19351935

1936-
def test_partial_bad_indexes(self):
1937-
sd = tsinfer.SampleData.from_tree_sequence(
1938-
msprime.simulate(
1939-
10, mutation_rate=2, recombination_rate=2, random_seed=233
1940-
),
1941-
use_sites_time=False,
1942-
)
1943-
ancestors = tsinfer.generate_ancestors(sd)
1944-
a_ts = tsinfer.match_ancestors(sd, ancestors)
1945-
for bad_samples in [[], [-1, 0], [0, 10]]:
1946-
with pytest.raises(ValueError):
1947-
tsinfer.match_samples(sd, a_ts, indexes=bad_samples)
1936+
@pytest.mark.parametrize(
1937+
"bad_indexes, match",
1938+
[
1939+
([], "at least one"),
1940+
([-1, 0], "bounds"),
1941+
([0, 1000], "bounds"),
1942+
([1, 0], "increasing"),
1943+
],
1944+
)
1945+
def test_partial_bad_indexes(self, small_sd_fixture, bad_indexes, match):
1946+
ancestors = tsinfer.generate_ancestors(small_sd_fixture)
1947+
a_ts = tsinfer.match_ancestors(small_sd_fixture, ancestors)
1948+
with pytest.raises(ValueError, match=match):
1949+
tsinfer.match_samples(small_sd_fixture, a_ts, indexes=bad_indexes)
19481950

19491951
def test_time_units_default_uncalibrated(self):
19501952
with tsinfer.SampleData(1.0) as sample_data:

tsinfer/inference.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ def check_sample_indexes(sample_data, indexes):
207207
raise ValueError("Must supply at least one sample to match")
208208
if np.any(indexes < 0) or np.any(indexes >= sample_data.num_samples):
209209
raise ValueError("Sample index out of bounds")
210+
if np.any(indexes[:-1] >= indexes[1:]):
211+
raise ValueError("Sample indexes must be in increasing order")
210212
return indexes
211213

212214

@@ -491,7 +493,7 @@ def augment_ancestors(
491493
:class:`tskit.TreeSequence` instance representing the inferred
492494
history among ancestral ancestral haplotypes.
493495
:param array indexes: The sample indexes to insert into the ancestors
494-
tree sequence.
496+
tree sequence, in increasing order.
495497
:param recombination_rate: Either a floating point value giving a constant rate
496498
:math:`\\rho` per unit length of genome, or an :class:`msprime.RateMap`
497499
object. This is used to calculate the probability of recombination between
@@ -591,7 +593,7 @@ def match_samples(
591593
``filter_individuals`` set to False and ``keep_unary`` set to True
592594
(default = ``True``).
593595
:param array_like indexes: An array of indexes into the sample_data file of
594-
the samples to match, or None for all samples.
596+
the samples to match (in increasing order) or None for all samples.
595597
:param bool force_sample_times: After matching, should an attempt be made to
596598
adjust the time of "historical samples" (those associated with an individual
597599
having a non-zero time) such that the sample nodes in the tree sequence

0 commit comments

Comments
 (0)