Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/tvm/relax/frontend/tflite/tflite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1630,6 +1630,9 @@ def convert_gather_nd(self, op):

indices_dims = len(self._infer_shape(indices))
indices_t = relax.op.permute_dims(indices, axes=[-1] + list(range(indices_dims - 1)))
if indices_type == TensorType.INT32:
# Relax gather_nd requires int64 indices.
indices_t = relax.op.astype(indices_t, "int64")

out = relax.op.gather_nd(data, indices_t)
return out
Expand Down
58 changes: 58 additions & 0 deletions tests/python/relax/test_frontend_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1451,6 +1451,64 @@ def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float3

verify(ReverseV2, Expected)


def test_gather():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test covers gather with int64 indices. For completeness, it would be beneficial to also add a test case for int32 indices, as TFLite's GATHER op supports both int32 and int64 for indices.

class Gather(tf.Module):
@tf.function(
input_signature=[
tf.TensorSpec(shape=(2, 3, 4), dtype=tf.float32),
tf.TensorSpec(shape=(2,), dtype=tf.int64),
]
)
def func(self, x, indices):
return tf.gather(x, indices, axis=1)

@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor((2, 3, 4), dtype="float32"),
indices: R.Tensor((2,), dtype="int64"),
) -> R.Tensor((2, 2, 4), dtype="float32"):
R.func_attr({"num_input": 2})
with R.dataflow():
lv: R.Tensor((2,), dtype="int32") = R.astype(indices, dtype="int32")
gv: R.Tensor((2, 2, 4), dtype="float32") = R.take(x, lv, axis=1, mode="fast")
R.output(gv)
return gv

verify(Gather, Expected)


def test_gather_nd():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test covers gather_nd with int32 indices, which is great for verifying the change in this PR. To make the test coverage for this operator more complete, could you also add a test case for int64 indices? This would ensure the int64 path is not broken by future changes and remains tested.

class GatherND(tf.Module):
@tf.function(
input_signature=[
tf.TensorSpec(shape=(2, 3, 4), dtype=tf.float32),
tf.TensorSpec(shape=(2, 2), dtype=tf.int32),
]
)
def func(self, x, indices):
return tf.gather_nd(x, indices)

@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor((2, 3, 4), dtype="float32"),
indices: R.Tensor((2, 2), dtype="int32"),
) -> R.Tensor((2, 4), dtype="float32"):
R.func_attr({"num_input": 2})
with R.dataflow():
lv: R.Tensor((2, 2), dtype="int32") = R.permute_dims(indices, axes=[-1, 0])
lv1: R.Tensor((2, 2), dtype="int64") = R.astype(lv, dtype="int64")
gv: R.Tensor((2, 4), dtype="float32") = R.gather_nd(x, lv1, batch_dims=0)
R.output(gv)
return gv

verify(GatherND, Expected)


def _make_conv2d_module(data_shape, kernel_shape, data_format, strides, padding):
class Conv2DModule(tf.Module):
@tf.function(
Expand Down