|
| 1 | +import functools |
| 2 | +from django.db.models.fields.related import ManyToManyField, ReverseManyRelatedObjectsDescriptor, ManyRelatedObjectsDescriptor |
| 3 | +from django.db.models.query import QuerySet |
| 4 | +from django.db.models import signals |
| 5 | +from cache import cache |
| 6 | +from types import MethodType |
| 7 | + |
| 8 | +CACHE_DURATION = 60 * 30 |
| 9 | + |
| 10 | +def invalidate_cache(obj, field): |
| 11 | + cache.set(obj._get_cache_key(field=field), None, 5) |
| 12 | + |
| 13 | +def fix_where(where, modified=False): |
| 14 | + def wrap_add(f): |
| 15 | + @functools.wraps(f) |
| 16 | + def add(self, *args, **kwargs): |
| 17 | + """ |
| 18 | + Wraps django.db.models.sql.where.add to indicate that a new |
| 19 | + 'where' condition has been added. |
| 20 | + """ |
| 21 | + self.modified = True |
| 22 | + return f(*args, **kwargs) |
| 23 | + return add |
| 24 | + where.modified = modified |
| 25 | + where.add = MethodType(wrap_add(where.add), where, where.__class__) |
| 26 | + return where |
| 27 | + |
| 28 | + |
| 29 | +def get_pk_list_query_set(superclass): |
| 30 | + class PKListQuerySet(superclass): |
| 31 | + """ |
| 32 | + QuerySet that, when unfiltered, fetches objects individually from |
| 33 | + the datastore by pk. |
| 34 | +
|
| 35 | + The `pk_list` attribute is a list of primary keys for objects that |
| 36 | + should be fetched. |
| 37 | +
|
| 38 | + """ |
| 39 | + def __init__(self, pk_list=[], from_cache=False, *args, **kwargs): |
| 40 | + super(PKListQuerySet, self).__init__(*args, **kwargs) |
| 41 | + self.pk_list = pk_list |
| 42 | + self.from_cache = from_cache |
| 43 | + self.query.where = fix_where(self.query.where) |
| 44 | + |
| 45 | + def iterator(self): |
| 46 | + if not self.query.where.modified: |
| 47 | + for pk in self.pk_list: |
| 48 | + yield self.model._default_manager.get(pk=pk) |
| 49 | + else: |
| 50 | + superiter = super(PKListQuerySet, self).iterator() |
| 51 | + while True: |
| 52 | + yield superiter.next() |
| 53 | + |
| 54 | + def _clone(self, *args, **kwargs): |
| 55 | + c = super(PKListQuerySet, self)._clone(*args, **kwargs) |
| 56 | + c.query.where = fix_where(c.query.where, modified=self.query.where.modified) |
| 57 | + c.pk_list = self.pk_list |
| 58 | + c.from_cache = self.from_cache |
| 59 | + return c |
| 60 | + return PKListQuerySet |
| 61 | + |
| 62 | + |
| 63 | +def get_caching_related_manager(superclass, instance, field_name, related_name): |
| 64 | + class CachingRelatedManager(superclass): |
| 65 | + def all(self): |
| 66 | + key = instance._get_cache_key(field=field_name) |
| 67 | + qs = super(CachingRelatedManager, self).get_query_set() |
| 68 | + PKListQuerySet = get_pk_list_query_set(qs.__class__) |
| 69 | + qs = qs._clone(klass=PKListQuerySet) |
| 70 | + pk_list = cache.get(key) |
| 71 | + if pk_list is None: |
| 72 | + pk_list = qs.values_list('pk', flat=True) |
| 73 | + cache.add(key, pk_list, CACHE_DURATION) |
| 74 | + else: |
| 75 | + qs.from_cache = True |
| 76 | + qs.pk_list = pk_list |
| 77 | + return qs |
| 78 | + |
| 79 | + def add(self, *objs): |
| 80 | + super(CachingRelatedManager, self).add(*objs) |
| 81 | + for obj in objs: |
| 82 | + invalidate_cache(obj, related_name) |
| 83 | + invalidate_cache(instance, field_name) |
| 84 | + |
| 85 | + def remove(self, *objs): |
| 86 | + super(CachingRelatedManager, self).remove(*objs) |
| 87 | + for obj in objs: |
| 88 | + invalidate_cache(obj, related_name) |
| 89 | + invalidate_cache(instance, field_name) |
| 90 | + |
| 91 | + def clear(self): |
| 92 | + objs = list(self.all()) |
| 93 | + super(CachingRelatedManager, self).clear() |
| 94 | + for obj in objs: |
| 95 | + invalidate_cache(obj, related_name) |
| 96 | + invalidate_cache(instance, field_name) |
| 97 | + return CachingRelatedManager |
| 98 | + |
| 99 | + |
| 100 | +class CachingReverseManyRelatedObjectsDescriptor(ReverseManyRelatedObjectsDescriptor): |
| 101 | + def __get__(self, instance, cls=None): |
| 102 | + manager = super(CachingReverseManyRelatedObjectsDescriptor, self).__get__(instance, cls) |
| 103 | + |
| 104 | + CachingRelatedManager = get_caching_related_manager(manager.__class__, |
| 105 | + instance, |
| 106 | + self.field.name, |
| 107 | + self.field.rel.related_name) |
| 108 | + |
| 109 | + manager.__class__ = CachingRelatedManager |
| 110 | + return manager |
| 111 | + |
| 112 | + |
| 113 | +class CachingManyRelatedObjectsDescriptor(ManyRelatedObjectsDescriptor): |
| 114 | + def __get__(self, instance, cls=None): |
| 115 | + manager = super(CachingManyRelatedObjectsDescriptor, self).__get__(instance, cls) |
| 116 | + |
| 117 | + CachingRelatedManager = get_caching_related_manager(manager.__class__, |
| 118 | + instance, |
| 119 | + self.related.get_accessor_name(), |
| 120 | + self.related.field.name) |
| 121 | + |
| 122 | + manager.__class__ = CachingRelatedManager |
| 123 | + return manager |
| 124 | + |
| 125 | + |
| 126 | +class CachingManyToManyField(ManyToManyField): |
| 127 | + def contribute_to_class(self, cls, name): |
| 128 | + super(CachingManyToManyField, self).contribute_to_class(cls, name) |
| 129 | + setattr(cls, self.name, CachingReverseManyRelatedObjectsDescriptor(self)) |
| 130 | + |
| 131 | + def contribute_to_related_class(self, cls, related): |
| 132 | + super(CachingManyToManyField, self).contribute_to_related_class(cls, related) |
| 133 | + setattr(cls, related.get_accessor_name(), CachingManyRelatedObjectsDescriptor(related)) |
0 commit comments