diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 8183f64f7305..b3614ba89242 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -175,10 +175,18 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "lowered_while_functions": {}, "lowering_stack": [], "module_builder": ctx, + "resource_values": {}, + "hashtable_values": {}, + "in_call_once_init": False, } else: conversion_state.setdefault("module_builder", ctx) + conversion_state.setdefault("resource_values", {}) + conversion_state.setdefault("hashtable_values", {}) + conversion_state.setdefault("in_call_once_init", False) self.conversion_state = conversion_state + self.resource_handles = {} + self.hashtable_handles = {} # Add more operators self.convert_map = { @@ -187,6 +195,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "ADD_N": self.convert_add_n, "ARG_MAX": functools.partial(self._convert_arg_min_max, relax_op=_op.argmax), "ARG_MIN": functools.partial(self._convert_arg_min_max, relax_op=_op.argmin), + "ASSIGN_VARIABLE": self.convert_assign_variable, "ATAN2": functools.partial(self._convert_elemwise, relax_op=_op.atan2), "AVERAGE_POOL_2D": functools.partial(self.convert_pool2d, pool_type="average"), "BATCH_TO_SPACE_ND": self.convert_batch_to_space_nd, @@ -234,6 +243,10 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): ), "GELU": self.convert_gelu, "HARD_SWISH": self.convert_hard_swish, + "HASHTABLE": self.convert_hashtable, + "HASHTABLE_FIND": self.convert_hashtable_find, + "HASHTABLE_IMPORT": self.convert_hashtable_import, + "HASHTABLE_SIZE": self.convert_hashtable_size, "IF": self.convert_if, "L2_NORMALIZATION": self.convert_l2_normalization, "L2_POOL_2D": functools.partial(self.convert_pool2d, pool_type="l2"), @@ -276,6 +289,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "QUANTIZE": self.convert_quantize, "RANDOM_STANDARD_NORMAL": self.convert_random_standard_normal, "RANDOM_UNIFORM": self.convert_random_uniform, + "READ_VARIABLE": self.convert_read_variable, "REDUCE_ALL": functools.partial(self._convert_reduce_bool, relax_op=_op.min), "REDUCE_ANY": functools.partial(self._convert_reduce_bool, relax_op=_op.max), "REDUCE_MAX": functools.partial(self._convert_reduce, relax_op=_op.max), @@ -389,6 +403,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): self._convert_segment_op, op_name="UNSORTED_SEGMENT_PROD", reduction="mul" ), # "UNIDIRECTIONAL_SEQUENCE_LSTM": self.convert_unidirectional_sequence_lstm, + "VAR_HANDLE": self.convert_var_handle, "WHERE": self.convert_select, "WHILE": self.convert_while, "ZEROS_LIKE": self.convert_zeros_like, @@ -516,6 +531,244 @@ def convert_op_to_relax(self): get_tensor_name(self.subgraph, output_tensor.tensor_idx), ret[idx] ) + @staticmethod + def _decode_tflite_string(value): + """Decode a TFLite string field.""" + if value is None: + return "" + if isinstance(value, bytes | bytearray): + return value.decode("utf-8") + return str(value) + + def _get_var_handle_resource_key(self, op, fallback_tensor=None): + """Return a stable resource key for a VAR_HANDLE op.""" + container = "" + shared_name = "" + if op.BuiltinOptions() is not None: + try: + from tflite.VarHandleOptions import VarHandleOptions + + opts = self._get_builtin_options(op, VarHandleOptions) + if hasattr(opts, "Container"): + container = self._decode_tflite_string(opts.Container()) + if hasattr(opts, "SharedName"): + shared_name = self._decode_tflite_string(opts.SharedName()) + except (ImportError, ModuleNotFoundError): + pass + + if container or shared_name: + return (container, shared_name) + if fallback_tensor is not None: + return ("", get_tensor_name(self.subgraph, fallback_tensor.tensor_idx)) + raise tvm.error.OpNotImplemented("VAR_HANDLE requires VarHandleOptions") + + def _get_resource_key_for_handle(self, tensor, op_name): + tensor_name = get_tensor_name(self.subgraph, tensor.tensor_idx) + if tensor_name not in self.resource_handles: + raise tvm.error.OpNotImplemented( + f"{op_name} requires a VAR_HANDLE in the same TFLite subgraph" + ) + return self.resource_handles[tensor_name] + + def convert_var_handle(self, op): + """Convert a TFLite VAR_HANDLE into an importer-local resource handle.""" + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + if len(input_tensors) != 0 or len(output_tensors) != 1: + raise tvm.error.OpNotImplemented("VAR_HANDLE expects no inputs and one output") + + resource_key = self._get_var_handle_resource_key(op, output_tensors[0]) + resource_tensor_name = get_tensor_name(self.subgraph, output_tensors[0].tensor_idx) + self.resource_handles[resource_tensor_name] = resource_key + return None + + def convert_assign_variable(self, op): + """Convert the CALL_ONCE initialization subset of ASSIGN_VARIABLE.""" + if not self.conversion_state["in_call_once_init"]: + raise tvm.error.OpNotImplemented( + "ASSIGN_VARIABLE outside CALL_ONCE initialization is not supported by the " + "Relax TFLite frontend yet because it requires mutable resource state modeling." + ) + + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + if len(input_tensors) != 2 or len(output_tensors) != 0: + raise tvm.error.OpNotImplemented( + "ASSIGN_VARIABLE expects a resource handle and value input with no outputs" + ) + + resource_key = self._get_resource_key_for_handle(input_tensors[0], "ASSIGN_VARIABLE") + self.conversion_state["resource_values"][resource_key] = self.get_tensor_expr( + input_tensors[1] + ) + return None + + def convert_read_variable(self, op): + """Convert READ_VARIABLE for resources initialized by CALL_ONCE.""" + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + if len(input_tensors) != 1 or len(output_tensors) != 1: + raise tvm.error.OpNotImplemented("READ_VARIABLE expects one input and one output") + + resource_key = self._get_resource_key_for_handle(input_tensors[0], "READ_VARIABLE") + resource_values = self.conversion_state["resource_values"] + if resource_key not in resource_values: + raise tvm.error.OpNotImplemented( + "READ_VARIABLE requires a resource initialized by a supported CALL_ONCE subgraph" + ) + return resource_values[resource_key] + + def _is_tflite_string_type(self, tensor_type): + from tflite.TensorType import TensorType + + return hasattr(TensorType, "STRING") and tensor_type == TensorType.STRING + + def _is_supported_hashtable_type_pair(self, key_dtype, value_dtype): + from tflite.TensorType import TensorType + + return (key_dtype == TensorType.INT64 and self._is_tflite_string_type(value_dtype)) or ( + self._is_tflite_string_type(key_dtype) and value_dtype == TensorType.INT64 + ) + + def _get_hashtable_key(self, op, fallback_tensor=None): + """Return a stable key and TFLite dtype pair for a HASHTABLE resource.""" + table_id = None + key_dtype = None + value_dtype = None + if op.BuiltinOptions() is not None: + try: + from tflite.HashtableOptions import HashtableOptions + + opts = self._get_builtin_options(op, HashtableOptions) + table_id = int(opts.TableId()) + key_dtype = int(opts.KeyDtype()) + value_dtype = int(opts.ValueDtype()) + except (ImportError, ModuleNotFoundError): + pass + + if key_dtype is None or value_dtype is None: + raise tvm.error.OpNotImplemented("HASHTABLE requires HashtableOptions") + if not self._is_supported_hashtable_type_pair(key_dtype, value_dtype): + raise tvm.error.OpNotImplemented( + "TFLite HASHTABLE only supports int64/string or string/int64 tables" + ) + + if table_id is not None: + return table_id, key_dtype, value_dtype + if fallback_tensor is not None: + return ( + get_tensor_name(self.subgraph, fallback_tensor.tensor_idx), + key_dtype, + value_dtype, + ) + raise tvm.error.OpNotImplemented("HASHTABLE requires HashtableOptions") + + def _get_hashtable_info_for_handle(self, tensor, op_name): + tensor_name = get_tensor_name(self.subgraph, tensor.tensor_idx) + if tensor_name not in self.hashtable_handles: + raise tvm.error.OpNotImplemented( + f"{op_name} requires a HASHTABLE in the same TFLite subgraph" + ) + return self.hashtable_handles[tensor_name] + + @staticmethod + def _get_tensor_shape_tuple(tensor_wrapper): + if tensor_wrapper.tensor.ShapeLength() == 0: + return () + return tuple(int(dim) for dim in tensor_wrapper.tensor.ShapeAsNumpy()) + + @staticmethod + def _has_tensor_buffer_data(tensor_wrapper): + return ( + tensor_wrapper.buffer is not None + and hasattr(tensor_wrapper.buffer, "DataLength") + and tensor_wrapper.buffer.DataLength() > 0 + ) + + def convert_hashtable(self, op): + """Convert a TFLite HASHTABLE into an importer-local table handle.""" + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + if len(input_tensors) != 0 or len(output_tensors) != 1: + raise tvm.error.OpNotImplemented("HASHTABLE expects no inputs and one output") + + table_key, key_dtype, value_dtype = self._get_hashtable_key(op, output_tensors[0]) + table_tensor_name = get_tensor_name(self.subgraph, output_tensors[0].tensor_idx) + self.hashtable_handles[table_tensor_name] = { + "table_key": table_key, + "key_dtype": key_dtype, + "value_dtype": value_dtype, + } + return None + + def convert_hashtable_import(self, op): + """Convert static metadata for the CALL_ONCE HASHTABLE_IMPORT subset.""" + if not self.conversion_state["in_call_once_init"]: + raise tvm.error.OpNotImplemented( + "HASHTABLE_IMPORT outside CALL_ONCE initialization is not supported by the " + "Relax TFLite frontend yet because it requires mutable resource state modeling." + ) + + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + if len(input_tensors) != 3 or len(output_tensors) != 0: + raise tvm.error.OpNotImplemented( + "HASHTABLE_IMPORT expects table, keys, and values inputs with no outputs" + ) + + table_info = self._get_hashtable_info_for_handle(input_tensors[0], "HASHTABLE_IMPORT") + key_tensor = input_tensors[1] + value_tensor = input_tensors[2] + if ( + key_tensor.tensor.Type() != table_info["key_dtype"] + or value_tensor.tensor.Type() != table_info["value_dtype"] + ): + raise tvm.error.OpNotImplemented("HASHTABLE_IMPORT key/value dtypes mismatch") + key_shape = self._get_tensor_shape_tuple(key_tensor) + value_shape = self._get_tensor_shape_tuple(value_tensor) + if key_shape != value_shape: + raise tvm.error.OpNotImplemented("HASHTABLE_IMPORT requires keys and values same shape") + if not self._has_tensor_buffer_data(key_tensor) or not self._has_tensor_buffer_data( + value_tensor + ): + raise tvm.error.OpNotImplemented("HASHTABLE_IMPORT requires constant keys and values") + + hashtable_values = self.conversion_state["hashtable_values"] + table_key = table_info["table_key"] + if table_key not in hashtable_values: + hashtable_values[table_key] = { + "size": math.prod(key_shape) if key_shape else 1, + "key_dtype": table_info["key_dtype"], + "value_dtype": table_info["value_dtype"], + } + return None + + def convert_hashtable_find(self, op): + """Reject HASHTABLE_FIND until Relax can represent TFLite string tensors.""" + raise tvm.error.OpNotImplemented( + "HASHTABLE_FIND requires TensorType.STRING support in Relax TFLite frontend" + ) + + def convert_hashtable_size(self, op): + """Convert HASHTABLE_SIZE for a statically imported TFLite hashtable.""" + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + if len(input_tensors) != 1 or len(output_tensors) != 1: + raise tvm.error.OpNotImplemented("HASHTABLE_SIZE expects one input and one output") + + from tflite.TensorType import TensorType + + if output_tensors[0].tensor.Type() != TensorType.INT64: + raise tvm.error.OpNotImplemented("HASHTABLE_SIZE output must be int64") + table_info = self._get_hashtable_info_for_handle(input_tensors[0], "HASHTABLE_SIZE") + table_key = table_info["table_key"] + hashtable_values = self.conversion_state["hashtable_values"] + if table_key not in hashtable_values: + raise tvm.error.OpNotImplemented( + "HASHTABLE_SIZE requires a table initialized by a supported CALL_ONCE subgraph" + ) + return relax.const(np.array([hashtable_values[table_key]["size"]], dtype=np.int64), "int64") + def get_op_code_str(self, op): """Get TFLite ops string representation""" @@ -2288,13 +2541,7 @@ def convert_while(self, op): return relax.Call(loop_gv, args) def convert_call_once(self, op): - """Convert the no-op subset of TFLite CALL_ONCE. - - Non-empty CALL_ONCE init subgraphs are used for resource initialization - side effects in TFLite. The Relax TFLite frontend does not yet support - TFLite resource variable operators, so only the empty no-op form is safe - to import. - """ + """Convert TFLite CALL_ONCE for no-op and resource-variable initialization subsets.""" from tflite.CallOnceOptions import CallOnceOptions opts = self._get_builtin_options(op, CallOnceOptions) @@ -2310,11 +2557,36 @@ def convert_call_once(self, op): "CALL_ONCE with non-empty init subgraph I/O is not supported" ) if init_subgraph.OperatorsLength() != 0: - raise tvm.error.OpNotImplemented( - "CALL_ONCE with non-empty init subgraphs is not supported" - ) + self._convert_call_once_init_subgraph(init_subgraph) return None + def _convert_call_once_init_subgraph(self, init_subgraph): + """Convert the resource-variable initialization subset of a CALL_ONCE subgraph.""" + supported_init_ops = {"VAR_HANDLE", "ASSIGN_VARIABLE", "HASHTABLE", "HASHTABLE_IMPORT"} + for op_idx in range(init_subgraph.OperatorsLength()): + op_name = self.get_op_code_str(init_subgraph.Operators(op_idx)) + if op_name not in supported_init_ops: + raise tvm.error.OpNotImplemented( + f"CALL_ONCE init subgraph operator {op_name} is not supported" + ) + + old_in_call_once_init = self.conversion_state["in_call_once_init"] + self.conversion_state["in_call_once_init"] = True + try: + # The supported init ops below only update importer state and return None. + # If future CALL_ONCE ops emit Relax bindings, revisit sharing the parent builder. + subgraph_converter = type(self)( + self.model, + init_subgraph, + ExprTable(), + self.bb, + self.conversion_state, + ) + subgraph_converter.check_unsupported_ops() + subgraph_converter.convert_op_to_relax() + finally: + self.conversion_state["in_call_once_init"] = old_in_call_once_init + def _convert_stablehlo_convert(self, op): """Convert STABLEHLO_CONVERT to Relax (astype). diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index f1abacec27da..feae04abf685 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -3941,6 +3941,78 @@ def _build_call_once_options(builder, init_subgraph_index): return _tfl_call_once_options.CallOnceOptionsEnd(builder) +def _get_builtin_options_type(options_name): + if not hasattr(_tfl_builtin_options, options_name): + pytest.skip(f"TFLite schema does not provide BuiltinOptions.{options_name}") + return getattr(_tfl_builtin_options, options_name) + + +def _get_resource_tensor_type(): + if not hasattr(_tfl_tensor_type, "RESOURCE"): + pytest.skip("TFLite schema does not provide TensorType.RESOURCE") + return getattr(_tfl_tensor_type, "RESOURCE") + + +def _get_string_tensor_type(): + if not hasattr(_tfl_tensor_type, "STRING"): + pytest.skip("TFLite schema does not provide TensorType.STRING") + return getattr(_tfl_tensor_type, "STRING") + + +def _build_tflite_string_buffer(values): + encoded = [value.encode("utf-8") for value in values] + offsets = [] + cursor = 4 * (len(encoded) + 2) + for value in encoded: + offsets.append(cursor) + cursor += len(value) + offsets.append(cursor) + header = np.array([len(encoded), *offsets], dtype=np.int32).tobytes() + return header + b"".join(encoded) + + +def _build_var_handle_options(builder, shared_name="resource_var", container=""): + try: + var_handle_options = _get_tflite_schema_module("VarHandleOptions") + except ModuleNotFoundError: + pytest.skip("TFLite schema does not provide VarHandleOptions") + container_offset = builder.CreateString(container) + shared_name_offset = builder.CreateString(shared_name) + var_handle_options.VarHandleOptionsStart(builder) + var_handle_options.VarHandleOptionsAddContainer(builder, container_offset) + var_handle_options.VarHandleOptionsAddSharedName(builder, shared_name_offset) + return var_handle_options.VarHandleOptionsEnd(builder) + + +def _build_empty_builtin_options(builder, options_name): + try: + options_module = _get_tflite_schema_module(options_name) + except ModuleNotFoundError: + pytest.skip(f"TFLite schema does not provide {options_name}") + getattr(options_module, f"{options_name}Start")(builder) + return getattr(options_module, f"{options_name}End")(builder) + + +def _build_hashtable_options( + builder, + table_id=0, + key_dtype=None, + value_dtype=None, +): + try: + hashtable_options = _get_tflite_schema_module("HashtableOptions") + except ModuleNotFoundError: + pytest.skip("TFLite schema does not provide HashtableOptions") + + key_dtype = _tfl_tensor_type.INT64 if key_dtype is None else key_dtype + value_dtype = _get_string_tensor_type() if value_dtype is None else value_dtype + hashtable_options.HashtableOptionsStart(builder) + hashtable_options.HashtableOptionsAddTableId(builder, table_id) + hashtable_options.HashtableOptionsAddKeyDtype(builder, key_dtype) + hashtable_options.HashtableOptionsAddValueDtype(builder, value_dtype) + return hashtable_options.HashtableOptionsEnd(builder) + + def _load_model_from_buffer(model_bytes): if hasattr(tflite.Model, "Model"): tflite_model = tflite.Model.Model.GetRootAsModel(model_bytes, 0) @@ -5273,6 +5345,545 @@ def test_call_once_invalid_index_unsupported(): _load_model_from_buffer(_build_tflite_call_once_model(init_subgraph_index=2)) +def _build_tflite_resource_variable_model(): + """Build a model that initializes a resource variable in CALL_ONCE and reads it.""" + builder = flatbuffers.Builder(1024) + resource_type = _get_resource_tensor_type() + initial_value = np.array([1.0, 2.0], dtype=np.float32) + + call_once_options = _build_call_once_options(builder, 1) + main_var_handle_options = _build_var_handle_options(builder) + main_read_options = _build_empty_builtin_options(builder, "ReadVariableOptions") + init_var_handle_options = _build_var_handle_options(builder) + init_assign_options = _build_empty_builtin_options(builder, "AssignVariableOptions") + + resource_tensor = _build_tensor(builder, 0, [], tensor_type=resource_type) + main_output_tensor = _build_tensor(builder, 0, [2]) + main_call_once = _build_operator( + builder, + 0, + [], + [], + builtin_options_type=_get_builtin_options_type("CallOnceOptions"), + builtin_options=call_once_options, + ) + main_var_handle = _build_operator( + builder, + 1, + [], + [0], + builtin_options_type=_get_builtin_options_type("VarHandleOptions"), + builtin_options=main_var_handle_options, + ) + main_read = _build_operator( + builder, + 2, + [0], + [1], + builtin_options_type=_get_builtin_options_type("ReadVariableOptions"), + builtin_options=main_read_options, + ) + main_subgraph = _build_subgraph( + builder, + tensors=[resource_tensor, main_output_tensor], + operators=[main_call_once, main_var_handle, main_read], + inputs=[], + outputs=[1], + ) + + init_resource_tensor = _build_tensor(builder, 0, [], tensor_type=resource_type) + init_value_tensor = _build_tensor(builder, 1, [2]) + init_var_handle = _build_operator( + builder, + 1, + [], + [0], + builtin_options_type=_get_builtin_options_type("VarHandleOptions"), + builtin_options=init_var_handle_options, + ) + init_assign = _build_operator( + builder, + 3, + [0, 1], + [], + builtin_options_type=_get_builtin_options_type("AssignVariableOptions"), + builtin_options=init_assign_options, + ) + init_subgraph = _build_subgraph( + builder, + tensors=[init_resource_tensor, init_value_tensor], + operators=[init_var_handle, init_assign], + inputs=[], + outputs=[], + ) + + operator_codes = [ + _build_operator_code(builder, _get_builtin_operator("CALL_ONCE")), + _build_operator_code(builder, _get_builtin_operator("VAR_HANDLE")), + _build_operator_code(builder, _get_builtin_operator("READ_VARIABLE")), + _build_operator_code(builder, _get_builtin_operator("ASSIGN_VARIABLE")), + ] + buffers = [_build_buffer(builder), _build_buffer(builder, initial_value.tobytes())] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + extra_subgraphs=[init_subgraph], + operator_codes=operator_codes, + buffers=buffers, + ) + + +def _build_tflite_resource_assign_in_main_model(): + """Build a model that attempts to assign a resource variable in the main subgraph.""" + builder = flatbuffers.Builder(1024) + resource_type = _get_resource_tensor_type() + value = np.array([1.0, 2.0], dtype=np.float32) + + var_handle_options = _build_var_handle_options(builder) + assign_options = _build_empty_builtin_options(builder, "AssignVariableOptions") + resource_tensor = _build_tensor(builder, 0, [], tensor_type=resource_type) + value_tensor = _build_tensor(builder, 1, [2]) + var_handle = _build_operator( + builder, + 0, + [], + [0], + builtin_options_type=_get_builtin_options_type("VarHandleOptions"), + builtin_options=var_handle_options, + ) + assign = _build_operator( + builder, + 1, + [0, 1], + [], + builtin_options_type=_get_builtin_options_type("AssignVariableOptions"), + builtin_options=assign_options, + ) + main_subgraph = _build_subgraph( + builder, + tensors=[resource_tensor, value_tensor], + operators=[var_handle, assign], + inputs=[], + outputs=[1], + ) + operator_codes = [ + _build_operator_code(builder, _get_builtin_operator("VAR_HANDLE")), + _build_operator_code(builder, _get_builtin_operator("ASSIGN_VARIABLE")), + ] + buffers = [_build_buffer(builder), _build_buffer(builder, value.tobytes())] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + operator_codes=operator_codes, + buffers=buffers, + ) + + +def _build_tflite_resource_read_uninitialized_model(): + """Build a model that reads a resource variable without CALL_ONCE initialization.""" + builder = flatbuffers.Builder(1024) + resource_type = _get_resource_tensor_type() + + var_handle_options = _build_var_handle_options(builder) + read_options = _build_empty_builtin_options(builder, "ReadVariableOptions") + resource_tensor = _build_tensor(builder, 0, [], tensor_type=resource_type) + output_tensor = _build_tensor(builder, 0, [2]) + var_handle = _build_operator( + builder, + 0, + [], + [0], + builtin_options_type=_get_builtin_options_type("VarHandleOptions"), + builtin_options=var_handle_options, + ) + read = _build_operator( + builder, + 1, + [0], + [1], + builtin_options_type=_get_builtin_options_type("ReadVariableOptions"), + builtin_options=read_options, + ) + main_subgraph = _build_subgraph( + builder, + tensors=[resource_tensor, output_tensor], + operators=[var_handle, read], + inputs=[], + outputs=[1], + ) + operator_codes = [ + _build_operator_code(builder, _get_builtin_operator("VAR_HANDLE")), + _build_operator_code(builder, _get_builtin_operator("READ_VARIABLE")), + ] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder)], + ) + + +def _build_tflite_hashtable_find_model(): + """Build a model that imports a static hashtable and finds runtime query keys.""" + builder = flatbuffers.Builder(1024) + resource_type = _get_resource_tensor_type() + string_type = _get_string_tensor_type() + table_keys = np.array([10, 20], dtype=np.int64) + table_values = _build_tflite_string_buffer(["one hundred", "two hundred"]) + default_value = _build_tflite_string_buffer(["missing"]) + + call_once_options = _build_call_once_options(builder, 1) + main_table_options = _build_hashtable_options(builder, table_id=0) + find_options = _build_empty_builtin_options(builder, "HashtableFindOptions") + init_table_options = _build_hashtable_options(builder, table_id=0) + import_options = _build_empty_builtin_options(builder, "HashtableImportOptions") + + query_tensor = _build_tensor(builder, 0, [3], tensor_type=_tfl_tensor_type.INT64) + table_tensor = _build_tensor(builder, 0, [1], tensor_type=resource_type) + default_tensor = _build_tensor(builder, 1, [], tensor_type=string_type) + output_tensor = _build_tensor(builder, 0, [3], tensor_type=string_type) + main_call_once = _build_operator( + builder, + 0, + [], + [], + builtin_options_type=_get_builtin_options_type("CallOnceOptions"), + builtin_options=call_once_options, + ) + main_hashtable = _build_operator( + builder, + 1, + [], + [1], + builtin_options_type=_get_builtin_options_type("HashtableOptions"), + builtin_options=main_table_options, + ) + main_find = _build_operator( + builder, + 2, + [1, 0, 2], + [3], + builtin_options_type=_get_builtin_options_type("HashtableFindOptions"), + builtin_options=find_options, + ) + main_subgraph = _build_subgraph( + builder, + tensors=[query_tensor, table_tensor, default_tensor, output_tensor], + operators=[main_call_once, main_hashtable, main_find], + inputs=[0], + outputs=[3], + ) + + init_table_tensor = _build_tensor(builder, 0, [1], tensor_type=resource_type) + init_keys_tensor = _build_tensor(builder, 2, [2], tensor_type=_tfl_tensor_type.INT64) + init_values_tensor = _build_tensor( + builder, + 3, + [2], + tensor_type=string_type, + ) + init_hashtable = _build_operator( + builder, + 1, + [], + [0], + builtin_options_type=_get_builtin_options_type("HashtableOptions"), + builtin_options=init_table_options, + ) + init_import = _build_operator( + builder, + 3, + [0, 1, 2], + [], + builtin_options_type=_get_builtin_options_type("HashtableImportOptions"), + builtin_options=import_options, + ) + init_subgraph = _build_subgraph( + builder, + tensors=[init_table_tensor, init_keys_tensor, init_values_tensor], + operators=[init_hashtable, init_import], + inputs=[], + outputs=[], + ) + + operator_codes = [ + _build_operator_code(builder, _get_builtin_operator("CALL_ONCE")), + _build_operator_code(builder, _get_builtin_operator("HASHTABLE")), + _build_operator_code(builder, _get_builtin_operator("HASHTABLE_FIND")), + _build_operator_code(builder, _get_builtin_operator("HASHTABLE_IMPORT")), + ] + buffers = [ + _build_buffer(builder), + _build_buffer(builder, default_value), + _build_buffer(builder, table_keys.tobytes()), + _build_buffer(builder, table_values), + ] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + extra_subgraphs=[init_subgraph], + operator_codes=operator_codes, + buffers=buffers, + ) + + +def _build_tflite_hashtable_size_model(): + """Build a model that imports a static hashtable and returns its size.""" + builder = flatbuffers.Builder(1024) + resource_type = _get_resource_tensor_type() + string_type = _get_string_tensor_type() + table_keys = np.array([10, 20], dtype=np.int64) + table_values = _build_tflite_string_buffer(["one hundred", "two hundred"]) + + call_once_options = _build_call_once_options(builder, 1) + main_table_options = _build_hashtable_options(builder, table_id=0) + size_options = _build_empty_builtin_options(builder, "HashtableSizeOptions") + init_table_options = _build_hashtable_options(builder, table_id=0) + import_options = _build_empty_builtin_options(builder, "HashtableImportOptions") + + table_tensor = _build_tensor(builder, 0, [1], tensor_type=resource_type) + size_tensor = _build_tensor(builder, 0, [1], tensor_type=_tfl_tensor_type.INT64) + main_call_once = _build_operator( + builder, + 0, + [], + [], + builtin_options_type=_get_builtin_options_type("CallOnceOptions"), + builtin_options=call_once_options, + ) + main_hashtable = _build_operator( + builder, + 1, + [], + [0], + builtin_options_type=_get_builtin_options_type("HashtableOptions"), + builtin_options=main_table_options, + ) + main_size = _build_operator( + builder, + 2, + [0], + [1], + builtin_options_type=_get_builtin_options_type("HashtableSizeOptions"), + builtin_options=size_options, + ) + main_subgraph = _build_subgraph( + builder, + tensors=[table_tensor, size_tensor], + operators=[main_call_once, main_hashtable, main_size], + inputs=[], + outputs=[1], + ) + + init_table_tensor = _build_tensor(builder, 0, [1], tensor_type=resource_type) + init_keys_tensor = _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT64) + init_values_tensor = _build_tensor(builder, 2, [2], tensor_type=string_type) + init_hashtable = _build_operator( + builder, + 1, + [], + [0], + builtin_options_type=_get_builtin_options_type("HashtableOptions"), + builtin_options=init_table_options, + ) + init_import = _build_operator( + builder, + 3, + [0, 1, 2], + [], + builtin_options_type=_get_builtin_options_type("HashtableImportOptions"), + builtin_options=import_options, + ) + init_subgraph = _build_subgraph( + builder, + tensors=[init_table_tensor, init_keys_tensor, init_values_tensor], + operators=[init_hashtable, init_import], + inputs=[], + outputs=[], + ) + + operator_codes = [ + _build_operator_code(builder, _get_builtin_operator("CALL_ONCE")), + _build_operator_code(builder, _get_builtin_operator("HASHTABLE")), + _build_operator_code(builder, _get_builtin_operator("HASHTABLE_SIZE")), + _build_operator_code(builder, _get_builtin_operator("HASHTABLE_IMPORT")), + ] + buffers = [ + _build_buffer(builder), + _build_buffer(builder, table_keys.tobytes()), + _build_buffer(builder, table_values), + ] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + extra_subgraphs=[init_subgraph], + operator_codes=operator_codes, + buffers=buffers, + ) + + +def _build_tflite_hashtable_import_in_main_model(): + """Build a model that attempts to import hashtable values in the main subgraph.""" + builder = flatbuffers.Builder(1024) + resource_type = _get_resource_tensor_type() + string_type = _get_string_tensor_type() + table_keys = np.array([10, 20], dtype=np.int64) + table_values = _build_tflite_string_buffer(["one hundred", "two hundred"]) + + table_options = _build_hashtable_options(builder, table_id=0) + import_options = _build_empty_builtin_options(builder, "HashtableImportOptions") + + table_tensor = _build_tensor(builder, 0, [1], tensor_type=resource_type) + keys_tensor = _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT64) + values_tensor = _build_tensor(builder, 2, [2], tensor_type=string_type) + hashtable = _build_operator( + builder, + 0, + [], + [0], + builtin_options_type=_get_builtin_options_type("HashtableOptions"), + builtin_options=table_options, + ) + hashtable_import = _build_operator( + builder, + 1, + [0, 1, 2], + [], + builtin_options_type=_get_builtin_options_type("HashtableImportOptions"), + builtin_options=import_options, + ) + main_subgraph = _build_subgraph( + builder, + tensors=[table_tensor, keys_tensor, values_tensor], + operators=[hashtable, hashtable_import], + inputs=[], + outputs=[2], + ) + operator_codes = [ + _build_operator_code(builder, _get_builtin_operator("HASHTABLE")), + _build_operator_code(builder, _get_builtin_operator("HASHTABLE_IMPORT")), + ] + buffers = [ + _build_buffer(builder), + _build_buffer(builder, table_keys.tobytes()), + _build_buffer(builder, table_values), + ] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + operator_codes=operator_codes, + buffers=buffers, + ) + + +def _build_tflite_hashtable_size_uninitialized_model(): + """Build a model that queries the size of a hashtable without importing values.""" + builder = flatbuffers.Builder(1024) + resource_type = _get_resource_tensor_type() + + table_options = _build_hashtable_options(builder, table_id=0) + size_options = _build_empty_builtin_options(builder, "HashtableSizeOptions") + table_tensor = _build_tensor(builder, 0, [1], tensor_type=resource_type) + size_tensor = _build_tensor(builder, 0, [1], tensor_type=_tfl_tensor_type.INT64) + hashtable = _build_operator( + builder, + 0, + [], + [0], + builtin_options_type=_get_builtin_options_type("HashtableOptions"), + builtin_options=table_options, + ) + hashtable_size = _build_operator( + builder, + 1, + [0], + [1], + builtin_options_type=_get_builtin_options_type("HashtableSizeOptions"), + builtin_options=size_options, + ) + main_subgraph = _build_subgraph( + builder, + tensors=[table_tensor, size_tensor], + operators=[hashtable, hashtable_size], + inputs=[], + outputs=[1], + ) + operator_codes = [ + _build_operator_code(builder, _get_builtin_operator("HASHTABLE")), + _build_operator_code(builder, _get_builtin_operator("HASHTABLE_SIZE")), + ] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder)], + ) + + +def test_resource_variable_call_once_init_read(): + """Test reading a resource variable initialized by a supported CALL_ONCE subgraph.""" + mod = _load_model_from_buffer(_build_tflite_resource_variable_model()) + + @I.ir_module + class Expected: + @R.function + def main() -> R.Tensor((2,), dtype="float32"): + R.func_attr({"num_input": 0}) + with R.dataflow(): + gv: R.Tensor((2,), dtype="float32") = R.const([1.0, 2.0], "float32") + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_assign_variable_main_subgraph_unsupported(): + """Test ASSIGN_VARIABLE remains unsupported outside CALL_ONCE initialization.""" + with pytest.raises(tvm.error.OpNotImplemented, match="ASSIGN_VARIABLE outside CALL_ONCE"): + _load_model_from_buffer(_build_tflite_resource_assign_in_main_model()) + + +def test_read_variable_uninitialized_unsupported(): + """Test READ_VARIABLE rejects resource handles without supported initialization.""" + with pytest.raises(tvm.error.OpNotImplemented, match="READ_VARIABLE requires a resource"): + _load_model_from_buffer(_build_tflite_resource_read_uninitialized_model()) + + +def test_hashtable_call_once_import_find_unsupported(): + """Test HASHTABLE_FIND remains unsupported until TFLite string tensors are supported.""" + with pytest.raises(tvm.error.OpNotImplemented, match="TensorType.STRING"): + _load_model_from_buffer(_build_tflite_hashtable_find_model()) + + +def test_hashtable_call_once_import_size(): + """Test HASHTABLE_SIZE for a table initialized by a supported CALL_ONCE subgraph.""" + mod = _load_model_from_buffer(_build_tflite_hashtable_size_model()) + + @I.ir_module + class Expected: + @R.function + def main() -> R.Tensor((1,), dtype="int64"): + R.func_attr({"num_input": 0}) + with R.dataflow(): + gv: R.Tensor((1,), dtype="int64") = R.const([2], "int64") + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_hashtable_import_main_subgraph_unsupported(): + """Test HASHTABLE_IMPORT remains unsupported outside CALL_ONCE initialization.""" + with pytest.raises(tvm.error.OpNotImplemented, match="HASHTABLE_IMPORT outside CALL_ONCE"): + _load_model_from_buffer(_build_tflite_hashtable_import_in_main_model()) + + +def test_hashtable_size_uninitialized_unsupported(): + """Test HASHTABLE_SIZE rejects tables without supported initialization.""" + with pytest.raises(tvm.error.OpNotImplemented, match="HASHTABLE_SIZE requires a table"): + _load_model_from_buffer(_build_tflite_hashtable_size_uninitialized_model()) + + def _get_stablehlo_builtin_operator(builtin_name): if not hasattr(_tfl_builtin_operator, builtin_name): pytest.skip(f"TFLite schema does not provide BuiltinOperator.{builtin_name}")