From cb7765bd64e65a0669d68e669bbb8bd36265382b Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 18 May 2026 04:04:30 +0000 Subject: [PATCH 01/16] [TESTS][CONNECT] Expand Connect-specific tests for DataFrame column resolution Add parity tests in test_connect_column.py and layered programs in test_parity_dataframe.py to lock in known Connect/Classic behavior differences in DataFrame column resolution. Tests cover shadowing, pass-through, aggregation, pivot, set ops, self-join, subquery-as-table patterns under both strict and non-strict modes of spark.sql.analyzer.strictDataFrameColumnResolution, plus three mixed-surface layered programs that combine filters, joins, aggregations, set ops, window functions, UDFs and temp views. Generated-by: Claude Code (Anthropic), claude-opus-4-7 --- .../sql/tests/connect/test_connect_column.py | 240 ++++++++++++++++++ .../tests/connect/test_parity_dataframe.py | 101 ++++++++ 2 files changed, 341 insertions(+) diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index 6fa6c4686c527..d6afad6162162 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -167,6 +167,246 @@ def test_select_column_replaced_by_withcolumn(self): with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): df.withColumn("c", CF.col("c").cast("string")).select(df["c"]).collect() + # --- Connect DataFrame column-resolution divergence parity tests ---------- + # + # These tests pin Connect-specific column-resolution behavior so future + # tightening that aligns Connect with Spark Classic cannot silently regress + # patterns that customer workflows depend on. Each test exercises both modes + # of `spark.sql.analyzer.strictDataFrameColumnResolution`: + # + # * strict=true (default): plan-id-based resolution only; tagged + # `df["c"]` references that point at an attribute no longer in the + # current plan fail with CANNOT_RESOLVE_DATAFRAME_COLUMN. + # * strict=false: when plan-id-based resolution fails, the analyzer + # falls back to name-based resolution against the current child output. + # + # The fallback behavior is what AT&T-style workflows historically depended + # on; deletions of these tests should be reviewed as a behavioral change. + + def test_resolve_after_chained_withcolumn_shadow(self): + # Two consecutive withColumn calls each shadow `c` with a new + # attribute carrying the same name. Under non-strict the tagged + # df["c"] falls through to name-based resolution and matches the final + # projected `c`; under strict the tagged ancestor's attribute is no + # longer in the plan and analysis fails. + from pyspark.errors.exceptions.connect import AnalysisException + + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): + df = self.connect.sql("SELECT 1 AS c") + df.withColumn("c", CF.col("c").cast("string")).withColumn( + "c", CF.col("c").cast("int") + ).select(df["c"]).collect() + + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): + df = self.connect.sql("SELECT 1 AS c") + with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): + df.withColumn("c", CF.col("c").cast("string")).withColumn( + "c", CF.col("c").cast("int") + ).select(df["c"]).collect() + + def test_resolve_after_select_alias_shadow(self): + # Same shadowing shape as withColumn but expressed through a select + # with alias. The original `c` attribute is dropped and a new `c` is + # projected with the same name; tagged df["c"] only resolves under + # non-strict via name-based fallback. + from pyspark.errors.exceptions.connect import AnalysisException + + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): + df = self.connect.sql("SELECT 1 AS c") + df.select(df["c"].cast("string").alias("c")).select(df["c"]).collect() + + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): + df = self.connect.sql("SELECT 1 AS c") + with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): + df.select(df["c"].cast("string").alias("c")).select(df["c"]).collect() + + def test_resolve_after_withcolumnrenamed(self): + # withColumnRenamed drops the original `c` attribute and projects it + # as `c2`. The tagged df["c"] cannot resolve in either mode: under + # strict the plan-id ancestor's attribute is gone, and under + # non-strict the name-based fallback also fails because the current + # child output no longer contains a column named `c`. + from pyspark.errors.exceptions.connect import AnalysisException + + for strict in (True, False): + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): + df = self.connect.sql("SELECT 1 AS c") + with self.assertRaises(AnalysisException): + df.withColumnRenamed("c", "c2").select(df["c"]).collect() + + def test_resolve_after_drop(self): + # drop("c") removes the column entirely. Tagged df["c"] cannot resolve + # under either mode: plan-id resolution misses (attribute gone) and + # name-based fallback misses (name gone). + from pyspark.errors.exceptions.connect import AnalysisException + + for strict in (True, False): + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): + df = self.connect.sql("SELECT 1 AS c, 2 AS d") + with self.assertRaises(AnalysisException): + df.drop("c").select(df["c"]).collect() + + def test_resolve_through_filter(self): + # filter is a pass-through operator: the child Project's attributes + # flow through unchanged, so plan-id-based resolution finds the + # original tagged attribute in both modes. + for strict in (True, False): + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): + df = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") + rows = df.filter(df["c"] > 0).select(df["c"]).collect() + self.assertEqual(sorted(r.c for r in rows), [1, 2]) + + def test_resolve_through_sort(self): + # sort is also a pass-through operator. + for strict in (True, False): + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): + df = self.connect.sql("SELECT 2 AS c UNION ALL SELECT 1 AS c") + rows = df.sort(df["c"]).select(df["c"]).collect() + self.assertEqual([r.c for r in rows], [1, 2]) + + def test_resolve_through_distinct(self): + # distinct is also a pass-through operator from the perspective of + # attribute identity. + for strict in (True, False): + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): + df = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 1 AS c") + rows = df.distinct().select(df["c"]).collect() + self.assertEqual([r.c for r in rows], [1]) + + def test_resolve_after_groupby_count(self): + # groupBy("c").count() preserves the grouping key by name but emits + # a new aggregate output schema. Under non-strict the tagged df["c"] + # falls back to name-based resolution and matches the grouping key. + # Under strict the tagged attribute id is not in the aggregate output + # and analysis fails. + from pyspark.errors.exceptions.connect import AnalysisException + + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): + df = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 1 AS c UNION ALL SELECT 2 AS c") + rows = df.groupBy("c").count().select(df["c"]).collect() + self.assertEqual(sorted(r.c for r in rows), [1, 2]) + + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): + df = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 1 AS c UNION ALL SELECT 2 AS c") + with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): + df.groupBy("c").count().select(df["c"]).collect() + + def test_resolve_after_agg_alias_shadow(self): + # An aggregate output named `c` via alias() collides with the source + # `c` by name. The tagged df["c"] still references the source + # attribute (which has been aggregated away); non-strict mode falls + # back to name-based resolution and matches the aliased aggregate. + from pyspark.errors.exceptions.connect import AnalysisException + + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): + df = self.connect.sql("SELECT 1 AS x") + df.groupBy().agg(CF.sum("x").alias("c")).select(df["c"]).collect() + + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): + df = self.connect.sql("SELECT 1 AS x") + with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): + df.groupBy().agg(CF.sum("x").alias("c")).select(df["c"]).collect() + + def test_resolve_after_pivot(self): + # pivot emits a new schema whose columns are the pivot values, with + # the grouping key preserved by name. Tagged references to the + # original grouping key resolve under non-strict via name-based + # fallback and fail under strict. + from pyspark.errors.exceptions.connect import AnalysisException + + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): + df = self.connect.sql( + "SELECT 1 AS c, 'a' AS k, 10 AS v UNION ALL SELECT 2 AS c, 'b' AS k, 20 AS v" + ) + df.groupBy("c").pivot("k").sum("v").select(df["c"]).collect() + + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): + df = self.connect.sql( + "SELECT 1 AS c, 'a' AS k, 10 AS v UNION ALL SELECT 2 AS c, 'b' AS k, 20 AS v" + ) + with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): + df.groupBy("c").pivot("k").sum("v").select(df["c"]).collect() + + def test_resolve_after_union(self): + # Union emits new attribute ids that differ from either input's + # attributes. The tagged df1["c"] reference cannot match the union + # output by plan id under strict mode; under non-strict it resolves + # by name to the union's `c` output. + from pyspark.errors.exceptions.connect import AnalysisException + + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): + df1 = self.connect.sql("SELECT 1 AS c") + df2 = self.connect.sql("SELECT 2 AS c") + df1.union(df2).select(df1["c"]).collect() + + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): + df1 = self.connect.sql("SELECT 1 AS c") + df2 = self.connect.sql("SELECT 2 AS c") + with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): + df1.union(df2).select(df1["c"]).collect() + + def test_resolve_after_intersect(self): + # intersect, like union, emits new attribute ids. Tagged df1["c"] + # only matches by name under non-strict. + from pyspark.errors.exceptions.connect import AnalysisException + + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): + df1 = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") + df2 = self.connect.sql("SELECT 2 AS c UNION ALL SELECT 3 AS c") + df1.intersect(df2).select(df1["c"]).collect() + + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): + df1 = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") + df2 = self.connect.sql("SELECT 2 AS c UNION ALL SELECT 3 AS c") + with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): + df1.intersect(df2).select(df1["c"]).collect() + + def test_resolve_self_join_alias(self): + # In a self-join, both sides originate from the same plan-id-tagged + # ancestor. Connect resolves aliased self-joins by attaching distinct + # plan ids to the aliased DataFrames; the original df["c"] reference + # is ambiguous under strict mode (the plan-id ancestor matches both + # sides). Under non-strict, name-based fallback still hits the + # ambiguity. + from pyspark.errors.exceptions.connect import AnalysisException + + for strict in (True, False): + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): + df = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") + a = df.alias("a") + b = df.alias("b") + with self.assertRaises(AnalysisException): + a.join(b, a["c"] == b["c"]).select(df["c"]).collect() + + def test_resolve_after_subquery_view(self): + # Persisting the original DataFrame as a temp view and reading it + # back via spark.table() produces a new plan with new attribute ids. + # The tagged df["c"] reference targets the original plan id, which + # is not an ancestor of the new DataFrame. Under non-strict the + # name-based fallback succeeds; under strict it does not. + import uuid + + from pyspark.errors.exceptions.connect import AnalysisException + + view_name = f"v_{uuid.uuid4().hex}" + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): + df = self.connect.sql("SELECT 1 AS c") + df.createOrReplaceTempView(view_name) + try: + self.connect.table(view_name).select(df["c"]).collect() + finally: + self.connect.sql(f"DROP VIEW IF EXISTS {view_name}") + + view_name = f"v_{uuid.uuid4().hex}" + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): + df = self.connect.sql("SELECT 1 AS c") + df.createOrReplaceTempView(view_name) + try: + with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): + self.connect.table(view_name).select(df["c"]).collect() + finally: + self.connect.sql(f"DROP VIEW IF EXISTS {view_name}") + def test_column_with_null(self): # SPARK-41751: test isNull, isNotNull, eqNullSafe diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py index decb027f8e21e..2baf7cf2c4af8 100644 --- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py +++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py @@ -16,6 +16,7 @@ # import unittest +import uuid from pyspark.sql.tests.test_dataframe import DataFrameTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase @@ -38,6 +39,106 @@ def test_query_execution_unsupported_in_classic(self): def test_to_json(self): pass + # --- Connect-only layered DataFrame regression programs ------------------ + # + # These tests exercise mixed-surface DataFrame pipelines (4-6 chained + # operators combining filters, joins, aggregations, set ops, window + # functions, UDFs and temporary views) end-to-end. The intent is to catch + # regressions in Connect's plan-id propagation through analyzer rules + # that single-operator tests miss when rules interact. + # + # Each program runs under the non-strict mode of + # `spark.sql.analyzer.strictDataFrameColumnResolution` (the historical + # Connect contract that customer workflows depend on) and asserts the + # expected output. Under strict mode the same program either succeeds + # identically (when no tagged reference crosses a shadowing boundary) or + # fails at the first divergence. + # + # Tests live here intentionally: they must not be moved to Classic-shared + # suites where they'd be removed as "diverging from Classic" during + # routine cleanup. + + def test_layered_filter_join_agg_shadow(self): + from pyspark.sql import functions as sf + + with self.sql_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": "false"}): + df = self.spark.createDataFrame( + [(1, 10), (1, 20), (2, 30), (2, 40), (3, 50)], ["c", "v"] + ) + # filter -> self-join on c -> groupBy/agg shadowing c -> withColumn + # shadowing c again -> select(df["c"]) which routes through name + # fallback in lenient mode. + result = ( + df.filter(df["v"] > 10) + .alias("a") + .join(df.alias("b"), sf.col("a.c") == sf.col("b.c")) + .groupBy(sf.col("a.c").alias("c")) + .agg(sf.sum(sf.col("b.v")).alias("s")) + .withColumn("c", sf.col("c").cast("string")) + .select(df["c"], "s") + .collect() + ) + self.assertEqual( + sorted((r.c, r.s) for r in result), + [("1", 60), ("2", 140)], + ) + + def test_layered_temp_view_subquery_udf(self): + from pyspark.sql import functions as sf + from pyspark.sql.types import IntegerType + + with self.sql_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": "false"}): + view_name = f"layered_view_{uuid.uuid4().hex}" + df = self.spark.createDataFrame([(1, 10), (2, 20), (3, 30)], ["c", "v"]) + df.createOrReplaceTempView(view_name) + try: + # createOrReplaceTempView -> SQL with subquery referencing the + # view -> join back to the original DataFrame -> apply a UDF + # over a shadowed column -> select via the tagged df["c"] + # (resolved by name fallback since the view+SQL path emits new + # attribute ids). + double_udf = sf.udf(lambda x: x * 2 if x is not None else None, IntegerType()) + sub = self.spark.sql( + f"SELECT c, v FROM {view_name} WHERE v IN (SELECT v FROM {view_name} WHERE c > 1)" + ) + result = ( + sub.join(df, sub["c"] == df["c"]) + .withColumn("v", double_udf(sub["v"])) + .select(df["c"], "v") + .collect() + ) + self.assertEqual( + sorted((r.c, r.v) for r in result), + [(2, 40), (3, 60)], + ) + finally: + self.spark.sql(f"DROP VIEW IF EXISTS {view_name}") + + def test_layered_union_window_pivot_shadow(self): + from pyspark.sql import functions as sf + from pyspark.sql.window import Window + + with self.sql_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": "false"}): + df1 = self.spark.createDataFrame([(1, "a", 10), (1, "b", 20)], ["c", "k", "v"]) + df2 = self.spark.createDataFrame([(2, "a", 30), (2, "b", 40)], ["c", "k", "v"]) + unioned = df1.unionByName(df2) + # union -> window aggregation -> pivot -> withColumn shadow -> + # select via original df1["c"] tagged reference. + w = Window.partitionBy("c") + result = ( + unioned.withColumn("rank_v", sf.row_number().over(w.orderBy("v"))) + .groupBy("c") + .pivot("k", ["a", "b"]) + .sum("v") + .withColumn("c", sf.col("c").cast("string")) + .select(df1["c"], "a", "b") + .collect() + ) + self.assertEqual( + sorted((r.c, r.a, r.b) for r in result), + [("1", 10, 20), ("2", 30, 40)], + ) + if __name__ == "__main__": from pyspark.testing import main From 2a0af6516818377e52c682a769908a47a5de4635 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 18 May 2026 04:33:35 +0000 Subject: [PATCH 02/16] Group divergence tests in SparkConnectColumnResolutionTests with Classic baselines Move the focused parity tests and the layered programs into a dedicated SparkConnectColumnResolutionTests class in test_connect_column.py. Each test now starts with a Classic baseline against self.spark to document how Spark Classic handles the same pattern, followed by Connect strict and Connect lenient blocks. test_parity_dataframe.py is reverted to its prior state. Generated-by: Claude Code (Anthropic), claude-opus-4-7 --- .../sql/tests/connect/test_connect_column.py | 774 ++++++++++++------ .../tests/connect/test_parity_dataframe.py | 101 --- 2 files changed, 534 insertions(+), 341 deletions(-) diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index d6afad6162162..0e3c8bd9d39c6 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -167,246 +167,6 @@ def test_select_column_replaced_by_withcolumn(self): with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): df.withColumn("c", CF.col("c").cast("string")).select(df["c"]).collect() - # --- Connect DataFrame column-resolution divergence parity tests ---------- - # - # These tests pin Connect-specific column-resolution behavior so future - # tightening that aligns Connect with Spark Classic cannot silently regress - # patterns that customer workflows depend on. Each test exercises both modes - # of `spark.sql.analyzer.strictDataFrameColumnResolution`: - # - # * strict=true (default): plan-id-based resolution only; tagged - # `df["c"]` references that point at an attribute no longer in the - # current plan fail with CANNOT_RESOLVE_DATAFRAME_COLUMN. - # * strict=false: when plan-id-based resolution fails, the analyzer - # falls back to name-based resolution against the current child output. - # - # The fallback behavior is what AT&T-style workflows historically depended - # on; deletions of these tests should be reviewed as a behavioral change. - - def test_resolve_after_chained_withcolumn_shadow(self): - # Two consecutive withColumn calls each shadow `c` with a new - # attribute carrying the same name. Under non-strict the tagged - # df["c"] falls through to name-based resolution and matches the final - # projected `c`; under strict the tagged ancestor's attribute is no - # longer in the plan and analysis fails. - from pyspark.errors.exceptions.connect import AnalysisException - - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): - df = self.connect.sql("SELECT 1 AS c") - df.withColumn("c", CF.col("c").cast("string")).withColumn( - "c", CF.col("c").cast("int") - ).select(df["c"]).collect() - - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): - df = self.connect.sql("SELECT 1 AS c") - with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): - df.withColumn("c", CF.col("c").cast("string")).withColumn( - "c", CF.col("c").cast("int") - ).select(df["c"]).collect() - - def test_resolve_after_select_alias_shadow(self): - # Same shadowing shape as withColumn but expressed through a select - # with alias. The original `c` attribute is dropped and a new `c` is - # projected with the same name; tagged df["c"] only resolves under - # non-strict via name-based fallback. - from pyspark.errors.exceptions.connect import AnalysisException - - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): - df = self.connect.sql("SELECT 1 AS c") - df.select(df["c"].cast("string").alias("c")).select(df["c"]).collect() - - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): - df = self.connect.sql("SELECT 1 AS c") - with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): - df.select(df["c"].cast("string").alias("c")).select(df["c"]).collect() - - def test_resolve_after_withcolumnrenamed(self): - # withColumnRenamed drops the original `c` attribute and projects it - # as `c2`. The tagged df["c"] cannot resolve in either mode: under - # strict the plan-id ancestor's attribute is gone, and under - # non-strict the name-based fallback also fails because the current - # child output no longer contains a column named `c`. - from pyspark.errors.exceptions.connect import AnalysisException - - for strict in (True, False): - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - df = self.connect.sql("SELECT 1 AS c") - with self.assertRaises(AnalysisException): - df.withColumnRenamed("c", "c2").select(df["c"]).collect() - - def test_resolve_after_drop(self): - # drop("c") removes the column entirely. Tagged df["c"] cannot resolve - # under either mode: plan-id resolution misses (attribute gone) and - # name-based fallback misses (name gone). - from pyspark.errors.exceptions.connect import AnalysisException - - for strict in (True, False): - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - df = self.connect.sql("SELECT 1 AS c, 2 AS d") - with self.assertRaises(AnalysisException): - df.drop("c").select(df["c"]).collect() - - def test_resolve_through_filter(self): - # filter is a pass-through operator: the child Project's attributes - # flow through unchanged, so plan-id-based resolution finds the - # original tagged attribute in both modes. - for strict in (True, False): - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - df = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") - rows = df.filter(df["c"] > 0).select(df["c"]).collect() - self.assertEqual(sorted(r.c for r in rows), [1, 2]) - - def test_resolve_through_sort(self): - # sort is also a pass-through operator. - for strict in (True, False): - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - df = self.connect.sql("SELECT 2 AS c UNION ALL SELECT 1 AS c") - rows = df.sort(df["c"]).select(df["c"]).collect() - self.assertEqual([r.c for r in rows], [1, 2]) - - def test_resolve_through_distinct(self): - # distinct is also a pass-through operator from the perspective of - # attribute identity. - for strict in (True, False): - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - df = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 1 AS c") - rows = df.distinct().select(df["c"]).collect() - self.assertEqual([r.c for r in rows], [1]) - - def test_resolve_after_groupby_count(self): - # groupBy("c").count() preserves the grouping key by name but emits - # a new aggregate output schema. Under non-strict the tagged df["c"] - # falls back to name-based resolution and matches the grouping key. - # Under strict the tagged attribute id is not in the aggregate output - # and analysis fails. - from pyspark.errors.exceptions.connect import AnalysisException - - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): - df = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 1 AS c UNION ALL SELECT 2 AS c") - rows = df.groupBy("c").count().select(df["c"]).collect() - self.assertEqual(sorted(r.c for r in rows), [1, 2]) - - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): - df = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 1 AS c UNION ALL SELECT 2 AS c") - with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): - df.groupBy("c").count().select(df["c"]).collect() - - def test_resolve_after_agg_alias_shadow(self): - # An aggregate output named `c` via alias() collides with the source - # `c` by name. The tagged df["c"] still references the source - # attribute (which has been aggregated away); non-strict mode falls - # back to name-based resolution and matches the aliased aggregate. - from pyspark.errors.exceptions.connect import AnalysisException - - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): - df = self.connect.sql("SELECT 1 AS x") - df.groupBy().agg(CF.sum("x").alias("c")).select(df["c"]).collect() - - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): - df = self.connect.sql("SELECT 1 AS x") - with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): - df.groupBy().agg(CF.sum("x").alias("c")).select(df["c"]).collect() - - def test_resolve_after_pivot(self): - # pivot emits a new schema whose columns are the pivot values, with - # the grouping key preserved by name. Tagged references to the - # original grouping key resolve under non-strict via name-based - # fallback and fail under strict. - from pyspark.errors.exceptions.connect import AnalysisException - - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): - df = self.connect.sql( - "SELECT 1 AS c, 'a' AS k, 10 AS v UNION ALL SELECT 2 AS c, 'b' AS k, 20 AS v" - ) - df.groupBy("c").pivot("k").sum("v").select(df["c"]).collect() - - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): - df = self.connect.sql( - "SELECT 1 AS c, 'a' AS k, 10 AS v UNION ALL SELECT 2 AS c, 'b' AS k, 20 AS v" - ) - with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): - df.groupBy("c").pivot("k").sum("v").select(df["c"]).collect() - - def test_resolve_after_union(self): - # Union emits new attribute ids that differ from either input's - # attributes. The tagged df1["c"] reference cannot match the union - # output by plan id under strict mode; under non-strict it resolves - # by name to the union's `c` output. - from pyspark.errors.exceptions.connect import AnalysisException - - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): - df1 = self.connect.sql("SELECT 1 AS c") - df2 = self.connect.sql("SELECT 2 AS c") - df1.union(df2).select(df1["c"]).collect() - - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): - df1 = self.connect.sql("SELECT 1 AS c") - df2 = self.connect.sql("SELECT 2 AS c") - with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): - df1.union(df2).select(df1["c"]).collect() - - def test_resolve_after_intersect(self): - # intersect, like union, emits new attribute ids. Tagged df1["c"] - # only matches by name under non-strict. - from pyspark.errors.exceptions.connect import AnalysisException - - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): - df1 = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") - df2 = self.connect.sql("SELECT 2 AS c UNION ALL SELECT 3 AS c") - df1.intersect(df2).select(df1["c"]).collect() - - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): - df1 = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") - df2 = self.connect.sql("SELECT 2 AS c UNION ALL SELECT 3 AS c") - with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): - df1.intersect(df2).select(df1["c"]).collect() - - def test_resolve_self_join_alias(self): - # In a self-join, both sides originate from the same plan-id-tagged - # ancestor. Connect resolves aliased self-joins by attaching distinct - # plan ids to the aliased DataFrames; the original df["c"] reference - # is ambiguous under strict mode (the plan-id ancestor matches both - # sides). Under non-strict, name-based fallback still hits the - # ambiguity. - from pyspark.errors.exceptions.connect import AnalysisException - - for strict in (True, False): - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - df = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") - a = df.alias("a") - b = df.alias("b") - with self.assertRaises(AnalysisException): - a.join(b, a["c"] == b["c"]).select(df["c"]).collect() - - def test_resolve_after_subquery_view(self): - # Persisting the original DataFrame as a temp view and reading it - # back via spark.table() produces a new plan with new attribute ids. - # The tagged df["c"] reference targets the original plan id, which - # is not an ancestor of the new DataFrame. Under non-strict the - # name-based fallback succeeds; under strict it does not. - import uuid - - from pyspark.errors.exceptions.connect import AnalysisException - - view_name = f"v_{uuid.uuid4().hex}" - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): - df = self.connect.sql("SELECT 1 AS c") - df.createOrReplaceTempView(view_name) - try: - self.connect.table(view_name).select(df["c"]).collect() - finally: - self.connect.sql(f"DROP VIEW IF EXISTS {view_name}") - - view_name = f"v_{uuid.uuid4().hex}" - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): - df = self.connect.sql("SELECT 1 AS c") - df.createOrReplaceTempView(view_name) - try: - with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): - self.connect.table(view_name).select(df["c"]).collect() - finally: - self.connect.sql(f"DROP VIEW IF EXISTS {view_name}") - def test_column_with_null(self): # SPARK-41751: test isNull, isNotNull, eqNullSafe @@ -1343,6 +1103,540 @@ def test_transform(self): ) +class SparkConnectColumnResolutionTests(ReusedMixedTestCase): + """Connect-only tests pinning known Connect/Classic divergences in DataFrame + column resolution. + + For each pattern, the test runs the equivalent program against: + + * Spark Classic (``self.spark``) - one-block baseline that anchors the + Classic behavior (always-failed shadowing patterns vs. pass-through + operators that work in both worlds). + * Spark Connect strict mode (default, + ``spark.sql.analyzer.strictDataFrameColumnResolution=true``) - plan-id- + based resolution only. Tagged ``df["c"]`` references whose ancestor's + attribute is gone fail with ``CANNOT_RESOLVE_DATAFRAME_COLUMN``. + * Spark Connect lenient mode + (``spark.sql.analyzer.strictDataFrameColumnResolution=false``) - if + plan-id-based resolution fails the analyzer also tries name-based + resolution against the current child output. This is the Connect- + specific contract that customer workflows depend on. + + Tests live here intentionally: they must NOT be moved to Classic-shared + suites where they'd be removed as "diverging from Classic" during routine + cleanup. + """ + + def test_resolve_after_chained_withcolumn_shadow(self): + # Two consecutive withColumn calls each shadow `c` with a new + # attribute carrying the same name; the original `c` is no longer in + # the projection. + from pyspark.errors import AnalysisException + from pyspark.errors.exceptions.connect import ( + AnalysisException as ConnectAnalysisException, + ) + + # Classic: fails (the AT&T-style root cause). + sdf = self.spark.sql("SELECT 1 AS c") + with self.assertRaises(AnalysisException): + sdf.withColumn("c", SF.col("c").cast("string")).withColumn( + "c", SF.col("c").cast("int") + ).select(sdf["c"]).collect() + + # Connect strict: same root cause, different error class. + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): + df = self.connect.sql("SELECT 1 AS c") + with self.assertRaisesRegex( + ConnectAnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN" + ): + df.withColumn("c", CF.col("c").cast("string")).withColumn( + "c", CF.col("c").cast("int") + ).select(df["c"]).collect() + + # Connect lenient: succeeds via name-based fallback. + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): + df = self.connect.sql("SELECT 1 AS c") + df.withColumn("c", CF.col("c").cast("string")).withColumn( + "c", CF.col("c").cast("int") + ).select(df["c"]).collect() + + def test_resolve_after_select_alias_shadow(self): + # Same shadowing shape as withColumn but expressed through a select + # with alias. + from pyspark.errors import AnalysisException + from pyspark.errors.exceptions.connect import ( + AnalysisException as ConnectAnalysisException, + ) + + # Classic: fails. + sdf = self.spark.sql("SELECT 1 AS c") + with self.assertRaises(AnalysisException): + sdf.select(sdf["c"].cast("string").alias("c")).select(sdf["c"]).collect() + + # Connect strict: fails. + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): + df = self.connect.sql("SELECT 1 AS c") + with self.assertRaisesRegex( + ConnectAnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN" + ): + df.select(df["c"].cast("string").alias("c")).select(df["c"]).collect() + + # Connect lenient: succeeds via name-based fallback. + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): + df = self.connect.sql("SELECT 1 AS c") + df.select(df["c"].cast("string").alias("c")).select(df["c"]).collect() + + def test_resolve_after_withcolumnrenamed(self): + # withColumnRenamed drops the original `c` attribute and projects it + # as `c2`. The tagged df["c"] cannot resolve under any mode because + # neither the original attribute nor a column named `c` is in the + # current child output. + from pyspark.errors import AnalysisException + from pyspark.errors.exceptions.connect import ( + AnalysisException as ConnectAnalysisException, + ) + + # Classic: fails. + sdf = self.spark.sql("SELECT 1 AS c") + with self.assertRaises(AnalysisException): + sdf.withColumnRenamed("c", "c2").select(sdf["c"]).collect() + + # Connect: fails in both modes. + for strict in (True, False): + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): + df = self.connect.sql("SELECT 1 AS c") + with self.assertRaises(ConnectAnalysisException): + df.withColumnRenamed("c", "c2").select(df["c"]).collect() + + def test_resolve_after_drop(self): + # drop("c") removes the column entirely. Tagged df["c"] cannot resolve + # under any mode. + from pyspark.errors import AnalysisException + from pyspark.errors.exceptions.connect import ( + AnalysisException as ConnectAnalysisException, + ) + + # Classic: fails. + sdf = self.spark.sql("SELECT 1 AS c, 2 AS d") + with self.assertRaises(AnalysisException): + sdf.drop("c").select(sdf["c"]).collect() + + # Connect: fails in both modes. + for strict in (True, False): + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): + df = self.connect.sql("SELECT 1 AS c, 2 AS d") + with self.assertRaises(ConnectAnalysisException): + df.drop("c").select(df["c"]).collect() + + def test_resolve_through_filter(self): + # filter is a pass-through operator: the child Project's attributes + # flow through unchanged, so the tagged reference resolves in both + # worlds. + expected = [1, 2] + + # Classic: succeeds. + sdf = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") + srows = sdf.filter(sdf["c"] > 0).select(sdf["c"]).collect() + self.assertEqual(sorted(r.c for r in srows), expected) + + # Connect: succeeds in both modes. + for strict in (True, False): + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): + df = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") + rows = df.filter(df["c"] > 0).select(df["c"]).collect() + self.assertEqual(sorted(r.c for r in rows), expected) + + def test_resolve_through_sort(self): + # sort is also a pass-through operator. + expected = [1, 2] + + # Classic: succeeds. + sdf = self.spark.sql("SELECT 2 AS c UNION ALL SELECT 1 AS c") + srows = sdf.sort(sdf["c"]).select(sdf["c"]).collect() + self.assertEqual([r.c for r in srows], expected) + + # Connect: succeeds in both modes. + for strict in (True, False): + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): + df = self.connect.sql("SELECT 2 AS c UNION ALL SELECT 1 AS c") + rows = df.sort(df["c"]).select(df["c"]).collect() + self.assertEqual([r.c for r in rows], expected) + + def test_resolve_through_distinct(self): + # distinct preserves attribute identity from the perspective of + # column resolution. + expected = [1] + + # Classic: succeeds. + sdf = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 1 AS c") + srows = sdf.distinct().select(sdf["c"]).collect() + self.assertEqual([r.c for r in srows], expected) + + # Connect: succeeds in both modes. + for strict in (True, False): + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): + df = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 1 AS c") + rows = df.distinct().select(df["c"]).collect() + self.assertEqual([r.c for r in rows], expected) + + def test_resolve_after_groupby_count(self): + # groupBy("c").count() emits a new aggregate output schema. Connect + # strict cannot resolve the original tagged attribute; Connect lenient + # falls back to name-based resolution on the grouping key. + from pyspark.errors import AnalysisException + from pyspark.errors.exceptions.connect import ( + AnalysisException as ConnectAnalysisException, + ) + + # Classic: fails (the source attribute is consumed by Aggregate). + sdf = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 1 AS c UNION ALL SELECT 2 AS c") + with self.assertRaises(AnalysisException): + sdf.groupBy("c").count().select(sdf["c"]).collect() + + # Connect strict: fails. + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): + df = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 1 AS c UNION ALL SELECT 2 AS c") + with self.assertRaisesRegex( + ConnectAnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN" + ): + df.groupBy("c").count().select(df["c"]).collect() + + # Connect lenient: succeeds via name-based fallback. + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): + df = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 1 AS c UNION ALL SELECT 2 AS c") + rows = df.groupBy("c").count().select(df["c"]).collect() + self.assertEqual(sorted(r.c for r in rows), [1, 2]) + + def test_resolve_after_agg_alias_shadow(self): + # An aggregate output named `c` via alias() collides by name with + # the source `c`. The tagged df["c"] still references the source + # attribute that has been aggregated away. + from pyspark.errors import AnalysisException + from pyspark.errors.exceptions.connect import ( + AnalysisException as ConnectAnalysisException, + ) + + # Classic: fails. + sdf = self.spark.sql("SELECT 1 AS x") + with self.assertRaises(AnalysisException): + sdf.groupBy().agg(SF.sum("x").alias("c")).select(sdf["c"]).collect() + + # Connect strict: fails. + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): + df = self.connect.sql("SELECT 1 AS x") + with self.assertRaisesRegex( + ConnectAnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN" + ): + df.groupBy().agg(CF.sum("x").alias("c")).select(df["c"]).collect() + + # Connect lenient: succeeds via name-based fallback. + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): + df = self.connect.sql("SELECT 1 AS x") + df.groupBy().agg(CF.sum("x").alias("c")).select(df["c"]).collect() + + def test_resolve_after_pivot(self): + # pivot emits a new schema. Tagged references to the original + # grouping key resolve under Connect lenient via name-based fallback + # and fail under Connect strict. + from pyspark.errors import AnalysisException + from pyspark.errors.exceptions.connect import ( + AnalysisException as ConnectAnalysisException, + ) + + query = "SELECT 1 AS c, 'a' AS k, 10 AS v UNION ALL SELECT 2 AS c, 'b' AS k, 20 AS v" + + # Classic: fails. + sdf = self.spark.sql(query) + with self.assertRaises(AnalysisException): + sdf.groupBy("c").pivot("k").sum("v").select(sdf["c"]).collect() + + # Connect strict: fails. + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): + df = self.connect.sql(query) + with self.assertRaisesRegex( + ConnectAnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN" + ): + df.groupBy("c").pivot("k").sum("v").select(df["c"]).collect() + + # Connect lenient: succeeds via name-based fallback. + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): + df = self.connect.sql(query) + df.groupBy("c").pivot("k").sum("v").select(df["c"]).collect() + + def test_resolve_after_union(self): + # Union emits new attribute ids that differ from either input's + # attributes. Tagged df1["c"] only resolves under Connect lenient by + # name-based fallback. + from pyspark.errors import AnalysisException + from pyspark.errors.exceptions.connect import ( + AnalysisException as ConnectAnalysisException, + ) + + # Classic: fails. + sdf1 = self.spark.sql("SELECT 1 AS c") + sdf2 = self.spark.sql("SELECT 2 AS c") + with self.assertRaises(AnalysisException): + sdf1.union(sdf2).select(sdf1["c"]).collect() + + # Connect strict: fails. + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): + df1 = self.connect.sql("SELECT 1 AS c") + df2 = self.connect.sql("SELECT 2 AS c") + with self.assertRaisesRegex( + ConnectAnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN" + ): + df1.union(df2).select(df1["c"]).collect() + + # Connect lenient: succeeds via name-based fallback. + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): + df1 = self.connect.sql("SELECT 1 AS c") + df2 = self.connect.sql("SELECT 2 AS c") + df1.union(df2).select(df1["c"]).collect() + + def test_resolve_after_intersect(self): + # intersect, like union, emits new attribute ids. + from pyspark.errors import AnalysisException + from pyspark.errors.exceptions.connect import ( + AnalysisException as ConnectAnalysisException, + ) + + # Classic: fails. + sdf1 = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") + sdf2 = self.spark.sql("SELECT 2 AS c UNION ALL SELECT 3 AS c") + with self.assertRaises(AnalysisException): + sdf1.intersect(sdf2).select(sdf1["c"]).collect() + + # Connect strict: fails. + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): + df1 = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") + df2 = self.connect.sql("SELECT 2 AS c UNION ALL SELECT 3 AS c") + with self.assertRaisesRegex( + ConnectAnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN" + ): + df1.intersect(df2).select(df1["c"]).collect() + + # Connect lenient: succeeds via name-based fallback. + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): + df1 = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") + df2 = self.connect.sql("SELECT 2 AS c UNION ALL SELECT 3 AS c") + df1.intersect(df2).select(df1["c"]).collect() + + def test_resolve_self_join_alias(self): + # In a self-join, both sides originate from the same plan-id-tagged + # ancestor. The tagged df["c"] is ambiguous because two output + # attributes match by name. + from pyspark.errors import AnalysisException + from pyspark.errors.exceptions.connect import ( + AnalysisException as ConnectAnalysisException, + ) + + # Classic: fails (ambiguous reference). + sdf = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") + a, b = sdf.alias("a"), sdf.alias("b") + with self.assertRaises(AnalysisException): + a.join(b, a["c"] == b["c"]).select(sdf["c"]).collect() + + # Connect: fails in both modes. + for strict in (True, False): + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): + df = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") + a, b = df.alias("a"), df.alias("b") + with self.assertRaises(ConnectAnalysisException): + a.join(b, a["c"] == b["c"]).select(df["c"]).collect() + + def test_resolve_after_subquery_view(self): + # Persisting the original DataFrame as a temp view and reading it + # back via spark.table() produces a new plan with new attribute ids. + # The tagged reference targets the original plan id, which is not an + # ancestor of the new DataFrame. + import uuid + + from pyspark.errors import AnalysisException + from pyspark.errors.exceptions.connect import ( + AnalysisException as ConnectAnalysisException, + ) + + # Classic: fails. + view = f"v_{uuid.uuid4().hex}" + sdf = self.spark.sql("SELECT 1 AS c") + sdf.createOrReplaceTempView(view) + try: + with self.assertRaises(AnalysisException): + self.spark.table(view).select(sdf["c"]).collect() + finally: + self.spark.sql(f"DROP VIEW IF EXISTS {view}") + + # Connect strict: fails. + view = f"v_{uuid.uuid4().hex}" + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): + df = self.connect.sql("SELECT 1 AS c") + df.createOrReplaceTempView(view) + try: + with self.assertRaisesRegex( + ConnectAnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN" + ): + self.connect.table(view).select(df["c"]).collect() + finally: + self.connect.sql(f"DROP VIEW IF EXISTS {view}") + + # Connect lenient: succeeds via name-based fallback. + view = f"v_{uuid.uuid4().hex}" + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): + df = self.connect.sql("SELECT 1 AS c") + df.createOrReplaceTempView(view) + try: + self.connect.table(view).select(df["c"]).collect() + finally: + self.connect.sql(f"DROP VIEW IF EXISTS {view}") + + # --- Mixed-surface layered programs -------------------------------------- + # + # These programs combine 4-6 chained operators (filters, joins, + # aggregations, set ops, window functions, UDFs and temp views) in a + # single pipeline. The intent is to catch regressions in Connect's + # plan-id propagation through analyzer rules that single-operator tests + # miss when rules interact. Each runs under Connect lenient mode (the + # historical Connect contract) and asserts the expected output, with a + # Classic baseline that documents how Spark Classic handles the same + # pipeline. + + def test_layered_filter_join_agg_shadow(self): + # filter -> self-join -> groupBy/agg shadowing c -> withColumn + # shadowing c -> select(df["c"]). + from pyspark.errors import AnalysisException + + data = [(1, 10), (1, 20), (2, 30), (2, 40), (3, 50)] + + # Classic: fails at the final shadowing select. + sdf = self.spark.createDataFrame(data, ["c", "v"]) + with self.assertRaises(AnalysisException): + ( + sdf.filter(sdf["v"] > 10) + .alias("a") + .join(sdf.alias("b"), SF.col("a.c") == SF.col("b.c")) + .groupBy(SF.col("a.c").alias("c")) + .agg(SF.sum(SF.col("b.v")).alias("s")) + .withColumn("c", SF.col("c").cast("string")) + .select(sdf["c"], "s") + .collect() + ) + + # Connect lenient: succeeds end-to-end. + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): + df = self.connect.createDataFrame(data, ["c", "v"]) + result = ( + df.filter(df["v"] > 10) + .alias("a") + .join(df.alias("b"), CF.col("a.c") == CF.col("b.c")) + .groupBy(CF.col("a.c").alias("c")) + .agg(CF.sum(CF.col("b.v")).alias("s")) + .withColumn("c", CF.col("c").cast("string")) + .select(df["c"], "s") + .collect() + ) + self.assertEqual( + sorted((r.c, r.s) for r in result), + [("1", 60), ("2", 140)], + ) + + def test_layered_temp_view_subquery_udf(self): + # createOrReplaceTempView -> SQL with subquery referencing the view + # -> join back to the original DataFrame -> apply a UDF -> select via + # the tagged df["c"] (resolved by name fallback since the view+SQL + # path emits new attribute ids). + import uuid + + from pyspark.errors import AnalysisException + from pyspark.sql.types import IntegerType + + data = [(1, 10), (2, 20), (3, 30)] + + # Classic baseline: same shape - the cross-plan tagged reference is + # never in the join output, so analysis fails. + view = f"layered_view_{uuid.uuid4().hex}" + sdf = self.spark.createDataFrame(data, ["c", "v"]) + sdf.createOrReplaceTempView(view) + try: + with self.assertRaises(AnalysisException): + ssub = self.spark.sql( + f"SELECT c, v FROM {view} WHERE v IN (SELECT v FROM {view} WHERE c > 1)" + ) + double_udf = SF.udf(lambda x: x * 2 if x is not None else None, IntegerType()) + ssub.join(sdf, ssub["c"] == sdf["c"]).withColumn("v", double_udf(ssub["v"])).select( + sdf["c"], "v" + ).collect() + finally: + self.spark.sql(f"DROP VIEW IF EXISTS {view}") + + # Connect lenient: succeeds end-to-end. + view = f"layered_view_{uuid.uuid4().hex}" + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): + df = self.connect.createDataFrame(data, ["c", "v"]) + df.createOrReplaceTempView(view) + try: + double_udf = CF.udf(lambda x: x * 2 if x is not None else None, IntegerType()) + sub = self.connect.sql( + f"SELECT c, v FROM {view} WHERE v IN (SELECT v FROM {view} WHERE c > 1)" + ) + result = ( + sub.join(df, sub["c"] == df["c"]) + .withColumn("v", double_udf(sub["v"])) + .select(df["c"], "v") + .collect() + ) + self.assertEqual( + sorted((r.c, r.v) for r in result), + [(2, 40), (3, 60)], + ) + finally: + self.connect.sql(f"DROP VIEW IF EXISTS {view}") + + def test_layered_union_window_pivot_shadow(self): + # union -> window aggregation -> pivot -> withColumn shadow -> + # select via original df1["c"] tagged reference. + from pyspark.errors import AnalysisException + from pyspark.sql.window import Window + + data1 = [(1, "a", 10), (1, "b", 20)] + data2 = [(2, "a", 30), (2, "b", 40)] + + # Classic: fails. + sdf1 = self.spark.createDataFrame(data1, ["c", "k", "v"]) + sdf2 = self.spark.createDataFrame(data2, ["c", "k", "v"]) + w = Window.partitionBy("c") + with self.assertRaises(AnalysisException): + ( + sdf1.unionByName(sdf2) + .withColumn("rank_v", SF.row_number().over(w.orderBy("v"))) + .groupBy("c") + .pivot("k", ["a", "b"]) + .sum("v") + .withColumn("c", SF.col("c").cast("string")) + .select(sdf1["c"], "a", "b") + .collect() + ) + + # Connect lenient: succeeds end-to-end. + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): + df1 = self.connect.createDataFrame(data1, ["c", "k", "v"]) + df2 = self.connect.createDataFrame(data2, ["c", "k", "v"]) + cw = Window.partitionBy("c") + result = ( + df1.unionByName(df2) + .withColumn("rank_v", CF.row_number().over(cw.orderBy("v"))) + .groupBy("c") + .pivot("k", ["a", "b"]) + .sum("v") + .withColumn("c", CF.col("c").cast("string")) + .select(df1["c"], "a", "b") + .collect() + ) + self.assertEqual( + sorted((r.c, r.a, r.b) for r in result), + [("1", 10, 20), ("2", 30, 40)], + ) + + if __name__ == "__main__": from pyspark.testing import main diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py index 2baf7cf2c4af8..decb027f8e21e 100644 --- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py +++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py @@ -16,7 +16,6 @@ # import unittest -import uuid from pyspark.sql.tests.test_dataframe import DataFrameTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase @@ -39,106 +38,6 @@ def test_query_execution_unsupported_in_classic(self): def test_to_json(self): pass - # --- Connect-only layered DataFrame regression programs ------------------ - # - # These tests exercise mixed-surface DataFrame pipelines (4-6 chained - # operators combining filters, joins, aggregations, set ops, window - # functions, UDFs and temporary views) end-to-end. The intent is to catch - # regressions in Connect's plan-id propagation through analyzer rules - # that single-operator tests miss when rules interact. - # - # Each program runs under the non-strict mode of - # `spark.sql.analyzer.strictDataFrameColumnResolution` (the historical - # Connect contract that customer workflows depend on) and asserts the - # expected output. Under strict mode the same program either succeeds - # identically (when no tagged reference crosses a shadowing boundary) or - # fails at the first divergence. - # - # Tests live here intentionally: they must not be moved to Classic-shared - # suites where they'd be removed as "diverging from Classic" during - # routine cleanup. - - def test_layered_filter_join_agg_shadow(self): - from pyspark.sql import functions as sf - - with self.sql_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": "false"}): - df = self.spark.createDataFrame( - [(1, 10), (1, 20), (2, 30), (2, 40), (3, 50)], ["c", "v"] - ) - # filter -> self-join on c -> groupBy/agg shadowing c -> withColumn - # shadowing c again -> select(df["c"]) which routes through name - # fallback in lenient mode. - result = ( - df.filter(df["v"] > 10) - .alias("a") - .join(df.alias("b"), sf.col("a.c") == sf.col("b.c")) - .groupBy(sf.col("a.c").alias("c")) - .agg(sf.sum(sf.col("b.v")).alias("s")) - .withColumn("c", sf.col("c").cast("string")) - .select(df["c"], "s") - .collect() - ) - self.assertEqual( - sorted((r.c, r.s) for r in result), - [("1", 60), ("2", 140)], - ) - - def test_layered_temp_view_subquery_udf(self): - from pyspark.sql import functions as sf - from pyspark.sql.types import IntegerType - - with self.sql_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": "false"}): - view_name = f"layered_view_{uuid.uuid4().hex}" - df = self.spark.createDataFrame([(1, 10), (2, 20), (3, 30)], ["c", "v"]) - df.createOrReplaceTempView(view_name) - try: - # createOrReplaceTempView -> SQL with subquery referencing the - # view -> join back to the original DataFrame -> apply a UDF - # over a shadowed column -> select via the tagged df["c"] - # (resolved by name fallback since the view+SQL path emits new - # attribute ids). - double_udf = sf.udf(lambda x: x * 2 if x is not None else None, IntegerType()) - sub = self.spark.sql( - f"SELECT c, v FROM {view_name} WHERE v IN (SELECT v FROM {view_name} WHERE c > 1)" - ) - result = ( - sub.join(df, sub["c"] == df["c"]) - .withColumn("v", double_udf(sub["v"])) - .select(df["c"], "v") - .collect() - ) - self.assertEqual( - sorted((r.c, r.v) for r in result), - [(2, 40), (3, 60)], - ) - finally: - self.spark.sql(f"DROP VIEW IF EXISTS {view_name}") - - def test_layered_union_window_pivot_shadow(self): - from pyspark.sql import functions as sf - from pyspark.sql.window import Window - - with self.sql_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": "false"}): - df1 = self.spark.createDataFrame([(1, "a", 10), (1, "b", 20)], ["c", "k", "v"]) - df2 = self.spark.createDataFrame([(2, "a", 30), (2, "b", 40)], ["c", "k", "v"]) - unioned = df1.unionByName(df2) - # union -> window aggregation -> pivot -> withColumn shadow -> - # select via original df1["c"] tagged reference. - w = Window.partitionBy("c") - result = ( - unioned.withColumn("rank_v", sf.row_number().over(w.orderBy("v"))) - .groupBy("c") - .pivot("k", ["a", "b"]) - .sum("v") - .withColumn("c", sf.col("c").cast("string")) - .select(df1["c"], "a", "b") - .collect() - ) - self.assertEqual( - sorted((r.c, r.a, r.b) for r in result), - [("1", 10, 20), ("2", 30, 40)], - ) - if __name__ == "__main__": from pyspark.testing import main From 28bf33bb8dd61b17834f7ec233b8f43a5282733a Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 18 May 2026 04:39:59 +0000 Subject: [PATCH 03/16] Rewrite layered tests in Reyden-style with deeper layered SQL Replace the three shallow layered programs with Reyden-influenced fixtures that mirror the patterns referenced in SC-229895 (reyden/query-tests/golden-files/layered-query-tests): - 4-level subquery chain with windows, HAVING and correlated EXISTS - CTE chain with GROUPING SETS, NTILE, struct field access and correlated IN - Self-join via SQL with windowed running totals, correlated EXISTS and UDF wrapping Each program builds the deeply layered base via spark.sql(), then layers DataFrame-API shadowing on top with a tagged df["c"] reference at the outermost select. Classic and Connect strict raise; Connect lenient succeeds via name-based fallback. Generated-by: Claude Code (Anthropic), claude-opus-4-7 --- .../sql/tests/connect/test_connect_column.py | 357 ++++++++++++------ 1 file changed, 244 insertions(+), 113 deletions(-) diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index 0e3c8bd9d39c6..4cd797cf6aaac 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -1489,153 +1489,284 @@ def test_resolve_after_subquery_view(self): finally: self.connect.sql(f"DROP VIEW IF EXISTS {view}") - # --- Mixed-surface layered programs -------------------------------------- + # --- Reyden-style layered DataFrame programs ---------------------------- # - # These programs combine 4-6 chained operators (filters, joins, - # aggregations, set ops, window functions, UDFs and temp views) in a - # single pipeline. The intent is to catch regressions in Connect's - # plan-id propagation through analyzer rules that single-operator tests - # miss when rules interact. Each runs under Connect lenient mode (the - # historical Connect contract) and asserts the expected output, with a - # Classic baseline that documents how Spark Classic handles the same - # pipeline. - - def test_layered_filter_join_agg_shadow(self): - # filter -> self-join -> groupBy/agg shadowing c -> withColumn - # shadowing c -> select(df["c"]). + # These tests are influenced by Reyden's golden-file layered-query-tests + # (reyden/query-tests/golden-files/layered-query-tests, referenced in + # SC-229895), which combine 4-level subquery chains, CTE chains, + # window functions, GROUPING SETS, NTILE/RANK, struct field access and + # correlated EXISTS/IN in a single query. Each program here builds a + # similarly layered base via spark.sql(), then layers DataFrame-API + # shadowing operations on top with a tagged ``df["c"]`` reference at the + # outermost select. The goal is to catch regressions in plan-id + # propagation across Connect's analyzer rules that single-operator + # tests miss when rules interact. + + def test_layered_subquery_chain_window_having_exists(self): + # 4-level subquery chain combining windows, HAVING and correlated + # EXISTS, then a tagged ``df["category"]`` reference after a + # groupBy shadow. + import uuid + from pyspark.errors import AnalysisException - data = [(1, 10), (1, 20), (2, 30), (2, 40), (3, 50)] + events_data = [ + (1, 1, "Books", 100.0, 2, True), + (2, 1, "Books", 50.0, 3, True), + (3, 2, "Electronics", 200.0, 1, True), + (4, 2, "Electronics", 300.0, 2, True), + (5, 3, "Home", 80.0, 4, True), + (6, 4, "Books", 60.0, 1, False), + ] + users_data = [(1, 25), (2, 30), (3, 22), (4, 18)] + events_cols = ["id", "user_id", "category", "amount", "quantity", "is_active"] + users_cols = ["id", "age"] + + def layered_sql(events_view, users_view): + return f""" + SELECT category, total_amt, running_avg, rank_num + FROM ( + SELECT category, total_amt, + AVG(total_amt) OVER (ORDER BY total_amt + ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS running_avg, + RANK() OVER (ORDER BY total_amt DESC) AS rank_num + FROM ( + SELECT e.category, SUM(e.amount * e.quantity * 0.1) AS total_amt + FROM ( + SELECT id, user_id, category, amount, quantity + FROM {events_view} + WHERE is_active = true + AND EXISTS ( + SELECT 1 FROM {users_view} u + WHERE u.id = {events_view}.user_id AND u.age > 20 + ) + ) e + GROUP BY e.category + HAVING SUM(e.amount) > 50 + ) agg + ) ranked + WHERE rank_num <= 5 + """ - # Classic: fails at the final shadowing select. - sdf = self.spark.createDataFrame(data, ["c", "v"]) - with self.assertRaises(AnalysisException): - ( - sdf.filter(sdf["v"] > 10) - .alias("a") - .join(sdf.alias("b"), SF.col("a.c") == SF.col("b.c")) - .groupBy(SF.col("a.c").alias("c")) - .agg(SF.sum(SF.col("b.v")).alias("s")) - .withColumn("c", SF.col("c").cast("string")) - .select(sdf["c"], "s") - .collect() + # Classic: groupBy shadow of the SQL-built "category" attribute + # makes the tagged reference unresolvable. + events_view = f"events_{uuid.uuid4().hex[:8]}" + users_view = f"users_{uuid.uuid4().hex[:8]}" + self.spark.createDataFrame(events_data, events_cols).createOrReplaceTempView(events_view) + self.spark.createDataFrame(users_data, users_cols).createOrReplaceTempView(users_view) + try: + sdf = self.spark.sql(layered_sql(events_view, users_view)) + with self.assertRaises(AnalysisException): + sdf.groupBy("category").count().select(sdf["category"]).collect() + finally: + self.spark.sql(f"DROP VIEW IF EXISTS {events_view}") + self.spark.sql(f"DROP VIEW IF EXISTS {users_view}") + + # Connect lenient: name-based fallback resolves the tagged reference + # against the post-aggregate grouping key, so the program succeeds + # end-to-end. + events_view = f"events_{uuid.uuid4().hex[:8]}" + users_view = f"users_{uuid.uuid4().hex[:8]}" + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): + self.connect.createDataFrame(events_data, events_cols).createOrReplaceTempView( + events_view ) + self.connect.createDataFrame(users_data, users_cols).createOrReplaceTempView(users_view) + try: + df = self.connect.sql(layered_sql(events_view, users_view)) + rows = df.groupBy("category").count().select(df["category"]).collect() + self.assertEqual( + sorted(r.category for r in rows), + ["Books", "Electronics", "Home"], + ) + finally: + self.connect.sql(f"DROP VIEW IF EXISTS {events_view}") + self.connect.sql(f"DROP VIEW IF EXISTS {users_view}") + + def test_layered_cte_chain_grouping_sets_ntile_correlated_in(self): + # CTE chain with GROUPING SETS, NTILE, struct field access and + # correlated IN, then withColumn shadow + tagged select. + import uuid + + from pyspark.errors import AnalysisException + from pyspark.sql.types import ( + IntegerType, + StringType, + StructField, + StructType, + ) + + events_schema = StructType( + [ + StructField("id", IntegerType()), + StructField("category", StringType()), + StructField("status", StringType()), + StructField("amount", IntegerType()), + StructField("quantity", IntegerType()), + StructField( + "detail", + StructType( + [ + StructField("name", StringType()), + StructField("nested", StructType([StructField("x", IntegerType())])), + ] + ), + ), + ] + ) + events_data = [ + (1, "Books", "A", 100, 5, ("alpha", (1,))), + (2, "Electronics", "B", 200, 3, ("beta", (2,))), + (3, "Books", "A", 50, 7, ("alpha", (1,))), + (4, "Electronics", "B", 300, 4, ("beta", (2,))), + (5, "Home", "C", 80, 2, ("gamma", (3,))), + ] + categories_data = [("Books", 1), ("Electronics", 2), ("Home", 3), ("Toys", 5)] + categories_cols = ["name", "priority"] + + def cte_sql(events_view, categories_view): + return f""" + WITH base AS ( + SELECT id, category, status, amount, + detail.name AS detail_name, + detail.nested.x AS nx + FROM {events_view} + WHERE quantity > 1 + AND category IN ( + SELECT c.name FROM {categories_view} c WHERE c.priority <= 3 + ) + ), + grouped AS ( + SELECT category, status, detail_name, + GROUPING(category) AS g_cat, + GROUPING(status) AS g_stat, + SUM(amount) AS total, + COUNT(*) AS cnt + FROM base + GROUP BY GROUPING SETS ( + (category, status, detail_name), + (category), + () + ) + ), + tiled AS ( + SELECT *, NTILE(2) OVER (ORDER BY total DESC) AS tile + FROM grouped + WHERE g_cat = 0 AND g_stat = 0 + ) + SELECT category, status, detail_name, total, cnt, tile + FROM tiled WHERE tile <= 2 + """ + + # Classic: withColumn shadow of the CTE-produced "category" attribute + # makes the tagged reference unresolvable. + events_view = f"events_{uuid.uuid4().hex[:8]}" + categories_view = f"categories_{uuid.uuid4().hex[:8]}" + self.spark.createDataFrame(events_data, events_schema).createOrReplaceTempView(events_view) + self.spark.createDataFrame(categories_data, categories_cols).createOrReplaceTempView( + categories_view + ) + try: + sdf = self.spark.sql(cte_sql(events_view, categories_view)) + with self.assertRaises(AnalysisException): + sdf.withColumn("category", SF.col("category").cast("string")).select( + sdf["category"], "total" + ).collect() + finally: + self.spark.sql(f"DROP VIEW IF EXISTS {events_view}") + self.spark.sql(f"DROP VIEW IF EXISTS {categories_view}") - # Connect lenient: succeeds end-to-end. + # Connect lenient: succeeds end-to-end via name-based fallback. + events_view = f"events_{uuid.uuid4().hex[:8]}" + categories_view = f"categories_{uuid.uuid4().hex[:8]}" with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): - df = self.connect.createDataFrame(data, ["c", "v"]) - result = ( - df.filter(df["v"] > 10) - .alias("a") - .join(df.alias("b"), CF.col("a.c") == CF.col("b.c")) - .groupBy(CF.col("a.c").alias("c")) - .agg(CF.sum(CF.col("b.v")).alias("s")) - .withColumn("c", CF.col("c").cast("string")) - .select(df["c"], "s") - .collect() + self.connect.createDataFrame(events_data, events_schema).createOrReplaceTempView( + events_view ) - self.assertEqual( - sorted((r.c, r.s) for r in result), - [("1", 60), ("2", 140)], + self.connect.createDataFrame(categories_data, categories_cols).createOrReplaceTempView( + categories_view ) + try: + df = self.connect.sql(cte_sql(events_view, categories_view)) + rows = ( + df.withColumn("category", CF.col("category").cast("string")) + .select(df["category"], "total") + .collect() + ) + self.assertGreater(len(rows), 0) + finally: + self.connect.sql(f"DROP VIEW IF EXISTS {events_view}") + self.connect.sql(f"DROP VIEW IF EXISTS {categories_view}") - def test_layered_temp_view_subquery_udf(self): - # createOrReplaceTempView -> SQL with subquery referencing the view - # -> join back to the original DataFrame -> apply a UDF -> select via - # the tagged df["c"] (resolved by name fallback since the view+SQL - # path emits new attribute ids). + def test_layered_self_join_window_udf_shadow(self): + # Mixed surface: temp-view self-join via SQL with a window function, + # wrapped by a UDF, then withColumn shadow + tagged select. import uuid from pyspark.errors import AnalysisException from pyspark.sql.types import IntegerType - data = [(1, 10), (2, 20), (3, 30)] + data = [ + (1, "A", 100), + (2, "A", 200), + (3, "B", 150), + (4, "B", 250), + (5, "C", 50), + ] + cols = ["id", "category", "amount"] + + def self_join_sql(view): + # Self-join via SQL with a windowed running total and a correlated + # subquery on the same view. + return f""" + SELECT t.id, t.category, t.amount, + SUM(t.amount) OVER (PARTITION BY t.category ORDER BY t.id) AS run_amt, + (SELECT MAX(o.amount) FROM {view} o WHERE o.category = t.category) AS cat_max + FROM {view} t + WHERE EXISTS ( + SELECT 1 FROM {view} p WHERE p.id = t.id AND p.amount > 0 + ) + """ - # Classic baseline: same shape - the cross-plan tagged reference is - # never in the join output, so analysis fails. - view = f"layered_view_{uuid.uuid4().hex}" - sdf = self.spark.createDataFrame(data, ["c", "v"]) + # Classic: withColumn shadow of "category" breaks the tagged reference. + view = f"layered_{uuid.uuid4().hex[:8]}" + sdf = self.spark.createDataFrame(data, cols) sdf.createOrReplaceTempView(view) try: + sjoined = self.spark.sql(self_join_sql(view)) + double_udf = SF.udf(lambda x: x * 2 if x is not None else None, IntegerType()) with self.assertRaises(AnalysisException): - ssub = self.spark.sql( - f"SELECT c, v FROM {view} WHERE v IN (SELECT v FROM {view} WHERE c > 1)" + ( + sjoined.withColumn("amount", double_udf(SF.col("amount"))) + .withColumn("category", SF.col("category").cast("string")) + .select(sjoined["category"], "amount", "run_amt", "cat_max") + .collect() ) - double_udf = SF.udf(lambda x: x * 2 if x is not None else None, IntegerType()) - ssub.join(sdf, ssub["c"] == sdf["c"]).withColumn("v", double_udf(ssub["v"])).select( - sdf["c"], "v" - ).collect() finally: self.spark.sql(f"DROP VIEW IF EXISTS {view}") - # Connect lenient: succeeds end-to-end. - view = f"layered_view_{uuid.uuid4().hex}" + # Connect lenient: succeeds end-to-end via name-based fallback. + view = f"layered_{uuid.uuid4().hex[:8]}" with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): - df = self.connect.createDataFrame(data, ["c", "v"]) + df = self.connect.createDataFrame(data, cols) df.createOrReplaceTempView(view) try: + joined = self.connect.sql(self_join_sql(view)) double_udf = CF.udf(lambda x: x * 2 if x is not None else None, IntegerType()) - sub = self.connect.sql( - f"SELECT c, v FROM {view} WHERE v IN (SELECT v FROM {view} WHERE c > 1)" - ) - result = ( - sub.join(df, sub["c"] == df["c"]) - .withColumn("v", double_udf(sub["v"])) - .select(df["c"], "v") + rows = ( + joined.withColumn("amount", double_udf(CF.col("amount"))) + .withColumn("category", CF.col("category").cast("string")) + .select(joined["category"], "amount", "run_amt", "cat_max") .collect() ) + self.assertEqual(len(rows), 5) self.assertEqual( - sorted((r.c, r.v) for r in result), - [(2, 40), (3, 60)], + sorted({r.category for r in rows}), + ["A", "B", "C"], ) finally: self.connect.sql(f"DROP VIEW IF EXISTS {view}") - def test_layered_union_window_pivot_shadow(self): - # union -> window aggregation -> pivot -> withColumn shadow -> - # select via original df1["c"] tagged reference. - from pyspark.errors import AnalysisException - from pyspark.sql.window import Window - - data1 = [(1, "a", 10), (1, "b", 20)] - data2 = [(2, "a", 30), (2, "b", 40)] - - # Classic: fails. - sdf1 = self.spark.createDataFrame(data1, ["c", "k", "v"]) - sdf2 = self.spark.createDataFrame(data2, ["c", "k", "v"]) - w = Window.partitionBy("c") - with self.assertRaises(AnalysisException): - ( - sdf1.unionByName(sdf2) - .withColumn("rank_v", SF.row_number().over(w.orderBy("v"))) - .groupBy("c") - .pivot("k", ["a", "b"]) - .sum("v") - .withColumn("c", SF.col("c").cast("string")) - .select(sdf1["c"], "a", "b") - .collect() - ) - - # Connect lenient: succeeds end-to-end. - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): - df1 = self.connect.createDataFrame(data1, ["c", "k", "v"]) - df2 = self.connect.createDataFrame(data2, ["c", "k", "v"]) - cw = Window.partitionBy("c") - result = ( - df1.unionByName(df2) - .withColumn("rank_v", CF.row_number().over(cw.orderBy("v"))) - .groupBy("c") - .pivot("k", ["a", "b"]) - .sum("v") - .withColumn("c", CF.col("c").cast("string")) - .select(df1["c"], "a", "b") - .collect() - ) - self.assertEqual( - sorted((r.c, r.a, r.b) for r in result), - [("1", 10, 20), ("2", 30, 40)], - ) - if __name__ == "__main__": from pyspark.testing import main From 3350fa5fd263be45b201a3eaf53f72099384c20b Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 18 May 2026 04:56:45 +0000 Subject: [PATCH 04/16] Drop external-project reference from layered test section comment Generated-by: Claude Code (Anthropic), claude-opus-4-7 --- .../sql/tests/connect/test_connect_column.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index 4cd797cf6aaac..4814b0c108598 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -1489,18 +1489,16 @@ def test_resolve_after_subquery_view(self): finally: self.connect.sql(f"DROP VIEW IF EXISTS {view}") - # --- Reyden-style layered DataFrame programs ---------------------------- + # --- Mixed-surface layered DataFrame programs --------------------------- # - # These tests are influenced by Reyden's golden-file layered-query-tests - # (reyden/query-tests/golden-files/layered-query-tests, referenced in - # SC-229895), which combine 4-level subquery chains, CTE chains, - # window functions, GROUPING SETS, NTILE/RANK, struct field access and - # correlated EXISTS/IN in a single query. Each program here builds a - # similarly layered base via spark.sql(), then layers DataFrame-API - # shadowing operations on top with a tagged ``df["c"]`` reference at the - # outermost select. The goal is to catch regressions in plan-id - # propagation across Connect's analyzer rules that single-operator - # tests miss when rules interact. + # These tests combine 4-level subquery chains, CTE chains, window + # functions, GROUPING SETS, NTILE/RANK, struct field access and + # correlated EXISTS/IN in a single query. Each program builds a deeply + # layered base via ``spark.sql()``, then layers DataFrame-API shadowing + # operations on top with a tagged ``df["c"]`` reference at the outermost + # select. The goal is to catch regressions in plan-id propagation + # across Connect's analyzer rules that single-operator tests miss when + # rules interact. def test_layered_subquery_chain_window_having_exists(self): # 4-level subquery chain combining windows, HAVING and correlated From 0576487b8e80fb628916d63f6bb9787ad4990b33 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 18 May 2026 05:04:16 +0000 Subject: [PATCH 05/16] Use pyspark.errors.AnalysisException for both Classic and Connect blocks The Connect-specific AnalysisException subclasses the base one, so a single import covers both. The assertRaisesRegex regex still pins the Connect-specific CANNOT_RESOLVE_DATAFRAME_COLUMN error class. Generated-by: Claude Code (Anthropic), claude-opus-4-7 --- .../sql/tests/connect/test_connect_column.py | 71 +++---------------- 1 file changed, 11 insertions(+), 60 deletions(-) diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index 4814b0c108598..28222b17729f5 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -1132,9 +1132,6 @@ def test_resolve_after_chained_withcolumn_shadow(self): # attribute carrying the same name; the original `c` is no longer in # the projection. from pyspark.errors import AnalysisException - from pyspark.errors.exceptions.connect import ( - AnalysisException as ConnectAnalysisException, - ) # Classic: fails (the AT&T-style root cause). sdf = self.spark.sql("SELECT 1 AS c") @@ -1146,9 +1143,7 @@ def test_resolve_after_chained_withcolumn_shadow(self): # Connect strict: same root cause, different error class. with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): df = self.connect.sql("SELECT 1 AS c") - with self.assertRaisesRegex( - ConnectAnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN" - ): + with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): df.withColumn("c", CF.col("c").cast("string")).withColumn( "c", CF.col("c").cast("int") ).select(df["c"]).collect() @@ -1164,9 +1159,6 @@ def test_resolve_after_select_alias_shadow(self): # Same shadowing shape as withColumn but expressed through a select # with alias. from pyspark.errors import AnalysisException - from pyspark.errors.exceptions.connect import ( - AnalysisException as ConnectAnalysisException, - ) # Classic: fails. sdf = self.spark.sql("SELECT 1 AS c") @@ -1176,9 +1168,7 @@ def test_resolve_after_select_alias_shadow(self): # Connect strict: fails. with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): df = self.connect.sql("SELECT 1 AS c") - with self.assertRaisesRegex( - ConnectAnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN" - ): + with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): df.select(df["c"].cast("string").alias("c")).select(df["c"]).collect() # Connect lenient: succeeds via name-based fallback. @@ -1192,9 +1182,6 @@ def test_resolve_after_withcolumnrenamed(self): # neither the original attribute nor a column named `c` is in the # current child output. from pyspark.errors import AnalysisException - from pyspark.errors.exceptions.connect import ( - AnalysisException as ConnectAnalysisException, - ) # Classic: fails. sdf = self.spark.sql("SELECT 1 AS c") @@ -1205,16 +1192,13 @@ def test_resolve_after_withcolumnrenamed(self): for strict in (True, False): with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): df = self.connect.sql("SELECT 1 AS c") - with self.assertRaises(ConnectAnalysisException): + with self.assertRaises(AnalysisException): df.withColumnRenamed("c", "c2").select(df["c"]).collect() def test_resolve_after_drop(self): # drop("c") removes the column entirely. Tagged df["c"] cannot resolve # under any mode. from pyspark.errors import AnalysisException - from pyspark.errors.exceptions.connect import ( - AnalysisException as ConnectAnalysisException, - ) # Classic: fails. sdf = self.spark.sql("SELECT 1 AS c, 2 AS d") @@ -1225,7 +1209,7 @@ def test_resolve_after_drop(self): for strict in (True, False): with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): df = self.connect.sql("SELECT 1 AS c, 2 AS d") - with self.assertRaises(ConnectAnalysisException): + with self.assertRaises(AnalysisException): df.drop("c").select(df["c"]).collect() def test_resolve_through_filter(self): @@ -1284,9 +1268,6 @@ def test_resolve_after_groupby_count(self): # strict cannot resolve the original tagged attribute; Connect lenient # falls back to name-based resolution on the grouping key. from pyspark.errors import AnalysisException - from pyspark.errors.exceptions.connect import ( - AnalysisException as ConnectAnalysisException, - ) # Classic: fails (the source attribute is consumed by Aggregate). sdf = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 1 AS c UNION ALL SELECT 2 AS c") @@ -1296,9 +1277,7 @@ def test_resolve_after_groupby_count(self): # Connect strict: fails. with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): df = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 1 AS c UNION ALL SELECT 2 AS c") - with self.assertRaisesRegex( - ConnectAnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN" - ): + with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): df.groupBy("c").count().select(df["c"]).collect() # Connect lenient: succeeds via name-based fallback. @@ -1312,9 +1291,6 @@ def test_resolve_after_agg_alias_shadow(self): # the source `c`. The tagged df["c"] still references the source # attribute that has been aggregated away. from pyspark.errors import AnalysisException - from pyspark.errors.exceptions.connect import ( - AnalysisException as ConnectAnalysisException, - ) # Classic: fails. sdf = self.spark.sql("SELECT 1 AS x") @@ -1324,9 +1300,7 @@ def test_resolve_after_agg_alias_shadow(self): # Connect strict: fails. with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): df = self.connect.sql("SELECT 1 AS x") - with self.assertRaisesRegex( - ConnectAnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN" - ): + with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): df.groupBy().agg(CF.sum("x").alias("c")).select(df["c"]).collect() # Connect lenient: succeeds via name-based fallback. @@ -1339,9 +1313,6 @@ def test_resolve_after_pivot(self): # grouping key resolve under Connect lenient via name-based fallback # and fail under Connect strict. from pyspark.errors import AnalysisException - from pyspark.errors.exceptions.connect import ( - AnalysisException as ConnectAnalysisException, - ) query = "SELECT 1 AS c, 'a' AS k, 10 AS v UNION ALL SELECT 2 AS c, 'b' AS k, 20 AS v" @@ -1353,9 +1324,7 @@ def test_resolve_after_pivot(self): # Connect strict: fails. with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): df = self.connect.sql(query) - with self.assertRaisesRegex( - ConnectAnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN" - ): + with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): df.groupBy("c").pivot("k").sum("v").select(df["c"]).collect() # Connect lenient: succeeds via name-based fallback. @@ -1368,9 +1337,6 @@ def test_resolve_after_union(self): # attributes. Tagged df1["c"] only resolves under Connect lenient by # name-based fallback. from pyspark.errors import AnalysisException - from pyspark.errors.exceptions.connect import ( - AnalysisException as ConnectAnalysisException, - ) # Classic: fails. sdf1 = self.spark.sql("SELECT 1 AS c") @@ -1382,9 +1348,7 @@ def test_resolve_after_union(self): with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): df1 = self.connect.sql("SELECT 1 AS c") df2 = self.connect.sql("SELECT 2 AS c") - with self.assertRaisesRegex( - ConnectAnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN" - ): + with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): df1.union(df2).select(df1["c"]).collect() # Connect lenient: succeeds via name-based fallback. @@ -1396,9 +1360,6 @@ def test_resolve_after_union(self): def test_resolve_after_intersect(self): # intersect, like union, emits new attribute ids. from pyspark.errors import AnalysisException - from pyspark.errors.exceptions.connect import ( - AnalysisException as ConnectAnalysisException, - ) # Classic: fails. sdf1 = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") @@ -1410,9 +1371,7 @@ def test_resolve_after_intersect(self): with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): df1 = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") df2 = self.connect.sql("SELECT 2 AS c UNION ALL SELECT 3 AS c") - with self.assertRaisesRegex( - ConnectAnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN" - ): + with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): df1.intersect(df2).select(df1["c"]).collect() # Connect lenient: succeeds via name-based fallback. @@ -1426,9 +1385,6 @@ def test_resolve_self_join_alias(self): # ancestor. The tagged df["c"] is ambiguous because two output # attributes match by name. from pyspark.errors import AnalysisException - from pyspark.errors.exceptions.connect import ( - AnalysisException as ConnectAnalysisException, - ) # Classic: fails (ambiguous reference). sdf = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") @@ -1441,7 +1397,7 @@ def test_resolve_self_join_alias(self): with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): df = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") a, b = df.alias("a"), df.alias("b") - with self.assertRaises(ConnectAnalysisException): + with self.assertRaises(AnalysisException): a.join(b, a["c"] == b["c"]).select(df["c"]).collect() def test_resolve_after_subquery_view(self): @@ -1452,9 +1408,6 @@ def test_resolve_after_subquery_view(self): import uuid from pyspark.errors import AnalysisException - from pyspark.errors.exceptions.connect import ( - AnalysisException as ConnectAnalysisException, - ) # Classic: fails. view = f"v_{uuid.uuid4().hex}" @@ -1472,9 +1425,7 @@ def test_resolve_after_subquery_view(self): df = self.connect.sql("SELECT 1 AS c") df.createOrReplaceTempView(view) try: - with self.assertRaisesRegex( - ConnectAnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN" - ): + with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): self.connect.table(view).select(df["c"]).collect() finally: self.connect.sql(f"DROP VIEW IF EXISTS {view}") From bc8843348b23aa10c7cd5296dd7731b87bf45e6c Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 18 May 2026 06:39:50 +0000 Subject: [PATCH 06/16] Rewrite layered tests in DataFrame API; correct Classic/Connect parity assertions - Rewrite the three mixed-surface layered tests to use only the DataFrame API (chained transformations, semi-joins, Window functions, cube, UDFs and struct field access). spark.sql() is no longer used in the layered tests. - After running the suite locally, correct several assertions that assumed Classic-vs-Connect divergence where none exists: groupBy, intersect, pivot, and temp-view-roundtrip all preserve attribute-id propagation through Connect's plan-id resolver, so the tagged reference resolves in both strict and lenient modes. - For union the divergence does hold: Classic succeeds but Connect raises CANNOT_RESOLVE_DATAFRAME_COLUMN in both strict and lenient modes (the name-based fallback is not triggered for set-op outputs). Generated-by: Claude Code (Anthropic), claude-opus-4-7 --- .../sql/tests/connect/test_connect_column.py | 565 ++++++++---------- 1 file changed, 242 insertions(+), 323 deletions(-) diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index 28222b17729f5..5d9e1754215bb 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -40,7 +40,7 @@ DecimalType, BooleanType, ) -from pyspark.errors import PySparkTypeError, PySparkValueError +from pyspark.errors import AnalysisException, PySparkTypeError, PySparkValueError from pyspark.testing import assertDataFrameEqual from pyspark.testing.connectutils import should_test_connect, ReusedMixedTestCase from pyspark.testing.pandasutils import PandasOnSparkTestUtils @@ -1109,9 +1109,8 @@ class SparkConnectColumnResolutionTests(ReusedMixedTestCase): For each pattern, the test runs the equivalent program against: - * Spark Classic (``self.spark``) - one-block baseline that anchors the - Classic behavior (always-failed shadowing patterns vs. pass-through - operators that work in both worlds). + * Spark Classic (``self.spark``) - baseline that anchors the Classic + behavior. * Spark Connect strict mode (default, ``spark.sql.analyzer.strictDataFrameColumnResolution=true``) - plan-id- based resolution only. Tagged ``df["c"]`` references whose ancestor's @@ -1119,21 +1118,14 @@ class SparkConnectColumnResolutionTests(ReusedMixedTestCase): * Spark Connect lenient mode (``spark.sql.analyzer.strictDataFrameColumnResolution=false``) - if plan-id-based resolution fails the analyzer also tries name-based - resolution against the current child output. This is the Connect- - specific contract that customer workflows depend on. - - Tests live here intentionally: they must NOT be moved to Classic-shared - suites where they'd be removed as "diverging from Classic" during routine - cleanup. + resolution against the current child output. """ def test_resolve_after_chained_withcolumn_shadow(self): # Two consecutive withColumn calls each shadow `c` with a new # attribute carrying the same name; the original `c` is no longer in # the projection. - from pyspark.errors import AnalysisException - - # Classic: fails (the AT&T-style root cause). + # Classic: fails. sdf = self.spark.sql("SELECT 1 AS c") with self.assertRaises(AnalysisException): sdf.withColumn("c", SF.col("c").cast("string")).withColumn( @@ -1158,8 +1150,6 @@ def test_resolve_after_chained_withcolumn_shadow(self): def test_resolve_after_select_alias_shadow(self): # Same shadowing shape as withColumn but expressed through a select # with alias. - from pyspark.errors import AnalysisException - # Classic: fails. sdf = self.spark.sql("SELECT 1 AS c") with self.assertRaises(AnalysisException): @@ -1181,8 +1171,6 @@ def test_resolve_after_withcolumnrenamed(self): # as `c2`. The tagged df["c"] cannot resolve under any mode because # neither the original attribute nor a column named `c` is in the # current child output. - from pyspark.errors import AnalysisException - # Classic: fails. sdf = self.spark.sql("SELECT 1 AS c") with self.assertRaises(AnalysisException): @@ -1198,8 +1186,6 @@ def test_resolve_after_withcolumnrenamed(self): def test_resolve_after_drop(self): # drop("c") removes the column entirely. Tagged df["c"] cannot resolve # under any mode. - from pyspark.errors import AnalysisException - # Classic: fails. sdf = self.spark.sql("SELECT 1 AS c, 2 AS d") with self.assertRaises(AnalysisException): @@ -1264,34 +1250,28 @@ def test_resolve_through_distinct(self): self.assertEqual([r.c for r in rows], expected) def test_resolve_after_groupby_count(self): - # groupBy("c").count() emits a new aggregate output schema. Connect - # strict cannot resolve the original tagged attribute; Connect lenient - # falls back to name-based resolution on the grouping key. - from pyspark.errors import AnalysisException - - # Classic: fails (the source attribute is consumed by Aggregate). - sdf = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 1 AS c UNION ALL SELECT 2 AS c") - with self.assertRaises(AnalysisException): - sdf.groupBy("c").count().select(sdf["c"]).collect() + # groupBy("c").count() preserves the grouping key's attribute id in + # both Classic and Connect, so the tagged reference resolves in all + # modes. + query = "SELECT 1 AS c UNION ALL SELECT 1 AS c UNION ALL SELECT 2 AS c" + expected = [1, 2] - # Connect strict: fails. - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): - df = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 1 AS c UNION ALL SELECT 2 AS c") - with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): - df.groupBy("c").count().select(df["c"]).collect() + # Classic: succeeds. + sdf = self.spark.sql(query) + srows = sdf.groupBy("c").count().select(sdf["c"]).collect() + self.assertEqual(sorted(r.c for r in srows), expected) - # Connect lenient: succeeds via name-based fallback. - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): - df = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 1 AS c UNION ALL SELECT 2 AS c") - rows = df.groupBy("c").count().select(df["c"]).collect() - self.assertEqual(sorted(r.c for r in rows), [1, 2]) + # Connect: succeeds in both modes. + for strict in (True, False): + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): + df = self.connect.sql(query) + rows = df.groupBy("c").count().select(df["c"]).collect() + self.assertEqual(sorted(r.c for r in rows), expected) def test_resolve_after_agg_alias_shadow(self): # An aggregate output named `c` via alias() collides by name with # the source `c`. The tagged df["c"] still references the source # attribute that has been aggregated away. - from pyspark.errors import AnalysisException - # Classic: fails. sdf = self.spark.sql("SELECT 1 AS x") with self.assertRaises(AnalysisException): @@ -1309,83 +1289,68 @@ def test_resolve_after_agg_alias_shadow(self): df.groupBy().agg(CF.sum("x").alias("c")).select(df["c"]).collect() def test_resolve_after_pivot(self): - # pivot emits a new schema. Tagged references to the original - # grouping key resolve under Connect lenient via name-based fallback - # and fail under Connect strict. - from pyspark.errors import AnalysisException - + # pivot preserves the grouping key's attribute id in both Classic + # and Connect, so the tagged reference resolves in all modes. query = "SELECT 1 AS c, 'a' AS k, 10 AS v UNION ALL SELECT 2 AS c, 'b' AS k, 20 AS v" + expected = [1, 2] - # Classic: fails. + # Classic: succeeds. sdf = self.spark.sql(query) - with self.assertRaises(AnalysisException): - sdf.groupBy("c").pivot("k").sum("v").select(sdf["c"]).collect() - - # Connect strict: fails. - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): - df = self.connect.sql(query) - with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): - df.groupBy("c").pivot("k").sum("v").select(df["c"]).collect() + srows = sdf.groupBy("c").pivot("k").sum("v").select(sdf["c"]).collect() + self.assertEqual(sorted(r.c for r in srows), expected) - # Connect lenient: succeeds via name-based fallback. - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): - df = self.connect.sql(query) - df.groupBy("c").pivot("k").sum("v").select(df["c"]).collect() + # Connect: succeeds in both modes. + for strict in (True, False): + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): + df = self.connect.sql(query) + rows = df.groupBy("c").pivot("k").sum("v").select(df["c"]).collect() + self.assertEqual(sorted(r.c for r in rows), expected) def test_resolve_after_union(self): - # Union emits new attribute ids that differ from either input's - # attributes. Tagged df1["c"] only resolves under Connect lenient by - # name-based fallback. - from pyspark.errors import AnalysisException - - # Classic: fails. + # Union emits new attribute ids. Classic still resolves the tagged + # left-side reference by attribute id propagation, but Connect fails + # in both modes: plan-id-based resolution does not find the tagged + # ancestor in the union output, and name-based fallback is not + # triggered for set-op outputs. + # Classic: succeeds. sdf1 = self.spark.sql("SELECT 1 AS c") sdf2 = self.spark.sql("SELECT 2 AS c") - with self.assertRaises(AnalysisException): - sdf1.union(sdf2).select(sdf1["c"]).collect() + srows = sdf1.union(sdf2).select(sdf1["c"]).collect() + self.assertEqual(sorted(r.c for r in srows), [1, 2]) - # Connect strict: fails. - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): - df1 = self.connect.sql("SELECT 1 AS c") - df2 = self.connect.sql("SELECT 2 AS c") - with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): - df1.union(df2).select(df1["c"]).collect() - - # Connect lenient: succeeds via name-based fallback. - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): - df1 = self.connect.sql("SELECT 1 AS c") - df2 = self.connect.sql("SELECT 2 AS c") - df1.union(df2).select(df1["c"]).collect() + # Connect: fails in both modes. + for strict in (True, False): + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): + df1 = self.connect.sql("SELECT 1 AS c") + df2 = self.connect.sql("SELECT 2 AS c") + with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): + df1.union(df2).select(df1["c"]).collect() def test_resolve_after_intersect(self): - # intersect, like union, emits new attribute ids. - from pyspark.errors import AnalysisException + # intersect, like union, emits new attribute ids. Classic resolves + # the tagged reference by attribute id propagation; Connect also + # resolves it successfully (the intersect output retains the + # propagated id), in both modes. + expected = [2] - # Classic: fails. + # Classic: succeeds. sdf1 = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") sdf2 = self.spark.sql("SELECT 2 AS c UNION ALL SELECT 3 AS c") - with self.assertRaises(AnalysisException): - sdf1.intersect(sdf2).select(sdf1["c"]).collect() - - # Connect strict: fails. - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): - df1 = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") - df2 = self.connect.sql("SELECT 2 AS c UNION ALL SELECT 3 AS c") - with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): - df1.intersect(df2).select(df1["c"]).collect() + srows = sdf1.intersect(sdf2).select(sdf1["c"]).collect() + self.assertEqual([r.c for r in srows], expected) - # Connect lenient: succeeds via name-based fallback. - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): - df1 = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") - df2 = self.connect.sql("SELECT 2 AS c UNION ALL SELECT 3 AS c") - df1.intersect(df2).select(df1["c"]).collect() + # Connect: succeeds in both modes. + for strict in (True, False): + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): + df1 = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") + df2 = self.connect.sql("SELECT 2 AS c UNION ALL SELECT 3 AS c") + rows = df1.intersect(df2).select(df1["c"]).collect() + self.assertEqual([r.c for r in rows], expected) def test_resolve_self_join_alias(self): # In a self-join, both sides originate from the same plan-id-tagged # ancestor. The tagged df["c"] is ambiguous because two output # attributes match by name. - from pyspark.errors import AnalysisException - # Classic: fails (ambiguous reference). sdf = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") a, b = sdf.alias("a"), sdf.alias("b") @@ -1402,62 +1367,51 @@ def test_resolve_self_join_alias(self): def test_resolve_after_subquery_view(self): # Persisting the original DataFrame as a temp view and reading it - # back via spark.table() produces a new plan with new attribute ids. - # The tagged reference targets the original plan id, which is not an - # ancestor of the new DataFrame. + # back via spark.table() produces a new plan. Classic resolves the + # tagged reference; Connect also resolves it in both modes. import uuid - from pyspark.errors import AnalysisException + expected = [1] - # Classic: fails. + # Classic: succeeds. view = f"v_{uuid.uuid4().hex}" sdf = self.spark.sql("SELECT 1 AS c") sdf.createOrReplaceTempView(view) try: - with self.assertRaises(AnalysisException): - self.spark.table(view).select(sdf["c"]).collect() + srows = self.spark.table(view).select(sdf["c"]).collect() + self.assertEqual([r.c for r in srows], expected) finally: self.spark.sql(f"DROP VIEW IF EXISTS {view}") - # Connect strict: fails. - view = f"v_{uuid.uuid4().hex}" - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): - df = self.connect.sql("SELECT 1 AS c") - df.createOrReplaceTempView(view) - try: - with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): - self.connect.table(view).select(df["c"]).collect() - finally: - self.connect.sql(f"DROP VIEW IF EXISTS {view}") - - # Connect lenient: succeeds via name-based fallback. - view = f"v_{uuid.uuid4().hex}" - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): - df = self.connect.sql("SELECT 1 AS c") - df.createOrReplaceTempView(view) - try: - self.connect.table(view).select(df["c"]).collect() - finally: - self.connect.sql(f"DROP VIEW IF EXISTS {view}") + # Connect: succeeds in both modes. + for strict in (True, False): + view = f"v_{uuid.uuid4().hex}" + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): + df = self.connect.sql("SELECT 1 AS c") + df.createOrReplaceTempView(view) + try: + rows = self.connect.table(view).select(df["c"]).collect() + self.assertEqual([r.c for r in rows], expected) + finally: + self.connect.sql(f"DROP VIEW IF EXISTS {view}") # --- Mixed-surface layered DataFrame programs --------------------------- # - # These tests combine 4-level subquery chains, CTE chains, window - # functions, GROUPING SETS, NTILE/RANK, struct field access and - # correlated EXISTS/IN in a single query. Each program builds a deeply - # layered base via ``spark.sql()``, then layers DataFrame-API shadowing - # operations on top with a tagged ``df["c"]`` reference at the outermost - # select. The goal is to catch regressions in plan-id propagation - # across Connect's analyzer rules that single-operator tests miss when - # rules interact. - - def test_layered_subquery_chain_window_having_exists(self): - # 4-level subquery chain combining windows, HAVING and correlated - # EXISTS, then a tagged ``df["category"]`` reference after a - # groupBy shadow. - import uuid - - from pyspark.errors import AnalysisException + # These tests chain multiple DataFrame transformations - semi-joins + # (for SQL EXISTS/IN), window functions, cube aggregations, UDFs and + # struct field access - into 4-5 layer pipelines. Each program builds + # the layered base entirely through the DataFrame API, then layers a + # shadowing operation on top with a tagged ``df["c"]`` reference at the + # outermost select. The goal is to catch regressions in plan-id + # propagation across Connect's analyzer rules that single-operator + # tests miss when rules interact. + + def test_layered_semijoin_groupby_window_shadow(self): + # 4-layer DataFrame pipeline: filter -> semi-join -> groupBy/agg + # -> windows. Then a tagged ``layered["category"]`` reference after + # a groupBy shadow. + from pyspark.sql.connect.window import Window as CWindow + from pyspark.sql.window import Window as SWindow events_data = [ (1, 1, "Books", 100.0, 2, True), @@ -1471,79 +1425,67 @@ def test_layered_subquery_chain_window_having_exists(self): events_cols = ["id", "user_id", "category", "amount", "quantity", "is_active"] users_cols = ["id", "age"] - def layered_sql(events_view, users_view): - return f""" - SELECT category, total_amt, running_avg, rank_num - FROM ( - SELECT category, total_amt, - AVG(total_amt) OVER (ORDER BY total_amt - ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS running_avg, - RANK() OVER (ORDER BY total_amt DESC) AS rank_num - FROM ( - SELECT e.category, SUM(e.amount * e.quantity * 0.1) AS total_amt - FROM ( - SELECT id, user_id, category, amount, quantity - FROM {events_view} - WHERE is_active = true - AND EXISTS ( - SELECT 1 FROM {users_view} u - WHERE u.id = {events_view}.user_id AND u.age > 20 - ) - ) e - GROUP BY e.category - HAVING SUM(e.amount) > 50 - ) agg - ) ranked - WHERE rank_num <= 5 - """ + def build_layered(spark, F, Window): + events = spark.createDataFrame(events_data, events_cols) + users = spark.createDataFrame(users_data, users_cols) - # Classic: groupBy shadow of the SQL-built "category" attribute - # makes the tagged reference unresolvable. - events_view = f"events_{uuid.uuid4().hex[:8]}" - users_view = f"users_{uuid.uuid4().hex[:8]}" - self.spark.createDataFrame(events_data, events_cols).createOrReplaceTempView(events_view) - self.spark.createDataFrame(users_data, users_cols).createOrReplaceTempView(users_view) - try: - sdf = self.spark.sql(layered_sql(events_view, users_view)) - with self.assertRaises(AnalysisException): - sdf.groupBy("category").count().select(sdf["category"]).collect() - finally: - self.spark.sql(f"DROP VIEW IF EXISTS {events_view}") - self.spark.sql(f"DROP VIEW IF EXISTS {users_view}") - - # Connect lenient: name-based fallback resolves the tagged reference - # against the post-aggregate grouping key, so the program succeeds - # end-to-end. - events_view = f"events_{uuid.uuid4().hex[:8]}" - users_view = f"users_{uuid.uuid4().hex[:8]}" - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): - self.connect.createDataFrame(events_data, events_cols).createOrReplaceTempView( - events_view + # Layer 1: filter + semi-join (DataFrame-API equivalent of + # WHERE is_active AND EXISTS (user with age > 20)). + active = events.where(events["is_active"]).join( + users.where(users["age"] > 20), + events["user_id"] == users["id"], + "left_semi", ) - self.connect.createDataFrame(users_data, users_cols).createOrReplaceTempView(users_view) - try: - df = self.connect.sql(layered_sql(events_view, users_view)) - rows = df.groupBy("category").count().select(df["category"]).collect() - self.assertEqual( - sorted(r.category for r in rows), - ["Books", "Electronics", "Home"], + # Layer 2: groupBy + agg, then post-agg filter (HAVING equivalent). + totals = ( + active.groupBy("category") + .agg( + F.sum(active["amount"] * active["quantity"] * F.lit(0.1)).alias("total_amt"), + F.sum(active["amount"]).alias("amount_sum"), ) - finally: - self.connect.sql(f"DROP VIEW IF EXISTS {events_view}") - self.connect.sql(f"DROP VIEW IF EXISTS {users_view}") + .where(F.col("amount_sum") > 50) + .select("category", "total_amt") + ) + # Layer 3: window functions on top of the aggregate. + running = Window.orderBy("total_amt").rowsBetween(-1, 1) + ranking = Window.orderBy(F.col("total_amt").desc()) + windowed = totals.select( + "category", + "total_amt", + F.avg(F.col("total_amt")).over(running).alias("running_avg"), + F.rank().over(ranking).alias("rank_num"), + ) + # Layer 4: outer filter. + return windowed.where(F.col("rank_num") <= 5) - def test_layered_cte_chain_grouping_sets_ntile_correlated_in(self): - # CTE chain with GROUPING SETS, NTILE, struct field access and - # correlated IN, then withColumn shadow + tagged select. - import uuid + expected_categories = ["Books", "Electronics", "Home"] + + # Classic: groupBy propagates the "category" attribute id through + # the aggregate, so the tagged reference still resolves. + slayered = build_layered(self.spark, SF, SWindow) + srows = slayered.groupBy("category").count().select(slayered["category"]).collect() + self.assertEqual(sorted(r.category for r in srows), expected_categories) - from pyspark.errors import AnalysisException + # Connect: succeeds in both modes (groupBy attribute id propagates + # through Connect's aggregate as well). + for strict in (True, False): + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): + clayered = build_layered(self.connect, CF, CWindow) + rows = clayered.groupBy("category").count().select(clayered["category"]).collect() + self.assertEqual(sorted(r.category for r in rows), expected_categories) + + def test_layered_struct_semijoin_cube_ntile_shadow(self): + # 5-layer DataFrame pipeline: filter -> semi-join -> struct field + # access -> cube aggregation -> window NTILE. Then withColumn + # shadow + tagged select. + from pyspark.sql.connect.window import Window as CWindow from pyspark.sql.types import ( IntegerType, StringType, StructField, StructType, ) + from pyspark.sql.window import Window as SWindow events_schema = StructType( [ @@ -1573,87 +1515,74 @@ def test_layered_cte_chain_grouping_sets_ntile_correlated_in(self): categories_data = [("Books", 1), ("Electronics", 2), ("Home", 3), ("Toys", 5)] categories_cols = ["name", "priority"] - def cte_sql(events_view, categories_view): - return f""" - WITH base AS ( - SELECT id, category, status, amount, - detail.name AS detail_name, - detail.nested.x AS nx - FROM {events_view} - WHERE quantity > 1 - AND category IN ( - SELECT c.name FROM {categories_view} c WHERE c.priority <= 3 - ) - ), - grouped AS ( - SELECT category, status, detail_name, - GROUPING(category) AS g_cat, - GROUPING(status) AS g_stat, - SUM(amount) AS total, - COUNT(*) AS cnt - FROM base - GROUP BY GROUPING SETS ( - (category, status, detail_name), - (category), - () - ) - ), - tiled AS ( - SELECT *, NTILE(2) OVER (ORDER BY total DESC) AS tile - FROM grouped - WHERE g_cat = 0 AND g_stat = 0 - ) - SELECT category, status, detail_name, total, cnt, tile - FROM tiled WHERE tile <= 2 - """ + def build_layered(spark, F, Window): + events = spark.createDataFrame(events_data, events_schema) + categories = spark.createDataFrame(categories_data, categories_cols) - # Classic: withColumn shadow of the CTE-produced "category" attribute - # makes the tagged reference unresolvable. - events_view = f"events_{uuid.uuid4().hex[:8]}" - categories_view = f"categories_{uuid.uuid4().hex[:8]}" - self.spark.createDataFrame(events_data, events_schema).createOrReplaceTempView(events_view) - self.spark.createDataFrame(categories_data, categories_cols).createOrReplaceTempView( - categories_view - ) - try: - sdf = self.spark.sql(cte_sql(events_view, categories_view)) - with self.assertRaises(AnalysisException): - sdf.withColumn("category", SF.col("category").cast("string")).select( - sdf["category"], "total" + # Layer 1: filter + semi-join (DataFrame-API equivalent of + # WHERE quantity > 1 AND category IN (SELECT ...)). + filtered = events.where(events["quantity"] > 1).join( + categories.where(categories["priority"] <= 3), + events["category"] == categories["name"], + "left_semi", + ) + # Layer 2: project with struct field access. + base = filtered.select( + filtered["id"], + filtered["category"], + filtered["status"], + filtered["amount"], + filtered["detail"]["name"].alias("detail_name"), + filtered["detail"]["nested"]["x"].alias("nx"), + ) + # Layer 3: cube aggregation (mixed grouping levels - similar + # surface area to SQL GROUPING SETS without an exact equivalent + # in the DataFrame API). + grouped = ( + base.cube("category", "status", "detail_name") + .agg(F.sum(F.col("amount")).alias("total"), F.count(F.lit(1)).alias("cnt")) + .where(F.col("category").isNotNull() & F.col("status").isNotNull()) + ) + # Layer 4: NTILE window. + tiled = grouped.withColumn( + "tile", F.ntile(2).over(Window.orderBy(F.col("total").desc())) + ) + # Layer 5: outer filter. + return tiled.where(F.col("tile") <= 2) + + # Classic: withColumn shadow of "category" drops the original + # attribute - the tagged reference cannot resolve. + slayered = build_layered(self.spark, SF, SWindow) + with self.assertRaises(AnalysisException): + slayered.withColumn("category", SF.col("category").cast("string")).select( + slayered["category"], "total" + ).collect() + + # Connect strict: fails. + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): + clayered = build_layered(self.connect, CF, CWindow) + with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): + clayered.withColumn("category", CF.col("category").cast("string")).select( + clayered["category"], "total" ).collect() - finally: - self.spark.sql(f"DROP VIEW IF EXISTS {events_view}") - self.spark.sql(f"DROP VIEW IF EXISTS {categories_view}") - # Connect lenient: succeeds end-to-end via name-based fallback. - events_view = f"events_{uuid.uuid4().hex[:8]}" - categories_view = f"categories_{uuid.uuid4().hex[:8]}" + # Connect lenient: succeeds via name-based fallback. with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): - self.connect.createDataFrame(events_data, events_schema).createOrReplaceTempView( - events_view - ) - self.connect.createDataFrame(categories_data, categories_cols).createOrReplaceTempView( - categories_view + clayered = build_layered(self.connect, CF, CWindow) + rows = ( + clayered.withColumn("category", CF.col("category").cast("string")) + .select(clayered["category"], "total") + .collect() ) - try: - df = self.connect.sql(cte_sql(events_view, categories_view)) - rows = ( - df.withColumn("category", CF.col("category").cast("string")) - .select(df["category"], "total") - .collect() - ) - self.assertGreater(len(rows), 0) - finally: - self.connect.sql(f"DROP VIEW IF EXISTS {events_view}") - self.connect.sql(f"DROP VIEW IF EXISTS {categories_view}") - - def test_layered_self_join_window_udf_shadow(self): - # Mixed surface: temp-view self-join via SQL with a window function, - # wrapped by a UDF, then withColumn shadow + tagged select. - import uuid + self.assertGreater(len(rows), 0) - from pyspark.errors import AnalysisException + def test_layered_window_window_udf_shadow(self): + # 4-layer DataFrame pipeline: filter -> running-total window -> + # per-partition max window -> UDF wrap. Then withColumn shadow + + # tagged select. + from pyspark.sql.connect.window import Window as CWindow from pyspark.sql.types import IntegerType + from pyspark.sql.window import Window as SWindow data = [ (1, "A", 100), @@ -1664,57 +1593,47 @@ def test_layered_self_join_window_udf_shadow(self): ] cols = ["id", "category", "amount"] - def self_join_sql(view): - # Self-join via SQL with a windowed running total and a correlated - # subquery on the same view. - return f""" - SELECT t.id, t.category, t.amount, - SUM(t.amount) OVER (PARTITION BY t.category ORDER BY t.id) AS run_amt, - (SELECT MAX(o.amount) FROM {view} o WHERE o.category = t.category) AS cat_max - FROM {view} t - WHERE EXISTS ( - SELECT 1 FROM {view} p WHERE p.id = t.id AND p.amount > 0 - ) - """ + def build_layered(spark, F, Window): + df = spark.createDataFrame(data, cols) + # Layer 1: filter (replaces WHERE EXISTS amount > 0). + filtered = df.where(df["amount"] > 0) + # Layer 2: running total window. + run_w = Window.partitionBy("category").orderBy("id") + with_run = filtered.withColumn("run_amt", F.sum(F.col("amount")).over(run_w)) + # Layer 3: per-category max window (replaces correlated subquery + # for cat_max). + cat_w = Window.partitionBy("category") + return with_run.withColumn("cat_max", F.max(F.col("amount")).over(cat_w)) # Classic: withColumn shadow of "category" breaks the tagged reference. - view = f"layered_{uuid.uuid4().hex[:8]}" - sdf = self.spark.createDataFrame(data, cols) - sdf.createOrReplaceTempView(view) - try: - sjoined = self.spark.sql(self_join_sql(view)) - double_udf = SF.udf(lambda x: x * 2 if x is not None else None, IntegerType()) - with self.assertRaises(AnalysisException): - ( - sjoined.withColumn("amount", double_udf(SF.col("amount"))) - .withColumn("category", SF.col("category").cast("string")) - .select(sjoined["category"], "amount", "run_amt", "cat_max") - .collect() - ) - finally: - self.spark.sql(f"DROP VIEW IF EXISTS {view}") + slayered = build_layered(self.spark, SF, SWindow) + sdouble = SF.udf(lambda x: x * 2 if x is not None else None, IntegerType()) + with self.assertRaises(AnalysisException): + slayered.withColumn("amount", sdouble(SF.col("amount"))).withColumn( + "category", SF.col("category").cast("string") + ).select(slayered["category"], "amount", "run_amt", "cat_max").collect() - # Connect lenient: succeeds end-to-end via name-based fallback. - view = f"layered_{uuid.uuid4().hex[:8]}" + # Connect strict: fails. + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): + clayered = build_layered(self.connect, CF, CWindow) + cdouble = CF.udf(lambda x: x * 2 if x is not None else None, IntegerType()) + with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): + clayered.withColumn("amount", cdouble(CF.col("amount"))).withColumn( + "category", CF.col("category").cast("string") + ).select(clayered["category"], "amount", "run_amt", "cat_max").collect() + + # Connect lenient: succeeds via name-based fallback. with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): - df = self.connect.createDataFrame(data, cols) - df.createOrReplaceTempView(view) - try: - joined = self.connect.sql(self_join_sql(view)) - double_udf = CF.udf(lambda x: x * 2 if x is not None else None, IntegerType()) - rows = ( - joined.withColumn("amount", double_udf(CF.col("amount"))) - .withColumn("category", CF.col("category").cast("string")) - .select(joined["category"], "amount", "run_amt", "cat_max") - .collect() - ) - self.assertEqual(len(rows), 5) - self.assertEqual( - sorted({r.category for r in rows}), - ["A", "B", "C"], - ) - finally: - self.connect.sql(f"DROP VIEW IF EXISTS {view}") + clayered = build_layered(self.connect, CF, CWindow) + cdouble = CF.udf(lambda x: x * 2 if x is not None else None, IntegerType()) + rows = ( + clayered.withColumn("amount", cdouble(CF.col("amount"))) + .withColumn("category", CF.col("category").cast("string")) + .select(clayered["category"], "amount", "run_amt", "cat_max") + .collect() + ) + self.assertEqual(len(rows), 5) + self.assertEqual(sorted({r.category for r in rows}), ["A", "B", "C"]) if __name__ == "__main__": From 060e9a6ce6cdc34378d45d260e556f70f4cdd6a0 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 18 May 2026 09:52:04 +0000 Subject: [PATCH 07/16] Rename Connect-side df variables to cdf; drop redundant type imports Rename df / df1 / df2 variables that refer to Connect DataFrames in SparkConnectColumnResolutionTests to cdf / cdf1 / cdf2, matching the existing sdf convention for Classic. Drop the per-test pyspark.sql.types imports from the layered tests; the needed types are already imported at the top of the file. Generated-by: Claude Code (Anthropic), claude-opus-4-7 --- .../sql/tests/connect/test_connect_column.py | 97 +++++++++---------- 1 file changed, 45 insertions(+), 52 deletions(-) diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index 5d9e1754215bb..1109fcc459c95 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -1134,18 +1134,18 @@ def test_resolve_after_chained_withcolumn_shadow(self): # Connect strict: same root cause, different error class. with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): - df = self.connect.sql("SELECT 1 AS c") + cdf = self.connect.sql("SELECT 1 AS c") with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): - df.withColumn("c", CF.col("c").cast("string")).withColumn( + cdf.withColumn("c", CF.col("c").cast("string")).withColumn( "c", CF.col("c").cast("int") - ).select(df["c"]).collect() + ).select(cdf["c"]).collect() # Connect lenient: succeeds via name-based fallback. with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): - df = self.connect.sql("SELECT 1 AS c") - df.withColumn("c", CF.col("c").cast("string")).withColumn( + cdf = self.connect.sql("SELECT 1 AS c") + cdf.withColumn("c", CF.col("c").cast("string")).withColumn( "c", CF.col("c").cast("int") - ).select(df["c"]).collect() + ).select(cdf["c"]).collect() def test_resolve_after_select_alias_shadow(self): # Same shadowing shape as withColumn but expressed through a select @@ -1157,18 +1157,18 @@ def test_resolve_after_select_alias_shadow(self): # Connect strict: fails. with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): - df = self.connect.sql("SELECT 1 AS c") + cdf = self.connect.sql("SELECT 1 AS c") with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): - df.select(df["c"].cast("string").alias("c")).select(df["c"]).collect() + cdf.select(cdf["c"].cast("string").alias("c")).select(cdf["c"]).collect() # Connect lenient: succeeds via name-based fallback. with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): - df = self.connect.sql("SELECT 1 AS c") - df.select(df["c"].cast("string").alias("c")).select(df["c"]).collect() + cdf = self.connect.sql("SELECT 1 AS c") + cdf.select(cdf["c"].cast("string").alias("c")).select(cdf["c"]).collect() def test_resolve_after_withcolumnrenamed(self): # withColumnRenamed drops the original `c` attribute and projects it - # as `c2`. The tagged df["c"] cannot resolve under any mode because + # as `c2`. The tagged cdf["c"] cannot resolve under any mode because # neither the original attribute nor a column named `c` is in the # current child output. # Classic: fails. @@ -1179,12 +1179,12 @@ def test_resolve_after_withcolumnrenamed(self): # Connect: fails in both modes. for strict in (True, False): with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - df = self.connect.sql("SELECT 1 AS c") + cdf = self.connect.sql("SELECT 1 AS c") with self.assertRaises(AnalysisException): - df.withColumnRenamed("c", "c2").select(df["c"]).collect() + cdf.withColumnRenamed("c", "c2").select(cdf["c"]).collect() def test_resolve_after_drop(self): - # drop("c") removes the column entirely. Tagged df["c"] cannot resolve + # drop("c") removes the column entirely. Tagged cdf["c"] cannot resolve # under any mode. # Classic: fails. sdf = self.spark.sql("SELECT 1 AS c, 2 AS d") @@ -1194,9 +1194,9 @@ def test_resolve_after_drop(self): # Connect: fails in both modes. for strict in (True, False): with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - df = self.connect.sql("SELECT 1 AS c, 2 AS d") + cdf = self.connect.sql("SELECT 1 AS c, 2 AS d") with self.assertRaises(AnalysisException): - df.drop("c").select(df["c"]).collect() + cdf.drop("c").select(cdf["c"]).collect() def test_resolve_through_filter(self): # filter is a pass-through operator: the child Project's attributes @@ -1212,8 +1212,8 @@ def test_resolve_through_filter(self): # Connect: succeeds in both modes. for strict in (True, False): with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - df = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") - rows = df.filter(df["c"] > 0).select(df["c"]).collect() + cdf = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") + rows = cdf.filter(cdf["c"] > 0).select(cdf["c"]).collect() self.assertEqual(sorted(r.c for r in rows), expected) def test_resolve_through_sort(self): @@ -1228,8 +1228,8 @@ def test_resolve_through_sort(self): # Connect: succeeds in both modes. for strict in (True, False): with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - df = self.connect.sql("SELECT 2 AS c UNION ALL SELECT 1 AS c") - rows = df.sort(df["c"]).select(df["c"]).collect() + cdf = self.connect.sql("SELECT 2 AS c UNION ALL SELECT 1 AS c") + rows = cdf.sort(cdf["c"]).select(cdf["c"]).collect() self.assertEqual([r.c for r in rows], expected) def test_resolve_through_distinct(self): @@ -1245,8 +1245,8 @@ def test_resolve_through_distinct(self): # Connect: succeeds in both modes. for strict in (True, False): with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - df = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 1 AS c") - rows = df.distinct().select(df["c"]).collect() + cdf = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 1 AS c") + rows = cdf.distinct().select(cdf["c"]).collect() self.assertEqual([r.c for r in rows], expected) def test_resolve_after_groupby_count(self): @@ -1264,13 +1264,13 @@ def test_resolve_after_groupby_count(self): # Connect: succeeds in both modes. for strict in (True, False): with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - df = self.connect.sql(query) - rows = df.groupBy("c").count().select(df["c"]).collect() + cdf = self.connect.sql(query) + rows = cdf.groupBy("c").count().select(cdf["c"]).collect() self.assertEqual(sorted(r.c for r in rows), expected) def test_resolve_after_agg_alias_shadow(self): # An aggregate output named `c` via alias() collides by name with - # the source `c`. The tagged df["c"] still references the source + # the source `c`. The tagged cdf["c"] still references the source # attribute that has been aggregated away. # Classic: fails. sdf = self.spark.sql("SELECT 1 AS x") @@ -1279,14 +1279,14 @@ def test_resolve_after_agg_alias_shadow(self): # Connect strict: fails. with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): - df = self.connect.sql("SELECT 1 AS x") + cdf = self.connect.sql("SELECT 1 AS x") with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): - df.groupBy().agg(CF.sum("x").alias("c")).select(df["c"]).collect() + cdf.groupBy().agg(CF.sum("x").alias("c")).select(cdf["c"]).collect() # Connect lenient: succeeds via name-based fallback. with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): - df = self.connect.sql("SELECT 1 AS x") - df.groupBy().agg(CF.sum("x").alias("c")).select(df["c"]).collect() + cdf = self.connect.sql("SELECT 1 AS x") + cdf.groupBy().agg(CF.sum("x").alias("c")).select(cdf["c"]).collect() def test_resolve_after_pivot(self): # pivot preserves the grouping key's attribute id in both Classic @@ -1302,8 +1302,8 @@ def test_resolve_after_pivot(self): # Connect: succeeds in both modes. for strict in (True, False): with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - df = self.connect.sql(query) - rows = df.groupBy("c").pivot("k").sum("v").select(df["c"]).collect() + cdf = self.connect.sql(query) + rows = cdf.groupBy("c").pivot("k").sum("v").select(cdf["c"]).collect() self.assertEqual(sorted(r.c for r in rows), expected) def test_resolve_after_union(self): @@ -1321,10 +1321,10 @@ def test_resolve_after_union(self): # Connect: fails in both modes. for strict in (True, False): with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - df1 = self.connect.sql("SELECT 1 AS c") - df2 = self.connect.sql("SELECT 2 AS c") + cdf1 = self.connect.sql("SELECT 1 AS c") + cdf2 = self.connect.sql("SELECT 2 AS c") with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): - df1.union(df2).select(df1["c"]).collect() + cdf1.union(cdf2).select(cdf1["c"]).collect() def test_resolve_after_intersect(self): # intersect, like union, emits new attribute ids. Classic resolves @@ -1342,14 +1342,14 @@ def test_resolve_after_intersect(self): # Connect: succeeds in both modes. for strict in (True, False): with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - df1 = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") - df2 = self.connect.sql("SELECT 2 AS c UNION ALL SELECT 3 AS c") - rows = df1.intersect(df2).select(df1["c"]).collect() + cdf1 = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") + cdf2 = self.connect.sql("SELECT 2 AS c UNION ALL SELECT 3 AS c") + rows = cdf1.intersect(cdf2).select(cdf1["c"]).collect() self.assertEqual([r.c for r in rows], expected) def test_resolve_self_join_alias(self): # In a self-join, both sides originate from the same plan-id-tagged - # ancestor. The tagged df["c"] is ambiguous because two output + # ancestor. The tagged cdf["c"] is ambiguous because two output # attributes match by name. # Classic: fails (ambiguous reference). sdf = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") @@ -1360,10 +1360,10 @@ def test_resolve_self_join_alias(self): # Connect: fails in both modes. for strict in (True, False): with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - df = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") - a, b = df.alias("a"), df.alias("b") + cdf = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") + a, b = cdf.alias("a"), cdf.alias("b") with self.assertRaises(AnalysisException): - a.join(b, a["c"] == b["c"]).select(df["c"]).collect() + a.join(b, a["c"] == b["c"]).select(cdf["c"]).collect() def test_resolve_after_subquery_view(self): # Persisting the original DataFrame as a temp view and reading it @@ -1387,10 +1387,10 @@ def test_resolve_after_subquery_view(self): for strict in (True, False): view = f"v_{uuid.uuid4().hex}" with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - df = self.connect.sql("SELECT 1 AS c") - df.createOrReplaceTempView(view) + cdf = self.connect.sql("SELECT 1 AS c") + cdf.createOrReplaceTempView(view) try: - rows = self.connect.table(view).select(df["c"]).collect() + rows = self.connect.table(view).select(cdf["c"]).collect() self.assertEqual([r.c for r in rows], expected) finally: self.connect.sql(f"DROP VIEW IF EXISTS {view}") @@ -1401,7 +1401,7 @@ def test_resolve_after_subquery_view(self): # (for SQL EXISTS/IN), window functions, cube aggregations, UDFs and # struct field access - into 4-5 layer pipelines. Each program builds # the layered base entirely through the DataFrame API, then layers a - # shadowing operation on top with a tagged ``df["c"]`` reference at the + # shadowing operation on top with a tagged ``cdf["c"]`` reference at the # outermost select. The goal is to catch regressions in plan-id # propagation across Connect's analyzer rules that single-operator # tests miss when rules interact. @@ -1479,12 +1479,6 @@ def test_layered_struct_semijoin_cube_ntile_shadow(self): # access -> cube aggregation -> window NTILE. Then withColumn # shadow + tagged select. from pyspark.sql.connect.window import Window as CWindow - from pyspark.sql.types import ( - IntegerType, - StringType, - StructField, - StructType, - ) from pyspark.sql.window import Window as SWindow events_schema = StructType( @@ -1581,7 +1575,6 @@ def test_layered_window_window_udf_shadow(self): # per-partition max window -> UDF wrap. Then withColumn shadow + # tagged select. from pyspark.sql.connect.window import Window as CWindow - from pyspark.sql.types import IntegerType from pyspark.sql.window import Window as SWindow data = [ From 1d586fac370daea7a09b30d667e97c072f1ae7dc Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 18 May 2026 10:23:03 +0000 Subject: [PATCH 08/16] Add tests for cases documented in ColumnResolutionHelper.scala Based on the resolution logic in sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala add four tests covering behaviors not previously exercised: - test_resolve_cross_dataframe_illegal_reference: the documented df1.select(df2.a) case where df2's plan id is not in df1's plan tree; fails with CANNOT_RESOLVE_DATAFRAME_COLUMN before any name-based fallback can run. - test_resolve_df_star: plan-id-tagged star expansion via UnresolvedDataFrameStar; succeeds in both Classic and Connect modes. - test_resolve_self_join_withcolumnrenamed: the documented self-join disambiguation example where the right-side candidate is filtered out by the rename projection above it. - test_resolve_sort_missing_attr_recovery: the documented df.select(df.v).sort(df.id) case where Sort's reference is recovered by resolveExprsAndAddMissingAttrs adding id back to the upstream projection, in both strict and lenient modes. Tighten comments on two existing tests: - test_resolve_after_union: cite that Union is treated as a leaf node when walking the plan tree for plan-id resolution. - test_resolve_self_join_alias: cite AMBIGUOUS_COLUMN_REFERENCE as the specific failure mode. Generated-by: Claude Code (Anthropic), claude-opus-4-7 --- .../sql/tests/connect/test_connect_column.py | 107 +++++++++++++++++- 1 file changed, 102 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index 1109fcc459c95..2a5db7047a95a 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -1309,9 +1309,11 @@ def test_resolve_after_pivot(self): def test_resolve_after_union(self): # Union emits new attribute ids. Classic still resolves the tagged # left-side reference by attribute id propagation, but Connect fails - # in both modes: plan-id-based resolution does not find the tagged - # ancestor in the union output, and name-based fallback is not - # triggered for set-op outputs. + # in both modes: per ColumnResolutionHelper, Union is treated as a + # leaf node when walking the plan tree for plan-id resolution + # (children are not searched), so cdf1's plan id is never found and + # CANNOT_RESOLVE_DATAFRAME_COLUMN is thrown before the lenient + # name-based fallback can run. # Classic: succeeds. sdf1 = self.spark.sql("SELECT 1 AS c") sdf2 = self.spark.sql("SELECT 2 AS c") @@ -1349,8 +1351,11 @@ def test_resolve_after_intersect(self): def test_resolve_self_join_alias(self): # In a self-join, both sides originate from the same plan-id-tagged - # ancestor. The tagged cdf["c"] is ambiguous because two output - # attributes match by name. + # ancestor. Plan-id resolution finds two candidates of equal depth + # that share the same attribute id; the disambiguation in + # ColumnResolutionHelper.resolveDataFrameColumn cannot tiebreak by + # depth and raises AMBIGUOUS_COLUMN_REFERENCE. Classic raises an + # ambiguous-reference AnalysisException for the same reason. # Classic: fails (ambiguous reference). sdf = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") a, b = sdf.alias("a"), sdf.alias("b") @@ -1395,6 +1400,98 @@ def test_resolve_after_subquery_view(self): finally: self.connect.sql(f"DROP VIEW IF EXISTS {view}") + def test_resolve_cross_dataframe_illegal_reference(self): + # Per ColumnResolutionHelper.resolveDataFrameColumn (the documented + # `df1.select(df2.a)` case): referencing a column from a DataFrame + # whose plan id is not an ancestor in the target plan tree fails + # with CANNOT_RESOLVE_DATAFRAME_COLUMN. The strict / lenient switch + # does not gate this throw, so the failure is identical in both + # Connect modes. + # Classic: fails (the attribute id of cdf2.id is not in cdf1's plan). + sdf1 = self.spark.range(3) + sdf2 = self.spark.range(5) + with self.assertRaises(AnalysisException): + sdf1.select(sdf2["id"]).collect() + + # Connect: fails in both modes. + for strict in (True, False): + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): + cdf1 = self.connect.range(3) + cdf2 = self.connect.range(5) + with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): + cdf1.select(cdf2["id"]).collect() + + def test_resolve_df_star(self): + # `cdf["*"]` is an UnresolvedDataFrameStar carrying cdf's plan id. + # The analyzer expands it to the output of the matched plan node. + # This works in both Classic and Connect, in both Connect modes. + query = "SELECT 'Books' AS c, 100 AS v UNION ALL SELECT 'Electronics' AS c, 200 AS v" + expected = [("Books", 100), ("Electronics", 200)] + + # Classic: succeeds. + sdf = self.spark.sql(query) + srows = sdf.select(sdf["*"]).collect() + self.assertEqual(sorted((r.c, r.v) for r in srows), expected) + + # Connect: succeeds in both modes. + for strict in (True, False): + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): + cdf = self.connect.sql(query) + rows = cdf.select(cdf["*"]).collect() + self.assertEqual(sorted((r.c, r.v) for r in rows), expected) + + def test_resolve_self_join_withcolumnrenamed(self): + # Documented example from ColumnResolutionHelper.scala (adjusted to + # produce one row per id rather than a 10x10 cross match): + # + # df1 = spark.range(10).withColumn("a", sf.col("id")) + # df2 = df1.withColumnRenamed("a", "b") + # df1.join(df2, df1["a"] == df2["b"]) + # + # When resolving the column reference df1.a, the target node with + # df1's plan id can be found on both sides of the Join. The + # candidate from the right side is filtered out because its `a` + # attribute is not in the output of the renaming Project above it. + # Disambiguation succeeds. + # Classic: succeeds. + sdf1 = self.spark.range(10).withColumn("a", SF.col("id")) + sdf2 = sdf1.withColumnRenamed("a", "b") + srows = sdf1.join(sdf2, sdf1["a"] == sdf2["b"]).select(sdf1["a"], sdf2["b"]).collect() + self.assertEqual(len(srows), 10) + + # Connect: succeeds in both modes. + for strict in (True, False): + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): + cdf1 = self.connect.range(10).withColumn("a", CF.col("id")) + cdf2 = cdf1.withColumnRenamed("a", "b") + rows = ( + cdf1.join(cdf2, cdf1["a"] == cdf2["b"]).select(cdf1["a"], cdf2["b"]).collect() + ) + self.assertEqual(len(rows), 10) + + def test_resolve_sort_missing_attr_recovery(self): + # Documented example from ColumnResolutionHelper.scala: + # + # df = spark.range(10).withColumn("v", sf.col("id") + 1) + # df.select(df.v).sort(df.id) + # + # Sort references df.id which is not in the upstream select's + # output. The analyzer's resolveExprsAndAddMissingAttrs descends + # through the Project, resolves df.id via plan-id at the source, + # and adds it back to the upstream projection. This works in both + # Classic and Connect, in both Connect modes. + # Classic: succeeds. + sdf = self.spark.range(10).withColumn("v", SF.col("id") + 1) + srows = sdf.select(sdf["v"]).sort(sdf["id"]).collect() + self.assertEqual(len(srows), 10) + + # Connect: succeeds in both modes. + for strict in (True, False): + with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): + cdf = self.connect.range(10).withColumn("v", CF.col("id") + 1) + rows = cdf.select(cdf["v"]).sort(cdf["id"]).collect() + self.assertEqual(len(rows), 10) + # --- Mixed-surface layered DataFrame programs --------------------------- # # These tests chain multiple DataFrame transformations - semi-joins From 4b50bb1fff705949a12972ae09ac8c7154e11161 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 28 May 2026 09:57:17 +0000 Subject: [PATCH 09/16] Move layered tests to ColumnTestsMixin; drop outer shadow Reframe the 3 layered DataFrame tests around ``layered[col]`` usage in filter and select at the outermost surface (no synthetic shadow on top), and move them into the shared ``ColumnTestsMixin`` so they run under Classic (``ColumnTests``), Connect strict (``ColumnParityTests``), and Connect lenient (``ColumnParityTestsWithNonStrictDFColResolution``) via the existing parity framework instead of building separate Classic and Connect arms in one test body. Generated-by: Claude Code (Anthropic), claude-opus-4-7 --- .../sql/tests/connect/test_connect_column.py | 233 ------------------ python/pyspark/sql/tests/test_column.py | 215 +++++++++++++++- 2 files changed, 214 insertions(+), 234 deletions(-) diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index 2a5db7047a95a..f50044a94d8c7 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -1492,239 +1492,6 @@ def test_resolve_sort_missing_attr_recovery(self): rows = cdf.select(cdf["v"]).sort(cdf["id"]).collect() self.assertEqual(len(rows), 10) - # --- Mixed-surface layered DataFrame programs --------------------------- - # - # These tests chain multiple DataFrame transformations - semi-joins - # (for SQL EXISTS/IN), window functions, cube aggregations, UDFs and - # struct field access - into 4-5 layer pipelines. Each program builds - # the layered base entirely through the DataFrame API, then layers a - # shadowing operation on top with a tagged ``cdf["c"]`` reference at the - # outermost select. The goal is to catch regressions in plan-id - # propagation across Connect's analyzer rules that single-operator - # tests miss when rules interact. - - def test_layered_semijoin_groupby_window_shadow(self): - # 4-layer DataFrame pipeline: filter -> semi-join -> groupBy/agg - # -> windows. Then a tagged ``layered["category"]`` reference after - # a groupBy shadow. - from pyspark.sql.connect.window import Window as CWindow - from pyspark.sql.window import Window as SWindow - - events_data = [ - (1, 1, "Books", 100.0, 2, True), - (2, 1, "Books", 50.0, 3, True), - (3, 2, "Electronics", 200.0, 1, True), - (4, 2, "Electronics", 300.0, 2, True), - (5, 3, "Home", 80.0, 4, True), - (6, 4, "Books", 60.0, 1, False), - ] - users_data = [(1, 25), (2, 30), (3, 22), (4, 18)] - events_cols = ["id", "user_id", "category", "amount", "quantity", "is_active"] - users_cols = ["id", "age"] - - def build_layered(spark, F, Window): - events = spark.createDataFrame(events_data, events_cols) - users = spark.createDataFrame(users_data, users_cols) - - # Layer 1: filter + semi-join (DataFrame-API equivalent of - # WHERE is_active AND EXISTS (user with age > 20)). - active = events.where(events["is_active"]).join( - users.where(users["age"] > 20), - events["user_id"] == users["id"], - "left_semi", - ) - # Layer 2: groupBy + agg, then post-agg filter (HAVING equivalent). - totals = ( - active.groupBy("category") - .agg( - F.sum(active["amount"] * active["quantity"] * F.lit(0.1)).alias("total_amt"), - F.sum(active["amount"]).alias("amount_sum"), - ) - .where(F.col("amount_sum") > 50) - .select("category", "total_amt") - ) - # Layer 3: window functions on top of the aggregate. - running = Window.orderBy("total_amt").rowsBetween(-1, 1) - ranking = Window.orderBy(F.col("total_amt").desc()) - windowed = totals.select( - "category", - "total_amt", - F.avg(F.col("total_amt")).over(running).alias("running_avg"), - F.rank().over(ranking).alias("rank_num"), - ) - # Layer 4: outer filter. - return windowed.where(F.col("rank_num") <= 5) - - expected_categories = ["Books", "Electronics", "Home"] - - # Classic: groupBy propagates the "category" attribute id through - # the aggregate, so the tagged reference still resolves. - slayered = build_layered(self.spark, SF, SWindow) - srows = slayered.groupBy("category").count().select(slayered["category"]).collect() - self.assertEqual(sorted(r.category for r in srows), expected_categories) - - # Connect: succeeds in both modes (groupBy attribute id propagates - # through Connect's aggregate as well). - for strict in (True, False): - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - clayered = build_layered(self.connect, CF, CWindow) - rows = clayered.groupBy("category").count().select(clayered["category"]).collect() - self.assertEqual(sorted(r.category for r in rows), expected_categories) - - def test_layered_struct_semijoin_cube_ntile_shadow(self): - # 5-layer DataFrame pipeline: filter -> semi-join -> struct field - # access -> cube aggregation -> window NTILE. Then withColumn - # shadow + tagged select. - from pyspark.sql.connect.window import Window as CWindow - from pyspark.sql.window import Window as SWindow - - events_schema = StructType( - [ - StructField("id", IntegerType()), - StructField("category", StringType()), - StructField("status", StringType()), - StructField("amount", IntegerType()), - StructField("quantity", IntegerType()), - StructField( - "detail", - StructType( - [ - StructField("name", StringType()), - StructField("nested", StructType([StructField("x", IntegerType())])), - ] - ), - ), - ] - ) - events_data = [ - (1, "Books", "A", 100, 5, ("alpha", (1,))), - (2, "Electronics", "B", 200, 3, ("beta", (2,))), - (3, "Books", "A", 50, 7, ("alpha", (1,))), - (4, "Electronics", "B", 300, 4, ("beta", (2,))), - (5, "Home", "C", 80, 2, ("gamma", (3,))), - ] - categories_data = [("Books", 1), ("Electronics", 2), ("Home", 3), ("Toys", 5)] - categories_cols = ["name", "priority"] - - def build_layered(spark, F, Window): - events = spark.createDataFrame(events_data, events_schema) - categories = spark.createDataFrame(categories_data, categories_cols) - - # Layer 1: filter + semi-join (DataFrame-API equivalent of - # WHERE quantity > 1 AND category IN (SELECT ...)). - filtered = events.where(events["quantity"] > 1).join( - categories.where(categories["priority"] <= 3), - events["category"] == categories["name"], - "left_semi", - ) - # Layer 2: project with struct field access. - base = filtered.select( - filtered["id"], - filtered["category"], - filtered["status"], - filtered["amount"], - filtered["detail"]["name"].alias("detail_name"), - filtered["detail"]["nested"]["x"].alias("nx"), - ) - # Layer 3: cube aggregation (mixed grouping levels - similar - # surface area to SQL GROUPING SETS without an exact equivalent - # in the DataFrame API). - grouped = ( - base.cube("category", "status", "detail_name") - .agg(F.sum(F.col("amount")).alias("total"), F.count(F.lit(1)).alias("cnt")) - .where(F.col("category").isNotNull() & F.col("status").isNotNull()) - ) - # Layer 4: NTILE window. - tiled = grouped.withColumn( - "tile", F.ntile(2).over(Window.orderBy(F.col("total").desc())) - ) - # Layer 5: outer filter. - return tiled.where(F.col("tile") <= 2) - - # Classic: withColumn shadow of "category" drops the original - # attribute - the tagged reference cannot resolve. - slayered = build_layered(self.spark, SF, SWindow) - with self.assertRaises(AnalysisException): - slayered.withColumn("category", SF.col("category").cast("string")).select( - slayered["category"], "total" - ).collect() - - # Connect strict: fails. - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): - clayered = build_layered(self.connect, CF, CWindow) - with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): - clayered.withColumn("category", CF.col("category").cast("string")).select( - clayered["category"], "total" - ).collect() - - # Connect lenient: succeeds via name-based fallback. - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): - clayered = build_layered(self.connect, CF, CWindow) - rows = ( - clayered.withColumn("category", CF.col("category").cast("string")) - .select(clayered["category"], "total") - .collect() - ) - self.assertGreater(len(rows), 0) - - def test_layered_window_window_udf_shadow(self): - # 4-layer DataFrame pipeline: filter -> running-total window -> - # per-partition max window -> UDF wrap. Then withColumn shadow + - # tagged select. - from pyspark.sql.connect.window import Window as CWindow - from pyspark.sql.window import Window as SWindow - - data = [ - (1, "A", 100), - (2, "A", 200), - (3, "B", 150), - (4, "B", 250), - (5, "C", 50), - ] - cols = ["id", "category", "amount"] - - def build_layered(spark, F, Window): - df = spark.createDataFrame(data, cols) - # Layer 1: filter (replaces WHERE EXISTS amount > 0). - filtered = df.where(df["amount"] > 0) - # Layer 2: running total window. - run_w = Window.partitionBy("category").orderBy("id") - with_run = filtered.withColumn("run_amt", F.sum(F.col("amount")).over(run_w)) - # Layer 3: per-category max window (replaces correlated subquery - # for cat_max). - cat_w = Window.partitionBy("category") - return with_run.withColumn("cat_max", F.max(F.col("amount")).over(cat_w)) - - # Classic: withColumn shadow of "category" breaks the tagged reference. - slayered = build_layered(self.spark, SF, SWindow) - sdouble = SF.udf(lambda x: x * 2 if x is not None else None, IntegerType()) - with self.assertRaises(AnalysisException): - slayered.withColumn("amount", sdouble(SF.col("amount"))).withColumn( - "category", SF.col("category").cast("string") - ).select(slayered["category"], "amount", "run_amt", "cat_max").collect() - - # Connect strict: fails. - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): - clayered = build_layered(self.connect, CF, CWindow) - cdouble = CF.udf(lambda x: x * 2 if x is not None else None, IntegerType()) - with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): - clayered.withColumn("amount", cdouble(CF.col("amount"))).withColumn( - "category", CF.col("category").cast("string") - ).select(clayered["category"], "amount", "run_amt", "cat_max").collect() - - # Connect lenient: succeeds via name-based fallback. - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): - clayered = build_layered(self.connect, CF, CWindow) - cdouble = CF.udf(lambda x: x * 2 if x is not None else None, IntegerType()) - rows = ( - clayered.withColumn("amount", cdouble(CF.col("amount"))) - .withColumn("category", CF.col("category").cast("string")) - .select(clayered["category"], "amount", "run_amt", "cat_max") - .collect() - ) - self.assertEqual(len(rows), 5) - self.assertEqual(sorted({r.category for r in rows}), ["A", "B", "C"]) - if __name__ == "__main__": from pyspark.testing import main diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index 74a7746b154de..b87abadcfd475 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -23,7 +23,7 @@ from pyspark.sql import Column, Row from pyspark.sql import functions as sf -from pyspark.sql.types import StructType, StructField, IntegerType, LongType +from pyspark.sql.types import StructType, StructField, IntegerType, LongType, StringType from pyspark.errors import AnalysisException, PySparkTypeError, PySparkValueError from pyspark.testing.sqlutils import ReusedSQLTestCase from pyspark.testing.utils import have_pandas, pandas_requirement_message @@ -605,6 +605,219 @@ def test_drop_notexistent_col(self): self.assertEqual(df4.columns, ["colA", "colB", "colC", "colC", "colD", "colE"]) self.assertEqual(df4.count(), 1) + # --- Mixed-surface layered DataFrame programs --------------------------- + # + # These tests chain multiple DataFrame transformations - semi-joins + # (for SQL EXISTS/IN), window functions, cube aggregations, UDFs and + # struct field access - into 4-5 layer pipelines, then reference the + # final layered DataFrame's columns via ``layered[col]`` in both filter + # and select at the outermost surface. The goal is to catch regressions + # in plan-id propagation across analyzer rules that single-operator + # tests miss when rules interact. + + def test_layered_semijoin_groupby_window(self): + # 4-layer DataFrame pipeline: filter -> semi-join -> groupBy/agg + # -> window functions. ``layered[col]`` references appear in both + # filter and select at the outermost surface. + from pyspark.sql.window import Window + + events_data = [ + (1, 1, "Books", 100.0, 2, True), + (2, 1, "Books", 50.0, 3, True), + (3, 2, "Electronics", 200.0, 1, True), + (4, 2, "Electronics", 300.0, 2, True), + (5, 3, "Home", 80.0, 4, True), + (6, 4, "Books", 60.0, 1, False), + ] + users_data = [(1, 25), (2, 30), (3, 22), (4, 18)] + events_cols = ["id", "user_id", "category", "amount", "quantity", "is_active"] + users_cols = ["id", "age"] + + events = self.spark.createDataFrame(events_data, events_cols) + users = self.spark.createDataFrame(users_data, users_cols) + # Layer 1: filter + semi-join (DataFrame-API equivalent of + # WHERE is_active AND EXISTS (user with age > 20)). + active = events.where(events["is_active"]).join( + users.where(users["age"] > 20), + events["user_id"] == users["id"], + "left_semi", + ) + # Layer 2: groupBy + agg, then post-agg filter (HAVING equivalent). + totals = ( + active.groupBy("category") + .agg( + sf.sum(active["amount"] * active["quantity"] * sf.lit(0.1)).alias("total_amt"), + sf.sum(active["amount"]).alias("amount_sum"), + ) + .where(sf.col("amount_sum") > 50) + .select("category", "total_amt") + ) + # Layer 3: window functions on top of the aggregate. + running = Window.orderBy("total_amt").rowsBetween(-1, 1) + ranking = Window.orderBy(sf.col("total_amt").desc()) + windowed = totals.select( + "category", + "total_amt", + sf.avg(sf.col("total_amt")).over(running).alias("running_avg"), + sf.rank().over(ranking).alias("rank_num"), + ) + # Layer 4: outer filter. + layered = windowed.where(sf.col("rank_num") <= 5) + + rows = ( + layered.filter(layered["rank_num"] <= 3) + .select( + layered["category"], + layered["total_amt"], + layered["running_avg"], + layered["rank_num"], + ) + .collect() + ) + result = sorted((r.category, r.rank_num) for r in rows) + self.assertEqual(result, [("Books", 2), ("Electronics", 1), ("Home", 3)]) + + def test_layered_struct_semijoin_cube_ntile(self): + # 5-layer DataFrame pipeline: filter -> semi-join -> struct field + # access -> cube aggregation -> window NTILE. ``layered[col]`` + # references appear in both filter and select at the outermost + # surface. + from pyspark.sql.window import Window + + events_schema = StructType( + [ + StructField("id", IntegerType()), + StructField("category", StringType()), + StructField("status", StringType()), + StructField("amount", IntegerType()), + StructField("quantity", IntegerType()), + StructField( + "detail", + StructType( + [ + StructField("name", StringType()), + StructField("nested", StructType([StructField("x", IntegerType())])), + ] + ), + ), + ] + ) + events_data = [ + (1, "Books", "A", 100, 5, ("alpha", (1,))), + (2, "Electronics", "B", 200, 3, ("beta", (2,))), + (3, "Books", "A", 50, 7, ("alpha", (1,))), + (4, "Electronics", "B", 300, 4, ("beta", (2,))), + (5, "Home", "C", 80, 2, ("gamma", (3,))), + ] + categories_data = [("Books", 1), ("Electronics", 2), ("Home", 3), ("Toys", 5)] + categories_cols = ["name", "priority"] + + events = self.spark.createDataFrame(events_data, events_schema) + categories = self.spark.createDataFrame(categories_data, categories_cols) + # Layer 1: filter + semi-join (DataFrame-API equivalent of + # WHERE quantity > 1 AND category IN (SELECT ...)). + filtered = events.where(events["quantity"] > 1).join( + categories.where(categories["priority"] <= 3), + events["category"] == categories["name"], + "left_semi", + ) + # Layer 2: project with struct field access. + base = filtered.select( + filtered["id"], + filtered["category"], + filtered["status"], + filtered["amount"], + filtered["detail"]["name"].alias("detail_name"), + filtered["detail"]["nested"]["x"].alias("nx"), + ) + # Layer 3: cube aggregation (mixed grouping levels - similar + # surface area to SQL GROUPING SETS without an exact equivalent + # in the DataFrame API). + grouped = ( + base.cube("category", "status", "detail_name") + .agg(sf.sum(sf.col("amount")).alias("total"), sf.count(sf.lit(1)).alias("cnt")) + .where(sf.col("category").isNotNull() & sf.col("status").isNotNull()) + ) + # Layer 4: NTILE window. + tiled = grouped.withColumn("tile", sf.ntile(2).over(Window.orderBy(sf.col("total").desc()))) + # Layer 5: outer filter. + layered = tiled.where(sf.col("tile") <= 2) + + rows = ( + layered.filter(layered["tile"] >= 1) + .select( + layered["category"], + layered["status"], + layered["detail_name"], + layered["total"], + layered["cnt"], + layered["tile"], + ) + .collect() + ) + # Cube emits one (category, status, detail_name) group per distinct + # combination plus one (category, status, NULL) rollup per distinct + # (category, status) pair. The where filter keeps both. + self.assertEqual(len(rows), 6) + self.assertEqual({r.category for r in rows}, {"Books", "Electronics", "Home"}) + self.assertEqual({r.total for r in rows}, {80, 150, 500}) + self.assertEqual({r.tile for r in rows}, {1, 2}) + + def test_layered_window_window_udf(self): + # 4-layer DataFrame pipeline: filter -> running-total window -> + # per-partition max window -> UDF wrap. ``layered[col]`` references + # appear in both filter and select at the outermost surface. + from pyspark.sql.window import Window + + data = [ + (1, "A", 100), + (2, "A", 200), + (3, "B", 150), + (4, "B", 250), + (5, "C", 50), + ] + cols = ["id", "category", "amount"] + + df = self.spark.createDataFrame(data, cols) + # Layer 1: filter (replaces WHERE EXISTS amount > 0). + filtered = df.where(df["amount"] > 0) + # Layer 2: running total window. + run_w = Window.partitionBy("category").orderBy("id") + with_run = filtered.withColumn("run_amt", sf.sum(sf.col("amount")).over(run_w)) + # Layer 3: per-category max window (replaces correlated subquery + # for cat_max). + cat_w = Window.partitionBy("category") + with_max = with_run.withColumn("cat_max", sf.max(sf.col("amount")).over(cat_w)) + # Layer 4: UDF. + double = sf.udf(lambda x: x * 2 if x is not None else None, IntegerType()) + layered = with_max.withColumn("doubled_amt", double(sf.col("amount"))) + + rows = ( + layered.filter(layered["amount"] > 0) + .select( + layered["id"], + layered["category"], + layered["amount"], + layered["run_amt"], + layered["cat_max"], + layered["doubled_amt"], + ) + .collect() + ) + result = sorted( + (r.id, r.category, r.amount, r.run_amt, r.cat_max, r.doubled_amt) for r in rows + ) + self.assertEqual( + result, + [ + (1, "A", 100, 100, 200, 200), + (2, "A", 200, 300, 200, 400), + (3, "B", 150, 150, 250, 300), + (4, "B", 250, 400, 250, 500), + (5, "C", 50, 50, 50, 100), + ], + ) + class ColumnTests(ColumnTestsMixin, ReusedSQLTestCase): pass From 061d37beeccfa55e9ddc9acc6c212e3069237d6a Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 29 May 2026 09:22:35 +0000 Subject: [PATCH 10/16] Use df.col_name getattr form in layered tests Reference DataFrame columns via the ``df.col_name`` getattr form instead of ``sf.col(...)`` / ``df["col"]`` in the three layered tests, binding the post-agg intermediates to variables so a DataFrame is in scope for the HAVING-style filters. Struct subfields keep bracket access because ``detail.name`` would resolve to ``Column.name``. Generated-by: Claude Code (Anthropic), claude-opus-4-8 --- python/pyspark/sql/tests/test_column.py | 110 ++++++++++++------------ 1 file changed, 53 insertions(+), 57 deletions(-) diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index b87abadcfd475..f1781722ee099 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -610,14 +610,14 @@ def test_drop_notexistent_col(self): # These tests chain multiple DataFrame transformations - semi-joins # (for SQL EXISTS/IN), window functions, cube aggregations, UDFs and # struct field access - into 4-5 layer pipelines, then reference the - # final layered DataFrame's columns via ``layered[col]`` in both filter + # final layered DataFrame's columns via ``layered.col`` in both filter # and select at the outermost surface. The goal is to catch regressions # in plan-id propagation across analyzer rules that single-operator # tests miss when rules interact. def test_layered_semijoin_groupby_window(self): # 4-layer DataFrame pipeline: filter -> semi-join -> groupBy/agg - # -> window functions. ``layered[col]`` references appear in both + # -> window functions. ``layered.col`` references appear in both # filter and select at the outermost surface. from pyspark.sql.window import Window @@ -637,40 +637,36 @@ def test_layered_semijoin_groupby_window(self): users = self.spark.createDataFrame(users_data, users_cols) # Layer 1: filter + semi-join (DataFrame-API equivalent of # WHERE is_active AND EXISTS (user with age > 20)). - active = events.where(events["is_active"]).join( - users.where(users["age"] > 20), - events["user_id"] == users["id"], + active = events.where(events.is_active).join( + users.where(users.age > 20), + events.user_id == users.id, "left_semi", ) # Layer 2: groupBy + agg, then post-agg filter (HAVING equivalent). - totals = ( - active.groupBy("category") - .agg( - sf.sum(active["amount"] * active["quantity"] * sf.lit(0.1)).alias("total_amt"), - sf.sum(active["amount"]).alias("amount_sum"), - ) - .where(sf.col("amount_sum") > 50) - .select("category", "total_amt") + agg = active.groupBy("category").agg( + sf.sum(active.amount * active.quantity * sf.lit(0.1)).alias("total_amt"), + sf.sum(active.amount).alias("amount_sum"), ) + totals = agg.where(agg.amount_sum > 50).select("category", "total_amt") # Layer 3: window functions on top of the aggregate. running = Window.orderBy("total_amt").rowsBetween(-1, 1) - ranking = Window.orderBy(sf.col("total_amt").desc()) + ranking = Window.orderBy(totals.total_amt.desc()) windowed = totals.select( "category", "total_amt", - sf.avg(sf.col("total_amt")).over(running).alias("running_avg"), + sf.avg(totals.total_amt).over(running).alias("running_avg"), sf.rank().over(ranking).alias("rank_num"), ) # Layer 4: outer filter. - layered = windowed.where(sf.col("rank_num") <= 5) + layered = windowed.where(windowed.rank_num <= 5) rows = ( - layered.filter(layered["rank_num"] <= 3) + layered.filter(layered.rank_num <= 3) .select( - layered["category"], - layered["total_amt"], - layered["running_avg"], - layered["rank_num"], + layered.category, + layered.total_amt, + layered.running_avg, + layered.rank_num, ) .collect() ) @@ -679,7 +675,7 @@ def test_layered_semijoin_groupby_window(self): def test_layered_struct_semijoin_cube_ntile(self): # 5-layer DataFrame pipeline: filter -> semi-join -> struct field - # access -> cube aggregation -> window NTILE. ``layered[col]`` + # access -> cube aggregation -> window NTILE. ``layered.col`` # references appear in both filter and select at the outermost # surface. from pyspark.sql.window import Window @@ -716,42 +712,42 @@ def test_layered_struct_semijoin_cube_ntile(self): categories = self.spark.createDataFrame(categories_data, categories_cols) # Layer 1: filter + semi-join (DataFrame-API equivalent of # WHERE quantity > 1 AND category IN (SELECT ...)). - filtered = events.where(events["quantity"] > 1).join( - categories.where(categories["priority"] <= 3), - events["category"] == categories["name"], + filtered = events.where(events.quantity > 1).join( + categories.where(categories.priority <= 3), + events.category == categories.name, "left_semi", ) - # Layer 2: project with struct field access. + # Layer 2: project with struct field access (struct subfields use + # bracket access since ``detail.name`` would hit ``Column.name``). base = filtered.select( - filtered["id"], - filtered["category"], - filtered["status"], - filtered["amount"], - filtered["detail"]["name"].alias("detail_name"), - filtered["detail"]["nested"]["x"].alias("nx"), + filtered.id, + filtered.category, + filtered.status, + filtered.amount, + filtered.detail["name"].alias("detail_name"), + filtered.detail["nested"]["x"].alias("nx"), ) # Layer 3: cube aggregation (mixed grouping levels - similar # surface area to SQL GROUPING SETS without an exact equivalent # in the DataFrame API). - grouped = ( - base.cube("category", "status", "detail_name") - .agg(sf.sum(sf.col("amount")).alias("total"), sf.count(sf.lit(1)).alias("cnt")) - .where(sf.col("category").isNotNull() & sf.col("status").isNotNull()) + agg = base.cube("category", "status", "detail_name").agg( + sf.sum(base.amount).alias("total"), sf.count(sf.lit(1)).alias("cnt") ) + grouped = agg.where(agg.category.isNotNull() & agg.status.isNotNull()) # Layer 4: NTILE window. - tiled = grouped.withColumn("tile", sf.ntile(2).over(Window.orderBy(sf.col("total").desc()))) + tiled = grouped.withColumn("tile", sf.ntile(2).over(Window.orderBy(grouped.total.desc()))) # Layer 5: outer filter. - layered = tiled.where(sf.col("tile") <= 2) + layered = tiled.where(tiled.tile <= 2) rows = ( - layered.filter(layered["tile"] >= 1) + layered.filter(layered.tile >= 1) .select( - layered["category"], - layered["status"], - layered["detail_name"], - layered["total"], - layered["cnt"], - layered["tile"], + layered.category, + layered.status, + layered.detail_name, + layered.total, + layered.cnt, + layered.tile, ) .collect() ) @@ -765,7 +761,7 @@ def test_layered_struct_semijoin_cube_ntile(self): def test_layered_window_window_udf(self): # 4-layer DataFrame pipeline: filter -> running-total window -> - # per-partition max window -> UDF wrap. ``layered[col]`` references + # per-partition max window -> UDF wrap. ``layered.col`` references # appear in both filter and select at the outermost surface. from pyspark.sql.window import Window @@ -780,27 +776,27 @@ def test_layered_window_window_udf(self): df = self.spark.createDataFrame(data, cols) # Layer 1: filter (replaces WHERE EXISTS amount > 0). - filtered = df.where(df["amount"] > 0) + filtered = df.where(df.amount > 0) # Layer 2: running total window. run_w = Window.partitionBy("category").orderBy("id") - with_run = filtered.withColumn("run_amt", sf.sum(sf.col("amount")).over(run_w)) + with_run = filtered.withColumn("run_amt", sf.sum(filtered.amount).over(run_w)) # Layer 3: per-category max window (replaces correlated subquery # for cat_max). cat_w = Window.partitionBy("category") - with_max = with_run.withColumn("cat_max", sf.max(sf.col("amount")).over(cat_w)) + with_max = with_run.withColumn("cat_max", sf.max(with_run.amount).over(cat_w)) # Layer 4: UDF. double = sf.udf(lambda x: x * 2 if x is not None else None, IntegerType()) - layered = with_max.withColumn("doubled_amt", double(sf.col("amount"))) + layered = with_max.withColumn("doubled_amt", double(with_max.amount)) rows = ( - layered.filter(layered["amount"] > 0) + layered.filter(layered.amount > 0) .select( - layered["id"], - layered["category"], - layered["amount"], - layered["run_amt"], - layered["cat_max"], - layered["doubled_amt"], + layered.id, + layered.category, + layered.amount, + layered.run_amt, + layered.cat_max, + layered.doubled_amt, ) .collect() ) From 29f715b7c1db4f4d5b80ec2c77f8bd65df435269 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 29 May 2026 11:00:29 +0000 Subject: [PATCH 11/16] Move resolution tests to ColumnTestsMixin; override divergences Move the tagged DataFrame column-resolution tests out of the Connect-only SparkConnectColumnResolutionTests and into the shared ColumnTestsMixin, so each runs under Classic (ColumnTests), Connect strict (ColumnParityTests), and Connect lenient (ColumnParityTestsWithNonStrictDFColResolution). The mixin body asserts the behavior shared by Classic and Connect strict. The diverging cases are overridden in the parity suites: - union: ColumnParityTests asserts Connect raises CANNOT_RESOLVE_DATAFRAME_COLUMN in both modes (Classic resolves via attribute-id propagation in the mixin); - the shadowing trio: the lenient suite asserts name-based fallback succeeds. Tagged references use the df.col getattr form; df["*"] keeps bracket access and untagged transform expressions keep sf.col. Drops the now-unused SparkConnectColumnResolutionTests and its AnalysisException import. Generated-by: Claude Code (Anthropic), claude-opus-4-8 --- .../sql/tests/connect/test_connect_column.py | 392 +----------------- .../sql/tests/connect/test_parity_column.py | 36 ++ python/pyspark/sql/tests/test_column.py | 169 ++++++++ 3 files changed, 206 insertions(+), 391 deletions(-) diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index f50044a94d8c7..6fa6c4686c527 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -40,7 +40,7 @@ DecimalType, BooleanType, ) -from pyspark.errors import AnalysisException, PySparkTypeError, PySparkValueError +from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.testing import assertDataFrameEqual from pyspark.testing.connectutils import should_test_connect, ReusedMixedTestCase from pyspark.testing.pandasutils import PandasOnSparkTestUtils @@ -1103,396 +1103,6 @@ def test_transform(self): ) -class SparkConnectColumnResolutionTests(ReusedMixedTestCase): - """Connect-only tests pinning known Connect/Classic divergences in DataFrame - column resolution. - - For each pattern, the test runs the equivalent program against: - - * Spark Classic (``self.spark``) - baseline that anchors the Classic - behavior. - * Spark Connect strict mode (default, - ``spark.sql.analyzer.strictDataFrameColumnResolution=true``) - plan-id- - based resolution only. Tagged ``df["c"]`` references whose ancestor's - attribute is gone fail with ``CANNOT_RESOLVE_DATAFRAME_COLUMN``. - * Spark Connect lenient mode - (``spark.sql.analyzer.strictDataFrameColumnResolution=false``) - if - plan-id-based resolution fails the analyzer also tries name-based - resolution against the current child output. - """ - - def test_resolve_after_chained_withcolumn_shadow(self): - # Two consecutive withColumn calls each shadow `c` with a new - # attribute carrying the same name; the original `c` is no longer in - # the projection. - # Classic: fails. - sdf = self.spark.sql("SELECT 1 AS c") - with self.assertRaises(AnalysisException): - sdf.withColumn("c", SF.col("c").cast("string")).withColumn( - "c", SF.col("c").cast("int") - ).select(sdf["c"]).collect() - - # Connect strict: same root cause, different error class. - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): - cdf = self.connect.sql("SELECT 1 AS c") - with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): - cdf.withColumn("c", CF.col("c").cast("string")).withColumn( - "c", CF.col("c").cast("int") - ).select(cdf["c"]).collect() - - # Connect lenient: succeeds via name-based fallback. - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): - cdf = self.connect.sql("SELECT 1 AS c") - cdf.withColumn("c", CF.col("c").cast("string")).withColumn( - "c", CF.col("c").cast("int") - ).select(cdf["c"]).collect() - - def test_resolve_after_select_alias_shadow(self): - # Same shadowing shape as withColumn but expressed through a select - # with alias. - # Classic: fails. - sdf = self.spark.sql("SELECT 1 AS c") - with self.assertRaises(AnalysisException): - sdf.select(sdf["c"].cast("string").alias("c")).select(sdf["c"]).collect() - - # Connect strict: fails. - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): - cdf = self.connect.sql("SELECT 1 AS c") - with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): - cdf.select(cdf["c"].cast("string").alias("c")).select(cdf["c"]).collect() - - # Connect lenient: succeeds via name-based fallback. - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): - cdf = self.connect.sql("SELECT 1 AS c") - cdf.select(cdf["c"].cast("string").alias("c")).select(cdf["c"]).collect() - - def test_resolve_after_withcolumnrenamed(self): - # withColumnRenamed drops the original `c` attribute and projects it - # as `c2`. The tagged cdf["c"] cannot resolve under any mode because - # neither the original attribute nor a column named `c` is in the - # current child output. - # Classic: fails. - sdf = self.spark.sql("SELECT 1 AS c") - with self.assertRaises(AnalysisException): - sdf.withColumnRenamed("c", "c2").select(sdf["c"]).collect() - - # Connect: fails in both modes. - for strict in (True, False): - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - cdf = self.connect.sql("SELECT 1 AS c") - with self.assertRaises(AnalysisException): - cdf.withColumnRenamed("c", "c2").select(cdf["c"]).collect() - - def test_resolve_after_drop(self): - # drop("c") removes the column entirely. Tagged cdf["c"] cannot resolve - # under any mode. - # Classic: fails. - sdf = self.spark.sql("SELECT 1 AS c, 2 AS d") - with self.assertRaises(AnalysisException): - sdf.drop("c").select(sdf["c"]).collect() - - # Connect: fails in both modes. - for strict in (True, False): - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - cdf = self.connect.sql("SELECT 1 AS c, 2 AS d") - with self.assertRaises(AnalysisException): - cdf.drop("c").select(cdf["c"]).collect() - - def test_resolve_through_filter(self): - # filter is a pass-through operator: the child Project's attributes - # flow through unchanged, so the tagged reference resolves in both - # worlds. - expected = [1, 2] - - # Classic: succeeds. - sdf = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") - srows = sdf.filter(sdf["c"] > 0).select(sdf["c"]).collect() - self.assertEqual(sorted(r.c for r in srows), expected) - - # Connect: succeeds in both modes. - for strict in (True, False): - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - cdf = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") - rows = cdf.filter(cdf["c"] > 0).select(cdf["c"]).collect() - self.assertEqual(sorted(r.c for r in rows), expected) - - def test_resolve_through_sort(self): - # sort is also a pass-through operator. - expected = [1, 2] - - # Classic: succeeds. - sdf = self.spark.sql("SELECT 2 AS c UNION ALL SELECT 1 AS c") - srows = sdf.sort(sdf["c"]).select(sdf["c"]).collect() - self.assertEqual([r.c for r in srows], expected) - - # Connect: succeeds in both modes. - for strict in (True, False): - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - cdf = self.connect.sql("SELECT 2 AS c UNION ALL SELECT 1 AS c") - rows = cdf.sort(cdf["c"]).select(cdf["c"]).collect() - self.assertEqual([r.c for r in rows], expected) - - def test_resolve_through_distinct(self): - # distinct preserves attribute identity from the perspective of - # column resolution. - expected = [1] - - # Classic: succeeds. - sdf = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 1 AS c") - srows = sdf.distinct().select(sdf["c"]).collect() - self.assertEqual([r.c for r in srows], expected) - - # Connect: succeeds in both modes. - for strict in (True, False): - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - cdf = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 1 AS c") - rows = cdf.distinct().select(cdf["c"]).collect() - self.assertEqual([r.c for r in rows], expected) - - def test_resolve_after_groupby_count(self): - # groupBy("c").count() preserves the grouping key's attribute id in - # both Classic and Connect, so the tagged reference resolves in all - # modes. - query = "SELECT 1 AS c UNION ALL SELECT 1 AS c UNION ALL SELECT 2 AS c" - expected = [1, 2] - - # Classic: succeeds. - sdf = self.spark.sql(query) - srows = sdf.groupBy("c").count().select(sdf["c"]).collect() - self.assertEqual(sorted(r.c for r in srows), expected) - - # Connect: succeeds in both modes. - for strict in (True, False): - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - cdf = self.connect.sql(query) - rows = cdf.groupBy("c").count().select(cdf["c"]).collect() - self.assertEqual(sorted(r.c for r in rows), expected) - - def test_resolve_after_agg_alias_shadow(self): - # An aggregate output named `c` via alias() collides by name with - # the source `c`. The tagged cdf["c"] still references the source - # attribute that has been aggregated away. - # Classic: fails. - sdf = self.spark.sql("SELECT 1 AS x") - with self.assertRaises(AnalysisException): - sdf.groupBy().agg(SF.sum("x").alias("c")).select(sdf["c"]).collect() - - # Connect strict: fails. - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": True}): - cdf = self.connect.sql("SELECT 1 AS x") - with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): - cdf.groupBy().agg(CF.sum("x").alias("c")).select(cdf["c"]).collect() - - # Connect lenient: succeeds via name-based fallback. - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": False}): - cdf = self.connect.sql("SELECT 1 AS x") - cdf.groupBy().agg(CF.sum("x").alias("c")).select(cdf["c"]).collect() - - def test_resolve_after_pivot(self): - # pivot preserves the grouping key's attribute id in both Classic - # and Connect, so the tagged reference resolves in all modes. - query = "SELECT 1 AS c, 'a' AS k, 10 AS v UNION ALL SELECT 2 AS c, 'b' AS k, 20 AS v" - expected = [1, 2] - - # Classic: succeeds. - sdf = self.spark.sql(query) - srows = sdf.groupBy("c").pivot("k").sum("v").select(sdf["c"]).collect() - self.assertEqual(sorted(r.c for r in srows), expected) - - # Connect: succeeds in both modes. - for strict in (True, False): - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - cdf = self.connect.sql(query) - rows = cdf.groupBy("c").pivot("k").sum("v").select(cdf["c"]).collect() - self.assertEqual(sorted(r.c for r in rows), expected) - - def test_resolve_after_union(self): - # Union emits new attribute ids. Classic still resolves the tagged - # left-side reference by attribute id propagation, but Connect fails - # in both modes: per ColumnResolutionHelper, Union is treated as a - # leaf node when walking the plan tree for plan-id resolution - # (children are not searched), so cdf1's plan id is never found and - # CANNOT_RESOLVE_DATAFRAME_COLUMN is thrown before the lenient - # name-based fallback can run. - # Classic: succeeds. - sdf1 = self.spark.sql("SELECT 1 AS c") - sdf2 = self.spark.sql("SELECT 2 AS c") - srows = sdf1.union(sdf2).select(sdf1["c"]).collect() - self.assertEqual(sorted(r.c for r in srows), [1, 2]) - - # Connect: fails in both modes. - for strict in (True, False): - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - cdf1 = self.connect.sql("SELECT 1 AS c") - cdf2 = self.connect.sql("SELECT 2 AS c") - with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): - cdf1.union(cdf2).select(cdf1["c"]).collect() - - def test_resolve_after_intersect(self): - # intersect, like union, emits new attribute ids. Classic resolves - # the tagged reference by attribute id propagation; Connect also - # resolves it successfully (the intersect output retains the - # propagated id), in both modes. - expected = [2] - - # Classic: succeeds. - sdf1 = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") - sdf2 = self.spark.sql("SELECT 2 AS c UNION ALL SELECT 3 AS c") - srows = sdf1.intersect(sdf2).select(sdf1["c"]).collect() - self.assertEqual([r.c for r in srows], expected) - - # Connect: succeeds in both modes. - for strict in (True, False): - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - cdf1 = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") - cdf2 = self.connect.sql("SELECT 2 AS c UNION ALL SELECT 3 AS c") - rows = cdf1.intersect(cdf2).select(cdf1["c"]).collect() - self.assertEqual([r.c for r in rows], expected) - - def test_resolve_self_join_alias(self): - # In a self-join, both sides originate from the same plan-id-tagged - # ancestor. Plan-id resolution finds two candidates of equal depth - # that share the same attribute id; the disambiguation in - # ColumnResolutionHelper.resolveDataFrameColumn cannot tiebreak by - # depth and raises AMBIGUOUS_COLUMN_REFERENCE. Classic raises an - # ambiguous-reference AnalysisException for the same reason. - # Classic: fails (ambiguous reference). - sdf = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") - a, b = sdf.alias("a"), sdf.alias("b") - with self.assertRaises(AnalysisException): - a.join(b, a["c"] == b["c"]).select(sdf["c"]).collect() - - # Connect: fails in both modes. - for strict in (True, False): - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - cdf = self.connect.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") - a, b = cdf.alias("a"), cdf.alias("b") - with self.assertRaises(AnalysisException): - a.join(b, a["c"] == b["c"]).select(cdf["c"]).collect() - - def test_resolve_after_subquery_view(self): - # Persisting the original DataFrame as a temp view and reading it - # back via spark.table() produces a new plan. Classic resolves the - # tagged reference; Connect also resolves it in both modes. - import uuid - - expected = [1] - - # Classic: succeeds. - view = f"v_{uuid.uuid4().hex}" - sdf = self.spark.sql("SELECT 1 AS c") - sdf.createOrReplaceTempView(view) - try: - srows = self.spark.table(view).select(sdf["c"]).collect() - self.assertEqual([r.c for r in srows], expected) - finally: - self.spark.sql(f"DROP VIEW IF EXISTS {view}") - - # Connect: succeeds in both modes. - for strict in (True, False): - view = f"v_{uuid.uuid4().hex}" - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - cdf = self.connect.sql("SELECT 1 AS c") - cdf.createOrReplaceTempView(view) - try: - rows = self.connect.table(view).select(cdf["c"]).collect() - self.assertEqual([r.c for r in rows], expected) - finally: - self.connect.sql(f"DROP VIEW IF EXISTS {view}") - - def test_resolve_cross_dataframe_illegal_reference(self): - # Per ColumnResolutionHelper.resolveDataFrameColumn (the documented - # `df1.select(df2.a)` case): referencing a column from a DataFrame - # whose plan id is not an ancestor in the target plan tree fails - # with CANNOT_RESOLVE_DATAFRAME_COLUMN. The strict / lenient switch - # does not gate this throw, so the failure is identical in both - # Connect modes. - # Classic: fails (the attribute id of cdf2.id is not in cdf1's plan). - sdf1 = self.spark.range(3) - sdf2 = self.spark.range(5) - with self.assertRaises(AnalysisException): - sdf1.select(sdf2["id"]).collect() - - # Connect: fails in both modes. - for strict in (True, False): - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - cdf1 = self.connect.range(3) - cdf2 = self.connect.range(5) - with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): - cdf1.select(cdf2["id"]).collect() - - def test_resolve_df_star(self): - # `cdf["*"]` is an UnresolvedDataFrameStar carrying cdf's plan id. - # The analyzer expands it to the output of the matched plan node. - # This works in both Classic and Connect, in both Connect modes. - query = "SELECT 'Books' AS c, 100 AS v UNION ALL SELECT 'Electronics' AS c, 200 AS v" - expected = [("Books", 100), ("Electronics", 200)] - - # Classic: succeeds. - sdf = self.spark.sql(query) - srows = sdf.select(sdf["*"]).collect() - self.assertEqual(sorted((r.c, r.v) for r in srows), expected) - - # Connect: succeeds in both modes. - for strict in (True, False): - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - cdf = self.connect.sql(query) - rows = cdf.select(cdf["*"]).collect() - self.assertEqual(sorted((r.c, r.v) for r in rows), expected) - - def test_resolve_self_join_withcolumnrenamed(self): - # Documented example from ColumnResolutionHelper.scala (adjusted to - # produce one row per id rather than a 10x10 cross match): - # - # df1 = spark.range(10).withColumn("a", sf.col("id")) - # df2 = df1.withColumnRenamed("a", "b") - # df1.join(df2, df1["a"] == df2["b"]) - # - # When resolving the column reference df1.a, the target node with - # df1's plan id can be found on both sides of the Join. The - # candidate from the right side is filtered out because its `a` - # attribute is not in the output of the renaming Project above it. - # Disambiguation succeeds. - # Classic: succeeds. - sdf1 = self.spark.range(10).withColumn("a", SF.col("id")) - sdf2 = sdf1.withColumnRenamed("a", "b") - srows = sdf1.join(sdf2, sdf1["a"] == sdf2["b"]).select(sdf1["a"], sdf2["b"]).collect() - self.assertEqual(len(srows), 10) - - # Connect: succeeds in both modes. - for strict in (True, False): - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - cdf1 = self.connect.range(10).withColumn("a", CF.col("id")) - cdf2 = cdf1.withColumnRenamed("a", "b") - rows = ( - cdf1.join(cdf2, cdf1["a"] == cdf2["b"]).select(cdf1["a"], cdf2["b"]).collect() - ) - self.assertEqual(len(rows), 10) - - def test_resolve_sort_missing_attr_recovery(self): - # Documented example from ColumnResolutionHelper.scala: - # - # df = spark.range(10).withColumn("v", sf.col("id") + 1) - # df.select(df.v).sort(df.id) - # - # Sort references df.id which is not in the upstream select's - # output. The analyzer's resolveExprsAndAddMissingAttrs descends - # through the Project, resolves df.id via plan-id at the source, - # and adds it back to the upstream projection. This works in both - # Classic and Connect, in both Connect modes. - # Classic: succeeds. - sdf = self.spark.range(10).withColumn("v", SF.col("id") + 1) - srows = sdf.select(sdf["v"]).sort(sdf["id"]).collect() - self.assertEqual(len(srows), 10) - - # Connect: succeeds in both modes. - for strict in (True, False): - with self.connect_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": strict}): - cdf = self.connect.range(10).withColumn("v", CF.col("id") + 1) - rows = cdf.select(cdf["v"]).sort(cdf["id"]).collect() - self.assertEqual(len(rows), 10) - - if __name__ == "__main__": from pyspark.testing import main diff --git a/python/pyspark/sql/tests/connect/test_parity_column.py b/python/pyspark/sql/tests/connect/test_parity_column.py index 3903bb57a3750..ad9b4b3ba7b6e 100644 --- a/python/pyspark/sql/tests/connect/test_parity_column.py +++ b/python/pyspark/sql/tests/connect/test_parity_column.py @@ -17,6 +17,8 @@ import unittest +from pyspark.errors import AnalysisException +from pyspark.sql import functions as sf from pyspark.sql.tests.test_column import ColumnTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase @@ -38,6 +40,16 @@ def tearDownClass(cls): def test_validate_column_types(self): super().test_validate_column_types() + def test_resolve_after_union(self): + # Connect diverges from Classic here: Union is treated as a leaf when + # walking the plan tree for plan-id resolution, so the left-side plan + # id is never found and CANNOT_RESOLVE_DATAFRAME_COLUMN is thrown + # before any name-based fallback - in both strict and lenient modes. + df1 = self.spark.sql("SELECT 1 AS c") + df2 = self.spark.sql("SELECT 2 AS c") + with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): + df1.union(df2).select(df1.c).collect() + def test_df_col_resolution_mode(self): self.assertEqual( self.spark.conf.get("spark.sql.analyzer.strictDataFrameColumnResolution"), @@ -68,6 +80,30 @@ def test_df_col_resolution_mode(self): "false", ) + # The shadowing trio diverges in lenient mode: where Classic and Connect + # strict raise, lenient resolves the tagged reference by name against the + # current (shadowed) output. + + def test_resolve_after_chained_withcolumn_shadow(self): + df = self.spark.sql("SELECT 1 AS c") + rows = ( + df.withColumn("c", sf.col("c").cast("string")) + .withColumn("c", sf.col("c").cast("int")) + .select(df.c) + .collect() + ) + self.assertEqual([r.c for r in rows], [1]) + + def test_resolve_after_select_alias_shadow(self): + df = self.spark.sql("SELECT 1 AS c") + rows = df.select(df.c.cast("string").alias("c")).select(df.c).collect() + self.assertEqual([r.c for r in rows], ["1"]) + + def test_resolve_after_agg_alias_shadow(self): + df = self.spark.sql("SELECT 1 AS x") + rows = df.groupBy().agg(sf.sum("x").alias("c")).select(df.c).collect() + self.assertEqual([r.c for r in rows], [1]) + if __name__ == "__main__": from pyspark.testing import main diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index f1781722ee099..cb9716c26adb0 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -20,6 +20,7 @@ from itertools import chain import datetime import unittest +import uuid from pyspark.sql import Column, Row from pyspark.sql import functions as sf @@ -814,6 +815,174 @@ def test_layered_window_window_udf(self): ], ) + # --- Tagged DataFrame column resolution -------------------------------- + # + # ``df.col`` / ``df["col"]`` carries the source DataFrame's plan id. These + # tests pin how that tagged reference resolves after assorted operators. + # The behavior is shared across Spark Classic and Spark Connect (both + # ``spark.sql.analyzer.strictDataFrameColumnResolution`` modes) except for + # a few diverging cases, which are overridden in the Connect parity suites + # (``ColumnParityTests`` / ``...WithNonStrictDFColResolution``): + # + # * the shadowing trio - Classic and Connect strict raise, Connect + # lenient resolves the shadowed name via name-based fallback; + # * union - Classic resolves via attribute-id propagation, Connect + # raises in both modes. + + def test_resolve_after_chained_withcolumn_shadow(self): + # Two consecutive withColumn calls each shadow `c` with a new + # attribute of the same name, so the original `c` leaves the + # projection and the tagged `df.c` cannot resolve. + df = self.spark.sql("SELECT 1 AS c") + with self.assertRaises(AnalysisException): + df.withColumn("c", sf.col("c").cast("string")).withColumn( + "c", sf.col("c").cast("int") + ).select(df.c).collect() + + def test_resolve_after_select_alias_shadow(self): + # Same shadowing shape as withColumn but via select + alias. + df = self.spark.sql("SELECT 1 AS c") + with self.assertRaises(AnalysisException): + df.select(df.c.cast("string").alias("c")).select(df.c).collect() + + def test_resolve_after_withcolumnrenamed(self): + # withColumnRenamed drops the original `c` attribute and projects it + # as `c2`; the tagged `df.c` matches neither the original attribute + # nor a current column named `c`, so all modes raise. + df = self.spark.sql("SELECT 1 AS c") + with self.assertRaises(AnalysisException): + df.withColumnRenamed("c", "c2").select(df.c).collect() + + def test_resolve_after_drop(self): + # drop("c") removes the column entirely; the tagged `df.c` cannot + # resolve under any mode. + df = self.spark.sql("SELECT 1 AS c, 2 AS d") + with self.assertRaises(AnalysisException): + df.drop("c").select(df.c).collect() + + def test_resolve_through_filter(self): + # filter is a pass-through operator: the child Project's attributes + # flow through unchanged, so the tagged reference resolves. + df = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") + rows = df.filter(df.c > 0).select(df.c).collect() + self.assertEqual(sorted(r.c for r in rows), [1, 2]) + + def test_resolve_through_sort(self): + # sort is also a pass-through operator. + df = self.spark.sql("SELECT 2 AS c UNION ALL SELECT 1 AS c") + rows = df.sort(df.c).select(df.c).collect() + self.assertEqual([r.c for r in rows], [1, 2]) + + def test_resolve_through_distinct(self): + # distinct preserves attribute identity for column resolution. + df = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 1 AS c") + rows = df.distinct().select(df.c).collect() + self.assertEqual([r.c for r in rows], [1]) + + def test_resolve_after_groupby_count(self): + # groupBy("c").count() preserves the grouping key's attribute id, so + # the tagged reference resolves. + df = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 1 AS c UNION ALL SELECT 2 AS c") + rows = df.groupBy("c").count().select(df.c).collect() + self.assertEqual(sorted(r.c for r in rows), [1, 2]) + + def test_resolve_after_agg_alias_shadow(self): + # An aggregate output aliased `c` collides by name with the source + # `c`, but the tagged `df.c` still references the aggregated-away + # source attribute, so it cannot resolve. + df = self.spark.sql("SELECT 1 AS x") + with self.assertRaises(AnalysisException): + df.groupBy().agg(sf.sum("x").alias("c")).select(df.c).collect() + + def test_resolve_after_pivot(self): + # pivot preserves the grouping key's attribute id, so the tagged + # reference resolves. + df = self.spark.sql( + "SELECT 1 AS c, 'a' AS k, 10 AS v UNION ALL SELECT 2 AS c, 'b' AS k, 20 AS v" + ) + rows = df.groupBy("c").pivot("k").sum("v").select(df.c).collect() + self.assertEqual(sorted(r.c for r in rows), [1, 2]) + + def test_resolve_after_union(self): + # Union emits new attribute ids. Classic resolves the tagged + # left-side reference by attribute-id propagation and succeeds; + # Connect treats Union as a leaf when walking the plan tree for + # plan-id resolution and raises in both modes (overridden there). + df1 = self.spark.sql("SELECT 1 AS c") + df2 = self.spark.sql("SELECT 2 AS c") + rows = df1.union(df2).select(df1.c).collect() + self.assertEqual(sorted(r.c for r in rows), [1, 2]) + + def test_resolve_after_intersect(self): + # intersect, like union, emits new attribute ids, but the propagated + # id is retained in the output, so the tagged reference resolves in + # all modes. + df1 = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") + df2 = self.spark.sql("SELECT 2 AS c UNION ALL SELECT 3 AS c") + rows = df1.intersect(df2).select(df1.c).collect() + self.assertEqual([r.c for r in rows], [2]) + + def test_resolve_self_join_alias(self): + # Both self-join sides originate from the same plan-id-tagged + # ancestor, yielding two equal-depth candidates with the same + # attribute id. Disambiguation cannot tiebreak and all modes raise + # an ambiguous-reference error. + df = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") + a, b = df.alias("a"), df.alias("b") + with self.assertRaises(AnalysisException): + a.join(b, a.c == b.c).select(df.c).collect() + + def test_resolve_after_subquery_view(self): + # Persisting the DataFrame as a temp view and reading it back via + # table() produces a new plan; the tagged reference still resolves in + # all modes. + view = f"v_{uuid.uuid4().hex}" + df = self.spark.sql("SELECT 1 AS c") + df.createOrReplaceTempView(view) + try: + rows = self.spark.table(view).select(df.c).collect() + self.assertEqual([r.c for r in rows], [1]) + finally: + self.spark.sql(f"DROP VIEW IF EXISTS {view}") + + def test_resolve_cross_dataframe_illegal_reference(self): + # Referencing a column from a DataFrame whose plan id is not an + # ancestor of the target plan (`df1.select(df2.id)`) fails in all + # modes; the strict / lenient switch does not gate this throw. + df1 = self.spark.range(3) + df2 = self.spark.range(5) + with self.assertRaises(AnalysisException): + df1.select(df2.id).collect() + + def test_resolve_df_star(self): + # `df["*"]` is an UnresolvedDataFrameStar carrying df's plan id; the + # analyzer expands it to the matched node's output in all modes. + df = self.spark.sql( + "SELECT 'Books' AS c, 100 AS v UNION ALL SELECT 'Electronics' AS c, 200 AS v" + ) + rows = df.select(df["*"]).collect() + self.assertEqual(sorted((r.c, r.v) for r in rows), [("Books", 100), ("Electronics", 200)]) + + def test_resolve_self_join_withcolumnrenamed(self): + # Documented ColumnResolutionHelper case: df1 = range(10) + col `a`; + # df2 = df1 renamed `a` -> `b`; df1.join(df2, df1.a == df2.b). The + # node with df1's plan id is found on both Join sides; the right + # candidate is filtered out because its `a` is not in the renaming + # Project's output, so disambiguation succeeds in all modes. + df1 = self.spark.range(10).withColumn("a", sf.col("id")) + df2 = df1.withColumnRenamed("a", "b") + rows = df1.join(df2, df1.a == df2.b).select(df1.a, df2.b).collect() + self.assertEqual(len(rows), 10) + + def test_resolve_sort_missing_attr_recovery(self): + # Documented ColumnResolutionHelper case: df.select(df.v).sort(df.id) + # where df.id is not in the select's output. The analyzer descends + # through the Project, resolves df.id via plan id at the source, and + # adds it back to the upstream projection. Works in all modes. + df = self.spark.range(10).withColumn("v", sf.col("id") + 1) + rows = df.select(df.v).sort(df.id).collect() + self.assertEqual(len(rows), 10) + class ColumnTests(ColumnTestsMixin, ReusedSQLTestCase): pass From 59d2d3bdc8ba040767bd64993975e112447828b4 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 29 May 2026 12:32:50 +0000 Subject: [PATCH 12/16] Hoist Window import to module head; note Connect divergences Move ``from pyspark.sql.window import Window`` out of the three layered test methods to the module-level imports. Add a comment on each shadowing resolution test noting that Connect lenient mode diverges (name-based fallback succeeds, overridden in the lenient parity suite). Generated-by: Claude Code (Anthropic), claude-opus-4-8 --- python/pyspark/sql/tests/test_column.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index cb9716c26adb0..b4a20d9e33186 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -24,6 +24,7 @@ from pyspark.sql import Column, Row from pyspark.sql import functions as sf +from pyspark.sql.window import Window from pyspark.sql.types import StructType, StructField, IntegerType, LongType, StringType from pyspark.errors import AnalysisException, PySparkTypeError, PySparkValueError from pyspark.testing.sqlutils import ReusedSQLTestCase @@ -620,8 +621,6 @@ def test_layered_semijoin_groupby_window(self): # 4-layer DataFrame pipeline: filter -> semi-join -> groupBy/agg # -> window functions. ``layered.col`` references appear in both # filter and select at the outermost surface. - from pyspark.sql.window import Window - events_data = [ (1, 1, "Books", 100.0, 2, True), (2, 1, "Books", 50.0, 3, True), @@ -679,8 +678,6 @@ def test_layered_struct_semijoin_cube_ntile(self): # access -> cube aggregation -> window NTILE. ``layered.col`` # references appear in both filter and select at the outermost # surface. - from pyspark.sql.window import Window - events_schema = StructType( [ StructField("id", IntegerType()), @@ -764,8 +761,6 @@ def test_layered_window_window_udf(self): # 4-layer DataFrame pipeline: filter -> running-total window -> # per-partition max window -> UDF wrap. ``layered.col`` references # appear in both filter and select at the outermost surface. - from pyspark.sql.window import Window - data = [ (1, "A", 100), (2, "A", 200), @@ -833,6 +828,8 @@ def test_resolve_after_chained_withcolumn_shadow(self): # Two consecutive withColumn calls each shadow `c` with a new # attribute of the same name, so the original `c` leaves the # projection and the tagged `df.c` cannot resolve. + # Connect lenient diverges: name-based fallback resolves the + # shadowed name (overridden in the lenient parity suite). df = self.spark.sql("SELECT 1 AS c") with self.assertRaises(AnalysisException): df.withColumn("c", sf.col("c").cast("string")).withColumn( @@ -841,6 +838,8 @@ def test_resolve_after_chained_withcolumn_shadow(self): def test_resolve_after_select_alias_shadow(self): # Same shadowing shape as withColumn but via select + alias. + # Connect lenient diverges: name-based fallback resolves the + # shadowed name (overridden in the lenient parity suite). df = self.spark.sql("SELECT 1 AS c") with self.assertRaises(AnalysisException): df.select(df.c.cast("string").alias("c")).select(df.c).collect() @@ -890,6 +889,8 @@ def test_resolve_after_agg_alias_shadow(self): # An aggregate output aliased `c` collides by name with the source # `c`, but the tagged `df.c` still references the aggregated-away # source attribute, so it cannot resolve. + # Connect lenient diverges: name-based fallback resolves the + # aliased name (overridden in the lenient parity suite). df = self.spark.sql("SELECT 1 AS x") with self.assertRaises(AnalysisException): df.groupBy().agg(sf.sum("x").alias("c")).select(df.c).collect() From e2ff1aeea2986da5142aca363d117b352446c8ee Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Sat, 30 May 2026 01:36:15 +0000 Subject: [PATCH 13/16] Fix agg_alias_shadow: use df["c"] since c is not on source df df.c on a DataFrame with only column `x` raises PySparkAttributeError before reaching the analyzer. Use df["c"] so the tagged UnresolvedAttribute is built eagerly and the AnalysisException is thrown at analysis time, which is what this test pins. Same fix applied in the lenient parity override. Generated-by: Claude Code (Anthropic), claude-opus-4-8 --- python/pyspark/sql/tests/connect/test_parity_column.py | 2 +- python/pyspark/sql/tests/test_column.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests/connect/test_parity_column.py b/python/pyspark/sql/tests/connect/test_parity_column.py index ad9b4b3ba7b6e..22c145edf1269 100644 --- a/python/pyspark/sql/tests/connect/test_parity_column.py +++ b/python/pyspark/sql/tests/connect/test_parity_column.py @@ -101,7 +101,7 @@ def test_resolve_after_select_alias_shadow(self): def test_resolve_after_agg_alias_shadow(self): df = self.spark.sql("SELECT 1 AS x") - rows = df.groupBy().agg(sf.sum("x").alias("c")).select(df.c).collect() + rows = df.groupBy().agg(sf.sum("x").alias("c")).select(df["c"]).collect() self.assertEqual([r.c for r in rows], [1]) diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index b4a20d9e33186..5c6e030ee3fa1 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -893,7 +893,7 @@ def test_resolve_after_agg_alias_shadow(self): # aliased name (overridden in the lenient parity suite). df = self.spark.sql("SELECT 1 AS x") with self.assertRaises(AnalysisException): - df.groupBy().agg(sf.sum("x").alias("c")).select(df.c).collect() + df.groupBy().agg(sf.sum("x").alias("c")).select(df["c"]).collect() def test_resolve_after_pivot(self): # pivot preserves the grouping key's attribute id, so the tagged From 64e5d39c4e3031b0a816f894ff374935baec4a82 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Sat, 30 May 2026 02:06:54 +0000 Subject: [PATCH 14/16] Fix agg_alias_shadow: use df with column c so df.c works SELECT 1 AS x gave a df with no column c, so df.c raised PySparkAttributeError before Spark analysis. Change the source to SELECT 1 AS c and shadow via agg(sum("c").alias("c")) so df.c is valid and the tagged reference is aggregated away - same semantics, correct form. Generated-by: Claude Code (Anthropic), claude-opus-4-8 --- python/pyspark/sql/tests/connect/test_parity_column.py | 4 ++-- python/pyspark/sql/tests/test_column.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests/connect/test_parity_column.py b/python/pyspark/sql/tests/connect/test_parity_column.py index 22c145edf1269..a2b00d7955eee 100644 --- a/python/pyspark/sql/tests/connect/test_parity_column.py +++ b/python/pyspark/sql/tests/connect/test_parity_column.py @@ -100,8 +100,8 @@ def test_resolve_after_select_alias_shadow(self): self.assertEqual([r.c for r in rows], ["1"]) def test_resolve_after_agg_alias_shadow(self): - df = self.spark.sql("SELECT 1 AS x") - rows = df.groupBy().agg(sf.sum("x").alias("c")).select(df["c"]).collect() + df = self.spark.sql("SELECT 1 AS c") + rows = df.groupBy().agg(sf.sum("c").alias("c")).select(df.c).collect() self.assertEqual([r.c for r in rows], [1]) diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index 5c6e030ee3fa1..14298027de7f8 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -891,9 +891,9 @@ def test_resolve_after_agg_alias_shadow(self): # source attribute, so it cannot resolve. # Connect lenient diverges: name-based fallback resolves the # aliased name (overridden in the lenient parity suite). - df = self.spark.sql("SELECT 1 AS x") + df = self.spark.sql("SELECT 1 AS c") with self.assertRaises(AnalysisException): - df.groupBy().agg(sf.sum("x").alias("c")).select(df["c"]).collect() + df.groupBy().agg(sf.sum("c").alias("c")).select(df.c).collect() def test_resolve_after_pivot(self): # pivot preserves the grouping key's attribute id, so the tagged From 6e5088aeb2ab4ad74199210144621de31e172a87 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 31 May 2026 09:14:03 +0000 Subject: [PATCH 15/16] Correct the union/intersect divergence comments Both Union and Intersect preserve the left child's exprId (Union/Intersect.mergeChildOutputs), so neither "emits new attribute ids". The actual Connect divergence is that Union is treated as a leaf during plan-id resolution (ColumnResolutionHelper) while Intersect is not. Co-authored-by: Isaac --- python/pyspark/sql/tests/test_column.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index 14298027de7f8..b75be8889084a 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -905,19 +905,24 @@ def test_resolve_after_pivot(self): self.assertEqual(sorted(r.c for r in rows), [1, 2]) def test_resolve_after_union(self): - # Union emits new attribute ids. Classic resolves the tagged - # left-side reference by attribute-id propagation and succeeds; - # Connect treats Union as a leaf when walking the plan tree for - # plan-id resolution and raises in both modes (overridden there). + # Union's output keeps the left child's attribute ids + # (Union.mergeChildOutputs), so Classic resolves the tagged + # left-side reference directly against that output and succeeds. + # Connect resolves by walking the plan tree for the plan id but + # treats Union as a leaf (ColumnResolutionHelper), so the id below + # the Union is never found and it raises in both modes (overridden + # there). df1 = self.spark.sql("SELECT 1 AS c") df2 = self.spark.sql("SELECT 2 AS c") rows = df1.union(df2).select(df1.c).collect() self.assertEqual(sorted(r.c for r in rows), [1, 2]) def test_resolve_after_intersect(self): - # intersect, like union, emits new attribute ids, but the propagated - # id is retained in the output, so the tagged reference resolves in - # all modes. + # Intersect's output also keeps the left child's attribute ids + # (Intersect.mergeChildOutputs). Unlike Union, it is not treated as + # a leaf during plan-id resolution, so Connect's tree walk descends + # into the left child, finds the tagged node and resolves it; all + # modes succeed. df1 = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") df2 = self.spark.sql("SELECT 2 AS c UNION ALL SELECT 3 AS c") rows = df1.intersect(df2).select(df1.c).collect() From 2f08dc1aaf4f48484f2f883053f9dadad0dfab21 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 31 May 2026 10:11:29 +0000 Subject: [PATCH 16/16] Reword cube comment: 'rollup' -> 'subtotal' for the (cat, status, NULL) group CUBE produces subtotal groups; ROLLUP is a distinct grouping operator, so 'subtotal' describes the (category, status, NULL) group more precisely. Co-authored-by: Isaac --- python/pyspark/sql/tests/test_column.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index b75be8889084a..6a99c7de1a52d 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -750,7 +750,7 @@ def test_layered_struct_semijoin_cube_ntile(self): .collect() ) # Cube emits one (category, status, detail_name) group per distinct - # combination plus one (category, status, NULL) rollup per distinct + # combination plus one (category, status, NULL) subtotal per distinct # (category, status) pair. The where filter keeps both. self.assertEqual(len(rows), 6) self.assertEqual({r.category for r in rows}, {"Books", "Electronics", "Home"})