Source code for arviz_plots.plots.khat_plot

"""Plot Pareto tail indices."""

from collections.abc import Mapping, Sequence
from importlib import import_module
from typing import Any, Literal

import numpy as np
import xarray as xr
from arviz_base import rcParams
from arviz_base.labels import BaseLabeller

from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.utils import (
    calculate_khat_bin_edges,
    enable_hover_labels,
    filter_aes,
    format_coords_as_labels,
    get_visual_kwargs,
    set_wrap_layout,
)
from arviz_plots.visuals import (
    annotate_xy,
    hline,
    labelled_title,
    labelled_x,
    labelled_y,
    scatter_xy,
    set_xlim,
    set_xticks,
)


[docs] def plot_khat( elpd_data, threshold=None, show_hlines=False, show_bins=False, hover_label=False, hover_format="{index}: {label}", xlabels=False, legend=None, color=None, hline_values=None, bin_format="{pct:.1f}%", plot_collection=None, backend=None, labeller=None, aes_by_visuals: Mapping[ Literal[ "khat", "hlines", "bin_text", "threshold_text", "title", "xlabel", "ylabel", "ticks", ], Sequence[str], ] = None, visuals: Mapping[ Literal[ "khat", "hlines", "bin_text", "threshold_text", "title", "xlabel", "ylabel", "legend", "ticks", ], Mapping[str, Any] | bool, ] = None, **pc_kwargs, ): r"""Plot Pareto tail indices for diagnosing convergence in PSIS-LOO-CV. The Generalized Pareto distribution (GPD) is fitted to the largest importance ratios to diagnose convergence rates. The shape parameter :math:`\hat{k}` estimates the pre-asymptotic convergence rate based on the fractional number of finite moments. Values :math:`\hat{k} > 0.7` indicate impractically low convergence rates and unreliable estimates. Details are presented in [1]_ and [2]_. Parameters ---------- elpd_data : ELPDData ELPD data object returned by :func:`arviz_stats.loo` containing Pareto k diagnostics. threshold : float, optional Highlight khat values above this threshold with annotations. If None, no points are highlighted. show_hlines : bool, default False Show horizontal reference lines at diagnostic thresholds. show_bins : bool, default False Show the percentage of khat values falling in each bin delimited by reference lines. hover_label : bool, default False Enable interactive hover annotations when using an interactive backend. hover_format : str, default ``"{index}: {label}"`` Format string for hover annotations. Supports ``{index}``, ``{label}``, and ``{value}``. xlabels : bool, default False Show coordinate labels as x tick labels. legend : bool, optional Whether to display a legend when color aesthetics are active. If None, a legend is shown when a color mapping is available. color : color spec or str, optional Color for scatter points when no aesthetic mapping supplies one. If the value matches a dimension name, that dimension is mapped to the color aesthetic. hline_values : sequence of float, optional Custom horizontal line positions. Defaults to [0.0, 0.7, 1.0]. bin_format : str, default ``"{pct:.1f}%"`` Format string for bin percentages. Supports ``{count}`` and ``{pct}`` placeholders. plot_collection : PlotCollection, optional backend : {"matplotlib", "bokeh", "plotly"}, optional Plotting backend to use. Defaults to ``rcParams["plot.backend"]``. labeller : labeller, optional aes_by_visuals : mapping of {str : sequence of str or False}, optional Mapping of visuals to aesthetics that should use their mapping in `plot_collection` when plotted. Valid keys are the same as for `visuals`. By default: * khat -> uses all available aesthetic mappings * hlines -> uses no aesthetic mappings * bin_text -> uses no aesthetic mappings * threshold_text -> uses no aesthetic mappings * title -> uses no aesthetic mappings * xlabel -> uses no aesthetic mappings * ylabel -> uses no aesthetic mappings * ticks -> uses no aesthetic mappings visuals : mapping of {str : mapping or bool}, optional Valid keys are: * khat -> passed to :func:`~arviz_plots.visuals.scatter_xy` * hlines -> passed to :func:`~arviz_plots.visuals.hline` * bin_text -> passed to :func:`~arviz_plots.visuals.annotate_xy` * threshold_text -> passed to :func:`~arviz_plots.visuals.annotate_xy` * title -> passed to :func:`~arviz_plots.visuals.labelled_title` defaults to False * xlabel -> passed to :func:`~arviz_plots.visuals.labelled_x` * ylabel -> passed to :func:`~arviz_plots.visuals.labelled_y` * legend -> passed to :class:`arviz_plots.PlotCollection.add_legend` * ticks -> passed to :func:`~arviz_plots.visuals.set_xticks` **pc_kwargs Passed to :class:`arviz_plots.PlotCollection.wrap`. Returns ------- PlotCollection Warnings -------- When using custom markers via the ``visuals`` dict, ensure the marker type is compatible with your chosen backend. Not all marker types support separate facecolor and edgecolor across different backends. Examples -------- The most basic usage plots the Pareto k values from a LOO-CV computation. Each point represents one observation, with higher k values indicating less reliable importance sampling for that observation. .. plot:: :context: close-figs >>> from arviz_plots import plot_khat, style >>> style.use("arviz-variat") >>> from arviz_base import load_arviz_data >>> from arviz_stats import loo >>> dt = load_arviz_data("radon") >>> elpd_data = loo(dt, pointwise=True) >>> plot_khat(elpd_data, figure_kwargs={"figsize": (10, 5)}) We can highlight problematic observations by setting a ``threshold`` and add reference lines with ``show_hlines=True`` to visualize the diagnostic boundaries. Using ``show_bins=True`` displays the percentage of observations falling into each diagnostic category. Note that the ``hline_values`` parameter is independent of the ``threshold`` parameter. To draw a horizontal line at your custom threshold, you must set both parameters explicitly. .. plot:: :context: close-figs >>> plot_khat(elpd_data, >>> threshold=0.4, >>> show_hlines=True, >>> show_bins=True, >>> hline_values=[0.0, 0.4, 1.0], >>> visuals={"hlines": {"color":"B1"}}, >>> figure_kwargs={"figsize": (10, 5)} >>> ) .. minigallery:: plot_khat References ---------- .. [1] Vehtari et al. *Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC*. Statistics and Computing. 27(5) (2017). https://doi.org/10.1007/s11222-016-9696-4. arXiv preprint https://arxiv.org/abs/1507.04544. .. [2] Vehtari et al. *Pareto Smoothed Importance Sampling*. Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646 """ if hline_values is None: good_k = getattr(elpd_data, "good_k", 0.7) hline_values = [0.0, good_k, 1.0] else: hline_values = list(hline_values) if visuals is None: visuals = {} else: visuals = visuals.copy() visuals.setdefault("title", False) if backend is None: if plot_collection is None: backend = rcParams["plot.backend"] else: backend = plot_collection.backend if labeller is None: labeller = BaseLabeller() if aes_by_visuals is None: aes_by_visuals = {} else: aes_by_visuals = aes_by_visuals.copy() pc_kwargs = dict(pc_kwargs) if not hasattr(elpd_data, "pareto_k") or elpd_data.pareto_k is None: raise ValueError( "Could not find 'pareto_k' in the ELPDData object. " "Please ensure the LOO computation includes Pareto k diagnostics." ) khat_data = elpd_data.pareto_k distribution = khat_data.to_dataset(name="pareto_k") n_data_points = khat_data.size khat_dims = list(khat_data.dims) flat_coord_labels = format_coords_as_labels(khat_data) coord_map = {dim: khat_data.coords[dim] for dim in khat_dims if dim in khat_data.coords} if n_data_points: x_positions = np.arange(n_data_points).reshape(khat_data.shape) else: x_positions = np.zeros(khat_data.shape, dtype=float) xdata = xr.DataArray(x_positions, dims=khat_dims, coords=coord_map, name="pareto_k") x_dataset = xr.Dataset({"pareto_k": xdata}) khat_dataset = xr.concat([x_dataset, distribution], dim="plot_axis").assign_coords( plot_axis=["x", "y"] ) khat_values = np.asarray(khat_data.values).reshape(-1) x_flat = np.asarray(x_positions).reshape(-1) y_flat = np.asarray(khat_data.values).reshape(-1) x_min = x_flat.min() if x_flat.size else 0.0 x_max = x_flat.max() if x_flat.size else 0.0 good_k_threshold = hline_values[1] if len(hline_values) > 1 else 0.7 plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") scalar_ds = xr.Dataset({"pareto_k": xr.DataArray(0)}) if plot_collection is None: pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy() pc_kwargs["figure_kwargs"].setdefault("sharex", True) pc_kwargs["figure_kwargs"].setdefault("sharey", True) pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() if isinstance(color, str) and color in distribution.dims: pc_kwargs["aes"]["color"] = [color] color = None elif "model" in distribution.dims and "color" not in pc_kwargs["aes"]: pc_kwargs["aes"]["color"] = ["model"] pc_kwargs.setdefault("cols", []) pc_kwargs = set_wrap_layout(pc_kwargs, plot_bknd, distribution) plot_collection = PlotCollection.wrap( distribution, backend=backend, **pc_kwargs, ) aes_by_visuals.setdefault("khat", plot_collection.aes_set) aes_by_visuals.setdefault("hlines", []) aes_by_visuals.setdefault("bin_text", []) aes_by_visuals.setdefault("threshold_text", []) aes_by_visuals.setdefault("title", []) aes_by_visuals.setdefault("xlabel", []) aes_by_visuals.setdefault("ylabel", []) aes_by_visuals.setdefault("ticks", []) new_xlim = None khat_kwargs = get_visual_kwargs(visuals, "khat") if khat_kwargs is not False: _, khat_aes, khat_ignore = filter_aes(plot_collection, aes_by_visuals, "khat", []) default_color = khat_kwargs.get("color", color) if default_color is None and "color" not in khat_aes: default_color = "C0" if "color" not in khat_aes and default_color is not None: khat_kwargs.setdefault("color", default_color) plot_collection.map( scatter_xy, "khat", data=khat_dataset, ignore_aes=khat_ignore, **khat_kwargs, ) if show_hlines and hline_values: hlines_kwargs = get_visual_kwargs(visuals, "hlines") if hlines_kwargs is not False: _, hlines_aes, _ = filter_aes(plot_collection, aes_by_visuals, "hlines", []) for idx, value in enumerate(hline_values): h_kwargs = hlines_kwargs.copy() if "linestyle" not in hlines_aes: h_kwargs.setdefault("linestyle", f"C{idx}") if "color" not in hlines_aes: h_kwargs.setdefault("color", f"C{idx + 1}") if "alpha" not in hlines_aes: h_kwargs.setdefault("alpha", 0.7) h_ds = xr.Dataset({"pareto_k": xr.DataArray(value)}) plot_collection.map( hline, f"hline_{idx}", data=h_ds, ignore_aes="all", **h_kwargs, ) if show_bins: bin_text_kwargs = get_visual_kwargs(visuals, "bin_text") if bin_text_kwargs is not False: _, bin_text_aes, _ = filter_aes(plot_collection, aes_by_visuals, "bin_text", []) if "color" not in bin_text_aes: bin_text_kwargs.setdefault("color", "B1") bin_text_kwargs.setdefault("horizontal_align", "center") bin_edges = calculate_khat_bin_edges(khat_values, [good_k_threshold, 1.0]) if bin_edges is not None and n_data_points: counts, edges = np.histogram(khat_values, bins=bin_edges) span = max(1.0, x_max - x_min) x_margin = max(0.5, 0.05 * span) x_text = x_max + x_margin new_xlim = (x_min, x_text + x_margin) for bin_idx, count in enumerate(counts): if count == 0: continue lower = edges[bin_idx] upper = edges[bin_idx + 1] if np.isnan(lower) or np.isnan(upper): continue pct = (count / n_data_points * 100) if n_data_points else 0.0 label = bin_format.format(count=count, pct=pct) y_pos = 0.5 * (lower + upper) plot_collection.map( annotate_xy, f"bin_{bin_idx}", data=scalar_ds, x=x_text, y=y_pos, text=label, ignore_aes="all", **bin_text_kwargs, ) if threshold is not None and n_data_points: threshold_text_kwargs = get_visual_kwargs(visuals, "threshold_text") if threshold_text_kwargs is not False: _, _, threshold_text_ignore = filter_aes( plot_collection, aes_by_visuals, "threshold_text", [] ) threshold_text_kwargs.setdefault("color", "B1") threshold_text_kwargs.setdefault("vertical_align", "bottom") threshold_text_kwargs.setdefault("horizontal_align", "center") mask = np.asarray(khat_data > threshold).reshape(-1) indices = np.flatnonzero(mask) for flat_idx in indices: label_text = str(flat_coord_labels[flat_idx]) plot_collection.map( annotate_xy, f"threshold_{flat_idx}", data=scalar_ds, x=x_flat[flat_idx], y=y_flat[flat_idx], text=label_text, ignore_aes=threshold_text_ignore, **threshold_text_kwargs, ) if xlabels and x_flat.size and flat_coord_labels.size: ticks_kwargs = get_visual_kwargs(visuals, "ticks") if ticks_kwargs is not False: if "rotation" not in ticks_kwargs: ticks_kwargs.setdefault("rotation", 45) plot_collection.map( set_xticks, "ticks", data=scalar_ds, values=x_flat.tolist(), labels=[str(label) for label in flat_coord_labels], ignore_aes="all", store_artist=False, **ticks_kwargs, ) title_kwargs = get_visual_kwargs(visuals, "title") if title_kwargs is not False: _, title_aes, title_ignore = filter_aes(plot_collection, aes_by_visuals, "title", []) if "color" not in title_aes: title_kwargs.setdefault("color", "B1") plot_collection.map( labelled_title, "title", ignore_aes=title_ignore, subset_info=True, labeller=labeller, **title_kwargs, ) xlabel_kwargs = get_visual_kwargs(visuals, "xlabel") if xlabel_kwargs is not False: _, xlabel_aes, xlabel_ignore = filter_aes(plot_collection, aes_by_visuals, "xlabel", []) if "color" not in xlabel_aes: xlabel_kwargs.setdefault("color", "B1") xlabel_kwargs.setdefault("text", "Data Point") plot_collection.map( labelled_x, "xlabel", ignore_aes=xlabel_ignore, subset_info=True, **xlabel_kwargs, ) ylabel_kwargs = get_visual_kwargs(visuals, "ylabel") if ylabel_kwargs is not False: _, ylabel_aes, ylabel_ignore = filter_aes(plot_collection, aes_by_visuals, "ylabel", []) if "color" not in ylabel_aes: ylabel_kwargs.setdefault("color", "B1") ylabel_kwargs.setdefault("text", "Shape parameter k") plot_collection.map( labelled_y, "ylabel", ignore_aes=ylabel_ignore, subset_info=True, **ylabel_kwargs, ) legend_kwargs = get_visual_kwargs(visuals, "legend", default=None) if legend is False: legend_kwargs = False elif legend_kwargs is None: legend_kwargs = {} if legend_kwargs is not False and "color" in plot_collection.aes.children: color_mapping = plot_collection.aes["color"].data_vars.get("mapping") legend_dims = list(color_mapping.dims) if color_mapping is not None else [] if legend_kwargs is not False: legend_kwargs.setdefault("dim", legend_dims or ["color"]) if legend_kwargs is not False and (legend is None or legend): plot_collection.add_legend(**legend_kwargs) if new_xlim is not None: plot_collection.map( set_xlim, "xlim", data=scalar_ds, ignore_aes="all", store_artist=False, limits=new_xlim, ) if hover_label and n_data_points: labels_for_hover = [str(label) for label in flat_coord_labels] enable_hover_labels( backend, plot_collection, hover_format, labels_for_hover, None, y_flat, ) return plot_collection