Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions python/tvm/relax/frontend/tflite/tflite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None):
"POW": functools.partial(self._convert_elemwise, relax_op=_op.power),
"PRELU": self.convert_prelu,
"RANGE": self.convert_range,
"RNN": self.convert_rnn,
"QUANTIZE": self.convert_quantize,
"RANDOM_STANDARD_NORMAL": self.convert_random_standard_normal,
"RANDOM_UNIFORM": self.convert_random_uniform,
Expand Down Expand Up @@ -4878,6 +4879,84 @@ def convert_unpack(self, op):

return squeezed

def convert_rnn(self, op):
"""Convert TFLite RNN.

Single-step RNN cell.

Inputs (5 tensors):
[0] input [batch, input_size]
[1] input_weights [num_units, input_size]
[2] recurrent_weights [num_units, num_units]
[3] bias [num_units]
[4] hidden_state [batch, num_units] (variable, zero-initialised)

Output:
[0] output [batch, num_units]

Cell equation:
h = fused_activation(x @ W.T + h @ Wr.T + b)
"""
from tflite.BuiltinOptions import BuiltinOptions
from tflite.RNNOptions import RNNOptions

if self.is_quantized(op):
raise tvm.error.OpNotImplemented("TFLite quantized RNN is not supported yet.")

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 5, "input tensors length should be 5"

input_tensor = input_tensors[0]
weights_tensor = input_tensors[1]
recurrent_tensor = input_tensors[2]
bias_tensor = input_tensors[3]
hidden_state_tensor = input_tensors[4]

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) >= 1, "output tensors length should be at least 1"

assert op.BuiltinOptionsType() == BuiltinOptions.RNNOptions
op_options = op.BuiltinOptions()
rnn_options = RNNOptions()
rnn_options.Init(op_options.Bytes, op_options.Pos)
fused_activation_fn = rnn_options.FusedActivationFunction()

# Constant weight/bias expressions.
weights_expr = self.get_tensor_expr(weights_tensor) # [num_units, input_size]
recurrent_expr = self.get_tensor_expr(recurrent_tensor) # [num_units, num_units]
bias_expr = self.get_tensor_expr(bias_tensor) # [num_units]

# Transpose to [input_size, num_units] and [num_units, num_units] for x @ W.T.
w_t = relax.op.permute_dims(weights_expr)
wr_t = relax.op.permute_dims(recurrent_expr)

# Resolve the input expression.
in_expr = self.get_tensor_expr(input_tensor)

# Initial hidden state: use the model's tensor value when available (non-zero init or
# graph input), otherwise fall back to zeros for the common variable-tensor case.
h_dtype = self.get_tensor_type_str(hidden_state_tensor.tensor.Type())
if self.has_expr(hidden_state_tensor.tensor_idx) or (
hidden_state_tensor.buffer is not None and hidden_state_tensor.buffer.DataLength() > 0
):
h = self.get_tensor_expr(hidden_state_tensor)
else:
h_shape = tuple(to_int_list(self.get_tensor_shape(hidden_state_tensor)))
h = relax.op.zeros(h_shape, dtype=h_dtype)

gates = relax.op.add(
relax.op.add(relax.op.matmul(in_expr, w_t), relax.op.matmul(h, wr_t)),
bias_expr,
)
h = self.convert_fused_activation_function(gates, fused_activation_fn)

self.exp_tab.set_expr(
get_tensor_name(self.subgraph, hidden_state_tensor.tensor_idx),
h,
force_override=True,
)
return h

def convert_unidirectional_sequence_rnn(self, op):
"""Convert TFLite UNIDIRECTIONAL_SEQUENCE_RNN.

Expand Down
290 changes: 290 additions & 0 deletions tests/python/relax/test_frontend_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3720,6 +3720,7 @@ def _get_tflite_schema_enum(enum_name):
_tfl_sparse_index_vector = _get_tflite_schema_enum("SparseIndexVector")
_tfl_tensor_type = _get_tflite_schema_enum("TensorType")

_tfl_rnn_options = _get_tflite_schema_module("RNNOptions")
_tfl_sequence_rnn_options = _get_tflite_schema_module("SequenceRNNOptions")

_DENSIFY_TEST_VALUES = np.array([1.0, 2.0], dtype=np.float32)
Expand Down Expand Up @@ -9721,6 +9722,295 @@ def main(
tvm.ir.assert_structural_equal(mod, Expected)


# ── RNN ────────────────────────────────────────────────────────────────────────


def _build_rnn_model(batch, input_size, num_units, weights, recurrent_weights, bias, activation):
"""Build a minimal TFLite flatbuffer model containing one RNN op.

Tensor layout (indices 0-5):
0 - input [batch, input_size]
1 - input_weights [num_units, input_size] (constant)
2 - recurrent_weights [num_units, num_units] (constant)
3 - bias [num_units] (constant)
4 - hidden_state [batch, num_units] (variable, zero-initialised)
5 - output [batch, num_units]
"""
builder = flatbuffers.Builder(4096)

_tfl_rnn_options.RNNOptionsStart(builder)
_tfl_rnn_options.RNNOptionsAddFusedActivationFunction(builder, activation)
rnn_opts = _tfl_rnn_options.RNNOptionsEnd(builder)

rnn_op_code = _build_operator_code(builder, _tfl_builtin_operator.RNN)

def _t(buf_idx, shape, is_variable=False):
shape_vec = _tflite_shape(builder, shape)
_tfl_tensor.TensorStart(builder)
_tfl_tensor.TensorAddBuffer(builder, buf_idx)
_tfl_tensor.TensorAddHasRank(builder, True)
_tfl_tensor.TensorAddIsVariable(builder, is_variable)
_tfl_tensor.TensorAddShape(builder, shape_vec)
_tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32)
return _tfl_tensor.TensorEnd(builder)

tensors = [
_t(0, [batch, input_size]),
_t(1, [num_units, input_size]),
_t(2, [num_units, num_units]),
_t(3, [num_units]),
_t(4, [batch, num_units], is_variable=True),
_t(5, [batch, num_units]),
]

rnn_op = _build_operator(
builder,
0,
[0, 1, 2, 3, 4],
[5],
builtin_options_type=_tfl_builtin_options.RNNOptions,
builtin_options=rnn_opts,
)

subgraph = _build_subgraph(
builder,
tensors=tensors,
operators=[rnn_op],
inputs=[0],
outputs=[5],
)

buffers = [
_build_buffer(builder),
_build_buffer(builder, weights.tobytes()),
_build_buffer(builder, recurrent_weights.tobytes()),
_build_buffer(builder, bias.tobytes()),
_build_buffer(builder),
_build_buffer(builder),
]

return _finish_tflite_model(
builder,
subgraph=subgraph,
operator_codes=[rnn_op_code],
buffers=buffers,
)


def _build_two_step_shared_state_rnn_model(
batch, input_size, num_units, weights, recurrent_weights, bias, activation
):
"""Build a TFLite model with two RNN ops sharing the same hidden-state tensor."""
builder = flatbuffers.Builder(4096)

_tfl_rnn_options.RNNOptionsStart(builder)
_tfl_rnn_options.RNNOptionsAddFusedActivationFunction(builder, activation)
rnn_opts = _tfl_rnn_options.RNNOptionsEnd(builder)

rnn_op_code = _build_operator_code(builder, _tfl_builtin_operator.RNN)

def _t(buf_idx, shape, is_variable=False):
shape_vec = _tflite_shape(builder, shape)
_tfl_tensor.TensorStart(builder)
_tfl_tensor.TensorAddBuffer(builder, buf_idx)
_tfl_tensor.TensorAddHasRank(builder, True)
_tfl_tensor.TensorAddIsVariable(builder, is_variable)
_tfl_tensor.TensorAddShape(builder, shape_vec)
_tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32)
return _tfl_tensor.TensorEnd(builder)

tensors = [
_t(0, [batch, input_size]),
_t(1, [num_units, input_size]),
_t(2, [num_units, num_units]),
_t(3, [num_units]),
_t(4, [batch, num_units], is_variable=True),
_t(0, [batch, input_size]),
_t(0, [batch, num_units]),
_t(0, [batch, num_units]),
]

first_rnn_op = _build_operator(
builder,
0,
[0, 1, 2, 3, 4],
[6],
builtin_options_type=_tfl_builtin_options.RNNOptions,
builtin_options=rnn_opts,
)
second_rnn_op = _build_operator(
builder,
0,
[5, 1, 2, 3, 4],
[7],
builtin_options_type=_tfl_builtin_options.RNNOptions,
builtin_options=rnn_opts,
)

subgraph = _build_subgraph(
builder,
tensors=tensors,
operators=[first_rnn_op, second_rnn_op],
inputs=[0, 5],
outputs=[7],
)

buffers = [
_build_buffer(builder),
_build_buffer(builder, weights.tobytes()),
_build_buffer(builder, recurrent_weights.tobytes()),
_build_buffer(builder, bias.tobytes()),
_build_buffer(builder),
]

return _finish_tflite_model(
builder,
subgraph=subgraph,
operator_codes=[rnn_op_code],
buffers=buffers,
)


def test_rnn_none_activation():
"""RNN with NONE activation lowers to matmul/add.

Cell equation: h = x @ W.T + h @ Wr.T + b (no activation for NONE)
"""
from tflite.ActivationFunctionType import ActivationFunctionType

batch, input_size, num_units = 2, 2, 2
weights = np.eye(num_units, input_size, dtype=np.float32)
recurrent_weights = np.eye(num_units, dtype=np.float32)
bias = np.zeros(num_units, dtype=np.float32)

mod = _load_model_from_buffer(
_build_rnn_model(
batch,
input_size,
num_units,
weights,
recurrent_weights,
bias,
ActivationFunctionType.NONE,
)
)

@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
lv: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
R.const(np.eye(2, dtype=np.float32)), axes=None
)
lv1: R.Tensor((2, 2), dtype="float32") = R.matmul(x, lv, out_dtype="void")
lv2: R.Tensor((2, 2), dtype="float32") = R.zeros(R.shape([2, 2]), dtype="float32")
lv3: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
R.const(np.eye(2, dtype=np.float32)), axes=None
)
lv4: R.Tensor((2, 2), dtype="float32") = R.matmul(lv2, lv3, out_dtype="void")
lv5: R.Tensor((2, 2), dtype="float32") = R.add(lv1, lv4)
gv: R.Tensor((2, 2), dtype="float32") = R.add(
lv5, R.const(np.zeros(2, dtype=np.float32))
)
R.output(gv)
return gv

tvm.ir.assert_structural_equal(mod, Expected)


def test_rnn_relu_activation():
"""RNN with RELU activation and random weights."""
from tflite.ActivationFunctionType import ActivationFunctionType

batch, input_size, num_units = 2, 4, 8
np.random.seed(42)
weights = np.random.randn(num_units, input_size).astype(np.float32)
recurrent_weights = np.random.randn(num_units, num_units).astype(np.float32)
bias = np.random.randn(num_units).astype(np.float32)

mod = _load_model_from_buffer(
_build_rnn_model(
batch,
input_size,
num_units,
weights,
recurrent_weights,
bias,
ActivationFunctionType.RELU,
)
)

fn = mod["main"]
assert len(fn.params) == 1, "only the input should be a graph input"
in_shape = fn.params[0].struct_info.shape
assert tuple(int(d) for d in in_shape) == (batch, input_size)
out_shape = fn.ret_struct_info.shape
assert tuple(int(d) for d in out_shape) == (batch, num_units)


def test_rnn_shared_hidden_state_updates_exp_tab():
"""Two consecutive RNN ops sharing hidden_state should use the updated state."""
from tflite.ActivationFunctionType import ActivationFunctionType

batch, input_size, num_units = 2, 2, 2
weights = np.eye(num_units, input_size, dtype=np.float32)
recurrent_weights = np.eye(num_units, dtype=np.float32)
bias = np.zeros(num_units, dtype=np.float32)

mod = _load_model_from_buffer(
_build_two_step_shared_state_rnn_model(
batch,
input_size,
num_units,
weights,
recurrent_weights,
bias,
ActivationFunctionType.NONE,
)
)

@I.ir_module
class Expected:
@R.function
def main(
x0: R.Tensor((2, 2), dtype="float32"),
x1: R.Tensor((2, 2), dtype="float32"),
) -> R.Tensor((2, 2), dtype="float32"):
R.func_attr({"num_input": 2})
with R.dataflow():
lv: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
R.const(np.eye(2, dtype=np.float32)), axes=None
)
lv1: R.Tensor((2, 2), dtype="float32") = R.matmul(x0, lv, out_dtype="void")
lv2: R.Tensor((2, 2), dtype="float32") = R.zeros(R.shape([2, 2]), dtype="float32")
lv3: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
R.const(np.eye(2, dtype=np.float32)), axes=None
)
lv4: R.Tensor((2, 2), dtype="float32") = R.matmul(lv2, lv3, out_dtype="void")
lv5: R.Tensor((2, 2), dtype="float32") = R.add(lv1, lv4)
lv6: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
R.const(np.eye(2, dtype=np.float32)), axes=None
)
lv7: R.Tensor((2, 2), dtype="float32") = R.matmul(x1, lv6, out_dtype="void")
lv8: R.Tensor((2, 2), dtype="float32") = R.add(
lv5, R.const(np.zeros(2, dtype=np.float32))
)
lv9: R.Tensor((2, 2), dtype="float32") = R.permute_dims(
R.const(np.eye(2, dtype=np.float32)), axes=None
)
lv10: R.Tensor((2, 2), dtype="float32") = R.matmul(lv8, lv9, out_dtype="void")
lv11: R.Tensor((2, 2), dtype="float32") = R.add(lv7, lv10)
gv: R.Tensor((2, 2), dtype="float32") = R.add(
lv11, R.const(np.zeros(2, dtype=np.float32))
)
R.output(gv)
return gv

tvm.ir.assert_structural_equal(mod, Expected)


# ── UNIDIRECTIONAL_SEQUENCE_RNN ───────────────────────────────────────────────


Expand Down
Loading