Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion dataframely/_base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,25 @@ def __new__(
result = Metadata()
for base in bases:
result.update(mcs._get_metadata_recursively(base))
result.update(mcs._get_metadata(namespace))
# Before merging the child namespace, remove any parent columns that the
# child explicitly overrides (same attribute name). This allows subclasses to
# redefine inherited columns while still detecting genuine alias conflicts.
namespace_metadata = mcs._get_metadata(namespace)
for attr, value in namespace.items():
if not isinstance(value, Column):
continue
# Walk all parent MROs to find if this attribute was a Column in any
# parent class. In multiple-inheritance scenarios, the same attribute
# name may appear in more than one base with different aliases.
keys_to_remove: set[str] = set()
for base in bases:
for parent_cls in base.__mro__:
parent_col = parent_cls.__dict__.get(attr)
if parent_col is not None and isinstance(parent_col, Column):
keys_to_remove.add(parent_col.alias or attr)
for parent_key in keys_to_remove:
result.columns.pop(parent_key, None)
result.update(namespace_metadata)
namespace[_COLUMN_ATTR] = result.columns
cls = super().__new__(mcs, name, bases, namespace, *args, **kwargs)

Expand Down
19 changes: 19 additions & 0 deletions tests/schema/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,22 @@ def test_user_error_polars_datatype_type() -> None:
class MySchemaWithPolarsDataTypeType(dy.Schema):
a = dy.Int32(nullable=False)
b = pl.String # User error: Used pl.String instead of dy.String()


def test_override() -> None:
class FirstSchema(dy.Schema):
x = dy.Int64()

class SecondSchema(FirstSchema):
x = dy.Int64(nullable=True)

first_columns = FirstSchema.columns()
second_columns = SecondSchema.columns()

assert set(first_columns) == {"x"}
assert set(second_columns) == {"x"}

assert first_columns["x"].nullable is False
assert second_columns["x"].nullable is True

assert type(second_columns["x"]) is type(first_columns["x"])
Loading