Skip to content
This repository was archived by the owner on Jul 30, 2024. It is now read-only.

HMAC enabled Bloom filter and Scalable Bloom Filter #15

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 81 additions & 25 deletions pybloom/pybloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Requires the bitarray library: http://pypi.python.org/pypi/bitarray/

>>> from pybloom import BloomFilter
>>> f = BloomFilter(capacity=10000, error_rate=0.001)
>>> f = BloomFilter(capacity=10000, error_rate=0.001, hashmac=1)
>>> for i in range_fn(0, f.capacity):
... _ = f.add(i)
...
Expand All @@ -20,7 +20,7 @@
True

>>> from pybloom import ScalableBloomFilter
>>> sbf = ScalableBloomFilter(mode=ScalableBloomFilter.SMALL_SET_GROWTH)
>>> sbf = ScalableBloomFilter(mode=ScalableBloomFilter.SMALL_SET_GROWTH, hashmac=0)
>>> count = 10000
>>> for i in range_fn(0, count):
... _ = sbf.add(i)
Expand All @@ -36,7 +36,9 @@
from __future__ import absolute_import
import math
import hashlib
from pybloom.utils import range_fn, is_string_io, running_python_3
import hmac
import random
from utils import range_fn, is_string_io, running_python_3
from struct import unpack, pack, calcsize

try:
Expand All @@ -50,30 +52,72 @@
Alex Brasetvik <[email protected]>,\
Matt Bachmann <[email protected]>,\
"
# the last parameter hashmac determines whether
# HMAC is to be used

def make_hashfuncs(num_slices, num_bits):
def make_hashfuncs(num_slices, num_bits, hashmac):
if num_bits >= (1 << 31):
fmt_code, chunk_size = 'Q', 8
elif num_bits >= (1 << 15):
fmt_code, chunk_size = 'I', 4
else:
fmt_code, chunk_size = 'H', 2
total_hash_bits = 8 * num_slices * chunk_size

if total_hash_bits > 384:
hashfn = hashlib.sha512
digest_size = 64
if hashmac==0:
hashfn = hashlib.sha512
else:
hmackey = random.getrandbits(digest_size)
hashfn = hmac
elif total_hash_bits > 256:
hashfn = hashlib.sha384
digest_size = 48
if hashmac==0:
hashfn = hashlib.sha384
else:
hmackey = random.getrandbits(digest_size)
hashfn = hmac
elif total_hash_bits > 160:
hashfn = hashlib.sha256
digest_size = 32
if hashmac==0:
hashfn = hashlib.sha256
else:
hmackey = random.getrandbits(digest_size)
hashfn = hmac
elif total_hash_bits > 128:
hashfn = hashlib.sha1
digest_size = 20
if hashmac==0:
hashfn = hashlib.sha1
else:
hmackey = random.getrandbits(digest_size)
hashfn = hmac
else:
hashfn = hashlib.md5
fmt = fmt_code * (hashfn().digest_size // chunk_size)
digest_size=16
if hashmac==0:
hashfn = hashlib.md5
else:
hmackey = random.getrandbits(digest_size)
hashfn = hmac
fmt = fmt_code * (digest_size // chunk_size)
num_salts, extra = divmod(num_slices, len(fmt))
if extra:
num_salts += 1
salts = tuple(hashfn(hashfn(pack('I', i)).digest()) for i in range_fn(num_salts))

if hashmac==0:
salts = tuple(hashfn(hashfn(pack('I', i)).digest()) for i in range_fn(num_salts))
else:
if digest_size==64:
salts = tuple(hashfn.new(str(hmackey),hashfn.new(str(hmackey), pack('I',i), hashlib.sha512).digest(), hashlib.sha512) for i in range_fn(num_salts))
elif digest_size==48:
salts =tuple(hashfn.new(str(hmackey),hashfn.new(str(hmackey), pack('I',i), hashlib.sha384).digest(), hashlib.sha384) for i in range_fn(num_salts))
elif digest_size==32:
salts = tuple(hashfn.new(str(hmackey),hashfn.new(str(hmackey), pack('I',i), hashlib.sha256).digest(), hashlib.sha256) for i in range_fn(num_salts))
elif digest_size==20:
salts = tuple(hashfn.new(str(hmackey),hashfn.new(str(hmackey), pack('I',i), hashlib.sha1).digest(), hashlib.sha1) for i in range_fn(num_salts))
else:
salts = tuple(hashfn.new(str(hmackey),hashfn.new(str(hmackey), pack('I',i), hashlib.md5).digest(), hashlib.md5) for i in range_fn(num_salts))

def _make_hashfuncs(key):
if running_python_3:
if isinstance(key, str):
Expand All @@ -89,6 +133,7 @@ def _make_hashfuncs(key):
for salt in salts:
h = salt.copy()
h.update(key)
h.hexdigest()
for uint in unpack(fmt, h.digest()):
yield uint % num_bits
i += 1
Expand All @@ -101,7 +146,7 @@ def _make_hashfuncs(key):
class BloomFilter(object):
FILE_FMT = b'<dQQQQ'

def __init__(self, capacity, error_rate=0.001):
def __init__(self, capacity, error_rate=0.001, hashmac=0):
"""Implements a space-efficient probabilistic data structure

capacity
Expand All @@ -112,8 +157,10 @@ def __init__(self, capacity, error_rate=0.001):
the error_rate of the filter returning false positives. This
determines the filters capacity. Inserting more than capacity
elements greatly increases the chance of false positives.
hashmac
whether keyed-hash functions are to be used

>>> b = BloomFilter(capacity=100000, error_rate=0.001)
>>> b = BloomFilter(capacity=100000, error_rate=0.001, hashmac=0)
>>> b.add("test")
False
>>> "test" in b
Expand All @@ -134,18 +181,19 @@ def __init__(self, capacity, error_rate=0.001):
bits_per_slice = int(math.ceil(
(capacity * abs(math.log(error_rate))) /
(num_slices * (math.log(2) ** 2))))
self._setup(error_rate, num_slices, bits_per_slice, capacity, 0)
self._setup(error_rate, num_slices, bits_per_slice, capacity,0, hashmac)
self.bitarray = bitarray.bitarray(self.num_bits, endian='little')
self.bitarray.setall(False)

def _setup(self, error_rate, num_slices, bits_per_slice, capacity, count):
def _setup(self, error_rate, num_slices, bits_per_slice, capacity, count, hashmac):
self.error_rate = error_rate
self.num_slices = num_slices
self.bits_per_slice = bits_per_slice
self.capacity = capacity
self.num_bits = num_slices * bits_per_slice
self.count = count
self.make_hashes = make_hashfuncs(self.num_slices, self.bits_per_slice)
self.make_hashes = make_hashfuncs(self.num_slices, self.bits_per_slice, hashmac)
self.hashmac = hashmac

def __contains__(self, key):
"""Tests a key's membership in this bloom filter.
Expand Down Expand Up @@ -282,15 +330,16 @@ def __getstate__(self):

def __setstate__(self, d):
self.__dict__.update(d)
self.make_hashes = make_hashfuncs(self.num_slices, self.bits_per_slice)
self.make_hashes = make_hashfuncs(self.num_slices, self.bits_per_slice, self.hashmac)



class ScalableBloomFilter(object):
SMALL_SET_GROWTH = 2 # slower, but takes up less memory
LARGE_SET_GROWTH = 4 # faster, but takes up more memory faster
FILE_FMT = '<idQd'

def __init__(self, initial_capacity=100, error_rate=0.001,
mode=SMALL_SET_GROWTH):
def __init__(self, initial_capacity=100, error_rate=0.001,mode=SMALL_SET_GROWTH, hashmac=0):
"""Implements a space-efficient probabilistic data structure that
grows as more items are added while maintaining a steady false
positive rate
Expand All @@ -306,9 +355,12 @@ def __init__(self, initial_capacity=100, error_rate=0.001,
ScalableBloomFilter.LARGE_SET_GROWTH. SMALL_SET_GROWTH is slower
but uses less memory. LARGE_SET_GROWTH is faster but consumes
memory faster.
hashmac
whether keyed-hash functions are to be used

>>> b = ScalableBloomFilter(initial_capacity=512, error_rate=0.001, \
mode=ScalableBloomFilter.SMALL_SET_GROWTH)
mode=ScalableBloomFilter.SMALL_SET_GROWTH,\
hashmac = 1)
>>> b.add("test")
False
>>> "test" in b
Expand All @@ -321,20 +373,20 @@ def __init__(self, initial_capacity=100, error_rate=0.001,
"""
if not error_rate or error_rate < 0:
raise ValueError("Error_Rate must be a decimal less than 0.")
self._setup(mode, 0.9, initial_capacity, error_rate)
self._setup(mode, 0.9, initial_capacity, error_rate, hashmac)
self.filters = []

def _setup(self, mode, ratio, initial_capacity, error_rate):
def _setup(self, mode, ratio, initial_capacity, error_rate, hashmac):
self.scale = mode
self.ratio = ratio
self.initial_capacity = initial_capacity
self.error_rate = error_rate
self.hashmac = hashmac

def __contains__(self, key):
"""Tests a key's membership in this bloom filter.

>>> b = ScalableBloomFilter(initial_capacity=100, error_rate=0.001, \
mode=ScalableBloomFilter.SMALL_SET_GROWTH)
>>> b = ScalableBloomFilter(initial_capacity=100, error_rate=0.001, mode=ScalableBloomFilter.SMALL_SET_GROWTH, hashmac=0)
>>> b.add("hello")
False
>>> "hello" in b
Expand All @@ -352,7 +404,7 @@ def add(self, key):
Otherwise False.

>>> b = ScalableBloomFilter(initial_capacity=100, error_rate=0.001, \
mode=ScalableBloomFilter.SMALL_SET_GROWTH)
mode=ScalableBloomFilter.SMALL_SET_GROWTH, hashmac=0)
>>> b.add("hello")
False
>>> b.add("hello")
Expand Down Expand Up @@ -431,6 +483,10 @@ def __len__(self):
return sum(f.count for f in self.filters)




if __name__ == "__main__":
import doctest
doctest.testmod()