diff --git a/README.md b/README.md index f404701..2aec71b 100644 --- a/README.md +++ b/README.md @@ -34,8 +34,8 @@ cd pipelinerl Create the environments with dependencies. ```bash conda create -n pipeline-rl -y python=3.11 -conda run --no-capture-output -n pipeline-rl pip install torch==2.6.0 --index-url https://download.pytorch.org/whl/cu121 -conda run --no-capture-output -n pipeline-rl pip install -r requirements.txt --no-build-isolation +conda run --no-capture-output -n pipeline-rl pip install torch==2.6.0 +conda run --no-capture-output -n pipeline-rl pip install -e . --no-build-isolation ``` By default Pipeline-RL will use the file system as the medium for streaming the generated data to the trainer processes. This works on one node, but the files can get quite large. To use Redis instead you will need to install the Redis server in the same conda environment: diff --git a/conf/actor/math.yaml b/conf/actor/math.yaml new file mode 100644 index 0000000..109dd76 --- /dev/null +++ b/conf/actor/math.yaml @@ -0,0 +1,8 @@ +log_each_n_secs: 10 +llm_max_rollouts: 128 +rollout_workers: 1 +rollout_policy: pipelinerl.math.rollouts.generate_math_rollout +discount_factor: 1 +system_prompt: Please reason step by step, and put your final answer within \boxed{}. +task_template: |- + {task} diff --git a/conf/actor/web.yaml b/conf/actor/web.yaml new file mode 100644 index 0000000..f331e87 --- /dev/null +++ b/conf/actor/web.yaml @@ -0,0 +1,109 @@ +log_each_n_secs: 10 +llm_max_rollouts: 128 +rollout_workers: 1 +rollout_policy: pipelinerl.tapeagents_rollouts.generate_rollout + +environment: + _target_: tapeagents.mcp.MCPEnvironment + config_path: conf/mcp/web.json + +llm: + _target_: tapeagents.llms.LiteLLM + model_name: o4-mini-2025-04-16 + use_cache: true + context_size: 200000 + parameters: + temperature: 1 + max_completion_tokens: 16000 + +agent: + _target_: tapeagents.agent.Agent + name : web_agent + llms: + default: ${llm} + templates: + system_prompt: | + You are an expert AI Agent trained to assist users with complex information processing tasks. + Your role is to understand user queries and respond in a helpful and accurate manner. + Keep your replies concise and direct. Prioritize clarity and avoid over-elaboration. + Do not express emotions or opinions about user questions. + allowed_tools: | + You have access to the following tools: + {tools_description} + thought_format: | + Important! Respond with the plain text, do not include any JSON or code. + Do not output anything besides what I asked in this message. + allowed_steps: | + You have access to the following tools: + {tools_description} + You are allowed to produce ONLY steps with the following JSON schemas: + {allowed_steps} + Do not reproduce the schema when producing steps; use it as a reference. + format: > + Output only a single JSON dict or a single JSON list. + DO NOT OUTPUT ANYTHING BESIDES THE JSON! DO NOT PLACE ANY COMMENTS INSIDE THE JSON. + It will break the system that processes the output. + + nodes: + - _target_: tapeagents.nodes.StandardNode + name: plan + system_prompt: ${agent.templates.system_prompt} + guidance: | + Write a concise multi-step plan explaining which steps should be performed to find the answer for the given task. + Be specific about how each step should be performed. Only describe the intended actions here, do not perform them yet. + Consider that next steps may depend on results of previous steps, so include conditional branching using "if" statements where needed. + Start with the title "Plan". Every step should have short name and description. + ${agent.templates.thought_format} + steps_prompt: ${agent.templates.allowed_tools} + + - _target_: tapeagents.nodes.StandardNode + name: reflect + system_prompt: ${agent.templates.system_prompt} + guidance: | + Observe the current state of the task and produce the reflection text strictly following these rules: + 1. Evaluate the action's success, explain its impact on the task and our plan, + 2. If the last action was not successful, describe errors and the possible reasons for failure. + 3. List the next steps to accomplish the current plan step and propose single next immediate action. + 4. When proposing webpage interactions: + - Always accept cookie and close popups first before interacting + - If the last action was not successful, check if the target element is visible and use scrolling if it is not. + 5. Describe the expected effect of the proposed action. + ${agent.templates.thought_format} + steps_prompt: ${agent.templates.allowed_tools} + + - _target_: tapeagents.nodes.StandardNode + name: act + system_prompt: ${agent.templates.system_prompt} + guidance: Then produce single function call for the next step. If the answer is ready, call FinalStep function. + steps: + - tapeagents.steps.ReasoningThought + - tapeagents.core.FinalStep + use_known_actions: true + use_function_calls: true + next_node: act + + - _target_: tapeagents.nodes.StandardNode + name: summarize + system_prompt: ${agent.templates.system_prompt} + guidance: | + Summarize last observation. If its an image, thoroughly describe it with all details. + Describe the results of the last action and observed changes. Discuss its impact on the task and our plan. + Do not hallucinate or make up any information, only describe what you see in the observation. + Do not guess or assume action effects, describe only visible changes. + ${agent.templates.thought_format} + steps_prompt: ${agent.templates.allowed_tools} + next_node: reflect + +split: validation +batch: 2 +retry_unsolved: true + +only_tasks: #[] # list of (level, task_num) +- [1, 0] +- [1, 1] +- [1, 2] +- [1, 3] +- [1, 4] +- [1, 5] +- [1, 6] +- [1, 7] \ No newline at end of file diff --git a/conf/base.yaml b/conf/base.yaml index b56d3fb..674d15c 100644 --- a/conf/base.yaml +++ b/conf/base.yaml @@ -2,6 +2,7 @@ defaults: - finetune: base - rewards: pure_success - streams: files + - actor: math - _self_ finetune: @@ -37,9 +38,6 @@ finetune: max_lag: ${..max_lag} weight_update_interval: 1 pop_old_data: ${..pop_old_data} -actor: - llm_max_rollouts: 128 - rollout_workers: 1 verifier: host: localhost port: 7777 @@ -90,11 +88,6 @@ world: actor_group_port: 9000 -# changed -system_prompt: Please reason step by step, and put your final answer within \boxed{}. -task_template: |- - {task} - eval_every_n_versions: 78000 # changed @@ -115,7 +108,6 @@ force_restart: false pop_old_data: true max_lag: null attempts: 8 -discount_factor: 1 train_dataset_names: - open_reasoner_zero_57k - open_reasoner_zero_extended_72k diff --git a/conf/debug.yaml b/conf/debug.yaml new file mode 100644 index 0000000..2a3ab03 --- /dev/null +++ b/conf/debug.yaml @@ -0,0 +1,27 @@ +defaults: + - base + - override streams: redis + - _self_ + +finetune: + seq_length: 5000 + gradient_accumulation_passes: 1024 + +llm: + parameters: + max_tokens: 4096 + +test_llm: + parameters: + max_tokens: 4096 + +# debug: + # mode: open_loop + +output_dir: results/debug_4gpu_7b/${now:%Y_%m_%d}/${now:start_at_%H_%M_%S} + +# model_path: Qwen/Qwen2.5-0.5B + +# vllm_config: +# vllm_kwargs: +# enforce_eager: "" \ No newline at end of file diff --git a/conf/mcp/web.json b/conf/mcp/web.json new file mode 100644 index 0000000..47c63d5 --- /dev/null +++ b/conf/mcp/web.json @@ -0,0 +1,23 @@ +{ + "mcpServers": { + "serper-search": { + "command": "uv", + "args": ["run", "tapeagents/tools/mcp_servers/web_search.py"], + "env": {"SERPER_API_KEY": ""} + }, + "fetch": { + "command": "uvx", + "args": [ + "mcp-server-fetch" + ] + }, + "python_exec": { + "command": "npx", + "args": [ + "-y", + "@pydantic/mcp-run-python", + "stdio" + ] + } + } +} \ No newline at end of file diff --git a/pipelinerl/debug.py b/pipelinerl/debug.py new file mode 100644 index 0000000..5d1da60 --- /dev/null +++ b/pipelinerl/debug.py @@ -0,0 +1,13 @@ +import hydra +from omegaconf import DictConfig + +from pipelinerl.launch import main as launch_main + + +@hydra.main(config_path="../conf/", config_name="debug", version_base="1.3.2") +def main(cfg: DictConfig): + launch_main(cfg) + + +if __name__ == "__main__": + main() diff --git a/pipelinerl/debug_rollout.py b/pipelinerl/debug_rollout.py new file mode 100644 index 0000000..ae00af5 --- /dev/null +++ b/pipelinerl/debug_rollout.py @@ -0,0 +1,17 @@ +import hydra +from omegaconf import DictConfig + +from pipelinerl.tapeagents_rollouts import generate_rollout + + +@hydra.main(config_path="../conf/", config_name="debug", version_base="1.3.2") +def main(cfg: DictConfig): + llm = None + problem = None + session = None + result = generate_rollout(cfg, llm, problem, session) + print(result) + + +if __name__ == "__main__": + main() diff --git a/pipelinerl/entrypoints/verifier.py b/pipelinerl/entrypoints/verifier.py index 89b4d53..80f23e0 100644 --- a/pipelinerl/entrypoints/verifier.py +++ b/pipelinerl/entrypoints/verifier.py @@ -1,7 +1,7 @@ import hydra from omegaconf import DictConfig -from pipelinerl.verifier_api import run_verifier +from pipelinerl.math.verifier_api import run_verifier from pipelinerl.utils import better_crashing diff --git a/pipelinerl/launch.py b/pipelinerl/launch.py index cbcf752..7e55c43 100644 --- a/pipelinerl/launch.py +++ b/pipelinerl/launch.py @@ -55,7 +55,7 @@ def run_ref_llm(cfg: DictConfig, preprocessor_llm_idx: int, local_idx: int, gpus kwargs = cfg.vllm_config.vllm_kwargs if kwargs["num-scheduler-steps"] > 1: kwargs["num-scheduler-steps"] = 1 - logger.warning(f"Set num-scheduler-steps to 1 for reference vLLM") + logger.warning("Set num-scheduler-steps to 1 for reference vLLM") log_dir = exp_dir / f"ref_vllm_{preprocessor_llm_idx}" os.makedirs(log_dir, exist_ok=True) @@ -81,8 +81,8 @@ def run_ref_llm(cfg: DictConfig, preprocessor_llm_idx: int, local_idx: int, gpus gpu_str = ",".join([str(gpu) for gpu in gpus]) logger.info(f"Running reference LLM with command: {' '.join(cmd)} with gpus: {gpu_str}") - log_file_path = os.path.join(log_dir, f"stdout.log") - err_file_path = os.path.join(log_dir, f"stderr.log") + log_file_path = os.path.join(log_dir, "stdout.log") + err_file_path = os.path.join(log_dir, "stderr.log") with open(log_file_path, "a") as log_file, open(err_file_path, "a") as err_file: yield _popen( cmd, @@ -138,8 +138,9 @@ def run_actor_llm( gpu_str = ",".join([str(gpu) for gpu in gpus]) logger.info(f"Running actor_llm with command: {' '.join(cmd)} on gpus: {gpu_str}") - log_file_path = os.path.join(log_dir, f"stdout.log") - err_file_path = os.path.join(log_dir, f"stderr.log") + save_command(log_dir, cmd) + log_file_path = os.path.join(log_dir, "stdout.log") + err_file_path = os.path.join(log_dir, "stderr.log") with open(log_file_path, "a") as log_file, open(err_file_path, "a") as err_file: yield _popen( cmd, @@ -166,6 +167,7 @@ def run_actor(world_map: WorldMap, actor_idx: int, exp_dir: Path): f"+me.llm_urls={llm_urls}", ] logger.info(f"Running actor with command: {' '.join(cmd)}") + save_command(exp_dir / "actor", cmd) yield _popen( cmd, env=dict(os.environ), @@ -186,10 +188,11 @@ def run_verifier(cfg: DictConfig): f"hydra.run.dir={cfg.output_dir}/verifier", ] logger.info(f"Running verifier with command: {' '.join(cmd)}") + save_command(Path(cfg.output_dir) / "verifier", cmd) log_dir = os.path.join(cfg.output_dir, "verifier") os.makedirs(log_dir, exist_ok=True) - log_file_path = os.path.join(log_dir, f"stdout.log") - err_file_path = os.path.join(log_dir, f"stderr.log") + log_file_path = os.path.join(log_dir, "stdout.log") + err_file_path = os.path.join(log_dir, "stderr.log") with open(log_file_path, "a") as log_file, open(err_file_path, "a") as err_file: yield _popen( cmd, @@ -285,6 +288,7 @@ def run_finetune(cfg: DictConfig, world_map: WorldMap, gpus: list[int], exp_dir: cmd.append("finetune.send_weight_updates=False") logger.info(f"Running finetune with command: {' '.join(cmd)}") + save_command(exp_dir / "finetune", cmd) env = dict(os.environ) env["DS_ENV_FILE"] = str(exp_dir / ".deepspeed_env") yield _popen(cmd, env=env) @@ -307,6 +311,7 @@ def run_preprocess(world_map: WorldMap, preprocessor_idx: int, exp_dir: Path): f"+me.llm_urls={llm_urls}", ] logger.info(f"Running preprocess with command: {' '.join(cmd)}") + save_command(exp_dir / "preprocess", cmd) yield _popen( cmd, env=dict(os.environ), @@ -329,9 +334,22 @@ def run_redis(cfg: DictConfig): cfg.streams.save, ] logger.info(f"Running redis with command: {' '.join(cmd)}") + save_command(Path(cfg.output_dir) / "redis", cmd) yield _popen(cmd, env=dict(os.environ)) +def save_command(script_dir: Path, cmd): + os.makedirs(script_dir, exist_ok=True) + script_path = script_dir / "start.sh" + with open(script_path, "w") as f: + f.write("#!/bin/bash\n") + # Properly quote arguments for the shell script + quoted_cmd = [f"'{arg}'" if " " in arg or "$" in arg else arg for arg in cmd] + f.write(" ".join(quoted_cmd) + "\n") + os.chmod(script_path, 0o755) + logger.info(f"Saved start script to {script_path}") + + def clean_up(exp_dir, force_restart): logger.info("Cleaning up streams directory") if os.path.exists(f"{exp_dir}/streams"): @@ -386,7 +404,7 @@ def gently_stop_all_processes(): gently_stop_all_processes() sys.exit(1) # TODO: make the watcdog code below more stable - # if (trainer_state is not Noneq + # if (trainer_state is not None # and (version := trainer_state.propagated_weight_version is not None) # and version > last_trainer_version): # last_trainer_version = version @@ -493,7 +511,7 @@ def main(cfg: DictConfig): clean_up(exp_dir, cfg.force_restart) os.makedirs(config_dir, exist_ok=True) OmegaConf.save(cfg, config_dir / "exp_config.yaml") - logger.info(f"Orchestrator 0 created the exp folder") + logger.info("Orchestrator 0 created the exp folder") if cfg.streams.backend == "redis": processes.extend(run_redis(cfg)) redis = connect_to_redis(cfg.streams) diff --git a/pipelinerl/math_rollouts.py b/pipelinerl/math/rollouts.py similarity index 91% rename from pipelinerl/math_rollouts.py rename to pipelinerl/math/rollouts.py index e28d9cf..5af8cba 100644 --- a/pipelinerl/math_rollouts.py +++ b/pipelinerl/math/rollouts.py @@ -1,13 +1,14 @@ import time + import aiohttp from omegaconf import DictConfig from pydantic import BaseModel -from tapeagents.core import Prompt, LLMCall, TrainingText +from tapeagents.core import LLMCall, Prompt, TrainingText from tapeagents.llms.trainable import TrainableLLM -from pipelinerl.finetune.data import MASKED_TOKEN_ID from pipelinerl.async_llm import llm_async_generate -from pipelinerl.verifier_api import verify_answer_rpc +from pipelinerl.finetune.data import MASKED_TOKEN_ID +from pipelinerl.math.verifier_api import verify_answer_rpc class RewardTable(BaseModel): @@ -33,9 +34,9 @@ class RolloutResult(BaseModel): def make_prompt(problem: dict, cfg: DictConfig) -> Prompt: messages = [] - if cfg.system_prompt: - messages.append({"role": "system", "content": cfg.system_prompt}) - messages.append({"role": "user", "content": cfg.task_template.format(task=problem["task"])}) + if cfg.actor.system_prompt: + messages.append({"role": "system", "content": cfg.actor.system_prompt}) + messages.append({"role": "user", "content": cfg.actor.task_template.format(task=problem["task"])}) return Prompt(messages=messages) @@ -123,6 +124,6 @@ async def generate_math_rollout( llm=llm, answer=problem["answer"], rewards=RewardTable(**dict(cfg.rewards)), - discount_factor=cfg.discount_factor, + discount_factor=cfg.actor.discount_factor, ) return RolloutResult(training_texts=[sample], metrics=metrics, latency=latency, dataset_name=problem.get("dataset")) diff --git a/pipelinerl/verifier_api.py b/pipelinerl/math/verifier_api.py similarity index 100% rename from pipelinerl/verifier_api.py rename to pipelinerl/math/verifier_api.py diff --git a/pipelinerl/run_actor.py b/pipelinerl/run_actor.py index 8fc7d4d..f7995af 100644 --- a/pipelinerl/run_actor.py +++ b/pipelinerl/run_actor.py @@ -1,31 +1,28 @@ +import asyncio import logging import math -from multiprocessing.managers import SharedMemoryManager +import multiprocessing as mp import os import queue import random import time -import multiprocessing as mp from collections import defaultdict +from multiprocessing.managers import SharedMemoryManager from pathlib import Path -import uvloop import aiohttp - +import hydra +import uvloop from omegaconf import DictConfig from pydantic import BaseModel, Field - -from pipelinerl.shared_memory_array import SharedMemoryArray -from pipelinerl.verifier_api import wait_for_verifier from tapeagents.llms import TrainableLLM -from pipelinerl.finetune.logging_ import flatten_dict_config, init_wandb import wandb +from pipelinerl.finetune.logging_ import flatten_dict_config, init_wandb from pipelinerl.load_datasets import load_datasets -from pipelinerl.math_rollouts import RolloutResult, generate_math_rollout +from pipelinerl.math.rollouts import RolloutResult +from pipelinerl.shared_memory_array import SharedMemoryArray from pipelinerl.state import TrainerState -import asyncio -from collections import defaultdict from pipelinerl.streams import ( SingleStreamSpec, StreamSpec, @@ -33,11 +30,12 @@ set_streams_backend, write_to_streams, ) +from pipelinerl.math.verifier_api import wait_for_verifier from .utils import ( always_or_never_success_stats, - calculate_stats, calculate_per_group_stats, + calculate_stats, setup_logging, wait_for_inference_servers, ) @@ -130,6 +128,8 @@ async def schedule_rollouts( max_group_size_bytes = 0 # Track rollouts per problem group group_rollouts = {} + rollout_policy = hydra.utils.get_method(cfg.actor.rollout_policy) + logger.info(f"Use rollout policy: {rollout_policy}") async def rollout_and_maybe_produce_result( problem: dict, @@ -143,7 +143,7 @@ async def rollout_and_maybe_produce_result( llm = llms[llm_index] model_version = trainer_state.propagated_weight_version assert model_version is not None - rollout_result = await generate_math_rollout(cfg, llm, problem, session) + rollout_result = await rollout_policy(cfg, llm, problem, session) rollout_result.model_version = model_version # Make a group id that will be different from groups made by another rollout maker full_group_id = f"{scheduler_name}_{group_id}" @@ -164,7 +164,7 @@ async def rollout_and_maybe_produce_result( finished_rollouts += 1 except Exception as e: # Cancel all tasks except the current one - logger.error("Exception in rollout", exc_info=e) + logger.error("Exception in rollout, stop all other rollout tasks", exc_info=e) current_task = asyncio.current_task(loop=loop) for task in asyncio.all_tasks(loop=loop): if task != current_task: @@ -356,6 +356,7 @@ def run(self, dataset: list[tuple[str, dict]]): published_samples = 0 submitted_groups = 0 finished_groups = 0 + last_log_time = 0 expected_number_of_samples = -1 if self.is_training else len(dataset) if expected_number_of_samples > 0: logger.info(f"Will stop after {expected_number_of_samples} samples") @@ -477,22 +478,22 @@ def run(self, dataset: list[tuple[str, dict]]): # if we are training publish stats at every step else if all tapes are finished, publish stats if self.is_training or published_samples == expected_number_of_samples: if self.is_training: - loop_stats = { - "published_samples": published_samples, - "samples_in_queue": samples_in_queue, - "finished_groups": finished_groups, - "published_model_version": max_model_version, - "latency": max_latency, - "time_since_start": time.time() - loop_start_time, - } + log_time = time.monotonic() + if log_time - last_log_time > self.cfg.actor.log_each_n_secs: + loop_stats = { + "published_samples": published_samples, + "queue/problems": self.problem_queue.qsize(), + "queue/samples": samples_in_queue, + "finished_groups": finished_groups, + "published_model_version": max_model_version, + "latency": max_latency, + "time_since_start": time.time() - loop_start_time, + } + self.publish_stats(stats_writer=stats_writer, loop_stats=loop_stats, split_name=split_name) + last_log_time = log_time else: loop_stats = {"published_model_version": max_model_version} - - self.publish_stats( - stats_writer=stats_writer, - loop_stats=loop_stats, - split_name=split_name, - ) + self.publish_stats(stats_writer=stats_writer, loop_stats=loop_stats, split_name=split_name) if published_samples == expected_number_of_samples: logger.info(f"Finished {expected_number_of_samples} samples, stopping actor loop") @@ -545,13 +546,14 @@ def publish_stats(self, stats_writer: StreamWriter, loop_stats, split_name: str | {"output_tokens_" + k: v for k, v in calculate_stats(self.output_tokens[dataset_name]).items()} | {"overflows_" + k: v for k, v in calculate_stats(self.overflows[dataset_name]).items()} ) - sub_stats = {dataset_name + "_" + k: v for k, v in sub_stats.items()} + sub_stats = {dataset_name + "/" + k: v for k, v in sub_stats.items()} stats |= sub_stats stats |= loop_stats if loop_stats.get("finished_groups", 0) >= 2 * self.window_size: stats |= sliding_stats - wandb.log({"actor/" + k: v for k, v in stats.items()}) + logger.info("Publishing stats to wandb") + wandb.log({f"actor/{split_name}/{k}": v for k, v in stats.items()}) stats_writer.write(stats) self.init_stats() diff --git a/pipelinerl/run_finetune.py b/pipelinerl/run_finetune.py index ade2073..efbe445 100644 --- a/pipelinerl/run_finetune.py +++ b/pipelinerl/run_finetune.py @@ -1,30 +1,30 @@ -from concurrent.futures import ThreadPoolExecutor -import logging - -import deepspeed -from accelerate.utils import FullyShardedDataParallelPlugin - import contextlib import json +import logging import os import threading import time from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor from dataclasses import asdict from functools import partial from pathlib import Path from queue import Empty, Queue -from typing import Any, List, Literal, Dict +from typing import Any, Dict, List, Literal +import deepspeed import requests import torch import torch.distributed as dist +from accelerate.utils import FullyShardedDataParallelPlugin +from omegaconf import DictConfig +from pydantic import BaseModel from torch.distributed.fsdp import FullStateDictConfig, StateDictType from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.api import MixedPrecision +from transformers import PreTrainedTokenizerFast, get_scheduler, set_seed -from omegaconf import DictConfig -from pydantic import BaseModel +import pipelinerl.torch_utils from pipelinerl.finetune.checkpoints import ( load_model, load_tokenizer, @@ -37,23 +37,20 @@ from pipelinerl.finetune.data import collate, collate_packed from pipelinerl.finetune.logging_ import log_metrics, log_time, setup_logging from pipelinerl.finetune.optim import get_optimizer -from pipelinerl.finetune.utils import create_sentinel_batch, VersionedTensors from pipelinerl.finetune.rl import ( RLConfig, rl_step, ) from pipelinerl.finetune.rl.utils import get_avg_rl_stats from pipelinerl.finetune.types import TrainingMetrics -from transformers import get_scheduler, set_seed, PreTrainedTokenizerFast - -from pipelinerl.utils import wait_for_inference_servers -import pipelinerl.torch_utils +from pipelinerl.finetune.utils import VersionedTensors, create_sentinel_batch from pipelinerl.streams import ( SingleStreamSpec, read_stream, set_streams_backend, write_to_streams, ) +from pipelinerl.utils import wait_for_inference_servers logger = logging.getLogger(__name__) @@ -177,6 +174,7 @@ def run_dynamic_batch_size_data_loader( current_length = 0 samples_in_step = 0 sample_generator = sample_generator_fn(sample_queue) + skip_count = 0 while True: try: while True: @@ -185,9 +183,11 @@ def run_dynamic_batch_size_data_loader( sample_length = len(entry["input_ids"]) if entry else 0 if sample_length > max_seq_length: - raise ValueError( - f"Sample is of length {sample_length}, exceeding the max length of {max_seq_length}" + skip_count += 1 + logger.warning( + f"Sample length {sample_length} > max allowed length {max_seq_length}, skipping. Total {skip_count} samples skipped so far." ) + continue # check if adding current sample would exceed max_seq_length or if we've reached sample limit boundary = samples_in_step == samples_per_worker_per_step @@ -683,6 +683,9 @@ def batch_generator_fn(): ): training_metrics.samples_too_old_to_train += args.train_batch_size batch = versioned_batch.tensors + logger.info( + f"Got batch with version {versioned_batch.model_version} and size {len(batch['input_ids'])}X{len(batch['input_ids'][0])}" + ) lag_stats["min_version"] = min( lag_stats.get("min_version", versioned_batch.model_version), versioned_batch.model_version ) @@ -711,6 +714,9 @@ def batch_generator_fn(): dist.all_gather(all_samples, local_samples) total_samples = sum(int(tensor.item()) for tensor in all_samples) do_optimizer_step = total_samples == target_samples + logger.info( + f"{total_samples} out of {target_samples} samples before optimizer step. Do step: {do_optimizer_step}" + ) using_deepspeed = isinstance(model, deepspeed.DeepSpeedEngine) def backward(loss, is_final_micro_batch=False): @@ -820,7 +826,7 @@ def toggle_sync(sync: bool): "stats/epoch": training_metrics.epoch, "stats/min_actor_version": lag_stats["min_version"], "stats/max_actor_version": lag_stats["max_version"], - "stats/queue_size": sample_queue.qsize(), + "stats/queue/samples": sample_queue.qsize(), "stats/time_waiting_for_data": training_metrics.time_waiting_for_data, "stats/lag": training_metrics.last_broadcasted_version - lag_stats["min_version"], "throughput/tokens_perGPU_per_sec": this_worker_tokens / sum(passes_took) if passes_took else 0, diff --git a/pipelinerl/run_preprocess.py b/pipelinerl/run_preprocess.py index fc03531..d9593dd 100644 --- a/pipelinerl/run_preprocess.py +++ b/pipelinerl/run_preprocess.py @@ -2,42 +2,39 @@ os.environ["HF_DATASETS_DISABLE_PROGRESS_BARS"] = "1" -import multiprocessing as mp -from multiprocessing.managers import SharedMemoryManager -from concurrent.futures import ProcessPoolExecutor import logging +import multiprocessing as mp import queue -import time - -from litellm import BaseModel, Field - -from pipelinerl.utils import wait_for_inference_servers -from pipelinerl.world import WorldMap -from pipelinerl.finetune.logging_ import flatten_dict_config, init_wandb -from pipelinerl.shared_memory_array import SharedMemoryArray - -logger = logging.getLogger(__name__) import threading +import time +from concurrent.futures import ProcessPoolExecutor from functools import partial +from multiprocessing.managers import SharedMemoryManager from pathlib import Path from queue import Empty, Queue from typing import List import random -import transformers import datasets +import transformers +from litellm import BaseModel, Field + +from pipelinerl.finetune.logging_ import flatten_dict_config, init_wandb +from pipelinerl.shared_memory_array import SharedMemoryArray +from pipelinerl.utils import wait_for_inference_servers +from pipelinerl.world import WorldMap datasets.disable_caching() from datasets.arrow_dataset import Dataset from datasets.fingerprint import Hasher from omegaconf import DictConfig +from tapeagents.llms import TrainableLLM + from pipelinerl.finetune.checkpoints import ( load_tokenizer, ) from pipelinerl.finetune.data import preprocess_fn from pipelinerl.finetune.rl import RL_DATA_COLUMNS, RLConfig, populate_rl_data -from tapeagents.llms import TrainableLLM - from pipelinerl.streams import ( SingleStreamSpec, StreamRangeSpec, @@ -113,7 +110,7 @@ def replace_oov_tokens_with_the(data: list[dict], tokenizer: transformers.PreTra completion_start = len(entry["input_ids"]) - completion_length for i, token_id in enumerate(invalid_token_ids): if i + completion_start < len(entry["input_ids"]): - logger.warning(f"Invalid token in completion part, logprobs may be inconsistent") + logger.warning("Invalid token in completion part, logprobs may be inconsistent") entry["input_ids"] = new_input_ids new_data.append(entry) @@ -174,7 +171,7 @@ def run_dataset_loader( if len(buffer) == chunk_size: break if not _check_group_sizes(buffer, check_group_size): - raise ValueError(f"Invalid group sizes in data") + raise ValueError("Invalid group sizes in data") try: raw_chunk_queue.put_nowait(buffer) except queue.Full: @@ -294,7 +291,7 @@ def run_preprocessing_loop( partition_range=(0, max(world_map.total_finetune_gpus, 1)), ) stats_streams = SingleStreamSpec(exp_path=exp_root_dir, topic="preprocessor_stats") - logger.info(f"Streams initialized") + logger.info("Streams initialized") raw_chunk_queue = Queue(cfg.preprocess.queue_size) rl_config = RLConfig(**cfg.finetune.rl) @@ -405,8 +402,10 @@ def run_preprocessing_loop( stats = { "preprocessor/published_samples": published_samples, "preprocessor/published_model_version": max_model_version, - "preprocessor/samples_in_input_queue": raw_chunk_queue.qsize() * cfg.preprocess.chunk_size, - "preprocessor/samples_in_output_queue": samples_in_queue, + "processossor/queue/raw_samples": raw_chunk_queue.qsize() * cfg.preprocess.chunk_size, + "preprocessor/queue/raw": raw_chunk_queue.qsize(), + "preprocessor/queue/dataset_samples": samples_in_queue, + "preprocessor/queue/dataset": dataset_queue.qsize(), } if stats_aggregator.has_enough_data(): stats.update({"preprocessor/" + k: v for k, v in stats_aggregator.get_stats().items()}) diff --git a/pipelinerl/shared_memory_array.py b/pipelinerl/shared_memory_array.py index cc9b37f..b27cb4d 100644 --- a/pipelinerl/shared_memory_array.py +++ b/pipelinerl/shared_memory_array.py @@ -1,8 +1,8 @@ -import multiprocessing as mp -from multiprocessing.managers import SharedMemoryManager import pickle import struct -from typing import Any, Dict, List, Optional, Tuple, Union +from multiprocessing.managers import SharedMemoryManager +from typing import Any + class SharedMemoryArray: @@ -99,7 +99,7 @@ def __len__(self) -> int: """Return the number of entries in the array.""" return self.num_entries - def clear(self, index: int = None) -> None: + def clear(self, index: int | None = None) -> None: """ Clear an entry or the entire array. diff --git a/pipelinerl/streams.py b/pipelinerl/streams.py index b10f3c7..a898cfb 100644 --- a/pipelinerl/streams.py +++ b/pipelinerl/streams.py @@ -1,16 +1,16 @@ -from abc import ABC, abstractmethod -import orjson import json import logging import os -from pathlib import Path import time +from abc import ABC, abstractmethod +from pathlib import Path from typing import Any, Iterator, Literal, Self, TextIO + +import orjson import redis from pydantic import BaseModel import redis.exceptions - logger = logging.getLogger(__name__) # If we try to read too often, the code will be to slow. Too rarely, and delays will be too big. @@ -78,6 +78,9 @@ def __enter__(self) -> Self: def __exit__(self, exc_type, exc_value, traceback): pass + def __len__(self) -> int: + return 0 + @abstractmethod def write(self, data: Any): pass @@ -92,6 +95,9 @@ def __enter__(self) -> Self: def __exit__(self, exc_type, exc_value, traceback): pass + def __len__(self) -> int: + return 0 + @abstractmethod def read(self) -> Iterator[Any]: pass @@ -107,7 +113,7 @@ def connect_to_redis(config: RedisConfig): logger.info(f"Trying to connect to Redis server at {config.host}:{config.port}") client = redis.Redis(host=config.host, port=config.port) client.ping() - logger.info(f"Connected to Redis server") + logger.info("Connected to Redis server") return client except (redis.exceptions.TimeoutError, redis.ConnectionError) as e: logger.info(f"Waiting for Redis server ({type(e)}). Retrying in 5 seconds.") @@ -146,6 +152,9 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self._redis.close() + def __len__(self): + return self._redis.xlen(self._stream_name) + def write(self, data): if isinstance(data, BaseModel): data = data.model_dump() @@ -169,6 +178,9 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self._redis.close() + def __len__(self): + self._redis.xlen(self._stream_name) + def read(self): block = int(_REREAD_DELAY * 1000) while True: @@ -251,6 +263,9 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self._file.close() + def __len__(self): + return get_json_lines_number(self._file) # type: ignore + def write(self, data): if isinstance(data, BaseModel): data = data.model_dump() @@ -259,6 +274,14 @@ def write(self, data): self._file.flush() +def get_json_lines_number(f: TextIO) -> int: + # use wc -l to count the number of lines in the file + cmd = f"wc -l {f.name}" + result = os.popen(cmd).read() + lines = result.split()[0] + return int(lines) + + def read_jsonl_stream(f: TextIO, retry_delay: float = _REREAD_DELAY) -> Iterator[Any]: position = f.tell() @@ -299,6 +322,9 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self._file.close() + def __len__(self): + return get_json_lines_number(self._file) # type: ignore + def read(self): retry_time = 0.01 cur_retries = 0 diff --git a/pipelinerl/tapeagents_rollouts.py b/pipelinerl/tapeagents_rollouts.py new file mode 100644 index 0000000..fbb3b69 --- /dev/null +++ b/pipelinerl/tapeagents_rollouts.py @@ -0,0 +1,68 @@ +import asyncio +import time + +import aiohttp +from omegaconf import DictConfig +from tapeagents.agent import Agent, LLMEvent, LLMStream +from tapeagents.core import Prompt, StopStep, TrainingText +from tapeagents.dialog_tape import DialogTape, UserStep +from tapeagents.environment import Environment +from tapeagents.llms.trainable import TrainableLLM +from tapeagents.orchestrator import get_agent_and_env_from_config, main_loop + +from pipelinerl.async_llm import llm_async_generate +from pipelinerl.math_rollouts import RolloutResult + + +def run_tapeagent( + task: str, agent: Agent, environment: Environment, max_loops: int +) -> tuple[list[TrainingText], dict[str, float]]: + start_tape = DialogTape(steps=[UserStep(content=task)]) + tape: DialogTape | None = None + for event in main_loop(agent, start_tape, environment, max_loops): + if event.agent_tape: + tape = event.agent_tape + elif event.env_tape: + tape = event.env_tape + assert tape is not None, "No tape generated" + has_errors = any([1 for s in tape.steps if s.llm_dict().get("error")]) + has_answer = any([isinstance(s, StopStep) for s in tape.steps]) + _, llm_calls = agent.reuse(tape) + samples = [agent.make_training_text(llm_call) for llm_call in llm_calls] + reward = 0 # TODO: implement verifier usage and reward calculation + metrics = { + "reward": reward, + "success": reward > 0, + "no_error": not has_errors, + "no_answer": not has_answer, + "prompt_tokens": sum([llm_call.prompt_length_tokens for llm_call in llm_calls]), + "output_tokens": sum([llm_call.output_length_tokens for llm_call in llm_calls]), + "overflow": 0, # TODO: should we treat max_loops stop as overflow? + } + return samples, metrics + + +async def generate_rollout( + cfg: DictConfig, + llm: TrainableLLM, + problem: dict, + session: aiohttp.ClientSession, +) -> RolloutResult: + def generate(self, prompt: Prompt): + # !! should be called in a separate thread only !! + # 'self' here is the llm instance (agent.llms["default"]) + # 'session' is captured from the outer scope of generate_rollout + def _implementation(): + llm_call = asyncio.run(llm_async_generate(self, prompt, session)) + yield LLMEvent(output=llm_call.output, llm_call=llm_call) + + return LLMStream(_implementation(), prompt) + + time_start = time.time() + task: str = cfg.task_template.format(task=problem["task"]) + agent, environment = get_agent_and_env_from_config(cfg) + agent.llms = {"default": llm.model_copy()} + agent.llms["default"].generate = generate # type: ignore + samples, metrics = await asyncio.to_thread(run_tapeagent, task, agent, environment, cfg.max_loops) + latency = time.time() - time_start + return RolloutResult(training_texts=samples, metrics=metrics, latency=latency, dataset_name=problem.get("dataset")) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..c2fbcae --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,30 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "pipelinerl" +version = "0.1.0" +description = "A scalable asynchronous reinforcement learning implementation with in-flight weight updates." +readme = "README.md" +requires-python = ">=3.11" +license = { file = "LICENSE" } +authors = [ + { name = "ServiceNow" }, +] +dependencies = [ + "torch>=2.5.1", + "vllm==0.7.3", + "accelerate @ git+https://github.com/ServiceNow/accelerate.git@v1.2.0-hotfix", + "Tapeagents[finetune]==0.1.8", + "transformers==4.48.3", + "flash-attn==2.7.4.post1", + "math-verify[antlr4_9_3]==0.7.0", + "orjson==3.10.16", + "redis==5.2.1", + "hydra-core>=1.3.2", +] + +[tool.setuptools.packages.find] +where = ["."] +include = ["pipelinerl*"] diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index f67cf13..0000000 --- a/requirements.txt +++ /dev/null @@ -1,8 +0,0 @@ -vllm==0.8.3 -accelerate @ git+https://github.com/ServiceNow/accelerate.git@v1.2.0-hotfix -Tapeagents[finetune]==0.1.8 -transformers==4.51.0 -flash-attn==2.7.4.post1 -math-verify[antlr4_9_3]==0.7.0 -orjson==3.10.16 -redis==5.2.1 \ No newline at end of file