diff --git a/pybloom/pybloom.py b/pybloom/pybloom.py index beeefe4..6f9ef61 100644 --- a/pybloom/pybloom.py +++ b/pybloom/pybloom.py @@ -6,7 +6,7 @@ Requires the bitarray library: http://pypi.python.org/pypi/bitarray/ >>> from pybloom import BloomFilter - >>> f = BloomFilter(capacity=10000, error_rate=0.001) + >>> f = BloomFilter(capacity=10000, error_rate=0.001, hashmac=1) >>> for i in range_fn(0, f.capacity): ... _ = f.add(i) ... @@ -20,7 +20,7 @@ True >>> from pybloom import ScalableBloomFilter - >>> sbf = ScalableBloomFilter(mode=ScalableBloomFilter.SMALL_SET_GROWTH) + >>> sbf = ScalableBloomFilter(mode=ScalableBloomFilter.SMALL_SET_GROWTH, hashmac=0) >>> count = 10000 >>> for i in range_fn(0, count): ... _ = sbf.add(i) @@ -36,7 +36,9 @@ from __future__ import absolute_import import math import hashlib -from pybloom.utils import range_fn, is_string_io, running_python_3 +import hmac +import random +from utils import range_fn, is_string_io, running_python_3 from struct import unpack, pack, calcsize try: @@ -50,8 +52,10 @@ Alex Brasetvik ,\ Matt Bachmann ,\ " +# the last parameter hashmac determines whether +# HMAC is to be used -def make_hashfuncs(num_slices, num_bits): +def make_hashfuncs(num_slices, num_bits, hashmac): if num_bits >= (1 << 31): fmt_code, chunk_size = 'Q', 8 elif num_bits >= (1 << 15): @@ -59,21 +63,61 @@ def make_hashfuncs(num_slices, num_bits): else: fmt_code, chunk_size = 'H', 2 total_hash_bits = 8 * num_slices * chunk_size + if total_hash_bits > 384: - hashfn = hashlib.sha512 + digest_size = 64 + if hashmac==0: + hashfn = hashlib.sha512 + else: + hmackey = random.getrandbits(digest_size) + hashfn = hmac elif total_hash_bits > 256: - hashfn = hashlib.sha384 + digest_size = 48 + if hashmac==0: + hashfn = hashlib.sha384 + else: + hmackey = random.getrandbits(digest_size) + hashfn = hmac elif total_hash_bits > 160: - hashfn = hashlib.sha256 + digest_size = 32 + if hashmac==0: + hashfn = hashlib.sha256 + else: + hmackey = random.getrandbits(digest_size) + hashfn = hmac elif total_hash_bits > 128: - hashfn = hashlib.sha1 + digest_size = 20 + if hashmac==0: + hashfn = hashlib.sha1 + else: + hmackey = random.getrandbits(digest_size) + hashfn = hmac else: - hashfn = hashlib.md5 - fmt = fmt_code * (hashfn().digest_size // chunk_size) + digest_size=16 + if hashmac==0: + hashfn = hashlib.md5 + else: + hmackey = random.getrandbits(digest_size) + hashfn = hmac + fmt = fmt_code * (digest_size // chunk_size) num_salts, extra = divmod(num_slices, len(fmt)) if extra: num_salts += 1 - salts = tuple(hashfn(hashfn(pack('I', i)).digest()) for i in range_fn(num_salts)) + + if hashmac==0: + salts = tuple(hashfn(hashfn(pack('I', i)).digest()) for i in range_fn(num_salts)) + else: + if digest_size==64: + salts = tuple(hashfn.new(str(hmackey),hashfn.new(str(hmackey), pack('I',i), hashlib.sha512).digest(), hashlib.sha512) for i in range_fn(num_salts)) + elif digest_size==48: + salts =tuple(hashfn.new(str(hmackey),hashfn.new(str(hmackey), pack('I',i), hashlib.sha384).digest(), hashlib.sha384) for i in range_fn(num_salts)) + elif digest_size==32: + salts = tuple(hashfn.new(str(hmackey),hashfn.new(str(hmackey), pack('I',i), hashlib.sha256).digest(), hashlib.sha256) for i in range_fn(num_salts)) + elif digest_size==20: + salts = tuple(hashfn.new(str(hmackey),hashfn.new(str(hmackey), pack('I',i), hashlib.sha1).digest(), hashlib.sha1) for i in range_fn(num_salts)) + else: + salts = tuple(hashfn.new(str(hmackey),hashfn.new(str(hmackey), pack('I',i), hashlib.md5).digest(), hashlib.md5) for i in range_fn(num_salts)) + def _make_hashfuncs(key): if running_python_3: if isinstance(key, str): @@ -89,6 +133,7 @@ def _make_hashfuncs(key): for salt in salts: h = salt.copy() h.update(key) + h.hexdigest() for uint in unpack(fmt, h.digest()): yield uint % num_bits i += 1 @@ -101,7 +146,7 @@ def _make_hashfuncs(key): class BloomFilter(object): FILE_FMT = b'>> b = BloomFilter(capacity=100000, error_rate=0.001) + >>> b = BloomFilter(capacity=100000, error_rate=0.001, hashmac=0) >>> b.add("test") False >>> "test" in b @@ -134,18 +181,19 @@ def __init__(self, capacity, error_rate=0.001): bits_per_slice = int(math.ceil( (capacity * abs(math.log(error_rate))) / (num_slices * (math.log(2) ** 2)))) - self._setup(error_rate, num_slices, bits_per_slice, capacity, 0) + self._setup(error_rate, num_slices, bits_per_slice, capacity,0, hashmac) self.bitarray = bitarray.bitarray(self.num_bits, endian='little') self.bitarray.setall(False) - def _setup(self, error_rate, num_slices, bits_per_slice, capacity, count): + def _setup(self, error_rate, num_slices, bits_per_slice, capacity, count, hashmac): self.error_rate = error_rate self.num_slices = num_slices self.bits_per_slice = bits_per_slice self.capacity = capacity self.num_bits = num_slices * bits_per_slice self.count = count - self.make_hashes = make_hashfuncs(self.num_slices, self.bits_per_slice) + self.make_hashes = make_hashfuncs(self.num_slices, self.bits_per_slice, hashmac) + self.hashmac = hashmac def __contains__(self, key): """Tests a key's membership in this bloom filter. @@ -282,15 +330,16 @@ def __getstate__(self): def __setstate__(self, d): self.__dict__.update(d) - self.make_hashes = make_hashfuncs(self.num_slices, self.bits_per_slice) + self.make_hashes = make_hashfuncs(self.num_slices, self.bits_per_slice, self.hashmac) + + class ScalableBloomFilter(object): SMALL_SET_GROWTH = 2 # slower, but takes up less memory LARGE_SET_GROWTH = 4 # faster, but takes up more memory faster FILE_FMT = '>> b = ScalableBloomFilter(initial_capacity=512, error_rate=0.001, \ - mode=ScalableBloomFilter.SMALL_SET_GROWTH) + mode=ScalableBloomFilter.SMALL_SET_GROWTH,\ + hashmac = 1) >>> b.add("test") False >>> "test" in b @@ -321,20 +373,20 @@ def __init__(self, initial_capacity=100, error_rate=0.001, """ if not error_rate or error_rate < 0: raise ValueError("Error_Rate must be a decimal less than 0.") - self._setup(mode, 0.9, initial_capacity, error_rate) + self._setup(mode, 0.9, initial_capacity, error_rate, hashmac) self.filters = [] - def _setup(self, mode, ratio, initial_capacity, error_rate): + def _setup(self, mode, ratio, initial_capacity, error_rate, hashmac): self.scale = mode self.ratio = ratio self.initial_capacity = initial_capacity self.error_rate = error_rate + self.hashmac = hashmac def __contains__(self, key): """Tests a key's membership in this bloom filter. - >>> b = ScalableBloomFilter(initial_capacity=100, error_rate=0.001, \ - mode=ScalableBloomFilter.SMALL_SET_GROWTH) + >>> b = ScalableBloomFilter(initial_capacity=100, error_rate=0.001, mode=ScalableBloomFilter.SMALL_SET_GROWTH, hashmac=0) >>> b.add("hello") False >>> "hello" in b @@ -352,7 +404,7 @@ def add(self, key): Otherwise False. >>> b = ScalableBloomFilter(initial_capacity=100, error_rate=0.001, \ - mode=ScalableBloomFilter.SMALL_SET_GROWTH) + mode=ScalableBloomFilter.SMALL_SET_GROWTH, hashmac=0) >>> b.add("hello") False >>> b.add("hello") @@ -431,6 +483,10 @@ def __len__(self): return sum(f.count for f in self.filters) + + if __name__ == "__main__": import doctest doctest.testmod() + +