|
1 |
| -def test_noop(): |
2 |
| - assert 1 == 1 |
| 1 | +from unittest import TestCase |
| 2 | + |
| 3 | +from lassobbn.learn import learn_parameters, learn_structure, to_bbn, to_join_tree |
| 4 | + |
| 5 | + |
| 6 | +class ApiTest(TestCase): |
| 7 | + def test_process(self): |
| 8 | + df_path = './data/data-binary.csv' |
| 9 | + meta_path = './data/data-binary-complete.json' |
| 10 | + |
| 11 | + parents = learn_structure(df_path, meta_path, n_way=2, ignore_neg_gt=-0.01, ignore_pos_lt=0.05) |
| 12 | + self.assertDictEqual(parents, {'e': ['d!b'], 'd': ['b!a']}) |
| 13 | + |
| 14 | + # Learn the parameters |
| 15 | + d, g, p = learn_parameters(df_path, parents) |
| 16 | + self.assertDictEqual(d, |
| 17 | + {'d!b': ['0', '1'], 'e': ['0', '1'], 'd': ['0', '1'], 'b': ['0', '1'], 'b!a': ['0', '1'], |
| 18 | + 'a': ['0', '1']}) |
| 19 | + |
| 20 | + edges = [f'{pa} -> {ch}' for pa, ch in g.edges()] |
| 21 | + self.assertEqual(edges, ['d!b -> e', 'd -> d!b', 'b -> d!b', 'b -> b!a', 'b!a -> d', 'a -> b!a']) |
| 22 | + |
| 23 | + self.assertDictEqual(p, {'d!b': [1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0], |
| 24 | + 'e': [0.7674080232799834, 0.23259197672001664, 0.08465608465608465, |
| 25 | + 0.9153439153439153], |
| 26 | + 'd': [0.7935015472506546, 0.2064984527493454, 0.8041301627033792, 0.19586983729662077], |
| 27 | + 'b': [0.8029, 0.1971], 'b!a': [1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0], |
| 28 | + 'a': [0.1893, 0.8107]} |
| 29 | + ) |
| 30 | + |
| 31 | + # Get the BBN |
| 32 | + bbn = to_bbn(d, g, p) |
| 33 | + jt = to_join_tree(bbn) |
| 34 | + |
| 35 | + # get posteriors |
| 36 | + posteriors = [{**{'name': node}, **{val: prob for val, prob in posteriors.items()}} |
| 37 | + for node, posteriors in jt.get_posteriors().items()] |
| 38 | + self.assertEqual(posteriors, [{'name': 'd!b', '0': 0.960997490478821, '1': 0.03900250952117903}, |
| 39 | + {'name': 'e', '0': 0.7407789842932012, '1': 0.2592210157067987}, |
| 40 | + {'name': 'd', '0': 0.7951998827663717, '1': 0.20480011723362845}, |
| 41 | + {'name': 'b', '0': 0.8029, '1': 0.1971}, |
| 42 | + {'name': 'b!a', '0': 0.8402110300000001, '1': 0.15978897}, |
| 43 | + {'name': 'a', '0': 0.18929999999999997, '1': 0.8107}]) |
0 commit comments