diff --git a/CHANGES.md b/CHANGES.md index 5499cb066476..bb8cfe884411 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -69,6 +69,7 @@ ## New Features / Improvements * Added support for large pipeline options via a file (Python) ([#37370](https://github.com/apache/beam/issues/37370)). +* Supported infer schema from dataclass (Python) ([#22085](https://github.com/apache/beam/issues/22085)). Default coder for typehint-ed (or set with_output_type) for non-frozen dataclasses changed to RowCoder. To preserve the old behavior (fast primitive coder), explicitly register the type with FastPrimitiveCoder. ## Breaking Changes diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index b3e45bc7f35c..1270b98f9bc4 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -493,7 +493,7 @@ def encode_special_deterministic(self, value, stream): stream.write_byte(PROTO_TYPE) self.encode_type(type(value), stream) stream.write(value.SerializePartialToString(deterministic=True), True) - elif dataclasses and dataclasses.is_dataclass(value): + elif dataclasses.is_dataclass(value): if not type(value).__dataclass_params__.frozen: raise TypeError( "Unable to deterministically encode non-frozen '%s' of type '%s' " diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py b/sdks/python/apache_beam/typehints/native_type_compatibility.py index 886b1505ffec..a8a3ad293254 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py @@ -176,8 +176,50 @@ def match_is_named_tuple(user_type): hasattr(user_type, '__annotations__') and hasattr(user_type, '_fields')) -def match_is_dataclass(user_type): - return dataclasses.is_dataclass(user_type) and isinstance(user_type, type) +def match_dataclass_for_row(user_type): + """Match whether the type is a dataclass handled by row coder. + + For frozen dataclasses, only true when explicitly registered with row coder: + + beam.coders.typecoders.registry.register_coder( + MyDataClass, beam.coders.RowCoder) + + (for backward-compatibility reason). + + For non-frozen dataclasses, default to true otherwise explicitly registered + with a coder other than the row coder. + """ + + if not dataclasses.is_dataclass(user_type): + return False + + # pylint: disable=wrong-import-position + try: + from apache_beam.options.pipeline_options_context import get_pipeline_options # pylint: disable=line-too-long + except AttributeError: + pass + else: + opts = get_pipeline_options() + if opts and opts.is_compat_version_prior_to("2.73.0"): + return False + + is_frozen = user_type.__dataclass_params__.frozen + # avoid circular import + try: + from apache_beam.coders.typecoders import registry as coders_registry + from apache_beam.coders import RowCoder + except AttributeError: + # coder registery not yet initialized so it must be absent + return not is_frozen + + if is_frozen: + return ( + user_type in coders_registry._coders and + coders_registry._coders[user_type] == RowCoder) + else: + return ( + user_type not in coders_registry._coders or + coders_registry._coders[user_type] == RowCoder) def _match_is_optional(user_type): diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility_test.py b/sdks/python/apache_beam/typehints/native_type_compatibility_test.py index e9ce732d2e9b..01f40a29945f 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility_test.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility_test.py @@ -20,11 +20,17 @@ # pytype: skip-file import collections.abc +import dataclasses import enum import re import typing import unittest +from parameterized import param +from parameterized import parameterized + +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options_context import scoped_pipeline_options from apache_beam.typehints import typehints from apache_beam.typehints.native_type_compatibility import convert_builtin_to_typing from apache_beam.typehints.native_type_compatibility import convert_to_beam_type @@ -33,6 +39,7 @@ from apache_beam.typehints.native_type_compatibility import convert_to_python_types from apache_beam.typehints.native_type_compatibility import convert_typing_to_builtin from apache_beam.typehints.native_type_compatibility import is_any +from apache_beam.typehints.native_type_compatibility import match_dataclass_for_row _TestNamedTuple = typing.NamedTuple( '_TestNamedTuple', [('age', int), ('name', bytes)]) @@ -509,6 +516,58 @@ def test_type_alias_type_unwrapped(self): self.assertEqual( typehints.Tuple[int, ...], convert_to_beam_type(AliasTuple)) + def test_dataclass_default(self): + @dataclasses.dataclass(frozen=True) + class FrozenDC: + foo: int + + @dataclasses.dataclass + class NonFrozenDC: + foo: int + + self.assertFalse(match_dataclass_for_row(FrozenDC)) + self.assertTrue(match_dataclass_for_row(NonFrozenDC)) + + def test_dataclass_registered(self): + @dataclasses.dataclass(frozen=True) + class FrozenRegisteredDC: + foo: int + + @dataclasses.dataclass + class NonFrozenRegisteredDC: + foo: int + + # pylint: disable=wrong-import-position + from apache_beam.coders import RowCoder + from apache_beam.coders import typecoders + from apache_beam.coders.coders import FastPrimitivesCoder + + typecoders.registry.register_coder(FrozenRegisteredDC, RowCoder) + typecoders.registry.register_coder( + NonFrozenRegisteredDC, FastPrimitivesCoder) + + self.assertTrue(match_dataclass_for_row(FrozenRegisteredDC)) + self.assertFalse(match_dataclass_for_row(NonFrozenRegisteredDC)) + + @parameterized.expand([ + param(compat_version="2.72.0"), + param(compat_version="2.73.0"), + ]) + def test_dataclass_update_compatibility(self, compat_version): + @dataclasses.dataclass(frozen=True) + class FrozenDC: + foo: int + + @dataclasses.dataclass + class NonFrozenDC: + foo: int + + with scoped_pipeline_options( + PipelineOptions(update_compatibility_version=compat_version)): + self.assertFalse(match_dataclass_for_row(FrozenDC)) + self.assertEqual( + compat_version == "2.73.0", match_dataclass_for_row(NonFrozenDC)) + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/typehints/row_type.py b/sdks/python/apache_beam/typehints/row_type.py index 6f96f6f64e32..0697581cb435 100644 --- a/sdks/python/apache_beam/typehints/row_type.py +++ b/sdks/python/apache_beam/typehints/row_type.py @@ -27,7 +27,7 @@ from typing import Tuple from apache_beam.typehints import typehints -from apache_beam.typehints.native_type_compatibility import match_is_dataclass +from apache_beam.typehints.native_type_compatibility import match_dataclass_for_row from apache_beam.typehints.native_type_compatibility import match_is_named_tuple from apache_beam.typehints.schema_registry import SchemaTypeRegistry @@ -91,6 +91,9 @@ def __init__( # Currently registration happens when converting to schema protos, in # apache_beam.typehints.schemas self._schema_id = getattr(self._user_type, _BEAM_SCHEMA_ID, None) + if self._schema_id and _BEAM_SCHEMA_ID not in self._user_type.__dict__: + # schema id does not inherit. Unset if schema id is from base class + self._schema_id = None self._schema_options = schema_options or [] self._field_options = field_options or {} @@ -105,7 +108,7 @@ def from_user_type( if match_is_named_tuple(user_type): fields = [(name, user_type.__annotations__[name]) for name in user_type._fields] - elif match_is_dataclass(user_type): + elif match_dataclass_for_row(user_type): fields = [(field.name, field.type) for field in dataclasses.fields(user_type)] else: diff --git a/sdks/python/apache_beam/typehints/row_type_test.py b/sdks/python/apache_beam/typehints/row_type_test.py index 73d76fee49ce..97012d9561d7 100644 --- a/sdks/python/apache_beam/typehints/row_type_test.py +++ b/sdks/python/apache_beam/typehints/row_type_test.py @@ -26,6 +26,7 @@ from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to from apache_beam.typehints import row_type +from apache_beam.typehints import schemas class RowTypeTest(unittest.TestCase): @@ -85,6 +86,94 @@ def generate(num: int): | 'Count Elements' >> beam.Map(self._check_key_type_and_count)) assert_that(result, equal_to([10] * 100)) + def test_group_by_key_namedtuple_union(self): + Tuple1 = typing.NamedTuple("Tuple1", [("id", int)]) + + Tuple2 = typing.NamedTuple("Tuple2", [("id", int), ("name", str)]) + + def generate(num: int): + for i in range(2): + yield (Tuple1(i), num) + yield (Tuple2(i, 'a'), num) + + pipeline = TestPipeline(is_integration_test=False) + + with pipeline as p: + result = ( + p + | 'Create' >> beam.Create([i for i in range(2)]) + | 'Generate' >> beam.ParDo(generate).with_output_types( + tuple[(Tuple1 | Tuple2), int]) + | 'GBK' >> beam.GroupByKey() + | 'Count' >> beam.Map(lambda x: len(x[1]))) + assert_that(result, equal_to([2] * 4)) + + # Union of dataclasses as type hint currently result in FastPrimitiveCoder + # fails at GBK + @unittest.skip("https://github.com/apache/beam/issues/22085") + def test_group_by_key_inherited_dataclass_union(self): + @dataclass + class DataClassInt: + id: int + + @dataclass + class DataClassStr(DataClassInt): + name: str + + beam.coders.typecoders.registry.register_coder( + DataClassInt, beam.coders.RowCoder) + beam.coders.typecoders.registry.register_coder( + DataClassStr, beam.coders.RowCoder) + + def generate(num: int): + for i in range(10): + yield (DataClassInt(i), num) + yield (DataClassStr(i, 'a'), num) + + pipeline = TestPipeline(is_integration_test=False) + + with pipeline as p: + result = ( + p + | 'Create' >> beam.Create([i for i in range(2)]) + | 'Generate' >> beam.ParDo(generate).with_output_types( + tuple[(DataClassInt | DataClassStr), int]) + | 'GBK' >> beam.GroupByKey() + | 'Count Elements' >> beam.Map(self._check_key_type_and_count)) + assert_that(result, equal_to([2] * 4)) + + def test_derived_dataclass_schema_id(self): + @dataclass + class BaseDataClass: + id: int + + @dataclass + class DerivedDataClass(BaseDataClass): + name: str + + self.assertFalse(hasattr(BaseDataClass, row_type._BEAM_SCHEMA_ID)) + schema_for_base = schemas.schema_from_element_type(BaseDataClass) + self.assertTrue(hasattr(BaseDataClass, row_type._BEAM_SCHEMA_ID)) + self.assertEqual( + schema_for_base.id, getattr(BaseDataClass, row_type._BEAM_SCHEMA_ID)) + + # Getting the schema for BaseDataClass sets the _beam_schema_id + schemas.typing_to_runner_api( + BaseDataClass, schema_registry=schemas.SchemaTypeRegistry()) + + # We create a RowTypeConstraint from DerivedDataClass. + # It should not inherit the _beam_schema_id from BaseDataClass! + derived_row_type = row_type.RowTypeConstraint.from_user_type( + DerivedDataClass) + self.assertIsNone(derived_row_type._schema_id) + + schema_for_derived = schemas.schema_from_element_type(DerivedDataClass) + self.assertTrue(hasattr(DerivedDataClass, row_type._BEAM_SCHEMA_ID)) + self.assertEqual( + schema_for_derived.id, + getattr(DerivedDataClass, row_type._BEAM_SCHEMA_ID)) + self.assertNotEqual(schema_for_derived.id, schema_for_base.id) + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/typehints/schemas.py b/sdks/python/apache_beam/typehints/schemas.py index 5dd8ff290c48..d2c4db8cabca 100644 --- a/sdks/python/apache_beam/typehints/schemas.py +++ b/sdks/python/apache_beam/typehints/schemas.py @@ -96,7 +96,7 @@ from apache_beam.typehints.native_type_compatibility import _safe_issubclass from apache_beam.typehints.native_type_compatibility import convert_to_python_type from apache_beam.typehints.native_type_compatibility import extract_optional_type -from apache_beam.typehints.native_type_compatibility import match_is_dataclass +from apache_beam.typehints.native_type_compatibility import match_dataclass_for_row from apache_beam.typehints.native_type_compatibility import match_is_named_tuple from apache_beam.typehints.schema_registry import SCHEMA_REGISTRY from apache_beam.typehints.schema_registry import SchemaTypeRegistry @@ -335,9 +335,11 @@ def typing_to_runner_api(self, type_: type) -> schema_pb2.FieldType: atomic_type=PRIMITIVE_TO_ATOMIC_TYPE[int]))) elif _safe_issubclass(type_, Sequence) and not _safe_issubclass(type_, str): - element_type = self.typing_to_runner_api(_get_args(type_)[0]) - return schema_pb2.FieldType( - array_type=schema_pb2.ArrayType(element_type=element_type)) + arg_types = _get_args(type_) + if len(arg_types) > 0: + element_type = self.typing_to_runner_api(arg_types[0]) + return schema_pb2.FieldType( + array_type=schema_pb2.ArrayType(element_type=element_type)) elif _safe_issubclass(type_, Mapping): key_type, value_type = map(self.typing_to_runner_api, _get_args(type_)) @@ -345,9 +347,11 @@ def typing_to_runner_api(self, type_: type) -> schema_pb2.FieldType: map_type=schema_pb2.MapType(key_type=key_type, value_type=value_type)) elif _safe_issubclass(type_, Iterable) and not _safe_issubclass(type_, str): - element_type = self.typing_to_runner_api(_get_args(type_)[0]) - return schema_pb2.FieldType( - array_type=schema_pb2.ArrayType(element_type=element_type)) + arg_types = _get_args(type_) + if len(arg_types) > 0: + element_type = self.typing_to_runner_api(arg_types[0]) + return schema_pb2.FieldType( + array_type=schema_pb2.ArrayType(element_type=element_type)) try: if LogicalType.is_known_logical_type(type_): @@ -630,8 +634,10 @@ def schema_from_element_type(element_type: type) -> schema_pb2.Schema: Returns schema as a list of (name, python_type) tuples""" if isinstance(element_type, row_type.RowTypeConstraint): return named_fields_to_schema(element_type._fields) - elif match_is_named_tuple(element_type) or match_is_dataclass(element_type): - if hasattr(element_type, row_type._BEAM_SCHEMA_ID): + elif match_is_named_tuple(element_type) or match_dataclass_for_row( + element_type): + # schema id does not inherit from base classes + if row_type._BEAM_SCHEMA_ID in element_type.__dict__: # if the named tuple's schema is in registry, we just use it instead of # regenerating one. schema_id = getattr(element_type, row_type._BEAM_SCHEMA_ID) @@ -657,8 +663,15 @@ def union_schema_type(element_types): element_types must be a set of schema-aware types whose fields have the same naming and ordering. """ + named_fields_and_types = [] + for t in element_types: + n = named_fields_from_element_type(t) + if named_fields_and_types and len(named_fields_and_types[-1]) != len(n): + raise TypeError("element types has different number of fields") + named_fields_and_types.append(n) + union_fields_and_types = [] - for field in zip(*[named_fields_from_element_type(t) for t in element_types]): + for field in zip(*named_fields_and_types): names, types = zip(*field) name_set = set(names) if len(name_set) != 1: