Skip to content

Normalize served prev_action_chunk + reserved-key sample_kwargs passthrough#22

Open
jiabinq wants to merge 1 commit intoOpenDriveLab:mainfrom
jiabinq:awbc-inference-fixes
Open

Normalize served prev_action_chunk + reserved-key sample_kwargs passthrough#22
jiabinq wants to merge 1 commit intoOpenDriveLab:mainfrom
jiabinq:awbc-inference-fixes

Conversation

@jiabinq
Copy link
Copy Markdown

@jiabinq jiabinq commented Apr 20, 2026

Two related fixes to served Policy.infer. Both are backward-compatible.

Summary

1. RTC prev_action_chunk is normalized to model space before reaching sample_actions.

Pi0RTC.sample_actions(...) consumes prev_action_chunk in model space (post-Normalize), but Policy.infer was forwarding obs["prev_action_chunk"] from the wire raw. The Agilex inference clients under train_deploy_alignment/inference/agilex/ send a raw deploy-space slice of their execution buffer, so the guidance term operates on un-normalized inputs — a silent train-deploy contract break. The bug is masked because Agilex action norm-stats are close to unit-variance, so the magnitude error in the guidance term is small, but the contract is broken.

The fix adds _normalize_and_pad_prev_chunk that delegates to the same transforms.Normalize instance the serving pipeline uses (so use_quantile_norm is honored) and pads to the model's action_horizon. Wired from the loaded checkpoint via three new optional Policy constructor params (norm_stats, use_quantile_norm, action_horizon); policy_config.create_trained_policy sets them automatically — call sites don't change.

Also guards against silent d=0 cheap-path activation when a client sends prev_action_chunk without inference_delay (would otherwise run the eager loop with no prefix conditioning).

2. Reserved-key obs["_sample_kwargs"] allowlist for transport-layer sample_kwargs overrides (currently: noise).

Policy.infer(obs, *, noise=...) accepts noise for in-process callers, but the websocket protocol drops it — WebsocketPolicyServer._handler calls self._policy.infer(obs) only. This makes deterministic evaluation of a served checkpoint impossible.

The fix extracts an optional obs["_sample_kwargs"]["noise"] into the noise kwarg path. The reserved-key namespace (leading underscore) avoids collision with any future model that legitimately uses an observation field named noise. The explicit noise= kwarg takes precedence over the obs-supplied noise (in-process callers behave unchanged).

Backward compatibility

  • Callers that don't include _sample_kwargs see no behavior change.
  • Non-RTC Policy(...) callers see no behavior change — the new constructor params are optional with safe defaults.
  • Existing RTC clients that were sending raw deploy-space prev_action_chunk will now receive correctly-normalized chunks. This is the bug fix. No client API change required.

Test plan

  • Send the same observation twice with obs["_sample_kwargs"]["noise"] = <fixed> and verify the returned actions are bit-identical.
  • Compare a served Policy.infer call with prev_action_chunk against a direct Normalize({"actions": stats})({"actions": raw})["actions"] and confirm the values forwarded into sample_kwargs match.
  • Confirm an existing non-RTC checkpoint still serves correctly with no client change (smoke).

Audit context

Found during a downstream parity audit of OpenDriveLab/kai0 9d93078 deploy stack. Audit deliverables (downstream fork): notes/awbc_inference_dagger_parity_gaps.md, reference/awbc_inference_dagger_upstream_review.md.

…hrough

Two related fixes to served Policy.infer:

1. RTC prev_action_chunk is now normalized to model space before reaching
   sample_actions. Pi0RTC.sample_actions() consumes prev_action_chunk in
   model space (post-Normalize), but Policy.infer was forwarding
   obs["prev_action_chunk"] from the wire raw. Agilex inference clients
   send a raw deploy-space slice of their execution buffer, so the
   guidance term was operating on un-normalized inputs — a silent
   train-deploy contract break (masked because Agilex action norm-stats
   are close to unit-variance, so the magnitude error is small).

   The fix adds a _normalize_and_pad_prev_chunk helper that delegates to
   the same transforms.Normalize the serving pipeline uses (so
   use_quantile_norm is honored), pads to action_horizon, and is wired
   from the loaded checkpoint via three new optional Policy params
   (norm_stats, use_quantile_norm, action_horizon). policy_config wires
   them automatically — call sites unchanged.

   Also guards against silent d=0 cheap-path activation when a client
   sends prev_action_chunk without inference_delay.

2. Reserved-key obs["_sample_kwargs"] allowlist for transport-layer
   sample_kwargs overrides (currently: noise). Previously the websocket
   protocol dropped the noise= kwarg — making deterministic served eval
   impossible. The reserved-key namespace (leading underscore) avoids
   collision with future models that legitimately use observation field
   names like "noise". Explicit noise= kwarg (in-process callers) takes
   precedence.

Both fixes are backward-compatible: existing callers see no behavior
change. Existing RTC clients that previously sent raw deploy-space
chunks will now receive the correct normalized chunk — this is the
bug fix.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant