From 9f6f6fd0d088fd22fa20f70dcce53c2487116c86 Mon Sep 17 00:00:00 2001 From: Aharrypotter Date: Thu, 28 May 2026 17:21:10 +0800 Subject: [PATCH 1/3] [Frontend][TFLite] Support resource variable initialization --- .../relax/frontend/tflite/tflite_frontend.py | 132 +++++++++- tests/python/relax/test_frontend_tflite.py | 241 ++++++++++++++++++ 2 files changed, 363 insertions(+), 10 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 8183f64f7305..d2746b47eede 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -175,10 +175,15 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "lowered_while_functions": {}, "lowering_stack": [], "module_builder": ctx, + "resource_values": {}, + "in_call_once_init": False, } else: conversion_state.setdefault("module_builder", ctx) + conversion_state.setdefault("resource_values", {}) + conversion_state.setdefault("in_call_once_init", False) self.conversion_state = conversion_state + self.resource_handles = {} # Add more operators self.convert_map = { @@ -187,6 +192,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "ADD_N": self.convert_add_n, "ARG_MAX": functools.partial(self._convert_arg_min_max, relax_op=_op.argmax), "ARG_MIN": functools.partial(self._convert_arg_min_max, relax_op=_op.argmin), + "ASSIGN_VARIABLE": self.convert_assign_variable, "ATAN2": functools.partial(self._convert_elemwise, relax_op=_op.atan2), "AVERAGE_POOL_2D": functools.partial(self.convert_pool2d, pool_type="average"), "BATCH_TO_SPACE_ND": self.convert_batch_to_space_nd, @@ -276,6 +282,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "QUANTIZE": self.convert_quantize, "RANDOM_STANDARD_NORMAL": self.convert_random_standard_normal, "RANDOM_UNIFORM": self.convert_random_uniform, + "READ_VARIABLE": self.convert_read_variable, "REDUCE_ALL": functools.partial(self._convert_reduce_bool, relax_op=_op.min), "REDUCE_ANY": functools.partial(self._convert_reduce_bool, relax_op=_op.max), "REDUCE_MAX": functools.partial(self._convert_reduce, relax_op=_op.max), @@ -389,6 +396,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): self._convert_segment_op, op_name="UNSORTED_SEGMENT_PROD", reduction="mul" ), # "UNIDIRECTIONAL_SEQUENCE_LSTM": self.convert_unidirectional_sequence_lstm, + "VAR_HANDLE": self.convert_var_handle, "WHERE": self.convert_select, "WHILE": self.convert_while, "ZEROS_LIKE": self.convert_zeros_like, @@ -516,6 +524,93 @@ def convert_op_to_relax(self): get_tensor_name(self.subgraph, output_tensor.tensor_idx), ret[idx] ) + @staticmethod + def _decode_tflite_string(value): + """Decode a TFLite string field.""" + if value is None: + return "" + if isinstance(value, bytes | bytearray): + return value.decode("utf-8") + return str(value) + + def _get_var_handle_resource_key(self, op, fallback_tensor=None): + """Return a stable resource key for a VAR_HANDLE op.""" + container = "" + shared_name = "" + if op.BuiltinOptions() is not None: + try: + from tflite.VarHandleOptions import VarHandleOptions + + opts = self._get_builtin_options(op, VarHandleOptions) + if hasattr(opts, "Container"): + container = self._decode_tflite_string(opts.Container()) + if hasattr(opts, "SharedName"): + shared_name = self._decode_tflite_string(opts.SharedName()) + except (ImportError, ModuleNotFoundError): + pass + + if container or shared_name: + return (container, shared_name) + if fallback_tensor is not None: + return ("", get_tensor_name(self.subgraph, fallback_tensor.tensor_idx)) + raise tvm.error.OpNotImplemented("VAR_HANDLE requires VarHandleOptions") + + def _get_resource_key_for_handle(self, tensor, op_name): + tensor_name = get_tensor_name(self.subgraph, tensor.tensor_idx) + if tensor_name not in self.resource_handles: + raise tvm.error.OpNotImplemented( + f"{op_name} requires a VAR_HANDLE in the same TFLite subgraph" + ) + return self.resource_handles[tensor_name] + + def convert_var_handle(self, op): + """Convert a TFLite VAR_HANDLE into an importer-local resource handle.""" + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + if len(input_tensors) != 0 or len(output_tensors) != 1: + raise tvm.error.OpNotImplemented("VAR_HANDLE expects no inputs and one output") + + resource_key = self._get_var_handle_resource_key(op, output_tensors[0]) + resource_tensor_name = get_tensor_name(self.subgraph, output_tensors[0].tensor_idx) + self.resource_handles[resource_tensor_name] = resource_key + return None + + def convert_assign_variable(self, op): + """Convert the CALL_ONCE initialization subset of ASSIGN_VARIABLE.""" + if not self.conversion_state["in_call_once_init"]: + raise tvm.error.OpNotImplemented( + "ASSIGN_VARIABLE outside CALL_ONCE initialization is not supported by the " + "Relax TFLite frontend yet because it requires mutable resource state modeling." + ) + + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + if len(input_tensors) != 2 or len(output_tensors) != 0: + raise tvm.error.OpNotImplemented( + "ASSIGN_VARIABLE expects a resource handle and value input with no outputs" + ) + + resource_key = self._get_resource_key_for_handle(input_tensors[0], "ASSIGN_VARIABLE") + self.conversion_state["resource_values"][resource_key] = self.get_tensor_expr( + input_tensors[1] + ) + return None + + def convert_read_variable(self, op): + """Convert READ_VARIABLE for resources initialized by CALL_ONCE.""" + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + if len(input_tensors) != 1 or len(output_tensors) != 1: + raise tvm.error.OpNotImplemented("READ_VARIABLE expects one input and one output") + + resource_key = self._get_resource_key_for_handle(input_tensors[0], "READ_VARIABLE") + resource_values = self.conversion_state["resource_values"] + if resource_key not in resource_values: + raise tvm.error.OpNotImplemented( + "READ_VARIABLE requires a resource initialized by a supported CALL_ONCE subgraph" + ) + return resource_values[resource_key] + def get_op_code_str(self, op): """Get TFLite ops string representation""" @@ -2288,13 +2383,7 @@ def convert_while(self, op): return relax.Call(loop_gv, args) def convert_call_once(self, op): - """Convert the no-op subset of TFLite CALL_ONCE. - - Non-empty CALL_ONCE init subgraphs are used for resource initialization - side effects in TFLite. The Relax TFLite frontend does not yet support - TFLite resource variable operators, so only the empty no-op form is safe - to import. - """ + """Convert TFLite CALL_ONCE for no-op and resource-variable initialization subsets.""" from tflite.CallOnceOptions import CallOnceOptions opts = self._get_builtin_options(op, CallOnceOptions) @@ -2310,11 +2399,34 @@ def convert_call_once(self, op): "CALL_ONCE with non-empty init subgraph I/O is not supported" ) if init_subgraph.OperatorsLength() != 0: - raise tvm.error.OpNotImplemented( - "CALL_ONCE with non-empty init subgraphs is not supported" - ) + self._convert_call_once_init_subgraph(init_subgraph) return None + def _convert_call_once_init_subgraph(self, init_subgraph): + """Convert the resource-variable initialization subset of a CALL_ONCE subgraph.""" + supported_init_ops = {"VAR_HANDLE", "ASSIGN_VARIABLE"} + for op_idx in range(init_subgraph.OperatorsLength()): + op_name = self.get_op_code_str(init_subgraph.Operators(op_idx)) + if op_name not in supported_init_ops: + raise tvm.error.OpNotImplemented( + f"CALL_ONCE init subgraph operator {op_name} is not supported" + ) + + old_in_call_once_init = self.conversion_state["in_call_once_init"] + self.conversion_state["in_call_once_init"] = True + try: + subgraph_converter = type(self)( + self.model, + init_subgraph, + ExprTable(), + self.bb, + self.conversion_state, + ) + subgraph_converter.check_unsupported_ops() + subgraph_converter.convert_op_to_relax() + finally: + self.conversion_state["in_call_once_init"] = old_in_call_once_init + def _convert_stablehlo_convert(self, op): """Convert STABLEHLO_CONVERT to Relax (astype). diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index f1abacec27da..f771e0a2a63a 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -3941,6 +3941,40 @@ def _build_call_once_options(builder, init_subgraph_index): return _tfl_call_once_options.CallOnceOptionsEnd(builder) +def _get_builtin_options_type(options_name): + if not hasattr(_tfl_builtin_options, options_name): + pytest.skip(f"TFLite schema does not provide BuiltinOptions.{options_name}") + return getattr(_tfl_builtin_options, options_name) + + +def _get_resource_tensor_type(): + if not hasattr(_tfl_tensor_type, "RESOURCE"): + pytest.skip("TFLite schema does not provide TensorType.RESOURCE") + return getattr(_tfl_tensor_type, "RESOURCE") + + +def _build_var_handle_options(builder, shared_name="resource_var", container=""): + try: + var_handle_options = _get_tflite_schema_module("VarHandleOptions") + except ModuleNotFoundError: + pytest.skip("TFLite schema does not provide VarHandleOptions") + container_offset = builder.CreateString(container) + shared_name_offset = builder.CreateString(shared_name) + var_handle_options.VarHandleOptionsStart(builder) + var_handle_options.VarHandleOptionsAddContainer(builder, container_offset) + var_handle_options.VarHandleOptionsAddSharedName(builder, shared_name_offset) + return var_handle_options.VarHandleOptionsEnd(builder) + + +def _build_empty_builtin_options(builder, options_name): + try: + options_module = _get_tflite_schema_module(options_name) + except ModuleNotFoundError: + pytest.skip(f"TFLite schema does not provide {options_name}") + getattr(options_module, f"{options_name}Start")(builder) + return getattr(options_module, f"{options_name}End")(builder) + + def _load_model_from_buffer(model_bytes): if hasattr(tflite.Model, "Model"): tflite_model = tflite.Model.Model.GetRootAsModel(model_bytes, 0) @@ -5273,6 +5307,213 @@ def test_call_once_invalid_index_unsupported(): _load_model_from_buffer(_build_tflite_call_once_model(init_subgraph_index=2)) +def _build_tflite_resource_variable_model(): + """Build a model that initializes a resource variable in CALL_ONCE and reads it.""" + builder = flatbuffers.Builder(1024) + resource_type = _get_resource_tensor_type() + initial_value = np.array([1.0, 2.0], dtype=np.float32) + + call_once_options = _build_call_once_options(builder, 1) + main_var_handle_options = _build_var_handle_options(builder) + main_read_options = _build_empty_builtin_options(builder, "ReadVariableOptions") + init_var_handle_options = _build_var_handle_options(builder) + init_assign_options = _build_empty_builtin_options(builder, "AssignVariableOptions") + + resource_tensor = _build_tensor(builder, 0, [], tensor_type=resource_type) + main_output_tensor = _build_tensor(builder, 0, [2]) + main_call_once = _build_operator( + builder, + 0, + [], + [], + builtin_options_type=_get_builtin_options_type("CallOnceOptions"), + builtin_options=call_once_options, + ) + main_var_handle = _build_operator( + builder, + 1, + [], + [0], + builtin_options_type=_get_builtin_options_type("VarHandleOptions"), + builtin_options=main_var_handle_options, + ) + main_read = _build_operator( + builder, + 2, + [0], + [1], + builtin_options_type=_get_builtin_options_type("ReadVariableOptions"), + builtin_options=main_read_options, + ) + main_subgraph = _build_subgraph( + builder, + tensors=[resource_tensor, main_output_tensor], + operators=[main_call_once, main_var_handle, main_read], + inputs=[], + outputs=[1], + ) + + init_resource_tensor = _build_tensor(builder, 0, [], tensor_type=resource_type) + init_value_tensor = _build_tensor(builder, 1, [2]) + init_var_handle = _build_operator( + builder, + 1, + [], + [0], + builtin_options_type=_get_builtin_options_type("VarHandleOptions"), + builtin_options=init_var_handle_options, + ) + init_assign = _build_operator( + builder, + 3, + [0, 1], + [], + builtin_options_type=_get_builtin_options_type("AssignVariableOptions"), + builtin_options=init_assign_options, + ) + init_subgraph = _build_subgraph( + builder, + tensors=[init_resource_tensor, init_value_tensor], + operators=[init_var_handle, init_assign], + inputs=[], + outputs=[], + ) + + operator_codes = [ + _build_operator_code(builder, _get_builtin_operator("CALL_ONCE")), + _build_operator_code(builder, _get_builtin_operator("VAR_HANDLE")), + _build_operator_code(builder, _get_builtin_operator("READ_VARIABLE")), + _build_operator_code(builder, _get_builtin_operator("ASSIGN_VARIABLE")), + ] + buffers = [_build_buffer(builder), _build_buffer(builder, initial_value.tobytes())] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + extra_subgraphs=[init_subgraph], + operator_codes=operator_codes, + buffers=buffers, + ) + + +def _build_tflite_resource_assign_in_main_model(): + """Build a model that attempts to assign a resource variable in the main subgraph.""" + builder = flatbuffers.Builder(1024) + resource_type = _get_resource_tensor_type() + value = np.array([1.0, 2.0], dtype=np.float32) + + var_handle_options = _build_var_handle_options(builder) + assign_options = _build_empty_builtin_options(builder, "AssignVariableOptions") + resource_tensor = _build_tensor(builder, 0, [], tensor_type=resource_type) + value_tensor = _build_tensor(builder, 1, [2]) + var_handle = _build_operator( + builder, + 0, + [], + [0], + builtin_options_type=_get_builtin_options_type("VarHandleOptions"), + builtin_options=var_handle_options, + ) + assign = _build_operator( + builder, + 1, + [0, 1], + [], + builtin_options_type=_get_builtin_options_type("AssignVariableOptions"), + builtin_options=assign_options, + ) + main_subgraph = _build_subgraph( + builder, + tensors=[resource_tensor, value_tensor], + operators=[var_handle, assign], + inputs=[], + outputs=[1], + ) + operator_codes = [ + _build_operator_code(builder, _get_builtin_operator("VAR_HANDLE")), + _build_operator_code(builder, _get_builtin_operator("ASSIGN_VARIABLE")), + ] + buffers = [_build_buffer(builder), _build_buffer(builder, value.tobytes())] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + operator_codes=operator_codes, + buffers=buffers, + ) + + +def _build_tflite_resource_read_uninitialized_model(): + """Build a model that reads a resource variable without CALL_ONCE initialization.""" + builder = flatbuffers.Builder(1024) + resource_type = _get_resource_tensor_type() + + var_handle_options = _build_var_handle_options(builder) + read_options = _build_empty_builtin_options(builder, "ReadVariableOptions") + resource_tensor = _build_tensor(builder, 0, [], tensor_type=resource_type) + output_tensor = _build_tensor(builder, 0, [2]) + var_handle = _build_operator( + builder, + 0, + [], + [0], + builtin_options_type=_get_builtin_options_type("VarHandleOptions"), + builtin_options=var_handle_options, + ) + read = _build_operator( + builder, + 1, + [0], + [1], + builtin_options_type=_get_builtin_options_type("ReadVariableOptions"), + builtin_options=read_options, + ) + main_subgraph = _build_subgraph( + builder, + tensors=[resource_tensor, output_tensor], + operators=[var_handle, read], + inputs=[], + outputs=[1], + ) + operator_codes = [ + _build_operator_code(builder, _get_builtin_operator("VAR_HANDLE")), + _build_operator_code(builder, _get_builtin_operator("READ_VARIABLE")), + ] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder)], + ) + + +def test_resource_variable_call_once_init_read(): + """Test reading a resource variable initialized by a supported CALL_ONCE subgraph.""" + mod = _load_model_from_buffer(_build_tflite_resource_variable_model()) + + @I.ir_module + class Expected: + @R.function + def main() -> R.Tensor((2,), dtype="float32"): + R.func_attr({"num_input": 0}) + with R.dataflow(): + gv: R.Tensor((2,), dtype="float32") = R.const([1.0, 2.0], "float32") + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_assign_variable_main_subgraph_unsupported(): + """Test ASSIGN_VARIABLE remains unsupported outside CALL_ONCE initialization.""" + with pytest.raises(tvm.error.OpNotImplemented, match="ASSIGN_VARIABLE outside CALL_ONCE"): + _load_model_from_buffer(_build_tflite_resource_assign_in_main_model()) + + +def test_read_variable_uninitialized_unsupported(): + """Test READ_VARIABLE rejects resource handles without supported initialization.""" + with pytest.raises(tvm.error.OpNotImplemented, match="READ_VARIABLE requires a resource"): + _load_model_from_buffer(_build_tflite_resource_read_uninitialized_model()) + + def _get_stablehlo_builtin_operator(builtin_name): if not hasattr(_tfl_builtin_operator, builtin_name): pytest.skip(f"TFLite schema does not provide BuiltinOperator.{builtin_name}") From 81aa1be95deb2df600b57fe88851f6b7d9449cff Mon Sep 17 00:00:00 2001 From: Aharrypotter Date: Thu, 28 May 2026 17:29:04 +0800 Subject: [PATCH 2/3] [Frontend][TFLite] Support static hashtable initialization --- .../relax/frontend/tflite/tflite_frontend.py | 167 +++++- tests/python/relax/test_frontend_tflite.py | 540 ++++++++++++++++++ 2 files changed, 706 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index d2746b47eede..9962cb71d635 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -176,14 +176,17 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "lowering_stack": [], "module_builder": ctx, "resource_values": {}, + "hashtable_values": {}, "in_call_once_init": False, } else: conversion_state.setdefault("module_builder", ctx) conversion_state.setdefault("resource_values", {}) + conversion_state.setdefault("hashtable_values", {}) conversion_state.setdefault("in_call_once_init", False) self.conversion_state = conversion_state self.resource_handles = {} + self.hashtable_handles = {} # Add more operators self.convert_map = { @@ -240,6 +243,10 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): ), "GELU": self.convert_gelu, "HARD_SWISH": self.convert_hard_swish, + "HASHTABLE": self.convert_hashtable, + "HASHTABLE_FIND": self.convert_hashtable_find, + "HASHTABLE_IMPORT": self.convert_hashtable_import, + "HASHTABLE_SIZE": self.convert_hashtable_size, "IF": self.convert_if, "L2_NORMALIZATION": self.convert_l2_normalization, "L2_POOL_2D": functools.partial(self.convert_pool2d, pool_type="l2"), @@ -611,6 +618,162 @@ def convert_read_variable(self, op): ) return resource_values[resource_key] + def _get_hashtable_key(self, op, fallback_tensor=None): + """Return a stable key for a TFLite HASHTABLE resource.""" + table_id = None + if op.BuiltinOptions() is not None: + try: + from tflite.HashtableOptions import HashtableOptions + + opts = self._get_builtin_options(op, HashtableOptions) + table_id = int(opts.TableId()) + except (ImportError, ModuleNotFoundError): + pass + + if table_id is not None: + return table_id + if fallback_tensor is not None: + return get_tensor_name(self.subgraph, fallback_tensor.tensor_idx) + raise tvm.error.OpNotImplemented("HASHTABLE requires HashtableOptions") + + def _get_hashtable_key_for_handle(self, tensor, op_name): + tensor_name = get_tensor_name(self.subgraph, tensor.tensor_idx) + if tensor_name not in self.hashtable_handles: + raise tvm.error.OpNotImplemented( + f"{op_name} requires a HASHTABLE in the same TFLite subgraph" + ) + return self.hashtable_handles[tensor_name] + + def convert_hashtable(self, op): + """Convert a TFLite HASHTABLE into an importer-local table handle.""" + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + if len(input_tensors) != 0 or len(output_tensors) != 1: + raise tvm.error.OpNotImplemented("HASHTABLE expects no inputs and one output") + + table_key = self._get_hashtable_key(op, output_tensors[0]) + table_tensor_name = get_tensor_name(self.subgraph, output_tensors[0].tensor_idx) + self.hashtable_handles[table_tensor_name] = table_key + return None + + @staticmethod + def _has_tensor_buffer_data(tensor_wrapper): + return ( + tensor_wrapper.buffer is not None + and hasattr(tensor_wrapper.buffer, "DataLength") + and tensor_wrapper.buffer.DataLength() > 0 + ) + + def convert_hashtable_import(self, op): + """Convert the CALL_ONCE initialization subset of HASHTABLE_IMPORT.""" + if not self.conversion_state["in_call_once_init"]: + raise tvm.error.OpNotImplemented( + "HASHTABLE_IMPORT outside CALL_ONCE initialization is not supported by the " + "Relax TFLite frontend yet because it requires mutable resource state modeling." + ) + + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + if len(input_tensors) != 3 or len(output_tensors) != 0: + raise tvm.error.OpNotImplemented( + "HASHTABLE_IMPORT expects table, keys, and values inputs with no outputs" + ) + + table_key = self._get_hashtable_key_for_handle(input_tensors[0], "HASHTABLE_IMPORT") + if not self._has_tensor_buffer_data(input_tensors[1]) or not self._has_tensor_buffer_data( + input_tensors[2] + ): + raise tvm.error.OpNotImplemented("HASHTABLE_IMPORT requires constant keys and values") + keys = self.get_tensor_value(input_tensors[1]) + values = self.get_tensor_value(input_tensors[2]) + if keys.ndim != 1 or values.ndim < 1: + raise tvm.error.OpNotImplemented( + "HASHTABLE_IMPORT requires one-dimensional keys and at least one-dimensional values" + ) + if keys.shape[0] != values.shape[0]: + raise tvm.error.OpNotImplemented("HASHTABLE_IMPORT keys and values size mismatch") + + self.conversion_state["hashtable_values"][table_key] = { + "keys": self.get_tensor_expr(input_tensors[1]), + "values": self.get_tensor_expr(input_tensors[2]), + "size": int(keys.shape[0]), + "value_shape": tuple(int(dim) for dim in values.shape[1:]), + } + return None + + def convert_hashtable_find(self, op): + """Convert HASHTABLE_FIND for static tables initialized by CALL_ONCE.""" + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + if len(input_tensors) != 3 or len(output_tensors) != 1: + raise tvm.error.OpNotImplemented( + "HASHTABLE_FIND expects table, keys, and default value inputs with one output" + ) + + table_key = self._get_hashtable_key_for_handle(input_tensors[0], "HASHTABLE_FIND") + hashtable_values = self.conversion_state["hashtable_values"] + if table_key not in hashtable_values: + raise tvm.error.OpNotImplemented( + "HASHTABLE_FIND requires a table initialized by a supported CALL_ONCE subgraph" + ) + + table = hashtable_values[table_key] + output_shape = ( + tuple(output_tensors[0].tensor.ShapeAsNumpy()) + if output_tensors[0].tensor.ShapeLength() > 0 + else () + ) + value_shape = table["value_shape"] + if value_shape and ( + len(output_shape) < len(value_shape) + or tuple(output_shape[-len(value_shape) :]) != value_shape + ): + raise tvm.error.OpNotImplemented( + "HASHTABLE_FIND output shape must append the imported value shape" + ) + query_keys = self.get_tensor_expr(input_tensors[1]) + table_keys = table["keys"] + table_values = table["values"] + default_value = self.get_tensor_expr(input_tensors[2]) + + if input_tensors[1].tensor.ShapeLength() == 0: + matches = relax.op.equal(query_keys, table_keys) + else: + matches = relax.op.equal(relax.op.expand_dims(query_keys, axis=-1), table_keys) + matches_int = relax.op.astype(matches, "int32") + matched_indices = relax.op.argmax(matches_int, axis=-1) + selected_values = relax.op.take(table_values, matched_indices, axis=0, mode="fast") + has_match = relax.op.greater( + relax.op.max(matches_int, axis=-1), + relax.const(0, dtype="int32"), + ) + if value_shape: + for _ in value_shape: + has_match = relax.op.expand_dims(has_match, axis=-1) + has_match = relax.op.broadcast_to( + has_match, + output_shape, + ) + default_values = relax.op.broadcast_to(default_value, output_shape) + return relax.op.where(has_match, selected_values, default_values) + + def convert_hashtable_size(self, op): + """Convert HASHTABLE_SIZE for static tables initialized by CALL_ONCE.""" + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + if len(input_tensors) != 1 or len(output_tensors) != 1: + raise tvm.error.OpNotImplemented("HASHTABLE_SIZE expects one input and one output") + + table_key = self._get_hashtable_key_for_handle(input_tensors[0], "HASHTABLE_SIZE") + hashtable_values = self.conversion_state["hashtable_values"] + if table_key not in hashtable_values: + raise tvm.error.OpNotImplemented( + "HASHTABLE_SIZE requires a table initialized by a supported CALL_ONCE subgraph" + ) + + output_dtype = self.get_tensor_type_str(output_tensors[0].tensor.Type()) + return relax.const(hashtable_values[table_key]["size"], dtype=output_dtype) + def get_op_code_str(self, op): """Get TFLite ops string representation""" @@ -2404,7 +2567,7 @@ def convert_call_once(self, op): def _convert_call_once_init_subgraph(self, init_subgraph): """Convert the resource-variable initialization subset of a CALL_ONCE subgraph.""" - supported_init_ops = {"VAR_HANDLE", "ASSIGN_VARIABLE"} + supported_init_ops = {"VAR_HANDLE", "ASSIGN_VARIABLE", "HASHTABLE", "HASHTABLE_IMPORT"} for op_idx in range(init_subgraph.OperatorsLength()): op_name = self.get_op_code_str(init_subgraph.Operators(op_idx)) if op_name not in supported_init_ops: @@ -2415,6 +2578,8 @@ def _convert_call_once_init_subgraph(self, init_subgraph): old_in_call_once_init = self.conversion_state["in_call_once_init"] self.conversion_state["in_call_once_init"] = True try: + # The supported init ops below only update importer state and return None. + # If future CALL_ONCE ops emit Relax bindings, revisit sharing the parent builder. subgraph_converter = type(self)( self.model, init_subgraph, diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index f771e0a2a63a..2c02f6394a8f 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -3975,6 +3975,26 @@ def _build_empty_builtin_options(builder, options_name): return getattr(options_module, f"{options_name}End")(builder) +def _build_hashtable_options( + builder, + table_id=0, + key_dtype=None, + value_dtype=None, +): + try: + hashtable_options = _get_tflite_schema_module("HashtableOptions") + except ModuleNotFoundError: + pytest.skip("TFLite schema does not provide HashtableOptions") + + key_dtype = _tfl_tensor_type.INT32 if key_dtype is None else key_dtype + value_dtype = _tfl_tensor_type.INT32 if value_dtype is None else value_dtype + hashtable_options.HashtableOptionsStart(builder) + hashtable_options.HashtableOptionsAddTableId(builder, table_id) + hashtable_options.HashtableOptionsAddKeyDtype(builder, key_dtype) + hashtable_options.HashtableOptionsAddValueDtype(builder, value_dtype) + return hashtable_options.HashtableOptionsEnd(builder) + + def _load_model_from_buffer(model_bytes): if hasattr(tflite.Model, "Model"): tflite_model = tflite.Model.Model.GetRootAsModel(model_bytes, 0) @@ -5485,6 +5505,405 @@ def _build_tflite_resource_read_uninitialized_model(): ) +def _build_tflite_hashtable_find_model(vector_values=False, scalar_query=False): + """Build a model that imports a static hashtable and finds runtime query keys.""" + builder = flatbuffers.Builder(1024) + resource_type = _get_resource_tensor_type() + table_keys = np.array([10, 20], dtype=np.int32) + if vector_values: + table_values = np.array([[100, 101], [200, 201]], dtype=np.int32) + default_value = np.array([-1, -2], dtype=np.int32) + default_shape = [2] + output_shape = [2] if scalar_query else [3, 2] + else: + table_values = np.array([100, 200], dtype=np.int32) + default_value = np.array(-1, dtype=np.int32) + default_shape = [] + output_shape = [] if scalar_query else [3] + + call_once_options = _build_call_once_options(builder, 1) + main_table_options = _build_hashtable_options(builder, table_id=0) + find_options = _build_empty_builtin_options(builder, "HashtableFindOptions") + init_table_options = _build_hashtable_options(builder, table_id=0) + import_options = _build_empty_builtin_options(builder, "HashtableImportOptions") + + query_tensor = _build_tensor( + builder, + 0, + [] if scalar_query else [3], + tensor_type=_tfl_tensor_type.INT32, + ) + table_tensor = _build_tensor(builder, 0, [], tensor_type=resource_type) + default_tensor = _build_tensor(builder, 1, default_shape, tensor_type=_tfl_tensor_type.INT32) + output_tensor = _build_tensor(builder, 0, output_shape, tensor_type=_tfl_tensor_type.INT32) + main_call_once = _build_operator( + builder, + 0, + [], + [], + builtin_options_type=_get_builtin_options_type("CallOnceOptions"), + builtin_options=call_once_options, + ) + main_hashtable = _build_operator( + builder, + 1, + [], + [1], + builtin_options_type=_get_builtin_options_type("HashtableOptions"), + builtin_options=main_table_options, + ) + main_find = _build_operator( + builder, + 2, + [1, 0, 2], + [3], + builtin_options_type=_get_builtin_options_type("HashtableFindOptions"), + builtin_options=find_options, + ) + main_subgraph = _build_subgraph( + builder, + tensors=[query_tensor, table_tensor, default_tensor, output_tensor], + operators=[main_call_once, main_hashtable, main_find], + inputs=[0], + outputs=[3], + ) + + init_table_tensor = _build_tensor(builder, 0, [], tensor_type=resource_type) + init_keys_tensor = _build_tensor(builder, 2, [2], tensor_type=_tfl_tensor_type.INT32) + init_values_tensor = _build_tensor( + builder, + 3, + list(table_values.shape), + tensor_type=_tfl_tensor_type.INT32, + ) + init_hashtable = _build_operator( + builder, + 1, + [], + [0], + builtin_options_type=_get_builtin_options_type("HashtableOptions"), + builtin_options=init_table_options, + ) + init_import = _build_operator( + builder, + 3, + [0, 1, 2], + [], + builtin_options_type=_get_builtin_options_type("HashtableImportOptions"), + builtin_options=import_options, + ) + init_subgraph = _build_subgraph( + builder, + tensors=[init_table_tensor, init_keys_tensor, init_values_tensor], + operators=[init_hashtable, init_import], + inputs=[], + outputs=[], + ) + + operator_codes = [ + _build_operator_code(builder, _get_builtin_operator("CALL_ONCE")), + _build_operator_code(builder, _get_builtin_operator("HASHTABLE")), + _build_operator_code(builder, _get_builtin_operator("HASHTABLE_FIND")), + _build_operator_code(builder, _get_builtin_operator("HASHTABLE_IMPORT")), + ] + buffers = [ + _build_buffer(builder), + _build_buffer(builder, default_value.tobytes()), + _build_buffer(builder, table_keys.tobytes()), + _build_buffer(builder, table_values.tobytes()), + ] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + extra_subgraphs=[init_subgraph], + operator_codes=operator_codes, + buffers=buffers, + ) + + +def _build_tflite_hashtable_size_model(): + """Build a model that imports a static hashtable and returns its size.""" + builder = flatbuffers.Builder(1024) + resource_type = _get_resource_tensor_type() + table_keys = np.array([10, 20], dtype=np.int32) + table_values = np.array([100, 200], dtype=np.int32) + + call_once_options = _build_call_once_options(builder, 1) + main_table_options = _build_hashtable_options(builder, table_id=0) + size_options = _build_empty_builtin_options(builder, "HashtableSizeOptions") + init_table_options = _build_hashtable_options(builder, table_id=0) + import_options = _build_empty_builtin_options(builder, "HashtableImportOptions") + + table_tensor = _build_tensor(builder, 0, [], tensor_type=resource_type) + size_tensor = _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32) + main_call_once = _build_operator( + builder, + 0, + [], + [], + builtin_options_type=_get_builtin_options_type("CallOnceOptions"), + builtin_options=call_once_options, + ) + main_hashtable = _build_operator( + builder, + 1, + [], + [0], + builtin_options_type=_get_builtin_options_type("HashtableOptions"), + builtin_options=main_table_options, + ) + main_size = _build_operator( + builder, + 2, + [0], + [1], + builtin_options_type=_get_builtin_options_type("HashtableSizeOptions"), + builtin_options=size_options, + ) + main_subgraph = _build_subgraph( + builder, + tensors=[table_tensor, size_tensor], + operators=[main_call_once, main_hashtable, main_size], + inputs=[], + outputs=[1], + ) + + init_table_tensor = _build_tensor(builder, 0, [], tensor_type=resource_type) + init_keys_tensor = _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT32) + init_values_tensor = _build_tensor(builder, 2, [2], tensor_type=_tfl_tensor_type.INT32) + init_hashtable = _build_operator( + builder, + 1, + [], + [0], + builtin_options_type=_get_builtin_options_type("HashtableOptions"), + builtin_options=init_table_options, + ) + init_import = _build_operator( + builder, + 3, + [0, 1, 2], + [], + builtin_options_type=_get_builtin_options_type("HashtableImportOptions"), + builtin_options=import_options, + ) + init_subgraph = _build_subgraph( + builder, + tensors=[init_table_tensor, init_keys_tensor, init_values_tensor], + operators=[init_hashtable, init_import], + inputs=[], + outputs=[], + ) + + operator_codes = [ + _build_operator_code(builder, _get_builtin_operator("CALL_ONCE")), + _build_operator_code(builder, _get_builtin_operator("HASHTABLE")), + _build_operator_code(builder, _get_builtin_operator("HASHTABLE_SIZE")), + _build_operator_code(builder, _get_builtin_operator("HASHTABLE_IMPORT")), + ] + buffers = [ + _build_buffer(builder), + _build_buffer(builder, table_keys.tobytes()), + _build_buffer(builder, table_values.tobytes()), + ] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + extra_subgraphs=[init_subgraph], + operator_codes=operator_codes, + buffers=buffers, + ) + + +def _build_tflite_hashtable_import_in_main_model(): + """Build a model that attempts to import hashtable values in the main subgraph.""" + builder = flatbuffers.Builder(1024) + resource_type = _get_resource_tensor_type() + table_keys = np.array([10, 20], dtype=np.int32) + table_values = np.array([100, 200], dtype=np.int32) + + table_options = _build_hashtable_options(builder, table_id=0) + import_options = _build_empty_builtin_options(builder, "HashtableImportOptions") + + table_tensor = _build_tensor(builder, 0, [], tensor_type=resource_type) + keys_tensor = _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT32) + values_tensor = _build_tensor(builder, 2, [2], tensor_type=_tfl_tensor_type.INT32) + hashtable = _build_operator( + builder, + 0, + [], + [0], + builtin_options_type=_get_builtin_options_type("HashtableOptions"), + builtin_options=table_options, + ) + hashtable_import = _build_operator( + builder, + 1, + [0, 1, 2], + [], + builtin_options_type=_get_builtin_options_type("HashtableImportOptions"), + builtin_options=import_options, + ) + main_subgraph = _build_subgraph( + builder, + tensors=[table_tensor, keys_tensor, values_tensor], + operators=[hashtable, hashtable_import], + inputs=[], + outputs=[2], + ) + operator_codes = [ + _build_operator_code(builder, _get_builtin_operator("HASHTABLE")), + _build_operator_code(builder, _get_builtin_operator("HASHTABLE_IMPORT")), + ] + buffers = [ + _build_buffer(builder), + _build_buffer(builder, table_keys.tobytes()), + _build_buffer(builder, table_values.tobytes()), + ] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + operator_codes=operator_codes, + buffers=buffers, + ) + + +def _build_tflite_hashtable_import_nonconstant_model(): + """Build a model that imports hashtable values from a non-constant keys tensor.""" + builder = flatbuffers.Builder(1024) + resource_type = _get_resource_tensor_type() + table_values = np.array([100, 200], dtype=np.int32) + + call_once_options = _build_call_once_options(builder, 1) + main_table_options = _build_hashtable_options(builder, table_id=0) + size_options = _build_empty_builtin_options(builder, "HashtableSizeOptions") + init_table_options = _build_hashtable_options(builder, table_id=0) + import_options = _build_empty_builtin_options(builder, "HashtableImportOptions") + + table_tensor = _build_tensor(builder, 0, [], tensor_type=resource_type) + size_tensor = _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32) + main_call_once = _build_operator( + builder, + 0, + [], + [], + builtin_options_type=_get_builtin_options_type("CallOnceOptions"), + builtin_options=call_once_options, + ) + main_hashtable = _build_operator( + builder, + 1, + [], + [0], + builtin_options_type=_get_builtin_options_type("HashtableOptions"), + builtin_options=main_table_options, + ) + main_size = _build_operator( + builder, + 2, + [0], + [1], + builtin_options_type=_get_builtin_options_type("HashtableSizeOptions"), + builtin_options=size_options, + ) + main_subgraph = _build_subgraph( + builder, + tensors=[table_tensor, size_tensor], + operators=[main_call_once, main_hashtable, main_size], + inputs=[], + outputs=[1], + ) + + init_table_tensor = _build_tensor(builder, 0, [], tensor_type=resource_type) + init_keys_tensor = _build_tensor(builder, 0, [2], tensor_type=_tfl_tensor_type.INT32) + init_values_tensor = _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT32) + init_hashtable = _build_operator( + builder, + 1, + [], + [0], + builtin_options_type=_get_builtin_options_type("HashtableOptions"), + builtin_options=init_table_options, + ) + init_import = _build_operator( + builder, + 3, + [0, 1, 2], + [], + builtin_options_type=_get_builtin_options_type("HashtableImportOptions"), + builtin_options=import_options, + ) + init_subgraph = _build_subgraph( + builder, + tensors=[init_table_tensor, init_keys_tensor, init_values_tensor], + operators=[init_hashtable, init_import], + inputs=[], + outputs=[], + ) + + operator_codes = [ + _build_operator_code(builder, _get_builtin_operator("CALL_ONCE")), + _build_operator_code(builder, _get_builtin_operator("HASHTABLE")), + _build_operator_code(builder, _get_builtin_operator("HASHTABLE_SIZE")), + _build_operator_code(builder, _get_builtin_operator("HASHTABLE_IMPORT")), + ] + buffers = [ + _build_buffer(builder), + _build_buffer(builder, table_values.tobytes()), + ] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + extra_subgraphs=[init_subgraph], + operator_codes=operator_codes, + buffers=buffers, + ) + + +def _build_tflite_hashtable_size_uninitialized_model(): + """Build a model that queries the size of a hashtable without importing values.""" + builder = flatbuffers.Builder(1024) + resource_type = _get_resource_tensor_type() + + table_options = _build_hashtable_options(builder, table_id=0) + size_options = _build_empty_builtin_options(builder, "HashtableSizeOptions") + table_tensor = _build_tensor(builder, 0, [], tensor_type=resource_type) + size_tensor = _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32) + hashtable = _build_operator( + builder, + 0, + [], + [0], + builtin_options_type=_get_builtin_options_type("HashtableOptions"), + builtin_options=table_options, + ) + hashtable_size = _build_operator( + builder, + 1, + [0], + [1], + builtin_options_type=_get_builtin_options_type("HashtableSizeOptions"), + builtin_options=size_options, + ) + main_subgraph = _build_subgraph( + builder, + tensors=[table_tensor, size_tensor], + operators=[hashtable, hashtable_size], + inputs=[], + outputs=[1], + ) + operator_codes = [ + _build_operator_code(builder, _get_builtin_operator("HASHTABLE")), + _build_operator_code(builder, _get_builtin_operator("HASHTABLE_SIZE")), + ] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + operator_codes=operator_codes, + buffers=[_build_buffer(builder)], + ) + + def test_resource_variable_call_once_init_read(): """Test reading a resource variable initialized by a supported CALL_ONCE subgraph.""" mod = _load_model_from_buffer(_build_tflite_resource_variable_model()) @@ -5514,6 +5933,127 @@ def test_read_variable_uninitialized_unsupported(): _load_model_from_buffer(_build_tflite_resource_read_uninitialized_model()) +def test_hashtable_call_once_import_find(): + """Test finding values in a hashtable initialized by a supported CALL_ONCE subgraph.""" + mod = _load_model_from_buffer(_build_tflite_hashtable_find_model()) + + @I.ir_module + class Expected: + @R.function + def main(tvmgen_tensor_0: R.Tensor((3,), dtype="int32")) -> R.Tensor((3,), dtype="int32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + lv: R.Tensor((3, 1), dtype="int32") = R.expand_dims(tvmgen_tensor_0, axis=[-1]) + lv1: R.Tensor((3, 2), dtype="bool") = R.equal(lv, R.const([10, 20], "int32")) + lv2: R.Tensor((3, 2), dtype="int32") = R.astype(lv1, dtype="int32") + lv3: R.Tensor((3,), dtype="int32") = R.max(lv2, axis=[-1], keepdims=False) + lv4: R.Tensor((3,), dtype="bool") = R.greater(lv3, R.const(0, "int32")) + lv5: R.Tensor((3,), dtype="int64") = R.argmax(lv2, axis=-1, keepdims=False) + lv6: R.Tensor((3,), dtype="int32") = R.take( + R.const([100, 200], "int32"), lv5, axis=0, mode="fast" + ) + lv7: R.Tensor((3,), dtype="int32") = R.broadcast_to( + R.const(-1, "int32"), R.shape([3]) + ) + gv: R.Tensor((3,), dtype="int32") = R.where(lv4, lv6, lv7) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_hashtable_call_once_import_find_all_mismatch(): + """Test HASHTABLE_FIND returns the default value for keys absent from a static table.""" + mod = _load_model_from_buffer(_build_tflite_hashtable_find_model()) + + ex = tvm.compile(mod, tvm.target.Target("llvm")) + vm = relax.VirtualMachine(ex, tvm.cpu()) + vm.set_input("main", np.array([30, 40, 50], dtype=np.int32)) + vm.invoke_stateful("main") + tvm_output = vm.get_outputs("main") + + np.testing.assert_array_equal(tvm_output.numpy(), np.array([-1, -1, -1], dtype=np.int32)) + + +def test_hashtable_call_once_import_find_scalar_query(): + """Test HASHTABLE_FIND supports scalar query keys.""" + mod = _load_model_from_buffer(_build_tflite_hashtable_find_model(scalar_query=True)) + + ex = tvm.compile(mod, tvm.target.Target("llvm")) + vm = relax.VirtualMachine(ex, tvm.cpu()) + vm.set_input("main", np.array(20, dtype=np.int32)) + vm.invoke_stateful("main") + tvm_output = vm.get_outputs("main") + + np.testing.assert_array_equal(tvm_output.numpy(), np.array(200, dtype=np.int32)) + + +def test_hashtable_call_once_import_find_vector_values(): + """Test finding vector values in a hashtable initialized by CALL_ONCE.""" + mod = _load_model_from_buffer(_build_tflite_hashtable_find_model(vector_values=True)) + + @I.ir_module + class Expected: + @R.function + def main(tvmgen_tensor_0: R.Tensor((3,), dtype="int32")) -> R.Tensor((3, 2), dtype="int32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + lv: R.Tensor((3, 1), dtype="int32") = R.expand_dims(tvmgen_tensor_0, axis=[-1]) + lv1: R.Tensor((3, 2), dtype="bool") = R.equal(lv, R.const([10, 20], "int32")) + lv2: R.Tensor((3, 2), dtype="int32") = R.astype(lv1, dtype="int32") + lv3: R.Tensor((3,), dtype="int32") = R.max(lv2, axis=[-1], keepdims=False) + lv4: R.Tensor((3,), dtype="bool") = R.greater(lv3, R.const(0, "int32")) + lv5: R.Tensor((3, 1), dtype="bool") = R.expand_dims(lv4, axis=[-1]) + lv6: R.Tensor((3, 2), dtype="bool") = R.broadcast_to(lv5, R.shape([3, 2])) + lv7: R.Tensor((3,), dtype="int64") = R.argmax(lv2, axis=-1, keepdims=False) + lv8: R.Tensor((3, 2), dtype="int32") = R.take( + R.const([[100, 101], [200, 201]], "int32"), lv7, axis=0, mode="fast" + ) + lv9: R.Tensor((3, 2), dtype="int32") = R.broadcast_to( + R.const([-1, -2], "int32"), R.shape([3, 2]) + ) + gv: R.Tensor((3, 2), dtype="int32") = R.where(lv6, lv8, lv9) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_hashtable_call_once_import_size(): + """Test HASHTABLE_SIZE for a table initialized by a supported CALL_ONCE subgraph.""" + mod = _load_model_from_buffer(_build_tflite_hashtable_size_model()) + + @I.ir_module + class Expected: + @R.function + def main() -> R.Tensor((), dtype="int32"): + R.func_attr({"num_input": 0}) + with R.dataflow(): + gv: R.Tensor((), dtype="int32") = R.const(2, "int32") + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_hashtable_import_main_subgraph_unsupported(): + """Test HASHTABLE_IMPORT remains unsupported outside CALL_ONCE initialization.""" + with pytest.raises(tvm.error.OpNotImplemented, match="HASHTABLE_IMPORT outside CALL_ONCE"): + _load_model_from_buffer(_build_tflite_hashtable_import_in_main_model()) + + +def test_hashtable_import_nonconstant_unsupported(): + """Test HASHTABLE_IMPORT rejects non-constant keys or values.""" + with pytest.raises(tvm.error.OpNotImplemented, match="requires constant keys and values"): + _load_model_from_buffer(_build_tflite_hashtable_import_nonconstant_model()) + + +def test_hashtable_size_uninitialized_unsupported(): + """Test HASHTABLE_SIZE rejects tables without supported initialization.""" + with pytest.raises(tvm.error.OpNotImplemented, match="HASHTABLE_SIZE requires a table"): + _load_model_from_buffer(_build_tflite_hashtable_size_uninitialized_model()) + + def _get_stablehlo_builtin_operator(builtin_name): if not hasattr(_tfl_builtin_operator, builtin_name): pytest.skip(f"TFLite schema does not provide BuiltinOperator.{builtin_name}") From 43d5042e6f8a7e6cc78cf8327576eb601a08d135 Mon Sep 17 00:00:00 2001 From: Aharrypotter Date: Fri, 29 May 2026 18:49:40 +0800 Subject: [PATCH 3/3] [Frontend][TFLite] Refine hashtable import support --- .../relax/frontend/tflite/tflite_frontend.py | 175 ++++++----- tests/python/relax/test_frontend_tflite.py | 288 ++++-------------- 2 files changed, 144 insertions(+), 319 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 9962cb71d635..b3614ba89242 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -618,25 +618,52 @@ def convert_read_variable(self, op): ) return resource_values[resource_key] + def _is_tflite_string_type(self, tensor_type): + from tflite.TensorType import TensorType + + return hasattr(TensorType, "STRING") and tensor_type == TensorType.STRING + + def _is_supported_hashtable_type_pair(self, key_dtype, value_dtype): + from tflite.TensorType import TensorType + + return (key_dtype == TensorType.INT64 and self._is_tflite_string_type(value_dtype)) or ( + self._is_tflite_string_type(key_dtype) and value_dtype == TensorType.INT64 + ) + def _get_hashtable_key(self, op, fallback_tensor=None): - """Return a stable key for a TFLite HASHTABLE resource.""" + """Return a stable key and TFLite dtype pair for a HASHTABLE resource.""" table_id = None + key_dtype = None + value_dtype = None if op.BuiltinOptions() is not None: try: from tflite.HashtableOptions import HashtableOptions opts = self._get_builtin_options(op, HashtableOptions) table_id = int(opts.TableId()) + key_dtype = int(opts.KeyDtype()) + value_dtype = int(opts.ValueDtype()) except (ImportError, ModuleNotFoundError): pass + if key_dtype is None or value_dtype is None: + raise tvm.error.OpNotImplemented("HASHTABLE requires HashtableOptions") + if not self._is_supported_hashtable_type_pair(key_dtype, value_dtype): + raise tvm.error.OpNotImplemented( + "TFLite HASHTABLE only supports int64/string or string/int64 tables" + ) + if table_id is not None: - return table_id + return table_id, key_dtype, value_dtype if fallback_tensor is not None: - return get_tensor_name(self.subgraph, fallback_tensor.tensor_idx) + return ( + get_tensor_name(self.subgraph, fallback_tensor.tensor_idx), + key_dtype, + value_dtype, + ) raise tvm.error.OpNotImplemented("HASHTABLE requires HashtableOptions") - def _get_hashtable_key_for_handle(self, tensor, op_name): + def _get_hashtable_info_for_handle(self, tensor, op_name): tensor_name = get_tensor_name(self.subgraph, tensor.tensor_idx) if tensor_name not in self.hashtable_handles: raise tvm.error.OpNotImplemented( @@ -644,6 +671,20 @@ def _get_hashtable_key_for_handle(self, tensor, op_name): ) return self.hashtable_handles[tensor_name] + @staticmethod + def _get_tensor_shape_tuple(tensor_wrapper): + if tensor_wrapper.tensor.ShapeLength() == 0: + return () + return tuple(int(dim) for dim in tensor_wrapper.tensor.ShapeAsNumpy()) + + @staticmethod + def _has_tensor_buffer_data(tensor_wrapper): + return ( + tensor_wrapper.buffer is not None + and hasattr(tensor_wrapper.buffer, "DataLength") + and tensor_wrapper.buffer.DataLength() > 0 + ) + def convert_hashtable(self, op): """Convert a TFLite HASHTABLE into an importer-local table handle.""" input_tensors = self.get_input_tensors(op) @@ -651,21 +692,17 @@ def convert_hashtable(self, op): if len(input_tensors) != 0 or len(output_tensors) != 1: raise tvm.error.OpNotImplemented("HASHTABLE expects no inputs and one output") - table_key = self._get_hashtable_key(op, output_tensors[0]) + table_key, key_dtype, value_dtype = self._get_hashtable_key(op, output_tensors[0]) table_tensor_name = get_tensor_name(self.subgraph, output_tensors[0].tensor_idx) - self.hashtable_handles[table_tensor_name] = table_key + self.hashtable_handles[table_tensor_name] = { + "table_key": table_key, + "key_dtype": key_dtype, + "value_dtype": value_dtype, + } return None - @staticmethod - def _has_tensor_buffer_data(tensor_wrapper): - return ( - tensor_wrapper.buffer is not None - and hasattr(tensor_wrapper.buffer, "DataLength") - and tensor_wrapper.buffer.DataLength() > 0 - ) - def convert_hashtable_import(self, op): - """Convert the CALL_ONCE initialization subset of HASHTABLE_IMPORT.""" + """Convert static metadata for the CALL_ONCE HASHTABLE_IMPORT subset.""" if not self.conversion_state["in_call_once_init"]: raise tvm.error.OpNotImplemented( "HASHTABLE_IMPORT outside CALL_ONCE initialization is not supported by the " @@ -679,100 +716,58 @@ def convert_hashtable_import(self, op): "HASHTABLE_IMPORT expects table, keys, and values inputs with no outputs" ) - table_key = self._get_hashtable_key_for_handle(input_tensors[0], "HASHTABLE_IMPORT") - if not self._has_tensor_buffer_data(input_tensors[1]) or not self._has_tensor_buffer_data( - input_tensors[2] + table_info = self._get_hashtable_info_for_handle(input_tensors[0], "HASHTABLE_IMPORT") + key_tensor = input_tensors[1] + value_tensor = input_tensors[2] + if ( + key_tensor.tensor.Type() != table_info["key_dtype"] + or value_tensor.tensor.Type() != table_info["value_dtype"] + ): + raise tvm.error.OpNotImplemented("HASHTABLE_IMPORT key/value dtypes mismatch") + key_shape = self._get_tensor_shape_tuple(key_tensor) + value_shape = self._get_tensor_shape_tuple(value_tensor) + if key_shape != value_shape: + raise tvm.error.OpNotImplemented("HASHTABLE_IMPORT requires keys and values same shape") + if not self._has_tensor_buffer_data(key_tensor) or not self._has_tensor_buffer_data( + value_tensor ): raise tvm.error.OpNotImplemented("HASHTABLE_IMPORT requires constant keys and values") - keys = self.get_tensor_value(input_tensors[1]) - values = self.get_tensor_value(input_tensors[2]) - if keys.ndim != 1 or values.ndim < 1: - raise tvm.error.OpNotImplemented( - "HASHTABLE_IMPORT requires one-dimensional keys and at least one-dimensional values" - ) - if keys.shape[0] != values.shape[0]: - raise tvm.error.OpNotImplemented("HASHTABLE_IMPORT keys and values size mismatch") - self.conversion_state["hashtable_values"][table_key] = { - "keys": self.get_tensor_expr(input_tensors[1]), - "values": self.get_tensor_expr(input_tensors[2]), - "size": int(keys.shape[0]), - "value_shape": tuple(int(dim) for dim in values.shape[1:]), - } - return None - - def convert_hashtable_find(self, op): - """Convert HASHTABLE_FIND for static tables initialized by CALL_ONCE.""" - input_tensors = self.get_input_tensors(op) - output_tensors = self.get_output_tensors(op) - if len(input_tensors) != 3 or len(output_tensors) != 1: - raise tvm.error.OpNotImplemented( - "HASHTABLE_FIND expects table, keys, and default value inputs with one output" - ) - - table_key = self._get_hashtable_key_for_handle(input_tensors[0], "HASHTABLE_FIND") hashtable_values = self.conversion_state["hashtable_values"] + table_key = table_info["table_key"] if table_key not in hashtable_values: - raise tvm.error.OpNotImplemented( - "HASHTABLE_FIND requires a table initialized by a supported CALL_ONCE subgraph" - ) - - table = hashtable_values[table_key] - output_shape = ( - tuple(output_tensors[0].tensor.ShapeAsNumpy()) - if output_tensors[0].tensor.ShapeLength() > 0 - else () - ) - value_shape = table["value_shape"] - if value_shape and ( - len(output_shape) < len(value_shape) - or tuple(output_shape[-len(value_shape) :]) != value_shape - ): - raise tvm.error.OpNotImplemented( - "HASHTABLE_FIND output shape must append the imported value shape" - ) - query_keys = self.get_tensor_expr(input_tensors[1]) - table_keys = table["keys"] - table_values = table["values"] - default_value = self.get_tensor_expr(input_tensors[2]) + hashtable_values[table_key] = { + "size": math.prod(key_shape) if key_shape else 1, + "key_dtype": table_info["key_dtype"], + "value_dtype": table_info["value_dtype"], + } + return None - if input_tensors[1].tensor.ShapeLength() == 0: - matches = relax.op.equal(query_keys, table_keys) - else: - matches = relax.op.equal(relax.op.expand_dims(query_keys, axis=-1), table_keys) - matches_int = relax.op.astype(matches, "int32") - matched_indices = relax.op.argmax(matches_int, axis=-1) - selected_values = relax.op.take(table_values, matched_indices, axis=0, mode="fast") - has_match = relax.op.greater( - relax.op.max(matches_int, axis=-1), - relax.const(0, dtype="int32"), + def convert_hashtable_find(self, op): + """Reject HASHTABLE_FIND until Relax can represent TFLite string tensors.""" + raise tvm.error.OpNotImplemented( + "HASHTABLE_FIND requires TensorType.STRING support in Relax TFLite frontend" ) - if value_shape: - for _ in value_shape: - has_match = relax.op.expand_dims(has_match, axis=-1) - has_match = relax.op.broadcast_to( - has_match, - output_shape, - ) - default_values = relax.op.broadcast_to(default_value, output_shape) - return relax.op.where(has_match, selected_values, default_values) def convert_hashtable_size(self, op): - """Convert HASHTABLE_SIZE for static tables initialized by CALL_ONCE.""" + """Convert HASHTABLE_SIZE for a statically imported TFLite hashtable.""" input_tensors = self.get_input_tensors(op) output_tensors = self.get_output_tensors(op) if len(input_tensors) != 1 or len(output_tensors) != 1: raise tvm.error.OpNotImplemented("HASHTABLE_SIZE expects one input and one output") - table_key = self._get_hashtable_key_for_handle(input_tensors[0], "HASHTABLE_SIZE") + from tflite.TensorType import TensorType + + if output_tensors[0].tensor.Type() != TensorType.INT64: + raise tvm.error.OpNotImplemented("HASHTABLE_SIZE output must be int64") + table_info = self._get_hashtable_info_for_handle(input_tensors[0], "HASHTABLE_SIZE") + table_key = table_info["table_key"] hashtable_values = self.conversion_state["hashtable_values"] if table_key not in hashtable_values: raise tvm.error.OpNotImplemented( "HASHTABLE_SIZE requires a table initialized by a supported CALL_ONCE subgraph" ) - - output_dtype = self.get_tensor_type_str(output_tensors[0].tensor.Type()) - return relax.const(hashtable_values[table_key]["size"], dtype=output_dtype) + return relax.const(np.array([hashtable_values[table_key]["size"]], dtype=np.int64), "int64") def get_op_code_str(self, op): """Get TFLite ops string representation""" diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 2c02f6394a8f..feae04abf685 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -3953,6 +3953,24 @@ def _get_resource_tensor_type(): return getattr(_tfl_tensor_type, "RESOURCE") +def _get_string_tensor_type(): + if not hasattr(_tfl_tensor_type, "STRING"): + pytest.skip("TFLite schema does not provide TensorType.STRING") + return getattr(_tfl_tensor_type, "STRING") + + +def _build_tflite_string_buffer(values): + encoded = [value.encode("utf-8") for value in values] + offsets = [] + cursor = 4 * (len(encoded) + 2) + for value in encoded: + offsets.append(cursor) + cursor += len(value) + offsets.append(cursor) + header = np.array([len(encoded), *offsets], dtype=np.int32).tobytes() + return header + b"".join(encoded) + + def _build_var_handle_options(builder, shared_name="resource_var", container=""): try: var_handle_options = _get_tflite_schema_module("VarHandleOptions") @@ -3986,8 +4004,8 @@ def _build_hashtable_options( except ModuleNotFoundError: pytest.skip("TFLite schema does not provide HashtableOptions") - key_dtype = _tfl_tensor_type.INT32 if key_dtype is None else key_dtype - value_dtype = _tfl_tensor_type.INT32 if value_dtype is None else value_dtype + key_dtype = _tfl_tensor_type.INT64 if key_dtype is None else key_dtype + value_dtype = _get_string_tensor_type() if value_dtype is None else value_dtype hashtable_options.HashtableOptionsStart(builder) hashtable_options.HashtableOptionsAddTableId(builder, table_id) hashtable_options.HashtableOptionsAddKeyDtype(builder, key_dtype) @@ -5505,21 +5523,14 @@ def _build_tflite_resource_read_uninitialized_model(): ) -def _build_tflite_hashtable_find_model(vector_values=False, scalar_query=False): +def _build_tflite_hashtable_find_model(): """Build a model that imports a static hashtable and finds runtime query keys.""" builder = flatbuffers.Builder(1024) resource_type = _get_resource_tensor_type() - table_keys = np.array([10, 20], dtype=np.int32) - if vector_values: - table_values = np.array([[100, 101], [200, 201]], dtype=np.int32) - default_value = np.array([-1, -2], dtype=np.int32) - default_shape = [2] - output_shape = [2] if scalar_query else [3, 2] - else: - table_values = np.array([100, 200], dtype=np.int32) - default_value = np.array(-1, dtype=np.int32) - default_shape = [] - output_shape = [] if scalar_query else [3] + string_type = _get_string_tensor_type() + table_keys = np.array([10, 20], dtype=np.int64) + table_values = _build_tflite_string_buffer(["one hundred", "two hundred"]) + default_value = _build_tflite_string_buffer(["missing"]) call_once_options = _build_call_once_options(builder, 1) main_table_options = _build_hashtable_options(builder, table_id=0) @@ -5527,15 +5538,10 @@ def _build_tflite_hashtable_find_model(vector_values=False, scalar_query=False): init_table_options = _build_hashtable_options(builder, table_id=0) import_options = _build_empty_builtin_options(builder, "HashtableImportOptions") - query_tensor = _build_tensor( - builder, - 0, - [] if scalar_query else [3], - tensor_type=_tfl_tensor_type.INT32, - ) - table_tensor = _build_tensor(builder, 0, [], tensor_type=resource_type) - default_tensor = _build_tensor(builder, 1, default_shape, tensor_type=_tfl_tensor_type.INT32) - output_tensor = _build_tensor(builder, 0, output_shape, tensor_type=_tfl_tensor_type.INT32) + query_tensor = _build_tensor(builder, 0, [3], tensor_type=_tfl_tensor_type.INT64) + table_tensor = _build_tensor(builder, 0, [1], tensor_type=resource_type) + default_tensor = _build_tensor(builder, 1, [], tensor_type=string_type) + output_tensor = _build_tensor(builder, 0, [3], tensor_type=string_type) main_call_once = _build_operator( builder, 0, @@ -5568,13 +5574,13 @@ def _build_tflite_hashtable_find_model(vector_values=False, scalar_query=False): outputs=[3], ) - init_table_tensor = _build_tensor(builder, 0, [], tensor_type=resource_type) - init_keys_tensor = _build_tensor(builder, 2, [2], tensor_type=_tfl_tensor_type.INT32) + init_table_tensor = _build_tensor(builder, 0, [1], tensor_type=resource_type) + init_keys_tensor = _build_tensor(builder, 2, [2], tensor_type=_tfl_tensor_type.INT64) init_values_tensor = _build_tensor( builder, 3, - list(table_values.shape), - tensor_type=_tfl_tensor_type.INT32, + [2], + tensor_type=string_type, ) init_hashtable = _build_operator( builder, @@ -5608,9 +5614,9 @@ def _build_tflite_hashtable_find_model(vector_values=False, scalar_query=False): ] buffers = [ _build_buffer(builder), - _build_buffer(builder, default_value.tobytes()), + _build_buffer(builder, default_value), _build_buffer(builder, table_keys.tobytes()), - _build_buffer(builder, table_values.tobytes()), + _build_buffer(builder, table_values), ] return _finish_tflite_model( builder, @@ -5625,8 +5631,9 @@ def _build_tflite_hashtable_size_model(): """Build a model that imports a static hashtable and returns its size.""" builder = flatbuffers.Builder(1024) resource_type = _get_resource_tensor_type() - table_keys = np.array([10, 20], dtype=np.int32) - table_values = np.array([100, 200], dtype=np.int32) + string_type = _get_string_tensor_type() + table_keys = np.array([10, 20], dtype=np.int64) + table_values = _build_tflite_string_buffer(["one hundred", "two hundred"]) call_once_options = _build_call_once_options(builder, 1) main_table_options = _build_hashtable_options(builder, table_id=0) @@ -5634,8 +5641,8 @@ def _build_tflite_hashtable_size_model(): init_table_options = _build_hashtable_options(builder, table_id=0) import_options = _build_empty_builtin_options(builder, "HashtableImportOptions") - table_tensor = _build_tensor(builder, 0, [], tensor_type=resource_type) - size_tensor = _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32) + table_tensor = _build_tensor(builder, 0, [1], tensor_type=resource_type) + size_tensor = _build_tensor(builder, 0, [1], tensor_type=_tfl_tensor_type.INT64) main_call_once = _build_operator( builder, 0, @@ -5668,9 +5675,9 @@ def _build_tflite_hashtable_size_model(): outputs=[1], ) - init_table_tensor = _build_tensor(builder, 0, [], tensor_type=resource_type) - init_keys_tensor = _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT32) - init_values_tensor = _build_tensor(builder, 2, [2], tensor_type=_tfl_tensor_type.INT32) + init_table_tensor = _build_tensor(builder, 0, [1], tensor_type=resource_type) + init_keys_tensor = _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT64) + init_values_tensor = _build_tensor(builder, 2, [2], tensor_type=string_type) init_hashtable = _build_operator( builder, 1, @@ -5704,7 +5711,7 @@ def _build_tflite_hashtable_size_model(): buffers = [ _build_buffer(builder), _build_buffer(builder, table_keys.tobytes()), - _build_buffer(builder, table_values.tobytes()), + _build_buffer(builder, table_values), ] return _finish_tflite_model( builder, @@ -5719,15 +5726,16 @@ def _build_tflite_hashtable_import_in_main_model(): """Build a model that attempts to import hashtable values in the main subgraph.""" builder = flatbuffers.Builder(1024) resource_type = _get_resource_tensor_type() - table_keys = np.array([10, 20], dtype=np.int32) - table_values = np.array([100, 200], dtype=np.int32) + string_type = _get_string_tensor_type() + table_keys = np.array([10, 20], dtype=np.int64) + table_values = _build_tflite_string_buffer(["one hundred", "two hundred"]) table_options = _build_hashtable_options(builder, table_id=0) import_options = _build_empty_builtin_options(builder, "HashtableImportOptions") - table_tensor = _build_tensor(builder, 0, [], tensor_type=resource_type) - keys_tensor = _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT32) - values_tensor = _build_tensor(builder, 2, [2], tensor_type=_tfl_tensor_type.INT32) + table_tensor = _build_tensor(builder, 0, [1], tensor_type=resource_type) + keys_tensor = _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT64) + values_tensor = _build_tensor(builder, 2, [2], tensor_type=string_type) hashtable = _build_operator( builder, 0, @@ -5758,103 +5766,11 @@ def _build_tflite_hashtable_import_in_main_model(): buffers = [ _build_buffer(builder), _build_buffer(builder, table_keys.tobytes()), - _build_buffer(builder, table_values.tobytes()), - ] - return _finish_tflite_model( - builder, - subgraph=main_subgraph, - operator_codes=operator_codes, - buffers=buffers, - ) - - -def _build_tflite_hashtable_import_nonconstant_model(): - """Build a model that imports hashtable values from a non-constant keys tensor.""" - builder = flatbuffers.Builder(1024) - resource_type = _get_resource_tensor_type() - table_values = np.array([100, 200], dtype=np.int32) - - call_once_options = _build_call_once_options(builder, 1) - main_table_options = _build_hashtable_options(builder, table_id=0) - size_options = _build_empty_builtin_options(builder, "HashtableSizeOptions") - init_table_options = _build_hashtable_options(builder, table_id=0) - import_options = _build_empty_builtin_options(builder, "HashtableImportOptions") - - table_tensor = _build_tensor(builder, 0, [], tensor_type=resource_type) - size_tensor = _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32) - main_call_once = _build_operator( - builder, - 0, - [], - [], - builtin_options_type=_get_builtin_options_type("CallOnceOptions"), - builtin_options=call_once_options, - ) - main_hashtable = _build_operator( - builder, - 1, - [], - [0], - builtin_options_type=_get_builtin_options_type("HashtableOptions"), - builtin_options=main_table_options, - ) - main_size = _build_operator( - builder, - 2, - [0], - [1], - builtin_options_type=_get_builtin_options_type("HashtableSizeOptions"), - builtin_options=size_options, - ) - main_subgraph = _build_subgraph( - builder, - tensors=[table_tensor, size_tensor], - operators=[main_call_once, main_hashtable, main_size], - inputs=[], - outputs=[1], - ) - - init_table_tensor = _build_tensor(builder, 0, [], tensor_type=resource_type) - init_keys_tensor = _build_tensor(builder, 0, [2], tensor_type=_tfl_tensor_type.INT32) - init_values_tensor = _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT32) - init_hashtable = _build_operator( - builder, - 1, - [], - [0], - builtin_options_type=_get_builtin_options_type("HashtableOptions"), - builtin_options=init_table_options, - ) - init_import = _build_operator( - builder, - 3, - [0, 1, 2], - [], - builtin_options_type=_get_builtin_options_type("HashtableImportOptions"), - builtin_options=import_options, - ) - init_subgraph = _build_subgraph( - builder, - tensors=[init_table_tensor, init_keys_tensor, init_values_tensor], - operators=[init_hashtable, init_import], - inputs=[], - outputs=[], - ) - - operator_codes = [ - _build_operator_code(builder, _get_builtin_operator("CALL_ONCE")), - _build_operator_code(builder, _get_builtin_operator("HASHTABLE")), - _build_operator_code(builder, _get_builtin_operator("HASHTABLE_SIZE")), - _build_operator_code(builder, _get_builtin_operator("HASHTABLE_IMPORT")), - ] - buffers = [ - _build_buffer(builder), - _build_buffer(builder, table_values.tobytes()), + _build_buffer(builder, table_values), ] return _finish_tflite_model( builder, subgraph=main_subgraph, - extra_subgraphs=[init_subgraph], operator_codes=operator_codes, buffers=buffers, ) @@ -5867,8 +5783,8 @@ def _build_tflite_hashtable_size_uninitialized_model(): table_options = _build_hashtable_options(builder, table_id=0) size_options = _build_empty_builtin_options(builder, "HashtableSizeOptions") - table_tensor = _build_tensor(builder, 0, [], tensor_type=resource_type) - size_tensor = _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32) + table_tensor = _build_tensor(builder, 0, [1], tensor_type=resource_type) + size_tensor = _build_tensor(builder, 0, [1], tensor_type=_tfl_tensor_type.INT64) hashtable = _build_operator( builder, 0, @@ -5933,90 +5849,10 @@ def test_read_variable_uninitialized_unsupported(): _load_model_from_buffer(_build_tflite_resource_read_uninitialized_model()) -def test_hashtable_call_once_import_find(): - """Test finding values in a hashtable initialized by a supported CALL_ONCE subgraph.""" - mod = _load_model_from_buffer(_build_tflite_hashtable_find_model()) - - @I.ir_module - class Expected: - @R.function - def main(tvmgen_tensor_0: R.Tensor((3,), dtype="int32")) -> R.Tensor((3,), dtype="int32"): - R.func_attr({"num_input": 1}) - with R.dataflow(): - lv: R.Tensor((3, 1), dtype="int32") = R.expand_dims(tvmgen_tensor_0, axis=[-1]) - lv1: R.Tensor((3, 2), dtype="bool") = R.equal(lv, R.const([10, 20], "int32")) - lv2: R.Tensor((3, 2), dtype="int32") = R.astype(lv1, dtype="int32") - lv3: R.Tensor((3,), dtype="int32") = R.max(lv2, axis=[-1], keepdims=False) - lv4: R.Tensor((3,), dtype="bool") = R.greater(lv3, R.const(0, "int32")) - lv5: R.Tensor((3,), dtype="int64") = R.argmax(lv2, axis=-1, keepdims=False) - lv6: R.Tensor((3,), dtype="int32") = R.take( - R.const([100, 200], "int32"), lv5, axis=0, mode="fast" - ) - lv7: R.Tensor((3,), dtype="int32") = R.broadcast_to( - R.const(-1, "int32"), R.shape([3]) - ) - gv: R.Tensor((3,), dtype="int32") = R.where(lv4, lv6, lv7) - R.output(gv) - return gv - - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_hashtable_call_once_import_find_all_mismatch(): - """Test HASHTABLE_FIND returns the default value for keys absent from a static table.""" - mod = _load_model_from_buffer(_build_tflite_hashtable_find_model()) - - ex = tvm.compile(mod, tvm.target.Target("llvm")) - vm = relax.VirtualMachine(ex, tvm.cpu()) - vm.set_input("main", np.array([30, 40, 50], dtype=np.int32)) - vm.invoke_stateful("main") - tvm_output = vm.get_outputs("main") - - np.testing.assert_array_equal(tvm_output.numpy(), np.array([-1, -1, -1], dtype=np.int32)) - - -def test_hashtable_call_once_import_find_scalar_query(): - """Test HASHTABLE_FIND supports scalar query keys.""" - mod = _load_model_from_buffer(_build_tflite_hashtable_find_model(scalar_query=True)) - - ex = tvm.compile(mod, tvm.target.Target("llvm")) - vm = relax.VirtualMachine(ex, tvm.cpu()) - vm.set_input("main", np.array(20, dtype=np.int32)) - vm.invoke_stateful("main") - tvm_output = vm.get_outputs("main") - - np.testing.assert_array_equal(tvm_output.numpy(), np.array(200, dtype=np.int32)) - - -def test_hashtable_call_once_import_find_vector_values(): - """Test finding vector values in a hashtable initialized by CALL_ONCE.""" - mod = _load_model_from_buffer(_build_tflite_hashtable_find_model(vector_values=True)) - - @I.ir_module - class Expected: - @R.function - def main(tvmgen_tensor_0: R.Tensor((3,), dtype="int32")) -> R.Tensor((3, 2), dtype="int32"): - R.func_attr({"num_input": 1}) - with R.dataflow(): - lv: R.Tensor((3, 1), dtype="int32") = R.expand_dims(tvmgen_tensor_0, axis=[-1]) - lv1: R.Tensor((3, 2), dtype="bool") = R.equal(lv, R.const([10, 20], "int32")) - lv2: R.Tensor((3, 2), dtype="int32") = R.astype(lv1, dtype="int32") - lv3: R.Tensor((3,), dtype="int32") = R.max(lv2, axis=[-1], keepdims=False) - lv4: R.Tensor((3,), dtype="bool") = R.greater(lv3, R.const(0, "int32")) - lv5: R.Tensor((3, 1), dtype="bool") = R.expand_dims(lv4, axis=[-1]) - lv6: R.Tensor((3, 2), dtype="bool") = R.broadcast_to(lv5, R.shape([3, 2])) - lv7: R.Tensor((3,), dtype="int64") = R.argmax(lv2, axis=-1, keepdims=False) - lv8: R.Tensor((3, 2), dtype="int32") = R.take( - R.const([[100, 101], [200, 201]], "int32"), lv7, axis=0, mode="fast" - ) - lv9: R.Tensor((3, 2), dtype="int32") = R.broadcast_to( - R.const([-1, -2], "int32"), R.shape([3, 2]) - ) - gv: R.Tensor((3, 2), dtype="int32") = R.where(lv6, lv8, lv9) - R.output(gv) - return gv - - tvm.ir.assert_structural_equal(mod, Expected) +def test_hashtable_call_once_import_find_unsupported(): + """Test HASHTABLE_FIND remains unsupported until TFLite string tensors are supported.""" + with pytest.raises(tvm.error.OpNotImplemented, match="TensorType.STRING"): + _load_model_from_buffer(_build_tflite_hashtable_find_model()) def test_hashtable_call_once_import_size(): @@ -6026,10 +5862,10 @@ def test_hashtable_call_once_import_size(): @I.ir_module class Expected: @R.function - def main() -> R.Tensor((), dtype="int32"): + def main() -> R.Tensor((1,), dtype="int64"): R.func_attr({"num_input": 0}) with R.dataflow(): - gv: R.Tensor((), dtype="int32") = R.const(2, "int32") + gv: R.Tensor((1,), dtype="int64") = R.const([2], "int64") R.output(gv) return gv @@ -6042,12 +5878,6 @@ def test_hashtable_import_main_subgraph_unsupported(): _load_model_from_buffer(_build_tflite_hashtable_import_in_main_model()) -def test_hashtable_import_nonconstant_unsupported(): - """Test HASHTABLE_IMPORT rejects non-constant keys or values.""" - with pytest.raises(tvm.error.OpNotImplemented, match="requires constant keys and values"): - _load_model_from_buffer(_build_tflite_hashtable_import_nonconstant_model()) - - def test_hashtable_size_uninitialized_unsupported(): """Test HASHTABLE_SIZE rejects tables without supported initialization.""" with pytest.raises(tvm.error.OpNotImplemented, match="HASHTABLE_SIZE requires a table"):