Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 121 additions & 31 deletions src/arith/transitive_comparison_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,19 @@ class TransitiveComparisonAnalyzer::Impl {
*/
bool Implies(const Comparison& other) const;

/*! \brief Structural equality over all four fields.
*
* Two Comparisons compare equal iff their `lhs_`, `rhs_`,
* `offset_`, and `result_` all match. Used by the Bind-override
* cleanup path to locate and remove a Comparison's companion
* entry stored under its partner key in `knowns_by_key_`.
*/
friend bool operator==(const Comparison& a, const Comparison& b) {
return a.lhs_ == b.lhs_ && a.rhs_ == b.rhs_ && a.offset_ == b.offset_ &&
a.result_ == b.result_;
}
friend bool operator!=(const Comparison& a, const Comparison& b) { return !(a == b); }

// The LHS of the comparison
Key lhs_;

Expand All @@ -191,15 +204,26 @@ class TransitiveComparisonAnalyzer::Impl {
/*! \brief Generate a Comparison representing the given expression */
std::optional<Comparison> FromExpr(const PrimExpr& expr);

/*! \brief Utility function used by Bind and EnterConstraint
/*! \brief Convert a comparison expression into Comparison objects
* and insert them into the persistent per-key index `knowns_by_key_`.
*
* Used by Bind for definitionally-true facts about variables
* (e.g. loop iterator ranges).
*
* \param expr The comparison expression.
*/
void AddKnown(const PrimExpr& expr);

/*! \brief Convert a comparison expression into Comparison objects
* and append them to `scoped_knowns_`.
*
* \param expr The comparison expression, to be converted into
* internal Comparison objects.
* Used by EnterConstraint to track facts that only hold within an
* active `With<ConstraintContext>`; the scope's recovery lambda
* truncates these entries on exit.
*
* \param vec The vector to which the Comparison objects should be
* appended.
* \param expr The comparison expression.
*/
void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec);
void AddScopedKnown(const PrimExpr& expr);

/*! Collect known comparisons between LHS and RHS, without propagation
*
Expand All @@ -210,9 +234,10 @@ class TransitiveComparisonAnalyzer::Impl {
*
* \param rhs_key The right-hand side of the comparison
*
* \returns A subset of `knowns_` and `scoped_knowns_`, filtered to
* only include comparisons between `lhs_key` and `rhs_key`,
* normalized such that `lhs_key` is on the left-hand side.
* \returns A subset of `knowns_by_key_` and `scoped_knowns_`,
* filtered to only include comparisons between `lhs_key` and
* `rhs_key`, normalized such that `lhs_key` is on the left-hand
* side.
*/
std::vector<Comparison> CollectDirectComparisons(Key lhs_key, Key rhs_key) const;

Expand All @@ -223,8 +248,8 @@ class TransitiveComparisonAnalyzer::Impl {
* \param rhs_key The right-hand side of the comparison
*
* \returns All comparisons between `lhs_key` and `rhs_key`,
* including the explicitly-provided comparisons in `knowns_` and
* `scoped_knowns_`, and comparisons provable through a series of
* including the explicitly-provided comparisons in `knowns_by_key_`
* and `scoped_knowns_`, and comparisons provable through a series of
* comparisons through other values. All comparisons returned are
* between `lhs_key` and `rhs_key`, and are normalized such that
* `lhs_key` is on the left-hand side.
Expand Down Expand Up @@ -284,8 +309,23 @@ class TransitiveComparisonAnalyzer::Impl {
* known statements are always true, based on the definition site of
* the variable. e.g. A loop iterator may never exceed the bounds
* of its loop.
*
* Indexed by `Key`: each Comparison is stored under both its `lhs_`
* and `rhs_` keys (collapsed to one entry when they are equal) so
* that `CollectDirectComparisons` and `DFSFromLHS` can look up only
* the bucket(s) that mention the query keys.
*/
std::unordered_map<Key, std::vector<Comparison>> knowns_by_key_;

/*! \brief Append `cmp` to `knowns_by_key_` under both its keys.
* When `lhs_ == rhs_`, insert only once.
*/
std::vector<Comparison> knowns_;
void IndexKnown(const Comparison& cmp) {
knowns_by_key_[cmp.lhs_].push_back(cmp);
if (cmp.rhs_ != cmp.lhs_) {
knowns_by_key_[cmp.rhs_].push_back(cmp);
}
}

/*! \brief Known comparisons based on scoped conditions
*
Expand Down Expand Up @@ -544,12 +584,25 @@ std::function<void()> TransitiveComparisonAnalyzer::EnterConstraint(const PrimEx
return impl_->EnterConstraint(constraint);
}

void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr,
std::vector<Comparison>* vec) {
void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr) {
// Bind path: route every Comparison into the persistent per-key
// index so it is reachable by future TryCompare queries.
for (const auto& subexpr : ExtractConstraints(expr, false)) {
if (tirx::SideEffect(expr) <= tirx::CallEffectKind::kPure) {
if (auto cmp = FromExpr(subexpr)) {
vec->push_back(cmp.value());
IndexKnown(cmp.value());
}
}
}
}

void TransitiveComparisonAnalyzer::Impl::AddScopedKnown(const PrimExpr& expr) {
// EnterConstraint path: append to `scoped_knowns_` so the scope's
// recovery lambda can truncate these entries on exit.
for (const auto& subexpr : ExtractConstraints(expr, false)) {
if (tirx::SideEffect(expr) <= tirx::CallEffectKind::kPure) {
if (auto cmp = FromExpr(subexpr)) {
scoped_knowns_.push_back(cmp.value());
}
}
}
Expand All @@ -566,20 +619,38 @@ void TransitiveComparisonAnalyzer::Impl::Bind(const tirx::Var& var, const Range&
TVM_FFI_ICHECK(allow_override) << "Binding of variable " << var << " as " << range
<< " conflicts with previous binding as " << (*it).second;
if (auto key = ExprToPreviousKey(var)) {
knowns_.erase(std::remove_if(knowns_.begin(), knowns_.end(),
[&](const auto& known) { return known.lhs_ == key.value(); }),
knowns_.end());
Key old_key = key.value();

// Every entry in `knowns_by_key_[old_key]` involves old_key by
// construction (on either side). Remove each from its partner
// bucket and then drop the whole old_key bucket in one go.
auto idx_it = knowns_by_key_.find(old_key);
if (idx_it != knowns_by_key_.end()) {
const std::vector<Comparison>& to_remove = idx_it->second;
for (const auto& cmp : to_remove) {
Key partner_key = (cmp.lhs_ == old_key) ? cmp.rhs_ : cmp.lhs_;
// self-comparison (lhs_ == rhs_): only stored once, in
// the bucket we are about to erase.
if (partner_key == old_key) continue;
auto other_it = knowns_by_key_.find(partner_key);
if (other_it == knowns_by_key_.end()) continue;
other_it->second.erase(
std::remove(other_it->second.begin(), other_it->second.end(), cmp),
other_it->second.end());
}
knowns_by_key_.erase(idx_it);
}
}
}
}

prev_bindings_.Set(var, range);

if (is_const_int(range->extent, 1)) {
AddKnown(var == range->min, &knowns_);
AddKnown(var == range->min);
} else {
AddKnown(var >= range->min, &knowns_);
AddKnown(var < range->min + range->extent, &knowns_);
AddKnown(var >= range->min);
AddKnown(var < range->min + range->extent);
}
Comment on lines +622 to 654
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of Bind (inherited from the previous version) has a logic issue where it calls AddKnown even when the range hasn't changed (i.e., when differs_from_previous is false). This leads to redundant entries being pushed into the knowns_by_key_ vectors every time Bind is called with the same range.

While the cleanup logic handles this by removing all matching comparisons when a variable is finally overridden, the accumulation of duplicates can degrade performance and increase memory usage during the lifetime of the analyzer.

Consider wrapping the prev_bindings_.Set and AddKnown calls in a check that ensures they only run if the binding is new or has actually changed.

}

Expand All @@ -590,7 +661,7 @@ void TransitiveComparisonAnalyzer::Impl::Bind(const tirx::Var& var, const PrimEx

std::function<void()> TransitiveComparisonAnalyzer::Impl::EnterConstraint(const PrimExpr& expr) {
size_t old_literal_size = scoped_knowns_.size();
AddKnown(expr, &scoped_knowns_);
AddScopedKnown(expr);
size_t new_literal_size = scoped_knowns_.size();

auto frecover = [old_literal_size, new_literal_size, this]() {
Expand Down Expand Up @@ -652,9 +723,20 @@ TransitiveComparisonAnalyzer::Impl::CollectDirectComparisons(Key lhs_key, Key rh
}
};

for (const auto& known : knowns_) {
append_known(known);
// Use the per-key index to look up only the comparisons that
// actually mention `lhs_key` or `rhs_key`. A matching Comparison is
// stored under both of its keys, so we only need to walk the
// smaller of the two buckets. If either key has never been seen,
// there is nothing to find.
auto it_l = knowns_by_key_.find(lhs_key);
auto it_r = knowns_by_key_.find(rhs_key);
if (it_l != knowns_by_key_.end() && it_r != knowns_by_key_.end()) {
const auto& bucket = (it_l->second.size() <= it_r->second.size()) ? it_l->second : it_r->second;
for (const auto& known : bucket) {
append_known(known);
}
}

for (const auto& known : scoped_knowns_) {
append_known(known);
}
Expand Down Expand Up @@ -713,10 +795,14 @@ TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key, Key rhs_key) const {
};

// Initialize the search based on any known (in)equalities that use
// the LHS of the comparison.
for (const auto& known : knowns_) {
if (auto normalized = known.WithLHS(lhs_key)) {
declare_known(normalized.value());
// the LHS of the comparison. Iterate only the bucket of knowns
// that mention lhs_key.
auto seed_it = knowns_by_key_.find(lhs_key);
if (seed_it != knowns_by_key_.end()) {
for (const auto& known : seed_it->second) {
if (auto normalized = known.WithLHS(lhs_key)) {
declare_known(normalized.value());
}
}
}
for (const auto& known : scoped_knowns_) {
Expand Down Expand Up @@ -787,9 +873,13 @@ TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key, Key rhs_key) const {
// we must first combine `a<=b` and `b<=c` into `a<=c`. During
// this first step, `b` is the "middle" and `c` is the "right".
// The next step can then combind `a<=c` and `c<=d` into `a<=d`.
for (const auto& known : knowns_) {
if (auto cmp = known.WithLHS(middle_key)) {
attempt_transitive(cmp.value());
// Iterate only the bucket of knowns that mention middle_key.
auto mid_it = knowns_by_key_.find(middle_key);
if (mid_it != knowns_by_key_.end()) {
for (const auto& known : mid_it->second) {
if (auto cmp = known.WithLHS(middle_key)) {
attempt_transitive(cmp.value());
}
}
}

Expand Down
Loading
Loading