feat: add logprob cross-engine benchmark#107
Conversation
📝 WalkthroughWalkthroughAdds an end-to-end logprob cross-engine benchmark: fixture/schema, synthetic/HF/vLLM rollout builders, training replay scoring, per-token drift comparison with thresholds, CLI entrypoint and wrapper, output persistence, docs, tests, and .gitignore updates. ChangesLogprob Cross-Engine Benchmark
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
rl_engine/benchmarks/logprob_cross_engine.py (1)
559-562: 💤 Low valueConsider logging the exception when tokenizer loading fails silently.
The broad
except Exceptionswallows all errors including unexpected ones likeKeyboardInterruptsubclasses or memory errors. While the fallback totokenizer = Noneis intentional, logging the actual exception would help debugging when tokenizer loading fails unexpectedly.Suggested improvement
try: transformers = _import_transformers() tokenizer = transformers.AutoTokenizer.from_pretrained( config.tokenizer or config.model, revision=config.model_revision, trust_remote_code=config.trust_remote_code, ) - except Exception: + except Exception as exc: + import logging + logging.debug("Optional tokenizer loading failed: %s", exc) tokenizer = None🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@rl_engine/benchmarks/logprob_cross_engine.py` around lines 559 - 562, The try/except that sets tokenizer = None swallows errors silently; modify the except block around the tokenizer load (the try that uses trust_remote_code=config.trust_remote_code and assigns to tokenizer) to capture the exception (e.g. except Exception as e:) and log it before falling back, e.g. logger.exception or logger.error(..., exc_info=True), then keep tokenizer = None so behavior doesn't change.Source: Linters/SAST tools
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@rl_engine/benchmarks/logprob_cross_engine.py`:
- Line 1517: The condition that checks token_id uses item.get("token_id"), which
mypy flags as possibly None even after an 'in' check; change the expression in
the if-condition to use direct key access so mypy can infer non-None: replace
int(item.get("token_id")) with int(item["token_id"]) in the condition that
begins with if "logprob" in item and ("token_id" not in item or ...), keeping
the same logic and cast to int(item["token_id"]) == token_id.
---
Nitpick comments:
In `@rl_engine/benchmarks/logprob_cross_engine.py`:
- Around line 559-562: The try/except that sets tokenizer = None swallows errors
silently; modify the except block around the tokenizer load (the try that uses
trust_remote_code=config.trust_remote_code and assigns to tokenizer) to capture
the exception (e.g. except Exception as e:) and log it before falling back, e.g.
logger.exception or logger.error(..., exc_info=True), then keep tokenizer = None
so behavior doesn't change.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: b645d588-5550-404d-80a9-46964cad797d
📒 Files selected for processing (6)
.gitignorebenchmarks/logprob_cross_engine.pydocs/benchmarking/README.mdrl_engine/benchmarks/__init__.pyrl_engine/benchmarks/logprob_cross_engine.pytests/test_logprob_cross_engine.py
c76f784 to
847abb6
Compare
Signed-off-by: inaniloquentee <3051000145@qq.com>
847abb6 to
33cadd0
Compare
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@rl_engine/benchmarks/logprob_cross_engine.py`:
- Around line 907-923: The JSONL ingestion flattens nested
LogprobBenchmarkFixture instances but discards their metadata (model, tokenizer,
dtype, rollout_engine, schema_version), causing downstream replay to use wrong
defaults; update the logic in the JSONL loader (the branch that calls
LogprobBenchmarkFixture.from_dict and extends sequences) to also capture and
merge the first non-default metadata fields from the nested fixture into the
top-level LogprobBenchmarkFixture you return (e.g., set model, tokenizer, dtype,
rollout_engine, and schema_version from the nested fixture when they are present
and not placeholder values like "fixture"/"unknown"/None), and also merge/append
nested metadata dicts into the top-level metadata field so provenance is
preserved; keep using LogprobSequence.from_dict for individual sequence lines
and preserve default_sequence_id behavior.
- Around line 1183-1192: The helper _call_model should not blanket-catch all
TypeError instances; instead, inspect the model's callable signature and pass
only kwargs that the model actually accepts. Replace the current
retry-on-TypeError logic with code that obtains the callable (prefer
model.forward if available, else model.__call__), uses inspect.signature(...) to
collect parameter names, builds a filtered_kwargs containing only keys present
in that parameter set (keep input_ids as a positional/keyword argument), and
then calls the model once with input_ids and filtered_kwargs; remove the nested
TypeError retries so internal TypeErrors are not swallowed.
- Around line 177-180: The current list comprehension silently drops non-object
entries from sequences_payload; change it to explicitly validate each item and
reject malformed entries by raising a clear exception instead of filtering them
out: iterate over sequences_payload with enumerate, and for each (index, item)
if not isinstance(item, Mapping) raise a TypeError/ValueError that includes the
index and the offending item, otherwise call LogprobSequence.from_dict(item,
default_sequence_id=f"seq-{index}") and append to sequences; this ensures the
code using sequences (and the LogprobSequence.from_dict call) only runs on
validated inputs and surfaces corrupted fixtures immediately.
- Around line 428-447: HF generation and vLLM sampling don't receive the global
seed/do_sample flags, causing non-reproducible and inconsistent behavior across
engines; update the HF path around generation_kwargs/model.generate to create a
torch.Generator seeded with config.seed (e.g., g =
torch.Generator(device=encoded["input_ids"].device).manual_seed(config.seed))
and pass generator=g into model.generate, and only include
temperature/top_p/top_k in generation_kwargs when config.do_sample is true
(preserve the existing top_k > 0 guard). For the vLLM path, forward config.seed
and config.do_sample into vllm.SamplingParams (e.g.,
SamplingParams(seed=int(config.seed) if set, do_sample=bool(config.do_sample),
temperature=..., top_p=..., top_k=... with same top_k > 0 rule) so both engines
honor the same sampling flags and seed. Ensure any places mentioned in the
comment (the HF block around generation_kwargs/model.generate and the vLLM
SamplingParams construction) are updated accordingly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 7c391178-ba8e-4e3b-9c1a-fcfaa51a0ad0
📒 Files selected for processing (6)
.gitignorebenchmarks/logprob_cross_engine.pydocs/benchmarking/README.mdrl_engine/benchmarks/__init__.pyrl_engine/benchmarks/logprob_cross_engine.pytests/test_logprob_cross_engine.py
✅ Files skipped from review due to trivial changes (3)
- rl_engine/benchmarks/init.py
- .gitignore
- docs/benchmarking/README.md
🚧 Files skipped from review as they are similar to previous changes (1)
- benchmarks/logprob_cross_engine.py
| sequences=[ | ||
| LogprobSequence.from_dict(item, default_sequence_id=f"seq-{index}") | ||
| for index, item in enumerate(sequences_payload) | ||
| if isinstance(item, Mapping) |
There was a problem hiding this comment.
Reject malformed sequences entries instead of silently dropping them.
This comprehension filters out non-object items and keeps benchmarking the remaining subset. A partially corrupt fixture then changes the evaluated sample set without any ingest failure, which is a bad contract for a parity benchmark.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/benchmarks/logprob_cross_engine.py` around lines 177 - 180, The
current list comprehension silently drops non-object entries from
sequences_payload; change it to explicitly validate each item and reject
malformed entries by raising a clear exception instead of filtering them out:
iterate over sequences_payload with enumerate, and for each (index, item) if not
isinstance(item, Mapping) raise a TypeError/ValueError that includes the index
and the offending item, otherwise call LogprobSequence.from_dict(item,
default_sequence_id=f"seq-{index}") and append to sequences; this ensures the
code using sequences (and the LogprobSequence.from_dict call) only runs on
validated inputs and surfaces corrupted fixtures immediately.
| generation_kwargs = { | ||
| "max_new_tokens": config.max_new_tokens, | ||
| "do_sample": config.do_sample, | ||
| "temperature": config.temperature if config.do_sample else None, | ||
| "top_p": config.top_p if config.do_sample else None, | ||
| "top_k": config.top_k if config.do_sample and config.top_k > 0 else None, | ||
| "num_return_sequences": config.num_generations, | ||
| "return_dict_in_generate": True, | ||
| "output_scores": True, | ||
| "pad_token_id": tokenizer.pad_token_id, | ||
| } | ||
| generation_kwargs = { | ||
| key: value for key, value in generation_kwargs.items() if value is not None | ||
| } | ||
| with torch.no_grad(): | ||
| outputs = model.generate( | ||
| input_ids=encoded["input_ids"], | ||
| attention_mask=encoded["attention_mask"], | ||
| **generation_kwargs, | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
FILE="rl_engine/benchmarks/logprob_cross_engine.py"
echo "== File exists? =="
ls -l "$FILE" || true
echo "== Show relevant sections around 380-620 =="
nl -ba "$FILE" | sed -n '380,620p'
echo "== Search for config.seed usage in this file =="
rg -n "config\.seed|\bseed\b" "$FILE"
echo "== Search for do_sample forwarding in this file =="
rg -n "do_sample|max_new_tokens|temperature|top_p|top_k|num_return_sequences" "$FILE"
echo "== Find vLLM-specific code paths =="
rg -n "vllm|VLLM|LLM\(" "$FILE"
echo "== Find HuggingFace generate() usage =="
rg -n "model\.generate|generate\(" "$FILE"
echo "== List other files that might handle rollout/backends (if referenced) =="
rg -n "rollout|backend|engine|hf|huggingface|vllm" "$FILE"Repository: RL-Align/RL-Kernel
Length of output: 281
🏁 Script executed:
#!/bin/bash
set -euo pipefail
FILE="rl_engine/benchmarks/logprob_cross_engine.py"
echo "== Show relevant sections around 380-620 =="
cat -n "$FILE" | sed -n '380,620p'
echo "== Search for config.seed usage in this file =="
rg -n "config\.seed|\bseed\b" "$FILE" || true
echo "== Search for do_sample forwarding patterns in this file =="
rg -n "do_sample|max_new_tokens|temperature|top_p|top_k|num_return_sequences" "$FILE" || true
echo "== Find vLLM-specific code paths in this file =="
rg -n "vllm|VLLM" "$FILE" || true
echo "== Find HuggingFace generate() usage in this file =="
rg -n "model\.generate|generate\(" "$FILE" || trueRepository: RL-Align/RL-Kernel
Length of output: 16378
🏁 Script executed:
#!/bin/bash
set -euo pipefail
FILE="rl_engine/benchmarks/logprob_cross_engine.py"
echo "== Show sections around seed/generator in this file (150-360) =="
cat -n "$FILE" | sed -n '150,360p'
echo "== Search for torch.manual_seed / torch.Generator usage in this file =="
rg -n "manual_seed|torch\.Generator|Generator\(" "$FILE" || true
echo "== Inspect lines around model/generation kwargs for HF rollout (410-470) =="
cat -n "$FILE" | sed -n '400,480p'
echo "== Inspect vLLM sampler config call site (520-620) =="
cat -n "$FILE" | sed -n '520,630p'
# locate vLLM sampler implementation
echo "== Locate vllm_sampler implementation =="
ls -R rl_engine/executors | head -n 200Repository: RL-Align/RL-Kernel
Length of output: 20184
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== Inspect rl_engine/executors/vllm_sampler.py (seed/do_sample related parts) =="
FILE="rl_engine/executors/vllm_sampler.py"
ls -l "$FILE"
wc -l "$FILE"
# show head and then relevant sections by grep context
rg -n "seed|do_sample|temperature|top_p|top_k|sampling_params|generator" "$FILE" || true
echo "== Show sections around any seed occurrences =="
rg -n "seed" "$FILE" | cut -d: -f1 | head -n 50 | while read -r line; do
start=$((line-30)); end=$((line+30));
sed -n "${start},${end}p" "$FILE" | sed -n '1,120p' >/dev/null
done
# Instead, print exact line windows for first few occurrences
lines=$(rg -n "seed" "$FILE" | cut -d: -f1 | head -n 10 | tr '\n' ' ')
if [ -n "$lines" ]; then
for line in $lines; do
start=$((line-25)); end=$((line+25));
echo "--- around seed at line $line ---"
sed -n "${start},${end}p" "$FILE" | cat
done
fi
echo "== Show sections around VLLMSamplerConfig definition =="
rg -n "class VLLMSamplerConfig|VLLMSamplerConfig\\(" "$FILE" || true
rg -n "def generate\\(" "$FILE" || true
# Show around config/dataclass and generate method
cfg_line=$(rg -n "class VLLMSamplerConfig" "$FILE" | head -n1 | cut -d: -f1 || true)
gen_line=$(rg -n "def generate\\(" "$FILE" | head -n1 | cut -d: -f1 || true)
if [ -n "${cfg_line:-}" ]; then
echo "--- around VLLMSamplerConfig at line $cfg_line ---"
sed -n "$((cfg_line-40)),$((cfg_line+80))p" "$FILE" | cat
fi
if [ -n "${gen_line:-}" ]; then
echo "--- around generate at line $gen_line ---"
sed -n "$((gen_line-40)),$((gen_line+120))p" "$FILE" | cat
fi
echo "== Search other seeding call sites in logprob_cross_engine.py and repo =="
FILE2="rl_engine/benchmarks/logprob_cross_engine.py"
rg -n "manual_seed|torch\.Generator\\(|seed\\s*=\\s*config\\.seed|generator\\s*=" "$FILE2" || true
# broader search for torch.manual_seed usage
rg -n "torch\\.manual_seed\\(" . || trueRepository: RL-Align/RL-Kernel
Length of output: 1304
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== Inspect rl_engine/executors/vllm_sampler.py (seed/do_sample related parts) =="
FILE="rl_engine/executors/vllm_sampler.py"
ls -l "$FILE"
wc -l "$FILE"
rg -n "seed|do_sample|temperature|top_p|top_k|sampling_params|generator" "$FILE" || true
echo "== Show sections around any seed occurrences =="
seed_lines=$(rg -n "seed" "$FILE" | cut -d: -f1 | head -n 20 || true)
if [ -n "${seed_lines:-}" ]; then
for line in $seed_lines; do
start=$((line-25)); end=$((line+25));
echo "--- around seed at line $line ---"
sed -n "${start},${end}p" "$FILE" | cat
done
fi
echo "== Show around VLLMSamplerConfig and generate() =="
cfg_line=$(rg -n "class VLLMSamplerConfig" "$FILE" | head -n1 | cut -d: -f1 || true)
gen_line=$(rg -n "def generate\\(" "$FILE" | head -n1 | cut -d: -f1 || true)
if [ -n "${cfg_line:-}" ]; then
echo "--- around VLLMSamplerConfig at line $cfg_line ---"
sed -n "$((cfg_line-40)),$((cfg_line+120))p" "$FILE" | cat
fi
if [ -n "${gen_line:-}" ]; then
echo "--- around generate() at line $gen_line ---"
sed -n "$((gen_line-60)),$((gen_line+160))p" "$FILE" | cat
fi
echo "== HF path: check for global torch seeding usage =="
FILE2="rl_engine/benchmarks/logprob_cross_engine.py"
rg -n "manual_seed|torch\.Generator\\(|generator\\s*=" "$FILE2" || true
echo "== Repo-wide torch.manual_seed usage =="
rg -n "torch\\.manual_seed\\(" . || trueRepository: RL-Align/RL-Kernel
Length of output: 3240
🏁 Script executed:
#!/bin/bash
set -euo pipefail
FILE="rl_engine/executors/vllm_sampler.py"
echo "== Print vllm_sampler.py with line numbers (selected ranges) =="
awk 'NR>=1 && NR<=220 {printf "%4d\t%s\n", NR, $0}' "$FILE" | sed -n '1,220p'
echo "----"
awk 'NR>=220 && NR<=320 {printf "%4d\t%s\n", NR, $0}' "$FILE" | sed -n '1,120p'
echo "== Search for 'seed' in vllm_sampler.py =="
rg -n "\bseed\b" "$FILE" || true
echo "== Search for vLLM SamplingParamsConfig seed / generator / do_sample usage =="
rg -n "SamplingParams|do_sample|temperature|top_p|top_k|seed|rng|generator" "$FILE" || trueRepository: RL-Align/RL-Kernel
Length of output: 12621
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== Search repo for torch.manual_seed / set_seed usage =="
rg -n "torch\.manual_seed\(|set_seed\(" . || true
echo "== Search within logprob_cross_engine.py for any seeding calls =="
rg -n "manual_seed|set_seed|Generator\(|torch\.Generator" rl_engine/benchmarks/logprob_cross_engine.py || true
echo "== Inspect CLI/entrypoint around --seed parsing and any seeding logic =="
cat -n rl_engine/benchmarks/logprob_cross_engine.py | sed -n '980,1160p'Repository: RL-Align/RL-Kernel
Length of output: 10030
Propagate --seed and --do-sample into every rollout backend (HF + vLLM)
config.seed is only used for the synthetic rollout generator; the HF path calls model.generate(...) without a generator/seed, and the vLLM path never forwards seed (or any do_sample flag) into vllm.SamplingParams, so sampled runs won’t be reproducible across engines. Also, --do-sample changes behavior in the HF path but is effectively ignored in the vLLM path (it always passes temperature/top_p), breaking cross-engine parity.
Applies to: 428-447 and 531-549
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/benchmarks/logprob_cross_engine.py` around lines 428 - 447, HF
generation and vLLM sampling don't receive the global seed/do_sample flags,
causing non-reproducible and inconsistent behavior across engines; update the HF
path around generation_kwargs/model.generate to create a torch.Generator seeded
with config.seed (e.g., g =
torch.Generator(device=encoded["input_ids"].device).manual_seed(config.seed))
and pass generator=g into model.generate, and only include
temperature/top_p/top_k in generation_kwargs when config.do_sample is true
(preserve the existing top_k > 0 guard). For the vLLM path, forward config.seed
and config.do_sample into vllm.SamplingParams (e.g.,
SamplingParams(seed=int(config.seed) if set, do_sample=bool(config.do_sample),
temperature=..., top_p=..., top_k=... with same top_k > 0 rule) so both engines
honor the same sampling flags and seed. Ensure any places mentioned in the
comment (the HF block around generation_kwargs/model.generate and the vLLM
SamplingParams construction) are updated accordingly.
| if "sequences" in payload: | ||
| nested = LogprobBenchmarkFixture.from_dict(payload) | ||
| sequences.extend(nested.sequences) | ||
| continue | ||
| sequences.append( | ||
| LogprobSequence.from_dict(payload, default_sequence_id=f"jsonl-{index}") | ||
| ) | ||
| return LogprobBenchmarkFixture( | ||
| schema_version=SCHEMA_VERSION, | ||
| created_at=_utc_now(), | ||
| rollout_engine="fixture", | ||
| model="fixture", | ||
| tokenizer=None, | ||
| dtype="unknown", | ||
| sequences=sequences, | ||
| metadata={"source_path": str(fixture_path)}, | ||
| ) |
There was a problem hiding this comment.
JSONL ingest currently throws away the metadata needed for exact replay.
When a .jsonl file is loaded, nested fixtures are flattened into sequences and the returned LogprobBenchmarkFixture is rebuilt with model="fixture", dtype="unknown", and tokenizer=None. Downstream, load_training_model() uses fixture.model to infer the replay model, so JSONL replay can silently score against the wrong weights and report incomplete provenance.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/benchmarks/logprob_cross_engine.py` around lines 907 - 923, The
JSONL ingestion flattens nested LogprobBenchmarkFixture instances but discards
their metadata (model, tokenizer, dtype, rollout_engine, schema_version),
causing downstream replay to use wrong defaults; update the logic in the JSONL
loader (the branch that calls LogprobBenchmarkFixture.from_dict and extends
sequences) to also capture and merge the first non-default metadata fields from
the nested fixture into the top-level LogprobBenchmarkFixture you return (e.g.,
set model, tokenizer, dtype, rollout_engine, and schema_version from the nested
fixture when they are present and not placeholder values like
"fixture"/"unknown"/None), and also merge/append nested metadata dicts into the
top-level metadata field so provenance is preserved; keep using
LogprobSequence.from_dict for individual sequence lines and preserve
default_sequence_id behavior.
| def _call_model(model: torch.nn.Module, input_ids: torch.Tensor, **kwargs: Any) -> Any: | ||
| try: | ||
| return model(input_ids=input_ids, **kwargs) | ||
| except TypeError: | ||
| kwargs.pop("use_cache", None) | ||
| try: | ||
| return model(input_ids=input_ids, **kwargs) | ||
| except TypeError: | ||
| kwargs.pop("attention_mask", None) | ||
| return model(input_ids) |
There was a problem hiding this comment.
Do not treat any TypeError as a signature mismatch.
If model.forward() raises TypeError from inside its implementation, this helper retries with fewer arguments and can silently benchmark a different code path instead of surfacing the real failure. Filter unsupported kwargs before the call, or only swallow explicit unexpected-argument errors.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/benchmarks/logprob_cross_engine.py` around lines 1183 - 1192, The
helper _call_model should not blanket-catch all TypeError instances; instead,
inspect the model's callable signature and pass only kwargs that the model
actually accepts. Replace the current retry-on-TypeError logic with code that
obtains the callable (prefer model.forward if available, else model.__call__),
uses inspect.signature(...) to collect parameter names, builds a filtered_kwargs
containing only keys present in that parameter set (keep input_ids as a
positional/keyword argument), and then calls the model once with input_ids and
filtered_kwargs; remove the nested TypeError retries so internal TypeErrors are
not swallowed.
33cadd0 to
7549409
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@rl_engine/benchmarks/logprob_cross_engine.py`:
- Around line 106-110: The code currently casts every item in
payload["completion_mask"] through bool(...), which silently accepts stringy
values like "false" and corrupts the mask; instead validate items explicitly:
when completion_mask_value is provided, iterate over completion_mask_value and
accept only actual booleans (bool) or explicit integer encodings 0/1 (int ->
bool), otherwise raise a clear ValueError indicating the invalid mask entry and
its index; update the creation of completion_mask (the variable currently
computed from completion_mask_value and completion_token_ids) to enforce the
length matches completion_token_ids and to fail fast on malformed entries rather
than truthifying them.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: dc96046b-a92d-47af-ad20-6a2fc99f3db2
📒 Files selected for processing (6)
.gitignorebenchmarks/logprob_cross_engine.pydocs/benchmarking/README.mdrl_engine/benchmarks/__init__.pyrl_engine/benchmarks/logprob_cross_engine.pytests/test_logprob_cross_engine.py
✅ Files skipped from review due to trivial changes (2)
- rl_engine/benchmarks/init.py
- .gitignore
🚧 Files skipped from review as they are similar to previous changes (2)
- docs/benchmarking/README.md
- benchmarks/logprob_cross_engine.py
| completion_mask_value = payload.get("completion_mask") | ||
| if completion_mask_value is None: | ||
| completion_mask = [True] * len(completion_token_ids) | ||
| else: | ||
| completion_mask = [bool(item) for item in completion_mask_value] |
There was a problem hiding this comment.
Reject stringly completion_mask values instead of truthifying them.
Lines 106-110 run every entry through bool(...), so payloads like ["false", "false"] or ["0", "1"] become all-True masks and silently change which tokens are benchmarked. For a replay fixture, malformed mask encodings should fail fast or be parsed explicitly.
Suggested fix
completion_mask_value = payload.get("completion_mask")
if completion_mask_value is None:
completion_mask = [True] * len(completion_token_ids)
else:
- completion_mask = [bool(item) for item in completion_mask_value]
+ if not isinstance(completion_mask_value, Sequence) or isinstance(
+ completion_mask_value, (str, bytes)
+ ):
+ raise ValueError("completion_mask must be a sequence of booleans")
+ completion_mask = []
+ for index, item in enumerate(completion_mask_value):
+ if isinstance(item, bool):
+ completion_mask.append(item)
+ elif isinstance(item, int) and item in {0, 1}:
+ completion_mask.append(bool(item))
+ else:
+ raise ValueError(
+ f"completion_mask[{index}] must be a boolean or 0/1 integer"
+ )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/benchmarks/logprob_cross_engine.py` around lines 106 - 110, The
code currently casts every item in payload["completion_mask"] through bool(...),
which silently accepts stringy values like "false" and corrupts the mask;
instead validate items explicitly: when completion_mask_value is provided,
iterate over completion_mask_value and accept only actual booleans (bool) or
explicit integer encodings 0/1 (int -> bool), otherwise raise a clear ValueError
indicating the invalid mask entry and its index; update the creation of
completion_mask (the variable currently computed from completion_mask_value and
completion_token_ids) to enforce the length matches completion_token_ids and to
fail fast on malformed entries rather than truthifying them.
Summary
Closes #106
Validation
py -3.13 -m pre_commit run --all-filespy -3.13 -m pytest -qpy -3.13 -m ruff check rl_engine/benchmarks/logprob_cross_engine.py benchmarks/logprob_cross_engine.py tests/test_logprob_cross_engine.pypy -3.13 benchmarks/logprob_cross_engine.py --smoke --device cpu --output-dir artifacts/logprob_cross_engine/final-smoke --no-summaryDCO
Signed-off-by: inaniloquentee <3051000145@qq.com>.Summary by CodeRabbit
New Features
Documentation
Tests
Chores