Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,19 @@ case class MergeRowsExec(
private val notMatchedBySourceInstructions: Seq[InstructionExec])
extends Iterator[InternalRow] {

// Resolve each metric at most once per partition, on first use; longMetric(name) is a map
// lookup. See SPARK-56933.
private lazy val numTargetRowsCopied = longMetric("numTargetRowsCopied")
private lazy val numTargetRowsInserted = longMetric("numTargetRowsInserted")
private lazy val numTargetRowsDeleted = longMetric("numTargetRowsDeleted")
private lazy val numTargetRowsUpdated = longMetric("numTargetRowsUpdated")
private lazy val numTargetRowsMatchedUpdated = longMetric("numTargetRowsMatchedUpdated")
private lazy val numTargetRowsMatchedDeleted = longMetric("numTargetRowsMatchedDeleted")
private lazy val numTargetRowsNotMatchedBySourceUpdated =
longMetric("numTargetRowsNotMatchedBySourceUpdated")
private lazy val numTargetRowsNotMatchedBySourceDeleted =
longMetric("numTargetRowsNotMatchedBySourceDeleted")

var cachedExtraRow: InternalRow = _

override def hasNext: Boolean = cachedExtraRow != null || rowIterator.hasNext
Expand Down Expand Up @@ -579,28 +592,27 @@ case class MergeRowsExec(

null
}
}

// For group based merge, copy is inserted if row matches no other case
private def incrementCopyMetric(): Unit = longMetric("numTargetRowsCopied") += 1
private def incrementCopyMetric(): Unit = numTargetRowsCopied += 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm now wondering if these += 1 can also be optimized be further. Doing += on a SQLMetric is technically doing a function call to SQLMetric.add which does some more work in addition to simply adding the value. Maybe JVM is smart enough to optimize it, I'm not sure. Do you think it's worth exploring?

We could replace the new private lazy val SQLMetric fields above with simple integers initially all set to 0, and increment those only. And at the end of applyInstructions, we can lookup and increment all metrics for which the value is > 0.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's an interesting idea. Yea I can play around with it as well. How about a separate pr, so as to checkmark the progress?


private def incrementInsertMetric(): Unit = longMetric("numTargetRowsInserted") += 1
private def incrementInsertMetric(): Unit = numTargetRowsInserted += 1

private def incrementDeleteMetric(sourcePresent: Boolean): Unit = {
longMetric("numTargetRowsDeleted") += 1
if (sourcePresent) {
longMetric("numTargetRowsMatchedDeleted") += 1
} else {
longMetric("numTargetRowsNotMatchedBySourceDeleted") += 1
private def incrementDeleteMetric(sourcePresent: Boolean): Unit = {
numTargetRowsDeleted += 1
if (sourcePresent) {
numTargetRowsMatchedDeleted += 1
} else {
numTargetRowsNotMatchedBySourceDeleted += 1
}
}
}

private def incrementUpdateMetric(sourcePresent: Boolean): Unit = {
longMetric("numTargetRowsUpdated") += 1
if (sourcePresent) {
longMetric("numTargetRowsMatchedUpdated") += 1
} else {
longMetric("numTargetRowsNotMatchedBySourceUpdated") += 1
private def incrementUpdateMetric(sourcePresent: Boolean): Unit = {
numTargetRowsUpdated += 1
if (sourcePresent) {
numTargetRowsMatchedUpdated += 1
} else {
numTargetRowsNotMatchedBySourceUpdated += 1
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.execution.benchmark

import scala.concurrent.duration._

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GreaterThan, IsNotNull, Literal}
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
Expand All @@ -43,6 +45,18 @@ object MergeRowsExecBenchmark extends SqlBasedBenchmark with ClassicConversions

private val N = 20 << 20

/** Longer warm-up and timed window for stable interpreted (whole-stage off) results. */
private def mergeRowsBenchmark(name: String, cardinality: Long)(f: => Unit): Unit = {
codegenBenchmark(
name,
cardinality,
warmupTime = 7.seconds,
minTime = 7.seconds,
minNumIters = 3,
wholestageOffNumIters = 0,
wholestageOnNumIters = 0)(f)
}

/**
* Creates a DataFrame simulating the join output from a MERGE operation.
*
Expand Down Expand Up @@ -110,7 +124,7 @@ object MergeRowsExecBenchmark extends SqlBasedBenchmark with ClassicConversions
a(0), a(5), a(6), a(3)
)))

codegenBenchmark("merge - matched update only", N) {
mergeRowsBenchmark("merge - matched update only", N) {
val df = buildMergeRowsDF(inputDF, matchedInstr)
assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[MergeRowsExec]))
df.noop()
Expand All @@ -126,7 +140,7 @@ object MergeRowsExecBenchmark extends SqlBasedBenchmark with ClassicConversions
a(4), a(5), a(6), a(7)
)))

codegenBenchmark("merge - not matched insert only", N) {
mergeRowsBenchmark("merge - not matched insert only", N) {
val df = buildMergeRowsDF(inputDF, Seq.empty, notMatchedInstr)
assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[MergeRowsExec]))
df.noop()
Expand All @@ -144,7 +158,7 @@ object MergeRowsExecBenchmark extends SqlBasedBenchmark with ClassicConversions
a(4), a(5), a(6), a(7)
)))

codegenBenchmark("merge - matched update + not matched insert", N) {
mergeRowsBenchmark("merge - matched update + not matched insert", N) {
val df = buildMergeRowsDF(inputDF, matchedInstr, notMatchedInstr)
assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[MergeRowsExec]))
df.noop()
Expand All @@ -156,7 +170,7 @@ object MergeRowsExecBenchmark extends SqlBasedBenchmark with ClassicConversions

val matchedInstr = Seq(Discard(TrueLiteral))

codegenBenchmark("merge - matched delete", N) {
mergeRowsBenchmark("merge - matched delete", N) {
val df = buildMergeRowsDF(inputDF, matchedInstr)
assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[MergeRowsExec]))
df.noop()
Expand All @@ -177,7 +191,7 @@ object MergeRowsExecBenchmark extends SqlBasedBenchmark with ClassicConversions
Keep(Insert, GreaterThan(a(5), Literal(500)), Seq(a(4), a(5), a(6), a(7)))
)

codegenBenchmark("merge - conditional clauses", N) {
mergeRowsBenchmark("merge - conditional clauses", N) {
val df = buildMergeRowsDF(inputDF, matchedInstr, notMatchedInstr)
assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[MergeRowsExec]))
df.noop()
Expand All @@ -199,7 +213,7 @@ object MergeRowsExecBenchmark extends SqlBasedBenchmark with ClassicConversions
)))
val notMatchedBySourceInstr = Seq(Discard(TrueLiteral))

codegenBenchmark("merge - matched + not matched + not matched by source", N) {
mergeRowsBenchmark("merge - matched + not matched + not matched by source", N) {
val df = buildMergeRowsDF(inputDF, matchedInstr, notMatchedInstr, notMatchedBySourceInstr)
assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[MergeRowsExec]))
df.noop()
Expand All @@ -216,7 +230,7 @@ object MergeRowsExecBenchmark extends SqlBasedBenchmark with ClassicConversions
Seq(a(0), a(5), a(6), a(3))
))

codegenBenchmark("merge - split update (delete + insert)", N) {
mergeRowsBenchmark("merge - split update (delete + insert)", N) {
val df = buildMergeRowsDF(inputDF, matchedInstr)
assert(df.queryExecution.sparkPlan.exists(_.isInstanceOf[MergeRowsExec]))
df.noop()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.execution.benchmark

import scala.concurrent.duration._

import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
import org.apache.spark.internal.config.MAX_RESULT_SIZE
import org.apache.spark.internal.config.UI.UI_ENABLED
Expand Down Expand Up @@ -46,17 +48,42 @@ trait SqlBasedBenchmark extends BenchmarkBase with SQLHelper {
.getOrCreate()
}

/** Runs function `f` with whole stage codegen on and off. */
final def codegenBenchmark(name: String, cardinality: Long)(f: => Unit): Unit = {
val benchmark = new Benchmark(name, cardinality, output = output)
/**
* Runs function `f` with whole stage codegen on and off.
*
* @param minNumIters minimum timed iterations per case when the corresponding
* `wholestageOffNumIters` or `wholestageOnNumIters` is zero.
* @param warmupTime JIT warm-up duration per case before timed iterations.
* @param minTime minimum total timed duration per case when the corresponding
* `wholestageOffNumIters` or `wholestageOnNumIters` is zero.
* @param wholestageOffNumIters if non-zero, run exactly this many timed iterations
* for the wholestage-off case; otherwise use `minNumIters` and `minTime`.
* @param wholestageOnNumIters if non-zero, run exactly this many timed iterations
* for the wholestage-on case; otherwise use `minNumIters` and `minTime`.
*/
final def codegenBenchmark(
name: String,
cardinality: Long,
minNumIters: Int = 2,
warmupTime: FiniteDuration = 2.seconds,
minTime: FiniteDuration = 2.seconds,
wholestageOffNumIters: Int = 2,
wholestageOnNumIters: Int = 5)(f: => Unit): Unit = {
val benchmark = new Benchmark(
name,
cardinality,
minNumIters = minNumIters,
warmupTime = warmupTime,
minTime = minTime,
output = output)

benchmark.addCase(s"$name wholestage off", numIters = 2) { _ =>
benchmark.addCase(s"$name wholestage off", numIters = wholestageOffNumIters) { _ =>
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
f
}
}

benchmark.addCase(s"$name wholestage on", numIters = 5) { _ =>
benchmark.addCase(s"$name wholestage on", numIters = wholestageOnNumIters) { _ =>
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") {
f
}
Expand Down