diff --git a/include/tvm/ir/base_expr.h b/include/tvm/ir/base_expr.h index 6d566bd5c92e..a836d0410a8a 100644 --- a/include/tvm/ir/base_expr.h +++ b/include/tvm/ir/base_expr.h @@ -299,12 +299,11 @@ class ExprNode : public ffi::Object { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - // span and ty do not participate in structural equal and hash. + // span does not participate in structural equal and hash. refl::ObjectDef() .def_ro("span", &ExprNode::span, refl::DefaultValue(Span()), refl::AttachFieldFlag::SEqHashIgnore()) - .def_ro("ty", &ExprNode::ty, refl::DefaultValue(Type::Missing()), - refl::AttachFieldFlag::SEqHashIgnore()); + .def_ro("ty", &ExprNode::ty, refl::DefaultValue(Type::Missing())); } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 7f1c84f50f75..dd168a2462c8 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -276,6 +276,11 @@ class GlobalVarNode : public ExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("name_hint", &GlobalVarNode::name_hint); + // A GlobalVar identifies a module-level symbol. Its type is derived from the + // corresponding function definition and is not part of the symbol identity. + refl::TypeAttrDef() + .def("__s_equal__", &GlobalVarNode::SEqual) + .def("__s_hash__", &GlobalVarNode::SHash); } bool SEqual(const GlobalVarNode* other, diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 83e03c13e7b2..51b4ed43bdb9 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -501,6 +501,25 @@ class SeqExprNode : public ExprNode { refl::ObjectDef() .def_ro("blocks", &SeqExprNode::blocks) .def_ro("body", &SeqExprNode::body); + refl::TypeAttrDef() + .def("__s_equal__", &SeqExprNode::SEqual) + .def("__s_hash__", &SeqExprNode::SHash); + } + + bool SEqual(const SeqExprNode* other, + ffi::TypedFunction equal) const { + // Establish mappings for symbolic variables defined by bindings before + // comparing their uses in the SeqExpr result type and body. + return equal(blocks, other->blocks, false, "blocks") && equal(ty, other->ty, false, "ty") && + equal(body, other->body, false, "body"); + } + + int64_t SHash(int64_t init_hash, ffi::TypedFunction hash) const { + int64_t hash_value = init_hash; + hash_value = hash(blocks, hash_value, false); + hash_value = hash(ty, hash_value, false); + hash_value = hash(body, hash_value, false); + return hash_value; } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.expr.SeqExpr", SeqExprNode, ExprNode); }; diff --git a/include/tvm/tirx/function.h b/include/tvm/tirx/function.h index e4e33f35760c..a69b32a43377 100644 --- a/include/tvm/tirx/function.h +++ b/include/tvm/tirx/function.h @@ -108,6 +108,31 @@ class PrimFuncNode : public BaseFuncNode { .def_ro("ret_type", &PrimFuncNode::ret_type) .def_ro("buffer_map", &PrimFuncNode::buffer_map) .def_ro("body", &PrimFuncNode::body); + refl::TypeAttrDef() + .def("__s_equal__", &PrimFuncNode::SEqual) + .def("__s_hash__", &PrimFuncNode::SHash); + } + + bool SEqual(const PrimFuncNode* other, + ffi::TypedFunction equal) const { + // `ty` is derived from the fields below. PrimFunc transformations update + // those source fields without maintaining this redundant cache eagerly. + // Remove this exception once all PrimFunc mutation paths recompute `ty`. + return equal(attrs, other->attrs, false, "attrs") && + equal(params, other->params, true, "params") && + equal(ret_type, other->ret_type, false, "ret_type") && + equal(buffer_map, other->buffer_map, false, "buffer_map") && + equal(body, other->body, false, "body"); + } + + int64_t SHash(int64_t init_hash, ffi::TypedFunction hash) const { + int64_t hash_value = init_hash; + hash_value = hash(attrs, hash_value, false); + hash_value = hash(params, hash_value, true); + hash_value = hash(ret_type, hash_value, false); + hash_value = hash(buffer_map, hash_value, false); + hash_value = hash(body, hash_value, false); + return hash_value; } /*! diff --git a/python/tvm/relax/op/memory/memory.py b/python/tvm/relax/op/memory/memory.py index 624d39a1336a..a1dec11069e9 100644 --- a/python/tvm/relax/op/memory/memory.py +++ b/python/tvm/relax/op/memory/memory.py @@ -98,6 +98,8 @@ def alloc_tensor( shape = convert_to_expr(shape) if isinstance(dtype, str): dtype = DataTypeImm(dtype) + if isinstance(runtime_device_ind, int): + runtime_device_ind = prim_value(runtime_device_ind) return _ffi_api.alloc_tensor(storage, offset, shape, dtype, runtime_device_ind) # type: ignore diff --git a/python/tvm/relax/op/vm/vm.py b/python/tvm/relax/op/vm/vm.py index 0fc236d59ba4..afa9782e3dd3 100644 --- a/python/tvm/relax/op/vm/vm.py +++ b/python/tvm/relax/op/vm/vm.py @@ -98,6 +98,8 @@ def alloc_tensor( shape = convert_to_expr(shape) if isinstance(dtype, str): dtype = DataTypeImm(dtype) + if isinstance(runtime_device_ind, int): + runtime_device_ind = prim_value(runtime_device_ind) return _ffi_api.alloc_tensor(storage, offset, shape, dtype, runtime_device_ind) # type: ignore diff --git a/src/relax/distributed/transform/lower_global_view_to_local_view.cc b/src/relax/distributed/transform/lower_global_view_to_local_view.cc index f37c0d3863dc..c04970f6316d 100644 --- a/src/relax/distributed/transform/lower_global_view_to_local_view.cc +++ b/src/relax/distributed/transform/lower_global_view_to_local_view.cc @@ -166,10 +166,9 @@ class DistributedBufferCompactor : StmtExprMutator { } Stmt new_body = compactor(prim_func->body); new_body = DistBufferReplacer::BufferReplace(new_body, replace_buffer_map); - ffi::ObjectPtr new_func = ffi::make_object(*prim_func.get()); - new_func->buffer_map = new_func_buffer_map; - new_func->body = new_body; - return std::make_tuple(PrimFunc(new_func), compactor.add_allreduce_kind_); + PrimFunc new_func(prim_func->params, new_body, prim_func->ret_type, new_func_buffer_map, + prim_func->attrs, prim_func->span); + return std::make_tuple(new_func, compactor.add_allreduce_kind_); } private: diff --git a/src/relax/script/printer/dependent_type.cc b/src/relax/script/printer/dependent_type.cc index d5c09cbcc704..c3b77223b57d 100644 --- a/src/relax/script/printer/dependent_type.cc +++ b/src/relax/script/printer/dependent_type.cc @@ -54,7 +54,10 @@ ExprDoc PrintShapeVar(const PrimExpr& e, const AccessPath& e_p, const IRDocsifie } // Step 3. Stringify the PrimExpr if func var exists if (func_var_mode) { - return LiteralDoc::Str(DocToPythonScript(expr_doc, d->cfg), e_p); + // This nested render is only converting a shape expression into one token. + // Give it an independent configuration so it cannot consume or emit the + // enclosing invocation's access-path diagnostics. + return LiteralDoc::Str(DocToPythonScript(expr_doc, PrinterConfig()), e_p); } return expr_doc; } diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 1ee88ea846f7..afd53f1b7404 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -6227,9 +6227,9 @@ def main(input: R.Tensor((10, 10), dtype="float32")) -> R.Tuple( with R.dataflow(): lv: R.Tensor((20,), dtype="float32") = R.hamming_window( R.prim_value(20), - R.prim_value(1), - R.prim_value(T.float32(0.54000000000000004)), - R.prim_value(T.float32(0.46000000000000002)), + R.prim_value(True), + R.prim_value(T.float64(0.54000000000000004)), + R.prim_value(T.float64(0.46000000000000002)), dtype="float32", ) gv: R.Tuple(R.Tensor((20,), dtype="float32")) = (lv,) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index f4dbac634a17..9058cdf70ae7 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -3945,6 +3945,8 @@ def test_dynamic_shape_squeeze(axis): axes = relax.Var("axes", relax.TensorType([1], "int64")) gv = relax.Var("gv", tvm.ir.PrimType("int64")) body = relax.SeqExpr([relax.DataflowBlock([relax.VarBinding(gv, a)])], gv) + # Match the importer boundary, where BlockBuilder populates the SeqExpr result type. + body = relax.BlockBuilder().normalize(body) expected_func = relax.Function([x, axes], body, tvm.ir.PrimType("int64")).with_attrs( {"num_input": 1, "global_symbol": "main"} ) diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 3397d08abb57..c756227fe09b 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -643,7 +643,13 @@ def main( R.func_attr({"num_input": 1}) with R.dataflow(): lv: R.Tuple(R.Tensor(dtype="int32", ndim=1), R.Tensor(dtype="int64", ndim=1)) = ( - R.unique(x, R.prim_value(0), R.prim_value(0), R.prim_value(1), R.prim_value(0)) + R.unique( + x, + R.prim_value(False), + R.prim_value(False), + R.prim_value(True), + R.prim_value(False), + ) ) lv1: R.Tensor(dtype="int32", ndim=1) = lv[0] lv2: R.Tensor(dtype="int64", ndim=1) = lv[1] diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py index 9ee05b7fdd8f..b654edfc310b 100644 --- a/tests/python/relax/test_utils.py +++ b/tests/python/relax/test_utils.py @@ -143,7 +143,7 @@ def func_2(A: R.Tensor([16, 16], "float32")): with pytest.raises( ValueError, - match=re.escape(".body.blocks[0].bindings[0].value.op"), + match=re.escape(".ty.ret.shape.ty.values[0].value"), ): assert_structural_equal(func_1, func_2)