From 9e7c88543596c69d34e74192be76c3f20b3dded2 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 28 May 2026 12:17:55 +0000 Subject: [PATCH 1/5] [REFACTOR][TIR] Replace IR-position Integer/Bool constructors with IntImm/const_true/const_false Bool(true) and Bool(false) in PrimExpr position add no information over IntImm(DataType::Bool(), 1) and IntImm(DataType::Bool(), 0) (which const_true() and const_false() produce). Integer(N) in PrimExpr position adds no information over IntImm(DataType::Int(32), N); both wrap the same IntImmNode. Replace throughout the 62-file IR, arith, relax, s_tir, tirx, topi, and target/cuda call sites. --- include/tvm/topi/nn.h | 14 ++-- include/tvm/topi/nn/softmax.h | 2 +- include/tvm/topi/transform.h | 4 +- src/arith/conjunctive_normal_form.cc | 13 +-- src/arith/iter_affine_map.cc | 6 +- src/arith/modular_set.cc | 2 +- src/arith/presburger_set.cc | 4 +- src/arith/rewrite_simplify.cc | 6 +- src/relax/analysis/struct_info_analysis.cc | 67 +++++++-------- src/relax/analysis/tir_op_pattern_kind.cc | 5 +- src/relax/backend/contrib/clml/codegen.cc | 6 +- src/relax/ir/dataflow_matcher.cc | 8 +- src/relax/ir/dataflow_matcher.h | 3 +- src/relax/ir/expr_functor.cc | 3 +- src/relax/op/memory/view.cc | 2 +- src/relax/op/nn/convolution.cc | 84 +++++++++---------- src/relax/op/nn/pooling.cc | 84 +++++++++---------- src/relax/op/vision/multibox_transform_loc.cc | 2 +- src/relax/op/vision/roi_align.cc | 8 +- src/relax/op/vision/roi_pool.cc | 2 +- src/relax/transform/adjust_matmul_order.cc | 3 +- src/relax/transform/allocate_workspace.cc | 4 +- src/relax/transform/fuse_tir.cc | 5 +- src/relax/transform/infer_amp_utils.cc | 4 +- src/s_tir/analysis/identify_memcpy.cc | 5 +- src/s_tir/meta_schedule/utils.h | 4 +- src/s_tir/schedule/analysis/layout.cc | 4 +- .../schedule/primitive/blockize_tensorize.cc | 4 +- .../schedule/primitive/cache_read_write.cc | 20 ++--- src/s_tir/schedule/primitive/compute_at.cc | 2 +- .../schedule/primitive/compute_inline.cc | 2 +- .../primitive/layout_transformation.cc | 4 +- .../schedule/primitive/loop_transformation.cc | 4 +- src/s_tir/schedule/primitive/pad_einsum.cc | 4 +- src/s_tir/schedule/primitive/read_write_at.cc | 2 +- src/s_tir/schedule/primitive/reduction.cc | 4 +- src/s_tir/schedule/state.cc | 8 +- src/s_tir/schedule/trace.cc | 2 +- src/s_tir/schedule/transform.cc | 4 +- src/s_tir/support/nd_int_set.h | 2 +- src/s_tir/transform/default_gpu_schedule.cc | 6 +- .../transform/inject_software_pipeline.cc | 8 +- .../transform/lower_cross_thread_reduction.cc | 10 +-- src/s_tir/transform/lower_opaque_block.cc | 2 +- src/s_tir/transform/memhammer_coalesce.cc | 2 +- .../transform/memhammer_intermediate_stage.cc | 8 +- .../transform/memhammer_lower_auto_copy.cc | 4 +- .../transform/memhammer_tensorcore_rewrite.cc | 20 ++--- .../plan_update_buffer_allocation_location.cc | 2 +- .../transform/transform_mma_buffer_layout.cc | 6 +- .../using_assume_to_reduce_branches.cc | 4 +- src/script/printer/ir/distributed.cc | 2 +- src/target/cuda/codegen_cuda.cc | 12 +-- src/te/operation/create_primfunc.cc | 9 +- src/tirx/ir/expr.cc | 2 +- src/tirx/ir/index_map.cc | 2 +- src/tirx/ir/script/script_complete.cc | 3 +- src/tirx/script/builder/frame.cc | 2 +- .../transform/lower_device_kernel_launch.cc | 2 +- src/tirx/transform/lower_tvm_builtin.cc | 2 +- src/tirx/transform/make_packed_api.cc | 2 +- src/topi/einsum.cc | 2 +- 62 files changed, 269 insertions(+), 259 deletions(-) diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index 979cb2148c63..226dd88511f8 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -537,7 +537,7 @@ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data, r_shape.push_back(div(padded_shape[i], block_shape[i - 1])); r_shape.push_back(block_shape[i - 1]); block_shape_prod *= block_shape[i - 1]; - axis.push_back(Integer(r_shape.size() - 1)); // index of block_shape[i - 1] + axis.push_back(IntImm(DataType::Int(32), r_shape.size() - 1)); // index of block_shape[i - 1] } size_t n = axis.size(); @@ -553,7 +553,7 @@ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data, // append remaining shape for (size_t i = num_block_dims + 1; i < input_shape.size(); i++) { r_shape.push_back(input_shape[i]); - axis.push_back(Integer(r_shape.size() - 1)); // index of remaining shape in r_shape + axis.push_back(IntImm(DataType::Int(32), r_shape.size() - 1)); // index of remaining shape in r_shape o_shape.push_back(input_shape[i]); } @@ -595,13 +595,13 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, r_shape.push_back(block_shape[i]); block_shape_prod *= block_shape[i]; } - axis.push_back(Integer(r_shape.size())); // axis of (batch / block_shape_prod) + axis.push_back(IntImm(DataType::Int(32), r_shape.size())); // axis of (batch / block_shape_prod) r_shape.push_back(batch / block_shape_prod); for (size_t i = 1; i < num_input_dims; i++) { - axis.push_back(Integer(r_shape.size())); // axis of in_shape[i] + axis.push_back(IntImm(DataType::Int(32), r_shape.size())); // axis of in_shape[i] if (axis.size() < (num_block_dims + num_input_dims)) { - axis.push_back(Integer(r_shape.size() - (num_block_dims + 1))); // axis of block_shape[i] + axis.push_back(IntImm(DataType::Int(32), r_shape.size() - (num_block_dims + 1))); // axis of block_shape[i] } r_shape.push_back(in_shape[i]); } @@ -623,7 +623,7 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, // Crop the start and end of dimensions of out ffi::Array begin_idx, end_idx, strides; for (size_t i = 0; i < r_p_shape.size(); ++i) { - strides.push_back(Integer(1)); + strides.push_back(IntImm(DataType::Int(32), 1)); if (i > 0 && i <= num_block_dims) { // prepare begin and end index for spatial dimensions int begin_i = static_cast(GetConstInt(crop_begin_list[i - 1])); @@ -636,7 +636,7 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, end_idx.push_back(out_i - end_i); } else { // ignore the batch and remaining dimension - begin_idx.push_back(Integer(0)); + begin_idx.push_back(IntImm(DataType::Int(32), 0)); end_idx.push_back(static_cast(GetConstInt(r_p_shape[i]))); } } diff --git a/include/tvm/topi/nn/softmax.h b/include/tvm/topi/nn/softmax.h index 8b18ebe4b686..9786099f9edb 100644 --- a/include/tvm/topi/nn/softmax.h +++ b/include/tvm/topi/nn/softmax.h @@ -61,7 +61,7 @@ inline Tensor softmax(const Tensor& x, int axis = -1, std::string name = "tensor auto reduced_shape = MakeReduceTargetShape({axis}, x, false, false); tvm::ffi::Map attrs; - attrs.Set("axis", Integer(axis)); + attrs.Set("axis", IntImm(DataType::Int(32), axis)); auto insert_reduce_index = [axis, ndim](const ffi::Array& indices, const IterVar& reduce_index) { diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index db53b8b64f33..dda18baa15fb 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -919,7 +919,7 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const ffi::Array TVM_FFI_ICHECK(axis >= 0 && axis < src_tensor_dim) << "Axis " << axes[i].IntValue() << " is out of bounds for tensor with " << src_tensor_dim << " dimensions"; - normalized_axes.push_back(Integer(axis)); + normalized_axes.push_back(IntImm(DataType::Int(32), axis)); } std::vector begin_vec, end_vec, strides_vec; @@ -2044,7 +2044,7 @@ inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const Prim int indices_index = 0; for (int i = 0; i < ndim; i++) { if (i == true_axis) { - oshape.push_back(Integer(depth)); + oshape.push_back(IntImm(DataType::Int(32), depth)); } else { oshape.push_back(indices->shape[indices_index++]); } diff --git a/src/arith/conjunctive_normal_form.cc b/src/arith/conjunctive_normal_form.cc index 92afb242313a..17df960a127c 100644 --- a/src/arith/conjunctive_normal_form.cc +++ b/src/arith/conjunctive_normal_form.cc @@ -25,6 +25,7 @@ #include #include +#include #include #include @@ -138,15 +139,15 @@ class AndOfOrs { /*! \brief Mapping from PrimExpr to internal Key */ std::unordered_map expr_to_key_; - /*! \brief Cached key representing tirx::Bool(true) */ + /*! \brief Cached key representing tirx::IntImm(DataType::Bool(), 1) */ Key key_true_; - /*! \brief Cached key representing tirx::Bool(false) */ + /*! \brief Cached key representing tirx::IntImm(DataType::Bool(), 0) */ Key key_false_; }; AndOfOrs::AndOfOrs(const PrimExpr& expr) - : key_true_(GetKey(Bool(true))), key_false_(GetKey(Bool(false))) { + : key_true_(GetKey(IntImm(DataType::Bool(), 1))), key_false_(GetKey(IntImm(DataType::Bool(), 0))) { VisitAndExpressions(expr, [&](const PrimExpr& outer_expr) { std::vector or_components; VisitOrExpressions(outer_expr, [&](const PrimExpr& inner_expr) { @@ -233,9 +234,9 @@ PrimExpr AndOfOrs::GetExpr(AndOfOrs::Key key) const { } PrimExpr AndOfOrs::AsPrimExpr() const { - PrimExpr expr = Bool(true); + PrimExpr expr = IntImm(DataType::Bool(), 1); for (const auto& chunk : chunks_) { - PrimExpr chunk_expr = Bool(false); + PrimExpr chunk_expr = IntImm(DataType::Bool(), 0); for (Key j : chunk) { chunk_expr = chunk_expr || GetExpr(j); } @@ -366,7 +367,7 @@ void AndOfOrs::SimplifyAcrossChunks(Analyzer* analyzer) { // When attempting to simplify (B and C), the analyzer may // assume that A is false. PrimExpr known = [&]() { - PrimExpr known = Bool(true); + PrimExpr known = IntImm(DataType::Bool(), 1); for (const auto& key : i_chunk) { if (&key != &key_i) { known = known && analyzer->Simplify(!GetExpr(key)); diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index a8233d183418..2f9111a0c03a 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -1711,7 +1711,7 @@ PrimExpr ApproxLeastCommonMultiple(const PrimExpr& a, const PrimExpr& b, Analyze }; auto p1 = fsplit(a); auto p2 = fsplit(b); - auto const_lcm = Integer(LeastCommonMultiple(p1.second, p2.second)); + auto const_lcm = IntImm(DataType::Int(32), LeastCommonMultiple(p1.second, p2.second)); if (analyzer->CanProveEqual(p1.first, p2.first)) { return p1.first * const_lcm; } else if (analyzer->CanProveEqual(floormod(p1.first, p2.first), 0)) { @@ -2479,7 +2479,7 @@ class SubspaceDivider { std::unordered_map split_map_; // predicate of outer space and inner space; - PrimExpr outer_preds_{Bool(true)}, inner_preds_{Bool(true)}; + PrimExpr outer_preds_{const_true()}, inner_preds_{const_true()}; }; ffi::Array> SubspaceDivide(const ffi::Array& bindings, @@ -2540,7 +2540,7 @@ class InverseAffineIterMapTransformer { // initialize back propagation accumulator for (const IterMapExprNode* node : post_dfs_order) { - backprop_.Set(ffi::GetRef(node), Integer(0)); + backprop_.Set(ffi::GetRef(node), IntImm(DataType::Int(32), 0)); } for (size_t i = 0; i < iter_map.size(); i++) { backprop_.Set(iter_map[i], outputs[i]); diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index 840b27941158..f01972351b3e 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -303,7 +303,7 @@ class ModularSetAnalyzer::Impl : public ExprFunctorargs[1]); if (b.is_const()) { int shift; - if (is_const_power_of_two_integer(Integer(b.base + 1), &shift)) { + if (is_const_power_of_two_integer(IntImm(DataType::Int(32), b.base + 1), &shift)) { return ModByConst(op->args[0], static_cast(1) << shift, true); } } diff --git a/src/arith/presburger_set.cc b/src/arith/presburger_set.cc index 0cf7b57b9593..c36a19349305 100644 --- a/src/arith/presburger_set.cc +++ b/src/arith/presburger_set.cc @@ -126,9 +126,9 @@ void PresburgerSetNode::UpdateConstraint(const PrimExpr& constraint, const ffi:: } PrimExpr PresburgerSetNode::GenerateConstraint() const { - PrimExpr constraint = Bool(0); + PrimExpr constraint = const_false(); for (const IntegerRelation& disjunct : disjuncts) { - PrimExpr union_entry = Bool(1); + PrimExpr union_entry = const_true(); for (unsigned i = 0, e = disjunct.getNumEqualities(); i < e; ++i) { PrimExpr linear_eq = IntImm(DataType::Int(64), 0); if (disjunct.getNumCols() > 1) { diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 804cb3cd976c..0d13a5ecd375 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -211,7 +211,7 @@ CompareResult RewriteSimplifier::Impl::TryComparisonOfProductAndSum(const PrimEx (B * A) + (A + B) * C, } .Match(diff)) { - return std::tuple{A.Eval(), B.Eval(), C.Eval(), Integer(-1)}; + return std::tuple{A.Eval(), B.Eval(), C.Eval(), IntImm(DataType::Int(32), -1)}; } else { return std::nullopt; } @@ -1063,7 +1063,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { floordiv(y + x * c1, c2).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; - PrimExpr yval = y.EvalOr(Integer(0)); + PrimExpr yval = y.EvalOr(IntImm(DataType::Int(32), 0)); if (c2val == 0) return ret; // try eliminate residue part @@ -1072,7 +1072,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { PrimExpr y_div = CanProveEqual(floordiv(yval, c2val), 0) ? 0 : floordiv(yval, c2val); auto bound = analyzer_->const_int_bound(residue); if (bound.defined() && bound->max_value == bound->min_value) { - return x.Eval() * floordiv(c1val, c2.Eval()) + (y_div + Integer(bound->max_value)); + return x.Eval() * floordiv(c1val, c2.Eval()) + (y_div + IntImm(DataType::Int(32), bound->max_value)); } // try simplify divisor diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index ff89c6347fb4..704e40c6b191 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -30,6 +30,7 @@ #include #include #include +#include namespace tvm { namespace relax { @@ -632,97 +633,97 @@ class StructInfoBasePreconditionCollector PrimExpr VisitStructInfo(const StructInfo& lhs, const StructInfo& other) override { if (lhs.same_as(other)) { // Early bail-out if the StructInfo has reference equality. - return Bool(true); + return IntImm(DataType::Bool(), 1); } else { return StructInfoFunctor::VisitStructInfo(lhs, other); } } PrimExpr VisitStructInfo_(const ObjectStructInfoNode* lhs, const StructInfo& other) final { - return Bool(true); + return IntImm(DataType::Bool(), 1); } PrimExpr VisitStructInfo_(const PrimStructInfoNode* lhs, const StructInfo& other) final { auto* rhs = other.as(); if (rhs == nullptr) { - return Bool(false); + return IntImm(DataType::Bool(), 0); } if (lhs->dtype != rhs->dtype) { - return Bool(false); + return IntImm(DataType::Bool(), 0); } if (lhs->value.defined() && rhs->value.defined()) { return lhs->value.value() == rhs->value.value(); } else if (lhs->value.defined() && !rhs->value.defined()) { - return Bool(false); + return IntImm(DataType::Bool(), 0); } else { - return Bool(true); + return IntImm(DataType::Bool(), 1); } } PrimExpr VisitStructInfo_(const ShapeStructInfoNode* lhs, const StructInfo& other) final { auto* rhs = other.as(); if (rhs == nullptr) { - return Bool(false); + return IntImm(DataType::Bool(), 0); } // lhs have unknown ndim if (lhs->IsUnknownNdim()) { - return Bool(true); + return IntImm(DataType::Bool(), 1); } // ndim must match if (lhs->ndim != rhs->ndim) { - return Bool(false); + return IntImm(DataType::Bool(), 0); } if (lhs->values.defined() && rhs->values.defined()) { return ArrayCheck(lhs->values.value(), rhs->values.value()); } else if (lhs->values.defined() && !rhs->values.defined()) { - return Bool(false); + return IntImm(DataType::Bool(), 0); } else { - return Bool(true); + return IntImm(DataType::Bool(), 1); } } PrimExpr VisitStructInfo_(const TensorStructInfoNode* lhs, const StructInfo& other) final { auto* rhs = other.as(); if (rhs == nullptr) { - return Bool(false); + return IntImm(DataType::Bool(), 0); } // dtype mismatch if (!lhs->IsUnknownDtype() && lhs->dtype != rhs->dtype) { - return Bool(false); + return IntImm(DataType::Bool(), 0); } // ndim mismatch if (!lhs->IsUnknownNdim() && lhs->ndim != rhs->ndim) { - return Bool(false); + return IntImm(DataType::Bool(), 0); } // vdevice mismatch if (lhs->vdevice.defined() && !rhs->vdevice.defined()) { - return Bool(false); + return IntImm(DataType::Bool(), 0); } if (lhs->vdevice.defined() && rhs->vdevice.defined()) { VDevice lhs_vdevice = lhs->vdevice.value(); VDevice rhs_vdevice = rhs->vdevice.value(); if (lhs_vdevice->target.defined() && !rhs_vdevice->target.defined()) { - return Bool(false); + return IntImm(DataType::Bool(), 0); } // mismatch in either the target, vdevice_id, or memory_scope if ((lhs_vdevice->target.defined() && rhs_vdevice->target.defined()) && (lhs_vdevice->target != rhs_vdevice->target || lhs_vdevice->vdevice_id != rhs_vdevice->vdevice_id || lhs_vdevice->memory_scope != rhs_vdevice->memory_scope)) { - return Bool(false); + return IntImm(DataType::Bool(), 0); } } if (lhs->shape.same_as(rhs->shape)) { - return Bool(true); + return IntImm(DataType::Bool(), 1); } else if (lhs->shape.defined() && !rhs->shape.defined()) { - return Bool(false); + return IntImm(DataType::Bool(), 0); } auto* lhs_shape = lhs->shape.as(); @@ -730,23 +731,23 @@ class StructInfoBasePreconditionCollector if (lhs_shape && rhs_shape) { return ArrayCheck(lhs_shape->values, rhs_shape->values); } else if (lhs_shape && !rhs_shape) { - return Bool(false); + return IntImm(DataType::Bool(), 0); } - return Bool(true); + return IntImm(DataType::Bool(), 1); } PrimExpr VisitStructInfo_(const distributed::DTensorStructInfoNode* lhs, const StructInfo& other) final { auto* rhs = other.as(); if (rhs == nullptr) { - return Bool(false); + return IntImm(DataType::Bool(), 0); } ffi::StructuralEqual struct_equal; if (!struct_equal(lhs->device_mesh, rhs->device_mesh) || !struct_equal(lhs->placement, rhs->placement)) { - return Bool(false); + return IntImm(DataType::Bool(), 0); } return this->VisitStructInfo(lhs->tensor_sinfo, rhs->tensor_sinfo); @@ -755,7 +756,7 @@ class StructInfoBasePreconditionCollector PrimExpr VisitStructInfo_(const TupleStructInfoNode* lhs, const StructInfo& other) final { auto* rhs = other.as(); if (rhs == nullptr) { - return Bool(false); + return IntImm(DataType::Bool(), 0); } return ArrayCheck(lhs->fields, rhs->fields); } @@ -763,19 +764,19 @@ class StructInfoBasePreconditionCollector PrimExpr VisitStructInfo_(const FuncStructInfoNode* lhs, const StructInfo& other) override { auto* rhs = other.as(); if (rhs == nullptr) { - return Bool(false); + return IntImm(DataType::Bool(), 0); } // Check purity: Pure functions are a subtype of impure functions if (lhs->purity && !rhs->purity) { - return Bool(false); + return IntImm(DataType::Bool(), 0); } if (lhs->derive_func.defined() && !lhs->derive_func.same_as(rhs->derive_func)) { - return Bool(false); + return IntImm(DataType::Bool(), 0); } if (lhs->params.defined() && !rhs->params.defined()) { - return Bool(false); + return IntImm(DataType::Bool(), 0); } PrimExpr all_match = VisitStructInfo(lhs->ret, rhs->ret); @@ -784,7 +785,7 @@ class StructInfoBasePreconditionCollector if (lhs->params.defined()) { param_check = ArrayCheck(lhs->params.value(), rhs->params.value()); } else { - param_check = Bool(true); + param_check = IntImm(DataType::Bool(), 1); } PrimExpr ret_check = VisitStructInfo(lhs->ret, rhs->ret); @@ -795,10 +796,10 @@ class StructInfoBasePreconditionCollector private: PrimExpr ArrayCheck(const ffi::Array& lhs, const ffi::Array& rhs) { if (lhs.size() != rhs.size()) { - return Bool(false); + return IntImm(DataType::Bool(), 0); } - PrimExpr all_equal = Bool(true); + PrimExpr all_equal = IntImm(DataType::Bool(), 1); for (size_t i = 0; i < lhs.size(); i++) { all_equal = all_equal && (lhs[i] == rhs[i]); } @@ -807,10 +808,10 @@ class StructInfoBasePreconditionCollector PrimExpr ArrayCheck(const ffi::Array& lhs, const ffi::Array& rhs) { if (lhs.size() != rhs.size()) { - return Bool(false); + return IntImm(DataType::Bool(), 0); } - PrimExpr all_pass = Bool(true); + PrimExpr all_pass = IntImm(DataType::Bool(), 1); for (size_t i = 0; i < lhs.size(); ++i) { all_pass = all_pass && VisitStructInfo(lhs[i], rhs[i]); diff --git a/src/relax/analysis/tir_op_pattern_kind.cc b/src/relax/analysis/tir_op_pattern_kind.cc index cdf3bf21ebab..ace88a5ce801 100644 --- a/src/relax/analysis/tir_op_pattern_kind.cc +++ b/src/relax/analysis/tir_op_pattern_kind.cc @@ -26,6 +26,7 @@ #include #include #include +#include namespace tvm { namespace relax { @@ -444,7 +445,7 @@ bool HasReshapePattern(const PrimFunc& func) { return arith::IterMapSimplify( /*indices=*/{idx}, /*input_iters=*/var_range, - /*input_pred=*/Bool(true), + /*input_pred=*/const_true(), /*check_level=*/arith::IterMapLevel::Surjective, /*analyzer=*/&ana_, /*simplify_trivial_iterators=*/true)[0]; @@ -494,7 +495,7 @@ bool HasReshapePattern(const PrimFunc& func) { ffi::Array simplify_res = arith::IterMapSimplify( /*indices=*/{flattened_idx}, /*input_iters=*/{{fused_var, Range(IntImm(dtype, /*value=*/0), stride)}}, - /*input_pred=*/Bool(true), + /*input_pred=*/const_true(), /*check_level=*/arith::IterMapLevel::Surjective, /*analyzer=*/&this->ana_, /*simplify_trivial_iterators=*/true); diff --git a/src/relax/backend/contrib/clml/codegen.cc b/src/relax/backend/contrib/clml/codegen.cc index c58c2ee9aa92..94083c8b8f18 100644 --- a/src/relax/backend/contrib/clml/codegen.cc +++ b/src/relax/backend/contrib/clml/codegen.cc @@ -48,7 +48,7 @@ struct OpenCLMLCompilerConfigNode : public ffi::Object { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro( "clml_version", &OpenCLMLCompilerConfigNode::clml_version, - "OpenCLML version as (major, minor, patch).", refl::DefaultValue(Integer(3))); + "OpenCLML version as (major, minor, patch).", refl::DefaultValue(IntImm(DataType::Int(32), 3))); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.ext.attrs.OpenCLMLCompilerConfig", OpenCLMLCompilerConfigNode, ffi::Object); @@ -334,9 +334,9 @@ inline constexpr bool IsOpenCLMLRuntimeEnabled() { */ Integer GetOpenCLMLVersion() { #if TVM_GRAPH_EXECUTOR_CLML - return Integer(TVM_CLML_VERSION); + return IntImm(DataType::Int(32), TVM_CLML_VERSION); #else - return Integer(3); + return IntImm(DataType::Int(32), 3); #endif // TVM_GRAPH_EXECUTOR_CLML } diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index e8eafde31747..08d39ac29c42 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -471,7 +471,7 @@ PrimExpr DFPatternMatcher::SimplifyCondition(PrimExpr condition) { constraints.begin(), constraints.end(), [&sort_key](const PrimExpr& a, const PrimExpr& b) { return sort_key(a) < sort_key(b); }); - PrimExpr sorted_condition = Bool(true); + PrimExpr sorted_condition = IntImm(DataType::Bool(), 1); for (const PrimExpr& constraint : constraints) { sorted_condition = sorted_condition && constraint; } @@ -504,7 +504,7 @@ std::tuple SameShapeConstraintNode::AsPrimExpr( bool all_shapes_defined = true; // The expression that must be true in order - PrimExpr all_dimensions_equal = Bool(true); + PrimExpr all_dimensions_equal = IntImm(DataType::Bool(), 1); for (const auto& arg : args) { if (auto opt_var = match_state(arg.get())) { @@ -523,7 +523,7 @@ std::tuple SameShapeConstraintNode::AsPrimExpr( if (!opt_var_shape.defined()) { // The pattern has matched to something without a shape. // Therefore, it cannot have the same shape as something else. - return {PrimExpr(Bool(false)), true}; + return {PrimExpr(IntImm(DataType::Bool(), 0)), true}; } auto var_shape = opt_var_shape.value(); @@ -540,7 +540,7 @@ std::tuple SameShapeConstraintNode::AsPrimExpr( // The shapes have different dimensionality. No need to // perform potentially-expensive simplifications, because // the dimensions do not match. - return {PrimExpr(Bool(false)), true}; + return {PrimExpr(IntImm(DataType::Bool(), 0)), true}; } } else { diff --git a/src/relax/ir/dataflow_matcher.h b/src/relax/ir/dataflow_matcher.h index ca6b5a97087a..45b76de68ad0 100644 --- a/src/relax/ir/dataflow_matcher.h +++ b/src/relax/ir/dataflow_matcher.h @@ -32,6 +32,7 @@ #include #include #include +#include namespace tvm { namespace relax { @@ -93,7 +94,7 @@ class DFPatternMatcher : public DFPatternFunctor memo_; var2val_t var2val_; std::vector matched_nodes_; - PrimExpr symbolic_expr_condition_{Bool(true)}; + PrimExpr symbolic_expr_condition_{IntImm(DataType::Bool(), 1)}; arith::Analyzer analyzer_; bool memoize_ = true; }; diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index c203e59d4e35..b69f58ebb7af 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -30,6 +30,7 @@ #include #include #include +#include // functions to be overriden. #define RELAX_VISIT_BINDING_DISPATCH(OP) \ @@ -798,7 +799,7 @@ Expr ExprMutator::VisitWithNewScope(const Expr& expr, ffi::OptionalIsInstance()) << "Normal form requires all new scope is stored as SeqExpr"; - PrimExpr constraint = Bool(true); + PrimExpr constraint = IntImm(DataType::Bool(), 1); if (params.defined()) { auto non_negative_expressions = CollectNonNegativeExpressions(TupleStructInfo(params.value().Map(GetStructInfo))); diff --git a/src/relax/op/memory/view.cc b/src/relax/op/memory/view.cc index 62bddebb0483..74b1e0c69519 100644 --- a/src/relax/op/memory/view.cc +++ b/src/relax/op/memory/view.cc @@ -188,7 +188,7 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { return std::nullopt; } - PrimExpr num_elements = Integer(1); + PrimExpr num_elements = IntImm(DataType::Int(32), 1); for (const auto& dim : shape.value()) { num_elements *= dim; } diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc index 1b77b4225203..2fdea26bd7ed 100644 --- a/src/relax/op/nn/convolution.cc +++ b/src/relax/op/nn/convolution.cc @@ -128,15 +128,15 @@ StructInfo InferStructInfoConv1d(const Call& call, const BlockBuilder& ctx) { PrimExpr input_w = data_NCW_shape[2]; PrimExpr kernel_w = weight_OIW_shape[2]; - PrimExpr padding_w = Integer(attrs->padding[0]) + Integer(attrs->padding[1]); + PrimExpr padding_w = IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[1]); std::vector out_NCW_shape; out_NCW_shape.resize(3); out_NCW_shape[0] = data_NCW_shape[0]; out_NCW_shape[1] = weight_OIW_shape[0]; - PrimExpr numerator_w = input_w + padding_w - Integer(attrs->dilation[0]) * (kernel_w - 1) - 1; - out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w, Integer(attrs->strides[0])) + 1); + PrimExpr numerator_w = input_w + padding_w - IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_w - 1) - 1; + out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w, IntImm(DataType::Int(32), attrs->strides[0])) + 1); ffi::Array out_shape = out2NCW.BackwardShape(out_NCW_shape); return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); @@ -299,18 +299,18 @@ StructInfo InferStructInfoConv2d(const Call& call, const BlockBuilder& ctx) { PrimExpr input_w = data_NCHW_shape[3]; PrimExpr kernel_h = weight_OIHW_shape[2]; PrimExpr kernel_w = weight_OIHW_shape[3]; - PrimExpr padding_h = Integer(attrs->padding[0]) + Integer(attrs->padding[2]); - PrimExpr padding_w = Integer(attrs->padding[1]) + Integer(attrs->padding[3]); + PrimExpr padding_h = IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[2]); + PrimExpr padding_w = IntImm(DataType::Int(32), attrs->padding[1]) + IntImm(DataType::Int(32), attrs->padding[3]); std::vector out_NCHW_shape; out_NCHW_shape.resize(4); out_NCHW_shape[0] = data_NCHW_shape[0]; out_NCHW_shape[1] = weight_OIHW_shape[0]; - PrimExpr numerator_h = input_h + padding_h - Integer(attrs->dilation[0]) * (kernel_h - 1) - 1; - PrimExpr numerator_w = input_w + padding_w - Integer(attrs->dilation[1]) * (kernel_w - 1) - 1; - out_NCHW_shape[2] = analyzer->Simplify(floordiv(numerator_h, Integer(attrs->strides[0])) + 1); - out_NCHW_shape[3] = analyzer->Simplify(floordiv(numerator_w, Integer(attrs->strides[1])) + 1); + PrimExpr numerator_h = input_h + padding_h - IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_h - 1) - 1; + PrimExpr numerator_w = input_w + padding_w - IntImm(DataType::Int(32), attrs->dilation[1]) * (kernel_w - 1) - 1; + out_NCHW_shape[2] = analyzer->Simplify(floordiv(numerator_h, IntImm(DataType::Int(32), attrs->strides[0])) + 1); + out_NCHW_shape[3] = analyzer->Simplify(floordiv(numerator_w, IntImm(DataType::Int(32), attrs->strides[1])) + 1); ffi::Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); @@ -512,21 +512,21 @@ StructInfo InferStructInfoConv3d(const Call& call, const BlockBuilder& ctx) { PrimExpr kernel_d = weight_OIDHW_shape[2]; PrimExpr kernel_h = weight_OIDHW_shape[3]; PrimExpr kernel_w = weight_OIDHW_shape[4]; - PrimExpr padding_d = Integer(attrs->padding[0]) + Integer(attrs->padding[3]); - PrimExpr padding_h = Integer(attrs->padding[1]) + Integer(attrs->padding[4]); - PrimExpr padding_w = Integer(attrs->padding[2]) + Integer(attrs->padding[5]); + PrimExpr padding_d = IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[3]); + PrimExpr padding_h = IntImm(DataType::Int(32), attrs->padding[1]) + IntImm(DataType::Int(32), attrs->padding[4]); + PrimExpr padding_w = IntImm(DataType::Int(32), attrs->padding[2]) + IntImm(DataType::Int(32), attrs->padding[5]); std::vector out_NCDHW_shape; out_NCDHW_shape.resize(5); out_NCDHW_shape[0] = data_NCDHW_shape[0]; out_NCDHW_shape[1] = weight_OIDHW_shape[0]; - PrimExpr numerator_d = input_d + padding_d - Integer(attrs->dilation[0]) * (kernel_d - 1) - 1; - PrimExpr numerator_h = input_h + padding_h - Integer(attrs->dilation[1]) * (kernel_h - 1) - 1; - PrimExpr numerator_w = input_w + padding_w - Integer(attrs->dilation[2]) * (kernel_w - 1) - 1; - out_NCDHW_shape[2] = analyzer->Simplify(floordiv(numerator_d, Integer(attrs->strides[0])) + 1); - out_NCDHW_shape[3] = analyzer->Simplify(floordiv(numerator_h, Integer(attrs->strides[1])) + 1); - out_NCDHW_shape[4] = analyzer->Simplify(floordiv(numerator_w, Integer(attrs->strides[2])) + 1); + PrimExpr numerator_d = input_d + padding_d - IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_d - 1) - 1; + PrimExpr numerator_h = input_h + padding_h - IntImm(DataType::Int(32), attrs->dilation[1]) * (kernel_h - 1) - 1; + PrimExpr numerator_w = input_w + padding_w - IntImm(DataType::Int(32), attrs->dilation[2]) * (kernel_w - 1) - 1; + out_NCDHW_shape[2] = analyzer->Simplify(floordiv(numerator_d, IntImm(DataType::Int(32), attrs->strides[0])) + 1); + out_NCDHW_shape[3] = analyzer->Simplify(floordiv(numerator_h, IntImm(DataType::Int(32), attrs->strides[1])) + 1); + out_NCDHW_shape[4] = analyzer->Simplify(floordiv(numerator_w, IntImm(DataType::Int(32), attrs->strides[2])) + 1); ffi::Array out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape); return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); @@ -701,16 +701,16 @@ StructInfo InferStructInfoConv1dTranspose(const Call& call, const BlockBuilder& PrimExpr input_w = data_NCW_shape[2]; PrimExpr kernel_w = weight_IOW_shape[2]; - PrimExpr padding_w = Integer(attrs->padding[0]) + Integer(attrs->padding[1]); + PrimExpr padding_w = IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[1]); std::vector out_NCW_shape; out_NCW_shape.resize(3); out_NCW_shape[0] = data_NCW_shape[0]; out_NCW_shape[1] = weight_IOW_shape[1] * attrs->groups; - PrimExpr out_w = (input_w - 1) * Integer(attrs->strides[0]) - padding_w + - Integer(attrs->dilation[0]) * (kernel_w - 1) + - Integer(attrs->output_padding[0]) + 1; + PrimExpr out_w = (input_w - 1) * IntImm(DataType::Int(32), attrs->strides[0]) - padding_w + + IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_w - 1) + + IntImm(DataType::Int(32), attrs->output_padding[0]) + 1; out_NCW_shape[2] = analyzer->Simplify(out_w); ffi::Array out_shape = out2NCW.BackwardShape(out_NCW_shape); @@ -895,20 +895,20 @@ StructInfo InferStructInfoConv2dTranspose(const Call& call, const BlockBuilder& PrimExpr input_w = data_NCHW_shape[3]; PrimExpr kernel_h = weight_IOHW_shape[2]; PrimExpr kernel_w = weight_IOHW_shape[3]; - PrimExpr padding_h = Integer(attrs->padding[0]) + Integer(attrs->padding[2]); - PrimExpr padding_w = Integer(attrs->padding[1]) + Integer(attrs->padding[3]); + PrimExpr padding_h = IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[2]); + PrimExpr padding_w = IntImm(DataType::Int(32), attrs->padding[1]) + IntImm(DataType::Int(32), attrs->padding[3]); std::vector out_NCHW_shape; out_NCHW_shape.resize(4); out_NCHW_shape[0] = data_NCHW_shape[0]; out_NCHW_shape[1] = weight_IOHW_shape[1] * attrs->groups; - PrimExpr out_h = (input_h - 1) * Integer(attrs->strides[0]) - padding_h + - Integer(attrs->dilation[0]) * (kernel_h - 1) + - Integer(attrs->output_padding[0]) + 1; - PrimExpr out_w = (input_w - 1) * Integer(attrs->strides[1]) - padding_w + - Integer(attrs->dilation[1]) * (kernel_w - 1) + - Integer(attrs->output_padding[1]) + 1; + PrimExpr out_h = (input_h - 1) * IntImm(DataType::Int(32), attrs->strides[0]) - padding_h + + IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_h - 1) + + IntImm(DataType::Int(32), attrs->output_padding[0]) + 1; + PrimExpr out_w = (input_w - 1) * IntImm(DataType::Int(32), attrs->strides[1]) - padding_w + + IntImm(DataType::Int(32), attrs->dilation[1]) * (kernel_w - 1) + + IntImm(DataType::Int(32), attrs->output_padding[1]) + 1; out_NCHW_shape[2] = analyzer->Simplify(out_h); out_NCHW_shape[3] = analyzer->Simplify(out_w); @@ -1132,24 +1132,24 @@ StructInfo InferStructInfoConv3dTranspose(const Call& call, const BlockBuilder& PrimExpr kernel_d = weight_IODHW_shape[2]; PrimExpr kernel_h = weight_IODHW_shape[3]; PrimExpr kernel_w = weight_IODHW_shape[4]; - PrimExpr padding_d = Integer(attrs->padding[0]) + Integer(attrs->padding[3]); - PrimExpr padding_h = Integer(attrs->padding[1]) + Integer(attrs->padding[4]); - PrimExpr padding_w = Integer(attrs->padding[2]) + Integer(attrs->padding[5]); + PrimExpr padding_d = IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[3]); + PrimExpr padding_h = IntImm(DataType::Int(32), attrs->padding[1]) + IntImm(DataType::Int(32), attrs->padding[4]); + PrimExpr padding_w = IntImm(DataType::Int(32), attrs->padding[2]) + IntImm(DataType::Int(32), attrs->padding[5]); std::vector out_NCDHW_shape; out_NCDHW_shape.resize(5); out_NCDHW_shape[0] = data_NCDHW_shape[0]; out_NCDHW_shape[1] = weight_IODHW_shape[1] * attrs->groups; - PrimExpr out_d = (input_d - 1) * Integer(attrs->strides[0]) - padding_d + - Integer(attrs->dilation[0]) * (kernel_d - 1) + - Integer(attrs->output_padding[0]) + 1; - PrimExpr out_h = (input_h - 1) * Integer(attrs->strides[1]) - padding_h + - Integer(attrs->dilation[1]) * (kernel_h - 1) + - Integer(attrs->output_padding[1]) + 1; - PrimExpr out_w = (input_w - 1) * Integer(attrs->strides[2]) - padding_w + - Integer(attrs->dilation[2]) * (kernel_w - 1) + - Integer(attrs->output_padding[2]) + 1; + PrimExpr out_d = (input_d - 1) * IntImm(DataType::Int(32), attrs->strides[0]) - padding_d + + IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_d - 1) + + IntImm(DataType::Int(32), attrs->output_padding[0]) + 1; + PrimExpr out_h = (input_h - 1) * IntImm(DataType::Int(32), attrs->strides[1]) - padding_h + + IntImm(DataType::Int(32), attrs->dilation[1]) * (kernel_h - 1) + + IntImm(DataType::Int(32), attrs->output_padding[1]) + 1; + PrimExpr out_w = (input_w - 1) * IntImm(DataType::Int(32), attrs->strides[2]) - padding_w + + IntImm(DataType::Int(32), attrs->dilation[2]) * (kernel_w - 1) + + IntImm(DataType::Int(32), attrs->output_padding[2]) + 1; out_NCDHW_shape[2] = analyzer->Simplify(out_d); out_NCDHW_shape[3] = analyzer->Simplify(out_h); out_NCDHW_shape[4] = analyzer->Simplify(out_w); diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc index 60430519111d..df432f9b8e46 100644 --- a/src/relax/op/nn/pooling.cc +++ b/src/relax/op/nn/pooling.cc @@ -99,8 +99,8 @@ StructInfo InferStructInfoPool1D(const Call& call, const BlockBuilder& ctx) { ffi::Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); PrimExpr input_w = data_NCW_shape[2]; - PrimExpr kernel_w = Integer(attrs->pool_size[0]); - PrimExpr padding_w = Integer(attrs->padding[0]) + Integer(attrs->padding[1]); + PrimExpr kernel_w = IntImm(DataType::Int(32), attrs->pool_size[0]); + PrimExpr padding_w = IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[1]); arith::Analyzer* analyzer = ctx->GetAnalyzer(); std::vector out_NCW_shape; @@ -108,14 +108,14 @@ StructInfo InferStructInfoPool1D(const Call& call, const BlockBuilder& ctx) { out_NCW_shape[0] = data_NCW_shape[0]; out_NCW_shape[1] = data_NCW_shape[1]; - PrimExpr numerator_w = input_w + padding_w - Integer(attrs->dilation[0]) * (kernel_w - 1) - 1; + PrimExpr numerator_w = input_w + padding_w - IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_w - 1) - 1; if (attrs->ceil_mode) { - numerator_w += Integer(attrs->strides[0]) - 1; + numerator_w += IntImm(DataType::Int(32), attrs->strides[0]) - 1; } - PrimExpr raw_out_w = floordiv(numerator_w, Integer(attrs->strides[0])) + 1; + PrimExpr raw_out_w = floordiv(numerator_w, IntImm(DataType::Int(32), attrs->strides[0])) + 1; if (attrs->ceil_mode) { PrimExpr invalid_last_w = - (raw_out_w - 1) * Integer(attrs->strides[0]) >= input_w + Integer(attrs->padding[0]); + (raw_out_w - 1) * IntImm(DataType::Int(32), attrs->strides[0]) >= input_w + IntImm(DataType::Int(32), attrs->padding[0]); out_NCW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_w, raw_out_w - 1, raw_out_w)); } else { out_NCW_shape[2] = analyzer->Simplify(raw_out_w); @@ -223,10 +223,10 @@ StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) { PrimExpr input_h = data_NCHW_shape[2]; PrimExpr input_w = data_NCHW_shape[3]; - PrimExpr kernel_h = Integer(attrs->pool_size[0]); - PrimExpr kernel_w = Integer(attrs->pool_size[1]); - PrimExpr padding_h = Integer(attrs->padding[0]) + Integer(attrs->padding[2]); - PrimExpr padding_w = Integer(attrs->padding[1]) + Integer(attrs->padding[3]); + PrimExpr kernel_h = IntImm(DataType::Int(32), attrs->pool_size[0]); + PrimExpr kernel_w = IntImm(DataType::Int(32), attrs->pool_size[1]); + PrimExpr padding_h = IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[2]); + PrimExpr padding_w = IntImm(DataType::Int(32), attrs->padding[1]) + IntImm(DataType::Int(32), attrs->padding[3]); arith::Analyzer* analyzer = ctx->GetAnalyzer(); std::vector out_NCHW_shape; @@ -234,19 +234,19 @@ StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) { out_NCHW_shape[0] = data_NCHW_shape[0]; out_NCHW_shape[1] = data_NCHW_shape[1]; - PrimExpr numerator_h = input_h + padding_h - Integer(attrs->dilation[0]) * (kernel_h - 1) - 1; - PrimExpr numerator_w = input_w + padding_w - Integer(attrs->dilation[1]) * (kernel_w - 1) - 1; + PrimExpr numerator_h = input_h + padding_h - IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_h - 1) - 1; + PrimExpr numerator_w = input_w + padding_w - IntImm(DataType::Int(32), attrs->dilation[1]) * (kernel_w - 1) - 1; if (attrs->ceil_mode) { - numerator_h += Integer(attrs->strides[0]) - 1; - numerator_w += Integer(attrs->strides[1]) - 1; + numerator_h += IntImm(DataType::Int(32), attrs->strides[0]) - 1; + numerator_w += IntImm(DataType::Int(32), attrs->strides[1]) - 1; } - PrimExpr raw_out_h = floordiv(numerator_h, Integer(attrs->strides[0])) + 1; - PrimExpr raw_out_w = floordiv(numerator_w, Integer(attrs->strides[1])) + 1; + PrimExpr raw_out_h = floordiv(numerator_h, IntImm(DataType::Int(32), attrs->strides[0])) + 1; + PrimExpr raw_out_w = floordiv(numerator_w, IntImm(DataType::Int(32), attrs->strides[1])) + 1; if (attrs->ceil_mode) { PrimExpr invalid_last_h = - (raw_out_h - 1) * Integer(attrs->strides[0]) >= input_h + Integer(attrs->padding[0]); + (raw_out_h - 1) * IntImm(DataType::Int(32), attrs->strides[0]) >= input_h + IntImm(DataType::Int(32), attrs->padding[0]); PrimExpr invalid_last_w = - (raw_out_w - 1) * Integer(attrs->strides[1]) >= input_w + Integer(attrs->padding[1]); + (raw_out_w - 1) * IntImm(DataType::Int(32), attrs->strides[1]) >= input_w + IntImm(DataType::Int(32), attrs->padding[1]); out_NCHW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_h, raw_out_h - 1, raw_out_h)); out_NCHW_shape[3] = analyzer->Simplify(if_then_else(invalid_last_w, raw_out_w - 1, raw_out_w)); } else { @@ -378,12 +378,12 @@ StructInfo InferStructInfoPool3D(const Call& call, const BlockBuilder& ctx) { PrimExpr input_d = data_NCDHW_shape[2]; PrimExpr input_h = data_NCDHW_shape[3]; PrimExpr input_w = data_NCDHW_shape[4]; - PrimExpr kernel_d = Integer(attrs->pool_size[0]); - PrimExpr kernel_h = Integer(attrs->pool_size[1]); - PrimExpr kernel_w = Integer(attrs->pool_size[2]); - PrimExpr padding_d = Integer(attrs->padding[0]) + Integer(attrs->padding[3]); - PrimExpr padding_h = Integer(attrs->padding[1]) + Integer(attrs->padding[4]); - PrimExpr padding_w = Integer(attrs->padding[2]) + Integer(attrs->padding[5]); + PrimExpr kernel_d = IntImm(DataType::Int(32), attrs->pool_size[0]); + PrimExpr kernel_h = IntImm(DataType::Int(32), attrs->pool_size[1]); + PrimExpr kernel_w = IntImm(DataType::Int(32), attrs->pool_size[2]); + PrimExpr padding_d = IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[3]); + PrimExpr padding_h = IntImm(DataType::Int(32), attrs->padding[1]) + IntImm(DataType::Int(32), attrs->padding[4]); + PrimExpr padding_w = IntImm(DataType::Int(32), attrs->padding[2]) + IntImm(DataType::Int(32), attrs->padding[5]); arith::Analyzer* analyzer = ctx->GetAnalyzer(); std::vector out_NCDHW_shape; @@ -391,24 +391,24 @@ StructInfo InferStructInfoPool3D(const Call& call, const BlockBuilder& ctx) { out_NCDHW_shape[0] = data_NCDHW_shape[0]; out_NCDHW_shape[1] = data_NCDHW_shape[1]; - PrimExpr numerator_d = input_d + padding_d - Integer(attrs->dilation[0]) * (kernel_d - 1) - 1; - PrimExpr numerator_h = input_h + padding_h - Integer(attrs->dilation[1]) * (kernel_h - 1) - 1; - PrimExpr numerator_w = input_w + padding_w - Integer(attrs->dilation[2]) * (kernel_w - 1) - 1; + PrimExpr numerator_d = input_d + padding_d - IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_d - 1) - 1; + PrimExpr numerator_h = input_h + padding_h - IntImm(DataType::Int(32), attrs->dilation[1]) * (kernel_h - 1) - 1; + PrimExpr numerator_w = input_w + padding_w - IntImm(DataType::Int(32), attrs->dilation[2]) * (kernel_w - 1) - 1; if (attrs->ceil_mode) { - numerator_d += Integer(attrs->strides[0]) - 1; - numerator_h += Integer(attrs->strides[1]) - 1; - numerator_w += Integer(attrs->strides[2]) - 1; + numerator_d += IntImm(DataType::Int(32), attrs->strides[0]) - 1; + numerator_h += IntImm(DataType::Int(32), attrs->strides[1]) - 1; + numerator_w += IntImm(DataType::Int(32), attrs->strides[2]) - 1; } - PrimExpr raw_out_d = floordiv(numerator_d, Integer(attrs->strides[0])) + 1; - PrimExpr raw_out_h = floordiv(numerator_h, Integer(attrs->strides[1])) + 1; - PrimExpr raw_out_w = floordiv(numerator_w, Integer(attrs->strides[2])) + 1; + PrimExpr raw_out_d = floordiv(numerator_d, IntImm(DataType::Int(32), attrs->strides[0])) + 1; + PrimExpr raw_out_h = floordiv(numerator_h, IntImm(DataType::Int(32), attrs->strides[1])) + 1; + PrimExpr raw_out_w = floordiv(numerator_w, IntImm(DataType::Int(32), attrs->strides[2])) + 1; if (attrs->ceil_mode) { PrimExpr invalid_last_d = - (raw_out_d - 1) * Integer(attrs->strides[0]) >= input_d + Integer(attrs->padding[0]); + (raw_out_d - 1) * IntImm(DataType::Int(32), attrs->strides[0]) >= input_d + IntImm(DataType::Int(32), attrs->padding[0]); PrimExpr invalid_last_h = - (raw_out_h - 1) * Integer(attrs->strides[1]) >= input_h + Integer(attrs->padding[1]); + (raw_out_h - 1) * IntImm(DataType::Int(32), attrs->strides[1]) >= input_h + IntImm(DataType::Int(32), attrs->padding[1]); PrimExpr invalid_last_w = - (raw_out_w - 1) * Integer(attrs->strides[2]) >= input_w + Integer(attrs->padding[2]); + (raw_out_w - 1) * IntImm(DataType::Int(32), attrs->strides[2]) >= input_w + IntImm(DataType::Int(32), attrs->padding[2]); out_NCDHW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_d, raw_out_d - 1, raw_out_d)); out_NCDHW_shape[3] = analyzer->Simplify(if_then_else(invalid_last_h, raw_out_h - 1, raw_out_h)); out_NCDHW_shape[4] = analyzer->Simplify(if_then_else(invalid_last_w, raw_out_w - 1, raw_out_w)); @@ -563,7 +563,7 @@ StructInfo InferStructInfoAdaptiveAvgPool1D(const Call& call, const BlockBuilder ffi::Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); ffi::Array out_NCW_shape(data_NCW_shape); if (attrs->output_size.defined()) { - out_NCW_shape.Set(2, Integer(attrs->output_size.value()[0])); + out_NCW_shape.Set(2, IntImm(DataType::Int(32), attrs->output_size.value()[0])); } ffi::Array out_shape = out2NCW.BackwardShape(out_NCW_shape); @@ -648,8 +648,8 @@ StructInfo InferStructInfoAdaptiveAvgPool2D(const Call& call, const BlockBuilder ffi::Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); ffi::Array out_NCHW_shape(data_NCHW_shape); if (attrs->output_size.defined()) { - out_NCHW_shape.Set(2, Integer(attrs->output_size.value()[0])); - out_NCHW_shape.Set(3, Integer(attrs->output_size.value()[1])); + out_NCHW_shape.Set(2, IntImm(DataType::Int(32), attrs->output_size.value()[0])); + out_NCHW_shape.Set(3, IntImm(DataType::Int(32), attrs->output_size.value()[1])); } ffi::Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); @@ -750,9 +750,9 @@ StructInfo InferStructInfoAdaptiveAvgPool3D(const Call& call, const BlockBuilder ffi::Array data_NCDHW_shape = data2NCDHW.ForwardShape(data_shape.value()->values); ffi::Array out_NCDHW_shape(data_NCDHW_shape); if (attrs->output_size.defined()) { - out_NCDHW_shape.Set(2, Integer(attrs->output_size.value()[0])); - out_NCDHW_shape.Set(3, Integer(attrs->output_size.value()[1])); - out_NCDHW_shape.Set(4, Integer(attrs->output_size.value()[2])); + out_NCDHW_shape.Set(2, IntImm(DataType::Int(32), attrs->output_size.value()[0])); + out_NCDHW_shape.Set(3, IntImm(DataType::Int(32), attrs->output_size.value()[1])); + out_NCDHW_shape.Set(4, IntImm(DataType::Int(32), attrs->output_size.value()[2])); } ffi::Array out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape); diff --git a/src/relax/op/vision/multibox_transform_loc.cc b/src/relax/op/vision/multibox_transform_loc.cc index 070c81bbe97d..cffa876235ce 100644 --- a/src/relax/op/vision/multibox_transform_loc.cc +++ b/src/relax/op/vision/multibox_transform_loc.cc @@ -179,7 +179,7 @@ StructInfo InferStructInfoMultiboxTransformLoc(const Call& call, const BlockBuil } } - ffi::Array boxes_shape = {batch, num_anchors, Integer(4)}; + ffi::Array boxes_shape = {batch, num_anchors, IntImm(DataType::Int(32), 4)}; ffi::Array scores_shape = {batch, num_classes, num_anchors}; ffi::Array fields = { TensorStructInfo(ShapeExpr(boxes_shape), cls_sinfo->dtype, vdev), diff --git a/src/relax/op/vision/roi_align.cc b/src/relax/op/vision/roi_align.cc index e1be949fce52..5c3ef52c6a62 100644 --- a/src/relax/op/vision/roi_align.cc +++ b/src/relax/op/vision/roi_align.cc @@ -118,11 +118,11 @@ StructInfo InferStructInfoROIAlign(const Call& call, const BlockBuilder& ctx) { ffi::Array data_shape = data_sinfo->shape.as()->values; ffi::Array out_shape; if (attrs->layout == "NCHW") { - out_shape = {rois_shape->values[0], data_shape[1], Integer(attrs->pooled_size[0]), - Integer(attrs->pooled_size[1])}; + out_shape = {rois_shape->values[0], data_shape[1], IntImm(DataType::Int(32), attrs->pooled_size[0]), + IntImm(DataType::Int(32), attrs->pooled_size[1])}; } else { - out_shape = {rois_shape->values[0], Integer(attrs->pooled_size[0]), - Integer(attrs->pooled_size[1]), data_shape[3]}; + out_shape = {rois_shape->values[0], IntImm(DataType::Int(32), attrs->pooled_size[0]), + IntImm(DataType::Int(32), attrs->pooled_size[1]), data_shape[3]}; } return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); } diff --git a/src/relax/op/vision/roi_pool.cc b/src/relax/op/vision/roi_pool.cc index ffba294c5a77..25e529308882 100644 --- a/src/relax/op/vision/roi_pool.cc +++ b/src/relax/op/vision/roi_pool.cc @@ -110,7 +110,7 @@ StructInfo InferStructInfoROIPool(const Call& call, const BlockBuilder& ctx) { ffi::Array data_shape = data_sinfo->shape.as()->values; ffi::Array out_shape = {rois_shape->values[0], data_shape[1], - Integer(attrs->pooled_size[0]), Integer(attrs->pooled_size[1])}; + IntImm(DataType::Int(32), attrs->pooled_size[0]), IntImm(DataType::Int(32), attrs->pooled_size[1])}; return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); } diff --git a/src/relax/transform/adjust_matmul_order.cc b/src/relax/transform/adjust_matmul_order.cc index 84ad94c3887e..54bca2aaefdf 100644 --- a/src/relax/transform/adjust_matmul_order.cc +++ b/src/relax/transform/adjust_matmul_order.cc @@ -35,6 +35,7 @@ #include "../op/tensor/linear_algebra.h" #include "../op/tensor/manipulate.h" +#include namespace tvm { namespace relax { @@ -72,7 +73,7 @@ std::tuple)>> auto pat = pat_matmul_on_lhs | pat_matmul_on_rhs | pat_permuted_matmul_on_lhs | pat_permuted_matmul_on_rhs; - PrimExpr symbolic_var_constraints = Bool(true); + PrimExpr symbolic_var_constraints = IntImm(DataType::Bool(), 1); auto upper_bounds = func->GetAttr>("tir_var_upper_bound"); auto lower_bounds = func->GetAttr>("tir_var_lower_bound"); diff --git a/src/relax/transform/allocate_workspace.cc b/src/relax/transform/allocate_workspace.cc index 8049b5f0257f..6b8f3c776185 100644 --- a/src/relax/transform/allocate_workspace.cc +++ b/src/relax/transform/allocate_workspace.cc @@ -61,7 +61,7 @@ class ExternFunctionRewriter : ExprMutator { // Append the workspace parameter to this function. ffi::Array new_params = func_node->params; - auto sinfo = TensorStructInfo(ShapeExpr({Integer(max_workspace_size_)}), DataType::UInt(8)); + auto sinfo = TensorStructInfo(ShapeExpr({IntImm(DataType::Int(32), max_workspace_size_)}), DataType::UInt(8)); Var workspace_param(name_sup_->FreshName("workspace"), sinfo); if (func_node->GetAttr(attr::kCodegen)) { @@ -148,7 +148,7 @@ class WorkspaceProvider : ExprMutator { BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) final { builder_->BeginDataflowBlock(); if (!workspace_var_main_.defined()) { - auto shape = ShapeExpr({Integer(max_workspace_size_)}); + auto shape = ShapeExpr({IntImm(DataType::Int(32), max_workspace_size_)}); auto ty = DataTypeImm(DataType::UInt(8)); auto workspace = MakeAllocTensor(shape, ty, PrimValue::Int64(0)); workspace_var_main_ = builder_->Emit(workspace, "workspace_main"); diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 52f38d1a8c3e..3db3c12f1e96 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -154,7 +155,7 @@ class SymbolicMatcher : ExprFunctor* var_remap_; - PrimExpr must_prove_ = Bool(true); + PrimExpr must_prove_ = IntImm(DataType::Bool(), 1); }; /*! @@ -1020,7 +1021,7 @@ class FusedTIRConstructor : public ExprVisitor { body = subst.Substitute(body); body = tirx::SBlock({}, {}, {}, "root", std::move(body), std::nullopt, alloc_buffers); - body = tirx::SBlockRealize({}, Bool(true), Downcast(body)); + body = tirx::SBlockRealize({}, IntImm(DataType::Bool(), 1), Downcast(body)); tirx::PrimFunc func(func_info_.params, body, VoidType(), func_info_.buffer_map, DictAttrs(attr_map)); // Renew function defs to prevent using the same symbolic vars in different functions diff --git a/src/relax/transform/infer_amp_utils.cc b/src/relax/transform/infer_amp_utils.cc index 2b2bb1949d60..94fe226146fc 100644 --- a/src/relax/transform/infer_amp_utils.cc +++ b/src/relax/transform/infer_amp_utils.cc @@ -54,11 +54,11 @@ NType NTypeMerge(const NType& a, const NType& b) { } ffi::Array InferMixedPrecisionFollow(const Call& call, const DataType& out_dtype) { - return {Integer(MixedPrecisionPolicyKind::kFollow), call}; + return {IntImm(DataType::Int(32), MixedPrecisionPolicyKind::kFollow), call}; } ffi::Array InferMixedPrecisionNever(const Call& call, const DataType& out_dtype) { - return {Integer(MixedPrecisionPolicyKind::kNever), call}; + return {IntImm(DataType::Int(32), MixedPrecisionPolicyKind::kNever), call}; } } // namespace relax diff --git a/src/s_tir/analysis/identify_memcpy.cc b/src/s_tir/analysis/identify_memcpy.cc index 91ccf1e89783..11cdc2487548 100644 --- a/src/s_tir/analysis/identify_memcpy.cc +++ b/src/s_tir/analysis/identify_memcpy.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include #include @@ -105,7 +106,7 @@ std::variant IdentifyMemCpyImpl(const For& loop, // for i in T.serial(16): // B[i] = A[T.abs(i-8)] - auto src_iter_map = arith::DetectIterMap({src_index}, loop_ranges, Bool(true), + auto src_iter_map = arith::DetectIterMap({src_index}, loop_ranges, const_true(), arith::IterMapLevel::Bijective, analyzer); if (src_iter_map->errors.size()) { return static_cast(std::stringstream() @@ -115,7 +116,7 @@ std::variant IdentifyMemCpyImpl(const For& loop, << " for src_index = " << src_index) .str(); } - auto dst_iter_map = arith::DetectIterMap({dst_index}, loop_ranges, Bool(true), + auto dst_iter_map = arith::DetectIterMap({dst_index}, loop_ranges, const_true(), arith::IterMapLevel::Bijective, analyzer); if (dst_iter_map->errors.size()) { return static_cast(std::stringstream() diff --git a/src/s_tir/meta_schedule/utils.h b/src/s_tir/meta_schedule/utils.h index 5576594f757b..738e8ac95c9d 100644 --- a/src/s_tir/meta_schedule/utils.h +++ b/src/s_tir/meta_schedule/utils.h @@ -655,9 +655,9 @@ class SBlockCollector : public tirx::StmtVisitor { // If filter function is provided, use it to selectively collect blocks. // Otherwise collect all blocks. - Bool collect_block = Bool(true); + bool collect_block = true; if (f_block_filter_ != nullptr) { - collect_block = f_block_filter_(ffi::GetRef(block)).cast(); + collect_block = f_block_filter_(ffi::GetRef(block)).cast()->value != 0; } if (collect_block) { blocks_to_collect_.push_back(block->name_hint); diff --git a/src/s_tir/schedule/analysis/layout.cc b/src/s_tir/schedule/analysis/layout.cc index ef7acb1163ba..7700a94e2d54 100644 --- a/src/s_tir/schedule/analysis/layout.cc +++ b/src/s_tir/schedule/analysis/layout.cc @@ -218,14 +218,14 @@ ffi::Optional SuggestIndexMap(const Buffer& buffer, const ffi::Array

Bind(index, Range::FromMinExtent(0, Integer(split_exprs[i].extent))); + analyzer->Bind(index, Range::FromMinExtent(0, IntImm(DataType::Int(32), split_exprs[i].extent))); } // Step 6.2: Fuse all the indices. This is the inverse of Step 5.2. PrimExpr flattened_index = make_const(indices[0]->dtype, 0); int64_t stride = 1; for (int i = static_cast(split_exprs.size()) - 1; i >= 0; --i) { - flattened_index = inv_permuted_indices[i] * Integer(stride) + flattened_index; + flattened_index = inv_permuted_indices[i] * IntImm(DataType::Int(32), stride) + flattened_index; stride *= split_exprs[i].extent; } // Step 6.3: Split the flattened index into multiple indices. This is the inverse of Step 5.1. diff --git a/src/s_tir/schedule/primitive/blockize_tensorize.cc b/src/s_tir/schedule/primitive/blockize_tensorize.cc index a2f915b0bb86..5c55f5d7578f 100644 --- a/src/s_tir/schedule/primitive/blockize_tensorize.cc +++ b/src/s_tir/schedule/primitive/blockize_tensorize.cc @@ -139,8 +139,8 @@ ffi::Array> TrivialSubspaceDivision( return {}; } } - res.push_back({arith::IterMark(arith::IterSumExpr({}, 0), Bool(true)), - arith::IterMark(arith::IterSumExpr({}, 0), Bool(true))}); + res.push_back({arith::IterMark(arith::IterSumExpr({}, 0), const_true()), + arith::IterMark(arith::IterSumExpr({}, 0), const_true())}); return res; } diff --git a/src/s_tir/schedule/primitive/cache_read_write.cc b/src/s_tir/schedule/primitive/cache_read_write.cc index 3c754d1fa3af..626eaa57f3d4 100644 --- a/src/s_tir/schedule/primitive/cache_read_write.cc +++ b/src/s_tir/schedule/primitive/cache_read_write.cc @@ -189,13 +189,13 @@ SBlock MakeReindexCacheStage(const BufferRegion& cache_region, ReindexCacheStage Region& old_region = (is_cache_read) ? read_access_region : write_access_region; for (const Range& range : cache_region->region) { old_indices.push_back(Substitute(range->min, var_map)); - old_region.push_back(Range::FromMinExtent(old_indices.back(), Integer(1))); + old_region.push_back(Range::FromMinExtent(old_indices.back(), IntImm(DataType::Int(32), 1))); } ffi::Array& new_indices = (is_cache_read) ? write_access_indices : read_access_indices; Region& new_region = (is_cache_read) ? write_access_region : read_access_region; for (const PrimExpr& idx : info->indices) { new_indices.push_back(Substitute((idx), var_map)); - new_region.push_back(Range::FromMinExtent(new_indices.back(), Integer(1))); + new_region.push_back(Range::FromMinExtent(new_indices.back(), IntImm(DataType::Int(32), 1))); } // Create New Block @@ -562,7 +562,7 @@ static PrimExpr CollectNestedBlockPredicates(const Stmt& body, const Buffer& buf BufferIndexType index_type) { struct Collector : public StmtVisitor { Collector(const Buffer& buf, BufferIndexType idx_type) - : buffer_(buf), index_type_(idx_type), result_(Bool(false)), found_(false) {} + : buffer_(buf), index_type_(idx_type), result_(const_false()), found_(false) {} void VisitStmt_(const SBlockRealizeNode* realize) final { const SBlockNode* block = realize->block.get(); @@ -604,7 +604,7 @@ static PrimExpr CollectNestedBlockPredicates(const Stmt& body, const Buffer& buf collector(body); // If no nested block accessed the buffer, return true (no restriction — the caller // will fall back to the original scope-block reads / FullRegion path). - return collector.found_ ? collector.result_ : Bool(true); + return collector.found_ ? collector.result_ : const_true(); } /*! @@ -621,7 +621,7 @@ static PrimExpr CollectNestedBlockPredicates(const Stmt& body, const Buffer& buf BufferRegion RelaxBufferRegion(ScheduleState self, const BufferRegion& buffer_region, const StmtSRef& block_sref, const StmtSRef& dom_low_inclusive, const StmtSRef& dom_high_exclusive, - PrimExpr extra_predicate = Bool(true)) { + PrimExpr extra_predicate = const_true()) { SBlockRealize realize = GetSBlockRealize(self, block_sref); ffi::Map binding = GetBindings(realize); const Buffer& buffer = buffer_region->buffer; @@ -1089,7 +1089,7 @@ class ReindexCacheReadRewriter : public CacheReadRewriter { if (buf_region->buffer.same_as(info_->read_buffer)) { Region region; for (const PrimExpr index : new_indices_) { - region.push_back(Range::FromMinExtent(index, Integer(1))); + region.push_back(Range::FromMinExtent(index, IntImm(DataType::Int(32), 1))); } new_reads.push_back(BufferRegion(info_->write_buffer, region)); } else { @@ -1105,7 +1105,7 @@ class ReindexCacheReadRewriter : public CacheReadRewriter { if (source->buffer.same_as(info_->read_buffer)) { Region region; for (const PrimExpr index : new_indices_) { - region.push_back(Range::FromMinExtent(index, Integer(1))); + region.push_back(Range::FromMinExtent(index, IntImm(DataType::Int(32), 1))); } new_match_buffers.push_back(MatchBufferRegion(match_buffer_region->buffer, BufferRegion(info_->write_buffer, region))); @@ -1378,7 +1378,7 @@ class ReindexCacheWriteRewriter : public CacheWriteRewriter { if (buf_region->buffer.same_as(info_->write_buffer)) { Region region; for (const PrimExpr index : new_indices_) { - region.push_back(Range::FromMinExtent(index, Integer(1))); + region.push_back(Range::FromMinExtent(index, IntImm(DataType::Int(32), 1))); } new_reads.push_back(BufferRegion(info_->read_buffer, region)); } else { @@ -1394,7 +1394,7 @@ class ReindexCacheWriteRewriter : public CacheWriteRewriter { if (source->buffer.same_as(info_->write_buffer)) { Region region; for (const PrimExpr index : new_indices_) { - region.push_back(Range::FromMinExtent(index, Integer(1))); + region.push_back(Range::FromMinExtent(index, IntImm(DataType::Int(32), 1))); } new_match_buffers.push_back(MatchBufferRegion(match_buffer_region->buffer, BufferRegion(info_->read_buffer, region))); @@ -1781,7 +1781,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff GetBufferRegionFromBuffer(block->reads, read_buffer); PrimExpr nested_pred = read_region_opt ? CollectNestedBlockPredicates(block->body, read_buffer, BufferIndexType::kRead) - : Bool(true); + : const_true(); if (read_region_opt && !is_one(nested_pred) && block_sref->parent != nullptr) { StmtSRef parent_sref = ffi::GetRef(block_sref->parent); cache_region = RelaxBufferRegion(self, read_region_opt.value(), block_sref, parent_sref, diff --git a/src/s_tir/schedule/primitive/compute_at.cc b/src/s_tir/schedule/primitive/compute_at.cc index a611a1bee347..0affecd2d5c6 100644 --- a/src/s_tir/schedule/primitive/compute_at.cc +++ b/src/s_tir/schedule/primitive/compute_at.cc @@ -300,7 +300,7 @@ class ScopeReconstructor : private StmtMutator { const Var& loop_var = loop_vars[i]; const PrimExpr& loop_extent = loop_extents[i]; new_subtree = For(/*loop_var=*/loop_var, - /*min=*/Integer(0), + /*min=*/IntImm(DataType::Int(32), 0), /*extent=*/loop_extent, /*ForKind=*/ForKind::kSerial, /*body=*/std::move(new_subtree)); diff --git a/src/s_tir/schedule/primitive/compute_inline.cc b/src/s_tir/schedule/primitive/compute_inline.cc index 19cbe9217655..20043b720a39 100644 --- a/src/s_tir/schedule/primitive/compute_inline.cc +++ b/src/s_tir/schedule/primitive/compute_inline.cc @@ -625,7 +625,7 @@ class ReverseComputeInliner : public BaseInliner { producer_block_(producer_block), consumer_block_(consumer_block_realize->block.get()) { // Initialize the predicates to ensure consumer block iters are in-bound - consumer_iter_in_bound_ = Bool(true); + consumer_iter_in_bound_ = const_true(); for (const IterVar& iter : consumer_block_realize->block->iter_vars) { consumer_iter_in_bound_ = consumer_iter_in_bound_ && diff --git a/src/s_tir/schedule/primitive/layout_transformation.cc b/src/s_tir/schedule/primitive/layout_transformation.cc index d9c729dd9078..e9fa97772862 100644 --- a/src/s_tir/schedule/primitive/layout_transformation.cc +++ b/src/s_tir/schedule/primitive/layout_transformation.cc @@ -492,7 +492,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { std::stringstream block_name; block_name << "buffer_" << new_buffer->name << "_assumptions"; auto read_region = BufferRegion::FromPoint(new_buffer, indices); - stmt = SBlockRealize(iter_values, Bool(true), + stmt = SBlockRealize(iter_values, const_true(), SBlock(iter_vars, {read_region}, {}, block_name.str(), stmt)); for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) { @@ -1187,7 +1187,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ const SBlockNode* scope_block = TVM_SREF_TO_SBLOCK(scope_sref); ffi::Optional opt_inverse = std::nullopt; - PrimExpr padding_predicate = Bool(false); + PrimExpr padding_predicate = const_false(); if (!assume_injective_transform) { std::tie(opt_inverse, padding_predicate) = [&]() { ffi::Array region; diff --git a/src/s_tir/schedule/primitive/loop_transformation.cc b/src/s_tir/schedule/primitive/loop_transformation.cc index 14223925a3cb..87b5b4042f77 100644 --- a/src/s_tir/schedule/primitive/loop_transformation.cc +++ b/src/s_tir/schedule/primitive/loop_transformation.cc @@ -743,13 +743,13 @@ class LoopReconstructor : private StmtMutator { new_stmts.push_back(new_stmt); this->need_remove_loop_.push_back(loops_[i].back()); } - auto new_loop = For(new_loop_vars[0], Integer(0), new_loop_extents[0], ForKind::kSerial, + auto new_loop = For(new_loop_vars[0], IntImm(DataType::Int(32), 0), new_loop_extents[0], ForKind::kSerial, SeqStmt(std::move(new_stmts))); this->new_inner_loop_ = new_loop; for (size_t i = 1; i < new_loop_vars.size(); ++i) { const Var& loop_var = new_loop_vars[i]; const PrimExpr& loop_extent = new_loop_extents[i]; - new_loop = For(loop_var, Integer(0), loop_extent, ForKind::kSerial, new_loop); + new_loop = For(loop_var, IntImm(DataType::Int(32), 0), loop_extent, ForKind::kSerial, new_loop); } this->new_outer_loop_ = new_loop; } diff --git a/src/s_tir/schedule/primitive/pad_einsum.cc b/src/s_tir/schedule/primitive/pad_einsum.cc index b4f3f3a46b18..e805ff1e7df3 100644 --- a/src/s_tir/schedule/primitive/pad_einsum.cc +++ b/src/s_tir/schedule/primitive/pad_einsum.cc @@ -183,7 +183,7 @@ struct BufferPadding { } Stmt body{nullptr}; if (is_read) { - PrimExpr predicate = Bool(true); + PrimExpr predicate = const_true(); for (int i = 0; i < ndim; ++i) { if (!analyzer->CanProveEqual(buffer->shape[i], padded_buffer->shape[i])) { predicate = predicate && (indices[i] < buffer->shape[i]); @@ -203,7 +203,7 @@ struct BufferPadding { SBlock new_block(iter_vars, {read_region}, {write_region}, padded_buffer->name, std::move(body)); blocks->push_back(new_block); - body = SBlockRealize(ffi::Array{loop_vars.begin(), loop_vars.end()}, Bool(true), + body = SBlockRealize(ffi::Array{loop_vars.begin(), loop_vars.end()}, const_true(), new_block); for (int i = ndim - 1; i >= 0; --i) { body = For(loop_vars[i], loop_doms[i]->min, loop_doms[i]->extent, ForKind::kSerial, diff --git a/src/s_tir/schedule/primitive/read_write_at.cc b/src/s_tir/schedule/primitive/read_write_at.cc index 793927322598..73990add29b1 100644 --- a/src/s_tir/schedule/primitive/read_write_at.cc +++ b/src/s_tir/schedule/primitive/read_write_at.cc @@ -306,7 +306,7 @@ struct ReadWriteAtImpl { } Stmt stmt = BufferStore(copy_to, /*value=*/BufferLoad(copy_from, indices), /*indices=*/indices); for (int i = n - 1; i >= 0; --i) { - stmt = For(loop_vars[i], Integer(0), domain[i]->extent, ForKind::kSerial, stmt); + stmt = For(loop_vars[i], IntImm(DataType::Int(32), 0), domain[i]->extent, ForKind::kSerial, stmt); } return SBlockRealize( /*values=*/iter_values, diff --git a/src/s_tir/schedule/primitive/reduction.cc b/src/s_tir/schedule/primitive/reduction.cc index c4183e05f02a..dc900f94cdb6 100644 --- a/src/s_tir/schedule/primitive/reduction.cc +++ b/src/s_tir/schedule/primitive/reduction.cc @@ -158,8 +158,8 @@ class LoopHeightError : public ScheduleError { }; PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set& discarded_loops) { - if (is_one(pred)) return Bool(true); - PrimExpr new_pred = Bool(true); + if (is_one(pred)) return const_true(); + PrimExpr new_pred = const_true(); auto f = [&](const VarNode* var) { return discarded_loops.count(var); }; arith::PVar lhs, rhs, rest; for (;;) { diff --git a/src/s_tir/schedule/state.cc b/src/s_tir/schedule/state.cc index 1914f48bad08..6ddc3358106b 100644 --- a/src/s_tir/schedule/state.cc +++ b/src/s_tir/schedule/state.cc @@ -1013,11 +1013,11 @@ void ScheduleStateNode::UpdateScopeSBlockInfo(const Stmt& stmt) { SBlockInfoCollector::Collect(this, stmt); } -TVM_DLL ffi::Array GetCachedFlags(const ScheduleState& self, const StmtSRef& block_sref) { +TVM_DLL ffi::Array GetCachedFlags(const ScheduleState& self, const StmtSRef& block_sref) { const SBlockInfo& info = self->GetSBlockInfo(block_sref); - return {Bool(info.affine_binding), // - Bool(info.region_cover), // - Bool(info.stage_pipeline)}; + return {IntImm(DataType::Bool(), info.affine_binding), // + IntImm(DataType::Bool(), info.region_cover), // + IntImm(DataType::Bool(), info.stage_pipeline)}; } /**************** FFI ****************/ diff --git a/src/s_tir/schedule/trace.cc b/src/s_tir/schedule/trace.cc index 6e4de3c5f8ed..2702b4daa371 100644 --- a/src/s_tir/schedule/trace.cc +++ b/src/s_tir/schedule/trace.cc @@ -405,7 +405,7 @@ ffi::ObjectRef TraceNode::AsJSON(bool remove_postproc) const { Any decision = this->GetDecision(inst); if (decision != nullptr) { json_decisions.push_back(ffi::Array{ - /* 0: index */ Integer(i), + /* 0: index */ IntImm(DataType::Int(32), i), /* 1: decision */ decision, }); } diff --git a/src/s_tir/schedule/transform.cc b/src/s_tir/schedule/transform.cc index ff343a700825..ee273597c841 100644 --- a/src/s_tir/schedule/transform.cc +++ b/src/s_tir/schedule/transform.cc @@ -407,7 +407,7 @@ ffi::Optional TileWithTensorIntrin(const s_tir::Schedule& sch, // Do the split. Leave the outer extent as std::nullopt (unspecified) so that the split factors // can be used for different extents (needed during tuning). ffi::Array split = - sch->Split(loop2rv.at(block_loop_sref), {std::nullopt, Integer(inner)}); + sch->Split(loop2rv.at(block_loop_sref), {std::nullopt, IntImm(DataType::Int(32), inner)}); TVM_FFI_ICHECK_EQ(split.size(), 2); inner_loops.insert(sch->GetSRef(split[1]).operator->()); // The inner split will be reordered to the loop domain that is tensorized @@ -549,7 +549,7 @@ ffi::Optional NormalizePrimFunc(Schedule sch) { bool is_reduction = IsReductionBlock(sch->state(), // sch->GetSRef(block), // sch->GetSRef(root_block)); - block_is_reduction.push_back(Bool(is_reduction)); + block_is_reduction.push_back(IntImm(DataType::Bool(), is_reduction)); } return ffi::Array{leaf_blocks, block_loops, block_iters, block_is_reduction}; } diff --git a/src/s_tir/support/nd_int_set.h b/src/s_tir/support/nd_int_set.h index 03f3672b452d..c46aff83600d 100644 --- a/src/s_tir/support/nd_int_set.h +++ b/src/s_tir/support/nd_int_set.h @@ -51,7 +51,7 @@ inline NDIntSet NDIntSetFromRegion(const tirx::Region& region) { * \return The constructed set. */ inline NDIntSet NDIntSetFromShape(const ffi::Array& shape) { - PrimExpr zero = Integer(0); + PrimExpr zero = IntImm(DataType::Int(32), 0); NDIntSet result; result.reserve(shape.size()); for (const PrimExpr& extent : shape) { diff --git a/src/s_tir/transform/default_gpu_schedule.cc b/src/s_tir/transform/default_gpu_schedule.cc index cbcc4972033d..dddaed193b4c 100644 --- a/src/s_tir/transform/default_gpu_schedule.cc +++ b/src/s_tir/transform/default_gpu_schedule.cc @@ -72,13 +72,13 @@ void ThreadBind(s_tir::Schedule sch, const s_tir::SBlockRV& block, int64_t max_t if (product > max_thread_per_block * max_threadblocks) { ffi::Array splits = sch->Split( fused, - /*factors=*/{std::nullopt, Integer(max_threadblocks), Integer(max_thread_per_block)}); + /*factors=*/{std::nullopt, IntImm(DataType::Int(32), max_threadblocks), IntImm(DataType::Int(32), max_thread_per_block)}); sch->Reorder(/*ordered_loop_rvs=*/{splits[1], splits[2], splits[0]}); sch->Bind(splits[1], "blockIdx.x"); sch->Bind(splits[2], "threadIdx.x"); } else { ffi::Array splits = sch->Split( - fused, /*factors=*/{std::nullopt, Integer(std::min(product, max_thread_per_block))}); + fused, /*factors=*/{std::nullopt, IntImm(DataType::Int(32), std::min(product, max_thread_per_block))}); sch->Bind(splits[0], "blockIdx.x"); sch->Bind(splits[1], "threadIdx.x"); } @@ -146,7 +146,7 @@ tirx::PrimFunc WrapBareSBlockBody(const tirx::PrimFunc& func) { /*writes=*/ffi::Array{}, /*name_hint=*/"root", /*body=*/for_stmt); tirx::SBlockRealize root_realize(/*iter_values=*/ffi::Array{}, - /*predicate=*/tvm::Bool(true), root_block); + /*predicate=*/const_true(), root_block); tirx::PrimFunc result = func; result.CopyOnWrite()->body = std::move(root_realize); return result; diff --git a/src/s_tir/transform/inject_software_pipeline.cc b/src/s_tir/transform/inject_software_pipeline.cc index ba6c3bf666b2..f85918e511f5 100644 --- a/src/s_tir/transform/inject_software_pipeline.cc +++ b/src/s_tir/transform/inject_software_pipeline.cc @@ -246,7 +246,7 @@ class PipelineBodyRewriter : public StmtExprMutator { ? Range::FromMinExtent(0, new_buffer->shape[0]) : Range::FromMinExtent(floormod((pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]), - Integer(1)); + IntImm(DataType::Int(32), 1)); new_region.insert(new_region.begin(), accessed_version); return BufferRegion(new_buffer, new_region); } @@ -397,7 +397,7 @@ class PipelineRewriter : public StmtExprMutator { } SBlock block = MakeSBlock(stmt, buffer_data_to_buffer_); block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers); - return SBlockRealize({}, Bool(true), block); + return SBlockRealize({}, const_true(), block); } private: @@ -824,7 +824,7 @@ class PipelineRewriter : public StmtExprMutator { PrimExpr new_loop_var; PrimExpr extent = end - start; - auto make_nop = []() { return SBlockRealize({}, Bool(true), MakeSBlock(Evaluate(0), {})); }; + auto make_nop = []() { return SBlockRealize({}, const_true(), MakeSBlock(Evaluate(0), {})); }; if (analyzer_.CanProve(extent <= 0)) { return make_nop(); @@ -970,7 +970,7 @@ class PipelineRewriter : public StmtExprMutator { } } - return SBlockRealize({}, Bool(true), MakeSBlock(std::move(new_loop), buffer_data_to_buffer_)); + return SBlockRealize({}, const_true(), MakeSBlock(std::move(new_loop), buffer_data_to_buffer_)); } arith::Analyzer analyzer_; diff --git a/src/s_tir/transform/lower_cross_thread_reduction.cc b/src/s_tir/transform/lower_cross_thread_reduction.cc index ba7dd6962576..a07ecb5dd6eb 100644 --- a/src/s_tir/transform/lower_cross_thread_reduction.cc +++ b/src/s_tir/transform/lower_cross_thread_reduction.cc @@ -149,8 +149,8 @@ ffi::Array MakeScratchpads(const ffi::Array& reduction_buffers, name = name + "_thread_" + buffer->name; new_buffers.push_back(Buffer(/*ptr=*/Var(name, PointerType(PrimType(buffer->dtype), "local")), /*dtype=*/buffer->dtype, - /*shape=*/{Integer(1)}, - /*strides=*/{Integer(1)}, + /*shape=*/{IntImm(DataType::Int(32), 1)}, + /*strides=*/{IntImm(DataType::Int(32), 1)}, /*elem_offset=*/PrimExpr{nullptr}, /*name=*/name, /*data_alignment=*/0, @@ -336,7 +336,7 @@ Stmt TransformReductionBlock(const SBlockRealizeNode* realize, inits.reserve(n_buffers); for (int i = 0; i < n_buffers; ++i) { inits.push_back( - BufferStore(it_buffers.value()[i], reducer->identity_element[i], {Integer(0)})); + BufferStore(it_buffers.value()[i], reducer->identity_element[i], {IntImm(DataType::Int(32), 0)})); } stmts.push_back(SBlockRealize(/*iter_values=*/{}, /*predicate=*/const_true(), @@ -380,7 +380,7 @@ Stmt TransformReductionBlock(const SBlockRealizeNode* realize, // Next `n_buffers` arguments: sources if (it_buffers.defined()) { for (int i = 0; i < n_buffers; ++i) { - parameters.push_back(BufferLoad(it_buffers.value()[i], {Integer(0)})); + parameters.push_back(BufferLoad(it_buffers.value()[i], {IntImm(DataType::Int(32), 0)})); } } else { parameters.insert(parameters.end(), combiner_rhs.begin(), combiner_rhs.end()); @@ -465,7 +465,7 @@ Stmt TransformReductionBlock(const SBlockRealizeNode* realize, } for (int i = 0; i < n_buffers; ++i) { wb_updates.push_back( - BufferStore(wb_buffers[i], BufferLoad(ct_buffers[i], {Integer(0)}), wb_indices)); + BufferStore(wb_buffers[i], BufferLoad(ct_buffers[i], {IntImm(DataType::Int(32), 0)}), wb_indices)); wb_regions.push_back(BufferRegion(wb_buffers[i], region)); } diff --git a/src/s_tir/transform/lower_opaque_block.cc b/src/s_tir/transform/lower_opaque_block.cc index fad67115ecdb..99468d6f975d 100644 --- a/src/s_tir/transform/lower_opaque_block.cc +++ b/src/s_tir/transform/lower_opaque_block.cc @@ -80,7 +80,7 @@ class OpaqueBlockLower : public StmtExprMutator { std::vector> pragma_attrs; HandleAnnotations(new_block->annotations, &pragma_attrs, /*is_block=*/true); for (auto it = pragma_attrs.rbegin(); it != pragma_attrs.rend(); ++it) { - body = AttrStmt(Integer(0), it->first, it->second, std::move(body)); + body = AttrStmt(IntImm(DataType::Int(32), 0), it->first, it->second, std::move(body)); } return body; } diff --git a/src/s_tir/transform/memhammer_coalesce.cc b/src/s_tir/transform/memhammer_coalesce.cc index 52d00d88e6b6..7785cab3bfdd 100644 --- a/src/s_tir/transform/memhammer_coalesce.cc +++ b/src/s_tir/transform/memhammer_coalesce.cc @@ -191,7 +191,7 @@ Stmt InverseMapping::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, arith::Analyzer analyzer; DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule())); auto iter_map = - arith::DetectIterMap(mapping_pattern, var_range, Bool(true), arith::Bijective, &analyzer); + arith::DetectIterMap(mapping_pattern, var_range, const_true(), arith::Bijective, &analyzer); TVM_FFI_ICHECK_EQ(iter_map->indices.size(), loop_vars.size()); ffi::Map inverse_mapping = arith::InverseAffineIterMap(iter_map->indices, loop_vars); diff --git a/src/s_tir/transform/memhammer_intermediate_stage.cc b/src/s_tir/transform/memhammer_intermediate_stage.cc index 9baf203b911d..63e51cd7b8f9 100644 --- a/src/s_tir/transform/memhammer_intermediate_stage.cc +++ b/src/s_tir/transform/memhammer_intermediate_stage.cc @@ -131,7 +131,7 @@ class IndexPatternFinder : public ExprVisitor { switch (o.kind) { case Operator::OpKind::Mul: max *= o.operand; - index = index * Integer(o.operand); + index = index * IntImm(DataType::Int(32), o.operand); break; case Operator::OpKind::FloorDiv: if (max % o.operand != 0 && o.operand % max != 0) { @@ -146,7 +146,7 @@ class IndexPatternFinder : public ExprVisitor { success_ = false; return; } - index = floordiv(index, Integer(o.operand)); + index = floordiv(index, IntImm(DataType::Int(32), o.operand)); break; case Operator::OpKind::FloorMod: int64_t step = max / extent; @@ -161,12 +161,12 @@ class IndexPatternFinder : public ExprVisitor { extent = std::max(static_cast(1), std::min(extent, o.operand / step)); max = extent * step; } - index = floormod(index, Integer(o.operand)); + index = floormod(index, IntImm(DataType::Int(32), o.operand)); } } if (extent > 1) { TVM_FFI_ICHECK(max % extent == 0); - access_shape_.push_back(Integer(extent)); + access_shape_.push_back(IntImm(DataType::Int(32), extent)); resulting_index_->push_back(floordiv(index, max / extent)); } } diff --git a/src/s_tir/transform/memhammer_lower_auto_copy.cc b/src/s_tir/transform/memhammer_lower_auto_copy.cc index 478e12bba9c3..245c40318ae3 100644 --- a/src/s_tir/transform/memhammer_lower_auto_copy.cc +++ b/src/s_tir/transform/memhammer_lower_auto_copy.cc @@ -464,14 +464,14 @@ class AutoPadder { bool CheckVarContiguous(PrimExpr e, Var var, const ffi::Map& subst_map) { PrimExpr e1 = Substitute(e, [var](const Var& v) -> ffi::Optional { if (v.same_as(var)) { - return Integer(0); + return IntImm(DataType::Int(32), 0); } else { return v; } }); PrimExpr e2 = Substitute(e, [var](const Var& v) -> ffi::Optional { if (v.same_as(var)) { - return Integer(1); + return IntImm(DataType::Int(32), 1); } else { return v; } diff --git a/src/s_tir/transform/memhammer_tensorcore_rewrite.cc b/src/s_tir/transform/memhammer_tensorcore_rewrite.cc index ef046cf9fc42..1a4532b8a4aa 100644 --- a/src/s_tir/transform/memhammer_tensorcore_rewrite.cc +++ b/src/s_tir/transform/memhammer_tensorcore_rewrite.cc @@ -129,7 +129,7 @@ Stmt RewriteWmmaLoad(Stmt stmt) { Buffer new_src_buffer( /*data=*/Var("src", PointerType(PrimType(dtype), src_buffer.scope())), /*dtype=*/dtype, - /*shape=*/{Integer(16), Integer(16)}, + /*shape=*/{IntImm(DataType::Int(32), 16), IntImm(DataType::Int(32), 16)}, /*strides=*/{Var("s1", int32), Var("s0", int32)}, /*elem_offset=*/Var("src_elem_offset", int32), /*name=*/"src", @@ -139,7 +139,7 @@ Stmt RewriteWmmaLoad(Stmt stmt) { Buffer new_tgt_buffer( /*data=*/Var("tgt", PointerType(PrimType(dtype), tgt_buffer.scope())), /*dtype=*/dtype, - /*shape=*/{Integer(16), Integer(16)}, + /*shape=*/{IntImm(DataType::Int(32), 16), IntImm(DataType::Int(32), 16)}, /*strides=*/{}, /*elem_offset=*/Var("tgt_elem_offset", int32), /*name=*/"tgt", @@ -150,7 +150,7 @@ Stmt RewriteWmmaLoad(Stmt stmt) { ffi::Array write_region = RelaxIndices(buf_store->indices, tgt_buffer->shape, var_dom); Stmt wmma_body = SBlockRealize( /*iter_values=*/{}, - /*predicate=*/Bool(true), + /*predicate=*/const_true(), SBlock( /*iter_vars=*/{}, /*reads=*/{BufferRegion(src_buffer, read_region)}, @@ -238,7 +238,7 @@ Stmt RewriteWmmaStore(Stmt stmt) { Buffer new_src_buffer(/*data=*/Var("src", PointerType(PrimType(dtype), src_buffer.scope())), /*dtype=*/dtype, - /*shape=*/{Integer(16), Integer(16)}, + /*shape=*/{IntImm(DataType::Int(32), 16), IntImm(DataType::Int(32), 16)}, /*strides=*/{}, /*elem_offset=*/Var("src_elem_offset", int32), /*name=*/"src", @@ -247,7 +247,7 @@ Stmt RewriteWmmaStore(Stmt stmt) { /*buffer_type=*/kDefault); Buffer new_tgt_buffer(/*data=*/Var("tgt", PointerType(PrimType(dtype), tgt_buffer.scope())), /*dtype=*/dtype, - /*shape=*/{Integer(16), Integer(16)}, + /*shape=*/{IntImm(DataType::Int(32), 16), IntImm(DataType::Int(32), 16)}, /*strides=*/{Var("s1", int32), Var("s0", int32)}, /*elem_offset=*/Var("tgt_elem_offset", int32), /*name=*/"tgt", @@ -259,7 +259,7 @@ Stmt RewriteWmmaStore(Stmt stmt) { ffi::Array write_region = RelaxIndices(buf_store->indices, tgt_buffer->shape, var_dom); Stmt wmma_body = SBlockRealize( /*iter_values=*/{}, // - /*predicate=*/Bool(true), + /*predicate=*/const_true(), SBlock(/*iter_vars=*/{}, /*reads=*/{BufferRegion(src_buffer, read_region)}, /*writes=*/{BufferRegion(tgt_buffer, write_region)}, @@ -458,7 +458,7 @@ Stmt RewriteMmaStore(Stmt stmt) { const DataType dtype = src_buffer->dtype; Buffer new_src_buffer(/*data=*/Var("src", PointerType(PrimType(dtype), src_buffer.scope())), /*dtype=*/dtype, - /*shape=*/{Integer(8), Integer(8)}, + /*shape=*/{IntImm(DataType::Int(32), 8), IntImm(DataType::Int(32), 8)}, /*strides=*/{}, /*elem_offset=*/Var("src_elem_offset", int32), /*name=*/"src", @@ -467,7 +467,7 @@ Stmt RewriteMmaStore(Stmt stmt) { /*buffer_type=*/kDefault); Buffer new_tgt_buffer(/*data=*/Var("tgt", PointerType(PrimType(dtype), tgt_buffer.scope())), /*dtype=*/dtype, - /*shape=*/{Integer(8), Integer(8)}, + /*shape=*/{IntImm(DataType::Int(32), 8), IntImm(DataType::Int(32), 8)}, /*strides=*/{Var("s1", int32), Var("s0", int32)}, /*elem_offset=*/Var("tgt_elem_offset", int32), /*name=*/"tgt", @@ -486,7 +486,7 @@ Stmt RewriteMmaStore(Stmt stmt) { Var vec = Var("vec"); Stmt mma_body = SBlockRealize( /*iter_values=*/{}, // - /*predicate=*/Bool(true), + /*predicate=*/const_true(), SBlock(/*iter_vars=*/{}, /*reads=*/{BufferRegion(src_buffer, read_region)}, /*writes=*/{BufferRegion(tgt_buffer, write_region)}, @@ -498,7 +498,7 @@ Stmt RewriteMmaStore(Stmt stmt) { /*iter_type=*/IterVarType::kThreadIndex, /*thread_tag=*/"threadIdx.x"), /*attr_key=*/"thread_extent", - /*value=*/Integer(32), + /*value=*/IntImm(DataType::Int(32), 32), /*body=*/ For(vec, 0, 2, ForKind::kVectorized, /*body=*/ diff --git a/src/s_tir/transform/plan_update_buffer_allocation_location.cc b/src/s_tir/transform/plan_update_buffer_allocation_location.cc index c46947a093fa..e727f167b843 100644 --- a/src/s_tir/transform/plan_update_buffer_allocation_location.cc +++ b/src/s_tir/transform/plan_update_buffer_allocation_location.cc @@ -216,7 +216,7 @@ class BufferAllocationLocator : public StmtExprMutator { GetSBlockReadWriteRegion(opaque_block, buffer_data_to_buffer_); n->reads = access[0]; n->writes = access[1]; - SBlockRealize realize({}, Bool(true), SBlock(n)); + SBlockRealize realize({}, const_true(), SBlock(n)); return realize; } diff --git a/src/s_tir/transform/transform_mma_buffer_layout.cc b/src/s_tir/transform/transform_mma_buffer_layout.cc index ac4073e33f94..a6451286d108 100644 --- a/src/s_tir/transform/transform_mma_buffer_layout.cc +++ b/src/s_tir/transform/transform_mma_buffer_layout.cc @@ -68,7 +68,7 @@ class MmaBufferLayoutTransformer : public StmtExprMutator { new_shape.push_back(buffer->shape[i]); } new_shape.insert(new_shape.end(), - {Integer(dim0->value / 16), Integer(dim1->value / 8), 2, 2}); + {IntImm(DataType::Int(32), dim0->value / 16), IntImm(DataType::Int(32), dim1->value / 8), 2, 2}); Buffer new_buffer = decl_buffer(std::move(new_shape), buffer->dtype, buffer->name, "local", buffer->axis_separators); @@ -90,7 +90,7 @@ class MmaBufferLayoutTransformer : public StmtExprMutator { new_shape.push_back(buffer->shape[i]); } new_shape.insert(new_shape.end(), - {Integer(dim0->value / 32), Integer(dim1->value / 8), 4, 2}); + {IntImm(DataType::Int(32), dim0->value / 32), IntImm(DataType::Int(32), dim1->value / 8), 4, 2}); Buffer new_buffer = decl_buffer(std::move(new_shape), buffer->dtype, buffer->name, "local", buffer->axis_separators); @@ -112,7 +112,7 @@ class MmaBufferLayoutTransformer : public StmtExprMutator { new_shape.push_back(buffer->shape[i]); } new_shape.insert(new_shape.end(), - {Integer(dim0->value / 8), Integer(dim1->value / 32), 1, 8}); + {IntImm(DataType::Int(32), dim0->value / 8), IntImm(DataType::Int(32), dim1->value / 32), 1, 8}); Buffer new_buffer = decl_buffer(std::move(new_shape), buffer->dtype, buffer->name, "local", buffer->axis_separators); diff --git a/src/s_tir/transform/using_assume_to_reduce_branches.cc b/src/s_tir/transform/using_assume_to_reduce_branches.cc index daf72d54f310..672769949c03 100644 --- a/src/s_tir/transform/using_assume_to_reduce_branches.cc +++ b/src/s_tir/transform/using_assume_to_reduce_branches.cc @@ -177,7 +177,7 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { PrimExpr CurrentScopePredicate() const { /* This combines all the constraints in a scope */ - PrimExpr predicate = Bool(true); + PrimExpr predicate = const_true(); for (const auto& condition : conditions_) { predicate = predicate && condition; } @@ -281,7 +281,7 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { } void AssumeConstraintComponent(PrimExpr assumption) { - PrimExpr additional_predicate = Bool(true); + PrimExpr additional_predicate = const_true(); assume_struct buf_data; std::vector buffer_exprs; diff --git a/src/script/printer/ir/distributed.cc b/src/script/printer/ir/distributed.cc index 5abc316154e0..f2ca5d356693 100644 --- a/src/script/printer/ir/distributed.cc +++ b/src/script/printer/ir/distributed.cc @@ -29,7 +29,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ffi::Array results; results.reserve(s); for (int i = 0; i < s; ++i) { - results.push_back(d->AsDoc(Integer(n[i]), n_p->ArrayItem(i))); + results.push_back(d->AsDoc(IntImm(DataType::Int(32), n[i]), n_p->ArrayItem(i))); } return TupleDoc(results); }); diff --git a/src/target/cuda/codegen_cuda.cc b/src/target/cuda/codegen_cuda.cc index 27cad36735b6..a57e70ff5fc1 100644 --- a/src/target/cuda/codegen_cuda.cc +++ b/src/target/cuda/codegen_cuda.cc @@ -197,12 +197,12 @@ class ThreadIdxExtractor : public tirx::StmtVisitor { } public: - PrimExpr threadIdx_x_ext = Integer(1); - PrimExpr threadIdx_y_ext = Integer(1); - PrimExpr threadIdx_z_ext = Integer(1); - PrimExpr clusterCtaIdx_x_ext = Integer(1); - PrimExpr clusterCtaIdx_y_ext = Integer(1); - PrimExpr clusterCtaIdx_z_ext = Integer(1); + PrimExpr threadIdx_x_ext = IntImm(DataType::Int(32), 1); + PrimExpr threadIdx_y_ext = IntImm(DataType::Int(32), 1); + PrimExpr threadIdx_z_ext = IntImm(DataType::Int(32), 1); + PrimExpr clusterCtaIdx_x_ext = IntImm(DataType::Int(32), 1); + PrimExpr clusterCtaIdx_y_ext = IntImm(DataType::Int(32), 1); + PrimExpr clusterCtaIdx_z_ext = IntImm(DataType::Int(32), 1); bool is_persistent_kernel = false; }; diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index c8dee88794c8..e9e5ee233053 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -544,7 +545,7 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in Stmt body = GenerateBodyStmt(leaf.store_indices, buffers, leaf.axes_remap, expr_body, info, analyzer); seq_stmt.push_back(SBlockRealize(/*iter_values=*/leaf.bindings, - /*predicate=*/Bool(true), + /*predicate=*/IntImm(DataType::Bool(), 1), /*block=*/ SBlock(/*iter_vars=*/leaf.block_iters, /*reads=*/{}, @@ -566,7 +567,7 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in Stmt body = GenerateBodyStmt(leaf.store_indices, {buffers[i]}, leaf.axes_remap, expr_body, info, analyzer); seq_stmt.push_back(SBlockRealize(/*iter_values=*/leaf.bindings, - /*predicate=*/Bool(true), + /*predicate=*/IntImm(DataType::Bool(), 1), /*block=*/ SBlock(/*iter_vars=*/leaf.block_iters, /*reads=*/{}, @@ -599,7 +600,7 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in // wrap nested block body = SBlockRealize(/*iter_values=*/cur.bindings, - /*predicate=*/Bool(true), + /*predicate=*/IntImm(DataType::Bool(), 1), /*block=*/ SBlock(/*iter_vars=*/block_iters, /*reads=*/{}, @@ -659,7 +660,7 @@ Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* inf // Step 4. Generate opaque block as body. return SBlockRealize(/*iter_values=*/{}, - /*predicate=*/Bool(true), + /*predicate=*/IntImm(DataType::Bool(), 1), /*block=*/ SBlock(/*iter_vars=*/{}, /*reads=*/{}, diff --git a/src/tirx/ir/expr.cc b/src/tirx/ir/expr.cc index d458c091807b..ffb48e883548 100644 --- a/src/tirx/ir/expr.cc +++ b/src/tirx/ir/expr.cc @@ -692,7 +692,7 @@ PrimExpr Shuffle::Concat(ffi::Array vectors, Span span) { } PrimExpr Shuffle::ExtractElement(PrimExpr vector, int index, Span span) { - return Shuffle({vector}, {Integer(index)}, span); + return Shuffle({vector}, {IntImm(DataType::Int(32), index)}, span); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/tirx/ir/index_map.cc b/src/tirx/ir/index_map.cc index 46885bd16c3a..43b379351a70 100644 --- a/src/tirx/ir/index_map.cc +++ b/src/tirx/ir/index_map.cc @@ -67,7 +67,7 @@ std::pair IndexMapInverseImpl(const IndexMap& self, // return the pre-defined inverse index map if exists. In this // case, the user-defined inverse is assumed to be correct and // bijective. - PrimExpr padding_predicate = Bool(false); + PrimExpr padding_predicate = IntImm(DataType::Bool(), 0); return {Downcast(self->inverse_index_map.value()), padding_predicate}; } diff --git a/src/tirx/ir/script/script_complete.cc b/src/tirx/ir/script/script_complete.cc index f9e213190a54..b986597c8e63 100644 --- a/src/tirx/ir/script/script_complete.cc +++ b/src/tirx/ir/script/script_complete.cc @@ -28,6 +28,7 @@ #include #include #include +#include #include @@ -153,7 +154,7 @@ PrimFunc ScriptComplete(PrimFunc func, const ffi::Array& root_allocates, if (s_tir && should_insert_root) { SBlock root_block({}, {}, {}, "root", std::move(res), std::nullopt, root_allocates); - res = SBlockRealize({}, Bool(true), std::move(root_block)); + res = SBlockRealize({}, IntImm(DataType::Bool(), 1), std::move(root_block)); } // generate surrounding loops automatically diff --git a/src/tirx/script/builder/frame.cc b/src/tirx/script/builder/frame.cc index 7c36da768b96..6d7628ad7e99 100644 --- a/src/tirx/script/builder/frame.cc +++ b/src/tirx/script/builder/frame.cc @@ -195,7 +195,7 @@ void SBlockFrameNode::ExitWithScope() { << "`T.where` is not allowed when `no_realize=True`"; AddToParent(block); } else { - AddToParent(tvm::tirx::SBlockRealize(iter_values, predicate.value_or(Bool(true)), block)); + AddToParent(tvm::tirx::SBlockRealize(iter_values, predicate.value_or(IntImm(DataType::Bool(), 1)), block)); } } diff --git a/src/tirx/transform/lower_device_kernel_launch.cc b/src/tirx/transform/lower_device_kernel_launch.cc index af30af6bfb37..29a9c02e4351 100644 --- a/src/tirx/transform/lower_device_kernel_launch.cc +++ b/src/tirx/transform/lower_device_kernel_launch.cc @@ -142,7 +142,7 @@ class DeviceInfoCollector : public StmtVisitor { << "Only one dynamic shared memory allocation is allowed."; TVM_FFI_ICHECK_GT(op->buffer->shape.size(), 0); - PrimExpr dyn_size = Integer(1); + PrimExpr dyn_size = IntImm(DataType::Int(32), 1); for (const auto& extent : op->buffer->shape) { dyn_size *= extent; } diff --git a/src/tirx/transform/lower_tvm_builtin.cc b/src/tirx/transform/lower_tvm_builtin.cc index cf3c53f37dcb..f522f46c6c27 100644 --- a/src/tirx/transform/lower_tvm_builtin.cc +++ b/src/tirx/transform/lower_tvm_builtin.cc @@ -45,7 +45,7 @@ class BuiltinLower : public StmtExprMutator { static PrimFunc Build(PrimFunc func) { ffi::Optional device_type = std::nullopt; if (auto target = func->GetAttr(tvm::attr::kTarget)) { - device_type = Integer(target.value()->kind->default_device_type); + device_type = IntImm(DataType::Int(32), target.value()->kind->default_device_type); } BuiltinLower mutator(device_type); diff --git a/src/tirx/transform/make_packed_api.cc b/src/tirx/transform/make_packed_api.cc index 7d3c2e29bf6e..c7125d82fa3b 100644 --- a/src/tirx/transform/make_packed_api.cc +++ b/src/tirx/transform/make_packed_api.cc @@ -268,7 +268,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { } // Return error code of zero on success - body = SeqStmt({body, Evaluate(ret(Integer(0)))}); + body = SeqStmt({body, Evaluate(ret(IntImm(DataType::Int(32), 0)))}); body = MergeNest({std::move(result.init_nest), seq_check, std::move(result.asserts), std::move(result.decl_buffers)}, diff --git a/src/topi/einsum.cc b/src/topi/einsum.cc index 53f604defd96..2a807b7e261e 100644 --- a/src/topi/einsum.cc +++ b/src/topi/einsum.cc @@ -109,7 +109,7 @@ PrimExpr GetBroadcastedExtent(const PrimExpr& extent1, const PrimExpr& extent2) if (extent1_imm->value == extent2_imm->value) { return extent1; } else if (extent1_imm->value == 1 || extent2_imm->value == 1) { - return Integer(std::max(extent1_imm->value, extent2_imm->value)); + return IntImm(DataType::Int(32), std::max(extent1_imm->value, extent2_imm->value)); } TVM_FFI_THROW(InternalError) << "Cannot broadcast extents " << extent1 << " and " << extent2; throw; From ee41dab6c835e0448e9ef48539b161a50857dd66 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 28 May 2026 12:21:15 +0000 Subject: [PATCH 2/5] [REFACTOR][SCHEDULE] Replace Integer() trace boxing with IntImm in Schedule and MetaSchedule Schedule trace attrs and MetaSchedule decision payloads boxed ints into Integer() to pass through ffi::Any. IntImm(DataType::Int(32), N) is the correct canonical form for integer constants in IR position; Integer() was a redundant wrapper. Replace across traced_schedule.cc, concrete_schedule, and ~15 MetaSchedule files. Also change MultiLevelTilingWideVector's vector_length_in_bits parameter from Integer to int64_t in both the header declaration and implementation, matching the underlying field type. --- .../tvm/s_tir/meta_schedule/schedule_rule.h | 2 +- .../meta_schedule/database/json_database.cc | 2 +- .../meta_schedule/mutator/mutate_parallel.cc | 2 +- .../postproc/rewrite_cooperative_fetch.cc | 16 +++--- .../rewrite_parallel_vectorize_unroll.cc | 2 +- .../postproc/rewrite_unbound_block.cc | 2 +- .../meta_schedule/postproc/verify_gpu_code.cc | 4 +- .../schedule/cuda/thread_bind.cc | 4 +- .../meta_schedule/schedule/cuda/winograd.cc | 4 +- .../schedule_rule/add_rfactor.cc | 2 +- .../schedule_rule/multi_level_tiling.cc | 4 +- .../multi_level_tiling_tensor_core.cc | 16 +++--- .../multi_level_tiling_wide_vector.cc | 4 +- .../parallel_vectorize_unroll.cc | 4 +- src/s_tir/schedule/concrete_schedule.cc | 4 +- src/s_tir/schedule/concrete_schedule.h | 4 +- src/s_tir/schedule/traced_schedule.cc | 56 +++++++++---------- 17 files changed, 66 insertions(+), 66 deletions(-) diff --git a/include/tvm/s_tir/meta_schedule/schedule_rule.h b/include/tvm/s_tir/meta_schedule/schedule_rule.h index de4d212db36d..e1964628e369 100644 --- a/include/tvm/s_tir/meta_schedule/schedule_rule.h +++ b/include/tvm/s_tir/meta_schedule/schedule_rule.h @@ -231,7 +231,7 @@ class ScheduleRule : public ffi::ObjectRef { * \return The schedule rule created */ TVM_DLL static ScheduleRule MultiLevelTilingWideVector( - ffi::String structure, Integer vector_length_in_bits, + ffi::String structure, int64_t vector_length_in_bits, ffi::Optional max_innermost_factor, ffi::Optional> reuse_read, ffi::Optional> reuse_write); diff --git a/src/s_tir/meta_schedule/database/json_database.cc b/src/s_tir/meta_schedule/database/json_database.cc index 8705412fa28e..a6c656f5098b 100644 --- a/src/s_tir/meta_schedule/database/json_database.cc +++ b/src/s_tir/meta_schedule/database/json_database.cc @@ -116,7 +116,7 @@ class JSONDatabaseNode : public DatabaseNode { this->tuning_records_.insert(record); JSONFileAppendLine(this->path_tuning_record, JSONDumps(ffi::Array{ - /*workload_index=*/Integer(this->workloads2idx_.at(record->workload)), + /*workload_index=*/IntImm(DataType::Int(32), this->workloads2idx_.at(record->workload)), /*tuning_record=*/record->AsJSON() // })); } diff --git a/src/s_tir/meta_schedule/mutator/mutate_parallel.cc b/src/s_tir/meta_schedule/mutator/mutate_parallel.cc index d3f74554c741..7e3b20fc3dea 100644 --- a/src/s_tir/meta_schedule/mutator/mutate_parallel.cc +++ b/src/s_tir/meta_schedule/mutator/mutate_parallel.cc @@ -54,7 +54,7 @@ bool IsAnnotateWithParallel(const Instruction& inst) { Instruction ReplaceAnnValue(Instruction inst, int64_t ann_val) { TVM_FFI_ICHECK_EQ(inst->inputs.size(), 2); return Instruction(/*kind=*/inst->kind, // - /*inputs=*/{inst->inputs[0], Integer(ann_val)}, // + /*inputs=*/{inst->inputs[0], IntImm(DataType::Int(32), ann_val)}, // /*attrs=*/inst->attrs, /*outputs=*/inst->outputs); } diff --git a/src/s_tir/meta_schedule/postproc/rewrite_cooperative_fetch.cc b/src/s_tir/meta_schedule/postproc/rewrite_cooperative_fetch.cc index d3a860eb0512..c27f3196b03b 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_cooperative_fetch.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_cooperative_fetch.cc @@ -199,29 +199,29 @@ bool RewriteCooperativeFetchNode::Apply(const s_tir::Schedule& sch) { if (thread_extent_y != -1) { if (vector_lane > 1) { ffi::Array split = sch->Split(fused, {std::nullopt, // - Integer(thread_extent_y), // - Integer(thread_extent_x), // - Integer(vector_lane)}); + IntImm(DataType::Int(32), thread_extent_y), // + IntImm(DataType::Int(32), thread_extent_x), // + IntImm(DataType::Int(32), vector_lane)}); sch->Vectorize(split[3]); sch->Bind(split[2], "threadIdx.x"); sch->Bind(split[1], "threadIdx.y"); } else { ffi::Array split = sch->Split(fused, {std::nullopt, // - Integer(thread_extent_y), // - Integer(thread_extent_x)}); + IntImm(DataType::Int(32), thread_extent_y), // + IntImm(DataType::Int(32), thread_extent_x)}); sch->Bind(split[2], "threadIdx.x"); sch->Bind(split[1], "threadIdx.y"); } } else { if (vector_lane > 1) { ffi::Array split = sch->Split(fused, {std::nullopt, // - Integer(thread_extent_x), // - Integer(vector_lane)}); + IntImm(DataType::Int(32), thread_extent_x), // + IntImm(DataType::Int(32), vector_lane)}); sch->Vectorize(split[2]); sch->Bind(split[1], "threadIdx.x"); } else { ffi::Array split = - sch->Split(fused, {std::nullopt, Integer(thread_extent_x)}); + sch->Split(fused, {std::nullopt, IntImm(DataType::Int(32), thread_extent_x)}); sch->Bind(split[1], "threadIdx.x"); } } diff --git a/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index b4e89b6bb79e..b77355ee3bb2 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -380,7 +380,7 @@ void RewriteFuseSplitParallelVectorize(const Schedule& sch, ffi::Array* int vec_len) { size_t n_loops = loop_rvs->size(); LoopRV fused = sch->Fuse({loop_rvs->begin(), loop_rvs->end()}); - ffi::Array split = sch->Split(fused, {std::nullopt, Integer(vec_len)}); + ffi::Array split = sch->Split(fused, {std::nullopt, IntImm(DataType::Int(32), vec_len)}); TVM_FFI_ICHECK_EQ(split.size(), 2); const LoopRV& outer = split[0]; const LoopRV& inner = split[1]; diff --git a/src/s_tir/meta_schedule/postproc/rewrite_unbound_block.cc b/src/s_tir/meta_schedule/postproc/rewrite_unbound_block.cc index 14bf177ec3a7..002dc62612f2 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_unbound_block.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_unbound_block.cc @@ -127,7 +127,7 @@ bool RewriteUnboundBlockNode::Apply(const s_tir::Schedule& sch) { using s_tir::Schedule; TVM_FFI_ICHECK_NE(this->max_threads_per_block_, -1); auto get_factor = [t = this->max_threads_per_block_](int max_extent) -> ExprRV { - return Integer(std::min(t, max_extent)); + return IntImm(DataType::Int(32), std::min(t, max_extent)); }; std::vector> unbound_blocks = s_tir::UnboundBlockFinder::Find(sch->state()); diff --git a/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc b/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc index e8d9b8e85627..ee67d7275f9c 100644 --- a/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc @@ -129,8 +129,8 @@ class VerifyGPUCodeNode : public PostprocNode { this->target_constraints_ = ffi::Map{ {"max_shared_memory_per_block", Extract(this->target_, "max_shared_memory_per_block")}, {"max_threads_per_block", Extract(this->target_, "max_threads_per_block")}, - {"max_vthread", Integer(8)}, - {"max_vector_bytes", Integer(16)}, + {"max_vthread", IntImm(DataType::Int(32), 8)}, + {"max_vector_bytes", IntImm(DataType::Int(32), 16)}, }; thread_warp_size_ = Extract(this->target_, "thread_warp_size").IntValue(); } diff --git a/src/s_tir/meta_schedule/schedule/cuda/thread_bind.cc b/src/s_tir/meta_schedule/schedule/cuda/thread_bind.cc index 32659488739d..0fa916786787 100644 --- a/src/s_tir/meta_schedule/schedule/cuda/thread_bind.cc +++ b/src/s_tir/meta_schedule/schedule/cuda/thread_bind.cc @@ -86,8 +86,8 @@ ffi::Array BindSpatialLoop(Schedule sch, LoopRV loop, int64_t max_thread return {splits[0], splits[1]}; } else { ffi::Array splits = sch->Split(loop, {std::nullopt, - Integer(max_threadblocks), // - Integer(max_threads_per_block)}); + IntImm(DataType::Int(32), max_threadblocks), // + IntImm(DataType::Int(32), max_threads_per_block)}); TVM_FFI_ICHECK_EQ(splits.size(), 3); sch->Reorder({splits[1], splits[2], splits[0]}); sch->Bind(splits[1], "blockIdx.x"); diff --git a/src/s_tir/meta_schedule/schedule/cuda/winograd.cc b/src/s_tir/meta_schedule/schedule/cuda/winograd.cc index 5a75e000d6d9..7beaca5698a7 100644 --- a/src/s_tir/meta_schedule/schedule/cuda/winograd.cc +++ b/src/s_tir/meta_schedule/schedule/cuda/winograd.cc @@ -150,8 +150,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { SBlockRV output = sch->GetConsumers(inverse)[0]; ffi::Array nchw = sch->GetLoops(output); TVM_FFI_ICHECK_EQ(nchw.size(), 4); - ffi::Array hs = sch->Split(nchw[2], {std::nullopt, Integer(tile_size)}); - ffi::Array ws = sch->Split(nchw[3], {std::nullopt, Integer(tile_size)}); + ffi::Array hs = sch->Split(nchw[2], {std::nullopt, IntImm(DataType::Int(32), tile_size)}); + ffi::Array ws = sch->Split(nchw[3], {std::nullopt, IntImm(DataType::Int(32), tile_size)}); sch->Reorder({hs[0], ws[0], hs[1], ws[1]}); outer = ws[0]; } diff --git a/src/s_tir/meta_schedule/schedule_rule/add_rfactor.cc b/src/s_tir/meta_schedule/schedule_rule/add_rfactor.cc index 933b41dbb169..e5436e5efc41 100644 --- a/src/s_tir/meta_schedule/schedule_rule/add_rfactor.cc +++ b/src/s_tir/meta_schedule/schedule_rule/add_rfactor.cc @@ -115,7 +115,7 @@ ffi::Array AddRFactorNode::Apply(const s_tir::Schedule& sch, // Annotate that the rfactor block, which is now the producer of the original block, needs to // be considered by the rule Random-Compute-Location. - sch_tmp->Annotate(block_rv, s_tir::attr::meta_schedule_random_compute_producer, Integer(1)); + sch_tmp->Annotate(block_rv, s_tir::attr::meta_schedule_random_compute_producer, IntImm(DataType::Int(32), 1)); res.push_back(sch_tmp); } catch (const tvm::ffi::Error& e) { } diff --git a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.cc index 4471e877c13b..2360e2f538f1 100644 --- a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -284,9 +284,9 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state, low_inclusive = this->thread_warp_size_; } sch->Annotate(block_rv, s_tir::attr::meta_schedule_thread_extent_low_inclusive, - Integer(low_inclusive)); + IntImm(DataType::Int(32), low_inclusive)); sch->Annotate(block_rv, s_tir::attr::meta_schedule_thread_extent_high_inclusive, - Integer(high_inclusive)); + IntImm(DataType::Int(32), high_inclusive)); } return {state}; } diff --git a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 674dc4de13bc..68bdf960734b 100644 --- a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -425,9 +425,9 @@ std::vector MultiLevelTilingTensorCoreNode::MMATileLoopNest(TensorCoreSta low_inclusive = this->thread_warp_size_; } sch->Annotate(block_rv, s_tir::attr::meta_schedule_thread_extent_low_inclusive, - Integer(low_inclusive)); + IntImm(DataType::Int(32), low_inclusive)); sch->Annotate(block_rv, s_tir::attr::meta_schedule_thread_extent_high_inclusive, - Integer(high_inclusive)); + IntImm(DataType::Int(32), high_inclusive)); } return {state}; } @@ -668,15 +668,15 @@ std::vector MultiLevelTilingTensorCoreNode::AddSoftwarePipeline( const s_tir::SBlockRV cache_read = state->read_reuse.at(i); if (state->is_mma) { // Add vector bytes for memhammer - sch->Annotate(cache_read, s_tir::attr::vector_bytes, Integer(16)); + sch->Annotate(cache_read, s_tir::attr::vector_bytes, IntImm(DataType::Int(32), 16)); if (!state->use_async) { - sch->Annotate(cache_read, s_tir::attr::local_stage, Integer(1)); - sch->Annotate(cache_read, s_tir::attr::double_buffer_scope, Integer(0)); + sch->Annotate(cache_read, s_tir::attr::local_stage, IntImm(DataType::Int(32), 1)); + sch->Annotate(cache_read, s_tir::attr::double_buffer_scope, IntImm(DataType::Int(32), 0)); } } else { // Add local stage and double buffering - sch->Annotate(cache_read, s_tir::attr::manifest_shared_memory_local_stage, Integer(1)); - sch->Annotate(cache_read, s_tir::attr::double_buffer_scope, Integer(0)); + sch->Annotate(cache_read, s_tir::attr::manifest_shared_memory_local_stage, IntImm(DataType::Int(32), 1)); + sch->Annotate(cache_read, s_tir::attr::double_buffer_scope, IntImm(DataType::Int(32), 0)); } } @@ -908,7 +908,7 @@ inline std::vector MultiLevelTilingTensorCoreNode::TransformForTensorizat state->intrin_group.compute_intrin); state->sch->Annotate(state->block_rv, s_tir::attr::meta_schedule_auto_tensorize_init, state->intrin_group.init_intrin); - state->sch->Annotate(state->block_rv, s_tir::attr::warp_execution, Integer(1)); + state->sch->Annotate(state->block_rv, s_tir::attr::warp_execution, IntImm(DataType::Int(32), 1)); return {std::move(state)}; } diff --git a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc index 271ede9fec72..1dee2fe1d007 100644 --- a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc +++ b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc @@ -125,13 +125,13 @@ MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& sch, SBlockRV block_rv } ScheduleRule ScheduleRule::MultiLevelTilingWideVector( - ffi::String structure, Integer vector_length_in_bits, + ffi::String structure, int64_t vector_length_in_bits, ffi::Optional max_innermost_factor, ffi::Optional> reuse_read, ffi::Optional> reuse_write) { auto node = MultiLevelTilingInitCommon( structure, std::nullopt, max_innermost_factor, std::nullopt, reuse_read, reuse_write); - node->vector_length_in_bits = vector_length_in_bits->value; + node->vector_length_in_bits = vector_length_in_bits; return ScheduleRule(node); } diff --git a/src/s_tir/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc b/src/s_tir/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc index 7fc8d57f8138..8115dc91a8ee 100644 --- a/src/s_tir/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc +++ b/src/s_tir/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc @@ -64,11 +64,11 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { // Parallelization if (max_jobs_per_core != -1) { sch->Annotate(root_rv, s_tir::attr::meta_schedule_parallel, - Integer(this->max_parallel_extent_)); + IntImm(DataType::Int(32), this->max_parallel_extent_)); } // Vectorization if (max_vectorize_extent != -1) { - sch->Annotate(root_rv, s_tir::attr::meta_schedule_vectorize, Integer(max_vectorize_extent)); + sch->Annotate(root_rv, s_tir::attr::meta_schedule_vectorize, IntImm(DataType::Int(32), max_vectorize_extent)); } // Unroll if (!unroll_max_steps.empty() && !s_tir::CheckSpatialPrimFunc(sch, root_rv)) { diff --git a/src/s_tir/schedule/concrete_schedule.cc b/src/s_tir/schedule/concrete_schedule.cc index 9520332b1502..e69b5b2f47e7 100644 --- a/src/s_tir/schedule/concrete_schedule.cc +++ b/src/s_tir/schedule/concrete_schedule.cc @@ -488,7 +488,7 @@ ffi::Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, // infer factor if needed and check validity of factors for (size_t i = 0; i < factor_rvs.size(); i++) { if (!factor_rvs[i].defined()) { - factors.push_back(Integer(-1)); + factors.push_back(IntImm(DataType::Int(32), -1)); if (infer_index != -1) { throw NotSingleInferFactorError(state_->mod); } @@ -555,7 +555,7 @@ ffi::Array ConcreteScheduleNode::LoopPartition( // infer factor if needed and check validity of factors for (size_t i = 0; i < factor_rvs.size(); i++) { if (!factor_rvs[i].defined()) { - factors.push_back(Integer(-1)); + factors.push_back(IntImm(DataType::Int(32), -1)); if (infer_index != -1) { throw NotSingleInferFactorError(state_->mod); } diff --git a/src/s_tir/schedule/concrete_schedule.h b/src/s_tir/schedule/concrete_schedule.h index 428ab0188dd8..848965208cc3 100644 --- a/src/s_tir/schedule/concrete_schedule.h +++ b/src/s_tir/schedule/concrete_schedule.h @@ -268,7 +268,7 @@ inline PrimExpr ConcreteScheduleNode::Get(const ExprRV& expr_rv) const { } const ffi::ObjectRef& obj = (*it).second; const auto* int_imm = TVM_TYPE_AS(obj, IntImmNode); - return Integer(int_imm->value); + return IntImm(DataType::Int(32), int_imm->value); }); return this->analyzer_->Simplify(transformed); } @@ -370,7 +370,7 @@ inline T ConcreteScheduleNode::CreateRV(const StmtSRef& sref) { inline ExprRV ConcreteScheduleNode::CreateRV(int64_t value) { Var rv("v" + std::to_string(this->symbol_table_.size() + 1), DataType::Int(32)); - this->symbol_table_.Set(rv, Integer(static_cast(value))); + this->symbol_table_.Set(rv, IntImm(DataType::Int(32), static_cast(value))); return rv; } diff --git a/src/s_tir/schedule/traced_schedule.cc b/src/s_tir/schedule/traced_schedule.cc index 9dd113b446d8..5ee3c377cc31 100644 --- a/src/s_tir/schedule/traced_schedule.cc +++ b/src/s_tir/schedule/traced_schedule.cc @@ -78,7 +78,7 @@ ffi::Array TracedScheduleNode::SamplePerfectTile( static const InstructionKind& kind = InstructionKind::Get("SamplePerfectTile"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // /*inputs=*/{loop_rv}, - /*attrs=*/{Integer(n), Integer(max_innermost_factor)}, + /*attrs=*/{IntImm(DataType::Int(32), n), IntImm(DataType::Int(32), max_innermost_factor)}, /*outputs=*/results), /*decision=*/decision); return results; @@ -94,7 +94,7 @@ ffi::Array TracedScheduleNode::SamplePartitionedTile( trace_->Append(/*inst=*/Instruction( /*kind=*/kind, // /*inputs=*/{loop_rv}, - /*attrs=*/{Integer(n), Integer(partition_pos), Integer(innerpart_factor)}, + /*attrs=*/{IntImm(DataType::Int(32), n), IntImm(DataType::Int(32), partition_pos), IntImm(DataType::Int(32), innerpart_factor)}, /*outputs=*/results), /*decision=*/decision); return results; @@ -223,7 +223,7 @@ LoopRV TracedScheduleNode::Fuse(const ffi::Array& loop_rvs, bool preserv static const InstructionKind& kind = InstructionKind::Get("Fuse"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, /*inputs=*/loop_rvs, - /*attrs=*/{Integer(preserve_unit_loops)}, + /*attrs=*/{IntImm(DataType::Int(32), preserve_unit_loops)}, /*outputs=*/{result})); return result; } @@ -266,7 +266,7 @@ ffi::Array TracedScheduleNode::LoopPartition( static const InstructionKind& kind = InstructionKind::Get("LoopPartition"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, /*inputs=*/inputs, - /*attrs=*/{Integer(preserve_unit_iters)}, + /*attrs=*/{IntImm(DataType::Int(32), preserve_unit_iters)}, /*outputs=*/results)); return results; } @@ -364,7 +364,7 @@ SBlockRV TracedScheduleNode::CacheRead(const SBlockRV& block_rv, int read_buffer static const InstructionKind& kind = InstructionKind::Get("CacheRead"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, /*inputs=*/{block_rv, consumer_blocks}, - /*attrs=*/{Integer(read_buffer_index), storage_scope}, + /*attrs=*/{IntImm(DataType::Int(32), read_buffer_index), storage_scope}, /*outputs=*/{result})); return result; } @@ -378,7 +378,7 @@ SBlockRV TracedScheduleNode::CacheWrite(const SBlockRV& block_rv, int write_buff static const InstructionKind& kind = InstructionKind::Get("CacheWrite"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, /*inputs=*/{block_rv, consumer_blocks}, - /*attrs=*/{Integer(write_buffer_index), storage_scope}, + /*attrs=*/{IntImm(DataType::Int(32), write_buffer_index), storage_scope}, /*outputs=*/{result})); return result; } @@ -394,7 +394,7 @@ SBlockRV TracedScheduleNode::ReindexCacheRead(const SBlockRV& block_rv, int read /*inst=*/Instruction( /*kind=*/kind, /*inputs=*/{block_rv, index_map}, - /*attrs=*/{Integer(read_buffer_index), storage_scope}, + /*attrs=*/{IntImm(DataType::Int(32), read_buffer_index), storage_scope}, /*outputs=*/{result})); return result; } @@ -410,7 +410,7 @@ SBlockRV TracedScheduleNode::ReindexCacheWrite(const SBlockRV& block_rv, int wri /*inst=*/Instruction( /*kind=*/kind, /*inputs=*/{block_rv, index_map}, - /*attrs=*/{Integer(write_buffer_index), storage_scope}, + /*attrs=*/{IntImm(DataType::Int(32), write_buffer_index), storage_scope}, /*outputs=*/{result})); return result; } @@ -427,7 +427,7 @@ ffi::Array TracedScheduleNode::CacheInplace(const SBlockRV& block_rv, static const InstructionKind& kind = InstructionKind::Get("CacheInplace"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, /*inputs=*/{block_rv}, - /*attrs=*/{Integer(read_buffer_index), storage_scope}, + /*attrs=*/{IntImm(DataType::Int(32), read_buffer_index), storage_scope}, /*outputs=*/results)); return result; } @@ -444,7 +444,7 @@ ffi::Array TracedScheduleNode::CacheIndex(const SBlockRV& block_rv, static const InstructionKind& kind = InstructionKind::Get("CacheIndex"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, /*inputs=*/{block_rv}, - /*attrs=*/{storage_scope, Integer(cse_thresh)}, + /*attrs=*/{storage_scope, IntImm(DataType::Int(32), cse_thresh)}, /*outputs=*/outputs)); return result; } @@ -456,7 +456,7 @@ SBlockRV TracedScheduleNode::ReIndex(const SBlockRV& block_rv, int buffer_index, static const InstructionKind& kind = InstructionKind::Get("ReIndex"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, /*inputs=*/{block_rv}, - /*attrs=*/{Integer(buffer_index), Integer(buffer_index_type)}, + /*attrs=*/{IntImm(DataType::Int(32), buffer_index), IntImm(DataType::Int(32), static_cast(buffer_index_type))}, /*outputs=*/{result})); return result; } @@ -471,7 +471,7 @@ SBlockRV TracedScheduleNode::ReadAt(const LoopRV& loop_rv, const SBlockRV& block static const InstructionKind& kind = InstructionKind::Get("ReadAt"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, /*inputs=*/{loop_rv, block_rv}, - /*attrs=*/{Integer(read_buffer_index), storage_scope}, + /*attrs=*/{IntImm(DataType::Int(32), read_buffer_index), storage_scope}, /*outputs=*/{result})); return result; } @@ -484,7 +484,7 @@ SBlockRV TracedScheduleNode::WriteAt(const LoopRV& loop_rv, const SBlockRV& bloc static const InstructionKind& kind = InstructionKind::Get("WriteAt"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, /*inputs=*/{loop_rv, block_rv}, - /*attrs=*/{Integer(write_buffer_index), storage_scope}, + /*attrs=*/{IntImm(DataType::Int(32), write_buffer_index), storage_scope}, /*outputs=*/{result})); return result; } @@ -499,7 +499,7 @@ void TracedScheduleNode::ComputeAt(const SBlockRV& block_rv, const LoopRV& loop_ trace_->Append( /*inst=*/Instruction(/*kind=*/kind, /*inputs=*/{block_rv, loop_rv}, - /*attrs=*/{Integer(preserve_unit_loops), Integer(index)}, + /*attrs=*/{IntImm(DataType::Int(32), preserve_unit_loops), IntImm(DataType::Int(32), index)}, /*outputs=*/{})); } @@ -510,7 +510,7 @@ void TracedScheduleNode::ReverseComputeAt(const SBlockRV& block_rv, const LoopRV static const InstructionKind& kind = InstructionKind::Get("ReverseComputeAt"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, /*inputs=*/{block_rv, loop_rv}, - /*attrs=*/{Integer(preserve_unit_loops), Integer(index)}, + /*attrs=*/{IntImm(DataType::Int(32), preserve_unit_loops), IntImm(DataType::Int(32), index)}, /*outputs=*/{})); } @@ -562,7 +562,7 @@ SBlockRV TracedScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { static const InstructionKind& kind = InstructionKind::Get("RFactor"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, /*inputs=*/{loop_rv}, - /*attrs=*/{Integer(factor_axis)}, + /*attrs=*/{IntImm(DataType::Int(32), factor_axis)}, /*outputs=*/{result})); return result; } @@ -576,7 +576,7 @@ void TracedScheduleNode::StorageAlign(const SBlockRV& block_rv, int buffer_index trace_->Append(/*inst=*/Instruction( /*kind=*/kind, /*inputs=*/{block_rv}, - /*attrs=*/{Integer(buffer_index), Integer(axis), Integer(factor), Integer(offset)}, + /*attrs=*/{IntImm(DataType::Int(32), buffer_index), IntImm(DataType::Int(32), axis), IntImm(DataType::Int(32), factor), IntImm(DataType::Int(32), offset)}, /*outputs=*/{})); } @@ -587,7 +587,7 @@ void TracedScheduleNode::SetScope(const SBlockRV& block_rv, int buffer_index, trace_->Append(/*inst=*/Instruction( /*kind=*/kind, /*inputs=*/{block_rv}, - /*attrs=*/{Integer(buffer_index), storage_scope}, + /*attrs=*/{IntImm(DataType::Int(32), buffer_index), storage_scope}, /*outputs=*/{})); } @@ -598,7 +598,7 @@ void TracedScheduleNode::UnsafeSetDType(const SBlockRV& block_rv, int buffer_ind trace_->Append(/*inst=*/Instruction( /*kind=*/kind, /*inputs=*/{block_rv}, - /*attrs=*/{Integer(buffer_index), dtype}, + /*attrs=*/{IntImm(DataType::Int(32), buffer_index), dtype}, /*outputs=*/{})); } @@ -610,7 +610,7 @@ SBlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv, bool preserve_unit_ trace_->Append(/*inst=*/Instruction( /*kind=*/kind, /*inputs=*/{loop_rv}, - /*attrs=*/{Bool(preserve_unit_iters)}, + /*attrs=*/{IntImm(DataType::Bool(), preserve_unit_iters)}, /*outputs=*/{new_block})); return new_block; } @@ -622,7 +622,7 @@ SBlockRV TracedScheduleNode::Blockize(const ffi::Array& blocks, trace_->Append(/*inst=*/Instruction( /*kind=*/kind, /*inputs=*/{blocks}, - /*attrs=*/{Bool(preserve_unit_iters)}, + /*attrs=*/{IntImm(DataType::Bool(), preserve_unit_iters)}, /*outputs=*/{new_block})); return new_block; } @@ -634,7 +634,7 @@ void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const ffi::String& int trace_->Append(/*inst=*/Instruction( /*kind=*/kind, /*inputs=*/{loop_rv}, - /*attrs=*/{intrin, Bool(preserve_unit_iters)}, + /*attrs=*/{intrin, IntImm(DataType::Bool(), preserve_unit_iters)}, /*outputs=*/{})); } @@ -645,7 +645,7 @@ void TracedScheduleNode::Tensorize(const SBlockRV& block_rv, const ffi::String& trace_->Append(/*inst=*/Instruction( /*kind=*/kind, /*inputs=*/{block_rv}, - /*attrs=*/{intrin, Bool(preserve_unit_iters)}, + /*attrs=*/{intrin, IntImm(DataType::Bool(), preserve_unit_iters)}, /*outputs=*/{})); } @@ -704,8 +704,8 @@ void TracedScheduleNode::TransformLayout(const SBlockRV& block_rv, int buffer_in /*kind=*/kind, /*inputs=*/{block_rv, index_map}, /*attrs=*/ - {Integer(buffer_index), Integer(buffer_index_type), pad_value, - Bool(assume_injective_transform)}, + {IntImm(DataType::Int(32), buffer_index), IntImm(DataType::Int(32), static_cast(buffer_index_type)), pad_value, + IntImm(DataType::Bool(), assume_injective_transform)}, /*outputs=*/{})); } @@ -728,7 +728,7 @@ void TracedScheduleNode::SetAxisSeparator(const SBlockRV& block_rv, int buffer_i trace_->Append(/*inst=*/Instruction( /*kind=*/kind, /*inputs=*/{block_rv}, - /*attrs=*/{Integer(buffer_index), Integer(buffer_index_type), axis_separators}, + /*attrs=*/{IntImm(DataType::Int(32), buffer_index), IntImm(DataType::Int(32), static_cast(buffer_index_type)), axis_separators}, /*outputs=*/{})); } @@ -762,7 +762,7 @@ void TracedScheduleNode::RollingBuffer(const SBlockRV& block_rv, int write_buffe trace_->Append(/*inst=*/Instruction( /*kind=*/kind, /*inputs=*/{block_rv}, - /*attrs=*/{Integer(write_buffer_index)}, + /*attrs=*/{IntImm(DataType::Int(32), write_buffer_index)}, /*outputs=*/{})); } @@ -796,7 +796,7 @@ void TracedScheduleNode::AnnotateBufferAccess(const SBlockRV& block_rv, int buff static const InstructionKind& kind = InstructionKind::Get("AnnotateBufferAccess"); trace_->Append(/*inst=*/Instruction( /*kind=*/kind, - /*inputs=*/{block_rv, Integer(buffer_index), Integer(buffer_index_type), index_map}, + /*inputs=*/{block_rv, IntImm(DataType::Int(32), buffer_index), IntImm(DataType::Int(32), static_cast(buffer_index_type)), index_map}, /*attrs=*/{}, /*outputs=*/{})); } From dffc9b982603fd727382cd2f732c05e5aaa92754 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 28 May 2026 12:49:34 +0000 Subject: [PATCH 3/5] [REFACTOR][TOPI] Migrate topi container signatures + stray Integer/Bool callers - Change ffi::Array to ffi::Array in topi headers (strided_slice, reduction, transform, nn, utils, nn/group_norm, nn/instance_norm, nn/layer_norm, nn/rms_norm) and corresponding sources; update internal usages (.IntValue(), .defined(), ->value, ->dtype) to plain int64_t arithmetic - Migrate ShardingNode::sharding_dim and CalculateConstantBytes/ CalculateWorkspaceBytes signatures from Integer to int64_t - Migrate codegen_c.h constants_byte_alignment_ field from Integer to int64_t - Replace Downcast with Downcast in schedule primitives, metaschedule helpers, and misc transform passes - Convert remaining Integer/Bool local variables and function return types (clml codegen, make_packed_api, infer_layout_utils, etc.) to IntImm or plain int64_t - Update stale doc-comment "Type: Integer/Bool" annotations in function.h and tirx/function.h to "Type: IntImm" - Fix test files: nested_msg_test.cc (NestedMsg), ir_functor_test.cc, arith_simplify_test.cc --- include/tvm/ir/function.h | 4 +- include/tvm/relax/distributed/struct_info.h | 2 +- include/tvm/relax/nested_msg.h | 2 +- include/tvm/tirx/analysis.h | 5 +- include/tvm/tirx/function.h | 10 +- include/tvm/topi/detail/strided_slice.h | 43 +++----- include/tvm/topi/nn.h | 60 +++++----- include/tvm/topi/nn/group_norm.h | 6 +- include/tvm/topi/nn/instance_norm.h | 2 +- include/tvm/topi/nn/layer_norm.h | 2 +- include/tvm/topi/nn/rms_norm.h | 2 +- include/tvm/topi/reduction.h | 26 ++--- include/tvm/topi/transform.h | 103 +++++++++--------- include/tvm/topi/utils.h | 6 +- src/relax/backend/contrib/clml/codegen.cc | 6 +- src/relax/distributed/axis_group_graph.cc | 8 +- src/relax/ir/expr.cc | 5 +- src/relax/op/tensor/index.cc | 2 +- src/relax/transform/dataflow_inplace.cc | 4 +- src/relax/transform/infer_layout_utils.h | 4 +- .../reorder_permute_dims_after_concat.cc | 4 +- .../transform/split_call_tir_by_pattern.cc | 2 +- src/s_tir/meta_schedule/arg_info.cc | 2 +- .../meta_schedule/mutator/mutate_tile_size.cc | 2 +- .../postproc/rewrite_cooperative_fetch.cc | 2 +- .../meta_schedule/postproc/rewrite_layout.cc | 6 +- .../meta_schedule/postproc/verify_gpu_code.cc | 6 +- .../space_generator/space_generator.cc | 4 +- src/s_tir/schedule/instruction_traits.h | 8 +- .../primitive/annotate_buffer_access.cc | 6 +- .../schedule/primitive/block_annotate.cc | 16 +-- .../schedule/primitive/blockize_tensorize.cc | 20 ++-- src/s_tir/schedule/primitive/cache_index.cc | 4 +- .../schedule/primitive/cache_read_write.cc | 28 ++--- src/s_tir/schedule/primitive/compute_at.cc | 16 +-- .../primitive/layout_transformation.cc | 24 ++-- .../schedule/primitive/loop_transformation.cc | 32 +++--- src/s_tir/schedule/primitive/read_write_at.cc | 8 +- src/s_tir/schedule/primitive/reduction.cc | 4 +- .../primitive/reorder_block_iter_var.cc | 4 +- .../schedule/primitive/rolling_buffer.cc | 6 +- src/s_tir/schedule/primitive/sampling.cc | 12 +- .../transform/inject_software_pipeline.cc | 2 +- src/s_tir/transform/memhammer_coalesce.cc | 2 +- .../transform/memhammer_lower_auto_copy.cc | 4 +- src/s_tir/transform/memhammer_rewrite_rule.h | 4 +- src/target/cuda/codegen_cuda.cc | 14 +-- src/target/source/codegen_c.h | 4 +- src/target/target.cc | 2 +- src/target/target_kind.cc | 4 +- src/target/vulkan/codegen_spirv.cc | 2 +- src/tirx/ir/data_type_rewriter.cc | 2 +- .../transform/force_narrow_index_to_i32.cc | 2 +- src/tirx/transform/lower_tvm_builtin.cc | 2 +- src/tirx/transform/make_packed_api.cc | 2 +- src/tirx/transform/unroll_loop.cc | 4 +- src/topi/nn.cc | 14 +-- src/topi/reduction.cc | 2 +- src/topi/transform.cc | 18 +-- tests/cpp/arith_simplify_test.cc | 4 +- tests/cpp/ir_functor_test.cc | 4 +- tests/cpp/nested_msg_test.cc | 80 +++++++------- 62 files changed, 343 insertions(+), 347 deletions(-) diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index a03233b6d076..b0b9d06b5954 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -89,7 +89,7 @@ namespace attr { /*! * \brief Indicates the special calling convention. * - * Type: Integer + * Type: IntImm * * \sa tvm::CallingConv */ @@ -131,7 +131,7 @@ constexpr const char* kGlobalSymbol = "global_symbol"; * and printer emits `s_tir=True` on the decorator. * Default (attr absent or False) is tirx semantics. * - * Type: Bool + * Type: IntImm (bool dtype) */ constexpr const char* kSTir = "s_tir"; diff --git a/include/tvm/relax/distributed/struct_info.h b/include/tvm/relax/distributed/struct_info.h index f663c9145091..81fdf0fb3ffc 100644 --- a/include/tvm/relax/distributed/struct_info.h +++ b/include/tvm/relax/distributed/struct_info.h @@ -71,7 +71,7 @@ class PlacementSpec : public ffi::ObjectRef { class ShardingNode : public PlacementSpecNode { public: /*! \brief The dimension of tensor we shard*/ - Integer sharding_dim; + int64_t sharding_dim; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; diff --git a/include/tvm/relax/nested_msg.h b/include/tvm/relax/nested_msg.h index 20495e00102b..4b11e9d2b043 100644 --- a/include/tvm/relax/nested_msg.h +++ b/include/tvm/relax/nested_msg.h @@ -157,7 +157,7 @@ class NestedMsg { } // delete the int constructor - // since NestedMsg(0) is ambiguous + // since NestedMsg(0) is ambiguous // 0 can be implicitly casted to nullptr_t explicit NestedMsg(int val) = delete; NestedMsg& operator=(int val) = delete; diff --git a/include/tvm/tirx/analysis.h b/include/tvm/tirx/analysis.h index 66378503b60f..1279455c8e2b 100644 --- a/include/tvm/tirx/analysis.h +++ b/include/tvm/tirx/analysis.h @@ -160,7 +160,7 @@ TVM_DLL size_t CalculateExprComplexity(const PrimExpr& expr); * \param func The TIR PrimFunc for which the constants size to be calculated * \param constant_byte_alignment The byte alignment required for each constant allocated */ -TVM_DLL size_t CalculateConstantBytes(const PrimFunc& func, const Integer& constant_byte_alignment); +TVM_DLL size_t CalculateConstantBytes(const PrimFunc& func, int64_t constant_byte_alignment); /*! * \brief Calculate the workspace size in bytes needed by the TIR allocates inside the TIR PrimFunc @@ -168,8 +168,7 @@ TVM_DLL size_t CalculateConstantBytes(const PrimFunc& func, const Integer& const * \param workspace_byte_alignment The byte alignment required for each tensor allocated in this * workspace */ -TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func, - const Integer& workspace_byte_alignment); +TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func, int64_t workspace_byte_alignment); /*! * \brief Verify if the given TIR is well-formed. The verification includes: diff --git a/include/tvm/tirx/function.h b/include/tvm/tirx/function.h index 45a8600a6ee4..38c78db872f9 100644 --- a/include/tvm/tirx/function.h +++ b/include/tvm/tirx/function.h @@ -310,7 +310,7 @@ constexpr const char* kKernelLaunchParams = "tirx.kernel_launch_params"; /*! * \brief Whether to set noalias rule on the function arguments. * - * Type: Integer + * Type: IntImm */ constexpr const char* kNoAlias = "tirx.noalias"; @@ -318,7 +318,7 @@ constexpr const char* kNoAlias = "tirx.noalias"; * \brief Mark the function as the entry function of * the final generated runtime module. * - * Type: Integer + * Type: IntImm * * \note There can only be one entry function per module. */ @@ -327,21 +327,21 @@ constexpr const char* kIsEntryFunc = "tirx.is_entry_func"; /*! * \brief Mark the function as the global function called from the host. * - * Type: Integer + * Type: IntImm */ constexpr const char* kIsGlobalFunc = "tirx.is_global_func"; /*! * \brief Mark the function as run on the host, mutually exclusive with kTarget. * - * Type: Integer + * Type: IntImm */ constexpr const char* kIsHostFunc = "tirx.is_host_func"; /*! * \brief Mark the function as scheduled, so the default schedule will pass will skip it. * - * Type: Integer + * Type: IntImm */ constexpr const char* kIsScheduled = "tirx.is_scheduled"; diff --git a/include/tvm/topi/detail/strided_slice.h b/include/tvm/topi/detail/strided_slice.h index e70b1542d4a4..b85908a88ba9 100644 --- a/include/tvm/topi/detail/strided_slice.h +++ b/include/tvm/topi/detail/strided_slice.h @@ -50,39 +50,30 @@ inline int64_t CanonicalizeIndex(int64_t index, int64_t extent, int64_t stride) } inline std::tuple, std::vector, std::vector> ConvertToVec( - const ffi::Array& begin, const ffi::Array& end, - const ffi::Array& strides, std::string slice_mode) { + const ffi::Array& begin, const ffi::Array& end, + const ffi::Array& strides, std::string slice_mode) { std::vector stride_vec(strides.size(), 1); if (slice_mode == "end") { for (size_t i = 0; i < strides.size(); ++i) { - TVM_FFI_ICHECK(strides[i].defined()); - stride_vec[i] = GetConstInt(strides[i]); + stride_vec[i] = strides[i]; } } const int64_t max_range = std::numeric_limits::max(); std::vector begin_vec; for (size_t i = 0; i < begin.size(); ++i) { - if (!begin[i].defined()) { - // value=None - begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); - } else { - begin_vec.push_back(GetConstInt(begin[i])); - } + begin_vec.push_back(begin[i]); } std::vector end_vec; for (size_t i = 0; i < end.size(); ++i) { - // allow end to be None - if (!end[i].defined()) { - end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); - } else if (slice_mode == "size") { - int64_t end_val = GetConstInt(end[i]); + if (slice_mode == "size") { + int64_t end_val = end[i]; if (end_val < 0) { end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); } else { end_vec.push_back(begin_vec[i] + end_val); } } else { - end_vec.push_back(GetConstInt(end[i])); + end_vec.push_back(end[i]); } } return std::make_tuple(begin_vec, end_vec, stride_vec); @@ -91,17 +82,18 @@ inline std::tuple, std::vector, std::vector StridedSliceCanonicalizeBegin(const ffi::Array& ishape, const std::vector& begin, const std::vector& strides, - const ffi::Array& axes, + const ffi::Array& axes, DataType dtype, std::string slice_mode = "end") { ffi::Array begin_expr; for (size_t i = 0; i < axes.size(); ++i) { - if (ishape[axes[i].IntValue()]->IsInstance()) { - int64_t dim_i = GetConstInt(ishape[axes[i].IntValue()]); + int64_t ax = axes[i]; + if (ishape[ax]->IsInstance()) { + int64_t dim_i = GetConstInt(ishape[ax]); int64_t begin_i = CanonicalizeIndex(begin[i], dim_i, strides[i]); begin_expr.push_back(make_const(dtype, begin_i)); } else { - auto idim = ishape[axes[i].IntValue()]; + auto idim = ishape[ax]; auto b_expr = make_const(dtype, begin[i]); PrimExpr b = begin[i] < 0 ? b_expr + idim : b_expr; auto s = strides[i]; @@ -119,7 +111,7 @@ inline ffi::Array StridedSliceCanonicalizeBegin(const ffi::Array StridedSliceOutputShape( const ffi::Array& ishape, const std::vector& begin, const std::vector& end, const std::vector& strides, - const ffi::Array& axes, std::string slice_mode, + const ffi::Array& axes, std::string slice_mode, const ffi::Array& begin_canonicalized, bool use_any = false) { TVM_FFI_ICHECK(!use_any) << "StridedSliceOutputShape does not legacy use_any"; const size_t src_tensor_dim = ishape.size(); @@ -129,8 +121,9 @@ inline ffi::Array StridedSliceOutputShape( } for (size_t i = 0; i < axes.size(); ++i) { - if (ishape[axes[i].IntValue()]->IsInstance()) { - const int64_t dim_i = GetConstInt(ishape[axes[i].IntValue()]); + int64_t ax = axes[i]; + if (ishape[ax]->IsInstance()) { + const int64_t dim_i = GetConstInt(ishape[ax]); TVM_FFI_ICHECK(begin_canonicalized[i]->IsInstance()); int64_t begin_i = GetConstInt(begin_canonicalized[i]); int64_t end_i = CanonicalizeIndex(end[i], dim_i, strides[i]); @@ -139,9 +132,9 @@ inline ffi::Array StridedSliceOutputShape( static_cast((interval + std::abs(strides[i]) - 1) / std::abs(strides[i])); TVM_FFI_ICHECK(strides[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i)) << ": Input [Begin=" << begin[i] << ", End=" << end[i] << "] is invalid for axis=" << i; - out_shape.Set(axes[i].IntValue(), cast(out_shape[i].dtype(), PrimExpr(slice_size))); + out_shape.Set(ax, cast(out_shape[i].dtype(), PrimExpr(slice_size))); } else { - out_shape.Set(axes[i].IntValue(), tvm::tirx::Var("dim", out_shape[i]->dtype)); + out_shape.Set(ax, tvm::tirx::Var("dim", out_shape[i]->dtype)); } } diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index 226dd88511f8..dd8e03aeac5a 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -481,7 +481,7 @@ inline tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor& I, const tvm::t * \return A Tensor whose op member is the space_to_batch_nd operation */ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data, - const tvm::ffi::Array& block_shape, + const tvm::ffi::Array& block_shape, const tvm::ffi::Array& pad_before, const tvm::ffi::Array& pad_after, PrimExpr pad_value = PrimExpr(), @@ -516,7 +516,7 @@ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data, // infer shapes tvm::ffi::Array r_shape; - tvm::ffi::Array axis; + tvm::ffi::Array axis; tvm::ffi::Array o_shape; size_t num_block_dims = block_shape.size(); @@ -526,7 +526,7 @@ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data, for (size_t i = 1; i <= num_block_dims; i++) { int padded_input = static_cast(GetConstInt(padded_shape[i])); - int block_size = static_cast(GetConstInt(block_shape[i - 1])); + int block_size = static_cast(block_shape[i - 1]); TVM_FFI_ICHECK_EQ((padded_input % block_size), 0) << "(" << i << ")th " @@ -534,26 +534,28 @@ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data, << padded_input << ")" << " must be divisible by its block size (" << block_size << ")"; - r_shape.push_back(div(padded_shape[i], block_shape[i - 1])); - r_shape.push_back(block_shape[i - 1]); - block_shape_prod *= block_shape[i - 1]; - axis.push_back(IntImm(DataType::Int(32), r_shape.size() - 1)); // index of block_shape[i - 1] + PrimExpr bs = IntImm(DataType::Int(64), block_shape[i - 1]); + r_shape.push_back(div(padded_shape[i], bs)); + r_shape.push_back(bs); + block_shape_prod *= bs; + axis.push_back(static_cast(r_shape.size() - 1)); // index of block_shape[i - 1] } size_t n = axis.size(); axis.push_back(0); // batch is at index 0 // index of (padded_shape[i] / block_shape[i - 1]) in r_shape for (size_t i = 0; i < n; i++) { - axis.push_back(static_cast(GetConstInt(axis[i] - 1))); + axis.push_back(axis[i] - 1); } o_shape.push_back(tvm::PrimExpr(batch) * block_shape_prod); for (size_t i = 1; i <= num_block_dims; i++) { - o_shape.push_back(div(padded_shape[i], block_shape[i - 1])); + PrimExpr bs = IntImm(DataType::Int(64), block_shape[i - 1]); + o_shape.push_back(div(padded_shape[i], bs)); } // append remaining shape for (size_t i = num_block_dims + 1; i < input_shape.size(); i++) { r_shape.push_back(input_shape[i]); - axis.push_back(IntImm(DataType::Int(32), r_shape.size() - 1)); // index of remaining shape in r_shape + axis.push_back(static_cast(r_shape.size() - 1)); // index of remaining shape in r_shape o_shape.push_back(input_shape[i]); } @@ -577,7 +579,7 @@ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data, * \return A Tensor whose op member is the batch_to_space_nd operation */ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, - const tvm::ffi::Array& block_shape, + const tvm::ffi::Array& block_shape, const tvm::ffi::Array& crop_begin_list, const tvm::ffi::Array& crop_end_list, std::string name = "batch_to_space_nd", @@ -585,23 +587,24 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, // Construct shapes for reshape and transpose operation ffi::Array in_shape = data->shape; ffi::Array r_shape; - ffi::Array axis; + ffi::Array axis; size_t num_block_dims = block_shape.size(); size_t num_input_dims = in_shape.size(); tvm::PrimExpr block_shape_prod(1); int batch = static_cast(GetConstInt(in_shape[0])); for (size_t i = 0; i < num_block_dims; i++) { - r_shape.push_back(block_shape[i]); - block_shape_prod *= block_shape[i]; + PrimExpr bs = IntImm(DataType::Int(64), block_shape[i]); + r_shape.push_back(bs); + block_shape_prod *= bs; } - axis.push_back(IntImm(DataType::Int(32), r_shape.size())); // axis of (batch / block_shape_prod) + axis.push_back(static_cast(r_shape.size())); // axis of (batch / block_shape_prod) r_shape.push_back(batch / block_shape_prod); for (size_t i = 1; i < num_input_dims; i++) { - axis.push_back(IntImm(DataType::Int(32), r_shape.size())); // axis of in_shape[i] + axis.push_back(static_cast(r_shape.size())); // axis of in_shape[i] if (axis.size() < (num_block_dims + num_input_dims)) { - axis.push_back(IntImm(DataType::Int(32), r_shape.size() - (num_block_dims + 1))); // axis of block_shape[i] + axis.push_back(static_cast(r_shape.size() - (num_block_dims + 1))); // axis of block_shape[i] } r_shape.push_back(in_shape[i]); } @@ -609,7 +612,8 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, ffi::Array r_p_shape; r_p_shape.push_back(batch / block_shape_prod); for (size_t i = 1; i <= num_block_dims; i++) { - r_p_shape.push_back(in_shape[i] * block_shape[i - 1]); + PrimExpr bs = IntImm(DataType::Int(64), block_shape[i - 1]); + r_p_shape.push_back(in_shape[i] * bs); } for (size_t i = num_block_dims + 1; i < num_input_dims; i++) { r_p_shape.push_back(in_shape[i]); @@ -621,14 +625,14 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, out = reshape(out, r_p_shape); // Crop the start and end of dimensions of out - ffi::Array begin_idx, end_idx, strides; + ffi::Array begin_idx, end_idx, strides; for (size_t i = 0; i < r_p_shape.size(); ++i) { - strides.push_back(IntImm(DataType::Int(32), 1)); + strides.push_back(int64_t(1)); if (i > 0 && i <= num_block_dims) { // prepare begin and end index for spatial dimensions - int begin_i = static_cast(GetConstInt(crop_begin_list[i - 1])); - int end_i = static_cast(GetConstInt(crop_end_list[i - 1])); - int out_i = static_cast(GetConstInt(r_p_shape[i])); + int64_t begin_i = GetConstInt(crop_begin_list[i - 1]); + int64_t end_i = GetConstInt(crop_end_list[i - 1]); + int64_t out_i = GetConstInt(r_p_shape[i]); TVM_FFI_ICHECK_GT(out_i, (begin_i + end_i)) << "Incorrect crop sizes for (" << i << ")th dim, can not crop more than" << " output size" << out_i << " vs " << (begin_i + end_i); @@ -636,8 +640,8 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, end_idx.push_back(out_i - end_i); } else { // ignore the batch and remaining dimension - begin_idx.push_back(IntImm(DataType::Int(32), 0)); - end_idx.push_back(static_cast(GetConstInt(r_p_shape[i]))); + begin_idx.push_back(int64_t(0)); + end_idx.push_back(GetConstInt(r_p_shape[i])); } } @@ -710,10 +714,10 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T tvm::tirx::make_const(predictions->dtype, 0)); }, name, tag); - return topi::divide(topi::sum(T, tvm::ffi::Array(nullptr)), - topi::sum(W, tvm::ffi::Array(nullptr))); + return topi::divide(topi::sum(T, tvm::ffi::Array(nullptr)), + topi::sum(W, tvm::ffi::Array(nullptr))); } else if (reduction == "sum") { - return topi::sum(T, tvm::ffi::Array(nullptr)); + return topi::sum(T, tvm::ffi::Array(nullptr)); } else { // reduction == "none" return T; } diff --git a/include/tvm/topi/nn/group_norm.h b/include/tvm/topi/nn/group_norm.h index b0e71c7cf777..1f1ac91867af 100644 --- a/include/tvm/topi/nn/group_norm.h +++ b/include/tvm/topi/nn/group_norm.h @@ -37,7 +37,7 @@ namespace nn { using namespace tvm::te; inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& beta, - int num_groups, int channel_axis, const ffi::Array& axes, + int num_groups, int channel_axis, const ffi::Array& axes, double epsilon, std::string name = "T_group_norm", std::string tag = kInjective) { const auto& data_type = data->dtype; @@ -50,7 +50,7 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& bool is_float16 = data_type == DataType::Float(16); // reshape data C -> G, C/G int ndim = data->shape.size(); - channel_axis = GetRealAxis(static_cast(ndim), ffi::Array({channel_axis}))[0]; + channel_axis = GetRealAxis(static_cast(ndim), ffi::Array({channel_axis}))[0]; auto shape = data->shape; auto group_size = floordiv(shape[channel_axis], num_groups); @@ -82,7 +82,7 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& // get the new axes to normalize after reshape std::vector new_axes{channel_axis + 1}; for (auto axis : axes) { - int new_axis = GetRealAxis(static_cast(ndim), ffi::Array({axis}))[0]; + int new_axis = GetRealAxis(static_cast(ndim), ffi::Array({axis}))[0]; if (new_axis < channel_axis) { new_axes.push_back(new_axis); } else if (new_axis > channel_axis) { diff --git a/include/tvm/topi/nn/instance_norm.h b/include/tvm/topi/nn/instance_norm.h index 66baf3e2f5c1..48fcf23904d5 100644 --- a/include/tvm/topi/nn/instance_norm.h +++ b/include/tvm/topi/nn/instance_norm.h @@ -51,7 +51,7 @@ using namespace tvm::te; * \return The normalized tensor, with the same shape as data. */ inline Tensor instance_norm(const Tensor& data, const Tensor& gamma, const Tensor& beta, - int channel_axis, const ffi::Array& axis, double epsilon, + int channel_axis, const ffi::Array& axis, double epsilon, std::string name = "T_instance_norm", std::string tag = kInjective) { const auto& data_type = data->dtype; const auto& gamma_type = gamma.defined() ? gamma->dtype : data_type; diff --git a/include/tvm/topi/nn/layer_norm.h b/include/tvm/topi/nn/layer_norm.h index 6c3409aca3a9..873a5fd1b2d2 100644 --- a/include/tvm/topi/nn/layer_norm.h +++ b/include/tvm/topi/nn/layer_norm.h @@ -49,7 +49,7 @@ using namespace tvm::te; * \return The normalized tensor, with the same shape as data. */ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& beta, - const ffi::Array& axis, double epsilon, + const ffi::Array& axis, double epsilon, std::string name = "T_layer_norm", std::string tag = kInjective) { const auto& data_type = data->dtype; const auto& gamma_type = gamma.defined() ? gamma->dtype : data_type; diff --git a/include/tvm/topi/nn/rms_norm.h b/include/tvm/topi/nn/rms_norm.h index 4f6292d968ac..ac36e5badd41 100644 --- a/include/tvm/topi/nn/rms_norm.h +++ b/include/tvm/topi/nn/rms_norm.h @@ -47,7 +47,7 @@ using namespace tvm::te; * \param tag The tag to mark the operation. * \return The normalized tensor, with the same shape as data. */ -inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const ffi::Array& axis, +inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const ffi::Array& axis, double epsilon, std::string name = "T_rms_norm", std::string tag = kInjective) { const auto& data_type = data->dtype; diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index 73c5fc31ce77..e3f5444efe38 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -62,7 +62,7 @@ using FCommReduce = std::function( * If any input element is negative, it will be treated as an offset from the * last dimension (same as python indexing rules). */ -inline std::vector GetRealAxis(int ndim, const ffi::Optional>& axis) { +inline std::vector GetRealAxis(int ndim, const ffi::Optional>& axis) { std::vector real_axis; if (!axis.has_value()) { for (int i = 0; i < ndim; ++i) { @@ -70,8 +70,8 @@ inline std::vector GetRealAxis(int ndim, const ffi::Optionalvalue; + for (int64_t elem : axis.value()) { + int64_t val = elem; if (val < 0) { val += ndim; } @@ -181,7 +181,7 @@ inline Tensor DoCommReduce(const Tensor& data, FReduce func, * * \return The result tensor. */ -inline Tensor CommReduce(const Tensor& data, const ffi::Optional>& axis, +inline Tensor CommReduce(const Tensor& data, const ffi::Optional>& axis, FReduce func, bool keepdims, bool atleast1d) { auto ndim = data->shape.size(); TVM_FFI_ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; @@ -204,7 +204,7 @@ inline Tensor CommReduce(const Tensor& data, const ffi::Optional>& axis, +inline Tensor CommReduceIdx(const Tensor& data, const ffi::Optional>& axis, FCommReduce func, bool keepdims, bool atleast1d) { auto ndim = data->shape.size(); TVM_FFI_ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; @@ -325,7 +325,7 @@ inline PrimExpr ProdOp(PrimExpr source, ffi::Array axis, ffi::Array>& axis, +inline Tensor sum(const Tensor& data, const ffi::Optional>& axis, bool keepdims = false, bool atleast1d = false) { if (data->dtype.is_bool()) { return CommReduce(data, axis, tvm::any, keepdims, atleast1d); @@ -382,7 +382,7 @@ inline Tensor collapse_sum(const Tensor& data, ffi::Array target_shape * * \return A Tensor whose op member is the all operation */ -inline Tensor all(const Tensor& data, const ffi::Optional>& axis, +inline Tensor all(const Tensor& data, const ffi::Optional>& axis, bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, tvm::all, keepdims, atleast1d); } @@ -401,7 +401,7 @@ inline Tensor all(const Tensor& data, const ffi::Optional>& * * \return A Tensor whose op member is the all operation */ -inline Tensor any(const Tensor& data, const ffi::Optional>& axis, +inline Tensor any(const Tensor& data, const ffi::Optional>& axis, bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, tvm::any, keepdims, atleast1d); } @@ -420,7 +420,7 @@ inline Tensor any(const Tensor& data, const ffi::Optional>& * * \return A Tensor whose op member is the min operation */ -inline Tensor min(const Tensor& data, const ffi::Optional>& axis, +inline Tensor min(const Tensor& data, const ffi::Optional>& axis, bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, MinOp, keepdims, atleast1d); } @@ -439,7 +439,7 @@ inline Tensor min(const Tensor& data, const ffi::Optional>& * * \return A Tensor whose op member is the max operation */ -inline Tensor max(const Tensor& data, const ffi::Optional>& axis, +inline Tensor max(const Tensor& data, const ffi::Optional>& axis, bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, MaxOp, keepdims, atleast1d); } @@ -499,7 +499,7 @@ inline FCommReduce MakeArgminReducer(bool select_last_index = false) { * * \return A Tensor whose op member is the argmin operation */ -inline Tensor argmin(const Tensor& data, const ffi::Optional>& axis, +inline Tensor argmin(const Tensor& data, const ffi::Optional>& axis, bool keepdims = false, bool atleast1d = false, bool select_last_index = false) { auto reducer = MakeArgminReducer(select_last_index); @@ -560,7 +560,7 @@ inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { * appears multiple times, else select the first index. * \return A Tensor whose op member is the argmax operation */ -inline Tensor argmax(const Tensor& data, const ffi::Optional>& axis, +inline Tensor argmax(const Tensor& data, const ffi::Optional>& axis, bool keepdims = false, bool atleast1d = false, bool select_last_index = false) { auto reducer = MakeArgmaxReducer(select_last_index); @@ -580,7 +580,7 @@ inline Tensor argmax(const Tensor& data, const ffi::Optional * * \return A Tensor whose op member is the prod operation */ -inline Tensor prod(const Tensor& data, const ffi::Optional>& axis, +inline Tensor prod(const Tensor& data, const ffi::Optional>& axis, bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, ProdOp, keepdims, atleast1d); } diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index dda18baa15fb..3c458435b37b 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -73,8 +73,8 @@ using namespace topi::detail; * * \return A Tensor whose op member is the sliding_window operation */ -inline Tensor sliding_window(const Tensor& x, int axis, ffi::Array window_shape, - ffi::Array strides, std::string name = "T_sliding_window", +inline Tensor sliding_window(const Tensor& x, int axis, ffi::Array window_shape, + ffi::Array strides, std::string name = "T_sliding_window", std::string tag = "") { TVM_FFI_ICHECK_GE(axis, 0); auto _axis = size_t(axis); @@ -98,16 +98,16 @@ inline Tensor sliding_window(const Tensor& x, int axis, ffi::Array wind // Length of the shape along this dimension. auto dim_len = x->shape[_axis + i]; // Length of the window along this dimension. - auto window_len = window_shape[i]; + PrimExpr window_len = IntImm(DataType::Int(64), window_shape[i]); // Strides along this dimension. - auto stride = strides[i]; + PrimExpr stride = IntImm(DataType::Int(64), strides[i]); new_shape.push_back(floordiv(dim_len - (window_len - 1) + stride - 1, stride)); } // Dimensions comprising the window. for (size_t i = 0; i < window_shape.size(); ++i) { - new_shape.push_back(window_shape[i]); + new_shape.push_back(IntImm(DataType::Int(64), window_shape[i])); } TVM_FFI_ICHECK(new_shape.size() == _axis + 2 * window_shape.size()); @@ -129,7 +129,7 @@ inline Tensor sliding_window(const Tensor& x, int axis, ffi::Array wind // Which index within the window we are indexing. auto idx_within_window = indices[_axis + window_shape.size() + i]; // Stride value for this dimension. - auto stride = strides[i]; + PrimExpr stride = IntImm(DataType::Int(64), strides[i]); idx.push_back(window_idx * stride + idx_within_window); } @@ -202,9 +202,9 @@ inline Tensor expand_dims(const Tensor& x, int axis, int num_newaxis = 1, * * \return A Tensor whose op member is the transpose operation */ -inline Tensor transpose(const Tensor& x, ffi::Optional> opt_axes, +inline Tensor transpose(const Tensor& x, ffi::Optional> opt_axes, std::string name = "T_transpose", std::string tag = kInjective) { - ffi::Array axes = opt_axes.value_or({}); + ffi::Array axes = opt_axes.value_or({}); if (axes.size() == 0) { for (int i = static_cast(x->shape.size()) - 1; i >= 0; --i) { axes.push_back(i); @@ -213,7 +213,7 @@ inline Tensor transpose(const Tensor& x, ffi::Optional> opt_ ffi::Array new_shape; for (size_t i = 0; i < axes.size(); ++i) { - int axis = static_cast(axes[i]->value); + int axis = static_cast(axes[i]); int new_axis = axis; if (axis < 0) { new_axis = static_cast(x->shape.size()) + axis; @@ -225,7 +225,7 @@ inline Tensor transpose(const Tensor& x, ffi::Optional> opt_ for (size_t j = 0; j < axes.size(); ++j) { if (i != j) { - TVM_FFI_ICHECK(new_axis != static_cast(axes[j]->value)) + TVM_FFI_ICHECK(new_axis != static_cast(axes[j])) << "repeated axis in transpose"; } } @@ -240,7 +240,7 @@ inline Tensor transpose(const Tensor& x, ffi::Optional> opt_ idx.push_back(1); } for (size_t i = 0; i < axes.size(); ++i) { - int axis = static_cast(axes[i]->value); + int axis = static_cast(axes[i]); idx[axis] = indices[i]; } return x(idx); @@ -412,7 +412,7 @@ inline Tensor unravel_index(const Tensor& x, const Tensor& shape, std::string na * * \return A Tensor whose op member is the squeeze operation */ -inline Tensor squeeze(const Tensor& x, ffi::Optional> opt_axes, +inline Tensor squeeze(const Tensor& x, ffi::Optional> opt_axes, bool atleast1d = false, std::string name = "T_squeeze", std::string tag = kInjective) { auto ndim = x->shape.size(); @@ -424,9 +424,9 @@ inline Tensor squeeze(const Tensor& x, ffi::Optional> opt_ax } } } else { - ffi::Array axis = *std::move(opt_axes); + ffi::Array axis = *std::move(opt_axes); for (size_t i = 0; i < axis.size(); ++i) { - int64_t val = axis[i]->value; + int64_t val = axis[i]; if (val < 0) { val += static_cast(x->shape.size()); } @@ -715,7 +715,7 @@ inline PrimExpr GetLength(PrimExpr begin, PrimExpr end, PrimExpr stride, PrimExp */ inline te::Tensor dynamic_strided_slice_with_axes( const te::Tensor& x, const ffi::Array& begin, const ffi::Array& end, - const ffi::Array& strides, const ffi::Array& axes, + const ffi::Array& strides, const ffi::Array& axes, bool assume_inbound = true, std::string name = "T_dynamic_strided_slice_with_axes", std::string tag = kInjective) { const size_t src_tensor_dim = x->shape.size(); @@ -725,7 +725,7 @@ inline te::Tensor dynamic_strided_slice_with_axes( TVM_FFI_ICHECK_LE(begin.size(), src_tensor_dim); for (const auto& axis_imm : axes) { - int axis = axis_imm->value; + int axis = static_cast(axis_imm); TVM_FFI_ICHECK_LT(axis, src_tensor_dim); } @@ -733,7 +733,7 @@ inline te::Tensor dynamic_strided_slice_with_axes( ffi::Array out_shape = x->shape; for (size_t i = 0; i < begin.size(); i++) { - int axis = axes[i]->value; + int axis = static_cast(axes[i]); PrimExpr new_shape = analyzer.Simplify(GetLength(begin[i], end[i], strides[i], out_shape[axis], assume_inbound)); out_shape.Set(axis, new_shape); @@ -746,7 +746,7 @@ inline te::Tensor dynamic_strided_slice_with_axes( indices.Map([](const auto& var) -> PrimExpr { return var; }); for (size_t i = 0; i < begin.size(); i++) { - int axis = axes[i]->value; + int axis = static_cast(axes[i]); PrimExpr new_index = indices[axis] * strides[i] + begin[i]; real_indices.Set(axis, new_index); } @@ -866,17 +866,17 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b * \return The output shape of strided_slice using the arguments above */ inline ffi::Array StridedSliceOutputShape(const ffi::Array& ishape, - const ffi::Array& begin, - const ffi::Array& end, - const ffi::Array& strides, - const ffi::Array& axes, + const ffi::Array& begin, + const ffi::Array& end, + const ffi::Array& strides, + const ffi::Array& axes, const std::string& slice_mode) { TVM_FFI_ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size()); std::vector begin_vec, end_vec, strides_vec; std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode); auto begin_canonicalized = StridedSliceCanonicalizeBegin(ishape, begin_vec, strides_vec, axes, - begin[0]->dtype, slice_mode); + DataType::Int(64), slice_mode); return StridedSliceOutputShape(ishape, begin_vec, end_vec, strides_vec, axes, slice_mode, begin_canonicalized, true); } @@ -897,10 +897,10 @@ inline ffi::Array StridedSliceOutputShape(const ffi::Array& * * \return A Tensor whose op member is the sstrided_slice operation */ -inline Tensor strided_slice_with_axes(const Tensor& x, const ffi::Array& begin, - const ffi::Array& end, - const ffi::Array& strides, - const ffi::Array& axes, +inline Tensor strided_slice_with_axes(const Tensor& x, const ffi::Array& begin, + const ffi::Array& end, + const ffi::Array& strides, + const ffi::Array& axes, std::string slice_mode = "end", std::string name = "T_strided_slice_with_axes", std::string tag = kInjective) { @@ -910,23 +910,24 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const ffi::Array axes.size() == strides.size()); // Normalize negative axes - ffi::Array normalized_axes; + ffi::Array normalized_axes; for (size_t i = 0; i < axes.size(); ++i) { - int64_t axis = axes[i].IntValue(); + int64_t axis = axes[i]; if (axis < 0) { axis += src_tensor_dim; } TVM_FFI_ICHECK(axis >= 0 && axis < src_tensor_dim) - << "Axis " << axes[i].IntValue() << " is out of bounds for tensor with " << src_tensor_dim + << "Axis " << axes[i] << " is out of bounds for tensor with " << src_tensor_dim << " dimensions"; - normalized_axes.push_back(IntImm(DataType::Int(32), axis)); + normalized_axes.push_back(axis); } std::vector begin_vec, end_vec, strides_vec; std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode); + DataType index_dtype = begin.size() > 0 ? DataType::Int(64) : DataType::Int(64); auto begin_expr = StridedSliceCanonicalizeBegin(x->shape, begin_vec, strides_vec, normalized_axes, - begin[0]->dtype, slice_mode); + index_dtype, slice_mode); auto out_shape = StridedSliceOutputShape(x->shape, begin_vec, end_vec, strides_vec, normalized_axes, slice_mode, begin_expr); @@ -936,9 +937,10 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const ffi::Array ffi::Array real_indices; for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]); for (size_t i = 0; i < normalized_axes.size(); ++i) { - auto stride = make_const(strides[i].dtype(), strides_vec[i]); - PrimExpr ind = indices[normalized_axes[i].IntValue()] * stride + begin_expr[i]; - real_indices.Set(normalized_axes[i].IntValue(), ind); + int64_t ax = normalized_axes[i]; + auto stride = make_const(DataType::Int(64), strides_vec[i]); + PrimExpr ind = indices[ax] * stride + begin_expr[i]; + real_indices.Set(ax, ind); } return x(real_indices); }, @@ -959,30 +961,29 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const ffi::Array * * \return A Tensor whose op member is the strided_slice operation */ -inline Tensor strided_slice(const Tensor& x, const ffi::Array& begin, - const ffi::Array& end, const ffi::Array& strides, +inline Tensor strided_slice(const Tensor& x, const ffi::Array& begin, + const ffi::Array& end, const ffi::Array& strides, std::string slice_mode = "end", std::string name = "T_strided_slice", std::string tag = kInjective) { size_t src_tensor_dim = static_cast(x->shape.size()); - ffi::Array axes; + ffi::Array axes; for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i); - ffi::Array begin_full(begin); - ffi::Array end_full(end); - ffi::Array strides_full(strides); + ffi::Array begin_full(begin); + ffi::Array end_full(end); + ffi::Array strides_full(strides); - DataType index_dtype = begin.size() > 0 ? begin[0]->dtype : DataType::Int(64); - const IntImm one = IntImm(index_dtype, 1); - const IntImm zero = IntImm(index_dtype, 0); - const IntImm max_range = Downcast(max_value(index_dtype)); + constexpr int64_t one = 1; + constexpr int64_t zero = 0; + const int64_t max_range = std::numeric_limits::max(); for (size_t i = strides.size(); i < src_tensor_dim; ++i) { strides_full.push_back(one); } for (size_t i = begin.size(); i < src_tensor_dim; ++i) { - begin_full.push_back(GetConstInt(strides_full[i]) > 0 ? zero : max_range); + begin_full.push_back(strides_full[i] > 0 ? zero : max_range); } for (size_t i = end.size(); i < src_tensor_dim; ++i) { - end_full.push_back(GetConstInt(strides_full[i]) < 0 ? zero : max_range); + end_full.push_back(strides_full[i] < 0 ? zero : max_range); } return strided_slice_with_axes(x, begin_full, end_full, strides_full, axes, slice_mode, name, @@ -1414,7 +1415,7 @@ inline Tensor repeat(const Tensor& x, int repeats, int axis, std::string name = * * \return A Tensor whose op member is the tile operation */ -inline Tensor tile(const Tensor& x, ffi::Array reps, std::string name = "T_tile", +inline Tensor tile(const Tensor& x, ffi::Array reps, std::string name = "T_tile", std::string tag = kBroadcast) { size_t ndim = x->shape.size(); size_t rdim = reps.size(); @@ -1425,16 +1426,16 @@ inline Tensor tile(const Tensor& x, ffi::Array reps, std::string name = if (ndim == rdim) { for (size_t i = 0; i < ndim; ++i) { data_shape.push_back(x->shape[i]); - reps_shape.push_back(reps[i]); + reps_shape.push_back(IntImm(DataType::Int(64), reps[i])); } } else if (ndim > rdim) { for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]); for (size_t i = 0; i < (ndim - rdim); ++i) reps_shape.push_back(1); - for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]); + for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(IntImm(DataType::Int(64), reps[i])); } else { for (size_t i = 0; i < (rdim - ndim); ++i) data_shape.push_back(1); for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]); - for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]); + for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(IntImm(DataType::Int(64), reps[i])); } for (size_t i = 0; i < tdim; ++i) new_shape.push_back(data_shape[i] * reps_shape[i]); diff --git a/include/tvm/topi/utils.h b/include/tvm/topi/utils.h index 41a2cce0e4f9..33ddaaf6533c 100644 --- a/include/tvm/topi/utils.h +++ b/include/tvm/topi/utils.h @@ -33,16 +33,16 @@ namespace topi { using namespace tvm::runtime; /*! \brief Canonicalize an argument that may be ffi::Array or int to ffi::Array */ -inline ffi::Optional> ArrayOrInt(AnyView arg) { +inline ffi::Optional> ArrayOrInt(AnyView arg) { if (arg == nullptr) { return std::nullopt; } if (auto opt_int = arg.try_cast()) { - ffi::Array result; + ffi::Array result; result.push_back(opt_int.value()); return result; } else { - return arg.cast>(); + return arg.cast>(); } } } // namespace topi diff --git a/src/relax/backend/contrib/clml/codegen.cc b/src/relax/backend/contrib/clml/codegen.cc index 94083c8b8f18..14c8230d0b81 100644 --- a/src/relax/backend/contrib/clml/codegen.cc +++ b/src/relax/backend/contrib/clml/codegen.cc @@ -42,7 +42,7 @@ namespace contrib { /*! \brief Attributes to store the compiler options for OpenCLML. */ struct OpenCLMLCompilerConfigNode : public ffi::Object { - Integer clml_version; + IntImm clml_version; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -269,7 +269,7 @@ class OpenCLMLJSONSerializer : public JSONSerializer { if (!cfg.defined()) { cfg = transform::PassConfigWithDefaults(); } - node->SetAttr("clml_version", static_cast(cfg.value()->clml_version.IntValue())); + node->SetAttr("clml_version", static_cast(cfg.value()->clml_version->value)); } private: @@ -332,7 +332,7 @@ inline constexpr bool IsOpenCLMLRuntimeEnabled() { * \brief Get OpenCLML version that TVM is built against. * \return The OpenCLML SDK version. */ -Integer GetOpenCLMLVersion() { +IntImm GetOpenCLMLVersion() { #if TVM_GRAPH_EXECUTOR_CLML return IntImm(DataType::Int(32), TVM_CLML_VERSION); #else diff --git a/src/relax/distributed/axis_group_graph.cc b/src/relax/distributed/axis_group_graph.cc index 961c074d466e..c805ea6a5c7f 100644 --- a/src/relax/distributed/axis_group_graph.cc +++ b/src/relax/distributed/axis_group_graph.cc @@ -181,8 +181,8 @@ void BuildAxisGraphReduce(const Var& output_var, const Call& call, int ndim = GetTensorStructInfo(input_tensor)->ndim; std::unordered_set normalized_axes; - for (const Integer& i : axes) { - int val = i->value; + for (int64_t i : axes) { + int val = static_cast(i); TVM_FFI_ICHECK(val < ndim && val >= -ndim); if (val < 0) { val = ndim + val; @@ -289,8 +289,8 @@ void BuildAxisGraphPermuteDims(const Var& output_var, const Call& call, int ndim = GetTensorStructInfo(input_tensor)->ndim; std::vector normalized_axes; if (attrs->axes.defined()) { - for (const Integer& i : attrs->axes.value()) { - int val = i->value; + for (int64_t i : attrs->axes.value()) { + int val = static_cast(i); TVM_FFI_ICHECK(val < ndim && val >= -ndim); if (val < 0) { val = ndim + val; diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 5c2419209b42..cec5ae65fbc2 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -231,15 +231,14 @@ TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { TupleGetItem WithFields(TupleGetItem tuple_get_item, ffi::Optional opt_tuple, ffi::Optional opt_index, ffi::Optional opt_span) { Expr tuple = opt_tuple.value_or(tuple_get_item->tuple); - Integer index = opt_index.value_or(tuple_get_item->index); + int64_t index = opt_index.value_or(tuple_get_item->index); Span span = opt_span.value_or(tuple_get_item->span); bool unchanged = tuple.same_as(tuple_get_item->tuple) && (index == tuple_get_item->index) && span.same_as(tuple_get_item->span); if (!unchanged) { TupleGetItemNode* cow_tuple_get_item_node = tuple_get_item.CopyOnWrite(); - cow_tuple_get_item_node->tuple = tuple; - cow_tuple_get_item_node->index = index.IntValue(); + cow_tuple_get_item_node->index = static_cast(index); cow_tuple_get_item_node->span = span; } return tuple_get_item; diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index 6b02ca050bea..79bedfdc485c 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -474,7 +474,7 @@ InferLayoutOutput InferLayoutStridedSlice( } return InferLayoutOutput({existing_layout}, {existing_layout}, call->attrs, - {{1, relax::Tuple(new_axes)}}); + {{IntImm(DataType::Int(32), 1), relax::Tuple(new_axes)}}); } TVM_REGISTER_OP("relax.strided_slice") diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index 9777638d79a8..8072ee5d146f 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -981,7 +981,7 @@ ffi::Array DataflowAliasAnalysis(const DataflowBlock& block, auto alias_sets = res.first; auto tuple_map = res.second; ffi::Map> new_alias_sets; - ffi::Map>> new_tuple_map; + ffi::Map>> new_tuple_map; for (auto kv : alias_sets) { ffi::Array aliases; for (auto alias : kv.second) { @@ -998,7 +998,7 @@ ffi::Array DataflowAliasAnalysis(const DataflowBlock& block, } elem_aliases.push_back(dim_aliases); } - new_tuple_map.Set(kv.first, elem_aliases); + new_tuple_map.Set(IntImm(DataType::Int(32), kv.first), elem_aliases); } return {new_alias_sets, new_tuple_map}; } diff --git a/src/relax/transform/infer_layout_utils.h b/src/relax/transform/infer_layout_utils.h index 60bb3db63a38..724464a945c9 100644 --- a/src/relax/transform/infer_layout_utils.h +++ b/src/relax/transform/infer_layout_utils.h @@ -106,7 +106,7 @@ class InferLayoutOutputNode : public ffi::Object { ffi::Array input_layouts; ffi::Array output_layouts; Attrs new_attrs; - ffi::Map new_args; + ffi::Map new_args; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -124,7 +124,7 @@ class InferLayoutOutputNode : public ffi::Object { class InferLayoutOutput : public ffi::ObjectRef { public: explicit InferLayoutOutput(ffi::Array input_layouts, ffi::Array output_layouts, - Attrs new_attrs, ffi::Map new_args = {}) { + Attrs new_attrs, ffi::Map new_args = {}) { auto n = ffi::make_object(); n->input_layouts = std::move(input_layouts); n->output_layouts = std::move(output_layouts); diff --git a/src/relax/transform/reorder_permute_dims_after_concat.cc b/src/relax/transform/reorder_permute_dims_after_concat.cc index 9e0067471071..88c64521b047 100644 --- a/src/relax/transform/reorder_permute_dims_after_concat.cc +++ b/src/relax/transform/reorder_permute_dims_after_concat.cc @@ -163,9 +163,9 @@ std::tuple)>> TVM_FFI_ICHECK_LT(old_concat_axis, ndim) << "concat axis " << old_concat_axis << " out of range for " << ndim << "-D input"; - Integer new_concat_axis = permute_dims_axes[static_cast(old_concat_axis)]; + int64_t new_concat_axis = permute_dims_axes[static_cast(old_concat_axis)]; - auto new_concat = concat(Tuple(args), new_concat_axis->value); + auto new_concat = concat(Tuple(args), new_concat_axis); auto new_permute_dims = permute_dims(new_concat, permute_axes); return new_permute_dims; diff --git a/src/relax/transform/split_call_tir_by_pattern.cc b/src/relax/transform/split_call_tir_by_pattern.cc index 856742810858..45c0e61a25f1 100644 --- a/src/relax/transform/split_call_tir_by_pattern.cc +++ b/src/relax/transform/split_call_tir_by_pattern.cc @@ -581,7 +581,7 @@ std::pair> SplitFunctions( ffi::Array codegen_result = f_codegen(match_results); TVM_FFI_ICHECK(codegen_result.size() == 3); ffi::String library_code = Downcast(codegen_result[0]); - int num_matched_ops = Downcast(codegen_result[1])->value; + int num_matched_ops = Downcast(codegen_result[1])->value; ffi::Array func1_args = Downcast>(codegen_result[2]); if (num_matched_ops == 0) { return {func, std::nullopt}; diff --git a/src/s_tir/meta_schedule/arg_info.cc b/src/s_tir/meta_schedule/arg_info.cc index 4259ac999bc7..4163a2b8b552 100644 --- a/src/s_tir/meta_schedule/arg_info.cc +++ b/src/s_tir/meta_schedule/arg_info.cc @@ -150,7 +150,7 @@ TensorInfo TensorInfo::FromJSON(const ffi::ObjectRef& json_obj) { } std::vector s; std::transform(shape.begin(), shape.end(), std::back_inserter(s), - [](Integer i) { return i.IntValue(); }); + [](int64_t i) { return i; }); return TensorInfo(DataType(dtype), ffi::Shape(s.begin(), s.end())); } diff --git a/src/s_tir/meta_schedule/mutator/mutate_tile_size.cc b/src/s_tir/meta_schedule/mutator/mutate_tile_size.cc index e5e145bf37e8..ad58b293f87b 100644 --- a/src/s_tir/meta_schedule/mutator/mutate_tile_size.cc +++ b/src/s_tir/meta_schedule/mutator/mutate_tile_size.cc @@ -214,7 +214,7 @@ ffi::Optional MutateSampleTileSize(const Trace& trace, Instruction inst, if (y != n_splits - 1) { divide_factor = factors[s_tir::SampleInt(rand_state, 1, factors.size())]; } else { - int64_t limit = Downcast(inst->attrs[1])->value; + int64_t limit = Downcast(inst->attrs[1])->value; int max_factor_index = static_cast(factors.size()) - 1; for (; max_factor_index >= 1; max_factor_index--) { if (factors[max_factor_index] * tiles[y] <= limit) { diff --git a/src/s_tir/meta_schedule/postproc/rewrite_cooperative_fetch.cc b/src/s_tir/meta_schedule/postproc/rewrite_cooperative_fetch.cc index c27f3196b03b..3e0cdd8ac88d 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_cooperative_fetch.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_cooperative_fetch.cc @@ -66,7 +66,7 @@ ffi::Optional ParseAnnotate(const Schedule& sch, const Instruction& in if (ann_key != s_tir::attr::meta_schedule_cooperative_fetch) { return std::nullopt; } - *vector_lane = Downcast(sch->Get(Downcast(inst->inputs[1])))->value; + *vector_lane = Downcast(sch->Get(Downcast(inst->inputs[1])))->value; return Downcast(inst->inputs[0]); } diff --git a/src/s_tir/meta_schedule/postproc/rewrite_layout.cc b/src/s_tir/meta_schedule/postproc/rewrite_layout.cc index 1517e0f6e109..d53e53969ad0 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_layout.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_layout.cc @@ -126,9 +126,9 @@ ffi::Array CollectLayoutFreeBuffers(const PrimFuncNode* func) { func->GetAttr(s_tir::attr::layout_free_buffers, ffi::Array()).value(); ffi::Array layout_free_buffers; - for (const Integer& index : layout_free_buffer_index) { - TVM_FFI_ICHECK(static_cast(index->value) < func->params.size()); - const Var& param = func->params[index->value]; + for (int64_t index : layout_free_buffer_index) { + TVM_FFI_ICHECK(static_cast(index) < func->params.size()); + const Var& param = func->params[index]; layout_free_buffers.push_back(func->buffer_map.at(param)); } diff --git a/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc b/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc index ee67d7275f9c..52d17f038332 100644 --- a/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc @@ -107,10 +107,10 @@ namespace s_tir { namespace meta_schedule { /*! \brief Extract attribute from a target. */ -Integer Extract(const Target& target, const char* name) { +IntImm Extract(const Target& target, const char* name) { TVM_FFI_ICHECK(target.defined()); if (ffi::Optional v = target->GetAttr(name)) { - return v.value(); + return IntImm(DataType::Int(64), v.value()); } TVM_FFI_THROW(AttributedError) << "\"" << name << "\" is not defined in the target"; throw; @@ -132,7 +132,7 @@ class VerifyGPUCodeNode : public PostprocNode { {"max_vthread", IntImm(DataType::Int(32), 8)}, {"max_vector_bytes", IntImm(DataType::Int(32), 16)}, }; - thread_warp_size_ = Extract(this->target_, "thread_warp_size").IntValue(); + thread_warp_size_ = static_cast(Extract(this->target_, "thread_warp_size")->value); } bool Verify(const IRModule& mod) const { diff --git a/src/s_tir/meta_schedule/space_generator/space_generator.cc b/src/s_tir/meta_schedule/space_generator/space_generator.cc index da5f5f399833..890511ad3bca 100644 --- a/src/s_tir/meta_schedule/space_generator/space_generator.cc +++ b/src/s_tir/meta_schedule/space_generator/space_generator.cc @@ -49,10 +49,10 @@ ffi::String GetRuleKindFromTarget(const Target& target) { ffi::Map target_json = target::canonicalizer::llvm::aprofile::Canonicalize(target->ToConfig()); - if (Downcast(target_json.at("feature.has_dotprod"))) { + if (Downcast(target_json.at("feature.has_dotprod"))->value) { return "dotprod"; } - if (Downcast(target_json.at("feature.has_asimd"))) { + if (Downcast(target_json.at("feature.has_asimd"))->value) { return "asimd"; } return "llvm"; diff --git a/src/s_tir/schedule/instruction_traits.h b/src/s_tir/schedule/instruction_traits.h index a083f53d16ab..d37e075424a0 100644 --- a/src/s_tir/schedule/instruction_traits.h +++ b/src/s_tir/schedule/instruction_traits.h @@ -112,8 +112,8 @@ using namespace tvm::tirx; * static ffi::Array UnpackedApplyToSchedule( * Schedule sch, * LoopRV loop_rv, - * Integer n, - * Integer max_innermost_factor, + * IntImm n, + * IntImm max_innermost_factor, * ffi::Optional> decision) { * return sch->SamplePerfectTile(loop_rv, n->value, max_innermost_factor->value, decision); * } @@ -127,8 +127,8 @@ using namespace tvm::tirx; * static ffi::String UnpackedAsPython( * ffi::Array outputs, * ffi::String loop_rv, - * Integer n, - * Integer max_innermost_factor, + * IntImm n, + * IntImm max_innermost_factor, * ffi::Optional> decision) { * PythonAPICall py("sample_perfect_tile"); * py.Input("loop", loop_rv); diff --git a/src/s_tir/schedule/primitive/annotate_buffer_access.cc b/src/s_tir/schedule/primitive/annotate_buffer_access.cc index 82d1e6a1c888..82a3a0de1cfe 100644 --- a/src/s_tir/schedule/primitive/annotate_buffer_access.cc +++ b/src/s_tir/schedule/primitive/annotate_buffer_access.cc @@ -122,8 +122,8 @@ struct AnnotateBufferAccessTraits : public UnpackedInstTraitsAnnotateBufferAccess(block, buffer_index->value, static_cast(buffer_index_type->value), index_map); @@ -150,7 +150,7 @@ struct AnnotateBufferAccessTraits : public UnpackedInstTraits outputs, ffi::String block, - Integer buffer_index, Integer buffer_index_type, + IntImm buffer_index, IntImm buffer_index_type, IndexMap index_map) { PythonAPICall py("annotate_buffer_access"); py.Input("block", block); diff --git a/src/s_tir/schedule/primitive/block_annotate.cc b/src/s_tir/schedule/primitive/block_annotate.cc index 752bf6692d1f..3734fc3f3fce 100644 --- a/src/s_tir/schedule/primitive/block_annotate.cc +++ b/src/s_tir/schedule/primitive/block_annotate.cc @@ -383,15 +383,15 @@ struct StorageAlignTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 4; static constexpr size_t kNumDecisions = 0; - static void UnpackedApplyToSchedule(Schedule sch, SBlockRV block_rv, Integer buffer_index, - Integer axis, Integer factor, Integer offset) { + static void UnpackedApplyToSchedule(Schedule sch, SBlockRV block_rv, IntImm buffer_index, + IntImm axis, IntImm factor, IntImm offset) { return sch->StorageAlign(block_rv, buffer_index->value, axis->value, factor->value, offset->value); } static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, - Integer buffer_index, Integer axis, Integer factor, - Integer offset) { + IntImm buffer_index, IntImm axis, IntImm factor, + IntImm offset) { PythonAPICall py("storage_align"); py.Input("block", block_rv); py.Input("buffer_index", buffer_index); @@ -414,13 +414,13 @@ struct SetScopeTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 0; - static void UnpackedApplyToSchedule(Schedule sch, SBlockRV block_rv, Integer buffer_index, + static void UnpackedApplyToSchedule(Schedule sch, SBlockRV block_rv, IntImm buffer_index, ffi::String storage_scope) { return sch->SetScope(block_rv, buffer_index->value, storage_scope); } static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, - Integer buffer_index, ffi::String storage_scope) { + IntImm buffer_index, ffi::String storage_scope) { PythonAPICall py("set_scope"); py.Input("block", block_rv); py.Input("buffer_index", buffer_index); @@ -441,13 +441,13 @@ struct UnsafeSetDTypeTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 0; - static void UnpackedApplyToSchedule(Schedule sch, SBlockRV block_rv, Integer buffer_index, + static void UnpackedApplyToSchedule(Schedule sch, SBlockRV block_rv, IntImm buffer_index, ffi::String dtype) { return sch->UnsafeSetDType(block_rv, buffer_index->value, dtype); } static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, - Integer buffer_index, ffi::String dtype) { + IntImm buffer_index, ffi::String dtype) { PythonAPICall py("unsafe_set_dtype"); py.Input("block", block_rv); py.Input("buffer_index", buffer_index); diff --git a/src/s_tir/schedule/primitive/blockize_tensorize.cc b/src/s_tir/schedule/primitive/blockize_tensorize.cc index 5c55f5d7578f..4848c582c234 100644 --- a/src/s_tir/schedule/primitive/blockize_tensorize.cc +++ b/src/s_tir/schedule/primitive/blockize_tensorize.cc @@ -876,21 +876,21 @@ struct BlockizeTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; static SBlockRV UnpackedApplyToSchedule(Schedule sch, ffi::ObjectRef target, - Bool preserve_unit_iters) { + IntImm preserve_unit_iters) { if (auto loop = target.as()) { - return sch->Blockize(loop.value(), preserve_unit_iters.operator bool()); + return sch->Blockize(loop.value(), preserve_unit_iters->value != 0); } else if (auto blocks = target.as>()) { - return sch->Blockize(blocks.value(), preserve_unit_iters.operator bool()); + return sch->Blockize(blocks.value(), preserve_unit_iters->value != 0); } TVM_FFI_THROW(TypeError) << "expect Loop or list of SBlocks, but gets:" << target->GetTypeKey(); TVM_FFI_UNREACHABLE(); } static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::ObjectRef target, - Bool preserve_unit_iters) { + IntImm preserve_unit_iters) { PythonAPICall py("blockize"); py.Input("target", target); - py.Input("preserve_unit_iters", preserve_unit_iters.operator bool()); + py.Input("preserve_unit_iters", preserve_unit_iters->value != 0); py.SingleOutput(outputs); return py.Str(); } @@ -909,11 +909,11 @@ struct TensorizeTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; static void UnpackedApplyToSchedule(Schedule sch, ffi::ObjectRef block_or_loop_rv, - ffi::String intrin, Bool preserve_unit_iters) { + ffi::String intrin, IntImm preserve_unit_iters) { if (auto block = block_or_loop_rv.as()) { - sch->Tensorize(block.value(), intrin, preserve_unit_iters.operator bool()); + sch->Tensorize(block.value(), intrin, preserve_unit_iters->value != 0); } else if (auto loop = block_or_loop_rv.as()) { - sch->Tensorize(loop.value(), intrin, preserve_unit_iters.operator bool()); + sch->Tensorize(loop.value(), intrin, preserve_unit_iters->value != 0); } else { TVM_FFI_THROW(TypeError) << "Expected SBlock or Loop, but gets: " << block_or_loop_rv->GetTypeKey(); @@ -921,11 +921,11 @@ struct TensorizeTraits : public UnpackedInstTraits { } static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_or_loop_rv, - ffi::String intrin, Bool preserve_unit_iters) { + ffi::String intrin, IntImm preserve_unit_iters) { PythonAPICall py("tensorize"); py.Input("block_or_loop", block_or_loop_rv); py.Input("tensor_intrin", intrin); - py.Input("preserve_unit_iters", preserve_unit_iters.operator bool()); + py.Input("preserve_unit_iters", preserve_unit_iters->value != 0); return py.Str(); } diff --git a/src/s_tir/schedule/primitive/cache_index.cc b/src/s_tir/schedule/primitive/cache_index.cc index 9566817f8015..3cd33aea0c51 100644 --- a/src/s_tir/schedule/primitive/cache_index.cc +++ b/src/s_tir/schedule/primitive/cache_index.cc @@ -507,12 +507,12 @@ struct CacheIndexTraits : public UnpackedInstTraits { static ffi::Array UnpackedApplyToSchedule(Schedule sch, SBlockRV block, ffi::String storage_scope, - Integer cse_thresh) { + IntImm cse_thresh) { return sch->CacheIndex(block, storage_scope, cse_thresh->value); } static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, - ffi::String storage_scope, Integer cse_thresh) { + ffi::String storage_scope, IntImm cse_thresh) { PythonAPICall py("cache_index"); py.Input("block", block); py.Input("storage_scope", storage_scope); diff --git a/src/s_tir/schedule/primitive/cache_read_write.cc b/src/s_tir/schedule/primitive/cache_read_write.cc index 626eaa57f3d4..2cb9b5ac9484 100644 --- a/src/s_tir/schedule/primitive/cache_read_write.cc +++ b/src/s_tir/schedule/primitive/cache_read_write.cc @@ -2396,13 +2396,13 @@ struct CacheReadTraits : public UnpackedInstTraits { static SBlockRV UnpackedApplyToSchedule(Schedule sch, SBlockRV block, ffi::Array consumer_blocks, - Integer read_buffer_index, ffi::String storage_scope) { + IntImm read_buffer_index, ffi::String storage_scope) { return sch->CacheRead(block, read_buffer_index->value, storage_scope, consumer_blocks); } static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, ffi::Array consumer_blocks, - Integer read_buffer_index, ffi::String storage_scope) { + IntImm read_buffer_index, ffi::String storage_scope) { PythonAPICall py("cache_read"); py.Input("block", block); py.Input("read_buffer_index", read_buffer_index->value); @@ -2430,13 +2430,13 @@ struct CacheWriteTraits : public UnpackedInstTraits { static SBlockRV UnpackedApplyToSchedule(Schedule sch, SBlockRV block, ffi::Array consumer_blocks, - Integer write_buffer_index, ffi::String storage_scope) { + IntImm write_buffer_index, ffi::String storage_scope) { return sch->CacheWrite(block, write_buffer_index->value, storage_scope, consumer_blocks); } static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, ffi::Array consumer_blocks, - Integer write_buffer_index, ffi::String storage_scope) { + IntImm write_buffer_index, ffi::String storage_scope) { PythonAPICall py("cache_write"); py.Input("block", block); py.Input("write_buffer_index", write_buffer_index->value); @@ -2463,13 +2463,13 @@ struct CacheInplaceTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; static ffi::Array UnpackedApplyToSchedule(Schedule sch, SBlockRV block, - Integer read_buffer_index, + IntImm read_buffer_index, ffi::String storage_scope) { return sch->CacheInplace(block, read_buffer_index->value, storage_scope); } static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, - Integer read_buffer_index, ffi::String storage_scope) { + IntImm read_buffer_index, ffi::String storage_scope) { PythonAPICall py("cache_inplace"); py.Input("block", block); py.Input("read_buffer_index", read_buffer_index->value); @@ -2491,14 +2491,14 @@ struct ReIndexTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 0; - static SBlockRV UnpackedApplyToSchedule(Schedule sch, SBlockRV block, Integer buffer_index, - Integer buffer_index_type) { - return sch->ReIndex(block, buffer_index.IntValue(), + static SBlockRV UnpackedApplyToSchedule(Schedule sch, SBlockRV block, IntImm buffer_index, + IntImm buffer_index_type) { + return sch->ReIndex(block, buffer_index->value, static_cast(buffer_index_type->value)); } static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, - Integer buffer_index, Integer buffer_index_type) { + IntImm buffer_index, IntImm buffer_index_type) { PythonAPICall py("reindex"); py.Input("block", block); std::ostringstream os; @@ -2523,12 +2523,12 @@ struct ReindexCacheReadTraits : public UnpackedInstTraitsReindexCacheRead(block, read_buffer_index->value, storage_scope, index_map); } static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, - IndexMap index_map, Integer read_buffer_index, + IndexMap index_map, IntImm read_buffer_index, ffi::String storage_scope) { PythonAPICall py("reindex_cache_read"); py.Input("block", block); @@ -2553,12 +2553,12 @@ struct ReindexCacheWriteTraits : public UnpackedInstTraitsReindexCacheWrite(block, write_buffer_index->value, storage_scope, index_map); } static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, - IndexMap index_map, Integer write_buffer_index, + IndexMap index_map, IntImm write_buffer_index, ffi::String storage_scope) { PythonAPICall py("reindex_cache_write"); py.Input("block", block); diff --git a/src/s_tir/schedule/primitive/compute_at.cc b/src/s_tir/schedule/primitive/compute_at.cc index 0affecd2d5c6..0ea79faab317 100644 --- a/src/s_tir/schedule/primitive/compute_at.cc +++ b/src/s_tir/schedule/primitive/compute_at.cc @@ -815,16 +815,16 @@ struct ComputeAtTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; static void UnpackedApplyToSchedule(Schedule sch, SBlockRV block_rv, LoopRV loop_rv, - Bool preserve_unit_loops, IntImm index) { - return sch->ComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool(), index->value); + IntImm preserve_unit_loops, IntImm index) { + return sch->ComputeAt(block_rv, loop_rv, preserve_unit_loops->value != 0, index->value); } static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, - ffi::String loop_rv, Bool preserve_unit_loops, IntImm index) { + ffi::String loop_rv, IntImm preserve_unit_loops, IntImm index) { PythonAPICall py("compute_at"); py.Input("block", block_rv); py.Input("loop", loop_rv); - py.Input("preserve_unit_loops", preserve_unit_loops.operator bool()); + py.Input("preserve_unit_loops", preserve_unit_loops->value != 0); py.Input("index", index); return py.Str(); } @@ -843,17 +843,17 @@ struct ReverseComputeAtTraits : public UnpackedInstTraitsReverseComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool(), + IntImm preserve_unit_loops, IntImm index) { + return sch->ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops->value != 0, index->value); } static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, - ffi::String loop_rv, Bool preserve_unit_loops, IntImm index) { + ffi::String loop_rv, IntImm preserve_unit_loops, IntImm index) { PythonAPICall py("reverse_compute_at"); py.Input("block", block_rv); py.Input("loop", loop_rv); - py.Input("preserve_unit_loops", preserve_unit_loops.operator bool()); + py.Input("preserve_unit_loops", preserve_unit_loops->value != 0); py.Input("index", index); return py.Str(); } diff --git a/src/s_tir/schedule/primitive/layout_transformation.cc b/src/s_tir/schedule/primitive/layout_transformation.cc index e9fa97772862..9878828e3eb9 100644 --- a/src/s_tir/schedule/primitive/layout_transformation.cc +++ b/src/s_tir/schedule/primitive/layout_transformation.cc @@ -1579,18 +1579,18 @@ struct TransformLayoutTraits : public UnpackedInstTraits static constexpr size_t kNumDecisions = 0; static void UnpackedApplyToSchedule(Schedule sch, SBlockRV block_rv, IndexMap index_map, - Integer buffer_index, Integer buffer_index_type, + IntImm buffer_index, IntImm buffer_index_type, ffi::Optional pad_value, - Bool assume_injective_transform) { - return sch->TransformLayout(block_rv, buffer_index.IntValue(), + IntImm assume_injective_transform) { + return sch->TransformLayout(block_rv, buffer_index->value, static_cast(buffer_index_type->value), index_map, - pad_value, assume_injective_transform.operator bool()); + pad_value, assume_injective_transform->value != 0); } static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, - IndexMap index_map, Integer buffer_index, - Integer buffer_index_type, ffi::Optional pad_value, - Bool assume_injective_transform) { + IndexMap index_map, IntImm buffer_index, + IntImm buffer_index_type, ffi::Optional pad_value, + IntImm assume_injective_transform) { PythonAPICall py("transform_layout"); py.Input("block", block_rv); @@ -1600,7 +1600,7 @@ struct TransformLayoutTraits : public UnpackedInstTraits py.Input("buffer", os.str()); py.Input("index_map", index_map->ToPythonString()); py.Input("pad_value", pad_value ? pad_value.value()->ToPythonString() : "None"); - py.Input("assume_injective_transform", assume_injective_transform.operator bool()); + py.Input("assume_injective_transform", assume_injective_transform->value != 0); return py.Str(); } @@ -1691,16 +1691,16 @@ struct SetAxisSeparatorTraits : public UnpackedInstTraits axis_separators) { - return sch->SetAxisSeparator(block_rv, buffer_index.IntValue(), + return sch->SetAxisSeparator(block_rv, buffer_index->value, static_cast(buffer_index_type->value), axis_separators); } static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, - Integer buffer_index, Integer buffer_index_type, + IntImm buffer_index, IntImm buffer_index_type, ffi::Array axis_separators) { PythonAPICall py("set_axis_separator"); py.Input("block", block_rv); diff --git a/src/s_tir/schedule/primitive/loop_transformation.cc b/src/s_tir/schedule/primitive/loop_transformation.cc index 87b5b4042f77..f86f5d3b3fa5 100644 --- a/src/s_tir/schedule/primitive/loop_transformation.cc +++ b/src/s_tir/schedule/primitive/loop_transformation.cc @@ -1200,20 +1200,20 @@ struct SplitTraits : public UnpackedInstTraits { static ffi::Array UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, ffi::Array> factors, - Bool preserve_unit_iters, - Bool disable_predication) { - return sch->Split(loop_rv, factors, preserve_unit_iters.operator bool(), - disable_predication.operator bool()); + IntImm preserve_unit_iters, + IntImm disable_predication) { + return sch->Split(loop_rv, factors, preserve_unit_iters->value != 0, + disable_predication->value != 0); } static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop_rv, - ffi::Array factors, Bool preserve_unit_iters, - Bool disable_predication) { + ffi::Array factors, IntImm preserve_unit_iters, + IntImm disable_predication) { PythonAPICall py("split"); py.Input("loop", loop_rv); py.Input("factors", factors); - py.Input("preserve_unit_iters", preserve_unit_iters.operator bool()); - py.Input("disable_predication", disable_predication.operator bool()); + py.Input("preserve_unit_iters", preserve_unit_iters->value != 0); + py.Input("disable_predication", disable_predication->value != 0); py.OutputList(outputs); return py.Str(); } @@ -1243,16 +1243,16 @@ struct LoopPartitionTraits : public UnpackedInstTraits { static ffi::Array UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, ffi::Array> factors, - Bool preserve_unit_iters) { - return sch->LoopPartition(loop_rv, factors, preserve_unit_iters.operator bool()); + IntImm preserve_unit_iters) { + return sch->LoopPartition(loop_rv, factors, preserve_unit_iters->value != 0); } static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop_rv, - ffi::Array factors, Bool preserve_unit_iters) { + ffi::Array factors, IntImm preserve_unit_iters) { PythonAPICall py("loop_partition"); py.Input("loop", loop_rv); py.Input("factors", factors); - py.Input("preserve_unit_iters", preserve_unit_iters.operator bool()); + py.Input("preserve_unit_iters", preserve_unit_iters->value != 0); py.OutputList(outputs); return py.Str(); } @@ -1308,17 +1308,17 @@ struct FuseTraits : public UnpackedInstTraits { } static LoopRV UnpackedApplyToSchedule(Schedule sch, ffi::Array loop_rvs, - Bool preserve_unit_iters) { - return sch->Fuse(loop_rvs, preserve_unit_iters.operator bool()); + IntImm preserve_unit_iters) { + return sch->Fuse(loop_rvs, preserve_unit_iters->value != 0); } static ffi::String UnpackedAsPython(ffi::Array outputs, - ffi::Array loop_rvs, Bool preserve_unit_iters) { + ffi::Array loop_rvs, IntImm preserve_unit_iters) { PythonAPICall py("fuse"); for (const ffi::String& loop_rv : loop_rvs) { py.Input("", loop_rv); } - py.Input("preserve_unit_iters", preserve_unit_iters.operator bool()); + py.Input("preserve_unit_iters", preserve_unit_iters->value != 0); py.SingleOutput(outputs); return py.Str(); } diff --git a/src/s_tir/schedule/primitive/read_write_at.cc b/src/s_tir/schedule/primitive/read_write_at.cc index 73990add29b1..04ef08b9d738 100644 --- a/src/s_tir/schedule/primitive/read_write_at.cc +++ b/src/s_tir/schedule/primitive/read_write_at.cc @@ -371,12 +371,12 @@ struct ReadAtTraits : public UnpackedInstTraits { StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, int buffer_index, const ffi::String& storage_scope); static SBlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop, SBlockRV block, - Integer read_buffer_index, ffi::String storage_scope) { + IntImm read_buffer_index, ffi::String storage_scope) { return sch->ReadAt(loop, block, read_buffer_index->value, storage_scope); } static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop, - ffi::String block, Integer read_buffer_index, + ffi::String block, IntImm read_buffer_index, ffi::String storage_scope) { PythonAPICall py("read_at"); py.Input("loop", loop); @@ -401,12 +401,12 @@ struct WriteAtTraits : public UnpackedInstTraits { static constexpr size_t kNumDecisions = 0; static SBlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop, SBlockRV block, - Integer write_buffer_index, ffi::String storage_scope) { + IntImm write_buffer_index, ffi::String storage_scope) { return sch->WriteAt(loop, block, write_buffer_index->value, storage_scope); } static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop, - ffi::String block, Integer write_buffer_index, + ffi::String block, IntImm write_buffer_index, ffi::String storage_scope) { PythonAPICall py("write_at"); py.Input("loop", loop); diff --git a/src/s_tir/schedule/primitive/reduction.cc b/src/s_tir/schedule/primitive/reduction.cc index dc900f94cdb6..c36dc86ec907 100644 --- a/src/s_tir/schedule/primitive/reduction.cc +++ b/src/s_tir/schedule/primitive/reduction.cc @@ -1334,12 +1334,12 @@ struct RFactorTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 1; static constexpr size_t kNumDecisions = 0; - static SBlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Integer factor_axis) { + static SBlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, IntImm factor_axis) { return sch->RFactor(loop_rv, factor_axis->value); } static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop_rv, - Integer factor_axis) { + IntImm factor_axis) { PythonAPICall py("rfactor"); py.Input("loop", loop_rv); py.Input("factor_axis", factor_axis->value); diff --git a/src/s_tir/schedule/primitive/reorder_block_iter_var.cc b/src/s_tir/schedule/primitive/reorder_block_iter_var.cc index a3246b7c9d20..753b593ef357 100644 --- a/src/s_tir/schedule/primitive/reorder_block_iter_var.cc +++ b/src/s_tir/schedule/primitive/reorder_block_iter_var.cc @@ -88,8 +88,8 @@ void ReorderBlockIterVar(ScheduleState self, const StmtSRef& block_sref, const ffi::Array& new_order) { const SBlockNode* block_n = TVM_SREF_TO_SBLOCK(block_sref); std::vector new_order_vec; - for (const Integer& x : new_order) { - new_order_vec.push_back(x->value); + for (int64_t x : new_order) { + new_order_vec.push_back(static_cast(x)); } // check whether new_order is valid or not; size_t num_block_itervars = block_n->iter_vars.size(); diff --git a/src/s_tir/schedule/primitive/rolling_buffer.cc b/src/s_tir/schedule/primitive/rolling_buffer.cc index 85e4d3b2a8bb..402cb8aef106 100644 --- a/src/s_tir/schedule/primitive/rolling_buffer.cc +++ b/src/s_tir/schedule/primitive/rolling_buffer.cc @@ -458,12 +458,12 @@ struct RollingBufferTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 1; static constexpr size_t kNumDecisions = 0; - static void UnpackedApplyToSchedule(Schedule sch, SBlockRV block, Integer write_buffer_index) { - return sch->RollingBuffer(block, write_buffer_index.IntValue()); + static void UnpackedApplyToSchedule(Schedule sch, SBlockRV block, IntImm write_buffer_index) { + return sch->RollingBuffer(block, write_buffer_index->value); } static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, - Integer write_buffer_index) { + IntImm write_buffer_index) { PythonAPICall py("rolling_buffer"); py.Input("block", block); py.Input("write_buffer_index", write_buffer_index); diff --git a/src/s_tir/schedule/primitive/sampling.cc b/src/s_tir/schedule/primitive/sampling.cc index 72388e505e2f..337e57b4ad49 100644 --- a/src/s_tir/schedule/primitive/sampling.cc +++ b/src/s_tir/schedule/primitive/sampling.cc @@ -493,14 +493,14 @@ struct SamplePerfectTileTraits : public UnpackedInstTraits UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Integer n, - Integer max_innermost_factor, + static ffi::Array UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, IntImm n, + IntImm max_innermost_factor, ffi::Optional> decision) { return sch->SamplePerfectTile(loop_rv, n->value, max_innermost_factor->value, decision); } static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop_rv, - Integer n, Integer max_innermost_factor, + IntImm n, IntImm max_innermost_factor, ffi::Optional> decision) { PythonAPICall py("sample_perfect_tile"); py.Input("loop", loop_rv); @@ -524,15 +524,15 @@ struct SamplePartitionedTileTraits : public UnpackedInstTraits UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Integer n, - Integer partition_pos, Integer innerpart_factor, + static ffi::Array UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, IntImm n, + IntImm partition_pos, IntImm innerpart_factor, ffi::Optional> decision) { return sch->SamplePartitionedTile(loop_rv, n->value, partition_pos->value, innerpart_factor->value, decision); } static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String loop_rv, - Integer n, Integer partition_pos, Integer innerpart_factor, + IntImm n, IntImm partition_pos, IntImm innerpart_factor, ffi::Optional> decision) { PythonAPICall py("sample_partitioned_tile"); py.Input("loop", loop_rv); diff --git a/src/s_tir/transform/inject_software_pipeline.cc b/src/s_tir/transform/inject_software_pipeline.cc index f85918e511f5..79e3289d04be 100644 --- a/src/s_tir/transform/inject_software_pipeline.cc +++ b/src/s_tir/transform/inject_software_pipeline.cc @@ -1218,7 +1218,7 @@ class PipelineInjector : private StmtExprMutator { auto it = op->annotations.find(s_tir::attr::double_buffer_scope); if (it != op->annotations.end()) { - int buffer_index = Downcast((*it).second).IntValue(); + int buffer_index = static_cast(Downcast((*it).second)->value); TVM_FFI_CHECK(buffer_index >= 0 && static_cast(buffer_index) < op->writes.size(), ValueError) << "Index of the buffer exceeds the size of the write regions of the block. (" diff --git a/src/s_tir/transform/memhammer_coalesce.cc b/src/s_tir/transform/memhammer_coalesce.cc index 7785cab3bfdd..fb67c3eae1b0 100644 --- a/src/s_tir/transform/memhammer_coalesce.cc +++ b/src/s_tir/transform/memhammer_coalesce.cc @@ -67,7 +67,7 @@ Stmt FuseNestLoops(Stmt body) { */ Stmt SplitBindVectorize(const Stmt& stmt, const ConstraintSet& constraints) { const ForNode* loop = TVM_TYPE_AS(stmt, ForNode); - int loop_extent = Downcast(loop->extent)->value; + int loop_extent = Downcast(loop->extent)->value; int vector_bytes = constraints.vector_bytes; int data_bits = constraints.data_bits; int vector_len = std::max(1, vector_bytes * 8 / data_bits); diff --git a/src/s_tir/transform/memhammer_lower_auto_copy.cc b/src/s_tir/transform/memhammer_lower_auto_copy.cc index 245c40318ae3..76e1d0302b70 100644 --- a/src/s_tir/transform/memhammer_lower_auto_copy.cc +++ b/src/s_tir/transform/memhammer_lower_auto_copy.cc @@ -484,9 +484,9 @@ class AutoPadder { if (op->kind != ForKind::kThreadBinding) { substitute_map_.Set(op->loop_var, op->min); } else { - Integer extent = + int64_t extent = warp_thread_extent_.Get(op->thread_binding.value()->thread_tag).value_or(1); - var_range_.Set(op->loop_var, Range::FromMinExtent(op->min, extent)); + var_range_.Set(op->loop_var, Range::FromMinExtent(op->min, IntImm(DataType::Int(64), extent))); } if (op->kind == ForKind::kVectorized) { vector_var = op->loop_var; diff --git a/src/s_tir/transform/memhammer_rewrite_rule.h b/src/s_tir/transform/memhammer_rewrite_rule.h index 1c5e3bf45b78..2f8442e17e51 100644 --- a/src/s_tir/transform/memhammer_rewrite_rule.h +++ b/src/s_tir/transform/memhammer_rewrite_rule.h @@ -64,10 +64,10 @@ struct ConstraintSet { write_region(write_region), data_bits(data_bits) { if (auto add_local_stage = ann.Get("local_stage")) { - this->add_local_stage = Downcast(add_local_stage.value())->value; + this->add_local_stage = Downcast(add_local_stage.value())->value; } if (auto vector_bytes = ann.Get("vector_bytes")) { - this->vector_bytes = Downcast(vector_bytes.value())->value; + this->vector_bytes = Downcast(vector_bytes.value())->value; } } }; diff --git a/src/target/cuda/codegen_cuda.cc b/src/target/cuda/codegen_cuda.cc index a57e70ff5fc1..5036fb208e3f 100644 --- a/src/target/cuda/codegen_cuda.cc +++ b/src/target/cuda/codegen_cuda.cc @@ -1051,7 +1051,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string b_bias = this->PrintExpr(op->args[9]); std::string c_ref = this->PrintExpr(op->args[10]); std::string c_bias = this->PrintExpr(op->args[11]); - bool saturate = Downcast(op->args[12])->value; + bool saturate = Downcast(op->args[12])->value; std::string bit_op = op->args.size() > 13 ? Downcast(op->args[13])->value : ""; std::string asm_code = PrintMMAAssembly(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, b_ref, @@ -1091,14 +1091,14 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string metadata = this->PrintExpr(op->args[12]); std::string metadata_offset = this->PrintExpr(op->args[13]); std::string sparse_selector = this->PrintExpr(op->args[14]); - bool saturate = Downcast(op->args[15])->value; + bool saturate = Downcast(op->args[15])->value; std::string asm_code = PrintMMAAssembly( shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset, b_ref, b_offset, c_ref, c_offset, metadata, metadata_offset, sparse_selector, "", true, saturate); this->stream << asm_code; } else if (op->op.same_as(builtin::mma_store())) { - int m = Downcast(op->args[0])->value; - int n = Downcast(op->args[1])->value; + int m = Downcast(op->args[0])->value; + int n = Downcast(op->args[1])->value; std::string dst = this->PrintExpr(op->args[2]); std::string src = this->PrintExpr(op->args[3]); std::string src_offset = this->PrintExpr(op->args[4]); @@ -1172,7 +1172,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string b_bias = this->PrintExpr(op->args[9]); std::string c_ref = this->PrintExpr(op->args[10]); std::string c_bias = this->PrintExpr(op->args[11]); - bool saturate = Downcast(op->args[12])->value; + bool saturate = Downcast(op->args[12])->value; std::string bit_op = op->args.size() > 13 ? Downcast(op->args[13])->value : ""; this->stream << PrintMMAAssembly(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, b_ref, b_bias, c_ref, c_bias, "", "", "", bit_op, @@ -1209,8 +1209,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { // args: m, n, dst_ptr, src_ptr_var, src_offset, dst_stride // (dst_ptr is typically an access_ptr Call that already encodes // dst.elem_offset and the global pointer cast.) - int m = Downcast(op->args[0])->value; - int n = Downcast(op->args[1])->value; + int m = Downcast(op->args[0])->value; + int n = Downcast(op->args[1])->value; std::string dst = this->PrintExpr(op->args[2]); std::string src = this->PrintExpr(op->args[3]); std::string src_offset = this->PrintExpr(op->args[4]); diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index f1d04bf4aa84..6f44a8d3d0f8 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -232,7 +232,7 @@ class CodeGenC : public ExprFunctor, // Print restrict keyword for a given Var if applicable virtual void PrintRestrict(const Var& v, std::ostream& os); - virtual void SetConstantsByteAlignment(Integer constants_byte_alignment) { + virtual void SetConstantsByteAlignment(int64_t constants_byte_alignment) { constants_byte_alignment_ = constants_byte_alignment; } @@ -323,7 +323,7 @@ class CodeGenC : public ExprFunctor, // cache commonly used ops const Op& builtin_call_extern_ = builtin::call_extern(); const Op& builtin_call_pure_extern_ = builtin::call_pure_extern(); - Integer constants_byte_alignment_ = 16; + int64_t constants_byte_alignment_ = 16; /*! \brief whether to print in SSA form */ bool print_ssa_form_{false}; /*! \brief whether the module has a main function declared */ diff --git a/src/target/target.cc b/src/target/target.cc index f1cf5a007bb6..89cf328cf959 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -177,7 +177,7 @@ Target Target::WithoutHost() const { int TargetNode::GetTargetDeviceType() const { if (ffi::Optional device_type = GetAttr("target_device_type")) { - return Downcast(device_type)->value; + return Downcast(device_type)->value; } return kind->default_device_type; } diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 5779b4da0ec2..cbad63fdaf18 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -277,11 +277,11 @@ ffi::Map UpdateROCmAttrs(ffi::Map ffi::Map UpdateWebGPUAttrs(ffi::Map target) { bool subgroups = false; if (target.count("supports_subgroups")) { - subgroups = Downcast(target.at("supports_subgroups")); + subgroups = Downcast(target.at("supports_subgroups"))->value != 0; } if (target.count("thread_warp_size")) { - int64_t thread_warp_size = Downcast(target.at("thread_warp_size"))->value; + int64_t thread_warp_size = Downcast(target.at("thread_warp_size"))->value; TVM_FFI_ICHECK(subgroups || thread_warp_size <= 1) << "WebGPU target with thread_warp_size=" << thread_warp_size << " requires supports_subgroups=true"; diff --git a/src/target/vulkan/codegen_spirv.cc b/src/target/vulkan/codegen_spirv.cc index 3e67d2ea1fd6..7e9fa2b8a3df 100644 --- a/src/target/vulkan/codegen_spirv.cc +++ b/src/target/vulkan/codegen_spirv.cc @@ -627,7 +627,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const ShuffleNode* op) { << "SPIR-V codegen only supports shuffle " << "of one vector with one index"; spirv::Value vector = MakeValue(op->vectors[0]); - int index = Downcast(op->indices[0])->value; + int index = Downcast(op->indices[0])->value; spirv::SType etype = builder_->GetSType(op->dtype); spirv::Value element = builder_->MakeValue(spv::OpCompositeExtract, etype, vector, index); return element; diff --git a/src/tirx/ir/data_type_rewriter.cc b/src/tirx/ir/data_type_rewriter.cc index cc4c2d5f78df..9b030c560e09 100644 --- a/src/tirx/ir/data_type_rewriter.cc +++ b/src/tirx/ir/data_type_rewriter.cc @@ -630,7 +630,7 @@ bool IndexDataTypeNormalizer::CanRewriteDType(DataType dtype) const { PrimExpr IndexDataTypeNormalizer::VisitExpr_(const IntImmNode* op) { if (is_enabled_ && CanRewriteDType(op->dtype)) { - TVM_FFI_ICHECK_LE(op->value, Downcast(max_value(target_data_type_))->value); + TVM_FFI_ICHECK_LE(op->value, Downcast(max_value(target_data_type_))->value); return cast(target_data_type_, ffi::GetRef(op)); } return ffi::GetRef(op); diff --git a/src/tirx/transform/force_narrow_index_to_i32.cc b/src/tirx/transform/force_narrow_index_to_i32.cc index b38b2588992c..82a23d3b4f17 100644 --- a/src/tirx/transform/force_narrow_index_to_i32.cc +++ b/src/tirx/transform/force_narrow_index_to_i32.cc @@ -56,7 +56,7 @@ class Int32DTypeNarrower : public IndexDataTypeNormalizer { PrimExpr VisitExpr_(const IntImmNode* op) final { // ignore the enabled condition and always rewrite i64 if (op->dtype == DataType::Int(64)) { - TVM_FFI_ICHECK_LE(op->value, Downcast(max_value(target_data_type_))->value); + TVM_FFI_ICHECK_LE(op->value, Downcast(max_value(target_data_type_))->value); return IntImm(DataType::Int(32), op->value); } return ffi::GetRef(op); diff --git a/src/tirx/transform/lower_tvm_builtin.cc b/src/tirx/transform/lower_tvm_builtin.cc index f522f46c6c27..3b1336515721 100644 --- a/src/tirx/transform/lower_tvm_builtin.cc +++ b/src/tirx/transform/lower_tvm_builtin.cc @@ -241,7 +241,7 @@ class BuiltinLower : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (op->annotations.count(transform::kDisableLowerTVMBuiltin)) { - if (Downcast(op->annotations[transform::kDisableLowerTVMBuiltin])) { + if (Downcast(op->annotations[transform::kDisableLowerTVMBuiltin])->value) { return stmt; } } diff --git a/src/tirx/transform/make_packed_api.cc b/src/tirx/transform/make_packed_api.cc index c7125d82fa3b..4f8229080f9c 100644 --- a/src/tirx/transform/make_packed_api.cc +++ b/src/tirx/transform/make_packed_api.cc @@ -228,7 +228,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { // The device context Var device_id("dev_id"); - Integer device_type(target_device_type); + IntImm device_type(DataType::Int(32), target_device_type); // Create TVMFFIABIBuilder and decode all packed args TVMFFIABIBuilder binder(name_hint, func_ptr->params, func_ptr->buffer_map, v_packed_args, diff --git a/src/tirx/transform/unroll_loop.cc b/src/tirx/transform/unroll_loop.cc index ae99410ceea0..4a6beae92f0f 100644 --- a/src/tirx/transform/unroll_loop.cc +++ b/src/tirx/transform/unroll_loop.cc @@ -103,13 +103,13 @@ class LoopUnroller : public StmtExprMutator { Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == "pragma_auto_unroll_max_step") { - int value = static_cast(Downcast(op->value)->value); + int value = static_cast(Downcast(op->value)->value); std::swap(value, auto_max_step_); Stmt ret = this->VisitStmt(op->body); std::swap(value, auto_max_step_); return ret; } else if (op->attr_key == "pragma_unroll_explicit") { - bool explicit_unroll = Downcast(op->value)->value; + bool explicit_unroll = Downcast(op->value)->value; std::swap(explicit_unroll, explicit_unroll_); Stmt ret = this->VisitStmt(op->body); std::swap(explicit_unroll, explicit_unroll_); diff --git a/src/topi/nn.cc b/src/topi/nn.cc index 1f8118231fae..e7b0d9c69e44 100644 --- a/src/topi/nn.cc +++ b/src/topi/nn.cc @@ -68,14 +68,14 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def_packed("topi.nn.space_to_batch_nd", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = space_to_batch_nd( - args[0].cast(), args[1].cast>(), + args[0].cast(), args[1].cast>(), args[2].cast>(), args[3].cast>(), args[4].cast()); }) .def_packed("topi.nn.batch_to_space_nd", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = batch_to_space_nd( - args[0].cast(), args[1].cast>(), + args[0].cast(), args[1].cast>(), args[2].cast>(), args[3].cast>(), args[4].cast()); }) @@ -107,7 +107,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("topi.nn.dilate", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::dilate(args[0].cast(), args[1].cast>(), + *rv = nn::dilate(args[0].cast(), args[1].cast>(), args[2].cast()); }); } @@ -239,7 +239,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("topi.nn.layer_norm", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::layer_norm(args[0].cast(), args[1].cast(), - args[2].cast(), args[3].cast>(), + args[2].cast(), args[3].cast>(), args[4].cast()); }); } @@ -250,7 +250,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def_packed("topi.nn.group_norm", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::group_norm(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast(), args[4].cast(), - args[5].cast>(), args[6].cast()); + args[5].cast>(), args[6].cast()); }); } @@ -260,7 +260,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def_packed("topi.nn.instance_norm", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::instance_norm(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast(), - args[4].cast>(), args[5].cast()); + args[4].cast>(), args[5].cast()); }); } @@ -269,7 +269,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("topi.nn.rms_norm", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::rms_norm(args[0].cast(), args[1].cast(), - args[2].cast>(), args[3].cast()); + args[2].cast>(), args[3].cast()); }); } diff --git a/src/topi/reduction.cc b/src/topi/reduction.cc index 0f2a7f49fc73..3ab084e8cb99 100644 --- a/src/topi/reduction.cc +++ b/src/topi/reduction.cc @@ -76,7 +76,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { args[2].cast()); }) .def_packed("topi.collapse_sum", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = topi::collapse_sum(args[0].cast(), args[1].cast>()); + *rv = topi::collapse_sum(args[0].cast(), args[1].cast>()); }); } diff --git a/src/topi/transform.cc b/src/topi/transform.cc index 09f9a9be5ea7..203e1b7da6f5 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -48,7 +48,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def_packed("topi.transpose", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = transpose(args[0].cast(), - args[1].cast>>()); + args[1].cast>>()); }) .def_packed("topi.flip", [](ffi::PackedArgs args, ffi::Any* rv) { @@ -68,8 +68,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def_packed("topi.sliding_window", [](ffi::PackedArgs args, ffi::Any* rv) { *rv = sliding_window(args[0].cast(), args[1].cast(), - args[2].cast>(), - args[3].cast>()); + args[2].cast>(), + args[3].cast>()); }) .def_packed("topi.squeeze", [](ffi::PackedArgs args, ffi::Any* rv) { @@ -98,7 +98,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { args[2].cast()); } else { *rv = split_indices_array(args[0].cast(), - args[1].cast>(), + args[1].cast>(), args[2].cast()); } }) @@ -154,7 +154,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { }) .def_packed("topi.tile", [](ffi::PackedArgs args, ffi::Any* rv) { - *rv = tile(args[0].cast(), args[1].cast>()); + *rv = tile(args[0].cast(), args[1].cast>()); }) .def_packed("topi.dyn_tile", [](ffi::PackedArgs args, ffi::Any* rv) { @@ -220,13 +220,13 @@ TVM_FFI_STATIC_INIT_BLOCK() { ffi::Array begin = args[1].cast>(); ffi::Array end = args[2].cast>(); ffi::Array strides = args[3].cast>(); - ffi::Array axes = args[4].cast>(); + ffi::Array axes = args[4].cast>(); bool assume_inbound = args[6].cast(); if (IsConstIntArray(begin) && IsConstIntArray(end) && IsConstIntArray(strides) && IsConstIntArray(x->shape)) { - ffi::Array begin_static = args[1].cast>(); - ffi::Array end_static = args[2].cast>(); - ffi::Array strides_static = args[3].cast>(); + ffi::Array begin_static = args[1].cast>(); + ffi::Array end_static = args[2].cast>(); + ffi::Array strides_static = args[3].cast>(); auto slice_mode = args[5].cast(); if (axes.size()) { *rv = strided_slice_with_axes(x, begin_static, end_static, strides_static, axes, diff --git a/tests/cpp/arith_simplify_test.cc b/tests/cpp/arith_simplify_test.cc index 2724f3a04245..0c2084e8fd3a 100644 --- a/tests/cpp/arith_simplify_test.cc +++ b/tests/cpp/arith_simplify_test.cc @@ -45,8 +45,8 @@ TEST(Simplify, Mul) { TEST(Simplify, Mod) { tvm::arith::Analyzer ana; - auto x = tvm::Integer(10); - auto y = tvm::Integer(12); + auto x = tvm::IntImm(DataType::Int(32), 10); + auto y = tvm::IntImm(DataType::Int(32), 12); // Mod::make is used instead of % to avoid constant folding during // calling operator%(x,y). Mod::make doesn't try constant folding, // and therefore, the constant folding will be attempted in CanonicalSimplify diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index e7a1715cc7bf..8fdd768b81c6 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -58,8 +58,8 @@ TEST(IRF, CountVar) { TEST(IRF, PreOrderVisit) { using namespace tvm; using namespace tvm::tirx; - Stmt init = IfThenElse(const_true(), Evaluate(Integer(0)), Evaluate(Integer(0))); - Stmt body = Evaluate(Integer(1)); + Stmt init = IfThenElse(const_true(), Evaluate(IntImm(DataType::Int(32), 0)), Evaluate(IntImm(DataType::Int(32), 0))); + Stmt body = Evaluate(IntImm(DataType::Int(32), 1)); SBlock block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"block", /*body=*/body, /*init=*/init); diff --git a/tests/cpp/nested_msg_test.cc b/tests/cpp/nested_msg_test.cc index 54594cb0f118..7df9888b689c 100644 --- a/tests/cpp/nested_msg_test.cc +++ b/tests/cpp/nested_msg_test.cc @@ -152,23 +152,23 @@ TEST(NestedMsg, MapAndDecompose) { relax::Expr t0 = bb->Normalize(Tuple({x, y})); relax::Expr t1 = bb->Normalize(Tuple({t0, x, z, t0})); - auto c0 = Integer(0); - auto c1 = Integer(1); - auto c2 = Integer(2); + auto c0 = IntImm(DataType::Int(32), 0); + auto c1 = IntImm(DataType::Int(32), 1); + auto c2 = IntImm(DataType::Int(32), 2); - auto output = MapToNestedMsg(t1, [&](Expr value) { + auto output = MapToNestedMsg(t1, [&](Expr value) { if (value.same_as(x)) return c0; if (value.same_as(y)) return c1; return c2; }); - NestedMsg expected = {{c0, c1}, c0, c2, {c0, c1}}; + NestedMsg expected = {{c0, c1}, c0, c2, {c0, c1}}; EXPECT_TRUE(Equal(output, expected, - [](Integer lhs, Integer rhs) -> bool { return lhs->value == rhs->value; })); + [](IntImm lhs, IntImm rhs) -> bool { return lhs->value == rhs->value; })); auto output2 = - MapToNestedMsg(GetStructInfo(t1), [&](StructInfo sinfo) -> NestedMsg { + MapToNestedMsg(GetStructInfo(t1), [&](StructInfo sinfo) -> NestedMsg { const auto* prim_sinfo = sinfo.as(); if (prim_sinfo == nullptr) return std::nullopt; int bits = prim_sinfo->dtype.bits(); @@ -179,11 +179,11 @@ TEST(NestedMsg, MapAndDecompose) { }); EXPECT_TRUE(Equal(output2, expected, - [](Integer lhs, Integer rhs) -> bool { return lhs->value == rhs->value; })); + [](IntImm lhs, IntImm rhs) -> bool { return lhs->value == rhs->value; })); int x_count = 0, y_count = 0, z_count = 0; - DecomposeNestedMsg(t1, expected, [&](Expr value, NestedMsg msg) { + DecomposeNestedMsg(t1, expected, [&](Expr value, NestedMsg msg) { if (value.same_as(x)) { EXPECT_TRUE(msg.LeafValue().same_as(c0)); ++x_count; @@ -226,16 +226,16 @@ TEST(NestedMsg, NestedMsgToExpr) { auto sf0 = TensorStructInfo(DataType::Float(32), /*ndim=*/0); auto sf1 = TupleStructInfo({sf0, sf0}); - auto c0 = Integer(0); - auto c1 = Integer(1); - auto c2 = Integer(2); + auto c0 = IntImm(DataType::Int(32), 0); + auto c1 = IntImm(DataType::Int(32), 1); + auto c2 = IntImm(DataType::Int(32), 2); relax::Var x("x", sf0), y("y", sf0), z("z", sf0); - NestedMsg msg = {c0, {c0, c1}, {c0, {c1, c2}}}; - auto expr = NestedMsgToExpr(msg, [&](ffi::Optional leaf) { + NestedMsg msg = {c0, {c0, c1}, {c0, {c1, c2}}}; + auto expr = NestedMsgToExpr(msg, [&](ffi::Optional leaf) { TVM_FFI_ICHECK(leaf.defined()); - int value = leaf.value().IntValue(); + int value = leaf.value()->value; switch (value) { case 0: return x; @@ -257,51 +257,51 @@ TEST(NestedMsg, NestedMsgToExpr) { } TEST(NestedMsg, CombineNestedMsg) { - auto c0 = Integer(0); - auto c1 = Integer(1); - auto c2 = Integer(2); + auto c0 = IntImm(DataType::Int(32), 0); + auto c1 = IntImm(DataType::Int(32), 1); + auto c2 = IntImm(DataType::Int(32), 2); - NestedMsg lhs = {c0, {c0, c1}, std::nullopt, {c0, {c1, c2}}}; - NestedMsg rhs = {c1, {c2, std::nullopt}, std::nullopt, {c1, {c2, c2}}}; - NestedMsg expected = {c1, {c2, c1}, std::nullopt, {c1, {c2, c2}}}; + NestedMsg lhs = {c0, {c0, c1}, std::nullopt, {c0, {c1, c2}}}; + NestedMsg rhs = {c1, {c2, std::nullopt}, std::nullopt, {c1, {c2, c2}}}; + NestedMsg expected = {c1, {c2, c1}, std::nullopt, {c1, {c2, c2}}}; - auto output = CombineNestedMsg(lhs, rhs, [](Integer x, Integer y) { + auto output = CombineNestedMsg(lhs, rhs, [](IntImm x, IntImm y) { if (x->value > y->value) return x; return y; }); EXPECT_TRUE(Equal(output, expected, - [](Integer lhs, Integer rhs) -> bool { return lhs->value == rhs->value; })); + [](IntImm lhs, IntImm rhs) -> bool { return lhs->value == rhs->value; })); } TEST(NestedMsg, MapNestedMsg) { - auto c0 = Integer(0); - auto c1 = Integer(1); - auto c2 = Integer(2); - auto c3 = Integer(3); + auto c0 = IntImm(DataType::Int(32), 0); + auto c1 = IntImm(DataType::Int(32), 1); + auto c2 = IntImm(DataType::Int(32), 2); + auto c3 = IntImm(DataType::Int(32), 3); - NestedMsg msg = {c0, {c0, c1}, std::nullopt, {c0, {c2, c1}}}; - NestedMsg expected = {c3, {c3, std::nullopt}, std::nullopt, {c3, {c2, std::nullopt}}}; + NestedMsg msg = {c0, {c0, c1}, std::nullopt, {c0, {c2, c1}}}; + NestedMsg expected = {c3, {c3, std::nullopt}, std::nullopt, {c3, {c2, std::nullopt}}}; - auto output = MapNestedMsg(msg, [](Integer x) { + auto output = MapNestedMsg(msg, [](IntImm x) { if (x->value == 0) { - return NestedMsg(Integer(3)); + return NestedMsg(IntImm(DataType::Int(32), 3)); } else if (x->value == 1) { - return NestedMsg(); + return NestedMsg(); } else { - return NestedMsg(x); + return NestedMsg(x); } }); EXPECT_TRUE(Equal(output, expected, - [](Integer lhs, Integer rhs) -> bool { return lhs->value == rhs->value; })); + [](IntImm lhs, IntImm rhs) -> bool { return lhs->value == rhs->value; })); } TEST(NestedMsg, TransformTupleLeaf) { - auto c0 = Integer(0); - auto c1 = Integer(1); - auto c2 = Integer(2); - using NInt = NestedMsg; + auto c0 = IntImm(DataType::Int(32), 0); + auto c1 = IntImm(DataType::Int(32), 1); + auto c2 = IntImm(DataType::Int(32), 2); + using NInt = NestedMsg; NInt msg1 = {c0, {c0, c1}, c2, {c0, {c1, c2}}}; NInt msg2 = {c1, {c2, c0}, c2, {c1, {c2, c0}}}; @@ -312,8 +312,8 @@ TEST(NestedMsg, TransformTupleLeaf) { Expr expr = bb->Normalize(Tuple({x, Tuple({x, x}), x, Tuple({x, Tuple({x, x})})})); auto ftransleaf = [&](Expr value, std::array msgs) -> Expr { - int lhs = Downcast(msgs[0].LeafValue())->value; - int rhs = Downcast(msgs[1].LeafValue())->value; + int lhs = Downcast(msgs[0].LeafValue())->value; + int rhs = Downcast(msgs[1].LeafValue())->value; if (lhs > rhs) return z; else if (lhs == rhs) From 7186af5f6593921decc4139f79e03ef766f38f90 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 28 May 2026 12:57:18 +0000 Subject: [PATCH 4/5] [REFACTOR][IR] Delete class Bool and class Integer from include/tvm/ir/expr.h MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove the Bool and Integer thin-wrapper classes (which were both subclasses of IntImm sharing IntImmNode) along with their TypeTraits and TypeTraits specializations. All call sites were migrated to IntImm in the preceding three commits. The canonical replacements are: - Integer(N) → IntImm(DataType::Int(32), N) - Bool(b) → IntImm(DataType::Bool(), b) - x.IntValue() → x->value - x operator bool → x->value != 0 --- include/tvm/ir/expr.h | 120 ---------------------- include/tvm/topi/nn.h | 6 +- include/tvm/topi/transform.h | 3 +- src/relax/backend/contrib/clml/codegen.cc | 3 +- tests/cpp/arith_simplify_test.cc | 4 +- 5 files changed, 9 insertions(+), 127 deletions(-) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index fcd267163c2c..9a3e5b1843f1 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -557,108 +557,6 @@ class FloatImm : public PrimExpr { TVM_DEFINE_OBJECT_REF_COW_METHOD(FloatImmNode); }; -/*! - * \brief Boolean constant. - * - * This reference type is useful to add additional compile-time - * type checks and helper functions for Integer equal comparisons. - */ -class Bool : public IntImm { - public: - explicit Bool(bool value, Span span = Span()) : IntImm(DataType::Bool(), value, span) {} - Bool operator!() const { return Bool((*this)->value == 0); } - operator bool() const { return (*this)->value != 0; } - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Bool, IntImm, IntImmNode); -}; - -// Overload operators to make sure we have the most fine grained types. -inline Bool operator||(const Bool& a, bool b) { return Bool(a.operator bool() || b); } -inline Bool operator||(bool a, const Bool& b) { return Bool(a || b.operator bool()); } -inline Bool operator||(const Bool& a, const Bool& b) { - return Bool(a.operator bool() || b.operator bool()); -} -inline Bool operator&&(const Bool& a, bool b) { return Bool(a.operator bool() && b); } -inline Bool operator&&(bool a, const Bool& b) { return Bool(a && b.operator bool()); } -inline Bool operator&&(const Bool& a, const Bool& b) { - return Bool(a.operator bool() && b.operator bool()); -} - -inline bool operator==(const Bool& a, bool b) { return a.operator bool() == b; } -inline bool operator==(bool a, const Bool& b) { return a == b.operator bool(); } -inline bool operator==(const Bool& a, const Bool& b) { - return a.operator bool() == b.operator bool(); -} - -/*! - * \brief Container of constant int that adds more constructors. - * - * This is used to store and automate type check - * attributes that must be constant integer. - * - * \sa IntImm - */ -class Integer : public IntImm { - public: - Integer() {} - /*! - * \brief constructor from node. - */ - explicit Integer(ffi::ObjectPtr node) : IntImm(node) {} - /*! - * \brief constructor with UnsafeInit - */ - explicit Integer(ffi::UnsafeInit tag) : IntImm(tag) {} - /*! - * \brief Construct integer from int value. - */ - Integer(int value, Span span = Span()) : IntImm(DataType::Int(32), value, span) {} // NOLINT(*) - /*! - * \brief Construct integer from int imm. - * \param other The other value. - */ - Integer(IntImm other) : IntImm(std::move(other)) {} // NOLINT(*) - /*! - * \brief Constructor from enum - * \tparam Enum The enum type. - * \param value The enum value. - */ - template ::value>::type> - explicit Integer(Enum value) : Integer(static_cast(value)) { - static_assert(std::is_same::type>::value, - "declare enum to be enum int to use visitor"); - } - /*! - * \brief Assign an expression to integer. - * \param other another expression. - */ - Integer& operator=(const IntImm& other) { - data_ = ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef(other); - return *this; - } - /*! - * \brief convert to int64_t - */ - int64_t IntValue() const { - TVM_FFI_ICHECK(data_ != nullptr) << " Trying to reference a null Integer"; - return (*this)->value; - } - // comparators - Bool operator==(int other) const { - if (data_ == nullptr) return Bool(false); - return Bool((*this)->value == other); - } - Bool operator!=(int other) const { return !(*this == other); } - template ::value>::type> - Bool operator==(Enum other) const { - return *this == static_cast(other); - } - template ::value>::type> - Bool operator!=(Enum other) const { - return *this != static_cast(other); - } -}; - /*! \brief range over one dimension */ class RangeNode : public ffi::Object { public: @@ -729,16 +627,6 @@ struct TypeTraits : public ObjectRefWithFallbackTraitsBase -inline constexpr bool use_default_type_traits_v = false; - -template <> -struct TypeTraits : public ObjectRefWithFallbackTraitsBase { - TVM_FFI_INLINE static Integer ConvertFallbackValue(int64_t value) { - return Integer(TypeTraits::ConvertFallbackValue(value)); - } -}; - template <> inline constexpr bool use_default_type_traits_v = false; @@ -749,14 +637,6 @@ struct TypeTraits : public ObjectRefWithFallbackTraitsBase -inline constexpr bool use_default_type_traits_v = false; - -template <> -struct TypeTraits : public ObjectRefWithFallbackTraitsBase { - TVM_FFI_INLINE static Bool ConvertFallbackValue(int64_t value) { return Bool(value != 0); } -}; - // define automatic conversion from bool, int64_t, double to PrimExpr TVM_FFI_INLINE PrimExpr TypeTraits::ConvertFallbackValue(StrictBool value) { return IntImm(DataType::Bool(), value, Span()); diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index dd8e03aeac5a..81c35d890a9d 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -555,7 +555,8 @@ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data, // append remaining shape for (size_t i = num_block_dims + 1; i < input_shape.size(); i++) { r_shape.push_back(input_shape[i]); - axis.push_back(static_cast(r_shape.size() - 1)); // index of remaining shape in r_shape + axis.push_back( + static_cast(r_shape.size() - 1)); // index of remaining shape in r_shape o_shape.push_back(input_shape[i]); } @@ -604,7 +605,8 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, for (size_t i = 1; i < num_input_dims; i++) { axis.push_back(static_cast(r_shape.size())); // axis of in_shape[i] if (axis.size() < (num_block_dims + num_input_dims)) { - axis.push_back(static_cast(r_shape.size() - (num_block_dims + 1))); // axis of block_shape[i] + axis.push_back( + static_cast(r_shape.size() - (num_block_dims + 1))); // axis of block_shape[i] } r_shape.push_back(in_shape[i]); } diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 3c458435b37b..1178fcae3667 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -225,8 +225,7 @@ inline Tensor transpose(const Tensor& x, ffi::Optional> opt_ for (size_t j = 0; j < axes.size(); ++j) { if (i != j) { - TVM_FFI_ICHECK(new_axis != static_cast(axes[j])) - << "repeated axis in transpose"; + TVM_FFI_ICHECK(new_axis != static_cast(axes[j])) << "repeated axis in transpose"; } } new_shape.push_back(x->shape[new_axis]); diff --git a/src/relax/backend/contrib/clml/codegen.cc b/src/relax/backend/contrib/clml/codegen.cc index 14c8230d0b81..75073de17da4 100644 --- a/src/relax/backend/contrib/clml/codegen.cc +++ b/src/relax/backend/contrib/clml/codegen.cc @@ -48,7 +48,8 @@ struct OpenCLMLCompilerConfigNode : public ffi::Object { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro( "clml_version", &OpenCLMLCompilerConfigNode::clml_version, - "OpenCLML version as (major, minor, patch).", refl::DefaultValue(IntImm(DataType::Int(32), 3))); + "OpenCLML version as (major, minor, patch).", + refl::DefaultValue(IntImm(DataType::Int(32), 3))); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.ext.attrs.OpenCLMLCompilerConfig", OpenCLMLCompilerConfigNode, ffi::Object); diff --git a/tests/cpp/arith_simplify_test.cc b/tests/cpp/arith_simplify_test.cc index 0c2084e8fd3a..2c7b9cea2472 100644 --- a/tests/cpp/arith_simplify_test.cc +++ b/tests/cpp/arith_simplify_test.cc @@ -45,8 +45,8 @@ TEST(Simplify, Mul) { TEST(Simplify, Mod) { tvm::arith::Analyzer ana; - auto x = tvm::IntImm(DataType::Int(32), 10); - auto y = tvm::IntImm(DataType::Int(32), 12); + auto x = tvm::IntImm(tvm::DataType::Int(32), 10); + auto y = tvm::IntImm(tvm::DataType::Int(32), 12); // Mod::make is used instead of % to avoid constant folding during // calling operator%(x,y). Mod::make doesn't try constant folding, // and therefore, the constant folding will be attempted in CanonicalSimplify From 8bb7fd6803878e42a9278b95bb1a696788bb2376 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 28 May 2026 15:02:47 +0000 Subject: [PATCH 5/5] =?UTF-8?q?[REFACTOR][RELAX]=20Fix=20Integer=E2=86=92i?= =?UTF-8?q?nt64=5Ft=20Python=20construction=20sites=20and=20apply=20const?= =?UTF-8?q?=5Ftrue=20cleanups?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit After migrating topi container signatures from Array, the strided_slice family lost two features the original Array parameters carried: 1. Nullable entries. begin/end accept None entries meaning "use stride-direction default". The Array port silently dropped the !defined() branches. 2. Index dtype propagation. begin[0]->dtype drove the index dtype for IntImm/PrimExpr construction; hardcoding Int(64) widened int32 callers. Restore both features by porting begin/end to ffi::Array> and strides to ffi::Array: - include/tvm/topi/detail/strided_slice.h: ConvertToVec gets Optional/IntImm params; reinstates the None branches for begin and end. - include/tvm/topi/transform.h: strided_slice / strided_slice_with_axes / StridedSliceOutputShape match the new types; index_dtype is derived from begin[0]->dtype again; max_range uses max_value(dtype); stride dtype from strides[i]->dtype. - src/topi/transform.cc: static-path cast now produces Array> for begin/end and Array for strides. - include/tvm/topi/nn.h: the strided_slice caller in spatial_pad_to_batch builds Array>/Array instead of Array. axes remains Array on the C++ side (only host integers are read). python/tvm/topi/transform.py drops the begin/end/strides _unbox shim — the FFI now marshals Python ints and IntImm directly — but keeps a small axes-only IntImm→int conversion for relax legalize callers that produce IntImm-valued axes lists. Also apply six mechanical cleanups from review feedback: - include/tvm/topi/transform.h: remove redundant ternary (Int(64) ? Int(64)) - src/te/operation/create_primfunc.cc: IntImm(Bool(),1) → const_true() - src/relax/analysis/struct_info_analysis.cc: same - src/relax/ir/dataflow_matcher.cc: same - src/relax/transform/adjust_matmul_order.cc: same - src/relax/transform/fuse_tir.cc: same (no namespace qualifier needed, already in tvm::tirx; relax-namespace sites use tirx::const_true()) Clang-format and ruff-format reformats from pre-commit included. --- include/tvm/topi/detail/strided_slice.h | 22 ++-- include/tvm/topi/nn.h | 14 ++- include/tvm/topi/transform.h | 55 +++++---- python/tvm/topi/transform.py | 3 + src/arith/conjunctive_normal_form.cc | 3 +- src/arith/rewrite_simplify.cc | 3 +- src/relax/analysis/struct_info_analysis.cc | 2 +- src/relax/analysis/tir_op_pattern_kind.cc | 2 +- src/relax/ir/dataflow_matcher.cc | 2 +- src/relax/ir/dataflow_matcher.h | 2 +- src/relax/op/nn/convolution.cc | 72 +++++++---- src/relax/op/nn/pooling.cc | 60 +++++---- src/relax/op/vision/roi_align.cc | 3 +- src/relax/op/vision/roi_pool.cc | 3 +- src/relax/transform/adjust_matmul_order.cc | 4 +- src/relax/transform/allocate_workspace.cc | 3 +- src/relax/transform/fuse_tir.cc | 4 +- src/s_tir/meta_schedule/arg_info.cc | 3 +- .../meta_schedule/database/json_database.cc | 11 +- .../meta_schedule/mutator/mutate_parallel.cc | 2 +- .../postproc/rewrite_cooperative_fetch.cc | 23 ++-- .../schedule/cuda/thread_bind.cc | 6 +- .../meta_schedule/schedule/cuda/winograd.cc | 6 +- .../schedule_rule/add_rfactor.cc | 3 +- .../multi_level_tiling_tensor_core.cc | 3 +- .../parallel_vectorize_unroll.cc | 3 +- src/s_tir/schedule/analysis/layout.cc | 6 +- src/s_tir/schedule/primitive/compute_at.cc | 9 +- .../schedule/primitive/loop_transformation.cc | 10 +- src/s_tir/schedule/primitive/read_write_at.cc | 3 +- src/s_tir/schedule/traced_schedule.cc | 115 +++++++++++------- src/s_tir/transform/default_gpu_schedule.cc | 10 +- .../transform/lower_cross_thread_reduction.cc | 8 +- .../transform/memhammer_lower_auto_copy.cc | 3 +- .../transform/transform_mma_buffer_layout.cc | 12 +- src/te/operation/create_primfunc.cc | 4 +- src/tirx/script/builder/frame.cc | 3 +- src/topi/transform.cc | 8 +- tests/cpp/ir_functor_test.cc | 3 +- 39 files changed, 306 insertions(+), 205 deletions(-) diff --git a/include/tvm/topi/detail/strided_slice.h b/include/tvm/topi/detail/strided_slice.h index b85908a88ba9..2e5df30808be 100644 --- a/include/tvm/topi/detail/strided_slice.h +++ b/include/tvm/topi/detail/strided_slice.h @@ -50,30 +50,38 @@ inline int64_t CanonicalizeIndex(int64_t index, int64_t extent, int64_t stride) } inline std::tuple, std::vector, std::vector> ConvertToVec( - const ffi::Array& begin, const ffi::Array& end, - const ffi::Array& strides, std::string slice_mode) { + const ffi::Array>& begin, const ffi::Array>& end, + const ffi::Array& strides, std::string slice_mode) { std::vector stride_vec(strides.size(), 1); if (slice_mode == "end") { for (size_t i = 0; i < strides.size(); ++i) { - stride_vec[i] = strides[i]; + stride_vec[i] = strides[i]->value; } } const int64_t max_range = std::numeric_limits::max(); std::vector begin_vec; for (size_t i = 0; i < begin.size(); ++i) { - begin_vec.push_back(begin[i]); + if (!begin[i].defined()) { + // value=None + begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); + } else { + begin_vec.push_back(begin[i].value()->value); + } } std::vector end_vec; for (size_t i = 0; i < end.size(); ++i) { - if (slice_mode == "size") { - int64_t end_val = end[i]; + // allow end to be None + if (!end[i].defined()) { + end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); + } else if (slice_mode == "size") { + int64_t end_val = end[i].value()->value; if (end_val < 0) { end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); } else { end_vec.push_back(begin_vec[i] + end_val); } } else { - end_vec.push_back(end[i]); + end_vec.push_back(end[i].value()->value); } } return std::make_tuple(begin_vec, end_vec, stride_vec); diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index 81c35d890a9d..23a22359d261 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -627,9 +627,11 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, out = reshape(out, r_p_shape); // Crop the start and end of dimensions of out - ffi::Array begin_idx, end_idx, strides; + ffi::Array> begin_idx, end_idx; + ffi::Array strides; + DataType index_dtype = DataType::Int(64); for (size_t i = 0; i < r_p_shape.size(); ++i) { - strides.push_back(int64_t(1)); + strides.push_back(IntImm(index_dtype, 1)); if (i > 0 && i <= num_block_dims) { // prepare begin and end index for spatial dimensions int64_t begin_i = GetConstInt(crop_begin_list[i - 1]); @@ -638,12 +640,12 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, TVM_FFI_ICHECK_GT(out_i, (begin_i + end_i)) << "Incorrect crop sizes for (" << i << ")th dim, can not crop more than" << " output size" << out_i << " vs " << (begin_i + end_i); - begin_idx.push_back(begin_i); - end_idx.push_back(out_i - end_i); + begin_idx.push_back(IntImm(index_dtype, begin_i)); + end_idx.push_back(IntImm(index_dtype, out_i - end_i)); } else { // ignore the batch and remaining dimension - begin_idx.push_back(int64_t(0)); - end_idx.push_back(GetConstInt(r_p_shape[i])); + begin_idx.push_back(IntImm(index_dtype, 0)); + end_idx.push_back(IntImm(index_dtype, GetConstInt(r_p_shape[i]))); } } diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 1178fcae3667..ed72c08e5a87 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -865,17 +865,19 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b * \return The output shape of strided_slice using the arguments above */ inline ffi::Array StridedSliceOutputShape(const ffi::Array& ishape, - const ffi::Array& begin, - const ffi::Array& end, - const ffi::Array& strides, + const ffi::Array>& begin, + const ffi::Array>& end, + const ffi::Array& strides, const ffi::Array& axes, const std::string& slice_mode) { TVM_FFI_ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size()); std::vector begin_vec, end_vec, strides_vec; std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode); - auto begin_canonicalized = StridedSliceCanonicalizeBegin(ishape, begin_vec, strides_vec, axes, - DataType::Int(64), slice_mode); + DataType index_dtype = + (begin.size() > 0 && begin[0].defined()) ? begin[0].value()->dtype : DataType::Int(64); + auto begin_canonicalized = + StridedSliceCanonicalizeBegin(ishape, begin_vec, strides_vec, axes, index_dtype, slice_mode); return StridedSliceOutputShape(ishape, begin_vec, end_vec, strides_vec, axes, slice_mode, begin_canonicalized, true); } @@ -896,13 +898,11 @@ inline ffi::Array StridedSliceOutputShape(const ffi::Array& * * \return A Tensor whose op member is the sstrided_slice operation */ -inline Tensor strided_slice_with_axes(const Tensor& x, const ffi::Array& begin, - const ffi::Array& end, - const ffi::Array& strides, - const ffi::Array& axes, - std::string slice_mode = "end", - std::string name = "T_strided_slice_with_axes", - std::string tag = kInjective) { +inline Tensor strided_slice_with_axes( + const Tensor& x, const ffi::Array>& begin, + const ffi::Array>& end, const ffi::Array& strides, + const ffi::Array& axes, std::string slice_mode = "end", + std::string name = "T_strided_slice_with_axes", std::string tag = kInjective) { const int64_t src_tensor_dim = static_cast(x->shape.size()); TVM_FFI_ICHECK(static_cast(axes.size()) <= src_tensor_dim); TVM_FFI_ICHECK(axes.size() == begin.size() && axes.size() == end.size() && @@ -924,7 +924,8 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const ffi::Array std::vector begin_vec, end_vec, strides_vec; std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode); - DataType index_dtype = begin.size() > 0 ? DataType::Int(64) : DataType::Int(64); + DataType index_dtype = + (begin.size() > 0 && begin[0].defined()) ? begin[0].value()->dtype : DataType::Int(64); auto begin_expr = StridedSliceCanonicalizeBegin(x->shape, begin_vec, strides_vec, normalized_axes, index_dtype, slice_mode); auto out_shape = StridedSliceOutputShape(x->shape, begin_vec, end_vec, strides_vec, @@ -937,7 +938,7 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const ffi::Array for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]); for (size_t i = 0; i < normalized_axes.size(); ++i) { int64_t ax = normalized_axes[i]; - auto stride = make_const(DataType::Int(64), strides_vec[i]); + auto stride = make_const(strides[i]->dtype, strides_vec[i]); PrimExpr ind = indices[ax] * stride + begin_expr[i]; real_indices.Set(ax, ind); } @@ -960,29 +961,31 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const ffi::Array * * \return A Tensor whose op member is the strided_slice operation */ -inline Tensor strided_slice(const Tensor& x, const ffi::Array& begin, - const ffi::Array& end, const ffi::Array& strides, - std::string slice_mode = "end", std::string name = "T_strided_slice", - std::string tag = kInjective) { +inline Tensor strided_slice(const Tensor& x, const ffi::Array>& begin, + const ffi::Array>& end, + const ffi::Array& strides, std::string slice_mode = "end", + std::string name = "T_strided_slice", std::string tag = kInjective) { size_t src_tensor_dim = static_cast(x->shape.size()); ffi::Array axes; for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i); - ffi::Array begin_full(begin); - ffi::Array end_full(end); - ffi::Array strides_full(strides); + ffi::Array> begin_full(begin); + ffi::Array> end_full(end); + ffi::Array strides_full(strides); - constexpr int64_t one = 1; - constexpr int64_t zero = 0; - const int64_t max_range = std::numeric_limits::max(); + DataType index_dtype = + (begin.size() > 0 && begin[0].defined()) ? begin[0].value()->dtype : DataType::Int(64); + const IntImm one = IntImm(index_dtype, 1); + const IntImm zero = IntImm(index_dtype, 0); + const IntImm max_range = Downcast(max_value(index_dtype)); for (size_t i = strides.size(); i < src_tensor_dim; ++i) { strides_full.push_back(one); } for (size_t i = begin.size(); i < src_tensor_dim; ++i) { - begin_full.push_back(strides_full[i] > 0 ? zero : max_range); + begin_full.push_back(strides_full[i]->value > 0 ? zero : max_range); } for (size_t i = end.size(); i < src_tensor_dim; ++i) { - end_full.push_back(strides_full[i] < 0 ? zero : max_range); + end_full.push_back(strides_full[i]->value < 0 ? zero : max_range); } return strided_slice_with_axes(x, begin_full, end_full, strides_full, axes, slice_mode, name, diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index fba3eb4cfa7e..4d5266c3bac3 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -229,6 +229,9 @@ def strided_slice(a, begin, end, strides=None, axes=None, slice_mode="end", assu strides = [] if axes is None: axes = [] + # axes is a list of host integers on the C++ side (Array); unwrap any + # IntImm entries that callers may pass through (e.g. relax legalize pipeline). + axes = [int(v) if isinstance(v, tvm.tirx.IntImm) else v for v in axes] return cpp.strided_slice(a, begin, end, strides, axes, slice_mode, assume_inbound) diff --git a/src/arith/conjunctive_normal_form.cc b/src/arith/conjunctive_normal_form.cc index 17df960a127c..6aaef8327003 100644 --- a/src/arith/conjunctive_normal_form.cc +++ b/src/arith/conjunctive_normal_form.cc @@ -147,7 +147,8 @@ class AndOfOrs { }; AndOfOrs::AndOfOrs(const PrimExpr& expr) - : key_true_(GetKey(IntImm(DataType::Bool(), 1))), key_false_(GetKey(IntImm(DataType::Bool(), 0))) { + : key_true_(GetKey(IntImm(DataType::Bool(), 1))), + key_false_(GetKey(IntImm(DataType::Bool(), 0))) { VisitAndExpressions(expr, [&](const PrimExpr& outer_expr) { std::vector or_components; VisitOrExpressions(outer_expr, [&](const PrimExpr& inner_expr) { diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 0d13a5ecd375..1ae8b012860c 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -1072,7 +1072,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { PrimExpr y_div = CanProveEqual(floordiv(yval, c2val), 0) ? 0 : floordiv(yval, c2val); auto bound = analyzer_->const_int_bound(residue); if (bound.defined() && bound->max_value == bound->min_value) { - return x.Eval() * floordiv(c1val, c2.Eval()) + (y_div + IntImm(DataType::Int(32), bound->max_value)); + return x.Eval() * floordiv(c1val, c2.Eval()) + + (y_div + IntImm(DataType::Int(32), bound->max_value)); } // try simplify divisor diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index 704e40c6b191..66062c1870c3 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -633,7 +633,7 @@ class StructInfoBasePreconditionCollector PrimExpr VisitStructInfo(const StructInfo& lhs, const StructInfo& other) override { if (lhs.same_as(other)) { // Early bail-out if the StructInfo has reference equality. - return IntImm(DataType::Bool(), 1); + return tirx::const_true(); } else { return StructInfoFunctor::VisitStructInfo(lhs, other); } diff --git a/src/relax/analysis/tir_op_pattern_kind.cc b/src/relax/analysis/tir_op_pattern_kind.cc index ace88a5ce801..26041475c64d 100644 --- a/src/relax/analysis/tir_op_pattern_kind.cc +++ b/src/relax/analysis/tir_op_pattern_kind.cc @@ -25,8 +25,8 @@ #include #include #include -#include #include +#include namespace tvm { namespace relax { diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 08d39ac29c42..57578773c675 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -471,7 +471,7 @@ PrimExpr DFPatternMatcher::SimplifyCondition(PrimExpr condition) { constraints.begin(), constraints.end(), [&sort_key](const PrimExpr& a, const PrimExpr& b) { return sort_key(a) < sort_key(b); }); - PrimExpr sorted_condition = IntImm(DataType::Bool(), 1); + PrimExpr sorted_condition = tirx::const_true(); for (const PrimExpr& constraint : constraints) { sorted_condition = sorted_condition && constraint; } diff --git a/src/relax/ir/dataflow_matcher.h b/src/relax/ir/dataflow_matcher.h index 45b76de68ad0..e4006e2bc4bb 100644 --- a/src/relax/ir/dataflow_matcher.h +++ b/src/relax/ir/dataflow_matcher.h @@ -28,11 +28,11 @@ #include #include #include +#include #include #include #include -#include namespace tvm { namespace relax { diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc index 2fdea26bd7ed..8916e430822c 100644 --- a/src/relax/op/nn/convolution.cc +++ b/src/relax/op/nn/convolution.cc @@ -128,15 +128,18 @@ StructInfo InferStructInfoConv1d(const Call& call, const BlockBuilder& ctx) { PrimExpr input_w = data_NCW_shape[2]; PrimExpr kernel_w = weight_OIW_shape[2]; - PrimExpr padding_w = IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[1]); + PrimExpr padding_w = + IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[1]); std::vector out_NCW_shape; out_NCW_shape.resize(3); out_NCW_shape[0] = data_NCW_shape[0]; out_NCW_shape[1] = weight_OIW_shape[0]; - PrimExpr numerator_w = input_w + padding_w - IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_w - 1) - 1; - out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w, IntImm(DataType::Int(32), attrs->strides[0])) + 1); + PrimExpr numerator_w = + input_w + padding_w - IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_w - 1) - 1; + out_NCW_shape[2] = + analyzer->Simplify(floordiv(numerator_w, IntImm(DataType::Int(32), attrs->strides[0])) + 1); ffi::Array out_shape = out2NCW.BackwardShape(out_NCW_shape); return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); @@ -299,18 +302,24 @@ StructInfo InferStructInfoConv2d(const Call& call, const BlockBuilder& ctx) { PrimExpr input_w = data_NCHW_shape[3]; PrimExpr kernel_h = weight_OIHW_shape[2]; PrimExpr kernel_w = weight_OIHW_shape[3]; - PrimExpr padding_h = IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[2]); - PrimExpr padding_w = IntImm(DataType::Int(32), attrs->padding[1]) + IntImm(DataType::Int(32), attrs->padding[3]); + PrimExpr padding_h = + IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[2]); + PrimExpr padding_w = + IntImm(DataType::Int(32), attrs->padding[1]) + IntImm(DataType::Int(32), attrs->padding[3]); std::vector out_NCHW_shape; out_NCHW_shape.resize(4); out_NCHW_shape[0] = data_NCHW_shape[0]; out_NCHW_shape[1] = weight_OIHW_shape[0]; - PrimExpr numerator_h = input_h + padding_h - IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_h - 1) - 1; - PrimExpr numerator_w = input_w + padding_w - IntImm(DataType::Int(32), attrs->dilation[1]) * (kernel_w - 1) - 1; - out_NCHW_shape[2] = analyzer->Simplify(floordiv(numerator_h, IntImm(DataType::Int(32), attrs->strides[0])) + 1); - out_NCHW_shape[3] = analyzer->Simplify(floordiv(numerator_w, IntImm(DataType::Int(32), attrs->strides[1])) + 1); + PrimExpr numerator_h = + input_h + padding_h - IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_h - 1) - 1; + PrimExpr numerator_w = + input_w + padding_w - IntImm(DataType::Int(32), attrs->dilation[1]) * (kernel_w - 1) - 1; + out_NCHW_shape[2] = + analyzer->Simplify(floordiv(numerator_h, IntImm(DataType::Int(32), attrs->strides[0])) + 1); + out_NCHW_shape[3] = + analyzer->Simplify(floordiv(numerator_w, IntImm(DataType::Int(32), attrs->strides[1])) + 1); ffi::Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); @@ -512,21 +521,30 @@ StructInfo InferStructInfoConv3d(const Call& call, const BlockBuilder& ctx) { PrimExpr kernel_d = weight_OIDHW_shape[2]; PrimExpr kernel_h = weight_OIDHW_shape[3]; PrimExpr kernel_w = weight_OIDHW_shape[4]; - PrimExpr padding_d = IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[3]); - PrimExpr padding_h = IntImm(DataType::Int(32), attrs->padding[1]) + IntImm(DataType::Int(32), attrs->padding[4]); - PrimExpr padding_w = IntImm(DataType::Int(32), attrs->padding[2]) + IntImm(DataType::Int(32), attrs->padding[5]); + PrimExpr padding_d = + IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[3]); + PrimExpr padding_h = + IntImm(DataType::Int(32), attrs->padding[1]) + IntImm(DataType::Int(32), attrs->padding[4]); + PrimExpr padding_w = + IntImm(DataType::Int(32), attrs->padding[2]) + IntImm(DataType::Int(32), attrs->padding[5]); std::vector out_NCDHW_shape; out_NCDHW_shape.resize(5); out_NCDHW_shape[0] = data_NCDHW_shape[0]; out_NCDHW_shape[1] = weight_OIDHW_shape[0]; - PrimExpr numerator_d = input_d + padding_d - IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_d - 1) - 1; - PrimExpr numerator_h = input_h + padding_h - IntImm(DataType::Int(32), attrs->dilation[1]) * (kernel_h - 1) - 1; - PrimExpr numerator_w = input_w + padding_w - IntImm(DataType::Int(32), attrs->dilation[2]) * (kernel_w - 1) - 1; - out_NCDHW_shape[2] = analyzer->Simplify(floordiv(numerator_d, IntImm(DataType::Int(32), attrs->strides[0])) + 1); - out_NCDHW_shape[3] = analyzer->Simplify(floordiv(numerator_h, IntImm(DataType::Int(32), attrs->strides[1])) + 1); - out_NCDHW_shape[4] = analyzer->Simplify(floordiv(numerator_w, IntImm(DataType::Int(32), attrs->strides[2])) + 1); + PrimExpr numerator_d = + input_d + padding_d - IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_d - 1) - 1; + PrimExpr numerator_h = + input_h + padding_h - IntImm(DataType::Int(32), attrs->dilation[1]) * (kernel_h - 1) - 1; + PrimExpr numerator_w = + input_w + padding_w - IntImm(DataType::Int(32), attrs->dilation[2]) * (kernel_w - 1) - 1; + out_NCDHW_shape[2] = + analyzer->Simplify(floordiv(numerator_d, IntImm(DataType::Int(32), attrs->strides[0])) + 1); + out_NCDHW_shape[3] = + analyzer->Simplify(floordiv(numerator_h, IntImm(DataType::Int(32), attrs->strides[1])) + 1); + out_NCDHW_shape[4] = + analyzer->Simplify(floordiv(numerator_w, IntImm(DataType::Int(32), attrs->strides[2])) + 1); ffi::Array out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape); return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); @@ -701,7 +719,8 @@ StructInfo InferStructInfoConv1dTranspose(const Call& call, const BlockBuilder& PrimExpr input_w = data_NCW_shape[2]; PrimExpr kernel_w = weight_IOW_shape[2]; - PrimExpr padding_w = IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[1]); + PrimExpr padding_w = + IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[1]); std::vector out_NCW_shape; out_NCW_shape.resize(3); @@ -895,8 +914,10 @@ StructInfo InferStructInfoConv2dTranspose(const Call& call, const BlockBuilder& PrimExpr input_w = data_NCHW_shape[3]; PrimExpr kernel_h = weight_IOHW_shape[2]; PrimExpr kernel_w = weight_IOHW_shape[3]; - PrimExpr padding_h = IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[2]); - PrimExpr padding_w = IntImm(DataType::Int(32), attrs->padding[1]) + IntImm(DataType::Int(32), attrs->padding[3]); + PrimExpr padding_h = + IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[2]); + PrimExpr padding_w = + IntImm(DataType::Int(32), attrs->padding[1]) + IntImm(DataType::Int(32), attrs->padding[3]); std::vector out_NCHW_shape; out_NCHW_shape.resize(4); @@ -1132,9 +1153,12 @@ StructInfo InferStructInfoConv3dTranspose(const Call& call, const BlockBuilder& PrimExpr kernel_d = weight_IODHW_shape[2]; PrimExpr kernel_h = weight_IODHW_shape[3]; PrimExpr kernel_w = weight_IODHW_shape[4]; - PrimExpr padding_d = IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[3]); - PrimExpr padding_h = IntImm(DataType::Int(32), attrs->padding[1]) + IntImm(DataType::Int(32), attrs->padding[4]); - PrimExpr padding_w = IntImm(DataType::Int(32), attrs->padding[2]) + IntImm(DataType::Int(32), attrs->padding[5]); + PrimExpr padding_d = + IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[3]); + PrimExpr padding_h = + IntImm(DataType::Int(32), attrs->padding[1]) + IntImm(DataType::Int(32), attrs->padding[4]); + PrimExpr padding_w = + IntImm(DataType::Int(32), attrs->padding[2]) + IntImm(DataType::Int(32), attrs->padding[5]); std::vector out_NCDHW_shape; out_NCDHW_shape.resize(5); diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc index df432f9b8e46..2be119b788ec 100644 --- a/src/relax/op/nn/pooling.cc +++ b/src/relax/op/nn/pooling.cc @@ -100,7 +100,8 @@ StructInfo InferStructInfoPool1D(const Call& call, const BlockBuilder& ctx) { PrimExpr input_w = data_NCW_shape[2]; PrimExpr kernel_w = IntImm(DataType::Int(32), attrs->pool_size[0]); - PrimExpr padding_w = IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[1]); + PrimExpr padding_w = + IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[1]); arith::Analyzer* analyzer = ctx->GetAnalyzer(); std::vector out_NCW_shape; @@ -108,14 +109,15 @@ StructInfo InferStructInfoPool1D(const Call& call, const BlockBuilder& ctx) { out_NCW_shape[0] = data_NCW_shape[0]; out_NCW_shape[1] = data_NCW_shape[1]; - PrimExpr numerator_w = input_w + padding_w - IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_w - 1) - 1; + PrimExpr numerator_w = + input_w + padding_w - IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_w - 1) - 1; if (attrs->ceil_mode) { numerator_w += IntImm(DataType::Int(32), attrs->strides[0]) - 1; } PrimExpr raw_out_w = floordiv(numerator_w, IntImm(DataType::Int(32), attrs->strides[0])) + 1; if (attrs->ceil_mode) { - PrimExpr invalid_last_w = - (raw_out_w - 1) * IntImm(DataType::Int(32), attrs->strides[0]) >= input_w + IntImm(DataType::Int(32), attrs->padding[0]); + PrimExpr invalid_last_w = (raw_out_w - 1) * IntImm(DataType::Int(32), attrs->strides[0]) >= + input_w + IntImm(DataType::Int(32), attrs->padding[0]); out_NCW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_w, raw_out_w - 1, raw_out_w)); } else { out_NCW_shape[2] = analyzer->Simplify(raw_out_w); @@ -225,8 +227,10 @@ StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) { PrimExpr input_w = data_NCHW_shape[3]; PrimExpr kernel_h = IntImm(DataType::Int(32), attrs->pool_size[0]); PrimExpr kernel_w = IntImm(DataType::Int(32), attrs->pool_size[1]); - PrimExpr padding_h = IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[2]); - PrimExpr padding_w = IntImm(DataType::Int(32), attrs->padding[1]) + IntImm(DataType::Int(32), attrs->padding[3]); + PrimExpr padding_h = + IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[2]); + PrimExpr padding_w = + IntImm(DataType::Int(32), attrs->padding[1]) + IntImm(DataType::Int(32), attrs->padding[3]); arith::Analyzer* analyzer = ctx->GetAnalyzer(); std::vector out_NCHW_shape; @@ -234,8 +238,10 @@ StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) { out_NCHW_shape[0] = data_NCHW_shape[0]; out_NCHW_shape[1] = data_NCHW_shape[1]; - PrimExpr numerator_h = input_h + padding_h - IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_h - 1) - 1; - PrimExpr numerator_w = input_w + padding_w - IntImm(DataType::Int(32), attrs->dilation[1]) * (kernel_w - 1) - 1; + PrimExpr numerator_h = + input_h + padding_h - IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_h - 1) - 1; + PrimExpr numerator_w = + input_w + padding_w - IntImm(DataType::Int(32), attrs->dilation[1]) * (kernel_w - 1) - 1; if (attrs->ceil_mode) { numerator_h += IntImm(DataType::Int(32), attrs->strides[0]) - 1; numerator_w += IntImm(DataType::Int(32), attrs->strides[1]) - 1; @@ -243,10 +249,10 @@ StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) { PrimExpr raw_out_h = floordiv(numerator_h, IntImm(DataType::Int(32), attrs->strides[0])) + 1; PrimExpr raw_out_w = floordiv(numerator_w, IntImm(DataType::Int(32), attrs->strides[1])) + 1; if (attrs->ceil_mode) { - PrimExpr invalid_last_h = - (raw_out_h - 1) * IntImm(DataType::Int(32), attrs->strides[0]) >= input_h + IntImm(DataType::Int(32), attrs->padding[0]); - PrimExpr invalid_last_w = - (raw_out_w - 1) * IntImm(DataType::Int(32), attrs->strides[1]) >= input_w + IntImm(DataType::Int(32), attrs->padding[1]); + PrimExpr invalid_last_h = (raw_out_h - 1) * IntImm(DataType::Int(32), attrs->strides[0]) >= + input_h + IntImm(DataType::Int(32), attrs->padding[0]); + PrimExpr invalid_last_w = (raw_out_w - 1) * IntImm(DataType::Int(32), attrs->strides[1]) >= + input_w + IntImm(DataType::Int(32), attrs->padding[1]); out_NCHW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_h, raw_out_h - 1, raw_out_h)); out_NCHW_shape[3] = analyzer->Simplify(if_then_else(invalid_last_w, raw_out_w - 1, raw_out_w)); } else { @@ -381,9 +387,12 @@ StructInfo InferStructInfoPool3D(const Call& call, const BlockBuilder& ctx) { PrimExpr kernel_d = IntImm(DataType::Int(32), attrs->pool_size[0]); PrimExpr kernel_h = IntImm(DataType::Int(32), attrs->pool_size[1]); PrimExpr kernel_w = IntImm(DataType::Int(32), attrs->pool_size[2]); - PrimExpr padding_d = IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[3]); - PrimExpr padding_h = IntImm(DataType::Int(32), attrs->padding[1]) + IntImm(DataType::Int(32), attrs->padding[4]); - PrimExpr padding_w = IntImm(DataType::Int(32), attrs->padding[2]) + IntImm(DataType::Int(32), attrs->padding[5]); + PrimExpr padding_d = + IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[3]); + PrimExpr padding_h = + IntImm(DataType::Int(32), attrs->padding[1]) + IntImm(DataType::Int(32), attrs->padding[4]); + PrimExpr padding_w = + IntImm(DataType::Int(32), attrs->padding[2]) + IntImm(DataType::Int(32), attrs->padding[5]); arith::Analyzer* analyzer = ctx->GetAnalyzer(); std::vector out_NCDHW_shape; @@ -391,9 +400,12 @@ StructInfo InferStructInfoPool3D(const Call& call, const BlockBuilder& ctx) { out_NCDHW_shape[0] = data_NCDHW_shape[0]; out_NCDHW_shape[1] = data_NCDHW_shape[1]; - PrimExpr numerator_d = input_d + padding_d - IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_d - 1) - 1; - PrimExpr numerator_h = input_h + padding_h - IntImm(DataType::Int(32), attrs->dilation[1]) * (kernel_h - 1) - 1; - PrimExpr numerator_w = input_w + padding_w - IntImm(DataType::Int(32), attrs->dilation[2]) * (kernel_w - 1) - 1; + PrimExpr numerator_d = + input_d + padding_d - IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_d - 1) - 1; + PrimExpr numerator_h = + input_h + padding_h - IntImm(DataType::Int(32), attrs->dilation[1]) * (kernel_h - 1) - 1; + PrimExpr numerator_w = + input_w + padding_w - IntImm(DataType::Int(32), attrs->dilation[2]) * (kernel_w - 1) - 1; if (attrs->ceil_mode) { numerator_d += IntImm(DataType::Int(32), attrs->strides[0]) - 1; numerator_h += IntImm(DataType::Int(32), attrs->strides[1]) - 1; @@ -403,12 +415,12 @@ StructInfo InferStructInfoPool3D(const Call& call, const BlockBuilder& ctx) { PrimExpr raw_out_h = floordiv(numerator_h, IntImm(DataType::Int(32), attrs->strides[1])) + 1; PrimExpr raw_out_w = floordiv(numerator_w, IntImm(DataType::Int(32), attrs->strides[2])) + 1; if (attrs->ceil_mode) { - PrimExpr invalid_last_d = - (raw_out_d - 1) * IntImm(DataType::Int(32), attrs->strides[0]) >= input_d + IntImm(DataType::Int(32), attrs->padding[0]); - PrimExpr invalid_last_h = - (raw_out_h - 1) * IntImm(DataType::Int(32), attrs->strides[1]) >= input_h + IntImm(DataType::Int(32), attrs->padding[1]); - PrimExpr invalid_last_w = - (raw_out_w - 1) * IntImm(DataType::Int(32), attrs->strides[2]) >= input_w + IntImm(DataType::Int(32), attrs->padding[2]); + PrimExpr invalid_last_d = (raw_out_d - 1) * IntImm(DataType::Int(32), attrs->strides[0]) >= + input_d + IntImm(DataType::Int(32), attrs->padding[0]); + PrimExpr invalid_last_h = (raw_out_h - 1) * IntImm(DataType::Int(32), attrs->strides[1]) >= + input_h + IntImm(DataType::Int(32), attrs->padding[1]); + PrimExpr invalid_last_w = (raw_out_w - 1) * IntImm(DataType::Int(32), attrs->strides[2]) >= + input_w + IntImm(DataType::Int(32), attrs->padding[2]); out_NCDHW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_d, raw_out_d - 1, raw_out_d)); out_NCDHW_shape[3] = analyzer->Simplify(if_then_else(invalid_last_h, raw_out_h - 1, raw_out_h)); out_NCDHW_shape[4] = analyzer->Simplify(if_then_else(invalid_last_w, raw_out_w - 1, raw_out_w)); diff --git a/src/relax/op/vision/roi_align.cc b/src/relax/op/vision/roi_align.cc index 5c3ef52c6a62..e2dc4396a6d1 100644 --- a/src/relax/op/vision/roi_align.cc +++ b/src/relax/op/vision/roi_align.cc @@ -118,7 +118,8 @@ StructInfo InferStructInfoROIAlign(const Call& call, const BlockBuilder& ctx) { ffi::Array data_shape = data_sinfo->shape.as()->values; ffi::Array out_shape; if (attrs->layout == "NCHW") { - out_shape = {rois_shape->values[0], data_shape[1], IntImm(DataType::Int(32), attrs->pooled_size[0]), + out_shape = {rois_shape->values[0], data_shape[1], + IntImm(DataType::Int(32), attrs->pooled_size[0]), IntImm(DataType::Int(32), attrs->pooled_size[1])}; } else { out_shape = {rois_shape->values[0], IntImm(DataType::Int(32), attrs->pooled_size[0]), diff --git a/src/relax/op/vision/roi_pool.cc b/src/relax/op/vision/roi_pool.cc index 25e529308882..4a98a3629008 100644 --- a/src/relax/op/vision/roi_pool.cc +++ b/src/relax/op/vision/roi_pool.cc @@ -110,7 +110,8 @@ StructInfo InferStructInfoROIPool(const Call& call, const BlockBuilder& ctx) { ffi::Array data_shape = data_sinfo->shape.as()->values; ffi::Array out_shape = {rois_shape->values[0], data_shape[1], - IntImm(DataType::Int(32), attrs->pooled_size[0]), IntImm(DataType::Int(32), attrs->pooled_size[1])}; + IntImm(DataType::Int(32), attrs->pooled_size[0]), + IntImm(DataType::Int(32), attrs->pooled_size[1])}; return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); } diff --git a/src/relax/transform/adjust_matmul_order.cc b/src/relax/transform/adjust_matmul_order.cc index 54bca2aaefdf..9ea47aa64844 100644 --- a/src/relax/transform/adjust_matmul_order.cc +++ b/src/relax/transform/adjust_matmul_order.cc @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -35,7 +36,6 @@ #include "../op/tensor/linear_algebra.h" #include "../op/tensor/manipulate.h" -#include namespace tvm { namespace relax { @@ -73,7 +73,7 @@ std::tuple)>> auto pat = pat_matmul_on_lhs | pat_matmul_on_rhs | pat_permuted_matmul_on_lhs | pat_permuted_matmul_on_rhs; - PrimExpr symbolic_var_constraints = IntImm(DataType::Bool(), 1); + PrimExpr symbolic_var_constraints = tirx::const_true(); auto upper_bounds = func->GetAttr>("tir_var_upper_bound"); auto lower_bounds = func->GetAttr>("tir_var_lower_bound"); diff --git a/src/relax/transform/allocate_workspace.cc b/src/relax/transform/allocate_workspace.cc index 6b8f3c776185..718214d49157 100644 --- a/src/relax/transform/allocate_workspace.cc +++ b/src/relax/transform/allocate_workspace.cc @@ -61,7 +61,8 @@ class ExternFunctionRewriter : ExprMutator { // Append the workspace parameter to this function. ffi::Array new_params = func_node->params; - auto sinfo = TensorStructInfo(ShapeExpr({IntImm(DataType::Int(32), max_workspace_size_)}), DataType::UInt(8)); + auto sinfo = TensorStructInfo(ShapeExpr({IntImm(DataType::Int(32), max_workspace_size_)}), + DataType::UInt(8)); Var workspace_param(name_sup_->FreshName("workspace"), sinfo); if (func_node->GetAttr(attr::kCodegen)) { diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 3db3c12f1e96..d0089734ad24 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -24,8 +24,8 @@ #include #include #include -#include #include +#include #include #include @@ -155,7 +155,7 @@ class SymbolicMatcher : ExprFunctor* var_remap_; - PrimExpr must_prove_ = IntImm(DataType::Bool(), 1); + PrimExpr must_prove_ = const_true(); }; /*! diff --git a/src/s_tir/meta_schedule/arg_info.cc b/src/s_tir/meta_schedule/arg_info.cc index 4163a2b8b552..dc452b370037 100644 --- a/src/s_tir/meta_schedule/arg_info.cc +++ b/src/s_tir/meta_schedule/arg_info.cc @@ -149,8 +149,7 @@ TensorInfo TensorInfo::FromJSON(const ffi::ObjectRef& json_obj) { << "\nThe error is: " << e.what(); } std::vector s; - std::transform(shape.begin(), shape.end(), std::back_inserter(s), - [](int64_t i) { return i; }); + std::transform(shape.begin(), shape.end(), std::back_inserter(s), [](int64_t i) { return i; }); return TensorInfo(DataType(dtype), ffi::Shape(s.begin(), s.end())); } diff --git a/src/s_tir/meta_schedule/database/json_database.cc b/src/s_tir/meta_schedule/database/json_database.cc index a6c656f5098b..5b5395f33b37 100644 --- a/src/s_tir/meta_schedule/database/json_database.cc +++ b/src/s_tir/meta_schedule/database/json_database.cc @@ -114,11 +114,12 @@ class JSONDatabaseNode : public DatabaseNode { void CommitTuningRecord(const TuningRecord& record) { this->tuning_records_.insert(record); - JSONFileAppendLine(this->path_tuning_record, - JSONDumps(ffi::Array{ - /*workload_index=*/IntImm(DataType::Int(32), this->workloads2idx_.at(record->workload)), - /*tuning_record=*/record->AsJSON() // - })); + JSONFileAppendLine( + this->path_tuning_record, + JSONDumps(ffi::Array{ + /*workload_index=*/IntImm(DataType::Int(32), this->workloads2idx_.at(record->workload)), + /*tuning_record=*/record->AsJSON() // + })); } ffi::Array GetTopK(const Workload& workload, int top_k) { diff --git a/src/s_tir/meta_schedule/mutator/mutate_parallel.cc b/src/s_tir/meta_schedule/mutator/mutate_parallel.cc index 7e3b20fc3dea..95a2c03b8df1 100644 --- a/src/s_tir/meta_schedule/mutator/mutate_parallel.cc +++ b/src/s_tir/meta_schedule/mutator/mutate_parallel.cc @@ -53,7 +53,7 @@ bool IsAnnotateWithParallel(const Instruction& inst) { */ Instruction ReplaceAnnValue(Instruction inst, int64_t ann_val) { TVM_FFI_ICHECK_EQ(inst->inputs.size(), 2); - return Instruction(/*kind=*/inst->kind, // + return Instruction(/*kind=*/inst->kind, // /*inputs=*/{inst->inputs[0], IntImm(DataType::Int(32), ann_val)}, // /*attrs=*/inst->attrs, /*outputs=*/inst->outputs); diff --git a/src/s_tir/meta_schedule/postproc/rewrite_cooperative_fetch.cc b/src/s_tir/meta_schedule/postproc/rewrite_cooperative_fetch.cc index 3e0cdd8ac88d..ac85b92dc63a 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_cooperative_fetch.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_cooperative_fetch.cc @@ -198,25 +198,28 @@ bool RewriteCooperativeFetchNode::Apply(const s_tir::Schedule& sch) { } if (thread_extent_y != -1) { if (vector_lane > 1) { - ffi::Array split = sch->Split(fused, {std::nullopt, // - IntImm(DataType::Int(32), thread_extent_y), // - IntImm(DataType::Int(32), thread_extent_x), // - IntImm(DataType::Int(32), vector_lane)}); + ffi::Array split = + sch->Split(fused, {std::nullopt, // + IntImm(DataType::Int(32), thread_extent_y), // + IntImm(DataType::Int(32), thread_extent_x), // + IntImm(DataType::Int(32), vector_lane)}); sch->Vectorize(split[3]); sch->Bind(split[2], "threadIdx.x"); sch->Bind(split[1], "threadIdx.y"); } else { - ffi::Array split = sch->Split(fused, {std::nullopt, // - IntImm(DataType::Int(32), thread_extent_y), // - IntImm(DataType::Int(32), thread_extent_x)}); + ffi::Array split = + sch->Split(fused, {std::nullopt, // + IntImm(DataType::Int(32), thread_extent_y), // + IntImm(DataType::Int(32), thread_extent_x)}); sch->Bind(split[2], "threadIdx.x"); sch->Bind(split[1], "threadIdx.y"); } } else { if (vector_lane > 1) { - ffi::Array split = sch->Split(fused, {std::nullopt, // - IntImm(DataType::Int(32), thread_extent_x), // - IntImm(DataType::Int(32), vector_lane)}); + ffi::Array split = + sch->Split(fused, {std::nullopt, // + IntImm(DataType::Int(32), thread_extent_x), // + IntImm(DataType::Int(32), vector_lane)}); sch->Vectorize(split[2]); sch->Bind(split[1], "threadIdx.x"); } else { diff --git a/src/s_tir/meta_schedule/schedule/cuda/thread_bind.cc b/src/s_tir/meta_schedule/schedule/cuda/thread_bind.cc index 0fa916786787..365a558930ba 100644 --- a/src/s_tir/meta_schedule/schedule/cuda/thread_bind.cc +++ b/src/s_tir/meta_schedule/schedule/cuda/thread_bind.cc @@ -85,9 +85,9 @@ ffi::Array BindSpatialLoop(Schedule sch, LoopRV loop, int64_t max_thread sch->Bind(splits[1], "threadIdx.x"); return {splits[0], splits[1]}; } else { - ffi::Array splits = sch->Split(loop, {std::nullopt, - IntImm(DataType::Int(32), max_threadblocks), // - IntImm(DataType::Int(32), max_threads_per_block)}); + ffi::Array splits = + sch->Split(loop, {std::nullopt, IntImm(DataType::Int(32), max_threadblocks), // + IntImm(DataType::Int(32), max_threads_per_block)}); TVM_FFI_ICHECK_EQ(splits.size(), 3); sch->Reorder({splits[1], splits[2], splits[0]}); sch->Bind(splits[1], "blockIdx.x"); diff --git a/src/s_tir/meta_schedule/schedule/cuda/winograd.cc b/src/s_tir/meta_schedule/schedule/cuda/winograd.cc index 7beaca5698a7..47e559d157b5 100644 --- a/src/s_tir/meta_schedule/schedule/cuda/winograd.cc +++ b/src/s_tir/meta_schedule/schedule/cuda/winograd.cc @@ -150,8 +150,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { SBlockRV output = sch->GetConsumers(inverse)[0]; ffi::Array nchw = sch->GetLoops(output); TVM_FFI_ICHECK_EQ(nchw.size(), 4); - ffi::Array hs = sch->Split(nchw[2], {std::nullopt, IntImm(DataType::Int(32), tile_size)}); - ffi::Array ws = sch->Split(nchw[3], {std::nullopt, IntImm(DataType::Int(32), tile_size)}); + ffi::Array hs = + sch->Split(nchw[2], {std::nullopt, IntImm(DataType::Int(32), tile_size)}); + ffi::Array ws = + sch->Split(nchw[3], {std::nullopt, IntImm(DataType::Int(32), tile_size)}); sch->Reorder({hs[0], ws[0], hs[1], ws[1]}); outer = ws[0]; } diff --git a/src/s_tir/meta_schedule/schedule_rule/add_rfactor.cc b/src/s_tir/meta_schedule/schedule_rule/add_rfactor.cc index e5436e5efc41..2399739ff93e 100644 --- a/src/s_tir/meta_schedule/schedule_rule/add_rfactor.cc +++ b/src/s_tir/meta_schedule/schedule_rule/add_rfactor.cc @@ -115,7 +115,8 @@ ffi::Array AddRFactorNode::Apply(const s_tir::Schedule& sch, // Annotate that the rfactor block, which is now the producer of the original block, needs to // be considered by the rule Random-Compute-Location. - sch_tmp->Annotate(block_rv, s_tir::attr::meta_schedule_random_compute_producer, IntImm(DataType::Int(32), 1)); + sch_tmp->Annotate(block_rv, s_tir::attr::meta_schedule_random_compute_producer, + IntImm(DataType::Int(32), 1)); res.push_back(sch_tmp); } catch (const tvm::ffi::Error& e) { } diff --git a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 68bdf960734b..7431e433969e 100644 --- a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -675,7 +675,8 @@ std::vector MultiLevelTilingTensorCoreNode::AddSoftwarePipeline( } } else { // Add local stage and double buffering - sch->Annotate(cache_read, s_tir::attr::manifest_shared_memory_local_stage, IntImm(DataType::Int(32), 1)); + sch->Annotate(cache_read, s_tir::attr::manifest_shared_memory_local_stage, + IntImm(DataType::Int(32), 1)); sch->Annotate(cache_read, s_tir::attr::double_buffer_scope, IntImm(DataType::Int(32), 0)); } } diff --git a/src/s_tir/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc b/src/s_tir/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc index 8115dc91a8ee..bcbbf6746ed3 100644 --- a/src/s_tir/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc +++ b/src/s_tir/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc @@ -68,7 +68,8 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { } // Vectorization if (max_vectorize_extent != -1) { - sch->Annotate(root_rv, s_tir::attr::meta_schedule_vectorize, IntImm(DataType::Int(32), max_vectorize_extent)); + sch->Annotate(root_rv, s_tir::attr::meta_schedule_vectorize, + IntImm(DataType::Int(32), max_vectorize_extent)); } // Unroll if (!unroll_max_steps.empty() && !s_tir::CheckSpatialPrimFunc(sch, root_rv)) { diff --git a/src/s_tir/schedule/analysis/layout.cc b/src/s_tir/schedule/analysis/layout.cc index 7700a94e2d54..035faee48436 100644 --- a/src/s_tir/schedule/analysis/layout.cc +++ b/src/s_tir/schedule/analysis/layout.cc @@ -218,14 +218,16 @@ ffi::Optional SuggestIndexMap(const Buffer& buffer, const ffi::Array

Bind(index, Range::FromMinExtent(0, IntImm(DataType::Int(32), split_exprs[i].extent))); + analyzer->Bind(index, + Range::FromMinExtent(0, IntImm(DataType::Int(32), split_exprs[i].extent))); } // Step 6.2: Fuse all the indices. This is the inverse of Step 5.2. PrimExpr flattened_index = make_const(indices[0]->dtype, 0); int64_t stride = 1; for (int i = static_cast(split_exprs.size()) - 1; i >= 0; --i) { - flattened_index = inv_permuted_indices[i] * IntImm(DataType::Int(32), stride) + flattened_index; + flattened_index = + inv_permuted_indices[i] * IntImm(DataType::Int(32), stride) + flattened_index; stride *= split_exprs[i].extent; } // Step 6.3: Split the flattened index into multiple indices. This is the inverse of Step 5.1. diff --git a/src/s_tir/schedule/primitive/compute_at.cc b/src/s_tir/schedule/primitive/compute_at.cc index 0ea79faab317..79dd56241cf1 100644 --- a/src/s_tir/schedule/primitive/compute_at.cc +++ b/src/s_tir/schedule/primitive/compute_at.cc @@ -820,7 +820,8 @@ struct ComputeAtTraits : public UnpackedInstTraits { } static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, - ffi::String loop_rv, IntImm preserve_unit_loops, IntImm index) { + ffi::String loop_rv, IntImm preserve_unit_loops, + IntImm index) { PythonAPICall py("compute_at"); py.Input("block", block_rv); py.Input("loop", loop_rv); @@ -844,12 +845,12 @@ struct ReverseComputeAtTraits : public UnpackedInstTraitsReverseComputeAt(block_rv, loop_rv, preserve_unit_loops->value != 0, - index->value); + return sch->ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops->value != 0, index->value); } static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block_rv, - ffi::String loop_rv, IntImm preserve_unit_loops, IntImm index) { + ffi::String loop_rv, IntImm preserve_unit_loops, + IntImm index) { PythonAPICall py("reverse_compute_at"); py.Input("block", block_rv); py.Input("loop", loop_rv); diff --git a/src/s_tir/schedule/primitive/loop_transformation.cc b/src/s_tir/schedule/primitive/loop_transformation.cc index f86f5d3b3fa5..8011b09d0c29 100644 --- a/src/s_tir/schedule/primitive/loop_transformation.cc +++ b/src/s_tir/schedule/primitive/loop_transformation.cc @@ -743,13 +743,14 @@ class LoopReconstructor : private StmtMutator { new_stmts.push_back(new_stmt); this->need_remove_loop_.push_back(loops_[i].back()); } - auto new_loop = For(new_loop_vars[0], IntImm(DataType::Int(32), 0), new_loop_extents[0], ForKind::kSerial, - SeqStmt(std::move(new_stmts))); + auto new_loop = For(new_loop_vars[0], IntImm(DataType::Int(32), 0), new_loop_extents[0], + ForKind::kSerial, SeqStmt(std::move(new_stmts))); this->new_inner_loop_ = new_loop; for (size_t i = 1; i < new_loop_vars.size(); ++i) { const Var& loop_var = new_loop_vars[i]; const PrimExpr& loop_extent = new_loop_extents[i]; - new_loop = For(loop_var, IntImm(DataType::Int(32), 0), loop_extent, ForKind::kSerial, new_loop); + new_loop = + For(loop_var, IntImm(DataType::Int(32), 0), loop_extent, ForKind::kSerial, new_loop); } this->new_outer_loop_ = new_loop; } @@ -1313,7 +1314,8 @@ struct FuseTraits : public UnpackedInstTraits { } static ffi::String UnpackedAsPython(ffi::Array outputs, - ffi::Array loop_rvs, IntImm preserve_unit_iters) { + ffi::Array loop_rvs, + IntImm preserve_unit_iters) { PythonAPICall py("fuse"); for (const ffi::String& loop_rv : loop_rvs) { py.Input("", loop_rv); diff --git a/src/s_tir/schedule/primitive/read_write_at.cc b/src/s_tir/schedule/primitive/read_write_at.cc index 04ef08b9d738..7a9e00cbf371 100644 --- a/src/s_tir/schedule/primitive/read_write_at.cc +++ b/src/s_tir/schedule/primitive/read_write_at.cc @@ -306,7 +306,8 @@ struct ReadWriteAtImpl { } Stmt stmt = BufferStore(copy_to, /*value=*/BufferLoad(copy_from, indices), /*indices=*/indices); for (int i = n - 1; i >= 0; --i) { - stmt = For(loop_vars[i], IntImm(DataType::Int(32), 0), domain[i]->extent, ForKind::kSerial, stmt); + stmt = For(loop_vars[i], IntImm(DataType::Int(32), 0), domain[i]->extent, ForKind::kSerial, + stmt); } return SBlockRealize( /*values=*/iter_values, diff --git a/src/s_tir/schedule/traced_schedule.cc b/src/s_tir/schedule/traced_schedule.cc index 5ee3c377cc31..22465846e86c 100644 --- a/src/s_tir/schedule/traced_schedule.cc +++ b/src/s_tir/schedule/traced_schedule.cc @@ -76,11 +76,13 @@ ffi::Array TracedScheduleNode::SamplePerfectTile( max_innermost_factor, &decision), /*convert_negone_to_none=*/true); static const InstructionKind& kind = InstructionKind::Get("SamplePerfectTile"); - trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // - /*inputs=*/{loop_rv}, - /*attrs=*/{IntImm(DataType::Int(32), n), IntImm(DataType::Int(32), max_innermost_factor)}, - /*outputs=*/results), - /*decision=*/decision); + trace_->Append( + /*inst=*/Instruction( + /*kind=*/kind, // + /*inputs=*/{loop_rv}, + /*attrs=*/{IntImm(DataType::Int(32), n), IntImm(DataType::Int(32), max_innermost_factor)}, + /*outputs=*/results), + /*decision=*/decision); return results; } @@ -94,7 +96,9 @@ ffi::Array TracedScheduleNode::SamplePartitionedTile( trace_->Append(/*inst=*/Instruction( /*kind=*/kind, // /*inputs=*/{loop_rv}, - /*attrs=*/{IntImm(DataType::Int(32), n), IntImm(DataType::Int(32), partition_pos), IntImm(DataType::Int(32), innerpart_factor)}, + /*attrs=*/ + {IntImm(DataType::Int(32), n), IntImm(DataType::Int(32), partition_pos), + IntImm(DataType::Int(32), innerpart_factor)}, /*outputs=*/results), /*decision=*/decision); return results; @@ -362,10 +366,11 @@ SBlockRV TracedScheduleNode::CacheRead(const SBlockRV& block_rv, int read_buffer ConcreteScheduleNode::CacheRead(block_rv, read_buffer_index, storage_scope, consumer_blocks); static const InstructionKind& kind = InstructionKind::Get("CacheRead"); - trace_->Append(/*inst=*/Instruction(/*kind=*/kind, - /*inputs=*/{block_rv, consumer_blocks}, - /*attrs=*/{IntImm(DataType::Int(32), read_buffer_index), storage_scope}, - /*outputs=*/{result})); + trace_->Append( + /*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv, consumer_blocks}, + /*attrs=*/{IntImm(DataType::Int(32), read_buffer_index), storage_scope}, + /*outputs=*/{result})); return result; } @@ -376,10 +381,11 @@ SBlockRV TracedScheduleNode::CacheWrite(const SBlockRV& block_rv, int write_buff consumer_blocks); static const InstructionKind& kind = InstructionKind::Get("CacheWrite"); - trace_->Append(/*inst=*/Instruction(/*kind=*/kind, - /*inputs=*/{block_rv, consumer_blocks}, - /*attrs=*/{IntImm(DataType::Int(32), write_buffer_index), storage_scope}, - /*outputs=*/{result})); + trace_->Append( + /*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv, consumer_blocks}, + /*attrs=*/{IntImm(DataType::Int(32), write_buffer_index), storage_scope}, + /*outputs=*/{result})); return result; } @@ -425,10 +431,11 @@ ffi::Array TracedScheduleNode::CacheInplace(const SBlockRV& block_rv, results.push_back(r); } static const InstructionKind& kind = InstructionKind::Get("CacheInplace"); - trace_->Append(/*inst=*/Instruction(/*kind=*/kind, - /*inputs=*/{block_rv}, - /*attrs=*/{IntImm(DataType::Int(32), read_buffer_index), storage_scope}, - /*outputs=*/results)); + trace_->Append( + /*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{IntImm(DataType::Int(32), read_buffer_index), storage_scope}, + /*outputs=*/results)); return result; } @@ -442,10 +449,11 @@ ffi::Array TracedScheduleNode::CacheIndex(const SBlockRV& block_rv, outputs.push_back(r); } static const InstructionKind& kind = InstructionKind::Get("CacheIndex"); - trace_->Append(/*inst=*/Instruction(/*kind=*/kind, - /*inputs=*/{block_rv}, - /*attrs=*/{storage_scope, IntImm(DataType::Int(32), cse_thresh)}, - /*outputs=*/outputs)); + trace_->Append( + /*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{storage_scope, IntImm(DataType::Int(32), cse_thresh)}, + /*outputs=*/outputs)); return result; } @@ -454,10 +462,13 @@ SBlockRV TracedScheduleNode::ReIndex(const SBlockRV& block_rv, int buffer_index, SBlockRV result = ConcreteScheduleNode::ReIndex(block_rv, buffer_index, buffer_index_type); static const InstructionKind& kind = InstructionKind::Get("ReIndex"); - trace_->Append(/*inst=*/Instruction(/*kind=*/kind, - /*inputs=*/{block_rv}, - /*attrs=*/{IntImm(DataType::Int(32), buffer_index), IntImm(DataType::Int(32), static_cast(buffer_index_type))}, - /*outputs=*/{result})); + trace_->Append( + /*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/ + {IntImm(DataType::Int(32), buffer_index), + IntImm(DataType::Int(32), static_cast(buffer_index_type))}, + /*outputs=*/{result})); return result; } @@ -469,10 +480,11 @@ SBlockRV TracedScheduleNode::ReadAt(const LoopRV& loop_rv, const SBlockRV& block ConcreteScheduleNode::ReadAt(loop_rv, block_rv, read_buffer_index, storage_scope); static const InstructionKind& kind = InstructionKind::Get("ReadAt"); - trace_->Append(/*inst=*/Instruction(/*kind=*/kind, - /*inputs=*/{loop_rv, block_rv}, - /*attrs=*/{IntImm(DataType::Int(32), read_buffer_index), storage_scope}, - /*outputs=*/{result})); + trace_->Append( + /*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv, block_rv}, + /*attrs=*/{IntImm(DataType::Int(32), read_buffer_index), storage_scope}, + /*outputs=*/{result})); return result; } @@ -482,10 +494,11 @@ SBlockRV TracedScheduleNode::WriteAt(const LoopRV& loop_rv, const SBlockRV& bloc ConcreteScheduleNode::WriteAt(loop_rv, block_rv, write_buffer_index, storage_scope); static const InstructionKind& kind = InstructionKind::Get("WriteAt"); - trace_->Append(/*inst=*/Instruction(/*kind=*/kind, - /*inputs=*/{loop_rv, block_rv}, - /*attrs=*/{IntImm(DataType::Int(32), write_buffer_index), storage_scope}, - /*outputs=*/{result})); + trace_->Append( + /*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv, block_rv}, + /*attrs=*/{IntImm(DataType::Int(32), write_buffer_index), storage_scope}, + /*outputs=*/{result})); return result; } @@ -497,10 +510,12 @@ void TracedScheduleNode::ComputeAt(const SBlockRV& block_rv, const LoopRV& loop_ static const InstructionKind& kind = InstructionKind::Get("ComputeAt"); trace_->Append( - /*inst=*/Instruction(/*kind=*/kind, - /*inputs=*/{block_rv, loop_rv}, - /*attrs=*/{IntImm(DataType::Int(32), preserve_unit_loops), IntImm(DataType::Int(32), index)}, - /*outputs=*/{})); + /*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{block_rv, loop_rv}, + /*attrs=*/ + {IntImm(DataType::Int(32), preserve_unit_loops), IntImm(DataType::Int(32), index)}, + /*outputs=*/{})); } void TracedScheduleNode::ReverseComputeAt(const SBlockRV& block_rv, const LoopRV& loop_rv, @@ -508,10 +523,11 @@ void TracedScheduleNode::ReverseComputeAt(const SBlockRV& block_rv, const LoopRV ConcreteScheduleNode::ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops, index); static const InstructionKind& kind = InstructionKind::Get("ReverseComputeAt"); - trace_->Append(/*inst=*/Instruction(/*kind=*/kind, - /*inputs=*/{block_rv, loop_rv}, - /*attrs=*/{IntImm(DataType::Int(32), preserve_unit_loops), IntImm(DataType::Int(32), index)}, - /*outputs=*/{})); + trace_->Append(/*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{block_rv, loop_rv}, + /*attrs=*/{IntImm(DataType::Int(32), preserve_unit_loops), IntImm(DataType::Int(32), index)}, + /*outputs=*/{})); } void TracedScheduleNode::ComputeInline(const SBlockRV& block_rv) { @@ -576,7 +592,9 @@ void TracedScheduleNode::StorageAlign(const SBlockRV& block_rv, int buffer_index trace_->Append(/*inst=*/Instruction( /*kind=*/kind, /*inputs=*/{block_rv}, - /*attrs=*/{IntImm(DataType::Int(32), buffer_index), IntImm(DataType::Int(32), axis), IntImm(DataType::Int(32), factor), IntImm(DataType::Int(32), offset)}, + /*attrs=*/ + {IntImm(DataType::Int(32), buffer_index), IntImm(DataType::Int(32), axis), + IntImm(DataType::Int(32), factor), IntImm(DataType::Int(32), offset)}, /*outputs=*/{})); } @@ -704,7 +722,8 @@ void TracedScheduleNode::TransformLayout(const SBlockRV& block_rv, int buffer_in /*kind=*/kind, /*inputs=*/{block_rv, index_map}, /*attrs=*/ - {IntImm(DataType::Int(32), buffer_index), IntImm(DataType::Int(32), static_cast(buffer_index_type)), pad_value, + {IntImm(DataType::Int(32), buffer_index), + IntImm(DataType::Int(32), static_cast(buffer_index_type)), pad_value, IntImm(DataType::Bool(), assume_injective_transform)}, /*outputs=*/{})); } @@ -728,7 +747,9 @@ void TracedScheduleNode::SetAxisSeparator(const SBlockRV& block_rv, int buffer_i trace_->Append(/*inst=*/Instruction( /*kind=*/kind, /*inputs=*/{block_rv}, - /*attrs=*/{IntImm(DataType::Int(32), buffer_index), IntImm(DataType::Int(32), static_cast(buffer_index_type)), axis_separators}, + /*attrs=*/ + {IntImm(DataType::Int(32), buffer_index), + IntImm(DataType::Int(32), static_cast(buffer_index_type)), axis_separators}, /*outputs=*/{})); } @@ -796,7 +817,9 @@ void TracedScheduleNode::AnnotateBufferAccess(const SBlockRV& block_rv, int buff static const InstructionKind& kind = InstructionKind::Get("AnnotateBufferAccess"); trace_->Append(/*inst=*/Instruction( /*kind=*/kind, - /*inputs=*/{block_rv, IntImm(DataType::Int(32), buffer_index), IntImm(DataType::Int(32), static_cast(buffer_index_type)), index_map}, + /*inputs=*/ + {block_rv, IntImm(DataType::Int(32), buffer_index), + IntImm(DataType::Int(32), static_cast(buffer_index_type)), index_map}, /*attrs=*/{}, /*outputs=*/{})); } diff --git a/src/s_tir/transform/default_gpu_schedule.cc b/src/s_tir/transform/default_gpu_schedule.cc index dddaed193b4c..da57252541ad 100644 --- a/src/s_tir/transform/default_gpu_schedule.cc +++ b/src/s_tir/transform/default_gpu_schedule.cc @@ -70,15 +70,17 @@ void ThreadBind(s_tir::Schedule sch, const s_tir::SBlockRV& block, int64_t max_t } // schedule the fused loop if (product > max_thread_per_block * max_threadblocks) { - ffi::Array splits = sch->Split( - fused, - /*factors=*/{std::nullopt, IntImm(DataType::Int(32), max_threadblocks), IntImm(DataType::Int(32), max_thread_per_block)}); + ffi::Array splits = + sch->Split(fused, + /*factors=*/{std::nullopt, IntImm(DataType::Int(32), max_threadblocks), + IntImm(DataType::Int(32), max_thread_per_block)}); sch->Reorder(/*ordered_loop_rvs=*/{splits[1], splits[2], splits[0]}); sch->Bind(splits[1], "blockIdx.x"); sch->Bind(splits[2], "threadIdx.x"); } else { ffi::Array splits = sch->Split( - fused, /*factors=*/{std::nullopt, IntImm(DataType::Int(32), std::min(product, max_thread_per_block))}); + fused, /*factors=*/{std::nullopt, + IntImm(DataType::Int(32), std::min(product, max_thread_per_block))}); sch->Bind(splits[0], "blockIdx.x"); sch->Bind(splits[1], "threadIdx.x"); } diff --git a/src/s_tir/transform/lower_cross_thread_reduction.cc b/src/s_tir/transform/lower_cross_thread_reduction.cc index a07ecb5dd6eb..361466a2f6a1 100644 --- a/src/s_tir/transform/lower_cross_thread_reduction.cc +++ b/src/s_tir/transform/lower_cross_thread_reduction.cc @@ -335,8 +335,8 @@ Stmt TransformReductionBlock(const SBlockRealizeNode* realize, ffi::Array inits; inits.reserve(n_buffers); for (int i = 0; i < n_buffers; ++i) { - inits.push_back( - BufferStore(it_buffers.value()[i], reducer->identity_element[i], {IntImm(DataType::Int(32), 0)})); + inits.push_back(BufferStore(it_buffers.value()[i], reducer->identity_element[i], + {IntImm(DataType::Int(32), 0)})); } stmts.push_back(SBlockRealize(/*iter_values=*/{}, /*predicate=*/const_true(), @@ -464,8 +464,8 @@ Stmt TransformReductionBlock(const SBlockRealizeNode* realize, wb_indices.push_back(Substitute(old_wb_indices[d], var_map)); } for (int i = 0; i < n_buffers; ++i) { - wb_updates.push_back( - BufferStore(wb_buffers[i], BufferLoad(ct_buffers[i], {IntImm(DataType::Int(32), 0)}), wb_indices)); + wb_updates.push_back(BufferStore( + wb_buffers[i], BufferLoad(ct_buffers[i], {IntImm(DataType::Int(32), 0)}), wb_indices)); wb_regions.push_back(BufferRegion(wb_buffers[i], region)); } diff --git a/src/s_tir/transform/memhammer_lower_auto_copy.cc b/src/s_tir/transform/memhammer_lower_auto_copy.cc index 76e1d0302b70..3db122b2ea4e 100644 --- a/src/s_tir/transform/memhammer_lower_auto_copy.cc +++ b/src/s_tir/transform/memhammer_lower_auto_copy.cc @@ -486,7 +486,8 @@ class AutoPadder { } else { int64_t extent = warp_thread_extent_.Get(op->thread_binding.value()->thread_tag).value_or(1); - var_range_.Set(op->loop_var, Range::FromMinExtent(op->min, IntImm(DataType::Int(64), extent))); + var_range_.Set(op->loop_var, + Range::FromMinExtent(op->min, IntImm(DataType::Int(64), extent))); } if (op->kind == ForKind::kVectorized) { vector_var = op->loop_var; diff --git a/src/s_tir/transform/transform_mma_buffer_layout.cc b/src/s_tir/transform/transform_mma_buffer_layout.cc index a6451286d108..d3518ccd81ca 100644 --- a/src/s_tir/transform/transform_mma_buffer_layout.cc +++ b/src/s_tir/transform/transform_mma_buffer_layout.cc @@ -67,8 +67,8 @@ class MmaBufferLayoutTransformer : public StmtExprMutator { for (size_t i = 0; i < size - 2; ++i) { new_shape.push_back(buffer->shape[i]); } - new_shape.insert(new_shape.end(), - {IntImm(DataType::Int(32), dim0->value / 16), IntImm(DataType::Int(32), dim1->value / 8), 2, 2}); + new_shape.insert(new_shape.end(), {IntImm(DataType::Int(32), dim0->value / 16), + IntImm(DataType::Int(32), dim1->value / 8), 2, 2}); Buffer new_buffer = decl_buffer(std::move(new_shape), buffer->dtype, buffer->name, "local", buffer->axis_separators); @@ -89,8 +89,8 @@ class MmaBufferLayoutTransformer : public StmtExprMutator { for (size_t i = 0; i < size - 2; ++i) { new_shape.push_back(buffer->shape[i]); } - new_shape.insert(new_shape.end(), - {IntImm(DataType::Int(32), dim0->value / 32), IntImm(DataType::Int(32), dim1->value / 8), 4, 2}); + new_shape.insert(new_shape.end(), {IntImm(DataType::Int(32), dim0->value / 32), + IntImm(DataType::Int(32), dim1->value / 8), 4, 2}); Buffer new_buffer = decl_buffer(std::move(new_shape), buffer->dtype, buffer->name, "local", buffer->axis_separators); @@ -111,8 +111,8 @@ class MmaBufferLayoutTransformer : public StmtExprMutator { for (size_t i = 0; i < size - 2; ++i) { new_shape.push_back(buffer->shape[i]); } - new_shape.insert(new_shape.end(), - {IntImm(DataType::Int(32), dim0->value / 8), IntImm(DataType::Int(32), dim1->value / 32), 1, 8}); + new_shape.insert(new_shape.end(), {IntImm(DataType::Int(32), dim0->value / 8), + IntImm(DataType::Int(32), dim1->value / 32), 1, 8}); Buffer new_buffer = decl_buffer(std::move(new_shape), buffer->dtype, buffer->name, "local", buffer->axis_separators); diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index e9e5ee233053..14a0549ecb1d 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -27,8 +27,8 @@ #include #include #include -#include #include +#include #include #include @@ -545,7 +545,7 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in Stmt body = GenerateBodyStmt(leaf.store_indices, buffers, leaf.axes_remap, expr_body, info, analyzer); seq_stmt.push_back(SBlockRealize(/*iter_values=*/leaf.bindings, - /*predicate=*/IntImm(DataType::Bool(), 1), + /*predicate=*/const_true(), /*block=*/ SBlock(/*iter_vars=*/leaf.block_iters, /*reads=*/{}, diff --git a/src/tirx/script/builder/frame.cc b/src/tirx/script/builder/frame.cc index 6d7628ad7e99..7a3974e94d6f 100644 --- a/src/tirx/script/builder/frame.cc +++ b/src/tirx/script/builder/frame.cc @@ -195,7 +195,8 @@ void SBlockFrameNode::ExitWithScope() { << "`T.where` is not allowed when `no_realize=True`"; AddToParent(block); } else { - AddToParent(tvm::tirx::SBlockRealize(iter_values, predicate.value_or(IntImm(DataType::Bool(), 1)), block)); + AddToParent(tvm::tirx::SBlockRealize(iter_values, + predicate.value_or(IntImm(DataType::Bool(), 1)), block)); } } diff --git a/src/topi/transform.cc b/src/topi/transform.cc index 203e1b7da6f5..5e81e95c6015 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -224,9 +224,11 @@ TVM_FFI_STATIC_INIT_BLOCK() { bool assume_inbound = args[6].cast(); if (IsConstIntArray(begin) && IsConstIntArray(end) && IsConstIntArray(strides) && IsConstIntArray(x->shape)) { - ffi::Array begin_static = args[1].cast>(); - ffi::Array end_static = args[2].cast>(); - ffi::Array strides_static = args[3].cast>(); + ffi::Array> begin_static = + args[1].cast>>(); + ffi::Array> end_static = + args[2].cast>>(); + ffi::Array strides_static = args[3].cast>(); auto slice_mode = args[5].cast(); if (axes.size()) { *rv = strided_slice_with_axes(x, begin_static, end_static, strides_static, axes, diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 8fdd768b81c6..ecc6f1199b8e 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -58,7 +58,8 @@ TEST(IRF, CountVar) { TEST(IRF, PreOrderVisit) { using namespace tvm; using namespace tvm::tirx; - Stmt init = IfThenElse(const_true(), Evaluate(IntImm(DataType::Int(32), 0)), Evaluate(IntImm(DataType::Int(32), 0))); + Stmt init = IfThenElse(const_true(), Evaluate(IntImm(DataType::Int(32), 0)), + Evaluate(IntImm(DataType::Int(32), 0))); Stmt body = Evaluate(IntImm(DataType::Int(32), 1)); SBlock block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"block", /*body=*/body,