diff --git a/docs/api-assorted.md b/docs/api-assorted.md index c8101f0e..5c522932 100644 --- a/docs/api-assorted.md +++ b/docs/api-assorted.md @@ -27,4 +27,5 @@ setdiff1d sinc union1d + unravel_index ``` diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index d8f70b11..aa79f027 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -19,6 +19,7 @@ setdiff1d, sinc, union1d, + unravel_index, ) from ._lib._at import at from ._lib._funcs import ( @@ -58,4 +59,5 @@ "sinc", "testing", "union1d", + "unravel_index", ] diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 97dec674..71219fc1 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -31,6 +31,7 @@ "pad", "searchsorted", "sinc", + "unravel_index", ] @@ -1307,3 +1308,56 @@ def union1d(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array: return xp.union1d(a, b) return _funcs.union1d(a, b, xp=xp) + + +def unravel_index( + ind: Array, + shape: tuple[int, ...], + /, + *, + xp: ModuleType | None = None, +) -> tuple[Array, ...]: + """ + Convert a flat index or array of flat indices into a tuple of coordinate arrays. + + Parameters + ---------- + ind : array + An integer array whose elements are indices into the flattened version + of an array of dimensions `shape`. + + shape : tuple of ints + The shape to use for unraveling `indices`. + + xp : array_namespace, optional + The standard-compatible namespace for `x`. Default: infer. + + Returns + ------- + tuple of array + A tuple of unraveled indices. Each array in the tuple has the same shape + as the `indices` array. + + Examples + -------- + >>> import array_api_extra as xpx + >>> import array_api_strict as xp + >>> xpx.unravel_index(xp.asarray([1, 2, 3, 4, 5]), (4, 3)) + ( + Array([0, 0, 1, 1, 1], dtype=array_api_strict.int64), + Array([1, 2, 0, 1, 2], dtype=array_api_strict.int64), + ) + """ + if xp is None: + xp = array_namespace(ind) + + if ( + is_numpy_namespace(xp) + or is_cupy_namespace(xp) + or is_dask_namespace(xp) + or is_jax_namespace(xp) + or is_torch_namespace(xp) + ): + return xp.unravel_index(ind, shape) + + return _funcs.unravel_index(ind, shape) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 4e3b8753..ce995a4a 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -757,3 +757,13 @@ def angle(z: Array, /, *, deg: bool = False, xp: ModuleType | None = None) -> Ar if deg: a = a * 180 / xp.pi return a + + +def unravel_index(ind: Array, shape: tuple[int, ...], /) -> tuple[Array, ...]: + # numpydoc ignore=PR01,RT01 + """See docstring in `array_api_extra._delegation.py`.""" + coords: list[Array] = [] + for dim in reversed(shape): + coords.append(ind % dim) + ind = ind // dim + return tuple(reversed(coords)) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index c212129d..aeeda6b2 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -34,6 +34,7 @@ setdiff1d, sinc, union1d, + unravel_index, ) from array_api_extra import ( searchsorted as xpx_searchsorted, @@ -1981,3 +1982,51 @@ def test_2d(self, xp: ModuleType): def test_device(self, xp: ModuleType, device: Device): a = xp.asarray([1 + 1j], device=device) assert get_device(angle(a)) == device + + +class TestUnravelIndex: + def test_simple(self, xp: ModuleType): + ind = xp.asarray([22, 41, 37]) + shape = (7, 6) + expected = (xp.asarray([3, 6, 6]), xp.asarray([4, 5, 1])) + res = unravel_index(ind, shape) + for res_arr, exp_arr in zip(res, expected, strict=True): + assert_equal(res_arr, exp_arr) + + ind = xp.asarray([0, 1, 2, 3, 4, 5]) + shape = (3, 2) + expected = ( + xp.asarray([0, 0, 1, 1, 2, 2]), + xp.asarray([0, 1, 0, 1, 0, 1]), + ) + res = unravel_index(ind, shape) + for res_arr, exp_arr in zip(res, expected, strict=True): + assert_equal(res_arr, exp_arr) + + def test_indices_scalar(self, xp: ModuleType): + ind = xp.asarray(1621) + shape = (6, 7, 8, 9) + expected = (xp.asarray(3), xp.asarray(1), xp.asarray(4), xp.asarray(1)) + res = unravel_index(ind, shape) + # a tuple of integers is expected + assert res == expected + + def test_indices_2d(self, xp: ModuleType): + ind = xp.asarray([[1234], [5678]]) + shape = (10, 10, 10, 10) + expected = ( + xp.asarray([[1], [5]]), + xp.asarray([[2], [6]]), + xp.asarray([[3], [7]]), + xp.asarray([[4], [8]]), + ) + res = unravel_index(ind, shape) + for res_arr, exp_arr in zip(res, expected, strict=True): + assert_equal(res_arr, exp_arr) + + def test_device(self, xp: ModuleType, device: Device): + ind = xp.asarray([4, 1], device=device) + shape = (3, 2) + res = unravel_index(ind, shape) + for res_arr in res: + assert get_device(res_arr) == device