[Relax][Frontend][TFLite] Add RNN converter#19632
Conversation
fef53b3 to
98d1477
Compare
There was a problem hiding this comment.
Code Review
This pull request adds support for converting TFLite RNN and UNIDIRECTIONAL_SEQUENCE_RNN operators to TVM Relax, along with corresponding unit tests. The review feedback highlights critical issues regarding stateful execution and layout consistency: first, the initial hidden state should be resolved from the model inputs if available, rather than always being zero-initialized; second, the output of UNIDIRECTIONAL_SEQUENCE_RNN must respect the time_major layout attribute by transposing the stacked outputs back to [time, batch, num_units] when time_major is true, which also requires updating the test assertions and model builder shapes.
I am having trouble creating individual review comments. Click here to see my feedback.
python/tvm/relax/frontend/tflite/tflite_frontend.py (4536-4539)
If the hidden_state tensor is passed as a model input (i.e., it exists in exp_tab), we should use its resolved expression instead of always overwriting it with relax.op.zeros. This is crucial for supporting stateful RNN execution where the state is maintained across invocations.
# Zero-initialised hidden state [batch, num_units] if not provided as an input.
if self.has_expr(hidden_state_tensor.tensor_idx):
h = self.get_expr(hidden_state_tensor.tensor_idx)
else:
h_shape = tuple(to_int_list(self.get_tensor_shape(hidden_state_tensor)))
h_dtype = self.get_tensor_type_str(hidden_state_tensor.tensor.Type())
h = relax.op.zeros(h_shape, dtype=h_dtype)python/tvm/relax/frontend/tflite/tflite_frontend.py (4608-4611)
Similar to convert_rnn, if the initial hidden_state is provided as a model input, we should use its resolved expression instead of always zero-initializing it.
# Zero-initialised hidden state [batch, num_units] if not provided as an input.
if self.has_expr(hidden_state_tensor.tensor_idx):
h = self.get_expr(hidden_state_tensor.tensor_idx)
else:
h_shape = tuple(to_int_list(self.get_tensor_shape(hidden_state_tensor)))
h_dtype = self.get_tensor_type_str(hidden_state_tensor.tensor.Type())
h = relax.op.zeros(h_shape, dtype=h_dtype)
python/tvm/relax/frontend/tflite/tflite_frontend.py (4630-4631)
According to the TFLite specification, the output of UNIDIRECTIONAL_SEQUENCE_RNN should have the same layout (time-major or batch-major) as the input. If time_major is True, the output must be transposed back to [time, batch, num_units] to prevent shape mismatches in downstream operators.
# Stack timestep outputs: [batch, time, num_units].
stacked = relax.op.stack(outputs, axis=1)
if time_major:
return relax.op.permute_dims(stacked, [1, 0, 2])
return stackedtests/python/relax/test_frontend_tflite.py (6939-6946)
Update the test model's output tensor shape to be [time, batch, num_units] when time_major is True to align with the correct TFLite specification.
output_shape = [time, batch, num_units] if time_major else [batch, time, num_units]
tensors = [
_t(0, input_shape),
_t(1, [num_units, input_size]),
_t(2, [num_units, num_units]),
_t(3, [num_units]),
_t(4, [batch, num_units], is_variable=True),
_t(5, output_shape),
]
tests/python/relax/test_frontend_tflite.py (7093-7096)
Update the assertion to verify that the output is correctly time-major [time, batch, num_units] when time_major is True.
assert tuple(int(d) for d in in_shape) == (time, batch, input_size)
# Output is time-major [time, batch, num_units] when time_major=True.
out_shape = fn.ret_struct_info.shape
assert tuple(int(d) for d in out_shape) == (time, batch, num_units)
…ut in RNN converter
19112de to
9147a34
Compare
tlopex
left a comment
There was a problem hiding this comment.
Thanks for adding the RNN converter.
I think there is one statefulness issue here: TFLite RNN updates the variable hidden_state after computing the activated output. In this converter, we return the new hidden state as the op output, but the original hidden_state_tensor binding in exp_tab is not updated. That means if a later op reads the same hidden-state tensor, it will still see the initial state rather than the state produced by this RNN step.
Could we either update the expression table entry for hidden_state_tensor.tensor_idx to the activated result, using override semantics, or explicitly reject stateful RNN graphs if Relax cannot model this mutation yet?
A regression test with two consecutive RNN ops sharing the same hidden_state would be useful here, since the second op should consume the first op’s updated state.
|
Thanks review. I updated the RNN converter to write the activated hidden state back to the original hidden_state tensor entry in exp_tab with force_override=True, so later ops observe the updated state. I also added a regression test with two consecutive RNN ops sharing the same hidden_state tensor to verify the second RNN consumes the updated state from the first one. |
Summary
Add Relax TFLite frontend support for
RNN(BuiltinOperator 23), claimed in #19519 Group A.Single-step RNN cell:
Changes
convert_rnnregistered inconvert_map(alphabetical, afterRANGE)input [batch, input_size],input_weights [num_units, input_size],recurrent_weights [num_units, num_units],bias [num_units],hidden_state [batch, num_units](variable, zero-initialised)[batch, num_units]convert_fused_activation_functionOpNotImplementedTesting
Two tests added to
tests/python/relax/test_frontend_tflite.py:test_rnn_none_activation—tvm.ir.assert_structural_equalwith identity weights, NONE activationtest_rnn_relu_activation— shape check, random weights, RELU activationReferences