diff --git a/cache.go b/cache.go index e9f6dea..0648807 100644 --- a/cache.go +++ b/cache.go @@ -241,31 +241,24 @@ func (c *Cache[K, V]) GetOrSet(key K, val V, opts ...ItemOption) (actual V, load // DeleteExpired all expired items from the cache. func (c *Cache[K, V]) DeleteExpired() { - c.mu.Lock() - l := c.expManager.len() - c.mu.Unlock() - evict := func() bool { - key := c.expManager.pop() - // if is expired, delete it and return nil instead - item, ok := c.cache.Get(key) - if ok { - if item.Expired() { - c.cache.Delete(key) - return false - } - c.expManager.update(key, item.Expiration) + c.mu.Lock() + defer c.mu.Unlock() + if c.expManager.len() == 0 { + return false + } + key, expiration := c.expManager.pop() + if nowFunc().After(expiration) { + c.cache.Delete(key) + return true } - return true + if _, ok := c.cache.Get(key); ok { + c.expManager.update(key, expiration) + } + return false } - for i := 0; i < l; i++ { - c.mu.Lock() - shouldBreak := evict() - c.mu.Unlock() - if shouldBreak { - break - } + for evict() { } } diff --git a/cache_test.go b/cache_test.go index 5e2cef2..c2fe02d 100644 --- a/cache_test.go +++ b/cache_test.go @@ -2,7 +2,9 @@ package cache_test import ( "math/rand" + "strconv" "sync" + "sync/atomic" "testing" "time" @@ -121,3 +123,38 @@ func TestCallJanitor(t *testing.T) { t.Errorf("want items is empty but got %d", len(keys)) } } + +func TestConcurrentDelete(t *testing.T) { + c := cache.New[string, int]() + var ( + wg sync.WaitGroup + stop atomic.Bool + timeout = 10 * time.Second + ) + + if testing.Short() { + timeout = 100 * time.Millisecond + } + time.AfterFunc(timeout, func() { + stop.Store(true) + }) + + wg.Add(1) + go func() { + defer wg.Done() + for k := 1; !stop.Load(); k++ { + c.Set(strconv.Itoa(k), k, cache.WithExpiration(0)) + c.Delete(strconv.Itoa(k)) + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + for !stop.Load() { + c.DeleteExpired() + } + }() + + wg.Wait() +} diff --git a/expiration.go b/expiration.go index 44ee4ea..ee5d9c2 100644 --- a/expiration.go +++ b/expiration.go @@ -37,11 +37,12 @@ func (m *expirationManager[K]) len() int { return m.queue.Len() } -func (m *expirationManager[K]) pop() K { +func (m *expirationManager[K]) pop() (K, time.Time) { v := heap.Pop(&m.queue) key := v.(*expirationKey[K]).key + exp := v.(*expirationKey[K]).expiration delete(m.mapping, key) - return key + return key, exp } func (m *expirationManager[K]) remove(key K) {