Skip to content

Commit b47db86

Browse files
committed
Add tests to check LS HMM of tskit compared to BEAGLE
1 parent b6f9872 commit b47db86

File tree

1 file changed

+272
-0
lines changed

1 file changed

+272
-0
lines changed

python/tests/test_imputation.py

+272
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
"""
2+
Tests for genotype imputation (forward and Baum-Welsh algorithms).
3+
"""
4+
import io
5+
6+
import numpy as np
7+
import pandas as pd
8+
9+
import _tskit
10+
import tskit
11+
12+
13+
# A tree sequence containing 3 diploid individuals with 5 sites and 5 mutations
14+
# (one per site). The first 2 individuals are used as reference panel,
15+
# the last one is the target individual.
16+
17+
toy_ts_nodes_text = """\
18+
id is_sample time population individual metadata
19+
0 1 0.000000 0 0
20+
1 1 0.000000 0 0
21+
2 1 0.000000 0 1
22+
3 1 0.000000 0 1
23+
4 1 0.000000 0 2
24+
5 1 0.000000 0 2
25+
6 0 0.029768 0 -1
26+
7 0 0.133017 0 -1
27+
8 0 0.223233 0 -1
28+
9 0 0.651586 0 -1
29+
10 0 0.698831 0 -1
30+
11 0 2.114867 0 -1
31+
12 0 4.322031 0 -1
32+
13 0 7.432311 0 -1
33+
"""
34+
35+
toy_ts_edges_text = """\
36+
left right parent child metadata
37+
0.000000 1000000.000000 6 0
38+
0.000000 1000000.000000 6 3
39+
0.000000 1000000.000000 7 2
40+
0.000000 1000000.000000 7 5
41+
0.000000 1000000.000000 8 1
42+
0.000000 1000000.000000 8 4
43+
0.000000 781157.000000 9 6
44+
0.000000 781157.000000 9 7
45+
0.000000 505438.000000 10 8
46+
0.000000 505438.000000 10 9
47+
505438.000000 549484.000000 11 8
48+
505438.000000 549484.000000 11 9
49+
781157.000000 1000000.000000 12 6
50+
781157.000000 1000000.000000 12 7
51+
549484.000000 1000000.000000 13 8
52+
549484.000000 781157.000000 13 9
53+
781157.000000 1000000.000000 13 12
54+
"""
55+
56+
toy_ts_sites_text = """\
57+
position ancestral_state metadata
58+
200000.000000 A
59+
300000.000000 C
60+
520000.000000 G
61+
600000.000000 T
62+
900000.000000 A
63+
"""
64+
65+
toy_ts_mutations_text = """\
66+
site node time derived_state parent metadata
67+
0 9 unknown G -1
68+
1 8 unknown A -1
69+
2 9 unknown T -1
70+
3 9 unknown C -1
71+
4 12 unknown C -1
72+
"""
73+
74+
toy_ts_individuals_text = """\
75+
flags
76+
0
77+
0
78+
0
79+
"""
80+
81+
82+
def get_toy_data():
83+
"""
84+
Returns toy data contained in the toy tree sequence in text format above.
85+
86+
:param: None
87+
:return: Reference panel tree sequence and query haplotypes.
88+
:rtype: list
89+
"""
90+
ts = tskit.load_text(
91+
nodes=io.StringIO(toy_ts_nodes_text),
92+
edges=io.StringIO(toy_ts_edges_text),
93+
sites=io.StringIO(toy_ts_sites_text),
94+
mutations=io.StringIO(toy_ts_mutations_text),
95+
individuals=io.StringIO(toy_ts_individuals_text),
96+
strict=False,
97+
)
98+
ref_ts = ts.simplify(samples=np.arange(2 * 2), filter_sites=False)
99+
query_ts = ts.simplify(samples=[5, 6], filter_sites=False)
100+
query_h = query_ts.genotype_matrix().T
101+
return [ref_ts, query_h]
102+
103+
104+
# BEAGLE 4.1 was run on the toy data set above using default parameters.
105+
#
106+
# In the query VCF, the site at position 520,000 was redacted and then imputed.
107+
# Note that the ancestral allele in the simulated tree sequence is
108+
# treated as the REF in the VCFs.
109+
#
110+
# The following are the forward probability matrices and backward probability
111+
# matrices calculated when imputing into the third individual above. There are
112+
# two sets of matrices, one for each haplotype.
113+
#
114+
# Notes about calculations:
115+
# n = number of haplotypes in ref. panel
116+
# M = number of markers
117+
# m = index of marker (site)
118+
# h = index of haplotype in ref. panel
119+
#
120+
# In forward probability matrix,
121+
# fwd[m][h] = emission prob., if m = 0 (first marker)
122+
# fwd[m][h] = emission prob. * (scale * fwd[m - 1][h] + shift), otherwise
123+
# where scale = (1 - switch prob.)/sum of fwd[m - 1],
124+
# and shift = switch prob./n.
125+
#
126+
# In backward probability matrix,
127+
# bwd[m][h] = 1, if m = M - 1 (last marker) // DON'T SEE THIS IN BEAGLE
128+
# unadj. bwd[m][h] = emission prob. / n
129+
# bwd[m][h] = (unadj. bwd[m][h] + shift) * scale, otherwise
130+
# where scale = (1 - switch prob.)/sum of unadj. bwd[m],
131+
# and shift = switch prob./n.
132+
#
133+
# For each site, the sum of backward value over all haplotypes is calculated
134+
# before scaling and shifting.
135+
136+
beagle_forward_matrix_text_1 = """
137+
m,h,probRec,probNoRec,noErrProb,errProb,refAl,queryAl,shiftFac,scaleFac,sumSite,val
138+
0,0,0.000000,1.000000,0.999900,0.000100,1,0,0.000000,1.000000,0.000100,0.000100
139+
0,1,0.000000,1.000000,0.999900,0.000100,0,0,0.000000,1.000000,1.000000,0.999900
140+
0,2,0.000000,1.000000,0.999900,0.000100,1,0,0.000000,1.000000,1.000100,0.000100
141+
0,3,0.000000,1.000000,0.999900,0.000100,1,0,0.000000,1.000000,1.000200,0.000100
142+
1,0,1.000000,0.000000,0.999900,0.000100,0,1,0.250000,0.000000,0.000025,0.000025
143+
1,1,1.000000,0.000000,0.999900,0.000100,1,1,0.250000,0.000000,0.250000,0.249975
144+
1,2,1.000000,0.000000,0.999900,0.000100,0,1,0.250000,0.000000,0.250025,0.000025
145+
1,3,1.000000,0.000000,0.999900,0.000100,0,1,0.250000,0.000000,0.250050,0.000025
146+
2,0,1.000000,0.000000,0.999900,0.000100,1,0,0.250000,0.000000,0.000025,0.000025
147+
2,1,1.000000,0.000000,0.999900,0.000100,0,0,0.250000,0.000000,0.250000,0.249975
148+
2,2,1.000000,0.000000,0.999900,0.000100,1,0,0.250000,0.000000,0.250025,0.000025
149+
2,3,1.000000,0.000000,0.999900,0.000100,1,0,0.250000,0.000000,0.250050,0.000025
150+
3,0,1.000000,0.000000,0.999900,0.000100,1,0,0.250000,0.000000,0.000025,0.000025
151+
3,1,1.000000,0.000000,0.999900,0.000100,0,0,0.250000,0.000000,0.250000,0.249975
152+
3,2,1.000000,0.000000,0.999900,0.000100,1,0,0.250000,0.000000,0.250025,0.000025
153+
3,3,1.000000,0.000000,0.999900,0.000100,1,0,0.250000,0.000000,0.250050,0.000025
154+
"""
155+
156+
beagle_backward_matrix_text_1 = """
157+
m,h,probRec,probNoRec,noErrProb,errProb,refAl,queryAl,shiftFac,scaleFac,sumSite,val
158+
3,0,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000
159+
3,1,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000
160+
3,2,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000
161+
3,3,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000
162+
2,0,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.250000,0.250050,0.250000
163+
2,1,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.250000,0.250050,0.250000
164+
2,2,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.250000,0.250050,0.250000
165+
2,3,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.250000,0.250050,0.250000
166+
1,0,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.250000,0.250050,0.250000
167+
1,1,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.250000,0.250050,0.250000
168+
1,2,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.250000,0.250050,0.250000
169+
1,3,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.250000,0.250050,0.250000
170+
0,0,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.250000,0.250050,0.250000
171+
0,1,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.250000,0.250050,0.250000
172+
0,2,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.250000,0.250050,0.250000
173+
0,3,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.250000,0.250050,0.250000
174+
"""
175+
176+
beagle_forward_matrix_text_2 = """
177+
m,h,probRec,probNoRec,noErrProb,errProb,refAl,queryAl,shiftFac,scaleFac,sumSite,val
178+
0,0,0.000000,1.000000,0.999900,0.000100,1,1,0.000000,1.000000,0.999900,0.999900
179+
0,1,0.000000,1.000000,0.999900,0.000100,0,1,0.000000,1.000000,1.000000,0.000100
180+
0,2,0.000000,1.000000,0.999900,0.000100,1,1,0.000000,1.000000,1.999900,0.999900
181+
0,3,0.000000,1.000000,0.999900,0.000100,1,1,0.000000,1.000000,2.999800,0.999900
182+
1,0,1.000000,0.000000,0.999900,0.000100,0,0,0.250000,0.000000,0.249975,0.249975
183+
1,1,1.000000,0.000000,0.999900,0.000100,1,0,0.250000,0.000000,0.250000,0.000025
184+
1,2,1.000000,0.000000,0.999900,0.000100,0,0,0.250000,0.000000,0.499975,0.249975
185+
1,3,1.000000,0.000000,0.999900,0.000100,0,0,0.250000,0.000000,0.749950,0.249975
186+
2,0,1.000000,0.000000,0.999900,0.000100,1,1,0.250000,0.000000,0.249975,0.249975
187+
2,1,1.000000,0.000000,0.999900,0.000100,0,1,0.250000,0.000000,0.250000,0.000025
188+
2,2,1.000000,0.000000,0.999900,0.000100,1,1,0.250000,0.000000,0.499975,0.249975
189+
2,3,1.000000,0.000000,0.999900,0.000100,1,1,0.250000,0.000000,0.749950,0.249975
190+
3,0,1.000000,0.000000,0.999900,0.000100,1,1,0.250000,0.000000,0.249975,0.249975
191+
3,1,1.000000,0.000000,0.999900,0.000100,0,1,0.250000,0.000000,0.250000,0.000025
192+
3,2,1.000000,0.000000,0.999900,0.000100,1,1,0.250000,0.000000,0.499975,0.249975
193+
3,3,1.000000,0.000000,0.999900,0.000100,1,1,0.250000,0.000000,0.749950,0.249975
194+
"""
195+
196+
beagle_backward_matrix_text_2 = """
197+
m,h,probRec,probNoRec,noErrProb,errProb,refAl,queryAl,shiftFac,scaleFac,sumSite,val
198+
3,0,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000
199+
3,1,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000
200+
3,2,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000
201+
3,3,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000
202+
2,0,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.250000,0.749950,0.250000
203+
2,1,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.250000,0.749950,0.250000
204+
2,2,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.250000,0.749950,0.250000
205+
2,3,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.250000,0.749950,0.250000
206+
1,0,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.250000,0.749950,0.250000
207+
1,1,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.250000,0.749950,0.250000
208+
1,2,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.250000,0.749950,0.250000
209+
1,3,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.250000,0.749950,0.250000
210+
0,0,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.250000,0.749950,0.250000
211+
0,1,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.250000,0.749950,0.250000
212+
0,2,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.250000,0.749950,0.250000
213+
0,3,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.250000,0.749950,0.250000
214+
"""
215+
216+
217+
def convert_to_numpy(matrix_text):
218+
"""Converts a forward or backward matrix in text format to numpy."""
219+
df = pd.read_csv(io.StringIO(matrix_text))
220+
# Check that switch and non-switch probabilities sum to 1
221+
assert np.all(np.isin(df.probRec + df.probNoRec, [1, -2]))
222+
# Check that non-mismatch and mismatch probabilities sum to 1
223+
assert np.all(np.isin(df.noErrProb + df.errProb, [1, -2]))
224+
return df.val.to_numpy().reshape((4, 4))
225+
226+
227+
def get_beagle_forward_backward_matrices():
228+
fwd_matrix_1 = convert_to_numpy(beagle_forward_matrix_text_1)
229+
bwd_matrix_1 = convert_to_numpy(beagle_backward_matrix_text_1)
230+
fwd_matrix_2 = convert_to_numpy(beagle_forward_matrix_text_2)
231+
bwd_matrix_2 = convert_to_numpy(beagle_backward_matrix_text_2)
232+
return [fwd_matrix_1, bwd_matrix_1, fwd_matrix_2, bwd_matrix_2]
233+
234+
235+
def get_beagle_data(matrix_text, data_type):
236+
"""Extracts data to check forward or backward probability matrix calculations."""
237+
df = pd.read_csv(io.StringIO(matrix_text))
238+
if data_type == "switch":
239+
# Switch probability, one per site
240+
return df.probRec.to_numpy().reshape((4, 4))[:, 0]
241+
elif data_type == "mismatch":
242+
# Mismatch probability, one per site
243+
return df.errProb.to_numpy().reshape((4, 4))[:, 0]
244+
elif data_type == "ref_hap_allele":
245+
# Allele in haplotype in reference panel
246+
# 0 = ref allele, 1 = alt allele
247+
return df.refAl.to_numpy().reshape((4, 4))
248+
elif data_type == "query_hap_allele":
249+
# Allele in haplotype in query
250+
# 0 = ref allele, 1 = alt allele
251+
return df.queryAl.to_numpy().reshape((4, 4))[:, 0]
252+
elif data_type == "shift":
253+
# Shift factor, one per site
254+
return df.shiftFac.to_numpy().reshape((4, 4))[:, 0]
255+
elif data_type == "scale":
256+
# Scale factor, one per site
257+
return df.scaleFac.to_numpy().reshape((4, 4))[:, 0]
258+
elif data_type == "sum":
259+
# Sum of values over haplotypes
260+
return df.sumSite.to_numpy().reshape((4, 4))[:, 0]
261+
else:
262+
raise ValueError(f"Unknown data type: {data_type}")
263+
264+
265+
def get_tskit_forward_backward_matrices(ts, h):
266+
m = ts.num_sites
267+
fm = _tskit.CompressedMatrix(ts)
268+
bm = _tskit.CompressedMatrix(ts)
269+
ls_hmm = _tskit.LsHmm(ts, np.zeros(m) + 0.1, np.zeros(m) + 0.1)
270+
ls_hmm.forward_matrix(h, fm)
271+
ls_hmm.backward_matrix(h, fm.normalisation_factor, bm)
272+
return [fm, bm]

0 commit comments

Comments
 (0)