Skip to content

Add Ascertainment model for JointAscertainment#802

Open
cdc-mitzimorris wants to merge 51 commits intomainfrom
mem_777_joint_ascertainment
Open

Add Ascertainment model for JointAscertainment#802
cdc-mitzimorris wants to merge 51 commits intomainfrom
mem_777_joint_ascertainment

Conversation

@cdc-mitzimorris
Copy link
Copy Markdown
Collaborator

@cdc-mitzimorris cdc-mitzimorris commented Apr 28, 2026

Summary

This PR adds model-level ascertainment support so multiple observation signals can share structured ascertainment rates instead of each observation independently sampling its own rate.

The main motivation is to support H+E-style multisignal models where hospital admissions and ED visits have distinct but related observation probabilities. The new abstraction lets those probabilities be sampled once at the model level, then passed into observation processes through signal-specific accessors without creating duplicate NumPyro sample sites.

Additions / Changes

  • Added a new pyrenew.ascertainment package with:

    • AscertainmentModel, the base interface for shared ascertainment models.
    • AscertainmentSignal, a signal-specific accessor used by observation processes.
    • JointAscertainment, which samples correlated scalar ascertainment rates across signals on the logit scale.
    • A context layer that makes sampled ascertainment values available to observation processes during model execution.
  • Extended PyrenewBuilder with add_ascertainment(...), allowing shared ascertainment models to be registered alongside latent and observation components.

  • Extended MultiSignalModel so registered ascertainment models are sampled once before observation processes run. Observation processes then retrieve their signal-specific values from the active ascertainment context.

  • Added unit coverage for validation, sampling behavior, context safety, duplicate-site avoidance, builder registration, and direct MultiSignalModel validation.

  • Added integration coverage for mixed-cadence hospital + ED models:

    • weekly hospital admissions plus daily ED visits with JointAscertainment

Suggested Review Order

  1. Start with pyrenew/ascertainment/base.py and pyrenew/ascertainment/context.py to understand the core abstraction and how accessors retrieve already-sampled values.
  2. Review pyrenew/ascertainment/joint.py
  3. Review the integration points in pyrenew/model/pyrenew_builder.py and pyrenew/model/multisignal_model.py.
  4. Review tutorial docs/tutorials/ascertainment.py
  5. Review test/test_ascertainment.py and the new builder tests in test/test_pyrenew_builder.py.
  6. Finish with the integration fixtures and tests in test/integration/conftest.py and the two new H+E integration test files, since those show the intended end-to-end usage.

Comment thread pyrenew/ascertainment/timevarying.py Outdated
Comment thread pyrenew/ascertainment/timevarying.py Outdated
Copy link
Copy Markdown
Collaborator

@damonbayer damonbayer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should discuss the interface for the time-varying ascertainment rate.

@cdc-mitzimorris
Copy link
Copy Markdown
Collaborator Author

cdc-mitzimorris commented Apr 30, 2026

@damonbayer

I think we should discuss the interface for the time-varying ascertainment rate.

Agreed.

I made the other changes you suggested, but I after further consideration, I think that TimeVaryingAscertainment is problematic in that it gives the model too much freedom. More importantly, the current production HE model is doing something different, therefore we need to do something more or different.

I propose the following additional AscertainmentModel class: JointTimeVaryingAscertainment which corresponds to the HE model use case:

  • hospital and ED baseline ascertainment rates should share information;
  • ED ascertainment may vary over time;
  • hospital ascertainment can remain scalar;
  • hospital IHR should not be tied to the first realized ED ascertainment value.

This preserves the useful scale-sharing idea in the production HEW model while
avoiding the stronger coupling:

IHR = IEDR(0) * ratio

Instead, we should: sample correlated baseline ascertainment logits:

mu ~ MultivariateNormal(...)

For each signal:

if signal has no temporal process:
    alpha_j = sigmoid(mu_j)

if signal has a temporal process:
    alpha_j(t) = sigmoid(mu_j + z_j(t))

For HEW this gives:

IHR = sigmoid(mu_h)

z_ed(t) ~ weekly AR(1)
IEDR(t) = sigmoid(mu_ed + z_ed(t))

So hospital and ED baselines are related, but ED temporal deviations do not
force hospital ascertainment to move.

Interface Sketch

JointTimeVaryingAscertainment(
    name="he_ascertainment",
    signals=("hospital", "ed_visits"),
    baseline_loc=jnp.array([logit(0.01), logit(0.02)]),
    baseline_scale_tril=jnp.array(
        [
            [0.7, 0.0],
            [0.35, 0.606],
        ]
    ),
    processes={
        "ed_visits": WeeklyTemporalProcess(
            AR1(autoreg=0.8, innovation_sd=0.25),
            start_dow=MMWR_WEEK,
        ),
    },
)

Return value:

{
    "hospital": scalar_ihr,
    "ed_visits": iedr_trajectory,
}

Signals listed in signals but absent from processes return scalar rates.
Signals present in processes return trajectories on the model time axis.

Relationship To Existing Components

JointAscertainment
    Correlated scalar rates only.

TimeVaryingAscertainment
    Independent fixed baselines plus independent temporal deviations.

JointTimeVaryingAscertainment
    Correlated sampled baselines plus optional signal-specific temporal
    deviations.

Validation

The component should validate:

  • signals is a non-empty tuple of unique signal names;
  • baseline_loc has shape (n_signals,);
  • exactly one baseline covariance parameter is supplied;
  • baseline covariance parameter has shape (n_signals, n_signals);
  • every key in processes is one of signals;
  • every process satisfies the TemporalProcess protocol.

Note

This component keeps PyRenew's current count-observation semantics:
time-varying ascertainment is applied before delay convolution, so trajectories
are interpreted as infection-time ascertainment probabilities.

@damonbayer
Copy link
Copy Markdown
Collaborator

damonbayer commented Apr 30, 2026

JointAscertainment
Correlated scalar rates only.

TimeVaryingAscertainment
Independent fixed baselines plus independent temporal deviations.

JointTimeVaryingAscertainment
Correlated sampled baselines plus optional signal-specific temporal
deviations.

Can you be more precise with these descriptions?

What about optional signal-specific temporal deviations is optional? Is there a conceptual difference between independent temporal deviations and signal-specific temporal deviations?

My instinct is that even the JointTimeVaryingAscertainment ties the parameters together too weakly, but I'm not totally sure.

IHR = sigmoid(mu_h)

z_ed(t) ~ weekly AR(1)
IEDR(t) = sigmoid(mu_ed + z_ed(t))

So hospital and ED baselines are related, but ED temporal deviations do not
force hospital ascertainment to move.

I don't see how "hospital and ED baselines are related". Is it because mu_h and mu_ed have a bivariate normal distribution (not written)?

@cdc-mitzimorris
Copy link
Copy Markdown
Collaborator Author

@damonbayer asks good questions and I think that this PR should drop the TimeVaryingAscertainment component and just do scalar/joint ascertainment.

The current TimeVaryingAscertainment API makes a free temporal process the main abstraction:

logit(rate_j(t)) = logit(baseline_rate_j) + z_j(t)

But a free temporal ascertainment process is identifiability-expensive in a renewal model: observed counts identify something close to latent_incidence_t * ascertainment_t, so a flexible z_j(t) can compete directly with latent infections / Rt unless it is strongly motivated and constrained. We also already model day-of-week effects in the observation process, so time-varying ascertainment should not be used to absorb reporting periodicity.

If the use case is for ED visits, it needs justification - do we have any? Plausible changes in IEDR are not slow latent trends; they are known calendar/event/regime effects such as holidays, disasters, data-feed disruptions, coding changes, or policy changes. Those should probably be explicit covariates or regime effects, not inferred by a generic AR process that can compete with latent infections/Rt.

@cdc-mitzimorris cdc-mitzimorris changed the title Add Ascertainment model for JointAscertainment and TimeVaryingAscertainment Add Ascertainment model for JointAscertainment May 1, 2026
@cdc-mitzimorris
Copy link
Copy Markdown
Collaborator Author

@damonbayer - removed TimeVaryingAscertainment (for now).

```

Ascertainment is the probability that a latent infection appears in an observed signal, e.g., a hospitalization or visit to the emergency department.
For hospital admissions this probability is often called an infection-hospitalization ratio (IHR).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
For hospital admissions this probability is often called an infection-hospitalization ratio (IHR).
For hospital admissions this probability is often called an infection-to-hospitalization rate (IHR).


Ascertainment is the probability that a latent infection appears in an observed signal, e.g., a hospitalization or visit to the emergency department.
For hospital admissions this probability is often called an infection-hospitalization ratio (IHR).
For emergency department visits it might be called an infection-ED-visit ratio (IEDR).
Copy link
Copy Markdown
Collaborator

@dylanhmorris dylanhmorris May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
For emergency department visits it might be called an infection-ED-visit ratio (IEDR).
For emergency department visits it might be called an infection-to-emergency-department rate (IEDR).

independent_ihr = dist.Beta(2, 198).sample(key_ihr, (n_draws,))
independent_iedr = dist.Beta(2, 98).sample(key_iedr, (n_draws,))

independent_df = pd.DataFrame(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment indicating this will be used for visualization later?

Otherwise a bit unclear what it's doing here.

Each observation process receives the appropriate signal-specific accessor from `for_signal()`.

```python
builder = PyrenewBuilder()
Copy link
Copy Markdown
Collaborator

@dylanhmorris dylanhmorris May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an example creation of the object ascertainment?

Signal names in `for_signal()` must match the names used when the ascertainment model was created.
They do not have to match observation names, but matching them usually makes model code and posterior outputs easier to read.

When running a model that uses weekly observations or day-of-week effects, pass `obs_start_date` so PyRenew can align the model axis to the calendar.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we need this line.

Suggested change
When running a model that uses weekly observations or day-of-week effects, pass `obs_start_date` so PyRenew can align the model axis to the calendar.


- Use ordinary `RandomVariable` objects for independent scalar ascertainment rates.
- Use `JointAscertainment` for correlated scalar rates across signals.
- Ascertainment and latent infection scale are weakly identified without informative priors or external information.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is true but we didn't really demonstrate it above. Remove or justify/discuss more above.

Comment on lines +44 to +52
if not isinstance(ascertainment_name, str) or len(ascertainment_name) == 0:
raise ValueError(
"ascertainment_name must be a non-empty string. "
f"Got {type(ascertainment_name).__name__}: {ascertainment_name!r}"
)
if not isinstance(signal_name, str) or len(signal_name) == 0:
raise ValueError(
"signal_name must be a non-empty string. "
f"Got {type(signal_name).__name__}: {signal_name!r}"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separate PR: add a check helper for non-empty string names

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that there is one already in context.py; could be shared throughout codebase.

infection-to-observation ratios, while still being correlated because both
depend on care-seeking behavior, testing practices, or reporting systems.

The model samples one multivariate normal vector around natural-scale
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The model samples one multivariate normal vector around natural-scale
The model samples one multivariate logit-normal vector given natural-scale

depend on care-seeking behavior, testing practices, or reporting systems.

The model samples one multivariate normal vector around natural-scale
baseline ascertainment rates, which are converted to the logit scale
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
baseline ascertainment rates, which are converted to the logit scale
baseline ascertainment rates:

Comment on lines +88 to +90
self.baseline_logits = jnp.log(self.baseline_rates) - jnp.log1p(
-self.baseline_rates
)
Copy link
Copy Markdown
Collaborator

@dylanhmorris dylanhmorris May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.baseline_logits = jnp.log(self.baseline_rates) - jnp.log1p(
-self.baseline_rates
)
self.baseline_logits = logit(self.baseline_rates)

from collections.abc import Mapping

import jax.nn as jnn
import jax.numpy as jnp
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import jax.numpy as jnp
import jax.numpy as jnp
from jax.scipy.special import logit, expit

f"{self.name}_eta",
self._distribution(),
)
rates = jnn.sigmoid(eta)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
rates = jnn.sigmoid(eta)
rates = expit(eta)


from collections.abc import Mapping

import jax.nn as jnn
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import jax.nn as jnn

Comment on lines +83 to +86
self.baseline_rates = jnp.asarray(baseline_rates)
self.scale_tril = self._optional_array(scale_tril)
self.covariance_matrix = self._optional_array(covariance_matrix)
self.precision_matrix = self._optional_array(precision_matrix)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just duck-type this?

)

@staticmethod
def _optional_array(value: ArrayLike | None) -> ArrayLike | None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above question about duck typing.

f"{name} must have shape {matrix_shape}, got shape {param.shape}."
)

def _distribution(self) -> dist.MultivariateNormal:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not instantiate this at construction, without the wrapper?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this also enable deferring validation of the various possible covariance args to the MultivariateNormal constructor?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notably, it raises exactly the error raised here, e.g.:

ValueError: Exactly one of ['covariance_matrix', 'precision_matrix', 'scale_tril'] must be specified; got ['precision_matrix', 'scale_tril'].

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants