Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions temporalio/converter/_data_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
import temporalio.api.common.v1
import temporalio.api.failure.v1
import temporalio.common
from temporalio.api.sdk.v1.external_storage_pb2 import ExternalStorageReference
from temporalio.converter._extstore import (
_REFERENCE_ENCODING,
ExternalStorage,
StorageDriverStoreContext,
StorageWarning,
)
from temporalio.converter._failure_converter import (
FailureConverter,
Expand All @@ -41,6 +41,17 @@
WithSerializationContext,
)

_REFERENCE_MESSAGE_TYPE = ExternalStorageReference.DESCRIPTOR.full_name.encode()


def _is_reference_payload(p: temporalio.api.common.v1.Payload) -> bool:
"""Return True if *p* is an external-storage reference payload."""
return p.metadata.get("encoding") == _REFERENCE_ENCODING or (
p.metadata.get("encoding") == b"json/protobuf"
and p.metadata.get("messageType") == _REFERENCE_MESSAGE_TYPE
)


# Import defaults from public API to avoid pydoctor cross-reference issues
if TYPE_CHECKING:
from temporalio.converter import DefaultFailureConverter, DefaultPayloadConverter
Expand Down Expand Up @@ -307,13 +318,9 @@ async def _transform_inbound_payloads(
if self.external_storage:
await self.external_storage._retrieve_payloads(payloads)
else:
if any(
p.metadata.get("encoding") == _REFERENCE_ENCODING
for p in payloads.payloads
):
warnings.warn(
"[TMPRL1105] Detected externally stored payload(s) but external storage is not configured.",
StorageWarning,
if any(_is_reference_payload(p) for p in payloads.payloads):
raise RuntimeError(
"[TMPRL1105] Detected externally stored payload(s) but external storage is not configured."
)
if self.payload_codec:
await self.payload_codec.decode_wrapper(payloads)
Expand Down Expand Up @@ -348,13 +355,9 @@ async def _external_retrieve_payload_sequence(
retrieved_payloads
)
else:
if any(
p.metadata.get("encoding") == _REFERENCE_ENCODING
for p in retrieved_payloads
):
warnings.warn(
"[TMPRL1105] Detected externally stored payload(s) but external storage is not configured.",
StorageWarning,
if any(_is_reference_payload(p) for p in retrieved_payloads):
raise RuntimeError(
"[TMPRL1105] Detected externally stored payload(s) but external storage is not configured."
)
return retrieved_payloads

Expand Down
68 changes: 45 additions & 23 deletions temporalio/converter/_extstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from typing_extensions import Self

from temporalio.api.common.v1 import Payload, Payloads
from temporalio.converter._payload_converter import JSONPlainPayloadConverter
from temporalio.api.sdk.v1.external_storage_pb2 import ExternalStorageReference
from temporalio.converter._payload_converter import (
JSONPlainPayloadConverter,
JSONProtoPayloadConverter,
)

_T = TypeVar("_T")

Expand Down Expand Up @@ -225,6 +229,11 @@ class StorageWarning(RuntimeWarning):

@dataclass(frozen=True)
class _StorageReference:
"""Legacy external storage reference used only on the retrieval path as a
fallback for in-flight workflows that were written before the
ExternalStorageReference proto was introduced.
"""

driver_name: str
driver_claim: StorageDriverClaim

Expand Down Expand Up @@ -278,8 +287,9 @@ class ExternalStorage:
)
"""Store context bound to this instance via :meth:`_with_store_context`."""

_claim_converter: ClassVar[JSONPlainPayloadConverter] = JSONPlainPayloadConverter(
encoding=_REFERENCE_ENCODING.decode()
_claim_converter: ClassVar[JSONProtoPayloadConverter] = JSONProtoPayloadConverter()
_legacy_claim_converter: ClassVar[JSONPlainPayloadConverter] = (
JSONPlainPayloadConverter(encoding=_REFERENCE_ENCODING.decode())
)

def __post_init__(self) -> None:
Expand Down Expand Up @@ -357,9 +367,9 @@ async def _store_payload(self, payload: Payload) -> Payload:
self._validate_claim_length(claims, expected=1, driver=driver)

external_size = payload.ByteSize()
reference = _StorageReference(
reference = ExternalStorageReference(
driver_name=driver.name(),
driver_claim=claims[0],
claim_data=claims[0].claim_data,
Comment thread
jmaeagle99 marked this conversation as resolved.
)
reference_payload = self._claim_converter.to_payload(reference)
if reference_payload is None:
Expand Down Expand Up @@ -421,9 +431,9 @@ async def _store_payload_sequence(
self._validate_claim_length(claims, expected=len(indices), driver=driver)

for i, claim in enumerate(claims):
reference = _StorageReference(
reference = ExternalStorageReference(
driver_name=driver.name(),
driver_claim=claim,
claim_data=claim.claim_data,
)
reference_payload = self._claim_converter.to_payload(reference)
if reference_payload is None:
Expand All @@ -443,20 +453,35 @@ async def _store_payload_sequence(

return results

async def _retrieve_payload(self, payload: Payload) -> Payload:
def _decode_reference(self, payload: Payload) -> ExternalStorageReference | None:
"""Decode an external storage reference from a payload."""
if len(payload.external_payloads) == 0:
return payload

start_time = time.monotonic()
return None
encoding = payload.metadata.get("encoding", b"")
if encoding == _REFERENCE_ENCODING:
legacy = self._legacy_claim_converter.from_payload(
payload, _StorageReference
)
if not isinstance(legacy, _StorageReference):
return None
return ExternalStorageReference(
driver_name=legacy.driver_name,
claim_data=legacy.driver_claim.claim_data,
)
ref = self._claim_converter.from_payload(payload, ExternalStorageReference)
return ref if isinstance(ref, ExternalStorageReference) else None

reference = self._claim_converter.from_payload(payload, _StorageReference)
if not isinstance(reference, _StorageReference):
async def _retrieve_payload(self, payload: Payload) -> Payload:
ref = self._decode_reference(payload)
if ref is None:
return payload

driver = self._get_driver_by_name(reference.driver_name)
start_time = time.monotonic()
Comment thread
jmaeagle99 marked this conversation as resolved.
driver = self._get_driver_by_name(ref.driver_name)
context = StorageDriverRetrieveContext()
claim = StorageDriverClaim(claim_data=dict(ref.claim_data))

stored_payloads = await driver.retrieve(context, [reference.driver_claim])
stored_payloads = await driver.retrieve(context, [claim])

self._validate_payload_length(stored_payloads, expected=1, driver=driver)

Expand Down Expand Up @@ -486,15 +511,12 @@ async def _retrieve_payload_sequence(

driver_claims: dict[StorageDriver, list[tuple[int, StorageDriverClaim]]] = {}
for index, payload in enumerate(payloads):
if len(payload.external_payloads) == 0:
ref = self._decode_reference(payload)
if ref is None:
continue

reference = self._claim_converter.from_payload(payload, _StorageReference)
if not isinstance(reference, _StorageReference):
continue

driver = self._get_driver_by_name(reference.driver_name)
driver_claims.setdefault(driver, []).append((index, reference.driver_claim))
driver = self._get_driver_by_name(ref.driver_name)
claim = StorageDriverClaim(claim_data=dict(ref.claim_data))
driver_claims.setdefault(driver, []).append((index, claim))

if not driver_claims:
return results
Expand Down
141 changes: 124 additions & 17 deletions tests/test_extstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from temporalio.api.common.v1 import Payload
from temporalio.api.sdk.v1.external_storage_pb2 import ExternalStorageReference
from temporalio.converter import (
DataConverter,
ExternalStorage,
Expand All @@ -16,9 +17,26 @@
StorageDriverRetrieveContext,
StorageDriverStoreContext,
)
from temporalio.converter._extstore import _StorageReference
from temporalio.converter._extstore import _REFERENCE_ENCODING, _StorageReference
from temporalio.converter._payload_converter import JSONProtoPayloadConverter
from temporalio.exceptions import ApplicationError

_legacy_ref_converter = JSONPlainPayloadConverter(encoding=_REFERENCE_ENCODING.decode())


def _make_legacy_payload(
driver_name: str, claim_data: dict[str, str], size_bytes: int
) -> Payload:
"""Build a reference payload in the legacy ``json/external-storage-reference`` format."""
ref = _StorageReference(
driver_name=driver_name,
driver_claim=StorageDriverClaim(claim_data=claim_data),
)
payload = _legacy_ref_converter.to_payload(ref)
assert payload is not None
payload.external_payloads.add().size_bytes = size_bytes
return payload


class InMemoryTestDriver(StorageDriver):
"""In-memory storage driver for testing."""
Expand Down Expand Up @@ -115,33 +133,27 @@ async def test_extstore_encode_decode(self):
assert driver._retrieve_calls == 1

async def test_extstore_reference_structure(self):
"""Test that external storage creates proper reference structure."""
"""Externalized payloads are written as ExternalStorageReference proto (json/protobuf encoding)."""
converter = DataConverter(
external_storage=ExternalStorage(
drivers=[InMemoryTestDriver("test-driver")],
payload_size_threshold=50,
)
)

# Create large payload
large_value = "x" * 100
encoded = await converter.encode([large_value])

# Verify reference structure
reference_payload = encoded[0]
assert len(reference_payload.external_payloads) > 0
assert reference_payload.metadata.get("encoding") == b"json/protobuf"

# The payload should contain a serialized _ExternalStorageReference
# Deserialize it to verify structure using the same encoding
claim_converter = JSONPlainPayloadConverter(
encoding="json/external-storage-reference"
reference = JSONProtoPayloadConverter().from_payload(
reference_payload, ExternalStorageReference
)
reference = claim_converter.from_payload(reference_payload, _StorageReference)

assert isinstance(reference, _StorageReference)
assert "test-driver" == reference.driver_name
assert isinstance(reference.driver_claim, StorageDriverClaim)
assert "key" in reference.driver_claim.claim_data
assert isinstance(reference, ExternalStorageReference)
assert reference.driver_name == "test-driver"
assert "key" in reference.claim_data

async def test_extstore_composite_conditional(self):
"""Test using multiple drivers based on size."""
Expand Down Expand Up @@ -482,9 +494,10 @@ async def test_selector_always_first_driver_handles_all_stores(self):
assert second._store_calls == 0

# The reference in history names the first driver.
ref = JSONPlainPayloadConverter(
encoding="json/external-storage-reference"
).from_payload(encoded[0], _StorageReference)
ref = JSONProtoPayloadConverter().from_payload(
encoded[0], ExternalStorageReference
)
assert isinstance(ref, ExternalStorageReference)
assert ref.driver_name == "driver-first"

# Retrieval also goes to the first driver.
Expand Down Expand Up @@ -694,5 +707,99 @@ def test_negative_payload_size_threshold_raises(self, threshold: int):
)


class TestBackwardCompat:
"""Tests that the retrieval path handles the legacy ``json/external-storage-reference``
format for in-flight workflows written before the ExternalStorageReference proto."""

async def test_legacy_format_single_payload_decode(self):
"""A single payload in the legacy reference format is retrieved correctly."""
driver = InMemoryTestDriver()

inner_payload = (await DataConverter().encode(["x" * 200]))[0]
stored_key = "payload-0"
driver._storage[stored_key] = inner_payload.SerializeToString()

legacy_payload = _make_legacy_payload(
driver_name=driver.name(),
claim_data={"key": stored_key},
size_bytes=inner_payload.ByteSize(),
)

converter = DataConverter(
external_storage=ExternalStorage(
drivers=[driver],
payload_size_threshold=100,
)
)
decoded = await converter.decode([legacy_payload], [str])
assert decoded[0] == "x" * 200
assert driver._retrieve_calls == 1

async def test_legacy_and_new_format_mixed_batch_decode(self):
"""A batch containing legacy-format, new proto-format, and inline payloads
all decode correctly in a single call."""
driver = InMemoryTestDriver()
converter = DataConverter(
external_storage=ExternalStorage(
drivers=[driver],
payload_size_threshold=50,
)
)

new_value = "new-format-value" * 20
inline_value = "small"
encoded = await converter.encode([new_value, inline_value])
new_format_payload = encoded[0]
inline_payload = encoded[1]
assert driver._store_calls == 1

legacy_value = "legacy-format-value" * 20
legacy_inner = (await DataConverter().encode([legacy_value]))[0]
stored_key = f"payload-{len(driver._storage)}"
driver._storage[stored_key] = legacy_inner.SerializeToString()
legacy_payload = _make_legacy_payload(
driver_name=driver.name(),
claim_data={"key": stored_key},
size_bytes=legacy_inner.ByteSize(),
)

decoded = await converter.decode(
[legacy_payload, new_format_payload, inline_payload], [str, str, str]
)
assert decoded[0] == legacy_value
assert decoded[1] == new_value
assert decoded[2] == inline_value
# Both external payloads share the same driver and are batched into one retrieve call.
assert driver._retrieve_calls == 1

async def test_new_format_encode_round_trips(self):
"""Payloads written with the new ExternalStorageReference format round-trip
correctly and carry the expected proto encoding."""
driver = InMemoryTestDriver()
converter = DataConverter(
external_storage=ExternalStorage(
drivers=[driver],
payload_size_threshold=50,
)
)

value = "round-trip-value" * 20
encoded = await converter.encode([value])
ref_payload = encoded[0]

assert ref_payload.metadata.get("encoding") == b"json/protobuf"
assert len(ref_payload.external_payloads) > 0

ref = JSONProtoPayloadConverter().from_payload(
ref_payload, ExternalStorageReference
)
assert isinstance(ref, ExternalStorageReference)
assert ref.driver_name == driver.name()
assert "key" in ref.claim_data

decoded = await converter.decode(encoded, [str])
assert decoded[0] == value


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading
Loading