Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,9 @@ research_*.json
research_*.jsonc
daemon_logs*
paper
val_results.md
cocktail_vs_separate*
cocktail_results_*
PRINCIPLE.md
TODO.md
plot_mean_reward.py
9 changes: 8 additions & 1 deletion ajet/backbone/main_verl.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,16 @@ def run(self, config):

# Instantiate the tokenizer and processor.
from verl.utils import hf_processor, hf_tokenizer
from ajet.tokenizer.service import start_tokenizer_service

trust_remote_code = config.data.get("trust_remote_code", False)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
local_tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
# Cache hot tokenization calls (encode / decode / apply_chat_template)
# in a sidecar process; every other tokenizer attribute is served by
# the local instance directly.
tokenizer = start_tokenizer_service(
local_tokenizer, local_path, trust_remote_code=trust_remote_code
)
# Used for multimodal LLM, could be None
processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)

Expand Down
7 changes: 4 additions & 3 deletions ajet/backbone/trainer_verl.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,15 +511,16 @@ def fit(self): # noqa: C901
]
)
)
logger.info("start fit rollout")
logger.info("start batch rollout")
self.parallel_env.current_global_steps = self.global_steps
# rollout stage begin ✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨
context_tracker_arr: List[SingleAgentContextTracker] = self.parallel_env.rollout(
tasks, mode="sample", epoch=f"train.{epoch}"
)

# from ajet import bp; bp("BATCH")

logger.info("end fit rollout")
logger.info("end batch rollout")
gen_batch_output = self.parallel_env.to_dataproto(context_tracker_arr)
logger.info("end dataproto convertion")

Expand Down Expand Up @@ -710,7 +711,7 @@ def fit(self): # noqa: C901

# implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
# update actor ✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨
with marked_timer("update_actor", timing_raw, color="red"):
actor_output = self._update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
Expand Down
86 changes: 53 additions & 33 deletions ajet/context_tracker/multiagent_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ def step_spawn_timeline(self, messages: List[dict], tools: List = [], disable_to
if disable_toolcalls:
consider_roles.remove("tool")

previous_message_encounter_user_role = False

for i, msg in enumerate(messages):

if (disable_toolcalls) and (not isinstance(msg["content"], str)):
Expand Down Expand Up @@ -166,6 +168,11 @@ def step_spawn_timeline(self, messages: List[dict], tools: List = [], disable_to
else:
author = "env"

if msg["role"] == "user":
previous_message_encounter_user_role = True

any_later_msg_has_user_role = any((m["role"] == "user") for m in messages[i+1:])

# extract content block from openai-competible messages and convert to ExtendedMessage
timeline += [
ExtendedMessage(
Expand All @@ -179,8 +186,11 @@ def step_spawn_timeline(self, messages: List[dict], tools: List = [], disable_to
token_generator="auto",
name = (msg["name"] if "name" in msg else ""),
first_message=(i == 0),
before_last_query=any_later_msg_has_user_role
)
]
if ("<think>" in msg["content"]) and (not previous_message_encounter_user_role):
logger.warning(f"Warning! Message content contains <think> tag, but no prior message has `user` role! This is not a common scenario. Please check your agent loop carefully.")

return timeline

Expand Down Expand Up @@ -270,7 +280,6 @@ def step_track(
if (
"prompt_text" in llm_output and "prompt_token_ids" in llm_output
):
# currently we make this patch to better compat with Trinity training backend
# fix Retokenization Drift
timeline = self.patch_prompt_tokens(
prompt_text=llm_output["prompt_text"],
Expand Down Expand Up @@ -321,7 +330,6 @@ def save_llm_interaction_timeline(self, tools, llm_ext_msg, timeline):
)
):
logger.bind(exception=True).info(f"General Warning: merge failure discovered.\n")
# from ajet import bp; bp("SWARM")
return


Expand Down Expand Up @@ -393,13 +401,19 @@ def patch_prompt_tokens(
prompt_token_ids: List[int],
previous_ext_context: List[ExtendedMessage],
) -> List[ExtendedMessage]:
"""
fix retokenization drift
prompt_text = llm_output["prompt_text"]: [this llm call] the prompt in text format used in generation
prompt_token_ids = llm_output["prompt_token_ids"]: [this llm call] the prompt token ids used in generation (prompt_text->prompt_token_ids using tokenizer)
previous_ext_context: [from previous context] the context history
"""

# remove tailing
# remove tailing, usually `<|im_start|> assistant`
if prompt_text.endswith(self.generation_prompt):
prompt_text = prompt_text[: -len(self.generation_prompt)]
# prompt_token_ids = prompt_token_ids[: -len(self.generation_prompt_token)]

# split prompt token ids into message level
# split CURRENT prompt token ids into message level (split_prompt_token_ids is List[List[int]])
split_prompt_token_ids = []
tmp = []
for i in range(len(prompt_token_ids)):
Expand All @@ -412,25 +426,32 @@ def patch_prompt_tokens(
if len(tmp) > 0:
split_prompt_token_ids += [tmp]

# split prompt text into message level
# split CURRENT prompt text into message level (corresponding to split_prompt_token_ids)
prompt_text_split = prompt_text.split("<|im_start|>")
assert prompt_text_split[0] == "", "Prompt text should start with <|im_start|>"
prompt_text_split = prompt_text_split[1:] # remove the first empty string
for i in range(len(prompt_text_split)):
prompt_text_split[i] = "<|im_start|>" + prompt_text_split[i]

# context HISTORY prompt text
current_prompt_text = []
for j in range(len(previous_ext_context)):
current_prompt_text += [self.tokenizer.decode(previous_ext_context[j].token_arr)]

# HISTORY context length vs CURRENT prompt length
if len(previous_ext_context) != len(prompt_text_split):
logger.bind(exception=True).error(
f"Length mismatch when patching prompt tokens. Previous ext context length: {len(previous_ext_context)}, prompt text split length: {len(prompt_text_split)}. Replacing all tokens."
)

# try to recover tokens
if self.config.ajet.context_tracker.fix_retokenization_drift:
self.ensure_retokenization_perfect_match(previous_ext_context, split_prompt_token_ids, prompt_text_split, current_prompt_text)
self.ensure_retokenization_perfect_match(
previous_ext_context, # HISTORY
split_prompt_token_ids, # CURRENT
prompt_text_split, # CURRENT
current_prompt_text # HISTORY
)

# remove extra messages
if len(previous_ext_context) != len(prompt_text_split):
Expand All @@ -440,39 +461,38 @@ def patch_prompt_tokens(


def ensure_retokenization_perfect_match(self, previous_ext_context, split_prompt_token_ids, prompt_text_split, current_prompt_text):
"""
Ensure the retokenization is perfectly matched between HISTORY and CURRENT

previous_ext_context: the context history in ExtendedMessage format, which contains token_arr (token ids)
split_prompt_token_ids: the prompt token ids of CURRENT prompt, split into message level (List[List[int]])
prompt_text_split: the prompt text of CURRENT prompt, split into message level (List[str])
current_prompt_text: the prompt text of HISTORY context, converted from token_arr to text using tokenizer, in message level (List[str])
"""

for j in range(len(previous_ext_context)):
if prompt_text_split[j] != current_prompt_text[j]:
# if prompt text mismatch, we can replace the tokens
vllm_token_array = split_prompt_token_ids[j]
tracker_token_array = previous_ext_context[j].token_arr
if vllm_token_array == tracker_token_array:
# good, everything is perfect
continue
else:
from ajet import bp; bp("SWARM")
# otherwise, we throw a warning (do not worry, this causes almost no influence in the training)
print_dict(
{
"expected_prompt_text": prompt_text_split[j],
"current_prompt_text": current_prompt_text[j],
"expected_prompt_text": prompt_text_split[j], # from llm_output["prompt_text"]
"current_prompt_text": current_prompt_text[j], # history prompt text converted from token_arr to text using tokenizer
"expected_token_ids": vllm_token_array, # from llm_output["prompt_token_ids"]
"current_token_ids": tracker_token_array, # from previous_ext_context[j].token_arr
},
mod="exception",
header="Prompt text mismatch, Please report a github issue",
header="Prompt token ids mismatch.",
)
previous_ext_context[j].token_arr = self.tokenizer(
prompt_text_split[j], return_tensors="pt", padding=False
)
else:
# if prompt text match
# we further check whether all token ids matches
vllm_token_array = split_prompt_token_ids[j]
tracker_token_array = previous_ext_context[j].token_arr
if vllm_token_array == tracker_token_array:
# good, everything is perfect
continue
else:
# otherwise, we throw a warning (do not worry, this causes almost no influence in the training)
print_dict(
{
"expected_token_ids": split_prompt_token_ids[j],
"current_token_ids": previous_ext_context[j].token_arr,
},
mod="exception",
header="Prompt token ids mismatch, Please report a github issue",
)

# # fix drift
# previous_ext_context[j].token_arr = self.tokenizer(
# prompt_text_split[j], return_tensors="pt", padding=False
# )["input_ids"].tolist()

def process_reward(self, reward_structure: Reward):
self.reward_structure = reward_structure
Expand Down
4 changes: 2 additions & 2 deletions ajet/context_tracker/single_agent_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def get_token_inc_from_llm_response(
self.generated_token_cnt += len(vllm_output_raw_token)
if not self.generation_prompt_token:
self.generation_prompt_token = self.get_generation_prompt_token()
final_token_arr, token_logprob_arr, loss_mask, lack_normal_eos = replace_token_ids(
final_token_arr, token_logprob_arr, loss_mask, lack_normal_eos = replace_token_ids( # pad tokens and logprobs with begin_ids / other_ids / NA
token_container=completion_token_arr,
precise_token=vllm_output_raw_token,
precise_logprob=vllm_output_raw_logprob,
Expand Down Expand Up @@ -187,7 +187,7 @@ def to_role_content(self, ext_msg_array: List[ExtendedMessage]) -> List:
for ext_msg in ext_msg_array:
d: dict = {
"role": ext_msg.role,
"content": ext_msg.content_for_compare,
"content": ext_msg.text_content_for_compare,
}
if ext_msg.tool_calls:
d.update({"tool_calls": ext_msg.tool_calls})
Expand Down
12 changes: 6 additions & 6 deletions ajet/context_tracker/timeline_merging/timeline_merging.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def is_timeline_mergeable(
for i in range(len(target_timeline)):
if timeline_compare_level == "text":
same = (
source_timeline[i].content_for_compare
== target_timeline[i].content_for_compare
source_timeline[i].text_content_for_compare
== target_timeline[i].text_content_for_compare
)
elif timeline_compare_level == "token":
same = source_timeline[i].token_arr == target_timeline[i].token_arr
Expand Down Expand Up @@ -52,12 +52,12 @@ def is_timeline_mergeable(
# all_msg_match = False
# for i in range(len(target_timeline)):
# d = {}
# d["source"] = source_timeline[i].content_for_compare
# d["target"] = target_timeline[i].content_for_compare
# d["source"] = source_timeline[i].text_content_for_compare
# d["target"] = target_timeline[i].text_content_for_compare
# if timeline_compare_level == "text":
# same = (
# source_timeline[i].content_for_compare
# == target_timeline[i].content_for_compare
# source_timeline[i].text_content_for_compare
# == target_timeline[i].text_content_for_compare
# )
# elif timeline_compare_level == "token":
# same = source_timeline[i].token_arr == target_timeline[i].token_arr
Expand Down
16 changes: 16 additions & 0 deletions ajet/copilot/create-keep-think-model-chat-template/SKILL.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@


Your task is to investigate the chat template of given model, go to its tokenizer config and check whether the following behavior exists:

>
> Remove history <think> block from the input when apply chat template when converting messages.
>

This behavior will make RL training slower, if this behavior exists, please change the chat template to forbid such behavior.

You must not do this in-place, instead, please create another model.
E.g., "/mnt/data_cpfs/xielipeng.xlp/models/Qwen3-8B" -> "/mnt/data_cpfs/xielipeng.xlp/models/Qwen3-8B-Keep-History"
For all files within the original model path, please create symbolic links instead of copying files.
With only one exception, the tokenizer config file, which should be copied and modified to change the chat template.


16 changes: 14 additions & 2 deletions ajet/copilot/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ class AgentJetJob:
whose ``num_repeat`` episodes do *not* all share the same reward. Tasks with uniform
reward (e.g. all 0 or all 1) produce zero advantage under GRPO and are skipped —
useful when the dataset contains many too-easy or too-hard prompts.
- "rollout_until_any_client_agree_sync_weight": defer the stop decision to the swarm
clients themselves. Stops as soon as **any** active swarm client invokes
``SwarmClient.agree_sync_weight()``. A client is "active" once it has successfully
``end_episode``'d at least one rewarded (non-abort) episode since the last weight
sync, and falls off the active list after 10 minutes of no chat-completion or
``begin_episode`` activity.
- "rollout_until_all_clients_agree_sync_weight": like the above, but stops only when
**every** active swarm client has agreed (and there is at least one active client).
max_env_worker: an estimation about how many episodes will be running in parallel (all swarm clients combined).
backbone: Training backbone framework (e.g., 'verl').
max_prompt_length: Maximum token length for input prompts (token length before the first llm-generated token, default 3000).
Expand All @@ -90,6 +98,7 @@ class AgentJetJob:
val_print_to_markdown_file_path: Path to a file where validation metrics are appended after every validation pass (default None, disabled).
train_print_to_markdown_file_path: Path to a file where training metrics are appended after every training step (default None, disabled).
total_training_steps: Hard cap on total training steps. If None (default), training runs for `total_epochs` epochs.
timeline_compare_level: Comparison granularity used by the context tracker's timeline merging policy. One of 'text' (relaxed text compare, more aggressive merging, very low cost) or 'token' (strict token compare, less aggressive merging). Default 'text'.
"""

def __init__(
Expand Down Expand Up @@ -130,6 +139,7 @@ def __init__(
val_print_to_markdown_file_path: str | None = None,
train_print_to_markdown_file_path: str | None = None,
total_training_steps: int | None = None,
timeline_compare_level: str | None = None,
) -> None:

if base_yaml_config is None:
Expand Down Expand Up @@ -195,6 +205,7 @@ def __init__(
self.val_print_to_markdown_file_path: str = cast(str, val_print_to_markdown_file_path)
self.train_print_to_markdown_file_path: str = cast(str, train_print_to_markdown_file_path)
self.total_training_steps: int = cast(int, total_training_steps)
self.timeline_compare_level: str = cast(str, timeline_compare_level)

# see `ajet/default_config/ajet_swarm_default.yaml`
overrides = {
Expand Down Expand Up @@ -230,9 +241,10 @@ def __init__(
"ajet.trainer_common.use_kl_in_reward": "use_kl_in_reward",
"ajet.trainer_common.kl_penalty_type": "kl_penalty_type",
"ajet.rollout.compute_madness_checklist": "compute_madness_checklist",
"ajet.trainer_common.val_print_to_markdown_file_path": "val_print_to_markdown_file_path",
"ajet.trainer_common.train_print_to_markdown_file_path": "train_print_to_markdown_file_path",
"ajet.trainer_common.total_training_steps": "total_training_steps",
"ajet.trainer_common.val_print_to_markdown_file_path": "val_print_to_markdown_file_path",
"ajet.trainer_common.train_print_to_markdown_file_path": "train_print_to_markdown_file_path",
"ajet.context_tracker.timeline_merging_policy.timeline_compare_level": "timeline_compare_level",
}

# if any value given in kwargs, override the corresponding value in config
Expand Down
8 changes: 8 additions & 0 deletions ajet/copilot/monitor-with-tmux/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,11 @@ $ python3 /tmp/tmux_wait.py ajet_session 240 && tmux capture-pane -t ajet_sessio
tmux kill-session -t ajet_session

```


## For AgentJet Swarm

- You should create seperate tmux session for each agentjet swarm servers and each agentjet swarm clients
- When debugging, please do not restart agentjet swarm servers frequently, that waste too much time
- When you really having difficulty for clearing GPU memory, run `ajet --autokill` to automatically kill all python and ray processes (however, I still recommend using this as a last resort).
Comment on lines +187 to +189
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There are a few typos and grammatical errors in this section that could be corrected for clarity.

Suggested change
- You should create seperate tmux session for each agentjet swarm servers and each agentjet swarm clients
- When debugging, please do not restart agentjet swarm servers frequently, that waste too much time
- When you really having difficulty for clearing GPU memory, run `ajet --autokill` to automatically kill all python and ray processes (however, I still recommend using this as a last resort).
- You should create separate tmux sessions for each AgentJet swarm server and each AgentJet swarm client.
- When debugging, please do not restart AgentJet swarm servers frequently, as that wastes a lot of time.
- If you are having difficulty clearing GPU memory, run `ajet --autokill` to automatically kill all Python and Ray processes (however, I still recommend using this as a last resort).

- For AgentJet, always use tmux session name that starts with `ajet-*`
Loading
Loading