-
Notifications
You must be signed in to change notification settings - Fork 168
/
Copy pathexpiring_dict.py
161 lines (139 loc) · 5.62 KB
/
expiring_dict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# pyre-ignore-all-errors[14, 15, 58]
from __future__ import annotations
from collections import OrderedDict
from collections.abc import Iterable
from threading import RLock
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from sc2.bot_ai import BotAI
class ExpiringDict(OrderedDict):
"""
An expiring dict that uses the bot.state.game_loop to only return items that are valid for a specific amount of time.
Example usages::
async def on_step(iteration: int):
# This dict will hold up to 10 items and only return values that have been added up to 20 frames ago
my_dict = ExpiringDict(self, max_age_frames=20)
if iteration == 0:
# Add item
my_dict["test"] = "something"
if iteration == 2:
# On default, one iteration is called every 8 frames
if "test" in my_dict:
print("test is in dict")
if iteration == 20:
if "test" not in my_dict:
print("test is not anymore in dict")
"""
def __init__(self, bot: BotAI, max_age_frames: int = 1) -> None:
assert max_age_frames >= -1
assert bot
OrderedDict.__init__(self)
self.bot: BotAI = bot
self.max_age: int | float = max_age_frames
self.lock: RLock = RLock()
@property
def frame(self) -> int:
# pyre-ignore[16]
return self.bot.state.game_loop
def __contains__(self, key) -> bool:
"""Return True if dict has key, else False, e.g. 'key in dict'"""
with self.lock:
if OrderedDict.__contains__(self, key):
# Each item is a list of [value, frame time]
item = OrderedDict.__getitem__(self, key)
if self.frame - item[1] < self.max_age:
return True
del self[key]
return False
def __getitem__(self, key, with_age: bool = False) -> Any:
"""Return the item of the dict using d[key]"""
with self.lock:
# Each item is a list of [value, frame time]
item = OrderedDict.__getitem__(self, key)
if self.frame - item[1] < self.max_age:
if with_age:
return item[0], item[1]
return item[0]
OrderedDict.__delitem__(self, key)
raise KeyError(key)
def __setitem__(self, key, value) -> None:
"""Set d[key] = value"""
with self.lock:
OrderedDict.__setitem__(self, key, (value, self.frame))
def __repr__(self) -> str:
"""Printable version of the dict instead of getting memory adress"""
print_list = []
with self.lock:
for key, value in OrderedDict.items(self):
if self.frame - value[1] < self.max_age:
print_list.append(f"{repr(key)}: {repr(value)}")
print_str = ", ".join(print_list)
return f"ExpiringDict({print_str})"
def __str__(self):
return self.__repr__()
def __iter__(self):
"""Override 'for key in dict:'"""
with self.lock:
return self.keys()
# TODO find a way to improve len
def __len__(self) -> int:
"""Override len method as key value pairs aren't instantly being deleted, but only on __get__(item).
This function is slow because it has to check if each element is not expired yet."""
with self.lock:
count = 0
for _ in self.values():
count += 1
return count
def pop(self, key, default=None, with_age: bool = False):
"""Return the item and remove it"""
with self.lock:
if OrderedDict.__contains__(self, key):
item = OrderedDict.__getitem__(self, key)
if self.frame - item[1] < self.max_age:
del self[key]
if with_age:
return item[0], item[1]
return item[0]
del self[key]
if default is None:
raise KeyError(key)
if with_age:
return default, self.frame
return default
def get(self, key, default=None, with_age: bool = False):
"""Return the value for key if key is in dict, else default"""
with self.lock:
if OrderedDict.__contains__(self, key):
item = OrderedDict.__getitem__(self, key)
if self.frame - item[1] < self.max_age:
if with_age:
return item[0], item[1]
return item[0]
if default is None:
raise KeyError(key)
if with_age:
return default, self.frame
return None
return None
def update(self, other_dict: dict) -> None:
with self.lock:
for key, value in other_dict.items():
self[key] = value
def items(self) -> Iterable:
"""Return iterator of zipped list [keys, values]"""
with self.lock:
for key, value in OrderedDict.items(self):
if self.frame - value[1] < self.max_age:
yield key, value[0]
def keys(self) -> Iterable:
"""Return iterator of keys"""
with self.lock:
for key, value in OrderedDict.items(self):
if self.frame - value[1] < self.max_age:
yield key
def values(self) -> Iterable:
"""Return iterator of values"""
with self.lock:
for value in OrderedDict.values(self):
if self.frame - value[1] < self.max_age:
yield value[0]