Skip to content

feat: add logprob cross-engine benchmark#107

Open
inaniloquentee wants to merge 1 commit into
mainfrom
feat/logprob-cross-engine-tool
Open

feat: add logprob cross-engine benchmark#107
inaniloquentee wants to merge 1 commit into
mainfrom
feat/logprob-cross-engine-tool

Conversation

@inaniloquentee

@inaniloquentee inaniloquentee commented Jun 13, 2026

Copy link
Copy Markdown
Collaborator

Summary

  • Add an end-to-end selected-logprob cross-engine benchmark harness for rollout vs training parity.
  • Support synthetic smoke rollout, Hugging Face rollout, vLLM rollout adapter, and production fixture/JSONL replay.
  • Compare policy/old/reference logprob channels with reproducible JSON, JSONL, and markdown reports.
  • Document smoke, HF, and production fixture workflows.

Closes #106

Validation

  • py -3.13 -m pre_commit run --all-files
  • py -3.13 -m pytest -q
  • py -3.13 -m ruff check rl_engine/benchmarks/logprob_cross_engine.py benchmarks/logprob_cross_engine.py tests/test_logprob_cross_engine.py
  • py -3.13 benchmarks/logprob_cross_engine.py --smoke --device cpu --output-dir artifacts/logprob_cross_engine/final-smoke --no-summary

DCO

  • Commit includes Signed-off-by: inaniloquentee <3051000145@qq.com>.

Summary by CodeRabbit

  • New Features

    • Added a cross‑engine logprob validation benchmark CLI with smoke mode, configurable thresholds, optional mismatch injection, and detailed per-run outputs (reports, fixtures, token drift records).
  • Documentation

    • Added a benchmarking guide with workflows, example commands, expected artifacts, and triage/reporting fields.
  • Tests

    • Added end‑to‑end and unit tests covering smoke runs, fixture ingestion/replay, training replay behavior, mismatch injection, and CLI exit conditions.
  • Chores

    • Updated ignore rules to exclude artifacts/ and reports/ directories.

@coderabbitai

coderabbitai Bot commented Jun 13, 2026

Copy link
Copy Markdown

Review Change Stack

📝 Walkthrough

Walkthrough

Adds an end-to-end logprob cross-engine benchmark: fixture/schema, synthetic/HF/vLLM rollout builders, training replay scoring, per-token drift comparison with thresholds, CLI entrypoint and wrapper, output persistence, docs, tests, and .gitignore updates.

Changes

Logprob Cross-Engine Benchmark

Layer / File(s) Summary
Config and data schema
rl_engine/benchmarks/logprob_cross_engine.py (lines 1–236)
Defines immutable dataclasses LogprobSequence and LogprobBenchmarkFixture with JSON serialization, schema versioning, token normalization, and completion-mask handling; adds DriftThresholds and LogprobCrossBenchmarkConfig with defaults for rollout/training engines and drift tolerances.
Synthetic model & compute primitives
rl_engine/benchmarks/logprob_cross_engine.py (lines 240–279, 1116–1239)
Adds a deterministic tiny causal LM and internal batch-layout utilities for padded micro-batches, logit-position mapping, dtype helpers, and tolerant model-calling utilities used in smoke and test paths.
Rollout fixture builders
rl_engine/benchmarks/logprob_cross_engine.py (lines 281–622, 1480–1537)
Implements synthetic token-by-token fixture generation, HuggingFace batched generation with transition-score extraction, and vLLM sampler integration; includes prompt loading/encoding, trimming generated tokens at EOS/PAD, and selected-logprob extraction from rollout payloads.
Training logprob scoring
rl_engine/benchmarks/logprob_cross_engine.py (lines 624–877, 1241–1406)
Replays rollout sequences through training engines with per-model caching, micro-batched scoring (score_sequences_with_model), multi-shape logits extraction, gathering logits at precise positions, and computing selected logprobs via rl_engine.testing.selected_logprobs_reference.
Comparison, reporting, and output
rl_engine/benchmarks/logprob_cross_engine.py (lines 880–976, 1540–1541)
Compares rollout vs training logprobs per active token, computes absolute/relative errors and aggregates, flags pass/fail against thresholds, optionally attaches full token-drift records, and persists rollout_fixture.json/report.json/token_drifts.jsonl/summary.md.
CLI and orchestration
rl_engine/benchmarks/logprob_cross_engine.py, benchmarks/logprob_cross_engine.py
Adds build_arg_parser, config_from_args (with smoke-mode parameter clamping), main, _validate_config, and a top-level executable wrapper that runs main() when invoked as a script.
Comprehensive test coverage
tests/test_logprob_cross_engine.py
Nine tests covering synthetic smoke validation with output checks, JSONL fixture ingestion/replay, ragged batch micro-batching, multi-channel (policy + ref) comparison, token-shift mismatch detection with worst-drift metadata, and CLI integration for success and failure paths.
Documentation and project setup
docs/benchmarking/README.md, rl_engine/benchmarks/__init__.py, .gitignore, benchmarks/logprob_cross_engine.py
Adds a “Train-Inference Logprob Cross-Benchmark” docs section and smoke command example, package initializer for rl_engine.benchmarks, executable wrapper script, and .gitignore updates to ignore generated artifacts/ and reports/ directories.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Poem

🐰 A rabbit's ode to logprob harmony:
I hop through tokens, counting each small log,
Rollout and replay dance along the log,
I sniff the drift, I write the tidy files,
Tiny models hum and tests break into smiles.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 1.45% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'feat: add logprob cross-engine benchmark' directly and clearly describes the main changeset, which adds a new end-to-end logprob cross-engine benchmark harness.
Linked Issues check ✅ Passed The pull request comprehensively implements all primary coding objectives from issue #106: fixture schema, reference training replay, rollout adapters, comparator reporting, deterministic smoke mode, and extensive test coverage for mismatch detection.
Out of Scope Changes check ✅ Passed All changes are directly related to the logprob cross-engine benchmark implementation. The .gitignore update, documentation, new benchmark modules, and tests are all in-scope and support the core objective of issue #106.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/logprob-cross-engine-tool

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
rl_engine/benchmarks/logprob_cross_engine.py (1)

559-562: 💤 Low value

Consider logging the exception when tokenizer loading fails silently.

The broad except Exception swallows all errors including unexpected ones like KeyboardInterrupt subclasses or memory errors. While the fallback to tokenizer = None is intentional, logging the actual exception would help debugging when tokenizer loading fails unexpectedly.

Suggested improvement
         try:
             transformers = _import_transformers()
             tokenizer = transformers.AutoTokenizer.from_pretrained(
                 config.tokenizer or config.model,
                 revision=config.model_revision,
                 trust_remote_code=config.trust_remote_code,
             )
-        except Exception:
+        except Exception as exc:
+            import logging
+            logging.debug("Optional tokenizer loading failed: %s", exc)
             tokenizer = None
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/benchmarks/logprob_cross_engine.py` around lines 559 - 562, The
try/except that sets tokenizer = None swallows errors silently; modify the
except block around the tokenizer load (the try that uses
trust_remote_code=config.trust_remote_code and assigns to tokenizer) to capture
the exception (e.g. except Exception as e:) and log it before falling back, e.g.
logger.exception or logger.error(..., exc_info=True), then keep tokenizer = None
so behavior doesn't change.

Source: Linters/SAST tools

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@rl_engine/benchmarks/logprob_cross_engine.py`:
- Line 1517: The condition that checks token_id uses item.get("token_id"), which
mypy flags as possibly None even after an 'in' check; change the expression in
the if-condition to use direct key access so mypy can infer non-None: replace
int(item.get("token_id")) with int(item["token_id"]) in the condition that
begins with if "logprob" in item and ("token_id" not in item or ...), keeping
the same logic and cast to int(item["token_id"]) == token_id.

---

Nitpick comments:
In `@rl_engine/benchmarks/logprob_cross_engine.py`:
- Around line 559-562: The try/except that sets tokenizer = None swallows errors
silently; modify the except block around the tokenizer load (the try that uses
trust_remote_code=config.trust_remote_code and assigns to tokenizer) to capture
the exception (e.g. except Exception as e:) and log it before falling back, e.g.
logger.exception or logger.error(..., exc_info=True), then keep tokenizer = None
so behavior doesn't change.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: b645d588-5550-404d-80a9-46964cad797d

📥 Commits

Reviewing files that changed from the base of the PR and between 04c014d and c76f784.

📒 Files selected for processing (6)
  • .gitignore
  • benchmarks/logprob_cross_engine.py
  • docs/benchmarking/README.md
  • rl_engine/benchmarks/__init__.py
  • rl_engine/benchmarks/logprob_cross_engine.py
  • tests/test_logprob_cross_engine.py

Comment thread rl_engine/benchmarks/logprob_cross_engine.py Outdated
Signed-off-by: inaniloquentee <3051000145@qq.com>
@inaniloquentee inaniloquentee force-pushed the feat/logprob-cross-engine-tool branch from 847abb6 to 33cadd0 Compare June 13, 2026 09:44

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@rl_engine/benchmarks/logprob_cross_engine.py`:
- Around line 907-923: The JSONL ingestion flattens nested
LogprobBenchmarkFixture instances but discards their metadata (model, tokenizer,
dtype, rollout_engine, schema_version), causing downstream replay to use wrong
defaults; update the logic in the JSONL loader (the branch that calls
LogprobBenchmarkFixture.from_dict and extends sequences) to also capture and
merge the first non-default metadata fields from the nested fixture into the
top-level LogprobBenchmarkFixture you return (e.g., set model, tokenizer, dtype,
rollout_engine, and schema_version from the nested fixture when they are present
and not placeholder values like "fixture"/"unknown"/None), and also merge/append
nested metadata dicts into the top-level metadata field so provenance is
preserved; keep using LogprobSequence.from_dict for individual sequence lines
and preserve default_sequence_id behavior.
- Around line 1183-1192: The helper _call_model should not blanket-catch all
TypeError instances; instead, inspect the model's callable signature and pass
only kwargs that the model actually accepts. Replace the current
retry-on-TypeError logic with code that obtains the callable (prefer
model.forward if available, else model.__call__), uses inspect.signature(...) to
collect parameter names, builds a filtered_kwargs containing only keys present
in that parameter set (keep input_ids as a positional/keyword argument), and
then calls the model once with input_ids and filtered_kwargs; remove the nested
TypeError retries so internal TypeErrors are not swallowed.
- Around line 177-180: The current list comprehension silently drops non-object
entries from sequences_payload; change it to explicitly validate each item and
reject malformed entries by raising a clear exception instead of filtering them
out: iterate over sequences_payload with enumerate, and for each (index, item)
if not isinstance(item, Mapping) raise a TypeError/ValueError that includes the
index and the offending item, otherwise call LogprobSequence.from_dict(item,
default_sequence_id=f"seq-{index}") and append to sequences; this ensures the
code using sequences (and the LogprobSequence.from_dict call) only runs on
validated inputs and surfaces corrupted fixtures immediately.
- Around line 428-447: HF generation and vLLM sampling don't receive the global
seed/do_sample flags, causing non-reproducible and inconsistent behavior across
engines; update the HF path around generation_kwargs/model.generate to create a
torch.Generator seeded with config.seed (e.g., g =
torch.Generator(device=encoded["input_ids"].device).manual_seed(config.seed))
and pass generator=g into model.generate, and only include
temperature/top_p/top_k in generation_kwargs when config.do_sample is true
(preserve the existing top_k > 0 guard). For the vLLM path, forward config.seed
and config.do_sample into vllm.SamplingParams (e.g.,
SamplingParams(seed=int(config.seed) if set, do_sample=bool(config.do_sample),
temperature=..., top_p=..., top_k=... with same top_k > 0 rule) so both engines
honor the same sampling flags and seed. Ensure any places mentioned in the
comment (the HF block around generation_kwargs/model.generate and the vLLM
SamplingParams construction) are updated accordingly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 7c391178-ba8e-4e3b-9c1a-fcfaa51a0ad0

📥 Commits

Reviewing files that changed from the base of the PR and between 847abb6 and 33cadd0.

📒 Files selected for processing (6)
  • .gitignore
  • benchmarks/logprob_cross_engine.py
  • docs/benchmarking/README.md
  • rl_engine/benchmarks/__init__.py
  • rl_engine/benchmarks/logprob_cross_engine.py
  • tests/test_logprob_cross_engine.py
✅ Files skipped from review due to trivial changes (3)
  • rl_engine/benchmarks/init.py
  • .gitignore
  • docs/benchmarking/README.md
🚧 Files skipped from review as they are similar to previous changes (1)
  • benchmarks/logprob_cross_engine.py

Comment on lines +177 to +180
sequences=[
LogprobSequence.from_dict(item, default_sequence_id=f"seq-{index}")
for index, item in enumerate(sequences_payload)
if isinstance(item, Mapping)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Reject malformed sequences entries instead of silently dropping them.

This comprehension filters out non-object items and keeps benchmarking the remaining subset. A partially corrupt fixture then changes the evaluated sample set without any ingest failure, which is a bad contract for a parity benchmark.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/benchmarks/logprob_cross_engine.py` around lines 177 - 180, The
current list comprehension silently drops non-object entries from
sequences_payload; change it to explicitly validate each item and reject
malformed entries by raising a clear exception instead of filtering them out:
iterate over sequences_payload with enumerate, and for each (index, item) if not
isinstance(item, Mapping) raise a TypeError/ValueError that includes the index
and the offending item, otherwise call LogprobSequence.from_dict(item,
default_sequence_id=f"seq-{index}") and append to sequences; this ensures the
code using sequences (and the LogprobSequence.from_dict call) only runs on
validated inputs and surfaces corrupted fixtures immediately.

Comment on lines +428 to +447
generation_kwargs = {
"max_new_tokens": config.max_new_tokens,
"do_sample": config.do_sample,
"temperature": config.temperature if config.do_sample else None,
"top_p": config.top_p if config.do_sample else None,
"top_k": config.top_k if config.do_sample and config.top_k > 0 else None,
"num_return_sequences": config.num_generations,
"return_dict_in_generate": True,
"output_scores": True,
"pad_token_id": tokenizer.pad_token_id,
}
generation_kwargs = {
key: value for key, value in generation_kwargs.items() if value is not None
}
with torch.no_grad():
outputs = model.generate(
input_ids=encoded["input_ids"],
attention_mask=encoded["attention_mask"],
**generation_kwargs,
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

FILE="rl_engine/benchmarks/logprob_cross_engine.py"
echo "== File exists? =="
ls -l "$FILE" || true

echo "== Show relevant sections around 380-620 =="
nl -ba "$FILE" | sed -n '380,620p'

echo "== Search for config.seed usage in this file =="
rg -n "config\.seed|\bseed\b" "$FILE"

echo "== Search for do_sample forwarding in this file =="
rg -n "do_sample|max_new_tokens|temperature|top_p|top_k|num_return_sequences" "$FILE"

echo "== Find vLLM-specific code paths =="
rg -n "vllm|VLLM|LLM\(" "$FILE"

echo "== Find HuggingFace generate() usage =="
rg -n "model\.generate|generate\(" "$FILE"

echo "== List other files that might handle rollout/backends (if referenced) =="
rg -n "rollout|backend|engine|hf|huggingface|vllm" "$FILE"

Repository: RL-Align/RL-Kernel

Length of output: 281


🏁 Script executed:

#!/bin/bash
set -euo pipefail

FILE="rl_engine/benchmarks/logprob_cross_engine.py"
echo "== Show relevant sections around 380-620 =="
cat -n "$FILE" | sed -n '380,620p'

echo "== Search for config.seed usage in this file =="
rg -n "config\.seed|\bseed\b" "$FILE" || true

echo "== Search for do_sample forwarding patterns in this file =="
rg -n "do_sample|max_new_tokens|temperature|top_p|top_k|num_return_sequences" "$FILE" || true

echo "== Find vLLM-specific code paths in this file =="
rg -n "vllm|VLLM" "$FILE" || true

echo "== Find HuggingFace generate() usage in this file =="
rg -n "model\.generate|generate\(" "$FILE" || true

Repository: RL-Align/RL-Kernel

Length of output: 16378


🏁 Script executed:

#!/bin/bash
set -euo pipefail

FILE="rl_engine/benchmarks/logprob_cross_engine.py"
echo "== Show sections around seed/generator in this file (150-360) =="
cat -n "$FILE" | sed -n '150,360p'

echo "== Search for torch.manual_seed / torch.Generator usage in this file =="
rg -n "manual_seed|torch\.Generator|Generator\(" "$FILE" || true

echo "== Inspect lines around model/generation kwargs for HF rollout (410-470) =="
cat -n "$FILE" | sed -n '400,480p'

echo "== Inspect vLLM sampler config call site (520-620) =="
cat -n "$FILE" | sed -n '520,630p'

# locate vLLM sampler implementation
echo "== Locate vllm_sampler implementation =="
ls -R rl_engine/executors | head -n 200

Repository: RL-Align/RL-Kernel

Length of output: 20184


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Inspect rl_engine/executors/vllm_sampler.py (seed/do_sample related parts) =="
FILE="rl_engine/executors/vllm_sampler.py"
ls -l "$FILE"
wc -l "$FILE"

# show head and then relevant sections by grep context
rg -n "seed|do_sample|temperature|top_p|top_k|sampling_params|generator" "$FILE" || true

echo "== Show sections around any seed occurrences =="
rg -n "seed" "$FILE" | cut -d: -f1 | head -n 50 | while read -r line; do
  start=$((line-30)); end=$((line+30));
  sed -n "${start},${end}p" "$FILE" | sed -n '1,120p' >/dev/null
done

# Instead, print exact line windows for first few occurrences
lines=$(rg -n "seed" "$FILE" | cut -d: -f1 | head -n 10 | tr '\n' ' ')
if [ -n "$lines" ]; then
  for line in $lines; do
    start=$((line-25)); end=$((line+25));
    echo "--- around seed at line $line ---"
    sed -n "${start},${end}p" "$FILE" | cat
  done
fi

echo "== Show sections around VLLMSamplerConfig definition =="
rg -n "class VLLMSamplerConfig|VLLMSamplerConfig\\(" "$FILE" || true
rg -n "def generate\\(" "$FILE" || true

# Show around config/dataclass and generate method
cfg_line=$(rg -n "class VLLMSamplerConfig" "$FILE" | head -n1 | cut -d: -f1 || true)
gen_line=$(rg -n "def generate\\(" "$FILE" | head -n1 | cut -d: -f1 || true)
if [ -n "${cfg_line:-}" ]; then
  echo "--- around VLLMSamplerConfig at line $cfg_line ---"
  sed -n "$((cfg_line-40)),$((cfg_line+80))p" "$FILE" | cat
fi
if [ -n "${gen_line:-}" ]; then
  echo "--- around generate at line $gen_line ---"
  sed -n "$((gen_line-40)),$((gen_line+120))p" "$FILE" | cat
fi

echo "== Search other seeding call sites in logprob_cross_engine.py and repo =="
FILE2="rl_engine/benchmarks/logprob_cross_engine.py"
rg -n "manual_seed|torch\.Generator\\(|seed\\s*=\\s*config\\.seed|generator\\s*=" "$FILE2" || true

# broader search for torch.manual_seed usage
rg -n "torch\\.manual_seed\\(" . || true

Repository: RL-Align/RL-Kernel

Length of output: 1304


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Inspect rl_engine/executors/vllm_sampler.py (seed/do_sample related parts) =="
FILE="rl_engine/executors/vllm_sampler.py"
ls -l "$FILE"
wc -l "$FILE"

rg -n "seed|do_sample|temperature|top_p|top_k|sampling_params|generator" "$FILE" || true

echo "== Show sections around any seed occurrences =="
seed_lines=$(rg -n "seed" "$FILE" | cut -d: -f1 | head -n 20 || true)
if [ -n "${seed_lines:-}" ]; then
  for line in $seed_lines; do
    start=$((line-25)); end=$((line+25));
    echo "--- around seed at line $line ---"
    sed -n "${start},${end}p" "$FILE" | cat
  done
fi

echo "== Show around VLLMSamplerConfig and generate() =="
cfg_line=$(rg -n "class VLLMSamplerConfig" "$FILE" | head -n1 | cut -d: -f1 || true)
gen_line=$(rg -n "def generate\\(" "$FILE" | head -n1 | cut -d: -f1 || true)

if [ -n "${cfg_line:-}" ]; then
  echo "--- around VLLMSamplerConfig at line $cfg_line ---"
  sed -n "$((cfg_line-40)),$((cfg_line+120))p" "$FILE" | cat
fi

if [ -n "${gen_line:-}" ]; then
  echo "--- around generate() at line $gen_line ---"
  sed -n "$((gen_line-60)),$((gen_line+160))p" "$FILE" | cat
fi

echo "== HF path: check for global torch seeding usage =="
FILE2="rl_engine/benchmarks/logprob_cross_engine.py"
rg -n "manual_seed|torch\.Generator\\(|generator\\s*=" "$FILE2" || true

echo "== Repo-wide torch.manual_seed usage =="
rg -n "torch\\.manual_seed\\(" . || true

Repository: RL-Align/RL-Kernel

Length of output: 3240


🏁 Script executed:

#!/bin/bash
set -euo pipefail

FILE="rl_engine/executors/vllm_sampler.py"

echo "== Print vllm_sampler.py with line numbers (selected ranges) =="
awk 'NR>=1 && NR<=220 {printf "%4d\t%s\n", NR, $0}' "$FILE" | sed -n '1,220p'
echo "----"
awk 'NR>=220 && NR<=320 {printf "%4d\t%s\n", NR, $0}' "$FILE" | sed -n '1,120p'

echo "== Search for 'seed' in vllm_sampler.py =="
rg -n "\bseed\b" "$FILE" || true

echo "== Search for vLLM SamplingParamsConfig seed / generator / do_sample usage =="
rg -n "SamplingParams|do_sample|temperature|top_p|top_k|seed|rng|generator" "$FILE" || true

Repository: RL-Align/RL-Kernel

Length of output: 12621


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Search repo for torch.manual_seed / set_seed usage =="
rg -n "torch\.manual_seed\(|set_seed\(" . || true

echo "== Search within logprob_cross_engine.py for any seeding calls =="
rg -n "manual_seed|set_seed|Generator\(|torch\.Generator" rl_engine/benchmarks/logprob_cross_engine.py || true

echo "== Inspect CLI/entrypoint around --seed parsing and any seeding logic =="
cat -n rl_engine/benchmarks/logprob_cross_engine.py | sed -n '980,1160p'

Repository: RL-Align/RL-Kernel

Length of output: 10030


Propagate --seed and --do-sample into every rollout backend (HF + vLLM)

config.seed is only used for the synthetic rollout generator; the HF path calls model.generate(...) without a generator/seed, and the vLLM path never forwards seed (or any do_sample flag) into vllm.SamplingParams, so sampled runs won’t be reproducible across engines. Also, --do-sample changes behavior in the HF path but is effectively ignored in the vLLM path (it always passes temperature/top_p), breaking cross-engine parity.

Applies to: 428-447 and 531-549

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/benchmarks/logprob_cross_engine.py` around lines 428 - 447, HF
generation and vLLM sampling don't receive the global seed/do_sample flags,
causing non-reproducible and inconsistent behavior across engines; update the HF
path around generation_kwargs/model.generate to create a torch.Generator seeded
with config.seed (e.g., g =
torch.Generator(device=encoded["input_ids"].device).manual_seed(config.seed))
and pass generator=g into model.generate, and only include
temperature/top_p/top_k in generation_kwargs when config.do_sample is true
(preserve the existing top_k > 0 guard). For the vLLM path, forward config.seed
and config.do_sample into vllm.SamplingParams (e.g.,
SamplingParams(seed=int(config.seed) if set, do_sample=bool(config.do_sample),
temperature=..., top_p=..., top_k=... with same top_k > 0 rule) so both engines
honor the same sampling flags and seed. Ensure any places mentioned in the
comment (the HF block around generation_kwargs/model.generate and the vLLM
SamplingParams construction) are updated accordingly.

Comment on lines +907 to +923
if "sequences" in payload:
nested = LogprobBenchmarkFixture.from_dict(payload)
sequences.extend(nested.sequences)
continue
sequences.append(
LogprobSequence.from_dict(payload, default_sequence_id=f"jsonl-{index}")
)
return LogprobBenchmarkFixture(
schema_version=SCHEMA_VERSION,
created_at=_utc_now(),
rollout_engine="fixture",
model="fixture",
tokenizer=None,
dtype="unknown",
sequences=sequences,
metadata={"source_path": str(fixture_path)},
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

JSONL ingest currently throws away the metadata needed for exact replay.

When a .jsonl file is loaded, nested fixtures are flattened into sequences and the returned LogprobBenchmarkFixture is rebuilt with model="fixture", dtype="unknown", and tokenizer=None. Downstream, load_training_model() uses fixture.model to infer the replay model, so JSONL replay can silently score against the wrong weights and report incomplete provenance.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/benchmarks/logprob_cross_engine.py` around lines 907 - 923, The
JSONL ingestion flattens nested LogprobBenchmarkFixture instances but discards
their metadata (model, tokenizer, dtype, rollout_engine, schema_version),
causing downstream replay to use wrong defaults; update the logic in the JSONL
loader (the branch that calls LogprobBenchmarkFixture.from_dict and extends
sequences) to also capture and merge the first non-default metadata fields from
the nested fixture into the top-level LogprobBenchmarkFixture you return (e.g.,
set model, tokenizer, dtype, rollout_engine, and schema_version from the nested
fixture when they are present and not placeholder values like
"fixture"/"unknown"/None), and also merge/append nested metadata dicts into the
top-level metadata field so provenance is preserved; keep using
LogprobSequence.from_dict for individual sequence lines and preserve
default_sequence_id behavior.

Comment on lines +1183 to +1192
def _call_model(model: torch.nn.Module, input_ids: torch.Tensor, **kwargs: Any) -> Any:
try:
return model(input_ids=input_ids, **kwargs)
except TypeError:
kwargs.pop("use_cache", None)
try:
return model(input_ids=input_ids, **kwargs)
except TypeError:
kwargs.pop("attention_mask", None)
return model(input_ids)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Do not treat any TypeError as a signature mismatch.

If model.forward() raises TypeError from inside its implementation, this helper retries with fewer arguments and can silently benchmark a different code path instead of surfacing the real failure. Filter unsupported kwargs before the call, or only swallow explicit unexpected-argument errors.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/benchmarks/logprob_cross_engine.py` around lines 1183 - 1192, The
helper _call_model should not blanket-catch all TypeError instances; instead,
inspect the model's callable signature and pass only kwargs that the model
actually accepts. Replace the current retry-on-TypeError logic with code that
obtains the callable (prefer model.forward if available, else model.__call__),
uses inspect.signature(...) to collect parameter names, builds a filtered_kwargs
containing only keys present in that parameter set (keep input_ids as a
positional/keyword argument), and then calls the model once with input_ids and
filtered_kwargs; remove the nested TypeError retries so internal TypeErrors are
not swallowed.

@inaniloquentee inaniloquentee force-pushed the feat/logprob-cross-engine-tool branch from 33cadd0 to 7549409 Compare June 13, 2026 15:05

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@rl_engine/benchmarks/logprob_cross_engine.py`:
- Around line 106-110: The code currently casts every item in
payload["completion_mask"] through bool(...), which silently accepts stringy
values like "false" and corrupts the mask; instead validate items explicitly:
when completion_mask_value is provided, iterate over completion_mask_value and
accept only actual booleans (bool) or explicit integer encodings 0/1 (int ->
bool), otherwise raise a clear ValueError indicating the invalid mask entry and
its index; update the creation of completion_mask (the variable currently
computed from completion_mask_value and completion_token_ids) to enforce the
length matches completion_token_ids and to fail fast on malformed entries rather
than truthifying them.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: dc96046b-a92d-47af-ad20-6a2fc99f3db2

📥 Commits

Reviewing files that changed from the base of the PR and between 33cadd0 and 7549409.

📒 Files selected for processing (6)
  • .gitignore
  • benchmarks/logprob_cross_engine.py
  • docs/benchmarking/README.md
  • rl_engine/benchmarks/__init__.py
  • rl_engine/benchmarks/logprob_cross_engine.py
  • tests/test_logprob_cross_engine.py
✅ Files skipped from review due to trivial changes (2)
  • rl_engine/benchmarks/init.py
  • .gitignore
🚧 Files skipped from review as they are similar to previous changes (2)
  • docs/benchmarking/README.md
  • benchmarks/logprob_cross_engine.py

Comment on lines +106 to +110
completion_mask_value = payload.get("completion_mask")
if completion_mask_value is None:
completion_mask = [True] * len(completion_token_ids)
else:
completion_mask = [bool(item) for item in completion_mask_value]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Reject stringly completion_mask values instead of truthifying them.

Lines 106-110 run every entry through bool(...), so payloads like ["false", "false"] or ["0", "1"] become all-True masks and silently change which tokens are benchmarked. For a replay fixture, malformed mask encodings should fail fast or be parsed explicitly.

Suggested fix
         completion_mask_value = payload.get("completion_mask")
         if completion_mask_value is None:
             completion_mask = [True] * len(completion_token_ids)
         else:
-            completion_mask = [bool(item) for item in completion_mask_value]
+            if not isinstance(completion_mask_value, Sequence) or isinstance(
+                completion_mask_value, (str, bytes)
+            ):
+                raise ValueError("completion_mask must be a sequence of booleans")
+            completion_mask = []
+            for index, item in enumerate(completion_mask_value):
+                if isinstance(item, bool):
+                    completion_mask.append(item)
+                elif isinstance(item, int) and item in {0, 1}:
+                    completion_mask.append(bool(item))
+                else:
+                    raise ValueError(
+                        f"completion_mask[{index}] must be a boolean or 0/1 integer"
+                    )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/benchmarks/logprob_cross_engine.py` around lines 106 - 110, The
code currently casts every item in payload["completion_mask"] through bool(...),
which silently accepts stringy values like "false" and corrupts the mask;
instead validate items explicitly: when completion_mask_value is provided,
iterate over completion_mask_value and accept only actual booleans (bool) or
explicit integer encodings 0/1 (int -> bool), otherwise raise a clear ValueError
indicating the invalid mask entry and its index; update the creation of
completion_mask (the variable currently computed from completion_mask_value and
completion_token_ids) to enforce the length matches completion_token_ids and to
fail fast on malformed entries rather than truthifying them.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FEAT] End-to-end log-prob cross-benchmark tool for rollout vs training engines

1 participant