diff --git a/src/tirx/transform/common_subexpr_elim.cc b/src/tirx/transform/common_subexpr_elim.cc index 38925dc25a8d..9e7b2b1fb70b 100644 --- a/src/tirx/transform/common_subexpr_elim.cc +++ b/src/tirx/transform/common_subexpr_elim.cc @@ -49,6 +49,10 @@ * - It is not a leaf (Var, IntImm, FloatImm, StringImm). * - It does not contain Call or BufferLoad (side-effects / memory dependence). * - It is not Ramp or Broadcast (hardware-specific vector ops). + * - It is not bool-typed. Boolean predicates are kept inline because the + * consumer (if / Select / assert) reads more clearly with the condition + * spelled out, and downstream simplification benefits from seeing the + * predicate directly. * * Scope tree * ---------- @@ -263,6 +267,8 @@ class CSEPlanner : public StmtExprVisitor { * - Not a Call or BufferLoad (side effects / memory dependence). * - Not Ramp or Broadcast (hardware-specific vector construction). * - Does not transitively contain any forbidden node. + * - Is not bool-typed (predicates are kept inline for readability and + * downstream simplification). * * \param expr The expression to check. * \return true if the expression can participate in CSE. @@ -274,6 +280,14 @@ class CSEPlanner : public StmtExprVisitor { } if (IsForbiddenNode(expr)) return false; if (expr.as() || expr.as()) return false; + // Reject bool-typed expressions. Boolean predicates almost always feed an + // if / Select / assert, where reading the condition inline is clearer than + // going through a `cse_v: bool = (a < b)` temporary, and where downstream + // simplification (ProveCondition, branch elimination) benefits from seeing + // the predicate directly. BoolImm is already filtered above as an IntImm + // leaf, so this rule only affects compound bool expressions + // (LT/LE/GT/GE/EQ/NE/And/Or/Not/Cast-to-bool/Select-of-bool). + if (expr.dtype().is_bool()) return false; if (CheckContains::ExprContains(expr, IsForbiddenNode)) return false; return true; } diff --git a/tests/python/tirx-transform/test_tir_transform_common_subexpr_elim.py b/tests/python/tirx-transform/test_tir_transform_common_subexpr_elim.py index 8786720a2522..e025ae88a9f0 100644 --- a/tests/python/tirx-transform/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/tirx-transform/test_tir_transform_common_subexpr_elim.py @@ -713,6 +713,47 @@ def test_let_floordiv_pattern(): assert "cse_v" not in script, f"CSE incorrectly extracted from Let body:\n{script}" +# ===================================================================== +# T22: No lifting of bool predicate (comparison expression) +# A duplicated `i < n` feeds two if-statements. CSE must leave it +# inline rather than hoisting a `cse_v: bool = (i < n)` binding. +# ===================================================================== +def test_no_lift_bool_predicate(): + @tvm.script.ir_module + class Before: + @T.prim_func + def main(B: T.Buffer((50,), "int32"), n: T.int32, x: T.int32): + for i in range(50): + if i < n: + B[i] = x + if i < n: + B[i] = x + 1 + + after = tvm.tirx.transform.CommonSubexprElim()(Before) + tvm.ir.assert_structural_equal(after, Before) + assert "cse_v" not in after["main"].script() + + +# ===================================================================== +# T23: No lifting of bool logical expression (And) +# A duplicated `a && b` feeds two if-statements. CSE must leave it +# inline rather than hoisting a `cse_v: bool = T.And(a, b)` binding. +# ===================================================================== +def test_no_lift_bool_logical(): + @tvm.script.ir_module + class Before: + @T.prim_func + def main(B: T.Buffer((50,), "int32"), a: T.bool, b: T.bool, x: T.int32): + if T.And(a, b): + B[0] = x + if T.And(a, b): + B[1] = x + 1 + + after = tvm.tirx.transform.CommonSubexprElim()(Before) + tvm.ir.assert_structural_equal(after, Before) + assert "cse_v" not in after["main"].script() + + if __name__ == "__main__": test_basic() test_if_single_branch() @@ -735,3 +776,5 @@ def test_let_floordiv_pattern(): test_let_value_cse() test_nested_let_no_extraction() test_let_floordiv_pattern() + test_no_lift_bool_predicate() + test_no_lift_bool_logical()