Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion python/pyspark/sql/tests/coercion/test_pandas_udf_return_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_data(self):
np.arange(1, 3).astype("complex128"),
[np.array([1, 2, 3], dtype=np.int32), np.array([1, 2, 3], dtype=np.int32)],
pd.date_range("19700101", periods=2).values,
pd.date_range("19700101", periods=2, tz="US/Eastern").values,
pd.date_range("19700101", periods=2, tz="America/New_York").values,
[pd.Timedelta("1 day"), pd.Timedelta("2 days")],
pd.Categorical(["A", "B"]),
pd.DataFrame({"_1": [1, 2]}),
Expand Down Expand Up @@ -160,6 +160,49 @@ def _compare_or_generate_golden(self, golden_file, test_name):
golden = None
if not generating:
golden = self.load_golden_csv(golden_csv)
# The golden file was generated under pandas 2; patch the loaded
# copy in memory so the same file works under pandas >= 3.0, where
# the defaults differ: datetime64 ndarrays use [us] instead of [ns],
# Categorical categories use str instead of object, and the same
# casts return microseconds instead of nanoseconds.
if LooseVersion(pd.__version__) >= LooseVersion("3.0.0"):
rename = {}
scale_cols = []
for value in self.test_data:
new_key = self.repr_value(value)
if isinstance(value, np.ndarray) and value.dtype.kind == "M":
old_key = self.repr_value(value.astype("datetime64[ns]"))
if old_key != new_key:
rename[old_key] = new_key
scale_cols.append(new_key)
elif isinstance(value, pd.Categorical) and value.categories.dtype != object:
old_key = self.repr_value(
pd.Categorical(
value.tolist(),
categories=pd.Index(value.categories.tolist(), dtype=object),
)
)
if old_key != new_key:
rename[old_key] = new_key
elif isinstance(value, list) and value and isinstance(value[0], pd.Timedelta):
scale_cols.append(new_key)

if rename:
golden.rename(columns=rename, inplace=True)

for col in scale_cols:
golden[col] = golden[col].str.replace(
r"\d{13,}",
lambda m: str(int(m.group()) // 1000),
regex=True,
)

# Pandas 3 succeeds at coercing string list -> Decimal where
# pandas 2 errored, so the corresponding cell flips from "X".
decimal_idx = self.repr_type(DecimalType(10, 0))
decimal_col = self.repr_value(["12", "34"])
if decimal_idx in golden.index and decimal_col in golden.columns:
golden.loc[decimal_idx, decimal_col] = "[Decimal('12'), Decimal('34')]"

def work(arg):
spark_type, value = arg
Expand Down