From a105e989818f05065ede5a612a38d809328e26e9 Mon Sep 17 00:00:00 2001 From: Pablosinyores Date: Mon, 29 Jun 2026 16:00:53 +0530 Subject: [PATCH] Fix BatchNorm docstring to cover 4D input and test the NHWC path The BatchNorm summary line said it applies over a 2D or 3D input, but the layer also accepts 4D NHWC tensors: the __call__ guard allows 2-4 dims (and its error message says "2, 3 or 4 dimensions"), and the docstring body already describes the four-dimensional NHWC shape. Correct the summary to match, and add the missing 4D coverage to test_batch_norm, asserting that an NHWC input is normalized per channel across N, H and W. --- python/mlx/nn/layers/normalization.py | 2 +- python/tests/test_nn.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index 01e83f4d53..f7d2d78c3b 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -256,7 +256,7 @@ def __call__(self, x): class BatchNorm(Module): - r"""Applies Batch Normalization over a 2D or 3D input. + r"""Applies Batch Normalization over a 2D, 3D or 4D input. Computes diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index a50e3e1c81..79974aac48 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -770,6 +770,16 @@ def test_batch_norm(self): self.assertIn("weight", bn_trainable) self.assertIn("bias", bn_trainable) + # test with 4D input (NHWC) + mx.random.seed(42) + x = mx.random.normal((2, 3, 3, 6), dtype=mx.float32) + bn = nn.BatchNorm(num_features=6, affine=True) + y = bn(x) + self.assertTrue(x.shape == y.shape) + # batch norm over an NHWC input normalizes each channel across N, H, W + self.assertTrue(mx.allclose(y.mean(axis=(0, 1, 2)), mx.zeros((6,)), atol=1e-5)) + self.assertTrue(mx.allclose(y.var(axis=(0, 1, 2)), mx.ones((6,)), atol=1e-2)) + def test_batch_norm_stats(self): batch_size = 2 num_features = 4