diff --git a/tensorflow_transform/info_theory.py b/tensorflow_transform/info_theory.py index f536d105..e0497c5b 100644 --- a/tensorflow_transform/info_theory.py +++ b/tensorflow_transform/info_theory.py @@ -14,7 +14,7 @@ """Utilities for information-theoretic preprocessing algorithms.""" import math - +import numpy as np # math.log2 was added in Python 3.3 log2 = getattr(math, 'log2', lambda x: math.log(x, 2)) @@ -52,12 +52,10 @@ def calculate_partial_expected_mutual_information(n, x_i, y_j): if x_i == 0 or y_j == 0: return 0 coefficient = (-log2(x_i) - log2(y_j) + log2(n)) - sum_probability = 0.0 - partial_result = 0.0 - for n_j, p_j in _hypergeometric_pmf(n, x_i, y_j): - if n_j != 0: - partial_result += n_j * (coefficient + log2(n_j)) * p_j - sum_probability += p_j + hyp_geo_pmf = _hypergeometric_pmf(n, x_i, y_j) + sum_probability = np.sum([p_j for n_j, p_j in hyp_geo_pmf]) + partial_result = np.sum([n_j * (coefficient + log2(n_j)) * p_j for n_j, p_j in hyp_geo_pmf if n_j != 0]) + # The values of p_j should sum to 1, but given approximate calculations for # log2(x) and exp2(x) with large x, the full pmf might not sum to exactly 1. # We correct for this by dividing by the sum of the probabilities.