Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2975,13 +2975,12 @@ def __bool__(self) -> NoReturn:
def _get_reconciled_name_object(self, other):
"""
If the result of a set operation will be self,
return self, unless the name changes, in which
case make a shallow copy of self.
return a shallow copy of self.
"""
name = get_op_result_name(self, other)
if self.name is not name:
return self.rename(name)
return self
return self.copy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return self.copy()
return self.copy(deep=False)

I know that is the default, but for code readability I might add it explicitly (given the default for Series/DataFrame is True, one might expect that here as well; I was at least confused about when first looking at this)


@final
def _validate_sort_keyword(self, sort) -> None:
Expand Down
5 changes: 2 additions & 3 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4082,13 +4082,12 @@ def _is_comparable_dtype(self, dtype: DtypeObj) -> bool:
def _get_reconciled_name_object(self, other) -> MultiIndex:
"""
If the result of a set operation will be self,
return self, unless the names change, in which
case make a shallow copy of self.
return a shallow copy of self.
"""
names = self._maybe_match_names(other)
if self.names != names:
return self.rename(names)
return self
return self.copy()

def _maybe_match_names(self, other):
"""
Expand Down
87 changes: 82 additions & 5 deletions pandas/tests/indexes/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ def test_intersection(self, index, sort):

# Corner cases
inter = first.intersection(first, sort=sort)
assert inter is first
assert inter is not first

@pytest.mark.parametrize(
"index2_name,keeps_name",
Expand Down Expand Up @@ -812,16 +812,16 @@ def test_union_identity(self, index, sort):
first = index[5:20]

union = first.union(first, sort=sort)
# i.e. identity is not preserved when sort is True
assert (union is first) is (not sort)
# GH#63169 - identity is not preserved to prevent shared mutable state
assert union is not first

# This should no longer be the same object, since [] is not consistent,
# both objects will be recast to dtype('O')
union = first.union(Index([], dtype=first.dtype), sort=sort)
assert (union is first) is (not sort)
assert union is not first

union = Index([], dtype=first.dtype).union(first, sort=sort)
assert (union is first) is (not sort)
assert union is not first

@pytest.mark.parametrize("index", ["string"], indirect=True)
@pytest.mark.parametrize("second_name,expected", [(None, None), ("name", "name")])
Expand Down Expand Up @@ -984,3 +984,80 @@ def test_union_pyarrow_timestamp(self):
res = left.union(right)
expected = Index(["2020-01-01", "2020-01-02"], dtype=left.dtype)
tm.assert_index_equal(res, expected)


class TestSetOpsMutation:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd like to move away from having test classes in cases where the class itself isn't useful. Can you just make these functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. I'll convert them to functions shortly.

def test_intersection_mutation_safety(self):
# GH#63169
index1 = Index([0, 1], name="original")
index2 = Index([0, 1], name="original")

result = index1.intersection(index2)

assert result is not index1
assert result is not index2

tm.assert_index_equal(result, index1)
assert result.name == "original"

index1.name = "changed"

assert result.name == "original"
assert index1.name == "changed"

def test_union_mutation_safety(self):
# GH#63169
index1 = Index([0, 1], name="original")
index2 = Index([0, 1], name="original")

result = index1.union(index2)

assert result is not index1
assert result is not index2

tm.assert_index_equal(result, index1)
assert result.name == "original"

index1.name = "changed"

assert result.name == "original"
assert index1.name == "changed"

def test_union_mutation_safety_other(self):
# GH#63169
index1 = Index([0, 1], name="original")
index2 = Index([0, 1], name="original")

result = index1.union(index2)

assert result is not index2

tm.assert_index_equal(result, index2)
assert result.name == "original"

index2.name = "changed"

assert result.name == "original"
assert index2.name == "changed"

def test_multiindex_intersection_mutation_safety(self):
# GH#63169
mi1 = MultiIndex.from_tuples([("a", 1), ("b", 2)], names=["x", "y"])
mi2 = MultiIndex.from_tuples([("a", 1), ("b", 2)], names=["x", "y"])

result = mi1.intersection(mi2)
assert result is not mi1

mi1.names = ["changed1", "changed2"]
assert result.names == ["x", "y"]

def test_multiindex_union_mutation_safety(self):
# GH#63169
mi1 = MultiIndex.from_tuples([("a", 1), ("b", 2)], names=["x", "y"])
mi2 = MultiIndex.from_tuples([("a", 1), ("b", 2)], names=["x", "y"])

result = mi1.union(mi2)
assert result is not mi1

mi1.names = ["changed1", "changed2"]
assert result.names == ["x", "y"]
4 changes: 2 additions & 2 deletions pandas/tests/indexes/timedeltas/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_intersection_bug_1708(self):

def test_intersection_equal(self, sort):
# GH 24471 Test intersection outcome given the sort keyword
# for equal indices intersection should return the original index
# GH#63169 intersection returns a copy to prevent shared mutable state
first = timedelta_range("1 day", periods=4, freq="h")
second = timedelta_range("1 day", periods=4, freq="h")
intersect = first.intersection(second, sort=sort)
Expand All @@ -124,7 +124,7 @@ def test_intersection_equal(self, sort):

# Corner cases
inter = first.intersection(first, sort=sort)
assert inter is first
assert inter is not first
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert inter is not first
assert inter is not first
tm.assert_index_equal(inter, first)

Just to ensure it are still the same values


@pytest.mark.parametrize("period_1, period_2", [(0, 4), (4, 0)])
def test_intersection_zero_length(self, period_1, period_2, sort):
Expand Down
Loading