diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 8183f64f7305..44633d231664 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -273,6 +273,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "POW": functools.partial(self._convert_elemwise, relax_op=_op.power), "PRELU": self.convert_prelu, "RANGE": self.convert_range, + "RNN": self.convert_rnn, "QUANTIZE": self.convert_quantize, "RANDOM_STANDARD_NORMAL": self.convert_random_standard_normal, "RANDOM_UNIFORM": self.convert_random_uniform, @@ -4878,6 +4879,84 @@ def convert_unpack(self, op): return squeezed + def convert_rnn(self, op): + """Convert TFLite RNN. + + Single-step RNN cell. + + Inputs (5 tensors): + [0] input [batch, input_size] + [1] input_weights [num_units, input_size] + [2] recurrent_weights [num_units, num_units] + [3] bias [num_units] + [4] hidden_state [batch, num_units] (variable, zero-initialised) + + Output: + [0] output [batch, num_units] + + Cell equation: + h = fused_activation(x @ W.T + h @ Wr.T + b) + """ + from tflite.BuiltinOptions import BuiltinOptions + from tflite.RNNOptions import RNNOptions + + if self.is_quantized(op): + raise tvm.error.OpNotImplemented("TFLite quantized RNN is not supported yet.") + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 5, "input tensors length should be 5" + + input_tensor = input_tensors[0] + weights_tensor = input_tensors[1] + recurrent_tensor = input_tensors[2] + bias_tensor = input_tensors[3] + hidden_state_tensor = input_tensors[4] + + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) >= 1, "output tensors length should be at least 1" + + assert op.BuiltinOptionsType() == BuiltinOptions.RNNOptions + op_options = op.BuiltinOptions() + rnn_options = RNNOptions() + rnn_options.Init(op_options.Bytes, op_options.Pos) + fused_activation_fn = rnn_options.FusedActivationFunction() + + # Constant weight/bias expressions. + weights_expr = self.get_tensor_expr(weights_tensor) # [num_units, input_size] + recurrent_expr = self.get_tensor_expr(recurrent_tensor) # [num_units, num_units] + bias_expr = self.get_tensor_expr(bias_tensor) # [num_units] + + # Transpose to [input_size, num_units] and [num_units, num_units] for x @ W.T. + w_t = relax.op.permute_dims(weights_expr) + wr_t = relax.op.permute_dims(recurrent_expr) + + # Resolve the input expression. + in_expr = self.get_tensor_expr(input_tensor) + + # Initial hidden state: use the model's tensor value when available (non-zero init or + # graph input), otherwise fall back to zeros for the common variable-tensor case. + h_dtype = self.get_tensor_type_str(hidden_state_tensor.tensor.Type()) + if self.has_expr(hidden_state_tensor.tensor_idx) or ( + hidden_state_tensor.buffer is not None and hidden_state_tensor.buffer.DataLength() > 0 + ): + h = self.get_tensor_expr(hidden_state_tensor) + else: + h_shape = tuple(to_int_list(self.get_tensor_shape(hidden_state_tensor))) + h = relax.op.zeros(h_shape, dtype=h_dtype) + + gates = relax.op.add( + relax.op.add(relax.op.matmul(in_expr, w_t), relax.op.matmul(h, wr_t)), + bias_expr, + ) + h = self.convert_fused_activation_function(gates, fused_activation_fn) + + self.exp_tab.set_expr( + get_tensor_name(self.subgraph, hidden_state_tensor.tensor_idx), + h, + force_override=True, + ) + return h + def convert_unidirectional_sequence_rnn(self, op): """Convert TFLite UNIDIRECTIONAL_SEQUENCE_RNN. diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index f1abacec27da..3037cab333f2 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -3720,6 +3720,7 @@ def _get_tflite_schema_enum(enum_name): _tfl_sparse_index_vector = _get_tflite_schema_enum("SparseIndexVector") _tfl_tensor_type = _get_tflite_schema_enum("TensorType") +_tfl_rnn_options = _get_tflite_schema_module("RNNOptions") _tfl_sequence_rnn_options = _get_tflite_schema_module("SequenceRNNOptions") _DENSIFY_TEST_VALUES = np.array([1.0, 2.0], dtype=np.float32) @@ -9721,6 +9722,295 @@ def main( tvm.ir.assert_structural_equal(mod, Expected) +# ── RNN ──────────────────────────────────────────────────────────────────────── + + +def _build_rnn_model(batch, input_size, num_units, weights, recurrent_weights, bias, activation): + """Build a minimal TFLite flatbuffer model containing one RNN op. + + Tensor layout (indices 0-5): + 0 - input [batch, input_size] + 1 - input_weights [num_units, input_size] (constant) + 2 - recurrent_weights [num_units, num_units] (constant) + 3 - bias [num_units] (constant) + 4 - hidden_state [batch, num_units] (variable, zero-initialised) + 5 - output [batch, num_units] + """ + builder = flatbuffers.Builder(4096) + + _tfl_rnn_options.RNNOptionsStart(builder) + _tfl_rnn_options.RNNOptionsAddFusedActivationFunction(builder, activation) + rnn_opts = _tfl_rnn_options.RNNOptionsEnd(builder) + + rnn_op_code = _build_operator_code(builder, _tfl_builtin_operator.RNN) + + def _t(buf_idx, shape, is_variable=False): + shape_vec = _tflite_shape(builder, shape) + _tfl_tensor.TensorStart(builder) + _tfl_tensor.TensorAddBuffer(builder, buf_idx) + _tfl_tensor.TensorAddHasRank(builder, True) + _tfl_tensor.TensorAddIsVariable(builder, is_variable) + _tfl_tensor.TensorAddShape(builder, shape_vec) + _tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32) + return _tfl_tensor.TensorEnd(builder) + + tensors = [ + _t(0, [batch, input_size]), + _t(1, [num_units, input_size]), + _t(2, [num_units, num_units]), + _t(3, [num_units]), + _t(4, [batch, num_units], is_variable=True), + _t(5, [batch, num_units]), + ] + + rnn_op = _build_operator( + builder, + 0, + [0, 1, 2, 3, 4], + [5], + builtin_options_type=_tfl_builtin_options.RNNOptions, + builtin_options=rnn_opts, + ) + + subgraph = _build_subgraph( + builder, + tensors=tensors, + operators=[rnn_op], + inputs=[0], + outputs=[5], + ) + + buffers = [ + _build_buffer(builder), + _build_buffer(builder, weights.tobytes()), + _build_buffer(builder, recurrent_weights.tobytes()), + _build_buffer(builder, bias.tobytes()), + _build_buffer(builder), + _build_buffer(builder), + ] + + return _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=[rnn_op_code], + buffers=buffers, + ) + + +def _build_two_step_shared_state_rnn_model( + batch, input_size, num_units, weights, recurrent_weights, bias, activation +): + """Build a TFLite model with two RNN ops sharing the same hidden-state tensor.""" + builder = flatbuffers.Builder(4096) + + _tfl_rnn_options.RNNOptionsStart(builder) + _tfl_rnn_options.RNNOptionsAddFusedActivationFunction(builder, activation) + rnn_opts = _tfl_rnn_options.RNNOptionsEnd(builder) + + rnn_op_code = _build_operator_code(builder, _tfl_builtin_operator.RNN) + + def _t(buf_idx, shape, is_variable=False): + shape_vec = _tflite_shape(builder, shape) + _tfl_tensor.TensorStart(builder) + _tfl_tensor.TensorAddBuffer(builder, buf_idx) + _tfl_tensor.TensorAddHasRank(builder, True) + _tfl_tensor.TensorAddIsVariable(builder, is_variable) + _tfl_tensor.TensorAddShape(builder, shape_vec) + _tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32) + return _tfl_tensor.TensorEnd(builder) + + tensors = [ + _t(0, [batch, input_size]), + _t(1, [num_units, input_size]), + _t(2, [num_units, num_units]), + _t(3, [num_units]), + _t(4, [batch, num_units], is_variable=True), + _t(0, [batch, input_size]), + _t(0, [batch, num_units]), + _t(0, [batch, num_units]), + ] + + first_rnn_op = _build_operator( + builder, + 0, + [0, 1, 2, 3, 4], + [6], + builtin_options_type=_tfl_builtin_options.RNNOptions, + builtin_options=rnn_opts, + ) + second_rnn_op = _build_operator( + builder, + 0, + [5, 1, 2, 3, 4], + [7], + builtin_options_type=_tfl_builtin_options.RNNOptions, + builtin_options=rnn_opts, + ) + + subgraph = _build_subgraph( + builder, + tensors=tensors, + operators=[first_rnn_op, second_rnn_op], + inputs=[0, 5], + outputs=[7], + ) + + buffers = [ + _build_buffer(builder), + _build_buffer(builder, weights.tobytes()), + _build_buffer(builder, recurrent_weights.tobytes()), + _build_buffer(builder, bias.tobytes()), + _build_buffer(builder), + ] + + return _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=[rnn_op_code], + buffers=buffers, + ) + + +def test_rnn_none_activation(): + """RNN with NONE activation lowers to matmul/add. + + Cell equation: h = x @ W.T + h @ Wr.T + b (no activation for NONE) + """ + from tflite.ActivationFunctionType import ActivationFunctionType + + batch, input_size, num_units = 2, 2, 2 + weights = np.eye(num_units, input_size, dtype=np.float32) + recurrent_weights = np.eye(num_units, dtype=np.float32) + bias = np.zeros(num_units, dtype=np.float32) + + mod = _load_model_from_buffer( + _build_rnn_model( + batch, + input_size, + num_units, + weights, + recurrent_weights, + bias, + ActivationFunctionType.NONE, + ) + ) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + lv: R.Tensor((2, 2), dtype="float32") = R.permute_dims( + R.const(np.eye(2, dtype=np.float32)), axes=None + ) + lv1: R.Tensor((2, 2), dtype="float32") = R.matmul(x, lv, out_dtype="void") + lv2: R.Tensor((2, 2), dtype="float32") = R.zeros(R.shape([2, 2]), dtype="float32") + lv3: R.Tensor((2, 2), dtype="float32") = R.permute_dims( + R.const(np.eye(2, dtype=np.float32)), axes=None + ) + lv4: R.Tensor((2, 2), dtype="float32") = R.matmul(lv2, lv3, out_dtype="void") + lv5: R.Tensor((2, 2), dtype="float32") = R.add(lv1, lv4) + gv: R.Tensor((2, 2), dtype="float32") = R.add( + lv5, R.const(np.zeros(2, dtype=np.float32)) + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_rnn_relu_activation(): + """RNN with RELU activation and random weights.""" + from tflite.ActivationFunctionType import ActivationFunctionType + + batch, input_size, num_units = 2, 4, 8 + np.random.seed(42) + weights = np.random.randn(num_units, input_size).astype(np.float32) + recurrent_weights = np.random.randn(num_units, num_units).astype(np.float32) + bias = np.random.randn(num_units).astype(np.float32) + + mod = _load_model_from_buffer( + _build_rnn_model( + batch, + input_size, + num_units, + weights, + recurrent_weights, + bias, + ActivationFunctionType.RELU, + ) + ) + + fn = mod["main"] + assert len(fn.params) == 1, "only the input should be a graph input" + in_shape = fn.params[0].struct_info.shape + assert tuple(int(d) for d in in_shape) == (batch, input_size) + out_shape = fn.ret_struct_info.shape + assert tuple(int(d) for d in out_shape) == (batch, num_units) + + +def test_rnn_shared_hidden_state_updates_exp_tab(): + """Two consecutive RNN ops sharing hidden_state should use the updated state.""" + from tflite.ActivationFunctionType import ActivationFunctionType + + batch, input_size, num_units = 2, 2, 2 + weights = np.eye(num_units, input_size, dtype=np.float32) + recurrent_weights = np.eye(num_units, dtype=np.float32) + bias = np.zeros(num_units, dtype=np.float32) + + mod = _load_model_from_buffer( + _build_two_step_shared_state_rnn_model( + batch, + input_size, + num_units, + weights, + recurrent_weights, + bias, + ActivationFunctionType.NONE, + ) + ) + + @I.ir_module + class Expected: + @R.function + def main( + x0: R.Tensor((2, 2), dtype="float32"), + x1: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor((2, 2), dtype="float32"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((2, 2), dtype="float32") = R.permute_dims( + R.const(np.eye(2, dtype=np.float32)), axes=None + ) + lv1: R.Tensor((2, 2), dtype="float32") = R.matmul(x0, lv, out_dtype="void") + lv2: R.Tensor((2, 2), dtype="float32") = R.zeros(R.shape([2, 2]), dtype="float32") + lv3: R.Tensor((2, 2), dtype="float32") = R.permute_dims( + R.const(np.eye(2, dtype=np.float32)), axes=None + ) + lv4: R.Tensor((2, 2), dtype="float32") = R.matmul(lv2, lv3, out_dtype="void") + lv5: R.Tensor((2, 2), dtype="float32") = R.add(lv1, lv4) + lv6: R.Tensor((2, 2), dtype="float32") = R.permute_dims( + R.const(np.eye(2, dtype=np.float32)), axes=None + ) + lv7: R.Tensor((2, 2), dtype="float32") = R.matmul(x1, lv6, out_dtype="void") + lv8: R.Tensor((2, 2), dtype="float32") = R.add( + lv5, R.const(np.zeros(2, dtype=np.float32)) + ) + lv9: R.Tensor((2, 2), dtype="float32") = R.permute_dims( + R.const(np.eye(2, dtype=np.float32)), axes=None + ) + lv10: R.Tensor((2, 2), dtype="float32") = R.matmul(lv8, lv9, out_dtype="void") + lv11: R.Tensor((2, 2), dtype="float32") = R.add(lv7, lv10) + gv: R.Tensor((2, 2), dtype="float32") = R.add( + lv11, R.const(np.zeros(2, dtype=np.float32)) + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + # ── UNIDIRECTIONAL_SEQUENCE_RNN ───────────────────────────────────────────────