Skip to content

[Relax][Frontend][TFLite] Add RNN converter#19632

Merged
tlopex merged 4 commits into
apache:mainfrom
LudovicoYIN:relax/tflite-rnn
May 29, 2026
Merged

[Relax][Frontend][TFLite] Add RNN converter#19632
tlopex merged 4 commits into
apache:mainfrom
LudovicoYIN:relax/tflite-rnn

Conversation

@LudovicoYIN
Copy link
Copy Markdown
Contributor

Summary

Add Relax TFLite frontend support for RNN (BuiltinOperator 23), claimed in #19519 Group A.

Single-step RNN cell:

h = fused_activation(x @ W.T + h @ Wr.T + b)

Changes

  • Handler: convert_rnn registered in convert_map (alphabetical, after RANGE)
  • Inputs (5): input [batch, input_size], input_weights [num_units, input_size], recurrent_weights [num_units, num_units], bias [num_units], hidden_state [batch, num_units] (variable, zero-initialised)
  • Output: [batch, num_units]
  • Activations: all fused activations via convert_fused_activation_function
  • Quantized: raises OpNotImplemented

Testing

Two tests added to tests/python/relax/test_frontend_tflite.py:

  • test_rnn_none_activationtvm.ir.assert_structural_equal with identity weights, NONE activation
  • test_rnn_relu_activation — shape check, random weights, RELU activation
python -m pytest tests/python/relax/test_frontend_tflite.py -k rnn -v

References

  • Issue #19519 Group A: Sequence / recurrent model operators

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for converting TFLite RNN and UNIDIRECTIONAL_SEQUENCE_RNN operators to TVM Relax, along with corresponding unit tests. The review feedback highlights critical issues regarding stateful execution and layout consistency: first, the initial hidden state should be resolved from the model inputs if available, rather than always being zero-initialized; second, the output of UNIDIRECTIONAL_SEQUENCE_RNN must respect the time_major layout attribute by transposing the stacked outputs back to [time, batch, num_units] when time_major is true, which also requires updating the test assertions and model builder shapes.

I am having trouble creating individual review comments. Click here to see my feedback.

python/tvm/relax/frontend/tflite/tflite_frontend.py (4536-4539)

high

If the hidden_state tensor is passed as a model input (i.e., it exists in exp_tab), we should use its resolved expression instead of always overwriting it with relax.op.zeros. This is crucial for supporting stateful RNN execution where the state is maintained across invocations.

        # Zero-initialised hidden state [batch, num_units] if not provided as an input.
        if self.has_expr(hidden_state_tensor.tensor_idx):
            h = self.get_expr(hidden_state_tensor.tensor_idx)
        else:
            h_shape = tuple(to_int_list(self.get_tensor_shape(hidden_state_tensor)))
            h_dtype = self.get_tensor_type_str(hidden_state_tensor.tensor.Type())
            h = relax.op.zeros(h_shape, dtype=h_dtype)

python/tvm/relax/frontend/tflite/tflite_frontend.py (4608-4611)

high

Similar to convert_rnn, if the initial hidden_state is provided as a model input, we should use its resolved expression instead of always zero-initializing it.

        # Zero-initialised hidden state [batch, num_units] if not provided as an input.
        if self.has_expr(hidden_state_tensor.tensor_idx):
            h = self.get_expr(hidden_state_tensor.tensor_idx)
        else:
            h_shape = tuple(to_int_list(self.get_tensor_shape(hidden_state_tensor)))
            h_dtype = self.get_tensor_type_str(hidden_state_tensor.tensor.Type())
            h = relax.op.zeros(h_shape, dtype=h_dtype)

python/tvm/relax/frontend/tflite/tflite_frontend.py (4630-4631)

high

According to the TFLite specification, the output of UNIDIRECTIONAL_SEQUENCE_RNN should have the same layout (time-major or batch-major) as the input. If time_major is True, the output must be transposed back to [time, batch, num_units] to prevent shape mismatches in downstream operators.

        # Stack timestep outputs: [batch, time, num_units].
        stacked = relax.op.stack(outputs, axis=1)
        if time_major:
            return relax.op.permute_dims(stacked, [1, 0, 2])
        return stacked

tests/python/relax/test_frontend_tflite.py (6939-6946)

medium

Update the test model's output tensor shape to be [time, batch, num_units] when time_major is True to align with the correct TFLite specification.

    output_shape = [time, batch, num_units] if time_major else [batch, time, num_units]
    tensors = [
        _t(0, input_shape),
        _t(1, [num_units, input_size]),
        _t(2, [num_units, num_units]),
        _t(3, [num_units]),
        _t(4, [batch, num_units], is_variable=True),
        _t(5, output_shape),
    ]

tests/python/relax/test_frontend_tflite.py (7093-7096)

medium

Update the assertion to verify that the output is correctly time-major [time, batch, num_units] when time_major is True.

    assert tuple(int(d) for d in in_shape) == (time, batch, input_size)
    # Output is time-major [time, batch, num_units] when time_major=True.
    out_shape = fn.ret_struct_info.shape
    assert tuple(int(d) for d in out_shape) == (time, batch, num_units)

Copy link
Copy Markdown
Member

@tlopex tlopex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the RNN converter.

I think there is one statefulness issue here: TFLite RNN updates the variable hidden_state after computing the activated output. In this converter, we return the new hidden state as the op output, but the original hidden_state_tensor binding in exp_tab is not updated. That means if a later op reads the same hidden-state tensor, it will still see the initial state rather than the state produced by this RNN step.

Could we either update the expression table entry for hidden_state_tensor.tensor_idx to the activated result, using override semantics, or explicitly reject stateful RNN graphs if Relax cannot model this mutation yet?

A regression test with two consecutive RNN ops sharing the same hidden_state would be useful here, since the second op should consume the first op’s updated state.

@LudovicoYIN
Copy link
Copy Markdown
Contributor Author

Thanks review. I updated the RNN converter to write the activated hidden state back to the original hidden_state tensor entry in exp_tab with force_override=True, so later ops observe the updated state.

I also added a regression test with two consecutive RNN ops sharing the same hidden_state tensor to verify the second RNN consumes the updated state from the first one.

Copy link
Copy Markdown
Member

@tlopex tlopex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM Thanks

@tlopex tlopex merged commit e89570f into apache:main May 29, 2026
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants