Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
#include <tvm/ir/cow.h>
#include <tvm/ir/source_map.h>
#include <tvm/ir/type.h>
#include <tvm/script/printer/config.h>

#include <algorithm>
#include <functional>
Expand Down Expand Up @@ -113,8 +112,6 @@ class PrimExprNode : public BaseExprNode {
refl::ObjectDef<PrimExprNode>().def_ro("dtype", &PrimExprNode::dtype);
}

TVM_OBJECT_ENABLE_SCRIPT_PRINTER();

static constexpr const uint32_t _type_child_slots = 40;
TVM_FFI_DECLARE_OBJECT_INFO("ir.PrimExpr", PrimExprNode, BaseExprNode);
};
Expand Down
3 changes: 0 additions & 3 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
#include <tvm/ir/global_info.h>
#include <tvm/ir/source_map.h>
#include <tvm/ir/type.h>
#include <tvm/script/printer/config.h>

#include <string>
#include <unordered_map>
Expand Down Expand Up @@ -241,8 +240,6 @@ class IRModuleNode : public ffi::Object {
*/
TVM_DLL std::unordered_set<ffi::String> Imports() const;

TVM_OBJECT_ENABLE_SCRIPT_PRINTER();

static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;

TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.IRModule", IRModuleNode, ffi::Object);
Expand Down
1 change: 1 addition & 0 deletions include/tvm/script/ir_builder/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <tvm/ir/cast.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/function.h>
#include <tvm/ir/node_functor.h>

#include <vector>

Expand Down
65 changes: 11 additions & 54 deletions include/tvm/script/printer/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
*/
/*!
* \file tvm/script/printer/config.h
* \brief Printer class to print repr string of each AST/IR nodes.
* \brief Configuration object for the TVMScript printer.
*
* Contains PrinterConfig / PrinterConfigNode, GetBuiltinKeywords, GetExtraConfig,
* and RedirectedReprPrinterMethod. The entry-point free function tvm::Script()
* and the dispatch vtable TVMScriptPrinter live in printer.h.
*/
#ifndef TVM_SCRIPT_PRINTER_CONFIG_H_
#define TVM_SCRIPT_PRINTER_CONFIG_H_
Expand All @@ -30,7 +34,6 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/string.h>
#include <tvm/ir/cast.h>
#include <tvm/ir/node_functor.h>
#include <tvm/runtime/data_type.h>

#include <string>
Expand All @@ -45,25 +48,13 @@ class PrinterConfigNode : public ffi::Object {
bool show_meta = false;
/*! \brief The prefix of IR nodes */
ffi::String ir_prefix = "I";
/*! \brief The prefix of TIR nodes */
ffi::String tir_prefix = "T";
/*!
* \brief The TIR module name used in the printed import (e.g. "tir" or "tirx").
* Used in the header comment: "from tvm.script import <tir_import_module> as <tir_prefix>".
* When tir_prefix is "Tx", set to "tirx" so the printed script uses "import tirx as Tx".
*/
ffi::String tir_import_module = "tir";
/*! \brief The prefix of TIRX nodes */
ffi::String tirx_prefix = "Tx";
/*! \brief Default buffer dtype */
DataType buffer_dtype = DataType::Float(32);
/*! \brief The prefix of Relax nodes */
ffi::String relax_prefix = "R";
/*!
* \brief The alias of the current module at cross-function call
* \note Directly use module name if it's empty.
*/
ffi::String module_alias = "cls";
/*! \brief Default buffer dtype */
DataType buffer_dtype = DataType::Float(32);
/*! \brief Default data type of integer literals */
DataType int_dtype = DataType::Int(32);
/*!
Expand Down Expand Up @@ -99,7 +90,6 @@ class PrinterConfigNode : public ffi::Object {
*
* Keys are conventionally namespaced as "<dialect>.<knob>", e.g.:
* "tirx.prefix" — the TIR prefix (default "T")
* "tirx.buffer_dtype" — default buffer dtype (default float32)
* "relax.prefix" — the Relax prefix (default "R")
* "relax.show_all_struct_info" — whether to show all struct info (default true)
*
Expand Down Expand Up @@ -127,6 +117,7 @@ class PrinterConfigNode : public ffi::Object {
.def_ro("show_meta", &PrinterConfigNode::show_meta)
.def_ro("ir_prefix", &PrinterConfigNode::ir_prefix)
.def_ro("module_alias", &PrinterConfigNode::module_alias)
.def_ro("buffer_dtype", &PrinterConfigNode::buffer_dtype)
.def_ro("int_dtype", &PrinterConfigNode::int_dtype)
.def_ro("float_dtype", &PrinterConfigNode::float_dtype)
.def_ro("verbose_expr", &PrinterConfigNode::verbose_expr)
Expand Down Expand Up @@ -156,48 +147,14 @@ class TVM_DLL PrinterConfig : public ffi::ObjectRef {
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PrinterConfig, ffi::ObjectRef, PrinterConfigNode);
};

/*! \brief TVMScript-based printer for IR nodes. */
class TVMScriptPrinter {
public:
/* Convert the object to TVMScript format */
TVM_DLL static std::string Script(const ffi::ObjectRef& node,
const ffi::Optional<PrinterConfig>& cfg);
// Allow registration to be printer.
using FType = NodeFunctor<std::string(const ffi::ObjectRef&, const PrinterConfig&)>;
TVM_DLL static FType& vtable();
};

#define TVM_OBJECT_ENABLE_SCRIPT_PRINTER() \
std::string Script(const ffi::Optional<PrinterConfig>& config = std::nullopt) const { \
return TVMScriptPrinter::Script(ffi::GetRef<ffi::ObjectRef>(this), \
config.value_or(PrinterConfig())); \
}

/*!
* \brief The fallback body used by TVM_REGISTER_SCRIPT_AS_REPR.
* \brief The fallback body used by TVM_REGISTER_SCRIPT_AS_REPR (defined in printer.h).
*
* Tries to format \p obj via TVMScriptPrinter::Script; on error falls back to
* a plain address string. Defined in src/script/printer/config.cc so that
* Tries to format \p obj via tvm::Script; on error falls back to a plain
* address string. Defined in src/script/printer/config.cc so that
* <tvm/runtime/logging.h> is not pulled into this public header.
*/
TVM_DLL std::string RedirectedReprPrinterMethod(const ffi::ObjectRef& obj);

/*!
* \brief Register Script as the kRepr callback for ObjectType and install
* the per-type dispatch entry in TVMScriptPrinter::vtable().
*
* \param ObjectType The concrete object node type (e.g. tirx::VarNode).
* \param Method The TVMScriptPrinter vtable dispatch function.
*/
#define TVM_REGISTER_SCRIPT_AS_REPR(ObjectType, Method) \
TVM_FFI_STATIC_INIT_BLOCK() { \
namespace refl = tvm::ffi::reflection; \
refl::TypeAttrDef<ObjectType>().def(refl::type_attr::kRepr, \
[](ffi::ObjectRef obj, ffi::Function) -> ffi::String { \
return RedirectedReprPrinterMethod(obj); \
}); \
} \
TVM_STATIC_IR_FUNCTOR(TVMScriptPrinter, vtable).set_dispatch<ObjectType>(Method)

} // namespace tvm
#endif // TVM_SCRIPT_PRINTER_CONFIG_H_
1 change: 1 addition & 0 deletions include/tvm/script/printer/doc.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <tvm/ir/expr.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/device_api.h>
#include <tvm/script/printer/config.h>

#include <string>

Expand Down
1 change: 1 addition & 0 deletions include/tvm/script/printer/ir_docsifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/cast.h>
#include <tvm/ir/module.h>
#include <tvm/script/printer/config.h>
#include <tvm/script/printer/doc.h>
#include <tvm/script/printer/ir_docsifier_functor.h>

Expand Down
71 changes: 71 additions & 0 deletions include/tvm/script/printer/printer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/script/printer/printer.h
* \brief Entry-point header for TVMScript printing.
*
* Declares the free function `tvm::Script(node, optional_config)` and the
* dispatch vtable `TVMScriptPrinter::vtable()` used by per-dialect printers.
* `PrinterConfig` and its dataclass helpers live in config.h; this header is
* what callers include to invoke printing.
*/
#ifndef TVM_SCRIPT_PRINTER_PRINTER_H_
#define TVM_SCRIPT_PRINTER_PRINTER_H_

#include <tvm/ir/node_functor.h>
#include <tvm/script/printer/config.h>

namespace tvm {

/*! \brief Print \p node as TVMScript with the given \p config.
*
* Falls back to ffi::ReprPrint for types not registered with TVMScriptPrinter.
*/
TVM_DLL std::string Script(const ffi::ObjectRef& node,
const ffi::Optional<PrinterConfig>& config = std::nullopt);

/*! \brief Dispatch vtable used by per-dialect printers to register their
* object-type printing functions. Internal, but exposed here because
* TVM_REGISTER_SCRIPT_AS_REPR refers to it.
*/
class TVMScriptPrinter {
public:
using FType = NodeFunctor<std::string(const ffi::ObjectRef&, const PrinterConfig&)>;
TVM_DLL static FType& vtable();
};

/*!
* \brief Register Script as the kRepr callback for ObjectType and install
* the per-type dispatch entry in TVMScriptPrinter::vtable().
*
* \param ObjectType The concrete object node type (e.g. tirx::VarNode).
* \param Method The TVMScriptPrinter vtable dispatch function.
*/
#define TVM_REGISTER_SCRIPT_AS_REPR(ObjectType, Method) \
TVM_FFI_STATIC_INIT_BLOCK() { \
namespace refl = tvm::ffi::reflection; \
refl::TypeAttrDef<ObjectType>().def(refl::type_attr::kRepr, \
[](ffi::ObjectRef obj, ffi::Function) -> ffi::String { \
return RedirectedReprPrinterMethod(obj); \
}); \
} \
TVM_STATIC_IR_FUNCTOR(TVMScriptPrinter, vtable).set_dispatch<ObjectType>(Method)

} // namespace tvm
#endif // TVM_SCRIPT_PRINTER_PRINTER_H_
2 changes: 0 additions & 2 deletions include/tvm/tirx/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/string.h>
#include <tvm/ir/expr.h>
#include <tvm/script/printer/config.h>
#include <tvm/tirx/layout.h>
#include <tvm/tirx/var.h>

Expand Down Expand Up @@ -166,7 +165,6 @@ class BufferNode : public ffi::Object {
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;

TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Buffer", BufferNode, ffi::Object);
TVM_OBJECT_ENABLE_SCRIPT_PRINTER();
};

/*!
Expand Down
2 changes: 0 additions & 2 deletions include/tvm/tirx/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
#include <tvm/ir/cow.h>
#include <tvm/ir/function.h>
#include <tvm/runtime/tensor.h>
#include <tvm/script/printer/config.h>
#include <tvm/tirx/buffer.h>
#include <tvm/tirx/expr.h>
#include <tvm/tirx/stmt.h>
Expand Down Expand Up @@ -120,7 +119,6 @@ class PrimFuncNode : public BaseFuncNode {
*/
TVM_DLL FuncType func_type_annotation() const;

TVM_OBJECT_ENABLE_SCRIPT_PRINTER();
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.PrimFunc", PrimFuncNode, BaseFuncNode);
};

Expand Down
3 changes: 0 additions & 3 deletions include/tvm/tirx/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#define TVM_TIRX_STMT_H_

#include <tvm/ffi/reflection/registry.h>
#include <tvm/script/printer/config.h>
#include <tvm/tirx/exec_scope.h>
#include <tvm/tirx/expr.h>
#include <tvm/tirx/layout.h>
Expand Down Expand Up @@ -55,8 +54,6 @@ class StmtNode : public ffi::Object {
refl::ObjectDef<StmtNode>().def_ro("span", &StmtNode::span);
}

TVM_OBJECT_ENABLE_SCRIPT_PRINTER();

static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode;

static constexpr const uint32_t _type_child_slots = 15;
Expand Down
8 changes: 2 additions & 6 deletions python/tvm/relax/base_py_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,10 +501,7 @@ def script(
name: str | None = None,
show_meta: bool = False,
ir_prefix: str = "I",
tir_prefix: str = "T",
relax_prefix: str = "R",
module_alias: str = "cls",
buffer_dtype: str = "float32",
int_dtype: str = "int32",
float_dtype: str = "void",
verbose_expr: bool = False,
Expand All @@ -514,6 +511,7 @@ def script(
syntax_sugar: bool = True,
show_object_address: bool = False,
show_all_struct_info: bool = True,
extra_config: dict | None = None,
) -> str:
"""Print TVM IR into TVMScript text format with Python function support.

Expand All @@ -525,10 +523,7 @@ def script(
name=name,
show_meta=show_meta,
ir_prefix=ir_prefix,
tir_prefix=tir_prefix,
relax_prefix=relax_prefix,
module_alias=module_alias,
buffer_dtype=buffer_dtype,
int_dtype=int_dtype,
float_dtype=float_dtype,
verbose_expr=verbose_expr,
Expand All @@ -538,6 +533,7 @@ def script(
syntax_sugar=syntax_sugar,
show_object_address=show_object_address,
show_all_struct_info=show_all_struct_info,
extra_config=extra_config,
)

# If there are no Python functions, return the base script
Expand Down
Loading
Loading