Skip to content

Commit 8d6bef5

Browse files
#1742 allowing Choropleth key_on to traverse through array (#1772)
* #1742 allowing Choropleth key_on to traverse through array * #1742 allowing Choropleth key_on to traverse through array * adding test case for get_by_key * fixing pre commit check * fixes * avoid repeat code, simplify test * mypy --------- Co-authored-by: amrutha1098 <38883175+amrutha1098@users.noreply.github.com> Co-authored-by: Frank <33519926+Conengmo@users.noreply.github.com>
1 parent 40afeb1 commit 8d6bef5

2 files changed

Lines changed: 40 additions & 11 deletions

File tree

folium/features.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1552,17 +1552,8 @@ def __init__(
15521552

15531553
key_on = key_on[8:] if key_on.startswith("feature.") else key_on
15541554

1555-
def get_by_key(obj, key):
1556-
return (
1557-
obj.get(key, None)
1558-
if len(key.split(".")) <= 1
1559-
else get_by_key(
1560-
obj.get(key.split(".")[0], None), ".".join(key.split(".")[1:])
1561-
)
1562-
)
1563-
15641555
def color_scale_fun(x):
1565-
key_of_x = get_by_key(x, key_on)
1556+
key_of_x = self._get_by_key(x, key_on)
15661557
if key_of_x is None:
15671558
raise ValueError(f"key_on `{key_on!r}` not found in GeoJSON.")
15681559

@@ -1623,6 +1614,20 @@ def highlight_function(x):
16231614
if self.color_scale:
16241615
self.add_child(self.color_scale)
16251616

1617+
@classmethod
1618+
def _get_by_key(cls, obj: Union[dict, list], key: str) -> Union[float, str, None]:
1619+
key_parts = key.split(".")
1620+
first_key_part = key_parts[0]
1621+
if first_key_part.isdigit():
1622+
value = obj[int(first_key_part)]
1623+
else:
1624+
value = obj.get(first_key_part, None) # type: ignore
1625+
if len(key_parts) > 1:
1626+
new_key = ".".join(key_parts[1:])
1627+
return cls._get_by_key(value, new_key)
1628+
else:
1629+
return value
1630+
16261631
def render(self, **kwargs) -> None:
16271632
"""Render the GeoJson/TopoJson and color scale objects."""
16281633
if self.color_scale:

tests/test_features.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from branca.element import Element
1313

1414
import folium
15-
from folium import ClickForMarker, GeoJson, Map, Popup
15+
from folium import Choropleth, ClickForMarker, GeoJson, Map, Popup
1616

1717

1818
@pytest.fixture
@@ -302,3 +302,27 @@ def test_geometry_collection_get_bounds():
302302
"type": "GeometryCollection",
303303
}
304304
assert folium.GeoJson(geojson_data).get_bounds() == [[0, -3], [4, 2]]
305+
306+
307+
def test_choropleth_get_by_key():
308+
geojson_data = {
309+
"id": "0",
310+
"type": "Feature",
311+
"properties": {"idx": 0, "value": 78.0},
312+
"geometry": {
313+
"type": "Polygon",
314+
"coordinates": [
315+
[
316+
[1, 2],
317+
[3, 4],
318+
]
319+
],
320+
},
321+
}
322+
323+
# Test with string path in key_on
324+
assert Choropleth._get_by_key(geojson_data, "properties.idx") == 0
325+
assert Choropleth._get_by_key(geojson_data, "properties.value") == 78.0
326+
327+
# Test with combined string path and numerical index in key_on
328+
assert Choropleth._get_by_key(geojson_data, "geometry.coordinates.0.0") == [1, 2]

0 commit comments

Comments
 (0)