Add Ascertainment model for JointAscertainment#802
Add Ascertainment model for JointAscertainment#802cdc-mitzimorris wants to merge 51 commits intomainfrom
Conversation
damonbayer
left a comment
There was a problem hiding this comment.
I think we should discuss the interface for the time-varying ascertainment rate.
for more information, see https://pre-commit.ci
…new into mem_777_joint_ascertainment
Agreed. I made the other changes you suggested, but I after further consideration, I think that I propose the following additional
This preserves the useful scale-sharing idea in the production HEW model while Instead, we should: sample correlated baseline ascertainment logits: For each signal: For HEW this gives: So hospital and ED baselines are related, but ED temporal deviations do not Interface SketchJointTimeVaryingAscertainment(
name="he_ascertainment",
signals=("hospital", "ed_visits"),
baseline_loc=jnp.array([logit(0.01), logit(0.02)]),
baseline_scale_tril=jnp.array(
[
[0.7, 0.0],
[0.35, 0.606],
]
),
processes={
"ed_visits": WeeklyTemporalProcess(
AR1(autoreg=0.8, innovation_sd=0.25),
start_dow=MMWR_WEEK,
),
},
)Return value: {
"hospital": scalar_ihr,
"ed_visits": iedr_trajectory,
}Signals listed in Relationship To Existing ComponentsValidationThe component should validate:
NoteThis component keeps PyRenew's current count-observation semantics: |
Can you be more precise with these descriptions? What about My instinct is that even the
I don't see how "hospital and ED baselines are related". Is it because |
|
@damonbayer asks good questions and I think that this PR should drop the TimeVaryingAscertainment component and just do scalar/joint ascertainment. The current TimeVaryingAscertainment API makes a free temporal process the main abstraction: But a free temporal ascertainment process is identifiability-expensive in a renewal model: observed counts identify something close to latent_incidence_t * ascertainment_t, so a flexible z_j(t) can compete directly with latent infections / Rt unless it is strongly motivated and constrained. We also already model day-of-week effects in the observation process, so time-varying ascertainment should not be used to absorb reporting periodicity. If the use case is for ED visits, it needs justification - do we have any? Plausible changes in IEDR are not slow latent trends; they are known calendar/event/regime effects such as holidays, disasters, data-feed disruptions, coding changes, or policy changes. Those should probably be explicit covariates or regime effects, not inferred by a generic AR process that can compete with latent infections/Rt. |
|
@damonbayer - removed TimeVaryingAscertainment (for now). |
| ``` | ||
|
|
||
| Ascertainment is the probability that a latent infection appears in an observed signal, e.g., a hospitalization or visit to the emergency department. | ||
| For hospital admissions this probability is often called an infection-hospitalization ratio (IHR). |
There was a problem hiding this comment.
| For hospital admissions this probability is often called an infection-hospitalization ratio (IHR). | |
| For hospital admissions this probability is often called an infection-to-hospitalization rate (IHR). |
|
|
||
| Ascertainment is the probability that a latent infection appears in an observed signal, e.g., a hospitalization or visit to the emergency department. | ||
| For hospital admissions this probability is often called an infection-hospitalization ratio (IHR). | ||
| For emergency department visits it might be called an infection-ED-visit ratio (IEDR). |
There was a problem hiding this comment.
| For emergency department visits it might be called an infection-ED-visit ratio (IEDR). | |
| For emergency department visits it might be called an infection-to-emergency-department rate (IEDR). |
| independent_ihr = dist.Beta(2, 198).sample(key_ihr, (n_draws,)) | ||
| independent_iedr = dist.Beta(2, 98).sample(key_iedr, (n_draws,)) | ||
|
|
||
| independent_df = pd.DataFrame( |
There was a problem hiding this comment.
Add a comment indicating this will be used for visualization later?
Otherwise a bit unclear what it's doing here.
| Each observation process receives the appropriate signal-specific accessor from `for_signal()`. | ||
|
|
||
| ```python | ||
| builder = PyrenewBuilder() |
There was a problem hiding this comment.
Add an example creation of the object ascertainment?
| Signal names in `for_signal()` must match the names used when the ascertainment model was created. | ||
| They do not have to match observation names, but matching them usually makes model code and posterior outputs easier to read. | ||
|
|
||
| When running a model that uses weekly observations or day-of-week effects, pass `obs_start_date` so PyRenew can align the model axis to the calendar. |
There was a problem hiding this comment.
Not sure we need this line.
| When running a model that uses weekly observations or day-of-week effects, pass `obs_start_date` so PyRenew can align the model axis to the calendar. |
|
|
||
| - Use ordinary `RandomVariable` objects for independent scalar ascertainment rates. | ||
| - Use `JointAscertainment` for correlated scalar rates across signals. | ||
| - Ascertainment and latent infection scale are weakly identified without informative priors or external information. |
There was a problem hiding this comment.
This is true but we didn't really demonstrate it above. Remove or justify/discuss more above.
| if not isinstance(ascertainment_name, str) or len(ascertainment_name) == 0: | ||
| raise ValueError( | ||
| "ascertainment_name must be a non-empty string. " | ||
| f"Got {type(ascertainment_name).__name__}: {ascertainment_name!r}" | ||
| ) | ||
| if not isinstance(signal_name, str) or len(signal_name) == 0: | ||
| raise ValueError( | ||
| "signal_name must be a non-empty string. " | ||
| f"Got {type(signal_name).__name__}: {signal_name!r}" |
There was a problem hiding this comment.
Separate PR: add a check helper for non-empty string names
There was a problem hiding this comment.
Note that there is one already in context.py; could be shared throughout codebase.
| infection-to-observation ratios, while still being correlated because both | ||
| depend on care-seeking behavior, testing practices, or reporting systems. | ||
|
|
||
| The model samples one multivariate normal vector around natural-scale |
There was a problem hiding this comment.
| The model samples one multivariate normal vector around natural-scale | |
| The model samples one multivariate logit-normal vector given natural-scale |
| depend on care-seeking behavior, testing practices, or reporting systems. | ||
|
|
||
| The model samples one multivariate normal vector around natural-scale | ||
| baseline ascertainment rates, which are converted to the logit scale |
There was a problem hiding this comment.
| baseline ascertainment rates, which are converted to the logit scale | |
| baseline ascertainment rates: |
| self.baseline_logits = jnp.log(self.baseline_rates) - jnp.log1p( | ||
| -self.baseline_rates | ||
| ) |
There was a problem hiding this comment.
| self.baseline_logits = jnp.log(self.baseline_rates) - jnp.log1p( | |
| -self.baseline_rates | |
| ) | |
| self.baseline_logits = logit(self.baseline_rates) |
| from collections.abc import Mapping | ||
|
|
||
| import jax.nn as jnn | ||
| import jax.numpy as jnp |
There was a problem hiding this comment.
| import jax.numpy as jnp | |
| import jax.numpy as jnp | |
| from jax.scipy.special import logit, expit |
| f"{self.name}_eta", | ||
| self._distribution(), | ||
| ) | ||
| rates = jnn.sigmoid(eta) |
There was a problem hiding this comment.
| rates = jnn.sigmoid(eta) | |
| rates = expit(eta) |
|
|
||
| from collections.abc import Mapping | ||
|
|
||
| import jax.nn as jnn |
There was a problem hiding this comment.
| import jax.nn as jnn |
| self.baseline_rates = jnp.asarray(baseline_rates) | ||
| self.scale_tril = self._optional_array(scale_tril) | ||
| self.covariance_matrix = self._optional_array(covariance_matrix) | ||
| self.precision_matrix = self._optional_array(precision_matrix) |
There was a problem hiding this comment.
Why not just duck-type this?
| ) | ||
|
|
||
| @staticmethod | ||
| def _optional_array(value: ArrayLike | None) -> ArrayLike | None: |
There was a problem hiding this comment.
See above question about duck typing.
| f"{name} must have shape {matrix_shape}, got shape {param.shape}." | ||
| ) | ||
|
|
||
| def _distribution(self) -> dist.MultivariateNormal: |
There was a problem hiding this comment.
Why not instantiate this at construction, without the wrapper?
There was a problem hiding this comment.
Wouldn't this also enable deferring validation of the various possible covariance args to the MultivariateNormal constructor?
There was a problem hiding this comment.
Notably, it raises exactly the error raised here, e.g.:
ValueError: Exactly one of ['covariance_matrix', 'precision_matrix', 'scale_tril'] must be specified; got ['precision_matrix', 'scale_tril'].
Summary
This PR adds model-level ascertainment support so multiple observation signals can share structured ascertainment rates instead of each observation independently sampling its own rate.
The main motivation is to support H+E-style multisignal models where hospital admissions and ED visits have distinct but related observation probabilities. The new abstraction lets those probabilities be sampled once at the model level, then passed into observation processes through signal-specific accessors without creating duplicate NumPyro sample sites.
Additions / Changes
Added a new
pyrenew.ascertainmentpackage with:AscertainmentModel, the base interface for shared ascertainment models.AscertainmentSignal, a signal-specific accessor used by observation processes.JointAscertainment, which samples correlated scalar ascertainment rates across signals on the logit scale.Extended
PyrenewBuilderwithadd_ascertainment(...), allowing shared ascertainment models to be registered alongside latent and observation components.Extended
MultiSignalModelso registered ascertainment models are sampled once before observation processes run. Observation processes then retrieve their signal-specific values from the active ascertainment context.Added unit coverage for validation, sampling behavior, context safety, duplicate-site avoidance, builder registration, and direct
MultiSignalModelvalidation.Added integration coverage for mixed-cadence hospital + ED models:
JointAscertainmentSuggested Review Order
pyrenew/ascertainment/base.pyandpyrenew/ascertainment/context.pyto understand the core abstraction and how accessors retrieve already-sampled values.pyrenew/ascertainment/joint.pypyrenew/model/pyrenew_builder.pyandpyrenew/model/multisignal_model.py.docs/tutorials/ascertainment.pytest/test_ascertainment.pyand the new builder tests intest/test_pyrenew_builder.py.test/integration/conftest.pyand the two new H+E integration test files, since those show the intended end-to-end usage.