From a8edc505328dec42c061afa725a97ff7d3a90784 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 28 May 2026 15:29:15 +0000 Subject: [PATCH] [REFACTOR][SCRIPT] Revive buffer_dtype as a top-level PrinterConfig field buffer_dtype was moved to extra_config in the tvmscript streamline refactor, but it is conceptually a shared scalar-literal default alongside int_dtype and float_dtype, not a dialect-specific knob. Restore it as a top-level named field on PrinterConfigNode (and the Python PrinterConfig mirror), with the same DataType::Float(32) default and the same top-level cfg-dict conversion pattern as the sibling dtype fields. Update the single C++ reader in src/tirx/script/printer/buffer.cc to use cfg->buffer_dtype directly. --- include/tvm/script/printer/config.h | 2 +- python/tvm/runtime/script_printer.py | 3 ++- src/script/printer/script_printer.cc | 8 +++----- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/include/tvm/script/printer/config.h b/include/tvm/script/printer/config.h index 19510e76a816..8b5357ac4d9b 100644 --- a/include/tvm/script/printer/config.h +++ b/include/tvm/script/printer/config.h @@ -99,7 +99,6 @@ class PrinterConfigNode : public ffi::Object { * * Keys are conventionally namespaced as ".", e.g.: * "tirx.prefix" — the TIR prefix (default "T") - * "tirx.buffer_dtype" — default buffer dtype (default float32) * "relax.prefix" — the Relax prefix (default "R") * "relax.show_all_struct_info" — whether to show all struct info (default true) * @@ -127,6 +126,7 @@ class PrinterConfigNode : public ffi::Object { .def_ro("show_meta", &PrinterConfigNode::show_meta) .def_ro("ir_prefix", &PrinterConfigNode::ir_prefix) .def_ro("module_alias", &PrinterConfigNode::module_alias) + .def_ro("buffer_dtype", &PrinterConfigNode::buffer_dtype) .def_ro("int_dtype", &PrinterConfigNode::int_dtype) .def_ro("float_dtype", &PrinterConfigNode::float_dtype) .def_ro("verbose_expr", &PrinterConfigNode::verbose_expr) diff --git a/python/tvm/runtime/script_printer.py b/python/tvm/runtime/script_printer.py index e67d950a4cc0..bb573f747012 100644 --- a/python/tvm/runtime/script_printer.py +++ b/python/tvm/runtime/script_printer.py @@ -38,6 +38,7 @@ class PrinterConfig(Object): tir_import_module: str relax_prefix: str module_alias: str + buffer_dtype: str int_dtype: str float_dtype: str verbose_expr: bool @@ -86,6 +87,7 @@ def __init__( "tir_import_module": tir_import_module, "relax_prefix": relax_prefix, "module_alias": module_alias, + "buffer_dtype": buffer_dtype, "int_dtype": int_dtype, "float_dtype": float_dtype, "verbose_expr": verbose_expr, @@ -100,7 +102,6 @@ def __init__( "obj_to_annotate": obj_to_annotate, # Dialect-specific config via dotted keys in extra_config "tirx.prefix": tir_prefix, - "tirx.buffer_dtype": buffer_dtype, "relax.prefix": relax_prefix, "relax.show_all_struct_info": show_all_struct_info, } diff --git a/src/script/printer/script_printer.cc b/src/script/printer/script_printer.cc index f3fc27cf42db..dd9989630e21 100644 --- a/src/script/printer/script_printer.cc +++ b/src/script/printer/script_printer.cc @@ -80,6 +80,9 @@ PrinterConfig::PrinterConfig(ffi::Map config_dict) { if (auto v = config_dict.Get("module_alias")) { n->module_alias = Downcast(v.value()); } + if (auto v = config_dict.Get("buffer_dtype")) { + n->buffer_dtype = DataType(ffi::StringToDLDataType(Downcast(v.value()))); + } if (auto v = config_dict.Get("int_dtype")) { n->int_dtype = DataType(ffi::StringToDLDataType(Downcast(v.value()))); } @@ -129,11 +132,6 @@ PrinterConfig::PrinterConfig(ffi::Map config_dict) { n->extra_config.Set(ffi::String(key), v.value()); } } - // "tirx.buffer_dtype" is passed as a DLDataType string from Python; convert to DataType. - if (auto v = config_dict.Get("tirx.buffer_dtype")) { - DataType dt(ffi::StringToDLDataType(Downcast(v.value()))); - n->extra_config.Set(ffi::String("tirx.buffer_dtype"), ffi::Any(dt)); - } // Boolean dialect keys. if (auto v = config_dict.Get("relax.show_all_struct_info")) { n->extra_config.Set(ffi::String("relax.show_all_struct_info"), v.value());