diff --git a/docs/index.md b/docs/index.md index ed973294..a4e5c08a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -103,6 +103,7 @@ pip install git+https://github.com/CDCgov/PyRenew@main - [Latent subpopulation infections](tutorials/latent_subpopulation_infections.md) -- modeling latent infections with subpopulation structure. - [Observation processes: count data](tutorials/observation_processes_counts.md) -- connecting latent infections to observed counts. - [Observation processes: measurements](tutorials/observation_processes_measurements.md) -- connecting latent infections to continuous measurements. +- [Joint ascertainment](tutorials/ascertainment.md) -- sharing ascertainment structure across count signals. ## Resources diff --git a/docs/tutorials/.pages b/docs/tutorials/.pages index 9780c954..fd693b8d 100644 --- a/docs/tutorials/.pages +++ b/docs/tutorials/.pages @@ -7,3 +7,4 @@ nav: - observation_processes_measurements.md - right_truncation.md - day_of_week_effects.md + - ascertainment.md diff --git a/docs/tutorials/ascertainment.qmd b/docs/tutorials/ascertainment.qmd new file mode 100644 index 00000000..65db49b6 --- /dev/null +++ b/docs/tutorials/ascertainment.qmd @@ -0,0 +1,271 @@ +--- +title: Joint ascertainment +format: + gfm: + code-fold: true +engine: jupyter +jupyter: + jupytext: + text_representation: + extension: .qmd + format_name: quarto + format_version: "1.0" + jupytext_version: 1.19.1 + kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +```{python} +# | label: setup +# | output: false +import jax.nn as jnn +import jax.numpy as jnp +import jax.random as random +import numpy as np +import numpyro +import numpyro.distributions as dist +import pandas as pd +import plotnine as p9 +import warnings +from plotnine.exceptions import PlotnineWarning +from _tutorial_theme import theme_tutorial + +from pyrenew.ascertainment import JointAscertainment +from pyrenew.deterministic import DeterministicPMF +from pyrenew.model import PyrenewBuilder +from pyrenew.observation import NegativeBinomialNoise, PopulationCounts +from pyrenew.randomvariable import DistributionalVariable +from pyrenew.time import MMWR_WEEK + +warnings.filterwarnings("ignore", category=PlotnineWarning) +``` + +Ascertainment is the probability that a latent infection appears in an observed signal, e.g., a hospitalization or visit to the emergency department. +For hospital admissions this probability is often called an infection-hospitalization rate (IHR). +For emergency department visits it might be called an infection-ED-visit rate (IEDR). + +PyRenew count observations accept an `ascertainment_rate_rv`. +For simple models this can be any ordinary `RandomVariable`. + +For multi-signal models, PyRenew also provides model-level ascertainment components that let multiple observation processes share related ascertainment structure. +These components are intended for count signals whose observation probabilities are logically related because they are different event streams generated from the same latent infections. +Hospital admissions and ED visits are a natural example: they have different observation probabilities, but both depend on clinical care-seeking and reporting. + +## Independent scalar ascertainment + +Before considering multi-signal models with related ascertainment structure, we show how to specify a multi-signal model where the signals are modeled independently. +In this model, each observation process has its own scalar ascertainment prior. + +```{python} +# | label: independent-observation-processes +hosp_delay_pmf = jnp.array([0.05, 0.10, 0.15, 0.15, 0.20, 0.15, 0.15, 0.05]) +hosp_delay_rv = DeterministicPMF("inf_to_hosp_delay", hosp_delay_pmf) + +ed_delay_pmf = jnp.array([0.05, 0.15, 0.30, 0.30, 0.15, 0.05]) +ed_delay_rv = DeterministicPMF("inf_to_ed_delay", ed_delay_pmf) + +hospital_obs_independent = PopulationCounts( + name="hospital", + ascertainment_rate_rv=DistributionalVariable("ihr", dist.Beta(2, 198)), + delay_distribution_rv=hosp_delay_rv, + noise=NegativeBinomialNoise( + DistributionalVariable( + "hospital_concentration", dist.LogNormal(4.0, 0.5) + ) + ), +) + +ed_obs_independent = PopulationCounts( + name="ed_visits", + ascertainment_rate_rv=DistributionalVariable("iedr", dist.Beta(2, 98)), + delay_distribution_rv=ed_delay_rv, + noise=NegativeBinomialNoise( + DistributionalVariable("ed_concentration", dist.LogNormal(4.0, 0.5)) + ), +) +``` + +```{python} +# | label: independent-prior-draws +n_draws = 1500 +key_ihr, key_iedr = random.split(random.PRNGKey(1)) + +independent_ihr = dist.Beta(2, 198).sample(key_ihr, (n_draws,)) +independent_iedr = dist.Beta(2, 98).sample(key_iedr, (n_draws,)) +``` + +## Model-level ascertainment + +To share ascertainment structure across count signals, define an ascertainment model once and register it with the builder. +Each observation process receives the appropriate signal-specific accessor from `for_signal()`. + +```python +builder = PyrenewBuilder() + +joint_ascertainment = JointAscertainment( + name="he_ascertainment", + signals=("hospital", "ed_visits"), + baseline_rates=jnp.array([0.01, 0.02]), + scale_tril=jnp.array( + [[0.7, 0.0], + [0.35, 0.606],]), +) + +builder.add_ascertainment(joint_ascertainment) + +PopulationCounts( + name="hospital", + ascertainment_rate_rv=ascertainment.for_signal("hospital"), + ... +) + +PopulationCounts( + name="ed_visits", + ascertainment_rate_rv=ascertainment.for_signal("ed_visits"), + ... +) +``` + +The model samples the ascertainment component once per model execution. +The accessors passed to the observation processes read the sampled values. + +## Joint scalar ascertainment + +Use `JointAscertainment` when each signal has a scalar ascertainment rate, but the rates should be correlated across signals. +The model samples the rates jointly on the logit scale and returns one probability for each signal. +These rates are constant over the model time axis. + +```{python} +# | label: joint-ascertainment-object +joint_ascertainment = JointAscertainment( + name="he_ascertainment", + signals=("hospital", "ed_visits"), + baseline_rates=jnp.array([0.01, 0.02]), + scale_tril=jnp.array( + [ + [0.7, 0.0], + [0.35, 0.606], + ] + ), +) +``` + +The order of the specified `signals` determines how arguments `baseline_rates` and `scale_tril` are interpreted. +In the above example, the first entry is hospitalizations and the second is ED visits. +In a NumPyro trace, this object creates one sample site, `he_ascertainment_eta`, and deterministic signal-specific rates such as `he_ascertainment_hospital` and `he_ascertainment_ed_visits`. + +The next block samples directly from the same logit-normal prior used by JointAscertainment. +This is only for visualization; in a fitted PyRenew model, JointAscertainment handles this sampling internally. + +```{python} +# | label: joint-prior-draws +joint_eta = dist.MultivariateNormal( + loc=joint_ascertainment.baseline_logits, + scale_tril=joint_ascertainment.scale_tril, +).sample(random.PRNGKey(2), (n_draws,)) +joint_rates = jnn.sigmoid(joint_eta) + +independent_df = pd.DataFrame( + { + "hospital_ihr": np.array(independent_ihr), + "ed_iedr": np.array(independent_iedr), + } +) + +joint_df = pd.DataFrame( + { + "hospital_ihr": np.array(joint_rates[:, 0]), + "ed_iedr": np.array(joint_rates[:, 1]), + } +) + +compare_df = pd.concat( + [ + independent_df.assign(prior="Independent scalar priors"), + joint_df.assign(prior="Joint scalar prior"), + ], + ignore_index=True, +) +compare_df["prior"] = pd.Categorical( + compare_df["prior"], + categories=["Independent scalar priors", "Joint scalar prior"], + ordered=True, +) +``` + +The following plot compares the samples drawn from independent scalar priors and draws from the correlated logit-normal prior used by `JointAscertainment`. +The dashed line marks the prior-center ratio, IEDR = 2 × IHR. + +The prior draws from independent priors show that high IHR draws do not imply high IEDR draws. +The joint prior keeps the same approximate scale while inducing positive dependence between the two rates. + +```{python} +# | label: plot-joint-priors +# | fig-cap: | +# | Joint ascertainment induces correlation between scalar signal rates. +# | The dashed line marks the prior-center ratio, IEDR = 2 × IHR. +# | The joint prior keeps the same approximate scale while inducing +# | positive dependence between the two rates. + +( + p9.ggplot(compare_df, p9.aes(x="hospital_ihr", y="ed_iedr")) + + p9.geom_point(alpha=0.2, size=1.0, color="steelblue") + + p9.geom_abline(intercept=0, slope=2, linetype="dashed", color="gray") + + p9.facet_wrap("~prior", nrow=1) + + p9.labs( + x="Hospital ascertainment rate (IHR)", + y="ED ascertainment rate (IEDR)", + title="Independent vs. joint scalar ascertainment", + ) + + theme_tutorial +) +``` + +## Using ascertainment with a builder + +The main API pattern is: + +```python +builder = PyrenewBuilder() +builder.configure_latent(...) + +joint_ascertainment = JointAscertainment( + name="he_ascertainment", + signals=("hospital", "ed_visits"), + baseline_rates=jnp.array([0.01, 0.02]), + scale_tril=jnp.array( + [ + [0.7, 0.0], + [0.35, 0.606], + ] + ), +) +builder.add_ascertainment(joint_ascertainment) + +hospital_obs = PopulationCounts( + name="hospital", + ascertainment_rate_rv=joint_ascertainment.for_signal("hospital"), + delay_distribution_rv=hosp_delay_rv, + noise=hosp_noise_rv, + aggregation="weekly", + reporting_schedule="regular", + start_dow=MMWR_WEEK, +) +builder.add_observation(hospital_obs) + +ed_obs = PopulationCounts( + name="ed_visits", + ascertainment_rate_rv=joint_ascertainment.for_signal("ed_visits"), + delay_distribution_rv=ed_delay_rv, + noise=ed_noise_rv, + day_of_week_rv=DeterministicVariable("ed_dow", ed_day_of_week_effects), +) +builder.add_observation(ed_obs) + +model = builder.build() +``` + +Signal names in `for_signal()` must match the names used when the ascertainment model was created. +They do not have to match observation names, but matching them usually makes model code and posterior outputs easier to read. diff --git a/pyrenew/ascertainment/__init__.py b/pyrenew/ascertainment/__init__.py new file mode 100644 index 00000000..244261a2 --- /dev/null +++ b/pyrenew/ascertainment/__init__.py @@ -0,0 +1,13 @@ +# numpydoc ignore=GL08 +""" +Ascertainment models for shared observation-rate structure. +""" + +from pyrenew.ascertainment.base import AscertainmentModel, AscertainmentSignal +from pyrenew.ascertainment.joint import JointAscertainment + +__all__ = [ + "AscertainmentModel", + "AscertainmentSignal", + "JointAscertainment", +] diff --git a/pyrenew/ascertainment/base.py b/pyrenew/ascertainment/base.py new file mode 100644 index 00000000..9c0c823f --- /dev/null +++ b/pyrenew/ascertainment/base.py @@ -0,0 +1,187 @@ +# numpydoc ignore=GL08 +""" +Base classes for ascertainment models. +""" + +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from collections.abc import Mapping + +from jax.typing import ArrayLike + +from pyrenew.ascertainment.context import get_ascertainment_value +from pyrenew.metaclass import RandomVariable + + +class AscertainmentSignal(RandomVariable): + """ + Accessor for one signal's ascertainment value. + + Users usually do not instantiate this class directly. It is returned by + ``AscertainmentModel.for_signal(...)`` and passed to an observation process + as ``ascertainment_rate_rv``. During model execution, the parent + ``AscertainmentModel`` samples the actual rate once, and this accessor + retrieves the signal-specific value without creating additional NumPyro + sample sites. + """ + + def __init__( + self, + ascertainment_name: str, + signal_name: str, + ) -> None: + """ + Initialize a signal-specific ascertainment accessor. + + Parameters + ---------- + ascertainment_name + Name of the parent ascertainment model. + signal_name + Name of the signal to retrieve. + """ + if not isinstance(ascertainment_name, str) or len(ascertainment_name) == 0: + raise ValueError( + "ascertainment_name must be a non-empty string. " + f"Got {type(ascertainment_name).__name__}: {ascertainment_name!r}" + ) + if not isinstance(signal_name, str) or len(signal_name) == 0: + raise ValueError( + "signal_name must be a non-empty string. " + f"Got {type(signal_name).__name__}: {signal_name!r}" + ) + super().__init__(name=f"{ascertainment_name}_{signal_name}") + self.ascertainment_name = ascertainment_name + self.signal_name = signal_name + + def sample(self, **kwargs: object) -> ArrayLike: + """ + Return the sampled ascertainment value for this signal. + + Parameters + ---------- + **kwargs + Additional keyword arguments, ignored. + + Returns + ------- + ArrayLike + Signal-specific ascertainment value from the active context. + """ + return get_ascertainment_value( + ascertainment_name=self.ascertainment_name, + signal_name=self.signal_name, + ) + + +class AscertainmentModel(metaclass=ABCMeta): + """ + Base class for shared ascertainment structure. + + An ascertainment rate is the probability that latent incidence is observed + in a particular data stream. Examples include an infection-hospitalization + ratio for hospital admissions or an infection-ED-visit ratio for emergency + department visits. + + ``AscertainmentModel`` objects make shared structure explicit in a model + specification. A user defines the shared model once, registers it with + ``PyrenewBuilder.add_ascertainment(...)``, and passes signal-specific + accessors into observation processes: + + ```python + ascertainment = JointAscertainment(...) + builder.add_ascertainment(ascertainment) + + PopulationCounts( + name="hospital", + ascertainment_rate_rv=ascertainment.for_signal("hospital"), + ... + ) + ``` + + Subclasses own any NumPyro sites needed for the shared structure. + Accessors returned by ``for_signal()`` read the sampled values from the + active model context and do not sample independently. + """ + + def __init__( + self, + name: str, + signals: tuple[str, ...], + ) -> None: + """ + Initialize an ascertainment model. + + Parameters + ---------- + name + A non-empty string identifying the ascertainment model. + signals + Unique signal names produced by this model. + """ + if not isinstance(name, str) or len(name) == 0: + raise ValueError( + f"name must be a non-empty string. Got {type(name).__name__}: {name!r}" + ) + if not isinstance(signals, tuple) or len(signals) == 0: + raise ValueError("signals must be a non-empty tuple of strings.") + if any(not isinstance(signal, str) or len(signal) == 0 for signal in signals): + raise ValueError("all signals must be non-empty strings.") + if len(set(signals)) != len(signals): + raise ValueError("signals must be unique.") + + self.name = name + self.signals = signals + + def for_signal(self, signal_name: str) -> AscertainmentSignal: + """ + Return an observation-process accessor for one signal. + + Parameters + ---------- + signal_name + Name of the signal produced by this ascertainment model. This name + should match the signal name used when the ascertainment model was + constructed. It does not have to match the observation process name, + but using the same name usually makes model specifications easier + to read. + + Returns + ------- + AscertainmentSignal + RandomVariable-compatible accessor for the signal's sampled + ascertainment rate. + + Raises + ------ + ValueError + If ``signal_name`` is not produced by this model. + """ + if signal_name not in self.signals: + raise ValueError( + f"Unknown signal {signal_name!r} for ascertainment model " + f"{self.name!r}. Available signals: {self.signals}." + ) + return AscertainmentSignal( + ascertainment_name=self.name, + signal_name=signal_name, + ) + + @abstractmethod + def sample(self, **kwargs: object) -> Mapping[str, ArrayLike]: + """ + Sample all signal-specific ascertainment values owned by this model. + + Parameters + ---------- + **kwargs + Additional model-context arguments supplied by ``MultiSignalModel``. + Subclasses may ignore unused values. + + Returns + ------- + Mapping[str, ArrayLike] + Mapping from signal name to sampled ascertainment rate. + """ + pass # pragma: no cover diff --git a/pyrenew/ascertainment/context.py b/pyrenew/ascertainment/context.py new file mode 100644 index 00000000..7986d7f4 --- /dev/null +++ b/pyrenew/ascertainment/context.py @@ -0,0 +1,129 @@ +# numpydoc ignore=GL08 +""" +Execution context for sampled ascertainment values. + +Ascertainment models sample shared structure once per model execution. Their +signal-specific accessors are evaluated later by observation processes. This +module provides the context used to pass those sampled values from the model +level to the observation level without creating duplicate NumPyro sites. +""" + +from __future__ import annotations + +from collections.abc import Iterator, Mapping +from contextlib import contextmanager +from contextvars import ContextVar + +from jax.typing import ArrayLike + +_AscertainmentValues = Mapping[str, Mapping[str, ArrayLike]] + +_current_ascertainment_values: ContextVar[_AscertainmentValues | None] = ContextVar( + "current_ascertainment_values", + default=None, +) + + +def _validate_name(name: str, parameter: str) -> None: + """ + Validate a context lookup or mapping key name. + """ + if not isinstance(name, str) or len(name) == 0: + raise ValueError( + f"{parameter} must be a non-empty string. " + f"Got {type(name).__name__}: {name!r}" + ) + + +def _validate_ascertainment_values(values: _AscertainmentValues) -> None: + """ + Validate sampled values before making them available to accessors. + """ + if not isinstance(values, Mapping): + raise TypeError( + "ascertainment context values must be a mapping from " + "ascertainment model names to signal-value mappings." + ) + + for ascertainment_name, signal_values in values.items(): + _validate_name(ascertainment_name, "ascertainment model name") + if not isinstance(signal_values, Mapping): + raise TypeError( + "ascertainment context values must map each ascertainment " + "model name to a signal-value mapping." + ) + for signal_name in signal_values: + _validate_name(signal_name, "signal name") + + +@contextmanager +def ascertainment_context(values: _AscertainmentValues) -> Iterator[None]: + """ + Make sampled ascertainment values available to signal accessors. + + The context is entered by ``MultiSignalModel`` after all registered + ascertainment models have been sampled. Observation processes can then call + their ``ascertainment_rate_rv`` accessors and retrieve the corresponding + sampled value. + + Parameters + ---------- + values + Mapping from ascertainment model name to signal-specific values. + + Yields + ------ + None + The context in which signal accessors can retrieve sampled values. + """ + _validate_ascertainment_values(values) + token = _current_ascertainment_values.set(values) + try: + yield + finally: + _current_ascertainment_values.reset(token) + + +def get_ascertainment_value( + ascertainment_name: str, + signal_name: str, +) -> ArrayLike: + """ + Retrieve one signal's sampled ascertainment value from context. + + Parameters + ---------- + ascertainment_name + Name of the ascertainment model. + signal_name + Name of the signal. + + Returns + ------- + ArrayLike + The sampled ascertainment value for the requested signal. + + Raises + ------ + RuntimeError + If no ascertainment context is active, or if the requested model or + signal was not sampled in the active context. + """ + _validate_name(ascertainment_name, "ascertainment_name") + _validate_name(signal_name, "signal_name") + + values = _current_ascertainment_values.get() + if values is None: + raise RuntimeError( + f"Ascertainment signal {signal_name!r} from model " + f"{ascertainment_name!r} was requested before ascertainment " + "values were sampled." + ) + + try: + return values[ascertainment_name][signal_name] + except KeyError as exc: + raise RuntimeError( + f"Ascertainment signal {signal_name!r} from model " + f"{ascertainment_name!r} is not available in the current context." + ) from exc diff --git a/pyrenew/ascertainment/joint.py b/pyrenew/ascertainment/joint.py new file mode 100644 index 00000000..7b4a6e2a --- /dev/null +++ b/pyrenew/ascertainment/joint.py @@ -0,0 +1,146 @@ +# numpydoc ignore=GL08 +""" +Joint ascertainment models. +""" + +from __future__ import annotations + +from collections.abc import Mapping + +import jax.numpy as jnp +import numpyro +import numpyro.distributions as dist +from jax import Array +from jax.scipy.special import expit, logit +from jax.typing import ArrayLike + +from pyrenew.ascertainment.base import AscertainmentModel + + +class JointAscertainment(AscertainmentModel): + """ + Joint prior for scalar ascertainment rates across multiple signals. + + This model is useful when multiple observation streams have distinct but + related probabilities of observing latent incidence. For example, hospital + admissions and emergency department visits may have different + infection-to-observation ratios, while still being correlated because both + depend on care-seeking behavior, testing practices, or reporting systems. + + The model samples one logit multivariate normal vector given natural-scale + baseline ascertainment rates. + + ```text + eta ~ MultivariateNormal(logit(baseline_rates), covariance) + ascertainment_rate_j = sigmoid(eta_j) + ``` + + Each returned rate is scalar and constant over the model time axis. + """ + + def __init__( + self, + name: str, + signals: tuple[str, ...], + baseline_rates: ArrayLike, + scale_tril: ArrayLike | None = None, + covariance_matrix: ArrayLike | None = None, + precision_matrix: ArrayLike | None = None, + ) -> None: + """ + Initialize a joint scalar ascertainment model. + + Parameters + ---------- + name + Name of the ascertainment model. + signals + Unique signal names, such as ``("hospital", "ed_visits")``. The + order corresponds to entries in ``baseline_rates`` and the covariance + parameter. + baseline_rates + Natural-scale baseline ascertainment rates. Shape ``(n_signals,)``. + Values must be probabilities in ``(0, 1)``. A value of ``0.01`` + centers the corresponding ascertainment rate near 1 percent before + accounting for covariance. + scale_tril + Lower-triangular scale matrix for the multivariate normal on the + logit scale. Exactly one covariance parameter must be supplied. + covariance_matrix + Covariance matrix for the multivariate normal on the logit scale. + Exactly one covariance parameter must be supplied. + precision_matrix + Precision matrix for the multivariate normal on the logit scale. + Exactly one covariance parameter must be supplied. + """ + super().__init__(name=name, signals=signals) + self.baseline_rates: Array = jnp.asarray(baseline_rates) + self.scale_tril: Array | None = self._optional_array(scale_tril) + self.covariance_matrix: Array | None = self._optional_array(covariance_matrix) + self.precision_matrix: Array | None = self._optional_array(precision_matrix) + self._validate_parameters() + self.baseline_logits: Array = logit(self.baseline_rates) + self.distribution: dist.MultivariateNormal = dist.MultivariateNormal( + loc=self.baseline_logits, + scale_tril=self.scale_tril, + covariance_matrix=self.covariance_matrix, + precision_matrix=self.precision_matrix, + ) + + @staticmethod + def _optional_array(value: ArrayLike | None) -> Array | None: + """ + Convert optional array-like values to JAX arrays. + + Returns + ------- + Array | None + ``None`` if ``value`` is ``None``; otherwise ``value`` converted + to a JAX array. + """ + if value is None: + return None + return jnp.asarray(value) + + def _validate_parameters(self) -> None: + """ + Validate constructor parameters. + """ + n_signals = len(self.signals) + if self.baseline_rates.shape != (n_signals,): + raise ValueError( + "baseline_rates must have shape " + f"({n_signals},), got shape {self.baseline_rates.shape}." + ) + if jnp.any(self.baseline_rates <= 0) or jnp.any(self.baseline_rates >= 1): + raise ValueError( + "baseline_rates must contain probabilities in (0, 1), " + f"got {self.baseline_rates}." + ) + + def sample(self, **kwargs: object) -> Mapping[str, ArrayLike]: + """ + Sample jointly distributed scalar ascertainment rates. + + Parameters + ---------- + **kwargs + Additional model-context arguments, ignored. + + Returns + ------- + Mapping[str, ArrayLike] + Mapping from signal name to sampled scalar ascertainment rate. + """ + eta = numpyro.sample( + f"{self.name}_eta", + self.distribution, + ) + rates = expit(eta) + + result = {} + for signal, rate in zip(self.signals, rates): + numpyro.deterministic(f"{self.name}_{signal}", rate) + result[signal] = rate + + return result diff --git a/pyrenew/latent/base.py b/pyrenew/latent/base.py index 18742b12..f8840d96 100644 --- a/pyrenew/latent/base.py +++ b/pyrenew/latent/base.py @@ -82,13 +82,13 @@ class BaseLatentInfectionProcess(RandomVariable): gen_int_rv Generation interval PMF n_initialization_points - Number of initialization days before the first observation day. - Latent and observation arrays use a shared padded time axis with - element 0 at the start of this initialization period. In observation - natural coordinates, day 0 is the first observed data day; on the - shared padded axis, that same day is index ``n_initialization_points``. - Must be at least ``len(gen_int_rv())`` to provide enough history for - the renewal equation convolution. + Number of initialization days before the first observation day. + Latent and observation arrays use a shared padded time axis with + element 0 at the start of this initialization period. In observation + natural coordinates, day 0 is the first observed data day; on the + shared padded axis, that same day is index ``n_initialization_points``. + Must be at least ``len(gen_int_rv())`` to provide enough history for + the renewal equation convolution. Notes ----- diff --git a/pyrenew/model/multisignal_model.py b/pyrenew/model/multisignal_model.py index 733402cf..1a0f4b88 100644 --- a/pyrenew/model/multisignal_model.py +++ b/pyrenew/model/multisignal_model.py @@ -14,6 +14,8 @@ import numpyro.handlers from jax.typing import ArrayLike +from pyrenew.ascertainment import AscertainmentModel +from pyrenew.ascertainment.context import ascertainment_context from pyrenew.latent.base import BaseLatentInfectionProcess from pyrenew.metaclass import Model from pyrenew.observation.base import BaseObservationProcess @@ -40,6 +42,10 @@ class MultiSignalModel(Model): observations Dictionary mapping names to observation process instances. Names are used when passing observation data to sample(). + ascertainment_models + Optional dictionary mapping names to ascertainment model instances. + Each ascertainment model is sampled once per model execution before + observation processes run. Notes ----- @@ -53,6 +59,7 @@ def __init__( self, latent_process: BaseLatentInfectionProcess, observations: dict[str, BaseObservationProcess], + ascertainment_models: dict[str, AscertainmentModel] | None = None, ) -> None: """ Initialize multi-signal model. @@ -63,6 +70,8 @@ def __init__( Configured latent infection process observations Dictionary mapping observation names to observation process instances + ascertainment_models + Optional dictionary mapping names to ascertainment model instances Raises ------ @@ -72,6 +81,9 @@ def __init__( """ self.latent = latent_process self.observations = observations + if ascertainment_models is None: + ascertainment_models = {} + self.ascertainment_models = ascertainment_models self.validate() _SUPPORTED_RESOLUTIONS = {"aggregate", "subpop"} @@ -100,6 +112,17 @@ def validate(self) -> None: f"Observation '{name}' returned invalid infection_resolution " f"'{resolution}'. Expected one of {self._SUPPORTED_RESOLUTIONS}." ) + for name, ascertainment_model in self.ascertainment_models.items(): + if not isinstance(ascertainment_model, AscertainmentModel): + raise TypeError( + f"Ascertainment model '{name}' must be an AscertainmentModel, " + f"got {type(ascertainment_model).__name__}." + ) + if ascertainment_model.name != name: + raise ValueError( + f"Ascertainment model dictionary key {name!r} must match " + f"the model name {ascertainment_model.name!r}." + ) def pad_observations( self, @@ -365,25 +388,32 @@ def sample( "subpop": inf_all, } - # Apply each observation process - for name, obs_process in self.observations.items(): - # Get the appropriate latent infections based on observation type - resolution = obs_process.infection_resolution() - if resolution not in latent_map: - raise ValueError( - f"Observation '{name}' returned invalid infection_resolution " - f"'{resolution}'. Expected one of {self._SUPPORTED_RESOLUTIONS}." - ) - latent_infections = latent_map[resolution] - - # Get observation-specific data - obs_data = observation_data.get(name, {}) + ascertainment_values = { + name: ascertainment_model.sample() + for name, ascertainment_model in self.ascertainment_models.items() + } - # Sample from observation process - obs_process.sample( - infections=latent_infections, - first_day_dow=first_day_dow, - **obs_data, - ) + with ascertainment_context(ascertainment_values): + # Apply each observation process + for name, obs_process in self.observations.items(): + # Get the appropriate latent infections based on observation type + resolution = obs_process.infection_resolution() + if resolution not in latent_map: + raise ValueError( + f"Observation '{name}' returned invalid infection_resolution " + f"'{resolution}'. Expected one of " + f"{self._SUPPORTED_RESOLUTIONS}." + ) + latent_infections = latent_map[resolution] + + # Get observation-specific data + obs_data = observation_data.get(name, {}) + + # Sample from observation process + obs_process.sample( + infections=latent_infections, + first_day_dow=first_day_dow, + **obs_data, + ) return None diff --git a/pyrenew/model/pyrenew_builder.py b/pyrenew/model/pyrenew_builder.py index 02fc1e3f..43801373 100644 --- a/pyrenew/model/pyrenew_builder.py +++ b/pyrenew/model/pyrenew_builder.py @@ -9,6 +9,7 @@ from typing import Any +from pyrenew.ascertainment import AscertainmentModel from pyrenew.latent.base import BaseLatentInfectionProcess from pyrenew.model.multisignal_model import MultiSignalModel from pyrenew.observation.base import BaseObservationProcess @@ -49,6 +50,7 @@ def __init__(self) -> None: self.latent_class: type[BaseLatentInfectionProcess] | None = None self.latent_params: dict[str, Any] = {} self.observations: dict[str, BaseObservationProcess] = {} + self.ascertainment_models: dict[str, AscertainmentModel] = {} def configure_latent( self, @@ -147,6 +149,77 @@ def add_observation( self.observations[name] = obs_process return self + def add_ascertainment( + self, + ascertainment_model: AscertainmentModel, + ) -> PyrenewBuilder: + """ + Add shared ascertainment structure to the model. + + Use this method when observation probabilities are related across + signals. Independent scalar ascertainment rates do not require an + ascertainment model; those can be passed directly to an observation + process as ordinary ``RandomVariable`` objects. + + A registered ascertainment model is sampled once per model execution, + before observation processes run. Observation processes receive + signal-specific accessors from ``ascertainment_model.for_signal(...)``: + + ```python + ascertainment = JointAscertainment( + name="he_ascertainment", + signals=("hospital", "ed_visits"), + baseline_rates=..., + scale_tril=..., + ) + builder.add_ascertainment(ascertainment) + + builder.add_observation( + PopulationCounts( + name="hospital", + ascertainment_rate_rv=ascertainment.for_signal("hospital"), + ... + ) + ) + ``` + + The ascertainment model's ``name`` attribute is used as the unique + identifier in the built ``MultiSignalModel``. + + Parameters + ---------- + ascertainment_model + Configured ascertainment model instance, such as + ``JointAscertainment``. + + Returns + ------- + PyrenewBuilder + Self, for method chaining. + + Raises + ------ + TypeError + If ``ascertainment_model`` is not an ``AscertainmentModel``. + ValueError + If an ascertainment model with this name already exists. + """ + if not isinstance(ascertainment_model, AscertainmentModel): + raise TypeError( + "ascertainment_model must be an AscertainmentModel, " + f"got {type(ascertainment_model).__name__}." + ) + + name = ascertainment_model.name + if name in self.ascertainment_models: + raise ValueError( + f"Ascertainment model '{name}' already added. " + "Each ascertainment model must have a unique name." + ) + + self.ascertainment_models[name] = ascertainment_model + return self + def compute_n_initialization_points(self) -> int: """ Compute required n_initialization_points from all components. @@ -233,6 +306,7 @@ def build(self) -> MultiSignalModel: model = MultiSignalModel( latent_process=latent_process, observations=self.observations, + ascertainment_models=self.ascertainment_models, ) return model diff --git a/test/integration/conftest.py b/test/integration/conftest.py index 8a81e9a6..c4ce19df 100644 --- a/test/integration/conftest.py +++ b/test/integration/conftest.py @@ -12,6 +12,7 @@ import polars as pl import pytest +from pyrenew.ascertainment import JointAscertainment from pyrenew.datasets import ( load_example_infection_admission_interval, load_synthetic_daily_ed_visits, @@ -344,3 +345,90 @@ def he_weekly_model( builder.add_observation(ed_obs) return builder.build() + + +@pytest.fixture(scope="module") +def he_weekly_joint_ascertainment_model( + true_params: dict, + hosp_delay_pmf: jnp.ndarray, + ed_delay_pmf: jnp.ndarray, + ed_day_of_week_effects: jnp.ndarray, +) -> MultiSignalModel: + """ + Build a weekly-hospital + daily-ED model with joint ascertainment. + + The hospital observation is aggregated to MMWR epiweeks, the ED visit + observation stays daily, and both signal-specific ascertainment rates are + sampled once from a shared ``JointAscertainment`` model. This is + structurally comparable to the pyrenew-multisignal H+E model while keeping + PyRenew's scalar ascertainment-rate interface. + + Parameters + ---------- + true_params : dict + Ground-truth parameter dictionary used to center the prior. + hosp_delay_pmf : jnp.ndarray + Infection-to-hospitalization delay PMF. + ed_delay_pmf : jnp.ndarray + Infection-to-ED-visit delay PMF. + ed_day_of_week_effects : jnp.ndarray + Day-of-week multipliers used in synthetic ED generation. + + Returns + ------- + MultiSignalModel + Built model ready for fitting. + """ + gen_int_pmf = jnp.array( + [0.6326975, 0.2327564, 0.0856263, 0.03150015, 0.01158826, 0.00426308, 0.0015683] + ) + + true_ihr = true_params["hospitalizations"]["ihr"] + true_iedr = true_params["ed_visits"]["iedr"] + ascertainment = JointAscertainment( + name="he_ascertainment", + signals=("hospital", "ed_visits"), + baseline_rates=jnp.array([true_ihr, true_iedr]), + scale_tril=jnp.array( + [ + [0.7, 0.0], + [0.35, 0.606], + ] + ), + ) + + builder = PyrenewBuilder() + builder.configure_latent( + PopulationInfections, + gen_int_rv=DeterministicPMF("gen_int", gen_int_pmf), + I0_rv=DistributionalVariable("I0", dist.Beta(1, 10)), + log_rt_time_0_rv=DistributionalVariable("log_rt_time_0", dist.Normal(0.0, 0.5)), + single_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), + ) + builder.add_ascertainment(ascertainment) + + hospital_obs = PopulationCounts( + name="hospital", + ascertainment_rate_rv=ascertainment.for_signal("hospital"), + delay_distribution_rv=DeterministicPMF("hosp_delay", hosp_delay_pmf), + noise=NegativeBinomialNoise( + DistributionalVariable("hosp_conc", dist.LogNormal(5.0, 1.0)) + ), + aggregation="weekly", + reporting_schedule="regular", + start_dow=MMWR_WEEK, + ) + builder.add_observation(hospital_obs) + + ed_obs = PopulationCounts( + name="ed_visits", + ascertainment_rate_rv=ascertainment.for_signal("ed_visits"), + delay_distribution_rv=DeterministicPMF("ed_delay", ed_delay_pmf), + noise=NegativeBinomialNoise( + DistributionalVariable("ed_conc", dist.LogNormal(4.0, 1.0)) + ), + day_of_week_rv=DeterministicVariable("ed_dow", ed_day_of_week_effects), + ) + builder.add_observation(ed_obs) + + return builder.build() diff --git a/test/integration/test_population_infections_he_weekly_joint_ascertainment.py b/test/integration/test_population_infections_he_weekly_joint_ascertainment.py new file mode 100644 index 00000000..a00204b9 --- /dev/null +++ b/test/integration/test_population_infections_he_weekly_joint_ascertainment.py @@ -0,0 +1,386 @@ +""" +Integration test: weekly hospital + daily ED model with joint ascertainment. + +This exercises the mixed-cadence H+E structure with a ``JointAscertainment`` +model shared by the hospital and ED visit observation processes. +""" + +from __future__ import annotations + +from datetime import date + +import arviz as az +import jax +import jax.numpy as jnp +import jax.random as random +import numpyro +import polars as pl +import pytest + +from pyrenew.ascertainment import AscertainmentSignal +from pyrenew.model import MultiSignalModel +from pyrenew.time import MMWR_WEEK + +pytestmark = pytest.mark.integration + + +N_DAYS_FIT = 126 +NUM_WARMUP = 500 +NUM_SAMPLES = 500 +NUM_CHAINS = 4 +# First observation day of the synthetic data. 2023-11-05 is a Sunday. +OBS_START_DATE = date(2023, 11, 5) + + +def _build_hospital_obs_on_period_grid( + model: MultiSignalModel, + weekly_values: jnp.ndarray, + first_day_dow: int, +) -> jnp.ndarray: + """ + Build a dense weekly-observation array on the model's period grid. + + Parameters + ---------- + model : MultiSignalModel + Built model exposing ``latent.n_initialization_points``. + weekly_values : jnp.ndarray + Observed weekly hospital admissions, one per MMWR epiweek. + first_day_dow : int + Day-of-week index of element 0 of the shared daily axis. + + Returns + ------- + jnp.ndarray + Dense array with NaN for unobserved pre-data periods. + """ + hosp = model.observations["hospital"] + n_init = model.latent.n_initialization_points + n_total = n_init + N_DAYS_FIT + offset = hosp._compute_period_offset(first_day_dow, hosp.start_dow) + n_periods = (n_total - offset) // hosp.aggregation_period + n_pre = n_periods - len(weekly_values) + return jnp.concatenate([jnp.full(n_pre, jnp.nan, dtype=jnp.float32), weekly_values]) + + +class TestJointStructure: + """Check that the fixture has the intended H+E joint structure.""" + + def test_observation_cadences( + self, + he_weekly_joint_ascertainment_model: MultiSignalModel, + ) -> None: + """ + Verify hospital is weekly MMWR and ED visits remain daily. + + Parameters + ---------- + he_weekly_joint_ascertainment_model : MultiSignalModel + Built model with joint ascertainment. + """ + model = he_weekly_joint_ascertainment_model + + hospital = model.observations["hospital"] + assert hospital.aggregation == "weekly" + assert hospital.reporting_schedule == "regular" + assert hospital.start_dow == MMWR_WEEK + + ed_visits = model.observations["ed_visits"] + assert ed_visits.aggregation == "daily" + assert ed_visits.day_of_week_rv is not None + + def test_joint_ascertainment_is_registered( + self, + he_weekly_joint_ascertainment_model: MultiSignalModel, + ) -> None: + """ + Verify both count observations use accessors from the same model. + + Parameters + ---------- + he_weekly_joint_ascertainment_model : MultiSignalModel + Built model with joint ascertainment. + """ + model = he_weekly_joint_ascertainment_model + assert set(model.ascertainment_models) == {"he_ascertainment"} + + ascertainment = model.ascertainment_models["he_ascertainment"] + assert ascertainment.signals == ("hospital", "ed_visits") + + hospital_rate = model.observations["hospital"].ascertainment_rate_rv + ed_rate = model.observations["ed_visits"].ascertainment_rate_rv + assert isinstance(hospital_rate, AscertainmentSignal) + assert isinstance(ed_rate, AscertainmentSignal) + assert hospital_rate.ascertainment_name == "he_ascertainment" + assert hospital_rate.signal_name == "hospital" + assert ed_rate.ascertainment_name == "he_ascertainment" + assert ed_rate.signal_name == "ed_visits" + + def test_weekly_obs_alignment( + self, + he_weekly_joint_ascertainment_model: MultiSignalModel, + weekly_hosp: pl.DataFrame, + ) -> None: + """ + Verify weekly hospital observations align to the dense period grid. + + Parameters + ---------- + he_weekly_joint_ascertainment_model : MultiSignalModel + Built model with joint ascertainment. + weekly_hosp : pl.DataFrame + Weekly hospital admissions. + """ + model = he_weekly_joint_ascertainment_model + first_day_dow = model._resolve_first_day_dow(OBS_START_DATE) + weekly_values = jnp.array( + weekly_hosp["weekly_hosp_admits"].to_numpy(), dtype=jnp.float32 + ) + hosp_obs = _build_hospital_obs_on_period_grid( + model, weekly_values, first_day_dow + ) + + assert int((~jnp.isnan(hosp_obs)).sum()) == len(weekly_hosp) + assert jnp.isnan(hosp_obs[0]) + assert not jnp.isnan(hosp_obs[-1]) + + +class TestPriorPredictiveStructure: + """Check the NumPyro graph for joint ascertainment and mixed cadence.""" + + def test_joint_ascertainment_sites_are_sampled_once( + self, + he_weekly_joint_ascertainment_model: MultiSignalModel, + weekly_hosp: pl.DataFrame, + daily_ed: pl.DataFrame, + ) -> None: + """ + Single model execution exposes one joint latent vector and two rates. + + Parameters + ---------- + he_weekly_joint_ascertainment_model : MultiSignalModel + Built model with joint ascertainment. + weekly_hosp : pl.DataFrame + Weekly hospital admissions. + daily_ed : pl.DataFrame + Daily ED visits. + """ + model = he_weekly_joint_ascertainment_model + first_day_dow = model._resolve_first_day_dow(OBS_START_DATE) + weekly_values = jnp.array( + weekly_hosp["weekly_hosp_admits"].to_numpy(), dtype=jnp.float32 + ) + hosp_obs = _build_hospital_obs_on_period_grid( + model, weekly_values, first_day_dow + ) + ed_obs = model.pad_observations( + jnp.array(daily_ed["ed_visits"].to_numpy(), dtype=jnp.float32) + ) + population_size = float(weekly_hosp["pop"][0]) + + with numpyro.handlers.seed(rng_seed=0): + with numpyro.handlers.trace() as trace: + model.sample( + n_days_post_init=N_DAYS_FIT, + population_size=population_size, + obs_start_date=OBS_START_DATE, + hospital={"obs": hosp_obs}, + ed_visits={"obs": ed_obs}, + ) + + n_total = model.latent.n_initialization_points + N_DAYS_FIT + assert trace["he_ascertainment_eta"]["type"] == "sample" + assert trace["he_ascertainment_eta"]["value"].shape == (2,) + assert trace["he_ascertainment_hospital"]["type"] == "deterministic" + assert trace["he_ascertainment_ed_visits"]["type"] == "deterministic" + assert trace["hospital_predicted_daily"]["value"].shape == (n_total,) + assert trace["ed_visits_predicted"]["value"].shape == (n_total,) + + +class TestModelFit: + """Fit the joint-ascertainment H+E model and check core outputs.""" + + @pytest.fixture(scope="class") + def fitted_model( + self, + he_weekly_joint_ascertainment_model: MultiSignalModel, + weekly_hosp: pl.DataFrame, + daily_ed: pl.DataFrame, + ) -> MultiSignalModel: + """ + Fit the mixed-cadence joint-ascertainment H+E model. + + Parameters + ---------- + he_weekly_joint_ascertainment_model : MultiSignalModel + Built model with joint ascertainment. + weekly_hosp : pl.DataFrame + Weekly hospital admissions. + daily_ed : pl.DataFrame + Daily ED visits. + + Returns + ------- + MultiSignalModel + Model with MCMC results attached. + """ + model = he_weekly_joint_ascertainment_model + first_day_dow = model._resolve_first_day_dow(OBS_START_DATE) + + weekly_values = jnp.array( + weekly_hosp["weekly_hosp_admits"].to_numpy(), dtype=jnp.float32 + ) + hosp_obs = _build_hospital_obs_on_period_grid( + model, weekly_values, first_day_dow + ) + ed_obs = model.pad_observations( + jnp.array(daily_ed["ed_visits"].to_numpy(), dtype=jnp.float32) + ) + + model.run( + num_warmup=NUM_WARMUP, + num_samples=NUM_SAMPLES, + rng_key=random.PRNGKey(42), + mcmc_args={"num_chains": NUM_CHAINS, "progress_bar": False}, + n_days_post_init=N_DAYS_FIT, + population_size=float(weekly_hosp["pop"][0]), + obs_start_date=OBS_START_DATE, + hospital={"obs": hosp_obs}, + ed_visits={"obs": ed_obs}, + ) + + samples = model.mcmc.get_samples() + jax.block_until_ready(samples) + return model + + @pytest.fixture(scope="class") + def posterior_dt( + self, + fitted_model: MultiSignalModel, + ): + """ + Convert MCMC samples to an ArviZ DataTree, trimming init days. + + Parameters + ---------- + fitted_model : MultiSignalModel + Model with MCMC results. + + Returns + ------- + xarray.DataTree + ArviZ DataTree with posterior group. + """ + n_init = fitted_model.latent.n_initialization_points + dt = az.from_numpyro( + fitted_model.mcmc, + dims={ + "he_ascertainment_eta": ["signal"], + "latent_infections": ["time"], + "PopulationInfections::infections_aggregate": ["time"], + "PopulationInfections::log_rt_single": ["time", "dummy"], + "PopulationInfections::rt_single": ["time", "dummy"], + "hospital_predicted_daily": ["time"], + "hospital_predicted": ["week"], + "ed_visits_predicted": ["time"], + }, + ) + + def trim_init(ds): + """ + Trim the initialization period from daily-time variables. + + Parameters + ---------- + ds + Dataset to trim. + + Returns + ------- + xarray.Dataset + Dataset with ``time`` sliced to ``[n_init:]``. + """ + if "time" in ds.dims: + ds = ds.isel(time=slice(n_init, None)) + ds = ds.assign_coords(time=range(ds.sizes["time"])) + return ds + + return dt.map_over_datasets(trim_init) + + def test_mcmc_convergence( + self, + posterior_dt, + ) -> None: + """ + Check that core scalar parameters have acceptable Rhat and ESS. + + Parameters + ---------- + posterior_dt : xarray.DataTree + ArviZ DataTree with posterior group. + """ + summary = az.summary( + posterior_dt, + var_names=["I0", "log_rt_time_0", "he_ascertainment_eta"], + ) + rhat = summary["r_hat"].astype(float) + ess = summary["ess_bulk"].astype(float) + assert (rhat < 1.05).all(), f"Rhat exceeded 1.05:\n{summary[rhat >= 1.05]}" + assert (ess > 100).all(), f"ESS_bulk below 100:\n{summary[ess <= 100]}" + + def test_joint_rates_are_in_posterior( + self, + posterior_dt, + true_params: dict, + ) -> None: + """ + Check signal-specific rates are recorded and have plausible scale. + + Parameters + ---------- + posterior_dt : xarray.DataTree + ArviZ DataTree with posterior group. + true_params : dict + Ground-truth parameter dictionary. + """ + posterior = posterior_dt.posterior + assert "he_ascertainment_hospital" in posterior + assert "he_ascertainment_ed_visits" in posterior + + true_ihr = true_params["hospitalizations"]["ihr"] + true_iedr = true_params["ed_visits"]["iedr"] + ihr_median = float( + posterior["he_ascertainment_hospital"].median(dim=["chain", "draw"]).values + ) + iedr_median = float( + posterior["he_ascertainment_ed_visits"].median(dim=["chain", "draw"]).values + ) + + assert true_ihr / 5 <= ihr_median <= true_ihr * 5 + assert true_iedr / 5 <= iedr_median <= true_iedr * 5 + + def test_prediction_shapes( + self, + posterior_dt, + weekly_hosp: pl.DataFrame, + ) -> None: + """ + Check predictions live on weekly hospital and daily ED grids. + + Parameters + ---------- + posterior_dt : xarray.DataTree + ArviZ DataTree with posterior group. + weekly_hosp : pl.DataFrame + Weekly hospital admissions. + """ + posterior = posterior_dt.posterior + assert posterior["latent_infections"].sizes["time"] == N_DAYS_FIT + assert posterior["hospital_predicted_daily"].sizes["time"] == N_DAYS_FIT + assert posterior["hospital_predicted"].sizes["week"] >= len(weekly_hosp) + assert posterior["ed_visits_predicted"].sizes["time"] == N_DAYS_FIT + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-m", "integration"]) diff --git a/test/test_ascertainment.py b/test/test_ascertainment.py new file mode 100644 index 00000000..cec816ee --- /dev/null +++ b/test/test_ascertainment.py @@ -0,0 +1,338 @@ +""" +Tests for ascertainment models. +""" + +import jax.numpy as jnp +import numpyro +import pytest + +from pyrenew.ascertainment import ( + AscertainmentSignal, + JointAscertainment, +) +from pyrenew.ascertainment.context import ( + ascertainment_context, + get_ascertainment_value, +) + + +class TestJointAscertainmentValidation: + """Test JointAscertainment constructor validation.""" + + @pytest.mark.parametrize("name", ["", None]) + def test_requires_non_empty_name(self, name): + """Test that ascertainment model names must be non-empty strings.""" + with pytest.raises(ValueError, match="name must be a non-empty string"): + JointAscertainment( + name=name, + signals=("hospital", "ed"), + baseline_rates=jnp.full(2, 0.5), + scale_tril=jnp.eye(2), + ) + + @pytest.mark.parametrize( + "signals", + [ + (), + ["hospital", "ed"], + ("hospital", ""), + ("hospital", None), + ], + ) + def test_requires_non_empty_tuple_of_string_signals(self, signals): + """Test that signals must be a non-empty tuple of non-empty strings.""" + with pytest.raises(ValueError, match="signals|all signals"): + JointAscertainment( + name="he_ascertainment", + signals=signals, + baseline_rates=jnp.full(2, 0.5), + scale_tril=jnp.eye(2), + ) + + def test_requires_unique_signals(self): + """Test that signal names must be unique.""" + with pytest.raises(ValueError, match="signals must be unique"): + JointAscertainment( + name="he_ascertainment", + signals=("hospital", "hospital"), + baseline_rates=jnp.full(2, 0.5), + scale_tril=jnp.eye(2), + ) + + def test_rejects_unknown_signal(self): + """Test that for_signal rejects unknown signals.""" + ascertainment = JointAscertainment( + name="he_ascertainment", + signals=("hospital", "ed"), + baseline_rates=jnp.full(2, 0.5), + scale_tril=jnp.eye(2), + ) + + with pytest.raises(ValueError, match="Unknown signal"): + ascertainment.for_signal("wastewater") + + def test_requires_baseline_rates_shape_to_match_signals(self): + """Test that baseline_rates must have one entry per signal.""" + with pytest.raises(ValueError, match="baseline_rates must have shape"): + JointAscertainment( + name="he_ascertainment", + signals=("hospital", "ed"), + baseline_rates=jnp.full(3, 0.5), + scale_tril=jnp.eye(2), + ) + + @pytest.mark.parametrize("baseline_rates", [[0.0, 0.5], [1.0, 0.5], [-0.1, 0.5]]) + def test_requires_baseline_rates_in_open_unit_interval(self, baseline_rates): + """Test that baseline_rates must be natural-scale probabilities.""" + with pytest.raises(ValueError, match="baseline_rates must contain"): + JointAscertainment( + name="he_ascertainment", + signals=("hospital", "ed"), + baseline_rates=jnp.array(baseline_rates), + scale_tril=jnp.eye(2), + ) + + def test_requires_exactly_one_covariance_parameter(self): + """Test that exactly one multivariate normal matrix parameter is set.""" + with pytest.raises(ValueError, match="Exactly one"): + JointAscertainment( + name="he_ascertainment", + signals=("hospital", "ed"), + baseline_rates=jnp.full(2, 0.5), + ) + + with pytest.raises(ValueError, match="Exactly one"): + JointAscertainment( + name="he_ascertainment", + signals=("hospital", "ed"), + baseline_rates=jnp.full(2, 0.5), + scale_tril=jnp.eye(2), + covariance_matrix=jnp.eye(2), + ) + + def test_requires_matrix_shape_to_match_signals(self): + """Test that the covariance parameter must match signal count.""" + with pytest.raises(ValueError, match="Incompatible shapes"): + JointAscertainment( + name="he_ascertainment", + signals=("hospital", "ed"), + baseline_rates=jnp.full(2, 0.5), + scale_tril=jnp.eye(3), + ) + + def test_accepts_covariance_matrix(self): + """Test that covariance_matrix is accepted as the covariance parameter.""" + ascertainment = JointAscertainment( + name="he_ascertainment", + signals=("hospital", "ed"), + baseline_rates=jnp.full(2, 0.5), + covariance_matrix=jnp.eye(2), + ) + + assert ascertainment.covariance_matrix.shape == (2, 2) + + def test_accepts_precision_matrix(self): + """Test that precision_matrix is accepted as the covariance parameter.""" + ascertainment = JointAscertainment( + name="he_ascertainment", + signals=("hospital", "ed"), + baseline_rates=jnp.full(2, 0.5), + precision_matrix=jnp.eye(2), + ) + + assert ascertainment.precision_matrix.shape == (2, 2) + + +class TestAscertainmentSignalValidation: + """Test AscertainmentSignal constructor validation.""" + + @pytest.mark.parametrize("ascertainment_name", ["", None]) + def test_requires_non_empty_ascertainment_name(self, ascertainment_name): + """Test that ascertainment_name must be a non-empty string.""" + with pytest.raises(ValueError, match="ascertainment_name"): + AscertainmentSignal( + ascertainment_name=ascertainment_name, + signal_name="hospital", + ) + + @pytest.mark.parametrize("signal_name", ["", None]) + def test_requires_non_empty_signal_name(self, signal_name): + """Test that signal_name must be a non-empty string.""" + with pytest.raises(ValueError, match="signal_name"): + AscertainmentSignal( + ascertainment_name="he_ascertainment", + signal_name=signal_name, + ) + + +class TestJointAscertainmentSampling: + """Test JointAscertainment sampling behavior.""" + + def test_sample_creates_one_joint_sample_site_and_signal_deterministics(self): + """Test expected NumPyro sites and returned signal values.""" + ascertainment = JointAscertainment( + name="he_ascertainment", + signals=("hospital", "ed"), + baseline_rates=jnp.full(2, 0.5), + scale_tril=jnp.eye(2), + ) + + with numpyro.handlers.seed(rng_seed=42): + with numpyro.handlers.trace() as trace: + values = ascertainment.sample() + + assert set(values) == {"hospital", "ed"} + assert trace["he_ascertainment_eta"]["type"] == "sample" + assert trace["he_ascertainment_eta"]["value"].shape == (2,) + assert trace["he_ascertainment_hospital"]["type"] == "deterministic" + assert trace["he_ascertainment_ed"]["type"] == "deterministic" + assert jnp.array_equal( + values["hospital"], + trace["he_ascertainment_hospital"]["value"], + ) + assert jnp.array_equal( + values["ed"], + trace["he_ascertainment_ed"]["value"], + ) + + def test_sample_accepts_covariance_matrix_parameterization(self): + """Test joint ascertainment sampling with a covariance matrix.""" + ascertainment = JointAscertainment( + name="he_ascertainment", + signals=("hospital", "ed"), + baseline_rates=jnp.full(2, 0.5), + covariance_matrix=jnp.eye(2), + ) + + with numpyro.handlers.seed(rng_seed=42): + values = ascertainment.sample() + + assert set(values) == {"hospital", "ed"} + + def test_sample_accepts_precision_matrix_parameterization(self): + """Test joint ascertainment sampling with a precision matrix.""" + ascertainment = JointAscertainment( + name="he_ascertainment", + signals=("hospital", "ed"), + baseline_rates=jnp.full(2, 0.5), + precision_matrix=jnp.eye(2), + ) + + with numpyro.handlers.seed(rng_seed=42): + values = ascertainment.sample() + + assert set(values) == {"hospital", "ed"} + + def test_signal_accessor_reads_context_without_creating_sites(self): + """Test that signal accessors read context values and create no sites.""" + ascertainment = JointAscertainment( + name="he_ascertainment", + signals=("hospital", "ed"), + baseline_rates=jnp.full(2, 0.5), + scale_tril=jnp.eye(2), + ) + hospital = ascertainment.for_signal("hospital") + + with numpyro.handlers.trace() as trace: + with ascertainment_context( + {"he_ascertainment": {"hospital": jnp.array(0.25)}} + ): + value = hospital() + + assert value == jnp.array(0.25) + assert trace == {} + + def test_reused_signal_accessor_creates_no_duplicate_sites(self): + """Test repeated accessor calls still create no NumPyro sites.""" + ascertainment = JointAscertainment( + name="he_ascertainment", + signals=("hospital", "ed"), + baseline_rates=jnp.full(2, 0.5), + scale_tril=jnp.eye(2), + ) + hospital = ascertainment.for_signal("hospital") + + with numpyro.handlers.trace() as trace: + with ascertainment_context( + {"he_ascertainment": {"hospital": jnp.array(0.25)}} + ): + first = hospital() + second = hospital() + + assert first == jnp.array(0.25) + assert second == jnp.array(0.25) + assert trace == {} + + def test_signal_accessor_requires_active_context(self): + """Test that signal accessors fail clearly outside model context.""" + ascertainment = JointAscertainment( + name="he_ascertainment", + signals=("hospital", "ed"), + baseline_rates=jnp.full(2, 0.5), + scale_tril=jnp.eye(2), + ) + + with pytest.raises(RuntimeError, match="before ascertainment values"): + ascertainment.for_signal("hospital")() + + +class TestAscertainmentContextSafety: + """Test ascertainment context lifecycle and validation.""" + + @pytest.mark.parametrize( + "values, error_type", + [ + (None, TypeError), + ({"he_ascertainment": None}, TypeError), + ({"": {"hospital": jnp.array(0.1)}}, ValueError), + ({"he_ascertainment": {"": jnp.array(0.1)}}, ValueError), + ], + ) + def test_context_rejects_invalid_values(self, values, error_type): + """Test that malformed context payloads fail at context entry.""" + with pytest.raises(error_type): + with ascertainment_context(values): + pass + + @pytest.mark.parametrize( + "ascertainment_name, signal_name", + [ + ("", "hospital"), + ("he_ascertainment", ""), + ], + ) + def test_get_ascertainment_value_validates_lookup_names( + self, + ascertainment_name, + signal_name, + ): + """Test that context lookup names must be non-empty strings.""" + with pytest.raises(ValueError, match="must be a non-empty string"): + get_ascertainment_value(ascertainment_name, signal_name) + + def test_context_restores_outer_context_after_nested_context(self): + """Test nested contexts restore previous values on exit.""" + with ascertainment_context({"he_ascertainment": {"hospital": jnp.array(0.1)}}): + assert get_ascertainment_value("he_ascertainment", "hospital") == 0.1 + with ascertainment_context( + {"he_ascertainment": {"hospital": jnp.array(0.2)}} + ): + assert get_ascertainment_value("he_ascertainment", "hospital") == 0.2 + assert get_ascertainment_value("he_ascertainment", "hospital") == 0.1 + + def test_context_clears_after_exception(self): + """Test context is cleared even when an exception is raised.""" + with pytest.raises(RuntimeError, match="boom"): + with ascertainment_context( + {"he_ascertainment": {"hospital": jnp.array(0.1)}} + ): + raise RuntimeError("boom") + + with pytest.raises(RuntimeError, match="before ascertainment values"): + get_ascertainment_value("he_ascertainment", "hospital") + + def test_missing_context_value_raises_clear_error(self): + """Test unavailable context keys raise a clear RuntimeError.""" + with ascertainment_context({"he_ascertainment": {"hospital": jnp.array(0.1)}}): + with pytest.raises(RuntimeError, match="not available"): + get_ascertainment_value("he_ascertainment", "ed") diff --git a/test/test_pyrenew_builder.py b/test/test_pyrenew_builder.py index a67e69c3..00e4b52e 100644 --- a/test/test_pyrenew_builder.py +++ b/test/test_pyrenew_builder.py @@ -3,11 +3,13 @@ """ from datetime import date, timedelta +from types import SimpleNamespace import jax.numpy as jnp import numpyro import pytest +from pyrenew.ascertainment import JointAscertainment from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.latent import ( AR1, @@ -209,6 +211,42 @@ def test_rejects_duplicate_observation_name(self, simple_builder): with pytest.raises(ValueError, match="already added"): simple_builder.add_observation(obs) + def test_add_ascertainment_registers_model(self): + """Test that add_ascertainment stores ascertainment models by name.""" + builder = PyrenewBuilder() + ascertainment = JointAscertainment( + name="he_ascertainment", + signals=("hospital", "ed"), + baseline_rates=jnp.full(2, 0.5), + scale_tril=jnp.eye(2), + ) + + result = builder.add_ascertainment(ascertainment) + + assert result is builder + assert builder.ascertainment_models["he_ascertainment"] is ascertainment + + def test_add_ascertainment_rejects_duplicate_name(self): + """Test that duplicate ascertainment model names are rejected.""" + builder = PyrenewBuilder() + ascertainment = JointAscertainment( + name="he_ascertainment", + signals=("hospital", "ed"), + baseline_rates=jnp.full(2, 0.5), + scale_tril=jnp.eye(2), + ) + + builder.add_ascertainment(ascertainment) + with pytest.raises(ValueError, match="already added"): + builder.add_ascertainment(ascertainment) + + def test_add_ascertainment_rejects_wrong_type(self): + """Test that add_ascertainment requires an AscertainmentModel.""" + builder = PyrenewBuilder() + + with pytest.raises(TypeError, match="AscertainmentModel"): + builder.add_ascertainment(object()) + def test_build_creates_model(self, simple_builder): """Test that build() creates a MultiSignalModel.""" model = simple_builder.build() @@ -305,6 +343,163 @@ def test_prior_predictive_multi_signal(self, simple_builder): # All prior predictive infections should be positive assert jnp.all(prior_samples["latent_infections"] > 0) + def test_prior_predictive_with_joint_ascertainment(self): + """Test model-scoped sampling for shared joint ascertainment.""" + import jax.random + from numpyro.infer import Predictive + + builder = PyrenewBuilder() + gen_int = DeterministicPMF("gen_int", jnp.array([0.2, 0.5, 0.3])) + builder.configure_latent( + SubpopulationInfections, + gen_int_rv=gen_int, + I0_rv=DeterministicVariable("I0", 0.001), + log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), + baseline_rt_process=RandomWalk(), + subpop_rt_deviation_process=RandomWalk(), + ) + + ascertainment = JointAscertainment( + name="he_ascertainment", + signals=("hospital", "ed"), + baseline_rates=jnp.full(2, 0.5), + scale_tril=jnp.eye(2), + ) + builder.add_ascertainment(ascertainment) + + delay = DeterministicPMF("delay", jnp.array([0.1, 0.3, 0.4, 0.2])) + builder.add_observation( + PopulationCounts( + name="hospital", + ascertainment_rate_rv=ascertainment.for_signal("hospital"), + delay_distribution_rv=delay, + noise=NegativeBinomialNoise(DeterministicVariable("hosp_conc", 10.0)), + ) + ) + builder.add_observation( + PopulationCounts( + name="ed", + ascertainment_rate_rv=ascertainment.for_signal("ed"), + delay_distribution_rv=delay, + noise=NegativeBinomialNoise(DeterministicVariable("ed_conc", 10.0)), + ) + ) + + model = builder.build() + assert model.ascertainment_models["he_ascertainment"] is ascertainment + + predictive = Predictive(model.sample, num_samples=3) + prior_samples = predictive( + jax.random.PRNGKey(42), + n_days_post_init=10, + population_size=1_000_000, + subpop_fractions=SUBPOP_FRACTIONS, + hospital={"obs": None}, + ed={"obs": None}, + ) + + assert prior_samples["he_ascertainment_eta"].shape == (3, 2) + assert prior_samples["he_ascertainment_hospital"].shape == (3,) + assert prior_samples["he_ascertainment_ed"].shape == (3,) + assert "hospital_predicted" in prior_samples + assert "ed_predicted" in prior_samples + + with pytest.raises(RuntimeError, match="before ascertainment values"): + ascertainment.for_signal("hospital")() + + def test_prior_predictive_reuses_same_ascertainment_signal(self): + """Test two observations can reuse one signal accessor without site conflicts.""" + import jax.random + from numpyro.infer import Predictive + + builder = PyrenewBuilder() + gen_int = DeterministicPMF("gen_int", jnp.array([0.2, 0.5, 0.3])) + builder.configure_latent( + SubpopulationInfections, + gen_int_rv=gen_int, + I0_rv=DeterministicVariable("I0", 0.001), + log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), + baseline_rt_process=RandomWalk(), + subpop_rt_deviation_process=RandomWalk(), + ) + + ascertainment = JointAscertainment( + name="shared_ascertainment", + signals=("hospital", "ed"), + baseline_rates=jnp.full(2, 0.5), + scale_tril=jnp.eye(2), + ) + builder.add_ascertainment(ascertainment) + hospital_rate = ascertainment.for_signal("hospital") + + delay = DeterministicPMF("delay", jnp.array([0.1, 0.3, 0.4, 0.2])) + builder.add_observation( + PopulationCounts( + name="hospital_a", + ascertainment_rate_rv=hospital_rate, + delay_distribution_rv=delay, + noise=NegativeBinomialNoise(DeterministicVariable("conc_a", 10.0)), + ) + ) + builder.add_observation( + PopulationCounts( + name="hospital_b", + ascertainment_rate_rv=hospital_rate, + delay_distribution_rv=delay, + noise=NegativeBinomialNoise(DeterministicVariable("conc_b", 12.0)), + ) + ) + + model = builder.build() + predictive = Predictive(model.sample, num_samples=2) + prior_samples = predictive( + jax.random.PRNGKey(42), + n_days_post_init=10, + population_size=1_000_000, + subpop_fractions=SUBPOP_FRACTIONS, + hospital_a={"obs": None}, + hospital_b={"obs": None}, + ) + + assert prior_samples["shared_ascertainment_eta"].shape == (2, 2) + assert prior_samples["shared_ascertainment_hospital"].shape == (2,) + assert prior_samples["hospital_a_predicted"].shape == ( + 2, + model.latent.n_initialization_points + 10, + ) + assert prior_samples["hospital_b_predicted"].shape == ( + 2, + model.latent.n_initialization_points + 10, + ) + + def test_manual_model_rejects_invalid_ascertainment_model(self, simple_builder): + """Test direct MultiSignalModel construction validates ascertainment models.""" + model = simple_builder.build() + + with pytest.raises(TypeError, match="AscertainmentModel"): + MultiSignalModel( + latent_process=model.latent, + observations=model.observations, + ascertainment_models={"bad": object()}, + ) + + def test_manual_model_rejects_mismatched_ascertainment_key(self, simple_builder): + """Test ascertainment model dictionary keys must match model names.""" + model = simple_builder.build() + ascertainment = JointAscertainment( + name="he_ascertainment", + signals=("hospital", "ed"), + baseline_rates=jnp.full(2, 0.5), + scale_tril=jnp.eye(2), + ) + + with pytest.raises(ValueError, match="dictionary key"): + MultiSignalModel( + latent_process=model.latent, + observations=model.observations, + ascertainment_models={"wrong_name": ascertainment}, + ) + def test_first_day_dow_reaches_calendar_aligned_latent_process(self): """MultiSignalModel forwards model-axis day of week to the latent process.""" latent = PopulationInfections( @@ -501,6 +696,44 @@ def test_validate_method_calls_internal_validate(self, simple_builder): # Should not raise model.validate() + def test_sample_rejects_observation_resolution_that_changes_after_validation( + self, + ): + """Test sample-time infection_resolution validation.""" + + class Latent: # numpydoc ignore=GL08 + n_initialization_points = 0 + + def requires_calendar_anchor(self): # numpydoc ignore=GL08 + return False + + def sample(self, **kwargs): # numpydoc ignore=GL08 + return SimpleNamespace( + aggregate=jnp.ones(2), + all_subpops=jnp.ones((2, 1)), + ) + + class Observation: # numpydoc ignore=GL08 + def __init__(self): + self.calls = 0 + + def infection_resolution(self): # numpydoc ignore=GL08 + self.calls += 1 + if self.calls == 1: + return "aggregate" + return "invalid" + + def sample(self, **kwargs): # numpydoc ignore=GL08 + raise AssertionError("sample should not be called") + + model = MultiSignalModel( + latent_process=Latent(), + observations={"unstable": Observation()}, + ) + + with pytest.raises(ValueError, match="invalid infection_resolution"): + model.sample(n_days_post_init=2, population_size=1.0) + def test_validate_data_rejects_negative_subpop_indices(self, validation_builder): """Test that negative subpop_indices raises error.""" model = validation_builder.build()