Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,34 @@ def stats(key, left, right):
# stats returns three columns while here we set schema with two columns
self.cogrouped.applyInArrow(stats, schema="id long, m double").collect()

def test_apply_in_arrow_returning_wrong_column_count_positional_assignment(self):
def too_many_cols(key, left, right):
return pa.Table.from_pydict(
{
"a": [key[0].as_py()],
"b": [pc.mean(left.column("v")).as_py()],
"c": [pc.mean(right.column("v")).as_py()],
}
)

def too_few_cols(key, left, right):
return pa.Table.from_pydict({"a": [key[0].as_py()]})

with self.sql_conf(
{"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False}
):
with self.quiet():
for func, expected, actual in [
(too_many_cols, 2, 3),
(too_few_cols, 2, 1),
]:
with self.subTest(func=func.__name__):
with self.assertRaisesRegex(
PythonException,
rf"Expected: {expected}.*Actual: {actual}",
):
self.cogrouped.applyInArrow(func, schema="a long, b double").collect()

def test_apply_in_arrow_returning_empty_dataframe(self):
def odd_means(key, left, right):
if key[0].as_py() == 0:
Expand Down
33 changes: 33 additions & 0 deletions python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,39 @@ def stats(key, table):
func_variation, schema="id long, m double"
).collect()

def test_apply_in_arrow_returning_wrong_column_count_positional_assignment(self):
df = self.data

def too_many_cols(key, table):
return pa.Table.from_pydict(
{
"a": [key[0].as_py()],
"b": [pc.mean(table.column("v")).as_py()],
"c": [pc.stddev(table.column("v")).as_py()],
}
)

def too_few_cols(key, table):
return pa.Table.from_pydict({"a": [key[0].as_py()]})

with self.sql_conf(
{"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False}
):
with self.quiet():
for func, expected, actual in [
(too_many_cols, 2, 3),
(too_few_cols, 2, 1),
]:
with self.subTest(func=func.__name__):
for func_variation in function_variations(func):
with self.assertRaisesRegex(
PythonException,
rf"Expected: {expected}.*Actual: {actual}",
):
df.groupby("id").applyInArrow(
func_variation, schema="a long, b double"
).collect()

def test_apply_in_arrow_returning_empty_dataframe(self):
df = self.data

Expand Down
8 changes: 8 additions & 0 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,14 @@ def verify_arrow_result(result, assign_cols_by_name, expected_cols_and_types):
actual_cols_and_types = [
(name, dataType) for name, dataType in zip(result.schema.names, result.schema.types)
]
if len(actual_cols_and_types) != len(expected_cols_and_types):
raise PySparkRuntimeError(
errorClass="RESULT_COLUMN_SCHEMA_MISMATCH",
messageParameters={
"expected": str(len(expected_cols_and_types)),
"actual": str(len(actual_cols_and_types)),
},
)
column_types = [
(expected_name, expected_type, actual_type)
for (expected_name, expected_type), (actual_name, actual_type) in zip(
Expand Down