Source code for arviz_plots.plots.ppc_censored_plot

"""Posterior predictive check for survival/censored data using Kaplan-Meier curves."""

from importlib import import_module
from typing import Any, Literal, Mapping, Sequence

from arviz_base import rcParams
from arviz_base.labels import BaseLabeller
from arviz_stats.survival import generate_survival_curves, kaplan_meier

from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.utils import (
    filter_aes,
    get_visual_kwargs,
    process_group_variables_coords,
    set_wrap_layout,
)
from arviz_plots.visuals import ecdf_line, labelled_title, labelled_x, labelled_y


[docs] def plot_ppc_censored( dt, var_names=None, filter_vars=None, group="posterior_predictive", coords=None, sample_dims=None, num_samples=100, truncation_factor=1.2, plot_collection=None, backend=None, labeller=None, aes_by_visuals: Mapping[ Literal[ "observed_km", "predictive", "xlabel", "ylabel", "title", ], Sequence[str], ] = None, visuals: Mapping[ Literal[ "observed_km", "predictive", "xlabel", "ylabel", "title", ], Mapping[str, Any] | bool, ] = None, **pc_kwargs, ): """Plot Kaplan-Meier survival curve [1]_ vs predictive draws. Instead of plotting the raw data observation and predictions, as is common in posterior predictive checks, this function computes the Kaplan-Meier survival curves for observed and for predictive data computes the truncated survival probabilities. The truncation is done as a factor of the maximum observed data to avoid extending the survival curves too far beyond the range of observed data. Parameters ---------- dt : DataTree Input data containing the predictive samples and observed data. Should contain groups specified by `group` and "observed_data", optionally including a censoring status variable in "constant_data". This censoring variable should be binary where 1 indicates an event occurred and 0 indicates censoring. var_names : str or list of str, optional One or more variables to be plotted. filter_vars : {None, "like", "regex"}, optional, default=None If None (default), interpret var_names as the real variables names. If "like", interpret var_names as substrings of the real variables names. If "regex", interpret var_names as regular expressions on the real variables names. group : str, default "posterior_predictive" Group to be plotted. Can be "posterior_predictive" or "prior_predictive". coords : dict, optional Coordinates to subset the data. sample_dims : str or sequence of hashable, optional Dimensions to reduce unless mapped to an aesthetic. Defaults to ``rcParams["data.sample_dims"]`` num_samples : int, optional Number of samples to plot. Defaults to 100. truncation_factor : float, default 1.2 Factor by which to truncate the survival curves beyond the maximum observed time. Set to `None` to show all posterior predictive draws. plot_collection : PlotCollection, optional Existing plot collection to add to. backend : {"matplotlib", "bokeh", "plotly"}, optional Plotting backend to use. labeller : labeller, optional Labeller for plot titles and axes. aes_by_visuals : mapping of {str : sequence of str}, optional Mapping of visuals to aesthetics that should use their mapping in `plot_collection` when plotted. Valid keys are the same as for `visuals`. visuals : mapping of {str : mapping or bool}, optional Valid keys are: * observed_km -> passed to :func:`~arviz_plots.visuals.ecdf_line` * predictive -> passed to :func:`~arviz_plots.visuals.ecdf_line` * xlabel -> passed to :func:`~arviz_plots.visuals.labelled_x` * ylabel -> passed to :func:`~arviz_plots.visuals.labelled_y` * title -> passed to :func:`~arviz_plots.visuals.labelled_title` **pc_kwargs Additional arguments passed to PlotCollection. Returns ------- PlotCollection The plot collection containing the survival curve plot. Examples -------- Plot Kaplan-Meier curves for posterior predictive checking: .. plot:: :context: close-figs >>> from arviz_plots import plot_ppc_censored, style >>> style.use("arviz-variat") >>> from arviz_base import load_arviz_data >>> dt = load_arviz_data('censored_cats') >>> plot_ppc_censored(dt) .. minigallery:: plot_ppc_censored References ---------- .. [1] Kaplan, E. L., & Meier, P. Nonparametric estimation from incomplete observations. JASA, 53(282). (1958) https://doi.org/10.1080/01621459.1958.10501452 """ if sample_dims is None: sample_dims = rcParams["data.sample_dims"] if isinstance(sample_dims, str): sample_dims = [sample_dims] sample_dims = list(sample_dims) if visuals is None: visuals = {} else: visuals = visuals.copy() if aes_by_visuals is None: aes_by_visuals = {} else: aes_by_visuals = aes_by_visuals.copy() if backend is None: if plot_collection is None: backend = rcParams["plot.backend"] else: backend = plot_collection.backend if labeller is None: labeller = BaseLabeller() # Get predictive data predictive_dist = process_group_variables_coords( dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords ) # Get observed data if "observed_data" in dt: observed_dist = process_group_variables_coords( dt, group="observed_data", var_names=var_names, filter_vars=filter_vars, coords=coords, ) km_ds = kaplan_meier(dt, var_names=observed_dist.data_vars) else: observed_dist = None predictive_ds = generate_survival_curves( dt, var_names=predictive_dist.data_vars, group=group, num_samples=num_samples, truncation_factor=truncation_factor, ) plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") if plot_collection is None: pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy() pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() pc_kwargs["aes"].setdefault("overlay_ppc", ["sample"]) pc_kwargs.setdefault("cols", "__variable__") pc_kwargs = set_wrap_layout(pc_kwargs, plot_bknd, predictive_ds) plot_collection = PlotCollection.wrap( predictive_ds, backend=backend, **pc_kwargs, ) aes_by_visuals.setdefault("predictive", ["overlay_ppc"]) aes_by_visuals.setdefault("observed_km", plot_collection.aes_set) # Plot predictive survival curves predictive_kwargs = get_visual_kwargs(visuals, "predictive") if predictive_kwargs is not False: _, predictive_aes, predictive_ignore = filter_aes( plot_collection, aes_by_visuals, "predictive", sample_dims ) if "color" not in predictive_aes: predictive_kwargs.setdefault("color", "C0") predictive_kwargs.setdefault("alpha", 0.7) plot_collection.map( ecdf_line, "predictive", data=predictive_ds, ignore_aes=predictive_ignore, **predictive_kwargs, ) # Plot Kaplan-Meier curve observed_km_kwargs = get_visual_kwargs( visuals, "observed_km", False if group == "prior_predictive" or observed_dist is None else None, ) if observed_km_kwargs is not False: _, observed_aes, observed_ignore = filter_aes( plot_collection, aes_by_visuals, "observed_km", sample_dims ) if "color" not in observed_aes: observed_km_kwargs.setdefault("color", "B1") observed_km_kwargs.setdefault("linestyle", "C1") plot_collection.map( ecdf_line, "observed_km", data=km_ds, ignore_aes=observed_ignore, **observed_km_kwargs, ) # Add labels xlabel_kwargs = get_visual_kwargs(visuals, "xlabel") if xlabel_kwargs is not False: xlabel_kwargs.setdefault("color", "B1") plot_collection.map( labelled_x, "xlabel", data=km_ds, subset_info=True, labeller=labeller, ignore_aes=plot_collection.aes_set, **xlabel_kwargs, ) ylabel_kwargs = get_visual_kwargs(visuals, "ylabel") if ylabel_kwargs is not False: ylabel_kwargs.setdefault("text", "Survival Probability") ylabel_kwargs.setdefault("color", "B1") plot_collection.map( labelled_y, "ylabel", ignore_aes=plot_collection.aes_set, **ylabel_kwargs, ) # Add title title_kwargs = get_visual_kwargs(visuals, "title", False) if title_kwargs is not False: title_kwargs.setdefault("color", "B1") plot_collection.map( labelled_title, "title", ignore_aes=plot_collection.aes_set, subset_info=True, labeller=labeller, **title_kwargs, ) return plot_collection