Skip to content

Commit 200b6ab

Browse files
committed
Add function to get forward and backward matrices from _tskit.lshmm WIP
1 parent 1679b1a commit 200b6ab

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

python/tests/test_imputation.py

+11
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
import pandas as pd
88

9+
import _tskit
910
import tskit
1011

1112

@@ -259,3 +260,13 @@ def get_beagle_data(matrix_text, data_type):
259260
return df.sumSite.to_numpy().reshape((4, 4))[:, 0]
260261
else:
261262
raise ValueError(f"Unknown data type: {data_type}")
263+
264+
265+
def get_tskit_forward_backward_matrices(ts, h):
266+
m = ts.num_sites
267+
fm = _tskit.CompressedMatrix(ts)
268+
bm = _tskit.CompressedMatrix(ts)
269+
ls_hmm = _tskit.LsHmm(ts, np.zeros(m) + 0.1, np.zeros(m) + 0.1)
270+
ls_hmm.forward_matrix(h, fm)
271+
ls_hmm.backward_matrix(h, fm.normalisation_factor, bm)
272+
return [fm, bm]

0 commit comments

Comments
 (0)