Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 0 additions & 120 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -557,108 +557,6 @@ class FloatImm : public PrimExpr {
TVM_DEFINE_OBJECT_REF_COW_METHOD(FloatImmNode);
};

/*!
* \brief Boolean constant.
*
* This reference type is useful to add additional compile-time
* type checks and helper functions for Integer equal comparisons.
*/
class Bool : public IntImm {
public:
explicit Bool(bool value, Span span = Span()) : IntImm(DataType::Bool(), value, span) {}
Bool operator!() const { return Bool((*this)->value == 0); }
operator bool() const { return (*this)->value != 0; }

TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Bool, IntImm, IntImmNode);
};

// Overload operators to make sure we have the most fine grained types.
inline Bool operator||(const Bool& a, bool b) { return Bool(a.operator bool() || b); }
inline Bool operator||(bool a, const Bool& b) { return Bool(a || b.operator bool()); }
inline Bool operator||(const Bool& a, const Bool& b) {
return Bool(a.operator bool() || b.operator bool());
}
inline Bool operator&&(const Bool& a, bool b) { return Bool(a.operator bool() && b); }
inline Bool operator&&(bool a, const Bool& b) { return Bool(a && b.operator bool()); }
inline Bool operator&&(const Bool& a, const Bool& b) {
return Bool(a.operator bool() && b.operator bool());
}

inline bool operator==(const Bool& a, bool b) { return a.operator bool() == b; }
inline bool operator==(bool a, const Bool& b) { return a == b.operator bool(); }
inline bool operator==(const Bool& a, const Bool& b) {
return a.operator bool() == b.operator bool();
}

/*!
* \brief Container of constant int that adds more constructors.
*
* This is used to store and automate type check
* attributes that must be constant integer.
*
* \sa IntImm
*/
class Integer : public IntImm {
public:
Integer() {}
/*!
* \brief constructor from node.
*/
explicit Integer(ffi::ObjectPtr<IntImmNode> node) : IntImm(node) {}
/*!
* \brief constructor with UnsafeInit
*/
explicit Integer(ffi::UnsafeInit tag) : IntImm(tag) {}
/*!
* \brief Construct integer from int value.
*/
Integer(int value, Span span = Span()) : IntImm(DataType::Int(32), value, span) {} // NOLINT(*)
/*!
* \brief Construct integer from int imm.
* \param other The other value.
*/
Integer(IntImm other) : IntImm(std::move(other)) {} // NOLINT(*)
/*!
* \brief Constructor from enum
* \tparam Enum The enum type.
* \param value The enum value.
*/
template <typename Enum, typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
explicit Integer(Enum value) : Integer(static_cast<int>(value)) {
static_assert(std::is_same<int, typename std::underlying_type<Enum>::type>::value,
"declare enum to be enum int to use visitor");
}
/*!
* \brief Assign an expression to integer.
* \param other another expression.
*/
Integer& operator=(const IntImm& other) {
data_ = ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef<ffi::Object>(other);
return *this;
}
/*!
* \brief convert to int64_t
*/
int64_t IntValue() const {
TVM_FFI_ICHECK(data_ != nullptr) << " Trying to reference a null Integer";
return (*this)->value;
}
// comparators
Bool operator==(int other) const {
if (data_ == nullptr) return Bool(false);
return Bool((*this)->value == other);
}
Bool operator!=(int other) const { return !(*this == other); }
template <typename Enum, typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
Bool operator==(Enum other) const {
return *this == static_cast<int>(other);
}
template <typename Enum, typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
Bool operator!=(Enum other) const {
return *this != static_cast<int>(other);
}
};

/*! \brief range over one dimension */
class RangeNode : public ffi::Object {
public:
Expand Down Expand Up @@ -729,16 +627,6 @@ struct TypeTraits<IntImm> : public ObjectRefWithFallbackTraitsBase<IntImm, int64
}
};

template <>
inline constexpr bool use_default_type_traits_v<Integer> = false;

template <>
struct TypeTraits<Integer> : public ObjectRefWithFallbackTraitsBase<Integer, int64_t> {
TVM_FFI_INLINE static Integer ConvertFallbackValue(int64_t value) {
return Integer(TypeTraits<IntImm>::ConvertFallbackValue(value));
}
};

template <>
inline constexpr bool use_default_type_traits_v<FloatImm> = false;

Expand All @@ -749,14 +637,6 @@ struct TypeTraits<FloatImm> : public ObjectRefWithFallbackTraitsBase<FloatImm, d
}
};

template <>
inline constexpr bool use_default_type_traits_v<Bool> = false;

template <>
struct TypeTraits<Bool> : public ObjectRefWithFallbackTraitsBase<Bool, int64_t> {
TVM_FFI_INLINE static Bool ConvertFallbackValue(int64_t value) { return Bool(value != 0); }
};

// define automatic conversion from bool, int64_t, double to PrimExpr
TVM_FFI_INLINE PrimExpr TypeTraits<PrimExpr>::ConvertFallbackValue(StrictBool value) {
return IntImm(DataType::Bool(), value, Span());
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ namespace attr {
/*!
* \brief Indicates the special calling convention.
*
* Type: Integer
* Type: IntImm
*
* \sa tvm::CallingConv
*/
Expand Down Expand Up @@ -131,7 +131,7 @@ constexpr const char* kGlobalSymbol = "global_symbol";
* and printer emits `s_tir=True` on the decorator.
* Default (attr absent or False) is tirx semantics.
*
* Type: Bool
* Type: IntImm (bool dtype)
*/
constexpr const char* kSTir = "s_tir";

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relax/distributed/struct_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class PlacementSpec : public ffi::ObjectRef {
class ShardingNode : public PlacementSpecNode {
public:
/*! \brief The dimension of tensor we shard*/
Integer sharding_dim;
int64_t sharding_dim;

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relax/nested_msg.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class NestedMsg {
}

// delete the int constructor
// since NestedMsg<Integer>(0) is ambiguous
// since NestedMsg<IntImm>(0) is ambiguous
// 0 can be implicitly casted to nullptr_t
explicit NestedMsg(int val) = delete;
NestedMsg<T>& operator=(int val) = delete;
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/s_tir/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ class ScheduleRule : public ffi::ObjectRef {
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule MultiLevelTilingWideVector(
ffi::String structure, Integer vector_length_in_bits,
ffi::String structure, int64_t vector_length_in_bits,
ffi::Optional<int64_t> max_innermost_factor,
ffi::Optional<ffi::Map<ffi::String, ffi::Any>> reuse_read,
ffi::Optional<ffi::Map<ffi::String, ffi::Any>> reuse_write);
Expand Down
5 changes: 2 additions & 3 deletions include/tvm/tirx/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,16 +160,15 @@ TVM_DLL size_t CalculateExprComplexity(const PrimExpr& expr);
* \param func The TIR PrimFunc for which the constants size to be calculated
* \param constant_byte_alignment The byte alignment required for each constant allocated
*/
TVM_DLL size_t CalculateConstantBytes(const PrimFunc& func, const Integer& constant_byte_alignment);
TVM_DLL size_t CalculateConstantBytes(const PrimFunc& func, int64_t constant_byte_alignment);

/*!
* \brief Calculate the workspace size in bytes needed by the TIR allocates inside the TIR PrimFunc
* \param func The TIR PrimFunc for which the workspace size to be calculated
* \param workspace_byte_alignment The byte alignment required for each tensor allocated in this
* workspace
*/
TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func,
const Integer& workspace_byte_alignment);
TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func, int64_t workspace_byte_alignment);

/*!
* \brief Verify if the given TIR is well-formed. The verification includes:
Expand Down
10 changes: 5 additions & 5 deletions include/tvm/tirx/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,15 +310,15 @@ constexpr const char* kKernelLaunchParams = "tirx.kernel_launch_params";
/*!
* \brief Whether to set noalias rule on the function arguments.
*
* Type: Integer
* Type: IntImm
*/
constexpr const char* kNoAlias = "tirx.noalias";

/*!
* \brief Mark the function as the entry function of
* the final generated runtime module.
*
* Type: Integer
* Type: IntImm
*
* \note There can only be one entry function per module.
*/
Expand All @@ -327,21 +327,21 @@ constexpr const char* kIsEntryFunc = "tirx.is_entry_func";
/*!
* \brief Mark the function as the global function called from the host.
*
* Type: Integer
* Type: IntImm
*/
constexpr const char* kIsGlobalFunc = "tirx.is_global_func";

/*!
* \brief Mark the function as run on the host, mutually exclusive with kTarget.
*
* Type: Integer
* Type: IntImm
*/
constexpr const char* kIsHostFunc = "tirx.is_host_func";

/*!
* \brief Mark the function as scheduled, so the default schedule will pass will skip it.
*
* Type: Integer
* Type: IntImm
*/
constexpr const char* kIsScheduled = "tirx.is_scheduled";

Expand Down
33 changes: 17 additions & 16 deletions include/tvm/topi/detail/strided_slice.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,12 @@ inline int64_t CanonicalizeIndex(int64_t index, int64_t extent, int64_t stride)
}

inline std::tuple<std::vector<int64_t>, std::vector<int64_t>, std::vector<int64_t>> ConvertToVec(
const ffi::Array<Integer>& begin, const ffi::Array<Integer>& end,
const ffi::Array<Integer>& strides, std::string slice_mode) {
const ffi::Array<ffi::Optional<IntImm>>& begin, const ffi::Array<ffi::Optional<IntImm>>& end,
const ffi::Array<IntImm>& strides, std::string slice_mode) {
std::vector<int64_t> stride_vec(strides.size(), 1);
if (slice_mode == "end") {
for (size_t i = 0; i < strides.size(); ++i) {
TVM_FFI_ICHECK(strides[i].defined());
stride_vec[i] = GetConstInt(strides[i]);
stride_vec[i] = strides[i]->value;
}
}
const int64_t max_range = std::numeric_limits<int64_t>::max();
Expand All @@ -66,7 +65,7 @@ inline std::tuple<std::vector<int64_t>, std::vector<int64_t>, std::vector<int64_
// value=None
begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
} else {
begin_vec.push_back(GetConstInt(begin[i]));
begin_vec.push_back(begin[i].value()->value);
}
}
std::vector<int64_t> end_vec;
Expand All @@ -75,14 +74,14 @@ inline std::tuple<std::vector<int64_t>, std::vector<int64_t>, std::vector<int64_
if (!end[i].defined()) {
end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
} else if (slice_mode == "size") {
int64_t end_val = GetConstInt(end[i]);
int64_t end_val = end[i].value()->value;
if (end_val < 0) {
end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
} else {
end_vec.push_back(begin_vec[i] + end_val);
}
} else {
end_vec.push_back(GetConstInt(end[i]));
end_vec.push_back(end[i].value()->value);
}
}
return std::make_tuple(begin_vec, end_vec, stride_vec);
Expand All @@ -91,17 +90,18 @@ inline std::tuple<std::vector<int64_t>, std::vector<int64_t>, std::vector<int64_
inline ffi::Array<PrimExpr> StridedSliceCanonicalizeBegin(const ffi::Array<PrimExpr>& ishape,
const std::vector<int64_t>& begin,
const std::vector<int64_t>& strides,
const ffi::Array<Integer>& axes,
const ffi::Array<int64_t>& axes,
DataType dtype,
std::string slice_mode = "end") {
ffi::Array<PrimExpr> begin_expr;
for (size_t i = 0; i < axes.size(); ++i) {
if (ishape[axes[i].IntValue()]->IsInstance<tvm::IntImmNode>()) {
int64_t dim_i = GetConstInt(ishape[axes[i].IntValue()]);
int64_t ax = axes[i];
if (ishape[ax]->IsInstance<tvm::IntImmNode>()) {
int64_t dim_i = GetConstInt(ishape[ax]);
int64_t begin_i = CanonicalizeIndex(begin[i], dim_i, strides[i]);
begin_expr.push_back(make_const(dtype, begin_i));
} else {
auto idim = ishape[axes[i].IntValue()];
auto idim = ishape[ax];
auto b_expr = make_const(dtype, begin[i]);
PrimExpr b = begin[i] < 0 ? b_expr + idim : b_expr;
auto s = strides[i];
Expand All @@ -119,7 +119,7 @@ inline ffi::Array<PrimExpr> StridedSliceCanonicalizeBegin(const ffi::Array<PrimE
inline ffi::Array<PrimExpr> StridedSliceOutputShape(
const ffi::Array<PrimExpr>& ishape, const std::vector<int64_t>& begin,
const std::vector<int64_t>& end, const std::vector<int64_t>& strides,
const ffi::Array<Integer>& axes, std::string slice_mode,
const ffi::Array<int64_t>& axes, std::string slice_mode,
const ffi::Array<PrimExpr>& begin_canonicalized, bool use_any = false) {
TVM_FFI_ICHECK(!use_any) << "StridedSliceOutputShape does not legacy use_any";
const size_t src_tensor_dim = ishape.size();
Expand All @@ -129,8 +129,9 @@ inline ffi::Array<PrimExpr> StridedSliceOutputShape(
}

for (size_t i = 0; i < axes.size(); ++i) {
if (ishape[axes[i].IntValue()]->IsInstance<tvm::IntImmNode>()) {
const int64_t dim_i = GetConstInt(ishape[axes[i].IntValue()]);
int64_t ax = axes[i];
if (ishape[ax]->IsInstance<tvm::IntImmNode>()) {
const int64_t dim_i = GetConstInt(ishape[ax]);
TVM_FFI_ICHECK(begin_canonicalized[i]->IsInstance<tvm::IntImmNode>());
int64_t begin_i = GetConstInt(begin_canonicalized[i]);
int64_t end_i = CanonicalizeIndex(end[i], dim_i, strides[i]);
Expand All @@ -139,9 +140,9 @@ inline ffi::Array<PrimExpr> StridedSliceOutputShape(
static_cast<int>((interval + std::abs(strides[i]) - 1) / std::abs(strides[i]));
TVM_FFI_ICHECK(strides[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i))
<< ": Input [Begin=" << begin[i] << ", End=" << end[i] << "] is invalid for axis=" << i;
out_shape.Set(axes[i].IntValue(), cast(out_shape[i].dtype(), PrimExpr(slice_size)));
out_shape.Set(ax, cast(out_shape[i].dtype(), PrimExpr(slice_size)));
} else {
out_shape.Set(axes[i].IntValue(), tvm::tirx::Var("dim", out_shape[i]->dtype));
out_shape.Set(ax, tvm::tirx::Var("dim", out_shape[i]->dtype));
}
}

Expand Down
Loading
Loading