diff --git a/mlx/ops.cpp b/mlx/ops.cpp index d56ed7ffa4..f2e3304cab 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3947,6 +3947,7 @@ array cumsum( int axis, bool reverse /* = false*/, bool inclusive /* = true*/, + std::optional dtype /* = std::nullopt*/, StreamOrDevice s /* = {}*/) { int ndim = a.ndim(); if (axis >= ndim || axis < -ndim) { @@ -3956,21 +3957,24 @@ array cumsum( throw std::invalid_argument(msg.str()); } axis = (axis + a.ndim()) % a.ndim(); - auto out_type = a.dtype() == bool_ ? int32 : a.dtype(); + auto x = dtype ? astype(a, *dtype, s) : a; + auto out_type = x.dtype() == bool_ ? int32 : x.dtype(); return array( - a.shape(), + x.shape(), out_type, std::make_shared( to_stream(s), Scan::ReduceType::Sum, axis, reverse, inclusive), - {a}); + {x}); } array cumsum( const array& a, bool reverse /* = false*/, bool inclusive /* = true*/, + std::optional dtype /* = std::nullopt*/, StreamOrDevice s /* = {}*/) { - return cumsum(flatten(a, to_stream(s)), 0, reverse, inclusive, to_stream(s)); + return cumsum( + flatten(a, to_stream(s)), 0, reverse, inclusive, dtype, to_stream(s)); } array cumprod( @@ -3978,6 +3982,7 @@ array cumprod( int axis, bool reverse /* = false*/, bool inclusive /* = true*/, + std::optional dtype /* = std::nullopt*/, StreamOrDevice s /* = {}*/) { int ndim = a.ndim(); if (axis >= ndim || axis < -ndim) { @@ -3987,20 +3992,22 @@ array cumprod( throw std::invalid_argument(msg.str()); } axis = (axis + a.ndim()) % a.ndim(); + auto x = dtype ? astype(a, *dtype, s) : a; return array( - a.shape(), - a.dtype(), + x.shape(), + x.dtype(), std::make_shared( to_stream(s), Scan::ReduceType::Prod, axis, reverse, inclusive), - {a}); + {x}); } array cumprod( const array& a, bool reverse /* = false*/, bool inclusive /* = true*/, + std::optional dtype /* = std::nullopt*/, StreamOrDevice s /* = {}*/) { - return cumprod(flatten(a, s), 0, reverse, inclusive, s); + return cumprod(flatten(a, s), 0, reverse, inclusive, dtype, s); } array cummax( diff --git a/mlx/ops.h b/mlx/ops.h index 97f06eb6e3..cc6deef38f 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1333,7 +1333,12 @@ MLX_API array cumsum( const array& a, bool reverse = false, bool inclusive = true, + std::optional dtype = std::nullopt, StreamOrDevice s = {}); +inline array +cumsum(const array& a, bool reverse, bool inclusive, StreamOrDevice s) { + return cumsum(a, reverse, inclusive, std::nullopt, s); +} /** Cumulative sum of an array along the given axis. */ MLX_API array cumsum( @@ -1341,14 +1346,28 @@ MLX_API array cumsum( int axis, bool reverse = false, bool inclusive = true, + std::optional dtype = std::nullopt, StreamOrDevice s = {}); +inline array cumsum( + const array& a, + int axis, + bool reverse, + bool inclusive, + StreamOrDevice s) { + return cumsum(a, axis, reverse, inclusive, std::nullopt, s); +} /** Cumulative product of an array. */ MLX_API array cumprod( const array& a, bool reverse = false, bool inclusive = true, + std::optional dtype = std::nullopt, StreamOrDevice s = {}); +inline array +cumprod(const array& a, bool reverse, bool inclusive, StreamOrDevice s) { + return cumprod(a, reverse, inclusive, std::nullopt, s); +} /** Cumulative product of an array along the given axis. */ MLX_API array cumprod( @@ -1356,7 +1375,16 @@ MLX_API array cumprod( int axis, bool reverse = false, bool inclusive = true, + std::optional dtype = std::nullopt, StreamOrDevice s = {}); +inline array cumprod( + const array& a, + int axis, + bool reverse, + bool inclusive, + StreamOrDevice s) { + return cumprod(a, axis, reverse, inclusive, std::nullopt, s); +} /** Cumulative max of an array. */ MLX_API array cummax( diff --git a/python/src/ops.cpp b/python/src/ops.cpp index f11f98427d..4f515f5331 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3447,32 +3447,34 @@ void init_ops(nb::module_& m) { std::optional axis, bool reverse, bool inclusive, + std::optional dtype, mx::StreamOrDevice s) { if (axis) { - return mx::cumsum(a, *axis, reverse, inclusive, s); - } else { - return mx::cumsum(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); + return mx::cumsum(a, *axis, reverse, inclusive, dtype, s); } + return mx::cumsum(a, reverse, inclusive, dtype, s); }, nb::arg(), "axis"_a = nb::none(), nb::kw_only(), "reverse"_a = false, "inclusive"_a = true, + "dtype"_a = nb::none(), "stream"_a = nb::none(), nb::sig( - "def cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"), + "def cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, dtype: Optional[Dtype] = None, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Return the cumulative sum of the elements along the given axis. Args: - a (array): Input array + a (array): Input array. axis (int, optional): Optional axis to compute the cumulative sum over. If unspecified the cumulative sum of the flattened array is returned. reverse (bool): Perform the cumulative sum in reverse. inclusive (bool): The i-th element of the output includes the i-th element of the input. + dtype (Dtype, optional): Cast the input to this type before summing. Returns: array: The output array. @@ -3483,32 +3485,34 @@ void init_ops(nb::module_& m) { std::optional axis, bool reverse, bool inclusive, + std::optional dtype, mx::StreamOrDevice s) { if (axis) { - return mx::cumprod(a, *axis, reverse, inclusive, s); - } else { - return mx::cumprod(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); + return mx::cumprod(a, *axis, reverse, inclusive, dtype, s); } + return mx::cumprod(a, reverse, inclusive, dtype, s); }, nb::arg(), "axis"_a = nb::none(), nb::kw_only(), "reverse"_a = false, "inclusive"_a = true, + "dtype"_a = nb::none(), "stream"_a = nb::none(), nb::sig( - "def cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"), + "def cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, dtype: Optional[Dtype] = None, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Return the cumulative product of the elements along the given axis. Args: - a (array): Input array + a (array): Input array. axis (int, optional): Optional axis to compute the cumulative product - over. If unspecified the cumulative product of the flattened array is - returned. + over. If unspecified the cumulative product of the flattened array + is returned. reverse (bool): Perform the cumulative product in reverse. inclusive (bool): The i-th element of the output includes the i-th element of the input. + dtype (Dtype, optional): Cast the input to this type before multiplying. Returns: array: The output array. @@ -5914,6 +5918,8 @@ void init_ops(nb::module_& m) { m.attr("atan2") = m.attr("arctan2"); m.attr("bitwise_left_shift") = m.attr("left_shift"); m.attr("bitwise_right_shift") = m.attr("right_shift"); + m.attr("cumulative_prod") = m.attr("cumprod"); + m.attr("cumulative_sum") = m.attr("cumsum"); m.attr("empty") = m.attr("zeros"); m.attr("empty_like") = m.attr("zeros_like"); m.attr("matrix_transpose") = m.attr("transpose");