diff --git a/pyiceberg/expressions/__init__.py b/pyiceberg/expressions/__init__.py index 3910a146..ef4cb250 100644 --- a/pyiceberg/expressions/__init__.py +++ b/pyiceberg/expressions/__init__.py @@ -17,6 +17,7 @@ from __future__ import annotations +import copy from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Sequence from functools import cached_property @@ -339,6 +340,10 @@ def __invert__(self) -> BooleanExpression: # De Morgan's law: not (A and B) = (not A) or (not B) return Or(~self.left, ~self.right) + def __deepcopy__(self, memo: dict[int, Any]) -> And: + """Return a deep copy of the And expression.""" + return And(copy.deepcopy(self.left, memo), copy.deepcopy(self.right, memo)) + def __getnewargs__(self) -> tuple[BooleanExpression, BooleanExpression]: """Pickle the And class.""" return (self.left, self.right) @@ -386,6 +391,10 @@ def __invert__(self) -> BooleanExpression: # De Morgan's law: not (A or B) = (not A) and (not B) return And(~self.left, ~self.right) + def __deepcopy__(self, memo: dict[int, Any]) -> Or: + """Return a deep copy of the Or expression.""" + return Or(copy.deepcopy(self.left, memo), copy.deepcopy(self.right, memo)) + def __getnewargs__(self) -> tuple[BooleanExpression, BooleanExpression]: """Pickle the Or class.""" return (self.left, self.right) @@ -428,6 +437,10 @@ def __invert__(self) -> BooleanExpression: """Transform the Expression into its negated version.""" return self.child + def __deepcopy__(self, memo: dict[int, Any]) -> Not: + """Return a deep copy of the Not expression.""" + return Not(copy.deepcopy(self.child, memo)) + def __getnewargs__(self) -> tuple[BooleanExpression]: """Pickle the Not class.""" return (self.child,) diff --git a/tests/expressions/test_expressions.py b/tests/expressions/test_expressions.py index 157c1ada..fde69a46 100644 --- a/tests/expressions/test_expressions.py +++ b/tests/expressions/test_expressions.py @@ -16,6 +16,7 @@ # under the License. # pylint:disable=redefined-outer-name,eval-used +import copy import pickle import uuid from decimal import Decimal @@ -1292,6 +1293,78 @@ def test_bind_ambiguous_name() -> None: assert "Invalid schema, multiple fields for name foo.bar: 2 and 3" in str(exc_info) +# --- deepcopy tests --- + + +def test_deepcopy_and() -> None: + expr = And(EqualTo("x", 1), EqualTo("y", 2)) + copied = copy.deepcopy(expr) + assert copied == expr + assert copied is not expr + + +def test_deepcopy_or() -> None: + expr = Or(EqualTo("x", 1), EqualTo("y", 2)) + copied = copy.deepcopy(expr) + assert copied == expr + assert copied is not expr + + +def test_deepcopy_not() -> None: + expr = Not(EqualTo("x", 1)) + copied = copy.deepcopy(expr) + assert copied == expr + assert copied is not expr + + +def test_deepcopy_equal_to() -> None: + expr = EqualTo("x", 1) + copied = copy.deepcopy(expr) + assert copied == expr + assert copied is not expr + + +def test_deepcopy_always_true() -> None: + copied = copy.deepcopy(AlwaysTrue()) + assert copied is AlwaysTrue() + + +def test_deepcopy_always_false() -> None: + copied = copy.deepcopy(AlwaysFalse()) + assert copied is AlwaysFalse() + + +def test_deepcopy_always_true_then_pickle() -> None: + copied = copy.deepcopy(AlwaysTrue()) + restored = pickle.loads(pickle.dumps(copied)) + assert restored is AlwaysTrue() + + +def test_deepcopy_balanced_and() -> None: + expr = And(EqualTo("a", 1), EqualTo("b", 2), EqualTo("c", 3), EqualTo("d", 4)) + copied = copy.deepcopy(expr) + assert copied == expr + + +def test_deepcopy_balanced_or() -> None: + expr = Or(EqualTo("a", 1), EqualTo("b", 2), EqualTo("c", 3), EqualTo("d", 4)) + copied = copy.deepcopy(expr) + assert copied == expr + + +def test_deepcopy_nested_expression() -> None: + expr = And(Or(EqualTo("a", 1), EqualTo("b", 2)), Not(EqualTo("c", 3))) + copied = copy.deepcopy(expr) + assert copied == expr + + +def test_deepcopy_then_pickle() -> None: + expr = And(EqualTo("x", 1), EqualTo("y", 2)) + copied = copy.deepcopy(expr) + restored = pickle.loads(pickle.dumps(copied)) + assert restored == expr + + # __ __ ___ # | \/ |_ _| _ \_ _ # | |\/| | || | _/ || |