diff --git a/docs/examples/geo/04_choropleth.py b/docs/examples/geo/04_choropleth.py new file mode 100644 index 000000000..89b8ffeef --- /dev/null +++ b/docs/examples/geo/04_choropleth.py @@ -0,0 +1,55 @@ +""" +Simple choropleth +================= + +Color country-level values directly on a geographic axes. + +Why UltraPlot here? +------------------- +UltraPlot now exposes :meth:`~ultraplot.axes.GeoAxes.choropleth`, so you can +draw country-level thematic maps from plain ISO-style identifiers while using +the same concise colorbar and formatting API used elsewhere in the library. + +Key functions: :py:func:`ultraplot.subplots`, :py:meth:`ultraplot.axes.GeoAxes.choropleth`. + +See also +-------- +* :doc:`Geographic projections ` +""" + +import numpy as np + +import ultraplot as uplt + +country_values = { + "AUS": 1.2, + "BRA": 2.6, + "IND": 3.4, + "ZAF": np.nan, +} + +fig, ax = uplt.subplots(proj="robin", refwidth=4.6) + +ax.choropleth( + country_values, + country=True, + cmap="Fire", + edgecolor="white", + linewidth=0.6, + colorbar="r", + colorbar_kw={"label": "Index value"}, + missing_kw={"facecolor": "gray8", "hatch": "//", "edgecolor": "white"}, +) + +ax.format( + title="Country choropleth", + ocean=True, + oceancolor="ocean blue", + coast=True, + borders=True, + lonlines=60, + latlines=30, + labels=False, +) + +fig.show() diff --git a/docs/projections.py b/docs/projections.py index 582125cd6..24e285c76 100644 --- a/docs/projections.py +++ b/docs/projections.py @@ -325,6 +325,35 @@ ) +# %% +import shapely.geometry as sgeom + +fig, ax = uplt.subplots(proj="cyl", refwidth=3.5) +ax.choropleth( + [ + sgeom.box(-20, -10, -5, 5), + sgeom.box(0, -5, 15, 10), + sgeom.box(20, -8, 35, 8), + ], + [1.2, 2.4, 0.7], + cmap="Blues", + edgecolor="white", + linewidth=0.8, + colorbar="r", + colorbar_kw={"label": "value"}, +) +ax.format( + title="Polygon choropleth", + land=True, + coast=True, + lonlim=(-30, 40), + latlim=(-20, 20), + labels=True, + lonlines=10, + latlines=10, +) + + # %% [raw] raw_mimetype="text/restructuredtext" # .. _ug_geoformat: # diff --git a/ultraplot/axes/geo.py b/ultraplot/axes/geo.py index ce13b41cd..e7721e404 100644 --- a/ultraplot/axes/geo.py +++ b/ultraplot/axes/geo.py @@ -15,10 +15,12 @@ except ImportError: # From Python 3.5 from typing_extensions import override -from collections.abc import Iterator, MutableMapping, Sequence +from collections.abc import Iterator, Mapping, MutableMapping, Sequence from typing import Any, Optional, Protocol import matplotlib.axis as maxis +import matplotlib.collections as mcollections +import matplotlib.patches as mpatches import matplotlib.path as mpath import matplotlib.text as mtext import matplotlib.ticker as mticker @@ -31,6 +33,8 @@ from ..config import rc from ..internals import ( _not_none, + _pop_params, + _pop_props, _pop_rc, _version_cartopy, docstring, @@ -224,6 +228,56 @@ """ docstring._snippet_manager["geo.format"] = _format_docstring +_choropleth_docstring = """ +Draw polygon geometries colored by numeric values. + +Parameters +---------- +geometries + Sequence of polygon-like shapely geometries. Typical inputs include + GeoPandas ``geometry`` arrays or lists of shapely polygons in + longitude-latitude coordinates. When `country=True`, this can also + be a sequence of country codes/names or a mapping of country + identifiers to values. +values + Numeric values mapped to colors. Must have the same length as + `geometries`. Optional when `country=True` and `geometries` is a + mapping of country identifiers to values. +transform : cartopy CRS, optional + The input coordinate system for `geometries`. By default, cartopy + backends assume `~cartopy.crs.PlateCarree` and basemap backends + assume longitude-latitude input. +country : bool, optional + Interpret `geometries` as country identifiers and resolve them to + Natural Earth polygons before plotting. +country_reso : {'110m', '50m', '10m'}, optional + The Natural Earth country resolution used when `country=True`. + Defaults to :rc:`geo.choropleth.country_reso`. +country_territories : bool, optional + Whether to keep distant territories for multi-part country + geometries when `country=True`. Defaults to + :rc:`geo.choropleth.country_territories`. +colorbar, colorbar_kw + Passed to `~ultraplot.axes.Axes.colorbar`. +missing_kw : dict-like, optional + Style applied to geometries whose values are missing or non-finite. + If omitted, missing geometries are skipped. + +Other parameters +---------------- +cmap, cmap_kw, norm, norm_kw, vmin, vmax, levels, values + Standard UltraPlot colormap arguments. +edgecolor, linewidth, alpha, hatch, rasterized, zorder, label, ... + Collection styling arguments passed to the polygon collection. + +Returns +------- +matplotlib.collections.PatchCollection + The scalar-mappable collection for finite-valued polygons. +""" + +docstring._snippet_manager["geo.choropleth"] = _choropleth_docstring + class _GeoLabel(object): """ @@ -2209,6 +2263,159 @@ def format( # Parent format method super().format(rc_kw=rc_kw, rc_mode=rc_mode, **kwargs) + @docstring._snippet_manager + def choropleth( + self, + geometries: Sequence[Any], + values: Sequence[Any] | None = None, + *, + transform: Any = None, + country: bool = False, + country_reso: str | None = None, + country_territories: bool | None = None, + colorbar: Any = None, + colorbar_kw: MutableMapping[str, Any] | None = None, + missing_kw: MutableMapping[str, Any] | None = None, + **kwargs: Any, + ) -> mcollections.PatchCollection: + """ + %(geo.choropleth)s + """ + country_reso = _not_none( + country_reso, + rc.find("geo.choropleth.country_reso", context=True), + ) + country_territories = _not_none( + country_territories, + rc.find("geo.choropleth.country_territories", context=True), + ) + if country: + geometries, values, transform = _choropleth_country_inputs( + geometries, + values, + transform=transform, + resolution=country_reso, + include_far=country_territories, + ) + elif values is None: + raise ValueError( + "choropleth() requires values unless country=True and geometries " + "is a mapping of country identifiers to values." + ) + + geometries = list(geometries) + values_arr = np.ma.masked_invalid(np.asarray(values, dtype=float).ravel()) + if values_arr.ndim != 1: + raise ValueError("choropleth() values must be one-dimensional.") + if len(geometries) != values_arr.size: + raise ValueError( + "choropleth() geometries and values must have the same length. " + f"Got {len(geometries)} geometries and {values_arr.size} values." + ) + + kw = kwargs.copy() + kw.update(_pop_props(kw, "collection")) + center_levels = kw.pop("center_levels", None) + explicit_zorder = "zorder" in kwargs + zorder = _not_none( + kw.get("zorder", None), + rc.find("geo.choropleth.zorder", context=True), + rc["land.zorder"] + 0.1, + ) + kw["zorder"] = zorder + + invalid_face_keys = ("color", "colors", "facecolor", "facecolors") + ignored = {key: kw.pop(key) for key in invalid_face_keys if key in kw} + if ignored: + warnings._warn_ultraplot( + "choropleth() colors polygons from numeric values, so " + f"facecolor/color args are ignored: {tuple(ignored)}. " + "Use cmap=... or missing_kw=... instead." + ) + + valid_patches = [] + valid_values = [] + missing_patches = [] + valid_mask = ~np.ma.getmaskarray(values_arr) + for geometry, value, is_valid in zip(geometries, values_arr.data, valid_mask): + path = _choropleth_geometry_path(self, geometry, transform=transform) + if path is None: + continue + patch = mpatches.PathPatch(path) + if is_valid: + valid_patches.append(patch) + valid_values.append(float(value)) + else: + missing_patches.append(patch) + + if not valid_patches: + raise ValueError("choropleth() produced no polygon patches to draw.") + valid_values = np.asarray(valid_values, dtype=float) + + kw = self._parse_cmap( + valid_values, + default_discrete=True, + center_levels=center_levels, + **kw, + ) + cmap, norm = kw.pop("cmap"), kw.pop("norm") + guide_kw = _pop_params(kw, self._update_guide) + label = kw.pop("label", None) + + collection = mcollections.PatchCollection( + valid_patches, + cmap=cmap, + norm=norm, + label=label, + match_original=False, + ) + collection.set_array(valid_values) + collection.update(kw) + self.add_collection(collection) + edge_kw = _choropleth_edge_collection_kw( + kw, + zorder=collection.get_zorder(), + explicit_zorder=explicit_zorder, + ) + if edge_kw is not None: + edge_collection = mcollections.PatchCollection( + valid_patches, + match_original=False, + ) + edge_collection.update(edge_kw) + self.add_collection(edge_collection) + + if missing_patches and missing_kw is not None: + miss_kw = dict(missing_kw) + miss_kw.update(_pop_props(miss_kw, "collection")) + missing_explicit_zorder = "zorder" in missing_kw + if not any(key in miss_kw for key in invalid_face_keys): + miss_kw["facecolor"] = "none" + missing = mcollections.PatchCollection( + missing_patches, + match_original=False, + ) + missing.update(miss_kw) + self.add_collection(missing) + miss_edge_kw = _choropleth_edge_collection_kw( + miss_kw, + zorder=missing.get_zorder(), + explicit_zorder=missing_explicit_zorder, + ) + if miss_edge_kw is not None: + missing_edge = mcollections.PatchCollection( + missing_patches, + match_original=False, + ) + missing_edge.update(miss_edge_kw) + self.add_collection(missing_edge) + + self.autoscale_view() + self._update_guide(collection, queue_colorbar=False, **guide_kw) + if colorbar: + self.colorbar(collection, loc=colorbar, **(colorbar_kw or {})) + return collection + def _add_geoticks(self, x_or_y: str, itick: Any, ticklen: Any) -> None: """ Add tick marks to the geographic axes. @@ -3432,6 +3639,191 @@ def _update_minor_gridlines( axis.isDefault_minloc = True +def _is_platecarree_crs(transform: Any) -> bool: + """ + Return whether `transform` represents plain longitude-latitude coordinates. + """ + if transform is None: + return True + name = getattr(getattr(transform, "__class__", None), "__name__", "") + return name == "PlateCarree" + + +def _choropleth_close_path(vertices: Any) -> mpath.Path | None: + """ + Convert a single polygon ring into a closed path. + """ + vertices = np.asarray(vertices, dtype=float) + if vertices.ndim != 2 or vertices.shape[0] < 3: + return None + vertices = vertices[:, :2] + if not np.allclose(vertices[0], vertices[-1], equal_nan=True): + vertices = np.vstack((vertices, vertices[0])) + codes = np.full(vertices.shape[0], mpath.Path.LINETO, dtype=np.uint8) + codes[0] = mpath.Path.MOVETO + codes[-1] = mpath.Path.CLOSEPOLY + return mpath.Path(vertices, codes) + + +def _choropleth_iter_rings(geometry: Any) -> Iterator[Any]: + """ + Yield polygon rings from shapely-like polygon geometries. + """ + if geometry is None or getattr(geometry, "is_empty", False): + return + geom_type = getattr(geometry, "geom_type", None) + if geom_type == "Polygon": + yield geometry.exterior.coords + for ring in geometry.interiors: + yield ring.coords + return + if geom_type in ("MultiPolygon", "GeometryCollection"): + for part in getattr(geometry, "geoms", ()): + yield from _choropleth_iter_rings(part) + return + raise TypeError( + "choropleth() geometries must be polygon-like shapely objects. " + f"Got {type(geometry).__name__}." + ) + + +def _choropleth_project_vertices( + ax: GeoAxes, + vertices: Any, + *, + transform: Any = None, +) -> np.ndarray: + """ + Project polygon-ring vertices into the target map coordinate system. + """ + vertices = np.asarray(vertices, dtype=float) + xy = vertices[:, :2] + if ax._name == "cartopy": + src = transform + if src is None: + if ccrs is None: + raise RuntimeError("choropleth() requires cartopy for cartopy GeoAxes.") + src = ccrs.PlateCarree() + out = ax.projection.transform_points(src, xy[:, 0], xy[:, 1]) + return np.asarray(out[:, :2], dtype=float) + + if transform is not None and not _is_platecarree_crs(transform): + raise ValueError( + "Basemap choropleth() only supports longitude-latitude input " + "coordinates. Use transform=None or cartopy.crs.PlateCarree()." + ) + x, y = ax.projection(xy[:, 0], xy[:, 1]) + return np.column_stack((np.asarray(x, dtype=float), np.asarray(y, dtype=float))) + + +def _choropleth_geometry_path( + ax: GeoAxes, + geometry: Any, + *, + transform: Any = None, +) -> mpath.Path | None: + """ + Convert a polygon geometry to a projected matplotlib path. + """ + paths = [] + for ring in _choropleth_iter_rings(geometry): + projected = _choropleth_project_vertices(ax, ring, transform=transform) + path = _choropleth_close_path(projected) + if path is not None: + paths.append(path) + if not paths: + return None + return mpath.Path.make_compound_path(*paths) + + +def _choropleth_country_inputs( + geometries: Any, + values: Any, + *, + transform: Any = None, + resolution: str = "110m", + include_far: bool = False, +) -> tuple[list[Any], Any, Any]: + """ + Resolve country identifiers into polygon geometries. + """ + from .. import legend as plegend + + if values is None: + if not isinstance(geometries, Mapping): + raise ValueError( + "choropleth(country=True) requires either values=... or a " + "mapping of country identifiers to numeric values." + ) + keys = list(geometries.keys()) + values = list(geometries.values()) + else: + if isinstance(geometries, Mapping): + raise ValueError( + "choropleth(country=True) does not accept both a mapping input " + "and an explicit values=... argument." + ) + keys = list(geometries) + + if transform is not None and not _is_platecarree_crs(transform): + raise ValueError( + "choropleth(country=True) uses Natural Earth lon/lat geometries, so " + "transform must be None or cartopy.crs.PlateCarree()." + ) + + resolution = plegend._normalize_country_resolution(resolution) + geometries = [ + plegend._resolve_country_geometry( + str(key), + resolution=resolution, + include_far=include_far, + ) + for key in keys + ] + return geometries, values, transform + + +def _choropleth_edge_collection_kw( + kw: Mapping[str, Any], + *, + zorder: float, + explicit_zorder: bool = False, +) -> dict[str, Any] | None: + """ + Return edge-only collection settings when polygon outlines should overlay features. + """ + edge_keys = ( + "edgecolor", + "edgecolors", + "linewidth", + "linewidths", + "linestyle", + "linestyles", + ) + if not any(key in kw for key in edge_keys): + return None + edge_kw = { + key: value + for key, value in kw.items() + if key not in ("color", "colors", "facecolor", "facecolors", "hatch", "label") + } + if explicit_zorder: + edge_kw["zorder"] = zorder + else: + edge_kw["zorder"] = ( + max( + zorder, + *( + rc.find(f"{name}.zorder", context=True) + for name in ("coast", "rivers", "borders", "innerborders") + ), + ) + + 0.1 + ) + edge_kw["facecolor"] = "none" + return edge_kw + + # Apply signature obfuscation after storing previous signature GeoAxes._format_signatures[GeoAxes] = inspect.signature(GeoAxes.format) GeoAxes.format = docstring._obfuscate_kwargs(GeoAxes.format) diff --git a/ultraplot/internals/rcsetup.py b/ultraplot/internals/rcsetup.py index 9f580154b..3b94e8ec0 100644 --- a/ultraplot/internals/rcsetup.py +++ b/ultraplot/internals/rcsetup.py @@ -1732,6 +1732,24 @@ def _validator_accepts(validator, value): "If ``True`` (the default), polar `~ultraplot.axes.GeoAxes` like ``'npstere'`` " "and ``'spstere'`` are bounded with circles rather than squares.", ), + "geo.choropleth.country_reso": ( + "110m", + _validate_belongs("10m", "50m", "110m"), + "Default Natural Earth resolution used by `GeoAxes.choropleth` when " + "country identifiers are resolved to polygons.", + ), + "geo.choropleth.country_territories": ( + False, + _validate_bool, + "Whether `GeoAxes.choropleth` keeps distant territories when resolving " + "country identifiers into Natural Earth geometries.", + ), + "geo.choropleth.zorder": ( + None, + _validate_or_none(_validate_float), + "Default z-order for `GeoAxes.choropleth`. When ``None``, the choropleth " + "is drawn just above the land feature.", + ), # Graphs "graph.draw_nodes": ( True, diff --git a/ultraplot/tests/test_geographic.py b/ultraplot/tests/test_geographic.py index 7d363f9d9..802670d9e 100644 --- a/ultraplot/tests/test_geographic.py +++ b/ultraplot/tests/test_geographic.py @@ -929,6 +929,193 @@ def test_rasterize_feature(): uplt.close(fig) +@pytest.mark.parametrize("backend", ["cartopy", "basemap"]) +def test_choropleth_draws_patch_collection_and_missing_polygons(backend): + if backend == "cartopy": + pytest.importorskip("cartopy.crs") + sgeom = pytest.importorskip("shapely.geometry") + from matplotlib import collections as mcollections + + fig, ax = uplt.subplots(proj="cyl", backend=backend) + geo = ax[0] + coll = geo.choropleth( + [ + sgeom.box(-20, -10, -5, 5), + sgeom.box(0, -5, 15, 10), + sgeom.box(20, -5, 35, 10), + ], + [1.0, np.nan, 3.0], + edgecolor="k", + linewidth=0.5, + colorbar="r", + missing_kw={"facecolor": "gray8", "hatch": "//"}, + ) + fig.canvas.draw() + + assert isinstance(coll, mcollections.PatchCollection) + assert np.allclose(np.asarray(coll.get_array()), [1.0, 3.0]) + assert len(coll.get_paths()) == 2 + missing = [other for other in geo.collections if other.get_hatch() == "//"] + assert len(missing) == 1 + assert len(fig.axes) == 2 + uplt.close(fig) + + +@pytest.mark.parametrize("backend", ["cartopy", "basemap"]) +def test_choropleth_country_mapping_resolves_codes(monkeypatch, backend): + sgeom = pytest.importorskip("shapely.geometry") + from ultraplot import legend as plegend + + calls = [] + country_geoms = { + "AUS": sgeom.box(110, -45, 155, -10), + "NZL": sgeom.box(166, -48, 179, -34), + } + + def _fake_country(code, resolution="110m", include_far=False): + key = str(code).upper() + calls.append((key, resolution, bool(include_far))) + return country_geoms[key] + + monkeypatch.setattr(plegend, "_resolve_country_geometry", _fake_country) + + fig, ax = uplt.subplots(proj="cyl", backend=backend) + coll = ax[0].choropleth( + {"AUS": 1.0, "NZL": 2.0}, + country=True, + country_reso="50m", + country_territories=True, + ) + fig.canvas.draw() + + assert np.allclose(np.asarray(coll.get_array()), [1.0, 2.0]) + assert len(coll.get_paths()) == 2 + assert calls == [("AUS", "50m", True), ("NZL", "50m", True)] + uplt.close(fig) + + +@pytest.mark.parametrize("backend", ["cartopy", "basemap"]) +def test_choropleth_country_defaults_respect_rc(monkeypatch, backend): + sgeom = pytest.importorskip("shapely.geometry") + from ultraplot import legend as plegend + + calls = [] + + def _fake_country(code, resolution="110m", include_far=False): + calls.append((str(code).upper(), resolution, bool(include_far))) + return sgeom.box(110, -45, 155, -10) + + monkeypatch.setattr(plegend, "_resolve_country_geometry", _fake_country) + + with uplt.rc.context( + { + "geo.choropleth.country_reso": "50m", + "geo.choropleth.country_territories": True, + } + ): + fig, ax = uplt.subplots(proj="cyl", backend=backend) + coll = ax[0].choropleth({"AUS": 1.0}, country=True) + fig.canvas.draw() + + assert np.allclose(np.asarray(coll.get_array()), [1.0]) + assert calls == [("AUS", "50m", True)] + uplt.close(fig) + + +@pytest.mark.parametrize("backend", ["cartopy", "basemap"]) +def test_choropleth_default_zorder_above_land(backend): + sgeom = pytest.importorskip("shapely.geometry") + + fig, ax = uplt.subplots(proj="cyl", backend=backend) + geo = ax[0] + coll = geo.choropleth([sgeom.box(-20, -10, 20, 10)], [1.0]) + geo.format(land=True) + fig.canvas.draw() + + land = getattr(geo, "_land_feature") + if isinstance(land, (tuple, list)): + land_zorder = land[0].get_zorder() + else: + land_zorder = land.get_zorder() + assert coll.get_zorder() > land_zorder + uplt.close(fig) + + +@pytest.mark.parametrize("backend", ["cartopy", "basemap"]) +def test_choropleth_edgecolor_overlays_borders(backend): + sgeom = pytest.importorskip("shapely.geometry") + from matplotlib import colors as mcolors + + fig, ax = uplt.subplots(proj="cyl", backend=backend) + geo = ax[0] + coll = geo.choropleth( + [sgeom.box(-20, -10, 20, 10)], + [1.0], + edgecolor="red", + linewidth=2, + ) + geo.format(borders=True) + fig.canvas.draw() + + borders = getattr(geo, "_borders_feature") + if isinstance(borders, (tuple, list)): + borders_zorder = borders[0].get_zorder() + else: + borders_zorder = borders.get_zorder() + edge = next( + other + for other in geo.collections + if other is not coll + and len(other.get_paths()) == len(coll.get_paths()) + and np.allclose(np.asarray(other.get_edgecolor())[0], mcolors.to_rgba("red")) + ) + assert edge.get_zorder() > borders_zorder + assert np.allclose(np.asarray(edge.get_edgecolor())[0], mcolors.to_rgba("red")) + uplt.close(fig) + + +def test_choropleth_zorder_respects_rc(): + sgeom = pytest.importorskip("shapely.geometry") + + with uplt.rc.context({"geo.choropleth.zorder": 5.5}): + fig, ax = uplt.subplots(proj="cyl") + coll = ax[0].choropleth([sgeom.box(-20, -10, 20, 10)], [1.0]) + fig.canvas.draw() + + assert coll.get_zorder() == pytest.approx(5.5) + uplt.close(fig) + + +def test_choropleth_length_mismatch_raises(): + sgeom = pytest.importorskip("shapely.geometry") + + fig, ax = uplt.subplots(proj="cyl") + with pytest.raises(ValueError, match="same length"): + ax[0].choropleth([sgeom.box(-10, -10, 10, 10)], [1.0, 2.0]) + uplt.close(fig) + + +def test_choropleth_basemap_rejects_non_platecarree_transform(): + ccrs = pytest.importorskip("cartopy.crs") + sgeom = pytest.importorskip("shapely.geometry") + + fig, ax = uplt.subplots(proj="cyl", backend="basemap") + with pytest.raises(ValueError, match="Basemap choropleth"): + ax[0].choropleth( + [sgeom.box(-10, -10, 10, 10)], + [1.0], + transform=ccrs.Mercator(), + ) + uplt.close(fig) + + +def test_choropleth_country_mapping_with_explicit_values_raises(): + fig, ax = uplt.subplots(proj="cyl") + with pytest.raises(ValueError, match="does not accept both a mapping input"): + ax[0].choropleth({"AUS": 1.0}, [1.0], country=True) + uplt.close(fig) + + def test_check_tricontourf(): """ Ensure transform defaults are applied only when appropriate for tri-plots.