Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 168 additions & 92 deletions ultraplot/constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import re
from functools import partial
from numbers import Number
from typing import Callable, Iterator, TypeVar

import cycler
import matplotlib.colors as mcolors
Expand Down Expand Up @@ -68,60 +69,177 @@
DEFAULT_CYCLE_SAMPLES = 10
DEFAULT_CYCLE_LUMINANCE = 90

_RegistryValue = TypeVar("_RegistryValue")


class _RefreshingRegistry(dict[str, _RegistryValue]):
"""
Dictionary-like registry that rebuilds itself before reads.

This keeps constructor registries aligned with modules that may be reloaded
in-place during tests or interactive use.
"""

def __init__(self, factory: Callable[[], dict[str, _RegistryValue]]) -> None:
self._factory = factory
super().__init__(factory())

def _refresh(self) -> None:
super().clear()
super().update(self._factory())

def __contains__(self, key: object) -> bool:
self._refresh()
return super().__contains__(key)

def __getitem__(self, key: str) -> _RegistryValue:
self._refresh()
return super().__getitem__(key)

def __iter__(self) -> Iterator[str]:
self._refresh()
return super().__iter__()

def __len__(self) -> int:
self._refresh()
return super().__len__()

def get(
self, key: str, default: _RegistryValue | None = None
) -> _RegistryValue | None:
self._refresh()
return super().get(key, default)

def items(self): # type: ignore[override]
self._refresh()
return super().items()

def keys(self): # type: ignore[override]
self._refresh()
return super().keys()

def values(self): # type: ignore[override]
self._refresh()
return super().values()

def copy(self) -> dict[str, _RegistryValue]:
self._refresh()
return dict(super().items())


def _build_norm_registry() -> dict[str, type[mcolors.Normalize]]:
registry: dict[str, type[mcolors.Normalize]] = {
"none": mcolors.NoNorm,
"null": mcolors.NoNorm,
"div": pcolors.DivergingNorm,
"diverging": pcolors.DivergingNorm,
"segmented": pcolors.SegmentedNorm,
"segments": pcolors.SegmentedNorm,
"log": mcolors.LogNorm,
"linear": mcolors.Normalize,
"power": mcolors.PowerNorm,
"symlog": mcolors.SymLogNorm,
}
if hasattr(mcolors, "TwoSlopeNorm"):
registry["twoslope"] = mcolors.TwoSlopeNorm
return registry


def _build_locator_registry() -> dict[str, object]:
registry = {
"none": mticker.NullLocator,
"null": mticker.NullLocator,
"auto": mticker.AutoLocator,
"log": mticker.LogLocator,
"maxn": mticker.MaxNLocator,
"linear": mticker.LinearLocator,
"multiple": mticker.MultipleLocator,
"fixed": mticker.FixedLocator,
"index": pticker.IndexLocator,
"discrete": pticker.DiscreteLocator,
"discreteminor": partial(pticker.DiscreteLocator, minor=True),
"symlog": mticker.SymmetricalLogLocator,
"logit": mticker.LogitLocator,
"minor": mticker.AutoMinorLocator,
"date": mdates.AutoDateLocator,
"microsecond": mdates.MicrosecondLocator,
"second": mdates.SecondLocator,
"minute": mdates.MinuteLocator,
"hour": mdates.HourLocator,
"day": mdates.DayLocator,
"weekday": mdates.WeekdayLocator,
"month": mdates.MonthLocator,
"year": mdates.YearLocator,
"lon": partial(pticker.LongitudeLocator, dms=False),
"lat": partial(pticker.LatitudeLocator, dms=False),
"deglon": partial(pticker.LongitudeLocator, dms=False),
"deglat": partial(pticker.LatitudeLocator, dms=False),
}
if hasattr(mpolar, "ThetaLocator"):
registry["theta"] = mpolar.ThetaLocator
if _version_cartopy >= "0.18":
registry["dms"] = partial(pticker.DegreeLocator, dms=True)
registry["dmslon"] = partial(pticker.LongitudeLocator, dms=True)
registry["dmslat"] = partial(pticker.LatitudeLocator, dms=True)
return registry


def _build_formatter_registry() -> dict[str, object]:
registry = { # note default LogFormatter uses ugly e+00 notation
"none": mticker.NullFormatter,
"null": mticker.NullFormatter,
"auto": pticker.AutoFormatter,
"date": mdates.AutoDateFormatter,
"scalar": mticker.ScalarFormatter,
"simple": pticker.SimpleFormatter,
"fixed": mticker.FixedLocator,
"index": pticker.IndexFormatter,
"sci": pticker.SciFormatter,
"sigfig": pticker.SigFigFormatter,
"frac": pticker.FracFormatter,
"func": mticker.FuncFormatter,
"strmethod": mticker.StrMethodFormatter,
"formatstr": mticker.FormatStrFormatter,
"datestr": mdates.DateFormatter,
"log": mticker.LogFormatterSciNotation,
"logit": mticker.LogitFormatter,
"eng": mticker.EngFormatter,
"percent": mticker.PercentFormatter,
"e": partial(pticker.FracFormatter, symbol=r"$e$", number=np.e),
"pi": partial(pticker.FracFormatter, symbol=r"$\pi$", number=np.pi),
"tau": partial(pticker.FracFormatter, symbol=r"$\tau$", number=2 * np.pi),
"lat": partial(pticker.SimpleFormatter, negpos="SN"),
"lon": partial(pticker.SimpleFormatter, negpos="WE", wraprange=(-180, 180)),
"deg": partial(pticker.SimpleFormatter, suffix="\N{DEGREE SIGN}"),
"deglat": partial(
pticker.SimpleFormatter, suffix="\N{DEGREE SIGN}", negpos="SN"
),
"deglon": partial(
pticker.SimpleFormatter,
suffix="\N{DEGREE SIGN}",
negpos="WE",
wraprange=(-180, 180),
),
"math": mticker.LogFormatterMathtext,
}
if hasattr(mpolar, "ThetaFormatter"):
registry["theta"] = mpolar.ThetaFormatter
if hasattr(mdates, "ConciseDateFormatter"):
registry["concise"] = mdates.ConciseDateFormatter
if _version_cartopy >= "0.18":
registry["dms"] = partial(pticker.DegreeFormatter, dms=True)
registry["dmslon"] = partial(pticker.LongitudeFormatter, dms=True)
registry["dmslat"] = partial(pticker.LatitudeFormatter, dms=True)
return registry


# Normalizer registry
NORMS = {
"none": mcolors.NoNorm,
"null": mcolors.NoNorm,
"div": pcolors.DivergingNorm,
"diverging": pcolors.DivergingNorm,
"segmented": pcolors.SegmentedNorm,
"segments": pcolors.SegmentedNorm,
"log": mcolors.LogNorm,
"linear": mcolors.Normalize,
"power": mcolors.PowerNorm,
"symlog": mcolors.SymLogNorm,
}
if hasattr(mcolors, "TwoSlopeNorm"):
NORMS["twoslope"] = mcolors.TwoSlopeNorm
NORMS = _RefreshingRegistry(_build_norm_registry)

# Locator registry
# NOTE: Will raise error when you try to use degree-minute-second
# locators with cartopy < 0.18.
LOCATORS = {
"none": mticker.NullLocator,
"null": mticker.NullLocator,
"auto": mticker.AutoLocator,
"log": mticker.LogLocator,
"maxn": mticker.MaxNLocator,
"linear": mticker.LinearLocator,
"multiple": mticker.MultipleLocator,
"fixed": mticker.FixedLocator,
"index": pticker.IndexLocator,
"discrete": pticker.DiscreteLocator,
"discreteminor": partial(pticker.DiscreteLocator, minor=True),
"symlog": mticker.SymmetricalLogLocator,
"logit": mticker.LogitLocator,
"minor": mticker.AutoMinorLocator,
"date": mdates.AutoDateLocator,
"microsecond": mdates.MicrosecondLocator,
"second": mdates.SecondLocator,
"minute": mdates.MinuteLocator,
"hour": mdates.HourLocator,
"day": mdates.DayLocator,
"weekday": mdates.WeekdayLocator,
"month": mdates.MonthLocator,
"year": mdates.YearLocator,
"lon": partial(pticker.LongitudeLocator, dms=False),
"lat": partial(pticker.LatitudeLocator, dms=False),
"deglon": partial(pticker.LongitudeLocator, dms=False),
"deglat": partial(pticker.LatitudeLocator, dms=False),
}
if hasattr(mpolar, "ThetaLocator"):
LOCATORS["theta"] = mpolar.ThetaLocator
if _version_cartopy >= "0.18":
LOCATORS["dms"] = partial(pticker.DegreeLocator, dms=True)
LOCATORS["dmslon"] = partial(pticker.LongitudeLocator, dms=True)
LOCATORS["dmslat"] = partial(pticker.LatitudeLocator, dms=True)
LOCATORS = _RefreshingRegistry(_build_locator_registry)

# Formatter registry
# NOTE: Critical to use SimpleFormatter for cardinal formatters rather than
Expand All @@ -130,49 +248,7 @@
# is their distinguishing feature relative to ultraplot formatter.
# NOTE: Will raise error when you try to use degree-minute-second
# formatters with cartopy < 0.18.
FORMATTERS = { # note default LogFormatter uses ugly e+00 notation
"none": mticker.NullFormatter,
"null": mticker.NullFormatter,
"auto": pticker.AutoFormatter,
"date": mdates.AutoDateFormatter,
"scalar": mticker.ScalarFormatter,
"simple": pticker.SimpleFormatter,
"fixed": mticker.FixedLocator,
"index": pticker.IndexFormatter,
"sci": pticker.SciFormatter,
"sigfig": pticker.SigFigFormatter,
"frac": pticker.FracFormatter,
"func": mticker.FuncFormatter,
"strmethod": mticker.StrMethodFormatter,
"formatstr": mticker.FormatStrFormatter,
"datestr": mdates.DateFormatter,
"log": mticker.LogFormatterSciNotation, # NOTE: this is subclass of Mathtext class
"logit": mticker.LogitFormatter,
"eng": mticker.EngFormatter,
"percent": mticker.PercentFormatter,
"e": partial(pticker.FracFormatter, symbol=r"$e$", number=np.e),
"pi": partial(pticker.FracFormatter, symbol=r"$\pi$", number=np.pi),
"tau": partial(pticker.FracFormatter, symbol=r"$\tau$", number=2 * np.pi),
"lat": partial(pticker.SimpleFormatter, negpos="SN"),
"lon": partial(pticker.SimpleFormatter, negpos="WE", wraprange=(-180, 180)),
"deg": partial(pticker.SimpleFormatter, suffix="\N{DEGREE SIGN}"),
"deglat": partial(pticker.SimpleFormatter, suffix="\N{DEGREE SIGN}", negpos="SN"),
"deglon": partial(
pticker.SimpleFormatter,
suffix="\N{DEGREE SIGN}",
negpos="WE",
wraprange=(-180, 180),
), # noqa: E501
"math": mticker.LogFormatterMathtext, # deprecated (use SciNotation subclass)
}
if hasattr(mpolar, "ThetaFormatter"):
FORMATTERS["theta"] = mpolar.ThetaFormatter
if hasattr(mdates, "ConciseDateFormatter"):
FORMATTERS["concise"] = mdates.ConciseDateFormatter
if _version_cartopy >= "0.18":
FORMATTERS["dms"] = partial(pticker.DegreeFormatter, dms=True)
FORMATTERS["dmslon"] = partial(pticker.LongitudeFormatter, dms=True)
FORMATTERS["dmslat"] = partial(pticker.LatitudeFormatter, dms=True)
FORMATTERS = _RefreshingRegistry(_build_formatter_registry)

# Scale registry and presets
SCALES = mscale._scale_mapping
Expand Down
12 changes: 12 additions & 0 deletions ultraplot/tests/test_constructor_helpers_extra.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#!/usr/bin/env python3
"""Additional branch coverage for constructor helpers."""

import importlib

import cycler
import matplotlib.colors as mcolors
import matplotlib.dates as mdates
Expand Down Expand Up @@ -167,6 +169,16 @@ def test_norm_locator_formatter_and_scale_branches():
constructor.Scale(object())


def test_formatter_registry_refreshes_after_ticker_reload():
import ultraplot.ticker

importlib.reload(ultraplot.ticker)

assert constructor.FORMATTERS["sigfig"] is pticker.SigFigFormatter
formatter = constructor.Formatter(("sigfig", 3))
assert isinstance(formatter, pticker.SigFigFormatter)


def test_proj_constructor_branches():
ccrs = pytest.importorskip("cartopy.crs")

Expand Down
Loading