From 5f5e6fc09dfd0434e0c3abb21b7763b040859c2a Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Mon, 27 Apr 2026 15:18:01 -0700 Subject: [PATCH] Replace _StorageReference with ExternalStorageReference proto --- temporalio/converter/_data_converter.py | 33 +++--- temporalio/converter/_extstore.py | 68 ++++++++---- tests/test_extstore.py | 141 +++++++++++++++++++++--- tests/worker/test_extstore.py | 55 ++++----- 4 files changed, 209 insertions(+), 88 deletions(-) diff --git a/temporalio/converter/_data_converter.py b/temporalio/converter/_data_converter.py index 0323466e7..13b48e695 100644 --- a/temporalio/converter/_data_converter.py +++ b/temporalio/converter/_data_converter.py @@ -14,11 +14,11 @@ import temporalio.api.common.v1 import temporalio.api.failure.v1 import temporalio.common +from temporalio.api.sdk.v1.external_storage_pb2 import ExternalStorageReference from temporalio.converter._extstore import ( _REFERENCE_ENCODING, ExternalStorage, StorageDriverStoreContext, - StorageWarning, ) from temporalio.converter._failure_converter import ( FailureConverter, @@ -41,6 +41,17 @@ WithSerializationContext, ) +_REFERENCE_MESSAGE_TYPE = ExternalStorageReference.DESCRIPTOR.full_name.encode() + + +def _is_reference_payload(p: temporalio.api.common.v1.Payload) -> bool: + """Return True if *p* is an external-storage reference payload.""" + return p.metadata.get("encoding") == _REFERENCE_ENCODING or ( + p.metadata.get("encoding") == b"json/protobuf" + and p.metadata.get("messageType") == _REFERENCE_MESSAGE_TYPE + ) + + # Import defaults from public API to avoid pydoctor cross-reference issues if TYPE_CHECKING: from temporalio.converter import DefaultFailureConverter, DefaultPayloadConverter @@ -307,13 +318,9 @@ async def _transform_inbound_payloads( if self.external_storage: await self.external_storage._retrieve_payloads(payloads) else: - if any( - p.metadata.get("encoding") == _REFERENCE_ENCODING - for p in payloads.payloads - ): - warnings.warn( - "[TMPRL1105] Detected externally stored payload(s) but external storage is not configured.", - StorageWarning, + if any(_is_reference_payload(p) for p in payloads.payloads): + raise RuntimeError( + "[TMPRL1105] Detected externally stored payload(s) but external storage is not configured." ) if self.payload_codec: await self.payload_codec.decode_wrapper(payloads) @@ -348,13 +355,9 @@ async def _external_retrieve_payload_sequence( retrieved_payloads ) else: - if any( - p.metadata.get("encoding") == _REFERENCE_ENCODING - for p in retrieved_payloads - ): - warnings.warn( - "[TMPRL1105] Detected externally stored payload(s) but external storage is not configured.", - StorageWarning, + if any(_is_reference_payload(p) for p in retrieved_payloads): + raise RuntimeError( + "[TMPRL1105] Detected externally stored payload(s) but external storage is not configured." ) return retrieved_payloads diff --git a/temporalio/converter/_extstore.py b/temporalio/converter/_extstore.py index 55b1686bf..c31424acf 100644 --- a/temporalio/converter/_extstore.py +++ b/temporalio/converter/_extstore.py @@ -18,7 +18,11 @@ from typing_extensions import Self from temporalio.api.common.v1 import Payload, Payloads -from temporalio.converter._payload_converter import JSONPlainPayloadConverter +from temporalio.api.sdk.v1.external_storage_pb2 import ExternalStorageReference +from temporalio.converter._payload_converter import ( + JSONPlainPayloadConverter, + JSONProtoPayloadConverter, +) _T = TypeVar("_T") @@ -225,6 +229,11 @@ class StorageWarning(RuntimeWarning): @dataclass(frozen=True) class _StorageReference: + """Legacy external storage reference used only on the retrieval path as a + fallback for in-flight workflows that were written before the + ExternalStorageReference proto was introduced. + """ + driver_name: str driver_claim: StorageDriverClaim @@ -278,8 +287,9 @@ class ExternalStorage: ) """Store context bound to this instance via :meth:`_with_store_context`.""" - _claim_converter: ClassVar[JSONPlainPayloadConverter] = JSONPlainPayloadConverter( - encoding=_REFERENCE_ENCODING.decode() + _claim_converter: ClassVar[JSONProtoPayloadConverter] = JSONProtoPayloadConverter() + _legacy_claim_converter: ClassVar[JSONPlainPayloadConverter] = ( + JSONPlainPayloadConverter(encoding=_REFERENCE_ENCODING.decode()) ) def __post_init__(self) -> None: @@ -357,9 +367,9 @@ async def _store_payload(self, payload: Payload) -> Payload: self._validate_claim_length(claims, expected=1, driver=driver) external_size = payload.ByteSize() - reference = _StorageReference( + reference = ExternalStorageReference( driver_name=driver.name(), - driver_claim=claims[0], + claim_data=claims[0].claim_data, ) reference_payload = self._claim_converter.to_payload(reference) if reference_payload is None: @@ -421,9 +431,9 @@ async def _store_payload_sequence( self._validate_claim_length(claims, expected=len(indices), driver=driver) for i, claim in enumerate(claims): - reference = _StorageReference( + reference = ExternalStorageReference( driver_name=driver.name(), - driver_claim=claim, + claim_data=claim.claim_data, ) reference_payload = self._claim_converter.to_payload(reference) if reference_payload is None: @@ -443,20 +453,35 @@ async def _store_payload_sequence( return results - async def _retrieve_payload(self, payload: Payload) -> Payload: + def _decode_reference(self, payload: Payload) -> ExternalStorageReference | None: + """Decode an external storage reference from a payload.""" if len(payload.external_payloads) == 0: - return payload - - start_time = time.monotonic() + return None + encoding = payload.metadata.get("encoding", b"") + if encoding == _REFERENCE_ENCODING: + legacy = self._legacy_claim_converter.from_payload( + payload, _StorageReference + ) + if not isinstance(legacy, _StorageReference): + return None + return ExternalStorageReference( + driver_name=legacy.driver_name, + claim_data=legacy.driver_claim.claim_data, + ) + ref = self._claim_converter.from_payload(payload, ExternalStorageReference) + return ref if isinstance(ref, ExternalStorageReference) else None - reference = self._claim_converter.from_payload(payload, _StorageReference) - if not isinstance(reference, _StorageReference): + async def _retrieve_payload(self, payload: Payload) -> Payload: + ref = self._decode_reference(payload) + if ref is None: return payload - driver = self._get_driver_by_name(reference.driver_name) + start_time = time.monotonic() + driver = self._get_driver_by_name(ref.driver_name) context = StorageDriverRetrieveContext() + claim = StorageDriverClaim(claim_data=dict(ref.claim_data)) - stored_payloads = await driver.retrieve(context, [reference.driver_claim]) + stored_payloads = await driver.retrieve(context, [claim]) self._validate_payload_length(stored_payloads, expected=1, driver=driver) @@ -486,15 +511,12 @@ async def _retrieve_payload_sequence( driver_claims: dict[StorageDriver, list[tuple[int, StorageDriverClaim]]] = {} for index, payload in enumerate(payloads): - if len(payload.external_payloads) == 0: + ref = self._decode_reference(payload) + if ref is None: continue - - reference = self._claim_converter.from_payload(payload, _StorageReference) - if not isinstance(reference, _StorageReference): - continue - - driver = self._get_driver_by_name(reference.driver_name) - driver_claims.setdefault(driver, []).append((index, reference.driver_claim)) + driver = self._get_driver_by_name(ref.driver_name) + claim = StorageDriverClaim(claim_data=dict(ref.claim_data)) + driver_claims.setdefault(driver, []).append((index, claim)) if not driver_claims: return results diff --git a/tests/test_extstore.py b/tests/test_extstore.py index 1771778a7..196632042 100644 --- a/tests/test_extstore.py +++ b/tests/test_extstore.py @@ -6,6 +6,7 @@ import pytest from temporalio.api.common.v1 import Payload +from temporalio.api.sdk.v1.external_storage_pb2 import ExternalStorageReference from temporalio.converter import ( DataConverter, ExternalStorage, @@ -16,9 +17,26 @@ StorageDriverRetrieveContext, StorageDriverStoreContext, ) -from temporalio.converter._extstore import _StorageReference +from temporalio.converter._extstore import _REFERENCE_ENCODING, _StorageReference +from temporalio.converter._payload_converter import JSONProtoPayloadConverter from temporalio.exceptions import ApplicationError +_legacy_ref_converter = JSONPlainPayloadConverter(encoding=_REFERENCE_ENCODING.decode()) + + +def _make_legacy_payload( + driver_name: str, claim_data: dict[str, str], size_bytes: int +) -> Payload: + """Build a reference payload in the legacy ``json/external-storage-reference`` format.""" + ref = _StorageReference( + driver_name=driver_name, + driver_claim=StorageDriverClaim(claim_data=claim_data), + ) + payload = _legacy_ref_converter.to_payload(ref) + assert payload is not None + payload.external_payloads.add().size_bytes = size_bytes + return payload + class InMemoryTestDriver(StorageDriver): """In-memory storage driver for testing.""" @@ -115,7 +133,7 @@ async def test_extstore_encode_decode(self): assert driver._retrieve_calls == 1 async def test_extstore_reference_structure(self): - """Test that external storage creates proper reference structure.""" + """Externalized payloads are written as ExternalStorageReference proto (json/protobuf encoding).""" converter = DataConverter( external_storage=ExternalStorage( drivers=[InMemoryTestDriver("test-driver")], @@ -123,25 +141,19 @@ async def test_extstore_reference_structure(self): ) ) - # Create large payload large_value = "x" * 100 encoded = await converter.encode([large_value]) - # Verify reference structure reference_payload = encoded[0] assert len(reference_payload.external_payloads) > 0 + assert reference_payload.metadata.get("encoding") == b"json/protobuf" - # The payload should contain a serialized _ExternalStorageReference - # Deserialize it to verify structure using the same encoding - claim_converter = JSONPlainPayloadConverter( - encoding="json/external-storage-reference" + reference = JSONProtoPayloadConverter().from_payload( + reference_payload, ExternalStorageReference ) - reference = claim_converter.from_payload(reference_payload, _StorageReference) - - assert isinstance(reference, _StorageReference) - assert "test-driver" == reference.driver_name - assert isinstance(reference.driver_claim, StorageDriverClaim) - assert "key" in reference.driver_claim.claim_data + assert isinstance(reference, ExternalStorageReference) + assert reference.driver_name == "test-driver" + assert "key" in reference.claim_data async def test_extstore_composite_conditional(self): """Test using multiple drivers based on size.""" @@ -482,9 +494,10 @@ async def test_selector_always_first_driver_handles_all_stores(self): assert second._store_calls == 0 # The reference in history names the first driver. - ref = JSONPlainPayloadConverter( - encoding="json/external-storage-reference" - ).from_payload(encoded[0], _StorageReference) + ref = JSONProtoPayloadConverter().from_payload( + encoded[0], ExternalStorageReference + ) + assert isinstance(ref, ExternalStorageReference) assert ref.driver_name == "driver-first" # Retrieval also goes to the first driver. @@ -694,5 +707,99 @@ def test_negative_payload_size_threshold_raises(self, threshold: int): ) +class TestBackwardCompat: + """Tests that the retrieval path handles the legacy ``json/external-storage-reference`` + format for in-flight workflows written before the ExternalStorageReference proto.""" + + async def test_legacy_format_single_payload_decode(self): + """A single payload in the legacy reference format is retrieved correctly.""" + driver = InMemoryTestDriver() + + inner_payload = (await DataConverter().encode(["x" * 200]))[0] + stored_key = "payload-0" + driver._storage[stored_key] = inner_payload.SerializeToString() + + legacy_payload = _make_legacy_payload( + driver_name=driver.name(), + claim_data={"key": stored_key}, + size_bytes=inner_payload.ByteSize(), + ) + + converter = DataConverter( + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=100, + ) + ) + decoded = await converter.decode([legacy_payload], [str]) + assert decoded[0] == "x" * 200 + assert driver._retrieve_calls == 1 + + async def test_legacy_and_new_format_mixed_batch_decode(self): + """A batch containing legacy-format, new proto-format, and inline payloads + all decode correctly in a single call.""" + driver = InMemoryTestDriver() + converter = DataConverter( + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=50, + ) + ) + + new_value = "new-format-value" * 20 + inline_value = "small" + encoded = await converter.encode([new_value, inline_value]) + new_format_payload = encoded[0] + inline_payload = encoded[1] + assert driver._store_calls == 1 + + legacy_value = "legacy-format-value" * 20 + legacy_inner = (await DataConverter().encode([legacy_value]))[0] + stored_key = f"payload-{len(driver._storage)}" + driver._storage[stored_key] = legacy_inner.SerializeToString() + legacy_payload = _make_legacy_payload( + driver_name=driver.name(), + claim_data={"key": stored_key}, + size_bytes=legacy_inner.ByteSize(), + ) + + decoded = await converter.decode( + [legacy_payload, new_format_payload, inline_payload], [str, str, str] + ) + assert decoded[0] == legacy_value + assert decoded[1] == new_value + assert decoded[2] == inline_value + # Both external payloads share the same driver and are batched into one retrieve call. + assert driver._retrieve_calls == 1 + + async def test_new_format_encode_round_trips(self): + """Payloads written with the new ExternalStorageReference format round-trip + correctly and carry the expected proto encoding.""" + driver = InMemoryTestDriver() + converter = DataConverter( + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=50, + ) + ) + + value = "round-trip-value" * 20 + encoded = await converter.encode([value]) + ref_payload = encoded[0] + + assert ref_payload.metadata.get("encoding") == b"json/protobuf" + assert len(ref_payload.external_payloads) > 0 + + ref = JSONProtoPayloadConverter().from_payload( + ref_payload, ExternalStorageReference + ) + assert isinstance(ref, ExternalStorageReference) + assert ref.driver_name == driver.name() + assert "key" in ref.claim_data + + decoded = await converter.decode(encoded, [str]) + assert decoded[0] == value + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/worker/test_extstore.py b/tests/worker/test_extstore.py index 2265ed8ee..e186f4e67 100644 --- a/tests/worker/test_extstore.py +++ b/tests/worker/test_extstore.py @@ -27,7 +27,6 @@ StorageDriverRetrieveContext, StorageDriverStoreContext, StorageDriverWorkflowInfo, - StorageWarning, ) from temporalio.exceptions import ActivityError, ApplicationError from temporalio.testing._workflow import WorkflowEnvironment @@ -406,23 +405,18 @@ async def test_replay_extstore_history_fails_without_extstore( ) history = await handle.fetch_history() - # Replay without external storage — the reference payload cannot be decoded. - # The middleware emits a StorageWarning when it encounters a reference payload - # with no driver configured. - with pytest.warns( - StorageWarning, - match=r"^\[TMPRL1105\] Detected externally stored payload\(s\) but external storage is not configured\.$", - ): - result = await Replayer(workflows=[ExtStoreWorkflow]).replay_workflow( - history, raise_on_replay_failure=False - ) - # Must be a task-failure RuntimeError, not a NondeterminismError — external - # storage decode failures are distinct from workflow code changes. + # Replay without external storage: decode_activation raises when it + # encounters a reference payload with no driver configured, producing a + # task failure (not a NondeterminismError). + result = await Replayer(workflows=[ExtStoreWorkflow]).replay_workflow( + history, raise_on_replay_failure=False + ) assert isinstance(result.replay_failure, RuntimeError) assert not isinstance(result.replay_failure, workflow.NondeterminismError) - # The message is the full activation-completion failure string; the - # "Failed decoding arguments" text from _convert_payloads is embedded in it. - assert "Failed decoding arguments" in result.replay_failure.args[0] + assert ( + "[TMPRL1105] Detected externally stored payload(s) but external storage is not configured." + in result.replay_failure.args[0] + ) async def test_replay_extstore_history_succeeds_with_correct_extstore( @@ -483,9 +477,9 @@ async def test_replay_extstore_history_fails_with_empty_driver( async def test_replay_extstore_activity_result_fails_without_extstore( env: WorkflowEnvironment, ) -> None: - """A history where only the activity result was stored externally (the - workflow input is small enough to be inline) also fails to replay without - external storage — verifying that mid-workflow decode failures are caught.""" + """A history where only the activity result was stored externally also fails + to replay without external storage, verifying that mid-workflow reference + payloads are caught regardless of whether the workflow uses the result.""" driver = InMemoryTestDriver() handle = await _run_extstore_workflow_and_fetch_history( env, @@ -496,22 +490,17 @@ async def test_replay_extstore_activity_result_fails_without_extstore( history = await handle.fetch_history() # Replay without external storage. The workflow input decodes fine, but - # when the ActivityTaskCompleted result is delivered back to the workflow - # coroutine it cannot be decoded. - with pytest.warns( - StorageWarning, - match=r"^\[TMPRL1105\] Detected externally stored payload\(s\) but external storage is not configured\.$", - ): - result = await Replayer(workflows=[ExtStoreWorkflow]).replay_workflow( - history, raise_on_replay_failure=False - ) - # Mid-workflow decode failure is still a task failure (RuntimeError), not - # nondeterminism. + # decode_activation raises when the ActivityTaskCompleted reference payload + # is encountered, producing a task failure (not a NondeterminismError). + result = await Replayer(workflows=[ExtStoreWorkflow]).replay_workflow( + history, raise_on_replay_failure=False + ) assert isinstance(result.replay_failure, RuntimeError) assert not isinstance(result.replay_failure, workflow.NondeterminismError) - # The message is the full activation-completion failure string; the - # "Failed decoding arguments" text from _convert_payloads is embedded in it. - assert "Failed decoding arguments" in result.replay_failure.args[0] + assert ( + "[TMPRL1105] Detected externally stored payload(s) but external storage is not configured." + in result.replay_failure.args[0] + ) async def test_extstore_chained_activities(