diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index e66dff8356c8..f5b88b0c6ad5 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -1630,6 +1630,9 @@ def convert_gather_nd(self, op): indices_dims = len(self._infer_shape(indices)) indices_t = relax.op.permute_dims(indices, axes=[-1] + list(range(indices_dims - 1))) + if indices_type == TensorType.INT32: + # Relax gather_nd requires int64 indices. + indices_t = relax.op.astype(indices_t, "int64") out = relax.op.gather_nd(data, indices_t) return out diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 69e9b290fd32..e4c237887e6e 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -1451,6 +1451,64 @@ def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float3 verify(ReverseV2, Expected) + +def test_gather(): + class Gather(tf.Module): + @tf.function( + input_signature=[ + tf.TensorSpec(shape=(2, 3, 4), dtype=tf.float32), + tf.TensorSpec(shape=(2,), dtype=tf.int64), + ] + ) + def func(self, x, indices): + return tf.gather(x, indices, axis=1) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 4), dtype="float32"), + indices: R.Tensor((2,), dtype="int64"), + ) -> R.Tensor((2, 2, 4), dtype="float32"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((2,), dtype="int32") = R.astype(indices, dtype="int32") + gv: R.Tensor((2, 2, 4), dtype="float32") = R.take(x, lv, axis=1, mode="fast") + R.output(gv) + return gv + + verify(Gather, Expected) + + +def test_gather_nd(): + class GatherND(tf.Module): + @tf.function( + input_signature=[ + tf.TensorSpec(shape=(2, 3, 4), dtype=tf.float32), + tf.TensorSpec(shape=(2, 2), dtype=tf.int32), + ] + ) + def func(self, x, indices): + return tf.gather_nd(x, indices) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 4), dtype="float32"), + indices: R.Tensor((2, 2), dtype="int32"), + ) -> R.Tensor((2, 4), dtype="float32"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((2, 2), dtype="int32") = R.permute_dims(indices, axes=[-1, 0]) + lv1: R.Tensor((2, 2), dtype="int64") = R.astype(lv, dtype="int64") + gv: R.Tensor((2, 4), dtype="float32") = R.gather_nd(x, lv1, batch_dims=0) + R.output(gv) + return gv + + verify(GatherND, Expected) + + def _make_conv2d_module(data_shape, kernel_shape, data_format, strides, padding): class Conv2DModule(tf.Module): @tf.function(