22
22
"""
23
23
Python implementation of the Li and Stephens forwards and backwards algorithms.
24
24
"""
25
+ import io
25
26
import warnings
26
27
27
28
import lshmm as ls
37
38
MISSING = - 1
38
39
39
40
40
- # np.set_printoptions(linewidth=1000, precision=3)
41
+ # For debugging
42
+ np .set_printoptions (linewidth = 1000 , precision = 3 )
41
43
42
44
43
45
def check_alleles (alleles , m ):
@@ -151,7 +153,7 @@ def node_values(self):
151
153
def print_state (self ):
152
154
print ("LsHMM state" )
153
155
print ("match_all_nodes =" , self .match_all_nodes )
154
- print ("Tree =" )
156
+ print ("Tree = " , self . tree . index , self . tree . interval )
155
157
node_labels = {}
156
158
for u , value in self .node_values ().items ():
157
159
label = f"{ u } "
@@ -434,11 +436,13 @@ def update_probabilities(self, site, haplotype_state):
434
436
def process_site (self , site , haplotype_state ):
435
437
self .update_probabilities (site , haplotype_state )
436
438
# d1 = self.node_values()
439
+ print ("PRE" )
440
+ self .print_state ()
437
441
self .compress ()
438
442
# d2 = self.node_values()
439
443
# assert d1 == d2
440
- # print("AFTER COMPRESS")
441
- # self.print_state()
444
+ print ("AFTER COMPRESS" )
445
+ self .print_state ()
442
446
s = self .compute_normalisation_factor ()
443
447
for st in self .T :
444
448
assert st .tree_node != tskit .NULL
@@ -489,8 +493,13 @@ def run(self, h):
489
493
self .initialise (1 / n )
490
494
while self .tree .next ():
491
495
self .update_tree ()
496
+ if self .tree .index != 0 :
497
+ print ("AFTER UPDATE TREE" )
498
+ self .print_state ()
492
499
for site in self .tree .sites ():
493
500
self .process_site (site , h [site .id ])
501
+ print ("BEFORE UPDATE TREE" )
502
+ self .print_state ()
494
503
return self .output
495
504
496
505
def compute_normalisation_factor (self ):
@@ -1182,7 +1191,6 @@ def verify(self, ts):
1182
1191
self .assertAllClose (ll , ll_check )
1183
1192
1184
1193
1185
- # TODO add params to run the various checks
1186
1194
def check_viterbi (
1187
1195
ts ,
1188
1196
h ,
@@ -1212,10 +1220,10 @@ def check_viterbi(
1212
1220
cm = ls_viterbi_tree (
1213
1221
h , ts , rho = recombination , mu = mutation , match_all_nodes = match_all_nodes
1214
1222
)
1223
+ cm .print_state ()
1215
1224
path_tree = cm .traceback (match_all_nodes = match_all_nodes )
1216
1225
ll_tree = np .sum (np .log10 (cm .normalisation_factor ))
1217
1226
assert np .isscalar (ll_tree )
1218
- # print(cm)
1219
1227
# print("path tree = ", path_tree)
1220
1228
1221
1229
if compare_lshmm :
@@ -1437,8 +1445,8 @@ def test_match_sample(self, j):
1437
1445
ts = self .ts ()
1438
1446
h = np .zeros (4 )
1439
1447
h [j ] = 1
1440
- # path = check_viterbi(ts, h)
1441
- # nt.assert_array_equal([j, j, j, j], path)
1448
+ path = check_viterbi (ts , h )
1449
+ nt .assert_array_equal ([j , j , j , j ], path )
1442
1450
cm = check_forward_matrix (ts , h )
1443
1451
check_backward_matrix (ts , h , cm )
1444
1452
@@ -1525,6 +1533,19 @@ def test_match_sample(self, u, h):
1525
1533
)
1526
1534
1527
1535
1536
+ def validate_match_all_nodes (ts , h , expected_path ):
1537
+ path = check_viterbi (
1538
+ ts , h , match_all_nodes = True , compare_lib = False , compare_lshmm = False
1539
+ )
1540
+ nt .assert_array_equal (expected_path , path )
1541
+ cm = check_forward_matrix (
1542
+ ts , h , match_all_nodes = True , compare_lib = False , compare_lshmm = False
1543
+ )
1544
+ bm = check_backward_matrix (
1545
+ ts , h , cm , match_all_nodes = True , compare_lib = False , compare_lshmm = False
1546
+ )
1547
+
1548
+
1528
1549
class TestSingleBalancedTreeAllNodesExample :
1529
1550
# 3.00┊ 6 ┊
1530
1551
# ┊ ┏━┻━┓ ┊
@@ -1540,7 +1561,6 @@ def ts():
1540
1561
tables .tree_sequence (), start = 1 , nodes = np .arange (len (tables .nodes ) - 1 )
1541
1562
)
1542
1563
1543
- # def test_match_sample(self, u, h):
1544
1564
@pytest .mark .parametrize (
1545
1565
("h" , "expected_path" ),
1546
1566
[
@@ -1558,20 +1578,99 @@ def ts():
1558
1578
([0 , 0 , 0 , 0 , 0 , 0 ], [6 ] * 6 ),
1559
1579
],
1560
1580
)
1561
- def test_match_sample (self , h , expected_path ):
1562
- ts = self .ts ()
1563
- path = check_viterbi (
1564
- ts , h , match_all_nodes = True , compare_lib = False , compare_lshmm = False
1581
+ def test_exact_match (self , h , expected_path ):
1582
+ validate_match_all_nodes (self .ts (), h , expected_path )
1583
+
1584
+
1585
+ class TestMultiTreeExample :
1586
+ # 0.84┊ 7 ┊ 7 ┊
1587
+ # ┊ ┏━┻━┓ ┊ ┏━┻━┓ ┊
1588
+ # 0.42┊ ┃ ┃ ┊ 6 ┃ ┊
1589
+ # ┊ ┃ ┃ ┊ ┏┻┓ ┃ ┊
1590
+ # 0.05┊ 5 ┃ ┊ ┃ ┃ ┃ ┊
1591
+ # ┊ ┏━┻┓ ┃ ┊ ┃ ┃ ┃ ┊
1592
+ # 0.04┊ ┃ 4 ┃ ┊ ┃ ┃ 4 ┊
1593
+ # ┊ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┏┻┓ ┊
1594
+ # 0.00┊ 0 1 2 3 ┊ 0 3 1 2 ┊
1595
+ # 0 6 7
1596
+ @staticmethod
1597
+ def ts ():
1598
+ nodes = """\
1599
+ is_sample time
1600
+ 1 0.000000
1601
+ 1 0.000000
1602
+ 1 0.000000
1603
+ 1 0.000000
1604
+ 0 0.041304
1605
+ 0 0.045967
1606
+ 0 0.416719
1607
+ 0 0.838075
1608
+ """
1609
+ edges = """\
1610
+ left right parent child
1611
+ 0.000000 7.000000 4 1
1612
+ 0.000000 7.000000 4 2
1613
+ 0.000000 6.000000 5 0
1614
+ 0.000000 6.000000 5 4
1615
+ 6.000000 7.000000 6 0
1616
+ 6.000000 7.000000 6 3
1617
+ 0.000000 6.000000 7 3
1618
+ 6.000000 7.000000 7 4
1619
+ 0.000000 6.000000 7 5
1620
+ 6.000000 7.000000 7 6
1621
+ """
1622
+ ts = tskit .load_text (
1623
+ nodes = io .StringIO (nodes ), edges = io .StringIO (edges ), strict = False
1565
1624
)
1625
+ return add_unique_node_mutations (ts , nodes = range (7 ))
1626
+
1627
+ # 0.84┊ 7 ┊ 7 ┊
1628
+ # ┊ ┏━┻━┓ ┊ ┏━┻━┓ ┊
1629
+ # 0.42┊ ┃ ┃ ┊ 6 ┃ ┊
1630
+ # ┊ ┃ ┃ ┊ ┏┻┓ ┃ ┊
1631
+ # 0.05┊ 5 ┃ ┊ ┃ ┃ ┃ ┊
1632
+ # ┊ ┏━┻┓ ┃ ┊ ┃ ┃ ┃ ┊
1633
+ # 0.04┊ ┃ 4 ┃ ┊ ┃ ┃ 4 ┊
1634
+ # ┊ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┏┻┓ ┊
1635
+ # 0.00┊ 0 1 2 3 ┊ 0 3 1 2 ┊
1636
+ # 0 6 7
1637
+
1638
+ @pytest .mark .parametrize (
1639
+ ("h" , "expected_path" ),
1640
+ [
1641
+ # Just samples
1642
+ ([1 , 0 , 0 , 0 , 0 , 1 , 1 ], [0 ] * 7 ),
1643
+ ([0 , 1 , 0 , 0 , 1 , 1 , 0 ], [1 ] * 7 ),
1644
+ ([0 , 0 , 1 , 0 , 1 , 1 , 0 ], [2 ] * 7 ),
1645
+ ([0 , 0 , 0 , 1 , 0 , 0 , 1 ], [3 ] * 7 ),
1646
+ # Match root
1647
+ ([0 , 0 , 0 , 0 , 0 , 0 , 0 ], [7 ] * 7 ),
1648
+ ],
1649
+ )
1650
+ def test_match_all_nodes (self , h , expected_path ):
1651
+ # print()
1652
+ # print(self.ts().draw_text())
1653
+ # with open("tmp.svg", "w") as f:
1654
+ # f.write(self.ts().draw_svg())
1655
+ validate_match_all_nodes (self .ts (), h , expected_path )
1656
+
1657
+ @pytest .mark .parametrize (
1658
+ ("h" , "expected_path" ),
1659
+ [
1660
+ ([1 , 0 , 0 , 0 , 0 , 1 , 1 ], [0 ] * 7 ),
1661
+ ([0 , 1 , 0 , 0 , 1 , 1 , 0 ], [1 ] * 7 ),
1662
+ ([0 , 0 , 1 , 0 , 1 , 1 , 0 ], [2 ] * 7 ),
1663
+ ([0 , 0 , 0 , 1 , 0 , 0 , 1 ], [3 ] * 7 ),
1664
+ # Switch between each of the samples
1665
+ ([1 , 1 , 1 , 1 , 0 , 0 , 1 ], [0 , 1 , 2 , 3 , 3 , 3 , 3 ]),
1666
+ ],
1667
+ )
1668
+ def test_match_samples (self , h , expected_path ):
1669
+ ts = self .ts ()
1670
+ path = check_viterbi (ts , h )
1566
1671
nt .assert_array_equal (expected_path , path )
1567
- cm = check_forward_matrix (
1568
- ts , h , match_all_nodes = True , compare_lib = False , compare_lshmm = False
1569
- )
1570
- print (cm .decode ())
1571
- bm = check_backward_matrix (
1572
- ts , h , cm , match_all_nodes = True , compare_lib = False , compare_lshmm = False
1573
- )
1574
- print (bm .decode ())
1672
+ cm = check_forward_matrix (ts , h )
1673
+ check_backward_matrix (ts , h , cm )
1575
1674
1576
1675
1577
1676
class TestSimulationExamples :
0 commit comments