-
-
Notifications
You must be signed in to change notification settings - Fork 19.4k
BUG: ensure to always return new objects in Index set operations (avoid metadata mutation) #63174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
f8d9770
3e9186d
309af11
9d1c00e
52d01f0
e215763
05c0ace
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -727,7 +727,7 @@ def test_intersection(self, index, sort): | |
|
|
||
| # Corner cases | ||
| inter = first.intersection(first, sort=sort) | ||
| assert inter is first | ||
| assert inter is not first | ||
|
|
||
| @pytest.mark.parametrize( | ||
| "index2_name,keeps_name", | ||
|
|
@@ -812,16 +812,16 @@ def test_union_identity(self, index, sort): | |
| first = index[5:20] | ||
|
|
||
| union = first.union(first, sort=sort) | ||
| # i.e. identity is not preserved when sort is True | ||
| assert (union is first) is (not sort) | ||
| # GH#63169 - identity is not preserved to prevent shared mutable state | ||
| assert union is not first | ||
|
|
||
| # This should no longer be the same object, since [] is not consistent, | ||
| # both objects will be recast to dtype('O') | ||
| union = first.union(Index([], dtype=first.dtype), sort=sort) | ||
| assert (union is first) is (not sort) | ||
| assert union is not first | ||
|
|
||
| union = Index([], dtype=first.dtype).union(first, sort=sort) | ||
| assert (union is first) is (not sort) | ||
| assert union is not first | ||
|
|
||
| @pytest.mark.parametrize("index", ["string"], indirect=True) | ||
| @pytest.mark.parametrize("second_name,expected", [(None, None), ("name", "name")]) | ||
|
|
@@ -984,3 +984,80 @@ def test_union_pyarrow_timestamp(self): | |
| res = left.union(right) | ||
| expected = Index(["2020-01-01", "2020-01-02"], dtype=left.dtype) | ||
| tm.assert_index_equal(res, expected) | ||
|
|
||
|
|
||
| class TestSetOpsMutation: | ||
|
||
| def test_intersection_mutation_safety(self): | ||
| # GH#63169 | ||
| index1 = Index([0, 1], name="original") | ||
| index2 = Index([0, 1], name="original") | ||
|
|
||
| result = index1.intersection(index2) | ||
|
|
||
| assert result is not index1 | ||
| assert result is not index2 | ||
|
|
||
| tm.assert_index_equal(result, index1) | ||
| assert result.name == "original" | ||
|
|
||
| index1.name = "changed" | ||
|
|
||
| assert result.name == "original" | ||
| assert index1.name == "changed" | ||
|
|
||
| def test_union_mutation_safety(self): | ||
| # GH#63169 | ||
| index1 = Index([0, 1], name="original") | ||
| index2 = Index([0, 1], name="original") | ||
|
|
||
| result = index1.union(index2) | ||
|
|
||
| assert result is not index1 | ||
| assert result is not index2 | ||
|
|
||
| tm.assert_index_equal(result, index1) | ||
| assert result.name == "original" | ||
|
|
||
| index1.name = "changed" | ||
|
|
||
| assert result.name == "original" | ||
| assert index1.name == "changed" | ||
|
|
||
| def test_union_mutation_safety_other(self): | ||
| # GH#63169 | ||
| index1 = Index([0, 1], name="original") | ||
| index2 = Index([0, 1], name="original") | ||
|
|
||
| result = index1.union(index2) | ||
|
|
||
| assert result is not index2 | ||
|
|
||
| tm.assert_index_equal(result, index2) | ||
| assert result.name == "original" | ||
|
|
||
| index2.name = "changed" | ||
|
|
||
| assert result.name == "original" | ||
| assert index2.name == "changed" | ||
|
|
||
| def test_multiindex_intersection_mutation_safety(self): | ||
| # GH#63169 | ||
| mi1 = MultiIndex.from_tuples([("a", 1), ("b", 2)], names=["x", "y"]) | ||
| mi2 = MultiIndex.from_tuples([("a", 1), ("b", 2)], names=["x", "y"]) | ||
|
|
||
| result = mi1.intersection(mi2) | ||
| assert result is not mi1 | ||
|
|
||
| mi1.names = ["changed1", "changed2"] | ||
| assert result.names == ["x", "y"] | ||
|
|
||
| def test_multiindex_union_mutation_safety(self): | ||
| # GH#63169 | ||
| mi1 = MultiIndex.from_tuples([("a", 1), ("b", 2)], names=["x", "y"]) | ||
| mi2 = MultiIndex.from_tuples([("a", 1), ("b", 2)], names=["x", "y"]) | ||
|
|
||
| result = mi1.union(mi2) | ||
| assert result is not mi1 | ||
|
|
||
| mi1.names = ["changed1", "changed2"] | ||
| assert result.names == ["x", "y"] | ||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -114,7 +114,7 @@ def test_intersection_bug_1708(self): | |||||||
|
|
||||||||
| def test_intersection_equal(self, sort): | ||||||||
| # GH 24471 Test intersection outcome given the sort keyword | ||||||||
| # for equal indices intersection should return the original index | ||||||||
| # GH#63169 intersection returns a copy to prevent shared mutable state | ||||||||
| first = timedelta_range("1 day", periods=4, freq="h") | ||||||||
| second = timedelta_range("1 day", periods=4, freq="h") | ||||||||
| intersect = first.intersection(second, sort=sort) | ||||||||
|
|
@@ -124,7 +124,7 @@ def test_intersection_equal(self, sort): | |||||||
|
|
||||||||
| # Corner cases | ||||||||
| inter = first.intersection(first, sort=sort) | ||||||||
| assert inter is first | ||||||||
| assert inter is not first | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Just to ensure it are still the same values |
||||||||
|
|
||||||||
| @pytest.mark.parametrize("period_1, period_2", [(0, 4), (4, 0)]) | ||||||||
| def test_intersection_zero_length(self, period_1, period_2, sort): | ||||||||
|
|
||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know that is the default, but for code readability I might add it explicitly (given the default for Series/DataFrame is True, one might expect that here as well; I was at least confused about when first looking at this)