diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala index 526ff843a1496..178e028006d96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala @@ -518,6 +518,19 @@ case class MergeRowsExec( private val notMatchedBySourceInstructions: Seq[InstructionExec]) extends Iterator[InternalRow] { + // Resolve each metric at most once per partition, on first use; longMetric(name) is a map + // lookup. See SPARK-56933. + private lazy val numTargetRowsCopied = longMetric("numTargetRowsCopied") + private lazy val numTargetRowsInserted = longMetric("numTargetRowsInserted") + private lazy val numTargetRowsDeleted = longMetric("numTargetRowsDeleted") + private lazy val numTargetRowsUpdated = longMetric("numTargetRowsUpdated") + private lazy val numTargetRowsMatchedUpdated = longMetric("numTargetRowsMatchedUpdated") + private lazy val numTargetRowsMatchedDeleted = longMetric("numTargetRowsMatchedDeleted") + private lazy val numTargetRowsNotMatchedBySourceUpdated = + longMetric("numTargetRowsNotMatchedBySourceUpdated") + private lazy val numTargetRowsNotMatchedBySourceDeleted = + longMetric("numTargetRowsNotMatchedBySourceDeleted") + var cachedExtraRow: InternalRow = _ override def hasNext: Boolean = cachedExtraRow != null || rowIterator.hasNext @@ -579,28 +592,27 @@ case class MergeRowsExec( null } - } - // For group based merge, copy is inserted if row matches no other case - private def incrementCopyMetric(): Unit = longMetric("numTargetRowsCopied") += 1 + private def incrementCopyMetric(): Unit = numTargetRowsCopied += 1 - private def incrementInsertMetric(): Unit = longMetric("numTargetRowsInserted") += 1 + private def incrementInsertMetric(): Unit = numTargetRowsInserted += 1 - private def incrementDeleteMetric(sourcePresent: Boolean): Unit = { - longMetric("numTargetRowsDeleted") += 1 - if (sourcePresent) { - longMetric("numTargetRowsMatchedDeleted") += 1 - } else { - longMetric("numTargetRowsNotMatchedBySourceDeleted") += 1 + private def incrementDeleteMetric(sourcePresent: Boolean): Unit = { + numTargetRowsDeleted += 1 + if (sourcePresent) { + numTargetRowsMatchedDeleted += 1 + } else { + numTargetRowsNotMatchedBySourceDeleted += 1 + } } - } - private def incrementUpdateMetric(sourcePresent: Boolean): Unit = { - longMetric("numTargetRowsUpdated") += 1 - if (sourcePresent) { - longMetric("numTargetRowsMatchedUpdated") += 1 - } else { - longMetric("numTargetRowsNotMatchedBySourceUpdated") += 1 + private def incrementUpdateMetric(sourcePresent: Boolean): Unit = { + numTargetRowsUpdated += 1 + if (sourcePresent) { + numTargetRowsMatchedUpdated += 1 + } else { + numTargetRowsNotMatchedBySourceUpdated += 1 + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MergeRowsExecBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MergeRowsExecBenchmark.scala index 8ddbca46b7396..0fcac326d923d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MergeRowsExecBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MergeRowsExecBenchmark.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.benchmark +import scala.concurrent.duration._ + import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GreaterThan, IsNotNull, Literal} import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral @@ -43,6 +45,18 @@ object MergeRowsExecBenchmark extends SqlBasedBenchmark with ClassicConversions private val N = 20 << 20 + /** Longer warm-up and timed window for stable interpreted (whole-stage off) results. */ + private def mergeRowsBenchmark(name: String, cardinality: Long)(f: => Unit): Unit = { + codegenBenchmark( + name, + cardinality, + warmupTime = 7.seconds, + minTime = 7.seconds, + minNumIters = 3, + wholestageOffNumIters = 0, + wholestageOnNumIters = 0)(f) + } + /** * Creates a DataFrame simulating the join output from a MERGE operation. * @@ -110,7 +124,7 @@ object MergeRowsExecBenchmark extends SqlBasedBenchmark with ClassicConversions a(0), a(5), a(6), a(3) ))) - codegenBenchmark("merge - matched update only", N) { + mergeRowsBenchmark("merge - matched update only", N) { val df = buildMergeRowsDF(inputDF, matchedInstr) assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[MergeRowsExec])) df.noop() @@ -126,7 +140,7 @@ object MergeRowsExecBenchmark extends SqlBasedBenchmark with ClassicConversions a(4), a(5), a(6), a(7) ))) - codegenBenchmark("merge - not matched insert only", N) { + mergeRowsBenchmark("merge - not matched insert only", N) { val df = buildMergeRowsDF(inputDF, Seq.empty, notMatchedInstr) assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[MergeRowsExec])) df.noop() @@ -144,7 +158,7 @@ object MergeRowsExecBenchmark extends SqlBasedBenchmark with ClassicConversions a(4), a(5), a(6), a(7) ))) - codegenBenchmark("merge - matched update + not matched insert", N) { + mergeRowsBenchmark("merge - matched update + not matched insert", N) { val df = buildMergeRowsDF(inputDF, matchedInstr, notMatchedInstr) assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[MergeRowsExec])) df.noop() @@ -156,7 +170,7 @@ object MergeRowsExecBenchmark extends SqlBasedBenchmark with ClassicConversions val matchedInstr = Seq(Discard(TrueLiteral)) - codegenBenchmark("merge - matched delete", N) { + mergeRowsBenchmark("merge - matched delete", N) { val df = buildMergeRowsDF(inputDF, matchedInstr) assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[MergeRowsExec])) df.noop() @@ -177,7 +191,7 @@ object MergeRowsExecBenchmark extends SqlBasedBenchmark with ClassicConversions Keep(Insert, GreaterThan(a(5), Literal(500)), Seq(a(4), a(5), a(6), a(7))) ) - codegenBenchmark("merge - conditional clauses", N) { + mergeRowsBenchmark("merge - conditional clauses", N) { val df = buildMergeRowsDF(inputDF, matchedInstr, notMatchedInstr) assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[MergeRowsExec])) df.noop() @@ -199,7 +213,7 @@ object MergeRowsExecBenchmark extends SqlBasedBenchmark with ClassicConversions ))) val notMatchedBySourceInstr = Seq(Discard(TrueLiteral)) - codegenBenchmark("merge - matched + not matched + not matched by source", N) { + mergeRowsBenchmark("merge - matched + not matched + not matched by source", N) { val df = buildMergeRowsDF(inputDF, matchedInstr, notMatchedInstr, notMatchedBySourceInstr) assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[MergeRowsExec])) df.noop() @@ -216,7 +230,7 @@ object MergeRowsExecBenchmark extends SqlBasedBenchmark with ClassicConversions Seq(a(0), a(5), a(6), a(3)) )) - codegenBenchmark("merge - split update (delete + insert)", N) { + mergeRowsBenchmark("merge - split update (delete + insert)", N) { val df = buildMergeRowsDF(inputDF, matchedInstr) assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[MergeRowsExec])) df.noop() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SqlBasedBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SqlBasedBenchmark.scala index 78d6b01580355..6c60721599bbb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SqlBasedBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SqlBasedBenchmark.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.benchmark +import scala.concurrent.duration._ + import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.internal.config.MAX_RESULT_SIZE import org.apache.spark.internal.config.UI.UI_ENABLED @@ -46,17 +48,42 @@ trait SqlBasedBenchmark extends BenchmarkBase with SQLHelper { .getOrCreate() } - /** Runs function `f` with whole stage codegen on and off. */ - final def codegenBenchmark(name: String, cardinality: Long)(f: => Unit): Unit = { - val benchmark = new Benchmark(name, cardinality, output = output) + /** + * Runs function `f` with whole stage codegen on and off. + * + * @param minNumIters minimum timed iterations per case when the corresponding + * `wholestageOffNumIters` or `wholestageOnNumIters` is zero. + * @param warmupTime JIT warm-up duration per case before timed iterations. + * @param minTime minimum total timed duration per case when the corresponding + * `wholestageOffNumIters` or `wholestageOnNumIters` is zero. + * @param wholestageOffNumIters if non-zero, run exactly this many timed iterations + * for the wholestage-off case; otherwise use `minNumIters` and `minTime`. + * @param wholestageOnNumIters if non-zero, run exactly this many timed iterations + * for the wholestage-on case; otherwise use `minNumIters` and `minTime`. + */ + final def codegenBenchmark( + name: String, + cardinality: Long, + minNumIters: Int = 2, + warmupTime: FiniteDuration = 2.seconds, + minTime: FiniteDuration = 2.seconds, + wholestageOffNumIters: Int = 2, + wholestageOnNumIters: Int = 5)(f: => Unit): Unit = { + val benchmark = new Benchmark( + name, + cardinality, + minNumIters = minNumIters, + warmupTime = warmupTime, + minTime = minTime, + output = output) - benchmark.addCase(s"$name wholestage off", numIters = 2) { _ => + benchmark.addCase(s"$name wholestage off", numIters = wholestageOffNumIters) { _ => withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { f } } - benchmark.addCase(s"$name wholestage on", numIters = 5) { _ => + benchmark.addCase(s"$name wholestage on", numIters = wholestageOnNumIters) { _ => withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") { f }