diff --git a/dataframely/_base_schema.py b/dataframely/_base_schema.py index b875e6b..068943c 100644 --- a/dataframely/_base_schema.py +++ b/dataframely/_base_schema.py @@ -119,7 +119,25 @@ def __new__( result = Metadata() for base in bases: result.update(mcs._get_metadata_recursively(base)) - result.update(mcs._get_metadata(namespace)) + # Before merging the child namespace, remove any parent columns that the + # child explicitly overrides (same attribute name). This allows subclasses to + # redefine inherited columns while still detecting genuine alias conflicts. + namespace_metadata = mcs._get_metadata(namespace) + for attr, value in namespace.items(): + if not isinstance(value, Column): + continue + # Walk all parent MROs to find if this attribute was a Column in any + # parent class. In multiple-inheritance scenarios, the same attribute + # name may appear in more than one base with different aliases. + keys_to_remove: set[str] = set() + for base in bases: + for parent_cls in base.__mro__: + parent_col = parent_cls.__dict__.get(attr) + if parent_col is not None and isinstance(parent_col, Column): + keys_to_remove.add(parent_col.alias or attr) + for parent_key in keys_to_remove: + result.columns.pop(parent_key, None) + result.update(namespace_metadata) namespace[_COLUMN_ATTR] = result.columns cls = super().__new__(mcs, name, bases, namespace, *args, **kwargs) diff --git a/tests/schema/test_base.py b/tests/schema/test_base.py index 6eb2084..eb86ca3 100644 --- a/tests/schema/test_base.py +++ b/tests/schema/test_base.py @@ -141,3 +141,22 @@ def test_user_error_polars_datatype_type() -> None: class MySchemaWithPolarsDataTypeType(dy.Schema): a = dy.Int32(nullable=False) b = pl.String # User error: Used pl.String instead of dy.String() + + +def test_override() -> None: + class FirstSchema(dy.Schema): + x = dy.Int64() + + class SecondSchema(FirstSchema): + x = dy.Int64(nullable=True) + + first_columns = FirstSchema.columns() + second_columns = SecondSchema.columns() + + assert set(first_columns) == {"x"} + assert set(second_columns) == {"x"} + + assert first_columns["x"].nullable is False + assert second_columns["x"].nullable is True + + assert type(second_columns["x"]) is type(first_columns["x"])