diff --git a/pipelinerl/run_llm.py b/pipelinerl/run_llm.py index d518a1d..0852ee6 100644 --- a/pipelinerl/run_llm.py +++ b/pipelinerl/run_llm.py @@ -31,6 +31,7 @@ from vllm.usage.usage_lib import UsageContext from vllm.worker.multi_step_worker import MultiStepWorker from vllm.worker.multi_step_model_runner import MultiStepModelRunner +from vllm.core.scheduler import Scheduler import torch.distributed as dist @@ -47,6 +48,22 @@ handler.setFormatter(formatter) logger.addHandler(handler) +old_schedule_method = Scheduler.schedule +def new_schedule_method(self, *args, **kwargs): + result = old_schedule_method(self, *args, **kwargs) + if getattr(self, "_force_recompute_kv_cache", True): + logger.info(f"Clear the force recompute flag") + self._force_recompute_kv_cache = False + return result +Scheduler.schedule = new_schedule_method + +old_can_append_slots = Scheduler._can_append_slots +def new_can_append_slots(self, *args, **kwargs): + if getattr(self, "_force_recompute_kv_cache", True): + logger.info(f"Return False from can_append_slots because of force recompute") + return False + return old_can_append_slots(self, *args, **kwargs) +Scheduler._can_append_slots = new_can_append_slots def make_worker_class(multi_step: bool): @@ -209,17 +226,6 @@ def signal_handler(*_) -> None: if not args.disable_weight_updates: weight_update_manager.input_process_groups() - # weight_update_stream = SingleStreamSpec(exp_path=args.exp_root_dir, topic="weight_update_request") - # async def weight_update_receiver(): - # async with AsyncStreamReader(weight_update_stream) as reader: - # async for line in reader.read(): - # message = TypeAdapter(TrainerMessage).validate_python(line) - # if isinstance(message, WeightUpdateRequest): - # await weight_update_manager.receive_weight_update(message) - # if not args.disable_weight_updates: - # logger.info(f"Create weight update background task") - # asyncio.create_task(weight_update_receiver()) - # Run HTTP server sock_addr = (args.host or "", args.port) sock = create_server_socket(sock_addr) @@ -228,6 +234,8 @@ def signal_handler(*_) -> None: @app.post("/receive_weight_update") async def _receive_weight_update(request: WeightUpdateRequest): await weight_update_manager.receive_weight_update(request) + for scheduler in engine.engine.scheduler: + scheduler._force_recompute_kv_cache = True return {"status": "ok"} model_config = await engine.get_model_config()