Skip to content

Commit a635d86

Browse files
authored
Fix multi-coordinate indexes dropped in _replace_maybe_drop_dims (#11286)
1 parent ef078b5 commit a635d86

5 files changed

Lines changed: 121 additions & 5 deletions

File tree

doc/whats-new.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,14 @@ Deprecations
112112
Bug Fixes
113113
~~~~~~~~~
114114

115+
- Fix multi-coordinate indexes being dropped in :py:meth:`DataArray._replace_maybe_drop_dims`
116+
(e.g. after reducing over an unrelated dimension) and in :py:meth:`Dataset._copy_listed`
117+
(e.g. when subsetting a Dataset by variable names). Both paths now consult
118+
:py:meth:`Index.should_add_coord_to_array`, consistent with
119+
:py:meth:`Dataset._construct_dataarray`. Also simplify :py:meth:`Dataset.to_dataarray`
120+
to keep all coordinates and indexes directly, since variables are broadcast and all
121+
coords are retained (:issue:`11215`, :pull:`11286`).
122+
By `Rich Signell <https://github.com/rsignell>`_.
115123
- Allow writing ``StringDType`` variables to netCDF files (:issue:`11199`).
116124
By `Kristian Kollsgård <https://github.com/kkollsga>`_.
117125
- Fix ``Source`` link in api docs (:pull:`11187`)

xarray/core/dataarray.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -538,9 +538,13 @@ def _replace_maybe_drop_dims(
538538
indexes = filter_indexes_from_coords(self._indexes, set(coords))
539539
else:
540540
allowed_dims = set(variable.dims)
541-
coords = {
542-
k: v for k, v in self._coords.items() if set(v.dims) <= allowed_dims
543-
}
541+
coords = {}
542+
for k, v in self._coords.items():
543+
if k in self._indexes:
544+
if self._indexes[k].should_add_coord_to_array(k, v, allowed_dims):
545+
coords[k] = v
546+
elif set(v.dims) <= allowed_dims:
547+
coords[k] = v
544548
indexes = filter_indexes_from_coords(self._indexes, set(coords))
545549
return self._replace(variable, coords, name, indexes=indexes)
546550

xarray/core/dataset.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1226,7 +1226,13 @@ def _copy_listed(self, names: Iterable[Hashable]) -> Self:
12261226
if k not in self._coord_names:
12271227
continue
12281228

1229-
if set(self.variables[k].dims) <= needed_dims:
1229+
if k in self._indexes:
1230+
if self._indexes[k].should_add_coord_to_array(
1231+
k, self._variables[k], set(needed_dims)
1232+
):
1233+
variables[k] = self._variables[k]
1234+
coord_names.add(k)
1235+
elif set(self.variables[k].dims) <= needed_dims:
12301236
variables[k] = self._variables[k]
12311237
coord_names.add(k)
12321238

@@ -7168,7 +7174,7 @@ def to_dataarray(
71687174
variable = Variable(dims, data, self.attrs, fastpath=True)
71697175

71707176
coords = {k: v.variable for k, v in self.coords.items()}
7171-
indexes = filter_indexes_from_coords(self._indexes, set(coords))
7177+
indexes = dict(self._indexes)
71727178
new_dim_index = PandasIndex(list(self.data_vars), dim)
71737179
indexes[dim] = new_dim_index
71747180
coords.update(new_dim_index.create_variables())

xarray/tests/test_dataarray.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,39 @@ def should_add_coord_to_array(self, name, var, dims):
555555
assert_identical(actual.coords, coords, check_default_indexes=False)
556556
assert "x_bnds" not in actual.dims
557557

558+
def test_replace_maybe_drop_dims_preserves_multi_coord_index(self) -> None:
559+
# Regression test for https://github.com/pydata/xarray/issues/11215
560+
# Multi-coordinate indexes spanning multiple dims should be preserved
561+
# after reducing over an unrelated dimension.
562+
class MultiDimIndex(Index):
563+
def should_add_coord_to_array(self, name, var, dims):
564+
return True
565+
566+
idx = MultiDimIndex()
567+
coords = Coordinates(
568+
coords={
569+
"node_x": ("nodes", [0.0, 1.0, 2.0]),
570+
"node_y": ("nodes", [0.0, 0.0, 1.0]),
571+
"face_x": ("faces", [0.5, 1.5]),
572+
"face_y": ("faces", [0.5, 0.5]),
573+
},
574+
indexes=dict.fromkeys(["node_x", "node_y", "face_x", "face_y"], idx),
575+
)
576+
node_da = DataArray(
577+
np.random.rand(3, 4), dims=("nodes", "extra"), coords=coords
578+
)
579+
face_da = DataArray(
580+
np.random.rand(2, 4), dims=("faces", "extra"), coords=coords
581+
)
582+
583+
reduced_node = node_da.mean("extra")
584+
reduced_face = face_da.mean("extra")
585+
586+
for da in [reduced_node, reduced_face]:
587+
for name in ["node_x", "node_y", "face_x", "face_y"]:
588+
assert name in da.coords
589+
assert isinstance(da.xindexes[name], MultiDimIndex)
590+
558591
def test_equals_and_identical(self) -> None:
559592
orig = DataArray(np.arange(5.0), {"a": 42}, dims="x")
560593

xarray/tests/test_dataset.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4561,6 +4561,71 @@ def should_add_coord_to_array(self, name, var, dims):
45614561
assert_identical(actual.coords, coords, check_default_indexes=False)
45624562
assert "x_bnds" not in actual.dims
45634563

4564+
def test_copy_listed_preserves_multi_coord_index(self) -> None:
4565+
# Regression test for https://github.com/pydata/xarray/issues/11215
4566+
# Multi-coordinate indexes spanning multiple dims should be preserved
4567+
# when subsetting a Dataset by variable names via ds[["var"]].
4568+
class MultiDimIndex(Index):
4569+
def should_add_coord_to_array(self, name, var, dims):
4570+
return True
4571+
4572+
idx = MultiDimIndex()
4573+
coords = Coordinates(
4574+
coords={
4575+
"node_x": ("nodes", [0.0, 1.0, 2.0]),
4576+
"node_y": ("nodes", [0.0, 0.0, 1.0]),
4577+
"face_x": ("faces", [0.5, 1.5]),
4578+
"face_y": ("faces", [0.5, 0.5]),
4579+
},
4580+
indexes=dict.fromkeys(["node_x", "node_y", "face_x", "face_y"], idx),
4581+
)
4582+
ds = Dataset(
4583+
{
4584+
"node_data": (("nodes",), [1.0, 2.0, 3.0]),
4585+
"face_data": (("faces",), [10.0, 20.0]),
4586+
},
4587+
coords=coords,
4588+
)
4589+
4590+
node_subset = ds[["node_data"]]
4591+
face_subset = ds[["face_data"]]
4592+
4593+
for ds_sub in [node_subset, face_subset]:
4594+
for name in ["node_x", "node_y", "face_x", "face_y"]:
4595+
assert name in ds_sub.coords
4596+
assert isinstance(ds_sub.xindexes[name], MultiDimIndex)
4597+
4598+
def test_to_dataarray_preserves_multi_coord_index(self) -> None:
4599+
# Regression test for https://github.com/pydata/xarray/issues/11215
4600+
# Multi-coordinate indexes spanning multiple dims should be preserved
4601+
# when converting a Dataset to a DataArray via to_dataarray().
4602+
class MultiDimIndex(Index):
4603+
def should_add_coord_to_array(self, name, var, dims):
4604+
return True
4605+
4606+
idx = MultiDimIndex()
4607+
coords = Coordinates(
4608+
coords={
4609+
"node_x": ("nodes", [0.0, 1.0, 2.0]),
4610+
"node_y": ("nodes", [0.0, 0.0, 1.0]),
4611+
"face_x": ("faces", [0.5, 1.5]),
4612+
"face_y": ("faces", [0.5, 0.5]),
4613+
},
4614+
indexes=dict.fromkeys(["node_x", "node_y", "face_x", "face_y"], idx),
4615+
)
4616+
ds = Dataset(
4617+
{
4618+
"node_data": (("nodes",), [1.0, 2.0, 3.0]),
4619+
},
4620+
coords=coords,
4621+
)
4622+
4623+
da = ds.to_dataarray()
4624+
4625+
for name in ["node_x", "node_y", "face_x", "face_y"]:
4626+
assert name in da.coords
4627+
assert isinstance(da.xindexes[name], MultiDimIndex)
4628+
45644629
def test_virtual_variables_default_coords(self) -> None:
45654630
dataset = Dataset({"foo": ("x", range(10))})
45664631
expected1 = DataArray(range(10), dims="x", name="x")

0 commit comments

Comments
 (0)