From b728cf34eb01ca71d77a3fd264faad9266705ffd Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Tue, 19 May 2026 07:58:28 +0000 Subject: [PATCH] fix: raise on column count mismatch in verify_arrow_result positional branch --- .../tests/arrow/test_arrow_cogrouped_map.py | 28 ++++++++++++++++ .../sql/tests/arrow/test_arrow_grouped_map.py | 33 +++++++++++++++++++ python/pyspark/worker.py | 8 +++++ 3 files changed, 69 insertions(+) diff --git a/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py b/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py index cfeba6cbc3162..5b272f89bb5d9 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py @@ -230,6 +230,34 @@ def stats(key, left, right): # stats returns three columns while here we set schema with two columns self.cogrouped.applyInArrow(stats, schema="id long, m double").collect() + def test_apply_in_arrow_returning_wrong_column_count_positional_assignment(self): + def too_many_cols(key, left, right): + return pa.Table.from_pydict( + { + "a": [key[0].as_py()], + "b": [pc.mean(left.column("v")).as_py()], + "c": [pc.mean(right.column("v")).as_py()], + } + ) + + def too_few_cols(key, left, right): + return pa.Table.from_pydict({"a": [key[0].as_py()]}) + + with self.sql_conf( + {"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False} + ): + with self.quiet(): + for func, expected, actual in [ + (too_many_cols, 2, 3), + (too_few_cols, 2, 1), + ]: + with self.subTest(func=func.__name__): + with self.assertRaisesRegex( + PythonException, + rf"Expected: {expected}.*Actual: {actual}", + ): + self.cogrouped.applyInArrow(func, schema="a long, b double").collect() + def test_apply_in_arrow_returning_empty_dataframe(self): def odd_means(key, left, right): if key[0].as_py() == 0: diff --git a/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py b/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py index cefce8d0cf656..e0d40cfebe59e 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py @@ -254,6 +254,39 @@ def stats(key, table): func_variation, schema="id long, m double" ).collect() + def test_apply_in_arrow_returning_wrong_column_count_positional_assignment(self): + df = self.data + + def too_many_cols(key, table): + return pa.Table.from_pydict( + { + "a": [key[0].as_py()], + "b": [pc.mean(table.column("v")).as_py()], + "c": [pc.stddev(table.column("v")).as_py()], + } + ) + + def too_few_cols(key, table): + return pa.Table.from_pydict({"a": [key[0].as_py()]}) + + with self.sql_conf( + {"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False} + ): + with self.quiet(): + for func, expected, actual in [ + (too_many_cols, 2, 3), + (too_few_cols, 2, 1), + ]: + with self.subTest(func=func.__name__): + for func_variation in function_variations(func): + with self.assertRaisesRegex( + PythonException, + rf"Expected: {expected}.*Actual: {actual}", + ): + df.groupby("id").applyInArrow( + func_variation, schema="a long, b double" + ).collect() + def test_apply_in_arrow_returning_empty_dataframe(self): df = self.data diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 2d877565f55cd..2accab06e2d6e 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -547,6 +547,14 @@ def verify_arrow_result(result, assign_cols_by_name, expected_cols_and_types): actual_cols_and_types = [ (name, dataType) for name, dataType in zip(result.schema.names, result.schema.types) ] + if len(actual_cols_and_types) != len(expected_cols_and_types): + raise PySparkRuntimeError( + errorClass="RESULT_COLUMN_SCHEMA_MISMATCH", + messageParameters={ + "expected": str(len(expected_cols_and_types)), + "actual": str(len(actual_cols_and_types)), + }, + ) column_types = [ (expected_name, expected_type, actual_type) for (expected_name, expected_type), (actual_name, actual_type) in zip(