Skip to content

Commit c3166d9

Browse files
committed
test
1 parent 417ab2d commit c3166d9

File tree

1 file changed

+43
-2
lines changed

1 file changed

+43
-2
lines changed

tests/test_learn.py

+43-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,43 @@
1-
def test_noop():
2-
assert 1 == 1
1+
from unittest import TestCase
2+
3+
from lassobbn.learn import learn_parameters, learn_structure, to_bbn, to_join_tree
4+
5+
6+
class ApiTest(TestCase):
7+
def test_process(self):
8+
df_path = './data/data-binary.csv'
9+
meta_path = './data/data-binary-complete.json'
10+
11+
parents = learn_structure(df_path, meta_path, n_way=2, ignore_neg_gt=-0.01, ignore_pos_lt=0.05)
12+
self.assertDictEqual(parents, {'e': ['d!b'], 'd': ['b!a']})
13+
14+
# Learn the parameters
15+
d, g, p = learn_parameters(df_path, parents)
16+
self.assertDictEqual(d,
17+
{'d!b': ['0', '1'], 'e': ['0', '1'], 'd': ['0', '1'], 'b': ['0', '1'], 'b!a': ['0', '1'],
18+
'a': ['0', '1']})
19+
20+
edges = [f'{pa} -> {ch}' for pa, ch in g.edges()]
21+
self.assertEqual(edges, ['d!b -> e', 'd -> d!b', 'b -> d!b', 'b -> b!a', 'b!a -> d', 'a -> b!a'])
22+
23+
self.assertDictEqual(p, {'d!b': [1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0],
24+
'e': [0.7674080232799834, 0.23259197672001664, 0.08465608465608465,
25+
0.9153439153439153],
26+
'd': [0.7935015472506546, 0.2064984527493454, 0.8041301627033792, 0.19586983729662077],
27+
'b': [0.8029, 0.1971], 'b!a': [1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0],
28+
'a': [0.1893, 0.8107]}
29+
)
30+
31+
# Get the BBN
32+
bbn = to_bbn(d, g, p)
33+
jt = to_join_tree(bbn)
34+
35+
# get posteriors
36+
posteriors = [{**{'name': node}, **{val: prob for val, prob in posteriors.items()}}
37+
for node, posteriors in jt.get_posteriors().items()]
38+
self.assertEqual(posteriors, [{'name': 'd!b', '0': 0.960997490478821, '1': 0.03900250952117903},
39+
{'name': 'e', '0': 0.7407789842932012, '1': 0.2592210157067987},
40+
{'name': 'd', '0': 0.7951998827663717, '1': 0.20480011723362845},
41+
{'name': 'b', '0': 0.8029, '1': 0.1971},
42+
{'name': 'b!a', '0': 0.8402110300000001, '1': 0.15978897},
43+
{'name': 'a', '0': 0.18929999999999997, '1': 0.8107}])

0 commit comments

Comments
 (0)