diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index c146cf6c00e3..7296f73e9a0f 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1645,12 +1645,72 @@ def _sum(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.sum(x, dim, keepdims=keepdim)) def _var(self, node: fx.Node) -> relax.Var: + # `aten.var.correction` (and decomposed `aten.std.*`) carries an + # optional `correction` kwarg whose `None` default means 1 (Bessel). + # Legacy fx `tensor.var(...)` calls go through the original path + # below to keep this fix narrowly scoped. + target = node.target + if getattr(target, "_overloadname", None) == "correction" or getattr( + target, "overload_name", None + ) == "correction": + return self._var_correction(node) args = self.retrieve_args(node) x = args[0] dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) return self.block_builder.emit(relax.op.variance(x, dim, keepdims=keepdim)) + def _var_correction(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + keepdim = node.kwargs.get("keepdim", False) + correction = node.kwargs.get("correction", None) + if correction is None: + correction = 1 + var = self.block_builder.emit(relax.op.variance(x, dim, keepdims=keepdim)) + if correction == 0: + return var + n = self._reduction_size(x, dim) + if n is None: + raise NotImplementedError( + "var/std with non-zero correction requires statically known " + "reduction-axis sizes." + ) + # PyTorch returns NaN (with a warning) when `n - correction <= 0`; + # mirror that semantics rather than failing the import. + if n - correction <= 0: + scale = float("nan") + else: + scale = float(n) / float(n - correction) + return self.block_builder.emit( + relax.op.multiply(var, relax.const(scale, x.struct_info.dtype)) + ) + + @staticmethod + def _reduction_size(x: relax.Expr, dim) -> int | None: + """Static product of reduced-axis sizes; None if any axis is dynamic.""" + shape = x.struct_info.shape + if shape is None: + return None + rank = len(shape) + if dim is None: + axes = list(range(rank)) + elif isinstance(dim, int): + axes = [dim] + elif isinstance(dim, (list, tuple)) and all(isinstance(a, int) for a in dim): + axes = list(dim) + else: + return None + n = 1 + for ax in axes: + ax = ax + rank if ax < 0 else ax + s = shape[ax] + if not isinstance(s, tirx.IntImm): + return None + n *= int(s.value) + return n + def _any(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 602949937247..cb0e4a80a8fe 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -7492,6 +7492,7 @@ def main( def test_std(): + # torch.std(x) defaults to correction=1 (Bessel); decomposes to var.correction + sqrt. class Std(Module): def forward(self, x): return torch.std(x) @@ -7504,8 +7505,9 @@ def main( ) -> R.Tuple(R.Tensor((), dtype="float32")): with R.dataflow(): lv: R.Tensor((), dtype="float32") = R.variance(x, axis=None, keepdims=False) - lv1: R.Tensor((), dtype="float32") = R.sqrt(lv) - gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,) + lv1: R.Tensor((), dtype="float32") = R.multiply(lv, R.const(15.0 / 14.0, "float32")) + lv2: R.Tensor((), dtype="float32") = R.sqrt(lv1) + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv2,) R.output(gv) return gv @@ -7514,6 +7516,7 @@ def main( def test_var(): + # torch.var(x) defaults to correction=1 (Bessel). class Var(Module): def forward(self, x): return torch.var(x) @@ -7526,7 +7529,8 @@ def main( ) -> R.Tuple(R.Tensor((), dtype="float32")): with R.dataflow(): lv: R.Tensor((), dtype="float32") = R.variance(x, axis=None, keepdims=False) - gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) + lv1: R.Tensor((), dtype="float32") = R.multiply(lv, R.const(15.0 / 14.0, "float32")) + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,) R.output(gv) return gv @@ -7534,6 +7538,45 @@ def main( verify_model(Var(), example_args, {}, Expected) +def test_var_correction(): + class VarCorrection2(Module): + def forward(self, x): + return torch.var(x, dim=-1, correction=2) + + class VarCorrection0(Module): + def forward(self, x): + return torch.var(x, dim=1, correction=0) + + @tvm.script.ir_module + class Expected2: + @R.function + def main( + x: R.Tensor((2, 5), dtype="float32"), + ) -> R.Tuple(R.Tensor((2,), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((2,), dtype="float32") = R.variance(x, axis=[-1], keepdims=False) + lv1: R.Tensor((2,), dtype="float32") = R.multiply(lv, R.const(5.0 / 3.0, "float32")) + gv: R.Tuple(R.Tensor((2,), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected0: + @R.function + def main( + x: R.Tensor((2, 5), dtype="float32"), + ) -> R.Tuple(R.Tensor((2,), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((2,), dtype="float32") = R.variance(x, axis=[1], keepdims=False) + gv: R.Tuple(R.Tensor((2,), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(2, 5, dtype=torch.float32),) + verify_model(VarCorrection2(), example_args, {}, Expected2) + verify_model(VarCorrection0(), example_args, {}, Expected0) + + def test_prod(): class Prod(Module): def forward(self, x):