diff --git a/.gitignore b/.gitignore index 00e50326..69ca3670 100644 --- a/.gitignore +++ b/.gitignore @@ -185,3 +185,9 @@ research_*.json research_*.jsonc daemon_logs* paper +val_results.md +cocktail_vs_separate* +cocktail_results_* +PRINCIPLE.md +TODO.md +plot_mean_reward.py diff --git a/ajet/backbone/main_verl.py b/ajet/backbone/main_verl.py index 05ec0c2e..2b1bf14c 100644 --- a/ajet/backbone/main_verl.py +++ b/ajet/backbone/main_verl.py @@ -132,9 +132,16 @@ def run(self, config): # Instantiate the tokenizer and processor. from verl.utils import hf_processor, hf_tokenizer + from ajet.tokenizer.service import start_tokenizer_service trust_remote_code = config.data.get("trust_remote_code", False) - tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + local_tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + # Cache hot tokenization calls (encode / decode / apply_chat_template) + # in a sidecar process; every other tokenizer attribute is served by + # the local instance directly. + tokenizer = start_tokenizer_service( + local_tokenizer, local_path, trust_remote_code=trust_remote_code + ) # Used for multimodal LLM, could be None processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) diff --git a/ajet/backbone/trainer_verl.py b/ajet/backbone/trainer_verl.py index 540c6a15..47338c89 100644 --- a/ajet/backbone/trainer_verl.py +++ b/ajet/backbone/trainer_verl.py @@ -511,15 +511,16 @@ def fit(self): # noqa: C901 ] ) ) - logger.info("start fit rollout") + logger.info("start batch rollout") self.parallel_env.current_global_steps = self.global_steps + # rollout stage begin ✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨ context_tracker_arr: List[SingleAgentContextTracker] = self.parallel_env.rollout( tasks, mode="sample", epoch=f"train.{epoch}" ) # from ajet import bp; bp("BATCH") - logger.info("end fit rollout") + logger.info("end batch rollout") gen_batch_output = self.parallel_env.to_dataproto(context_tracker_arr) logger.info("end dataproto convertion") @@ -710,7 +711,7 @@ def fit(self): # noqa: C901 # implement critic warmup if self.config.trainer.critic_warmup <= self.global_steps: - # update actor + # update actor ✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨ with marked_timer("update_actor", timing_raw, color="red"): actor_output = self._update_actor(batch) actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) diff --git a/ajet/context_tracker/multiagent_tracking.py b/ajet/context_tracker/multiagent_tracking.py index 87c69e73..a21dc70d 100644 --- a/ajet/context_tracker/multiagent_tracking.py +++ b/ajet/context_tracker/multiagent_tracking.py @@ -130,6 +130,8 @@ def step_spawn_timeline(self, messages: List[dict], tools: List = [], disable_to if disable_toolcalls: consider_roles.remove("tool") + previous_message_encounter_user_role = False + for i, msg in enumerate(messages): if (disable_toolcalls) and (not isinstance(msg["content"], str)): @@ -166,6 +168,11 @@ def step_spawn_timeline(self, messages: List[dict], tools: List = [], disable_to else: author = "env" + if msg["role"] == "user": + previous_message_encounter_user_role = True + + any_later_msg_has_user_role = any((m["role"] == "user") for m in messages[i+1:]) + # extract content block from openai-competible messages and convert to ExtendedMessage timeline += [ ExtendedMessage( @@ -179,8 +186,11 @@ def step_spawn_timeline(self, messages: List[dict], tools: List = [], disable_to token_generator="auto", name = (msg["name"] if "name" in msg else ""), first_message=(i == 0), + before_last_query=any_later_msg_has_user_role ) ] + if ("" in msg["content"]) and (not previous_message_encounter_user_role): + logger.warning(f"Warning! Message content contains tag, but no prior message has `user` role! This is not a common scenario. Please check your agent loop carefully.") return timeline @@ -270,7 +280,6 @@ def step_track( if ( "prompt_text" in llm_output and "prompt_token_ids" in llm_output ): - # currently we make this patch to better compat with Trinity training backend # fix Retokenization Drift timeline = self.patch_prompt_tokens( prompt_text=llm_output["prompt_text"], @@ -321,7 +330,6 @@ def save_llm_interaction_timeline(self, tools, llm_ext_msg, timeline): ) ): logger.bind(exception=True).info(f"General Warning: merge failure discovered.\n") - # from ajet import bp; bp("SWARM") return @@ -393,13 +401,19 @@ def patch_prompt_tokens( prompt_token_ids: List[int], previous_ext_context: List[ExtendedMessage], ) -> List[ExtendedMessage]: + """ + fix retokenization drift + prompt_text = llm_output["prompt_text"]: [this llm call] the prompt in text format used in generation + prompt_token_ids = llm_output["prompt_token_ids"]: [this llm call] the prompt token ids used in generation (prompt_text->prompt_token_ids using tokenizer) + previous_ext_context: [from previous context] the context history + """ - # remove tailing + # remove tailing, usually `<|im_start|> assistant` if prompt_text.endswith(self.generation_prompt): prompt_text = prompt_text[: -len(self.generation_prompt)] # prompt_token_ids = prompt_token_ids[: -len(self.generation_prompt_token)] - # split prompt token ids into message level + # split CURRENT prompt token ids into message level (split_prompt_token_ids is List[List[int]]) split_prompt_token_ids = [] tmp = [] for i in range(len(prompt_token_ids)): @@ -412,17 +426,19 @@ def patch_prompt_tokens( if len(tmp) > 0: split_prompt_token_ids += [tmp] - # split prompt text into message level + # split CURRENT prompt text into message level (corresponding to split_prompt_token_ids) prompt_text_split = prompt_text.split("<|im_start|>") assert prompt_text_split[0] == "", "Prompt text should start with <|im_start|>" prompt_text_split = prompt_text_split[1:] # remove the first empty string for i in range(len(prompt_text_split)): prompt_text_split[i] = "<|im_start|>" + prompt_text_split[i] + # context HISTORY prompt text current_prompt_text = [] for j in range(len(previous_ext_context)): current_prompt_text += [self.tokenizer.decode(previous_ext_context[j].token_arr)] + # HISTORY context length vs CURRENT prompt length if len(previous_ext_context) != len(prompt_text_split): logger.bind(exception=True).error( f"Length mismatch when patching prompt tokens. Previous ext context length: {len(previous_ext_context)}, prompt text split length: {len(prompt_text_split)}. Replacing all tokens." @@ -430,7 +446,12 @@ def patch_prompt_tokens( # try to recover tokens if self.config.ajet.context_tracker.fix_retokenization_drift: - self.ensure_retokenization_perfect_match(previous_ext_context, split_prompt_token_ids, prompt_text_split, current_prompt_text) + self.ensure_retokenization_perfect_match( + previous_ext_context, # HISTORY + split_prompt_token_ids, # CURRENT + prompt_text_split, # CURRENT + current_prompt_text # HISTORY + ) # remove extra messages if len(previous_ext_context) != len(prompt_text_split): @@ -440,39 +461,38 @@ def patch_prompt_tokens( def ensure_retokenization_perfect_match(self, previous_ext_context, split_prompt_token_ids, prompt_text_split, current_prompt_text): + """ + Ensure the retokenization is perfectly matched between HISTORY and CURRENT + + previous_ext_context: the context history in ExtendedMessage format, which contains token_arr (token ids) + split_prompt_token_ids: the prompt token ids of CURRENT prompt, split into message level (List[List[int]]) + prompt_text_split: the prompt text of CURRENT prompt, split into message level (List[str]) + current_prompt_text: the prompt text of HISTORY context, converted from token_arr to text using tokenizer, in message level (List[str]) + """ + for j in range(len(previous_ext_context)): - if prompt_text_split[j] != current_prompt_text[j]: - # if prompt text mismatch, we can replace the tokens + vllm_token_array = split_prompt_token_ids[j] + tracker_token_array = previous_ext_context[j].token_arr + if vllm_token_array == tracker_token_array: + # good, everything is perfect + continue + else: + from ajet import bp; bp("SWARM") + # otherwise, we throw a warning (do not worry, this causes almost no influence in the training) print_dict( { - "expected_prompt_text": prompt_text_split[j], - "current_prompt_text": current_prompt_text[j], + "expected_prompt_text": prompt_text_split[j], # from llm_output["prompt_text"] + "current_prompt_text": current_prompt_text[j], # history prompt text converted from token_arr to text using tokenizer + "expected_token_ids": vllm_token_array, # from llm_output["prompt_token_ids"] + "current_token_ids": tracker_token_array, # from previous_ext_context[j].token_arr }, mod="exception", - header="Prompt text mismatch, Please report a github issue", + header="Prompt token ids mismatch.", ) - previous_ext_context[j].token_arr = self.tokenizer( - prompt_text_split[j], return_tensors="pt", padding=False - ) - else: - # if prompt text match - # we further check whether all token ids matches - vllm_token_array = split_prompt_token_ids[j] - tracker_token_array = previous_ext_context[j].token_arr - if vllm_token_array == tracker_token_array: - # good, everything is perfect - continue - else: - # otherwise, we throw a warning (do not worry, this causes almost no influence in the training) - print_dict( - { - "expected_token_ids": split_prompt_token_ids[j], - "current_token_ids": previous_ext_context[j].token_arr, - }, - mod="exception", - header="Prompt token ids mismatch, Please report a github issue", - ) - + # # fix drift + # previous_ext_context[j].token_arr = self.tokenizer( + # prompt_text_split[j], return_tensors="pt", padding=False + # )["input_ids"].tolist() def process_reward(self, reward_structure: Reward): self.reward_structure = reward_structure diff --git a/ajet/context_tracker/single_agent_tracking.py b/ajet/context_tracker/single_agent_tracking.py index 775abf3b..b48bcb02 100644 --- a/ajet/context_tracker/single_agent_tracking.py +++ b/ajet/context_tracker/single_agent_tracking.py @@ -91,7 +91,7 @@ def get_token_inc_from_llm_response( self.generated_token_cnt += len(vllm_output_raw_token) if not self.generation_prompt_token: self.generation_prompt_token = self.get_generation_prompt_token() - final_token_arr, token_logprob_arr, loss_mask, lack_normal_eos = replace_token_ids( + final_token_arr, token_logprob_arr, loss_mask, lack_normal_eos = replace_token_ids( # pad tokens and logprobs with begin_ids / other_ids / NA token_container=completion_token_arr, precise_token=vllm_output_raw_token, precise_logprob=vllm_output_raw_logprob, @@ -187,7 +187,7 @@ def to_role_content(self, ext_msg_array: List[ExtendedMessage]) -> List: for ext_msg in ext_msg_array: d: dict = { "role": ext_msg.role, - "content": ext_msg.content_for_compare, + "content": ext_msg.text_content_for_compare, } if ext_msg.tool_calls: d.update({"tool_calls": ext_msg.tool_calls}) diff --git a/ajet/context_tracker/timeline_merging/timeline_merging.py b/ajet/context_tracker/timeline_merging/timeline_merging.py index fcc3b052..6fa17a53 100644 --- a/ajet/context_tracker/timeline_merging/timeline_merging.py +++ b/ajet/context_tracker/timeline_merging/timeline_merging.py @@ -21,8 +21,8 @@ def is_timeline_mergeable( for i in range(len(target_timeline)): if timeline_compare_level == "text": same = ( - source_timeline[i].content_for_compare - == target_timeline[i].content_for_compare + source_timeline[i].text_content_for_compare + == target_timeline[i].text_content_for_compare ) elif timeline_compare_level == "token": same = source_timeline[i].token_arr == target_timeline[i].token_arr @@ -52,12 +52,12 @@ def is_timeline_mergeable( # all_msg_match = False # for i in range(len(target_timeline)): # d = {} - # d["source"] = source_timeline[i].content_for_compare - # d["target"] = target_timeline[i].content_for_compare + # d["source"] = source_timeline[i].text_content_for_compare + # d["target"] = target_timeline[i].text_content_for_compare # if timeline_compare_level == "text": # same = ( - # source_timeline[i].content_for_compare - # == target_timeline[i].content_for_compare + # source_timeline[i].text_content_for_compare + # == target_timeline[i].text_content_for_compare # ) # elif timeline_compare_level == "token": # same = source_timeline[i].token_arr == target_timeline[i].token_arr diff --git a/ajet/copilot/create-keep-think-model-chat-template/SKILL.md b/ajet/copilot/create-keep-think-model-chat-template/SKILL.md new file mode 100644 index 00000000..136b11b9 --- /dev/null +++ b/ajet/copilot/create-keep-think-model-chat-template/SKILL.md @@ -0,0 +1,16 @@ + + +Your task is to investigate the chat template of given model, go to its tokenizer config and check whether the following behavior exists: + +> +> Remove history block from the input when apply chat template when converting messages. +> + +This behavior will make RL training slower, if this behavior exists, please change the chat template to forbid such behavior. + +You must not do this in-place, instead, please create another model. +E.g., "/mnt/data_cpfs/xielipeng.xlp/models/Qwen3-8B" -> "/mnt/data_cpfs/xielipeng.xlp/models/Qwen3-8B-Keep-History" +For all files within the original model path, please create symbolic links instead of copying files. +With only one exception, the tokenizer config file, which should be copied and modified to change the chat template. + + diff --git a/ajet/copilot/job.py b/ajet/copilot/job.py index 9182e00b..c254a9a6 100644 --- a/ajet/copilot/job.py +++ b/ajet/copilot/job.py @@ -67,6 +67,14 @@ class AgentJetJob: whose ``num_repeat`` episodes do *not* all share the same reward. Tasks with uniform reward (e.g. all 0 or all 1) produce zero advantage under GRPO and are skipped — useful when the dataset contains many too-easy or too-hard prompts. + - "rollout_until_any_client_agree_sync_weight": defer the stop decision to the swarm + clients themselves. Stops as soon as **any** active swarm client invokes + ``SwarmClient.agree_sync_weight()``. A client is "active" once it has successfully + ``end_episode``'d at least one rewarded (non-abort) episode since the last weight + sync, and falls off the active list after 10 minutes of no chat-completion or + ``begin_episode`` activity. + - "rollout_until_all_clients_agree_sync_weight": like the above, but stops only when + **every** active swarm client has agreed (and there is at least one active client). max_env_worker: an estimation about how many episodes will be running in parallel (all swarm clients combined). backbone: Training backbone framework (e.g., 'verl'). max_prompt_length: Maximum token length for input prompts (token length before the first llm-generated token, default 3000). @@ -90,6 +98,7 @@ class AgentJetJob: val_print_to_markdown_file_path: Path to a file where validation metrics are appended after every validation pass (default None, disabled). train_print_to_markdown_file_path: Path to a file where training metrics are appended after every training step (default None, disabled). total_training_steps: Hard cap on total training steps. If None (default), training runs for `total_epochs` epochs. + timeline_compare_level: Comparison granularity used by the context tracker's timeline merging policy. One of 'text' (relaxed text compare, more aggressive merging, very low cost) or 'token' (strict token compare, less aggressive merging). Default 'text'. """ def __init__( @@ -130,6 +139,7 @@ def __init__( val_print_to_markdown_file_path: str | None = None, train_print_to_markdown_file_path: str | None = None, total_training_steps: int | None = None, + timeline_compare_level: str | None = None, ) -> None: if base_yaml_config is None: @@ -195,6 +205,7 @@ def __init__( self.val_print_to_markdown_file_path: str = cast(str, val_print_to_markdown_file_path) self.train_print_to_markdown_file_path: str = cast(str, train_print_to_markdown_file_path) self.total_training_steps: int = cast(int, total_training_steps) + self.timeline_compare_level: str = cast(str, timeline_compare_level) # see `ajet/default_config/ajet_swarm_default.yaml` overrides = { @@ -230,9 +241,10 @@ def __init__( "ajet.trainer_common.use_kl_in_reward": "use_kl_in_reward", "ajet.trainer_common.kl_penalty_type": "kl_penalty_type", "ajet.rollout.compute_madness_checklist": "compute_madness_checklist", - "ajet.trainer_common.val_print_to_markdown_file_path": "val_print_to_markdown_file_path", - "ajet.trainer_common.train_print_to_markdown_file_path": "train_print_to_markdown_file_path", "ajet.trainer_common.total_training_steps": "total_training_steps", + "ajet.trainer_common.val_print_to_markdown_file_path": "val_print_to_markdown_file_path", + "ajet.trainer_common.train_print_to_markdown_file_path": "train_print_to_markdown_file_path", + "ajet.context_tracker.timeline_merging_policy.timeline_compare_level": "timeline_compare_level", } # if any value given in kwargs, override the corresponding value in config diff --git a/ajet/copilot/monitor-with-tmux/SKILL.md b/ajet/copilot/monitor-with-tmux/SKILL.md index 7da08a71..ead5a48b 100644 --- a/ajet/copilot/monitor-with-tmux/SKILL.md +++ b/ajet/copilot/monitor-with-tmux/SKILL.md @@ -180,3 +180,11 @@ $ python3 /tmp/tmux_wait.py ajet_session 240 && tmux capture-pane -t ajet_sessio tmux kill-session -t ajet_session ``` + + +## For AgentJet Swarm + +- You should create seperate tmux session for each agentjet swarm servers and each agentjet swarm clients +- When debugging, please do not restart agentjet swarm servers frequently, that waste too much time +- When you really having difficulty for clearing GPU memory, run `ajet --autokill` to automatically kill all python and ray processes (however, I still recommend using this as a last resort). +- For AgentJet, always use tmux session name that starts with `ajet-*` diff --git a/ajet/copilot/swarm-configuration/SKILL.md b/ajet/copilot/swarm-configuration/SKILL.md new file mode 100644 index 00000000..62200032 --- /dev/null +++ b/ajet/copilot/swarm-configuration/SKILL.md @@ -0,0 +1,19 @@ +--- +name: swarm-configuration +description: How `max_env_worker` caps the "Running Episodes" gauge, and how `AgentJetJob` relates to the YAML config. +license: Complete terms in LICENSE.txt +--- + +## Running-episodes cap + +The `Running Episodes (Episodes: N)` number in `swarm_overwatch` is bounded by the **engine-side** `max_env_worker` (set on the job config, e.g. `CocktailV2Config.max_env_worker`, then forwarded into `AgentJetJob` and read at `ajet/backbone/trainer_verl.py` as `max_parallel`). In `ajet/task_rollout/native_parallel_worker.py::rollout_swarm`, the engine spawns `ceil(max_env_worker / grpo_n) * grpo_n` long-lived worker threads, each looping `register_episode` → wait-for-claim → repeat, so the total in-flight episodes (summed across **all** swarm clients) cannot exceed that count. `total_batch_size`, per-client `max_env_worker`, `grpo_n`, and the number of clients do **not** raise this cap , to lift it, raise the engine's `max_env_worker` (keep it divisible by `grpo_n`) and restart. + +## AgentJetJob ↔ YAML + +When using Agentjet Swarm, please first use `AgentJetJob` as the primary configuration interface. + +If there are fields you want to set that are not exposed as `AgentJetJob` kwargs, use yaml as the primary configuration interface. + +In general, you should place most configuration in a place (either `AgentJetJob` or yaml), and MUST NOT place configuration here and there at the same time. + +`AgentJetJob` (`ajet/copilot/job.py`) is a thin **YAML overlay**, not a separate config system. On `__init__` it loads a base YAML (default `ajet/default_config/ajet_swarm_default.yaml`, or whatever path is passed via `base_yaml_config=`) into `self.config`, then walks an `overrides` table that maps each constructor kwarg to a deep YAML key (e.g. `max_env_worker` → `ajet.rollout.max_env_worker`, `batch_size` → `ajet.data.train_batch_size`, `model` → `ajet.model.path`). For each entry: if the kwarg is `None` the YAML value wins; if non-`None` it overwrites the YAML value in-place. Anything not listed in `overrides` (e.g. `rollout.temperature`, `rollout.multi_turn`, `trainer_common.save_freq`) has no kwarg shortcut and must be set by mutating `ajet_job.config.ajet.*` directly after construction , this is what `build_cocktail_ajet_job` does in the cocktail_rl_v2 tutorial. `dump_job_as_yaml(path)` serialises the merged result back out, and that dumped YAML is the file the engine subprocess actually consumes. Net effect: **YAML is the source of truth for defaults; `AgentJetJob` kwargs are sparse overrides; post-construction attribute writes are the escape hatch for fields without a kwarg.** diff --git a/ajet/default_config/ajet_config_schema.py b/ajet/default_config/ajet_config_schema.py index f0bdb35b..56e07aa2 100644 --- a/ajet/default_config/ajet_config_schema.py +++ b/ajet/default_config/ajet_config_schema.py @@ -40,6 +40,7 @@ class AjetModel: class AjetData: max_prompt_length: int = 3000 max_response_length: int = 15000 + # Note that this value is ignored when swarm_mode_sample_collection_method="rollout_until_all_clients_agree_sync_weight" train_batch_size: int = 32 @@ -69,7 +70,7 @@ class AjetInterchangeServer: interchange_server_port: Any = "auto" num_fastapi_process: int = 1 max_fastapi_threads: int = 512 - max_inference_tracker_threads: int = 64 + max_inference_tracker_threads: int = 128 already_started: bool = False @@ -97,6 +98,17 @@ class AjetTaskReader: huggingface_dat_repo: HuggingfaceDatRepo = field(default_factory=HuggingfaceDatRepo) jsonl_dataset_file: JsonlDatasetFile = field(default_factory=JsonlDatasetFile) +@dataclass +class AjetTimelineMergingPolicy: + timeline_compare_level: str = "text" + ignore_tools: bool = True + + +@dataclass +class AjetContextTracker: + timeline_merging_policy: AjetTimelineMergingPolicy = field(default_factory=AjetTimelineMergingPolicy) + + @dataclass class AjetDefaultConfig: project_name: str = "ajet_default_project" @@ -110,6 +122,7 @@ class AjetDefaultConfig: trainer_common: AjetTrainerCommon = field(default_factory=AjetTrainerCommon) task_reader: AjetTaskReader = field(default_factory=AjetTaskReader) lora: AjetLora = field(default_factory=AjetLora) + context_tracker: AjetContextTracker = field(default_factory=AjetContextTracker) enable_swarm_mode: bool = True swarm_mode_sample_collection_method: str = "rollout_until_finish_enough_tasks" execute_test: bool = False diff --git a/ajet/default_config/ajet_default.yaml b/ajet/default_config/ajet_default.yaml index 8e425cb5..68915b54 100644 --- a/ajet/default_config/ajet_default.yaml +++ b/ajet/default_config/ajet_default.yaml @@ -16,6 +16,7 @@ ajet: # max number of tokens for response max_response_length: 15000 # how many tasks per training batch + # Note that this value is ignored when swarm_mode_sample_collection_method="rollout_until_all_clients_agree_sync_weight" train_batch_size: 32 # [Hint]: The final number of samples per update will be: N_{sample} = (data.train_batch_size * rollout.num_repeat * rollout.multi_turn.expected_steps) @@ -334,7 +335,7 @@ ajet: interchange_server_port: 'auto' num_fastapi_process: 1 # 1, 2 or 4 is fine max_fastapi_threads: 512 # 64 or 128 is fine - max_inference_tracker_threads: 64 # recommend to be equal to `ajet.rollout.max_env_worker` + max_inference_tracker_threads: 128 # recommend to be equal to `ajet.rollout.max_env_worker` already_started: False # do not edit, used by `swarm` # what is the stop condition for swarm mode sample collection # "rollout_until_finish_enough_episodes": @@ -345,6 +346,12 @@ ajet: # "rollout_until_finish_enough_non_dummy_tasks": # AgentJet will identify the **task_id** of each episode, and stop when it has collected [>= ajet.data.train_batch_size] unique & FINISHED & NON-DUMMY **task_id**. # (Hint: a **task_id** is considered "NON-DUMMY" at least one of **episodes** of **task_id** has **different** reward value.) + # "rollout_until_any_client_agree_sync_weight": + # AgentJet defers the stop decision to swarm clients: stop as soon as ANY active swarm client has called `SwarmClient.agree_sync_weight()`. + # (Hint: a swarm client becomes "active" once it has successfully `end_episode`'d a rewarded (non-abort) episode since the last weight sync, + # and falls off the active list if it does no chat-completion / begin_episode for 10 minutes.) + # "rollout_until_all_clients_agree_sync_weight": + # Like the above, but stop only when EVERY active swarm client has agreed (and there is at least one active client). swarm_mode_sample_collection_method: "rollout_until_finish_enough_tasks" swarm_mode_sample_collection_max_cached_episodes: 9999 diff --git a/ajet/default_config/ajet_swarm_default.yaml b/ajet/default_config/ajet_swarm_default.yaml index 2e975dad..8e35179f 100644 --- a/ajet/default_config/ajet_swarm_default.yaml +++ b/ajet/default_config/ajet_swarm_default.yaml @@ -41,7 +41,7 @@ ajet: interchange_server_port: 10086 num_fastapi_process: 1 # 1, 2 or 4 is fine max_fastapi_threads: 512 # 64 or 128 is fine - max_inference_tracker_threads: 64 # recommend to be equal to `ajet.rollout.max_env_worker` + max_inference_tracker_threads: 128 # recommend to be equal to `ajet.rollout.max_env_worker` already_started: False # do not edit, used by `swarm` # the method to determine when to stop rollout in swarm mode. Options: @@ -53,6 +53,12 @@ ajet: # "rollout_until_finish_enough_non_dummy_tasks": # AgentJet will identify the **task_id** of each episode, and stop when it has collected [>= ajet.data.train_batch_size] unique & FINISHED & NON-DUMMY **task_id**. # (Hint: a **task_id** is considered "NON-DUMMY" at least one of **episodes** of **task_id** has **different** reward value.) + # "rollout_until_any_client_agree_sync_weight": + # AgentJet defers the stop decision to swarm clients: stop as soon as ANY active swarm client has called `SwarmClient.agree_sync_weight()`. + # (Hint: a swarm client becomes "active" once it has successfully `end_episode`'d a rewarded (non-abort) episode since the last weight sync, + # and falls off the active list if it does no chat-completion / begin_episode for 10 minutes.) + # "rollout_until_all_clients_agree_sync_weight": + # Like the above, but stop only when EVERY active swarm client has agreed (and there is at least one active client). swarm_mode_sample_collection_method: "rollout_until_finish_enough_tasks" data: @@ -61,6 +67,7 @@ ajet: # max number of tokens for response max_response_length: 15000 # how many tasks per training batch + # Note that this value is ignored when swarm_mode_sample_collection_method="rollout_until_all_clients_agree_sync_weight" train_batch_size: 32 # [Hint]: The final number of samples per update will be: N_{sample} = (data.train_batch_size * rollout.num_repeat * rollout.multi_turn.expected_steps) diff --git a/ajet/launcher.py b/ajet/launcher.py index c30ec2e5..4e0ca205 100644 --- a/ajet/launcher.py +++ b/ajet/launcher.py @@ -4,7 +4,7 @@ from dotenv import load_dotenv from loguru import logger -from ajet.utils.cleaner import fast_kill_by_keyword_bash +from ajet.utils.cleaner import AUTOKILL_KEYWORDS, fast_kill_by_keyword_bash from ajet.utils.config_utils import prepare_experiment_config from ajet.utils.launch_utils import ( execute_training_process, @@ -177,7 +177,7 @@ def main(): check_avail_gpu(min_free_ratio=0.95) if args.autokill: - args.kill = "ray|vllm|VLLM|python" + args.kill = AUTOKILL_KEYWORDS # Handle kill-keywords argument if provided if args.kill: diff --git a/ajet/schema/extended_msg.py b/ajet/schema/extended_msg.py index f97813d3..a7330906 100644 --- a/ajet/schema/extended_msg.py +++ b/ajet/schema/extended_msg.py @@ -20,7 +20,7 @@ "memory", "llm(do_not_train)", ] -DUMMY_MSG = [{"role": "assistant", "content": "dummy text"}] + def find_sublist_indices(large_list, small_list, reverse=False): @@ -77,6 +77,7 @@ def __init__( token_logprob_arr=[], name="", # preserved field, not used currently first_message=False, + before_last_query=True, # whether this message is before the last user query in the conversation, used for auto tokenization logic ): self.author = author self.role = role @@ -86,7 +87,6 @@ def __init__( self.token_begin_index = token_begin_index self.token_end_index = token_end_index self.invalid_log_prob_value = INVALID_LOG_PROB_VALUE - self._content_for_compare = "" self._info = "" self.tools = tools self.tool_calls = tool_calls @@ -101,11 +101,14 @@ def __init__( self.manual_loss_mask_override = [] self.lack_normal_eos = False + # text content to compare when timeline merging + self._text_content_for_compare = "" self.generate_content_for_compare(content = self.content) self.eos_token_id = tokenizer.eos_token_id if token_generator == "auto": + self.before_last_query = before_last_query self.token_arr = self.auto_tokenize( tokenizer=tokenizer, tools=tools, @@ -123,7 +126,7 @@ def auto_tokenize(self, tokenizer, tools): else: auto_tokenize_target:dict = { "role": self.role, - "content": self.content_for_compare, + "content": self.text_content_for_compare, } if self.tool_calls: auto_tokenize_target.update({"tool_calls": self.tool_calls}) @@ -136,11 +139,17 @@ def auto_tokenize(self, tokenizer, tools): return self.token_arr def auto_tokenize_non_first_message(self, tokenizer, tools): + if self.before_last_query: + # for example, this will remove the block for qwen3's chat template + dummy_msg = [{"role": "assistant", "content": "dummy text"}] + else: + dummy_msg = [{"role": "user", "content": "dummy text"}] + try: # completion_token_arr will contain generation_prompt header auto_tokenize_target:dict = { "role": self.role, - "content": self.content_for_compare, + "content": self.text_content_for_compare, } if self.tool_calls: auto_tokenize_target.update({"tool_calls": self.tool_calls}) @@ -148,18 +157,18 @@ def auto_tokenize_non_first_message(self, tokenizer, tools): auto_tokenize_target.update({"tool_call_id": self.tool_call_id}) text_frag_to = ajet_apply_chat_template( tokenizer=tokenizer, - conversation=DUMMY_MSG + [auto_tokenize_target], + conversation=dummy_msg + [auto_tokenize_target], tokenize=False, tools=tools, ) except Exception as e: raise ValueError( - f"Cannot tokenize {self.role} --- {self.content_for_compare}, \n\n Error: {e}" + f"Cannot tokenize {self.role} --- {self.text_content_for_compare}, \n\n Error: {e}" ) self.token_arr, _ = self.get_inc_simple( text_frag_from=ajet_apply_chat_template( tokenizer=tokenizer, - conversation=DUMMY_MSG, + conversation=dummy_msg, tokenize=False, tools=tools, ), @@ -169,12 +178,11 @@ def auto_tokenize_non_first_message(self, tokenizer, tools): return self.token_arr @property - def content_for_compare(self): - if self._content_for_compare == "": + def text_content_for_compare(self): + if self._text_content_for_compare == "": if not self.tool_calls: - logger.exception("content_for_compare is not set, or previous llm output is empty!") - # self._content_for_compare - return self._content_for_compare + logger.exception("text_content_for_compare is not set, or previous llm output is empty!") + return self._text_content_for_compare @property def need_training(self): @@ -186,7 +194,7 @@ def need_training(self): return self.author in NEED_TRAIN_AUTHORS def generate_content_for_compare(self, content): - self._content_for_compare = content + self._text_content_for_compare = content def get_loss_mask(self, blackout_token_combo): if self.need_training: @@ -314,21 +322,23 @@ def merge_tool_group(group, tokenizer): token_logprob_arr=msg0.token_logprob_arr, first_message=msg0.first_message, ) + # a dummy msg, not necessary, can be [] + dummy_msg = [{"role": "user", "content": "dummy text"}] # re-compute token_arr auto_tokenize_targets = [ - {"role": msg.role, "content": msg.content_for_compare} for msg in group + {"role": msg.role, "content": msg.text_content_for_compare} for msg in group ] merged.token_arr, _ = merged.get_inc_simple( text_frag_from=ajet_apply_chat_template( tokenizer=tokenizer, - conversation=DUMMY_MSG, + conversation=dummy_msg, tokenize=False, tools=merged.tools, add_generation_prompt=False, ), text_frag_to=ajet_apply_chat_template( tokenizer, - conversation=DUMMY_MSG + auto_tokenize_targets, + conversation=dummy_msg + auto_tokenize_targets, tokenize=False, tools=merged.tools, add_generation_prompt=False, diff --git a/ajet/swarm_cli.py b/ajet/swarm_cli.py index 54787386..6ccdce58 100644 --- a/ajet/swarm_cli.py +++ b/ajet/swarm_cli.py @@ -5,6 +5,7 @@ from dotenv import load_dotenv from loguru import logger +from ajet.utils.cleaner import AUTOKILL_KEYWORDS, fast_kill_by_keyword_bash from ajet.utils.config_utils import prepare_experiment_config from ajet.utils.launch_utils import ( dict_to_namespace, @@ -41,6 +42,21 @@ def start_swarm_server(env, config, port): def cmd_start(args): """Handle the 'start' subcommand.""" + if args.autokill: + args.kill = AUTOKILL_KEYWORDS + + if args.kill: + logger.info(f"Killing processes matching keywords: {args.kill}") + for keyword in args.kill.split("|"): + logger.info(f"Killing processes matching keyword: {keyword}") + killed_pids = fast_kill_by_keyword_bash(keyword) + if killed_pids: + logger.success( + f"Successfully killed processes with PIDs: {killed_pids}" + ) + else: + logger.warning(f"No processes found matching keyword: {keyword}") + # Use default config if not provided exp_base_dir = args.exp_dir or DEFAULT_DIR if not args.conf: @@ -126,6 +142,19 @@ def main(): required=False, help="Debug tags; enables Ray post-mortem and DEBUG_TAGS env", ) + parser_start.add_argument( + "--kill", + type=str, + default="", + required=False, + help="list of keywords for killing processes", + ) + parser_start.add_argument( + "--autokill", + action="store_true", + default=False, + help="Kill system processes (ray + vllm + python) that may block the current experiment", + ) parser_start.set_defaults(func=cmd_start) diff --git a/ajet/task_rollout/async_llm_bridge.py b/ajet/task_rollout/async_llm_bridge.py index 685458a9..99b4e655 100644 --- a/ajet/task_rollout/async_llm_bridge.py +++ b/ajet/task_rollout/async_llm_bridge.py @@ -83,6 +83,7 @@ async def llm_chat_verl( updated_sampling_params.update(custom_sampling_params) input_messages = copy.deepcopy(messages) + # the input (prompt) sequence as text prompt_text = ajet_apply_chat_template( tokenizer=self.tokenizer, conversation=input_messages, @@ -90,6 +91,7 @@ async def llm_chat_verl( add_generation_prompt=True, tokenize=False, ) + # the input (prompt) sequence as input_ids prompt_token_ids = self.tokenizer(prompt_text)["input_ids"] final_res: TokenOutput = await self.async_rollout_manager.generate( @@ -122,7 +124,7 @@ async def llm_chat_verl( ): parsed_tool_calls = self.tool_parser.extract_tool_calls(decoded_text, None) # type: ignore - parsed_tool_calls = parsed_tool_calls.model_dump() + parsed_tool_calls = parsed_tool_calls.model_dump(mode='json') model_called = parsed_tool_calls["tools_called"] if model_called: @@ -155,11 +157,13 @@ async def llm_chat_verl( "completion_tokens": len(token_array), # type: ignore "total_tokens": len(prompt_token_ids) + len(token_array), # type: ignore } - # from ajet import bp; bp("DECODE") + return { "role": "assistant", "request_id": request_id, "content": decoded_text, + "prompt_text": prompt_text, + "prompt_token_ids": prompt_token_ids, "tool_calls": tool_calls, "finish_reason": finish_reason, "usage": usage, @@ -327,7 +331,7 @@ async def chat_completion_request( episode_uuid: str, ): from openai.types.chat.chat_completion import ChatCompletion - req_as_dict = req.model_dump() + req_as_dict = req.model_dump(mode='json') # infer + process with context tracker llm_output = await self.run_infer( diff --git a/ajet/task_rollout/native_parallel_worker.py b/ajet/task_rollout/native_parallel_worker.py index b97e4fa2..58cd8951 100644 --- a/ajet/task_rollout/native_parallel_worker.py +++ b/ajet/task_rollout/native_parallel_worker.py @@ -26,8 +26,9 @@ from ajet.context_tracker.single_agent_tracking import SingleAgentContextTracker from ajet.tuner_lib.experimental.interchange_utils import ( http_change_engine_status, - http_update_rollout_pool_information, + http_update_rollout_pool_information_and_fetch_instruction, CurrentBatchRolloutPoolInformation, + SwarmClientInstruction, ) @@ -264,7 +265,7 @@ def rollout_static( def rollout_swarm( # noqa: C901 self, - tasks: List[Task], + tasks: List[Task], # this is dummy task list, the size is `ajet.data.train_batch_size` * `ajet.rollout.num_repeat` mode: Literal["sample", "validate"], epoch: str, allow_sample_num_change=True, @@ -281,7 +282,10 @@ def rollout_swarm( # noqa: C901 tracker_array: List[SingleAgentContextTracker] = [] rollout_n = self.rollout_n n_batch_task = len(tasks) - n_task = min(len(tasks), ceil(self.max_parallel / rollout_n)) + n_task = min( + len(tasks), # `ajet.data.train_batch_size` * `ajet.rollout.num_repeat` / `ajet.rollout.num_repeat` = `ajet.data.train_batch_size` + ceil(self.max_parallel / rollout_n) # `ajet.rollout.max_env_worker` / `ajet.rollout.num_repeat` + ) assert n_task > 0, f"n_task is not valid, n_task = min(len(tasks), self.max_parallel // rollout_n) = {n_task}" self.current_token_count_time = time.time() @@ -292,6 +296,14 @@ def rollout_swarm( # noqa: C901 completed_task_id_map_ct: Dict[str, List[SingleAgentContextTracker]] = IterationSafeDict() executor_lock = threading.Lock() + accept_client_control = ("client" in self.config.ajet.swarm_mode_sample_collection_method) + if accept_client_control: + # Latest active-client / agreed-sync-weight snapshot from the swarm server. Refreshed on every pool-information update; + # consumed by the `rollout_until_*_agree_sync_weight` stop conditions. + latest_swarm_client_instructions: Dict[str, SwarmClientInstruction | None] = {"swarm_clients": None} + else: + latest_swarm_client_instructions = None + # count tasks to see whether we have reach the finish line for next weight update def count_tasks(completed_task_id_map_ct): total_completed_episodes = 0 @@ -340,6 +352,20 @@ def enough_finished_task_stop_condition(completed_task_id_map_ct) -> bool: completed_task_id_map_ct.clear() return (total_completed_tasks >= n_batch_task) + def any_client_agree_sync_weight_stop_condition(completed_task_id_map_ct) -> bool: + # ajet.swarm_mode_sample_collection_method == "rollout_until_any_client_agree_sync_weight" + instr = latest_swarm_client_instructions["swarm_clients"] + if instr is None: + return False + return any(c.allowed_sync_weight for c in instr.active_clients) + + def all_clients_agree_sync_weight_stop_condition(completed_task_id_map_ct) -> bool: + # ajet.swarm_mode_sample_collection_method == "rollout_until_all_clients_agree_sync_weight" + instr = latest_swarm_client_instructions["swarm_clients"] + if instr is None or not instr.active_clients: + return False + return all(c.allowed_sync_weight for c in instr.active_clients) + def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool: # ajet.swarm_mode_sample_collection_method == "rollout_until_finish_enough_non_dummy_tasks" counts = count_tasks(completed_task_id_map_ct) @@ -372,6 +398,10 @@ def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool: stop_condition = enough_finished_task_stop_condition elif self.config.ajet.swarm_mode_sample_collection_method == "rollout_until_finish_enough_non_dummy_tasks": stop_condition = enough_non_dummy_task_stop_condition + elif self.config.ajet.swarm_mode_sample_collection_method == "rollout_until_any_client_agree_sync_weight": + stop_condition = any_client_agree_sync_weight_stop_condition + elif self.config.ajet.swarm_mode_sample_collection_method == "rollout_until_all_clients_agree_sync_weight": + stop_condition = all_clients_agree_sync_weight_stop_condition else: logger.error(f"Invalid swarm_mode_sample_collection_method: {self.config.ajet.swarm_mode_sample_collection_method}, fallback to default method: rollout_until_finish_enough_tasks") stop_condition = enough_finished_task_stop_condition @@ -442,9 +472,16 @@ def update_rollout_result_array_preview(observation_window, completed_task_id_ma buffer += f"Total completed tasks: {counts['total_completed_tasks']} (target {n_batch_task})\n" buffer += f"Total completed non-dummy tasks: {counts['total_completed_non_dummy_tasks']} (target {n_batch_task})\n" buffer += f"Current stop condition: {self.config.ajet.swarm_mode_sample_collection_method}\n" + if accept_client_control: + sc_inst = latest_swarm_client_instructions["swarm_clients"] + if sc_inst is not None: + n_active = len(sc_inst.active_clients) + n_agreed = sum(1 for c in sc_inst.active_clients if c.allowed_sync_weight) + buffer += f"Active clients: {n_active} (agreed: {n_agreed})\n" observation_window["info"][-1] = buffer - # Update rollout pool information via API + # Update rollout pool information via API and pull the latest + # active-client / agreed-sync-weight instruction from the server. pool_info = CurrentBatchRolloutPoolInformation( sample_collection_method=self.config.ajet.swarm_mode_sample_collection_method, completed_episodes=counts['total_completed_episodes'], @@ -457,7 +494,10 @@ def update_rollout_result_array_preview(observation_window, completed_task_id_ma completed_tasks_details=completed_tasks_details, completed_tasks_rewards=completed_tasks_rewards, ) - http_update_rollout_pool_information(self.config, pool_info) + if accept_client_control: + instruction = http_update_rollout_pool_information_and_fetch_instruction(self.config, pool_info) + if instruction is not None: + latest_swarm_client_instructions["swarm_clients"] = instruction return update_rollout_result_array_preview(observation_window, completed_task_id_map_ct) diff --git a/ajet/tokenizer/__init__.py b/ajet/tokenizer/__init__.py new file mode 100644 index 00000000..e3d55066 --- /dev/null +++ b/ajet/tokenizer/__init__.py @@ -0,0 +1,7 @@ +# Intentionally empty. Import directly from ``ajet.tokenizer.service``: +# +# from ajet.tokenizer.service import RemoteTokenizer, start_tokenizer_service +# +# Re-exporting here would cause ``python -m ajet.tokenizer.service`` to load +# the package eagerly before running the service as __main__, which trips a +# RuntimeWarning from runpy. diff --git a/ajet/tokenizer/service.py b/ajet/tokenizer/service.py new file mode 100644 index 00000000..78b8f2c4 --- /dev/null +++ b/ajet/tokenizer/service.py @@ -0,0 +1,352 @@ +"""Lightweight tokenizer cache service. + +A companion process holds an HF tokenizer and serves ``encode``, ``decode`` +and ``apply_chat_template`` calls over a ZMQ ``ipc://`` socket with an LRU +cache. The caller keeps a local tokenizer for everything else (attributes, +``__call__``, etc.) — only those three hot methods cross the wire. +""" + +from __future__ import annotations + +import argparse +import atexit +import os +import shutil +import signal +import subprocess +import sys +import tempfile +import threading +import time +from collections import OrderedDict +from typing import Any, Optional + +import msgpack +import zmq + + +_CACHE_OPS = ("encode", "decode", "apply_chat_template") +_DEFAULT_CACHE_SIZE = 4096 +_DEFAULT_RECV_TIMEOUT_MS = 120_000 + + +# --------------------------------------------------------------------------- +# Server +# --------------------------------------------------------------------------- + + +def _serve( + model_path: str, + ipc_path: str, + *, + trust_remote_code: bool, + cache_size: int, + ready_file: Optional[str], +) -> None: + from loguru import logger + from verl.utils import hf_tokenizer + + tokenizer = hf_tokenizer(model_path, trust_remote_code=trust_remote_code) + + cache: "OrderedDict[bytes, bytes]" = OrderedDict() + ctx = zmq.Context.instance() + sock = ctx.socket(zmq.REP) + sock.setsockopt(zmq.LINGER, 0) + sock.bind(f"ipc://{ipc_path}") + + if ready_file: + try: + with open(ready_file, "w") as fh: + fh.write(str(os.getpid())) + except OSError as exc: + logger.warning(f"failed to write ready file {ready_file}: {exc}") + + logger.info(f"Tokenizer cache service ready at ipc://{ipc_path} (pid={os.getpid()})") + + hits = misses = 0 + while True: + try: + raw = sock.recv() + except (zmq.ContextTerminated, KeyboardInterrupt): + break + + # The request bytes are a stable cache key — msgpack is deterministic + # for our payloads (lists/dicts of primitives), so identical calls + # produce identical raw frames. + if raw in cache: + cache.move_to_end(raw) + hits += 1 + sock.send(cache[raw]) + continue + + try: + req = msgpack.unpackb(raw, raw=False) + op = req.get("op") + args = req.get("args") or [] + kwargs = req.get("kwargs") or {} + except Exception as exc: + sock.send(msgpack.packb({"ok": False, "error": f"bad request: {exc}"}, use_bin_type=True)) + continue + + if op == "shutdown": + sock.send(msgpack.packb({"ok": True}, use_bin_type=True)) + break + if op == "stats": + total = hits + misses + payload = { + "hits": hits, + "misses": misses, + "hit_rate": hits / total if total else 0.0, + "size": len(cache), + "max_size": cache_size, + } + sock.send(msgpack.packb({"ok": True, "result": payload}, use_bin_type=True)) + continue + + try: + if op == "encode": + result = tokenizer.encode(*args, **kwargs) + elif op == "decode": + result = tokenizer.decode(*args, **kwargs) + elif op == "apply_chat_template": + result = tokenizer.apply_chat_template(*args, **kwargs) + else: + sock.send( + msgpack.packb({"ok": False, "error": f"unknown op {op!r}"}, use_bin_type=True) + ) + continue + except Exception as exc: + import traceback + + err = f"{type(exc).__name__}: {exc}\n{traceback.format_exc()}" + sock.send(msgpack.packb({"ok": False, "error": err}, use_bin_type=True)) + continue + + misses += 1 + payload = msgpack.packb({"ok": True, "result": result}, use_bin_type=True) + if op in _CACHE_OPS: + cache[raw] = payload + while len(cache) > cache_size: + cache.popitem(last=False) + sock.send(payload) + + sock.close(linger=0) + ctx.term() + logger.info("Tokenizer cache service stopped") + + +# --------------------------------------------------------------------------- +# Client wrapper +# --------------------------------------------------------------------------- + + +class CachedTokenizer: + """Wraps a local HF tokenizer; routes ``encode`` / ``decode`` / + ``apply_chat_template`` through a cache process, falls back to the local + instance for every other attribute.""" + + def __init__( + self, + local_tokenizer, + ipc_path: str, + *, + recv_timeout_ms: int = _DEFAULT_RECV_TIMEOUT_MS, + ): + self._local = local_tokenizer + self._ipc_path = ipc_path + self._recv_timeout_ms = recv_timeout_ms + self._tls = threading.local() # one REQ socket per thread + + def _socket(self) -> zmq.Socket: + sock = getattr(self._tls, "sock", None) + if sock is not None: + return sock + ctx = zmq.Context.instance() + sock = ctx.socket(zmq.REQ) + sock.setsockopt(zmq.LINGER, 0) + sock.setsockopt(zmq.RCVTIMEO, self._recv_timeout_ms) + sock.connect(f"ipc://{self._ipc_path}") + self._tls.sock = sock + return sock + + def _call(self, op: str, args, kwargs): + sock = self._socket() + sock.send(msgpack.packb({"op": op, "args": list(args), "kwargs": kwargs}, use_bin_type=True)) + try: + raw = sock.recv() + except zmq.Again as exc: + # REQ socket is stuck after a timeout; recreate on next call. + try: + sock.close(linger=0) + except zmq.ZMQError: + pass + self._tls.sock = None + raise RuntimeError(f"tokenizer service timeout (op={op})") from exc + + resp = msgpack.unpackb(raw, raw=False) + if not resp.get("ok"): + raise RuntimeError(f"tokenizer service error (op={op}): {resp.get('error')}") + return resp.get("result") + + # -- the three cached methods ---------------------------------------- + + def encode(self, text, **kwargs): + return self._call("encode", [text], kwargs) + + def decode(self, ids, skip_special_tokens: bool = False, **kwargs): + # HF accepts both ``decode(123)`` and ``decode([123, 456])``. Normalize + # to a list so the cache key is stable and msgpack-serializable. + if isinstance(ids, int): + ids = [ids] + elif hasattr(ids, "tolist"): + ids = ids.tolist() + if isinstance(ids, int): # 0-d tensor + ids = [ids] + else: + ids = list(ids) + return self._call("decode", [ids], {"skip_special_tokens": skip_special_tokens, **kwargs}) + + def apply_chat_template( + self, + conversation, + tools=None, + *, + add_generation_prompt: bool = False, + tokenize: bool = True, + **kwargs, + ): + kw = {"add_generation_prompt": add_generation_prompt, "tokenize": tokenize, **kwargs} + if tools is not None: + kw["tools"] = tools + return self._call("apply_chat_template", [conversation], kw) + + # -- everything else: defer to local tokenizer ----------------------- + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(name) + return getattr(self._local, name) + + def __call__(self, *args, **kwargs): + return self._local(*args, **kwargs) + + def stats(self) -> dict: + return self._call("stats", [], {}) + + def shutdown_service(self) -> None: + try: + self._call("shutdown", [], {}) + except Exception: + pass + + +# --------------------------------------------------------------------------- +# Launcher +# --------------------------------------------------------------------------- + + +def start_tokenizer_service( + local_tokenizer, + model_path: str, + *, + trust_remote_code: bool = False, + ipc_path: Optional[str] = None, + cache_size: int = _DEFAULT_CACHE_SIZE, + ready_timeout_s: float = 180.0, + recv_timeout_ms: int = _DEFAULT_RECV_TIMEOUT_MS, +) -> CachedTokenizer: + """Spawn the cache service subprocess and return a wrapper around the + caller-supplied local tokenizer. + """ + from loguru import logger + + cleanup_dir: Optional[str] = None + if ipc_path is None: + cleanup_dir = tempfile.mkdtemp(prefix="ajet-tok-") + ipc_path = os.path.join(cleanup_dir, "sock") + + ready_file = ipc_path + ".ready" + if os.path.exists(ready_file): + try: + os.remove(ready_file) + except OSError: + pass + + cmd = [ + sys.executable, "-m", "ajet.tokenizer.service", "serve", + "--model-path", model_path, + "--ipc-path", ipc_path, + "--cache-size", str(cache_size), + "--ready-file", ready_file, + ] + if trust_remote_code: + cmd.append("--trust-remote-code") + + logger.info(f"Launching tokenizer cache service: {' '.join(cmd)}") + proc = subprocess.Popen(cmd, env=os.environ.copy()) + + def _cleanup() -> None: + if proc.poll() is None: + try: + proc.terminate() + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() + proc.wait(timeout=5) + except Exception: + pass + if cleanup_dir and os.path.isdir(cleanup_dir): + shutil.rmtree(cleanup_dir, ignore_errors=True) + + atexit.register(_cleanup) + + deadline = time.time() + ready_timeout_s + while time.time() < deadline: + if proc.poll() is not None: + raise RuntimeError(f"tokenizer service exited with code {proc.returncode}") + if os.path.exists(ready_file): + break + time.sleep(0.1) + else: + _cleanup() + raise RuntimeError( + f"tokenizer service at ipc://{ipc_path} did not become ready within {ready_timeout_s}s" + ) + + return CachedTokenizer(local_tokenizer, ipc_path, recv_timeout_ms=recv_timeout_ms) + + +# --------------------------------------------------------------------------- +# CLI entry point (used by the subprocess) +# --------------------------------------------------------------------------- + + +def _main(argv: Optional[list] = None) -> None: + parser = argparse.ArgumentParser(description="ZMQ IPC tokenizer cache service") + sub = parser.add_subparsers(dest="cmd", required=True) + + p_serve = sub.add_parser("serve") + p_serve.add_argument("--model-path", required=True) + p_serve.add_argument("--ipc-path", required=True) + p_serve.add_argument("--cache-size", type=int, default=_DEFAULT_CACHE_SIZE) + p_serve.add_argument("--ready-file", default=None) + p_serve.add_argument("--trust-remote-code", action="store_true") + + args = parser.parse_args(argv) + if args.cmd == "serve": + def _on_term(*_): + os._exit(0) + + signal.signal(signal.SIGTERM, _on_term) + _serve( + model_path=args.model_path, + ipc_path=args.ipc_path, + trust_remote_code=args.trust_remote_code, + cache_size=args.cache_size, + ready_file=args.ready_file, + ) + + +if __name__ == "__main__": + _main() diff --git a/ajet/tuner_lib/experimental/interchange_utils.py b/ajet/tuner_lib/experimental/interchange_utils.py index 8fcb92f0..01e82f86 100644 --- a/ajet/tuner_lib/experimental/interchange_utils.py +++ b/ajet/tuner_lib/experimental/interchange_utils.py @@ -64,6 +64,7 @@ class EndEpisodeRequest(BaseModel): episode_uuid: str workflow_output: WorkflowOutput task_id: str + declare_client_active: bool = True class EndEpisodeResponse(BaseModel): success: bool @@ -118,6 +119,138 @@ class VerboseLogsResponse(BaseModel): entries: List[VerboseLogEntry] = [] +class AgreeSyncWeightRequest(BaseModel): + client_uuid: str + + +class ActiveSwarmClient(BaseModel): + """Server-tracked record for one active swarm client. + + A swarm client enters this list once it has successfully `end_episode`'d + a rewarded (non-abort) episode since the last weight sync, and falls off + after `CLIENT_ACTIVE_TIMEOUT` seconds of no chat-completion / + `begin_episode` activity. The whole list is reset whenever the engine + leaves ROLLING/ROLLING_POST. + + Used both as the swarm server's authoritative storage (single + `shared_mem_dict["active_swarm_clients"]: List[ActiveSwarmClient]` key) + and as the wire payload sent back to the trainer in + `SwarmClientInstruction`. Add future per-client signals (e.g. + `requested_pause`, custom metrics) here -- pydantic field defaults keep + the wire format backwards-compatible across server/trainer versions. + + Fields: + client_uuid: the client_uuid as generated in `SwarmClient.__init__`. + last_activity_at: unix timestamp of the most recent chat-completion, + `begin_episode`, or `end_episode` from this client. Used by the + server's expiry sweep. + allowed_sync_weight: True iff this client has explicitly agreed to + the next weight sync via `SwarmClient.agree_sync_weight()`. + """ + client_uuid: str + last_activity_at: float + allowed_sync_weight: bool = False + + +class SwarmClientInstruction(BaseModel): + """Server -> trainer instruction returned alongside pool-info updates. + + Fields: + active_clients: list of `ActiveSwarmClient` records, one per + currently active client. + + Example wire payload: + ```json + { + "active_clients": [ + {"client_uuid": "9f3c-...-aaaa", "last_activity_at": 1746513900.1, "allowed_sync_weight": true}, + {"client_uuid": "9f3c-...-bbbb", "last_activity_at": 1746513912.4, "allowed_sync_weight": false}, + {"client_uuid": "9f3c-...-cccc", "last_activity_at": 1746513918.7, "allowed_sync_weight": false} + ] + } + ``` + + Example trainer-side use (matches DynamicRolloutManager.rollout_swarm): + ```python + # rollout_until_any_client_agree_sync_weight + if any(c.allowed_sync_weight for c in instr.active_clients): + stop() + + # rollout_until_all_clients_agree_sync_weight + if instr.active_clients and all( + c.allowed_sync_weight for c in instr.active_clients + ): + stop() + ``` + + For the payload above: + - "any" stop-condition evaluates True (one client agreed). + - "all" stop-condition evaluates False (two of three not yet agreed). + """ + active_clients: List[ActiveSwarmClient] = [] + + +# Active-client tracking timeout (seconds): a client falls off the active list +# if it has done no chat-completion or begin_episode call within this window. +CLIENT_ACTIVE_TIMEOUT = 10 * 60 + + +# -------------------------------------------------------------------- +# active-client tracking helpers +# -------------------------------------------------------------------- +# All active-client state lives behind a single shared_mem_dict key: +# "active_swarm_clients": List[ActiveSwarmClient] +# (See `ActiveSwarmClient` for field semantics and lifecycle.) The helpers +# below are imported by the swarm server's FastAPI routes and by the +# OAI-mode chat-completion handler. + + +def _refresh_client_activity(client_uuid: str, shared_mem_dict) -> None: + """If client is in the active list, refresh its last-activity timestamp. + + Called on chat-completion and begin_episode (claim_episode). Does NOT + add the client to the list -- only end_episode (success, non-abort) does. + """ + if not client_uuid: + return + clients: List[ActiveSwarmClient] = list(shared_mem_dict.get("active_swarm_clients", [])) + for i, c in enumerate(clients): + if c.client_uuid == client_uuid: + clients[i] = c.model_copy(update={"last_activity_at": time.time()}) + shared_mem_dict["active_swarm_clients"] = clients + return + + +def _register_active_client(client_uuid: str, shared_mem_dict) -> None: + """Add client to the active list (idempotent) and refresh its timestamp.""" + if not client_uuid: + return + clients: List[ActiveSwarmClient] = list(shared_mem_dict.get("active_swarm_clients", [])) + now = time.time() + for i, c in enumerate(clients): + if c.client_uuid == client_uuid: + clients[i] = c.model_copy(update={"last_activity_at": now}) + shared_mem_dict["active_swarm_clients"] = clients + return + clients.append(ActiveSwarmClient(client_uuid=client_uuid, last_activity_at=now)) + shared_mem_dict["active_swarm_clients"] = clients + + +def _expire_inactive_clients(shared_mem_dict) -> None: + """Drop clients whose last activity is older than CLIENT_ACTIVE_TIMEOUT.""" + now = time.time() + clients: List[ActiveSwarmClient] = list(shared_mem_dict.get("active_swarm_clients", [])) + if not clients: + return + kept = [c for c in clients if (now - c.last_activity_at) <= CLIENT_ACTIVE_TIMEOUT] + if len(kept) != len(clients): + shared_mem_dict["active_swarm_clients"] = kept + + +def _reset_active_client_tracking(shared_mem_dict) -> None: + """Clear all active-client state.""" + shared_mem_dict["active_swarm_clients"] = [] + DEBUG = False # DEBUG = True @@ -233,24 +366,34 @@ def http_push_verbose_log(message: str, tag: str = "", config=None): logger.warning(f"Failed to push verbose log: {e}") -def http_update_rollout_pool_information(config, pool_info: CurrentBatchRolloutPoolInformation): +def http_update_rollout_pool_information_and_fetch_instruction( + config, pool_info: CurrentBatchRolloutPoolInformation +) -> SwarmClientInstruction | None: """ - Update the rollout pool information on the interchange server. + Update the rollout pool information on the interchange server, and fetch + the swarm server's view of currently-active clients and their + agree-to-sync-weight state. Args: config: The configuration object pool_info: CurrentBatchRolloutPoolInformation object with rollout statistics + + Returns: + SwarmClientInstruction with `active_clients` (List[ActiveSwarmClient]), + or None if the request failed. """ try: resp = httpx.post( - f"{get_interchange_server_url(config)}/update_current_batch_rollout_pool_information", + f"{get_interchange_server_url(config)}/update_current_batch_rollout_pool_information_and_fetch_instruction", json=pool_info.model_dump(), timeout=5 ) resp.raise_for_status() + return SwarmClientInstruction.model_validate(resp.json()) except Exception as e: if DEBUG: logger.warning(f"Failed to update rollout pool information: {e}") + return None def get_zmq_socket(config, episode_uuid: str, tag: str = ""): diff --git a/ajet/tuner_lib/experimental/oai_model_server.py b/ajet/tuner_lib/experimental/oai_model_server.py index 3d3a3091..fa823af4 100644 --- a/ajet/tuner_lib/experimental/oai_model_server.py +++ b/ajet/tuner_lib/experimental/oai_model_server.py @@ -301,6 +301,7 @@ async def chat_completions(request: Request, authorization: str = Header(None)): # enable_swarm_mode if enable_swarm_mode: from ajet.tuner_lib.experimental.swarm_server import ep_key + from ajet.tuner_lib.experimental.interchange_utils import _refresh_client_activity assert shared_mem_dict is not None assert shared_mem_dict_lock is not None @@ -319,6 +320,9 @@ async def chat_completions(request: Request, authorization: str = Header(None)): shared_mem_dict[ep_key(episode_uuid)] = es if es.episode_type == "eval": preserve_sampling_params = True + # chat-completion counts as activity for keeping the owning client + # in the swarm-server active list (no-op if it's not active yet). + _refresh_client_activity(es.client_uuid, shared_mem_dict) # For streaming, we process as non-streaming but return in streaming format original_stream = new_req.stream diff --git a/ajet/tuner_lib/experimental/swarm_client.py b/ajet/tuner_lib/experimental/swarm_client.py index 81de62d7..791e53d7 100644 --- a/ajet/tuner_lib/experimental/swarm_client.py +++ b/ajet/tuner_lib/experimental/swarm_client.py @@ -24,6 +24,8 @@ EpisodeStatus, EpisodeBufferResponse, SwarmThrottlePolicy, + AgreeSyncWeightRequest, + BoolResponse, ) # general http timeout @@ -33,8 +35,24 @@ START_EPISODE_RETRY_DELAY = 15 TROTTLE_EPISODE_RETRY_DELAY = 2 WAIT_MORE_AVAIL_EPISODE_RETRY_DELAY = 2 +# agree_sync_weight retry policy. The call must succeed -- a dropped +# agreement can stall the trainer's stop condition. Retries cover both +# transport errors and server-side rejection (e.g. when a just-completed +# end_episode hasn't yet propagated to the server's active list). +AGREE_SYNC_WEIGHT_MAX_RETRIES = 60 +AGREE_SYNC_WEIGHT_RETRY_DELAY = 2.0 +DELAY_AFTER_AGREE_SYNC_WEIGHT = 30 def raise_for_status_with_detail(resp): + """ + Raise an exception with detailed error information if the response indicates an error. + + Args: + resp: The httpx response object to check. + + Raises: + RuntimeError: If the response status code indicates an error. + """ try: resp.raise_for_status() except httpx.HTTPStatusError as e: @@ -64,8 +82,20 @@ class SwarmClientBase(object): "broken pipe", "disconnected", "connection reset", "connection closed", "connection aborted", "bad file descriptor", ) + # Force-refresh the http client after this many consecutive poll failures, + # even if none of the error messages match REFRESH_TRIGGER_KEYWORDS. Covers + # cases where httpx wedges in ways our keyword heuristic can't detect + # (HTTP/2 protocol stalls, sticky timeouts, stale pools). + POLL_FORCE_REFRESH_AFTER = 3 def __init__(self, server_url: str, verbose: bool = True): + """ + Initialize the SwarmClientBase. + + Args: + server_url: The URL of the swarm server. + verbose: If True, enable verbose logging output. + """ self.server_url = server_url self.verbose = verbose self.client_uuid = str(uuid.uuid4()) @@ -80,6 +110,10 @@ def __init__(self, server_url: str, verbose: bool = True): self._engine_status_ready = threading.Event() self._engine_status_last_error_log_time = 0.0 self._engine_status_poll_interval = self.SLOW_POLL + # consecutive failures since the last successful poll. Used to force an + # http client refresh when the keyword-based heuristic in + # `_should_refresh_client_on_error` misses a wedged connection. + self._engine_status_consecutive_failures = 0 # fast-poll window: True for FAST_POLL_WINDOW seconds after each get_engine_status() call self._high_freq_update_status = False @@ -116,20 +150,36 @@ def logger_info(self, message): # ---- http client -------------------------------------------------- def _refresh_http_client(self): - """Close the existing http client and create a fresh one.""" + """Close the existing http client and create a fresh one. + + HTTP/1.1 only on purpose: swarm endpoints are small, low-frequency + polls/heartbeats, so multiplexing buys nothing — and HTTP/2 has a class + of stall failures (flow-control deadlock, GOAWAY mishandling, ping + timeouts, HPACK desync, server-restart-behind-LB) where the TCP + connection stays "alive" but every stream hangs without ever raising + connection-reset/broken-pipe. That makes the keyword-based refresh + heuristic miss them. Plain HTTP/1.1 fails loudly on the same + scenarios, which our refresh logic can detect. + """ with self._http_client_lock: try: self._http_client.close() except Exception: pass - try: - self._http_client = httpx.Client(timeout=GENERAL_TIMEOUT, http2=True) - except Exception: - self._http_client = httpx.Client(timeout=GENERAL_TIMEOUT, http2=False) + self._http_client = httpx.Client(timeout=GENERAL_TIMEOUT, http2=False) logger.warning("swarm client httpx client refreshed.") return self._http_client def _should_refresh_client_on_error(self, error: Exception) -> bool: + """ + Check if the HTTP client should be refreshed based on the error message. + + Args: + error: The exception that occurred during an HTTP request. + + Returns: + True if the error message contains keywords indicating a connection issue. + """ msg = str(error).lower() return any(k in msg for k in self.REFRESH_TRIGGER_KEYWORDS) @@ -141,6 +191,12 @@ def add_entering_weight_sync_callback(self, callback): self._entering_weight_sync_callbacks.append(callback) def _observe_engine_status(self, new_status: str): + """ + Observe engine status changes and fire callbacks on transitions. + + Args: + new_status: The new engine status string. + """ with self._engine_status_callback_lock: fresh_entry = ( new_status == "ENGINE.WEIGHT_SYNCING" @@ -172,15 +228,27 @@ def get_global_step(self) -> int: return status_json.get("global_step", 0) def _engine_status_poll_loop(self): - """Background thread: fetch engine status at _engine_status_poll_interval.""" + """Background thread: fetch engine status at _engine_status_poll_interval. + + Top-level try/except is a final safety net: if it dies the cache freezes + forever and only a process restart recovers — exactly the failure mode + this whole module is trying to prevent. + """ while not self._engine_status_poll_stop.is_set(): - if self._high_freq_update_status and time.time() >= self._high_freq_update_expiry: - self._high_freq_update_status = False - self._engine_status_poll_interval = self.SLOW_POLL - self._poll_engine_status_once() + try: + if self._high_freq_update_status and time.time() >= self._high_freq_update_expiry: + self._high_freq_update_status = False + self._engine_status_poll_interval = self.SLOW_POLL + self._poll_engine_status_once() + except Exception as e: + now = time.time() + if now - self._engine_status_last_error_log_time > 30: + logger.exception(f"Unexpected error in engine_status poll loop (continuing): {e}") + self._engine_status_last_error_log_time = now self._engine_status_poll_stop.wait(self._engine_status_poll_interval) def _poll_engine_status_once(self): + """Fetch engine status from the server once and update the cache.""" try: resp = self._http_client.get(f"{self.server_url}/get_engine_status", timeout=10) raise_for_status_with_detail(resp) @@ -190,17 +258,34 @@ def _poll_engine_status_once(self): logger.warning(f"get_engine_status: {resp_json}") self._engine_status_cache = (status, resp_json) self._engine_status_ready.set() + self._engine_status_consecutive_failures = 0 self._observe_engine_status(status) except Exception as e: - if self._should_refresh_client_on_error(e): - self._refresh_http_client() + self._engine_status_consecutive_failures += 1 + # Refresh on either: (a) a known-transient error pattern, or + # (b) sustained failure even if the error doesn't match — httpx can + # wedge in ways the keyword heuristic doesn't catch, and without + # this the same broken connection keeps failing forever and the + # cached status stays stale until the process is restarted. + try: + if ( + self._should_refresh_client_on_error(e) + or self._engine_status_consecutive_failures >= self.POLL_FORCE_REFRESH_AFTER + ): + self._refresh_http_client() + self._engine_status_consecutive_failures = 0 + except Exception as refresh_err: + logger.error(f"engine_status poll: http client refresh failed: {refresh_err}") if self._engine_status_cache is None: # unblock waiters on the very first call when the server is unreachable self._engine_status_cache = ("ENGINE.CANNOT_CONNECT", {}) self._engine_status_ready.set() now = time.time() if now - self._engine_status_last_error_log_time > 30: - logger.error(f"Error getting engine status in poll loop: {e}") + logger.error( + f"Error getting engine status in poll loop " + f"(consecutive failures: {self._engine_status_consecutive_failures}): {e}" + ) self._engine_status_last_error_log_time = now def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING", verbose=True, timeout=3600): @@ -246,8 +331,16 @@ def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING", verbose= class SwarmClient(SwarmClientBase): + """HTTP client for interacting with the Swarm server for distributed RL training.""" def __init__(self, server_url: str, verbose: bool = True): + """ + Initialize the SwarmClient. + + Args: + server_url: The URL of the swarm server. + verbose: If True, enable verbose logging output. + """ super().__init__(server_url=server_url, verbose=verbose) self.previous_warning_time = 0 self.record_episode_expire_time = {} @@ -259,7 +352,7 @@ def __init__(self, server_url: str, verbose: bool = True): self._recent_seen_tasks = [] def _clean_up_expired_records(self): - # remove records that have expired and expired at least CLEAN_RECORD_TIMEOUT seconds ago + """Remove episode records that have expired beyond CLEAN_RECORD_TIMEOUT seconds.""" current_time = time.time() expired_episodes = [ episode_uuid for episode_uuid, expire_time in self.record_episode_expire_time.items() @@ -346,6 +439,14 @@ def _check_throttle_policy(self, throttle_policy: SwarmThrottlePolicy, pool_info return False, "" def _remember_seen_task(self, task_id: str, batch_size, num_repeat): + """ + Record a task_id as recently seen for throttle policy tracking. + + Args: + task_id: The task ID to remember. + batch_size: Expected batch size, used to calculate buffer limit. + num_repeat: Expected number of repeats per task, used to calculate buffer limit. + """ MAX_SEEN_TASK_BUFFER_SIZE = batch_size*num_repeat*3 # keep buffer size manageable, can be tuned if task_id not in self._recent_seen_tasks: self._recent_seen_tasks.append(task_id) @@ -353,6 +454,16 @@ def _remember_seen_task(self, task_id: str, batch_size, num_repeat): self._recent_seen_tasks = self._recent_seen_tasks[-MAX_SEEN_TASK_BUFFER_SIZE:] def _should_throttle(self, throttle_policy: SwarmThrottlePolicy, pool_info: CurrentBatchRolloutPoolInformation) -> bool: + """ + Determine if the client should throttle based on the throttle policy. + + Args: + throttle_policy: The throttle policy configuration. + pool_info: Current batch rollout pool information from the server. + + Returns: + True if the client should throttle and delay starting a new episode. + """ should_throttle, throttle_reason = self._check_throttle_policy(throttle_policy, pool_info) if not should_throttle: # direct start this episode @@ -375,6 +486,20 @@ def begin_episode(self, discard_episode_timeout=240, episode_type="train", throt return self._begin_episode_auto_retry(discard_episode_timeout, episode_type, throttle_policy) def _begin_episode_auto_retry(self, discard_episode_timeout=240, episode_type="train", throttle_policy: SwarmThrottlePolicy|None = None) -> Tuple[str, OpenaiBaseUrlAndApiKey]: + """ + Internal method to claim an episode with automatic retry logic. + + Args: + discard_episode_timeout: Idle timeout in seconds before the server discards the episode. + episode_type: Type of episode, either "train" or "eval". + throttle_policy: Optional throttle policy for task distribution control. + + Returns: + A tuple of (episode_uuid, OpenaiBaseUrlAndApiKey). + + Raises: + SwarmServerOfflineError: If the server goes offline during the operation. + """ # max_episode_time: when an episode has **lasted** for more than X seconds, it will be terminated **locally** by client (call `end_episode` will be re-route to `abort_episode`) max_episode_time = 8*discard_episode_timeout status, status_json = self.get_engine_status() # warm up connection and log the status @@ -479,8 +604,21 @@ def _begin_episode_auto_retry(self, discard_episode_timeout=240, episode_type="t if self._begin_episode_lock.locked(): self._begin_episode_lock.release() - def end_episode(self, task:Task, episode_uuid: str, workflow_output: WorkflowOutput): + def end_episode(self, task:Task, episode_uuid: str, workflow_output: WorkflowOutput, declare_client_active: bool = True): + """ + End an episode and submit the workflow output to the server. + Args: + task: The task associated with this episode. + episode_uuid: The UUID of the episode to end. + workflow_output: The workflow output containing reward and metadata. + declare_client_active: If True, register this client as active on the server. + This is only useful when you select `rollout_until_all_clients_agree_sync_weight`, + because in this case the server has to know how many client nodes are active. + + Raises: + RuntimeError: If the server fails to end the episode. + """ if not episode_uuid: logger.error("No episode to end.") return @@ -505,7 +643,8 @@ def end_episode(self, task:Task, episode_uuid: str, workflow_output: WorkflowOut client_uuid=self.client_uuid, episode_uuid=episode_uuid, workflow_output=workflow_output, - task_id=task_id + task_id=task_id, + declare_client_active=declare_client_active ) resp = self._http_client.post( @@ -532,7 +671,16 @@ def end_episode(self, task:Task, episode_uuid: str, workflow_output: WorkflowOut raise RuntimeError(f"Failed to end episode {episode_uuid}") - def abort_episode(self, episode_uuid: str): + def abort_episode(self, episode_uuid: str, declare_client_active: bool = True): + """ + Abort an episode without submitting a valid workflow output. + + Args: + episode_uuid: The UUID of the episode to abort. + declare_client_active: If True, register this client as active on the server. + This is only useful when you select `rollout_until_all_clients_agree_sync_weight`, + because in this case the server has to know how many client nodes are active. + """ if not episode_uuid: logger.error("No episode to end.") return @@ -543,7 +691,8 @@ def abort_episode(self, episode_uuid: str): client_uuid=self.client_uuid, episode_uuid=episode_uuid, workflow_output=workflow_output, - task_id="" + task_id="", + declare_client_active=declare_client_active ) resp = self._http_client.post( @@ -624,6 +773,15 @@ def start_engine(self): def can_continue_episode(self, episode_uuid: str) -> bool: + """ + Check if an episode can continue (still claimed and engine is rolling). + + Args: + episode_uuid: The UUID of the episode to check. + + Returns: + True if the episode can continue, False otherwise. + """ if not episode_uuid: return False try: @@ -646,6 +804,12 @@ def can_continue_episode(self, episode_uuid: str) -> bool: return False def get_episode_buffer(self) -> List[EpisodeStatus]: + """ + Get the current episode buffer from the server. + + Returns: + A list of EpisodeStatus objects representing all active episodes. + """ try: resp = self._http_client.post( f"{self.server_url}/get_episode_buffer", @@ -661,7 +825,7 @@ def get_episode_buffer(self) -> List[EpisodeStatus]: logger.error(f"Error getting episode buffer: {e}") return [] - def auto_sync_train_config_and_start_engine(self, agent_jet_job: AgentJetJob, force_restart=False): + def auto_sync_train_config_and_start_engine(self, agent_jet_job: AgentJetJob, force_restart=False, _retry_once=True) -> None: """ Automatically sync training configuration and start the engine if needed. This checks the current engine status and performs actions accordingly. @@ -694,6 +858,9 @@ def auto_sync_train_config_and_start_engine(self, agent_jet_job: AgentJetJob, fo self.logger_info("Engine is already ROLLING. No action needed.") elif current_status in ["ENGINE.CANNOT_CONNECT"]: logger.error("Unable to connect to swarm server.") + if _retry_once: + time.sleep(16) + return self.auto_sync_train_config_and_start_engine(agent_jet_job, force_restart=force_restart, _retry_once=False) raise RuntimeError(f"Unable to connect to swarm server.") elif current_status in ["ENGINE.BOOTING", "ENGINE.WEIGHT_SYNCING"]: self.logger_info(f"Engine is {current_status}. Waiting until it becomes ROLLING...") @@ -767,6 +934,76 @@ def server_experiment_dir(self) -> str: except Exception as e: return "saved_experiments" + def agree_sync_weight(self) -> bool: + """Notify the swarm server that this client agrees to a weight sync. + + The server only accepts the agreement if this client is in its + active-client list (i.e. has end_episode'd at least one rewarded + episode since the last sync). Used together with the + `rollout_until_any_client_agree_sync_weight` / + `rollout_until_all_clients_agree_sync_weight` stop conditions so the + client can decide for itself when its current batch is "good enough". + + Important: this call retries on failure. A dropped agreement can + stall the trainer indefinitely (e.g. under "all clients agree"), and + the most common rejection -- "client not yet in active list" -- + clears itself once the just-finished end_episode propagates. Only + gives up after AGREE_SYNC_WEIGHT_MAX_RETRIES attempts, or if the + engine has left ROLLING/ROLLING_POST (the agreement would be wiped + by the server-side reset anyway). + + Returns: True if the agreement was registered, False after + exhausting retries (or after the engine left rolling state). + """ + last_failure = "" + for attempt in range(1, AGREE_SYNC_WEIGHT_MAX_RETRIES + 1): + engine_status, _ = self.get_engine_status() + if engine_status not in ("ENGINE.ROLLING", "ENGINE.ROLLING_POST"): + logger.warning( + f"agree_sync_weight: engine is {engine_status}, abandoning " + f"agreement (would be reset by server-side cleanup anyway)." + ) + return False + try: + req_obj = AgreeSyncWeightRequest(client_uuid=self.client_uuid) + resp = self._http_client.post( + f"{self.server_url}/agree_sync_weight", + json=req_obj.model_dump(), + timeout=10, + ) + raise_for_status_with_detail(resp) + data = BoolResponse.model_validate(resp.json()) + if data.success: + if self.verbose: + self.logger_info( + f"agree_sync_weight: registered with server " + f"(attempt {attempt})" + ) + # time.sleep(DELAY_AFTER_AGREE_SYNC_WEIGHT) + self._wait_until_status_change_to(desired_status="ENGINE.ROLLING_POST") + return True + last_failure = data.failure_reason + logger.warning( + f"agree_sync_weight rejected (attempt " + f"{attempt}/{AGREE_SYNC_WEIGHT_MAX_RETRIES}): " + f"{data.failure_reason}. Retrying in " + f"{AGREE_SYNC_WEIGHT_RETRY_DELAY}s..." + ) + except Exception as e: + last_failure = str(e) + if self._should_refresh_client_on_error(e): + self._refresh_http_client() + logger.error( + f"agree_sync_weight errored (attempt " + f"{attempt}/{AGREE_SYNC_WEIGHT_MAX_RETRIES}): {e}. Retrying..." + ) + time.sleep(AGREE_SYNC_WEIGHT_RETRY_DELAY) + logger.error( + f"agree_sync_weight: gave up after {AGREE_SYNC_WEIGHT_MAX_RETRIES} " + f"attempts. Last failure: {last_failure}" + ) + return False + def get_rollout_stat(self) -> CurrentBatchRolloutPoolInformation: """ Get the current batch rollout pool information from the Swarm server. @@ -813,6 +1050,16 @@ def print_rollout_stat(self): pass def auto_train_with_dataset(dataset, swarm_worker: SwarmClient, execute_agent, local_grpo_n=2, remote_batch_size=8): + """ + Automatically train with a dataset using the swarm worker. + + Args: + dataset: The dataset providing training tasks via generate_training_tasks(). + swarm_worker: The SwarmClient instance for communication with the server. + execute_agent: A callable that executes the agent on a task and returns WorkflowOutput. + local_grpo_n: Number of local GRPO repeats per task. + remote_batch_size: Number of parallel remote workers. + """ from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor def rollout(task) -> float | None: diff --git a/ajet/tuner_lib/experimental/swarm_overwatch_utils.py b/ajet/tuner_lib/experimental/swarm_overwatch_utils.py index 06355e50..1300133c 100644 --- a/ajet/tuner_lib/experimental/swarm_overwatch_utils.py +++ b/ajet/tuner_lib/experimental/swarm_overwatch_utils.py @@ -19,3 +19,4 @@ class CurrentBatchRolloutPoolInformation(BaseModel): global_step: int | None = None booting_start_time: float | None = None # timestamp when ENGINE.BOOTING started training_model_path: str | None = None # model path from synced training config + swarm_client_instruction: dict = {} diff --git a/ajet/tuner_lib/experimental/swarm_server.py b/ajet/tuner_lib/experimental/swarm_server.py index a4c1a346..9d3d60c0 100644 --- a/ajet/tuner_lib/experimental/swarm_server.py +++ b/ajet/tuner_lib/experimental/swarm_server.py @@ -12,7 +12,7 @@ from typing import Coroutine, Optional, Tuple, List from ajet.utils.process_killer import kill_process_tree from ajet.tuner_lib.experimental.swarm_overwatch_utils import CurrentBatchRolloutPoolInformation -from ajet.tuner_lib.experimental.interchange_utils import DEBUG, VERBOSE +from ajet.tuner_lib.experimental.interchange_utils import DEBUG, VERBOSE, CLIENT_ACTIVE_TIMEOUT from ajet.tuner_lib.experimental.interchange_utils import ( SyncTrainConfigRequest, ClaimEpisodeRequest, @@ -30,9 +30,17 @@ PushVerboseLogRequest, VerboseLogEntry, VerboseLogsResponse, + AgreeSyncWeightRequest, + SwarmClientInstruction, + ActiveSwarmClient, + _refresh_client_activity, + _register_active_client, + _expire_inactive_clients, + _reset_active_client_tracking, VALID_STATUSES, ) + VERBOSE_LOG_TTL_SECONDS = 30.0 VERBOSE_LOG_MAX_ENTRIES = 50 @@ -69,6 +77,11 @@ def register_enable_swarm_mode_routes( if "current_batch_rollout_pool_information" not in shared_mem_dict: shared_mem_dict["current_batch_rollout_pool_information"] = CurrentBatchRolloutPoolInformation() + # active swarm client tracking (List[ActiveSwarmClient]; helpers live in + # interchange_utils) + if "active_swarm_clients" not in shared_mem_dict: + shared_mem_dict["active_swarm_clients"] = [] + # ------------------------------------------------------------------------------------------------ # ------ Recycle claimed episodes that client failed to complete in (promised) time -------------- # --------------------------------- claimed -> unclaimed ---------------------------------------- @@ -227,6 +240,7 @@ async def register_episode_ready_listener(): while True: await asyncio.sleep(10) # check every 10 seconds await find_claimed_episodes_that_need_to_be_unclaimed() + _expire_inactive_clients(shared_mem_dict) # read_all_episode_status() if DEBUG: _write_swarm_server_dynamic_log(shared_mem_dict) @@ -280,6 +294,10 @@ def _clean_up_engine_status(shared_mem_dict_lock, shared_mem_dict): shared_mem_dict["unclaimed_episodes"] = [] logger.info(f"[_clean_up_engine_status] Cleared {num_unclaimed} unclaimed episodes") + # reset active-client tracking (cleared each time we leave ROLLING/ + # ROLLING_POST -- i.e. on entering WEIGHT_SYNCING etc.) + _reset_active_client_tracking(shared_mem_dict) + # -------------------------------------------------------------------------------------- # -------------------------- fastapi routes -------------------------------------------- # -------------------------------------------------------------------------------------- @@ -602,6 +620,11 @@ async def claim_episode(req: ClaimEpisodeRequest): if VERBOSE: logger.info(f"Running [{episode_uuid}]: /claim_episode") + # begin_episode counts as activity for keeping a client in the + # active list (only refreshes if already active; first activation + # comes from a successful end_episode). + _refresh_client_activity(req.client_uuid, shared_mem_dict) + return ClaimEpisodeResponse( success=True, client_uuid=req.client_uuid, @@ -668,6 +691,9 @@ async def end_episode(req: EndEpisodeRequest): shared_mem_dict, shared_mem_dict_lock, ) + # successful, non-abort end_episode marks the client "active" + if req.declare_client_active: + _register_active_client(client_uuid, shared_mem_dict) elif episode_type == "eval": if engine_status in ["ENGINE.ROLLING"]: @@ -707,6 +733,9 @@ async def abort_episode(req: EndEpisodeRequest): else: _delete_episode_record(episode_uuid, shared_mem_dict, shared_mem_dict_lock) + if req.declare_client_active: + _register_active_client(req.client_uuid, shared_mem_dict) + return EndEpisodeResponse(success=True) @app.post("/can_continue_episode", response_model=CanContinueEpisodeResponse) @@ -742,11 +771,23 @@ async def get_episode_buffer(): result = [v for k, v in shared_mem_dict.items() if is_key_episode_status(k)] return EpisodeBufferResponse(buffer=result) - @app.post("/update_current_batch_rollout_pool_information", response_model=BoolResponse) - async def update_current_batch_rollout_pool_information(req: CurrentBatchRolloutPoolInformation): - """Update the current batch rollout pool information.""" + @app.post( + "/update_current_batch_rollout_pool_information_and_fetch_instruction", + response_model=SwarmClientInstruction, + ) + async def update_current_batch_rollout_pool_information_and_fetch_instruction( + req: CurrentBatchRolloutPoolInformation, + ): + """Update pool information and return the active-client instruction. + + The trainer pushes its latest pool snapshot here every few seconds; + in the same call we hand back the server-maintained + `active_swarm_clients` list so the trainer can evaluate + `rollout_until_*_agree_sync_weight` stop conditions without an extra + round-trip. + """ if DEBUG: - logger.info(f"Running /update_current_batch_rollout_pool_information") + logger.info(f"Running /update_current_batch_rollout_pool_information_and_fetch_instruction") try: with shared_mem_dict_lock: # Ignore fields that are only maintained in shared_mem_dict @@ -755,10 +796,62 @@ async def update_current_batch_rollout_pool_information(req: CurrentBatchRollout req.global_step = None req.completed_tasks_client_uuids = {} shared_mem_dict["current_batch_rollout_pool_information"] = req - return BoolResponse(success=True) + instruction = SwarmClientInstruction( + active_clients=list(shared_mem_dict.get("active_swarm_clients", [])) + ) + return instruction except Exception as e: logger.error(f"Error updating current batch rollout pool information: {e}") - return BoolResponse(success=False, failure_reason=str(e)) + return SwarmClientInstruction() + + AGREE_SYNC_WEIGHT_VALID_METHODS = ( + "rollout_until_any_client_agree_sync_weight", + "rollout_until_all_clients_agree_sync_weight", + ) + + @app.post("/agree_sync_weight", response_model=BoolResponse) + async def agree_sync_weight(req: AgreeSyncWeightRequest): + """Mark a client as having agreed to the next weight sync. + + Only counts when the client is currently in the active list (otherwise + the agreement would be silently expired anyway). The set is cleared + whenever the engine leaves ROLLING/ROLLING_POST. + + Refuses the call unless the trainer is configured with one of the + agree-driven sample-collection methods, since under any other policy + the agreement would have no effect on when the trainer stops. + """ + if VERBOSE: + logger.info(f"Running /agree_sync_weight: {req.client_uuid}") + client_uuid = req.client_uuid + if not client_uuid: + return BoolResponse(success=False, failure_reason="client_uuid required") + pool_info: CurrentBatchRolloutPoolInformation = shared_mem_dict.get( + "current_batch_rollout_pool_information", + CurrentBatchRolloutPoolInformation(), + ) + assert pool_info.sample_collection_method in AGREE_SYNC_WEIGHT_VALID_METHODS, ( + f"agree_sync_weight is only valid when " + f"ajet.swarm_mode_sample_collection_method is one of " + f"{AGREE_SYNC_WEIGHT_VALID_METHODS}, but the trainer is currently " + f"running with '{pool_info.sample_collection_method}'." + ) + with shared_mem_dict_lock: + clients: List[ActiveSwarmClient] = list( + shared_mem_dict.get("active_swarm_clients", []) + ) + for i, c in enumerate(clients): + if c.client_uuid == client_uuid: + if not c.allowed_sync_weight: + clients[i] = c.model_copy(update={"allowed_sync_weight": True}) + shared_mem_dict["active_swarm_clients"] = clients + return BoolResponse(success=True) + return BoolResponse( + success=False, + failure_reason=( + f"Client {client_uuid} is not in the active list -- it must have completed at least one rewarded (non-abort) episode since the last weight sync before agreeing." + ), + ) @app.get("/get_current_batch_rollout_pool_information", response_model=CurrentBatchRolloutPoolInformation) async def get_current_batch_rollout_pool_information(): @@ -773,6 +866,9 @@ async def get_current_batch_rollout_pool_information(): pool_info.global_step = shared_mem_dict.get("global_step", None) pool_info.booting_start_time = shared_mem_dict.get("booting_start_time", None) pool_info.training_model_path = shared_mem_dict.get("training_model_path", None) + pool_info.swarm_client_instruction = SwarmClientInstruction( + active_clients=list(shared_mem_dict.get("active_swarm_clients", [])) + ).model_dump() # Build running_episode_details for claimed episodes running_episode_details = {} diff --git a/ajet/utils/cleaner.py b/ajet/utils/cleaner.py index 9bb37f80..163619eb 100644 --- a/ajet/utils/cleaner.py +++ b/ajet/utils/cleaner.py @@ -4,6 +4,13 @@ import time +# Canonical autokill keyword set. The string is `|`-separated to match the +# `--kill` CLI flag, which expects `kw1|kw2|...`. Every site that wants the +# "kill all training-related processes" behavior must import this constant +# rather than re-declaring the literal. +AUTOKILL_KEYWORDS = "ray|vllm|VLLM|python" + + def kill_ray_processes(): """run ray stop command to kill ray processes""" try: diff --git a/ajet/utils/env_service_client/env_client_ng.py b/ajet/utils/env_service_client/env_client_ng.py index a8e1112f..3aba8c0b 100644 --- a/ajet/utils/env_service_client/env_client_ng.py +++ b/ajet/utils/env_service_client/env_client_ng.py @@ -207,9 +207,13 @@ def call(): messages=action, params=params, ) - return resp["data"] + data = resp["data"] + while "data" in data and "state" not in data: + data = data["data"] + data["state"] = data["state"][0] + return data - res = retry_call( + return retry_call( call, max_retry=max_retry, fail_return=fallback, @@ -217,8 +221,6 @@ def call(): instance_id=instance_id, action_name="step", ) - res["state"] = res["state"][0] - return res def evaluate( self, diff --git a/ajet/utils/message_utils.py b/ajet/utils/message_utils.py index 3792161c..b079eed5 100644 --- a/ajet/utils/message_utils.py +++ b/ajet/utils/message_utils.py @@ -3,6 +3,30 @@ from loguru import logger +_TOKEN_OVERFLOW_SIGNATURE = "Exceeded max model context length. token_overflow" + +def is_token_overflow_message(content) -> bool: + """Return True if `content` represents the AgentJet token-overflow output + (prompt would exceed max_model_len). Accepts a raw string, a message dict + with a "content" field, bytes, or None. Match is substring-based so the + signal survives whitespace, the "AgentJet:" prefix being stripped, or the + content being embedded in a larger blob. + """ + if content is None: + return False + if isinstance(content, dict): + content = content.get("content") + if not isinstance(content, str): + return False + elif isinstance(content, (bytes, bytearray)): + try: + content = content.decode("utf-8", errors="ignore") + except Exception: + return False + elif not isinstance(content, str): + return False + return _TOKEN_OVERFLOW_SIGNATURE in content + def log_empty_content_messages(messages: List[Dict], episode_uuid: str = "") -> None: """Scan an OpenAI-compatible message list and log an error for any message diff --git a/ajet/utils/swarm_overwatch.py b/ajet/utils/swarm_overwatch.py index d076e330..09d00062 100644 --- a/ajet/utils/swarm_overwatch.py +++ b/ajet/utils/swarm_overwatch.py @@ -44,7 +44,15 @@ def __init__(self, server_url: str, refresh_interval: float = 2.0): self.last_update_time = None self.error_count = 0 self.total_requests = 0 - self._httpx_client = httpx.Client(timeout=5.0) + # Disable keep-alive: each poll opens a fresh TCP connection so the kernel + # re-picks a worker every time. Without this, a long-lived connection can + # stick to a stale/zombie worker after a server restart and silently keep + # returning frozen state. + self._httpx_client = httpx.Client( + timeout=5.0, + limits=httpx.Limits(max_keepalive_connections=0, keepalive_expiry=0), + headers={"Connection": "close"}, + ) self._verbose_logs: list = [] # list of dicts {timestamp, tag, message} def _refresh_http_client(self): @@ -53,7 +61,11 @@ def _refresh_http_client(self): self._httpx_client.close() except Exception: pass - self._httpx_client = httpx.Client(timeout=5.0) + self._httpx_client = httpx.Client( + timeout=5.0, + limits=httpx.Limits(max_keepalive_connections=0, keepalive_expiry=0), + headers={"Connection": "close"}, + ) logger.warning("swarm overwatch httpx client refreshed.") def _should_refresh_client_on_error(self, error: Exception) -> bool: @@ -116,7 +128,18 @@ def create_header( header_text = Text() header_text.append("AgentJet Swarm Overwatch", style="bold cyan") - header_text.append(f"\nServer: {self.server_url}", style="dim") + header_text.append(f" | Server: {self.server_url}", style="dim") + + instr = info.swarm_client_instruction if info else {} + active_clients = instr.get("active_clients", []) + agreed = sum(1 for c in active_clients if c.get("allowed_sync_weight")) + total = len(active_clients) + header_text.append(f"\nActive Clients: {total}", style="bold white") + if total: + parts = [f"{c.get('client_uuid','?')[:8]}{'✓' if c.get('allowed_sync_weight') else ''}" for c in active_clients[:8]] + suffix = f", +{total - 8} more" if total > 8 else "" + header_text.append(f" [{', '.join(parts)}{suffix}]", style="cyan") + header_text.append(f"\nCurrent Time: {now}", style="green") header_text.append(f" | Last Update: {last_update}", style="yellow") header_text.append(f" | Refresh: {self.refresh_interval}s", style="blue") @@ -196,6 +219,13 @@ def create_summary_table(self, info: CurrentBatchRolloutPoolInformation) -> Tabl info.sample_collection_method == "rollout_until_finish_enough_non_dummy_tasks" ) + highlight_any_agree = ( + info.sample_collection_method == "rollout_until_any_client_agree_sync_weight" + ) + highlight_all_agree = ( + info.sample_collection_method == "rollout_until_all_clients_agree_sync_weight" + ) + highlight_agree = highlight_any_agree or highlight_all_agree # Episodes ep_cur, ep_tgt, ep_pct = self.create_progress_bar( @@ -277,6 +307,30 @@ def create_summary_table(self, info: CurrentBatchRolloutPoolInformation) -> Tabl "-" ) + # Clients agree sync weight (only shown for agree_sync_weight methods) + if highlight_agree: + instr = info.swarm_client_instruction if info.swarm_client_instruction else {} + active_clients = instr.get("active_clients", []) + total_clients = len(active_clients) + agreed_clients = sum(1 for c in active_clients if c.get("allowed_sync_weight")) + # Target depends on mode: any=1, all=total + if highlight_any_agree: + agree_target = 1 if total_clients > 0 else 0 + else: + agree_target = total_clients + agree_pct = (agreed_clients / agree_target * 100) if agree_target > 0 else 0.0 + agree_bar = self._create_text_bar(agree_pct) + agree_metric = "-> *Clients Agree Sync Weight (chosen)*" + agree_style = "bold green" + table.add_row( + f"[{agree_style}]{agree_metric}[/{agree_style}]", + f"{agreed_clients:,}", + f"{agree_target:,}", + f"{agree_pct:.1f}%", + agree_bar, + style=agree_style, + ) + return table def _create_text_bar(self, percentage: float, width: int = 20) -> str: diff --git a/ajet/utils/thread_executors.py b/ajet/utils/thread_executors.py index fa1c2a9f..005adc0b 100644 --- a/ajet/utils/thread_executors.py +++ b/ajet/utils/thread_executors.py @@ -83,6 +83,10 @@ def submit_with_periodic_drain(self, fn, *args, **kwargs): """Submit a task, draining all in-flight work every `drain_every_n_job` submissions.""" drain_every_n_job = self._max_workers results = [] + self._submitted_count += 1 + future = self.submit(fn, *args, **kwargs) + self.current_futures.append(future) + if self._submitted_count > 0 and self._submitted_count % drain_every_n_job == 0: pbar = tqdm(total=len(self.current_futures), desc="Draining in-flight tasks") for _ in as_completed(self.current_futures): @@ -95,9 +99,6 @@ def submit_with_periodic_drain(self, fn, *args, **kwargs): logger.exception(f"Error in task execution: {e}") self.current_futures = [] - self._submitted_count += 1 - future = self.submit(fn, *args, **kwargs) - self.current_futures.append(future) return future, results def shutdown(self, wait=True): diff --git a/docs/en/configuration.md b/docs/en/configuration.md index 4a5db22f..6d0f4685 100644 --- a/docs/en/configuration.md +++ b/docs/en/configuration.md @@ -599,7 +599,7 @@ Controls the Context Tracker, which intercepts LLM calls, builds aligned timelin - **Type:** str. - **Default:** `"text"`. - **Description:** Controls how timelines are compared when deciding whether to merge shared conversation prefixes: - - `"text"` (relaxed) — Compares `content_for_compare` strings between timeline messages. More aggressive merging at very little cost, resulting in higher training speedup. + - `"text"` (relaxed) — Compares `text_content_for_compare` strings between timeline messages. More aggressive merging at very little cost, resulting in higher training speedup. - `"token"` (strict) — Compares exact `token_arr` sequences between timeline messages. Less aggressive merging since tokenization differences (e.g. whitespace handling) prevent matches. Use when tokenization fidelity is critical. ### `ajet.context_tracker.timeline_merging_policy.ignore_tools` diff --git a/scripts/deploy_model.py b/scripts/deploy_model.py index e030e4c6..9464f8ee 100644 --- a/scripts/deploy_model.py +++ b/scripts/deploy_model.py @@ -9,7 +9,7 @@ from loguru import logger # noqa: E402 -from ajet.utils.cleaner import fast_kill_by_keyword_bash # noqa: E402 +from ajet.utils.cleaner import AUTOKILL_KEYWORDS, fast_kill_by_keyword_bash # noqa: E402 from ajet.utils.smart_daemon import LaunchCommandWhenAbsent # noqa: E402 parser = argparse.ArgumentParser(description="deploy Hugging Face model") @@ -42,7 +42,7 @@ args = parser.parse_args() if args.autokill: - args.kill = "ray|vllm|VLLM|python" + args.kill = AUTOKILL_KEYWORDS # Handle kill-keywords argument if provided if args.kill: diff --git a/tests/bench/benchmark_appworld/execute_benchmark_appworld.py b/tests/bench/benchmark_appworld/execute_benchmark_appworld.py index 0541f59d..e6abd64c 100644 --- a/tests/bench/benchmark_appworld/execute_benchmark_appworld.py +++ b/tests/bench/benchmark_appworld/execute_benchmark_appworld.py @@ -55,7 +55,7 @@ def test_02_begin_trinity(self): def clear_system_processes(self): # kill all python + ray + vllm processes - from ajet.utils.cleaner import fast_kill_by_keyword_bash + from ajet.utils.cleaner import AUTOKILL_KEYWORDS, fast_kill_by_keyword_bash total_seconds = 15 for i in range(total_seconds): @@ -64,8 +64,7 @@ def clear_system_processes(self): ) time.sleep(1) - kill = "ray|vllm|VLLM|python" - for keyword in kill.split("|"): + for keyword in AUTOKILL_KEYWORDS.split("|"): logger.info(f"Killing processes matching keyword: {keyword}") killed_pids = fast_kill_by_keyword_bash(keyword) if killed_pids: diff --git a/tests/bench/benchmark_appworldlora/execute_benchmark_appworldlora.py b/tests/bench/benchmark_appworldlora/execute_benchmark_appworldlora.py index 01e09ebc..bbc7fba3 100644 --- a/tests/bench/benchmark_appworldlora/execute_benchmark_appworldlora.py +++ b/tests/bench/benchmark_appworldlora/execute_benchmark_appworldlora.py @@ -55,7 +55,7 @@ def test_02_begin_trinity(self): def clear_system_processes(self): # kill all python + ray + vllm processes - from ajet.utils.cleaner import fast_kill_by_keyword_bash + from ajet.utils.cleaner import AUTOKILL_KEYWORDS, fast_kill_by_keyword_bash total_seconds = 15 for i in range(total_seconds): @@ -64,8 +64,7 @@ def clear_system_processes(self): ) time.sleep(1) - kill = "ray|vllm|VLLM|python" - for keyword in kill.split("|"): + for keyword in AUTOKILL_KEYWORDS.split("|"): logger.info(f"Killing processes matching keyword: {keyword}") killed_pids = fast_kill_by_keyword_bash(keyword) if killed_pids: diff --git a/tutorial/example_appworld/appworld.md b/tutorial/example_appworld/appworld.md index 3a759a8a..cd91deac 100644 --- a/tutorial/example_appworld/appworld.md +++ b/tutorial/example_appworld/appworld.md @@ -1,8 +1,19 @@ ## Run Appworld AgentScope Agent -### 1. Prepare dataset +### 1. Install and Run Appworld + +- Install: +``` +rm -rf /tmp/pack_all_in_one & wget https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/astuner_archive/appworld_pack_v3.tar.gz && tar -xzf ./appworld_pack_v3.tar.gz -C /tmp +``` + +- Run: +``` +export APPWORLD_PATH="/tmp/pack_all_in_one" +export APPWORLD_SCRIPT="bash EnvService/env_sandbox/appworld.sh" +ajet --with-appworld --skip-check-avail-gpu +``` -Please download `env_service` and `appworld`. For specific steps, please refer to [EnvService Documentation](https://code.alibaba-inc.com/EconML/EnvService) ### 2. Prepare AgentScope Workflow diff --git a/tutorial/example_appworld_swarm/README.md b/tutorial/example_appworld_swarm/README.md new file mode 100644 index 00000000..5d1211aa --- /dev/null +++ b/tutorial/example_appworld_swarm/README.md @@ -0,0 +1,45 @@ +## AppWorld swarm mode + +Swarm-mode rewrite of `tutorial/example_appworld`. +The training engine runs remotely (server side), while task enumeration, +env_service instance lifecycle and reward evaluation all happen locally +in the rollout client. + +Files: +- `appworld_swarm.py` — workflow + lightweight `EnvClient` gym wrapper +- `agent_roll.py` — rollout driver (calls `begin_episode` / `end_episode`) +- `appworld.yaml` — swarm-mode training config + +Required env vars (with sensible defaults): +- `AJET_SWARM_URL` — swarm server URL (default `http://localhost:10086`) +- `APPWORLD_ENV_URL` — appworld env_service URL (default `http://127.0.0.1:8080`) +- `APPWORLD_ENV_TYPE` — env_type passed to env_service (default `appworld`) +- `APPWORLD_TRAINING_SPLIT` — train split for `get_env_profile` (default `train`) +- `APPWORLD_VALIDATION_SPLIT` — eval split for `get_env_profile` (default `dev`) +- `APPWORLD_MAX_STEPS` — per-episode step cap (default `25`) +- `APPWORLD_EVAL_INTERVAL` — run eval every N global steps (default `10`) +- `APPWORLD_EVAL_K` — rollouts per eval task, pass@k (default `1`) +- `APPWORLD_TOTAL_TRAINING_STEPS`— hard cap on global steps (default `200`) +- `APPWORLD_RESULT_DIR` — where eval logs / `val_results.md` are written (default `./appworld_swarm_results`) +- `APPWORLD_MAX_ENV_WORKER` — max parallel env workers for both train and eval (default `64`) + + +## Run swarm + +``` +tmux new-session -d -s "SWARM_SERVER" +tmux send-keys -t "SWARM_SERVER" "cd /mnt/data_cpfs/qingxu.fu/agentjet/hello-agentjet" Enter +tmux send-keys -t "SWARM_SERVER" "source .venv/bin/activate" Enter +tmux send-keys -t "SWARM_SERVER" "export SETUPTOOLS_USE_DISTUTILS=local" Enter +tmux send-keys -t "SWARM_SERVER" "ajet-swarm start" Enter +ta "SWARM_SERVER" + + +tmux new-session -d -s "SWARM_CLIENT" +tmux send-keys -t "SWARM_CLIENT" "cd /mnt/data_cpfs/qingxu.fu/agentjet/hello-agentjet" Enter +tmux send-keys -t "SWARM_CLIENT" "source .venv/bin/activate" Enter +tmux send-keys -t "SWARM_CLIENT" "export SETUPTOOLS_USE_DISTUTILS=local" Enter +tmux send-keys -t "SWARM_CLIENT" "sleep 30s" Enter +tmux send-keys -t "SWARM_CLIENT" "python -m tutorial.example_appworld_swarm.agent_roll" Enter +ta "SWARM_CLIENT" +``` diff --git a/tutorial/example_appworld_swarm/agent_roll.py b/tutorial/example_appworld_swarm/agent_roll.py new file mode 100644 index 00000000..4b86255b --- /dev/null +++ b/tutorial/example_appworld_swarm/agent_roll.py @@ -0,0 +1,228 @@ +# -*- coding: utf-8 -*- + +# python -m tutorial.example_appworld_swarm.agent_roll + +import os +import statistics +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Generator, List + +from tqdm import tqdm + +from ajet.copilot.job import AgentJetJob +from ajet.schema.task import Task +from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey +from ajet.tuner_lib.experimental.swarm_client import SwarmClient +from ajet.utils.env_service_client.env_client_ng import EnvClient +from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor + +NUM_EPOCH = 10000 +AJET_SWARM_URL = os.getenv("AJET_SWARM_URL", "http://localhost:10086") + +ENV_URL = os.getenv("APPWORLD_ENV_URL", "http://127.0.0.1:8080") +ENV_TYPE = os.getenv("APPWORLD_ENV_TYPE", "appworld") +TRAINING_SPLIT = os.getenv("APPWORLD_TRAINING_SPLIT", "train") +VALIDATION_SPLIT = os.getenv("APPWORLD_VALIDATION_SPLIT", "dev") +MAX_STEPS = int(os.getenv("APPWORLD_MAX_STEPS", "25")) + +EVAL_INTERVAL = int(os.getenv("APPWORLD_EVAL_INTERVAL", "10")) +EVAL_K = int(os.getenv("APPWORLD_EVAL_K", "1")) +TOTAL_TRAINING_STEPS = int(os.getenv("APPWORLD_TOTAL_TRAINING_STEPS", "200")) +RESULT_DIR = os.getenv("APPWORLD_RESULT_DIR", "./appworld_swarm_results") +MAX_ENV_WORKER = int(os.getenv("APPWORLD_MAX_ENV_WORKER", "64")) + + +def get_appworld_tasks(split: str) -> List[Task]: + """Enumerate appworld task ids from env_service for the given split. + + The swarm client owns task generation, so we hit env_service directly + (rather than going through `EnvServiceTaskReader`) to keep the config + surface flat. + """ + env_client = EnvClient(base_url=ENV_URL) + task_id_array = env_client.get_env_profile(ENV_TYPE, split=split) + if len(task_id_array) == 0: + raise ValueError( + f"No task_id found for env_type={ENV_TYPE}, split={split}, " + f"check connection to {ENV_URL}" + ) + return [ + Task( + main_query="[not defined]", + init_messages=[], + task_id=str(task_id), + env_type=ENV_TYPE, + metadata={}, + ) + for task_id in task_id_array + ] + + +def generate_training_tasks() -> Generator[Task, None, None]: + for task in get_appworld_tasks(TRAINING_SPLIT): + yield task + + +def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey): + import asyncio + from tutorial.example_appworld_swarm.appworld_swarm import ExampleAgentScopeWorkflow + workflow = ExampleAgentScopeWorkflow( + env_url=ENV_URL, + env_type=ENV_TYPE, + max_steps=MAX_STEPS, + ) + return asyncio.run(workflow.execute(task, api_baseurl_key)) + + +def main(): + + ajet_job = AgentJetJob( + base_yaml_config="tutorial/example_appworld_swarm/appworld.yaml", + algorithm="grpo", + experiment_name="appworld_swarm_14b", + max_env_worker=MAX_ENV_WORKER, + ) + + # Hand shake with remote swarm server + swarm_worker = SwarmClient(AJET_SWARM_URL) + swarm_worker.auto_sync_train_config_and_start_engine( + ajet_job, + # force_restart=True, + ) + + GRPO_N = ajet_job.num_repeat + REMOTE_BATCH_SIZE = ajet_job.batch_size + + os.makedirs(RESULT_DIR, exist_ok=True) + eval_log_path = os.path.join(RESULT_DIR, "eval_results.log") + val_result_path = os.path.join(RESULT_DIR, "val_results.md") + + eval_tasks = get_appworld_tasks(VALIDATION_SPLIT) + print(f"[INFO] Loaded {len(eval_tasks)} eval tasks (split={VALIDATION_SPLIT})") + + def rollout(task: Task) -> float: + # begin episode + episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=600) + # execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key ) + workflow_output = execute_agent(task, api_baseurl_key) + # report output back to swarm remote + swarm_worker.end_episode(task, episode_uuid, workflow_output) + return workflow_output.reward + + def eval_rollout(task: Task) -> float: + episode_uuid, api_baseurl_key = swarm_worker.begin_episode( + discard_episode_timeout=600, episode_type="eval" + ) + try: + workflow_output = execute_agent(task, api_baseurl_key) + return workflow_output.reward + finally: + # eval samples must NOT be fed back into the training pool + swarm_worker.abort_episode(episode_uuid) + + def run_eval(n_global_step: int): + if not eval_tasks: + return + k = EVAL_K + total_rollouts = len(eval_tasks) * k + print(f"\n[EVAL @ step {n_global_step}] {len(eval_tasks)} tasks x {k} (pass@{k})...") + per_task_rewards: List[List[float]] = [[] for _ in eval_tasks] + pbar = tqdm(total=total_rollouts, desc=f"EVAL @ step {n_global_step}") + + with ThreadPoolExecutor(max_workers=MAX_ENV_WORKER) as eval_executor: + future_to_idx = { + eval_executor.submit(eval_rollout, t): i + for i, t in enumerate(eval_tasks) + for _ in range(k) + } + for fut in as_completed(future_to_idx): + idx = future_to_idx[fut] + try: + per_task_rewards[idx].append(fut.result()) + except Exception as e: + print(f"[EVAL] future error: {e}") + pbar.update(1) + pbar.close() + + flat = [r for rs in per_task_rewards for r in rs if r is not None] + if not flat: + print(f"[EVAL @ step {n_global_step}] no valid rewards") + return + + avg = sum(flat) / len(flat) + std_reward = statistics.pstdev(flat) if len(flat) > 1 else 0.0 + # Full success requires raw_reward >= 1 (final_reward >= 1.5). + # Partial-credit rollouts have 0 < final_reward <= 0.5, so they must NOT + # count as passes; see EnvServiceJudge.compute_reward. + SUCCESS_THRESHOLD = 1.0 + pass1 = sum(1 for r in flat if r >= SUCCESS_THRESHOLD) / len(flat) + num_all_success_tasks = sum( + 1 + for rs in per_task_rewards + if rs and all((r is not None and r >= SUCCESS_THRESHOLD) for r in rs) + ) + num_pass_n_tasks = sum( + 1 + for rs in per_task_rewards + if any((r is not None and r >= SUCCESS_THRESHOLD) for r in rs) + ) + passk = num_pass_n_tasks / len(per_task_rewards) + summary = ( + f"[EVAL @ step {n_global_step}] mean_reward={avg:.4f} std_reward={std_reward:.4f} " + f"task_pass_rate@1={pass1*100:.2f}% task_pass_rate@{k}={passk*100:.2f}% " + f"n_tasks={len(per_task_rewards)} n_rollouts={len(flat)}" + ) + print(summary) + with open(eval_log_path, "a") as f: + f.write(summary + "\n") + with open(val_result_path, "a") as f: + f.write(f"\n## Step {n_global_step}\n") + f.write(f"- pass_n: {k}\n") + f.write(f"- total_tasks: {len(per_task_rewards)}\n") + f.write(f"- num_all_success_tasks: {num_all_success_tasks}\n") + f.write(f"- num_pass_n_tasks: {num_pass_n_tasks}\n") + f.write(f"- task_pass_rate@1: {pass1*100:.2f}%\n") + f.write(f"- task_pass_rate@{k}: {passk*100:.2f}%\n") + f.write(f"- mean_reward: {avg:.4f}\n") + f.write(f"- std_reward: {std_reward:.4f}\n") + f.write(f"- n_rollouts: {len(flat)}\n") + + # step-0 eval (swarm mode does not support val_before_train) + last_eval_step = 0 + run_eval(0) + + executor = PeriodicDrainThreadPoolExecutor( + workers=GRPO_N * REMOTE_BATCH_SIZE, max_parallel=64, auto_retry=True + ) + + n_global_step = 0 + for _ in range(NUM_EPOCH): + for task in generate_training_tasks(): + for _ in range(GRPO_N): + # `submit_with_periodic_drain` returns drained results only when the + # in-flight pool was actually drained on this submission. Each drain + # boundary corresponds to a fully-collected local batch -- exactly + # when this client should agree to a weight sync under + # `rollout_until_all_clients_agree_sync_weight`. + _, drained_results = executor.submit_with_periodic_drain( + fn=rollout, task=task + ) + if drained_results: + swarm_worker.agree_sync_weight() + + n_global_step = swarm_worker.get_global_step() + if n_global_step >= last_eval_step + EVAL_INTERVAL: + run_eval(n_global_step) + last_eval_step = n_global_step + + if n_global_step >= TOTAL_TRAINING_STEPS: + break + + if n_global_step >= TOTAL_TRAINING_STEPS: + break + + print("[INFO] Training complete.") + + +if __name__ == "__main__": + main() diff --git a/tutorial/example_appworld_swarm/appworld.yaml b/tutorial/example_appworld_swarm/appworld.yaml new file mode 100644 index 00000000..da4160a3 --- /dev/null +++ b/tutorial/example_appworld_swarm/appworld.yaml @@ -0,0 +1,84 @@ +# ------------------ main config ------------------ +# Swarm-mode counterpart of tutorial/example_appworld/appworld.yaml. +# Settings unrelated to swarm wiring are kept identical to the original yaml. +ajet: + project_name: example_appworld_swarm + experiment_dir: "auto" # {exp-dir}/{experiment_name} + + task_judge: + # reward is computed by the swarm workflow on the client side + judge_protocol: null + + task_reader: + # tasks are enumerated by the swarm client (env_service is queried in agent_roll.py) + type: random_dummy + + model: + # ✨ select model to be trained (matches original appworld.yaml) + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct + + rollout: + # workflow is driven from the swarm client, not from the server + user_workflow: null + force_disable_toolcalls: True + temperature: 0.9 + max_env_worker: 64 + num_repeat: 6 + agent_madness_reward: -1.0 + tensor_model_parallel_size: 1 + max_num_seqs: 64 + compute_madness_checklist: + - "nonsense" + max_response_length_in_one_turn: 4096 + max_model_len: 18000 + multi_turn: + max_sample_per_task: 25 + max_steps: 25 + + # the experimental ZeroMQ interchange server feature that allows `tuner.as_oai_baseurl_apikey` feature + enable_interchange_server: True + # train in cloud, run episode locally + enable_swarm_mode: True + # both swarm / oai share the same interchange server + interchange_server: + interchange_method: 'ipc' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) + interchange_server_port: 10086 + num_fastapi_process: 1 + max_fastapi_threads: 512 + max_inference_tracker_threads: 64 + already_started: False # do not edit, used by `swarm` + + # Stop the rollout phase only when **every** active swarm client has called + # `SwarmClient.agree_sync_weight()`. The driver (see `agent_roll.py`) calls + # `agree_sync_weight()` each time `submit_with_periodic_drain` actually + # drains, so each weight sync lines up with a complete drain boundary. + swarm_mode_sample_collection_method: "rollout_until_all_clients_agree_sync_weight" + + debug: + debug_max_parallel: 1 + debug_first_n_tasks: 1 + + data: + train_batch_size: 64 + max_prompt_length: 3000 + max_response_length: 15000 + + trainer_common: + save_freq: 99999 + test_freq: 99999 + total_epochs: 99999 + nnodes: 1 + n_gpus_per_node: 8 + + +# ------------------ do not edit ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl + +# ------------------ do not edit ------------------ +defaults: + - verl_default + - ajet_default + - _self_ diff --git a/tutorial/example_appworld_swarm/appworld_swarm.py b/tutorial/example_appworld_swarm/appworld_swarm.py new file mode 100644 index 00000000..5735f357 --- /dev/null +++ b/tutorial/example_appworld_swarm/appworld_swarm.py @@ -0,0 +1,169 @@ +from typing import Any, Tuple + +from loguru import logger +from openai.types.chat.chat_completion import ChatCompletion + +from ajet import WorkflowOutput +from ajet.schema.task import Task +from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey +from ajet.utils.env_service_client.env_client_ng import EnvClient +from ajet.utils.message_utils import is_token_overflow_message + + +class AppworldGymWrapper: + """Mirror of ajet.task_rollout.resource_keeper.BaseGymEnv for swarm-mode clients. + + The swarm runner does not build a `gym_env` for us, so we wrap `EnvClient` + directly to keep the `step()/evaluate()` surface that the agent loop expects. + """ + + def __init__(self, env_client: EnvClient, episode_uuid: str): + self.env_client = env_client + self.episode_uuid = episode_uuid + + def step(self, action: dict) -> Tuple[Any, float, bool, dict]: + env_output = self.env_client.step( + instance_id=self.episode_uuid, + action=action, + ) + obs: Any = "" + reward: float = 0 + info: dict = {} + if isinstance(env_output["state"], list): + obs = env_output["state"] + reward = env_output["reward"] + info = env_output["info"] + else: + if ("content" not in env_output["state"]) and ("error" in env_output["state"]): + obs = f"[Error from environment: {env_output['error']}]" + elif env_output["state"].get("content", "") == "": + obs = "Warning: the environment does not provide any feedback, please provide valid input and try again." + else: + obs = env_output["state"]["content"] + terminate = env_output["is_terminated"] + return obs, reward, terminate, info + + def evaluate(self, params=None): + return self.env_client.evaluate(self.episode_uuid, params=params or {"sparse": False}) + + +class ExampleAgentScopeWorkflow: + """Swarm-mode appworld workflow. + + Unlike the in-process workflow (which receives a fully initialized + `WorkflowTask` with `gym_env` populated by the framework), the swarm + client is responsible for the env_service instance lifecycle and reward + evaluation locally. + """ + + def __init__( + self, + env_url: str = "http://127.0.0.1:8080", + env_type: str = "appworld", + max_steps: int = 25, + ): + self.env_url = env_url + self.env_type = env_type + self.max_steps = max_steps + + async def execute(self, task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey) -> WorkflowOutput: + episode_uuid = api_baseurl_key.episode_uuid + env_client = EnvClient(base_url=self.env_url) + + try: + create_response = env_client.create_instance( + env_type=self.env_type, + task_id=task.task_id, + instance_id=episode_uuid, + params={}, + ) + state_message = create_response["state"] + if isinstance(state_message, dict): + raw_init_messages = [state_message] + elif isinstance(state_message, list): + raw_init_messages = state_message + else: + raise ValueError( + f"state_message should be dict or list, got {type(state_message)}" + ) + + if len(raw_init_messages) >= 2: + first_msg, init_messages = raw_init_messages[0], raw_init_messages[1:] + else: + first_msg = {"content": "You're a helpful assistant."} + init_messages = [] + + interaction_message = [ + { + "content": first_msg.get("content", "You're a helpful assistant."), + "role": "system", + } + ] + for msg in init_messages: + interaction_message.append( + { + "content": msg.get("content", ""), + "role": msg.get("role", "user"), + } + ) + + client = api_baseurl_key.as_raw_openai_sdk_client() + env = AppworldGymWrapper(env_client, episode_uuid) + step = 0 + for step in range(self.max_steps): + reply_message: ChatCompletion = await client.chat.completions.create( + model="ajet-model", + messages=interaction_message, + ) + reply_content = reply_message.choices[0].message.content + # AgentJet signals prompt overflow via a synthetic assistant message; further turns will only push the prompt further past max_model_len, + if is_token_overflow_message(reply_content): + logger.warning(f"[appworld_swarm] token overflow detected at step={step} (task_id={task.task_id}); aborting rollout.") + break + obs, _, terminate, _ = env.step( + action={"content": reply_content, "role": "assistant"} + ) + interaction_message.extend( + [ + { + "content": reply_message.choices[0].message.content, + "role": "assistant", + }, + { + "content": obs, + "role": "user", + } + ] + ) + if terminate: + break + + try: + raw_reward = env.evaluate(params={"sparse": False}) + except Exception: + logger.exception("Evaluation failed; defaulting raw_reward=0.0") + raw_reward = 0.0 + + # mirror EnvServiceJudge.compute_reward + if raw_reward >= 1: + is_success = True + final_reward = 1.0 + raw_reward * 0.5 + else: + is_success = False + final_reward = 0.0 + raw_reward * 0.5 + + return WorkflowOutput( + reward=final_reward, + is_success=is_success, + metadata={"total_step": step}, + ) + except Exception: + logger.bind(exception=True).exception( + f"Error during appworld swarm episode (task_id={task.task_id})." + ) + return WorkflowOutput(reward=0.0, is_success=False, metadata={"total_step": 0}) + finally: + try: + env_client.release_instance(episode_uuid) + except Exception: + logger.exception("Failed to release env instance") diff --git a/tutorial/example_cocktail_rl_v2/cocktail_v2_config.py b/tutorial/example_cocktail_rl_v2/cocktail_v2_config.py new file mode 100644 index 00000000..5c17ebfc --- /dev/null +++ b/tutorial/example_cocktail_rl_v2/cocktail_v2_config.py @@ -0,0 +1,185 @@ +# -*- coding: utf-8 -*- +""" +Single source of truth for example_cocktail_rl_v2. + +Every config value used anywhere in this tutorial -- v2 schedule knobs, engine +knobs, per-domain knobs -- lives on `CocktailV2Config`. There are no YAMLs, no +hardcoded constants in the runner or clients, no `.get(key, default)` fallback +patterns that could drift. To change anything, edit a default here. +""" + +from __future__ import annotations + +import math +import os +from dataclasses import dataclass, field +from typing import List + + +SCHEDULE_TYPES = ("linear", "cos", "constant") + + +# ============================ Per-domain sub-configs ============================ + +@dataclass +class AppWorldConfig: + env_url: str = "http://127.0.0.1:8080" + env_type: str = "appworld" + training_split: str = "train" + validation_split: str = "dev" + episode_timeout: int = 60 + + +@dataclass +class AimeConfig: + episode_timeout: int = 60 + # Filenames resolve under ../opencode_build_aime/data relative to this tutorial. + train_dataset_filename: str = "dapo-math-17k.parquet" + test_dataset_filenames: dict = field(default_factory=lambda: { + "AIME-2026": "aime-2026.parquet", + "DAPO-Math-Tiny-Val": "dapo-math-tiny-val.parquet", + }) + + +# ============================ Top-level config ============================ + +@dataclass +class CocktailV2Config: + """Single source of truth. Both client_0 and client_1 must agree on these + values, so the dataclass defaults ARE the canonical config. + + Schedule semantics for client_0's batch ratio: + schedule_type == "constant": ratio is always `schedule_start`. + schedule_type == "linear": linear from `schedule_start` at step 0 to + `schedule_end` at `schedule_end_step`, + then stays at `schedule_end`. + schedule_type == "cos": cosine anneal from `schedule_start` to + `schedule_end` over `schedule_end_step`, + then stays at `schedule_end`. + client_1's ratio is always 1 - client_0's ratio. + """ + # ---- v2 batching / schedule ---- + total_batch_size: int = 64 + grpo_n: int = 8 + schedule_type: str = "linear" + schedule_start: float = 0.5 + schedule_end: float = 0.0 + schedule_end_step: int = 200 + + # ---- v2 client-side runtime ---- + max_env_worker: int = 64 * 8 + max_inference_tracker_threads: int = 128 + eval_interval: int = 10 + eval_k: int = 4 + total_training_steps: int = 200 + swarm_url: str = "http://localhost:10086" + result_dir: str = "./cocktail_results_v2" + + # ---- engine-global per-rollout knobs (read by engine + per-client agents) ---- + max_response_length: int = 20000 + max_steps: int = 25 + + # ---- engine-only knobs (consumed by build_cocktail_ajet_job) ---- + project_name: str = "cocktail_rl" + experiment_name: str = "cocktail_rl_v2" + experiment_dir: str = "auto" + model_path: str = "/mnt/data_cpfs/xielipeng.xlp/models/Qwen3-8B" + algorithm: str = "grpo" + swarm_mode: bool = True + swarm_mode_sample_collection_method: str = "rollout_until_all_clients_agree_sync_weight" + logging: str = "swanlab" + compute_madness_checklist: List[str] = field(default_factory=lambda: ["nonsense"]) + max_prompt_length: int = 3000 + max_response_length_in_one_turn: int = 12000 + max_model_len: int = 23000 + max_num_seqs: int = 128 + n_gpu: int = 8 + use_kl_loss: bool = True + use_kl_in_reward: bool = False + kl_penalty_type: str = "kl" + + # ---- engine knobs not exposed as AgentJetJob kwargs ---- + temperature: float = 0.9 + force_disable_toolcalls: bool = False + agent_madness_reward: float = 0.0 + tensor_model_parallel_size: int = 1 + multi_turn_max_sample_per_task: int = 25 + save_freq: int = 1_000_000_000 + test_freq: int = 10 + total_epochs: int = 99_999 + nnodes: int = 1 + val_pass_n: int = 4 + val_before_train: bool = False + debug_max_parallel: int = 1 + debug_first_n_tasks: int = 1 + + # ---- per-domain ---- + appworld: AppWorldConfig = field(default_factory=AppWorldConfig) + aime: AimeConfig = field(default_factory=AimeConfig) + + def __post_init__(self) -> None: + assert self.total_batch_size >= 1, "total_batch_size must be >= 1" + assert self.grpo_n >= 1, "grpo_n must be >= 1" + assert self.schedule_type in SCHEDULE_TYPES, \ + f"schedule_type must be one of {SCHEDULE_TYPES}, got {self.schedule_type}" + assert 0.0 <= self.schedule_start <= 1.0, "schedule_start must be in [0, 1]" + assert 0.0 <= self.schedule_end <= 1.0, "schedule_end must be in [0, 1]" + assert self.schedule_end_step >= 0, "schedule_end_step must be >= 0" + + def get_client_0_ratio(self, global_step: int) -> float: + if self.schedule_type == "constant" or self.schedule_end_step <= 0: + return self.schedule_start + if global_step >= self.schedule_end_step: + return self.schedule_end + t = global_step / self.schedule_end_step + if self.schedule_type == "linear": + return self.schedule_start + t * (self.schedule_end - self.schedule_start) + if self.schedule_type == "cos": + cos_factor = 0.5 * (1.0 + math.cos(math.pi * t)) # 1 at t=0, 0 at t=1 + return self.schedule_end + (self.schedule_start - self.schedule_end) * cos_factor + raise ValueError(f"Unknown schedule_type: {self.schedule_type}") + + def split_local_batch_sizes(self, global_step: int) -> tuple[int, int]: + """Return (client_0_local_batch_size, client_1_local_batch_size) -- the + number of distinct prompts each client should contribute this round. + Uses round() on client_0; client_1 = total - client_0. Sum == total exactly.""" + r0 = max(0.0, min(1.0, self.get_client_0_ratio(global_step))) + client_0_local_batch_size = int(round(self.total_batch_size * r0)) + client_1_local_batch_size = self.total_batch_size - client_0_local_batch_size + return client_0_local_batch_size, client_1_local_batch_size + + +def cocktail_v2_config_from_env() -> CocktailV2Config: + """Build the v2 config and apply env-var overrides. + + Currently supported env vars: + COCKTAIL_RATIO_SCHEDULE = linear | cos | constant + Override schedule_type. The same value MUST be exported in both + clients' shells, otherwise the two will compute different per-round + local batch sizes. + COCKTAIL_RESULT_DIR = + Override result_dir (default './cocktail_results_v2'). Both clients + must export the same value; otherwise their logs will diverge. + COCKTAIL_SCHEDULE_START = + Override schedule_start (client_0's batch ratio at step 0; for + schedule_type=constant this is the ratio at every step). Both clients + must export the same value, or they will compute different local + batch sizes. + """ + cfg = CocktailV2Config() + sched_type = os.getenv("COCKTAIL_RATIO_SCHEDULE") + if sched_type is not None: + cfg.schedule_type = sched_type + # Re-validate since we mutated. + cfg.__post_init__() + print(f"[INFO] env override: COCKTAIL_RATIO_SCHEDULE = {sched_type!r}") + result_dir = os.getenv("COCKTAIL_RESULT_DIR") + if result_dir is not None: + cfg.result_dir = result_dir + print(f"[INFO] env override: COCKTAIL_RESULT_DIR = {result_dir!r}") + sched_start = os.getenv("COCKTAIL_SCHEDULE_START") + if sched_start is not None: + cfg.schedule_start = float(sched_start) + cfg.__post_init__() + print(f"[INFO] env override: COCKTAIL_SCHEDULE_START = {cfg.schedule_start!r}") + return cfg diff --git a/tutorial/example_cocktail_rl_v2/cocktail_v2_runner.py b/tutorial/example_cocktail_rl_v2/cocktail_v2_runner.py new file mode 100644 index 00000000..dd330320 --- /dev/null +++ b/tutorial/example_cocktail_rl_v2/cocktail_v2_runner.py @@ -0,0 +1,272 @@ +# -*- coding: utf-8 -*- +""" +Shared base class for example_cocktail_rl_v2. + +Each per-domain client (AppWorld / AIME) subclasses CocktailSwarmRunner and +implements four methods (setup_data, rollout, eval_rollout, is_success). +The driver subclass additionally overrides `build_ajet_job()`. The follower +inherits the default (returns None) and waits for the engine to roll. + +All configuration lives in `cocktail_v2_config.CocktailV2Config` -- this file +contains zero config defaults. +""" + +from __future__ import annotations + +import os +import time +import statistics +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import List, Optional + +from tqdm import tqdm + +from ajet.copilot.job import AgentJetJob +from ajet.schema.task import Task +from ajet.tuner_lib.experimental.swarm_client import SwarmClient +from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor + +from tutorial.example_cocktail_rl_v2.cocktail_v2_config import CocktailV2Config + + +class CocktailSwarmRunner(ABC): + ROLE: str = "" # "client_0" | "client_1" + IS_DRIVER: bool = False # whether this client drives engine startup + CLIENT_LABEL: str = "" # e.g. "appworld" | "aime", used in subdir + log lines + EPISODE_TIMEOUT: int = 60 + + def __init__(self, v2_config: CocktailV2Config): + assert self.ROLE in ("client_0", "client_1"), \ + f"subclass must set ROLE; got {self.ROLE!r}" + assert self.CLIENT_LABEL, "subclass must set CLIENT_LABEL" + + self.config = v2_config + self.swarm_worker: Optional[SwarmClient] = None + self.dataset = None # must have generate_training_tasks() method + self.eval_tasks_by_set: dict[str, list[Task]] = {} + + self.client_result_dir = os.path.join( + v2_config.result_dir, f"results_{self.CLIENT_LABEL}" + ) + os.makedirs(self.client_result_dir, exist_ok=True) + + # ---------------- to override ---------------- + + @abstractmethod + def setup_data(self) -> None: + """Populate self.dataset (with generate_training_tasks() method) and self.eval_tasks_by_set.""" + + @abstractmethod + def rollout(self, task: Task) -> float: + """Train rollout: begin_episode -> execute -> end_episode -> return reward.""" + + @abstractmethod + def eval_rollout(self, task: Task) -> float: + """Eval rollout: begin_episode(episode_type='eval') -> execute -> abort_episode.""" + + @abstractmethod + def is_success(self, reward: float) -> bool: + """Domain-specific success threshold for logging.""" + + def build_ajet_job(self) -> Optional[AgentJetJob]: + """Driver-only hook. Return a configured AgentJetJob; followers return None.""" + return None + + # ---------------- shared lifecycle ---------------- + + def setup(self) -> None: + self.swarm_worker = SwarmClient(self.config.swarm_url, verbose=False) + if self.IS_DRIVER: + ajet_job = self.build_ajet_job() + assert ajet_job is not None, f"{type(self).__name__}.build_ajet_job() must return AgentJetJob (IS_DRIVER=True)" + self.swarm_worker.auto_sync_train_config_and_start_engine(ajet_job) + else: + print("[INFO] Waiting for swarm server (ENGINE.ROLLING)...") + self.swarm_worker._wait_until_status_change_to(desired_status="ENGINE.ROLLING") + print("[INFO] Swarm server is ready.") + + self.setup_data() + + def run(self) -> None: + self.setup() + self.run_eval(n_global_step=0) + self.train_loop() + + + + + + + + + + + + + # ---------------- + # ---------------- shared training ---------------- + # ---------------- + + def _get_local_batch_size(self, step: int) -> int: + client_0_batch, client_1_batch = self.config.split_local_batch_sizes(step) + return client_0_batch if self.ROLE == "client_0" else client_1_batch + + def train_loop(self) -> None: + assert self.swarm_worker is not None and self.dataset is not None + + train_log_path = os.path.join( + self.client_result_dir, f"train_results_{self.CLIENT_LABEL}.log" + ) + last_eval_step = 0 + + num_epochs = 10000 + for epoch in range(num_epochs): + step = self.swarm_worker.get_global_step() + local_batch_size = self._get_local_batch_size(step) + + executor = PeriodicDrainThreadPoolExecutor( + workers=local_batch_size * self.config.grpo_n, + max_parallel=self.config.max_env_worker, + auto_retry=True, + ) + + for _, task in enumerate(self.dataset.generate_training_tasks()): + for _ in range(self.config.grpo_n): + _, drained_results = executor.submit_with_periodic_drain( # ✨✨✨✨ + fn=self.rollout, task=task + ) + if drained_results: + rewards = [r for r in drained_results if r is not None] + step = self.swarm_worker.get_global_step() + if rewards: + avg_reward = sum(rewards) / len(rewards) + std_reward = statistics.pstdev(rewards) if len(rewards) > 1 else 0.0 + success_rate = sum(1 for r in rewards if self.is_success(r)) / len(rewards) + line = ( + f"[TRAIN @ step {step}] client={self.CLIENT_LABEL} " + f"batch_size={len(rewards)} mean_reward={avg_reward:.4f} " + f"std_reward={std_reward:.4f} success_rate={success_rate*100:.2f}%" + ) + print(line) + with open(train_log_path, "a") as f: + f.write(line + "\n") + + self.swarm_worker.agree_sync_weight() + if step >= last_eval_step + self.config.eval_interval: + self.run_eval(step) + last_eval_step = step + + if step >= self.config.total_training_steps: + break + + executor.shutdown(wait=False) + if self.swarm_worker.get_global_step() >= self.config.total_training_steps: + break + + finish_flag = os.path.join(self.client_result_dir, "finish.flag") + with open(finish_flag, "w") as f: + f.write(f"Training completed at {time.time()}\n") + print(f"[INFO] {self.CLIENT_LABEL} training complete.") + + + + + + + + + + + + + + + + # ---------------- + # ---------------- shared eval ---------------- + # ---------------- + + def run_eval(self, n_global_step: int) -> None: + if not self.eval_tasks_by_set: + return + eval_log_path = os.path.join( + self.client_result_dir, f"eval_results_{self.CLIENT_LABEL}.log" + ) + for label, eval_tasks in self.eval_tasks_by_set.items(): + self.run_eval_for_one_benchmark_dataset(n_global_step, label, eval_tasks, eval_log_path) + + def run_eval_for_one_benchmark_dataset( + self, + n_global_step: int, + label: str, + eval_tasks: List[Task], + eval_log_path: str, + ) -> None: + k = self.config.eval_k + total_rollouts = len(eval_tasks) * k + print( + f"\n[EVAL @ step {n_global_step}] {self.CLIENT_LABEL}/{label}: " + f"{len(eval_tasks)} tasks x {k} (pass@{k})..." + ) + per_task_rewards: List[List[float]] = [[] for _ in eval_tasks] + pbar = tqdm(total=total_rollouts, desc=f"EVAL {label} @ step {n_global_step}") + + with ThreadPoolExecutor(max_workers=self.config.max_env_worker) as eval_executor: + future_to_idx = { + eval_executor.submit(self.eval_rollout, t): i + for i, t in enumerate(eval_tasks) + for _ in range(k) + } + for fut in as_completed(future_to_idx): + idx = future_to_idx[fut] + try: + per_task_rewards[idx].append(fut.result()) + except Exception as e: + print(f"[EVAL] future error: {e}") + pbar.update(1) + pbar.close() + + flat = [r for rs in per_task_rewards for r in rs if r is not None] + if not flat: + print(f"[EVAL @ step {n_global_step}] {self.CLIENT_LABEL}/{label} no valid rewards") + return + + avg = sum(flat) / len(flat) + std = statistics.pstdev(flat) if len(flat) > 1 else 0.0 + pass1 = sum(1 for r in flat if self.is_success(r)) / len(flat) + num_all_success_tasks = sum( + 1 + for rs in per_task_rewards + if rs and all((r is not None and self.is_success(r)) for r in rs) + ) + num_pass_n_tasks = sum( + 1 + for rs in per_task_rewards + if any((r is not None and self.is_success(r)) for r in rs) + ) + passk = num_pass_n_tasks / len(per_task_rewards) + summary = ( + f"[EVAL @ step {n_global_step}] {self.CLIENT_LABEL}/{label} " + f"mean_reward={avg:.4f} std_reward={std:.4f} " + f"task_pass_rate@1={pass1*100:.2f}% task_pass_rate@{k}={passk*100:.2f}% " + f"n_tasks={len(per_task_rewards)} n_rollouts={len(flat)}" + ) + print(summary) + with open(eval_log_path, "a") as f: + f.write(summary + "\n") + + val_result_path = os.path.join( + self.client_result_dir, f"val_results_{self.CLIENT_LABEL}.md" + ) + with open(val_result_path, "a") as f: + f.write(f"\n## Step {n_global_step} ({label})\n") + f.write(f"- pass_n: {k}\n") + f.write(f"- total_tasks: {len(per_task_rewards)}\n") + f.write(f"- num_all_success_tasks: {num_all_success_tasks}\n") + f.write(f"- num_pass_n_tasks: {num_pass_n_tasks}\n") + f.write(f"- task_pass_rate@1: {pass1*100:.2f}%\n") + f.write(f"- task_pass_rate@{k}: {passk*100:.2f}%\n") + f.write(f"- mean_reward: {avg:.4f}\n") + f.write(f"- std_reward: {std:.4f}\n") + f.write(f"- n_rollouts: {len(flat)}\n") diff --git a/tutorial/example_cocktail_rl_v2/readme.md b/tutorial/example_cocktail_rl_v2/readme.md new file mode 100644 index 00000000..7ac76696 --- /dev/null +++ b/tutorial/example_cocktail_rl_v2/readme.md @@ -0,0 +1,118 @@ +# example_cocktail_rl_v2 + +Cocktail RL on AppWorld + AIME with configurable per-client batch ratios and an optional dynamic schedule. + + + +```bash +# install appworld +rm -rf /tmp/pack_all_in_one & wget https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/astuner_archive/appworld_pack_v3.tar.gz && tar -xzf ./appworld_pack_v3.tar.gz -C /tmp + +cd /mnt/data_cpfs/qingxu.fu/alpha_auto_research/agentjet_codebase + +rm -rf cocktail_results_v2 +source .venv/bin/activate && ajet --autokill +tmux kill-session -t ajet_swarm + +# 创建一个 session,名字叫 ajet_swarm,第一个 pane 跑 appworld +tmux new -d -s ajet_swarm -n main +tmux send-keys -t ajet_swarm:main.0 "bash /tmp/pack_all_in_one/EnvService/env_sandbox/appworld.sh" Enter + +# 第二个 pane:server +tmux split-window -t ajet_swarm:main +tmux send-keys -t ajet_swarm:main.1 "source .venv/bin/activate && ajet-swarm start" Enter + +# 第三个 pane:client 0 (appworld) +tmux split-window -t ajet_swarm:main +tmux send-keys -t ajet_swarm:main.2 "export COCKTAIL_RATIO_SCHEDULE=constant && source .venv/bin/activate && python -m tutorial.example_cocktail_rl_v2.train_appworld_as_swarm_client_0" Enter + +# 第四个 pane:client 1 (aime) +tmux split-window -t ajet_swarm:main +tmux send-keys -t ajet_swarm:main.3 "export COCKTAIL_RATIO_SCHEDULE=constant && source .venv/bin/activate && python -m tutorial.example_cocktail_rl_v2.train_aime_as_swarm_client_1" Enter + +# 把四个 pane 平铺成 2x2 网格 +tmux select-layout -t ajet_swarm:main tiled + +# 进入 session 查看 +tmux attach -t ajet_swarm + +``` + +Edit `CocktailV2Config` defaults (cocktail_v2_runner.py) for `total_batch_size`, `schedule_start`/`schedule_end`/`schedule_end_step`. Engine knobs live in `build_cocktail_ajet_job()` (train_appworld_as_swarm_client_0.py). Both clients must agree on these. + +## Custom result dir + ratio 0.25 (client_0 = appworld) + +`COCKTAIL_RESULT_DIR` overrides `result_dir` (default `./cocktail_results_v2`); `COCKTAIL_SCHEDULE_START` overrides `schedule_start` (client_0's ratio; under `constant` schedule it is the ratio at every step). Both env vars MUST be set to the same value in both client panes. With `total_batch_size=64`, ratio 0.25 → client_0 (appworld) = 16, client_1 (aime) = 48. + +```bash + +# install appworld +rm -rf /tmp/pack_all_in_one & wget https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/astuner_archive/appworld_pack_v3.tar.gz && tar -xzf ./appworld_pack_v3.tar.gz -C /tmp + +cd /mnt/data_cpfs/qingxu.fu/alpha_auto_research/agentjet_codebase + +rm -rf cocktail_results_v2_r025 +source .venv/bin/activate && ajet --autokill +tmux kill-session -t ajet_swarm + +# 创建一个 session,名字叫 ajet_swarm,第一个 pane 跑 appworld +tmux new -d -s ajet_swarm -n main +tmux send-keys -t ajet_swarm:main.0 "bash /tmp/pack_all_in_one/EnvService/env_sandbox/appworld.sh" Enter + +# 第二个 pane:server +tmux split-window -t ajet_swarm:main +tmux send-keys -t ajet_swarm:main.1 "source .venv/bin/activate && ajet-swarm start" Enter + +# 第三个 pane:client 0 (appworld) +tmux split-window -t ajet_swarm:main +tmux send-keys -t ajet_swarm:main.2 "export COCKTAIL_RATIO_SCHEDULE=constant && export COCKTAIL_SCHEDULE_START=0.25 && export COCKTAIL_RESULT_DIR=./cocktail_results_v2_r025 && source .venv/bin/activate && python -m tutorial.example_cocktail_rl_v2.train_appworld_as_swarm_client_0" Enter + +# 第四个 pane:client 1 (aime) +tmux split-window -t ajet_swarm:main +tmux send-keys -t ajet_swarm:main.3 "export COCKTAIL_RATIO_SCHEDULE=constant && export COCKTAIL_SCHEDULE_START=0.25 && export COCKTAIL_RESULT_DIR=./cocktail_results_v2_r025 && source .venv/bin/activate && python -m tutorial.example_cocktail_rl_v2.train_aime_as_swarm_client_1" Enter + +# 把四个 pane 平铺成 2x2 网格 +tmux select-layout -t ajet_swarm:main tiled + +# 进入 session 查看 +tmux attach -t ajet_swarm + +``` + +## Custom result dir + ratio 0.75 (client_0 = appworld) + +Same setup as above, ratio flipped. With `total_batch_size=64`, ratio 0.75 → client_0 (appworld) = 48, client_1 (aime) = 16. + +```bash + +rm -rf /tmp/pack_all_in_one & wget https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/astuner_archive/appworld_pack_v3.tar.gz && tar -xzf ./appworld_pack_v3.tar.gz -C /tmp + +cd /mnt/data_cpfs/qingxu.fu/alpha_auto_research/agentjet_codebase + +rm -rf cocktail_results_v2_r075 +source .venv/bin/activate && ajet --autokill +tmux kill-session -t ajet_swarm + +# 创建一个 session,名字叫 ajet_swarm,第一个 pane 跑 appworld +tmux new -d -s ajet_swarm -n main +tmux send-keys -t ajet_swarm:main.0 "bash /tmp/pack_all_in_one/EnvService/env_sandbox/appworld.sh" Enter + +# 第二个 pane:server +tmux split-window -t ajet_swarm:main +tmux send-keys -t ajet_swarm:main.1 "source .venv/bin/activate && ajet-swarm start" Enter + +# 第三个 pane:client 0 (appworld) +tmux split-window -t ajet_swarm:main +tmux send-keys -t ajet_swarm:main.2 "export COCKTAIL_RATIO_SCHEDULE=constant && export COCKTAIL_SCHEDULE_START=0.75 && export COCKTAIL_RESULT_DIR=./cocktail_results_v2_r075 && source .venv/bin/activate && python -m tutorial.example_cocktail_rl_v2.train_appworld_as_swarm_client_0" Enter + +# 第四个 pane:client 1 (aime) +tmux split-window -t ajet_swarm:main +tmux send-keys -t ajet_swarm:main.3 "export COCKTAIL_RATIO_SCHEDULE=constant && export COCKTAIL_SCHEDULE_START=0.75 && export COCKTAIL_RESULT_DIR=./cocktail_results_v2_r075 && source .venv/bin/activate && python -m tutorial.example_cocktail_rl_v2.train_aime_as_swarm_client_1" Enter + +# 把四个 pane 平铺成 2x2 网格 +tmux select-layout -t ajet_swarm:main tiled + +# 进入 session 查看 +tmux attach -t ajet_swarm + +``` diff --git a/tutorial/example_cocktail_rl_v2/train_aime_as_swarm_client_1.py b/tutorial/example_cocktail_rl_v2/train_aime_as_swarm_client_1.py new file mode 100644 index 00000000..d15cafd7 --- /dev/null +++ b/tutorial/example_cocktail_rl_v2/train_aime_as_swarm_client_1.py @@ -0,0 +1,140 @@ +# -*- coding: utf-8 -*- +""" +AIME swarm client (follower) for example_cocktail_rl_v2. + +python -m tutorial.example_cocktail_rl_v2.train_aime_as_swarm_client_1 +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from typing import List + +from ajet.default_config.ajet_config_schema import AjetTaskReader, HuggingfaceDatRepo +from ajet.schema.task import Task +from ajet.task_reader import HuggingFaceTaskReader, RouterTaskReader + +from tutorial.example_cocktail_rl_v2.cocktail_v2_config import ( + CocktailV2Config, + cocktail_v2_config_from_env, +) +from tutorial.example_cocktail_rl_v2.cocktail_v2_runner import CocktailSwarmRunner +from tutorial.opencode_build_aime import download_data +from tutorial.opencode_build_aime.agent_run_v3 import execute_agent as _execute_aime_agent + + +_THIS_DIR = os.path.dirname(__file__) + + +@dataclass +class _AimeAgentConfig: + """Duck-types the subset of AgentJetJob that execute_agent reads.""" + model: str + max_response_length: int + + +def _load_eval_tasks(test_dataset: str, label: str = "") -> List[Task]: + eval_tasks: List[Task] = [] + if not os.path.exists(test_dataset): + print(f"[WARN] Eval dataset not found: {test_dataset}. Skipping {label or test_dataset}.") + return eval_tasks + + eval_reader = HuggingFaceTaskReader( + AjetTaskReader(huggingface_dat_repo=HuggingfaceDatRepo(dataset_path=test_dataset)) + ) + for t in eval_reader.generate_training_tasks(): + eval_tasks.append(t) + print(f"[INFO] Loaded {len(eval_tasks)} eval tasks from {label or test_dataset}") + return eval_tasks + + + + +class AimeRunner(CocktailSwarmRunner): + ROLE = "client_1" + IS_DRIVER = False + CLIENT_LABEL = "aime" + + def __init__(self, v2_config: CocktailV2Config): + super().__init__(v2_config) + am = v2_config.aime + self.EPISODE_TIMEOUT = am.episode_timeout + self.agent_config = _AimeAgentConfig( + model="dummy", + max_response_length=v2_config.max_response_length, + ) + + data_dir = os.path.join(_THIS_DIR, "..", "opencode_build_aime", "data") + self.train_dataset = os.path.join(data_dir, am.train_dataset_filename) + self.test_datasets = { + label: os.path.join(data_dir, fname) + for label, fname in am.test_dataset_filenames.items() + } + + def setup_data(self) -> None: + if not os.path.exists(self.train_dataset): + raise FileNotFoundError( + f"AIME training dataset missing: {self.train_dataset}\n" + "Please run: proxychains python -m tutorial.opencode_build_aime.download_data" + ) + + train_reader = RouterTaskReader( + reader_type="huggingface_dat_repo", + reader_config=AjetTaskReader( + huggingface_dat_repo=HuggingfaceDatRepo(dataset_path=self.train_dataset) + ), + ) + self.dataset = train_reader + + eval_downloaders = { + "AIME-2026": download_data.ensure_aime_2026, + } + for label, path in self.test_datasets.items(): + if not os.path.exists(path): + downloader = eval_downloaders.get(label) + if downloader is None: + print(f"[WARN] {label} parquet missing at {path} and no downloader registered. Skipping.") + continue + print(f"[INFO] {label} parquet missing, downloading...") + try: + downloader() + except Exception as e: + print(f"[WARN] Failed to download {label}: {e}") + continue + tasks = _load_eval_tasks(path, label=label) + if tasks: + self.eval_tasks_by_set[label] = tasks + + def rollout(self, task: Task) -> float: + assert self.swarm_worker is not None + episode_uuid, api_baseurl_key = self.swarm_worker.begin_episode( + discard_episode_timeout=self.EPISODE_TIMEOUT + ) + out = _execute_aime_agent(task, api_baseurl_key, self.agent_config) + self.swarm_worker.end_episode(task, episode_uuid, out) + return out.reward + + def eval_rollout(self, task: Task) -> float: + assert self.swarm_worker is not None + episode_uuid, api_baseurl_key = self.swarm_worker.begin_episode( + discard_episode_timeout=self.EPISODE_TIMEOUT, episode_type="eval" + ) + try: + out = _execute_aime_agent(task, api_baseurl_key, self.agent_config) + return out.reward + finally: + self.swarm_worker.abort_episode(episode_uuid) + + def is_success(self, reward: float) -> bool: + return reward > 0 + + +def main(): + cfg = cocktail_v2_config_from_env() + runner = AimeRunner(cfg) + runner.run() + + +if __name__ == "__main__": + main() diff --git a/tutorial/example_cocktail_rl_v2/train_appworld_as_swarm_client_0.py b/tutorial/example_cocktail_rl_v2/train_appworld_as_swarm_client_0.py new file mode 100644 index 00000000..86e8e14a --- /dev/null +++ b/tutorial/example_cocktail_rl_v2/train_appworld_as_swarm_client_0.py @@ -0,0 +1,192 @@ +# -*- coding: utf-8 -*- +""" +AppWorld swarm client (driver) for example_cocktail_rl_v2. + +python -m tutorial.example_cocktail_rl_v2.train_appworld_as_swarm_client_0 +""" + +from __future__ import annotations + +import os +import random +from typing import Iterator, List, Optional + +from ajet.copilot.job import AgentJetJob +from ajet.schema.task import Task +from ajet.utils.env_service_client.env_client_ng import EnvClient + +from tutorial.example_cocktail_rl_v2.cocktail_v2_config import CocktailV2Config, cocktail_v2_config_from_env +from tutorial.example_cocktail_rl_v2.cocktail_v2_runner import CocktailSwarmRunner + + +# ---------------- Engine config (was cocktail_rl_conf.yaml) ---------------- + +def build_cocktail_ajet_job(cfg: CocktailV2Config) -> AgentJetJob: + """Construct the AgentJetJob that drives the swarm engine. + + Every value is read from `cfg`. There are no hardcoded constants in this + function -- CocktailV2Config is the single source of truth for the entire + engine config. Fields not exposed as AgentJetJob kwargs are set on + `ajet_job.config.ajet.*` after construction and shipped to the engine via + `Config.to_dict()`. + """ + ajet_job = AgentJetJob( + # base_yaml_config=None -> use ajet/default_config/ajet_swarm_default.yaml + project_name=cfg.project_name, + experiment_name=cfg.experiment_name, + experiment_dir=cfg.experiment_dir, + model=cfg.model_path, + algorithm=cfg.algorithm, + num_repeat=cfg.grpo_n, + # batch_size is ignored under rollout_until_all_clients_agree_sync_weight, + # but we mirror cfg.total_batch_size so the dumped engine config reads coherently. + batch_size=cfg.total_batch_size, + swarm_mode=cfg.swarm_mode, + swarm_mode_sample_collection_method=cfg.swarm_mode_sample_collection_method, + max_env_worker=cfg.max_env_worker, + max_prompt_length=cfg.max_prompt_length, + max_response_length=cfg.max_response_length, + max_response_length_in_one_turn=cfg.max_response_length_in_one_turn, + max_model_len=cfg.max_model_len, + max_num_seqs=cfg.max_num_seqs, + compute_madness_checklist=list(cfg.compute_madness_checklist), + n_gpu=cfg.n_gpu, + logging=cfg.logging, + use_kl_loss=cfg.use_kl_loss, + use_kl_in_reward=cfg.use_kl_in_reward, + kl_penalty_type=cfg.kl_penalty_type, + total_training_steps=cfg.total_training_steps, + timeline_compare_level="token", + ) + + # Fields not exposed as AgentJetJob kwargs. + rollout = ajet_job.config.ajet.rollout + rollout.temperature = cfg.temperature + rollout.force_disable_toolcalls = cfg.force_disable_toolcalls + rollout.agent_madness_reward = cfg.agent_madness_reward + rollout.tensor_model_parallel_size = cfg.tensor_model_parallel_size + rollout.multi_turn = { + "max_sample_per_task": cfg.multi_turn_max_sample_per_task, + "max_steps": cfg.max_steps, + } + + trainer = ajet_job.config.ajet.trainer_common + trainer.save_freq = cfg.save_freq + trainer.test_freq = cfg.test_freq + trainer.total_epochs = cfg.total_epochs + trainer.nnodes = cfg.nnodes + trainer.val_pass_n = cfg.val_pass_n + trainer.val_before_train = cfg.val_before_train + + ajet_job.config.ajet.debug = { + "debug_max_parallel": cfg.debug_max_parallel, + "debug_first_n_tasks": cfg.debug_first_n_tasks, + } + + return ajet_job + + +# ---------------- AppWorld task / runner glue ---------------- + +def _get_appworld_tasks(env_url: str, env_type: str, split: str) -> List[Task]: + env_client = EnvClient(base_url=env_url) + task_id_array = env_client.get_env_profile(env_type, split=split) + if len(task_id_array) == 0: + raise ValueError( + f"No task_id found for env_type={env_type}, split={split}, " + f"check connection to {env_url}" + ) + return [ + Task( + main_query="[not defined]", + init_messages=[], + task_id=str(task_id), + env_type=env_type, + metadata={}, + ) + for task_id in task_id_array + ] + + +class ShuffledTaskDataset: + def __init__(self, tasks: List[Task]): + self.tasks = list(tasks) + + def generate_training_tasks(self) -> Iterator[Task]: + pool = list(self.tasks) + random.shuffle(pool) + for t in pool: + yield t + + +class AppWorldRunner(CocktailSwarmRunner): + ROLE = "client_0" + IS_DRIVER = True + CLIENT_LABEL = "appworld" + + def __init__(self, v2_config: CocktailV2Config): + super().__init__(v2_config) + ap = v2_config.appworld + self.env_url: str = ap.env_url + self.env_type: str = ap.env_type + self.training_split: str = ap.training_split + self.validation_split: str = ap.validation_split + self.max_steps: int = v2_config.max_steps + self.EPISODE_TIMEOUT = ap.episode_timeout + + def build_ajet_job(self) -> Optional[AgentJetJob]: + return build_cocktail_ajet_job(self.config) + + def setup_data(self) -> None: + train_tasks = _get_appworld_tasks(self.env_url, self.env_type, self.training_split) + print(f"[INFO] AppWorld training: {len(train_tasks)} tasks (split={self.training_split})") + self.dataset = ShuffledTaskDataset(train_tasks) + + eval_tasks = _get_appworld_tasks(self.env_url, self.env_type, self.validation_split) + print(f"[INFO] AppWorld eval: {len(eval_tasks)} tasks (split={self.validation_split})") + self.eval_tasks_by_set = {self.validation_split: eval_tasks} + + def rollout(self, task: Task) -> float: + assert self.swarm_worker is not None + episode_uuid, api_baseurl_key = self.swarm_worker.begin_episode( + discard_episode_timeout=self.EPISODE_TIMEOUT + ) + out = self._execute(task, api_baseurl_key) + self.swarm_worker.end_episode(task, episode_uuid, out) + return out.reward + + def eval_rollout(self, task: Task) -> float: + assert self.swarm_worker is not None + episode_uuid, api_baseurl_key = self.swarm_worker.begin_episode( + discard_episode_timeout=self.EPISODE_TIMEOUT, episode_type="eval" + ) + try: + out = self._execute(task, api_baseurl_key) + return out.reward + finally: + self.swarm_worker.abort_episode(episode_uuid) + + def is_success(self, reward: float) -> bool: + # Mirrors EnvServiceJudge partial-credit shaping: full success requires + # raw_reward >= 1, which corresponds to final_reward >= 1.0 here. + return reward >= 1.0 + + def _execute(self, task: Task, api_baseurl_key): + import asyncio + from tutorial.example_appworld_swarm.appworld_swarm import ExampleAgentScopeWorkflow + wf = ExampleAgentScopeWorkflow( + env_url=self.env_url, + env_type=self.env_type, + max_steps=self.max_steps, + ) + return asyncio.run(wf.execute(task, api_baseurl_key)) + + +def main(): + cfg = cocktail_v2_config_from_env() + runner = AppWorldRunner(cfg) + runner.run() + + +if __name__ == "__main__": + main() diff --git a/tutorial/opencode_build_aime/agent_roll_v3.py b/tutorial/opencode_build_aime/agent_roll_v3.py index e7b80895..7ec62a7d 100644 --- a/tutorial/opencode_build_aime/agent_roll_v3.py +++ b/tutorial/opencode_build_aime/agent_roll_v3.py @@ -217,8 +217,8 @@ def train(self): """Main training loop.""" # Run eval once before training starts (baseline) self.run_eval(0) + last_eval_step = 0 - task_count = 0 max_parallel = 64 executor = TaskCountLimitedThreadPoolExecutor( max_parallel_groups=BATCH_SIZE, @@ -233,13 +233,12 @@ def train(self): args_list = [{"task": task} for _ in range(self.grpo_n)] executor.submit_group(task_id=task.task_id, fn=self.rollout, args_list=args_list) - task_count += 1 + n_global_step = self.swarm_worker.get_global_step() - # Periodic evaluation every EVAL_INTERVAL * REMOTE_BATCH_SIZE tasks - time_to_eval = task_count % (EVAL_INTERVAL * self.remote_batch_size) == 0 - n_global_step = task_count // self.remote_batch_size + time_to_eval = n_global_step >= last_eval_step + EVAL_INTERVAL if time_to_eval: self.run_eval(n_global_step) + last_eval_step = n_global_step print("\n[INFO] Training complete!") diff --git a/tutorial/opencode_build_aime/agent_run_v3.py b/tutorial/opencode_build_aime/agent_run_v3.py index 228b5971..68f1ceac 100644 --- a/tutorial/opencode_build_aime/agent_run_v3.py +++ b/tutorial/opencode_build_aime/agent_run_v3.py @@ -28,6 +28,7 @@ from ajet.copilot.job import AgentJetJob from ajet.schema.task import Task, WorkflowOutput from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey +from ajet.utils.message_utils import is_token_overflow_message TIMEOUT_EXIT_CODE = -101 @@ -359,6 +360,11 @@ async def run(self, messages: list[dict], sampling_params: dict) -> tuple[str, l total_tokens_used += response.usage.total_tokens if response.usage else 0 + # AgentJet signals prompt overflow via a synthetic assistant message; further turns would only grow the prompt, so stop now. + if is_token_overflow_message(response_content): + formatted_messages.append({"role": "assistant", "content": response_content}) + break + if response_message.tool_calls: for tool_call in response_message.tool_calls: history_tool_calls.append({