Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3947,6 +3947,7 @@ array cumsum(
int axis,
bool reverse /* = false*/,
bool inclusive /* = true*/,
std::optional<Dtype> dtype /* = std::nullopt*/,
StreamOrDevice s /* = {}*/) {
int ndim = a.ndim();
if (axis >= ndim || axis < -ndim) {
Expand All @@ -3956,28 +3957,32 @@ array cumsum(
throw std::invalid_argument(msg.str());
}
axis = (axis + a.ndim()) % a.ndim();
auto out_type = a.dtype() == bool_ ? int32 : a.dtype();
auto x = dtype ? astype(a, *dtype, s) : a;
auto out_type = x.dtype() == bool_ ? int32 : x.dtype();
return array(
a.shape(),
x.shape(),
out_type,
std::make_shared<Scan>(
to_stream(s), Scan::ReduceType::Sum, axis, reverse, inclusive),
{a});
{x});
}

array cumsum(
const array& a,
bool reverse /* = false*/,
bool inclusive /* = true*/,
std::optional<Dtype> dtype /* = std::nullopt*/,
StreamOrDevice s /* = {}*/) {
return cumsum(flatten(a, to_stream(s)), 0, reverse, inclusive, to_stream(s));
return cumsum(
flatten(a, to_stream(s)), 0, reverse, inclusive, dtype, to_stream(s));
}

array cumprod(
const array& a,
int axis,
bool reverse /* = false*/,
bool inclusive /* = true*/,
std::optional<Dtype> dtype /* = std::nullopt*/,
StreamOrDevice s /* = {}*/) {
int ndim = a.ndim();
if (axis >= ndim || axis < -ndim) {
Expand All @@ -3987,20 +3992,22 @@ array cumprod(
throw std::invalid_argument(msg.str());
}
axis = (axis + a.ndim()) % a.ndim();
auto x = dtype ? astype(a, *dtype, s) : a;
return array(
a.shape(),
a.dtype(),
x.shape(),
x.dtype(),
std::make_shared<Scan>(
to_stream(s), Scan::ReduceType::Prod, axis, reverse, inclusive),
{a});
{x});
}

array cumprod(
const array& a,
bool reverse /* = false*/,
bool inclusive /* = true*/,
std::optional<Dtype> dtype /* = std::nullopt*/,
StreamOrDevice s /* = {}*/) {
return cumprod(flatten(a, s), 0, reverse, inclusive, s);
return cumprod(flatten(a, s), 0, reverse, inclusive, dtype, s);
}

array cummax(
Expand Down
28 changes: 28 additions & 0 deletions mlx/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -1333,30 +1333,58 @@ MLX_API array cumsum(
const array& a,
bool reverse = false,
bool inclusive = true,
std::optional<Dtype> dtype = std::nullopt,
StreamOrDevice s = {});
inline array
cumsum(const array& a, bool reverse, bool inclusive, StreamOrDevice s) {
return cumsum(a, reverse, inclusive, std::nullopt, s);
}

/** Cumulative sum of an array along the given axis. */
MLX_API array cumsum(
const array& a,
int axis,
bool reverse = false,
bool inclusive = true,
std::optional<Dtype> dtype = std::nullopt,
StreamOrDevice s = {});
inline array cumsum(
const array& a,
int axis,
bool reverse,
bool inclusive,
StreamOrDevice s) {
return cumsum(a, axis, reverse, inclusive, std::nullopt, s);
}

/** Cumulative product of an array. */
MLX_API array cumprod(
const array& a,
bool reverse = false,
bool inclusive = true,
std::optional<Dtype> dtype = std::nullopt,
StreamOrDevice s = {});
inline array
cumprod(const array& a, bool reverse, bool inclusive, StreamOrDevice s) {
return cumprod(a, reverse, inclusive, std::nullopt, s);
}

/** Cumulative product of an array along the given axis. */
MLX_API array cumprod(
const array& a,
int axis,
bool reverse = false,
bool inclusive = true,
std::optional<Dtype> dtype = std::nullopt,
StreamOrDevice s = {});
inline array cumprod(
const array& a,
int axis,
bool reverse,
bool inclusive,
StreamOrDevice s) {
return cumprod(a, axis, reverse, inclusive, std::nullopt, s);
}

/** Cumulative max of an array. */
MLX_API array cummax(
Expand Down
30 changes: 18 additions & 12 deletions python/src/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3447,32 +3447,34 @@ void init_ops(nb::module_& m) {
std::optional<int> axis,
bool reverse,
bool inclusive,
std::optional<mx::Dtype> dtype,
mx::StreamOrDevice s) {
if (axis) {
return mx::cumsum(a, *axis, reverse, inclusive, s);
} else {
return mx::cumsum(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
return mx::cumsum(a, *axis, reverse, inclusive, dtype, s);
}
return mx::cumsum(a, reverse, inclusive, dtype, s);
},
nb::arg(),
"axis"_a = nb::none(),
nb::kw_only(),
"reverse"_a = false,
"inclusive"_a = true,
"dtype"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"),
"def cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, dtype: Optional[Dtype] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Return the cumulative sum of the elements along the given axis.

Args:
a (array): Input array
a (array): Input array.
axis (int, optional): Optional axis to compute the cumulative sum
over. If unspecified the cumulative sum of the flattened array is
returned.
reverse (bool): Perform the cumulative sum in reverse.
inclusive (bool): The i-th element of the output includes the i-th
element of the input.
dtype (Dtype, optional): Cast the input to this type before summing.

Returns:
array: The output array.
Expand All @@ -3483,32 +3485,34 @@ void init_ops(nb::module_& m) {
std::optional<int> axis,
bool reverse,
bool inclusive,
std::optional<mx::Dtype> dtype,
mx::StreamOrDevice s) {
if (axis) {
return mx::cumprod(a, *axis, reverse, inclusive, s);
} else {
return mx::cumprod(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s);
return mx::cumprod(a, *axis, reverse, inclusive, dtype, s);
}
return mx::cumprod(a, reverse, inclusive, dtype, s);
},
nb::arg(),
"axis"_a = nb::none(),
nb::kw_only(),
"reverse"_a = false,
"inclusive"_a = true,
"dtype"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"),
"def cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, dtype: Optional[Dtype] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Return the cumulative product of the elements along the given axis.

Args:
a (array): Input array
a (array): Input array.
axis (int, optional): Optional axis to compute the cumulative product
over. If unspecified the cumulative product of the flattened array is
returned.
over. If unspecified the cumulative product of the flattened array
is returned.
reverse (bool): Perform the cumulative product in reverse.
inclusive (bool): The i-th element of the output includes the i-th
element of the input.
dtype (Dtype, optional): Cast the input to this type before multiplying.

Returns:
array: The output array.
Expand Down Expand Up @@ -5914,6 +5918,8 @@ void init_ops(nb::module_& m) {
m.attr("atan2") = m.attr("arctan2");
m.attr("bitwise_left_shift") = m.attr("left_shift");
m.attr("bitwise_right_shift") = m.attr("right_shift");
m.attr("cumulative_prod") = m.attr("cumprod");
m.attr("cumulative_sum") = m.attr("cumsum");
m.attr("empty") = m.attr("zeros");
m.attr("empty_like") = m.attr("zeros_like");
m.attr("matrix_transpose") = m.attr("transpose");
Expand Down
Loading