Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ This release is compatible with NumPy 2.5.
* Improved performance of `dpnp.fft` functions for complex strided input by avoiding oversized allocations and extra copies [#2939](https://github.com/IntelPython/dpnp/pull/2939)
* Refreshed `dpnp` documentation styling with the Furo theme [#2934](https://github.com/IntelPython/dpnp/pull/2934)
* Updated Python Array API specification version supported to `2025.12` [#2899](https://github.com/IntelPython/dpnp/pull/2899)
* Aligned the signature of `dpnp.tensor.expand_dims` with the Python array API by making `axis` a required argument [#2988](https://github.com/IntelPython/dpnp/pull/2988)

### Deprecated

Expand Down
19 changes: 10 additions & 9 deletions dpnp/dpnp_iface_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1758,7 +1758,7 @@ def dstack(tup):
return dpnp.concatenate(arrs, axis=2)


def expand_dims(a, axis):
def expand_dims(a, /, axis):
"""
Expand the shape of an array.

Expand All @@ -1782,14 +1782,15 @@ def expand_dims(a, axis):

Notes
-----
If `a` has rank (i.e, number of dimensions) `N`, a valid `axis` must reside
in the closed-interval `[-N-1, N]`.
If provided a negative `axis`, the `axis` position at which to insert a
singleton dimension is computed as `N + axis + 1`.
Hence, if provided `-1`, the resolved axis position is `N` (i.e.,
a singleton dimension must be appended to the input array `a`).
If provided `-N-1`, the resolved axis position is `0` (i.e., a
singleton dimension is added to the input array `a`).
If `a` has rank (i.e, number of dimensions) `N`, a valid `axis` value must
reside on the half-open interval `[-M, M)`, where `M = N + len(axis)` (with
`len(axis)` equal to ``1`` when `axis` is an integer).
If provided a negative `axis`, the position at which to insert a singleton
dimension is computed as ``M + axis``.
Hence, if provided ``-1``, the resolved axis position is ``M - 1`` (i.e.,
a singleton dimension is appended to the input array `a`).
If provided ``-M``, the resolved axis position is ``0`` (i.e., a singleton
dimension is prepended to the input array `a`).

See Also
--------
Expand Down
38 changes: 25 additions & 13 deletions dpnp/tensor/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,36 +355,48 @@ def concat(arrays, /, *, axis=0):
return res


def expand_dims(X, /, *, axis=0):
def expand_dims(X, /, axis):
"""expand_dims(x, axis)

Expands the shape of an array by inserting a new axis (dimension)
of size one at the position specified by axis.
of size one at the position (or positions) specified by axis.

Args:
x (usm_ndarray):
input array
axis (Union[int, Tuple[int]]):
axis position in the expanded axes (zero-based). If `x` has rank
(i.e, number of dimensions) `N`, a valid `axis` must reside
in the closed-interval `[-N-1, N]`. If provided a negative
`axis`, the `axis` position at which to insert a singleton
dimension is computed as `N + axis + 1`. Hence, if
provided `-1`, the resolved axis position is `N` (i.e.,
a singleton dimension must be appended to the input array `x`).
If provided `-N-1`, the resolved axis position is `0` (i.e., a
singleton dimension is prepended to the input array `x`).
axis (Union[int, Tuple[int, ...]]):
axis position(s) (zero-based). If ``axis`` is an integer, ``axis``
**must** be equivalent to the tuple ``(axis,)``. If ``axis`` is
a tuple,

- a valid axis position **must** reside on the half-open interval
``[-M, M)``, where ``M = N + len(axis)`` and ``N`` is the number
of dimensions in ``x``.
- if the i-th entry is a negative integer, the axis position of the
inserted singleton dimension in the output array **must** be
computed as ``M + axis[i]``.
- each entry of ``axis`` must resolve to a unique positive axis
position.
- for each entry of ``axis``, the corresponding dimension in the
expanded output array **must** be a singleton dimension.
- for the remaining dimensions of the expanded output array, the
output array dimensions **must** correspond to the dimensions of
``x`` in order.

Returns:
usm_ndarray:
Returns a view, if possible, and a copy otherwise with the number
of dimensions increased.
The expanded array has the same data type as the input array `x`.
If ``axis`` is an integer, the output array must have ``N + 1``
dimensions. If ``axis`` is a tuple, the output array must have
``N + len(axis)`` dimensions.
The expanded array is located on the same device as the input
array, and has the same USM allocation type.

Raises:
IndexError: if `axis` value is invalid.
AxisError: if an `axis` value is out of range.
ValueError: if `axis` contains a repeated value.
"""
if not isinstance(X, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
Expand Down
17 changes: 16 additions & 1 deletion dpnp/tests/tensor/test_usm_ndarray_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_permute_dims_2d_3d(shapes):
def test_expand_dims_incorrect_type():
X_list = [1, 2, 3, 4, 5]
with pytest.raises(TypeError):
dpt.permute_dims(X_list, axis=1)
dpt.expand_dims(X_list, axis=1)


def test_expand_dims_0d():
Expand Down Expand Up @@ -154,6 +154,21 @@ def test_expand_dims_tuple(axes):
assert_array_equal(Ynp, dpt.asnumpy(Y))


def test_expand_dims_positional_axis():
q = get_queue_or_skip()

Xnp = np.empty((3, 3, 3), dtype="u1")
X = dpt.asarray(Xnp, sycl_queue=q)

Y = dpt.expand_dims(X, 1) # `axis` is a positional-or-keyword argument
Ynp = np.expand_dims(Xnp, 1)
assert_array_equal(Ynp, dpt.asnumpy(Y))

# `axis` has no default value
with pytest.raises(TypeError):
dpt.expand_dims(X)


def test_expand_dims_incorrect_tuple():
try:
X = dpt.empty((3, 3, 3), dtype="i4")
Expand Down
Loading