From a0dcfb6c80214d3c28b0279b6b37c467e40a11f1 Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Thu, 2 Jul 2026 14:46:15 +0200 Subject: [PATCH 1/7] Align signature and axis docstring --- dpnp/tensor/_manipulation_functions.py | 40 +++++++++++++++++--------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/dpnp/tensor/_manipulation_functions.py b/dpnp/tensor/_manipulation_functions.py index 7347f62de115..1d9e7b4733d8 100644 --- a/dpnp/tensor/_manipulation_functions.py +++ b/dpnp/tensor/_manipulation_functions.py @@ -355,36 +355,50 @@ def concat(arrays, /, *, axis=0): return res -def expand_dims(X, /, *, axis=0): +def expand_dims(X, /, axis): """expand_dims(x, axis) Expands the shape of an array by inserting a new axis (dimension) - of size one at the position specified by axis. + of size one at the position (or positions) specified by axis. Args: x (usm_ndarray): input array - axis (Union[int, Tuple[int]]): - axis position in the expanded axes (zero-based). If `x` has rank - (i.e, number of dimensions) `N`, a valid `axis` must reside - in the closed-interval `[-N-1, N]`. If provided a negative - `axis`, the `axis` position at which to insert a singleton - dimension is computed as `N + axis + 1`. Hence, if - provided `-1`, the resolved axis position is `N` (i.e., - a singleton dimension must be appended to the input array `x`). - If provided `-N-1`, the resolved axis position is `0` (i.e., a - singleton dimension is prepended to the input array `x`). + axis (Union[int, Tuple[int, ...]]): + axis position(s) (zero-based). If ``axis`` is an integer, ``axis`` + **must** be equivalent to the tuple ``(axis,)``. If ``axis`` is + a tuple, + + - a valid axis position **must** reside on the half-open interval + ``[-M, M)``, where ``M = N + len(axis)`` and ``N`` is the number + of dimensions in ``x``. + - if the i-th entry is a negative integer, the axis position of the + inserted singleton dimension in the output array **must** be + computed as ``M + axis[i]``. + - each entry of ``axis`` must resolve to a unique positive axis + position. + - for each entry of ``axis``, the corresponding dimension in the + expanded output array **must** be a singleton dimension. + - for the remaining dimensions of the expanded output array, the + output array dimensions **must** correspond to the dimensions of + ``x`` in order. + - if provided an invalid axis position, the function **must** raise + an exception. Returns: usm_ndarray: Returns a view, if possible, and a copy otherwise with the number of dimensions increased. The expanded array has the same data type as the input array `x`. + If ``axis`` is an integer, the output array must have ``N + 1`` + dimensions. If ``axis`` is a tuple, the output array must have + ``N + len(axis)`` dimensions. The expanded array is located on the same device as the input array, and has the same USM allocation type. Raises: - IndexError: if `axis` value is invalid. + AxisError: if an `axis` value is out of range. + ValueError: if `axis` contains a repeated value. """ if not isinstance(X, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") From bcf120e41c0d90d35e16a4c7b5759ec4351990f7 Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Thu, 2 Jul 2026 14:50:05 +0200 Subject: [PATCH 2/7] Removed duplicating axis note --- dpnp/tensor/_manipulation_functions.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/dpnp/tensor/_manipulation_functions.py b/dpnp/tensor/_manipulation_functions.py index 1d9e7b4733d8..d12422fdefa4 100644 --- a/dpnp/tensor/_manipulation_functions.py +++ b/dpnp/tensor/_manipulation_functions.py @@ -382,8 +382,6 @@ def expand_dims(X, /, axis): - for the remaining dimensions of the expanded output array, the output array dimensions **must** correspond to the dimensions of ``x`` in order. - - if provided an invalid axis position, the function **must** raise - an exception. Returns: usm_ndarray: From fc70add8f7034b42e317d5964ce80c08507d95cc Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Thu, 2 Jul 2026 14:53:20 +0200 Subject: [PATCH 3/7] Fix typo in test_expand_dims_incorrect_type --- dpnp/tests/tensor/test_usm_ndarray_manipulation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpnp/tests/tensor/test_usm_ndarray_manipulation.py b/dpnp/tests/tensor/test_usm_ndarray_manipulation.py index 0375bb446370..c6a7ef2b6b44 100644 --- a/dpnp/tests/tensor/test_usm_ndarray_manipulation.py +++ b/dpnp/tests/tensor/test_usm_ndarray_manipulation.py @@ -102,7 +102,7 @@ def test_permute_dims_2d_3d(shapes): def test_expand_dims_incorrect_type(): X_list = [1, 2, 3, 4, 5] with pytest.raises(TypeError): - dpt.permute_dims(X_list, axis=1) + dpt.expand_dims(X_list, axis=1) def test_expand_dims_0d(): From 00e8bc60c51b223aa45719646dd2d56f696beea6 Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Thu, 2 Jul 2026 15:01:35 +0200 Subject: [PATCH 4/7] Add dedicated test to ensure axis might be a positional argument and has no default value --- .../tests/tensor/test_usm_ndarray_manipulation.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/dpnp/tests/tensor/test_usm_ndarray_manipulation.py b/dpnp/tests/tensor/test_usm_ndarray_manipulation.py index c6a7ef2b6b44..bb0a99a537ff 100644 --- a/dpnp/tests/tensor/test_usm_ndarray_manipulation.py +++ b/dpnp/tests/tensor/test_usm_ndarray_manipulation.py @@ -154,6 +154,21 @@ def test_expand_dims_tuple(axes): assert_array_equal(Ynp, dpt.asnumpy(Y)) +def test_expand_dims_positional_axis(): + q = get_queue_or_skip() + + Xnp = np.empty((3, 3, 3), dtype="u1") + X = dpt.asarray(Xnp, sycl_queue=q) + + Y = dpt.expand_dims(X, 1) # `axis` is a positional-or-keyword argument + Ynp = np.expand_dims(Xnp, 1) + assert_array_equal(Ynp, dpt.asnumpy(Y)) + + # `axis` has no default value + with pytest.raises(TypeError): + dpt.expand_dims(X) + + def test_expand_dims_incorrect_tuple(): try: X = dpt.empty((3, 3, 3), dtype="i4") From 10437781cabb43a7bfadd90e0d73daadfe470022 Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Thu, 2 Jul 2026 15:13:56 +0200 Subject: [PATCH 5/7] Update notes in dpnp.expand_dims --- dpnp/dpnp_iface_manipulation.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/dpnp/dpnp_iface_manipulation.py b/dpnp/dpnp_iface_manipulation.py index 559884acb049..fd0cfd64031d 100644 --- a/dpnp/dpnp_iface_manipulation.py +++ b/dpnp/dpnp_iface_manipulation.py @@ -1758,7 +1758,7 @@ def dstack(tup): return dpnp.concatenate(arrs, axis=2) -def expand_dims(a, axis): +def expand_dims(a, /, axis): """ Expand the shape of an array. @@ -1782,14 +1782,15 @@ def expand_dims(a, axis): Notes ----- - If `a` has rank (i.e, number of dimensions) `N`, a valid `axis` must reside - in the closed-interval `[-N-1, N]`. - If provided a negative `axis`, the `axis` position at which to insert a - singleton dimension is computed as `N + axis + 1`. - Hence, if provided `-1`, the resolved axis position is `N` (i.e., - a singleton dimension must be appended to the input array `a`). - If provided `-N-1`, the resolved axis position is `0` (i.e., a - singleton dimension is added to the input array `a`). + If `a` has rank (i.e, number of dimensions) `N`, a valid `axis` value must + reside on the half-open interval `[-M, M)`, where `M = N + len(axis)` (with + `len(axis)` equal to ``1`` when `axis` is an integer). + If provided a negative `axis`, the position at which to insert a singleton + dimension is computed as ``M + axis``. + Hence, if provided ``-1``, the resolved axis position is ``M - 1`` (i.e., + a singleton dimension is appended to the input array `a`). + If provided ``-M``, the resolved axis position is ``0`` (i.e., a singleton + dimension is prepended to the input array `a`). See Also -------- From d19fdc9c5cafd94458c5549fa50ffb6827f0840a Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Thu, 2 Jul 2026 15:15:23 +0200 Subject: [PATCH 6/7] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 70bf2a61a103..a9585e180e3a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ This release is compatible with NumPy 2.5. * Improved performance of `dpnp.fft` functions for complex strided input by avoiding oversized allocations and extra copies [#2939](https://github.com/IntelPython/dpnp/pull/2939) * Refreshed `dpnp` documentation styling with the Furo theme [#2934](https://github.com/IntelPython/dpnp/pull/2934) * Updated Python Array API specification version supported to `2025.12` [#2899](https://github.com/IntelPython/dpnp/pull/2899) +* Aligned the signature of `dpnp.tensor.expand_dims` with the Python array API by making `axis` a required argument [#2986](https://github.com/IntelPython/dpnp/pull/2986) ### Deprecated From e556d3e5328c284c9db6be472af18c93028107d5 Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Thu, 2 Jul 2026 15:21:18 +0200 Subject: [PATCH 7/7] Correct PR number in changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a9585e180e3a..3fb8c4af3f5e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,7 +32,7 @@ This release is compatible with NumPy 2.5. * Improved performance of `dpnp.fft` functions for complex strided input by avoiding oversized allocations and extra copies [#2939](https://github.com/IntelPython/dpnp/pull/2939) * Refreshed `dpnp` documentation styling with the Furo theme [#2934](https://github.com/IntelPython/dpnp/pull/2934) * Updated Python Array API specification version supported to `2025.12` [#2899](https://github.com/IntelPython/dpnp/pull/2899) -* Aligned the signature of `dpnp.tensor.expand_dims` with the Python array API by making `axis` a required argument [#2986](https://github.com/IntelPython/dpnp/pull/2986) +* Aligned the signature of `dpnp.tensor.expand_dims` with the Python array API by making `axis` a required argument [#2988](https://github.com/IntelPython/dpnp/pull/2988) ### Deprecated