Skip to content

Commit 773b66c

Browse files
Partial
1 parent 8926781 commit 773b66c

File tree

1 file changed

+120
-21
lines changed

1 file changed

+120
-21
lines changed

python/tests/test_haplotype_matching.py

+120-21
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"""
2323
Python implementation of the Li and Stephens forwards and backwards algorithms.
2424
"""
25+
import io
2526
import warnings
2627

2728
import lshmm as ls
@@ -37,7 +38,8 @@
3738
MISSING = -1
3839

3940

40-
# np.set_printoptions(linewidth=1000, precision=3)
41+
# For debugging
42+
np.set_printoptions(linewidth=1000, precision=3)
4143

4244

4345
def check_alleles(alleles, m):
@@ -151,7 +153,7 @@ def node_values(self):
151153
def print_state(self):
152154
print("LsHMM state")
153155
print("match_all_nodes =", self.match_all_nodes)
154-
print("Tree =")
156+
print("Tree = ", self.tree.index, self.tree.interval)
155157
node_labels = {}
156158
for u, value in self.node_values().items():
157159
label = f"{u}"
@@ -434,11 +436,13 @@ def update_probabilities(self, site, haplotype_state):
434436
def process_site(self, site, haplotype_state):
435437
self.update_probabilities(site, haplotype_state)
436438
# d1 = self.node_values()
439+
print("PRE")
440+
self.print_state()
437441
self.compress()
438442
# d2 = self.node_values()
439443
# assert d1 == d2
440-
# print("AFTER COMPRESS")
441-
# self.print_state()
444+
print("AFTER COMPRESS")
445+
self.print_state()
442446
s = self.compute_normalisation_factor()
443447
for st in self.T:
444448
assert st.tree_node != tskit.NULL
@@ -489,8 +493,13 @@ def run(self, h):
489493
self.initialise(1 / n)
490494
while self.tree.next():
491495
self.update_tree()
496+
if self.tree.index != 0:
497+
print("AFTER UPDATE TREE")
498+
self.print_state()
492499
for site in self.tree.sites():
493500
self.process_site(site, h[site.id])
501+
print("BEFORE UPDATE TREE")
502+
self.print_state()
494503
return self.output
495504

496505
def compute_normalisation_factor(self):
@@ -1182,7 +1191,6 @@ def verify(self, ts):
11821191
self.assertAllClose(ll, ll_check)
11831192

11841193

1185-
# TODO add params to run the various checks
11861194
def check_viterbi(
11871195
ts,
11881196
h,
@@ -1212,10 +1220,10 @@ def check_viterbi(
12121220
cm = ls_viterbi_tree(
12131221
h, ts, rho=recombination, mu=mutation, match_all_nodes=match_all_nodes
12141222
)
1223+
cm.print_state()
12151224
path_tree = cm.traceback(match_all_nodes=match_all_nodes)
12161225
ll_tree = np.sum(np.log10(cm.normalisation_factor))
12171226
assert np.isscalar(ll_tree)
1218-
# print(cm)
12191227
# print("path tree = ", path_tree)
12201228

12211229
if compare_lshmm:
@@ -1437,8 +1445,8 @@ def test_match_sample(self, j):
14371445
ts = self.ts()
14381446
h = np.zeros(4)
14391447
h[j] = 1
1440-
# path = check_viterbi(ts, h)
1441-
# nt.assert_array_equal([j, j, j, j], path)
1448+
path = check_viterbi(ts, h)
1449+
nt.assert_array_equal([j, j, j, j], path)
14421450
cm = check_forward_matrix(ts, h)
14431451
check_backward_matrix(ts, h, cm)
14441452

@@ -1525,6 +1533,19 @@ def test_match_sample(self, u, h):
15251533
)
15261534

15271535

1536+
def validate_match_all_nodes(ts, h, expected_path):
1537+
path = check_viterbi(
1538+
ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False
1539+
)
1540+
nt.assert_array_equal(expected_path, path)
1541+
cm = check_forward_matrix(
1542+
ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False
1543+
)
1544+
bm = check_backward_matrix(
1545+
ts, h, cm, match_all_nodes=True, compare_lib=False, compare_lshmm=False
1546+
)
1547+
1548+
15281549
class TestSingleBalancedTreeAllNodesExample:
15291550
# 3.00┊ 6 ┊
15301551
# ┊ ┏━┻━┓ ┊
@@ -1540,7 +1561,6 @@ def ts():
15401561
tables.tree_sequence(), start=1, nodes=np.arange(len(tables.nodes) - 1)
15411562
)
15421563

1543-
# def test_match_sample(self, u, h):
15441564
@pytest.mark.parametrize(
15451565
("h", "expected_path"),
15461566
[
@@ -1558,20 +1578,99 @@ def ts():
15581578
([0, 0, 0, 0, 0, 0], [6] * 6),
15591579
],
15601580
)
1561-
def test_match_sample(self, h, expected_path):
1562-
ts = self.ts()
1563-
path = check_viterbi(
1564-
ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False
1581+
def test_exact_match(self, h, expected_path):
1582+
validate_match_all_nodes(self.ts(), h, expected_path)
1583+
1584+
1585+
class TestMultiTreeExample:
1586+
# 0.84┊ 7 ┊ 7 ┊
1587+
# ┊ ┏━┻━┓ ┊ ┏━┻━┓ ┊
1588+
# 0.42┊ ┃ ┃ ┊ 6 ┃ ┊
1589+
# ┊ ┃ ┃ ┊ ┏┻┓ ┃ ┊
1590+
# 0.05┊ 5 ┃ ┊ ┃ ┃ ┃ ┊
1591+
# ┊ ┏━┻┓ ┃ ┊ ┃ ┃ ┃ ┊
1592+
# 0.04┊ ┃ 4 ┃ ┊ ┃ ┃ 4 ┊
1593+
# ┊ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┏┻┓ ┊
1594+
# 0.00┊ 0 1 2 3 ┊ 0 3 1 2 ┊
1595+
# 0 6 7
1596+
@staticmethod
1597+
def ts():
1598+
nodes = """\
1599+
is_sample time
1600+
1 0.000000
1601+
1 0.000000
1602+
1 0.000000
1603+
1 0.000000
1604+
0 0.041304
1605+
0 0.045967
1606+
0 0.416719
1607+
0 0.838075
1608+
"""
1609+
edges = """\
1610+
left right parent child
1611+
0.000000 7.000000 4 1
1612+
0.000000 7.000000 4 2
1613+
0.000000 6.000000 5 0
1614+
0.000000 6.000000 5 4
1615+
6.000000 7.000000 6 0
1616+
6.000000 7.000000 6 3
1617+
0.000000 6.000000 7 3
1618+
6.000000 7.000000 7 4
1619+
0.000000 6.000000 7 5
1620+
6.000000 7.000000 7 6
1621+
"""
1622+
ts = tskit.load_text(
1623+
nodes=io.StringIO(nodes), edges=io.StringIO(edges), strict=False
15651624
)
1625+
return add_unique_node_mutations(ts, nodes=range(7))
1626+
1627+
# 0.84┊ 7 ┊ 7 ┊
1628+
# ┊ ┏━┻━┓ ┊ ┏━┻━┓ ┊
1629+
# 0.42┊ ┃ ┃ ┊ 6 ┃ ┊
1630+
# ┊ ┃ ┃ ┊ ┏┻┓ ┃ ┊
1631+
# 0.05┊ 5 ┃ ┊ ┃ ┃ ┃ ┊
1632+
# ┊ ┏━┻┓ ┃ ┊ ┃ ┃ ┃ ┊
1633+
# 0.04┊ ┃ 4 ┃ ┊ ┃ ┃ 4 ┊
1634+
# ┊ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┏┻┓ ┊
1635+
# 0.00┊ 0 1 2 3 ┊ 0 3 1 2 ┊
1636+
# 0 6 7
1637+
1638+
@pytest.mark.parametrize(
1639+
("h", "expected_path"),
1640+
[
1641+
# Just samples
1642+
([1, 0, 0, 0, 0, 1, 1], [0] * 7),
1643+
([0, 1, 0, 0, 1, 1, 0], [1] * 7),
1644+
([0, 0, 1, 0, 1, 1, 0], [2] * 7),
1645+
([0, 0, 0, 1, 0, 0, 1], [3] * 7),
1646+
# Match root
1647+
([0, 0, 0, 0, 0, 0, 0], [7] * 7),
1648+
],
1649+
)
1650+
def test_match_all_nodes(self, h, expected_path):
1651+
# print()
1652+
# print(self.ts().draw_text())
1653+
# with open("tmp.svg", "w") as f:
1654+
# f.write(self.ts().draw_svg())
1655+
validate_match_all_nodes(self.ts(), h, expected_path)
1656+
1657+
@pytest.mark.parametrize(
1658+
("h", "expected_path"),
1659+
[
1660+
([1, 0, 0, 0, 0, 1, 1], [0] * 7),
1661+
([0, 1, 0, 0, 1, 1, 0], [1] * 7),
1662+
([0, 0, 1, 0, 1, 1, 0], [2] * 7),
1663+
([0, 0, 0, 1, 0, 0, 1], [3] * 7),
1664+
# Switch between each of the samples
1665+
([1, 1, 1, 1, 0, 0, 1], [0, 1, 2, 3, 3, 3, 3]),
1666+
],
1667+
)
1668+
def test_match_samples(self, h, expected_path):
1669+
ts = self.ts()
1670+
path = check_viterbi(ts, h)
15661671
nt.assert_array_equal(expected_path, path)
1567-
cm = check_forward_matrix(
1568-
ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False
1569-
)
1570-
print(cm.decode())
1571-
bm = check_backward_matrix(
1572-
ts, h, cm, match_all_nodes=True, compare_lib=False, compare_lshmm=False
1573-
)
1574-
print(bm.decode())
1672+
cm = check_forward_matrix(ts, h)
1673+
check_backward_matrix(ts, h, cm)
15751674

15761675

15771676
class TestSimulationExamples:

0 commit comments

Comments
 (0)