Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion burr/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@
import textwrap
import types
import typing

# types.UnionType was added in Python 3.10 to represent PEP 604 `X | Y`
# syntax. Burr supports Python >= 3.9, so on 3.9 we fall back to a sentinel
# type that keeps Union annotations well-formed. No PEP 604 union can ever
# exist on 3.9, so isinstance() checks against it simply never match.
if sys.version_info >= (3, 10):
_UnionType = types.UnionType
else: # pragma: no cover - exercised on Python 3.9 CI only
class _UnionType: # type: ignore[no-redef]
"""Placeholder for ``types.UnionType`` on Python < 3.10."""
from collections.abc import AsyncIterator
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -1511,7 +1521,7 @@ def pydantic(
writes: List[str],
state_input_type: Type["BaseModel"],
state_output_type: Type["BaseModel"],
stream_type: Union[Type["BaseModel"], Type[dict]],
stream_type: Union[Type["BaseModel"], Type[dict], _UnionType],
tags: Optional[List[str]] = None,
) -> Callable:
"""Creates a streaming action that uses pydantic models.
Expand Down
5 changes: 3 additions & 2 deletions burr/integrations/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from burr.core.action import (
FunctionBasedAction,
FunctionBasedStreamingAction,
_UnionType,
bind,
derive_inputs_from_fn,
)
Expand Down Expand Up @@ -269,7 +270,7 @@ async def async_action_function(state: State, **kwargs) -> State:
return decorator


PartialType = Union[Type[pydantic.BaseModel], Type[dict]]
PartialType = Union[Type[pydantic.BaseModel], Type[dict], _UnionType]

PydanticStreamingActionFunctionSync = Callable[
..., Generator[Tuple[Union[pydantic.BaseModel, dict], Optional[pydantic.BaseModel]], None, None]
Expand All @@ -290,7 +291,7 @@ async def async_action_function(state: State, **kwargs) -> State:

def _validate_and_extract_signature_types_streaming(
fn: PydanticStreamingActionFunction,
stream_type: Optional[Union[Type[pydantic.BaseModel], Type[dict]]],
stream_type: Optional[Union[Type[pydantic.BaseModel], Type[dict], _UnionType]],
state_input_type: Optional[Type[pydantic.BaseModel]] = None,
state_output_type: Optional[Type[pydantic.BaseModel]] = None,
) -> Tuple[
Expand Down
Loading