Skip to content
18 changes: 0 additions & 18 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include <tvm/tirx/expr.h>
#include <tvm/tirx/op.h>

#include "./scalable_expression.h"
#include "const_fold.h"
#include "product_normal_form.h"

Expand Down Expand Up @@ -231,23 +230,6 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) {
}
}

// Current analysis may not be powerful enough to prove expressions containing
// the same symbolic value multiple times. However, when the symbolic values are
// "T.vscale" and the compile target uses a scalable architecture extension like
// VLA, we can make some assumptions about the value of vscale and iterate over a
// space of pre-defined values to attempt to prove the expression.
Target curr_target = Target::Current();
if (ContainsVscaleCall(simplified)) {
if (TargetHasVLA(curr_target)) {
auto kVScaleValues = GetVScaleValues(curr_target);
return CanProveVscaleExpressionFromKnownValues(this, simplified, kVScaleValues);
}
LOG(WARNING)
<< "The expression contains scalable values. An attempt to prove by substituting "
"with known values of vscale was not performed. This proof currently only supports "
"VLA targets, but the target was "
<< curr_target;
}
return false;
}

Expand Down
6 changes: 0 additions & 6 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
#include "constraint_extract.h"
#include "int_operator.h"
#include "pattern_match.h"
#include "scalable_expression.h"

namespace tvm {
namespace arith {
Expand Down Expand Up @@ -417,17 +416,12 @@ class ConstIntBoundAnalyzer::Impl
// only special handle >> and & which can be
// used for index calculation.

auto curr_target = Target::Current();
if (op->op.same_as(tirx::builtin::shift_right())) {
return VisitRightShift(op);
} else if (op->op.same_as(tirx::builtin::shift_left())) {
return VisitLeftShift(op);
} else if (op->op.same_as(tirx::builtin::bitwise_and())) {
return VisitBitwiseAnd(op);
} else if (op->op.same_as(tirx::builtin::vscale()) && TargetHasVLA(curr_target)) {
auto kVScaleValues = GetVScaleValues(curr_target);
unsigned int max_val = *std::max_element(kVScaleValues.begin(), kVScaleValues.end());
return MakeBound(1, max_val);
} else {
return Everything(op->dtype);
}
Expand Down
36 changes: 31 additions & 5 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,41 @@
#include <utility>

#include "../target/datatype/registry.h"
#include "../tirx/analysis/check_contains.h"
#include "conjunctive_normal_form.h"
#include "const_fold.h"
#include "constraint_extract.h"
#include "pattern_match.h"
#include "scalable_expression.h"

namespace tvm {
namespace arith {

namespace {
// File-local helper: true if `expr` is a call to tirx::builtin::vscale().
bool IsVScaleCall(const PrimExpr& expr) {
if (const auto* call = expr.as<tirx::CallNode>()) {
return call->op.same_as(tirx::builtin::vscale());
}
return false;
}

// File-local helper: true if `expr` contains a call to tirx::builtin::vscale().
bool ContainsVscaleCall(const PrimExpr& expr) {
return tirx::CheckContains::ExprContains(expr, IsVScaleCall);
}

// File-local helper: returns the vscale multiplier if `lanes` is of the form
// `multiplier * vscale()` or `vscale() * multiplier`, nullopt otherwise.
std::optional<int> ExtractVscaleFactor(const PrimExpr& lanes) {
PVar<IntImm> multiplier;
PCallExpr<PVscaleOp> vscale;
if (PMatchesOneOf(multiplier * vscale, vscale * multiplier).Match(lanes)) {
return multiplier.Eval()->value;
}
Comment thread
tqchen marked this conversation as resolved.
return std::nullopt;
}
} // namespace

using namespace tirx;

TVM_FFI_STATIC_INIT_BLOCK() { RewriteSimplifierStatsNode::RegisterReflection(); }
Expand Down Expand Up @@ -789,7 +815,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) {
return ramp(div(b1, c2), div(c1, c2), lanes).Eval();
}
// If all possible indices in ramp are the same.
if (CanProveGreaterEqual(b1.Eval(), 0) && !arith::ExtractVscaleFactor(lanes.Eval())) {
if (CanProveGreaterEqual(b1.Eval(), 0) && !ExtractVscaleFactor(lanes.Eval())) {
ModularSet bmod = analyzer_->modular_set(b1.Eval());
int64_t ramp_min = bmod->base / c2val;
auto lanes_int = lanes.Eval().as<IntImmNode>()->value;
Expand Down Expand Up @@ -946,7 +972,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) {
// If all possible indices in ramp are the same.
if (CanProveGreaterEqual(b1.Eval(), 0)) {
ModularSet bmod = analyzer_->modular_set(b1.Eval());
if (!arith::ExtractVscaleFactor(lanes.Eval())) {
if (!ExtractVscaleFactor(lanes.Eval())) {
auto lanes_int = lanes.Eval().as<IntImmNode>()->value;
int64_t ramp_min = bmod->base / c2val;
int64_t ramp_max = (bmod->base + (lanes_int - 1) * c1val) / c2val;
Expand Down Expand Up @@ -1032,7 +1058,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
return ramp(floordiv(b1, c2), floordiv(c1, c2), lanes).Eval();
}
// If all possible indices in ramp are the same.
if (!arith::ExtractVscaleFactor(lanes.Eval())) {
if (!ExtractVscaleFactor(lanes.Eval())) {
ModularSet bmod = analyzer_->modular_set(b1.Eval());
int64_t ramp_min = floordiv(bmod->base, c2val);
auto lanes_int = lanes.Eval().as<IntImmNode>()->value;
Expand Down Expand Up @@ -1186,7 +1212,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
}
// If all possible indices in ramp are the same.
ModularSet bmod = analyzer_->modular_set(b1.Eval());
if (!arith::ExtractVscaleFactor(lanes.Eval())) {
if (!ExtractVscaleFactor(lanes.Eval())) {
int64_t ramp_min = floordiv(bmod->base, c2val);
auto lanes_int = lanes.Eval().as<IntImmNode>()->value;
int64_t ramp_max = floordiv(bmod->base + (lanes_int - 1) * c1val, c2val);
Expand Down
127 changes: 0 additions & 127 deletions src/arith/scalable_expression.cc

This file was deleted.

96 changes: 0 additions & 96 deletions src/arith/scalable_expression.h

This file was deleted.

20 changes: 18 additions & 2 deletions src/s_tir/schedule/ir_comparator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,27 @@
#include "./ir_comparator.h"

#include <tvm/ffi/cast.h>
#include <tvm/tirx/builtin.h>

#include "../../arith/scalable_expression.h"
#include "../../tirx/analysis/check_contains.h"

namespace tvm {

namespace {
// File-local helper: true if `expr` is a call to tirx::builtin::vscale().
bool IsVScaleCall(const PrimExpr& expr) {
if (const auto* call = expr.as<tirx::CallNode>()) {
return call->op.same_as(tirx::builtin::vscale());
}
return false;
}

// File-local helper: true if `expr` contains a call to tirx::builtin::vscale().
bool ContainsVscaleCall(const PrimExpr& expr) {
return tirx::CheckContains::ExprContains(expr, IsVScaleCall);
}
} // namespace

namespace s_tir {
using namespace tvm::tirx;

Expand Down Expand Up @@ -80,7 +96,7 @@ bool TensorizeComparator::VisitExpr(const PrimExpr& n, const PrimExpr& other) {
bool equal = n.same_as(other) ||
((n->type_index() == other->type_index()) &&
n.dtype().code() == other.dtype().code() && ExprComparator::VisitExpr(n, other)) ||
(tvm::arith::ContainsVscaleCall(n) && analyzer_.CanProveEqual(n, other));
(ContainsVscaleCall(n) && analyzer_.CanProveEqual(n, other));

if (!equal && assert_mode_) {
std::ostringstream os;
Expand Down
Loading
Loading