Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 103 additions & 12 deletions src/openpi/policies/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,53 @@
from openpi.shared import array_typing as at
from openpi.shared import nnx_utils

logger = logging.getLogger(__name__)

BasePolicy: TypeAlias = _base_policy.BasePolicy

# Reserved transport-layer key in the observation dict for served clients to override
# sample_kwargs (e.g. pass a deterministic noise sample). Leading underscore signals
# "transport-layer field, not a model observation input" — avoids collisions with future
# models that legitimately use observation field names like "noise".
_RESERVED_SAMPLE_KWARGS_KEY = "_sample_kwargs"
_ALLOWED_TRANSPORT_SAMPLE_KWARGS = frozenset({"noise"})


def _normalize_and_pad_prev_chunk(
raw: np.ndarray,
*,
norm_stats: dict[str, _transforms.NormStats],
use_quantile_norm: bool,
action_horizon: int,
) -> np.ndarray:
"""Normalize a client-supplied ``prev_action_chunk`` into model space and pad to ``action_horizon``.

The model's RTC ``sample_actions`` consumes ``prev_action_chunk`` in **model space**
(post-Normalize), but websocket clients send a raw ``(d, state_dim)`` slice of their
deploy-space execution buffer. Without this helper the guidance term operates on
un-normalized inputs — a silent train-deploy contract break.

Delegates to the same ``transforms.Normalize`` instance the serving pipeline uses so the
formula (z-score vs quantile) cannot drift. Pads the chunk to the model's
``action_horizon`` because the JAX/PyTorch RTC implementations require that shape.
"""
state_dim = raw.shape[-1]
action_stats = norm_stats["actions"]
if state_dim > action_stats.mean.shape[-1]:
raise ValueError(
f"prev_action_chunk state_dim={state_dim} exceeds norm_stats['actions'] width "
f"{action_stats.mean.shape[-1]}; client is sending more joints than the checkpoint knows about."
)
normalizer = _transforms.Normalize({"actions": action_stats}, use_quantiles=use_quantile_norm)
normalized = normalizer({"actions": raw})["actions"]
d = normalized.shape[0]
if d < action_horizon:
pad = np.zeros((action_horizon - d, state_dim), dtype=np.float32)
normalized = np.concatenate([normalized, pad], axis=0)
elif d > action_horizon:
normalized = normalized[:action_horizon]
return normalized.astype(np.float32, copy=False)


class Policy(BasePolicy):
def __init__(
Expand All @@ -33,6 +78,9 @@ def __init__(
metadata: dict[str, Any] | None = None,
pytorch_device: str = "cpu",
is_pytorch: bool = False,
norm_stats: dict[str, _transforms.NormStats] | None = None,
use_quantile_norm: bool = False,
action_horizon: int | None = None,
):
"""Initialize the Policy.

Expand All @@ -54,6 +102,15 @@ def __init__(
self._metadata = metadata or {}
self._is_pytorch_model = is_pytorch
self._pytorch_device = pytorch_device
if norm_stats is not None and action_horizon is None:
raise ValueError(
"Policy(norm_stats=...) requires action_horizon to also be provided; "
"without it, server-side prev_action_chunk normalization cannot pad to the model's horizon."
)
self._norm_stats = norm_stats
self._use_quantile_norm = use_quantile_norm
self._action_horizon = action_horizon
self._rtc_log_emitted = False

if self._is_pytorch_model:
self._model = self._model.to(pytorch_device)
Expand Down Expand Up @@ -81,21 +138,55 @@ def infer(self, obs: dict, *, noise: np.ndarray | None = None) -> dict: # type:
# Prepare kwargs for sample_actions
sample_kwargs = dict(self._sample_kwargs)

# TODO: For RTC passthrough: allow client to provide delay/prev chunk/horizon for realtime_action-capable models
if "prev_action_chunk" in obs:
sample_kwargs["prev_action_chunk"] = obs["prev_action_chunk"]
if "inference_delay" in obs:
# RTC cheap-path guidance. Client sends a raw (d, state_dim) slice of its blended queue head
# along with inference_delay; we normalize with the same stats+mode as the serving Normalize
# transform and pad horizon to the model's action_horizon before forwarding. Both fields must
# be present together — forwarding prev_action_chunk without inference_delay would silently
# trip the cheap-path gate with d=0, running the eager loop with no prefix conditioning.
has_prev = "prev_action_chunk" in obs
has_delay = "inference_delay" in obs
if has_prev and not has_delay:
logger.warning(
"[rtc_cheap_path] obs has prev_action_chunk but not inference_delay; skipping cheap-path "
"forwarding to avoid silent d=0 activation. Client must send both fields together."
)
elif has_prev and has_delay:
raw_prev = np.asarray(obs["prev_action_chunk"], dtype=np.float32)
if self._norm_stats is not None and self._action_horizon is not None:
prev_chunk = _normalize_and_pad_prev_chunk(
raw_prev,
norm_stats=self._norm_stats,
use_quantile_norm=self._use_quantile_norm,
action_horizon=self._action_horizon,
)
sample_kwargs["prev_action_chunk"] = prev_chunk
log_fn = logger.info if not self._rtc_log_emitted else logger.debug
log_fn(
"[rtc] forwarded prev_action_chunk d=%d ah=%d quantile=%s",
raw_prev.shape[0], self._action_horizon, self._use_quantile_norm,
)
self._rtc_log_emitted = True
else:
sample_kwargs["prev_action_chunk"] = raw_prev
sample_kwargs["inference_delay"] = obs["inference_delay"]
if "execute_horizon" in obs:
sample_kwargs["execute_horizon"] = obs["execute_horizon"]
# if "enable_rtc" in obs:
# sample_kwargs["enable_rtc"] = obs["enable_rtc"]
# if "mask_prefix_delay" in obs:
# sample_kwargs["mask_prefix_delay"] = obs["mask_prefix_delay"]
# if "prefix_attention_schedule" in obs:
# sample_kwargs["prefix_attention_schedule"] = obs["prefix_attention_schedule"]
# if "max_guidance_weight" in obs:
# sample_kwargs["max_guidance_weight"] = obs["max_guidance_weight"]
# Reserved-key transport for sample_kwargs overrides (currently: noise).
# Explicit `noise=` kwarg (in-process callers) takes precedence over obs-supplied noise.
sample_kwargs_override = obs.get(_RESERVED_SAMPLE_KWARGS_KEY) or {}
if not isinstance(sample_kwargs_override, dict):
raise TypeError(
f"obs[{_RESERVED_SAMPLE_KWARGS_KEY!r}] must be a dict, "
f"got {type(sample_kwargs_override).__name__}"
)
unknown = set(sample_kwargs_override) - _ALLOWED_TRANSPORT_SAMPLE_KWARGS
if unknown:
raise ValueError(
f"obs[{_RESERVED_SAMPLE_KWARGS_KEY!r}] contains unsupported keys: {sorted(unknown)}; "
f"allowlist: {sorted(_ALLOWED_TRANSPORT_SAMPLE_KWARGS)}"
)
if "noise" in sample_kwargs_override and noise is None:
noise = np.asarray(sample_kwargs_override["noise"])
if noise is not None:
noise = torch.from_numpy(noise).to(self._pytorch_device) if self._is_pytorch_model else jnp.asarray(noise)

Expand Down
5 changes: 5 additions & 0 deletions src/openpi/policies/policy_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,9 @@ def create_trained_policy(
metadata=train_config.policy_metadata,
is_pytorch=is_pytorch,
pytorch_device=pytorch_device if is_pytorch else None,
# Wire RTC normalization params from the loaded checkpoint so served prev_action_chunk
# is normalized into model space before reaching Pi0RTC.sample_actions().
norm_stats=norm_stats,
use_quantile_norm=data_config.use_quantile_norm,
action_horizon=train_config.model.action_horizon,
)