Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
## New Features / Improvements

* Added support for large pipeline options via a file (Python) ([#37370](https://github.com/apache/beam/issues/37370)).
* Supported infer schema from dataclass (Python) ([#22085](https://github.com/apache/beam/issues/22085)). Default coder for typehint-ed (or set with_output_type) for non-frozen dataclasses changed to RowCoder. To preserve the old behavior (fast primitive coder), explicitly register the type with FastPrimitiveCoder.

## Breaking Changes

Expand Down
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/coders/coder_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def encode_special_deterministic(self, value, stream):
stream.write_byte(PROTO_TYPE)
self.encode_type(type(value), stream)
stream.write(value.SerializePartialToString(deterministic=True), True)
elif dataclasses and dataclasses.is_dataclass(value):
elif dataclasses.is_dataclass(value):
if not type(value).__dataclass_params__.frozen:
raise TypeError(
"Unable to deterministically encode non-frozen '%s' of type '%s' "
Expand Down
46 changes: 44 additions & 2 deletions sdks/python/apache_beam/typehints/native_type_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,50 @@ def match_is_named_tuple(user_type):
hasattr(user_type, '__annotations__') and hasattr(user_type, '_fields'))


def match_is_dataclass(user_type):
return dataclasses.is_dataclass(user_type) and isinstance(user_type, type)
def match_dataclass_for_row(user_type):
Copy link
Contributor Author

@Abacn Abacn Mar 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part fixes #1-reassure backward compatibility of default coder for frozen dataclass

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wasn't clear to me why default coders were affected at all.

So to be clear, coders are only changed when schema_from_element_type is called e.g. for union types (if it can be normalized

return schemas.union_schema_type(params)

Copy link
Contributor Author

@Abacn Abacn Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So to be clear, coders are only changed when schema_from_element_type is called

Yes it affects UnionHint, but even for types incompatible for union schema, e.g. (SomeDataclass | None) or OptionalHint, default coder for the former still gets changed because "named_fields_from_element_type" (calls schema_from_element_type) get evaluated before checking schema compatibility. This made blasting radius fairly large.

https://github.com/apache/beam/pull/37855/changes/BASE..ed6e34ba22f5fb69063b687688edd48f14ccdee1#diff-d31b9184f7423473c4e6deda80b237aa474228143a0bab2faedd9afe2e944982L661

"""Match whether the type is a dataclass handled by row coder.

For frozen dataclasses, only true when explicitly registered with row coder:

beam.coders.typecoders.registry.register_coder(
MyDataClass, beam.coders.RowCoder)

(for backward-compatibility reason).

For non-frozen dataclasses, default to true otherwise explicitly registered
with a coder other than the row coder.
"""

if not dataclasses.is_dataclass(user_type):
return False

# pylint: disable=wrong-import-position
try:
from apache_beam.options.pipeline_options_context import get_pipeline_options # pylint: disable=line-too-long
except AttributeError:
pass
else:
opts = get_pipeline_options()
if opts and opts.is_compat_version_prior_to("2.73.0"):
return False

is_frozen = user_type.__dataclass_params__.frozen
# avoid circular import
try:
from apache_beam.coders.typecoders import registry as coders_registry
from apache_beam.coders import RowCoder
except AttributeError:
# coder registery not yet initialized so it must be absent
return not is_frozen

if is_frozen:
return (
user_type in coders_registry._coders and
coders_registry._coders[user_type] == RowCoder)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still necessary if we are already doing the is_compat_version_prior_to check?

I guess if we do not want to change the users type from dataclass -> named tuple (unless explicitly using row coder) then this check makes sense. But not necessarily for upgrade compatibility.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dataclass -> named tuple (unless explicitly using row coder) then this check makes sense

yes and this affected some internal tests (a fraction of the targets listed in b/492300593#comment4 internally)

else:
return (
user_type not in coders_registry._coders or
coders_registry._coders[user_type] == RowCoder)


def _match_is_optional(user_type):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,17 @@
# pytype: skip-file

import collections.abc
import dataclasses
import enum
import re
import typing
import unittest

from parameterized import param
from parameterized import parameterized

from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options_context import scoped_pipeline_options
from apache_beam.typehints import typehints
from apache_beam.typehints.native_type_compatibility import convert_builtin_to_typing
from apache_beam.typehints.native_type_compatibility import convert_to_beam_type
Expand All @@ -33,6 +39,7 @@
from apache_beam.typehints.native_type_compatibility import convert_to_python_types
from apache_beam.typehints.native_type_compatibility import convert_typing_to_builtin
from apache_beam.typehints.native_type_compatibility import is_any
from apache_beam.typehints.native_type_compatibility import match_dataclass_for_row

_TestNamedTuple = typing.NamedTuple(
'_TestNamedTuple', [('age', int), ('name', bytes)])
Expand Down Expand Up @@ -509,6 +516,58 @@ def test_type_alias_type_unwrapped(self):
self.assertEqual(
typehints.Tuple[int, ...], convert_to_beam_type(AliasTuple))

def test_dataclass_default(self):
@dataclasses.dataclass(frozen=True)
class FrozenDC:
foo: int

@dataclasses.dataclass
class NonFrozenDC:
foo: int

self.assertFalse(match_dataclass_for_row(FrozenDC))
self.assertTrue(match_dataclass_for_row(NonFrozenDC))

def test_dataclass_registered(self):
@dataclasses.dataclass(frozen=True)
class FrozenRegisteredDC:
foo: int

@dataclasses.dataclass
class NonFrozenRegisteredDC:
foo: int

# pylint: disable=wrong-import-position
from apache_beam.coders import RowCoder
from apache_beam.coders import typecoders
from apache_beam.coders.coders import FastPrimitivesCoder

typecoders.registry.register_coder(FrozenRegisteredDC, RowCoder)
typecoders.registry.register_coder(
NonFrozenRegisteredDC, FastPrimitivesCoder)

self.assertTrue(match_dataclass_for_row(FrozenRegisteredDC))
self.assertFalse(match_dataclass_for_row(NonFrozenRegisteredDC))

@parameterized.expand([
param(compat_version="2.72.0"),
param(compat_version="2.73.0"),
])
def test_dataclass_update_compatibility(self, compat_version):
@dataclasses.dataclass(frozen=True)
class FrozenDC:
foo: int

@dataclasses.dataclass
class NonFrozenDC:
foo: int

with scoped_pipeline_options(
PipelineOptions(update_compatibility_version=compat_version)):
self.assertFalse(match_dataclass_for_row(FrozenDC))
self.assertEqual(
compat_version == "2.73.0", match_dataclass_for_row(NonFrozenDC))


if __name__ == '__main__':
unittest.main()
7 changes: 5 additions & 2 deletions sdks/python/apache_beam/typehints/row_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from typing import Tuple

from apache_beam.typehints import typehints
from apache_beam.typehints.native_type_compatibility import match_is_dataclass
from apache_beam.typehints.native_type_compatibility import match_dataclass_for_row
from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
from apache_beam.typehints.schema_registry import SchemaTypeRegistry

Expand Down Expand Up @@ -91,6 +91,9 @@ def __init__(
# Currently registration happens when converting to schema protos, in
# apache_beam.typehints.schemas
self._schema_id = getattr(self._user_type, _BEAM_SCHEMA_ID, None)
if self._schema_id and _BEAM_SCHEMA_ID not in self._user_type.__dict__:
# schema id does not inherit. Unset if schema id is from base class
self._schema_id = None

self._schema_options = schema_options or []
self._field_options = field_options or {}
Expand All @@ -105,7 +108,7 @@ def from_user_type(
if match_is_named_tuple(user_type):
fields = [(name, user_type.__annotations__[name])
for name in user_type._fields]
elif match_is_dataclass(user_type):
elif match_dataclass_for_row(user_type):
fields = [(field.name, field.type)
for field in dataclasses.fields(user_type)]
else:
Expand Down
89 changes: 89 additions & 0 deletions sdks/python/apache_beam/typehints/row_type_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.typehints import row_type
from apache_beam.typehints import schemas


class RowTypeTest(unittest.TestCase):
Expand Down Expand Up @@ -85,6 +86,94 @@ def generate(num: int):
| 'Count Elements' >> beam.Map(self._check_key_type_and_count))
assert_that(result, equal_to([10] * 100))

def test_group_by_key_namedtuple_union(self):
Tuple1 = typing.NamedTuple("Tuple1", [("id", int)])

Tuple2 = typing.NamedTuple("Tuple2", [("id", int), ("name", str)])

def generate(num: int):
for i in range(2):
yield (Tuple1(i), num)
yield (Tuple2(i, 'a'), num)

pipeline = TestPipeline(is_integration_test=False)

with pipeline as p:
result = (
p
| 'Create' >> beam.Create([i for i in range(2)])
| 'Generate' >> beam.ParDo(generate).with_output_types(
tuple[(Tuple1 | Tuple2), int])
| 'GBK' >> beam.GroupByKey()
| 'Count' >> beam.Map(lambda x: len(x[1])))
assert_that(result, equal_to([2] * 4))

# Union of dataclasses as type hint currently result in FastPrimitiveCoder
# fails at GBK
@unittest.skip("https://github.com/apache/beam/issues/22085")
def test_group_by_key_inherited_dataclass_union(self):
@dataclass
class DataClassInt:
id: int

@dataclass
class DataClassStr(DataClassInt):
name: str

beam.coders.typecoders.registry.register_coder(
DataClassInt, beam.coders.RowCoder)
beam.coders.typecoders.registry.register_coder(
DataClassStr, beam.coders.RowCoder)

def generate(num: int):
for i in range(10):
yield (DataClassInt(i), num)
yield (DataClassStr(i, 'a'), num)

pipeline = TestPipeline(is_integration_test=False)

with pipeline as p:
result = (
p
| 'Create' >> beam.Create([i for i in range(2)])
| 'Generate' >> beam.ParDo(generate).with_output_types(
tuple[(DataClassInt | DataClassStr), int])
| 'GBK' >> beam.GroupByKey()
| 'Count Elements' >> beam.Map(self._check_key_type_and_count))
assert_that(result, equal_to([2] * 4))

def test_derived_dataclass_schema_id(self):
@dataclass
class BaseDataClass:
id: int

@dataclass
class DerivedDataClass(BaseDataClass):
name: str

self.assertFalse(hasattr(BaseDataClass, row_type._BEAM_SCHEMA_ID))
schema_for_base = schemas.schema_from_element_type(BaseDataClass)
self.assertTrue(hasattr(BaseDataClass, row_type._BEAM_SCHEMA_ID))
self.assertEqual(
schema_for_base.id, getattr(BaseDataClass, row_type._BEAM_SCHEMA_ID))

# Getting the schema for BaseDataClass sets the _beam_schema_id
schemas.typing_to_runner_api(
BaseDataClass, schema_registry=schemas.SchemaTypeRegistry())

# We create a RowTypeConstraint from DerivedDataClass.
# It should not inherit the _beam_schema_id from BaseDataClass!
derived_row_type = row_type.RowTypeConstraint.from_user_type(
DerivedDataClass)
self.assertIsNone(derived_row_type._schema_id)

schema_for_derived = schemas.schema_from_element_type(DerivedDataClass)
self.assertTrue(hasattr(DerivedDataClass, row_type._BEAM_SCHEMA_ID))
self.assertEqual(
schema_for_derived.id,
getattr(DerivedDataClass, row_type._BEAM_SCHEMA_ID))
self.assertNotEqual(schema_for_derived.id, schema_for_base.id)


if __name__ == '__main__':
unittest.main()
33 changes: 23 additions & 10 deletions sdks/python/apache_beam/typehints/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
from apache_beam.typehints.native_type_compatibility import _safe_issubclass
from apache_beam.typehints.native_type_compatibility import convert_to_python_type
from apache_beam.typehints.native_type_compatibility import extract_optional_type
from apache_beam.typehints.native_type_compatibility import match_is_dataclass
from apache_beam.typehints.native_type_compatibility import match_dataclass_for_row
from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
from apache_beam.typehints.schema_registry import SCHEMA_REGISTRY
from apache_beam.typehints.schema_registry import SchemaTypeRegistry
Expand Down Expand Up @@ -335,19 +335,23 @@ def typing_to_runner_api(self, type_: type) -> schema_pb2.FieldType:
atomic_type=PRIMITIVE_TO_ATOMIC_TYPE[int])))

elif _safe_issubclass(type_, Sequence) and not _safe_issubclass(type_, str):
element_type = self.typing_to_runner_api(_get_args(type_)[0])
return schema_pb2.FieldType(
array_type=schema_pb2.ArrayType(element_type=element_type))
arg_types = _get_args(type_)
if len(arg_types) > 0:
element_type = self.typing_to_runner_api(arg_types[0])
return schema_pb2.FieldType(
array_type=schema_pb2.ArrayType(element_type=element_type))

elif _safe_issubclass(type_, Mapping):
key_type, value_type = map(self.typing_to_runner_api, _get_args(type_))
return schema_pb2.FieldType(
map_type=schema_pb2.MapType(key_type=key_type, value_type=value_type))

elif _safe_issubclass(type_, Iterable) and not _safe_issubclass(type_, str):
element_type = self.typing_to_runner_api(_get_args(type_)[0])
return schema_pb2.FieldType(
array_type=schema_pb2.ArrayType(element_type=element_type))
arg_types = _get_args(type_)
if len(arg_types) > 0:
element_type = self.typing_to_runner_api(arg_types[0])
return schema_pb2.FieldType(
array_type=schema_pb2.ArrayType(element_type=element_type))

try:
if LogicalType.is_known_logical_type(type_):
Expand Down Expand Up @@ -630,8 +634,10 @@ def schema_from_element_type(element_type: type) -> schema_pb2.Schema:
Returns schema as a list of (name, python_type) tuples"""
if isinstance(element_type, row_type.RowTypeConstraint):
return named_fields_to_schema(element_type._fields)
elif match_is_named_tuple(element_type) or match_is_dataclass(element_type):
if hasattr(element_type, row_type._BEAM_SCHEMA_ID):
elif match_is_named_tuple(element_type) or match_dataclass_for_row(
element_type):
# schema id does not inherit from base classes
if row_type._BEAM_SCHEMA_ID in element_type.__dict__:
# if the named tuple's schema is in registry, we just use it instead of
# regenerating one.
schema_id = getattr(element_type, row_type._BEAM_SCHEMA_ID)
Expand All @@ -657,8 +663,15 @@ def union_schema_type(element_types):
element_types must be a set of schema-aware types whose fields have the
same naming and ordering.
"""
named_fields_and_types = []
for t in element_types:
n = named_fields_from_element_type(t)
if named_fields_and_types and len(named_fields_and_types[-1]) != len(n):
raise TypeError("element types has different number of fields")
named_fields_and_types.append(n)

union_fields_and_types = []
for field in zip(*[named_fields_from_element_type(t) for t in element_types]):
for field in zip(*named_fields_and_types):
names, types = zip(*field)
name_set = set(names)
if len(name_set) != 1:
Expand Down
Loading