diff --git a/src/arith/transitive_comparison_analyzer.cc b/src/arith/transitive_comparison_analyzer.cc index e7deea4cfd56..27b97de230b6 100644 --- a/src/arith/transitive_comparison_analyzer.cc +++ b/src/arith/transitive_comparison_analyzer.cc @@ -175,6 +175,19 @@ class TransitiveComparisonAnalyzer::Impl { */ bool Implies(const Comparison& other) const; + /*! \brief Structural equality over all four fields. + * + * Two Comparisons compare equal iff their `lhs_`, `rhs_`, + * `offset_`, and `result_` all match. Used by the Bind-override + * cleanup path to locate and remove a Comparison's companion + * entry stored under its partner key in `knowns_by_key_`. + */ + friend bool operator==(const Comparison& a, const Comparison& b) { + return a.lhs_ == b.lhs_ && a.rhs_ == b.rhs_ && a.offset_ == b.offset_ && + a.result_ == b.result_; + } + friend bool operator!=(const Comparison& a, const Comparison& b) { return !(a == b); } + // The LHS of the comparison Key lhs_; @@ -191,15 +204,26 @@ class TransitiveComparisonAnalyzer::Impl { /*! \brief Generate a Comparison representing the given expression */ std::optional FromExpr(const PrimExpr& expr); - /*! \brief Utility function used by Bind and EnterConstraint + /*! \brief Convert a comparison expression into Comparison objects + * and insert them into the persistent per-key index `knowns_by_key_`. + * + * Used by Bind for definitionally-true facts about variables + * (e.g. loop iterator ranges). + * + * \param expr The comparison expression. + */ + void AddKnown(const PrimExpr& expr); + + /*! \brief Convert a comparison expression into Comparison objects + * and append them to `scoped_knowns_`. * - * \param expr The comparison expression, to be converted into - * internal Comparison objects. + * Used by EnterConstraint to track facts that only hold within an + * active `With`; the scope's recovery lambda + * truncates these entries on exit. * - * \param vec The vector to which the Comparison objects should be - * appended. + * \param expr The comparison expression. */ - void AddKnown(const PrimExpr& expr, std::vector* vec); + void AddScopedKnown(const PrimExpr& expr); /*! Collect known comparisons between LHS and RHS, without propagation * @@ -210,9 +234,10 @@ class TransitiveComparisonAnalyzer::Impl { * * \param rhs_key The right-hand side of the comparison * - * \returns A subset of `knowns_` and `scoped_knowns_`, filtered to - * only include comparisons between `lhs_key` and `rhs_key`, - * normalized such that `lhs_key` is on the left-hand side. + * \returns A subset of `knowns_by_key_` and `scoped_knowns_`, + * filtered to only include comparisons between `lhs_key` and + * `rhs_key`, normalized such that `lhs_key` is on the left-hand + * side. */ std::vector CollectDirectComparisons(Key lhs_key, Key rhs_key) const; @@ -223,8 +248,8 @@ class TransitiveComparisonAnalyzer::Impl { * \param rhs_key The right-hand side of the comparison * * \returns All comparisons between `lhs_key` and `rhs_key`, - * including the explicitly-provided comparisons in `knowns_` and - * `scoped_knowns_`, and comparisons provable through a series of + * including the explicitly-provided comparisons in `knowns_by_key_` + * and `scoped_knowns_`, and comparisons provable through a series of * comparisons through other values. All comparisons returned are * between `lhs_key` and `rhs_key`, and are normalized such that * `lhs_key` is on the left-hand side. @@ -284,8 +309,23 @@ class TransitiveComparisonAnalyzer::Impl { * known statements are always true, based on the definition site of * the variable. e.g. A loop iterator may never exceed the bounds * of its loop. + * + * Indexed by `Key`: each Comparison is stored under both its `lhs_` + * and `rhs_` keys (collapsed to one entry when they are equal) so + * that `CollectDirectComparisons` and `DFSFromLHS` can look up only + * the bucket(s) that mention the query keys. + */ + std::unordered_map> knowns_by_key_; + + /*! \brief Append `cmp` to `knowns_by_key_` under both its keys. + * When `lhs_ == rhs_`, insert only once. */ - std::vector knowns_; + void IndexKnown(const Comparison& cmp) { + knowns_by_key_[cmp.lhs_].push_back(cmp); + if (cmp.rhs_ != cmp.lhs_) { + knowns_by_key_[cmp.rhs_].push_back(cmp); + } + } /*! \brief Known comparisons based on scoped conditions * @@ -544,12 +584,25 @@ std::function TransitiveComparisonAnalyzer::EnterConstraint(const PrimEx return impl_->EnterConstraint(constraint); } -void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr, - std::vector* vec) { +void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr) { + // Bind path: route every Comparison into the persistent per-key + // index so it is reachable by future TryCompare queries. for (const auto& subexpr : ExtractConstraints(expr, false)) { if (tirx::SideEffect(expr) <= tirx::CallEffectKind::kPure) { if (auto cmp = FromExpr(subexpr)) { - vec->push_back(cmp.value()); + IndexKnown(cmp.value()); + } + } + } +} + +void TransitiveComparisonAnalyzer::Impl::AddScopedKnown(const PrimExpr& expr) { + // EnterConstraint path: append to `scoped_knowns_` so the scope's + // recovery lambda can truncate these entries on exit. + for (const auto& subexpr : ExtractConstraints(expr, false)) { + if (tirx::SideEffect(expr) <= tirx::CallEffectKind::kPure) { + if (auto cmp = FromExpr(subexpr)) { + scoped_knowns_.push_back(cmp.value()); } } } @@ -566,9 +619,27 @@ void TransitiveComparisonAnalyzer::Impl::Bind(const tirx::Var& var, const Range& TVM_FFI_ICHECK(allow_override) << "Binding of variable " << var << " as " << range << " conflicts with previous binding as " << (*it).second; if (auto key = ExprToPreviousKey(var)) { - knowns_.erase(std::remove_if(knowns_.begin(), knowns_.end(), - [&](const auto& known) { return known.lhs_ == key.value(); }), - knowns_.end()); + Key old_key = key.value(); + + // Every entry in `knowns_by_key_[old_key]` involves old_key by + // construction (on either side). Remove each from its partner + // bucket and then drop the whole old_key bucket in one go. + auto idx_it = knowns_by_key_.find(old_key); + if (idx_it != knowns_by_key_.end()) { + const std::vector& to_remove = idx_it->second; + for (const auto& cmp : to_remove) { + Key partner_key = (cmp.lhs_ == old_key) ? cmp.rhs_ : cmp.lhs_; + // self-comparison (lhs_ == rhs_): only stored once, in + // the bucket we are about to erase. + if (partner_key == old_key) continue; + auto other_it = knowns_by_key_.find(partner_key); + if (other_it == knowns_by_key_.end()) continue; + other_it->second.erase( + std::remove(other_it->second.begin(), other_it->second.end(), cmp), + other_it->second.end()); + } + knowns_by_key_.erase(idx_it); + } } } } @@ -576,10 +647,10 @@ void TransitiveComparisonAnalyzer::Impl::Bind(const tirx::Var& var, const Range& prev_bindings_.Set(var, range); if (is_const_int(range->extent, 1)) { - AddKnown(var == range->min, &knowns_); + AddKnown(var == range->min); } else { - AddKnown(var >= range->min, &knowns_); - AddKnown(var < range->min + range->extent, &knowns_); + AddKnown(var >= range->min); + AddKnown(var < range->min + range->extent); } } @@ -590,7 +661,7 @@ void TransitiveComparisonAnalyzer::Impl::Bind(const tirx::Var& var, const PrimEx std::function TransitiveComparisonAnalyzer::Impl::EnterConstraint(const PrimExpr& expr) { size_t old_literal_size = scoped_knowns_.size(); - AddKnown(expr, &scoped_knowns_); + AddScopedKnown(expr); size_t new_literal_size = scoped_knowns_.size(); auto frecover = [old_literal_size, new_literal_size, this]() { @@ -652,9 +723,20 @@ TransitiveComparisonAnalyzer::Impl::CollectDirectComparisons(Key lhs_key, Key rh } }; - for (const auto& known : knowns_) { - append_known(known); + // Use the per-key index to look up only the comparisons that + // actually mention `lhs_key` or `rhs_key`. A matching Comparison is + // stored under both of its keys, so we only need to walk the + // smaller of the two buckets. If either key has never been seen, + // there is nothing to find. + auto it_l = knowns_by_key_.find(lhs_key); + auto it_r = knowns_by_key_.find(rhs_key); + if (it_l != knowns_by_key_.end() && it_r != knowns_by_key_.end()) { + const auto& bucket = (it_l->second.size() <= it_r->second.size()) ? it_l->second : it_r->second; + for (const auto& known : bucket) { + append_known(known); + } } + for (const auto& known : scoped_knowns_) { append_known(known); } @@ -713,10 +795,14 @@ TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key, Key rhs_key) const { }; // Initialize the search based on any known (in)equalities that use - // the LHS of the comparison. - for (const auto& known : knowns_) { - if (auto normalized = known.WithLHS(lhs_key)) { - declare_known(normalized.value()); + // the LHS of the comparison. Iterate only the bucket of knowns + // that mention lhs_key. + auto seed_it = knowns_by_key_.find(lhs_key); + if (seed_it != knowns_by_key_.end()) { + for (const auto& known : seed_it->second) { + if (auto normalized = known.WithLHS(lhs_key)) { + declare_known(normalized.value()); + } } } for (const auto& known : scoped_knowns_) { @@ -787,9 +873,13 @@ TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key, Key rhs_key) const { // we must first combine `a<=b` and `b<=c` into `a<=c`. During // this first step, `b` is the "middle" and `c` is the "right". // The next step can then combind `a<=c` and `c<=d` into `a<=d`. - for (const auto& known : knowns_) { - if (auto cmp = known.WithLHS(middle_key)) { - attempt_transitive(cmp.value()); + // Iterate only the bucket of knowns that mention middle_key. + auto mid_it = knowns_by_key_.find(middle_key); + if (mid_it != knowns_by_key_.end()) { + for (const auto& known : mid_it->second) { + if (auto cmp = known.WithLHS(middle_key)) { + attempt_transitive(cmp.value()); + } } } diff --git a/tests/python/arith/test_arith_transitive_comparison.py b/tests/python/arith/test_arith_transitive_comparison.py new file mode 100644 index 000000000000..d8dc58fa7045 --- /dev/null +++ b/tests/python/arith/test_arith_transitive_comparison.py @@ -0,0 +1,178 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +"""Tests for TransitiveComparisonAnalyzer and the per-key index.""" + +import tvm +import tvm.ir +import tvm.testing +from tvm import tirx +from tvm.script import tirx as T + + +def test_single_bind_provability(): + analyzer = tvm.arith.Analyzer() + x = tirx.Var("x", "int32") + analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 100)) + assert analyzer.can_prove(x >= 0) + assert analyzer.can_prove(x < 100) + assert analyzer.can_prove(x <= 99) + assert not analyzer.can_prove(x >= 1) + + +def test_many_binds_correctness_preserved(): + analyzer = tvm.arith.Analyzer() + vars_ = [tirx.Var(f"v{i}", "int32") for i in range(2048)] + for i, v in enumerate(vars_): + analyzer.bind(v, tvm.ir.Range.from_min_extent(i, 10)) + for i in (0, len(vars_) // 2, len(vars_) - 1): + v = vars_[i] + assert analyzer.can_prove(v >= i) + assert analyzer.can_prove(v < i + 10) + assert not analyzer.can_prove(v >= i + 1) + + +def test_bind_override_clears_old_constraints(): + analyzer = tvm.arith.Analyzer() + x = tirx.Var("x", "int32") + analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 100)) + assert analyzer.can_prove(x < 100) + analyzer.bind(x, tvm.ir.Range.from_min_extent(200, 100), allow_override=True) + assert analyzer.can_prove(x >= 200) + assert analyzer.can_prove(x < 300) + assert not analyzer.can_prove(x < 100) + assert not analyzer.can_prove(x < 200) + + +def test_bind_override_clears_constraints_where_var_is_rhs(): + analyzer = tvm.arith.Analyzer() + x = tirx.Var("x", "int32") + y = tirx.Var("y", "int32") + analyzer.bind(y, tvm.ir.Range.from_min_extent(0, 10)) + analyzer.bind(x, y + 5) + assert analyzer.can_prove(x < 15) + analyzer.bind(y, tvm.ir.Range.from_min_extent(200, 100), allow_override=True) + assert not analyzer.can_prove(x < 15) + assert analyzer.can_prove(x >= 205) + + +def test_scoped_constraint_enter_and_exit(): + analyzer = tvm.arith.Analyzer() + x = tirx.Var("x", "int32") + y = tirx.Var("y", "int32") + analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 100)) + with analyzer.constraint_scope(y < x): + assert analyzer.can_prove(y < x) + assert not analyzer.can_prove(y < x) + + +def test_cross_key_lookup(): + analyzer = tvm.arith.Analyzer() + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + analyzer.bind(a, tvm.ir.Range.from_min_extent(0, 100)) + with analyzer.constraint_scope(b > a): + assert analyzer.can_prove(a < b) + + +def test_nested_constraint_scopes(): + analyzer = tvm.arith.Analyzer() + x = tirx.Var("x", "int32") + y = tirx.Var("y", "int32") + z = tirx.Var("z", "int32") + analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 100)) + with analyzer.constraint_scope(y < x): + assert analyzer.can_prove(y < x) + with analyzer.constraint_scope(z < y): + assert analyzer.can_prove(y < x) + assert analyzer.can_prove(z < y) + assert analyzer.can_prove(y < x) + assert not analyzer.can_prove(z < y) + assert not analyzer.can_prove(y < x) + assert not analyzer.can_prove(z < y) + + +def test_unrelated_binds_do_not_match(): + analyzer = tvm.arith.Analyzer() + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + c = tirx.Var("c", "int32") + d = tirx.Var("d", "int32") + analyzer.bind(a, tvm.ir.Range.from_min_extent(0, 10)) + analyzer.bind(b, tvm.ir.Range.from_min_extent(0, 10)) + analyzer.bind(c, tvm.ir.Range.from_min_extent(0, 10)) + assert not analyzer.can_prove(a < b) + assert not analyzer.can_prove(b < c) + assert not analyzer.can_prove(c < d) + + +def test_scoped_then_global_bind_interaction(): + analyzer = tvm.arith.Analyzer() + y = tirx.Var("y", "int32") + x = tirx.Var("x", "int32") + with analyzer.constraint_scope(y > 0): + analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 100)) + assert analyzer.can_prove(x < 100) + assert analyzer.can_prove(y > 0) + assert not analyzer.can_prove(y > 0) + assert analyzer.can_prove(x < 100) + + +def test_self_comparison_indexed_once(): + # `x == x` produces a Comparison with lhs_ == rhs_; IndexKnown + # must store it once, not twice. + analyzer = tvm.arith.Analyzer() + x = tirx.Var("x", "int32") + with analyzer.constraint_scope(x == x): + assert analyzer.can_prove(x == x) + analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 10)) + assert analyzer.can_prove(x >= 0) + assert analyzer.can_prove(x < 10) + + +def test_transitively_prove_inequalities_uses_dfs_path(): + # `i < j` and `j < k` (from For ranges) compose into `i < k` only + # when the DFS path runs (transitively_prove_inequalities=True). + + @T.prim_func + def before(A: T.Buffer((1,), "int32")): + for i in T.serial(0, 50): + for j in T.serial(i + 1, 50): + for k in T.serial(j + 1, 50): + if i < k: + A[0] = 1 + else: + A[0] = 0 + + @T.prim_func + def after_dfs(A: T.Buffer((1,), "int32")): + T.func_attr({"global_symbol": "before"}) + for i in T.serial(0, 50): + for j in T.serial(i + 1, 50): + for k in T.serial(j + 1, 50): + A[0] = 1 + + mod = tvm.IRModule({"main": before}) + expected = tvm.IRModule({"main": after_dfs}) + + with tvm.transform.PassContext( + config={"tirx.Simplify": {"transitively_prove_inequalities": True}} + ): + out_with_dfs = tvm.tirx.transform.Simplify()(mod) + tvm.ir.assert_structural_equal(out_with_dfs, expected) + + # Negative control: without the flag the if-guard must remain, so + # the result must NOT match `expected` (proves the positive + # assertion above actually exercises the DFS path). + out_no_dfs = tvm.tirx.transform.Simplify()(mod) + assert not tvm.ir.structural_equal(out_no_dfs, expected) + + +if __name__ == "__main__": + tvm.testing.main()