[BUGFIX][TIR] Skip bool-typed expressions in CSE#19502
Conversation
The TIR CSE pass currently lifts bool-typed sub-expressions like `i < n` or `a && b` into `cse_v: bool = ...` bindings whenever they appear twice. Boolean expressions are almost always predicates feeding `if` / `Select` / `assert`, where reading the condition inline is clearer than going through a boolean temporary, and where downstream simplification (ProveCondition, branch elimination) benefits from seeing the predicate directly. Extend CSEPlanner::IsEligible to reject any expression whose result dtype is_bool().
There was a problem hiding this comment.
Code Review
This pull request modifies the Common Subexpression Elimination (CSE) pass to prevent the hoisting of boolean-typed expressions. This change ensures that predicates remain inline for better readability in control flow structures like 'if' and 'Select' statements, while also preserving opportunities for downstream simplifications. The update includes documentation and new test cases for comparison and logical expressions. Feedback was provided to optimize performance by reordering the boolean check to occur before more computationally expensive tree traversals.
| if (IsForbiddenNode(expr)) return false; | ||
| if (expr.as<RampNode>() || expr.as<BroadcastNode>()) return false; | ||
| if (CheckContains::ExprContains(expr, IsForbiddenNode)) return false; | ||
| // Reject bool-typed expressions. Boolean predicates almost always feed an | ||
| // if / Select / assert, where reading the condition inline is clearer than | ||
| // going through a `cse_v: bool = (a < b)` temporary, and where downstream | ||
| // simplification (ProveCondition, branch elimination) benefits from seeing | ||
| // the predicate directly. BoolImm is already filtered above as an IntImm | ||
| // leaf, so this rule only affects compound bool expressions | ||
| // (LT/LE/GT/GE/EQ/NE/And/Or/Not/Cast-to-bool/Select-of-bool). | ||
| if (expr.dtype().is_bool()) return false; |
There was a problem hiding this comment.
The check for boolean-typed expressions should be moved before the recursive CheckContains::ExprContains call. CheckContains performs a full traversal of the expression tree (O(N)), whereas checking the dtype is an O(1) operation. By reordering these checks, we can avoid the expensive traversal for all boolean predicates, which are common in TIR, thereby improving the efficiency of the CSE pass.
| if (IsForbiddenNode(expr)) return false; | |
| if (expr.as<RampNode>() || expr.as<BroadcastNode>()) return false; | |
| if (CheckContains::ExprContains(expr, IsForbiddenNode)) return false; | |
| // Reject bool-typed expressions. Boolean predicates almost always feed an | |
| // if / Select / assert, where reading the condition inline is clearer than | |
| // going through a `cse_v: bool = (a < b)` temporary, and where downstream | |
| // simplification (ProveCondition, branch elimination) benefits from seeing | |
| // the predicate directly. BoolImm is already filtered above as an IntImm | |
| // leaf, so this rule only affects compound bool expressions | |
| // (LT/LE/GT/GE/EQ/NE/And/Or/Not/Cast-to-bool/Select-of-bool). | |
| if (expr.dtype().is_bool()) return false; | |
| // Reject bool-typed expressions. Boolean predicates almost always feed an | |
| // if / Select / assert, where reading the condition inline is clearer than | |
| // going through a `cse_v: bool = (a < b)` temporary, and where downstream | |
| // simplification (ProveCondition, branch elimination) benefits from seeing | |
| // the predicate directly. BoolImm is already filtered above as an IntImm | |
| // leaf, so this rule only affects compound bool expressions | |
| // (LT/LE/GT/GE/EQ/NE/And/Or/Not/Cast-to-bool/Select-of-bool). | |
| if (expr.dtype().is_bool()) return false; | |
| if (IsForbiddenNode(expr)) return false; | |
| if (expr.as<RampNode>() || expr.as<BroadcastNode>()) return false; | |
| if (CheckContains::ExprContains(expr, IsForbiddenNode)) return false; |
The dtype check is O(1) and the ExprContains traversal is O(N); reordering avoids the traversal for boolean predicates, which are common in TIR. Pure ordering change, no behavior difference.
Summary
The TIR CSE pass currently lifts bool-typed sub-expressions like
i < nora && bintocse_v: bool = ...bindings whenever they appear twice. Boolean expressions are almost always predicates feedingif/Select/assert, where reading the condition inline is clearer than going through a boolean temporary, and where downstream simplification (ProveCondition, branch elimination) benefits from seeing the predicate directly.CSEPlanner::IsEligibleinsrc/tirx/transform/common_subexpr_elim.ccto reject any compound expression whose result dtype isbool.Eligibility rulesdoc-comment and the per-functionIsEligibledocstring to document the new rule.test_no_lift_bool_predicate,test_no_lift_bool_logical) covering comparison predicates and logical-And predicates respectively.