diff --git a/pybloom/pybloom.py b/pybloom/pybloom.py index beeefe4..eb02291 100644 --- a/pybloom/pybloom.py +++ b/pybloom/pybloom.py @@ -222,6 +222,10 @@ def union(self, other): both the same capacity and error rate") new_bloom = self.copy() new_bloom.bitarray = new_bloom.bitarray | other.bitarray + # Set the new count + # https://en.wikipedia.org/wiki/Bloom_filter#The_union_and_intersection_of_sets + new_bloom.count = -float(self.bits_per_slice) * math.log( + 1 - (float(new_bloom.bitarray.count(1)) / float(self.num_bits))) return new_bloom def __or__(self, other): @@ -236,6 +240,12 @@ def intersection(self, other): have equal capacity and error rate") new_bloom = self.copy() new_bloom.bitarray = new_bloom.bitarray & other.bitarray + # Set the new count + # https://en.wikipedia.org/wiki/Bloom_filter#The_union_and_intersection_of_sets + # The FPR in the resulting Bloom filter may be larger than the false positive probability in the Bloom filter created from scratch using the intersection of the two set + # Intersection guarantees to have all elements of the intersection but the false positive rate might be slightly higher than that of the pure intersection: + new_bloom.count = self.count + other.count + float(self.bits_per_slice) * math.log( + 1 - (float((self.copy() | other).bitarray.count(1)) / float(self.num_bits))) return new_bloom def __and__(self, other): diff --git a/pybloom/tests.py b/pybloom/tests.py index 13d9b7d..8ec1712 100644 --- a/pybloom/tests.py +++ b/pybloom/tests.py @@ -35,6 +35,65 @@ def test_union(self): for char in chars: self.assertTrue(char in new_bloom) + def test_union_size(self): + fpr = 0.001 + # False positive rate with small numbers is high, therefore let's test with bigger sets + bloom_one = BloomFilter(100000, fpr) + bloom_two = BloomFilter(100000, fpr) + listA = [str(random.getrandbits(8)) for i in range(10000)] + listB = [str(random.getrandbits(8)) for i in range(10000)] + + for char in listA: + bloom_one.add(char) + for char in listB: + bloom_two.add(char) + + merged_bloom = bloom_one.union(bloom_two) + + bloom_one_count = bloom_one.count + bloom_two_count = bloom_two.count + + listA_uniq_count = len(set(listA)) + listB_uniq_count = len(set(listB)) + + merged_bloom_count = merged_bloom.count + listAB_uniq_count = len(set(listA).union(set(listB))) + + assert bloom_one_count == listA_uniq_count + assert bloom_two_count == listB_uniq_count + assert (listAB_uniq_count * (1 - fpr) <= merged_bloom_count <= listAB_uniq_count * (1 + fpr)) + + def test_intersection_size(self): + fpr = 0.001 + # False positive rate with small numbers is high, therefore let's test with bigger sets + bloom_one = BloomFilter(100000, fpr) + bloom_two = BloomFilter(100000, fpr) + listA = [str(random.getrandbits(14)) for i in range(71000)] + listB = [str(random.getrandbits(12)) for i in range(69000)] + + for char in listA: + bloom_one.add(char) + for char in listB: + bloom_two.add(char) + + merged_bloom = bloom_one.intersection(bloom_two) + + bloom_one_count = bloom_one.count + bloom_two_count = bloom_two.count + + listA_uniq_count = len(set(listA)) + listB_uniq_count = len(set(listB)) + + merged_bloom_count = merged_bloom.count + listAB_uniq_count = len(set(listA).intersection(set(listB))) + + assert bloom_one_count == listA_uniq_count + assert bloom_two_count == listB_uniq_count + # Intersection guarantees to have all elements of the intersection but the false positive rate might be slightly higher than that of the pure intersection: + assert (listAB_uniq_count * (1 - 2 * fpr) <= merged_bloom_count <= listAB_uniq_count * (1 + 2 * fpr)) + + + def test_intersection(self): bloom_one = BloomFilter(100, 0.001) bloom_two = BloomFilter(100, 0.001)