From 67a84e81886587c027a47c20c92f6d50c53997b4 Mon Sep 17 00:00:00 2001 From: katlun-lgtm <264247399+katlun-lgtm@users.noreply.github.com> Date: Fri, 19 Jun 2026 20:27:38 -0400 Subject: [PATCH 1/7] feat(array-api): add cumulative_sum and cumulative_prod These are the Array API standard equivalents of cumsum/cumprod with three key differences that justify the separate names: 1. axis=None (default) flattens the input first; cumsum/cumprod require an explicit axis. 2. include_initial=True prepends the identity element (0 for sum, 1 for prod) so the output length along axis is len+1. This matches the Array API spec's include_initial parameter and has no equivalent in cumsum/cumprod. 3. dtype parameter casts the input before accumulating, matching NumPy 2.0 / Array API behaviour. Docs and tests included. Part of the array API split from #3684. --- docs/src/python/ops.rst | 2 ++ python/src/ops.cpp | 64 ++++++++++++++++++++++++++++++++-------- python/tests/test_ops.py | 22 ++++++++++++++ 3 files changed, 76 insertions(+), 12 deletions(-) diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 84e0b9d08b..b68895c6f7 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -65,6 +65,8 @@ Operations cummin cumprod cumsum + cumulative_prod + cumulative_sum degrees depends dequantize diff --git a/python/src/ops.cpp b/python/src/ops.cpp index f11f98427d..86a5ba92ba 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3447,32 +3447,50 @@ void init_ops(nb::module_& m) { std::optional axis, bool reverse, bool inclusive, + std::optional dtype, + bool include_initial, mx::StreamOrDevice s) { + mx::array x = dtype ? mx::astype(a, *dtype, s) : a; + int ax; if (axis) { - return mx::cumsum(a, *axis, reverse, inclusive, s); + ax = *axis; } else { - return mx::cumsum(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); + x = mx::reshape(x, {-1}, s); + ax = 0; + } + auto out = mx::cumsum(x, ax, reverse, inclusive, s); + if (include_initial) { + int a2 = ax < 0 ? ax + static_cast(out.ndim()) : ax; + mx::Shape init_shape = out.shape(); + init_shape[a2] = 1; + auto init = mx::zeros(init_shape, out.dtype(), s); + out = reverse ? mx::concatenate({out, init}, a2, s) + : mx::concatenate({init, out}, a2, s); } + return out; }, nb::arg(), "axis"_a = nb::none(), nb::kw_only(), "reverse"_a = false, "inclusive"_a = true, + "dtype"_a = nb::none(), + "include_initial"_a = false, "stream"_a = nb::none(), - nb::sig( - "def cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Return the cumulative sum of the elements along the given axis. Args: - a (array): Input array + a (array): Input array. axis (int, optional): Optional axis to compute the cumulative sum over. If unspecified the cumulative sum of the flattened array is returned. reverse (bool): Perform the cumulative sum in reverse. inclusive (bool): The i-th element of the output includes the i-th element of the input. + dtype (Dtype, optional): Cast the input to this type before summing. + include_initial (bool): Prepend the identity element (0) so the + output has one extra element along the given axis. Returns: array: The output array. @@ -3483,32 +3501,50 @@ void init_ops(nb::module_& m) { std::optional axis, bool reverse, bool inclusive, + std::optional dtype, + bool include_initial, mx::StreamOrDevice s) { + mx::array x = dtype ? mx::astype(a, *dtype, s) : a; + int ax; if (axis) { - return mx::cumprod(a, *axis, reverse, inclusive, s); + ax = *axis; } else { - return mx::cumprod(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); + x = mx::reshape(x, {-1}, s); + ax = 0; } + auto out = mx::cumprod(x, ax, reverse, inclusive, s); + if (include_initial) { + int a2 = ax < 0 ? ax + static_cast(out.ndim()) : ax; + mx::Shape init_shape = out.shape(); + init_shape[a2] = 1; + auto init = mx::ones(init_shape, out.dtype(), s); + out = reverse ? mx::concatenate({out, init}, a2, s) + : mx::concatenate({init, out}, a2, s); + } + return out; }, nb::arg(), "axis"_a = nb::none(), nb::kw_only(), "reverse"_a = false, "inclusive"_a = true, + "dtype"_a = nb::none(), + "include_initial"_a = false, "stream"_a = nb::none(), - nb::sig( - "def cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Return the cumulative product of the elements along the given axis. Args: - a (array): Input array + a (array): Input array. axis (int, optional): Optional axis to compute the cumulative product - over. If unspecified the cumulative product of the flattened array is - returned. + over. If unspecified the cumulative product of the flattened array + is returned. reverse (bool): Perform the cumulative product in reverse. inclusive (bool): The i-th element of the output includes the i-th element of the input. + dtype (Dtype, optional): Cast the input to this type before multiplying. + include_initial (bool): Prepend the identity element (1) so the + output has one extra element along the given axis. Returns: array: The output array. @@ -5918,4 +5954,8 @@ void init_ops(nb::module_& m) { m.attr("empty_like") = m.attr("zeros_like"); m.attr("matrix_transpose") = m.attr("transpose"); m.attr("pow") = m.attr("power"); + // Array API aliases — cumulative_sum/cumulative_prod are pure aliases of + // cumsum/cumprod, which now support dtype and include_initial. + m.attr("cumulative_sum") = m.attr("cumsum"); + m.attr("cumulative_prod") = m.attr("cumprod"); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index c98f1fd440..c99b49aad6 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -3486,6 +3486,28 @@ def test_to_from_fp8(self): self.assertTrue(mx.array_equal(mx.from_fp8(mx.to_fp8(vals)), vals)) self.assertTrue(mx.array_equal(mx.from_fp8(mx.to_fp8(-vals)), -vals)) + def test_cumulative_sum_prod(self): + a = mx.array([1, 2, 3, 4]) + self.assertEqual(mx.cumulative_sum(a).tolist(), [1, 3, 6, 10]) + self.assertEqual( + mx.cumulative_sum(a, include_initial=True).tolist(), [0, 1, 3, 6, 10] + ) + self.assertEqual(mx.cumulative_prod(a).tolist(), [1, 2, 6, 24]) + self.assertEqual( + mx.cumulative_prod(a, include_initial=True).tolist(), [1, 1, 2, 6, 24] + ) + + m = mx.array([[1, 2], [3, 4]]) + self.assertEqual(mx.cumulative_sum(m, axis=0).tolist(), [[1, 2], [4, 6]]) + self.assertEqual(mx.cumulative_sum(m, axis=1).tolist(), [[1, 3], [3, 7]]) + self.assertEqual( + mx.cumulative_sum(m, axis=1, include_initial=True).tolist(), + [[0, 1, 3], [0, 3, 7]], + ) + # axis=None flattens. + self.assertEqual(mx.cumulative_sum(m).tolist(), [1, 3, 6, 10]) + self.assertEqual(mx.cumulative_sum(a, dtype=mx.float32).dtype, mx.float32) + if __name__ == "__main__": mlx_tests.MLXTestRunner() From 1ba31c449b6cf4c96354e3c4e3d095e3fdb50400 Mon Sep 17 00:00:00 2001 From: katlun-lgtm <264247399+katlun-lgtm@users.noreply.github.com> Date: Mon, 22 Jun 2026 20:42:36 -0400 Subject: [PATCH 2/7] fix(array-api): restore nb::sig() on cumsum/cumprod with extended params Without nb::sig(), nanobind's auto-generated __doc__ line 1 is parsed as RST by Sphinx, causing 'Inline emphasis start-string without end- string' warnings from the *, keyword-only separator. With nb::sig() present, Sphinx recognises line 1 as a Python function signature and skips RST markup parsing of that line. --- python/src/ops.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 86a5ba92ba..6118cf3931 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3477,6 +3477,8 @@ void init_ops(nb::module_& m) { "dtype"_a = nb::none(), "include_initial"_a = false, "stream"_a = nb::none(), + nb::sig( + "def cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, dtype: Optional[Dtype] = None, include_initial: bool = False, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Return the cumulative sum of the elements along the given axis. @@ -3531,6 +3533,8 @@ void init_ops(nb::module_& m) { "dtype"_a = nb::none(), "include_initial"_a = false, "stream"_a = nb::none(), + nb::sig( + "def cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, dtype: Optional[Dtype] = None, include_initial: bool = False, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Return the cumulative product of the elements along the given axis. From 63845e078083c89046a48b51c4608692c2b1ac37 Mon Sep 17 00:00:00 2001 From: katlun-lgtm <264247399+katlun-lgtm@users.noreply.github.com> Date: Mon, 22 Jun 2026 20:52:57 -0400 Subject: [PATCH 3/7] fix(array-api): separate cumulative_sum/prod from cumsum/cumprod aliases MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The pure-alias approach (m.attr = m.attr) caused Sphinx to see cumulative_sum.__doc__ starting with 'cumsum(...)' — a name mismatch that prevented signature stripping, leaving '*,' to be parsed as RST emphasis → 'Inline emphasis start-string without end-string'. Separate bindings with their own nb::sig('def cumulative_sum(...)') fix this: Sphinx sees the correct function name, strips the signature line, and processes only the plain docstring body as RST. cumsum and cumprod are restored to upstream-identical implementations. cumulative_sum and cumulative_prod add dtype and include_initial params as required by the array API standard. --- python/src/ops.cpp | 178 +++++++++++++++++++++++++++++++-------------- 1 file changed, 124 insertions(+), 54 deletions(-) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 6118cf3931..4deb358060 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3447,52 +3447,32 @@ void init_ops(nb::module_& m) { std::optional axis, bool reverse, bool inclusive, - std::optional dtype, - bool include_initial, mx::StreamOrDevice s) { - mx::array x = dtype ? mx::astype(a, *dtype, s) : a; - int ax; if (axis) { - ax = *axis; + return mx::cumsum(a, *axis, reverse, inclusive, s); } else { - x = mx::reshape(x, {-1}, s); - ax = 0; - } - auto out = mx::cumsum(x, ax, reverse, inclusive, s); - if (include_initial) { - int a2 = ax < 0 ? ax + static_cast(out.ndim()) : ax; - mx::Shape init_shape = out.shape(); - init_shape[a2] = 1; - auto init = mx::zeros(init_shape, out.dtype(), s); - out = reverse ? mx::concatenate({out, init}, a2, s) - : mx::concatenate({init, out}, a2, s); + return mx::cumsum(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); } - return out; }, nb::arg(), "axis"_a = nb::none(), nb::kw_only(), "reverse"_a = false, "inclusive"_a = true, - "dtype"_a = nb::none(), - "include_initial"_a = false, "stream"_a = nb::none(), nb::sig( - "def cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, dtype: Optional[Dtype] = None, include_initial: bool = False, stream: Union[None, Stream, Device] = None) -> array"), + "def cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Return the cumulative sum of the elements along the given axis. Args: - a (array): Input array. + a (array): Input array axis (int, optional): Optional axis to compute the cumulative sum over. If unspecified the cumulative sum of the flattened array is returned. reverse (bool): Perform the cumulative sum in reverse. inclusive (bool): The i-th element of the output includes the i-th element of the input. - dtype (Dtype, optional): Cast the input to this type before summing. - include_initial (bool): Prepend the identity element (0) so the - output has one extra element along the given axis. Returns: array: The output array. @@ -3503,52 +3483,32 @@ void init_ops(nb::module_& m) { std::optional axis, bool reverse, bool inclusive, - std::optional dtype, - bool include_initial, mx::StreamOrDevice s) { - mx::array x = dtype ? mx::astype(a, *dtype, s) : a; - int ax; if (axis) { - ax = *axis; + return mx::cumprod(a, *axis, reverse, inclusive, s); } else { - x = mx::reshape(x, {-1}, s); - ax = 0; - } - auto out = mx::cumprod(x, ax, reverse, inclusive, s); - if (include_initial) { - int a2 = ax < 0 ? ax + static_cast(out.ndim()) : ax; - mx::Shape init_shape = out.shape(); - init_shape[a2] = 1; - auto init = mx::ones(init_shape, out.dtype(), s); - out = reverse ? mx::concatenate({out, init}, a2, s) - : mx::concatenate({init, out}, a2, s); + return mx::cumprod(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); } - return out; }, nb::arg(), "axis"_a = nb::none(), nb::kw_only(), "reverse"_a = false, "inclusive"_a = true, - "dtype"_a = nb::none(), - "include_initial"_a = false, "stream"_a = nb::none(), nb::sig( - "def cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, dtype: Optional[Dtype] = None, include_initial: bool = False, stream: Union[None, Stream, Device] = None) -> array"), + "def cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Return the cumulative product of the elements along the given axis. Args: - a (array): Input array. + a (array): Input array axis (int, optional): Optional axis to compute the cumulative product - over. If unspecified the cumulative product of the flattened array - is returned. + over. If unspecified the cumulative product of the flattened array is + returned. reverse (bool): Perform the cumulative product in reverse. inclusive (bool): The i-th element of the output includes the i-th element of the input. - dtype (Dtype, optional): Cast the input to this type before multiplying. - include_initial (bool): Prepend the identity element (1) so the - output has one extra element along the given axis. Returns: array: The output array. @@ -5958,8 +5918,118 @@ void init_ops(nb::module_& m) { m.attr("empty_like") = m.attr("zeros_like"); m.attr("matrix_transpose") = m.attr("transpose"); m.attr("pow") = m.attr("power"); - // Array API aliases — cumulative_sum/cumulative_prod are pure aliases of - // cumsum/cumprod, which now support dtype and include_initial. - m.attr("cumulative_sum") = m.attr("cumsum"); - m.attr("cumulative_prod") = m.attr("cumprod"); + // Array API: cumulative_sum and cumulative_prod extend cumsum/cumprod with + // dtype and include_initial as specified by the array API standard. + m.def( + "cumulative_sum", + [](const mx::array& a, + std::optional axis, + bool reverse, + bool inclusive, + std::optional dtype, + bool include_initial, + mx::StreamOrDevice s) { + mx::array x = dtype ? mx::astype(a, *dtype, s) : a; + int ax; + if (axis) { + ax = *axis; + } else { + x = mx::reshape(x, {-1}, s); + ax = 0; + } + auto out = mx::cumsum(x, ax, reverse, inclusive, s); + if (include_initial) { + int a2 = ax < 0 ? ax + static_cast(out.ndim()) : ax; + mx::Shape init_shape = out.shape(); + init_shape[a2] = 1; + auto init = mx::zeros(init_shape, out.dtype(), s); + out = reverse ? mx::concatenate({out, init}, a2, s) + : mx::concatenate({init, out}, a2, s); + } + return out; + }, + nb::arg(), + "axis"_a = nb::none(), + nb::kw_only(), + "reverse"_a = false, + "inclusive"_a = true, + "dtype"_a = nb::none(), + "include_initial"_a = false, + "stream"_a = nb::none(), + nb::sig( + "def cumulative_sum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, dtype: Optional[Dtype] = None, include_initial: bool = False, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Return the cumulative sum of the elements along the given axis. + + Args: + a (array): Input array. + axis (int, optional): Optional axis to compute the cumulative sum + over. If unspecified the cumulative sum of the flattened array is + returned. + reverse (bool): Perform the cumulative sum in reverse. + inclusive (bool): The i-th element of the output includes the i-th + element of the input. + dtype (Dtype, optional): Cast the input to this type before summing. + include_initial (bool): Prepend the identity element (0) so the + output has one extra element along the given axis. + + Returns: + array: The output array. + )pbdoc"); + m.def( + "cumulative_prod", + [](const mx::array& a, + std::optional axis, + bool reverse, + bool inclusive, + std::optional dtype, + bool include_initial, + mx::StreamOrDevice s) { + mx::array x = dtype ? mx::astype(a, *dtype, s) : a; + int ax; + if (axis) { + ax = *axis; + } else { + x = mx::reshape(x, {-1}, s); + ax = 0; + } + auto out = mx::cumprod(x, ax, reverse, inclusive, s); + if (include_initial) { + int a2 = ax < 0 ? ax + static_cast(out.ndim()) : ax; + mx::Shape init_shape = out.shape(); + init_shape[a2] = 1; + auto init = mx::ones(init_shape, out.dtype(), s); + out = reverse ? mx::concatenate({out, init}, a2, s) + : mx::concatenate({init, out}, a2, s); + } + return out; + }, + nb::arg(), + "axis"_a = nb::none(), + nb::kw_only(), + "reverse"_a = false, + "inclusive"_a = true, + "dtype"_a = nb::none(), + "include_initial"_a = false, + "stream"_a = nb::none(), + nb::sig( + "def cumulative_prod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, dtype: Optional[Dtype] = None, include_initial: bool = False, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Return the cumulative product of the elements along the given axis. + + Args: + a (array): Input array. + axis (int, optional): Optional axis to compute the cumulative product + over. If unspecified the cumulative product of the flattened array + is returned. + reverse (bool): Perform the cumulative product in reverse. + inclusive (bool): The i-th element of the output includes the i-th + element of the input. + dtype (Dtype, optional): Cast the input to this type before multiplying. + include_initial (bool): Prepend the identity element (1) so the + output has one extra element along the given axis. + + Returns: + array: The output array. + )pbdoc"); } From 381923fe1c6d7f6cf4739694b70a210f1e0c7712 Mon Sep 17 00:00:00 2001 From: katlun-lgtm <264247399+katlun-lgtm@users.noreply.github.com> Date: Mon, 22 Jun 2026 21:38:57 -0400 Subject: [PATCH 4/7] fix(array-api): extend cumsum/cumprod with dtype+include_initial, pure aliases MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per zcbenz review: extend cumsum and cumprod with dtype and include_initial params, then expose cumulative_sum and cumulative_prod as pure aliases (m.attr = m.attr). Also remove cumulative_sum/cumulative_prod from ops.rst autosummary — all other pure aliases (empty, pow, matrix_transpose) are intentionally absent from ops.rst; listing them caused Sphinx to encounter a name mismatch in the shared __doc__ (__doc__ starts with 'cumsum(...)' but Sphinx is documenting 'cumulative_sum') which prevented signature stripping and triggered an RST emphasis warning on '*, '. --- docs/src/python/ops.rst | 2 - python/src/ops.cpp | 176 ++++++++++++---------------------------- 2 files changed, 52 insertions(+), 126 deletions(-) diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index b68895c6f7..84e0b9d08b 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -65,8 +65,6 @@ Operations cummin cumprod cumsum - cumulative_prod - cumulative_sum degrees depends dequantize diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 4deb358060..23ff2859af 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3447,32 +3447,52 @@ void init_ops(nb::module_& m) { std::optional axis, bool reverse, bool inclusive, + std::optional dtype, + bool include_initial, mx::StreamOrDevice s) { + mx::array x = dtype ? mx::astype(a, *dtype, s) : a; + int ax; if (axis) { - return mx::cumsum(a, *axis, reverse, inclusive, s); + ax = *axis; } else { - return mx::cumsum(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); + x = mx::reshape(x, {-1}, s); + ax = 0; + } + auto out = mx::cumsum(x, ax, reverse, inclusive, s); + if (include_initial) { + int a2 = ax < 0 ? ax + static_cast(out.ndim()) : ax; + mx::Shape init_shape = out.shape(); + init_shape[a2] = 1; + auto init = mx::zeros(init_shape, out.dtype(), s); + out = reverse ? mx::concatenate({out, init}, a2, s) + : mx::concatenate({init, out}, a2, s); } + return out; }, nb::arg(), "axis"_a = nb::none(), nb::kw_only(), "reverse"_a = false, "inclusive"_a = true, + "dtype"_a = nb::none(), + "include_initial"_a = false, "stream"_a = nb::none(), nb::sig( - "def cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"), + "def cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, dtype: Optional[Dtype] = None, include_initial: bool = False, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Return the cumulative sum of the elements along the given axis. Args: - a (array): Input array + a (array): Input array. axis (int, optional): Optional axis to compute the cumulative sum over. If unspecified the cumulative sum of the flattened array is returned. reverse (bool): Perform the cumulative sum in reverse. inclusive (bool): The i-th element of the output includes the i-th element of the input. + dtype (Dtype, optional): Cast the input to this type before summing. + include_initial (bool): Prepend the identity element (0) so the + output has one extra element along the given axis. Returns: array: The output array. @@ -3483,32 +3503,52 @@ void init_ops(nb::module_& m) { std::optional axis, bool reverse, bool inclusive, + std::optional dtype, + bool include_initial, mx::StreamOrDevice s) { + mx::array x = dtype ? mx::astype(a, *dtype, s) : a; + int ax; if (axis) { - return mx::cumprod(a, *axis, reverse, inclusive, s); + ax = *axis; } else { - return mx::cumprod(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); + x = mx::reshape(x, {-1}, s); + ax = 0; + } + auto out = mx::cumprod(x, ax, reverse, inclusive, s); + if (include_initial) { + int a2 = ax < 0 ? ax + static_cast(out.ndim()) : ax; + mx::Shape init_shape = out.shape(); + init_shape[a2] = 1; + auto init = mx::ones(init_shape, out.dtype(), s); + out = reverse ? mx::concatenate({out, init}, a2, s) + : mx::concatenate({init, out}, a2, s); } + return out; }, nb::arg(), "axis"_a = nb::none(), nb::kw_only(), "reverse"_a = false, "inclusive"_a = true, + "dtype"_a = nb::none(), + "include_initial"_a = false, "stream"_a = nb::none(), nb::sig( - "def cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"), + "def cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, dtype: Optional[Dtype] = None, include_initial: bool = False, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Return the cumulative product of the elements along the given axis. Args: - a (array): Input array + a (array): Input array. axis (int, optional): Optional axis to compute the cumulative product - over. If unspecified the cumulative product of the flattened array is - returned. + over. If unspecified the cumulative product of the flattened array + is returned. reverse (bool): Perform the cumulative product in reverse. inclusive (bool): The i-th element of the output includes the i-th element of the input. + dtype (Dtype, optional): Cast the input to this type before multiplying. + include_initial (bool): Prepend the identity element (1) so the + output has one extra element along the given axis. Returns: array: The output array. @@ -5918,118 +5958,6 @@ void init_ops(nb::module_& m) { m.attr("empty_like") = m.attr("zeros_like"); m.attr("matrix_transpose") = m.attr("transpose"); m.attr("pow") = m.attr("power"); - // Array API: cumulative_sum and cumulative_prod extend cumsum/cumprod with - // dtype and include_initial as specified by the array API standard. - m.def( - "cumulative_sum", - [](const mx::array& a, - std::optional axis, - bool reverse, - bool inclusive, - std::optional dtype, - bool include_initial, - mx::StreamOrDevice s) { - mx::array x = dtype ? mx::astype(a, *dtype, s) : a; - int ax; - if (axis) { - ax = *axis; - } else { - x = mx::reshape(x, {-1}, s); - ax = 0; - } - auto out = mx::cumsum(x, ax, reverse, inclusive, s); - if (include_initial) { - int a2 = ax < 0 ? ax + static_cast(out.ndim()) : ax; - mx::Shape init_shape = out.shape(); - init_shape[a2] = 1; - auto init = mx::zeros(init_shape, out.dtype(), s); - out = reverse ? mx::concatenate({out, init}, a2, s) - : mx::concatenate({init, out}, a2, s); - } - return out; - }, - nb::arg(), - "axis"_a = nb::none(), - nb::kw_only(), - "reverse"_a = false, - "inclusive"_a = true, - "dtype"_a = nb::none(), - "include_initial"_a = false, - "stream"_a = nb::none(), - nb::sig( - "def cumulative_sum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, dtype: Optional[Dtype] = None, include_initial: bool = False, stream: Union[None, Stream, Device] = None) -> array"), - R"pbdoc( - Return the cumulative sum of the elements along the given axis. - - Args: - a (array): Input array. - axis (int, optional): Optional axis to compute the cumulative sum - over. If unspecified the cumulative sum of the flattened array is - returned. - reverse (bool): Perform the cumulative sum in reverse. - inclusive (bool): The i-th element of the output includes the i-th - element of the input. - dtype (Dtype, optional): Cast the input to this type before summing. - include_initial (bool): Prepend the identity element (0) so the - output has one extra element along the given axis. - - Returns: - array: The output array. - )pbdoc"); - m.def( - "cumulative_prod", - [](const mx::array& a, - std::optional axis, - bool reverse, - bool inclusive, - std::optional dtype, - bool include_initial, - mx::StreamOrDevice s) { - mx::array x = dtype ? mx::astype(a, *dtype, s) : a; - int ax; - if (axis) { - ax = *axis; - } else { - x = mx::reshape(x, {-1}, s); - ax = 0; - } - auto out = mx::cumprod(x, ax, reverse, inclusive, s); - if (include_initial) { - int a2 = ax < 0 ? ax + static_cast(out.ndim()) : ax; - mx::Shape init_shape = out.shape(); - init_shape[a2] = 1; - auto init = mx::ones(init_shape, out.dtype(), s); - out = reverse ? mx::concatenate({out, init}, a2, s) - : mx::concatenate({init, out}, a2, s); - } - return out; - }, - nb::arg(), - "axis"_a = nb::none(), - nb::kw_only(), - "reverse"_a = false, - "inclusive"_a = true, - "dtype"_a = nb::none(), - "include_initial"_a = false, - "stream"_a = nb::none(), - nb::sig( - "def cumulative_prod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, dtype: Optional[Dtype] = None, include_initial: bool = False, stream: Union[None, Stream, Device] = None) -> array"), - R"pbdoc( - Return the cumulative product of the elements along the given axis. - - Args: - a (array): Input array. - axis (int, optional): Optional axis to compute the cumulative product - over. If unspecified the cumulative product of the flattened array - is returned. - reverse (bool): Perform the cumulative product in reverse. - inclusive (bool): The i-th element of the output includes the i-th - element of the input. - dtype (Dtype, optional): Cast the input to this type before multiplying. - include_initial (bool): Prepend the identity element (1) so the - output has one extra element along the given axis. - - Returns: - array: The output array. - )pbdoc"); + m.attr("cumulative_sum") = m.attr("cumsum"); + m.attr("cumulative_prod") = m.attr("cumprod"); } From b96b56b5a550d50270328b8cade2a45fc30e6d53 Mon Sep 17 00:00:00 2001 From: katlun-lgtm <264247399+katlun-lgtm@users.noreply.github.com> Date: Tue, 23 Jun 2026 10:13:01 -0400 Subject: [PATCH 5/7] fix(array-api): move cumsum/cumprod dtype+include_initial logic to C++ --- mlx/ops.cpp | 47 ++++++++++++++++++++++++++++++++++++---------- mlx/ops.h | 8 ++++++++ python/src/ops.cpp | 36 ++++------------------------------- 3 files changed, 49 insertions(+), 42 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index d56ed7ffa4..3fd6e48bf4 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3947,6 +3947,8 @@ array cumsum( int axis, bool reverse /* = false*/, bool inclusive /* = true*/, + std::optional dtype /* = std::nullopt*/, + bool include_initial /* = false*/, StreamOrDevice s /* = {}*/) { int ndim = a.ndim(); if (axis >= ndim || axis < -ndim) { @@ -3956,21 +3958,33 @@ array cumsum( throw std::invalid_argument(msg.str()); } axis = (axis + a.ndim()) % a.ndim(); - auto out_type = a.dtype() == bool_ ? int32 : a.dtype(); - return array( - a.shape(), + auto x = dtype ? astype(a, *dtype, s) : a; + auto out_type = x.dtype() == bool_ ? int32 : x.dtype(); + auto out = array( + x.shape(), out_type, std::make_shared( to_stream(s), Scan::ReduceType::Sum, axis, reverse, inclusive), - {a}); + {x}); + if (include_initial) { + Shape init_shape = out.shape(); + init_shape[axis] = 1; + auto init = zeros(init_shape, out.dtype(), s); + out = reverse ? concatenate({out, init}, axis, s) + : concatenate({init, out}, axis, s); + } + return out; } array cumsum( const array& a, bool reverse /* = false*/, bool inclusive /* = true*/, + std::optional dtype /* = std::nullopt*/, + bool include_initial /* = false*/, StreamOrDevice s /* = {}*/) { - return cumsum(flatten(a, to_stream(s)), 0, reverse, inclusive, to_stream(s)); + return cumsum( + flatten(a, to_stream(s)), 0, reverse, inclusive, dtype, include_initial, to_stream(s)); } array cumprod( @@ -3978,6 +3992,8 @@ array cumprod( int axis, bool reverse /* = false*/, bool inclusive /* = true*/, + std::optional dtype /* = std::nullopt*/, + bool include_initial /* = false*/, StreamOrDevice s /* = {}*/) { int ndim = a.ndim(); if (axis >= ndim || axis < -ndim) { @@ -3987,20 +4003,31 @@ array cumprod( throw std::invalid_argument(msg.str()); } axis = (axis + a.ndim()) % a.ndim(); - return array( - a.shape(), - a.dtype(), + auto x = dtype ? astype(a, *dtype, s) : a; + auto out = array( + x.shape(), + x.dtype(), std::make_shared( to_stream(s), Scan::ReduceType::Prod, axis, reverse, inclusive), - {a}); + {x}); + if (include_initial) { + Shape init_shape = out.shape(); + init_shape[axis] = 1; + auto init = ones(init_shape, out.dtype(), s); + out = reverse ? concatenate({out, init}, axis, s) + : concatenate({init, out}, axis, s); + } + return out; } array cumprod( const array& a, bool reverse /* = false*/, bool inclusive /* = true*/, + std::optional dtype /* = std::nullopt*/, + bool include_initial /* = false*/, StreamOrDevice s /* = {}*/) { - return cumprod(flatten(a, s), 0, reverse, inclusive, s); + return cumprod(flatten(a, s), 0, reverse, inclusive, dtype, include_initial, s); } array cummax( diff --git a/mlx/ops.h b/mlx/ops.h index 97f06eb6e3..238e34d155 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1333,6 +1333,8 @@ MLX_API array cumsum( const array& a, bool reverse = false, bool inclusive = true, + std::optional dtype = std::nullopt, + bool include_initial = false, StreamOrDevice s = {}); /** Cumulative sum of an array along the given axis. */ @@ -1341,6 +1343,8 @@ MLX_API array cumsum( int axis, bool reverse = false, bool inclusive = true, + std::optional dtype = std::nullopt, + bool include_initial = false, StreamOrDevice s = {}); /** Cumulative product of an array. */ @@ -1348,6 +1352,8 @@ MLX_API array cumprod( const array& a, bool reverse = false, bool inclusive = true, + std::optional dtype = std::nullopt, + bool include_initial = false, StreamOrDevice s = {}); /** Cumulative product of an array along the given axis. */ @@ -1356,6 +1362,8 @@ MLX_API array cumprod( int axis, bool reverse = false, bool inclusive = true, + std::optional dtype = std::nullopt, + bool include_initial = false, StreamOrDevice s = {}); /** Cumulative max of an array. */ diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 23ff2859af..c015f757e2 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3450,24 +3450,10 @@ void init_ops(nb::module_& m) { std::optional dtype, bool include_initial, mx::StreamOrDevice s) { - mx::array x = dtype ? mx::astype(a, *dtype, s) : a; - int ax; if (axis) { - ax = *axis; - } else { - x = mx::reshape(x, {-1}, s); - ax = 0; - } - auto out = mx::cumsum(x, ax, reverse, inclusive, s); - if (include_initial) { - int a2 = ax < 0 ? ax + static_cast(out.ndim()) : ax; - mx::Shape init_shape = out.shape(); - init_shape[a2] = 1; - auto init = mx::zeros(init_shape, out.dtype(), s); - out = reverse ? mx::concatenate({out, init}, a2, s) - : mx::concatenate({init, out}, a2, s); + return mx::cumsum(a, *axis, reverse, inclusive, dtype, include_initial, s); } - return out; + return mx::cumsum(a, reverse, inclusive, dtype, include_initial, s); }, nb::arg(), "axis"_a = nb::none(), @@ -3506,24 +3492,10 @@ void init_ops(nb::module_& m) { std::optional dtype, bool include_initial, mx::StreamOrDevice s) { - mx::array x = dtype ? mx::astype(a, *dtype, s) : a; - int ax; if (axis) { - ax = *axis; - } else { - x = mx::reshape(x, {-1}, s); - ax = 0; - } - auto out = mx::cumprod(x, ax, reverse, inclusive, s); - if (include_initial) { - int a2 = ax < 0 ? ax + static_cast(out.ndim()) : ax; - mx::Shape init_shape = out.shape(); - init_shape[a2] = 1; - auto init = mx::ones(init_shape, out.dtype(), s); - out = reverse ? mx::concatenate({out, init}, a2, s) - : mx::concatenate({init, out}, a2, s); + return mx::cumprod(a, *axis, reverse, inclusive, dtype, include_initial, s); } - return out; + return mx::cumprod(a, reverse, inclusive, dtype, include_initial, s); }, nb::arg(), "axis"_a = nb::none(), From 153bd96ce9a8f39ff015da987205a4c0b3531904 Mon Sep 17 00:00:00 2001 From: katlun-lgtm <264247399+katlun-lgtm@users.noreply.github.com> Date: Fri, 26 Jun 2026 07:33:53 -0400 Subject: [PATCH 6/7] Address review: drop include_initial from cumsum/cumprod Per maintainer request (#3731), remove the include_initial arg. It needed an extra concatenate to be correct, which is the inefficient pattern angeloskath flagged. cumsum/cumprod keep the efficient dtype arg only; cumulative_sum and cumulative_prod remain pure aliases of them. --- mlx/ops.cpp | 28 ++++------------------------ mlx/ops.h | 4 ---- python/src/ops.cpp | 20 ++++++-------------- python/tests/test_ops.py | 10 ---------- 4 files changed, 10 insertions(+), 52 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 3fd6e48bf4..f2e3304cab 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3948,7 +3948,6 @@ array cumsum( bool reverse /* = false*/, bool inclusive /* = true*/, std::optional dtype /* = std::nullopt*/, - bool include_initial /* = false*/, StreamOrDevice s /* = {}*/) { int ndim = a.ndim(); if (axis >= ndim || axis < -ndim) { @@ -3960,20 +3959,12 @@ array cumsum( axis = (axis + a.ndim()) % a.ndim(); auto x = dtype ? astype(a, *dtype, s) : a; auto out_type = x.dtype() == bool_ ? int32 : x.dtype(); - auto out = array( + return array( x.shape(), out_type, std::make_shared( to_stream(s), Scan::ReduceType::Sum, axis, reverse, inclusive), {x}); - if (include_initial) { - Shape init_shape = out.shape(); - init_shape[axis] = 1; - auto init = zeros(init_shape, out.dtype(), s); - out = reverse ? concatenate({out, init}, axis, s) - : concatenate({init, out}, axis, s); - } - return out; } array cumsum( @@ -3981,10 +3972,9 @@ array cumsum( bool reverse /* = false*/, bool inclusive /* = true*/, std::optional dtype /* = std::nullopt*/, - bool include_initial /* = false*/, StreamOrDevice s /* = {}*/) { return cumsum( - flatten(a, to_stream(s)), 0, reverse, inclusive, dtype, include_initial, to_stream(s)); + flatten(a, to_stream(s)), 0, reverse, inclusive, dtype, to_stream(s)); } array cumprod( @@ -3993,7 +3983,6 @@ array cumprod( bool reverse /* = false*/, bool inclusive /* = true*/, std::optional dtype /* = std::nullopt*/, - bool include_initial /* = false*/, StreamOrDevice s /* = {}*/) { int ndim = a.ndim(); if (axis >= ndim || axis < -ndim) { @@ -4004,20 +3993,12 @@ array cumprod( } axis = (axis + a.ndim()) % a.ndim(); auto x = dtype ? astype(a, *dtype, s) : a; - auto out = array( + return array( x.shape(), x.dtype(), std::make_shared( to_stream(s), Scan::ReduceType::Prod, axis, reverse, inclusive), {x}); - if (include_initial) { - Shape init_shape = out.shape(); - init_shape[axis] = 1; - auto init = ones(init_shape, out.dtype(), s); - out = reverse ? concatenate({out, init}, axis, s) - : concatenate({init, out}, axis, s); - } - return out; } array cumprod( @@ -4025,9 +4006,8 @@ array cumprod( bool reverse /* = false*/, bool inclusive /* = true*/, std::optional dtype /* = std::nullopt*/, - bool include_initial /* = false*/, StreamOrDevice s /* = {}*/) { - return cumprod(flatten(a, s), 0, reverse, inclusive, dtype, include_initial, s); + return cumprod(flatten(a, s), 0, reverse, inclusive, dtype, s); } array cummax( diff --git a/mlx/ops.h b/mlx/ops.h index 238e34d155..d8cbe683fc 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1334,7 +1334,6 @@ MLX_API array cumsum( bool reverse = false, bool inclusive = true, std::optional dtype = std::nullopt, - bool include_initial = false, StreamOrDevice s = {}); /** Cumulative sum of an array along the given axis. */ @@ -1344,7 +1343,6 @@ MLX_API array cumsum( bool reverse = false, bool inclusive = true, std::optional dtype = std::nullopt, - bool include_initial = false, StreamOrDevice s = {}); /** Cumulative product of an array. */ @@ -1353,7 +1351,6 @@ MLX_API array cumprod( bool reverse = false, bool inclusive = true, std::optional dtype = std::nullopt, - bool include_initial = false, StreamOrDevice s = {}); /** Cumulative product of an array along the given axis. */ @@ -1363,7 +1360,6 @@ MLX_API array cumprod( bool reverse = false, bool inclusive = true, std::optional dtype = std::nullopt, - bool include_initial = false, StreamOrDevice s = {}); /** Cumulative max of an array. */ diff --git a/python/src/ops.cpp b/python/src/ops.cpp index c015f757e2..4b7bc85fb7 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3448,12 +3448,11 @@ void init_ops(nb::module_& m) { bool reverse, bool inclusive, std::optional dtype, - bool include_initial, mx::StreamOrDevice s) { if (axis) { - return mx::cumsum(a, *axis, reverse, inclusive, dtype, include_initial, s); + return mx::cumsum(a, *axis, reverse, inclusive, dtype, s); } - return mx::cumsum(a, reverse, inclusive, dtype, include_initial, s); + return mx::cumsum(a, reverse, inclusive, dtype, s); }, nb::arg(), "axis"_a = nb::none(), @@ -3461,10 +3460,9 @@ void init_ops(nb::module_& m) { "reverse"_a = false, "inclusive"_a = true, "dtype"_a = nb::none(), - "include_initial"_a = false, "stream"_a = nb::none(), nb::sig( - "def cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, dtype: Optional[Dtype] = None, include_initial: bool = False, stream: Union[None, Stream, Device] = None) -> array"), + "def cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, dtype: Optional[Dtype] = None, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Return the cumulative sum of the elements along the given axis. @@ -3477,8 +3475,6 @@ void init_ops(nb::module_& m) { inclusive (bool): The i-th element of the output includes the i-th element of the input. dtype (Dtype, optional): Cast the input to this type before summing. - include_initial (bool): Prepend the identity element (0) so the - output has one extra element along the given axis. Returns: array: The output array. @@ -3490,12 +3486,11 @@ void init_ops(nb::module_& m) { bool reverse, bool inclusive, std::optional dtype, - bool include_initial, mx::StreamOrDevice s) { if (axis) { - return mx::cumprod(a, *axis, reverse, inclusive, dtype, include_initial, s); + return mx::cumprod(a, *axis, reverse, inclusive, dtype, s); } - return mx::cumprod(a, reverse, inclusive, dtype, include_initial, s); + return mx::cumprod(a, reverse, inclusive, dtype, s); }, nb::arg(), "axis"_a = nb::none(), @@ -3503,10 +3498,9 @@ void init_ops(nb::module_& m) { "reverse"_a = false, "inclusive"_a = true, "dtype"_a = nb::none(), - "include_initial"_a = false, "stream"_a = nb::none(), nb::sig( - "def cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, dtype: Optional[Dtype] = None, include_initial: bool = False, stream: Union[None, Stream, Device] = None) -> array"), + "def cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, dtype: Optional[Dtype] = None, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Return the cumulative product of the elements along the given axis. @@ -3519,8 +3513,6 @@ void init_ops(nb::module_& m) { inclusive (bool): The i-th element of the output includes the i-th element of the input. dtype (Dtype, optional): Cast the input to this type before multiplying. - include_initial (bool): Prepend the identity element (1) so the - output has one extra element along the given axis. Returns: array: The output array. diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index c99b49aad6..879290b6e5 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -3489,21 +3489,11 @@ def test_to_from_fp8(self): def test_cumulative_sum_prod(self): a = mx.array([1, 2, 3, 4]) self.assertEqual(mx.cumulative_sum(a).tolist(), [1, 3, 6, 10]) - self.assertEqual( - mx.cumulative_sum(a, include_initial=True).tolist(), [0, 1, 3, 6, 10] - ) self.assertEqual(mx.cumulative_prod(a).tolist(), [1, 2, 6, 24]) - self.assertEqual( - mx.cumulative_prod(a, include_initial=True).tolist(), [1, 1, 2, 6, 24] - ) m = mx.array([[1, 2], [3, 4]]) self.assertEqual(mx.cumulative_sum(m, axis=0).tolist(), [[1, 2], [4, 6]]) self.assertEqual(mx.cumulative_sum(m, axis=1).tolist(), [[1, 3], [3, 7]]) - self.assertEqual( - mx.cumulative_sum(m, axis=1, include_initial=True).tolist(), - [[0, 1, 3], [0, 3, 7]], - ) # axis=None flattens. self.assertEqual(mx.cumulative_sum(m).tolist(), [1, 3, 6, 10]) self.assertEqual(mx.cumulative_sum(a, dtype=mx.float32).dtype, mx.float32) From b5230f6946bd090d53c9c396a24f391098cd0800 Mon Sep 17 00:00:00 2001 From: Cheng Date: Sun, 28 Jun 2026 15:46:13 +0900 Subject: [PATCH 7/7] Fix compilation --- mlx/ops.h | 24 ++++++++++++++++++++++++ python/src/ops.cpp | 4 ++-- python/tests/test_ops.py | 12 ------------ 3 files changed, 26 insertions(+), 14 deletions(-) diff --git a/mlx/ops.h b/mlx/ops.h index d8cbe683fc..cc6deef38f 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1335,6 +1335,10 @@ MLX_API array cumsum( bool inclusive = true, std::optional dtype = std::nullopt, StreamOrDevice s = {}); +inline array +cumsum(const array& a, bool reverse, bool inclusive, StreamOrDevice s) { + return cumsum(a, reverse, inclusive, std::nullopt, s); +} /** Cumulative sum of an array along the given axis. */ MLX_API array cumsum( @@ -1344,6 +1348,14 @@ MLX_API array cumsum( bool inclusive = true, std::optional dtype = std::nullopt, StreamOrDevice s = {}); +inline array cumsum( + const array& a, + int axis, + bool reverse, + bool inclusive, + StreamOrDevice s) { + return cumsum(a, axis, reverse, inclusive, std::nullopt, s); +} /** Cumulative product of an array. */ MLX_API array cumprod( @@ -1352,6 +1364,10 @@ MLX_API array cumprod( bool inclusive = true, std::optional dtype = std::nullopt, StreamOrDevice s = {}); +inline array +cumprod(const array& a, bool reverse, bool inclusive, StreamOrDevice s) { + return cumprod(a, reverse, inclusive, std::nullopt, s); +} /** Cumulative product of an array along the given axis. */ MLX_API array cumprod( @@ -1361,6 +1377,14 @@ MLX_API array cumprod( bool inclusive = true, std::optional dtype = std::nullopt, StreamOrDevice s = {}); +inline array cumprod( + const array& a, + int axis, + bool reverse, + bool inclusive, + StreamOrDevice s) { + return cumprod(a, axis, reverse, inclusive, std::nullopt, s); +} /** Cumulative max of an array. */ MLX_API array cummax( diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 4b7bc85fb7..4f515f5331 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -5918,10 +5918,10 @@ void init_ops(nb::module_& m) { m.attr("atan2") = m.attr("arctan2"); m.attr("bitwise_left_shift") = m.attr("left_shift"); m.attr("bitwise_right_shift") = m.attr("right_shift"); + m.attr("cumulative_prod") = m.attr("cumprod"); + m.attr("cumulative_sum") = m.attr("cumsum"); m.attr("empty") = m.attr("zeros"); m.attr("empty_like") = m.attr("zeros_like"); m.attr("matrix_transpose") = m.attr("transpose"); m.attr("pow") = m.attr("power"); - m.attr("cumulative_sum") = m.attr("cumsum"); - m.attr("cumulative_prod") = m.attr("cumprod"); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 879290b6e5..c98f1fd440 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -3486,18 +3486,6 @@ def test_to_from_fp8(self): self.assertTrue(mx.array_equal(mx.from_fp8(mx.to_fp8(vals)), vals)) self.assertTrue(mx.array_equal(mx.from_fp8(mx.to_fp8(-vals)), -vals)) - def test_cumulative_sum_prod(self): - a = mx.array([1, 2, 3, 4]) - self.assertEqual(mx.cumulative_sum(a).tolist(), [1, 3, 6, 10]) - self.assertEqual(mx.cumulative_prod(a).tolist(), [1, 2, 6, 24]) - - m = mx.array([[1, 2], [3, 4]]) - self.assertEqual(mx.cumulative_sum(m, axis=0).tolist(), [[1, 2], [4, 6]]) - self.assertEqual(mx.cumulative_sum(m, axis=1).tolist(), [[1, 3], [3, 7]]) - # axis=None flattens. - self.assertEqual(mx.cumulative_sum(m).tolist(), [1, 3, 6, 10]) - self.assertEqual(mx.cumulative_sum(a, dtype=mx.float32).dtype, mx.float32) - if __name__ == "__main__": mlx_tests.MLXTestRunner()