diff --git a/python/pyspark/sql/tests/coercion/test_pandas_udf_return_type.py b/python/pyspark/sql/tests/coercion/test_pandas_udf_return_type.py index 454fe726f95cd..f1ba3cd847239 100644 --- a/python/pyspark/sql/tests/coercion/test_pandas_udf_return_type.py +++ b/python/pyspark/sql/tests/coercion/test_pandas_udf_return_type.py @@ -104,7 +104,7 @@ def test_data(self): np.arange(1, 3).astype("complex128"), [np.array([1, 2, 3], dtype=np.int32), np.array([1, 2, 3], dtype=np.int32)], pd.date_range("19700101", periods=2).values, - pd.date_range("19700101", periods=2, tz="US/Eastern").values, + pd.date_range("19700101", periods=2, tz="America/New_York").values, [pd.Timedelta("1 day"), pd.Timedelta("2 days")], pd.Categorical(["A", "B"]), pd.DataFrame({"_1": [1, 2]}), @@ -160,6 +160,49 @@ def _compare_or_generate_golden(self, golden_file, test_name): golden = None if not generating: golden = self.load_golden_csv(golden_csv) + # The golden file was generated under pandas 2; patch the loaded + # copy in memory so the same file works under pandas >= 3.0, where + # the defaults differ: datetime64 ndarrays use [us] instead of [ns], + # Categorical categories use str instead of object, and the same + # casts return microseconds instead of nanoseconds. + if LooseVersion(pd.__version__) >= LooseVersion("3.0.0"): + rename = {} + scale_cols = [] + for value in self.test_data: + new_key = self.repr_value(value) + if isinstance(value, np.ndarray) and value.dtype.kind == "M": + old_key = self.repr_value(value.astype("datetime64[ns]")) + if old_key != new_key: + rename[old_key] = new_key + scale_cols.append(new_key) + elif isinstance(value, pd.Categorical) and value.categories.dtype != object: + old_key = self.repr_value( + pd.Categorical( + value.tolist(), + categories=pd.Index(value.categories.tolist(), dtype=object), + ) + ) + if old_key != new_key: + rename[old_key] = new_key + elif isinstance(value, list) and value and isinstance(value[0], pd.Timedelta): + scale_cols.append(new_key) + + if rename: + golden.rename(columns=rename, inplace=True) + + for col in scale_cols: + golden[col] = golden[col].str.replace( + r"\d{13,}", + lambda m: str(int(m.group()) // 1000), + regex=True, + ) + + # Pandas 3 succeeds at coercing string list -> Decimal where + # pandas 2 errored, so the corresponding cell flips from "X". + decimal_idx = self.repr_type(DecimalType(10, 0)) + decimal_col = self.repr_value(["12", "34"]) + if decimal_idx in golden.index and decimal_col in golden.columns: + golden.loc[decimal_idx, decimal_col] = "[Decimal('12'), Decimal('34')]" def work(arg): spark_type, value = arg