diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/optimize/CorrelVariableNormalizerShuttle.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/optimize/CorrelVariableNormalizerShuttle.java new file mode 100644 index 0000000000000..8148941b5fd61 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/optimize/CorrelVariableNormalizerShuttle.java @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.optimize; + +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelShuttleImpl; +import org.apache.calcite.rel.core.CorrelationId; +import org.apache.calcite.rel.logical.LogicalCorrelate; +import org.apache.calcite.rel.logical.LogicalFilter; +import org.apache.calcite.rel.logical.LogicalJoin; +import org.apache.calcite.rel.logical.LogicalProject; +import org.apache.calcite.rel.logical.LogicalTableFunctionScan; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCorrelVariable; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.rex.RexSubQuery; + +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Optional; + +/** + * Normalizes correlation variable ids in a RelNode tree to make equivalent subplans digest-match. + */ +public final class CorrelVariableNormalizerShuttle extends RelShuttleImpl { + + private final Map idMap = new LinkedHashMap<>(); + + private final RexBuilder rexBuilder; + private final RexShuttle rexCorrelNormalizer; + + public CorrelVariableNormalizerShuttle(RexBuilder rexBuilder) { + this.rexBuilder = rexBuilder; + rexCorrelNormalizer = new RexCorrelNormalizer(); + } + + @Override + public RelNode visit(LogicalCorrelate correlate) { + var adjustedId = adjustCorrelationId(correlate.getCorrelationId()); + if (adjustedId.isPresent()) { + var left = correlate.getLeft().accept(this); + var right = correlate.getRight().accept(this); + return correlate.copy( + correlate.getTraitSet(), + left, + right, + adjustedId.get(), + correlate.getRequiredColumns(), + correlate.getJoinType()); + } + + return super.visit(correlate); + } + + @Override + public RelNode visit(RelNode relNode) { + if (relNode instanceof LogicalTableFunctionScan && relNode.getInputs().isEmpty()) { + // visitChild applies the RexShuttle while walking RelNode inputs. A zero-input table + // function scan is a leaf, but unlike a regular TableScan it can still contain RexNodes + // (e.g., UNNEST over a correl variable), so rewrite it explicitly. + return relNode.accept(rexCorrelNormalizer); + } + + return super.visit(relNode); + } + + @Override + protected RelNode visitChild(RelNode parent, int i, RelNode child) { + if (i == 0) { + parent = parent.accept(rexCorrelNormalizer); + parent = remapVariablesSet(parent); + } + + return super.visitChild(parent, i, child); + } + + /** + * Filter, Project, and Join carry a {@link CorrelationId} set alongside their RexNodes. {@code + * RelNode.accept(RexShuttle)} only rewrites the RexNodes and preserves the old {@code + * variablesSet} via {@code copy()}, so ids we just adjusted in the condition/projects are still + * advertised under their old names. To overcome that, we need to rebuild that variable with the + * adjusted ids set as well. + */ + private RelNode remapVariablesSet(RelNode relNode) { + var oldSet = relNode.getVariablesSet(); + if (oldSet.isEmpty()) { + return relNode; + } + + var builder = com.google.common.collect.ImmutableSet.builder(); + boolean changed = false; + for (var id : oldSet) { + var adjusted = adjustCorrelationId(id); + if (adjusted.isPresent()) { + builder.add(adjusted.get()); + changed = true; + } else { + builder.add(id); + } + } + + if (!changed) { + return relNode; + } + + var newSet = builder.build(); + if (relNode instanceof LogicalFilter) { + var filter = (LogicalFilter) relNode; + return new LogicalFilter( + filter.getCluster(), + filter.getTraitSet(), + filter.getHints(), + filter.getInput(), + filter.getCondition(), + newSet); + } + + if (relNode instanceof LogicalProject) { + var project = (LogicalProject) relNode; + return new LogicalProject( + project.getCluster(), + project.getTraitSet(), + project.getHints(), + project.getInput(), + project.getProjects(), + project.getRowType(), + newSet); + } + + if (relNode instanceof LogicalJoin) { + var join = (LogicalJoin) relNode; + return new LogicalJoin( + join.getCluster(), + join.getTraitSet(), + join.getHints(), + join.getLeft(), + join.getRight(), + join.getCondition(), + newSet, + join.getJoinType(), + join.isSemiJoinDone(), + com.google.common.collect.ImmutableList.copyOf(join.getSystemFieldList())); + } + + return relNode; + } + + private Optional adjustCorrelationId(CorrelationId correlationId) { + if (correlationId.getName().startsWith(CorrelationId.CORREL_PREFIX)) { + int oldId = correlationId.getId(); + int newId = idMap.computeIfAbsent(oldId, k -> idMap.size() + 1); + if (newId != oldId) { + return Optional.of(new CorrelationId(newId)); + } + } + + return Optional.empty(); + } + + private final class RexCorrelNormalizer extends RexShuttle { + + @Override + public RexNode visitCorrelVariable(RexCorrelVariable variable) { + var adjustedId = adjustCorrelationId(variable.id); + if (adjustedId.isPresent()) { + return rexBuilder.makeCorrel(variable.getType(), adjustedId.get()); + } else { + return super.visitCorrelVariable(variable); + } + } + + @Override + public RexNode visitSubQuery(RexSubQuery subQuery) { + // Let the base shuttle rewrite the RexSubQuery's operands first, so any + // RexCorrelVariables they carry (e.g., the LHS of IN/SOME) are also adjusted. + var withOperands = (RexSubQuery) super.visitSubQuery(subQuery); + var rewritten = withOperands.rel.accept(CorrelVariableNormalizerShuttle.this); + + return rewritten == withOperands.rel ? withOperands : withOperands.clone(rewritten); + } + } +} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/RelNodeBlock.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/RelNodeBlock.scala index 0a61c191f390f..3d4f771245f65 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/RelNodeBlock.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/RelNodeBlock.scala @@ -412,10 +412,26 @@ object RelNodeBlockPlanBuilder { return relNodes } - // reuse sub-plan with same digest in input RelNode trees. + // The reuse lookup uses the original trees, while the rewrite runs on normalized + // trees. This keeps existing reuse unchanged: normalization + // does not change subtrees without correlation variables, so they still reuse as before. Subtrees with + // correlation variables (e.g., CROSS JOIN UNNEST or decorrelated sub-queries) + // used to have different digests in each view expansion, so they were not reused. + // + // Reusing those newly matching correlated subtrees is not safe yet. If such a + // subtree is shared, it can become a separate RelNodeBlock and be optimized without + // seeing its parent operators. During that local optimization, ROWTIME output fields + // may be converted to regular TIMESTAMP_LTZ fields. The parents still refer to the + // old ROWTIME-typed fields, so replacing the child can fail validation with a + // ROWTIME/plain timestamp mismatch. val context = new SubplanReuseContext(true, relNodes: _*) val reuseShuttle = new SubplanReuseShuttle(context) - relNodes.map(_.accept(reuseShuttle)) + + relNodes + // Normalize correlation variable ids per node so structurally equivalent + // subplans will share digests. + .map(n => n.accept(new CorrelVariableNormalizerShuttle(n.getCluster.getRexBuilder))) + .map(_.accept(reuseShuttle)) } } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/optimize/CorrelVariableNormalizerShuttleTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/optimize/CorrelVariableNormalizerShuttleTest.java new file mode 100644 index 0000000000000..4bf765958b574 --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/optimize/CorrelVariableNormalizerShuttleTest.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.optimize; + +import org.apache.flink.table.planner.calcite.FlinkRelBuilder; +import org.apache.flink.table.planner.utils.PlannerMocks; + +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.CorrelationId; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.LogicalFilter; +import org.apache.calcite.rel.logical.LogicalJoin; +import org.apache.calcite.rel.logical.LogicalProject; +import org.apache.calcite.rel.logical.LogicalValues; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexSubQuery; +import org.apache.calcite.sql.SqlOperatorBinding; +import org.apache.calcite.sql.fun.SqlCollectionTableOperator; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.validate.SqlModality; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Set; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link CorrelVariableNormalizerShuttle}. */ +class CorrelVariableNormalizerShuttleTest { + + private FlinkRelBuilder relBuilder; + private RexBuilder rexBuilder; + private RelOptCluster cluster; + + @BeforeEach + void before() { + relBuilder = PlannerMocks.create().getPlannerContext().createRelBuilder(); + rexBuilder = relBuilder.getRexBuilder(); + cluster = relBuilder.getCluster(); + } + + @Test + void testNormalizesIdsInEncounterOrderWithoutNegativeIds() { + RelNode input = oneRow(); + RelDataType correlType = singleFieldType("a"); + RelNode project = + LogicalProject.create( + input, + List.of(), + List.of( + correlField(correlType, new CorrelationId(5)), + correlField(correlType, new CorrelationId(2))), + List.of("a", "b"), + Set.of(new CorrelationId(5), new CorrelationId(2))); + + RelNode normalized = normalize(project); + String plan = RelOptUtil.toString(normalized); + + assertThat(plan).contains("$cor1", "$cor2"); + assertThat(plan).doesNotContain("$cor5", "$cor-", "$cor0"); + } + + @Test + void testNormalizesZeroInputTableFunctionScanRexNodes() { + CorrelationId oldId = new CorrelationId(5); + RelDataType inputType = singleFieldType("a"); + RelDataType outputType = singleFieldType("b"); + SqlCollectionTableOperator tableOperator = + new SqlCollectionTableOperator("TABLE", SqlModality.RELATION) { + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + return outputType; + } + }; + + RelNode relNode = + relBuilder + .push(oneRow()) + .functionScan(tableOperator, 0, correlField(inputType, oldId)) + .correlate(JoinRelType.INNER, oldId) + .build(); + + String plan = RelOptUtil.toString(normalize(relNode)); + + assertThat(plan).contains("$cor1"); + assertThat(plan).doesNotContain("$cor5"); + } + + @Test + void testNormalizesRexSubQueryOperandsAndRelTree() { + CorrelationId oldId = new CorrelationId(5); + RelDataType correlType = singleFieldType("a"); + RelNode subQueryRel = + LogicalProject.create( + oneRow(), + List.of(), + List.of(correlField(correlType, oldId)), + List.of("a"), + Set.of(oldId)); + RexSubQuery subQuery = + RexSubQuery.in( + subQueryRel, + com.google.common.collect.ImmutableList.of(correlField(correlType, oldId))); + RelNode filter = LogicalFilter.create(oneRow(), subQuery); + + String plan = RelOptUtil.toString(normalize(filter)); + + assertThat(plan).contains("$cor1"); + assertThat(plan).doesNotContain("$cor5"); + } + + @Test + void testRemapsVariablesSetForFilterProjectAndJoin() { + CorrelationId oldId = new CorrelationId(5); + CorrelationId newId = new CorrelationId(1); + RelNode input = oneRow(); + + RelNode filter = + LogicalFilter.create( + input, + rexBuilder.makeLiteral(true), + com.google.common.collect.ImmutableSet.of(oldId)); + assertThat(normalize(filter).getVariablesSet()).containsExactly(newId); + + RelNode project = + LogicalProject.create( + input, + List.of(), + List.of(rexBuilder.makeInputRef(input, 0)), + List.of("ZERO"), + Set.of(oldId)); + assertThat(normalize(project).getVariablesSet()).containsExactly(newId); + + RelNode join = + LogicalJoin.create( + oneRow(), + oneRow(), + List.of(), + rexBuilder.makeLiteral(true), + Set.of(oldId), + JoinRelType.INNER); + assertThat(normalize(join).getVariablesSet()).containsExactly(newId); + } + + private RelNode normalize(RelNode relNode) { + return relNode.accept(new CorrelVariableNormalizerShuttle(rexBuilder)); + } + + private RelNode oneRow() { + return LogicalValues.createOneRow(cluster); + } + + private RexNode correlField(RelDataType correlType, CorrelationId correlationId) { + return rexBuilder.makeFieldAccess(rexBuilder.makeCorrel(correlType, correlationId), 0); + } + + private RelDataType singleFieldType(String fieldName) { + return relBuilder + .getTypeFactory() + .createStructType( + List.of(relBuilder.getTypeFactory().createSqlType(SqlTypeName.BIGINT)), + List.of(fieldName)); + } +} diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/SubplanReuseTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/SubplanReuseTest.xml index 5bf4359078a33..08678375c8c07 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/SubplanReuseTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/SubplanReuseTest.xml @@ -902,6 +902,125 @@ Union(all=[true], union=[a, b]) : +- TableSourceScan(table=[[default_catalog, default_database, x]], fields=[a, b, c]) +- Calc(select=[a, (b * 2) AS b], where=[(b < 10)]) +- Reused(reference_id=[1]) +]]> + + + + + + + + diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/SubplanReuseTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/SubplanReuseTest.scala index 02e3255e68f21..2c7c66d462b74 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/SubplanReuseTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/SubplanReuseTest.scala @@ -442,6 +442,98 @@ class SubplanReuseTest extends TableTestBase { util.verifyExecPlan(stmtSet) } + @Test + def testSubplanReuseOnTemporalJoinWithUnnest(): Unit = { + util.tableEnv.getConfig.set( + OptimizerConfigOptions.TABLE_OPTIMIZER_REUSE_OPTIMIZE_BLOCK_WITH_DIGEST_ENABLED, + Boolean.box(true)) + + util.addTable(""" + |CREATE TABLE Versioned ( + | k BIGINT NOT NULL, + | v BIGINT NOT NULL, + | ts TIMESTAMP(3) WITH LOCAL TIME ZONE NOT NULL, + | WATERMARK FOR ts AS ts + |) WITH ('connector' = 'values') + """.stripMargin) + + util.addTable(""" + |CREATE TABLE Probe ( + | k BIGINT NOT NULL, + | arr ARRAY NOT NULL> NOT NULL, + | ts TIMESTAMP(3) WITH LOCAL TIME ZONE NOT NULL, + | WATERMARK FOR ts AS ts + |) WITH ('connector' = 'values') + """.stripMargin) + + util.addTable( + """ + |CREATE VIEW Dedup AS + |SELECT k, v, ts FROM ( + | SELECT *, ROW_NUMBER() OVER (PARTITION BY k ORDER BY ts DESC) AS rn FROM Versioned + |) WHERE rn = 1 + """.stripMargin) + + util.addTable(""" + |CREATE VIEW Joined AS + |SELECT p.k, p.ts AS p_ts, e.x, d.v, d.ts AS d_ts + |FROM Probe AS p + | CROSS JOIN UNNEST(p.arr) AS e(x) + | INNER JOIN Dedup FOR SYSTEM_TIME AS OF p.ts AS d ON p.k = d.k + """.stripMargin) + + util.addTable(""" + |CREATE VIEW Out1 AS SELECT k, v FROM Joined + """.stripMargin) + util.addTable(""" + |CREATE VIEW Out2 AS SELECT x, v FROM Joined + """.stripMargin) + + util.addTable(""" + |CREATE TABLE Sink1 (k BIGINT, v BIGINT) WITH ('connector' = 'values') + """.stripMargin) + util.addTable(""" + |CREATE TABLE Sink2 (x BIGINT, v BIGINT) WITH ('connector' = 'values') + """.stripMargin) + util.addTable(""" + |CREATE TABLE Sink3 ( + | k BIGINT, + | p_ts TIMESTAMP(3) WITH LOCAL TIME ZONE, + | x BIGINT, + | v BIGINT, + | d_ts TIMESTAMP(3) WITH LOCAL TIME ZONE + |) WITH ('connector' = 'values') + """.stripMargin) + util.addTable(""" + |CREATE TABLE Sink4 ( + | k BIGINT, + | p_ts TIMESTAMP(3) WITH LOCAL TIME ZONE, + | x BIGINT, + | v BIGINT, + | d_ts TIMESTAMP(3) WITH LOCAL TIME ZONE + |) WITH ('connector' = 'values') + """.stripMargin) + + val stmtSet = util.tableEnv.createStatementSet() + stmtSet.addInsertSql("INSERT INTO Sink1 SELECT * FROM Out1") + stmtSet.addInsertSql("INSERT INTO Sink2 SELECT * FROM Out2") + stmtSet.addInsertSql(""" + |INSERT INTO Sink3 SELECT k, + | CAST(p_ts AS TIMESTAMP(3) WITH LOCAL TIME ZONE), + | x, v, + | CAST(d_ts AS TIMESTAMP(3) WITH LOCAL TIME ZONE) + |FROM Joined + """.stripMargin) + stmtSet.addInsertSql(""" + |INSERT INTO Sink4 SELECT k, + | CAST(p_ts AS TIMESTAMP(3) WITH LOCAL TIME ZONE), + | x, v, + | CAST(d_ts AS TIMESTAMP(3) WITH LOCAL TIME ZONE) + |FROM Joined + """.stripMargin) + util.verifyExecPlan(stmtSet) + } + @Test def testSourceReuseWithEmptyFilterCondAndIgnoreEmptyFilter(): Unit = { util.addTable(s"""