From 3c464f5ad14ee82da93642105f314b5f269654a5 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 27 May 2026 23:09:27 +0000 Subject: [PATCH 1/5] [REFACTOR][SCRIPT] Lift TVMScript entry point into script/printer/printer.h Add a new public header `include/tvm/script/printer/printer.h` that declares the free function `tvm::Script(node, optional_config)` plus the dispatch vtable used by per-dialect printers. `PrinterConfig` and related dataclass helpers stay in `config.h`; `printer.h` is the entry-point header that callers include to invoke printing. The `TVMScriptPrinter::Script(...)` static method body moves to the free function `tvm::Script(...)` in `src/script/printer/script_printer.cc`. `TVMScriptPrinter` class and `TVM_REGISTER_SCRIPT_AS_REPR` macro move from `config.h` to `printer.h`. Existing direct callers in s_tir/schedule/error.cc, script/printer/config.cc, and script/printer/utils.h are rewritten to the new entry point. --- include/tvm/script/printer/config.h | 47 ++++-------------- include/tvm/script/printer/printer.h | 71 ++++++++++++++++++++++++++++ src/s_tir/schedule/error.cc | 4 +- src/script/printer/config.cc | 4 +- src/script/printer/script_printer.cc | 7 ++- src/script/printer/utils.h | 13 +---- 6 files changed, 88 insertions(+), 58 deletions(-) create mode 100644 include/tvm/script/printer/printer.h diff --git a/include/tvm/script/printer/config.h b/include/tvm/script/printer/config.h index 19510e76a816..7a84a920245d 100644 --- a/include/tvm/script/printer/config.h +++ b/include/tvm/script/printer/config.h @@ -18,7 +18,11 @@ */ /*! * \file tvm/script/printer/config.h - * \brief Printer class to print repr string of each AST/IR nodes. + * \brief Configuration object for the TVMScript printer. + * + * Contains PrinterConfig / PrinterConfigNode, GetBuiltinKeywords, GetExtraConfig, + * and RedirectedReprPrinterMethod. The entry-point free function tvm::Script() + * and the dispatch vtable TVMScriptPrinter live in printer.h. */ #ifndef TVM_SCRIPT_PRINTER_CONFIG_H_ #define TVM_SCRIPT_PRINTER_CONFIG_H_ @@ -30,7 +34,6 @@ #include #include #include -#include #include #include @@ -156,48 +159,14 @@ class TVM_DLL PrinterConfig : public ffi::ObjectRef { TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PrinterConfig, ffi::ObjectRef, PrinterConfigNode); }; -/*! \brief TVMScript-based printer for IR nodes. */ -class TVMScriptPrinter { - public: - /* Convert the object to TVMScript format */ - TVM_DLL static std::string Script(const ffi::ObjectRef& node, - const ffi::Optional& cfg); - // Allow registration to be printer. - using FType = NodeFunctor; - TVM_DLL static FType& vtable(); -}; - -#define TVM_OBJECT_ENABLE_SCRIPT_PRINTER() \ - std::string Script(const ffi::Optional& config = std::nullopt) const { \ - return TVMScriptPrinter::Script(ffi::GetRef(this), \ - config.value_or(PrinterConfig())); \ - } - /*! - * \brief The fallback body used by TVM_REGISTER_SCRIPT_AS_REPR. + * \brief The fallback body used by TVM_REGISTER_SCRIPT_AS_REPR (defined in printer.h). * - * Tries to format \p obj via TVMScriptPrinter::Script; on error falls back to - * a plain address string. Defined in src/script/printer/config.cc so that + * Tries to format \p obj via tvm::Script; on error falls back to a plain + * address string. Defined in src/script/printer/config.cc so that * is not pulled into this public header. */ TVM_DLL std::string RedirectedReprPrinterMethod(const ffi::ObjectRef& obj); -/*! - * \brief Register Script as the kRepr callback for ObjectType and install - * the per-type dispatch entry in TVMScriptPrinter::vtable(). - * - * \param ObjectType The concrete object node type (e.g. tirx::VarNode). - * \param Method The TVMScriptPrinter vtable dispatch function. - */ -#define TVM_REGISTER_SCRIPT_AS_REPR(ObjectType, Method) \ - TVM_FFI_STATIC_INIT_BLOCK() { \ - namespace refl = tvm::ffi::reflection; \ - refl::TypeAttrDef().def(refl::type_attr::kRepr, \ - [](ffi::ObjectRef obj, ffi::Function) -> ffi::String { \ - return RedirectedReprPrinterMethod(obj); \ - }); \ - } \ - TVM_STATIC_IR_FUNCTOR(TVMScriptPrinter, vtable).set_dispatch(Method) - } // namespace tvm #endif // TVM_SCRIPT_PRINTER_CONFIG_H_ diff --git a/include/tvm/script/printer/printer.h b/include/tvm/script/printer/printer.h new file mode 100644 index 000000000000..6ace9b842054 --- /dev/null +++ b/include/tvm/script/printer/printer.h @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/script/printer/printer.h + * \brief Entry-point header for TVMScript printing. + * + * Declares the free function `tvm::Script(node, optional_config)` and the + * dispatch vtable `TVMScriptPrinter::vtable()` used by per-dialect printers. + * `PrinterConfig` and its dataclass helpers live in config.h; this header is + * what callers include to invoke printing. + */ +#ifndef TVM_SCRIPT_PRINTER_PRINTER_H_ +#define TVM_SCRIPT_PRINTER_PRINTER_H_ + +#include +#include + +namespace tvm { + +/*! \brief Print \p node as TVMScript with the given \p config. + * + * Falls back to ffi::ReprPrint for types not registered with TVMScriptPrinter. + */ +TVM_DLL std::string Script(const ffi::ObjectRef& node, + const ffi::Optional& config = std::nullopt); + +/*! \brief Dispatch vtable used by per-dialect printers to register their + * object-type printing functions. Internal, but exposed here because + * TVM_REGISTER_SCRIPT_AS_REPR refers to it. + */ +class TVMScriptPrinter { + public: + using FType = NodeFunctor; + TVM_DLL static FType& vtable(); +}; + +/*! + * \brief Register Script as the kRepr callback for ObjectType and install + * the per-type dispatch entry in TVMScriptPrinter::vtable(). + * + * \param ObjectType The concrete object node type (e.g. tirx::VarNode). + * \param Method The TVMScriptPrinter vtable dispatch function. + */ +#define TVM_REGISTER_SCRIPT_AS_REPR(ObjectType, Method) \ + TVM_FFI_STATIC_INIT_BLOCK() { \ + namespace refl = tvm::ffi::reflection; \ + refl::TypeAttrDef().def(refl::type_attr::kRepr, \ + [](ffi::ObjectRef obj, ffi::Function) -> ffi::String { \ + return RedirectedReprPrinterMethod(obj); \ + }); \ + } \ + TVM_STATIC_IR_FUNCTOR(TVMScriptPrinter, vtable).set_dispatch(Method) + +} // namespace tvm +#endif // TVM_SCRIPT_PRINTER_PRINTER_H_ diff --git a/src/s_tir/schedule/error.cc b/src/s_tir/schedule/error.cc index 422352ad8857..73a29a59d516 100644 --- a/src/s_tir/schedule/error.cc +++ b/src/s_tir/schedule/error.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "./utils.h" namespace tvm { @@ -47,7 +49,7 @@ ffi::String ScheduleError::RenderReport(const ffi::String& primitive) const { } os << "ScheduleError: An error occurred in the schedule primitive '" << primitive << "'.\n\nThe IR with diagnostic is:\n" - << TVMScriptPrinter::Script(mod, cfg) << std::endl; + << tvm::Script(mod, cfg) << std::endl; // print error message os << "Error message: " << msg; diff --git a/src/script/printer/config.cc b/src/script/printer/config.cc index d68aaff2ce77..87ca87979fe5 100644 --- a/src/script/printer/config.cc +++ b/src/script/printer/config.cc @@ -17,7 +17,7 @@ * under the License. */ #include -#include +#include #include @@ -25,7 +25,7 @@ namespace tvm { std::string RedirectedReprPrinterMethod(const ffi::ObjectRef& obj) { try { - return TVMScriptPrinter::Script(obj, std::nullopt); + return tvm::Script(obj, std::nullopt); } catch (const tvm::ffi::Error& e) { LOG(WARNING) << "TVMScript printer falls back to the basic address printer with the error:\n" << e.what(); diff --git a/src/script/printer/script_printer.cc b/src/script/printer/script_printer.cc index f3fc27cf42db..c53d5a165fca 100644 --- a/src/script/printer/script_printer.cc +++ b/src/script/printer/script_printer.cc @@ -21,7 +21,7 @@ #include #include #include -#include +#include #include @@ -34,8 +34,7 @@ TVMScriptPrinter::FType& TVMScriptPrinter::vtable() { return inst; } -std::string TVMScriptPrinter::Script(const ffi::ObjectRef& node, - const ffi::Optional& cfg) { +std::string Script(const ffi::ObjectRef& node, const ffi::Optional& cfg) { if (!TVMScriptPrinter::vtable().can_dispatch(node)) { // Fall back to ffi::ReprPrint for types not registered with TVMScriptPrinter. return std::string(ffi::ReprPrint(ffi::Any(node))); @@ -174,7 +173,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef() .def("node.PrinterConfig", [](ffi::Map config_dict) { return PrinterConfig(config_dict); }) - .def("node.TVMScriptPrinterScript", TVMScriptPrinter::Script); + .def("node.TVMScriptPrinterScript", tvm::Script); } } // namespace tvm diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h index 67fbf8e1553c..e1b59aa0c788 100644 --- a/src/script/printer/utils.h +++ b/src/script/printer/utils.h @@ -26,8 +26,8 @@ #include #include #include -#include #include +#include #include #include @@ -46,17 +46,6 @@ namespace printer { // definition here would force the dialect headers to depend on this shared // header, which the per-dialect restructure aims to avoid for cross-directory // references. See each `/script/printer/utils.h` for the macro. -inline std::string RedirectedReprPrinterMethod(const ffi::ObjectRef& obj) { - try { - return TVMScriptPrinter::Script(obj, std::nullopt); - } catch (const tvm::ffi::Error& e) { - LOG(WARNING) << "TVMScript printer falls back to the basic address printer with the error:\n" - << e.what(); - std::ostringstream os; - os << obj->GetTypeKey() << '(' << obj.get() << ')'; - return os.str(); - } -} inline std::string Docsify(const ffi::ObjectRef& obj, const IRDocsifier& d, const Frame& f, const PrinterConfig& cfg) { From 410155af4fc9d1df6015b937d9acb86b640855a5 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 27 May 2026 23:10:48 +0000 Subject: [PATCH 2/5] [REFACTOR][SCRIPT] Drop TVM_OBJECT_ENABLE_SCRIPT_PRINTER macro The macro injected a `Script()` member method on IR Object types and forced those headers to reverse-include the script printer's config. The IR layer should not depend on the script layer; script depends on ir. Remove the macro from include/tvm/script/printer/config.h and drop its 5 usages (PrimExprNode, IRModuleNode, tirx Buffer/PrimFunc/Stmt nodes). The single C++ caller of `mod->Script()` in src/s_tir/meta_schedule/database/json_database.cc switches to the explicit `tvm::Script(...)` free function from include/tvm/script/printer/printer.h. The reverse-include of script/printer/config.h is dropped from the five IR/tirx headers, restoring one-way dependency: script depends on ir, never the other way. --- include/tvm/ir/expr.h | 3 --- include/tvm/ir/module.h | 3 --- include/tvm/tirx/buffer.h | 2 -- include/tvm/tirx/function.h | 2 -- include/tvm/tirx/stmt.h | 3 --- src/s_tir/meta_schedule/database/json_database.cc | 3 ++- 6 files changed, 2 insertions(+), 14 deletions(-) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index fcd267163c2c..c351dd83d855 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -31,7 +31,6 @@ #include #include #include -#include #include #include @@ -113,8 +112,6 @@ class PrimExprNode : public BaseExprNode { refl::ObjectDef().def_ro("dtype", &PrimExprNode::dtype); } - TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); - static constexpr const uint32_t _type_child_slots = 40; TVM_FFI_DECLARE_OBJECT_INFO("ir.PrimExpr", PrimExprNode, BaseExprNode); }; diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 6a5f41ca8d37..34a451be0846 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -34,7 +34,6 @@ #include #include #include -#include #include #include @@ -241,8 +240,6 @@ class IRModuleNode : public ffi::Object { */ TVM_DLL std::unordered_set Imports() const; - TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.IRModule", IRModuleNode, ffi::Object); diff --git a/include/tvm/tirx/buffer.h b/include/tvm/tirx/buffer.h index b32b06b7559d..a5146600f4fa 100644 --- a/include/tvm/tirx/buffer.h +++ b/include/tvm/tirx/buffer.h @@ -28,7 +28,6 @@ #include #include #include -#include #include #include @@ -166,7 +165,6 @@ class BufferNode : public ffi::Object { static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Buffer", BufferNode, ffi::Object); - TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); }; /*! diff --git a/include/tvm/tirx/function.h b/include/tvm/tirx/function.h index 45a8600a6ee4..0fae5bb96152 100644 --- a/include/tvm/tirx/function.h +++ b/include/tvm/tirx/function.h @@ -29,7 +29,6 @@ #include #include #include -#include #include #include #include @@ -120,7 +119,6 @@ class PrimFuncNode : public BaseFuncNode { */ TVM_DLL FuncType func_type_annotation() const; - TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.PrimFunc", PrimFuncNode, BaseFuncNode); }; diff --git a/include/tvm/tirx/stmt.h b/include/tvm/tirx/stmt.h index 39cfbac0cdca..2e336d1292c4 100644 --- a/include/tvm/tirx/stmt.h +++ b/include/tvm/tirx/stmt.h @@ -25,7 +25,6 @@ #define TVM_TIRX_STMT_H_ #include -#include #include #include #include @@ -55,8 +54,6 @@ class StmtNode : public ffi::Object { refl::ObjectDef().def_ro("span", &StmtNode::span); } - TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const uint32_t _type_child_slots = 15; diff --git a/src/s_tir/meta_schedule/database/json_database.cc b/src/s_tir/meta_schedule/database/json_database.cc index 8705412fa28e..3b949bb186a9 100644 --- a/src/s_tir/meta_schedule/database/json_database.cc +++ b/src/s_tir/meta_schedule/database/json_database.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include #include @@ -201,7 +202,7 @@ Database Database::JSONDatabase(ffi::String path_workload, ffi::String path_tuni } catch (std::runtime_error& e) { TVM_FFI_THROW(ValueError) << "Unable to parse TuningRecord, on line " << (task_id + 1) << " of file " << path_tuning_record << ". The workload is:\n" - << (workload.defined() ? workload->mod->Script() : "(null)") + << (workload.defined() ? tvm::Script(workload->mod) : "(null)") << "\nThe JSONObject of TuningRecord is:\n" << json_obj << "\nThe error message is:\n" << e.what(); From b3ce7cd94d2462d95b3e696c06bdeef56a49507c Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 27 May 2026 23:11:45 +0000 Subject: [PATCH 3/5] [REFACTOR][SCRIPT] Move dialect-specific PrinterConfig fields to extra_config PrinterConfig currently has 5 dialect-specific fields (tir_prefix, tir_import_module, tirx_prefix, relax_prefix, buffer_dtype). The first four have zero in-tree readers and are removed entirely. The fifth, buffer_dtype, has a single reader in src/tirx/script/printer/buffer.cc which is rewritten to look up the value from PrinterConfig::extra_config under the key "tirx.buffer_dtype". After this commit, PrinterConfig's core schema is dialect-agnostic; per-dialect customizations live in the extra_config ffi::Map. --- include/tvm/script/printer/config.h | 14 -------------- src/script/printer/script_printer.cc | 9 --------- src/tirx/script/printer/buffer.cc | 8 ++++++-- 3 files changed, 6 insertions(+), 25 deletions(-) diff --git a/include/tvm/script/printer/config.h b/include/tvm/script/printer/config.h index 7a84a920245d..2b05688fb2eb 100644 --- a/include/tvm/script/printer/config.h +++ b/include/tvm/script/printer/config.h @@ -48,20 +48,6 @@ class PrinterConfigNode : public ffi::Object { bool show_meta = false; /*! \brief The prefix of IR nodes */ ffi::String ir_prefix = "I"; - /*! \brief The prefix of TIR nodes */ - ffi::String tir_prefix = "T"; - /*! - * \brief The TIR module name used in the printed import (e.g. "tir" or "tirx"). - * Used in the header comment: "from tvm.script import as ". - * When tir_prefix is "Tx", set to "tirx" so the printed script uses "import tirx as Tx". - */ - ffi::String tir_import_module = "tir"; - /*! \brief The prefix of TIRX nodes */ - ffi::String tirx_prefix = "Tx"; - /*! \brief Default buffer dtype */ - DataType buffer_dtype = DataType::Float(32); - /*! \brief The prefix of Relax nodes */ - ffi::String relax_prefix = "R"; /*! * \brief The alias of the current module at cross-function call * \note Directly use module name if it's empty. diff --git a/src/script/printer/script_printer.cc b/src/script/printer/script_printer.cc index c53d5a165fca..39cc6a0ed4b6 100644 --- a/src/script/printer/script_printer.cc +++ b/src/script/printer/script_printer.cc @@ -67,15 +67,6 @@ PrinterConfig::PrinterConfig(ffi::Map config_dict) { if (auto v = config_dict.Get("ir_prefix")) { n->ir_prefix = Downcast(v.value()); } - if (auto v = config_dict.Get("tir_prefix")) { - n->tir_prefix = Downcast(v.value()); - } - if (auto v = config_dict.Get("tir_import_module")) { - n->tir_import_module = Downcast(v.value()); - } - if (auto v = config_dict.Get("relax_prefix")) { - n->relax_prefix = Downcast(v.value()); - } if (auto v = config_dict.Get("module_alias")) { n->module_alias = Downcast(v.value()); } diff --git a/src/tirx/script/printer/buffer.cc b/src/tirx/script/printer/buffer.cc index 32d50a8f8d6d..ff92af62e6a0 100644 --- a/src/tirx/script/printer/buffer.cc +++ b/src/tirx/script/printer/buffer.cc @@ -92,8 +92,12 @@ ffi::Map BufferAttrs(tirx::Buffer buffer, const AccessPath kwargs.Set("shape", TupleDoc(results)); } // Step 2. Handle `buffer.dtype` - if (buffer->dtype != d->cfg->buffer_dtype) { - kwargs.Set("dtype", LiteralDoc::DataType(buffer->dtype, buffer_p->Attr("dtype"))); + { + DataType default_buf_dtype = + d->cfg->GetExtraConfig("tirx.buffer_dtype", DataType::Float(32)); + if (buffer->dtype != default_buf_dtype) { + kwargs.Set("dtype", LiteralDoc::DataType(buffer->dtype, buffer_p->Attr("dtype"))); + } } // Step 3. Handle `buffer.data` // For tmem scope, DeclBuffer does not accept `data` (it auto-creates the data var). From 01289b6facb796b44eb076430ab5ecffe3ddc350 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 27 May 2026 23:14:30 +0000 Subject: [PATCH 4/5] [REFACTOR][PYTHON] tvm.script: drop dialect kwargs, expose extra_config Following the C++ PrinterConfig cleanup, drop the four dialect-specific keyword arguments from the Python script entry points (Scriptable.script, _relax_script, show, BasePyModule.script): tir_prefix, tir_import_module, tirx_prefix, relax_prefix, buffer_dtype. Add a generic extra_config: dict | None = None kwarg that maps directly to the C++ PrinterConfig.extra_config. The auto-switch logic in Scriptable.script that detects tirx PrimFunc / IRModule and switches to "Tx" prefix is preserved; it now stamps extra_config["tirx.prefix"] instead of the removed tir_prefix kwarg. Callers can suppress auto-switch or override the prefix via extra_config={"tirx.prefix": "..."}. Update six call sites in tests/python/tirx/ that previously passed dialect kwargs to use extra_config={"tirx.prefix": "Tx"} instead. --- python/tvm/relax/base_py_module.py | 8 +- python/tvm/runtime/script_printer.py | 82 ++++++------------- .../tirx/test_printer_tir_namespaces.py | 2 +- .../transform/test_transform_lower_tirx.py | 10 +-- 4 files changed, 32 insertions(+), 70 deletions(-) diff --git a/python/tvm/relax/base_py_module.py b/python/tvm/relax/base_py_module.py index 1834c25c3143..5dd8a107ae08 100644 --- a/python/tvm/relax/base_py_module.py +++ b/python/tvm/relax/base_py_module.py @@ -501,10 +501,7 @@ def script( name: str | None = None, show_meta: bool = False, ir_prefix: str = "I", - tir_prefix: str = "T", - relax_prefix: str = "R", module_alias: str = "cls", - buffer_dtype: str = "float32", int_dtype: str = "int32", float_dtype: str = "void", verbose_expr: bool = False, @@ -514,6 +511,7 @@ def script( syntax_sugar: bool = True, show_object_address: bool = False, show_all_struct_info: bool = True, + extra_config: dict | None = None, ) -> str: """Print TVM IR into TVMScript text format with Python function support. @@ -525,10 +523,7 @@ def script( name=name, show_meta=show_meta, ir_prefix=ir_prefix, - tir_prefix=tir_prefix, - relax_prefix=relax_prefix, module_alias=module_alias, - buffer_dtype=buffer_dtype, int_dtype=int_dtype, float_dtype=float_dtype, verbose_expr=verbose_expr, @@ -538,6 +533,7 @@ def script( syntax_sugar=syntax_sugar, show_object_address=show_object_address, show_all_struct_info=show_all_struct_info, + extra_config=extra_config, ) # If there are no Python functions, return the base script diff --git a/python/tvm/runtime/script_printer.py b/python/tvm/runtime/script_printer.py index e67d950a4cc0..0a00157c4f52 100644 --- a/python/tvm/runtime/script_printer.py +++ b/python/tvm/runtime/script_printer.py @@ -34,9 +34,6 @@ class PrinterConfig(Object): binding_names: Sequence[str] show_meta: bool ir_prefix: str - tir_prefix: str - tir_import_module: str - relax_prefix: str module_alias: str int_dtype: str float_dtype: str @@ -58,11 +55,7 @@ def __init__( name: str | None = None, show_meta: bool = False, ir_prefix: str = "I", - tir_prefix: str = "T", - tir_import_module: str = "tir", - relax_prefix: str = "R", module_alias: str = "cls", - buffer_dtype: str = "float32", int_dtype: str = "int32", float_dtype: str = "void", verbose_expr: bool = False, @@ -72,6 +65,7 @@ def __init__( syntax_sugar: bool = True, show_object_address: bool = False, show_all_struct_info: bool = True, + extra_config: dict | None = None, path_to_underline: list[AccessPath] | None = None, path_to_annotate: dict[AccessPath, str] | None = None, obj_to_underline: list[Object] | None = None, @@ -79,12 +73,9 @@ def __init__( ) -> None: if num_context_lines is None: num_context_lines = -1 - cfg = { + cfg: dict = { "show_meta": show_meta, "ir_prefix": ir_prefix, - "tir_prefix": tir_prefix, - "tir_import_module": tir_import_module, - "relax_prefix": relax_prefix, "module_alias": module_alias, "int_dtype": int_dtype, "float_dtype": float_dtype, @@ -99,14 +90,13 @@ def __init__( "obj_to_underline": obj_to_underline, "obj_to_annotate": obj_to_annotate, # Dialect-specific config via dotted keys in extra_config - "tirx.prefix": tir_prefix, - "tirx.buffer_dtype": buffer_dtype, - "relax.prefix": relax_prefix, "relax.show_all_struct_info": show_all_struct_info, } if name is not None: cfg["name"] = name + if extra_config is not None: + cfg["extra_config"] = extra_config self.__init_handle_by_constructor__( _ffi_node_api.PrinterConfig, cfg, # type: ignore # pylint: disable=no-member @@ -131,11 +121,7 @@ def script( name: str | None = None, show_meta: bool = False, ir_prefix: str = "I", - tir_prefix: str = "T", - tir_import_module: str = "tir", - relax_prefix: str = "R", module_alias: str = "cls", - buffer_dtype: str = "float32", int_dtype: str = "int32", float_dtype: str = "void", verbose_expr: bool = False, @@ -145,6 +131,7 @@ def script( syntax_sugar: bool = True, show_object_address: bool = False, show_all_struct_info: bool = True, + extra_config: dict | None = None, path_to_underline: list[AccessPath] | None = None, path_to_annotate: dict[AccessPath, str] | None = None, obj_to_underline: list[Object] | None = None, @@ -160,18 +147,9 @@ def script( Whether to print the meta data of the object ir_prefix : str = "I" The prefix of AST nodes from tvm.ir - tir_prefix : str = "T" - The prefix of AST nodes from tvm.tir - tir_import_module : str = "tir" - The module name in the printed import (e.g. \"tir\" or \"tirx\"). - Use tir_import_module=\"tirx\" with tir_prefix=\"Tx\" for all-Tx output. - relax_prefix : str = "R" - The prefix of AST nodes from tvm.relax module_alias : str = "cls" The alias of the current module at cross-function call, Directly use module name if it's empty. - buffer_dtype : str = "float32" - The default data type of buffer int_dtype : str = "int32" The default data type of integer float_dtype : str = "void" @@ -192,6 +170,10 @@ def script( If True (default), annotate all variable bindings with the struct info of that variable. If False, only add annotations where required for unambiguous round-trip of Relax -> TVMScript -> Relax. + extra_config : Optional[dict] = None + Dialect-specific configuration passed through to PrinterConfig.extra_config. + Keys are conventionally namespaced as ".", e.g. + ``{"tirx.prefix": "Tx", "tirx.buffer_dtype": "float16"}``. path_to_underline : Optional[List[AccessPath]] = None Object path to be underlined path_to_annotate : Optional[Dict[AccessPath, str]] = None @@ -211,9 +193,12 @@ def script( # printing a PrimFunc / IRModule that has no s_tir-tagged content. # Free objects (Buffer, BufferRegion, ...) keep the default `T`/`tir` # flavor — they have no enclosing function to indicate tirx vs s_tir. - tir_prefix_val = tir_prefix - tir_import_module_val = tir_import_module - if tir_prefix == "T" and tir_import_module == "tir": + merged_extra: dict = {} + if extra_config is not None: + merged_extra.update(extra_config) + + # Only auto-switch if the caller has not already set a tirx.prefix override. + if "tirx.prefix" not in merged_extra: from tvm.ir import IRModule # pylint: disable=import-outside-toplevel from tvm.tirx import PrimFunc # pylint: disable=import-outside-toplevel @@ -236,19 +221,15 @@ def script( if any_prim and not any_s_tir: switch_to_tirx = True if switch_to_tirx: - tir_prefix_val = "Tx" - tir_import_module_val = "tirx" + merged_extra["tirx.prefix"] = "Tx" + return _script( self, PrinterConfig( name=name, show_meta=show_meta, ir_prefix=ir_prefix, - tir_prefix=tir_prefix_val, - tir_import_module=tir_import_module_val, - relax_prefix=relax_prefix, module_alias=module_alias, - buffer_dtype=buffer_dtype, int_dtype=int_dtype, float_dtype=float_dtype, verbose_expr=verbose_expr, @@ -258,6 +239,7 @@ def script( syntax_sugar=syntax_sugar, show_object_address=show_object_address, show_all_struct_info=show_all_struct_info, + extra_config=merged_extra if merged_extra else None, path_to_underline=path_to_underline, path_to_annotate=path_to_annotate, obj_to_underline=obj_to_underline, @@ -271,11 +253,7 @@ def _relax_script( name: str | None = None, show_meta: bool = False, ir_prefix: str = "I", - tir_prefix: str = "T", - tir_import_module: str = "tir", - relax_prefix: str = "R", module_alias: str = "cls", - buffer_dtype: str = "float32", int_dtype: str = "int32", float_dtype: str = "void", verbose_expr: bool = False, @@ -284,6 +262,7 @@ def _relax_script( num_context_lines: int = -1, syntax_sugar: bool = True, show_object_address: bool = False, + extra_config: dict | None = None, path_to_underline: list[AccessPath] | None = None, path_to_annotate: dict[AccessPath, str] | None = None, obj_to_underline: list[Object] | None = None, @@ -295,11 +274,7 @@ def _relax_script( name=name, show_meta=show_meta, ir_prefix=ir_prefix, - tir_prefix=tir_prefix, - tir_import_module=tir_import_module, - relax_prefix=relax_prefix, module_alias=module_alias, - buffer_dtype=buffer_dtype, int_dtype=int_dtype, float_dtype=float_dtype, verbose_expr=verbose_expr, @@ -308,6 +283,7 @@ def _relax_script( num_context_lines=num_context_lines, syntax_sugar=syntax_sugar, show_object_address=show_object_address, + extra_config=extra_config, path_to_underline=path_to_underline, path_to_annotate=path_to_annotate, obj_to_underline=obj_to_underline, @@ -323,11 +299,7 @@ def show( name: str | None = None, show_meta: bool = False, ir_prefix: str = "I", - tir_prefix: str = "T", - tir_import_module: str = "tir", - relax_prefix: str = "R", module_alias: str = "cls", - buffer_dtype: str = "float32", int_dtype: str = "int32", float_dtype: str = "void", verbose_expr: bool = False, @@ -337,6 +309,7 @@ def show( syntax_sugar: bool = True, show_object_address: bool = False, show_all_struct_info: bool = True, + extra_config: dict | None = None, path_to_underline: list[AccessPath] | None = None, path_to_annotate: dict[AccessPath, str] | None = None, obj_to_underline: list[Object] | None = None, @@ -375,15 +348,9 @@ def show( Whether to print the meta data of the object ir_prefix : str = "I" The prefix of AST nodes from tvm.ir - tir_prefix : str = "T" - The prefix of AST nodes from tvm.tirx - relax_prefix : str = "R" - The prefix of AST nodes from tvm.relax module_alias : str = "cls" The alias of the current module at cross-function call, Directly use module name if it's empty. - buffer_dtype : str = "float32" - The default data type of buffer int_dtype : str = "int32" The default data type of integer float_dtype : str = "void" @@ -404,6 +371,8 @@ def show( If True (default), annotate all variable bindings with the struct info of that variable. If False, only add annotations where required for unambiguous round-trip of Relax -> TVMScript -> Relax. + extra_config : Optional[dict] = None + Dialect-specific configuration passed through to PrinterConfig.extra_config. path_to_underline : Optional[List[AccessPath]] = None Object path to be underlined path_to_annotate : Optional[Dict[AccessPath, str]] = None @@ -425,11 +394,7 @@ def show( name=name, show_meta=show_meta, ir_prefix=ir_prefix, - tir_prefix=tir_prefix, - tir_import_module=tir_import_module, - relax_prefix=relax_prefix, module_alias=module_alias, - buffer_dtype=buffer_dtype, int_dtype=int_dtype, float_dtype=float_dtype, verbose_expr=verbose_expr, @@ -439,6 +404,7 @@ def show( syntax_sugar=syntax_sugar, show_object_address=show_object_address, show_all_struct_info=show_all_struct_info, + extra_config=extra_config, path_to_underline=path_to_underline, path_to_annotate=path_to_annotate, obj_to_underline=obj_to_underline, diff --git a/tests/python/tirx/test_printer_tir_namespaces.py b/tests/python/tirx/test_printer_tir_namespaces.py index 79d37ea57186..50fdd4eea9e3 100644 --- a/tests/python/tirx/test_printer_tir_namespaces.py +++ b/tests/python/tirx/test_printer_tir_namespaces.py @@ -21,7 +21,7 @@ def _assert_print(obj, expected): # Use Tx prefix so standalone TIR nodes (non-PrimFunc) print as Tx to match tirx namespace - out = obj.script(verbose_expr=True, tir_prefix="Tx", tir_import_module="tirx").strip() + out = obj.script(verbose_expr=True, extra_config={"tirx.prefix": "Tx"}).strip() assert out == expected.strip() diff --git a/tests/python/tirx/transform/test_transform_lower_tirx.py b/tests/python/tirx/transform/test_transform_lower_tirx.py index c8434f505520..3e20d61f8059 100644 --- a/tests/python/tirx/transform/test_transform_lower_tirx.py +++ b/tests/python/tirx/transform/test_transform_lower_tirx.py @@ -953,7 +953,7 @@ def before(A_ptr: Tx.handle): with tvm.target.Target("cuda"): lowered = LowerTIRx()(tvm.IRModule({"main": before})) - script = lowered.script(tir_prefix="Tx", tir_import_module="tirx") + script = lowered.script(extra_config={"tirx.prefix": "Tx"}) assert "if wg_id == 0:" in script assert "0 <= wg_id" not in script assert "wg_id < 1" not in script @@ -977,7 +977,7 @@ def before(A_ptr: Tx.handle): with tvm.target.Target("cuda"): lowered = LowerTIRx()(tvm.IRModule({"main": before})) - script = lowered.script(tir_prefix="Tx", tir_import_module="tirx") + script = lowered.script(extra_config={"tirx.prefix": "Tx"}) assert "if wg_id == 0:" in script assert "0 <= wg_id" not in script assert "wg_id < 1" not in script @@ -1002,7 +1002,7 @@ def before(A_ptr: Tx.handle): lowered = LowerTIRx()(tvm.IRModule({"main": before})) simplified = Simplify()(lowered) - script = simplified.script(tir_prefix="Tx", tir_import_module="tirx") + script = simplified.script(extra_config={"tirx.prefix": "Tx"}) assert "if warp_id_in_cta // 4 == 0:" in script assert "if 0 <= warp_id_in_cta" not in script assert "A_1[warp_id_in_cta] = Tx.Cast" in script @@ -1018,7 +1018,7 @@ def test_lower_exec_context_selector_filter_for_elect_sync(): @register_dispatch("copy", "cuda", variant=variant, priority=10_000) def _probe(op_call, sctx): - seen.append(sctx.inter["laneid"][1].script(tir_prefix="Tx", tir_import_module="tirx")) + seen.append(sctx.inter["laneid"][1].script(extra_config={"tirx.prefix": "Tx"})) @Tx.prim_func(private=True) def impl(): @@ -1088,7 +1088,7 @@ def before(A_ptr: Tx.handle, B_ptr: Tx.handle): assert _int_pair(seen[0]["inter"], "warpid") == (1, 0) assert int(seen[0]["inter"]["laneid"][0]) == 1 assert ( - seen[0]["inter"]["laneid"][1].script(tir_prefix="Tx", tir_import_module="tirx") + seen[0]["inter"]["laneid"][1].script(extra_config={"tirx.prefix": "Tx"}) == "Tx.selector(lane_id, Tx.ptx.elect_sync())" ) assert len(seen[0]["intra"]) == 0 From e42381b67ecea75601a7977de2bdee6bf373c0f1 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 27 May 2026 23:33:57 +0000 Subject: [PATCH 5/5] [REFACTOR][SCRIPT] Fix transitive include breakage from config.h removal MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Removing #include from ir/expr.h and ir/module.h (in the macro-drop commit) broke three headers and one test that relied on those transitively pulling in PrinterConfig and NodeFunctor: - include/tvm/script/printer/doc.h: add config.h (uses PrinterConfig in DocToPythonScript declaration) - include/tvm/script/printer/ir_docsifier.h: add config.h (PrinterConfig member on IRDocsifierNode arrived via module.h → config.h) - include/tvm/script/ir_builder/base.h: add node_functor.h (NodeFunctor arrived via expr.h → config.h → node_functor.h) - tests/cpp/tir_scalable_datatype.cc: rewrite call->Script() to tvm::Script(call) since the Script() member was removed with TVM_OBJECT_ENABLE_SCRIPT_PRINTER --- include/tvm/script/ir_builder/base.h | 1 + include/tvm/script/printer/config.h | 4 +++- include/tvm/script/printer/doc.h | 1 + include/tvm/script/printer/ir_docsifier.h | 1 + python/tvm/runtime/script_printer.py | 5 ++++- src/s_tir/meta_schedule/database/json_database.cc | 13 +++++++------ src/script/printer/script_printer.cc | 8 +++----- src/tirx/script/printer/buffer.cc | 3 +-- tests/cpp/tir_scalable_datatype.cc | 3 ++- 9 files changed, 23 insertions(+), 16 deletions(-) diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h index a459df6ee645..0d9c8ccc4fec 100644 --- a/include/tvm/script/ir_builder/base.h +++ b/include/tvm/script/ir_builder/base.h @@ -23,6 +23,7 @@ #include #include #include +#include #include diff --git a/include/tvm/script/printer/config.h b/include/tvm/script/printer/config.h index 2b05688fb2eb..541d66f63526 100644 --- a/include/tvm/script/printer/config.h +++ b/include/tvm/script/printer/config.h @@ -53,6 +53,8 @@ class PrinterConfigNode : public ffi::Object { * \note Directly use module name if it's empty. */ ffi::String module_alias = "cls"; + /*! \brief Default buffer dtype */ + DataType buffer_dtype = DataType::Float(32); /*! \brief Default data type of integer literals */ DataType int_dtype = DataType::Int(32); /*! @@ -88,7 +90,6 @@ class PrinterConfigNode : public ffi::Object { * * Keys are conventionally namespaced as ".", e.g.: * "tirx.prefix" — the TIR prefix (default "T") - * "tirx.buffer_dtype" — default buffer dtype (default float32) * "relax.prefix" — the Relax prefix (default "R") * "relax.show_all_struct_info" — whether to show all struct info (default true) * @@ -116,6 +117,7 @@ class PrinterConfigNode : public ffi::Object { .def_ro("show_meta", &PrinterConfigNode::show_meta) .def_ro("ir_prefix", &PrinterConfigNode::ir_prefix) .def_ro("module_alias", &PrinterConfigNode::module_alias) + .def_ro("buffer_dtype", &PrinterConfigNode::buffer_dtype) .def_ro("int_dtype", &PrinterConfigNode::int_dtype) .def_ro("float_dtype", &PrinterConfigNode::float_dtype) .def_ro("verbose_expr", &PrinterConfigNode::verbose_expr) diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index c602fc80a492..d63942ac71df 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -24,6 +24,7 @@ #include #include #include +#include #include diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index e49d4f8a1cc0..32f2281828ad 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include diff --git a/python/tvm/runtime/script_printer.py b/python/tvm/runtime/script_printer.py index 0a00157c4f52..209efe77a0cc 100644 --- a/python/tvm/runtime/script_printer.py +++ b/python/tvm/runtime/script_printer.py @@ -35,6 +35,7 @@ class PrinterConfig(Object): show_meta: bool ir_prefix: str module_alias: str + buffer_dtype: str int_dtype: str float_dtype: str verbose_expr: bool @@ -56,6 +57,7 @@ def __init__( show_meta: bool = False, ir_prefix: str = "I", module_alias: str = "cls", + buffer_dtype: str = "float32", int_dtype: str = "int32", float_dtype: str = "void", verbose_expr: bool = False, @@ -77,6 +79,7 @@ def __init__( "show_meta": show_meta, "ir_prefix": ir_prefix, "module_alias": module_alias, + "buffer_dtype": buffer_dtype, "int_dtype": int_dtype, "float_dtype": float_dtype, "verbose_expr": verbose_expr, @@ -173,7 +176,7 @@ def script( extra_config : Optional[dict] = None Dialect-specific configuration passed through to PrinterConfig.extra_config. Keys are conventionally namespaced as ".", e.g. - ``{"tirx.prefix": "Tx", "tirx.buffer_dtype": "float16"}``. + ``{"tirx.prefix": "Tx"}``. path_to_underline : Optional[List[AccessPath]] = None Object path to be underlined path_to_annotate : Optional[Dict[AccessPath, str]] = None diff --git a/src/s_tir/meta_schedule/database/json_database.cc b/src/s_tir/meta_schedule/database/json_database.cc index 3b949bb186a9..9722dc39b405 100644 --- a/src/s_tir/meta_schedule/database/json_database.cc +++ b/src/s_tir/meta_schedule/database/json_database.cc @@ -200,12 +200,13 @@ Database Database::JSONDatabase(ffi::String path_workload, ffi::String path_tuni workload = workloads[workload_index]; records[task_id] = TuningRecord::FromJSON(arr->at(1).cast(), workload); } catch (std::runtime_error& e) { - TVM_FFI_THROW(ValueError) << "Unable to parse TuningRecord, on line " << (task_id + 1) - << " of file " << path_tuning_record << ". The workload is:\n" - << (workload.defined() ? tvm::Script(workload->mod) : "(null)") - << "\nThe JSONObject of TuningRecord is:\n" - << json_obj << "\nThe error message is:\n" - << e.what(); + TVM_FFI_THROW(ValueError) + << "Unable to parse TuningRecord, on line " << (task_id + 1) << " of file " + << path_tuning_record << ". The workload is:\n" + << (workload.defined() ? tvm::Script(workload->mod) : "(null)") + << "\nThe JSONObject of TuningRecord is:\n" + << json_obj << "\nThe error message is:\n" + << e.what(); } }); for (const TuningRecord& record : records) { diff --git a/src/script/printer/script_printer.cc b/src/script/printer/script_printer.cc index 39cc6a0ed4b6..d595898c919e 100644 --- a/src/script/printer/script_printer.cc +++ b/src/script/printer/script_printer.cc @@ -70,6 +70,9 @@ PrinterConfig::PrinterConfig(ffi::Map config_dict) { if (auto v = config_dict.Get("module_alias")) { n->module_alias = Downcast(v.value()); } + if (auto v = config_dict.Get("buffer_dtype")) { + n->buffer_dtype = DataType(ffi::StringToDLDataType(Downcast(v.value()))); + } if (auto v = config_dict.Get("int_dtype")) { n->int_dtype = DataType(ffi::StringToDLDataType(Downcast(v.value()))); } @@ -119,11 +122,6 @@ PrinterConfig::PrinterConfig(ffi::Map config_dict) { n->extra_config.Set(ffi::String(key), v.value()); } } - // "tirx.buffer_dtype" is passed as a DLDataType string from Python; convert to DataType. - if (auto v = config_dict.Get("tirx.buffer_dtype")) { - DataType dt(ffi::StringToDLDataType(Downcast(v.value()))); - n->extra_config.Set(ffi::String("tirx.buffer_dtype"), ffi::Any(dt)); - } // Boolean dialect keys. if (auto v = config_dict.Get("relax.show_all_struct_info")) { n->extra_config.Set(ffi::String("relax.show_all_struct_info"), v.value()); diff --git a/src/tirx/script/printer/buffer.cc b/src/tirx/script/printer/buffer.cc index ff92af62e6a0..2333eb89005b 100644 --- a/src/tirx/script/printer/buffer.cc +++ b/src/tirx/script/printer/buffer.cc @@ -93,8 +93,7 @@ ffi::Map BufferAttrs(tirx::Buffer buffer, const AccessPath } // Step 2. Handle `buffer.dtype` { - DataType default_buf_dtype = - d->cfg->GetExtraConfig("tirx.buffer_dtype", DataType::Float(32)); + DataType default_buf_dtype = d->cfg->buffer_dtype; if (buffer->dtype != default_buf_dtype) { kwargs.Set("dtype", LiteralDoc::DataType(buffer->dtype, buffer_p->Attr("dtype"))); } diff --git a/tests/cpp/tir_scalable_datatype.cc b/tests/cpp/tir_scalable_datatype.cc index fd9f76eee366..5ead9c7d404b 100644 --- a/tests/cpp/tir_scalable_datatype.cc +++ b/tests/cpp/tir_scalable_datatype.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -195,7 +196,7 @@ TEST(ScalableDataType, TestScalableIntrinCall) { ::llvm::Intrinsic::experimental_stepvector)}); #endif ASSERT_EQ(call->dtype, scalable_type); - ASSERT_EQ(call->Script(), + ASSERT_EQ(tvm::Script(call), #if TVM_LLVM_VERSION >= 200 "T.call_llvm_intrin(\"int32xvscalex4\", \"llvm.stepvector\")"); #else