Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/tirx/transform/common_subexpr_elim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@
* - It is not a leaf (Var, IntImm, FloatImm, StringImm).
* - It does not contain Call or BufferLoad (side-effects / memory dependence).
* - It is not Ramp or Broadcast (hardware-specific vector ops).
* - It is not bool-typed. Boolean predicates are kept inline because the
* consumer (if / Select / assert) reads more clearly with the condition
* spelled out, and downstream simplification benefits from seeing the
* predicate directly.
*
* Scope tree
* ----------
Expand Down Expand Up @@ -263,6 +267,8 @@ class CSEPlanner : public StmtExprVisitor {
* - Not a Call or BufferLoad (side effects / memory dependence).
* - Not Ramp or Broadcast (hardware-specific vector construction).
* - Does not transitively contain any forbidden node.
* - Is not bool-typed (predicates are kept inline for readability and
* downstream simplification).
*
* \param expr The expression to check.
* \return true if the expression can participate in CSE.
Expand All @@ -274,6 +280,14 @@ class CSEPlanner : public StmtExprVisitor {
}
if (IsForbiddenNode(expr)) return false;
if (expr.as<RampNode>() || expr.as<BroadcastNode>()) return false;
// Reject bool-typed expressions. Boolean predicates almost always feed an
// if / Select / assert, where reading the condition inline is clearer than
// going through a `cse_v: bool = (a < b)` temporary, and where downstream
// simplification (ProveCondition, branch elimination) benefits from seeing
// the predicate directly. BoolImm is already filtered above as an IntImm
// leaf, so this rule only affects compound bool expressions
// (LT/LE/GT/GE/EQ/NE/And/Or/Not/Cast-to-bool/Select-of-bool).
if (expr.dtype().is_bool()) return false;
Comment on lines 281 to +290
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check for boolean-typed expressions should be moved before the recursive CheckContains::ExprContains call. CheckContains performs a full traversal of the expression tree (O(N)), whereas checking the dtype is an O(1) operation. By reordering these checks, we can avoid the expensive traversal for all boolean predicates, which are common in TIR, thereby improving the efficiency of the CSE pass.

Suggested change
if (IsForbiddenNode(expr)) return false;
if (expr.as<RampNode>() || expr.as<BroadcastNode>()) return false;
if (CheckContains::ExprContains(expr, IsForbiddenNode)) return false;
// Reject bool-typed expressions. Boolean predicates almost always feed an
// if / Select / assert, where reading the condition inline is clearer than
// going through a `cse_v: bool = (a < b)` temporary, and where downstream
// simplification (ProveCondition, branch elimination) benefits from seeing
// the predicate directly. BoolImm is already filtered above as an IntImm
// leaf, so this rule only affects compound bool expressions
// (LT/LE/GT/GE/EQ/NE/And/Or/Not/Cast-to-bool/Select-of-bool).
if (expr.dtype().is_bool()) return false;
// Reject bool-typed expressions. Boolean predicates almost always feed an
// if / Select / assert, where reading the condition inline is clearer than
// going through a `cse_v: bool = (a < b)` temporary, and where downstream
// simplification (ProveCondition, branch elimination) benefits from seeing
// the predicate directly. BoolImm is already filtered above as an IntImm
// leaf, so this rule only affects compound bool expressions
// (LT/LE/GT/GE/EQ/NE/And/Or/Not/Cast-to-bool/Select-of-bool).
if (expr.dtype().is_bool()) return false;
if (IsForbiddenNode(expr)) return false;
if (expr.as<RampNode>() || expr.as<BroadcastNode>()) return false;
if (CheckContains::ExprContains(expr, IsForbiddenNode)) return false;

if (CheckContains::ExprContains(expr, IsForbiddenNode)) return false;
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,47 @@ def test_let_floordiv_pattern():
assert "cse_v" not in script, f"CSE incorrectly extracted from Let body:\n{script}"


# =====================================================================
# T22: No lifting of bool predicate (comparison expression)
# A duplicated `i < n` feeds two if-statements. CSE must leave it
# inline rather than hoisting a `cse_v: bool = (i < n)` binding.
# =====================================================================
def test_no_lift_bool_predicate():
@tvm.script.ir_module
class Before:
@T.prim_func
def main(B: T.Buffer((50,), "int32"), n: T.int32, x: T.int32):
for i in range(50):
if i < n:
B[i] = x
if i < n:
B[i] = x + 1

after = tvm.tirx.transform.CommonSubexprElim()(Before)
tvm.ir.assert_structural_equal(after, Before)
assert "cse_v" not in after["main"].script()


# =====================================================================
# T23: No lifting of bool logical expression (And)
# A duplicated `a && b` feeds two if-statements. CSE must leave it
# inline rather than hoisting a `cse_v: bool = T.And(a, b)` binding.
# =====================================================================
def test_no_lift_bool_logical():
@tvm.script.ir_module
class Before:
@T.prim_func
def main(B: T.Buffer((50,), "int32"), a: T.bool, b: T.bool, x: T.int32):
if T.And(a, b):
B[0] = x
if T.And(a, b):
B[1] = x + 1

after = tvm.tirx.transform.CommonSubexprElim()(Before)
tvm.ir.assert_structural_equal(after, Before)
assert "cse_v" not in after["main"].script()


if __name__ == "__main__":
test_basic()
test_if_single_branch()
Expand All @@ -735,3 +776,5 @@ def test_let_floordiv_pattern():
test_let_value_cse()
test_nested_let_no_extraction()
test_let_floordiv_pattern()
test_no_lift_bool_predicate()
test_no_lift_bool_logical()
Loading