-
Notifications
You must be signed in to change notification settings - Fork 9
Add Ascertainment model for JointAscertainment #802
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
680bb1e
2cb876b
60db8df
32a5314
d6213f2
96f27c9
1cb6fa2
f62e1e4
0c6785d
1ee62b9
0629461
efeadee
371ba98
0304bed
ffeea65
50e7261
dae6af8
5cb3097
1d80ccc
e73b401
b1473b5
0b929b5
3ee00a7
307982a
b862bc6
2c665a5
60d6458
ec8c464
c018bf7
d0207dd
f3c706a
684c6c5
ca2454f
0f38afc
d8e7a57
7e9b5fe
e1d8014
83ddbf0
69ea4ea
de851b9
2bde8fd
393e279
ec95c5a
6fcaa11
bd0fd60
594da80
c639fe9
bf2efc6
4b3646a
c3d9449
2485638
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,3 +7,4 @@ nav: | |
| - observation_processes_measurements.md | ||
| - right_truncation.md | ||
| - day_of_week_effects.md | ||
| - ascertainment.md | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,264 @@ | ||||||
| --- | ||||||
| title: Joint ascertainment | ||||||
| format: | ||||||
| gfm: | ||||||
| code-fold: true | ||||||
| engine: jupyter | ||||||
| jupyter: | ||||||
| jupytext: | ||||||
| text_representation: | ||||||
| extension: .qmd | ||||||
| format_name: quarto | ||||||
| format_version: '1.0' | ||||||
| jupytext_version: 1.19.1 | ||||||
| kernelspec: | ||||||
| display_name: Python 3 (ipykernel) | ||||||
| language: python | ||||||
| name: python3 | ||||||
| --- | ||||||
|
|
||||||
| ```{python} | ||||||
| # | label: setup | ||||||
| # | output: false | ||||||
| import jax.nn as jnn | ||||||
| import jax.numpy as jnp | ||||||
| import jax.random as random | ||||||
| import numpy as np | ||||||
| import numpyro | ||||||
| import numpyro.distributions as dist | ||||||
| import pandas as pd | ||||||
| import plotnine as p9 | ||||||
| import warnings | ||||||
| from plotnine.exceptions import PlotnineWarning | ||||||
| from _tutorial_theme import theme_tutorial | ||||||
|
|
||||||
| from pyrenew.ascertainment import JointAscertainment | ||||||
| from pyrenew.deterministic import DeterministicPMF | ||||||
| from pyrenew.model import PyrenewBuilder | ||||||
| from pyrenew.observation import NegativeBinomialNoise, PopulationCounts | ||||||
| from pyrenew.randomvariable import DistributionalVariable | ||||||
| from pyrenew.time import MMWR_WEEK | ||||||
|
|
||||||
| warnings.filterwarnings("ignore", category=PlotnineWarning) | ||||||
| ``` | ||||||
|
|
||||||
| Ascertainment is the probability that a latent infection appears in an observed signal, e.g., a hospitalization or visit to the emergency department. | ||||||
| For hospital admissions this probability is often called an infection-hospitalization ratio (IHR). | ||||||
| For emergency department visits it might be called an infection-ED-visit ratio (IEDR). | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
| PyRenew count observations accept an `ascertainment_rate_rv`. | ||||||
| For simple models this can be any ordinary `RandomVariable`. | ||||||
|
|
||||||
| For multi-signal models, PyRenew also provides model-level ascertainment components that let multiple observation processes share related ascertainment structure. | ||||||
| These components are intended for count signals whose observation probabilities are logically related because they are different event streams generated from the same latent | ||||||
| infections. Hospital admissions and ED visits are a natural example: they have different observation probabilities, but both depend on clinical care-seeking and reporting. | ||||||
|
|
||||||
| ## Independent scalar ascertainment | ||||||
|
|
||||||
| Before considering multi-signal models with related ascertainment structure, | ||||||
| we show how to specify a multi-signal model where the signals are modeled independently. | ||||||
| In this model, each observation process has its own scalar ascertainment prior. | ||||||
|
|
||||||
| ```{python} | ||||||
| # | label: independent-observation-processes | ||||||
| hosp_delay_pmf = jnp.array([0.05, 0.10, 0.15, 0.15, 0.20, 0.15, 0.15, 0.05]) | ||||||
| hosp_delay_rv = DeterministicPMF("inf_to_hosp_delay", hosp_delay_pmf) | ||||||
|
|
||||||
| ed_delay_pmf = jnp.array([0.05, 0.15, 0.30, 0.30, 0.15, 0.05]) | ||||||
| ed_delay_rv = DeterministicPMF("inf_to_ed_delay", ed_delay_pmf) | ||||||
|
|
||||||
| hospital_obs_independent = PopulationCounts( | ||||||
| name="hospital", | ||||||
| ascertainment_rate_rv=DistributionalVariable("ihr", dist.Beta(2, 198)), | ||||||
| delay_distribution_rv=hosp_delay_rv, | ||||||
| noise=NegativeBinomialNoise( | ||||||
| DistributionalVariable( | ||||||
| "hospital_concentration", dist.LogNormal(4.0, 0.5) | ||||||
| ) | ||||||
| ), | ||||||
| ) | ||||||
|
|
||||||
| ed_obs_independent = PopulationCounts( | ||||||
| name="ed_visits", | ||||||
| ascertainment_rate_rv=DistributionalVariable("iedr", dist.Beta(2, 98)), | ||||||
| delay_distribution_rv=ed_delay_rv, | ||||||
| noise=NegativeBinomialNoise( | ||||||
| DistributionalVariable("ed_concentration", dist.LogNormal(4.0, 0.5)) | ||||||
| ), | ||||||
| ) | ||||||
| ``` | ||||||
|
|
||||||
| ```{python} | ||||||
| # | label: independent-prior-draws | ||||||
| n_draws = 1500 | ||||||
| key_ihr, key_iedr = random.split(random.PRNGKey(1)) | ||||||
|
|
||||||
| independent_ihr = dist.Beta(2, 198).sample(key_ihr, (n_draws,)) | ||||||
| independent_iedr = dist.Beta(2, 98).sample(key_iedr, (n_draws,)) | ||||||
|
|
||||||
| independent_df = pd.DataFrame( | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a comment indicating this will be used for visualization later? Otherwise a bit unclear what it's doing here. |
||||||
| { | ||||||
| "hospital_ihr": np.array(independent_ihr), | ||||||
| "ed_iedr": np.array(independent_iedr), | ||||||
| } | ||||||
| ) | ||||||
| ``` | ||||||
|
|
||||||
| ## Model-level ascertainment | ||||||
|
|
||||||
| To share ascertainment structure across count signals, define an ascertainment model once and register it with the builder. | ||||||
| Each observation process receives the appropriate signal-specific accessor from `for_signal()`. | ||||||
|
|
||||||
| ```python | ||||||
| builder = PyrenewBuilder() | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add an example creation of the object |
||||||
| builder.add_ascertainment(ascertainment) | ||||||
|
|
||||||
| PopulationCounts( | ||||||
| name="hospital", | ||||||
| ascertainment_rate_rv=ascertainment.for_signal("hospital"), | ||||||
| ... | ||||||
| ) | ||||||
| ``` | ||||||
|
|
||||||
| The model samples the ascertainment component once per model execution. | ||||||
| The accessors passed to the observation processes read the sampled values. | ||||||
|
|
||||||
| ## Joint scalar ascertainment | ||||||
|
|
||||||
| Use `JointAscertainment` when each signal has a scalar ascertainment rate, but the rates should be correlated across signals. | ||||||
| The model samples the rates jointly on the logit scale and returns one probability for each signal. | ||||||
| These rates are constant over the model time axis. | ||||||
|
|
||||||
| ```{python} | ||||||
| # | label: joint-ascertainment-object | ||||||
| joint_ascertainment = JointAscertainment( | ||||||
| name="he_ascertainment", | ||||||
| signals=("hospital", "ed_visits"), | ||||||
| baseline_rates=jnp.array([0.01, 0.02]), | ||||||
| scale_tril=jnp.array( | ||||||
| [ | ||||||
| [0.7, 0.0], | ||||||
| [0.35, 0.606], | ||||||
| ] | ||||||
| ), | ||||||
| ) | ||||||
| ``` | ||||||
|
|
||||||
| The order of the specified `signals` determines how arguments `baseline_rates` and `scale_tril` are interpreted. | ||||||
| In the above example, the first entry is hospitalizations and the second is ED visits. | ||||||
| In a NumPyro trace, this object creates one sample site, `he_ascertainment_eta`, and deterministic signal-specific rates such as `he_ascertainment_hospital` and `he_ascertainment_ed_visits`. | ||||||
|
|
||||||
| The next block samples directly from the same logit-normal prior used by JointAscertainment. This is only for visualization; in a fitted PyRenew model, JointAscertainment handles this sampling internally. | ||||||
|
|
||||||
| ```{python} | ||||||
| # | label: joint-prior-draws | ||||||
| joint_eta = dist.MultivariateNormal( | ||||||
| loc=joint_ascertainment.baseline_logits, | ||||||
| scale_tril=joint_ascertainment.scale_tril, | ||||||
| ).sample(random.PRNGKey(2), (n_draws,)) | ||||||
| joint_rates = jnn.sigmoid(joint_eta) | ||||||
|
|
||||||
| joint_df = pd.DataFrame( | ||||||
| { | ||||||
| "hospital_ihr": np.array(joint_rates[:, 0]), | ||||||
| "ed_iedr": np.array(joint_rates[:, 1]), | ||||||
| } | ||||||
| ) | ||||||
|
|
||||||
| compare_df = pd.concat( | ||||||
| [ | ||||||
| independent_df.assign(prior="Independent scalar priors"), | ||||||
| joint_df.assign(prior="Joint scalar prior"), | ||||||
| ], | ||||||
| ignore_index=True, | ||||||
| ) | ||||||
| compare_df["prior"] = pd.Categorical( | ||||||
| compare_df["prior"], | ||||||
| categories=["Independent scalar priors", "Joint scalar prior"], | ||||||
| ordered=True, | ||||||
| ) | ||||||
| ``` | ||||||
|
|
||||||
| The following plot compares the samples drawn from independent scalar priors and draws from the correlated logit-normal prior used by `JointAscertainment`. | ||||||
| The dashed line marks the prior-center ratio, IEDR = 2 × IHR. | ||||||
|
|
||||||
| The prior draws from independent priors show that high IHR draws do not imply high IEDR draws. | ||||||
| The joint prior keeps the same approximate scale while inducing positive dependence between the two rates. | ||||||
|
|
||||||
| ```{python} | ||||||
| # | label: plot-joint-priors | ||||||
| # | fig-cap: | | ||||||
| # | Joint ascertainment induces correlation between scalar signal rates. | ||||||
| # | The dashed line marks the prior-center ratio, IEDR = 2 × IHR. | ||||||
| # | The joint prior keeps the same approximate scale while inducing | ||||||
| # | positive dependence between the two rates. | ||||||
|
|
||||||
| ( | ||||||
| p9.ggplot(compare_df, p9.aes(x="hospital_ihr", y="ed_iedr")) | ||||||
| + p9.geom_point(alpha=0.2, size=1.0, color="steelblue") | ||||||
| + p9.geom_abline(intercept=0, slope=2, linetype="dashed", color="gray") | ||||||
| + p9.facet_wrap("~prior", nrow=1) | ||||||
| + p9.labs( | ||||||
| x="Hospital ascertainment rate (IHR)", | ||||||
| y="ED ascertainment rate (IEDR)", | ||||||
| title="Independent vs. joint scalar ascertainment", | ||||||
| ) | ||||||
| + theme_tutorial | ||||||
| ) | ||||||
| ``` | ||||||
|
|
||||||
|
|
||||||
| ## Using ascertainment with a builder | ||||||
|
|
||||||
| The main API pattern is: | ||||||
|
|
||||||
| ```python | ||||||
| builder = PyrenewBuilder() | ||||||
| builder.configure_latent(...) | ||||||
|
|
||||||
| joint_ascertainment = JointAscertainment( | ||||||
| name="he_ascertainment", | ||||||
| signals=("hospital", "ed_visits"), | ||||||
| baseline_rates=jnp.array([0.01, 0.02]), | ||||||
| scale_tril=jnp.array( | ||||||
| [ | ||||||
| [0.7, 0.0], | ||||||
| [0.35, 0.606], | ||||||
| ] | ||||||
| ), | ||||||
| ) | ||||||
| builder.add_ascertainment(joint_ascertainment) | ||||||
|
|
||||||
| hospital_obs = PopulationCounts( | ||||||
| name="hospital", | ||||||
| ascertainment_rate_rv=joint_ascertainment.for_signal("hospital"), | ||||||
| delay_distribution_rv=hosp_delay_rv, | ||||||
| noise=hosp_noise_rv, | ||||||
| aggregation="weekly", | ||||||
| reporting_schedule="regular", | ||||||
| start_dow=MMWR_WEEK, | ||||||
| ) | ||||||
| builder.add_observation(hospital_obs) | ||||||
|
|
||||||
| ed_obs = PopulationCounts( | ||||||
| name="ed_visits", | ||||||
| ascertainment_rate_rv=joint_ascertainment.for_signal("ed_visits"), | ||||||
| delay_distribution_rv=ed_delay_rv, | ||||||
| noise=ed_noise_rv, | ||||||
| day_of_week_rv=DeterministicVariable("ed_dow", ed_day_of_week_effects), | ||||||
| ) | ||||||
| builder.add_observation(ed_obs) | ||||||
|
|
||||||
| model = builder.build() | ||||||
| ``` | ||||||
|
|
||||||
| Signal names in `for_signal()` must match the names used when the ascertainment model was created. | ||||||
| They do not have to match observation names, but matching them usually makes model code and posterior outputs easier to read. | ||||||
|
|
||||||
| When running a model that uses weekly observations or day-of-week effects, pass `obs_start_date` so PyRenew can align the model axis to the calendar. | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure we need this line.
Suggested change
|
||||||
|
|
||||||
| ## Summary | ||||||
|
|
||||||
| - Use ordinary `RandomVariable` objects for independent scalar ascertainment rates. | ||||||
| - Use `JointAscertainment` for correlated scalar rates across signals. | ||||||
| - Ascertainment and latent infection scale are weakly identified without informative priors or external information. | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is true but we didn't really demonstrate it above. Remove or justify/discuss more above. |
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # numpydoc ignore=GL08 | ||
| """ | ||
| Ascertainment models for shared observation-rate structure. | ||
| """ | ||
|
|
||
| from pyrenew.ascertainment.base import AscertainmentModel, AscertainmentSignal | ||
| from pyrenew.ascertainment.joint import JointAscertainment | ||
|
|
||
| __all__ = [ | ||
| "AscertainmentModel", | ||
| "AscertainmentSignal", | ||
| "JointAscertainment", | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.