diff --git a/python/pyspark/sql/tests/connect/test_parity_column.py b/python/pyspark/sql/tests/connect/test_parity_column.py index 3903bb57a3750..a2b00d7955eee 100644 --- a/python/pyspark/sql/tests/connect/test_parity_column.py +++ b/python/pyspark/sql/tests/connect/test_parity_column.py @@ -17,6 +17,8 @@ import unittest +from pyspark.errors import AnalysisException +from pyspark.sql import functions as sf from pyspark.sql.tests.test_column import ColumnTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase @@ -38,6 +40,16 @@ def tearDownClass(cls): def test_validate_column_types(self): super().test_validate_column_types() + def test_resolve_after_union(self): + # Connect diverges from Classic here: Union is treated as a leaf when + # walking the plan tree for plan-id resolution, so the left-side plan + # id is never found and CANNOT_RESOLVE_DATAFRAME_COLUMN is thrown + # before any name-based fallback - in both strict and lenient modes. + df1 = self.spark.sql("SELECT 1 AS c") + df2 = self.spark.sql("SELECT 2 AS c") + with self.assertRaisesRegex(AnalysisException, "CANNOT_RESOLVE_DATAFRAME_COLUMN"): + df1.union(df2).select(df1.c).collect() + def test_df_col_resolution_mode(self): self.assertEqual( self.spark.conf.get("spark.sql.analyzer.strictDataFrameColumnResolution"), @@ -68,6 +80,30 @@ def test_df_col_resolution_mode(self): "false", ) + # The shadowing trio diverges in lenient mode: where Classic and Connect + # strict raise, lenient resolves the tagged reference by name against the + # current (shadowed) output. + + def test_resolve_after_chained_withcolumn_shadow(self): + df = self.spark.sql("SELECT 1 AS c") + rows = ( + df.withColumn("c", sf.col("c").cast("string")) + .withColumn("c", sf.col("c").cast("int")) + .select(df.c) + .collect() + ) + self.assertEqual([r.c for r in rows], [1]) + + def test_resolve_after_select_alias_shadow(self): + df = self.spark.sql("SELECT 1 AS c") + rows = df.select(df.c.cast("string").alias("c")).select(df.c).collect() + self.assertEqual([r.c for r in rows], ["1"]) + + def test_resolve_after_agg_alias_shadow(self): + df = self.spark.sql("SELECT 1 AS c") + rows = df.groupBy().agg(sf.sum("c").alias("c")).select(df.c).collect() + self.assertEqual([r.c for r in rows], [1]) + if __name__ == "__main__": from pyspark.testing import main diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index 74a7746b154de..6a99c7de1a52d 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -20,10 +20,12 @@ from itertools import chain import datetime import unittest +import uuid from pyspark.sql import Column, Row from pyspark.sql import functions as sf -from pyspark.sql.types import StructType, StructField, IntegerType, LongType +from pyspark.sql.window import Window +from pyspark.sql.types import StructType, StructField, IntegerType, LongType, StringType from pyspark.errors import AnalysisException, PySparkTypeError, PySparkValueError from pyspark.testing.sqlutils import ReusedSQLTestCase from pyspark.testing.utils import have_pandas, pandas_requirement_message @@ -605,6 +607,388 @@ def test_drop_notexistent_col(self): self.assertEqual(df4.columns, ["colA", "colB", "colC", "colC", "colD", "colE"]) self.assertEqual(df4.count(), 1) + # --- Mixed-surface layered DataFrame programs --------------------------- + # + # These tests chain multiple DataFrame transformations - semi-joins + # (for SQL EXISTS/IN), window functions, cube aggregations, UDFs and + # struct field access - into 4-5 layer pipelines, then reference the + # final layered DataFrame's columns via ``layered.col`` in both filter + # and select at the outermost surface. The goal is to catch regressions + # in plan-id propagation across analyzer rules that single-operator + # tests miss when rules interact. + + def test_layered_semijoin_groupby_window(self): + # 4-layer DataFrame pipeline: filter -> semi-join -> groupBy/agg + # -> window functions. ``layered.col`` references appear in both + # filter and select at the outermost surface. + events_data = [ + (1, 1, "Books", 100.0, 2, True), + (2, 1, "Books", 50.0, 3, True), + (3, 2, "Electronics", 200.0, 1, True), + (4, 2, "Electronics", 300.0, 2, True), + (5, 3, "Home", 80.0, 4, True), + (6, 4, "Books", 60.0, 1, False), + ] + users_data = [(1, 25), (2, 30), (3, 22), (4, 18)] + events_cols = ["id", "user_id", "category", "amount", "quantity", "is_active"] + users_cols = ["id", "age"] + + events = self.spark.createDataFrame(events_data, events_cols) + users = self.spark.createDataFrame(users_data, users_cols) + # Layer 1: filter + semi-join (DataFrame-API equivalent of + # WHERE is_active AND EXISTS (user with age > 20)). + active = events.where(events.is_active).join( + users.where(users.age > 20), + events.user_id == users.id, + "left_semi", + ) + # Layer 2: groupBy + agg, then post-agg filter (HAVING equivalent). + agg = active.groupBy("category").agg( + sf.sum(active.amount * active.quantity * sf.lit(0.1)).alias("total_amt"), + sf.sum(active.amount).alias("amount_sum"), + ) + totals = agg.where(agg.amount_sum > 50).select("category", "total_amt") + # Layer 3: window functions on top of the aggregate. + running = Window.orderBy("total_amt").rowsBetween(-1, 1) + ranking = Window.orderBy(totals.total_amt.desc()) + windowed = totals.select( + "category", + "total_amt", + sf.avg(totals.total_amt).over(running).alias("running_avg"), + sf.rank().over(ranking).alias("rank_num"), + ) + # Layer 4: outer filter. + layered = windowed.where(windowed.rank_num <= 5) + + rows = ( + layered.filter(layered.rank_num <= 3) + .select( + layered.category, + layered.total_amt, + layered.running_avg, + layered.rank_num, + ) + .collect() + ) + result = sorted((r.category, r.rank_num) for r in rows) + self.assertEqual(result, [("Books", 2), ("Electronics", 1), ("Home", 3)]) + + def test_layered_struct_semijoin_cube_ntile(self): + # 5-layer DataFrame pipeline: filter -> semi-join -> struct field + # access -> cube aggregation -> window NTILE. ``layered.col`` + # references appear in both filter and select at the outermost + # surface. + events_schema = StructType( + [ + StructField("id", IntegerType()), + StructField("category", StringType()), + StructField("status", StringType()), + StructField("amount", IntegerType()), + StructField("quantity", IntegerType()), + StructField( + "detail", + StructType( + [ + StructField("name", StringType()), + StructField("nested", StructType([StructField("x", IntegerType())])), + ] + ), + ), + ] + ) + events_data = [ + (1, "Books", "A", 100, 5, ("alpha", (1,))), + (2, "Electronics", "B", 200, 3, ("beta", (2,))), + (3, "Books", "A", 50, 7, ("alpha", (1,))), + (4, "Electronics", "B", 300, 4, ("beta", (2,))), + (5, "Home", "C", 80, 2, ("gamma", (3,))), + ] + categories_data = [("Books", 1), ("Electronics", 2), ("Home", 3), ("Toys", 5)] + categories_cols = ["name", "priority"] + + events = self.spark.createDataFrame(events_data, events_schema) + categories = self.spark.createDataFrame(categories_data, categories_cols) + # Layer 1: filter + semi-join (DataFrame-API equivalent of + # WHERE quantity > 1 AND category IN (SELECT ...)). + filtered = events.where(events.quantity > 1).join( + categories.where(categories.priority <= 3), + events.category == categories.name, + "left_semi", + ) + # Layer 2: project with struct field access (struct subfields use + # bracket access since ``detail.name`` would hit ``Column.name``). + base = filtered.select( + filtered.id, + filtered.category, + filtered.status, + filtered.amount, + filtered.detail["name"].alias("detail_name"), + filtered.detail["nested"]["x"].alias("nx"), + ) + # Layer 3: cube aggregation (mixed grouping levels - similar + # surface area to SQL GROUPING SETS without an exact equivalent + # in the DataFrame API). + agg = base.cube("category", "status", "detail_name").agg( + sf.sum(base.amount).alias("total"), sf.count(sf.lit(1)).alias("cnt") + ) + grouped = agg.where(agg.category.isNotNull() & agg.status.isNotNull()) + # Layer 4: NTILE window. + tiled = grouped.withColumn("tile", sf.ntile(2).over(Window.orderBy(grouped.total.desc()))) + # Layer 5: outer filter. + layered = tiled.where(tiled.tile <= 2) + + rows = ( + layered.filter(layered.tile >= 1) + .select( + layered.category, + layered.status, + layered.detail_name, + layered.total, + layered.cnt, + layered.tile, + ) + .collect() + ) + # Cube emits one (category, status, detail_name) group per distinct + # combination plus one (category, status, NULL) subtotal per distinct + # (category, status) pair. The where filter keeps both. + self.assertEqual(len(rows), 6) + self.assertEqual({r.category for r in rows}, {"Books", "Electronics", "Home"}) + self.assertEqual({r.total for r in rows}, {80, 150, 500}) + self.assertEqual({r.tile for r in rows}, {1, 2}) + + def test_layered_window_window_udf(self): + # 4-layer DataFrame pipeline: filter -> running-total window -> + # per-partition max window -> UDF wrap. ``layered.col`` references + # appear in both filter and select at the outermost surface. + data = [ + (1, "A", 100), + (2, "A", 200), + (3, "B", 150), + (4, "B", 250), + (5, "C", 50), + ] + cols = ["id", "category", "amount"] + + df = self.spark.createDataFrame(data, cols) + # Layer 1: filter (replaces WHERE EXISTS amount > 0). + filtered = df.where(df.amount > 0) + # Layer 2: running total window. + run_w = Window.partitionBy("category").orderBy("id") + with_run = filtered.withColumn("run_amt", sf.sum(filtered.amount).over(run_w)) + # Layer 3: per-category max window (replaces correlated subquery + # for cat_max). + cat_w = Window.partitionBy("category") + with_max = with_run.withColumn("cat_max", sf.max(with_run.amount).over(cat_w)) + # Layer 4: UDF. + double = sf.udf(lambda x: x * 2 if x is not None else None, IntegerType()) + layered = with_max.withColumn("doubled_amt", double(with_max.amount)) + + rows = ( + layered.filter(layered.amount > 0) + .select( + layered.id, + layered.category, + layered.amount, + layered.run_amt, + layered.cat_max, + layered.doubled_amt, + ) + .collect() + ) + result = sorted( + (r.id, r.category, r.amount, r.run_amt, r.cat_max, r.doubled_amt) for r in rows + ) + self.assertEqual( + result, + [ + (1, "A", 100, 100, 200, 200), + (2, "A", 200, 300, 200, 400), + (3, "B", 150, 150, 250, 300), + (4, "B", 250, 400, 250, 500), + (5, "C", 50, 50, 50, 100), + ], + ) + + # --- Tagged DataFrame column resolution -------------------------------- + # + # ``df.col`` / ``df["col"]`` carries the source DataFrame's plan id. These + # tests pin how that tagged reference resolves after assorted operators. + # The behavior is shared across Spark Classic and Spark Connect (both + # ``spark.sql.analyzer.strictDataFrameColumnResolution`` modes) except for + # a few diverging cases, which are overridden in the Connect parity suites + # (``ColumnParityTests`` / ``...WithNonStrictDFColResolution``): + # + # * the shadowing trio - Classic and Connect strict raise, Connect + # lenient resolves the shadowed name via name-based fallback; + # * union - Classic resolves via attribute-id propagation, Connect + # raises in both modes. + + def test_resolve_after_chained_withcolumn_shadow(self): + # Two consecutive withColumn calls each shadow `c` with a new + # attribute of the same name, so the original `c` leaves the + # projection and the tagged `df.c` cannot resolve. + # Connect lenient diverges: name-based fallback resolves the + # shadowed name (overridden in the lenient parity suite). + df = self.spark.sql("SELECT 1 AS c") + with self.assertRaises(AnalysisException): + df.withColumn("c", sf.col("c").cast("string")).withColumn( + "c", sf.col("c").cast("int") + ).select(df.c).collect() + + def test_resolve_after_select_alias_shadow(self): + # Same shadowing shape as withColumn but via select + alias. + # Connect lenient diverges: name-based fallback resolves the + # shadowed name (overridden in the lenient parity suite). + df = self.spark.sql("SELECT 1 AS c") + with self.assertRaises(AnalysisException): + df.select(df.c.cast("string").alias("c")).select(df.c).collect() + + def test_resolve_after_withcolumnrenamed(self): + # withColumnRenamed drops the original `c` attribute and projects it + # as `c2`; the tagged `df.c` matches neither the original attribute + # nor a current column named `c`, so all modes raise. + df = self.spark.sql("SELECT 1 AS c") + with self.assertRaises(AnalysisException): + df.withColumnRenamed("c", "c2").select(df.c).collect() + + def test_resolve_after_drop(self): + # drop("c") removes the column entirely; the tagged `df.c` cannot + # resolve under any mode. + df = self.spark.sql("SELECT 1 AS c, 2 AS d") + with self.assertRaises(AnalysisException): + df.drop("c").select(df.c).collect() + + def test_resolve_through_filter(self): + # filter is a pass-through operator: the child Project's attributes + # flow through unchanged, so the tagged reference resolves. + df = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") + rows = df.filter(df.c > 0).select(df.c).collect() + self.assertEqual(sorted(r.c for r in rows), [1, 2]) + + def test_resolve_through_sort(self): + # sort is also a pass-through operator. + df = self.spark.sql("SELECT 2 AS c UNION ALL SELECT 1 AS c") + rows = df.sort(df.c).select(df.c).collect() + self.assertEqual([r.c for r in rows], [1, 2]) + + def test_resolve_through_distinct(self): + # distinct preserves attribute identity for column resolution. + df = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 1 AS c") + rows = df.distinct().select(df.c).collect() + self.assertEqual([r.c for r in rows], [1]) + + def test_resolve_after_groupby_count(self): + # groupBy("c").count() preserves the grouping key's attribute id, so + # the tagged reference resolves. + df = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 1 AS c UNION ALL SELECT 2 AS c") + rows = df.groupBy("c").count().select(df.c).collect() + self.assertEqual(sorted(r.c for r in rows), [1, 2]) + + def test_resolve_after_agg_alias_shadow(self): + # An aggregate output aliased `c` collides by name with the source + # `c`, but the tagged `df.c` still references the aggregated-away + # source attribute, so it cannot resolve. + # Connect lenient diverges: name-based fallback resolves the + # aliased name (overridden in the lenient parity suite). + df = self.spark.sql("SELECT 1 AS c") + with self.assertRaises(AnalysisException): + df.groupBy().agg(sf.sum("c").alias("c")).select(df.c).collect() + + def test_resolve_after_pivot(self): + # pivot preserves the grouping key's attribute id, so the tagged + # reference resolves. + df = self.spark.sql( + "SELECT 1 AS c, 'a' AS k, 10 AS v UNION ALL SELECT 2 AS c, 'b' AS k, 20 AS v" + ) + rows = df.groupBy("c").pivot("k").sum("v").select(df.c).collect() + self.assertEqual(sorted(r.c for r in rows), [1, 2]) + + def test_resolve_after_union(self): + # Union's output keeps the left child's attribute ids + # (Union.mergeChildOutputs), so Classic resolves the tagged + # left-side reference directly against that output and succeeds. + # Connect resolves by walking the plan tree for the plan id but + # treats Union as a leaf (ColumnResolutionHelper), so the id below + # the Union is never found and it raises in both modes (overridden + # there). + df1 = self.spark.sql("SELECT 1 AS c") + df2 = self.spark.sql("SELECT 2 AS c") + rows = df1.union(df2).select(df1.c).collect() + self.assertEqual(sorted(r.c for r in rows), [1, 2]) + + def test_resolve_after_intersect(self): + # Intersect's output also keeps the left child's attribute ids + # (Intersect.mergeChildOutputs). Unlike Union, it is not treated as + # a leaf during plan-id resolution, so Connect's tree walk descends + # into the left child, finds the tagged node and resolves it; all + # modes succeed. + df1 = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") + df2 = self.spark.sql("SELECT 2 AS c UNION ALL SELECT 3 AS c") + rows = df1.intersect(df2).select(df1.c).collect() + self.assertEqual([r.c for r in rows], [2]) + + def test_resolve_self_join_alias(self): + # Both self-join sides originate from the same plan-id-tagged + # ancestor, yielding two equal-depth candidates with the same + # attribute id. Disambiguation cannot tiebreak and all modes raise + # an ambiguous-reference error. + df = self.spark.sql("SELECT 1 AS c UNION ALL SELECT 2 AS c") + a, b = df.alias("a"), df.alias("b") + with self.assertRaises(AnalysisException): + a.join(b, a.c == b.c).select(df.c).collect() + + def test_resolve_after_subquery_view(self): + # Persisting the DataFrame as a temp view and reading it back via + # table() produces a new plan; the tagged reference still resolves in + # all modes. + view = f"v_{uuid.uuid4().hex}" + df = self.spark.sql("SELECT 1 AS c") + df.createOrReplaceTempView(view) + try: + rows = self.spark.table(view).select(df.c).collect() + self.assertEqual([r.c for r in rows], [1]) + finally: + self.spark.sql(f"DROP VIEW IF EXISTS {view}") + + def test_resolve_cross_dataframe_illegal_reference(self): + # Referencing a column from a DataFrame whose plan id is not an + # ancestor of the target plan (`df1.select(df2.id)`) fails in all + # modes; the strict / lenient switch does not gate this throw. + df1 = self.spark.range(3) + df2 = self.spark.range(5) + with self.assertRaises(AnalysisException): + df1.select(df2.id).collect() + + def test_resolve_df_star(self): + # `df["*"]` is an UnresolvedDataFrameStar carrying df's plan id; the + # analyzer expands it to the matched node's output in all modes. + df = self.spark.sql( + "SELECT 'Books' AS c, 100 AS v UNION ALL SELECT 'Electronics' AS c, 200 AS v" + ) + rows = df.select(df["*"]).collect() + self.assertEqual(sorted((r.c, r.v) for r in rows), [("Books", 100), ("Electronics", 200)]) + + def test_resolve_self_join_withcolumnrenamed(self): + # Documented ColumnResolutionHelper case: df1 = range(10) + col `a`; + # df2 = df1 renamed `a` -> `b`; df1.join(df2, df1.a == df2.b). The + # node with df1's plan id is found on both Join sides; the right + # candidate is filtered out because its `a` is not in the renaming + # Project's output, so disambiguation succeeds in all modes. + df1 = self.spark.range(10).withColumn("a", sf.col("id")) + df2 = df1.withColumnRenamed("a", "b") + rows = df1.join(df2, df1.a == df2.b).select(df1.a, df2.b).collect() + self.assertEqual(len(rows), 10) + + def test_resolve_sort_missing_attr_recovery(self): + # Documented ColumnResolutionHelper case: df.select(df.v).sort(df.id) + # where df.id is not in the select's output. The analyzer descends + # through the Project, resolves df.id via plan id at the source, and + # adds it back to the upstream projection. Works in all modes. + df = self.spark.range(10).withColumn("v", sf.col("id") + 1) + rows = df.select(df.v).sort(df.id).collect() + self.assertEqual(len(rows), 10) + class ColumnTests(ColumnTestsMixin, ReusedSQLTestCase): pass