From 40a2232584695d1cdfceb4da2c7ff34156e242fe Mon Sep 17 00:00:00 2001 From: LudovicoYIN Date: Sat, 30 May 2026 04:42:38 +0000 Subject: [PATCH] [Relax][Frontend][TFLite] Rebase sequence recurrent ops onto main --- .../relax/frontend/tflite/tflite_frontend.py | 670 ++++++++++--- tests/python/relax/test_frontend_tflite.py | 892 ++++++++++++++++++ 2 files changed, 1425 insertions(+), 137 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index c479ec83c179..7046e43bbe68 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -200,6 +200,8 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "AVERAGE_POOL_2D": functools.partial(self.convert_pool2d, pool_type="average"), "BATCH_TO_SPACE_ND": self.convert_batch_to_space_nd, "BATCH_MATMUL": self.convert_batch_matmul, + "BIDIRECTIONAL_SEQUENCE_LSTM": self.convert_bidirectional_sequence_lstm, + "BIDIRECTIONAL_SEQUENCE_RNN": self.convert_bidirectional_sequence_rnn, "BITCAST": self.convert_bitcast, "BROADCAST_TO": self.convert_broadcast_to, "BROADCAST_ARGS": self.convert_broadcast_args, @@ -404,7 +406,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "UNSORTED_SEGMENT_PROD": functools.partial( self._convert_segment_op, op_name="UNSORTED_SEGMENT_PROD", reduction="mul" ), - # "UNIDIRECTIONAL_SEQUENCE_LSTM": self.convert_unidirectional_sequence_lstm, + "UNIDIRECTIONAL_SEQUENCE_LSTM": self.convert_unidirectional_sequence_lstm, "VAR_HANDLE": self.convert_var_handle, "WHERE": self.convert_select, "WHILE": self.convert_while, @@ -5510,153 +5512,547 @@ def convert_unidirectional_sequence_rnn(self, op): # Stack timestep outputs: [batch, time, num_units]. return relax.op.stack(outputs, axis=1) - """ def convert_unidirectional_sequence_lstm(self, op): - ### Long Short Term Memory for TFLite implementation. ### + """Convert TFLite UNIDIRECTIONAL_SEQUENCE_LSTM. + + Inputs (24 tensors, same layout as single-step LSTM): + [0] input [batch, time, input_size] + [1] input_to_input_weights [num_units, input_size] (optional) + [2] input_to_forget_weights [num_units, input_size] + [3] input_to_cell_weights [num_units, input_size] + [4] input_to_output_weights [num_units, input_size] + [5] recurrent_to_input_weights [num_units, num_units] (optional) + [6] recurrent_to_forget_weights [num_units, num_units] + [7] recurrent_to_cell_weights [num_units, num_units] + [8] recurrent_to_output_weights [num_units, num_units] + [9] cell_to_input_weights [num_units] (optional) + [10] cell_to_forget_weights [num_units] (optional) + [11] cell_to_output_weights [num_units] (optional) + [12] input_gate_bias [num_units] (optional) + [13] forget_gate_bias [num_units] + [14] cell_gate_bias [num_units] + [15] output_gate_bias [num_units] + [16] projection_weights [num_units, num_units] (optional) + [17] projection_bias [num_units] (optional) + [18] output_state [batch, num_units] (variable) + [19] cell_state [batch, num_units] (variable) + [20-23] optional layer norm weights + + Output: + [0] output [batch, time, num_units] + + Uses coupled input-forget gate (i = 1 - f) for the FULL kernel. + """ + from tflite.BuiltinOptions import BuiltinOptions + from tflite.UnidirectionalSequenceLSTMOptions import UnidirectionalSequenceLSTMOptions + if self.is_quantized(op): raise tvm.error.OpNotImplemented( - "TFlite quantized UNIDIRECTIONALSEQUENCELSTM operator is not supported yet." + "TFLite quantized UNIDIRECTIONAL_SEQUENCE_LSTM is not supported yet." ) input_tensors = self.get_input_tensors(op) - assert len(input_tensors) == 24, "input tensors length should be == 24" + assert len(input_tensors) == 24, ( + f"input tensors length should be 24, got {len(input_tensors)}" + ) - # Extract input tensor from saved model - input_tensor = input_tensors[0] + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) >= 1, "output tensors length should be at least 1" + + assert op.BuiltinOptionsType() == BuiltinOptions.UnidirectionalSequenceLSTMOptions + op_options = op.BuiltinOptions() + lstm_opts = UnidirectionalSequenceLSTMOptions() + lstm_opts.Init(op_options.Bytes, op_options.Pos) + time_major = lstm_opts.TimeMajor() + fused_activation_fn = lstm_opts.FusedActivationFunction() + cell_clip = lstm_opts.CellClip() + proj_clip = lstm_opts.ProjClip() + + # Only coupled input-forget gate is supported. + if input_tensors[1].tensor_idx != -1 or input_tensors[5].tensor_idx != -1: + raise tvm.error.OpNotImplemented("Only coupled input-forget LSTM is supported.") + if any(input_tensors[idx].tensor_idx != -1 for idx in [9, 10, 11]): + raise tvm.error.OpNotImplemented("TFLite peephole LSTM is not supported yet.") + if any(input_tensors[idx].tensor_idx != -1 for idx in [16, 17]): + raise tvm.error.OpNotImplemented("TFLite projection LSTM is not supported yet.") + if any(input_tensors[idx].tensor_idx != -1 for idx in [20, 21, 22, 23]): + raise tvm.error.OpNotImplemented("TFLite layer-norm LSTM is not supported yet.") + + # Weights (transposed once outside the loop). + w_f_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[2])) + w_c_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[3])) + w_o_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[4])) + r_f_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[6])) + r_c_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[7])) + r_o_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[8])) + + # Biases. + b_f = self.get_tensor_expr(input_tensors[13]) + b_c = self.get_tensor_expr(input_tensors[14]) + b_o = self.get_tensor_expr(input_tensors[15]) + + # Initial states. + h = self.get_tensor_expr(input_tensors[18]) + c = self.get_tensor_expr(input_tensors[19]) + + # Resolve the input expression; normalise to batch-major [batch, time, input_size]. + in_expr = self.get_tensor_expr(input_tensors[0]) + in_shape = self.get_tensor_shape(input_tensors[0]) + if time_major: + in_expr = relax.op.permute_dims(in_expr, [1, 0, 2]) + num_steps = int(in_shape[0]) + else: + num_steps = int(in_shape[1]) + + # Unroll over the time axis. + if num_steps == 1: + steps = [relax.op.squeeze(in_expr, axis=[1])] + else: + splits = relax.op.split(in_expr, num_steps, axis=1) + steps = [relax.op.squeeze(splits[i], axis=[1]) for i in range(num_steps)] + + one = relax.const(1.0, "float32") + outputs = [] + for x_t in steps: + f = relax.op.sigmoid( + relax.op.add( + relax.op.add( + relax.op.matmul(x_t, w_f_t), + relax.op.matmul(h, r_f_t), + ), + b_f, + ) + ) + i = relax.op.subtract(one, f) + g = self.convert_fused_activation_function( + relax.op.add( + relax.op.add(relax.op.matmul(x_t, w_c_t), relax.op.matmul(h, r_c_t)), + b_c, + ), + fused_activation_fn, + ) + o = relax.op.sigmoid( + relax.op.add( + relax.op.add( + relax.op.matmul(x_t, w_o_t), + relax.op.matmul(h, r_o_t), + ), + b_o, + ) + ) + + c_new = relax.op.add(relax.op.multiply(f, c), relax.op.multiply(i, g)) + if cell_clip > 0.0: + c_new = relax.op.clip(c_new, -cell_clip, cell_clip) + + h_new = relax.op.multiply( + o, self.convert_fused_activation_function(c_new, fused_activation_fn) + ) + if proj_clip > 0.0: + h_new = relax.op.clip(h_new, -proj_clip, proj_clip) + outputs.append(h_new) + h, c = h_new, c_new + + h_out = relax.op.stack(outputs, axis=1) + if time_major: + h_out = relax.op.permute_dims(h_out, [1, 0, 2]) + + # Update state tensors in the expression table for subsequent ops. + self.exp_tab.set_expr( + get_tensor_name(self.subgraph, input_tensors[18].tensor_idx), + h, + force_override=True, + ) + self.exp_tab.set_expr( + get_tensor_name(self.subgraph, input_tensors[19].tensor_idx), + c, + force_override=True, + ) + + return h_out + + def convert_bidirectional_sequence_rnn(self, op): + """Convert TFLite BIDIRECTIONAL_SEQUENCE_RNN. + + Inputs (9 tensors, aux_input not supported): + [0] input [batch, time, input_size] + [1] fw_weights [num_units, input_size] + [2] fw_recurrent_weights [num_units, num_units] + [3] fw_bias [num_units] + [4] fw_hidden_state [batch, num_units] (variable) + [5] bw_weights [num_units, input_size] + [6] bw_recurrent_weights [num_units, num_units] + [7] bw_bias [num_units] + [8] bw_hidden_state [batch, num_units] (variable) + + Output (merge_outputs=True): + [0] output [batch, time, 2 * num_units] (fw and bw concatenated) + + Output (merge_outputs=False): + [0] fw_output [batch, time, num_units] + [1] bw_output [batch, time, num_units] + """ + from tflite.BidirectionalSequenceRNNOptions import BidirectionalSequenceRNNOptions + from tflite.BuiltinOptions import BuiltinOptions + + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + "TFLite quantized BIDIRECTIONAL_SEQUENCE_RNN is not supported yet." + ) + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 12, ( + f"input tensors length should be 12, got {len(input_tensors)}" + ) - # Extract tensors from input tensors from saved model - # Input weights - input_input_weights = input_tensors[1] - input_forget_weights = input_tensors[2] - input_cell_weights = input_tensors[3] - input_output_weights = input_tensors[4] - # Recurrent weights - recurrent_input_weights = input_tensors[5] - recurrent_forget_weights = input_tensors[6] - recurrent_cell_weights = input_tensors[7] - recurrent_output_weights = input_tensors[8] - # inputs 9, 10, 11, 16, 17, 20, 21, 22, 23 are not occupied - # there locations are -1 in the flatbuffer - # Bias weights - input_gate_bias = input_tensors[12] - forget_gate_bias = input_tensors[13] - cell_gate_bias = input_tensors[14] - output_gate_bias = input_tensors[15] - - # State input - output_state_in = input_tensors[18] - cell_state_in = input_tensors[19] - - # Extract output tensor from saved model output_tensors = self.get_output_tensors(op) - assert len(output_tensors) == 1, "output tensors length should be 1" - X_steps = self.unbind(input_tensor, axis=1) - weights_dict = {} - - # hidden_state_weights is equivalent to output_state_in in tflite model - out_state_in_shape = tuple(self.get_tensor_shape(output_state_in)) - out_state_in_dtype = self.get_tensor_type_str(output_state_in.tensor.Type()) - out_state_in_expr = relax.op.zeros(out_state_in_shape, dtype=out_state_in_dtype) - weights_dict["hidden_state"] = relax.op.split(out_state_in_expr, 1)[0] - - # cell_state_weights is equivalent to output_state_in tflite model - cell_state_in_shape = tuple(self.get_tensor_shape(cell_state_in)) - cell_state_in_dtype = self.get_tensor_type_str(cell_state_in.tensor.Type()) - cell_state_in_expr = relax.op.zeros(cell_state_in_shape, dtype=cell_state_in_dtype) - weights_dict["cell_state"] = relax.op.split(cell_state_in_expr, 1)[0] - - # Process weight matrix of input: w_inp - # Concatenate of [input_input_weight, input_forget_weights, - # input_cell_weights, input_output_weights] - input_input_weights_default_values = self.get_tensor_value(input_input_weights) - input_input_weights_op = relax.op.split( - relax.op.const(input_input_weights_default_values.tolist()), 1 - ) - input_output_weights_default_values = self.get_tensor_value(input_output_weights) - input_output_weights_op = relax.op.split( - relax.op.const(input_output_weights_default_values.tolist()), 1 - ) - input_forget_weights_default_values = self.get_tensor_value(input_forget_weights) - input_forget_weights_op = relax.op.split( - relax.op.const(input_forget_weights_default_values.tolist()), 1 - ) - input_cell_weights_default_values = self.get_tensor_value(input_cell_weights) - input_cell_weights_op = relax.op.split( - _op.const(input_cell_weights_default_values.tolist()), 1 - ) - weights_dict["w_inp"] = relax.op.concat( - [ - relax.op.squeeze(input_input_weights_op[0]), - relax.op.squeeze(input_forget_weights_op[0]), - relax.op.squeeze(input_cell_weights_op[0]), - relax.op.squeeze(input_output_weights_op[0]), - ], - axis=0, - ) - - # Process weight matrix of hidden state: - # w_hid to support lstm_cell function. Not used in tflite - recurrent_input_weights_values = self.get_tensor_value(recurrent_input_weights) - recurrent_input_weights_op = relax.op.split( - relax.op.const(recurrent_input_weights_values.tolist()), 1 - ) - recurrent_output_weights_values = self.get_tensor_value(recurrent_output_weights) - recurrent_output_weights_op = relax.op.split( - relax.op.const(recurrent_output_weights_values.tolist()), 1 - ) - recurrent_forget_weights_values = self.get_tensor_value(recurrent_forget_weights) - recurrent_forget_weights_op = relax.op.split( - relax.op.const(recurrent_forget_weights_values.tolist()), 1 - ) - recurrent_cell_weights_values = self.get_tensor_value(recurrent_cell_weights) - recurrent_cell_weights_op = relax.op.split( - _op.const(recurrent_cell_weights_values.tolist()), 1 - ) - weights_dict["w_hid"] = relax.op.concat( - [ - recurrent_input_weights_op[0], - recurrent_forget_weights_op[0], - recurrent_cell_weights_op[0], - recurrent_output_weights_op[0], - ], - axis=0, - ) - - # Process weight matrix of bias: b_inp - input_gate_bias_values = self.get_tensor_value(input_gate_bias) - input_gate_bias_op = relax.op.split(_op.const(input_gate_bias_values.tolist()), 1) - output_gate_bias_values = self.get_tensor_value(output_gate_bias) - output_gate_bias_op = relax.op.split(_op.const(output_gate_bias_values.tolist()), 1) - forget_gate_bias_values = self.get_tensor_value(forget_gate_bias) - forget_gate_bias_op = relax.op.split(_op.const(forget_gate_bias_values.tolist()), 1) - cell_gate_bias_values = self.get_tensor_value(cell_gate_bias) - cell_gate_bias_op = relax.op.split(_op.const(cell_gate_bias_values.tolist()), 1) - weights_dict["b_inp"] = relax.op.concat( - [ - input_gate_bias_op[0], - forget_gate_bias_op[0], - cell_gate_bias_op[0], - output_gate_bias_op[0], - ], - axis=0, - ) - - # Process weight matrix of hidden bias: - # b_hid (with the same shape as b_inp) - gate_bias_dtype = self.get_tensor_type_str(input_gate_bias.tensor.Type()) - weights_dict["b_hid"] = relax.op.split( - relax.op.const( - np.zeros(self._infer_shape(weights_dict["b_inp"]), dtype=gate_bias_dtype), - dtype=gate_bias_dtype, - ), - 1, - )[0] + assert len(output_tensors) >= 1, "output tensors length should be at least 1" + + assert op.BuiltinOptionsType() == BuiltinOptions.BidirectionalSequenceRNNOptions + op_options = op.BuiltinOptions() + rnn_opts = BidirectionalSequenceRNNOptions() + rnn_opts.Init(op_options.Bytes, op_options.Pos) + time_major = rnn_opts.TimeMajor() + fused_activation_fn = rnn_opts.FusedActivationFunction() + merge_outputs = rnn_opts.MergeOutputs() + if any(input_tensors[idx].tensor_idx != -1 for idx in [9, 10, 11]): + raise tvm.error.OpNotImplemented( + "TFLite BIDIRECTIONAL_SEQUENCE_RNN aux input is not supported yet." + ) - outputs, _, _ = lstm_cell(input_seqs=X_steps, **weights_dict) + # Forward weights and biases. + fw_weights_expr = self.get_tensor_expr(input_tensors[1]) + fw_recurrent_expr = self.get_tensor_expr(input_tensors[2]) + fw_bias_expr = self.get_tensor_expr(input_tensors[3]) + fw_w_t = relax.op.permute_dims(fw_weights_expr) + fw_wr_t = relax.op.permute_dims(fw_recurrent_expr) - output = relax.op.stack(outputs, axis=1) - return output - """ + # Backward weights and biases. + bw_weights_expr = self.get_tensor_expr(input_tensors[5]) + bw_recurrent_expr = self.get_tensor_expr(input_tensors[6]) + bw_bias_expr = self.get_tensor_expr(input_tensors[7]) + bw_w_t = relax.op.permute_dims(bw_weights_expr) + bw_wr_t = relax.op.permute_dims(bw_recurrent_expr) + + # Resolve the input expression; normalise to batch-major [batch, time, input_size]. + in_expr = self.get_tensor_expr(input_tensors[0]) + in_shape = self.get_tensor_shape(input_tensors[0]) + if time_major: + in_expr = relax.op.permute_dims(in_expr, [1, 0, 2]) + num_steps = int(in_shape[0]) + else: + num_steps = int(in_shape[1]) + + # Initial hidden states. + def _get_hidden_state(tensor): + if self.has_expr(tensor.tensor_idx) or ( + tensor.buffer is not None and tensor.buffer.DataLength() > 0 + ): + return self.get_tensor_expr(tensor) + dtype = self.get_tensor_type_str(tensor.tensor.Type()) + h_shape = tuple(to_int_list(self.get_tensor_shape(tensor))) + return relax.op.zeros(h_shape, dtype=dtype) + + fw_h = _get_hidden_state(input_tensors[4]) + bw_h = _get_hidden_state(input_tensors[8]) + + # Unroll over the time axis. + if num_steps == 1: + steps = [relax.op.squeeze(in_expr, axis=[1])] + else: + splits = relax.op.split(in_expr, num_steps, axis=1) + steps = [relax.op.squeeze(splits[i], axis=[1]) for i in range(num_steps)] + + # Forward pass. + fw_outputs = [] + for x_t in steps: + gates = relax.op.add( + relax.op.add(relax.op.matmul(x_t, fw_w_t), relax.op.matmul(fw_h, fw_wr_t)), + fw_bias_expr, + ) + fw_h = self.convert_fused_activation_function(gates, fused_activation_fn) + fw_outputs.append(fw_h) + + # Backward pass (process steps in reverse). + bw_outputs = [] + for x_t in reversed(steps): + gates = relax.op.add( + relax.op.add(relax.op.matmul(x_t, bw_w_t), relax.op.matmul(bw_h, bw_wr_t)), + bw_bias_expr, + ) + bw_h = self.convert_fused_activation_function(gates, fused_activation_fn) + bw_outputs.append(bw_h) + bw_outputs.reverse() + + fw_stacked = relax.op.stack(fw_outputs, axis=1) # [batch, time, num_units] + bw_stacked = relax.op.stack(bw_outputs, axis=1) # [batch, time, num_units] + if time_major: + fw_stacked = relax.op.permute_dims(fw_stacked, [1, 0, 2]) + bw_stacked = relax.op.permute_dims(bw_stacked, [1, 0, 2]) + + # Update state tensors in the expression table for subsequent ops. + self.exp_tab.set_expr( + get_tensor_name(self.subgraph, input_tensors[4].tensor_idx), + fw_h, + force_override=True, + ) + self.exp_tab.set_expr( + get_tensor_name(self.subgraph, input_tensors[8].tensor_idx), + bw_h, + force_override=True, + ) + + if merge_outputs: + return relax.op.concat([fw_stacked, bw_stacked], axis=-1) + else: + return relax.Tuple([fw_stacked, bw_stacked]) + + def convert_bidirectional_sequence_lstm(self, op): + """Convert TFLite BIDIRECTIONAL_SEQUENCE_LSTM. + + Inputs (48 tensors, indices 0-17 forward LSTM, 18-34 backward LSTM, 35-38 states, + 39-47 optional aux inputs, which are not supported): + + Forward LSTM cell (indices 0-17, same layout as single-step LSTM): + [0] input (shared) [batch, time, input_size] + [1] fw_input_to_input_weights (optional) + [2] fw_input_to_forget_weights + [3] fw_input_to_cell_weights + [4] fw_input_to_output_weights + [5] fw_recurrent_to_input_wts (optional) + [6] fw_recurrent_to_forget_wts + [7] fw_recurrent_to_cell_wts + [8] fw_recurrent_to_output_wts + [9-11] fw cell_to_*_weights (optional, not supported) + [12] fw_input_gate_bias (optional) + [13] fw_forget_gate_bias + [14] fw_cell_gate_bias + [15] fw_output_gate_bias + [16] fw_projection_weights (optional, not supported) + [17] fw_projection_bias (optional, not supported) + + Backward LSTM cell (indices 18-34, same layout as fw): + [19] bw_input_to_forget_weights + [20] bw_input_to_cell_weights + [21] bw_input_to_output_weights + [23] bw_recurrent_to_forget_wts + [24] bw_recurrent_to_cell_wts + [25] bw_recurrent_to_output_wts + [30] bw_forget_gate_bias + [31] bw_cell_gate_bias + [32] bw_output_gate_bias + + State tensors: + [35] fw_activation_state [batch, num_units] + [36] fw_cell_state [batch, num_units] + [37] bw_activation_state [batch, num_units] + [38] bw_cell_state [batch, num_units] + + Output (merge_outputs=True): + [0] output [batch, time, 2 * num_units] + + Output (merge_outputs=False): + [0] fw_output [batch, time, num_units] + [1] bw_output [batch, time, num_units] + """ + from tflite.BidirectionalSequenceLSTMOptions import BidirectionalSequenceLSTMOptions + from tflite.BuiltinOptions import BuiltinOptions + + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + "TFLite quantized BIDIRECTIONAL_SEQUENCE_LSTM is not supported yet." + ) + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 48, ( + f"input tensors length should be 48, got {len(input_tensors)}" + ) + + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) >= 1, "output tensors length should be at least 1" + + assert op.BuiltinOptionsType() == BuiltinOptions.BidirectionalSequenceLSTMOptions + op_options = op.BuiltinOptions() + lstm_opts = BidirectionalSequenceLSTMOptions() + lstm_opts.Init(op_options.Bytes, op_options.Pos) + time_major = lstm_opts.TimeMajor() + fused_activation_fn = lstm_opts.FusedActivationFunction() + merge_outputs = lstm_opts.MergeOutputs() + cell_clip = lstm_opts.CellClip() + proj_clip = lstm_opts.ProjClip() + + # ── Forward LSTM weights (transposed once outside the loop) ── + if input_tensors[1].tensor_idx != -1 or input_tensors[5].tensor_idx != -1: + raise tvm.error.OpNotImplemented("Only coupled input-forget LSTM is supported.") + if any(input_tensors[idx].tensor_idx != -1 for idx in [9, 10, 11]): + raise tvm.error.OpNotImplemented("TFLite peephole LSTM is not supported yet.") + if any(input_tensors[idx].tensor_idx != -1 for idx in [16, 17]): + raise tvm.error.OpNotImplemented("TFLite projection LSTM is not supported yet.") + + fw_w_f_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[2])) + fw_w_c_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[3])) + fw_w_o_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[4])) + fw_r_f_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[6])) + fw_r_c_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[7])) + fw_r_o_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[8])) + fw_b_f = self.get_tensor_expr(input_tensors[13]) + fw_b_c = self.get_tensor_expr(input_tensors[14]) + fw_b_o = self.get_tensor_expr(input_tensors[15]) + + # ── Backward LSTM weights (transposed once outside the loop) ── + if input_tensors[18].tensor_idx != -1 or input_tensors[22].tensor_idx != -1: + raise tvm.error.OpNotImplemented("Only coupled input-forget LSTM is supported.") + if any(input_tensors[idx].tensor_idx != -1 for idx in [26, 27, 28]): + raise tvm.error.OpNotImplemented("TFLite peephole LSTM is not supported yet.") + if any(input_tensors[idx].tensor_idx != -1 for idx in [33, 34]): + raise tvm.error.OpNotImplemented("TFLite projection LSTM is not supported yet.") + if any(input_tensors[idx].tensor_idx != -1 for idx in range(39, 48)): + raise tvm.error.OpNotImplemented( + "TFLite BIDIRECTIONAL_SEQUENCE_LSTM aux input is not supported yet." + ) + + bw_w_f_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[19])) + bw_w_c_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[20])) + bw_w_o_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[21])) + bw_r_f_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[23])) + bw_r_c_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[24])) + bw_r_o_t = relax.op.permute_dims(self.get_tensor_expr(input_tensors[25])) + bw_b_f = self.get_tensor_expr(input_tensors[30]) + bw_b_c = self.get_tensor_expr(input_tensors[31]) + bw_b_o = self.get_tensor_expr(input_tensors[32]) + + # ── Initial states ── + fw_h = self.get_tensor_expr(input_tensors[35]) + fw_c = self.get_tensor_expr(input_tensors[36]) + bw_h = self.get_tensor_expr(input_tensors[37]) + bw_c = self.get_tensor_expr(input_tensors[38]) + + # ── Unroll input ── + in_expr = self.get_tensor_expr(input_tensors[0]) + in_shape = self.get_tensor_shape(input_tensors[0]) + if time_major: + in_expr = relax.op.permute_dims(in_expr, [1, 0, 2]) + num_steps = int(in_shape[0]) + else: + num_steps = int(in_shape[1]) + + if num_steps == 1: + steps = [relax.op.squeeze(in_expr, axis=[1])] + else: + splits = relax.op.split(in_expr, num_steps, axis=1) + steps = [relax.op.squeeze(splits[i], axis=[1]) for i in range(num_steps)] + + one = relax.const(1.0, "float32") + + def _lstm_step(x_t, h, c, w_f_t, w_c_t, w_o_t, r_f_t, r_c_t, r_o_t, b_f, b_c, b_o): + """Single LSTM step with coupled input-forget gate.""" + f = relax.op.sigmoid( + relax.op.add( + relax.op.add( + relax.op.matmul(x_t, w_f_t), + relax.op.matmul(h, r_f_t), + ), + b_f, + ) + ) + i = relax.op.subtract(one, f) + g = self.convert_fused_activation_function( + relax.op.add( + relax.op.add(relax.op.matmul(x_t, w_c_t), relax.op.matmul(h, r_c_t)), + b_c, + ), + fused_activation_fn, + ) + o = relax.op.sigmoid( + relax.op.add( + relax.op.add( + relax.op.matmul(x_t, w_o_t), + relax.op.matmul(h, r_o_t), + ), + b_o, + ) + ) + c_new = relax.op.add(relax.op.multiply(f, c), relax.op.multiply(i, g)) + if cell_clip > 0.0: + c_new = relax.op.clip(c_new, -cell_clip, cell_clip) + h_new = relax.op.multiply( + o, self.convert_fused_activation_function(c_new, fused_activation_fn) + ) + if proj_clip > 0.0: + h_new = relax.op.clip(h_new, -proj_clip, proj_clip) + return h_new, c_new + + # ── Forward pass ── + fw_outputs = [] + for x_t in steps: + fw_h, fw_c = _lstm_step( + x_t, + fw_h, + fw_c, + fw_w_f_t, + fw_w_c_t, + fw_w_o_t, + fw_r_f_t, + fw_r_c_t, + fw_r_o_t, + fw_b_f, + fw_b_c, + fw_b_o, + ) + fw_outputs.append(fw_h) + + # ── Backward pass ── + bw_outputs = [] + for x_t in reversed(steps): + bw_h, bw_c = _lstm_step( + x_t, + bw_h, + bw_c, + bw_w_f_t, + bw_w_c_t, + bw_w_o_t, + bw_r_f_t, + bw_r_c_t, + bw_r_o_t, + bw_b_f, + bw_b_c, + bw_b_o, + ) + bw_outputs.append(bw_h) + bw_outputs.reverse() + + fw_stacked = relax.op.stack(fw_outputs, axis=1) + bw_stacked = relax.op.stack(bw_outputs, axis=1) + if time_major: + fw_stacked = relax.op.permute_dims(fw_stacked, [1, 0, 2]) + bw_stacked = relax.op.permute_dims(bw_stacked, [1, 0, 2]) + + # Update state tensors in the expression table for subsequent ops. + self.exp_tab.set_expr( + get_tensor_name(self.subgraph, input_tensors[35].tensor_idx), + fw_h, + force_override=True, + ) + self.exp_tab.set_expr( + get_tensor_name(self.subgraph, input_tensors[36].tensor_idx), + fw_c, + force_override=True, + ) + self.exp_tab.set_expr( + get_tensor_name(self.subgraph, input_tensors[37].tensor_idx), + bw_h, + force_override=True, + ) + self.exp_tab.set_expr( + get_tensor_name(self.subgraph, input_tensors[38].tensor_idx), + bw_c, + force_override=True, + ) + + if merge_outputs: + return relax.op.concat([fw_stacked, bw_stacked], axis=-1) + else: + return relax.Tuple([fw_stacked, bw_stacked]) def convert_batch_to_space_nd(self, op): """batch_to_space_nd implementation.""" diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index e9ccea7ad150..05a6c1e5e5fa 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -3723,6 +3723,15 @@ def _get_tflite_schema_enum(enum_name): _tfl_lstm_options = _get_tflite_schema_module("LSTMOptions") _tfl_sequence_rnn_options = _get_tflite_schema_module("SequenceRNNOptions") _tfl_svdf_options = _get_tflite_schema_module("SVDFOptions") +_tfl_unidirectional_sequence_lstm_options = _get_tflite_schema_module( + "UnidirectionalSequenceLSTMOptions" +) +_tfl_bidirectional_sequence_rnn_options = _get_tflite_schema_module( + "BidirectionalSequenceRNNOptions" +) +_tfl_bidirectional_sequence_lstm_options = _get_tflite_schema_module( + "BidirectionalSequenceLSTMOptions" +) _DENSIFY_TEST_VALUES = np.array([1.0, 2.0], dtype=np.float32) _DENSIFY_TEST_DENSE = np.array([[1.0, 0.0], [0.0, 2.0]], dtype=np.float32) @@ -11052,6 +11061,889 @@ def main( tvm.ir.assert_structural_equal(mod, Expected) +# ── UNIDIRECTIONAL_SEQUENCE_LSTM ───────────────────────────────────────────── + + +def _build_unidirectional_sequence_lstm_model( + batch, + time, + input_size, + num_units, + input_to_forget_weights, + input_to_cell_weights, + input_to_output_weights, + recurrent_to_forget_weights, + recurrent_to_cell_weights, + recurrent_to_output_weights, + forget_gate_bias, + cell_bias, + output_gate_bias, + activation, + *, + time_major=False, + cell_clip=0.0, + proj_clip=0.0, + projection_weights=None, +): + """Build a TFLite flatbuffer model with one UNIDIRECTIONAL_SEQUENCE_LSTM op. + + Tensor indices (same layout as single-step LSTM, but input is 3D): + 0 - input [batch, time, input_size] + 1 - input_to_forget_weights [num_units, input_size] + 2 - input_to_cell_weights [num_units, input_size] + 3 - input_to_output_weights [num_units, input_size] + 4 - recurrent_to_forget_weights [num_units, num_units] + 5 - recurrent_to_cell_weights [num_units, num_units] + 6 - recurrent_to_output_weights [num_units, num_units] + 7 - forget_gate_bias [num_units] + 8 - cell_bias [num_units] + 9 - output_gate_bias [num_units] + 10 - output_state [batch, num_units] (model input) + 11 - cell_state [batch, num_units] (model input) + 12 - output [batch, time, num_units] or [time, batch, num_units] + """ + builder = flatbuffers.Builder(4096) + + _tfl_unidirectional_sequence_lstm_options.UnidirectionalSequenceLSTMOptionsStart(builder) + _tfl_unidirectional_sequence_lstm_options.UnidirectionalSequenceLSTMOptionsAddFusedActivationFunction( + builder, activation + ) + _tfl_unidirectional_sequence_lstm_options.UnidirectionalSequenceLSTMOptionsAddTimeMajor( + builder, time_major + ) + _tfl_unidirectional_sequence_lstm_options.UnidirectionalSequenceLSTMOptionsAddCellClip( + builder, cell_clip + ) + _tfl_unidirectional_sequence_lstm_options.UnidirectionalSequenceLSTMOptionsAddProjClip( + builder, proj_clip + ) + lstm_opts = _tfl_unidirectional_sequence_lstm_options.UnidirectionalSequenceLSTMOptionsEnd( + builder + ) + + lstm_op_code = _build_operator_code(builder, _tfl_builtin_operator.UNIDIRECTIONAL_SEQUENCE_LSTM) + + def _t(buf_idx, shape): + shape_vec = _tflite_shape(builder, shape) + _tfl_tensor.TensorStart(builder) + _tfl_tensor.TensorAddBuffer(builder, buf_idx) + _tfl_tensor.TensorAddHasRank(builder, True) + _tfl_tensor.TensorAddIsVariable(builder, False) + _tfl_tensor.TensorAddShape(builder, shape_vec) + _tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32) + return _tfl_tensor.TensorEnd(builder) + + input_shape = [time, batch, input_size] if time_major else [batch, time, input_size] + output_shape = [time, batch, num_units] if time_major else [batch, time, num_units] + tensors = [ + _t(0, input_shape), # 0: input + _t(1, [num_units, input_size]), # 1: input_to_forget_weights + _t(2, [num_units, input_size]), # 2: input_to_cell_weights + _t(3, [num_units, input_size]), # 3: input_to_output_weights + _t(4, [num_units, num_units]), # 4: recurrent_to_forget_weights + _t(5, [num_units, num_units]), # 5: recurrent_to_cell_weights + _t(6, [num_units, num_units]), # 6: recurrent_to_output_weights + _t(7, [num_units]), # 7: forget_gate_bias + _t(8, [num_units]), # 8: cell_bias + _t(9, [num_units]), # 9: output_gate_bias + _t(0, [batch, num_units]), # 10: output_state (model input) + _t(0, [batch, num_units]), # 11: cell_state (model input) + _t(0, output_shape), # 12: output + ] + + # 24 operator inputs, -1 for absent. + lstm_inputs = [ + 0, + -1, + 1, + 2, + 3, + -1, + 4, + 5, + 6, + -1, + -1, + -1, + -1, + 7, + 8, + 9, + -1, + -1, + 10, + 11, + -1, + -1, + -1, + -1, + ] + buffers = [ + _build_buffer(builder), # 0: empty + _build_buffer(builder, input_to_forget_weights.tobytes()), # 1 + _build_buffer(builder, input_to_cell_weights.tobytes()), # 2 + _build_buffer(builder, input_to_output_weights.tobytes()), # 3 + _build_buffer(builder, recurrent_to_forget_weights.tobytes()), # 4 + _build_buffer(builder, recurrent_to_cell_weights.tobytes()), # 5 + _build_buffer(builder, recurrent_to_output_weights.tobytes()), # 6 + _build_buffer(builder, forget_gate_bias.tobytes()), # 7 + _build_buffer(builder, cell_bias.tobytes()), # 8 + _build_buffer(builder, output_gate_bias.tobytes()), # 9 + ] + if projection_weights is not None: + tensors.append(_t(len(buffers), [num_units, num_units])) + lstm_inputs[16] = len(tensors) - 1 + buffers.append(_build_buffer(builder, projection_weights.tobytes())) + + lstm_op = _build_operator( + builder, + 0, + lstm_inputs, + [12], + builtin_options_type=_tfl_builtin_options.UnidirectionalSequenceLSTMOptions, + builtin_options=lstm_opts, + ) + + subgraph = _build_subgraph( + builder, + tensors=tensors, + operators=[lstm_op], + inputs=[0, 10, 11], + outputs=[12], + ) + + return _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=[lstm_op_code], + buffers=buffers, + ) + + +def test_unidirectional_sequence_lstm_none_activation(): + """UNIDIRECTIONAL_SEQUENCE_LSTM with NONE activation keeps cell activation linear.""" + from tflite.ActivationFunctionType import ActivationFunctionType + + batch, time, input_size, num_units = 2, 1, 2, 2 + w_f = np.eye(num_units, input_size, dtype=np.float32) + w_c = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + w_o = np.array([[0.5, -0.25], [0.75, 0.5]], dtype=np.float32) + r_f = np.eye(num_units, dtype=np.float32) + r_c = np.array([[0.5, 0.0], [0.0, 0.25]], dtype=np.float32) + r_o = np.array([[0.1, 0.0], [0.0, 0.2]], dtype=np.float32) + b_f = np.zeros(num_units, dtype=np.float32) + b_c = np.zeros(num_units, dtype=np.float32) + b_o = np.zeros(num_units, dtype=np.float32) + + mod = _load_model_from_buffer( + _build_unidirectional_sequence_lstm_model( + batch, + time, + input_size, + num_units, + w_f, + w_c, + w_o, + r_f, + r_c, + r_o, + b_f, + b_c, + b_o, + ActivationFunctionType.NONE, + ) + ) + + script = mod.script(show_meta=True) + assert script.count("R.sigmoid") == 2 + assert "R.tanh" not in script + assert "R.multiply" in script + + +def test_unidirectional_sequence_lstm_tanh_activation(): + """UNIDIRECTIONAL_SEQUENCE_LSTM with TANH activation applies it inside the cell.""" + from tflite.ActivationFunctionType import ActivationFunctionType + + batch, time, input_size, num_units = 2, 1, 2, 2 + w_f = np.eye(num_units, input_size, dtype=np.float32) + w_c = np.array([[1.0, -1.0], [0.25, 0.5]], dtype=np.float32) + w_o = np.array([[0.5, 0.5], [-0.5, 1.0]], dtype=np.float32) + r_f = np.eye(num_units, dtype=np.float32) + r_c = np.array([[0.0, 0.1], [0.2, 0.0]], dtype=np.float32) + r_o = np.array([[0.3, 0.0], [0.0, 0.4]], dtype=np.float32) + b_f = np.zeros(num_units, dtype=np.float32) + b_c = np.zeros(num_units, dtype=np.float32) + b_o = np.zeros(num_units, dtype=np.float32) + + mod = _load_model_from_buffer( + _build_unidirectional_sequence_lstm_model( + batch, + time, + input_size, + num_units, + w_f, + w_c, + w_o, + r_f, + r_c, + r_o, + b_f, + b_c, + b_o, + ActivationFunctionType.TANH, + ) + ) + + script = mod.script(show_meta=True) + assert script.count("R.sigmoid") == 2 + assert script.count("R.tanh") == 2 + assert "R.multiply" in script + + +def test_unidirectional_sequence_lstm_time_major(): + """UNIDIRECTIONAL_SEQUENCE_LSTM preserves time-major output layout.""" + from tflite.ActivationFunctionType import ActivationFunctionType + + batch, time, input_size, num_units = 2, 3, 2, 2 + weights = np.eye(num_units, input_size, dtype=np.float32) + recurrent = np.eye(num_units, dtype=np.float32) + bias = np.zeros(num_units, dtype=np.float32) + + mod = _load_model_from_buffer( + _build_unidirectional_sequence_lstm_model( + batch, + time, + input_size, + num_units, + weights, + weights, + weights, + recurrent, + recurrent, + recurrent, + bias, + bias, + bias, + ActivationFunctionType.NONE, + time_major=True, + ) + ) + + fn = mod["main"] + assert tuple(int(d) for d in fn.params[0].struct_info.shape) == (time, batch, input_size) + assert tuple(int(d) for d in fn.ret_struct_info.shape) == (time, batch, num_units) + + +def test_unidirectional_sequence_lstm_rejects_projection(): + """UNIDIRECTIONAL_SEQUENCE_LSTM rejects unsupported projection inputs.""" + from tflite.ActivationFunctionType import ActivationFunctionType + + batch, time, input_size, num_units = 2, 2, 2, 2 + weights = np.eye(num_units, input_size, dtype=np.float32) + recurrent = np.eye(num_units, dtype=np.float32) + bias = np.zeros(num_units, dtype=np.float32) + + with pytest.raises(tvm.error.OpNotImplemented, match="projection LSTM"): + _load_model_from_buffer( + _build_unidirectional_sequence_lstm_model( + batch, + time, + input_size, + num_units, + weights, + weights, + weights, + recurrent, + recurrent, + recurrent, + bias, + bias, + bias, + ActivationFunctionType.NONE, + projection_weights=np.eye(num_units, dtype=np.float32), + ) + ) + + +# ── BIDIRECTIONAL_SEQUENCE_RNN ─────────────────────────────────────────────── + + +def _build_bidirectional_sequence_rnn_model( + batch, + time, + input_size, + num_units, + fw_weights, + fw_recurrent_weights, + fw_bias, + bw_weights, + bw_recurrent_weights, + bw_bias, + activation, + *, + time_major=False, + merge_outputs=True, + with_aux_input=False, +): + """Build a TFLite flatbuffer model with one BIDIRECTIONAL_SEQUENCE_RNN op. + + Tensor indices: + 0 - input [batch, time, input_size] + 1 - fw_weights [num_units, input_size] + 2 - fw_recurrent_weights [num_units, num_units] + 3 - fw_bias [num_units] + 4 - fw_hidden_state [batch, num_units] (model input) + 5 - bw_weights [num_units, input_size] + 6 - bw_recurrent_weights [num_units, num_units] + 7 - bw_bias [num_units] + 8 - bw_hidden_state [batch, num_units] (model input) + 9 - aux_input (optional) + 10 - fw_aux_weights (optional) + 11 - bw_aux_weights (optional) + 12 - output (or fw_output if merge_outputs=False) + 13 - bw_output (only if merge_outputs=False) + """ + builder = flatbuffers.Builder(4096) + + _tfl_bidirectional_sequence_rnn_options.BidirectionalSequenceRNNOptionsStart(builder) + _tfl_bidirectional_sequence_rnn_options.BidirectionalSequenceRNNOptionsAddTimeMajor( + builder, time_major + ) + _tfl_bidirectional_sequence_rnn_options.BidirectionalSequenceRNNOptionsAddFusedActivationFunction( + builder, activation + ) + _tfl_bidirectional_sequence_rnn_options.BidirectionalSequenceRNNOptionsAddMergeOutputs( + builder, merge_outputs + ) + rnn_opts = _tfl_bidirectional_sequence_rnn_options.BidirectionalSequenceRNNOptionsEnd(builder) + + rnn_op_code = _build_operator_code(builder, _tfl_builtin_operator.BIDIRECTIONAL_SEQUENCE_RNN) + + def _t(buf_idx, shape): + shape_vec = _tflite_shape(builder, shape) + _tfl_tensor.TensorStart(builder) + _tfl_tensor.TensorAddBuffer(builder, buf_idx) + _tfl_tensor.TensorAddHasRank(builder, True) + _tfl_tensor.TensorAddIsVariable(builder, False) + _tfl_tensor.TensorAddShape(builder, shape_vec) + _tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32) + return _tfl_tensor.TensorEnd(builder) + + input_shape = [time, batch, input_size] if time_major else [batch, time, input_size] + output_prefix = [time, batch] if time_major else [batch, time] + output_shape = output_prefix + ([num_units * 2] if merge_outputs else [num_units]) + + tensors = [ + _t(0, input_shape), # 0: input + _t(1, [num_units, input_size]), # 1: fw_weights + _t(2, [num_units, num_units]), # 2: fw_recurrent_weights + _t(3, [num_units]), # 3: fw_bias + _t(0, [batch, num_units]), # 4: fw_hidden_state (model input) + _t(4, [num_units, input_size]), # 5: bw_weights + _t(5, [num_units, num_units]), # 6: bw_recurrent_weights + _t(6, [num_units]), # 7: bw_bias + _t(0, [batch, num_units]), # 8: bw_hidden_state (model input) + ] + buffers = [ + _build_buffer(builder), # 0: empty + _build_buffer(builder, fw_weights.tobytes()), # 1 + _build_buffer(builder, fw_recurrent_weights.tobytes()), # 2 + _build_buffer(builder, fw_bias.tobytes()), # 3 + _build_buffer(builder, bw_weights.tobytes()), # 4 + _build_buffer(builder, bw_recurrent_weights.tobytes()), # 5 + _build_buffer(builder, bw_bias.tobytes()), # 6 + ] + rnn_inputs = [*list(range(9)), -1, -1, -1] + if with_aux_input: + tensors.extend( + [ + _t(len(buffers), input_shape), + _t(len(buffers) + 1, [num_units, input_size]), + _t(len(buffers) + 2, [num_units, input_size]), + ] + ) + rnn_inputs[9:12] = [len(tensors) - 3, len(tensors) - 2, len(tensors) - 1] + buffers.extend( + [ + _build_buffer(builder, np.zeros(input_shape, dtype=np.float32).tobytes()), + _build_buffer( + builder, np.zeros((num_units, input_size), dtype=np.float32).tobytes() + ), + _build_buffer( + builder, np.zeros((num_units, input_size), dtype=np.float32).tobytes() + ), + ] + ) + + if merge_outputs: + tensors.append(_t(0, output_shape)) + outputs = [len(tensors) - 1] + else: + tensors.extend([_t(0, output_shape), _t(0, output_shape)]) + outputs = [len(tensors) - 2, len(tensors) - 1] + + rnn_op = _build_operator( + builder, + 0, + rnn_inputs, + outputs, + builtin_options_type=_tfl_builtin_options.BidirectionalSequenceRNNOptions, + builtin_options=rnn_opts, + ) + + subgraph = _build_subgraph( + builder, + tensors=tensors, + operators=[rnn_op], + inputs=[0, 4, 8], + outputs=outputs, + ) + + return _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=[rnn_op_code], + buffers=buffers, + ) + + +def test_bidirectional_sequence_rnn_none_activation(): + """BIDIRECTIONAL_SEQUENCE_RNN with NONE activation lowers the expected equations.""" + from tflite.ActivationFunctionType import ActivationFunctionType + + batch, time, input_size, num_units = 2, 1, 2, 2 + fw_w = np.array([[1.0, 0.0], [0.5, -1.0]], dtype=np.float32) + fw_r = np.array([[0.25, 0.0], [0.0, 0.5]], dtype=np.float32) + fw_b = np.zeros(num_units, dtype=np.float32) + bw_w = np.array([[0.0, 1.0], [-0.5, 0.75]], dtype=np.float32) + bw_r = np.array([[0.1, 0.0], [0.0, 0.2]], dtype=np.float32) + bw_b = np.zeros(num_units, dtype=np.float32) + + mod = _load_model_from_buffer( + _build_bidirectional_sequence_rnn_model( + batch, + time, + input_size, + num_units, + fw_w, + fw_r, + fw_b, + bw_w, + bw_r, + bw_b, + ActivationFunctionType.NONE, + ) + ) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 1, 2), dtype="float32"), + fw_h: R.Tensor((2, 2), dtype="float32"), + bw_h: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor((2, 1, 4), dtype="float32"): + R.func_attr({"num_input": 3}) + with R.dataflow(): + x_t: R.Tensor((2, 2), dtype="float32") = R.squeeze(x, axis=[1]) + fw_w_t: R.Tensor((2, 2), dtype="float32") = R.permute_dims(R.const(fw_w), axes=None) + fw_x: R.Tensor((2, 2), dtype="float32") = R.matmul(x_t, fw_w_t, out_dtype="void") + fw_r_t: R.Tensor((2, 2), dtype="float32") = R.permute_dims(R.const(fw_r), axes=None) + fw_h_proj: R.Tensor((2, 2), dtype="float32") = R.matmul( + fw_h, fw_r_t, out_dtype="void" + ) + fw_out: R.Tensor((2, 2), dtype="float32") = R.add( + R.add(fw_x, fw_h_proj), R.const(fw_b) + ) + fw_stacked: R.Tensor((2, 1, 2), dtype="float32") = R.stack((fw_out,), axis=1) + bw_w_t: R.Tensor((2, 2), dtype="float32") = R.permute_dims(R.const(bw_w), axes=None) + bw_x: R.Tensor((2, 2), dtype="float32") = R.matmul(x_t, bw_w_t, out_dtype="void") + bw_r_t: R.Tensor((2, 2), dtype="float32") = R.permute_dims(R.const(bw_r), axes=None) + bw_h_proj: R.Tensor((2, 2), dtype="float32") = R.matmul( + bw_h, bw_r_t, out_dtype="void" + ) + bw_out: R.Tensor((2, 2), dtype="float32") = R.add( + R.add(bw_x, bw_h_proj), R.const(bw_b) + ) + bw_stacked: R.Tensor((2, 1, 2), dtype="float32") = R.stack((bw_out,), axis=1) + gv: R.Tensor((2, 1, 4), dtype="float32") = R.concat( + (fw_stacked, bw_stacked), axis=-1 + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_bidirectional_sequence_rnn_time_major(): + """BIDIRECTIONAL_SEQUENCE_RNN preserves time-major output layout.""" + from tflite.ActivationFunctionType import ActivationFunctionType + + batch, time, input_size, num_units = 2, 3, 2, 2 + weights = np.eye(num_units, input_size, dtype=np.float32) + recurrent = np.eye(num_units, dtype=np.float32) + bias = np.zeros(num_units, dtype=np.float32) + + mod = _load_model_from_buffer( + _build_bidirectional_sequence_rnn_model( + batch, + time, + input_size, + num_units, + weights, + recurrent, + bias, + weights, + recurrent, + bias, + ActivationFunctionType.NONE, + time_major=True, + ) + ) + + fn = mod["main"] + assert tuple(int(d) for d in fn.params[0].struct_info.shape) == (time, batch, input_size) + assert tuple(int(d) for d in fn.ret_struct_info.shape) == (time, batch, num_units * 2) + + +def test_bidirectional_sequence_rnn_rejects_aux_input(): + """BIDIRECTIONAL_SEQUENCE_RNN rejects unsupported auxiliary input tensors.""" + from tflite.ActivationFunctionType import ActivationFunctionType + + batch, time, input_size, num_units = 2, 2, 2, 2 + weights = np.eye(num_units, input_size, dtype=np.float32) + recurrent = np.eye(num_units, dtype=np.float32) + bias = np.zeros(num_units, dtype=np.float32) + + with pytest.raises(tvm.error.OpNotImplemented, match="aux input"): + _load_model_from_buffer( + _build_bidirectional_sequence_rnn_model( + batch, + time, + input_size, + num_units, + weights, + recurrent, + bias, + weights, + recurrent, + bias, + ActivationFunctionType.NONE, + with_aux_input=True, + ) + ) + + +# ── BIDIRECTIONAL_SEQUENCE_LSTM ────────────────────────────────────────────── + + +def _build_bidirectional_sequence_lstm_model( + batch, + time, + input_size, + num_units, + fw_w_f, + fw_w_c, + fw_w_o, + fw_r_f, + fw_r_c, + fw_r_o, + fw_b_f, + fw_b_c, + fw_b_o, + bw_w_f, + bw_w_c, + bw_w_o, + bw_r_f, + bw_r_c, + bw_r_o, + bw_b_f, + bw_b_c, + bw_b_o, + activation, + *, + time_major=False, + merge_outputs=True, + cell_clip=0.0, + proj_clip=0.0, + with_aux_input=False, +): + """Build a TFLite flatbuffer model with one BIDIRECTIONAL_SEQUENCE_LSTM op. + + 48 operator inputs. Forward LSTM: indices 0-17, Backward LSTM: indices 18-34, + States: indices 35-38. + """ + builder = flatbuffers.Builder(8192) + + _tfl_bidirectional_sequence_lstm_options.BidirectionalSequenceLSTMOptionsStart(builder) + _tfl_bidirectional_sequence_lstm_options.BidirectionalSequenceLSTMOptionsAddFusedActivationFunction( + builder, activation + ) + _tfl_bidirectional_sequence_lstm_options.BidirectionalSequenceLSTMOptionsAddTimeMajor( + builder, time_major + ) + _tfl_bidirectional_sequence_lstm_options.BidirectionalSequenceLSTMOptionsAddMergeOutputs( + builder, merge_outputs + ) + _tfl_bidirectional_sequence_lstm_options.BidirectionalSequenceLSTMOptionsAddCellClip( + builder, cell_clip + ) + _tfl_bidirectional_sequence_lstm_options.BidirectionalSequenceLSTMOptionsAddProjClip( + builder, proj_clip + ) + lstm_opts = _tfl_bidirectional_sequence_lstm_options.BidirectionalSequenceLSTMOptionsEnd( + builder + ) + + lstm_op_code = _build_operator_code(builder, _tfl_builtin_operator.BIDIRECTIONAL_SEQUENCE_LSTM) + + def _t(buf_idx, shape, is_variable=False): + shape_vec = _tflite_shape(builder, shape) + _tfl_tensor.TensorStart(builder) + _tfl_tensor.TensorAddBuffer(builder, buf_idx) + _tfl_tensor.TensorAddHasRank(builder, True) + _tfl_tensor.TensorAddIsVariable(builder, is_variable) + _tfl_tensor.TensorAddShape(builder, shape_vec) + _tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32) + return _tfl_tensor.TensorEnd(builder) + + input_shape = [time, batch, input_size] if time_major else [batch, time, input_size] + output_size = num_units * 2 if merge_outputs else num_units + output_shape = ([time, batch] if time_major else [batch, time]) + [output_size] + + tensors = [ + _t(0, input_shape), # 0: input + _t(1, [num_units, input_size]), # 1: fw_w_f + _t(2, [num_units, input_size]), # 2: fw_w_c + _t(3, [num_units, input_size]), # 3: fw_w_o + _t(4, [num_units, num_units]), # 4: fw_r_f + _t(5, [num_units, num_units]), # 5: fw_r_c + _t(6, [num_units, num_units]), # 6: fw_r_o + _t(7, [num_units]), # 7: fw_b_f + _t(8, [num_units]), # 8: fw_b_c + _t(9, [num_units]), # 9: fw_b_o + _t(10, [num_units, input_size]), # 10: bw_w_f + _t(11, [num_units, input_size]), # 11: bw_w_c + _t(12, [num_units, input_size]), # 12: bw_w_o + _t(13, [num_units, num_units]), # 13: bw_r_f + _t(14, [num_units, num_units]), # 14: bw_r_c + _t(15, [num_units, num_units]), # 15: bw_r_o + _t(16, [num_units]), # 16: bw_b_f + _t(17, [num_units]), # 17: bw_b_c + _t(18, [num_units]), # 18: bw_b_o + _t(0, [batch, num_units]), # 19: fw_activation_state (model input) + _t(0, [batch, num_units]), # 20: fw_cell_state (model input) + _t(0, [batch, num_units]), # 21: bw_activation_state (model input) + _t(0, [batch, num_units]), # 22: bw_cell_state (model input) + _t(0, output_shape), # 23: output + ] + + # Build operator inputs: 48 total, with unsupported optional inputs set to -1. + fw_inputs = [0, -1, 1, 2, 3, -1, 4, 5, 6, -1, -1, -1, -1, 7, 8, 9, -1, -1] + bw_inputs = [-1, 10, 11, 12, -1, 13, 14, 15, -1, -1, -1, -1, 16, 17, 18, -1, -1] + states = [19, 20, 21, 22] + aux_inputs = [-1] * 9 + if with_aux_input: + tensors.append(_t(0, input_shape)) + aux_inputs[0] = len(tensors) - 1 + lstm_inputs = fw_inputs + bw_inputs + states + aux_inputs + + lstm_op = _build_operator( + builder, + 0, + lstm_inputs, + [23], + builtin_options_type=_tfl_builtin_options.BidirectionalSequenceLSTMOptions, + builtin_options=lstm_opts, + ) + + subgraph = _build_subgraph( + builder, + tensors=tensors, + operators=[lstm_op], + inputs=[0, 19, 20, 21, 22], + outputs=[23], + ) + + buffers = [ + _build_buffer(builder), # 0: empty + _build_buffer(builder, fw_w_f.tobytes()), # 1 + _build_buffer(builder, fw_w_c.tobytes()), # 2 + _build_buffer(builder, fw_w_o.tobytes()), # 3 + _build_buffer(builder, fw_r_f.tobytes()), # 4 + _build_buffer(builder, fw_r_c.tobytes()), # 5 + _build_buffer(builder, fw_r_o.tobytes()), # 6 + _build_buffer(builder, fw_b_f.tobytes()), # 7 + _build_buffer(builder, fw_b_c.tobytes()), # 8 + _build_buffer(builder, fw_b_o.tobytes()), # 9 + _build_buffer(builder, bw_w_f.tobytes()), # 10 + _build_buffer(builder, bw_w_c.tobytes()), # 11 + _build_buffer(builder, bw_w_o.tobytes()), # 12 + _build_buffer(builder, bw_r_f.tobytes()), # 13 + _build_buffer(builder, bw_r_c.tobytes()), # 14 + _build_buffer(builder, bw_r_o.tobytes()), # 15 + _build_buffer(builder, bw_b_f.tobytes()), # 16 + _build_buffer(builder, bw_b_c.tobytes()), # 17 + _build_buffer(builder, bw_b_o.tobytes()), # 18 + ] + + return _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=[lstm_op_code], + buffers=buffers, + ) + + +def test_bidirectional_sequence_lstm_none_activation(): + """BIDIRECTIONAL_SEQUENCE_LSTM with NONE activation keeps both cell activations linear.""" + from tflite.ActivationFunctionType import ActivationFunctionType + + batch, time, input_size, num_units = 2, 1, 2, 2 + + def _eye_or_randn(m, n): + if m == n: + return np.eye(m, dtype=np.float32) + return np.arange(m * n, dtype=np.float32).reshape(m, n) / 10.0 + + fw_w_f = _eye_or_randn(num_units, input_size) + fw_w_c = np.array([[1.0, -0.5], [0.25, 0.75]], dtype=np.float32) + fw_w_o = np.array([[0.5, 0.25], [-0.25, 1.0]], dtype=np.float32) + fw_r_f = _eye_or_randn(num_units, num_units) + fw_r_c = np.array([[0.2, 0.0], [0.0, 0.3]], dtype=np.float32) + fw_r_o = np.array([[0.1, 0.0], [0.0, 0.2]], dtype=np.float32) + fw_b_f = np.zeros(num_units, dtype=np.float32) + fw_b_c = np.zeros(num_units, dtype=np.float32) + fw_b_o = np.zeros(num_units, dtype=np.float32) + + bw_w_f = np.array([[1.0, 0.0], [0.0, 1.0]], dtype=np.float32) + bw_w_c = np.array([[0.5, 0.5], [-0.5, 1.0]], dtype=np.float32) + bw_w_o = np.array([[0.25, -0.25], [0.75, 0.5]], dtype=np.float32) + bw_r_f = np.array([[0.4, 0.0], [0.0, 0.6]], dtype=np.float32) + bw_r_c = np.array([[0.3, 0.0], [0.0, 0.2]], dtype=np.float32) + bw_r_o = np.array([[0.2, 0.0], [0.0, 0.1]], dtype=np.float32) + bw_b_f = np.zeros(num_units, dtype=np.float32) + bw_b_c = np.zeros(num_units, dtype=np.float32) + bw_b_o = np.zeros(num_units, dtype=np.float32) + + mod = _load_model_from_buffer( + _build_bidirectional_sequence_lstm_model( + batch, + time, + input_size, + num_units, + fw_w_f, + fw_w_c, + fw_w_o, + fw_r_f, + fw_r_c, + fw_r_o, + fw_b_f, + fw_b_c, + fw_b_o, + bw_w_f, + bw_w_c, + bw_w_o, + bw_r_f, + bw_r_c, + bw_r_o, + bw_b_f, + bw_b_c, + bw_b_o, + ActivationFunctionType.NONE, + ) + ) + + script = mod.script(show_meta=True) + assert script.count("R.sigmoid") == 4 + assert "R.tanh" not in script + assert script.count("R.stack") == 2 + assert "R.concat" in script + + +def test_bidirectional_sequence_lstm_time_major(): + """BIDIRECTIONAL_SEQUENCE_LSTM preserves time-major output layout.""" + from tflite.ActivationFunctionType import ActivationFunctionType + + batch, time, input_size, num_units = 2, 3, 2, 2 + weights = np.eye(num_units, input_size, dtype=np.float32) + recurrent = np.eye(num_units, dtype=np.float32) + bias = np.zeros(num_units, dtype=np.float32) + + mod = _load_model_from_buffer( + _build_bidirectional_sequence_lstm_model( + batch, + time, + input_size, + num_units, + weights, + weights, + weights, + recurrent, + recurrent, + recurrent, + bias, + bias, + bias, + weights, + weights, + weights, + recurrent, + recurrent, + recurrent, + bias, + bias, + bias, + ActivationFunctionType.NONE, + time_major=True, + ) + ) + + fn = mod["main"] + assert tuple(int(d) for d in fn.params[0].struct_info.shape) == (time, batch, input_size) + assert tuple(int(d) for d in fn.ret_struct_info.shape) == (time, batch, num_units * 2) + + +def test_bidirectional_sequence_lstm_rejects_aux_input(): + """BIDIRECTIONAL_SEQUENCE_LSTM rejects unsupported auxiliary inputs.""" + from tflite.ActivationFunctionType import ActivationFunctionType + + batch, time, input_size, num_units = 2, 2, 2, 2 + weights = np.eye(num_units, input_size, dtype=np.float32) + recurrent = np.eye(num_units, dtype=np.float32) + bias = np.zeros(num_units, dtype=np.float32) + + with pytest.raises(tvm.error.OpNotImplemented, match="aux input"): + _load_model_from_buffer( + _build_bidirectional_sequence_lstm_model( + batch, + time, + input_size, + num_units, + weights, + weights, + weights, + recurrent, + recurrent, + recurrent, + bias, + bias, + bias, + weights, + weights, + weights, + recurrent, + recurrent, + recurrent, + bias, + bias, + bias, + ActivationFunctionType.NONE, + with_aux_input=True, + ) + ) + + # ── UNIDIRECTIONAL_SEQUENCE_RNN ───────────────────────────────────────────────