[Relax][Frontend][TFLite] Support sequence LSTM and RNN operators#19634
[Relax][Frontend][TFLite] Support sequence LSTM and RNN operators#19634LudovicoYIN wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request adds support for TFLite's BIDIRECTIONAL_SEQUENCE_LSTM and BIDIRECTIONAL_SEQUENCE_RNN operators, and enables the previously commented-out UNIDIRECTIONAL_SEQUENCE_LSTM operator in the TVM Relax frontend. It also adds corresponding unit tests. The feedback highlights performance and correctness issues in the LSTM conversion functions: specifically, calling relax.op.permute_dims and creating the constant 1.0 inside the timestep loops creates redundant nodes in the Relax AST, which can bloat the graph and slow down compilation. Additionally, the fused activation function in convert_unidirectional_sequence_lstm should be applied to the hidden state at each step inside the loop to ensure correct recurrent state activation.
tlopex
left a comment
There was a problem hiding this comment.
Thanks for adding these sequence recurrent converters. I think there are still several correctness issues before this should land.
- The LSTM fused activation semantics do not match TFLite.
In TFLite LSTM, FusedActivationFunction is used as the internal cell activation: it is applied to the cell update gate and to the cell state when computing the output state. It is not a post-activation applied to the final hidden state.
The current converters always use tanh for the cell gate and tanh(c_new), then apply convert_fused_activation_function to h_new. This is incorrect for both NONE and TANH.
time_major=Truereturns the wrong layout.
The converters normalize time-major input to batch-major for unrolling, but always return [batch, time, ...]. TFLite preserves the sequence layout, so when time_major=True, the output should be [time, batch, ...].
This affects UNIDIRECTIONAL_SEQUENCE_LSTM, BIDIRECTIONAL_SEQUENCE_RNN, and BIDIRECTIONAL_SEQUENCE_LSTM.
- Unsupported optional inputs are silently ignored.
The converters say peephole/projection/layer-norm/aux inputs are unsupported, but most of these are not rejected. If those optional tensors are present, the converter will produce incorrect IR instead of raising OpNotImplemented.
- The new tests are too weak and partly non-representative.
The bidirectional RNN/LSTM test builders use shortened input lists, while the TFLite kernels expect 12 inputs for BIDIRECTIONAL_SEQUENCE_RNN and 48 for BIDIRECTIONAL_SEQUENCE_LSTM. The tests mostly check shapes, so they do not catch the numerical semantics above.
Please add numerical tests against TFLite, or at least structural tests that verify the actual equations, time_major=True, and rejection of unsupported optional inputs.
|
Thanks for the detailed review. I updated the sequence recurrent converters to address the correctness issues you pointed out.
I also strengthened the tests:
|
|
Could you resolve the conflict so that we can merge it in? Thanks |
81e4340 to
40a2232
Compare
|
Hi, the conflicts are resolved now. |
Summary
Add three TFLite sequence recurrent operators to the Relax frontend, all with
coupled input-forget gate (FULL kernel) and float32-only support.
From #19519.
Changes
time and stacks per-step hidden states. Supports time_major, cell_clip, proj_clip,
and fused activation.
reverse. Supports merge_outputs (concat fw + bw) and split outputs via Tuple.
the same input tensor. States at indices 35-38.
OpNotImplemented).
Testing
test_unidirectional_sequence_lstm_none_activation— output shape [batch, time, num_units]test_bidirectional_sequence_rnn_none_activation— merge_outputs=True, shape [batch, time, 2*num_units]test_bidirectional_sequence_lstm_none_activation— merge_outputs=True, shape [batch, time, 2*num_units]python -m pytest tests/python/relax/test_frontend_tflite.py -k "sequence_lstm or sequence_rnn" -v