Skip to content

[BugFix][S-TIR] Wrap bare scalar bodies in DefaultGPUSchedule to avoid root-block crash#19514

Merged
tlopex merged 1 commit intoapache:mainfrom
swjng:fix/default-gpu-schedule-root-block
May 6, 2026
Merged

[BugFix][S-TIR] Wrap bare scalar bodies in DefaultGPUSchedule to avoid root-block crash#19514
tlopex merged 1 commit intoapache:mainfrom
swjng:fix/default-gpu-schedule-root-block

Conversation

@swjng
Copy link
Copy Markdown
Contributor

@swjng swjng commented May 6, 2026

Problem

Closes #17873.

DefaultGPUSchedule crashes when a PrimFunc body is a bare
SBlockRealize (a fully-scalar op with no enclosing loops and no iter
vars):

ValueError: Check failed: (sref->parent != nullptr) is false:
  Cannot add loops on top of the root block

Minimal repro (TVMScript decorators are omitted in this snippet to
satisfy the PR-body lint; the regression test uses the regular
T.prim_func form):

ir_module:
  prim_func main(a: Buffer((), "float32"),
                 b: Buffer((), "float32"),
                 c: Buffer((), "float32")):
      func_attr({"target": target("nvidia/geforce-rtx-3080")})
      with sblock("scalar_add"):
          c[()] = a[()] + b[()]

s_tir.transform.DefaultGPUSchedule()(M)  # crashes

Root Cause

The realized scalar_add block is itself the prim_func body's root
sref — it has no parent stmt to mutate. ThreadBind
(src/s_tir/transform/default_gpu_schedule.cc) reaches the
loops.empty() branch and calls sch->AddUnitLoop(block), which fails
the sref->parent != nullptr check in s_tir::AddUnitLoop
(src/s_tir/schedule/primitive/loop_transformation.cc:1166).

The schedule infrastructure additionally requires the prim_func body
to be an SBlockRealize whose block is the function's root
(GetRootPrimFunc in src/s_tir/schedule/analysis/analysis.cc:53),
so the body cannot simply be wrapped in a top-level For.

Fix

Before constructing the schedule, rewrite GPU-bound PrimFuncs whose
body is a bare-leaf SBlockRealize so the realized block is no longer
the root. The wrap conditions are intentionally narrow:

  1. func->body is SBlockRealize,
  2. the realized block has empty iter_vars, and
  3. the block's body is not For or SBlockRealize (i.e. it is a leaf
    computation, not the well-formed implicit root that wraps a loop
    nest produced by the rest of the pipeline).

When all three hold, the body becomes:

SBlockRealize(
  block=SBlock("root", body=
    For(u, 0, 1, kSerial,
      SBlockRealize(iter_values=[u],
        block=<original block, iter_vars=[IterVar(0..1, vu, kDataPar)]>))))

The synthesised 1-extent data-parallel iter keeps
iter_values.size() == iter_vars.size() for downstream checks, and the
new For loop gives ThreadBind a real loop to bind to blockIdx.x /
threadIdx.x. Already-scheduled functions and host-only PrimFuncs are
skipped via the existing IsScheduledOnGPU / kIsScheduled gating.

Testing

pytest tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py

10 passed (9 existing + 1 new test_scalar_block_no_loops). End-to-end
compile + execute on RTX 3080 (sm_86): the scalar repro returns the
expected 2.0 + 3.0 = 5.0.

…d root-block crash

When a PrimFunc body is a bare `SBlockRealize` (a fully-scalar op with
no enclosing loops and no iter vars), the realized block is itself the
function's root sref. `ThreadBind` reaches the `loops.empty()` branch
and calls `Schedule::AddUnitLoop(block)`, which fails the
`sref->parent != nullptr` check in `s_tir::AddUnitLoop` with
"Cannot add loops on top of the root block".

Before constructing the schedule, rewrite GPU-bound PrimFuncs whose
body is a bare-leaf `SBlockRealize` so the realized block is no longer
the root. The wrap conditions are intentionally narrow: body is
`SBlockRealize`, the block has empty `iter_vars`, and the block's body
is not `For` or `SBlockRealize` (so that well-formed implicit roots
already wrapping a loop nest are left alone). The new shape is

  SBlockRealize(block=SBlock("root", body=
    For(u, 0, 1, kSerial, SBlockRealize(iter_values=[u],
      block=<original block, iter_vars=[IterVar(0..1, vu, kDataPar)]>))))

The synthesised 1-extent data-parallel iter keeps iter_values and
iter_vars counts consistent for downstream checks, and the new For
gives ThreadBind a real loop to bind to blockIdx.x / threadIdx.x.

Closes apache#17873.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a mechanism to handle PrimFunc bodies consisting of a bare SBlockRealize, such as scalar operations, within the DefaultGPUSchedule pass. It adds a WrapBareSBlockBody helper function to wrap these blocks in a unit loop, which prevents crashes during thread binding by ensuring there is a loop to operate on. A new test case for scalar addition has been included to verify the implementation. I have no feedback to provide as there were no review comments.

@swjng swjng force-pushed the fix/default-gpu-schedule-root-block branch from b92e53b to 727d931 Compare May 6, 2026 13:53
@tlopex tlopex merged commit 446bd2d into apache:main May 6, 2026
10 of 14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

2 participants