Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions src/s_tir/transform/default_gpu_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,55 @@ IRModule MarkScheduled(const IRModule& mod) {
mod->global_infos); // global_infos
}

/*!
* \brief Wrap a PrimFunc body that is a bare \c SBlockRealize (no enclosing
* loops, no iter vars) so the realized block is no longer the function's root
* sref.
*
* Without this, \c ThreadBind below calls \c Schedule::AddUnitLoop(block) on
* a block that is itself the prim_func's root sref, hitting the
* "Cannot add loops on top of the root block" check in
* \c s_tir::AddUnitLoop. The schedule infrastructure additionally requires
* the prim_func body to be an \c SBlockRealize, so we keep that shape and
* push the original block one level deeper, inside a wrapping root block
* that holds a unit serial loop. The synthesised data-parallel iter keeps
* iter_values/iter_vars counts consistent for downstream checks.
*/
tirx::PrimFunc WrapBareSBlockBody(const tirx::PrimFunc& func) {
const auto* realize = func->body.as<tirx::SBlockRealizeNode>();
if (realize == nullptr || !realize->block->iter_vars.empty()) {
return func;
}
// Only wrap when the block is a leaf computation. A well-formed PrimFunc
// produced by the rest of the pipeline has an implicit root SBlockRealize
// whose block body is a For loop (or a nested SBlockRealize) — that case
// already has somewhere to put thread bindings, so leave it alone.
const tirx::Stmt& inner = realize->block->body;
if (inner->IsInstance<tirx::ForNode>() || inner->IsInstance<tirx::SBlockRealizeNode>()) {
return func;
}
tvm::IntImm zero(tvm::DataType::Int(32), 0);
tvm::IntImm one(tvm::DataType::Int(32), 1);
tirx::Var loop_var("u", tvm::DataType::Int(32));
tirx::Var iter_var_var("vu", tvm::DataType::Int(32));
tirx::IterVar new_iter(tvm::Range::FromMinExtent(zero, one), iter_var_var,
tirx::IterVarType::kDataPar);
tirx::SBlock inner_block = realize->block;
inner_block.CopyOnWrite()->iter_vars = ffi::Array<tirx::IterVar>{new_iter};
tirx::SBlockRealize inner_realize(/*iter_values=*/ffi::Array<tvm::PrimExpr>{loop_var},
/*predicate=*/realize->predicate, inner_block);
tirx::Stmt for_stmt = tirx::For(loop_var, zero, one, tirx::ForKind::kSerial, inner_realize);
tirx::SBlock root_block(/*iter_vars=*/ffi::Array<tirx::IterVar>{},
/*reads=*/ffi::Array<tirx::BufferRegion>{},
/*writes=*/ffi::Array<tirx::BufferRegion>{},
/*name_hint=*/"root", /*body=*/for_stmt);
tirx::SBlockRealize root_realize(/*iter_values=*/ffi::Array<tvm::PrimExpr>{},
/*predicate=*/tvm::Bool(true), root_block);
tirx::PrimFunc result = func;
result.CopyOnWrite()->body = std::move(root_realize);
return result;
}

bool IsScheduledOnGPU(const BaseFunc& func) {
// the target from context.
tvm::Target target = tvm::Target::Current();
Expand All @@ -125,6 +174,27 @@ bool IsScheduledOnGPU(const BaseFunc& func) {
Pass DefaultGPUSchedule() {
auto pass_func = //
[=](IRModule m, PassContext pc) {
// Wrap any GPU-bound PrimFunc whose body is a bare SBlockRealize
// (e.g. a scalar op) so ThreadBind below has a loop to operate on.
ffi::Map<GlobalVar, BaseFunc> wrapped;
bool any_wrapped = false;
for (const auto& [gv, base_func] : m->functions) {
if (const auto* prim_func_node = base_func.as<tirx::PrimFuncNode>();
prim_func_node != nullptr && IsScheduledOnGPU(base_func) &&
!base_func->HasNonzeroAttr(tirx::attr::kIsScheduled)) {
tirx::PrimFunc func = ffi::GetRef<tirx::PrimFunc>(prim_func_node);
tirx::PrimFunc new_func = WrapBareSBlockBody(func);
if (!new_func.same_as(func)) {
wrapped.Set(gv, new_func);
any_wrapped = true;
continue;
}
}
wrapped.Set(gv, base_func);
}
if (any_wrapped) {
m = IRModule(wrapped, m->source_map, m->attrs, m->global_infos);
}
s_tir::Schedule sch = s_tir::Schedule::Traced(m, /*seed=*/-1, /*debug_mask=*/0,
s_tir::ScheduleErrorRenderLevel::kDetail);
for (const auto& [gv, func] : m->functions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -567,5 +567,39 @@ def sum(A: T.Buffer((T.int64(2), T.int64(2)), "float64"), A_red: T.Buffer((), "f
tvm.ir.assert_structural_equal(mod, Expected)


def test_scalar_block_no_loops():
# A PrimFunc whose body is a bare SBlockRealize (e.g. a fully-scalar op)
# used to crash DefaultGPUSchedule with "Cannot add loops on top of the
# root block" because the realized block was the function's root sref.
# pylint: disable=no-self-argument,missing-class-docstring,line-too-long
# fmt: off
@tvm.script.ir_module
class Before:
@T.prim_func
def scalar_add(a: T.Buffer((), "float32"), b: T.Buffer((), "float32"), c: T.Buffer((), "float32")):
with T.sblock("scalar_add"):
c[()] = a[()] + b[()]

@tvm.script.ir_module
class Expected:
@T.prim_func
def scalar_add(a: T.Buffer((), "float32"), b: T.Buffer((), "float32"), c: T.Buffer((), "float32")):
T.func_attr({"tirx.is_scheduled": True})
# with T.sblock("root"):
for u_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
for u_fused_1 in T.thread_binding(1, thread="threadIdx.x"):
with T.sblock("scalar_add"):
vu = T.axis.spatial(1, 0)
T.reads()
T.writes()
c[()] = a[()] + b[()]
# fmt: on
# pylint: enable=no-self-argument,missing-class-docstring,line-too-long
target = tvm.target.Target("nvidia/geforce-rtx-3070")
with target, tvm.transform.PassContext(opt_level=0):
mod = DefaultGPUSchedule()(Before)
tvm.ir.assert_structural_equal(mod, Expected)


if __name__ == "__main__":
tvm.testing.main()
Loading