Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
292 changes: 282 additions & 10 deletions python/tvm/relax/frontend/tflite/tflite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,18 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None):
"lowered_while_functions": {},
"lowering_stack": [],
"module_builder": ctx,
"resource_values": {},
"hashtable_values": {},
"in_call_once_init": False,
}
else:
conversion_state.setdefault("module_builder", ctx)
conversion_state.setdefault("resource_values", {})
conversion_state.setdefault("hashtable_values", {})
conversion_state.setdefault("in_call_once_init", False)
self.conversion_state = conversion_state
self.resource_handles = {}
self.hashtable_handles = {}

# Add more operators
self.convert_map = {
Expand All @@ -187,6 +195,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None):
"ADD_N": self.convert_add_n,
"ARG_MAX": functools.partial(self._convert_arg_min_max, relax_op=_op.argmax),
"ARG_MIN": functools.partial(self._convert_arg_min_max, relax_op=_op.argmin),
"ASSIGN_VARIABLE": self.convert_assign_variable,
"ATAN2": functools.partial(self._convert_elemwise, relax_op=_op.atan2),
"AVERAGE_POOL_2D": functools.partial(self.convert_pool2d, pool_type="average"),
"BATCH_TO_SPACE_ND": self.convert_batch_to_space_nd,
Expand Down Expand Up @@ -234,6 +243,10 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None):
),
"GELU": self.convert_gelu,
"HARD_SWISH": self.convert_hard_swish,
"HASHTABLE": self.convert_hashtable,
"HASHTABLE_FIND": self.convert_hashtable_find,
"HASHTABLE_IMPORT": self.convert_hashtable_import,
"HASHTABLE_SIZE": self.convert_hashtable_size,
"IF": self.convert_if,
"L2_NORMALIZATION": self.convert_l2_normalization,
"L2_POOL_2D": functools.partial(self.convert_pool2d, pool_type="l2"),
Expand Down Expand Up @@ -276,6 +289,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None):
"QUANTIZE": self.convert_quantize,
"RANDOM_STANDARD_NORMAL": self.convert_random_standard_normal,
"RANDOM_UNIFORM": self.convert_random_uniform,
"READ_VARIABLE": self.convert_read_variable,
"REDUCE_ALL": functools.partial(self._convert_reduce_bool, relax_op=_op.min),
"REDUCE_ANY": functools.partial(self._convert_reduce_bool, relax_op=_op.max),
"REDUCE_MAX": functools.partial(self._convert_reduce, relax_op=_op.max),
Expand Down Expand Up @@ -389,6 +403,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None):
self._convert_segment_op, op_name="UNSORTED_SEGMENT_PROD", reduction="mul"
),
# "UNIDIRECTIONAL_SEQUENCE_LSTM": self.convert_unidirectional_sequence_lstm,
"VAR_HANDLE": self.convert_var_handle,
"WHERE": self.convert_select,
"WHILE": self.convert_while,
"ZEROS_LIKE": self.convert_zeros_like,
Expand Down Expand Up @@ -516,6 +531,244 @@ def convert_op_to_relax(self):
get_tensor_name(self.subgraph, output_tensor.tensor_idx), ret[idx]
)

@staticmethod
def _decode_tflite_string(value):
"""Decode a TFLite string field."""
if value is None:
return ""
if isinstance(value, bytes | bytearray):
return value.decode("utf-8")
return str(value)

def _get_var_handle_resource_key(self, op, fallback_tensor=None):
"""Return a stable resource key for a VAR_HANDLE op."""
container = ""
shared_name = ""
if op.BuiltinOptions() is not None:
try:
from tflite.VarHandleOptions import VarHandleOptions

opts = self._get_builtin_options(op, VarHandleOptions)
if hasattr(opts, "Container"):
container = self._decode_tflite_string(opts.Container())
if hasattr(opts, "SharedName"):
shared_name = self._decode_tflite_string(opts.SharedName())
except (ImportError, ModuleNotFoundError):
pass

if container or shared_name:
return (container, shared_name)
if fallback_tensor is not None:
return ("", get_tensor_name(self.subgraph, fallback_tensor.tensor_idx))
raise tvm.error.OpNotImplemented("VAR_HANDLE requires VarHandleOptions")
Comment thread
Aharrypotter marked this conversation as resolved.

def _get_resource_key_for_handle(self, tensor, op_name):
tensor_name = get_tensor_name(self.subgraph, tensor.tensor_idx)
if tensor_name not in self.resource_handles:
raise tvm.error.OpNotImplemented(
f"{op_name} requires a VAR_HANDLE in the same TFLite subgraph"
)
return self.resource_handles[tensor_name]

def convert_var_handle(self, op):
"""Convert a TFLite VAR_HANDLE into an importer-local resource handle."""
input_tensors = self.get_input_tensors(op)
output_tensors = self.get_output_tensors(op)
if len(input_tensors) != 0 or len(output_tensors) != 1:
raise tvm.error.OpNotImplemented("VAR_HANDLE expects no inputs and one output")

resource_key = self._get_var_handle_resource_key(op, output_tensors[0])
resource_tensor_name = get_tensor_name(self.subgraph, output_tensors[0].tensor_idx)
self.resource_handles[resource_tensor_name] = resource_key
return None

def convert_assign_variable(self, op):
"""Convert the CALL_ONCE initialization subset of ASSIGN_VARIABLE."""
if not self.conversion_state["in_call_once_init"]:
raise tvm.error.OpNotImplemented(
"ASSIGN_VARIABLE outside CALL_ONCE initialization is not supported by the "
"Relax TFLite frontend yet because it requires mutable resource state modeling."
)

input_tensors = self.get_input_tensors(op)
output_tensors = self.get_output_tensors(op)
if len(input_tensors) != 2 or len(output_tensors) != 0:
raise tvm.error.OpNotImplemented(
"ASSIGN_VARIABLE expects a resource handle and value input with no outputs"
)

resource_key = self._get_resource_key_for_handle(input_tensors[0], "ASSIGN_VARIABLE")
self.conversion_state["resource_values"][resource_key] = self.get_tensor_expr(
input_tensors[1]
)
return None

def convert_read_variable(self, op):
"""Convert READ_VARIABLE for resources initialized by CALL_ONCE."""
input_tensors = self.get_input_tensors(op)
output_tensors = self.get_output_tensors(op)
if len(input_tensors) != 1 or len(output_tensors) != 1:
raise tvm.error.OpNotImplemented("READ_VARIABLE expects one input and one output")

resource_key = self._get_resource_key_for_handle(input_tensors[0], "READ_VARIABLE")
resource_values = self.conversion_state["resource_values"]
if resource_key not in resource_values:
raise tvm.error.OpNotImplemented(
"READ_VARIABLE requires a resource initialized by a supported CALL_ONCE subgraph"
)
return resource_values[resource_key]

def _is_tflite_string_type(self, tensor_type):
from tflite.TensorType import TensorType

return hasattr(TensorType, "STRING") and tensor_type == TensorType.STRING

def _is_supported_hashtable_type_pair(self, key_dtype, value_dtype):
from tflite.TensorType import TensorType

return (key_dtype == TensorType.INT64 and self._is_tflite_string_type(value_dtype)) or (
self._is_tflite_string_type(key_dtype) and value_dtype == TensorType.INT64
)

def _get_hashtable_key(self, op, fallback_tensor=None):
"""Return a stable key and TFLite dtype pair for a HASHTABLE resource."""
table_id = None
key_dtype = None
value_dtype = None
if op.BuiltinOptions() is not None:
try:
from tflite.HashtableOptions import HashtableOptions

opts = self._get_builtin_options(op, HashtableOptions)
table_id = int(opts.TableId())
key_dtype = int(opts.KeyDtype())
value_dtype = int(opts.ValueDtype())
except (ImportError, ModuleNotFoundError):
pass

if key_dtype is None or value_dtype is None:
raise tvm.error.OpNotImplemented("HASHTABLE requires HashtableOptions")
if not self._is_supported_hashtable_type_pair(key_dtype, value_dtype):
raise tvm.error.OpNotImplemented(
"TFLite HASHTABLE only supports int64/string or string/int64 tables"
)

if table_id is not None:
return table_id, key_dtype, value_dtype
if fallback_tensor is not None:
return (
get_tensor_name(self.subgraph, fallback_tensor.tensor_idx),
key_dtype,
value_dtype,
)
raise tvm.error.OpNotImplemented("HASHTABLE requires HashtableOptions")

def _get_hashtable_info_for_handle(self, tensor, op_name):
tensor_name = get_tensor_name(self.subgraph, tensor.tensor_idx)
if tensor_name not in self.hashtable_handles:
raise tvm.error.OpNotImplemented(
f"{op_name} requires a HASHTABLE in the same TFLite subgraph"
)
return self.hashtable_handles[tensor_name]

@staticmethod
def _get_tensor_shape_tuple(tensor_wrapper):
if tensor_wrapper.tensor.ShapeLength() == 0:
return ()
return tuple(int(dim) for dim in tensor_wrapper.tensor.ShapeAsNumpy())

@staticmethod
def _has_tensor_buffer_data(tensor_wrapper):
return (
tensor_wrapper.buffer is not None
and hasattr(tensor_wrapper.buffer, "DataLength")
and tensor_wrapper.buffer.DataLength() > 0
)

def convert_hashtable(self, op):
"""Convert a TFLite HASHTABLE into an importer-local table handle."""
input_tensors = self.get_input_tensors(op)
output_tensors = self.get_output_tensors(op)
if len(input_tensors) != 0 or len(output_tensors) != 1:
raise tvm.error.OpNotImplemented("HASHTABLE expects no inputs and one output")

table_key, key_dtype, value_dtype = self._get_hashtable_key(op, output_tensors[0])
table_tensor_name = get_tensor_name(self.subgraph, output_tensors[0].tensor_idx)
self.hashtable_handles[table_tensor_name] = {
"table_key": table_key,
"key_dtype": key_dtype,
"value_dtype": value_dtype,
}
return None

def convert_hashtable_import(self, op):
"""Convert static metadata for the CALL_ONCE HASHTABLE_IMPORT subset."""
if not self.conversion_state["in_call_once_init"]:
raise tvm.error.OpNotImplemented(
"HASHTABLE_IMPORT outside CALL_ONCE initialization is not supported by the "
"Relax TFLite frontend yet because it requires mutable resource state modeling."
)

input_tensors = self.get_input_tensors(op)
output_tensors = self.get_output_tensors(op)
if len(input_tensors) != 3 or len(output_tensors) != 0:
raise tvm.error.OpNotImplemented(
"HASHTABLE_IMPORT expects table, keys, and values inputs with no outputs"
)

table_info = self._get_hashtable_info_for_handle(input_tensors[0], "HASHTABLE_IMPORT")
key_tensor = input_tensors[1]
value_tensor = input_tensors[2]
if (
key_tensor.tensor.Type() != table_info["key_dtype"]
or value_tensor.tensor.Type() != table_info["value_dtype"]
):
raise tvm.error.OpNotImplemented("HASHTABLE_IMPORT key/value dtypes mismatch")
key_shape = self._get_tensor_shape_tuple(key_tensor)
value_shape = self._get_tensor_shape_tuple(value_tensor)
if key_shape != value_shape:
raise tvm.error.OpNotImplemented("HASHTABLE_IMPORT requires keys and values same shape")
if not self._has_tensor_buffer_data(key_tensor) or not self._has_tensor_buffer_data(
value_tensor
):
raise tvm.error.OpNotImplemented("HASHTABLE_IMPORT requires constant keys and values")

hashtable_values = self.conversion_state["hashtable_values"]
table_key = table_info["table_key"]
if table_key not in hashtable_values:
hashtable_values[table_key] = {
"size": math.prod(key_shape) if key_shape else 1,
"key_dtype": table_info["key_dtype"],
"value_dtype": table_info["value_dtype"],
}
return None

def convert_hashtable_find(self, op):
"""Reject HASHTABLE_FIND until Relax can represent TFLite string tensors."""
raise tvm.error.OpNotImplemented(
"HASHTABLE_FIND requires TensorType.STRING support in Relax TFLite frontend"
)

def convert_hashtable_size(self, op):
"""Convert HASHTABLE_SIZE for a statically imported TFLite hashtable."""
input_tensors = self.get_input_tensors(op)
output_tensors = self.get_output_tensors(op)
if len(input_tensors) != 1 or len(output_tensors) != 1:
raise tvm.error.OpNotImplemented("HASHTABLE_SIZE expects one input and one output")

from tflite.TensorType import TensorType

if output_tensors[0].tensor.Type() != TensorType.INT64:
raise tvm.error.OpNotImplemented("HASHTABLE_SIZE output must be int64")
table_info = self._get_hashtable_info_for_handle(input_tensors[0], "HASHTABLE_SIZE")
table_key = table_info["table_key"]
hashtable_values = self.conversion_state["hashtable_values"]
if table_key not in hashtable_values:
raise tvm.error.OpNotImplemented(
"HASHTABLE_SIZE requires a table initialized by a supported CALL_ONCE subgraph"
)
return relax.const(np.array([hashtable_values[table_key]["size"]], dtype=np.int64), "int64")

def get_op_code_str(self, op):
"""Get TFLite ops string representation"""

Expand Down Expand Up @@ -2288,13 +2541,7 @@ def convert_while(self, op):
return relax.Call(loop_gv, args)

def convert_call_once(self, op):
"""Convert the no-op subset of TFLite CALL_ONCE.

Non-empty CALL_ONCE init subgraphs are used for resource initialization
side effects in TFLite. The Relax TFLite frontend does not yet support
TFLite resource variable operators, so only the empty no-op form is safe
to import.
"""
"""Convert TFLite CALL_ONCE for no-op and resource-variable initialization subsets."""
from tflite.CallOnceOptions import CallOnceOptions

opts = self._get_builtin_options(op, CallOnceOptions)
Expand All @@ -2310,11 +2557,36 @@ def convert_call_once(self, op):
"CALL_ONCE with non-empty init subgraph I/O is not supported"
)
if init_subgraph.OperatorsLength() != 0:
raise tvm.error.OpNotImplemented(
"CALL_ONCE with non-empty init subgraphs is not supported"
)
self._convert_call_once_init_subgraph(init_subgraph)
return None

def _convert_call_once_init_subgraph(self, init_subgraph):
"""Convert the resource-variable initialization subset of a CALL_ONCE subgraph."""
supported_init_ops = {"VAR_HANDLE", "ASSIGN_VARIABLE", "HASHTABLE", "HASHTABLE_IMPORT"}
for op_idx in range(init_subgraph.OperatorsLength()):
op_name = self.get_op_code_str(init_subgraph.Operators(op_idx))
if op_name not in supported_init_ops:
raise tvm.error.OpNotImplemented(
f"CALL_ONCE init subgraph operator {op_name} is not supported"
)

old_in_call_once_init = self.conversion_state["in_call_once_init"]
self.conversion_state["in_call_once_init"] = True
try:
# The supported init ops below only update importer state and return None.
# If future CALL_ONCE ops emit Relax bindings, revisit sharing the parent builder.
subgraph_converter = type(self)(
self.model,
init_subgraph,
ExprTable(),
self.bb,
self.conversion_state,
)
subgraph_converter.check_unsupported_ops()
subgraph_converter.convert_op_to_relax()
finally:
self.conversion_state["in_call_once_init"] = old_in_call_once_init

def _convert_stablehlo_convert(self, op):
"""Convert STABLEHLO_CONVERT to Relax (astype).

Expand Down
Loading
Loading