Skip to content

Commit 41011f1

Browse files
authored
Cleanup xarray provider (#2224)
* Cleanup xarray provider * Respond to PR feedback * Fix xarray tests
1 parent 99446d5 commit 41011f1

3 files changed

Lines changed: 70 additions & 53 deletions

File tree

pygeoapi/provider/xarray_.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ def get_fields(self):
126126
elif dtype.name.startswith('str'):
127127
dtype = 'string'
128128

129+
if value.attrs.get('units') is None:
130+
msg = f'Field {key} missing units, will be skipped'
131+
LOGGER.warning(msg)
132+
continue
133+
129134
self._fields[key] = {
130135
'type': dtype,
131136
'title': value.attrs.get('long_name'),
@@ -249,19 +254,21 @@ def query(self, properties=[], subsets={}, bbox=[], bbox_crs=4326,
249254
data.coords[self.x_field].values[-1],
250255
data.coords[self.y_field].values[-1]
251256
],
252-
"driver": "xarray",
253-
"height": data.sizes[self.y_field],
254-
"width": data.sizes[self.x_field],
255-
"variables": {var_name: var.attrs
256-
for var_name, var in data.variables.items()}
257+
'driver': 'xarray',
258+
'height': data.sizes[self.y_field],
259+
'width': data.sizes[self.x_field],
260+
'variables': {
261+
var_name: var.attrs
262+
for var_name, var in data.variables.items()
263+
}
257264
}
258265

259266
if self.time_field is not None:
260267
out_meta['time'] = [
261268
_to_datetime_string(data.coords[self.time_field].values[0]),
262-
_to_datetime_string(data.coords[self.time_field].values[-1]),
269+
_to_datetime_string(data.coords[self.time_field].values[-1])
263270
]
264-
out_meta["time_steps"] = data.sizes[self.time_field]
271+
out_meta['time_steps'] = data.sizes[self.time_field]
265272

266273
LOGGER.debug('Serializing data in memory')
267274
if format_ == 'json':
@@ -395,25 +402,30 @@ def gen_covjson(self, metadata, data, fields):
395402
try:
396403
for key, value in selected_fields.items():
397404
LOGGER.debug(f'Adding range {key}')
398-
cj['ranges'][key] = {
405+
range = {
399406
'type': 'NdArray',
400407
'dataType': value['type'],
401408
'axisNames': [
402409
'y', 'x'
403410
],
404-
'shape': [metadata['height'],
405-
metadata['width']]
411+
'shape': [
412+
metadata['height'], metadata['width']
413+
],
414+
'values': [
415+
None if np.isnan(v) else v
416+
for v in data[key].values.flatten()
417+
]
406418
}
407-
cj['ranges'][key]['values'] = [
408-
None if np.isnan(v) else v
409-
for v in data[key].values.flatten()
410-
]
411419

412420
if self.time_field is not None:
413-
cj['ranges'][key]['axisNames'].append('t')
414-
cj['ranges'][key]['shape'].append(metadata['time_steps'])
421+
LOGGER.debug(f'Adding time axis to range {key}')
422+
range['axisNames'].insert(0, 't')
423+
range['shape'].insert(0, metadata['time_steps'])
424+
425+
cj['ranges'][key] = range
426+
415427
except IndexError as err:
416-
LOGGER.warning(err)
428+
LOGGER.error(err)
417429
raise ProviderQueryError('Invalid query parameter')
418430

419431
LOGGER.debug('Returning data')
@@ -684,11 +696,11 @@ def _get_zarr_data(data):
684696

685697
def _convert_float32_to_float64(data):
686698
"""
687-
Converts DataArray values of float32 to float64
688-
:param data: Xarray dataset of coverage data
699+
Converts DataArray values of float32 to float64
700+
:param data: Xarray dataset of coverage data
689701
690-
:returns: Xarray dataset of coverage data
691-
"""
702+
:returns: Xarray dataset of coverage data
703+
"""
692704

693705
for var_name in data.variables:
694706
if data[var_name].dtype == 'float32':

pygeoapi/provider/xarray_edr.py

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from pygeoapi.provider.xarray_ import (
3737
_to_datetime_string,
3838
_convert_float32_to_float64,
39-
XarrayProvider,
39+
XarrayProvider
4040
)
4141

4242
LOGGER = logging.getLogger(__name__)
@@ -73,10 +73,9 @@ def position(self, **kwargs):
7373

7474
query_params = {}
7575

76+
LOGGER.debug('Query type: position')
7677
LOGGER.debug(f'Query parameters: {kwargs}')
7778

78-
LOGGER.debug(f"Query type: {kwargs.get('query_type')}")
79-
8079
wkt = kwargs.get('wkt')
8180
if wkt is not None:
8281
LOGGER.debug('Processing WKT')
@@ -115,7 +114,10 @@ def position(self, **kwargs):
115114

116115
try:
117116
if select_properties:
118-
self._fields = {k: v for k, v in self._fields.items() if k in select_properties} # noqa
117+
self._fields = {
118+
k: v for k, v in self._fields.items()
119+
if k in select_properties
120+
}
119121
data = self._data[[*select_properties]]
120122
else:
121123
data = self._data
@@ -156,12 +158,12 @@ def position(self, **kwargs):
156158
bbox = wkt.bounds
157159
out_meta = {
158160
'bbox': [bbox[0], bbox[1], bbox[2], bbox[3]],
159-
"time": time,
160-
"driver": "xarray",
161-
"height": height,
162-
"width": width,
163-
"time_steps": time_steps,
164-
"variables": {var_name: var.attrs
161+
'time': time,
162+
'driver': 'xarray',
163+
'height': height,
164+
'width': width,
165+
'time_steps': time_steps,
166+
'variables': {var_name: var.attrs
165167
for var_name, var in data.variables.items()}
166168
}
167169

@@ -183,12 +185,11 @@ def cube(self, **kwargs):
183185

184186
query_params = {}
185187

188+
LOGGER.debug('Query type: cube')
186189
LOGGER.debug(f'Query parameters: {kwargs}')
187190

188-
LOGGER.debug(f"Query type: {kwargs.get('query_type')}")
189-
190191
bbox = kwargs.get('bbox')
191-
xmin, ymin, xmax, ymax = self._configure_bbox(bbox)
192+
xmin, ymin, xmax, ymax = self._configure_bbox()
192193

193194
if len(bbox) == 4:
194195
query_params[self.x_field] = slice(bbox[xmin], bbox[xmax])
@@ -208,15 +209,17 @@ def cube(self, **kwargs):
208209
if datetime_ is not None:
209210
query_params[self.time_field] = self._make_datetime(datetime_)
210211

212+
fields = {
213+
field: self.fields[field]
214+
for field in select_properties
215+
if field in self.fields
216+
} if select_properties else self.fields
217+
211218
LOGGER.debug(f'query parameters: {query_params}')
212219
try:
213-
if select_properties:
214-
self._fields = {k: v for k, v in self._fields.items() if k in select_properties} # noqa
215-
data = self._data[[*select_properties]]
216-
else:
217-
data = self._data
218-
data = data.sel(query_params)
219-
data = _convert_float32_to_float64(data)
220+
data = _convert_float32_to_float64(
221+
self._data[[*fields]].sel(query_params)
222+
)
220223
except KeyError:
221224
raise ProviderNoDataError()
222225

@@ -231,16 +234,18 @@ def cube(self, **kwargs):
231234
data.coords[self.x_field].values[-1],
232235
data.coords[self.y_field].values[-1]
233236
],
234-
"time": time,
235-
"driver": "xarray",
236-
"height": height,
237-
"width": width,
238-
"time_steps": time_steps,
239-
"variables": {var_name: var.attrs
240-
for var_name, var in data.variables.items()}
237+
'time': time,
238+
'driver': 'xarray',
239+
'height': height,
240+
'width': width,
241+
'time_steps': time_steps,
242+
'variables': {
243+
var_name: var.attrs
244+
for var_name, var in data.variables.items()
245+
}
241246
}
242247

243-
return self.gen_covjson(out_meta, data, self.fields)
248+
return self.gen_covjson(out_meta, data, fields)
244249

245250
def _make_datetime(self, datetime_):
246251
"""
@@ -300,7 +305,7 @@ def _parse_time_metadata(self, data, kwargs):
300305
time_steps = kwargs.get('limit')
301306
return time, time_steps
302307

303-
def _configure_bbox(self, bbox):
308+
def _configure_bbox(self):
304309
xmin, ymin, xmax, ymax = 0, 1, 2, 3
305310
if self._data[self.x_field][0] > self._data[self.x_field][-1]:
306311
xmin, xmax = xmax, xmin

tests/provider/test_xarray_zarr_provider.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def config_no_time(tmp_path):
7373
def test_provider(config):
7474
p = XarrayProvider(config)
7575

76-
assert len(p.fields) == 4
76+
assert len(p.fields) == 3
7777
assert len(p.axes) == 3
7878
assert p.axes == ['lon', 'lat', 'time']
7979

@@ -82,7 +82,7 @@ def test_schema(config):
8282
p = XarrayProvider(config)
8383

8484
assert isinstance(p.fields, dict)
85-
assert len(p.fields) == 4
85+
assert len(p.fields) == 3
8686
assert p.fields['analysed_sst']['title'] == 'analysed sea surface temperature' # noqa
8787

8888

@@ -107,7 +107,7 @@ def test_numpy_json_serial():
107107
def test_no_time(config_no_time):
108108
p = XarrayProvider(config_no_time)
109109

110-
assert len(p.fields) == 4
110+
assert len(p.fields) == 3
111111
assert p.axes == ['lon', 'lat']
112112

113113
coverage = p.query(format='json')

0 commit comments

Comments
 (0)