From a0d898f94bc78a753686b873a2d0b25d682265b2 Mon Sep 17 00:00:00 2001 From: LudovicoYIN Date: Fri, 29 May 2026 06:36:56 +0000 Subject: [PATCH] [Relax][Frontend][TFLite] Rebase LSTM and SVDF onto main --- .../relax/frontend/tflite/tflite_frontend.py | 451 ++++--- tests/python/relax/test_frontend_tflite.py | 1170 +++++++++-------- 2 files changed, 830 insertions(+), 791 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 87f0f12b1bbd..87697dc6addf 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -253,6 +253,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "LOGICAL_NOT": self.convert_logical_not, "LOGICAL_OR": functools.partial(self._convert_logical_binary, relax_op=_op.logical_or), "LOGISTIC": self.convert_logistic, + "LSTM": self.convert_lstm, "MATRIX_DIAG": self.convert_matrix_diag, "MATRIX_SET_DIAG": self.convert_matrix_set_diag, "MAX_POOL_2D": functools.partial(self.convert_pool2d, pool_type="max"), @@ -273,7 +274,6 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "POW": functools.partial(self._convert_elemwise, relax_op=_op.power), "PRELU": self.convert_prelu, "RANGE": self.convert_range, - "RNN": self.convert_rnn, "QUANTIZE": self.convert_quantize, "RANDOM_STANDARD_NORMAL": self.convert_random_standard_normal, "RANDOM_UNIFORM": self.convert_random_uniform, @@ -282,7 +282,6 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "REDUCE_MAX": functools.partial(self._convert_reduce, relax_op=_op.max), "REDUCE_MIN": functools.partial(self._convert_reduce, relax_op=_op.min), "REDUCE_PROD": functools.partial(self._convert_reduce, relax_op=_op.prod), - "REDUCE_WINDOW": self.convert_reduce_window, "RELU": self.convert_relu, "RELU6": self.convert_relu6, "RELU_N1_TO_1": self.convert_relu_n1_to_1, @@ -376,6 +375,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "STRIDED_SLICE": self.convert_strided_slice, "SUB": functools.partial(self._convert_elemwise, relax_op=_op.subtract), "SUM": functools.partial(self._convert_reduce, relax_op=_op.sum), + "SVDF": self.convert_svdf, "TAN": functools.partial(self._convert_unary_elemwise, relax_op=_op.tan), "TANH": self.convert_tanh, "TILE": self.convert_tile, @@ -3447,171 +3447,6 @@ def _convert_reduce(self, relax_op, op): return out - def convert_reduce_window(self, op): - """Convert TFLite REDUCE_WINDOW.""" - - from tflite.BuiltinOptions2 import BuiltinOptions2 - from tflite.ReduceWindowFunction import ReduceWindowFunction - from tflite.ReduceWindowOptions import ReduceWindowOptions - - input_tensors = self.get_input_tensors(op) - output_tensors = self.get_output_tensors(op) - if len(input_tensors) != 5: - raise tvm.error.OpAttributeUnImplemented( - "TFLite REDUCE_WINDOW requires 5 input tensors." - ) - if len(output_tensors) != 1: - raise tvm.error.OpAttributeUnImplemented( - "TFLite REDUCE_WINDOW requires 1 output tensor." - ) - - if op.BuiltinOptions2Type() != BuiltinOptions2.ReduceWindowOptions: - raise tvm.error.OpAttributeUnImplemented( - "TFLite REDUCE_WINDOW requires ReduceWindowOptions." - ) - - ( - input_tensor, - init_tensor, - window_shape_tensor, - window_strides_tensor, - window_dilations_tensor, - ) = input_tensors - output_tensor = output_tensors[0] - - if any( - self.has_expr(tensor.tensor_idx) - for tensor in [window_shape_tensor, window_strides_tensor, window_dilations_tensor] - ): - raise tvm.error.OpNotImplemented( - "TFLite REDUCE_WINDOW requires constant window_shape, " - "window_strides, and window_dilations." - ) - - input_shape = to_int_list(self.get_tensor_shape(input_tensor)) - output_shape = to_int_list(self.get_tensor_shape(output_tensor)) - input_dtype = self.get_tensor_type_str(input_tensor.tensor.Type()) - output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type()) - - if input_tensor.qnn_params or output_tensor.qnn_params: - raise tvm.error.OpNotImplemented( - "Quantized TFLite REDUCE_WINDOW is not yet supported in the Relax frontend." - ) - - if input_dtype != output_dtype: - raise tvm.error.OpAttributeUnImplemented( - "TFLite REDUCE_WINDOW requires input and output dtypes to match." - ) - - init_shape = to_int_list(self.get_tensor_shape(init_tensor)) - if math.prod(init_shape) != 1: - raise tvm.error.OpNotImplemented( - "TFLite REDUCE_WINDOW requires init_value to contain exactly one element." - ) - - options = ReduceWindowOptions() - op_options = op.BuiltinOptions2() - options.Init(op_options.Bytes, op_options.Pos) - reduce_function = options.ReduceFunction() - - if reduce_function == ReduceWindowFunction.UNSUPPORTED: - raise tvm.error.OpNotImplemented( - "TFLite REDUCE_WINDOW with UNSUPPORTED reduce_function is not supported." - ) - - window_shape = to_int_list(self.get_tensor_value(window_shape_tensor)) - window_strides = to_int_list(self.get_tensor_value(window_strides_tensor)) - window_dilations = to_int_list(self.get_tensor_value(window_dilations_tensor)) - rank = len(input_shape) - - if not (len(window_shape) == len(window_strides) == len(window_dilations) == rank): - raise tvm.error.OpAttributeUnImplemented( - "TFLite REDUCE_WINDOW window_shape, window_strides, and window_dilations " - "must match input rank." - ) - - if any(value <= 0 for value in window_shape + window_strides + window_dilations): - raise tvm.error.OpAttributeUnImplemented( - "TFLite REDUCE_WINDOW window dimensions, strides, and dilations must be positive." - ) - - dilated_window_shape = [ - (window_dim - 1) * dilation + 1 - for window_dim, dilation in zip(window_shape, window_dilations) - ] - expected_output_shape = [ - 0 if input_dim < dilated_dim else (input_dim - dilated_dim) // stride + 1 - for input_dim, dilated_dim, stride in zip( - input_shape, dilated_window_shape, window_strides - ) - ] - - numeric_reduce_functions = { - ReduceWindowFunction.ADD: (relax.op.sum, relax.op.add), - ReduceWindowFunction.MUL: (relax.op.prod, relax.op.multiply), - ReduceWindowFunction.MINIMUM: (relax.op.min, relax.op.minimum), - ReduceWindowFunction.MAXIMUM: (relax.op.max, relax.op.maximum), - } - bool_reduce_functions = { - ReduceWindowFunction.ALL: (relax.op.min, relax.op.logical_and), - ReduceWindowFunction.ANY: (relax.op.max, relax.op.logical_or), - } - - if reduce_function in numeric_reduce_functions and input_dtype == "bool": - raise tvm.error.OpAttributeUnImplemented( - "TFLite REDUCE_WINDOW numeric reductions expect numeric input." - ) - if reduce_function in bool_reduce_functions and input_dtype != "bool": - raise tvm.error.OpAttributeUnImplemented( - "TFLite REDUCE_WINDOW boolean reductions expect bool input." - ) - - if output_shape != expected_output_shape: - raise tvm.error.OpAttributeUnImplemented( - "TFLite REDUCE_WINDOW output shape does not match input/window parameters." - ) - - if any(output_dim == 0 for output_dim in output_shape): - return relax.op.zeros(output_shape, output_dtype) - - data = self.get_tensor_expr(input_tensor) - init_value = self.get_tensor_expr(init_tensor) - if len(init_shape) != 0: - init_value = relax.op.reshape(init_value, []) - - windowed = relax.op.call_dps_packed( - "topi.sliding_window", - ( - data, - 0, - relax.ShapeExpr(dilated_window_shape), - relax.ShapeExpr(window_strides), - ), - out_sinfo=relax.TensorStructInfo(output_shape + dilated_window_shape, input_dtype), - ) - - if any(dilation != 1 for dilation in window_dilations): - windowed = relax.op.strided_slice( - windowed, - axes=list(range(rank, 2 * rank)), - begin=[0] * rank, - end=dilated_window_shape, - strides=window_dilations, - ) - - reduce_axes = list(range(rank, 2 * rank)) - if reduce_function in numeric_reduce_functions: - reduce_op, combine_op = numeric_reduce_functions[reduce_function] - return combine_op(reduce_op(windowed, axis=reduce_axes), init_value) - if reduce_function in bool_reduce_functions: - reduce_op, combine_op = bool_reduce_functions[reduce_function] - reduced = reduce_op(relax.op.astype(windowed, "int8"), axis=reduce_axes) - return combine_op(relax.op.astype(reduced, "bool"), init_value) - - raise tvm.error.OpNotImplemented( - f"TFLite REDUCE_WINDOW reduce_function {reduce_function} is not supported." - ) - def _convert_reduce_bool(self, relax_op, op): """Convert TFLite REDUCE_ANY / REDUCE_ALL (bool-only ops). @@ -5045,83 +4880,263 @@ def convert_unpack(self, op): return squeezed - def convert_rnn(self, op): - """Convert TFLite RNN. - - Single-step RNN cell. - - Inputs (5 tensors): - [0] input [batch, input_size] - [1] input_weights [num_units, input_size] - [2] recurrent_weights [num_units, num_units] - [3] bias [num_units] - [4] hidden_state [batch, num_units] (variable, zero-initialised) + def convert_lstm(self, op): + """Convert TFLite LSTM (single-step). + + Standard LSTM cell with FULL kernel and coupled input-forget gate. + Peephole, projection, and layer norm are not supported. + + Inputs (24 tensors, many optional): + [0] input [batch, input_size] + [1] input_to_input_weights (optional, -1 => coupled) + [2] input_to_forget_weights [num_units, input_size] + [3] input_to_cell_weights [num_units, input_size] + [4] input_to_output_weights [num_units, input_size] + [5] recurrent_to_input_weights (optional) + [6] recurrent_to_forget_weights [num_units, num_units] + [7] recurrent_to_cell_weights [num_units, num_units] + [8] recurrent_to_output_weights [num_units, num_units] + [9-11] cell_to_*_weights (optional, not supported) + [12] input_gate_bias (optional) + [13] forget_gate_bias [num_units] + [14] cell_bias [num_units] + [15] output_gate_bias [num_units] + [16-17] projection_weights/bias (optional, not supported) + [18] output_state [batch, num_units] + [19] cell_state [batch, num_units] + [20-23] layer_norm (optional, not supported) Output: [0] output [batch, num_units] - Cell equation: - h = fused_activation(x @ W.T + h @ Wr.T + b) + Cell (coupled input-forget): + f = sigmoid(x @ W_f.T + h @ R_f.T + b_f) + i = 1 - f + g = tanh(x @ W_c.T + h @ R_c.T + b_c) + o = sigmoid(x @ W_o.T + h @ R_o.T + b_o) + c_new = f * c_prev + i * g + h_new = fused_activation(o * tanh(c_new)) """ from tflite.BuiltinOptions import BuiltinOptions - from tflite.RNNOptions import RNNOptions + from tflite.LSTMOptions import LSTMOptions if self.is_quantized(op): - raise tvm.error.OpNotImplemented("TFLite quantized RNN is not supported yet.") + raise tvm.error.OpNotImplemented("TFLite quantized LSTM is not supported yet.") input_tensors = self.get_input_tensors(op) - assert len(input_tensors) == 5, "input tensors length should be 5" - - input_tensor = input_tensors[0] - weights_tensor = input_tensors[1] - recurrent_tensor = input_tensors[2] - bias_tensor = input_tensors[3] - hidden_state_tensor = input_tensors[4] + assert len(input_tensors) == 24, ( + f"input tensors length should be 24, got {len(input_tensors)}" + ) output_tensors = self.get_output_tensors(op) assert len(output_tensors) >= 1, "output tensors length should be at least 1" - assert op.BuiltinOptionsType() == BuiltinOptions.RNNOptions + assert op.BuiltinOptionsType() == BuiltinOptions.LSTMOptions op_options = op.BuiltinOptions() - rnn_options = RNNOptions() - rnn_options.Init(op_options.Bytes, op_options.Pos) - fused_activation_fn = rnn_options.FusedActivationFunction() + lstm_opts = LSTMOptions() + lstm_opts.Init(op_options.Bytes, op_options.Pos) - # Constant weight/bias expressions. - weights_expr = self.get_tensor_expr(weights_tensor) # [num_units, input_size] - recurrent_expr = self.get_tensor_expr(recurrent_tensor) # [num_units, num_units] - bias_expr = self.get_tensor_expr(bias_tensor) # [num_units] + fused_activation_fn = lstm_opts.FusedActivationFunction() + cell_clip = lstm_opts.CellClip() + proj_clip = lstm_opts.ProjClip() - # Transpose to [input_size, num_units] and [num_units, num_units] for x @ W.T. - w_t = relax.op.permute_dims(weights_expr) - wr_t = relax.op.permute_dims(recurrent_expr) + in_expr = self.get_tensor_expr(input_tensors[0]) - # Resolve the input expression. - in_expr = self.get_tensor_expr(input_tensor) + # Only coupled input-forget gate is supported. + if input_tensors[1].tensor_idx != -1 or input_tensors[5].tensor_idx != -1: + raise tvm.error.OpNotImplemented("Only coupled input-forget LSTM is supported.") - # Initial hidden state: use the model's tensor value when available (non-zero init or - # graph input), otherwise fall back to zeros for the common variable-tensor case. - h_dtype = self.get_tensor_type_str(hidden_state_tensor.tensor.Type()) - if self.has_expr(hidden_state_tensor.tensor_idx) or ( - hidden_state_tensor.buffer is not None and hidden_state_tensor.buffer.DataLength() > 0 + # Peephole, projection, and layer norm are not modeled yet. + if ( + any(t.tensor_idx != -1 for t in input_tensors[9:12]) + or any(t.tensor_idx != -1 for t in input_tensors[16:18]) + or any(t.tensor_idx != -1 for t in input_tensors[20:24]) ): - h = self.get_tensor_expr(hidden_state_tensor) - else: - h_shape = tuple(to_int_list(self.get_tensor_shape(hidden_state_tensor))) - h = relax.op.zeros(h_shape, dtype=h_dtype) + raise tvm.error.OpNotImplemented( + "Peephole, projection, and layer norm LSTM are not supported yet." + ) + + # Weights. + w_f = self.get_tensor_expr(input_tensors[2]) + w_c = self.get_tensor_expr(input_tensors[3]) + w_o = self.get_tensor_expr(input_tensors[4]) + + r_f = self.get_tensor_expr(input_tensors[6]) + r_c = self.get_tensor_expr(input_tensors[7]) + r_o = self.get_tensor_expr(input_tensors[8]) + + # Biases. + b_f = self.get_tensor_expr(input_tensors[13]) + b_c = self.get_tensor_expr(input_tensors[14]) + b_o = self.get_tensor_expr(input_tensors[15]) + + # State inputs. + h_prev = self.get_tensor_expr(input_tensors[18]) + c_prev = self.get_tensor_expr(input_tensors[19]) - gates = relax.op.add( - relax.op.add(relax.op.matmul(in_expr, w_t), relax.op.matmul(h, wr_t)), - bias_expr, + # Coupled input-forget gate. + f = relax.op.sigmoid( + relax.op.add( + relax.op.add( + relax.op.matmul(in_expr, relax.op.permute_dims(w_f)), + relax.op.matmul(h_prev, relax.op.permute_dims(r_f)), + ), + b_f, + ) + ) + i = relax.op.subtract( + relax.const(1.0, "float32"), + f, + ) + + # Cell candidate. + g = relax.op.tanh( + relax.op.add( + relax.op.add( + relax.op.matmul(in_expr, relax.op.permute_dims(w_c)), + relax.op.matmul(h_prev, relax.op.permute_dims(r_c)), + ), + b_c, + ) ) - h = self.convert_fused_activation_function(gates, fused_activation_fn) + # Output gate. + o = relax.op.sigmoid( + relax.op.add( + relax.op.add( + relax.op.matmul(in_expr, relax.op.permute_dims(w_o)), + relax.op.matmul(h_prev, relax.op.permute_dims(r_o)), + ), + b_o, + ) + ) + + # Cell state update with optional clipping. + c_new = relax.op.add( + relax.op.multiply(f, c_prev), + relax.op.multiply(i, g), + ) + if cell_clip > 0: + c_new = relax.op.clip(c_new, -cell_clip, cell_clip) + + # Hidden state. + # TFLite applies the fused activation to the cell state before the + # output gate multiply. + h_new = relax.op.multiply( + o, self.convert_fused_activation_function(c_new, fused_activation_fn) + ) + if proj_clip > 0: + h_new = relax.op.clip(h_new, -proj_clip, proj_clip) + + # Update state tensors in the expression table for subsequent ops. + self.exp_tab.set_expr( + get_tensor_name(self.subgraph, input_tensors[18].tensor_idx), + h_new, + force_override=True, + ) + self.exp_tab.set_expr( + get_tensor_name(self.subgraph, input_tensors[19].tensor_idx), + c_new, + force_override=True, + ) + + return h_new + + def convert_svdf(self, op): + """Convert TFLite SVDF (single-step). + + Structured-Vectorized Bidirectional Filter for keyword spotting. + + Inputs (5 tensors): + [0] input [batch, input_size] + [1] feature_weights [num_filters, input_size] + [2] time_weights [num_filters, memory_size] + [3] bias [num_filters] (optional) + [4] state [batch, num_filters * memory_size] (variable) + + Output: + [0] output [batch, num_units] + + Computation: + feat = x @ W_feat.T # feature projection + state_r = reshape(state, [B, F, memory_size]) # ring buffer + time = sum(state_r * time_weights, axis=-1) # time filtering + out = activation(sum(reshape(time, [B, U, rank]), axis=-1) + bias) + """ + from tflite.BuiltinOptions import BuiltinOptions + from tflite.SVDFOptions import SVDFOptions + + if self.is_quantized(op): + raise tvm.error.OpNotImplemented("TFLite quantized SVDF is not supported yet.") + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 5, ( + f"input tensors length should be 5, got {len(input_tensors)}" + ) + + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) >= 1, "output tensors length should be at least 1" + + assert op.BuiltinOptionsType() == BuiltinOptions.SVDFOptions + op_options = op.BuiltinOptions() + svdf_opts = SVDFOptions() + svdf_opts.Init(op_options.Bytes, op_options.Pos) + + rank = svdf_opts.Rank() + fused_activation_fn = svdf_opts.FusedActivationFunction() + + in_expr = self.get_tensor_expr(input_tensors[0]) + feat_weights = self.get_tensor_expr(input_tensors[1]) + time_weights = self.get_tensor_expr(input_tensors[2]) + + batch_size = self.get_tensor_shape(input_tensors[0])[0] + if isinstance(batch_size, np.integer | int): + batch_size = int(batch_size) + num_filters = to_int_list(self.get_tensor_shape(input_tensors[1]))[0] + if num_filters % rank != 0: + raise tvm.error.OpNotImplemented("SVDF num_filters must be divisible by rank.") + num_units = num_filters // rank + memory_size = to_int_list(self.get_tensor_shape(input_tensors[2]))[1] + + # Feature projection: [batch, input_size] @ [input_size, num_filters] + feat = relax.op.matmul(in_expr, relax.op.permute_dims(feat_weights)) + + # Time filtering: reshape state -> weight -> reduce. + state_expr = self.get_tensor_expr(input_tensors[4]) + state_3d = relax.op.reshape(state_expr, (batch_size, num_filters, memory_size)) + + # time_weights: [num_filters, memory_size], broadcast to [1, num_filters, memory_size] + tw_3d = relax.op.reshape(time_weights, (1, num_filters, memory_size)) + time_weighted = relax.op.multiply(state_3d, tw_3d) + time_output = relax.op.sum(time_weighted, axis=-1, keepdims=False) + reduced = relax.op.reshape(time_output, (batch_size, num_units, rank)) + result = relax.op.sum(reduced, axis=-1, keepdims=False) + + # Add bias if present + if input_tensors[3].tensor_idx != -1: + bias_expr = self.get_tensor_expr(input_tensors[3]) + result = relax.op.add(result, bias_expr) + + result = self.convert_fused_activation_function(result, fused_activation_fn) + + # Update state tensor in the expression table for subsequent steps. + # SVDF state is a FIFO ring-buffer: shift left by 1, append new feat. + feat_3d = relax.op.expand_dims(feat, axis=-1) + if memory_size > 1: + shifted_state = relax.op.strided_slice( + state_3d, axes=[2], begin=[1], end=[int(memory_size)] + ) + new_state_3d = relax.op.concat([shifted_state, feat_3d], axis=2) + else: + new_state_3d = feat_3d + new_state = relax.op.reshape(new_state_3d, (batch_size, num_filters * memory_size)) self.exp_tab.set_expr( - get_tensor_name(self.subgraph, hidden_state_tensor.tensor_idx), - h, + get_tensor_name(self.subgraph, input_tensors[4].tensor_idx), + new_state, force_override=True, ) - return h + + return result def convert_unidirectional_sequence_rnn(self, op): """Convert TFLite UNIDIRECTIONAL_SEQUENCE_RNN. diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 7c5951d631ea..263943ad6ae0 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -85,7 +85,7 @@ def verify(TestClass, expected=None): tf_output = cf(*tf_inputs) # TVM Run - tgt = tvm.target.Target("llvm") + tgt = tvm.target.Target("c") ex = tvm.compile(mod, tgt) vm = relax.VirtualMachine(ex, tvm.cpu()) vm.set_input("main", *tvm_inputs) @@ -110,7 +110,7 @@ def _verify_random_with_inputs(cfunc, inputs): tf_output = cfunc(*tf_inputs) - tgt = tvm.target.Target("llvm") + tgt = tvm.target.Target("c") ex = tvm.compile(mod, tgt) vm = relax.VirtualMachine(ex, tvm.cpu()) @@ -3705,7 +3705,6 @@ def _get_tflite_schema_enum(enum_name): _tfl_operator = _get_tflite_schema_module("Operator") _tfl_operator_code = _get_tflite_schema_module("OperatorCode") _tfl_quantization_parameters = _get_tflite_schema_module("QuantizationParameters") -_tfl_reduce_window_options = _get_tflite_schema_module("ReduceWindowOptions") _tfl_sparsity_parameters = _get_tflite_schema_module("SparsityParameters") _tfl_subgraph = _get_tflite_schema_module("SubGraph") _tfl_tensor = _get_tflite_schema_module("Tensor") @@ -3718,12 +3717,12 @@ def _get_tflite_schema_enum(enum_name): _tfl_dimension_type = _get_tflite_schema_enum("DimensionType") _tfl_fc_weights_format = _get_tflite_schema_enum("FullyConnectedOptionsWeightsFormat") _tfl_padding = _get_tflite_schema_enum("Padding") -_tfl_reduce_window_function = _get_tflite_schema_enum("ReduceWindowFunction") _tfl_sparse_index_vector = _get_tflite_schema_enum("SparseIndexVector") _tfl_tensor_type = _get_tflite_schema_enum("TensorType") -_tfl_rnn_options = _get_tflite_schema_module("RNNOptions") +_tfl_lstm_options = _get_tflite_schema_module("LSTMOptions") _tfl_sequence_rnn_options = _get_tflite_schema_module("SequenceRNNOptions") +_tfl_svdf_options = _get_tflite_schema_module("SVDFOptions") _DENSIFY_TEST_VALUES = np.array([1.0, 2.0], dtype=np.float32) _DENSIFY_TEST_DENSE = np.array([[1.0, 0.0], [0.0, 2.0]], dtype=np.float32) @@ -3954,410 +3953,6 @@ def _load_model_from_buffer(model_bytes): return mod -def _build_reduce_window_options(builder, reduce_function): - _tfl_reduce_window_options.ReduceWindowOptionsStart(builder) - _tfl_reduce_window_options.ReduceWindowOptionsAddReduceFunction(builder, reduce_function) - return _tfl_reduce_window_options.ReduceWindowOptionsEnd(builder) - - -def _reduce_window_output_shape(input_shape, window_shape, window_strides, window_dilations): - output_shape = [] - for input_dim, window_dim, stride, dilation in zip( - input_shape, window_shape, window_strides, window_dilations - ): - dilated_window = (window_dim - 1) * dilation + 1 - if stride <= 0: - output_shape.append(0) - elif input_dim < dilated_window: - output_shape.append(0) - else: - output_shape.append((input_dim - dilated_window) // stride + 1) - return tuple(output_shape) - - -def _build_reduce_window_model( - *, - input_shape, - init_value, - init_shape=(), - window_shape, - window_strides, - window_dilations, - output_shape=None, - reduce_function, - tensor_type=None, - value_dtype=np.float32, -): - builder = flatbuffers.Builder(1024) - if tensor_type is None: - tensor_type = _tfl_tensor_type.FLOAT32 - - input_tensor_idx = 0 - init_tensor_idx = 1 - window_shape_tensor_idx = 2 - window_strides_tensor_idx = 3 - window_dilations_tensor_idx = 4 - output_tensor_idx = 5 - - if output_shape is None: - output_shape = _reduce_window_output_shape( - input_shape, window_shape, window_strides, window_dilations - ) - - input_tensor = _build_tensor(builder, 1, input_shape, tensor_type=tensor_type) - init_tensor = _build_tensor(builder, 2, init_shape, tensor_type=tensor_type) - window_shape_tensor = _build_tensor( - builder, 3, [len(window_shape)], tensor_type=_tfl_tensor_type.INT64 - ) - window_strides_tensor = _build_tensor( - builder, 4, [len(window_strides)], tensor_type=_tfl_tensor_type.INT64 - ) - window_dilations_tensor = _build_tensor( - builder, 5, [len(window_dilations)], tensor_type=_tfl_tensor_type.INT64 - ) - output_tensor = _build_tensor(builder, 6, output_shape, tensor_type=tensor_type) - - reduce_window_opts = _build_reduce_window_options(builder, reduce_function) - reduce_window_op = _build_operator( - builder, - 0, - [ - input_tensor_idx, - init_tensor_idx, - window_shape_tensor_idx, - window_strides_tensor_idx, - window_dilations_tensor_idx, - ], - [output_tensor_idx], - builtin_options2_type=_tfl_builtin_options2.ReduceWindowOptions, - builtin_options2=reduce_window_opts, - ) - - subgraph = _build_subgraph( - builder, - tensors=[ - input_tensor, - init_tensor, - window_shape_tensor, - window_strides_tensor, - window_dilations_tensor, - output_tensor, - ], - operators=[reduce_window_op], - inputs=[input_tensor_idx], - outputs=[output_tensor_idx], - ) - operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.REDUCE_WINDOW)] - - buffers = [ - _build_buffer(builder), - _build_buffer(builder), - _build_buffer(builder, np.asarray([init_value], dtype=value_dtype).tobytes()), - _build_buffer(builder, np.asarray(window_shape, dtype=np.int64).tobytes()), - _build_buffer(builder, np.asarray(window_strides, dtype=np.int64).tobytes()), - _build_buffer(builder, np.asarray(window_dilations, dtype=np.int64).tobytes()), - _build_buffer(builder), - ] - - return _finish_tflite_model( - builder, subgraph=subgraph, operator_codes=operator_codes, buffers=buffers - ) - - -def _from_reduce_window_model(**kwargs): - return _load_model_from_buffer(_build_reduce_window_model(**kwargs)) - - -def _reduce_window_dilated_shape(window_shape, window_dilations): - return [ - (window_dim - 1) * dilation + 1 - for window_dim, dilation in zip(window_shape, window_dilations) - ] - - -def _make_reduce_window_numeric_expected( - *, - input_shape, - init_value, - init_shape=(), - window_shape, - window_strides, - window_dilations, - reduce_op, - combine_op, - dtype="float32", -): - output_shape = _reduce_window_output_shape( - input_shape, window_shape, window_strides, window_dilations - ) - dilated_window_shape = _reduce_window_dilated_shape(window_shape, window_dilations) - rank = len(input_shape) - - bb = relax.BlockBuilder() - x = relax.Var("tvmgen_tensor_0", relax.TensorStructInfo(input_shape, dtype)) - with bb.function("main", [x]): - with bb.dataflow(): - windowed = bb.emit( - relax.op.call_dps_packed( - "topi.sliding_window", - ( - x, - 0, - relax.ShapeExpr(dilated_window_shape), - relax.ShapeExpr(window_strides), - ), - out_sinfo=relax.TensorStructInfo( - output_shape + tuple(dilated_window_shape), dtype - ), - ) - ) - if any(dilation != 1 for dilation in window_dilations): - windowed = bb.emit( - relax.op.strided_slice( - windowed, - axes=list(range(rank, 2 * rank)), - begin=[0] * rank, - end=dilated_window_shape, - strides=window_dilations, - ) - ) - reduced = bb.emit(reduce_op(windowed, axis=list(range(rank, 2 * rank)))) - init = relax.const(np.asarray([init_value], dtype=dtype).reshape(init_shape), dtype) - if len(init_shape) != 0: - init = relax.op.reshape(init, []) - gv = bb.emit_output(combine_op(reduced, init)) - bb.emit_func_output(gv) - - mod = bb.get() - mod["main"] = mod["main"].with_attr("num_input", 1) - return mod - - -def _make_reduce_window_bool_expected( - *, - input_shape, - init_value, - window_shape, - window_strides, - window_dilations, - reduce_op, - combine_op, -): - output_shape = _reduce_window_output_shape( - input_shape, window_shape, window_strides, window_dilations - ) - dilated_window_shape = _reduce_window_dilated_shape(window_shape, window_dilations) - rank = len(input_shape) - - bb = relax.BlockBuilder() - x = relax.Var("tvmgen_tensor_0", relax.TensorStructInfo(input_shape, "bool")) - with bb.function("main", [x]): - with bb.dataflow(): - windowed = bb.emit( - relax.op.call_dps_packed( - "topi.sliding_window", - ( - x, - 0, - relax.ShapeExpr(dilated_window_shape), - relax.ShapeExpr(window_strides), - ), - out_sinfo=relax.TensorStructInfo( - output_shape + tuple(dilated_window_shape), "bool" - ), - ) - ) - cast_windowed = bb.emit(relax.op.astype(windowed, "int8")) - reduced = bb.emit(reduce_op(cast_windowed, axis=list(range(rank, 2 * rank)))) - reduced_bool = bb.emit(relax.op.astype(reduced, "bool")) - gv = bb.emit_output(combine_op(reduced_bool, relax.const(init_value, "bool"))) - bb.emit_func_output(gv) - - mod = bb.get() - mod["main"] = mod["main"].with_attr("num_input", 1) - return mod - - -def _make_reduce_window_empty_expected(*, input_shape, output_shape, dtype="float32"): - bb = relax.BlockBuilder() - x = relax.Var("tvmgen_tensor_0", relax.TensorStructInfo(input_shape, dtype)) - with bb.function("main", [x]): - with bb.dataflow(): - gv = bb.emit_output(relax.op.zeros(output_shape, dtype)) - bb.emit_func_output(gv) - - mod = bb.get() - mod["main"] = mod["main"].with_attr("num_input", 1) - return mod - - -def test_reduce_window_unsupported_function(): - with pytest.raises(tvm.error.OpNotImplemented, match="UNSUPPORTED reduce_function"): - _from_reduce_window_model( - input_shape=(4,), - init_value=0.0, - window_shape=[2], - window_strides=[1], - window_dilations=[1], - reduce_function=_tfl_reduce_window_function.UNSUPPORTED, - ) - - -@pytest.mark.parametrize( - "reduce_function, reduce_op, combine_op", - [ - (_tfl_reduce_window_function.ADD, relax.op.sum, relax.op.add), - (_tfl_reduce_window_function.MUL, relax.op.prod, relax.op.multiply), - (_tfl_reduce_window_function.MINIMUM, relax.op.min, relax.op.minimum), - (_tfl_reduce_window_function.MAXIMUM, relax.op.max, relax.op.maximum), - ], -) -def test_reduce_window_numeric_modes(reduce_function, reduce_op, combine_op): - input_shape = (4, 5) - init_value = 1.0 - window_shape = [2, 2] - window_strides = [1, 2] - window_dilations = [2, 1] - mod = _from_reduce_window_model( - input_shape=input_shape, - init_value=init_value, - window_shape=window_shape, - window_strides=window_strides, - window_dilations=window_dilations, - reduce_function=reduce_function, - ) - expected = _make_reduce_window_numeric_expected( - input_shape=input_shape, - init_value=init_value, - window_shape=window_shape, - window_strides=window_strides, - window_dilations=window_dilations, - reduce_op=reduce_op, - combine_op=combine_op, - ) - tvm.ir.assert_structural_equal(mod, expected) - - -def test_reduce_window_one_element_init_tensor(): - input_shape = (4,) - init_value = 1.0 - init_shape = (1,) - window_shape = [2] - window_strides = [1] - window_dilations = [1] - mod = _from_reduce_window_model( - input_shape=input_shape, - init_value=init_value, - init_shape=init_shape, - window_shape=window_shape, - window_strides=window_strides, - window_dilations=window_dilations, - reduce_function=_tfl_reduce_window_function.ADD, - ) - expected = _make_reduce_window_numeric_expected( - input_shape=input_shape, - init_value=init_value, - init_shape=init_shape, - window_shape=window_shape, - window_strides=window_strides, - window_dilations=window_dilations, - reduce_op=relax.op.sum, - combine_op=relax.op.add, - ) - tvm.ir.assert_structural_equal(mod, expected) - - -@pytest.mark.parametrize( - "reduce_function, reduce_op, combine_op, init_value", - [ - (_tfl_reduce_window_function.ALL, relax.op.min, relax.op.logical_and, True), - (_tfl_reduce_window_function.ANY, relax.op.max, relax.op.logical_or, False), - ], -) -def test_reduce_window_bool_modes(reduce_function, reduce_op, combine_op, init_value): - input_shape = (5,) - window_shape = [3] - window_strides = [2] - window_dilations = [1] - mod = _from_reduce_window_model( - input_shape=input_shape, - init_value=init_value, - window_shape=window_shape, - window_strides=window_strides, - window_dilations=window_dilations, - reduce_function=reduce_function, - tensor_type=_tfl_tensor_type.BOOL, - value_dtype=np.bool_, - ) - expected = _make_reduce_window_bool_expected( - input_shape=input_shape, - init_value=init_value, - window_shape=window_shape, - window_strides=window_strides, - window_dilations=window_dilations, - reduce_op=reduce_op, - combine_op=combine_op, - ) - tvm.ir.assert_structural_equal(mod, expected) - - -def test_reduce_window_empty_output_dimension(): - input_shape = (2,) - window_shape = [3] - window_strides = [1] - window_dilations = [1] - mod = _from_reduce_window_model( - input_shape=input_shape, - init_value=0.0, - window_shape=window_shape, - window_strides=window_strides, - window_dilations=window_dilations, - reduce_function=_tfl_reduce_window_function.ADD, - ) - expected = _make_reduce_window_empty_expected( - input_shape=input_shape, - output_shape=(0,), - ) - tvm.ir.assert_structural_equal(mod, expected) - - -def test_reduce_window_mismatched_window_rank(): - with pytest.raises(tvm.error.OpAttributeUnImplemented, match="must match input rank"): - _from_reduce_window_model( - input_shape=(4, 5), - init_value=0.0, - window_shape=[2], - window_strides=[1], - window_dilations=[1], - reduce_function=_tfl_reduce_window_function.ADD, - ) - - -def test_reduce_window_non_positive_stride(): - with pytest.raises(tvm.error.OpAttributeUnImplemented, match="must be positive"): - _from_reduce_window_model( - input_shape=(4,), - init_value=0.0, - window_shape=[2], - window_strides=[0], - window_dilations=[1], - reduce_function=_tfl_reduce_window_function.ADD, - ) - - -def test_reduce_window_inconsistent_output_shape(): - with pytest.raises(tvm.error.OpAttributeUnImplemented, match="output shape"): - _from_reduce_window_model( - input_shape=(5,), - init_value=0.0, - window_shape=[2], - window_strides=[1], - window_dilations=[1], - output_shape=(3,), - reduce_function=_tfl_reduce_window_function.ADD, - ) - - def _get_builtin_operator(builtin_name): if not hasattr(_tfl_builtin_operator, builtin_name): pytest.skip(f"TFLite schema does not provide BuiltinOperator.{builtin_name}") @@ -10128,251 +9723,661 @@ def main( tvm.ir.assert_structural_equal(mod, Expected) -# ── RNN ──────────────────────────────────────────────────────────────────────── +# ── LSTM ────────────────────────────────────────────────────────────────────── -def _build_rnn_model(batch, input_size, num_units, weights, recurrent_weights, bias, activation): - """Build a minimal TFLite flatbuffer model containing one RNN op. - - Tensor layout (indices 0-5): - 0 - input [batch, input_size] - 1 - input_weights [num_units, input_size] (constant) - 2 - recurrent_weights [num_units, num_units] (constant) - 3 - bias [num_units] (constant) - 4 - hidden_state [batch, num_units] (variable, zero-initialised) - 5 - output [batch, num_units] +def _build_lstm_model( + batch, + input_size, + num_units, + input_to_forget_weights, + input_to_cell_weights, + input_to_output_weights, + recurrent_to_forget_weights, + recurrent_to_cell_weights, + recurrent_to_output_weights, + forget_gate_bias, + cell_bias, + output_gate_bias, + activation, + *, + cell_clip=0.0, + proj_clip=0.0, + include_unsupported=False, +): + """Build a minimal TFLite flatbuffer model with one LSTM op (coupled input-forget). + + Tensor indices: + 0 - input [batch, input_size] + 1 - input_to_forget_weights [num_units, input_size] (constant) + 2 - input_to_cell_weights [num_units, input_size] (constant) + 3 - input_to_output_weights [num_units, input_size] (constant) + 4 - recurrent_to_forget_weights [num_units, num_units] (constant) + 5 - recurrent_to_cell_weights [num_units, num_units] (constant) + 6 - recurrent_to_output_weights [num_units, num_units] (constant) + 7 - forget_gate_bias [num_units] (constant) + 8 - cell_bias [num_units] (constant) + 9 - output_gate_bias [num_units] (constant) + 10 - output_state [batch, num_units] (input) + 11 - cell_state [batch, num_units] (input) + 12 - output [batch, num_units] + + Operator input indices (24 entries, -1 for absent): + [0, -1, 1, 2, 3, -1, 4, 5, 6, -1, -1, -1, -1, 7, 8, 9, -1, -1, 10, 11, -1, -1, -1, -1] """ builder = flatbuffers.Builder(4096) - _tfl_rnn_options.RNNOptionsStart(builder) - _tfl_rnn_options.RNNOptionsAddFusedActivationFunction(builder, activation) - rnn_opts = _tfl_rnn_options.RNNOptionsEnd(builder) + _tfl_lstm_options.LSTMOptionsStart(builder) + _tfl_lstm_options.LSTMOptionsAddFusedActivationFunction(builder, activation) + _tfl_lstm_options.LSTMOptionsAddCellClip(builder, cell_clip) + _tfl_lstm_options.LSTMOptionsAddProjClip(builder, proj_clip) + lstm_opts = _tfl_lstm_options.LSTMOptionsEnd(builder) - rnn_op_code = _build_operator_code(builder, _tfl_builtin_operator.RNN) + lstm_op_code = _build_operator_code(builder, _tfl_builtin_operator.LSTM) - def _t(buf_idx, shape, is_variable=False): + def _t(buf_idx, shape): shape_vec = _tflite_shape(builder, shape) _tfl_tensor.TensorStart(builder) _tfl_tensor.TensorAddBuffer(builder, buf_idx) _tfl_tensor.TensorAddHasRank(builder, True) - _tfl_tensor.TensorAddIsVariable(builder, is_variable) + _tfl_tensor.TensorAddIsVariable(builder, False) _tfl_tensor.TensorAddShape(builder, shape_vec) _tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32) return _tfl_tensor.TensorEnd(builder) tensors = [ + # 0: input _t(0, [batch, input_size]), + # 1: input_to_forget_weights (coupled) _t(1, [num_units, input_size]), - _t(2, [num_units, num_units]), - _t(3, [num_units]), - _t(4, [batch, num_units], is_variable=True), - _t(5, [batch, num_units]), + # 2: input_to_cell_weights + _t(2, [num_units, input_size]), + # 3: input_to_output_weights + _t(3, [num_units, input_size]), + # 4: recurrent_to_forget_weights (coupled) + _t(4, [num_units, num_units]), + # 5: recurrent_to_cell_weights + _t(5, [num_units, num_units]), + # 6: recurrent_to_output_weights + _t(6, [num_units, num_units]), + # 7: forget_gate_bias (coupled) + _t(7, [num_units]), + # 8: cell_bias + _t(8, [num_units]), + # 9: output_gate_bias + _t(9, [num_units]), + # 10: output_state (input) + _t(0, [batch, num_units]), + # 11: cell_state (input) + _t(0, [batch, num_units]), + # 12: output + _t(0, [batch, num_units]), ] - rnn_op = _build_operator( + if include_unsupported: + tensors.extend( + [ + _t(0, [num_units]), + _t(0, [num_units]), + _t(0, [num_units]), + _t(0, [num_units, num_units]), + _t(0, [num_units]), + _t(0, [num_units]), + _t(0, [num_units]), + _t(0, [num_units]), + _t(0, [num_units]), + ] + ) + + # Operator input indices: -1 for absent optional inputs + lstm_inputs = [ + 0, + -1, + 1, + 2, + 3, + -1, + 4, + 5, + 6, + 13 if include_unsupported else -1, + 14 if include_unsupported else -1, + 15 if include_unsupported else -1, + -1, + 7, + 8, + 9, + 16 if include_unsupported else -1, + 17 if include_unsupported else -1, + 10, + 11, + 18 if include_unsupported else -1, + 19 if include_unsupported else -1, + 20 if include_unsupported else -1, + 21 if include_unsupported else -1, + ] + + lstm_op = _build_operator( builder, 0, - [0, 1, 2, 3, 4], - [5], - builtin_options_type=_tfl_builtin_options.RNNOptions, - builtin_options=rnn_opts, + lstm_inputs, + [12], + builtin_options_type=_tfl_builtin_options.LSTMOptions, + builtin_options=lstm_opts, ) subgraph = _build_subgraph( builder, tensors=tensors, - operators=[rnn_op], - inputs=[0], - outputs=[5], + operators=[lstm_op], + inputs=[0, 10, 11], + outputs=[12], ) buffers = [ - _build_buffer(builder), - _build_buffer(builder, weights.tobytes()), - _build_buffer(builder, recurrent_weights.tobytes()), - _build_buffer(builder, bias.tobytes()), - _build_buffer(builder), - _build_buffer(builder), + _build_buffer(builder), # 0: empty + _build_buffer(builder, input_to_forget_weights.tobytes()), # 1 + _build_buffer(builder, input_to_cell_weights.tobytes()), # 2 + _build_buffer(builder, input_to_output_weights.tobytes()), # 3 + _build_buffer(builder, recurrent_to_forget_weights.tobytes()), # 4 + _build_buffer(builder, recurrent_to_cell_weights.tobytes()), # 5 + _build_buffer(builder, recurrent_to_output_weights.tobytes()), # 6 + _build_buffer(builder, forget_gate_bias.tobytes()), # 7 + _build_buffer(builder, cell_bias.tobytes()), # 8 + _build_buffer(builder, output_gate_bias.tobytes()), # 9 ] + if include_unsupported: + buffers.extend([_build_buffer(builder) for _ in range(9)]) + return _finish_tflite_model( builder, subgraph=subgraph, - operator_codes=[rnn_op_code], + operator_codes=[lstm_op_code], buffers=buffers, ) -def _build_two_step_shared_state_rnn_model( - batch, input_size, num_units, weights, recurrent_weights, bias, activation +def test_lstm_none_activation(): + """LSTM with NONE activation uses the cell state before the output gate multiply.""" + from tflite.ActivationFunctionType import ActivationFunctionType + + batch, input_size, num_units = 2, 2, 2 + w_f = np.eye(num_units, input_size, dtype=np.float32) + w_c = np.eye(num_units, input_size, dtype=np.float32) + w_o = np.eye(num_units, input_size, dtype=np.float32) + r_f = np.eye(num_units, dtype=np.float32) + r_c = np.eye(num_units, dtype=np.float32) + r_o = np.eye(num_units, dtype=np.float32) + b_f = np.zeros(num_units, dtype=np.float32) + b_c = np.zeros(num_units, dtype=np.float32) + b_o = np.zeros(num_units, dtype=np.float32) + + mod = _load_model_from_buffer( + _build_lstm_model( + batch, + input_size, + num_units, + w_f, + w_c, + w_o, + r_f, + r_c, + r_o, + b_f, + b_c, + b_o, + ActivationFunctionType.NONE, + ) + ) + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"), + tvmgen_tensor_10: R.Tensor((2, 2), dtype="float32"), + tvmgen_tensor_11: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor((2, 2), dtype="float32"): + R.func_attr({"num_input": 3}) + with R.dataflow(): + lv: R.Tensor((2, 2), dtype="float32") = R.permute_dims( + R.const(np.eye(2, dtype=np.float32)), axes=None + ) + lv1: R.Tensor((2, 2), dtype="float32") = R.matmul( + tvmgen_tensor_0, lv, out_dtype="void" + ) + lv2: R.Tensor((2, 2), dtype="float32") = R.permute_dims( + R.const(np.eye(2, dtype=np.float32)), axes=None + ) + lv3: R.Tensor((2, 2), dtype="float32") = R.matmul( + tvmgen_tensor_10, lv2, out_dtype="void" + ) + lv4: R.Tensor((2, 2), dtype="float32") = R.add(lv1, lv3) + lv5: R.Tensor((2, 2), dtype="float32") = R.add( + lv4, R.const(np.zeros(2, dtype=np.float32)) + ) + lv6: R.Tensor((2, 2), dtype="float32") = R.sigmoid(lv5) + lv7: R.Tensor((2, 2), dtype="float32") = R.permute_dims( + R.const(np.eye(2, dtype=np.float32)), axes=None + ) + lv8: R.Tensor((2, 2), dtype="float32") = R.matmul( + tvmgen_tensor_0, lv7, out_dtype="void" + ) + lv9: R.Tensor((2, 2), dtype="float32") = R.permute_dims( + R.const(np.eye(2, dtype=np.float32)), axes=None + ) + lv10: R.Tensor((2, 2), dtype="float32") = R.matmul( + tvmgen_tensor_10, lv9, out_dtype="void" + ) + lv11: R.Tensor((2, 2), dtype="float32") = R.add(lv8, lv10) + lv12: R.Tensor((2, 2), dtype="float32") = R.add( + lv11, R.const(np.zeros(2, dtype=np.float32)) + ) + lv13: R.Tensor((2, 2), dtype="float32") = R.sigmoid(lv12) + lv14: R.Tensor((2, 2), dtype="float32") = R.multiply(lv13, tvmgen_tensor_11) + lv15: R.Tensor((2, 2), dtype="float32") = R.subtract(R.const(1.0, "float32"), lv13) + lv16: R.Tensor((2, 2), dtype="float32") = R.permute_dims( + R.const(np.eye(2, dtype=np.float32)), axes=None + ) + lv17: R.Tensor((2, 2), dtype="float32") = R.matmul( + tvmgen_tensor_0, lv16, out_dtype="void" + ) + lv18: R.Tensor((2, 2), dtype="float32") = R.permute_dims( + R.const(np.eye(2, dtype=np.float32)), axes=None + ) + lv19: R.Tensor((2, 2), dtype="float32") = R.matmul( + tvmgen_tensor_10, lv18, out_dtype="void" + ) + lv20: R.Tensor((2, 2), dtype="float32") = R.add(lv17, lv19) + lv21: R.Tensor((2, 2), dtype="float32") = R.add( + lv20, R.const(np.zeros(2, dtype=np.float32)) + ) + lv22: R.Tensor((2, 2), dtype="float32") = R.tanh(lv21) + lv23: R.Tensor((2, 2), dtype="float32") = R.multiply(lv15, lv22) + lv24: R.Tensor((2, 2), dtype="float32") = R.add(lv14, lv23) + gv: R.Tensor((2, 2), dtype="float32") = R.multiply(lv6, lv24) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_lstm_tanh_activation(): + """LSTM with TANH activation applies tanh before the output gate multiply.""" + from tflite.ActivationFunctionType import ActivationFunctionType + + batch, input_size, num_units = 2, 2, 2 + w_f = np.eye(num_units, input_size, dtype=np.float32) + w_c = np.eye(num_units, input_size, dtype=np.float32) + w_o = np.eye(num_units, input_size, dtype=np.float32) + r_f = np.eye(num_units, dtype=np.float32) + r_c = np.eye(num_units, dtype=np.float32) + r_o = np.eye(num_units, dtype=np.float32) + b_f = np.zeros(num_units, dtype=np.float32) + b_c = np.zeros(num_units, dtype=np.float32) + b_o = np.zeros(num_units, dtype=np.float32) + + mod = _load_model_from_buffer( + _build_lstm_model( + batch, + input_size, + num_units, + w_f, + w_c, + w_o, + r_f, + r_c, + r_o, + b_f, + b_c, + b_o, + ActivationFunctionType.TANH, + ) + ) + + @I.ir_module + class Expected: + @R.function + def main( + tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"), + tvmgen_tensor_10: R.Tensor((2, 2), dtype="float32"), + tvmgen_tensor_11: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor((2, 2), dtype="float32"): + R.func_attr({"num_input": 3}) + with R.dataflow(): + lv: R.Tensor((2, 2), dtype="float32") = R.permute_dims( + R.const(np.eye(2, dtype=np.float32)), axes=None + ) + lv1: R.Tensor((2, 2), dtype="float32") = R.matmul( + tvmgen_tensor_0, lv, out_dtype="void" + ) + lv2: R.Tensor((2, 2), dtype="float32") = R.permute_dims( + R.const(np.eye(2, dtype=np.float32)), axes=None + ) + lv3: R.Tensor((2, 2), dtype="float32") = R.matmul( + tvmgen_tensor_10, lv2, out_dtype="void" + ) + lv4: R.Tensor((2, 2), dtype="float32") = R.add(lv1, lv3) + lv5: R.Tensor((2, 2), dtype="float32") = R.add( + lv4, R.const(np.zeros(2, dtype=np.float32)) + ) + lv6: R.Tensor((2, 2), dtype="float32") = R.sigmoid(lv5) + lv7: R.Tensor((2, 2), dtype="float32") = R.permute_dims( + R.const(np.eye(2, dtype=np.float32)), axes=None + ) + lv8: R.Tensor((2, 2), dtype="float32") = R.matmul( + tvmgen_tensor_0, lv7, out_dtype="void" + ) + lv9: R.Tensor((2, 2), dtype="float32") = R.permute_dims( + R.const(np.eye(2, dtype=np.float32)), axes=None + ) + lv10: R.Tensor((2, 2), dtype="float32") = R.matmul( + tvmgen_tensor_10, lv9, out_dtype="void" + ) + lv11: R.Tensor((2, 2), dtype="float32") = R.add(lv8, lv10) + lv12: R.Tensor((2, 2), dtype="float32") = R.add( + lv11, R.const(np.zeros(2, dtype=np.float32)) + ) + lv13: R.Tensor((2, 2), dtype="float32") = R.sigmoid(lv12) + lv14: R.Tensor((2, 2), dtype="float32") = R.multiply(lv13, tvmgen_tensor_11) + lv15: R.Tensor((2, 2), dtype="float32") = R.subtract(R.const(1.0, "float32"), lv13) + lv16: R.Tensor((2, 2), dtype="float32") = R.permute_dims( + R.const(np.eye(2, dtype=np.float32)), axes=None + ) + lv17: R.Tensor((2, 2), dtype="float32") = R.matmul( + tvmgen_tensor_0, lv16, out_dtype="void" + ) + lv18: R.Tensor((2, 2), dtype="float32") = R.permute_dims( + R.const(np.eye(2, dtype=np.float32)), axes=None + ) + lv19: R.Tensor((2, 2), dtype="float32") = R.matmul( + tvmgen_tensor_10, lv18, out_dtype="void" + ) + lv20: R.Tensor((2, 2), dtype="float32") = R.add(lv17, lv19) + lv21: R.Tensor((2, 2), dtype="float32") = R.add( + lv20, R.const(np.zeros(2, dtype=np.float32)) + ) + lv22: R.Tensor((2, 2), dtype="float32") = R.tanh(lv21) + lv23: R.Tensor((2, 2), dtype="float32") = R.multiply(lv15, lv22) + lv24: R.Tensor((2, 2), dtype="float32") = R.add(lv14, lv23) + lv25: R.Tensor((2, 2), dtype="float32") = R.tanh(lv24) + gv: R.Tensor((2, 2), dtype="float32") = R.multiply(lv6, lv25) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_lstm_rejects_unsupported_features(): + """LSTM with peephole/projection/layer norm tensors should be rejected.""" + from tflite.ActivationFunctionType import ActivationFunctionType + + batch, input_size, num_units = 2, 2, 2 + zeros_w = np.zeros((num_units, input_size), dtype=np.float32) + zeros_r = np.zeros((num_units, num_units), dtype=np.float32) + zeros_b = np.zeros(num_units, dtype=np.float32) + + with pytest.raises(tvm.error.OpNotImplemented, match="not supported yet"): + _load_model_from_buffer( + _build_lstm_model( + batch, + input_size, + num_units, + zeros_w, + zeros_w, + zeros_w, + zeros_r, + zeros_r, + zeros_r, + zeros_b, + zeros_b, + zeros_b, + ActivationFunctionType.NONE, + include_unsupported=True, + ) + ) + + +# ── SVDF ────────────────────────────────────────────────────────────────────── + + +def _build_svdf_model( + batch, + input_size, + num_units, + rank, + memory_size, + num_filters, + feat_weights, + time_weights, + bias, + activation, ): - """Build a TFLite model with two RNN ops sharing the same hidden-state tensor.""" + """Build a minimal TFLite flatbuffer model containing one SVDF op. + + Tensor indices: + 0 - input [batch, input_size] (model input) + 1 - feature_weights [num_filters, input_size] (constant) + 2 - time_weights [num_filters, memory_size] (constant) + 3 - bias [num_units] (constant) + 4 - state [batch, num_filters * memory_size] (variable, model input) + 5 - output [batch, num_units] + """ builder = flatbuffers.Builder(4096) - _tfl_rnn_options.RNNOptionsStart(builder) - _tfl_rnn_options.RNNOptionsAddFusedActivationFunction(builder, activation) - rnn_opts = _tfl_rnn_options.RNNOptionsEnd(builder) + _tfl_svdf_options.SVDFOptionsStart(builder) + _tfl_svdf_options.SVDFOptionsAddRank(builder, rank) + _tfl_svdf_options.SVDFOptionsAddFusedActivationFunction(builder, activation) + svdf_opts = _tfl_svdf_options.SVDFOptionsEnd(builder) - rnn_op_code = _build_operator_code(builder, _tfl_builtin_operator.RNN) + svdf_op_code = _build_operator_code(builder, _tfl_builtin_operator.SVDF) - def _t(buf_idx, shape, is_variable=False): + def _t(buf_idx, shape): shape_vec = _tflite_shape(builder, shape) _tfl_tensor.TensorStart(builder) _tfl_tensor.TensorAddBuffer(builder, buf_idx) _tfl_tensor.TensorAddHasRank(builder, True) - _tfl_tensor.TensorAddIsVariable(builder, is_variable) + _tfl_tensor.TensorAddIsVariable(builder, False) _tfl_tensor.TensorAddShape(builder, shape_vec) _tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32) return _tfl_tensor.TensorEnd(builder) tensors = [ - _t(0, [batch, input_size]), - _t(1, [num_units, input_size]), - _t(2, [num_units, num_units]), - _t(3, [num_units]), - _t(4, [batch, num_units], is_variable=True), - _t(0, [batch, input_size]), - _t(0, [batch, num_units]), - _t(0, [batch, num_units]), + _t(0, [batch, input_size]), # 0: input + _t(1, [num_filters, input_size]), # 1: feature_weights + _t(2, [num_filters, memory_size]), # 2: time_weights + _t(3, [num_units]), # 3: bias + _t(0, [batch, num_filters * memory_size]), # 4: state (variable, zero-filled) + _t(0, [batch, num_units]), # 5: output ] - first_rnn_op = _build_operator( + svdf_op = _build_operator( builder, 0, [0, 1, 2, 3, 4], - [6], - builtin_options_type=_tfl_builtin_options.RNNOptions, - builtin_options=rnn_opts, - ) - second_rnn_op = _build_operator( - builder, - 0, - [5, 1, 2, 3, 4], - [7], - builtin_options_type=_tfl_builtin_options.RNNOptions, - builtin_options=rnn_opts, + [5], + builtin_options_type=_tfl_builtin_options.SVDFOptions, + builtin_options=svdf_opts, ) subgraph = _build_subgraph( builder, tensors=tensors, - operators=[first_rnn_op, second_rnn_op], - inputs=[0, 5], - outputs=[7], + operators=[svdf_op], + inputs=[0, 4], + outputs=[5], ) buffers = [ - _build_buffer(builder), - _build_buffer(builder, weights.tobytes()), - _build_buffer(builder, recurrent_weights.tobytes()), - _build_buffer(builder, bias.tobytes()), - _build_buffer(builder), + _build_buffer(builder), # 0: empty + _build_buffer(builder, feat_weights.tobytes()), # 1 + _build_buffer(builder, time_weights.tobytes()), # 2 + _build_buffer(builder, bias.tobytes()), # 3 ] return _finish_tflite_model( builder, subgraph=subgraph, - operator_codes=[rnn_op_code], + operator_codes=[svdf_op_code], buffers=buffers, ) -def test_rnn_none_activation(): - """RNN with NONE activation lowers to matmul/add. - - Cell equation: h = x @ W.T + h @ Wr.T + b (no activation for NONE) - """ +def test_svdf_none_activation(): + """SVDF with NONE activation, verifying output shape and params.""" from tflite.ActivationFunctionType import ActivationFunctionType - batch, input_size, num_units = 2, 2, 2 - weights = np.eye(num_units, input_size, dtype=np.float32) - recurrent_weights = np.eye(num_units, dtype=np.float32) + batch, input_size, num_units, rank, memory_size = 2, 3, 2, 2, 3 + num_filters = num_units * rank + np.random.seed(42) + feat_weights = np.random.randn(num_filters, input_size).astype(np.float32) + time_weights = np.random.randn(num_filters, memory_size).astype(np.float32) bias = np.zeros(num_units, dtype=np.float32) mod = _load_model_from_buffer( - _build_rnn_model( + _build_svdf_model( batch, input_size, num_units, - weights, - recurrent_weights, + rank, + memory_size, + num_filters, + feat_weights, + time_weights, bias, ActivationFunctionType.NONE, ) ) - @I.ir_module - class Expected: - @R.function - def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2), dtype="float32"): - R.func_attr({"num_input": 1}) - with R.dataflow(): - lv: R.Tensor((2, 2), dtype="float32") = R.permute_dims( - R.const(np.eye(2, dtype=np.float32)), axes=None - ) - lv1: R.Tensor((2, 2), dtype="float32") = R.matmul(x, lv, out_dtype="void") - lv2: R.Tensor((2, 2), dtype="float32") = R.zeros(R.shape([2, 2]), dtype="float32") - lv3: R.Tensor((2, 2), dtype="float32") = R.permute_dims( - R.const(np.eye(2, dtype=np.float32)), axes=None - ) - lv4: R.Tensor((2, 2), dtype="float32") = R.matmul(lv2, lv3, out_dtype="void") - lv5: R.Tensor((2, 2), dtype="float32") = R.add(lv1, lv4) - gv: R.Tensor((2, 2), dtype="float32") = R.add( - lv5, R.const(np.zeros(2, dtype=np.float32)) - ) - R.output(gv) - return gv + fn = mod["main"] + assert len(fn.params) == 2, f"expected 2 params (input, state), got {len(fn.params)}" + in_shape = fn.params[0].struct_info.shape + assert tuple(int(d) for d in in_shape) == (batch, input_size) + state_shape = fn.params[1].struct_info.shape + assert tuple(int(d) for d in state_shape) == (batch, num_filters * memory_size) + out_shape = fn.ret_struct_info.shape + assert tuple(int(d) for d in out_shape) == (batch, num_units) - tvm.ir.assert_structural_equal(mod, Expected) +def _build_two_step_shared_state_svdf_model( + batch, + input_size, + num_units, + rank, + memory_size, + feat_weights_0, + time_weights_0, + bias_0, + feat_weights_1, + time_weights_1, + bias_1, + activation, +): + """Build two consecutive SVDF ops sharing a single state tensor.""" + builder = flatbuffers.Builder(4096) + num_filters = num_units * rank + + _tfl_svdf_options.SVDFOptionsStart(builder) + _tfl_svdf_options.SVDFOptionsAddRank(builder, rank) + _tfl_svdf_options.SVDFOptionsAddFusedActivationFunction(builder, activation) + svdf_opts = _tfl_svdf_options.SVDFOptionsEnd(builder) -def test_rnn_relu_activation(): - """RNN with RELU activation and random weights.""" - from tflite.ActivationFunctionType import ActivationFunctionType + svdf_op_code = _build_operator_code(builder, _tfl_builtin_operator.SVDF) - batch, input_size, num_units = 2, 4, 8 - np.random.seed(42) - weights = np.random.randn(num_units, input_size).astype(np.float32) - recurrent_weights = np.random.randn(num_units, num_units).astype(np.float32) - bias = np.random.randn(num_units).astype(np.float32) + def _t(buf_idx, shape): + shape_vec = _tflite_shape(builder, shape) + _tfl_tensor.TensorStart(builder) + _tfl_tensor.TensorAddBuffer(builder, buf_idx) + _tfl_tensor.TensorAddHasRank(builder, True) + _tfl_tensor.TensorAddIsVariable(builder, False) + _tfl_tensor.TensorAddShape(builder, shape_vec) + _tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32) + return _tfl_tensor.TensorEnd(builder) - mod = _load_model_from_buffer( - _build_rnn_model( - batch, - input_size, - num_units, - weights, - recurrent_weights, - bias, - ActivationFunctionType.RELU, - ) + tensors = [ + _t(0, [batch, input_size]), # 0 input_0 + _t(1, [num_filters, input_size]), # 1 feat_weights_0 + _t(2, [num_filters, memory_size]), # 2 time_weights_0 + _t(3, [num_units]), # 3 bias_0 + _t(0, [batch, num_filters * memory_size]), # 4 shared state + _t(0, [batch, num_units]), # 5 output_0 + _t(0, [batch, input_size]), # 6 input_1 + _t(4, [num_filters, input_size]), # 7 feat_weights_1 + _t(5, [num_filters, memory_size]), # 8 time_weights_1 + _t(6, [num_units]), # 9 bias_1 + _t(0, [batch, num_units]), # 10 output_1 + ] + + svdf_op_0 = _build_operator( + builder, + 0, + [0, 1, 2, 3, 4], + [5], + builtin_options_type=_tfl_builtin_options.SVDFOptions, + builtin_options=svdf_opts, + ) + svdf_op_1 = _build_operator( + builder, + 0, + [6, 7, 8, 9, 4], + [10], + builtin_options_type=_tfl_builtin_options.SVDFOptions, + builtin_options=svdf_opts, ) - fn = mod["main"] - assert len(fn.params) == 1, "only the input should be a graph input" - in_shape = fn.params[0].struct_info.shape - assert tuple(int(d) for d in in_shape) == (batch, input_size) - out_shape = fn.ret_struct_info.shape - assert tuple(int(d) for d in out_shape) == (batch, num_units) + subgraph = _build_subgraph( + builder, + tensors=tensors, + operators=[svdf_op_0, svdf_op_1], + inputs=[0, 6, 4], + outputs=[10], + ) + + buffers = [ + _build_buffer(builder), + _build_buffer(builder, feat_weights_0.tobytes()), + _build_buffer(builder, time_weights_0.tobytes()), + _build_buffer(builder, bias_0.tobytes()), + _build_buffer(builder, feat_weights_1.tobytes()), + _build_buffer(builder, time_weights_1.tobytes()), + _build_buffer(builder, bias_1.tobytes()), + ] + + return _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=[svdf_op_code], + buffers=buffers, + ) -def test_rnn_shared_hidden_state_updates_exp_tab(): - """Two consecutive RNN ops sharing hidden_state should use the updated state.""" +def test_svdf_shared_state_updates_exp_tab(): + """Two SVDF ops sharing state should use the updated FIFO state in the second step.""" from tflite.ActivationFunctionType import ActivationFunctionType - batch, input_size, num_units = 2, 2, 2 - weights = np.eye(num_units, input_size, dtype=np.float32) - recurrent_weights = np.eye(num_units, dtype=np.float32) - bias = np.zeros(num_units, dtype=np.float32) + batch, input_size, num_units, rank, memory_size = 1, 1, 1, 2, 3 + feat_weights_0 = np.array([[1.0], [2.0]], dtype=np.float32) + time_weights_0 = np.array([[1.0, 3.0, 5.0], [2.0, 4.0, 6.0]], dtype=np.float32) + bias_0 = np.zeros(num_units, dtype=np.float32) + + feat_weights_1 = np.array([[7.0], [11.0]], dtype=np.float32) + time_weights_1 = np.array([[13.0, 17.0, 19.0], [23.0, 29.0, 31.0]], dtype=np.float32) + bias_1 = np.zeros(num_units, dtype=np.float32) mod = _load_model_from_buffer( - _build_two_step_shared_state_rnn_model( + _build_two_step_shared_state_svdf_model( batch, input_size, num_units, - weights, - recurrent_weights, - bias, + rank, + memory_size, + feat_weights_0, + time_weights_0, + bias_0, + feat_weights_1, + time_weights_1, + bias_1, ActivationFunctionType.NONE, ) ) @@ -10381,35 +10386,54 @@ def test_rnn_shared_hidden_state_updates_exp_tab(): class Expected: @R.function def main( - x0: R.Tensor((2, 2), dtype="float32"), - x1: R.Tensor((2, 2), dtype="float32"), - ) -> R.Tensor((2, 2), dtype="float32"): - R.func_attr({"num_input": 2}) + tvmgen_tensor_0: R.Tensor((1, 1), dtype="float32"), + tvmgen_tensor_6: R.Tensor((1, 1), dtype="float32"), + tvmgen_tensor_4: R.Tensor((1, 6), dtype="float32"), + ) -> R.Tensor((1, 1), dtype="float32"): + R.func_attr({"num_input": 3}) with R.dataflow(): - lv: R.Tensor((2, 2), dtype="float32") = R.permute_dims( - R.const(np.eye(2, dtype=np.float32)), axes=None + lv: R.Tensor((1, 2, 3), dtype="float32") = R.reshape( + tvmgen_tensor_4, R.shape([1, 2, 3]) ) - lv1: R.Tensor((2, 2), dtype="float32") = R.matmul(x0, lv, out_dtype="void") - lv2: R.Tensor((2, 2), dtype="float32") = R.zeros(R.shape([2, 2]), dtype="float32") - lv3: R.Tensor((2, 2), dtype="float32") = R.permute_dims( - R.const(np.eye(2, dtype=np.float32)), axes=None + lv1: R.Tensor((1, 2, 3), dtype="float32") = R.reshape( + R.const(np.array([[1.0, 3.0, 5.0], [2.0, 4.0, 6.0]], dtype=np.float32)), + R.shape([1, 2, 3]), ) - lv4: R.Tensor((2, 2), dtype="float32") = R.matmul(lv2, lv3, out_dtype="void") - lv5: R.Tensor((2, 2), dtype="float32") = R.add(lv1, lv4) - lv6: R.Tensor((2, 2), dtype="float32") = R.permute_dims( - R.const(np.eye(2, dtype=np.float32)), axes=None + lv2: R.Tensor((1, 2, 3), dtype="float32") = R.multiply(lv, lv1) + lv3: R.Tensor((1, 2), dtype="float32") = R.sum(lv2, axis=[-1], keepdims=False) + lv4: R.Tensor((1, 1, 2), dtype="float32") = R.reshape(lv3, R.shape([1, 1, 2])) + lv5: R.Tensor((1, 1), dtype="float32") = R.sum( # noqa: F841 + lv4, axis=[-1], keepdims=False ) - lv7: R.Tensor((2, 2), dtype="float32") = R.matmul(x1, lv6, out_dtype="void") - lv8: R.Tensor((2, 2), dtype="float32") = R.add( - lv5, R.const(np.zeros(2, dtype=np.float32)) + lv6: R.Tensor((1, 2, 2), dtype="float32") = R.strided_slice( + lv, + (R.prim_value(2),), + (R.prim_value(1),), + (R.prim_value(3),), + assume_inbound=False, ) - lv9: R.Tensor((2, 2), dtype="float32") = R.permute_dims( - R.const(np.eye(2, dtype=np.float32)), axes=None + lv7: R.Tensor((1, 2), dtype="float32") = R.permute_dims( + R.const(np.array([[1.0], [2.0]], dtype=np.float32)), axes=None ) - lv10: R.Tensor((2, 2), dtype="float32") = R.matmul(lv8, lv9, out_dtype="void") - lv11: R.Tensor((2, 2), dtype="float32") = R.add(lv7, lv10) - gv: R.Tensor((2, 2), dtype="float32") = R.add( - lv11, R.const(np.zeros(2, dtype=np.float32)) + lv8: R.Tensor((1, 2), dtype="float32") = R.matmul( + tvmgen_tensor_0, + lv7, + out_dtype="void", + ) + lv9: R.Tensor((1, 2, 1), dtype="float32") = R.expand_dims(lv8, axis=[-1]) + lv10: R.Tensor((1, 2, 3), dtype="float32") = R.concat((lv6, lv9), axis=2) + lv11: R.Tensor((1, 6), dtype="float32") = R.reshape(lv10, R.shape([1, 6])) + lv12: R.Tensor((1, 2, 3), dtype="float32") = R.reshape(lv11, R.shape([1, 2, 3])) + lv13: R.Tensor((1, 2, 3), dtype="float32") = R.reshape( + R.const(np.array([[13.0, 17.0, 19.0], [23.0, 29.0, 31.0]], dtype=np.float32)), + R.shape([1, 2, 3]), + ) + lv14: R.Tensor((1, 2, 3), dtype="float32") = R.multiply(lv12, lv13) + lv15: R.Tensor((1, 2), dtype="float32") = R.sum(lv14, axis=[-1], keepdims=False) + lv16: R.Tensor((1, 1, 2), dtype="float32") = R.reshape(lv15, R.shape([1, 1, 2])) + lv17: R.Tensor((1, 1), dtype="float32") = R.sum(lv16, axis=[-1], keepdims=False) + gv: R.Tensor((1, 1), dtype="float32") = R.add( + lv17, R.const(np.zeros(1, dtype=np.float32)) ) R.output(gv) return gv