diff --git a/aikido_zen/background_process/commands/__init__.py b/aikido_zen/background_process/commands/__init__.py index dd244219..0cd75ccc 100644 --- a/aikido_zen/background_process/commands/__init__.py +++ b/aikido_zen/background_process/commands/__init__.py @@ -3,7 +3,6 @@ from aikido_zen.helpers.logging import logger from .attack import process_attack from .read_property import process_read_property -from .initialize_route import process_initialize_route from .user import process_user from .should_ratelimit import process_should_ratelimit from .kill import process_kill @@ -16,7 +15,6 @@ # This maps to a tuple : (function, returns_data?) # Commands that don't return data : "ATTACK": (process_attack, False), - "INITIALIZE_ROUTE": (process_initialize_route, False), "USER": (process_user, False), "KILL": (process_kill, False), "STATISTICS": (process_statistics, False), diff --git a/aikido_zen/background_process/commands/initialize_route.py b/aikido_zen/background_process/commands/initialize_route.py deleted file mode 100644 index 83737e98..00000000 --- a/aikido_zen/background_process/commands/initialize_route.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Exports `process_initialize_route`""" - - -def process_initialize_route(connection_manager, data, queue=None): - """ - This is called the first time a route is discovered to initialize it and add one hit. - data is a dictionary called route_metadata which includes: route, method and url. - """ - if connection_manager: - connection_manager.routes.initialize_route(route_metadata=data) - connection_manager.routes.increment_route(route_metadata=data) diff --git a/aikido_zen/background_process/commands/initialize_route_test.py b/aikido_zen/background_process/commands/initialize_route_test.py deleted file mode 100644 index 69e7abbb..00000000 --- a/aikido_zen/background_process/commands/initialize_route_test.py +++ /dev/null @@ -1,37 +0,0 @@ -import pytest -from unittest.mock import MagicMock, patch -from .initialize_route import process_initialize_route - - -@pytest.fixture -def mock_connection_manager(): - """Fixture to create a mock connection_manager with a routes attribute.""" - connection_manager = MagicMock() - connection_manager.routes = MagicMock() - return connection_manager - - -def test_process_initialize_route(mock_connection_manager): - """Test that process_initialize_route adds a route when connection_manager is present.""" - data = 123456 - - process_initialize_route( - mock_connection_manager, data, None - ) # conn is not used in this function - - # Check that increment_route and initialize_route methods were called with the correct arguments - mock_connection_manager.routes.initialize_route.assert_called_once_with( - route_metadata=123456 - ) - mock_connection_manager.routes.increment_route.assert_called_once_with( - route_metadata=123456 - ) - - -def test_process_initialize_route_no_connection_manager(): - """Test that process_initialize_route does nothing when connection_manager is not present.""" - data = 123456 - - process_initialize_route(None, data, None) # conn is not used in this function - - # Check that no error occurs diff --git a/aikido_zen/background_process/routes/__init__.py b/aikido_zen/background_process/routes/__init__.py index 27b21d53..094ff6e6 100644 --- a/aikido_zen/background_process/routes/__init__.py +++ b/aikido_zen/background_process/routes/__init__.py @@ -19,20 +19,17 @@ def __init__(self, max_size=1000): def initialize_route(self, route_metadata): """ - Initializes a route for the first time. + Initializes a route for the first time. `hits_delta_since_sync` counts delta between syncs. """ self.manage_routes_size() key = route_to_key(route_metadata) - if self.routes.get(key): - return self.routes[key] = { "method": route_metadata.get("method"), "path": route_metadata.get("route"), "hits": 0, + "hits_delta_since_sync": 0, "apispec": {}, } - # This field counts the difference in hits in between synchronisation for threads : - self.routes[key]["hits_delta_since_sync"] = 0 def increment_route(self, route_metadata): """ @@ -41,8 +38,8 @@ def increment_route(self, route_metadata): """ key = route_to_key(route_metadata) if not self.routes.get(key): - return - # Add a hit to the route : + self.initialize_route(route_metadata) + # Add hits to route = self.routes.get(key) route["hits"] += 1 route["hits_delta_since_sync"] += 1 diff --git a/aikido_zen/background_process/routes/init_test.py b/aikido_zen/background_process/routes/init_test.py index 73dd9dd5..44471f62 100644 --- a/aikido_zen/background_process/routes/init_test.py +++ b/aikido_zen/background_process/routes/init_test.py @@ -102,13 +102,6 @@ def test_increment_route_twice(): assert routes.routes["GET:/api/resource"]["hits"] == 2 -def test_increment_route_that_does_not_exist(): - routes = Routes(max_size=3) - routes.increment_route(gen_route_metadata(route="/api/resource")) - routes.increment_route(gen_route_metadata(route="/api/resource")) - assert len(routes.routes) == 0 - - def test_clear_routes(): routes = Routes(max_size=3) routes.initialize_route(gen_route_metadata(route="/api/resource")) diff --git a/aikido_zen/middleware/init_test.py b/aikido_zen/middleware/init_test.py index dea1e69a..8ee6f2a2 100644 --- a/aikido_zen/middleware/init_test.py +++ b/aikido_zen/middleware/init_test.py @@ -2,16 +2,18 @@ import pytest from aikido_zen.context import current_context, Context, get_current_context -from aikido_zen.thread.thread_cache import ThreadCache, threadlocal_storage +from aikido_zen.thread.thread_cache import ThreadCache, get_cache from . import should_block_request @pytest.fixture(autouse=True) def run_around_tests(): + get_cache().reset() yield # Make sure to reset context and cache after every test so it does not # interfere with other tests current_context.set(None) + get_cache().reset() def test_without_context(): @@ -39,20 +41,15 @@ def set_context(user=None, executed_middleware=False): ).set_as_current_context() -class MyThreadCache(ThreadCache): - def renew_if_ttl_expired(self): - return - - def test_with_context_without_cache(): set_context() - threadlocal_storage.cache = None + get_cache().cache = None assert should_block_request() == {"block": False} def test_with_context_with_cache(): set_context(user={"id": "123"}) - thread_cache = MyThreadCache() + thread_cache = get_cache() thread_cache.config.blocked_uids = ["123"] assert get_current_context().executed_middleware == False @@ -76,7 +73,7 @@ def test_with_context_with_cache(): def test_cache_comms_with_endpoints(): set_context(user={"id": "456"}) - thread_cache = MyThreadCache() + thread_cache = get_cache() thread_cache.config.blocked_uids = ["123"] thread_cache.config.endpoints = [ { diff --git a/aikido_zen/sinks/tests/requests_test.py b/aikido_zen/sinks/tests/requests_test.py index ebc04dfd..628a78c7 100644 --- a/aikido_zen/sinks/tests/requests_test.py +++ b/aikido_zen/sinks/tests/requests_test.py @@ -1,7 +1,7 @@ import os import pytest from aikido_zen.context import Context, current_context -from aikido_zen.thread.thread_cache import ThreadCache, threadlocal_storage +from aikido_zen.thread.thread_cache import ThreadCache, get_cache from aikido_zen.errors import AikidoSSRF from aikido_zen.background_process.comms import reset_comms import aikido_zen.sinks.socket @@ -19,11 +19,12 @@ @pytest.fixture(autouse=True) def run_around_tests(): + get_cache().reset() yield # Make sure to reset context and cache after every test so it does not # interfere with other tests current_context.set(None) - setattr(threadlocal_storage, "cache", None) + get_cache().reset() def set_context_and_lifecycle(url): @@ -48,7 +49,6 @@ def set_context_and_lifecycle(url): source="flask", ) context.set_as_current_context() - ThreadCache() def test_srrf_test(monkeypatch): diff --git a/aikido_zen/sinks/tests/urrlib3_test.py b/aikido_zen/sinks/tests/urrlib3_test.py index aa827839..13efd24c 100644 --- a/aikido_zen/sinks/tests/urrlib3_test.py +++ b/aikido_zen/sinks/tests/urrlib3_test.py @@ -1,7 +1,7 @@ import os import pytest from aikido_zen.context import Context, current_context -from aikido_zen.thread.thread_cache import ThreadCache, threadlocal_storage +from aikido_zen.thread.thread_cache import get_cache from aikido_zen.errors import AikidoSSRF from aikido_zen.background_process.comms import reset_comms import aikido_zen.sinks.socket @@ -19,11 +19,12 @@ @pytest.fixture(autouse=True) def run_around_tests(): + get_cache().reset() yield # Make sure to reset context and cache after every test so it does not # interfere with other tests current_context.set(None) - setattr(threadlocal_storage, "cache", None) + get_cache().reset() def set_context_and_lifecycle(url, host=None): @@ -50,7 +51,6 @@ def set_context_and_lifecycle(url, host=None): source="flask", ) context.set_as_current_context() - ThreadCache() def test_srrf_test(monkeypatch): diff --git a/aikido_zen/sources/functions/request_handler.py b/aikido_zen/sources/functions/request_handler.py index 0ea7cb00..16b58788 100644 --- a/aikido_zen/sources/functions/request_handler.py +++ b/aikido_zen/sources/functions/request_handler.py @@ -1,6 +1,5 @@ """Exports request_handler function""" -import aikido_zen.background_process as communications import aikido_zen.context as ctx from aikido_zen.api_discovery.get_api_info import get_api_info from aikido_zen.api_discovery.update_route_info import update_route_info @@ -15,12 +14,9 @@ def request_handler(stage, status_code=0): """This will check for rate limiting, Allowed IP's, useful routes, etc.""" try: if stage == "init": - # Initial stage of the request, called after context is stored. thread_cache = get_cache() - thread_cache.renew_if_ttl_expired() # Only check TTL at the start of a request. if ctx.get_current_context() and thread_cache: thread_cache.increment_stats() # Increment request statistics if a context exists. - if stage == "pre_response": return pre_response() if stage == "post_response": @@ -76,9 +72,10 @@ def pre_response(): def post_response(status_code): """Checks if the current route is useful""" context = ctx.get_current_context() - comms = communications.get_comms() - if not context or not comms: + if not context: return + route_metadata = context.get_route_metadata() + is_curr_route_useful = is_useful_route( status_code, context.route, @@ -86,17 +83,12 @@ def post_response(status_code): ) if not is_curr_route_useful: return - route_metadata = context.get_route_metadata() + cache = get_cache() if cache: - route = cache.routes.get(route_metadata) - if not route: - # This route does not exist yet, initialize it: - cache.routes.initialize_route(route_metadata) - comms.send_data_to_bg_process("INITIALIZE_ROUTE", route_metadata) + cache.routes.increment_route(route_metadata) + # Run API Discovery : update_route_info( new_apispec=get_api_info(context), route=cache.routes.get(route_metadata) ) - # Add hit : - cache.routes.increment_route(route_metadata) diff --git a/aikido_zen/sources/functions/request_handler_test.py b/aikido_zen/sources/functions/request_handler_test.py index 9648d04e..44d73c90 100644 --- a/aikido_zen/sources/functions/request_handler_test.py +++ b/aikido_zen/sources/functions/request_handler_test.py @@ -1,7 +1,7 @@ import pytest from unittest.mock import patch, MagicMock -from aikido_zen.thread.thread_cache import get_cache, ThreadCache, threadlocal_storage -from .request_handler import request_handler, post_response +from aikido_zen.thread.thread_cache import get_cache, ThreadCache +from .request_handler import request_handler from ...background_process.service_config import ServiceConfig from ...context import Context, current_context @@ -22,19 +22,17 @@ def mock_context(): @pytest.fixture(autouse=True) def run_around_tests(): + get_cache().reset() current_context.set(None) yield - setattr(threadlocal_storage, "cache", None) + get_cache().reset() current_context.set(None) -@patch("aikido_zen.background_process.get_comms") -def test_post_response_useful_route(mock_get_comms, mock_context): +def test_post_response_useful_route(mock_context): """Test post_response when the route is useful.""" - comms = MagicMock() - mock_get_comms.return_value = comms - cache = ThreadCache() # Creates a new cache + cache = get_cache() # Creates a new cache assert cache.routes.routes == {} with patch("aikido_zen.context.get_current_context", return_value=mock_context): request_handler("post_response", status_code=200) @@ -50,15 +48,6 @@ def test_post_response_useful_route(mock_get_comms, mock_context): } } - comms.send_data_to_bg_process.assert_called_once_with( - "INITIALIZE_ROUTE", - { - "route": "/test/route", - "method": "GET", - "url": "http://localhost:8080/test/route", - }, - ) - @patch("aikido_zen.background_process.get_comms") def test_post_response_not_useful_route(mock_get_comms, mock_context): @@ -129,7 +118,7 @@ def create_service_config(blocked_ips=None): ) if blocked_ips: config.set_blocked_ips(blocked_ips) - ThreadCache().config = config + get_cache().config = config return config diff --git a/aikido_zen/thread/process_worker.py b/aikido_zen/thread/process_worker.py new file mode 100644 index 00000000..bdf14b7b --- /dev/null +++ b/aikido_zen/thread/process_worker.py @@ -0,0 +1,28 @@ +import multiprocessing +import time + +from aikido_zen.helpers.logging import logger +from aikido_zen.thread import thread_cache + +# Renew the cache from this background worker every 5 seconds +RENEW_CACHE_EVERY_X_SEC = 5 + + +def aikido_process_worker_thread(): + """ + process worker -> When a web server like gUnicorn makes new processes, and those have multiple threads, + Aikido process worker is linked to those new processes, so in essence it's 1 extra thread. This thread + is responsible for syncing statistics, route data, configuration, ... + """ + # Get the current process + current_process = multiprocessing.current_process() + + while True: + # Print information about the process + logger.debug( + f"Process ID: {current_process.pid}, Name: {current_process.name} - process_worker renewing thread cache." + ) + + # Renew the cache every 5 seconds + thread_cache.renew() + time.sleep(RENEW_CACHE_EVERY_X_SEC) diff --git a/aikido_zen/thread/process_worker_loader.py b/aikido_zen/thread/process_worker_loader.py new file mode 100644 index 00000000..bfb6a183 --- /dev/null +++ b/aikido_zen/thread/process_worker_loader.py @@ -0,0 +1,23 @@ +import multiprocessing +import threading + +from aikido_zen.context import get_current_context +from aikido_zen.thread.process_worker import aikido_process_worker_thread + + +def load_worker(): + """ + Loads in a new process worker if one does not already exist for the current process + """ + if get_current_context() is None: + return # don't start a worker if it's not related to a request. + + # The name is aikido-process-worker- + the current PID + thread_name = "aikido-process-worker-" + str(multiprocessing.current_process().pid) + if any([thread.name == thread_name for thread in threading.enumerate()]): + return # The thread already exists, returning. + + # Create a new daemon thread tht will handle communication to and from background agent + thread = threading.Thread(target=aikido_process_worker_thread, name=thread_name) + thread.daemon = True + thread.start() diff --git a/aikido_zen/thread/thread_cache.py b/aikido_zen/thread/thread_cache.py index 0dd8dcc1..0b6fe1ef 100644 --- a/aikido_zen/thread/thread_cache.py +++ b/aikido_zen/thread/thread_cache.py @@ -1,39 +1,20 @@ """Exports class ThreadConfig""" -from threading import local import aikido_zen.background_process.comms as comms -import aikido_zen.helpers.get_current_unixtime_ms as t -from aikido_zen.context import get_current_context from aikido_zen.background_process.routes import Routes from aikido_zen.background_process.service_config import ServiceConfig +from aikido_zen.context import get_current_context from aikido_zen.helpers.logging import logger - -THREAD_CONFIG_TTL_MS = 60 * 1000 # Time-To-Live is 60 seconds for the thread cache - -threadlocal_storage = local() - - -def get_cache(): - """Returns the current ThreadCache""" - cache = getattr(threadlocal_storage, "cache", None) - if not cache: - return ThreadCache() - return cache +from aikido_zen.thread import process_worker_loader class ThreadCache: """ - A thread-local cache object that holds routes, bypassed ips, endpoints amount of requests - With a Time-To-Live given by THREAD_CONFIG_TTL_MS + A process-local cache object that holds routes, bypassed ips, endpoints amount of requests """ def __init__(self): - # Load initial data : - self.reset() - self.renew() - - # Save as a thread-local object : - threadlocal_storage.cache = self + self.reset() # Initialize values def is_bypassed_ip(self, ip): """Checks the given IP against the list of bypassed ips""" @@ -43,14 +24,6 @@ def is_user_blocked(self, user_id): """Checks if the user id is blocked""" return user_id in self.config.blocked_uids - def renew_if_ttl_expired(self): - """Renews the data only if TTL has expired""" - ttl_has_expired = ( - t.get_unixtime_ms(monotonic=True) - self.last_renewal > THREAD_CONFIG_TTL_MS - ) - if ttl_has_expired: - self.renew() - def get_endpoints(self): return self.config.endpoints @@ -65,18 +38,13 @@ def reset(self): received_any_stats=False, ) self.reqs = 0 - self.last_renewal = 0 self.middleware_installed = False def renew(self): - """ - Makes an IPC call to store the amount of hits and requests and renew the config - """ - # Don't try to fetch a thread cache if communications don't work or - # if we are not inside the context of a web request - if not comms.get_comms() or not get_current_context(): + if not comms.get_comms(): return + # send stored data and receive new config and routes res = comms.get_comms().send_data_to_bg_process( action="SYNC_DATA", obj={ @@ -86,18 +54,40 @@ def renew(self): }, receive=True, ) + if not res["success"] or not res["data"]: + return self.reset() - if res["success"] and res["data"]: - if isinstance(res["data"].get("config"), ServiceConfig): - self.config = res["data"]["config"] - if isinstance(res["data"].get("routes"), dict): - self.routes.routes = res["data"]["routes"] - for route in self.routes.routes.values(): - route["hits_delta_since_sync"] = 0 - self.last_renewal = t.get_unixtime_ms(monotonic=True) - logger.debug("Renewed thread cache") + # update config + if isinstance(res["data"].get("config"), ServiceConfig): + self.config = res["data"]["config"] + + # update routes + if isinstance(res["data"].get("routes"), dict): + self.routes.routes = res["data"]["routes"] + for route in self.routes.routes.values(): + route["hits_delta_since_sync"] = 0 def increment_stats(self): """Increments the requests""" self.reqs += 1 + + +# For these 2 functions and the data they process, we rely on Python's GIL +# See here: https://wiki.python.org/moin/GlobalInterpreterLock +global_thread_cache = ThreadCache() + + +def get_cache(): + """ + Returns the cache, protected by Python's GIL (so not our own mutex), + and starts the process worker (which syncs info between the cache and agent), if it doesn't already exist. + """ + global global_thread_cache + process_worker_loader.load_worker() + return global_thread_cache + + +def renew(): + global global_thread_cache + global_thread_cache.renew() diff --git a/aikido_zen/thread/thread_cache_test.py b/aikido_zen/thread/thread_cache_test.py index 284ad56e..7259d227 100644 --- a/aikido_zen/thread/thread_cache_test.py +++ b/aikido_zen/thread/thread_cache_test.py @@ -1,7 +1,7 @@ import pytest from unittest.mock import patch, MagicMock from aikido_zen.background_process.routes import Routes -from .thread_cache import ThreadCache, THREAD_CONFIG_TTL_MS, threadlocal_storage +from .thread_cache import ThreadCache, get_cache from ..background_process.service_config import ServiceConfig from ..context import current_context, Context from aikido_zen.helpers.iplist import IPList @@ -24,7 +24,7 @@ def run_around_tests(): yield # Make sure to reset thread cache after every test so it does not # interfere with other tests - setattr(threadlocal_storage, "cache", None) + get_cache().reset() current_context.set(None) @@ -35,7 +35,6 @@ def test_initialization(thread_cache: ThreadCache): assert thread_cache.get_endpoints() == [] assert thread_cache.config.blocked_uids == set() assert thread_cache.reqs == 0 - assert thread_cache.last_renewal == 0 def test_is_bypassed_ip(thread_cache: ThreadCache): @@ -55,70 +54,6 @@ def test_is_user_blocked(thread_cache: ThreadCache): assert thread_cache.is_user_blocked("user456") is False -@patch("aikido_zen.background_process.comms.get_comms") -@patch("aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms") -def test_renew_if_ttl_expired( - mock_get_unixtime_ms, mock_get_comms, thread_cache: ThreadCache -): - """Test renewing the cache if TTL has expired.""" - mock_get_unixtime_ms.return_value = ( - THREAD_CONFIG_TTL_MS + 1 - ) # Simulate TTL expiration - mock_get_comms.return_value = MagicMock() - mock_get_comms.return_value.send_data_to_bg_process.return_value = { - "success": True, - "data": { - "config": ServiceConfig( - endpoints=[ - { - "graphql": False, - "method": "POST", - "route": "/v2", - "rate_limiting": { - "enabled": False, - }, - "force_protection_off": False, - } - ], - bypassed_ips=["192.168.1.1"], - blocked_uids={"user123"}, - last_updated_at=-1, - received_any_stats=True, - ), - "routes": {}, - }, - } - - thread_cache.renew_if_ttl_expired() - assert thread_cache.is_bypassed_ip("192.168.1.1") - assert thread_cache.get_endpoints() == [ - { - "graphql": False, - "method": "POST", - "route": "/v2", - "rate_limiting": { - "enabled": False, - }, - "force_protection_off": False, - } - ] - assert thread_cache.is_user_blocked("user123") - assert thread_cache.last_renewal > 0 - - -@patch("aikido_zen.background_process.comms.get_comms") -@patch("aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms") -def test_renew_if_ttl_not_expired( - mock_get_unixtime_ms, mock_get_comms, thread_cache: ThreadCache -): - """Test that renew is not called if TTL has not expired.""" - mock_get_unixtime_ms.return_value = 0 # Simulate TTL not expired - thread_cache.last_renewal = 0 # Set last renewal to 0 - - thread_cache.renew_if_ttl_expired() - assert thread_cache.last_renewal == 0 # Should not change - - def test_reset(thread_cache: ThreadCache): """Test that reset empties the cache.""" thread_cache.config.bypassed_ips.add("192.168.1.1") @@ -128,7 +63,6 @@ def test_reset(thread_cache: ThreadCache): assert isinstance(thread_cache.config.bypassed_ips, IPList) assert thread_cache.config.blocked_uids == set() assert thread_cache.reqs == 0 - assert thread_cache.last_renewal == 0 def test_increment_stats(thread_cache): @@ -148,7 +82,6 @@ def test_renew_with_no_comms(thread_cache: ThreadCache): assert thread_cache.get_endpoints() == [] assert thread_cache.config.blocked_uids == set() assert thread_cache.reqs == 0 - assert thread_cache.last_renewal == 0 @patch("aikido_zen.background_process.comms.get_comms") @@ -167,7 +100,6 @@ def test_renew_with_invalid_response(mock_get_comms, thread_cache: ThreadCache): assert isinstance(thread_cache.config.bypassed_ips, IPList) assert thread_cache.get_endpoints() == [] assert thread_cache.config.blocked_uids == set() - assert thread_cache.last_renewal > 0 # Should update last_renewal def test_is_bypassed_ip_case_insensitivity(thread_cache: ThreadCache): @@ -194,83 +126,12 @@ def increment_in_thread(): assert thread_cache.reqs == 1000 # 10 threads incrementing 100 times -@patch("aikido_zen.background_process.comms.get_comms") -@patch("aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms") -def test_renew_if_ttl_expired_multiple_times( - mock_get_unixtime_ms, mock_get_comms, thread_cache: ThreadCache -): - """Test renewing the cache multiple times if TTL has expired.""" - mock_get_unixtime_ms.return_value = ( - THREAD_CONFIG_TTL_MS + 1 - ) # Simulate TTL expiration - mock_get_comms.return_value = MagicMock() - mock_get_comms.return_value.send_data_to_bg_process.return_value = { - "success": True, - "data": { - "config": ServiceConfig( - endpoints=[ - { - "graphql": False, - "method": "POST", - "route": "/v2", - "rate_limiting": { - "enabled": False, - }, - "force_protection_off": False, - } - ], - bypassed_ips=["192.168.1.1"], - blocked_uids={"user123"}, - last_updated_at=-1, - received_any_stats=True, - ), - "routes": {}, - }, - } - - # First renewal - thread_cache.renew_if_ttl_expired() - assert thread_cache.is_bypassed_ip("192.168.1.1") - assert thread_cache.get_endpoints() == [ - { - "graphql": False, - "method": "POST", - "route": "/v2", - "rate_limiting": { - "enabled": False, - }, - "force_protection_off": False, - } - ] - assert thread_cache.is_user_blocked("user123") - - # Simulate another TTL expiration - mock_get_unixtime_ms.return_value += THREAD_CONFIG_TTL_MS + 1 - thread_cache.renew_if_ttl_expired() - assert thread_cache.is_bypassed_ip("192.168.1.1") - assert thread_cache.get_endpoints() == [ - { - "graphql": False, - "method": "POST", - "route": "/v2", - "rate_limiting": { - "enabled": False, - }, - "force_protection_off": False, - } - ] - assert thread_cache.is_user_blocked("user123") - - @patch("aikido_zen.background_process.comms.get_comms") @patch("aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms") def test_parses_routes_correctly( mock_get_unixtime_ms, mock_get_comms, thread_cache: ThreadCache ): """Test renewing the cache multiple times if TTL has expired.""" - mock_get_unixtime_ms.return_value = ( - THREAD_CONFIG_TTL_MS + 1 - ) # Simulate TTL expiration mock_get_comms.return_value = MagicMock() mock_get_comms.return_value.send_data_to_bg_process.return_value = { "success": True, @@ -312,7 +173,7 @@ def test_parses_routes_correctly( } # First renewal - thread_cache.renew_if_ttl_expired() + thread_cache.renew() assert thread_cache.is_bypassed_ip("192.168.1.1") assert thread_cache.get_endpoints() == [ { diff --git a/aikido_zen/vulnerabilities/init_test.py b/aikido_zen/vulnerabilities/init_test.py index 32264c46..21d2ac9a 100644 --- a/aikido_zen/vulnerabilities/init_test.py +++ b/aikido_zen/vulnerabilities/init_test.py @@ -3,17 +3,18 @@ from . import run_vulnerability_scan from aikido_zen.context import current_context, Context from aikido_zen.errors import AikidoSQLInjection -from aikido_zen.thread.thread_cache import ThreadCache, threadlocal_storage +from aikido_zen.thread.thread_cache import get_cache from aikido_zen.helpers.iplist import IPList @pytest.fixture(autouse=True) def run_around_tests(): + get_cache().reset() yield # Make sure to reset context and cache after every test so it does not # interfere with other tests current_context.set(None) - setattr(threadlocal_storage, "cache", None) + get_cache().reset() @pytest.fixture @@ -49,14 +50,14 @@ def get_context(): def test_run_vulnerability_scan_no_context(caplog): current_context.set(None) - threadlocal_storage.cache = 1 + get_cache().cache = 1 run_vulnerability_scan(kind="test", op="test", args=tuple()) assert len(caplog.text) == 0 def test_run_vulnerability_scan_no_context_no_lifecycle(caplog): current_context.set(None) - threadlocal_storage.cache = None + get_cache().cache = None run_vulnerability_scan(kind="test", op="test", args=tuple()) assert len(caplog.text) == 0 @@ -64,13 +65,12 @@ def test_run_vulnerability_scan_no_context_no_lifecycle(caplog): def test_run_vulnerability_scan_context_no_lifecycle(caplog): with pytest.raises(Exception): current_context.set(1) - threadlocal_storage.cache = None + get_cache().cache = None run_vulnerability_scan(kind="test", op="test", args=tuple()) def test_lifecycle_cache_ok(caplog, get_context): get_context.set_as_current_context() - cache = ThreadCache() run_vulnerability_scan(kind="test", op="test", args=tuple()) assert "Vulnerability type test currently has no scans implemented" in caplog.text @@ -82,7 +82,7 @@ def test_ssrf(caplog, get_context): def test_lifecycle_cache_bypassed_ip(caplog, get_context): get_context.set_as_current_context() - cache = ThreadCache() + cache = get_cache() cache.config.bypassed_ips = IPList() cache.config.bypassed_ips.add("198.51.100.23") assert cache.is_bypassed_ip("198.51.100.23") @@ -92,7 +92,6 @@ def test_lifecycle_cache_bypassed_ip(caplog, get_context): def test_sql_injection(caplog, get_context, monkeypatch): get_context.set_as_current_context() - cache = ThreadCache() monkeypatch.setenv("AIKIDO_BLOCK", "1") with pytest.raises(AikidoSQLInjection): run_vulnerability_scan( @@ -104,7 +103,6 @@ def test_sql_injection(caplog, get_context, monkeypatch): def test_sql_injection_with_route_params(caplog, get_context, monkeypatch): get_context.set_as_current_context() - cache = ThreadCache() monkeypatch.setenv("AIKIDO_BLOCK", "1") with pytest.raises(AikidoSQLInjection): run_vulnerability_scan( @@ -116,8 +114,6 @@ def test_sql_injection_with_route_params(caplog, get_context, monkeypatch): def test_sql_injection_with_comms(caplog, get_context, monkeypatch): get_context.set_as_current_context() - cache = ThreadCache() - cache.last_renewal = 9999999999999999999999 monkeypatch.setenv("AIKIDO_BLOCK", "1") with patch("aikido_zen.background_process.comms.get_comms") as mock_get_comms: # Create a mock comms object diff --git a/end2end/django_mysql_test.py b/end2end/django_mysql_test.py index 7e499be4..95c8f456 100644 --- a/end2end/django_mysql_test.py +++ b/end2end/django_mysql_test.py @@ -79,11 +79,11 @@ def test_initial_heartbeat(): assert len(heartbeat_events) == 1 validate_heartbeat(heartbeat_events[0], [{ - "apispec": {}, + "apispec": {'body': {'type': 'form-urlencoded', 'schema': {'type': 'object', 'properties': {'dog_name': {'type': 'string'}}}}, 'query': None, 'auth': None}, "hits": 1, - "hits_delta_since_sync": 1, + "hits_delta_since_sync": 0, "method": "POST", "path": "/app/create" }], - {"aborted":0,"attacksDetected":{"blocked":2,"total":2},"total":0} + {"aborted":0,"attacksDetected":{"blocked":2,"total":2},"total":3} ) diff --git a/end2end/server/check_events_from_mock.py b/end2end/server/check_events_from_mock.py index af8724be..8da473f3 100644 --- a/end2end/server/check_events_from_mock.py +++ b/end2end/server/check_events_from_mock.py @@ -22,6 +22,7 @@ def validate_started_event(event, stack, dry_mode=False, serverless=False, os_na # assert set(event["agent"]["stack"]) == set(stack) def validate_heartbeat(event, routes, req_stats): - assert event["type"] == "heartbeat" - assert event["routes"] == routes - assert event["stats"]["requests"] == req_stats + assert event["type"] == "heartbeat", f"Expected event type 'heartbeat', but got '{event['type']}'" + assert event["routes"] == routes, f"Expected routes '{routes}', but got '{event['routes']}'" + assert event["stats"]["requests"] == req_stats, f"Expected request stats '{req_stats}', but got '{event['stats']['requests']}'" +