Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions include/tvm/ir/base_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,11 @@ class ExprNode : public ffi::Object {

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
// span and ty do not participate in structural equal and hash.
// span does not participate in structural equal and hash.
refl::ObjectDef<ExprNode>()
.def_ro("span", &ExprNode::span, refl::DefaultValue(Span()),
refl::AttachFieldFlag::SEqHashIgnore())
.def_ro("ty", &ExprNode::ty, refl::DefaultValue(Type::Missing()),
refl::AttachFieldFlag::SEqHashIgnore());
.def_ro("ty", &ExprNode::ty, refl::DefaultValue(Type::Missing()));
}

static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,11 @@ class GlobalVarNode : public ExprNode {
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GlobalVarNode>().def_ro("name_hint", &GlobalVarNode::name_hint);
// A GlobalVar identifies a module-level symbol. Its type is derived from the
// corresponding function definition and is not part of the symbol identity.
refl::TypeAttrDef<GlobalVarNode>()
.def("__s_equal__", &GlobalVarNode::SEqual)
.def("__s_hash__", &GlobalVarNode::SHash);
}

bool SEqual(const GlobalVarNode* other,
Expand Down
19 changes: 19 additions & 0 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,25 @@ class SeqExprNode : public ExprNode {
refl::ObjectDef<SeqExprNode>()
.def_ro("blocks", &SeqExprNode::blocks)
.def_ro("body", &SeqExprNode::body);
refl::TypeAttrDef<SeqExprNode>()
.def("__s_equal__", &SeqExprNode::SEqual)
.def("__s_hash__", &SeqExprNode::SHash);
}

bool SEqual(const SeqExprNode* other,
ffi::TypedFunction<bool(AnyView, AnyView, bool, AnyView)> equal) const {
// Establish mappings for symbolic variables defined by bindings before
// comparing their uses in the SeqExpr result type and body.
return equal(blocks, other->blocks, false, "blocks") && equal(ty, other->ty, false, "ty") &&
equal(body, other->body, false, "body");
}

int64_t SHash(int64_t init_hash, ffi::TypedFunction<int64_t(AnyView, int64_t, bool)> hash) const {
int64_t hash_value = init_hash;
hash_value = hash(blocks, hash_value, false);
hash_value = hash(ty, hash_value, false);
hash_value = hash(body, hash_value, false);
return hash_value;
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.SeqExpr", SeqExprNode, ExprNode);
};
Expand Down
25 changes: 25 additions & 0 deletions include/tvm/tirx/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,31 @@ class PrimFuncNode : public BaseFuncNode {
.def_ro("ret_type", &PrimFuncNode::ret_type)
.def_ro("buffer_map", &PrimFuncNode::buffer_map)
.def_ro("body", &PrimFuncNode::body);
refl::TypeAttrDef<PrimFuncNode>()
.def("__s_equal__", &PrimFuncNode::SEqual)
.def("__s_hash__", &PrimFuncNode::SHash);
}

bool SEqual(const PrimFuncNode* other,
ffi::TypedFunction<bool(AnyView, AnyView, bool, AnyView)> equal) const {
// `ty` is derived from the fields below. PrimFunc transformations update
// those source fields without maintaining this redundant cache eagerly.
// Remove this exception once all PrimFunc mutation paths recompute `ty`.
return equal(attrs, other->attrs, false, "attrs") &&
equal(params, other->params, true, "params") &&
equal(ret_type, other->ret_type, false, "ret_type") &&
equal(buffer_map, other->buffer_map, false, "buffer_map") &&
equal(body, other->body, false, "body");
}

int64_t SHash(int64_t init_hash, ffi::TypedFunction<int64_t(AnyView, int64_t, bool)> hash) const {
int64_t hash_value = init_hash;
hash_value = hash(attrs, hash_value, false);
hash_value = hash(params, hash_value, true);
hash_value = hash(ret_type, hash_value, false);
hash_value = hash(buffer_map, hash_value, false);
hash_value = hash(body, hash_value, false);
return hash_value;
}

/*!
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relax/op/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def alloc_tensor(
shape = convert_to_expr(shape)
if isinstance(dtype, str):
dtype = DataTypeImm(dtype)
if isinstance(runtime_device_ind, int):
runtime_device_ind = prim_value(runtime_device_ind)
return _ffi_api.alloc_tensor(storage, offset, shape, dtype, runtime_device_ind) # type: ignore


Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relax/op/vm/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def alloc_tensor(
shape = convert_to_expr(shape)
if isinstance(dtype, str):
dtype = DataTypeImm(dtype)
if isinstance(runtime_device_ind, int):
runtime_device_ind = prim_value(runtime_device_ind)
return _ffi_api.alloc_tensor(storage, offset, shape, dtype, runtime_device_ind) # type: ignore


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,9 @@ class DistributedBufferCompactor : StmtExprMutator {
}
Stmt new_body = compactor(prim_func->body);
new_body = DistBufferReplacer::BufferReplace(new_body, replace_buffer_map);
ffi::ObjectPtr<PrimFuncNode> new_func = ffi::make_object<PrimFuncNode>(*prim_func.get());
new_func->buffer_map = new_func_buffer_map;
new_func->body = new_body;
return std::make_tuple(PrimFunc(new_func), compactor.add_allreduce_kind_);
PrimFunc new_func(prim_func->params, new_body, prim_func->ret_type, new_func_buffer_map,
prim_func->attrs, prim_func->span);
return std::make_tuple(new_func, compactor.add_allreduce_kind_);
}

private:
Expand Down
5 changes: 4 additions & 1 deletion src/relax/script/printer/dependent_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ ExprDoc PrintShapeVar(const PrimExpr& e, const AccessPath& e_p, const IRDocsifie
}
// Step 3. Stringify the PrimExpr if func var exists
if (func_var_mode) {
return LiteralDoc::Str(DocToPythonScript(expr_doc, d->cfg), e_p);
// This nested render is only converting a shape expression into one token.
// Give it an independent configuration so it cannot consume or emit the
// enclosing invocation's access-path diagnostics.
return LiteralDoc::Str(DocToPythonScript(expr_doc, PrinterConfig()), e_p);
}
return expr_doc;
}
Expand Down
6 changes: 3 additions & 3 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -6227,9 +6227,9 @@ def main(input: R.Tensor((10, 10), dtype="float32")) -> R.Tuple(
with R.dataflow():
lv: R.Tensor((20,), dtype="float32") = R.hamming_window(
R.prim_value(20),
R.prim_value(1),
R.prim_value(T.float32(0.54000000000000004)),
R.prim_value(T.float32(0.46000000000000002)),
R.prim_value(True),
R.prim_value(T.float64(0.54000000000000004)),
R.prim_value(T.float64(0.46000000000000002)),
dtype="float32",
)
gv: R.Tuple(R.Tensor((20,), dtype="float32")) = (lv,)
Expand Down
2 changes: 2 additions & 0 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3945,6 +3945,8 @@ def test_dynamic_shape_squeeze(axis):
axes = relax.Var("axes", relax.TensorType([1], "int64"))
gv = relax.Var("gv", tvm.ir.PrimType("int64"))
body = relax.SeqExpr([relax.DataflowBlock([relax.VarBinding(gv, a)])], gv)
# Match the importer boundary, where BlockBuilder populates the SeqExpr result type.
body = relax.BlockBuilder().normalize(body)
expected_func = relax.Function([x, axes], body, tvm.ir.PrimType("int64")).with_attrs(
{"num_input": 1, "global_symbol": "main"}
)
Expand Down
8 changes: 7 additions & 1 deletion tests/python/relax/test_frontend_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,13 @@ def main(
R.func_attr({"num_input": 1})
with R.dataflow():
lv: R.Tuple(R.Tensor(dtype="int32", ndim=1), R.Tensor(dtype="int64", ndim=1)) = (
R.unique(x, R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0))
R.unique(
x,
R.prim_value(False),
R.prim_value(False),
R.prim_value(True),
R.prim_value(False),
)
)
lv1: R.Tensor(dtype="int32", ndim=1) = lv[0]
lv2: R.Tensor(dtype="int64", ndim=1) = lv[1]
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relax/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def func_2(A: R.Tensor([16, 16], "float32")):

with pytest.raises(
ValueError,
match=re.escape("<root>.body.blocks[0].bindings[0].value.op"),
match=re.escape("<root>.ty.ret.shape.ty.values[0].value"),
):
assert_structural_equal(func_1, func_2)

Expand Down
Loading