From 4d22554b591884950b65b35037724ea54f9e6c1e Mon Sep 17 00:00:00 2001 From: Tobias Kaymak Date: Fri, 17 Apr 2026 16:01:37 +0200 Subject: [PATCH 01/14] build: add Spark 4.0.2 version property and Scala 2.13 support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add spark4_version (4.0.2) to BeamModulePlugin alongside the existing spark3_version. Update spark_runner.gradle to conditionally select the correct Scala library (2.13 vs 2.12), Jackson module, Kafka test dependency, and require Java 17 when building against Spark 4. Register the new :runners:spark:4 module in settings.gradle.kts. These changes are purely additive — all conditionals gate on spark_version.startsWith("4") or spark_scala_version == '2.13', leaving the Spark 3 build path untouched. Co-Authored-By: Claude Sonnet 4.6 --- .../apache/beam/gradle/BeamModulePlugin.groovy | 3 +++ runners/spark/spark_runner.gradle | 18 +++++++++++++----- settings.gradle.kts | 1 + 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy index aed7418b5ff9..5e15c386e0c5 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy @@ -649,6 +649,7 @@ class BeamModulePlugin implements Plugin { def solace_version = "10.21.0" def spark2_version = "2.4.8" def spark3_version = "3.5.0" + def spark4_version = "4.0.2" def spotbugs_version = "4.8.3" def testcontainers_version = "1.21.4" // [bomupgrader] determined by: org.apache.arrow:arrow-memory-core, consistent with: google_cloud_platform_libraries_bom @@ -658,6 +659,7 @@ class BeamModulePlugin implements Plugin { // Export Spark versions, so they are defined in a single place only project.ext.spark3_version = spark3_version + project.ext.spark4_version = spark4_version // version for BigQueryMetastore catalog (used by sdks:java:io:iceberg:bqms) // TODO: remove this and download the jar normally when the catalog gets // open-sourced (https://github.com/apache/iceberg/pull/11039) @@ -820,6 +822,7 @@ class BeamModulePlugin implements Plugin { jackson_datatype_jsr310 : "com.fasterxml.jackson.datatype:jackson-datatype-jsr310:$jackson_version", jackson_module_scala_2_11 : "com.fasterxml.jackson.module:jackson-module-scala_2.11:$jackson_version", jackson_module_scala_2_12 : "com.fasterxml.jackson.module:jackson-module-scala_2.12:$jackson_version", + jackson_module_scala_2_13 : "com.fasterxml.jackson.module:jackson-module-scala_2.13:$jackson_version", jamm : 'com.github.jbellis:jamm:0.4.0', jaxb_api : "jakarta.xml.bind:jakarta.xml.bind-api:$jaxb_api_version", jaxb_impl : "com.sun.xml.bind:jaxb-impl:$jaxb_api_version", diff --git a/runners/spark/spark_runner.gradle b/runners/spark/spark_runner.gradle index ecdfc8f0f697..a0e118412319 100644 --- a/runners/spark/spark_runner.gradle +++ b/runners/spark/spark_runner.gradle @@ -21,6 +21,7 @@ import groovy.json.JsonOutput apply plugin: 'org.apache.beam.module' applyJavaNature( enableStrictDependencies: true, + requireJavaVersion: (spark_version.startsWith("4") ? org.gradle.api.JavaVersion.VERSION_17 : null), automaticModuleName: 'org.apache.beam.runners.spark', archivesBaseName: (project.hasProperty('archives_base_name') ? archives_base_name : archivesBaseName), exportJavadoc: (project.hasProperty('exportJavadoc') ? exportJavadoc : true), @@ -182,10 +183,15 @@ dependencies { } permitUnusedDeclared "org.apache.spark:spark-network-common_$spark_scala_version:$spark_version" implementation "io.dropwizard.metrics:metrics-core:4.1.1" // version used by Spark 3.1 - compileOnly "org.scala-lang:scala-library:2.12.15" - runtimeOnly library.java.jackson_module_scala_2_12 - // Force paranamer 2.8 to avoid issues when using Scala 2.12 - runtimeOnly "com.thoughtworks.paranamer:paranamer:2.8" + if (spark_scala_version == '2.13') { + compileOnly "org.scala-lang:scala-library:2.13.15" + runtimeOnly library.java.jackson_module_scala_2_13 + } else { + compileOnly "org.scala-lang:scala-library:2.12.15" + runtimeOnly library.java.jackson_module_scala_2_12 + // Force paranamer 2.8 to avoid issues when using Scala 2.12 + runtimeOnly "com.thoughtworks.paranamer:paranamer:2.8" + } provided "org.apache.hadoop:hadoop-client-api:3.3.1" provided library.java.commons_io provided library.java.hamcrest @@ -200,7 +206,9 @@ dependencies { testImplementation project(path: ":sdks:java:extensions:avro", configuration: "testRuntimeMigration") testImplementation project(":sdks:java:harness") testImplementation library.java.avro - testImplementation "org.apache.kafka:kafka_$spark_scala_version:2.4.1" + // kafka_2.13 artifacts were first published in 2.5.0; use a later version for Scala 2.13 + def kafka_version = (spark_scala_version == '2.13') ? '2.8.0' : '2.4.1' + testImplementation "org.apache.kafka:kafka_$spark_scala_version:$kafka_version" testImplementation library.java.kafka_clients testImplementation library.java.junit testImplementation library.java.mockito_core diff --git a/settings.gradle.kts b/settings.gradle.kts index c001a1add446..66c99a2c796c 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -150,6 +150,7 @@ include(":runners:prism:java") include(":runners:spark:3") include(":runners:spark:3:job-server") include(":runners:spark:3:job-server:container") +include(":runners:spark:4") include(":runners:samza") include(":runners:samza:job-server") include(":sdks:go") From fb95a668c701bc0fe7507ac221246fbbb1ca6d65 Mon Sep 17 00:00:00 2001 From: Tobias Kaymak Date: Fri, 17 Apr 2026 16:01:37 +0200 Subject: [PATCH 02/14] refactor: make shared Spark source compatible with Scala 2.12 and 2.13 Co-Authored-By: Claude Opus 4.6 --- .../spark/coders/SparkRunnerKryoRegistrator.java | 12 ++++++++++-- .../org/apache/beam/runners/spark/io/SourceRDD.java | 10 +++++----- .../beam/runners/spark/io/SparkUnboundedSource.java | 2 +- .../stateful/SparkGroupAlsoByWindowViaWindowSet.java | 10 ++++++---- .../SparkStreamingPortablePipelineTranslator.java | 3 ++- .../translation/streaming/ParDoStateUpdateFn.java | 4 ++-- .../streaming/StreamingTransformTranslator.java | 2 +- 7 files changed, 27 insertions(+), 16 deletions(-) diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/SparkRunnerKryoRegistrator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/SparkRunnerKryoRegistrator.java index 68c602ff7f59..ba8c0812c9e5 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/SparkRunnerKryoRegistrator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/SparkRunnerKryoRegistrator.java @@ -30,7 +30,6 @@ import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.HashBasedTable; import org.apache.spark.serializer.KryoRegistrator; -import scala.collection.mutable.WrappedArray; /** * Custom {@link KryoRegistrator}s for Beam's Spark runner needs and registering used class in spark @@ -61,7 +60,16 @@ public void registerClasses(Kryo kryo) { kryo.register(PaneInfo.class); kryo.register(StateAndTimers.class); kryo.register(TupleTag.class); - kryo.register(WrappedArray.ofRef.class); + // Scala 2.12 uses WrappedArray$ofRef, Scala 2.13 renamed it to ArraySeq$ofRef + try { + kryo.register(Class.forName("scala.collection.mutable.ArraySeq$ofRef")); + } catch (ClassNotFoundException e) { + try { + kryo.register(Class.forName("scala.collection.mutable.WrappedArray$ofRef")); + } catch (ClassNotFoundException ignored) { + // Neither class found; skip registration + } + } try { kryo.register( diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java index e65dccd23f24..56a2219933b6 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java @@ -50,7 +50,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.Option; -import scala.collection.JavaConversions; +import scala.collection.JavaConverters; /** Classes implementing Beam {@link Source} {@link RDD}s. */ @SuppressWarnings({ @@ -75,7 +75,7 @@ public static class Bounded extends RDD> { // to satisfy Scala API. private static final scala.collection.immutable.Seq> NIL = - JavaConversions.asScalaBuffer(Collections.>emptyList()).toList(); + JavaConverters.asScalaBuffer(Collections.>emptyList()).toList(); public Bounded( SparkContext sc, @@ -148,7 +148,7 @@ public scala.collection.Iterator> compute( final Iterator> readerIterator = new ReaderToIteratorAdapter<>(metricsContainer, reader); - return new InterruptibleIterator<>(context, JavaConversions.asScalaIterator(readerIterator)); + return new InterruptibleIterator<>(context, JavaConverters.asScalaIterator(readerIterator)); } /** @@ -299,7 +299,7 @@ public static class Unbounded> NIL = - JavaConversions.asScalaBuffer(Collections.>emptyList()).toList(); + JavaConverters.asScalaBuffer(Collections.>emptyList()).toList(); public Unbounded( SparkContext sc, @@ -344,7 +344,7 @@ public scala.collection.Iterator, CheckpointMarkT>> compu (CheckpointableSourcePartition) split; scala.Tuple2, CheckpointMarkT> tuple2 = new scala.Tuple2<>(partition.getSource(), partition.checkpointMark); - return JavaConversions.asScalaIterator(Collections.singleton(tuple2).iterator()); + return JavaConverters.asScalaIterator(Collections.singleton(tuple2).iterator()); } } diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java index bea1557a7103..3f1fb103e47c 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java @@ -186,7 +186,7 @@ public Duration slideDuration() { @Override public scala.collection.immutable.List> dependencies() { - return scala.collection.JavaConversions.asScalaBuffer( + return scala.collection.JavaConverters.asScalaBuffer( Collections.>singletonList(parent)) .toList(); } diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java index 2c54f90badbe..be69ee78e51c 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java @@ -73,7 +73,7 @@ import scala.Tuple2; import scala.Tuple3; import scala.collection.Iterator; -import scala.collection.JavaConversions; +import scala.collection.JavaConverters; import scala.collection.Seq; import scala.runtime.AbstractFunction1; @@ -238,7 +238,7 @@ private Collection filterTimersEligibleForProcessing( // new input for key. try { final Iterable> elements = - FluentIterable.from(JavaConversions.asJavaIterable(encodedElements)) + FluentIterable.from(JavaConverters.asJavaIterable(encodedElements)) .transform(bytes -> CoderHelpers.fromByteArray(bytes, wvCoder)); LOG.trace("{}: input elements: {}", logPrefix, elements); @@ -410,7 +410,7 @@ private Collection filterTimersEligibleForProcessing( droppedDueToClosedWindow.inc(-droppedDueToClosedWindow.getCumulative()); } - return scala.collection.JavaConversions.asScalaIterator( + return JavaConverters.asScalaIterator( new UpdateStateByKeyOutputIterator(input, reduceFn, droppedDueToLateness)); } } @@ -522,7 +522,9 @@ JavaDStream>>> groupByKeyAndWindow( Tuple2>>*/ List>>> firedStream = pairDStream.updateStateByKey( - updateFunc, + // Raw cast to AbstractFunction1 suppresses Scala 2.12 (collection.Seq) vs + // Scala 2.13 (immutable.Seq) type difference — safe at runtime due to erasure. + (scala.runtime.AbstractFunction1) updateFunc, pairDStream.defaultPartitioner(pairDStream.defaultPartitioner$default$1()), true, JavaSparkContext$.MODULE$.fakeClassTag()); diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkStreamingPortablePipelineTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkStreamingPortablePipelineTranslator.java index 4850f886241b..9975c81b56a4 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkStreamingPortablePipelineTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkStreamingPortablePipelineTranslator.java @@ -330,7 +330,8 @@ private static void translateFlatten( } } // Unify streams into a single stream. - unifiedStreams = context.getStreamingContext().union(JavaConverters.asScalaBuffer(dStreams)); + unifiedStreams = + context.getStreamingContext().union(JavaConverters.asScalaBuffer(dStreams).toList()); } context.pushDataset( diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/ParDoStateUpdateFn.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/ParDoStateUpdateFn.java index 909624c23239..ed9299db4ee7 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/ParDoStateUpdateFn.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/ParDoStateUpdateFn.java @@ -19,6 +19,7 @@ import java.io.Serializable; import java.util.Collection; +import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -62,7 +63,6 @@ import org.checkerframework.checker.nullness.qual.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.sparkproject.guava.collect.Iterators; import scala.Option; import scala.Tuple2; import scala.runtime.AbstractFunction3; @@ -236,7 +236,7 @@ public TimerInternals timerInternals() { final byte[] byteValue = serializedValue.get(); @Nullable WindowedValue windowedValue; @Nullable WindowedValue> keyedWindowedValue; - Iterator>> iterator = Iterators.emptyIterator(); + Iterator>> iterator = Collections.emptyIterator(); if (byteValue.length > 0) { windowedValue = CoderHelpers.fromByteArray(byteValue, this.wvCoder); keyedWindowedValue = windowedValue.withValue(KV.of(key, windowedValue.getValue())); diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java index 48697a3dbafc..4a96edceba31 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java @@ -306,7 +306,7 @@ public void evaluate(Flatten.PCollections transform, EvaluationContext contex } // start by unifying streams into a single stream. JavaDStream> unifiedStreams = - context.getStreamingContext().union(JavaConverters.asScalaBuffer(dStreams)); + context.getStreamingContext().union(JavaConverters.asScalaBuffer(dStreams).toList()); context.putDataset(transform, new UnboundedDataset<>(unifiedStreams, streamingSources)); } From 46e2dba21d0b1b5838b3c63113e73d1202732ea2 Mon Sep 17 00:00:00 2001 From: Tobias Kaymak Date: Fri, 17 Apr 2026 16:01:37 +0200 Subject: [PATCH 03/14] build: add runners/spark/4/ build configuration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add the Gradle build file for the Spark 4 structured streaming runner. The module mirrors runners/spark/3/ — it inherits the shared RDD-base source from runners/spark/src/ via copySourceBase and adds its own Structured Streaming implementation in src/main/java. Key differences from the Spark 3 build: - Uses spark4_version (4.0.2) with Scala 2.13. - Excludes DStream-based streaming tests (Spark 4 supports only structured streaming batch). - Unconditionally adds --add-opens JVM flags required by Kryo on Java 17 (Spark 4's minimum). - Binds Spark driver to 127.0.0.1 for macOS compatibility. Co-Authored-By: Claude Sonnet 4.6 --- runners/spark/4/build.gradle | 89 ++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 runners/spark/4/build.gradle diff --git a/runners/spark/4/build.gradle b/runners/spark/4/build.gradle new file mode 100644 index 000000000000..908606218b85 --- /dev/null +++ b/runners/spark/4/build.gradle @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +def basePath = '..' +/* All properties required for loading the Spark build script */ +project.ext { + // Spark 4 version as defined in BeamModulePlugin; requires Scala 2.13 and Java 17 + spark_version = spark4_version + spark_scala_version = '2.13' + copySourceBase = true // copy shared base into build dir; Spark 3 remains the primary dev version + archives_base_name = 'beam-runners-spark-4' +} + +// Load the main build script which contains all build logic. +apply from: "$basePath/spark_runner.gradle" + +// Force Spark to bind to 127.0.0.1 so tests pass on machines where the hostname +// doesn't resolve to a bindable address (e.g. mac.lan in macOS VPN environments). +// Spark 4 always requires Java 17, so unconditionally add the --add-opens flags +// required by Kryo and other libraries that use reflection on JDK internals. +test { + systemProperty "spark.driver.host", "127.0.0.1" + systemProperty "spark.driver.bindAddress", "127.0.0.1" + jvmArgs "--add-opens=java.base/sun.nio.ch=ALL-UNNAMED", + "--add-opens=java.base/java.nio=ALL-UNNAMED", + "--add-opens=java.base/java.util=ALL-UNNAMED", + "--add-opens=java.base/java.lang.invoke=ALL-UNNAMED" +} + +// Add Spark 4 structured streaming source on top of the copied shared base. +sourceSets.main.java.srcDirs += "src/main/java" +sourceSets.test.java.srcDirs += "src/test/java" + +// Exclude DStream-based streaming tests from the shared base copy: the Spark 4 module +// supports only structured streaming (batch) and does not include legacy DStream support. +// Streaming test utilities also depend on kafka.server.KafkaServerStartable which was +// removed in Kafka 2.8.0 (the first Kafka version with a _2.13 artifact). +tasks.named("copyTestJava") { + exclude "**/translation/streaming/**" +} + +// Additional supported Spark 4.x versions for compatibility tests. +// Can be expanded as new patch releases are published. +def sparkVersions = [ + // "402": "4.0.2", // primary version; already tested via the default build +] + +sparkVersions.each { kv -> + configurations.create("sparkVersion$kv.key") + configurations."sparkVersion$kv.key" { + resolutionStrategy { + spark.components.each { component -> force "$component:$kv.value" } + } + } + + dependencies { + spark.components.each { component -> "sparkVersion$kv.key" "$component:$kv.value" } + } + + tasks.register("sparkVersion${kv.key}Test", Test) { + group = "Verification" + description = "Verifies code compatibility with Spark $kv.value" + classpath = configurations."sparkVersion$kv.key" + sourceSets.test.runtimeClasspath + systemProperties test.systemProperties + + include "**/*.class" + maxParallelForks 4 + } +} + +tasks.register("sparkVersionsTest") { + group = "Verification" + dependsOn sparkVersions.collect{k,v -> "sparkVersion${k}Test"} +} From 83451d792b451ac1a2c60bcb890b71388d4d8c19 Mon Sep 17 00:00:00 2001 From: Tobias Kaymak Date: Fri, 17 Apr 2026 16:01:37 +0200 Subject: [PATCH 04/14] feat: add Spark 4 structured streaming runner source MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add the Spark 4 structured streaming runner implementation and tests. Most files are adapted from the Spark 3 structured streaming runner with targeted changes for Spark 4 / Scala 2.13 API compatibility. Key Spark 4-specific changes (diff against runners/spark/3/src/): EncoderFactory — Replaced the direct ExpressionEncoder constructor (removed in Spark 4) with BeamAgnosticEncoder, a named class implementing both AgnosticExpressionPathEncoder (for expression delegation via toCatalyst/fromCatalyst) and AgnosticEncoders .StructEncoder (so Dataset.select(TypedColumn) creates an N-attribute plan, preventing FIELD_NUMBER_MISMATCH). The toCatalyst/fromCatalyst methods substitute the provided input expression via transformUp, enabling correct nesting inside composite encoders like Encoders.tuple(). EncoderHelpers — Added toExpressionEncoder() helper to handle Spark 4 built-in encoders that are AgnosticEncoder subclasses rather than ExpressionEncoder. GroupByKeyTranslatorBatch — Migrated from internal catalyst Expression API (CreateNamedStruct, Literal$) to public Column API (struct(), lit(), array()), as required by Spark 4. BoundedDatasetFactory — Use classic.Dataset$.MODULE$.ofRows() as Dataset moved to org.apache.spark.sql.classic in Spark 4. ScalaInterop — Replace WrappedArray.ofRef (removed in Scala 2.13) with JavaConverters.asScalaBuffer().toList() in seqOf(). GroupByKeyHelpers, CombinePerKeyTranslatorBatch — Replace TraversableOnce with IterableOnce (Scala 2.13 rename). SparkStructuredStreamingPipelineResult — Replace sparkproject.guava with Beam's vendored Guava. Co-Authored-By: Claude Sonnet 4.6 --- ...arkStructuredStreamingPipelineOptions.java | 42 ++ ...parkStructuredStreamingPipelineResult.java | 131 ++++ .../SparkStructuredStreamingRunner.java | 227 +++++++ ...arkStructuredStreamingRunnerRegistrar.java | 54 ++ .../examples/WordCount.java | 132 ++++ .../io/BoundedDatasetFactory.java | 330 ++++++++++ .../structuredstreaming/io/package-info.java | 20 + .../metrics/BeamMetricSet.java | 60 ++ .../metrics/MetricsAccumulator.java | 133 ++++ .../metrics/SparkBeamMetric.java | 113 ++++ .../metrics/SparkBeamMetricSource.java | 46 ++ .../metrics/WithMetricsSupport.java | 91 +++ .../metrics/package-info.java | 20 + .../metrics/sink/CodahaleCsvSink.java | 85 +++ .../metrics/sink/CodahaleGraphiteSink.java | 88 +++ .../metrics/sink/package-info.java | 20 + .../structuredstreaming/package-info.java | 20 + .../translation/EvaluationContext.java | 117 ++++ .../translation/PipelineTranslator.java | 516 +++++++++++++++ .../translation/SparkSessionFactory.java | 295 +++++++++ .../translation/SparkTransformOverrides.java | 56 ++ .../translation/TransformTranslator.java | 230 +++++++ .../translation/batch/Aggregators.java | 600 +++++++++++++++++ .../batch/CombineGloballyTranslatorBatch.java | 126 ++++ .../CombineGroupedValuesTranslatorBatch.java | 79 +++ .../batch/CombinePerKeyTranslatorBatch.java | 162 +++++ .../batch/DoFnPartitionIteratorFactory.java | 198 ++++++ .../translation/batch/DoFnRunnerFactory.java | 300 +++++++++ .../batch/DoFnRunnerWithMetrics.java | 116 ++++ .../batch/FlattenTranslatorBatch.java | 67 ++ .../translation/batch/GroupByKeyHelpers.java | 107 +++ .../batch/GroupByKeyTranslatorBatch.java | 298 +++++++++ .../batch/ImpulseTranslatorBatch.java | 47 ++ .../batch/ParDoTranslatorBatch.java | 274 ++++++++ .../batch/PipelineTranslatorBatch.java | 92 +++ .../batch/ReadSourceTranslatorBatch.java | 60 ++ .../batch/ReshuffleTranslatorBatch.java | 58 ++ .../batch/WindowAssignTranslatorBatch.java | 104 +++ .../functions/CachedSideInputReader.java | 178 +++++ .../GroupAlsoByWindowViaOutputBufferFn.java | 157 +++++ .../batch/functions/NoOpStepContext.java | 36 ++ .../batch/functions/SideInputValues.java | 189 ++++++ .../batch/functions/SparkSideInputReader.java | 147 +++++ .../batch/functions/package-info.java | 20 + .../translation/batch/package-info.java | 20 + .../translation/helpers/CoderHelpers.java | 59 ++ .../translation/helpers/EncoderFactory.java | 310 +++++++++ .../translation/helpers/EncoderHelpers.java | 610 ++++++++++++++++++ .../translation/helpers/EncoderProvider.java | 58 ++ .../translation/helpers/package-info.java | 20 + .../translation/package-info.java | 20 + .../translation/utils/ScalaInterop.java | 114 ++++ .../translation/utils/package-info.java | 20 + .../structuredstreaming/SparkSessionRule.java | 108 ++++ ...tructuredStreamingRunnerRegistrarTest.java | 70 ++ .../StructuredStreamingPipelineStateTest.java | 225 +++++++ .../metrics/sink/InMemoryMetrics.java | 84 +++ .../metrics/sink/InMemoryMetricsSinkRule.java | 28 + .../metrics/sink/SparkMetricsSinkTest.java | 73 +++ .../metrics/SparkBeamMetricTest.java | 59 ++ .../translation/batch/AggregatorsTest.java | 371 +++++++++++ .../batch/CombineGloballyTest.java | 149 +++++ .../batch/CombineGroupedValuesTest.java | 64 ++ .../translation/batch/CombinePerKeyTest.java | 174 +++++ .../translation/batch/ComplexSourceTest.java | 84 +++ .../translation/batch/FlattenTest.java | 55 ++ .../translation/batch/GroupByKeyTest.java | 205 ++++++ .../translation/batch/ParDoTest.java | 214 ++++++ .../translation/batch/SimpleSourceTest.java | 47 ++ .../translation/batch/WindowAssignTest.java | 63 ++ .../batch/functions/SideInputValuesTest.java | 130 ++++ .../helpers/EncoderHelpersTest.java | 298 +++++++++ .../streaming/SimpleSourceTest.java | 57 ++ 73 files changed, 10030 insertions(+) create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineOptions.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineResult.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunnerRegistrar.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/examples/WordCount.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/package-info.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/BeamMetricSet.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/MetricsAccumulator.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkBeamMetric.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkBeamMetricSource.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/WithMetricsSupport.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/package-info.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/sink/CodahaleCsvSink.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/sink/CodahaleGraphiteSink.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/sink/package-info.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/package-info.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/EvaluationContext.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/SparkSessionFactory.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/SparkTransformOverrides.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/Aggregators.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTranslatorBatch.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGroupedValuesTranslatorBatch.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnPartitionIteratorFactory.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnRunnerFactory.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnRunnerWithMetrics.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTranslatorBatch.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyHelpers.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ImpulseTranslatorBatch.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/PipelineTranslatorBatch.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReadSourceTranslatorBatch.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReshuffleTranslatorBatch.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTranslatorBatch.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/CachedSideInputReader.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/GroupAlsoByWindowViaOutputBufferFn.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/NoOpStepContext.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/SideInputValues.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/SparkSideInputReader.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/package-info.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/package-info.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/CoderHelpers.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderProvider.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/package-info.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/package-info.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop.java create mode 100644 runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/package-info.java create mode 100644 runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/SparkSessionRule.java create mode 100644 runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunnerRegistrarTest.java create mode 100644 runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/StructuredStreamingPipelineStateTest.java create mode 100644 runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/InMemoryMetrics.java create mode 100644 runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/InMemoryMetricsSinkRule.java create mode 100644 runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/SparkMetricsSinkTest.java create mode 100644 runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkBeamMetricTest.java create mode 100644 runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorsTest.java create mode 100644 runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTest.java create mode 100644 runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGroupedValuesTest.java create mode 100644 runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTest.java create mode 100644 runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ComplexSourceTest.java create mode 100644 runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTest.java create mode 100644 runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTest.java create mode 100644 runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java create mode 100644 runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/SimpleSourceTest.java create mode 100644 runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTest.java create mode 100644 runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/SideInputValuesTest.java create mode 100644 runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest.java create mode 100644 runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/SimpleSourceTest.java diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineOptions.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineOptions.java new file mode 100644 index 000000000000..3371a403b2c9 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineOptions.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming; + +import org.apache.beam.runners.spark.SparkCommonPipelineOptions; +import org.apache.beam.sdk.options.Default; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.PipelineOptions; + +/** + * Spark runner {@link PipelineOptions} handles Spark execution-related configurations, such as the + * master address, and other user-related knobs. + */ +public interface SparkStructuredStreamingPipelineOptions extends SparkCommonPipelineOptions { + + /** Set to true to run the job in test mode. */ + @Default.Boolean(false) + boolean getTestMode(); + + void setTestMode(boolean testMode); + + @Description("Enable if the runner should use the currently active Spark session.") + @Default.Boolean(false) + boolean getUseActiveSparkSession(); + + void setUseActiveSparkSession(boolean value); +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineResult.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineResult.java new file mode 100644 index 000000000000..806d838d9bff --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineResult.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming; + +import static org.apache.beam.runners.core.metrics.MetricsContainerStepMap.asAttemptedOnlyMetricResults; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects.firstNonNull; + +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import javax.annotation.Nullable; +import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.sdk.metrics.MetricResults; +import org.apache.beam.sdk.util.UserCodeException; +import org.apache.spark.SparkException; +import org.joda.time.Duration; + +public class SparkStructuredStreamingPipelineResult implements PipelineResult { + + private final Future pipelineExecution; + private final MetricsAccumulator metrics; + private @Nullable final Runnable onTerminalState; + private PipelineResult.State state; + + SparkStructuredStreamingPipelineResult( + Future pipelineExecution, + MetricsAccumulator metrics, + @Nullable final Runnable onTerminalState) { + this.pipelineExecution = pipelineExecution; + this.metrics = metrics; + this.onTerminalState = onTerminalState; + // pipelineExecution is expected to have started executing eagerly. + this.state = State.RUNNING; + } + + private static RuntimeException runtimeExceptionFrom(final Throwable e) { + return (e instanceof RuntimeException) ? (RuntimeException) e : new RuntimeException(e); + } + + /** + * Unwrap cause of SparkException or UserCodeException as PipelineExecutionException. Otherwise, + * return {@code exception} as RuntimeException. + */ + private static RuntimeException unwrapCause(Throwable exception) { + Throwable next = exception; + while (next != null && (next instanceof SparkException || next instanceof UserCodeException)) { + exception = next; + next = next.getCause(); + } + return exception == next + ? runtimeExceptionFrom(exception) + : new Pipeline.PipelineExecutionException(firstNonNull(next, exception)); + } + + private State awaitTermination(Duration duration) + throws TimeoutException, ExecutionException, InterruptedException { + pipelineExecution.get(duration.getMillis(), TimeUnit.MILLISECONDS); + // Throws an exception if the job is not finished successfully in the given time. + return PipelineResult.State.DONE; + } + + @Override + public PipelineResult.State getState() { + return state; + } + + @Override + public PipelineResult.State waitUntilFinish() { + return waitUntilFinish(Duration.millis(Long.MAX_VALUE)); + } + + @Override + public State waitUntilFinish(final Duration duration) { + try { + State finishState = awaitTermination(duration); + offerNewState(finishState); + } catch (final TimeoutException e) { + // ignore. + } catch (final ExecutionException e) { + offerNewState(PipelineResult.State.FAILED); + throw unwrapCause(firstNonNull(e.getCause(), e)); + } catch (final Exception e) { + offerNewState(PipelineResult.State.FAILED); + throw unwrapCause(e); + } + + return state; + } + + @Override + public MetricResults metrics() { + return asAttemptedOnlyMetricResults(metrics.value()); + } + + @Override + public PipelineResult.State cancel() throws IOException { + offerNewState(PipelineResult.State.CANCELLED); + return state; + } + + private void offerNewState(State newState) { + State oldState = this.state; + this.state = newState; + if (!oldState.isTerminal() && newState.isTerminal() && onTerminalState != null) { + try { + onTerminalState.run(); + } catch (Exception e) { + throw unwrapCause(e); + } + } + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java new file mode 100644 index 000000000000..96717f29e87f --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming; + +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadFactory; +import javax.annotation.Nullable; +import org.apache.beam.runners.core.metrics.MetricsPusher; +import org.apache.beam.runners.core.metrics.NoOpMetricsSink; +import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator; +import org.apache.beam.runners.spark.structuredstreaming.metrics.SparkBeamMetricSource; +import org.apache.beam.runners.spark.structuredstreaming.translation.EvaluationContext; +import org.apache.beam.runners.spark.structuredstreaming.translation.PipelineTranslator; +import org.apache.beam.runners.spark.structuredstreaming.translation.SparkSessionFactory; +import org.apache.beam.runners.spark.structuredstreaming.translation.batch.PipelineTranslatorBatch; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.PipelineRunner; +import org.apache.beam.sdk.metrics.MetricsEnvironment; +import org.apache.beam.sdk.metrics.MetricsOptions; +import org.apache.beam.sdk.options.ExperimentalOptions; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.options.PipelineOptionsValidator; +import org.apache.beam.sdk.util.construction.SplittableParDo; +import org.apache.beam.sdk.util.construction.graph.ProjectionPushdownOptimizer; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.apache.spark.SparkEnv$; +import org.apache.spark.metrics.MetricsSystem; +import org.apache.spark.sql.SparkSession; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A Spark runner build on top of Spark's SQL Engine (Structured + * Streaming framework). + * + *

This runner is experimental, its coverage of the Beam model is still partial. Due to + * limitations of the Structured Streaming framework (e.g. lack of support for multiple stateful + * operators), streaming mode is not yet supported by this runner. + * + *

The runner translates transforms defined on a Beam pipeline to Spark `Dataset` transformations + * (leveraging the high level Dataset API) and then submits these to Spark to be executed. + * + *

To run a Beam pipeline with the default options using Spark's local mode, we would do the + * following: + * + *

{@code
+ * Pipeline p = [logic for pipeline creation]
+ * PipelineResult result = p.run();
+ * }
+ * + *

To create a pipeline runner to run against a different spark cluster, with a custom master url + * we would do the following: + * + *

{@code
+ * Pipeline p = [logic for pipeline creation]
+ * SparkCommonPipelineOptions options = p.getOptions.as(SparkCommonPipelineOptions.class);
+ * options.setSparkMaster("spark://host:port");
+ * PipelineResult result = p.run();
+ * }
+ */ +@SuppressWarnings({ + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +public final class SparkStructuredStreamingRunner + extends PipelineRunner { + + private static final Logger LOG = LoggerFactory.getLogger(SparkStructuredStreamingRunner.class); + + /** Options used in this pipeline runner. */ + private final SparkStructuredStreamingPipelineOptions options; + + /** + * Creates and returns a new SparkStructuredStreamingRunner with default options. In particular, + * against a spark instance running in local mode. + * + * @return A pipeline runner with default options. + */ + public static SparkStructuredStreamingRunner create() { + SparkStructuredStreamingPipelineOptions options = + PipelineOptionsFactory.as(SparkStructuredStreamingPipelineOptions.class); + options.setRunner(SparkStructuredStreamingRunner.class); + return new SparkStructuredStreamingRunner(options); + } + + /** + * Creates and returns a new SparkStructuredStreamingRunner with specified options. + * + * @param options The SparkStructuredStreamingPipelineOptions to use when executing the job. + * @return A pipeline runner that will execute with specified options. + */ + public static SparkStructuredStreamingRunner create( + SparkStructuredStreamingPipelineOptions options) { + return new SparkStructuredStreamingRunner(options); + } + + /** + * Creates and returns a new SparkStructuredStreamingRunner with specified options. + * + * @param options The PipelineOptions to use when executing the job. + * @return A pipeline runner that will execute with specified options. + */ + public static SparkStructuredStreamingRunner fromOptions(PipelineOptions options) { + return new SparkStructuredStreamingRunner( + PipelineOptionsValidator.validate(SparkStructuredStreamingPipelineOptions.class, options)); + } + + /** + * No parameter constructor defaults to running this pipeline in Spark's local mode, in a single + * thread. + */ + private SparkStructuredStreamingRunner(SparkStructuredStreamingPipelineOptions options) { + this.options = options; + } + + @Override + public SparkStructuredStreamingPipelineResult run(final Pipeline pipeline) { + MetricsEnvironment.setMetricsSupported(true); + MetricsAccumulator.clear(); + + LOG.info( + "*** SparkStructuredStreamingRunner is based on spark structured streaming framework and is no more \n" + + " based on RDD/DStream API. See\n" + + " https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html\n" + + " It is still experimental, its coverage of the Beam model is partial. ***"); + + PipelineTranslator.detectStreamingMode(pipeline, options); + checkArgument(!options.isStreaming(), "Streaming is not supported."); + + final SparkSession sparkSession = SparkSessionFactory.getOrCreateSession(options); + final MetricsAccumulator metrics = MetricsAccumulator.getInstance(sparkSession); + + final Future submissionFuture = + runAsync(() -> translatePipeline(sparkSession, pipeline).evaluate()); + + final SparkStructuredStreamingPipelineResult result = + new SparkStructuredStreamingPipelineResult( + submissionFuture, + metrics, + sparkStopFn(sparkSession, options.getUseActiveSparkSession())); + + if (options.getEnableSparkMetricSinks()) { + registerMetricsSource(options.getAppName(), metrics); + } + startMetricsPusher(result, metrics); + + if (options.getTestMode()) { + result.waitUntilFinish(); + } + + return result; + } + + private EvaluationContext translatePipeline(SparkSession sparkSession, Pipeline pipeline) { + // Default to using the primitive versions of Read.Bounded and Read.Unbounded for non-portable + // execution. + // TODO(https://github.com/apache/beam/issues/20530): Use SDF read as default when we address + // performance issue. + if (!ExperimentalOptions.hasExperiment(pipeline.getOptions(), "beam_fn_api")) { + SplittableParDo.convertReadBasedSplittableDoFnsToPrimitiveReadsIfNecessary(pipeline); + } + + if (!ExperimentalOptions.hasExperiment(options, "disable_projection_pushdown")) { + ProjectionPushdownOptimizer.optimize(pipeline); + } + + PipelineTranslator.replaceTransforms(pipeline, options); + + PipelineTranslator pipelineTranslator = new PipelineTranslatorBatch(); + return pipelineTranslator.translate(pipeline, sparkSession, options); + } + + private void registerMetricsSource(String appName, MetricsAccumulator metrics) { + final MetricsSystem metricsSystem = SparkEnv$.MODULE$.get().metricsSystem(); + final SparkBeamMetricSource metricsSource = + new SparkBeamMetricSource(appName + ".Beam", metrics); + // re-register the metrics in case of context re-use + metricsSystem.removeSource(metricsSource); + metricsSystem.registerSource(metricsSource); + } + + /** Start {@link MetricsPusher} if sink is set. */ + private void startMetricsPusher( + SparkStructuredStreamingPipelineResult result, MetricsAccumulator metrics) { + MetricsOptions metricsOpts = options.as(MetricsOptions.class); + Class metricsSink = metricsOpts.getMetricsSink(); + if (metricsSink != null && !metricsSink.equals(NoOpMetricsSink.class)) { + new MetricsPusher(metrics.value(), metricsOpts, result).start(); + } + } + + private static Future runAsync(Runnable task) { + ThreadFactory factory = + new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("SparkStructuredStreamingRunner-thread") + .build(); + ExecutorService execService = Executors.newSingleThreadExecutor(factory); + Future future = execService.submit(task); + execService.shutdown(); + return future; + } + + private static @Nullable Runnable sparkStopFn(SparkSession session, boolean isProvided) { + return !isProvided ? () -> session.stop() : null; + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunnerRegistrar.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunnerRegistrar.java new file mode 100644 index 000000000000..a1dc3ad3a9be --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunnerRegistrar.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming; + +import com.google.auto.service.AutoService; +import org.apache.beam.sdk.PipelineRunner; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsRegistrar; +import org.apache.beam.sdk.runners.PipelineRunnerRegistrar; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; + +/** + * Contains the {@link PipelineRunnerRegistrar} and {@link PipelineOptionsRegistrar} for the {@link + * SparkStructuredStreamingRunner}. + * + *

{@link AutoService} will register Spark's implementations of the {@link PipelineRunner} and + * {@link PipelineOptions} as available pipeline runner services. + */ +public final class SparkStructuredStreamingRunnerRegistrar { + private SparkStructuredStreamingRunnerRegistrar() {} + + /** Registers the {@link SparkStructuredStreamingRunner}. */ + @AutoService(PipelineRunnerRegistrar.class) + public static class Runner implements PipelineRunnerRegistrar { + @Override + public Iterable>> getPipelineRunners() { + return ImmutableList.of(SparkStructuredStreamingRunner.class); + } + } + + /** Registers the {@link SparkStructuredStreamingPipelineOptions}. */ + @AutoService(PipelineOptionsRegistrar.class) + public static class Options implements PipelineOptionsRegistrar { + @Override + public Iterable> getPipelineOptions() { + return ImmutableList.of(SparkStructuredStreamingPipelineOptions.class); + } + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/examples/WordCount.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/examples/WordCount.java new file mode 100644 index 000000000000..dca43581f669 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/examples/WordCount.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.examples; + +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.io.TextIO; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; +import org.apache.beam.sdk.options.Default; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SimpleFunction; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; + +/** Duplicated from beam-examples-java to avoid dependency. */ +public class WordCount { + + /** + * Concept #2: You can make your pipeline code less verbose by defining your DoFns statically out- + * of-line. This DoFn tokenizes lines of text into individual words; we pass it to a ParDo in the + * pipeline. + */ + @SuppressWarnings("StringSplitter") + static class ExtractWordsFn extends DoFn { + private final Counter emptyLines = Metrics.counter(ExtractWordsFn.class, "emptyLines"); + + @ProcessElement + public void processElement(ProcessContext c) { + if (c.element().trim().isEmpty()) { + emptyLines.inc(); + } + + // Split the line into words. + String[] words = c.element().split("[^\\p{L}]+"); + + // Output each word encountered into the output PCollection. + for (String word : words) { + if (!word.isEmpty()) { + c.output(word); + } + } + } + } + + /** A SimpleFunction that converts a Word and Count into a printable string. */ + public static class FormatAsTextFn extends SimpleFunction, String> { + @Override + public String apply(KV input) { + return input.getKey() + ": " + input.getValue(); + } + } + + /** + * A PTransform that converts a PCollection containing lines of text into a PCollection of + * formatted word counts. + * + *

Concept #3: This is a custom composite transform that bundles two transforms (ParDo and + * Count) as a reusable PTransform subclass. Using composite transforms allows for easy reuse, + * modular testing, and an improved monitoring experience. + */ + public static class CountWords + extends PTransform, PCollection>> { + @Override + public PCollection> expand(PCollection lines) { + + // Convert lines of text into individual words. + PCollection words = lines.apply(ParDo.of(new ExtractWordsFn())); + + // Count the number of times each word occurs. + return words.apply(Count.perElement()); + } + } + + /** + * Options supported by {@link WordCount}. + * + *

Concept #4: Defining your own configuration options. Here, you can add your own arguments to + * be processed by the command-line parser, and specify default values for them. You can then + * access the options values in your pipeline code. + * + *

Inherits standard configuration options. + */ + public interface WordCountOptions extends PipelineOptions { + @Description("Path of the file to read from") + @Default.String("gs://beam-samples/shakespeare/kinglear.txt") + String getInputFile(); + + void setInputFile(String value); + + @Description("Path of the file to write to") + String getOutput(); + + void setOutput(String value); + } + + public static void main(String[] args) { + WordCountOptions options = + PipelineOptionsFactory.fromArgs(args).withValidation().as(WordCountOptions.class); + Pipeline p = Pipeline.create(options); + + // Concepts #2 and #3: Our pipeline applies the composite CountWords transform, and passes the + // static FormatAsTextFn() to the ParDo transform. + p.apply("ReadLines", TextIO.read().from(options.getInputFile())) + .apply(new CountWords()) + .apply(MapElements.via(new FormatAsTextFn())) + .apply("WriteCounts", TextIO.write().to(options.getOutput())); + + p.run().waitUntilFinish(); + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java new file mode 100644 index 000000000000..c00b2d3594c0 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java @@ -0,0 +1,330 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.io; + +import static java.util.stream.Collectors.toList; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.emptyList; +import static org.apache.beam.sdk.values.WindowedValues.timestampedValueInGlobalWindow; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import static scala.collection.JavaConverters.asScalaIterator; + +import java.io.Closeable; +import java.io.IOException; +import java.io.Serializable; +import java.util.List; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.IntSupplier; +import java.util.function.Supplier; +import javax.annotation.CheckForNull; +import javax.annotation.Nullable; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.io.BoundedSource.BoundedReader; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; +import org.apache.spark.InterruptibleIterator; +import org.apache.spark.Partition; +import org.apache.spark.SparkContext; +import org.apache.spark.TaskContext; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Serializer; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.classic.Dataset$; +import org.apache.spark.sql.connector.catalog.SupportsRead; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCapability; +import org.apache.spark.sql.connector.read.Batch; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import scala.Option; +import scala.collection.Iterator; +import scala.reflect.ClassTag; + +public class BoundedDatasetFactory { + private BoundedDatasetFactory() {} + + /** + * Create a {@link Dataset} for a {@link BoundedSource} via a Spark {@link Table}. + * + *

Unfortunately tables are expected to return an {@link InternalRow}, requiring serialization. + * This makes this approach at the time being significantly less performant than creating a + * dataset from an RDD. + */ + public static Dataset> createDatasetFromRows( + SparkSession session, + BoundedSource source, + Supplier options, + Encoder> encoder) { + Params params = new Params<>(encoder, options, session.sparkContext().defaultParallelism()); + BeamTable table = new BeamTable<>(source, params); + LogicalPlan logicalPlan = DataSourceV2Relation.create(table, Option.empty(), Option.empty()); + // In Spark 4.0+, Dataset$ moved to org.apache.spark.sql.classic; cast session accordingly. + return (Dataset>) + Dataset$.MODULE$ + .ofRows((org.apache.spark.sql.classic.SparkSession) session, logicalPlan) + .as(encoder); + } + + /** + * Create a {@link Dataset} for a {@link BoundedSource} via a Spark {@link RDD}. + * + *

This is currently the most efficient approach as it avoid any serialization overhead. + */ + public static Dataset> createDatasetFromRDD( + SparkSession session, + BoundedSource source, + Supplier options, + Encoder> encoder) { + Params params = new Params<>(encoder, options, session.sparkContext().defaultParallelism()); + RDD> rdd = new BoundedRDD<>(session.sparkContext(), source, params); + return session.createDataset(rdd, encoder); + } + + /** An {@link RDD} for a bounded Beam source. */ + private static class BoundedRDD extends RDD> { + final BoundedSource source; + final Params params; + + public BoundedRDD(SparkContext sc, BoundedSource source, Params params) { + super(sc, emptyList(), ClassTag.apply(WindowedValue.class)); + this.source = source; + this.params = params; + } + + @Override + public Iterator> compute(Partition split, TaskContext context) { + return new InterruptibleIterator<>( + context, + asScalaIterator(new SourcePartitionIterator<>((SourcePartition) split, params))); + } + + @Override + public Partition[] getPartitions() { + return SourcePartition.partitionsOf(source, params).toArray(new Partition[0]); + } + } + + /** A Spark {@link Table} for a bounded Beam source supporting batch reads only. */ + private static class BeamTable implements Table, SupportsRead { + final BoundedSource source; + final Params params; + + BeamTable(BoundedSource source, Params params) { + this.source = source; + this.params = params; + } + + public Encoder> getEncoder() { + return params.encoder; + } + + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap ignored) { + return () -> + new Scan() { + @Override + public StructType readSchema() { + return params.encoder.schema(); + } + + @Override + public Batch toBatch() { + return new BeamBatch<>(source, params); + } + }; + } + + @Override + public String name() { + return "BeamSource<" + source.getClass().getName() + ">"; + } + + @Override + public StructType schema() { + return params.encoder.schema(); + } + + @Override + public Set capabilities() { + return ImmutableSet.of(TableCapability.BATCH_READ); + } + + private static class BeamBatch implements Batch, Serializable { + final BoundedSource source; + final Params params; + + private BeamBatch(BoundedSource source, Params params) { + this.source = source; + this.params = params; + } + + @Override + public InputPartition[] planInputPartitions() { + return SourcePartition.partitionsOf(source, params).toArray(new InputPartition[0]); + } + + @Override + public PartitionReaderFactory createReaderFactory() { + return p -> new BeamPartitionReader<>(((SourcePartition) p), params); + } + } + + private static class BeamPartitionReader implements PartitionReader { + final SourcePartitionIterator iterator; + final Serializer> serializer; + transient @Nullable InternalRow next; + + BeamPartitionReader(SourcePartition partition, Params params) { + iterator = new SourcePartitionIterator<>(partition, params); + serializer = ((ExpressionEncoder>) params.encoder).createSerializer(); + } + + @Override + public boolean next() throws IOException { + if (iterator.hasNext()) { + next = serializer.apply(iterator.next()); + return true; + } + return false; + } + + @Override + public InternalRow get() { + if (next == null) { + throw new IllegalStateException("Next not available"); + } + return next; + } + + @Override + public void close() throws IOException { + next = null; + iterator.close(); + } + } + } + + /** A Spark partition wrapping the partitioned Beam {@link BoundedSource}. */ + private static class SourcePartition implements Partition, InputPartition { + final BoundedSource source; + final int index; + + SourcePartition(BoundedSource source, IntSupplier idxSupplier) { + this.source = source; + this.index = idxSupplier.getAsInt(); + } + + static List> partitionsOf(BoundedSource source, Params params) { + try { + PipelineOptions options = params.options.get(); + long desiredSize = source.getEstimatedSizeBytes(options) / params.numPartitions; + List> split = (List>) source.split(desiredSize, options); + IntSupplier idxSupplier = new AtomicInteger(0)::getAndIncrement; + return split.stream().map(s -> new SourcePartition<>(s, idxSupplier)).collect(toList()); + } catch (Exception e) { + throw new RuntimeException( + "Error splitting BoundedSource " + source.getClass().getCanonicalName(), e); + } + } + + @Override + public int index() { + return index; + } + + @Override + public int hashCode() { + return index; + } + } + + /** A partition iterator on a partitioned Beam {@link BoundedSource}. */ + private static class SourcePartitionIterator extends AbstractIterator> + implements Closeable { + BoundedReader reader; + boolean started = false; + + public SourcePartitionIterator(SourcePartition partition, Params params) { + try { + reader = partition.source.createReader(params.options.get()); + } catch (IOException e) { + throw new RuntimeException("Failed to create reader from a BoundedSource.", e); + } + } + + @Override + @SuppressWarnings("nullness") // ok, reader not used any longer + public void close() throws IOException { + if (reader != null) { + endOfData(); + try { + reader.close(); + } finally { + reader = null; + } + } + } + + @Override + protected @CheckForNull WindowedValue computeNext() { + try { + if (started ? reader.advance() : start()) { + return timestampedValueInGlobalWindow(reader.getCurrent(), reader.getCurrentTimestamp()); + } else { + close(); + return endOfData(); + } + } catch (IOException e) { + throw new RuntimeException("Failed to start or advance reader.", e); + } + } + + private boolean start() throws IOException { + started = true; + return reader.start(); + } + } + + /** Shared parameters. */ + private static class Params implements Serializable { + final Encoder> encoder; + final Supplier options; + final int numPartitions; + + Params( + Encoder> encoder, Supplier options, int numPartitions) { + checkArgument(numPartitions > 0, "Number of partitions must be greater than zero."); + this.encoder = encoder; + this.options = options; + this.numPartitions = numPartitions; + } + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/package-info.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/package-info.java new file mode 100644 index 000000000000..23de70c705b3 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Spark-specific transforms for I/O. */ +package org.apache.beam.runners.spark.structuredstreaming.io; diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/BeamMetricSet.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/BeamMetricSet.java new file mode 100644 index 000000000000..7095036f28a3 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/BeamMetricSet.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.metrics; + +import com.codahale.metrics.Gauge; +import com.codahale.metrics.MetricFilter; +import java.util.Map; +import org.apache.beam.runners.spark.metrics.WithMetricsSupport; + +/** + * {@link BeamMetricSet} is a {@link Gauge} that returns a map of multiple metrics which get + * flattened in {@link WithMetricsSupport#getGauges()} for usage in {@link + * org.apache.spark.metrics.sink.Sink Spark metric sinks}. + * + *

Note: Recent versions of Dropwizard {@link com.codahale.metrics.MetricRegistry MetricRegistry} + * do not allow registering arbitrary implementations of {@link com.codahale.metrics.Metric Metrics} + * and require usage of {@link Gauge} here. + */ +// TODO: turn into MetricRegistry https://github.com/apache/beam/issues/22384 +abstract class BeamMetricSet implements Gauge>> { + + @Override + public final Map> getValue() { + return getValue("", MetricFilter.ALL); + } + + protected abstract Map> getValue(String prefix, MetricFilter filter); + + protected Gauge staticGauge(Number number) { + return new ConstantGauge(number.doubleValue()); + } + + private static class ConstantGauge implements Gauge { + private final double value; + + ConstantGauge(double value) { + this.value = value; + } + + @Override + public Double getValue() { + return value; + } + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/MetricsAccumulator.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/MetricsAccumulator.java new file mode 100644 index 000000000000..63407b9f14d8 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/MetricsAccumulator.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.metrics; + +import org.apache.beam.runners.core.metrics.MetricsContainerStepMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.util.AccumulatorV2; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * {@link AccumulatorV2} for Beam metrics captured in {@link MetricsContainerStepMap}. + * + * @see accumulatorsV2 + */ +public class MetricsAccumulator + extends AccumulatorV2 { + private static final Logger LOG = LoggerFactory.getLogger(MetricsAccumulator.class); + private static final MetricsContainerStepMap EMPTY = new SparkMetricsContainerStepMap(); + private static final String ACCUMULATOR_NAME = "Beam.Metrics"; + + private static volatile @Nullable MetricsAccumulator instance = null; + + private MetricsContainerStepMap value; + + public MetricsAccumulator() { + value = new SparkMetricsContainerStepMap(); + } + + private MetricsAccumulator(MetricsContainerStepMap value) { + this.value = value; + } + + @Override + public boolean isZero() { + return value.equals(EMPTY); + } + + @Override + public MetricsAccumulator copy() { + MetricsContainerStepMap newContainer = new SparkMetricsContainerStepMap(); + newContainer.updateAll(value); + return new MetricsAccumulator(newContainer); + } + + @Override + public void reset() { + value = new SparkMetricsContainerStepMap(); + } + + @Override + public void add(MetricsContainerStepMap other) { + value.updateAll(other); + } + + @Override + public void merge(AccumulatorV2 other) { + value.updateAll(other.value()); + } + + @Override + public MetricsContainerStepMap value() { + return value; + } + + /** + * Get the {@link MetricsAccumulator} on this driver. If there's no such accumulator yet, it will + * be created and registered using the provided {@link SparkSession}. + */ + public static MetricsAccumulator getInstance(SparkSession session) { + MetricsAccumulator current = instance; + if (current != null) { + return current; + } + synchronized (MetricsAccumulator.class) { + MetricsAccumulator accumulator = instance; + if (accumulator == null) { + accumulator = new MetricsAccumulator(); + session.sparkContext().register(accumulator, ACCUMULATOR_NAME); + instance = accumulator; + LOG.info("Instantiated metrics accumulator: {}", instance.value()); + } + return accumulator; + } + } + + @VisibleForTesting + public static void clear() { + synchronized (MetricsAccumulator.class) { + instance = null; + } + } + + /** + * Sole purpose of this class is to override {@link #toString()} of {@link + * MetricsContainerStepMap} in order to show meaningful metrics in Spark Web Interface. + */ + private static class SparkMetricsContainerStepMap extends MetricsContainerStepMap { + + @Override + public String toString() { + return asAttemptedOnlyMetricResults(this).toString(); + } + + @Override + public boolean equals(@Nullable Object o) { + return super.equals(o); + } + + @Override + public int hashCode() { + return super.hashCode(); + } + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkBeamMetric.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkBeamMetric.java new file mode 100644 index 000000000000..19ba92956e70 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkBeamMetric.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.metrics; + +import static org.apache.beam.runners.core.metrics.MetricsContainerStepMap.asAttemptedOnlyMetricResults; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Predicates.not; + +import com.codahale.metrics.Gauge; +import com.codahale.metrics.Metric; +import com.codahale.metrics.MetricFilter; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import javax.annotation.Nullable; +import org.apache.beam.runners.core.metrics.MetricsContainerStepMap; +import org.apache.beam.sdk.metrics.DistributionResult; +import org.apache.beam.sdk.metrics.GaugeResult; +import org.apache.beam.sdk.metrics.MetricKey; +import org.apache.beam.sdk.metrics.MetricName; +import org.apache.beam.sdk.metrics.MetricQueryResults; +import org.apache.beam.sdk.metrics.MetricResult; +import org.apache.beam.sdk.metrics.MetricResults; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Streams; + +/** + * An adapter between the {@link MetricsContainerStepMap} and the Dropwizard {@link Metric} + * interface. + */ +class SparkBeamMetric extends BeamMetricSet { + + private static final String ILLEGAL_CHARACTERS = "[^A-Za-z0-9-]"; + + private final MetricsAccumulator metrics; + + SparkBeamMetric(MetricsAccumulator metrics) { + this.metrics = metrics; + } + + @Override + public Map> getValue(String prefix, MetricFilter filter) { + MetricResults metricResults = asAttemptedOnlyMetricResults(metrics.value()); + Map> metrics = new HashMap<>(); + MetricQueryResults allMetrics = metricResults.allMetrics(); + for (MetricResult metricResult : allMetrics.getCounters()) { + putFiltered(metrics, filter, renderName(prefix, metricResult), metricResult.getAttempted()); + } + for (MetricResult metricResult : allMetrics.getDistributions()) { + DistributionResult result = metricResult.getAttempted(); + String baseName = renderName(prefix, metricResult); + putFiltered(metrics, filter, baseName + ".count", result.getCount()); + putFiltered(metrics, filter, baseName + ".sum", result.getSum()); + putFiltered(metrics, filter, baseName + ".min", result.getMin()); + putFiltered(metrics, filter, baseName + ".max", result.getMax()); + putFiltered(metrics, filter, baseName + ".mean", result.getMean()); + } + for (MetricResult metricResult : allMetrics.getGauges()) { + putFiltered( + metrics, + filter, + renderName(prefix, metricResult), + metricResult.getAttempted().getValue()); + } + return metrics; + } + + @VisibleForTesting + @SuppressWarnings("nullness") // ok to have nullable elements on stream + static String renderName(String prefix, MetricResult metricResult) { + MetricKey key = metricResult.getKey(); + MetricName name = key.metricName(); + String step = key.stepName(); + return Streams.concat( + Stream.of(prefix), // prefix is not cleaned, should it be? + Stream.of(stripSuffix(normalizePart(step))), + Stream.of(name.getNamespace(), name.getName()).map(SparkBeamMetric::normalizePart)) + .filter(not(Strings::isNullOrEmpty)) + .collect(Collectors.joining(".")); + } + + private static @Nullable String normalizePart(@Nullable String str) { + return str != null ? str.replaceAll(ILLEGAL_CHARACTERS, "_") : null; + } + + private static @Nullable String stripSuffix(@Nullable String str) { + return str != null && str.endsWith("_") ? str.substring(0, str.length() - 1) : str; + } + + private void putFiltered( + Map> metrics, MetricFilter filter, String name, Number value) { + Gauge metric = staticGauge(value); + if (filter.matches(name, metric)) { + metrics.put(name, metric); + } + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkBeamMetricSource.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkBeamMetricSource.java new file mode 100644 index 000000000000..8a1e980ae0c5 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkBeamMetricSource.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.metrics; + +import com.codahale.metrics.MetricRegistry; +import org.apache.spark.metrics.source.Source; + +/** + * A Spark {@link Source} that is tailored to expose a {@link SparkBeamMetric}, wrapping an + * underlying {@link org.apache.beam.sdk.metrics.MetricResults} instance. + */ +public class SparkBeamMetricSource implements Source { + private final String name; + + private final MetricRegistry metricRegistry = new MetricRegistry(); + + public SparkBeamMetricSource(String name, MetricsAccumulator metrics) { + this.name = name; + metricRegistry.register(name, new SparkBeamMetric(metrics)); + } + + @Override + public String sourceName() { + return name; + } + + @Override + public MetricRegistry metricRegistry() { + return metricRegistry; + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/WithMetricsSupport.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/WithMetricsSupport.java new file mode 100644 index 000000000000..f632f7a6aa1a --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/WithMetricsSupport.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.metrics; + +import com.codahale.metrics.Counter; +import com.codahale.metrics.Gauge; +import com.codahale.metrics.Histogram; +import com.codahale.metrics.Meter; +import com.codahale.metrics.MetricFilter; +import com.codahale.metrics.MetricRegistry; +import com.codahale.metrics.Timer; +import java.util.Map; +import java.util.SortedMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSortedMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Ordering; + +/** + * A {@link MetricRegistry} decorator-like that supports {@link BeamMetricSet}s as {@link Gauge + * Gauges}. + * + *

{@link MetricRegistry} is not an interface, so this is not a by-the-book decorator. That said, + * it delegates all metric related getters to the "decorated" instance. + */ +public class WithMetricsSupport extends MetricRegistry { + + private final MetricRegistry internalMetricRegistry; + + private WithMetricsSupport(final MetricRegistry internalMetricRegistry) { + this.internalMetricRegistry = internalMetricRegistry; + } + + public static WithMetricsSupport forRegistry(final MetricRegistry metricRegistry) { + return new WithMetricsSupport(metricRegistry); + } + + @Override + public SortedMap getTimers(final MetricFilter filter) { + return internalMetricRegistry.getTimers(filter); + } + + @Override + public SortedMap getMeters(final MetricFilter filter) { + return internalMetricRegistry.getMeters(filter); + } + + @Override + public SortedMap getHistograms(final MetricFilter filter) { + return internalMetricRegistry.getHistograms(filter); + } + + @Override + public SortedMap getCounters(final MetricFilter filter) { + return internalMetricRegistry.getCounters(filter); + } + + @Override + @SuppressWarnings({"rawtypes"}) // required by interface + public SortedMap getGauges(final MetricFilter filter) { + ImmutableSortedMap.Builder builder = + new ImmutableSortedMap.Builder<>(Ordering.from(String.CASE_INSENSITIVE_ORDER)); + + Map gauges = + internalMetricRegistry.getGauges( + (n, m) -> filter.matches(n, m) || m instanceof BeamMetricSet); + + for (Map.Entry entry : gauges.entrySet()) { + Gauge gauge = entry.getValue(); + if (gauge instanceof BeamMetricSet) { + builder.putAll(((BeamMetricSet) gauge).getValue(entry.getKey(), filter)); + } else { + builder.put(entry.getKey(), gauge); + } + } + return builder.build(); + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/package-info.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/package-info.java new file mode 100644 index 000000000000..16a1a956e8e8 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Provides internal utilities for implementing Beam metrics using Spark accumulators. */ +package org.apache.beam.runners.spark.structuredstreaming.metrics; diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/sink/CodahaleCsvSink.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/sink/CodahaleCsvSink.java new file mode 100644 index 000000000000..dd23d5040464 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/sink/CodahaleCsvSink.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.metrics.sink; + +import com.codahale.metrics.MetricRegistry; +import java.util.Properties; +import org.apache.beam.runners.spark.structuredstreaming.metrics.WithMetricsSupport; +import org.apache.spark.SecurityManager; +import org.apache.spark.metrics.sink.Sink; + +/** + * A {@link Sink} for Spark's + * metric system reporting metrics (including Beam step metrics) to a CSV file. + * + *

The sink is configured using Spark configuration parameters, for example: + * + *

{@code
+ * "spark.metrics.conf.*.sink.csv.class"="org.apache.beam.runners.spark.structuredstreaming.metrics.sink.CodahaleCsvSink"
+ * "spark.metrics.conf.*.sink.csv.directory"=""
+ * "spark.metrics.conf.*.sink.csv.period"=10
+ * "spark.metrics.conf.*.sink.csv.unit"=seconds
+ * }
+ */ +public class CodahaleCsvSink implements Sink { + + // Initialized reflectively as done by Spark's MetricsSystem + private final org.apache.spark.metrics.sink.CsvSink delegate; + + /** Constructor for Spark 3.1.x and earlier. */ + public CodahaleCsvSink( + final Properties properties, + final MetricRegistry metricRegistry, + final SecurityManager securityMgr) { + try { + delegate = + org.apache.spark.metrics.sink.CsvSink.class + .getConstructor(Properties.class, MetricRegistry.class, SecurityManager.class) + .newInstance(properties, WithMetricsSupport.forRegistry(metricRegistry), securityMgr); + } catch (ReflectiveOperationException ex) { + throw new RuntimeException(ex); + } + } + + /** Constructor for Spark 3.2.x and later. */ + public CodahaleCsvSink(final Properties properties, final MetricRegistry metricRegistry) { + try { + delegate = + org.apache.spark.metrics.sink.CsvSink.class + .getConstructor(Properties.class, MetricRegistry.class) + .newInstance(properties, WithMetricsSupport.forRegistry(metricRegistry)); + } catch (ReflectiveOperationException ex) { + throw new RuntimeException(ex); + } + } + + @Override + public void start() { + delegate.start(); + } + + @Override + public void stop() { + delegate.stop(); + } + + @Override + public void report() { + delegate.report(); + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/sink/CodahaleGraphiteSink.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/sink/CodahaleGraphiteSink.java new file mode 100644 index 000000000000..fe709ad81ab7 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/sink/CodahaleGraphiteSink.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.metrics.sink; + +import com.codahale.metrics.MetricRegistry; +import java.util.Properties; +import org.apache.beam.runners.spark.structuredstreaming.metrics.WithMetricsSupport; +import org.apache.spark.SecurityManager; +import org.apache.spark.metrics.sink.Sink; + +/** + * A {@link Sink} for Spark's + * metric system reporting metrics (including Beam step metrics) to Graphite. + * + *

The sink is configured using Spark configuration parameters, for example: + * + *

{@code
+ * "spark.metrics.conf.*.sink.graphite.class"="org.apache.beam.runners.spark.structuredstreaming.metrics.sink.CodahaleGraphiteSink"
+ * "spark.metrics.conf.*.sink.graphite.host"=""
+ * "spark.metrics.conf.*.sink.graphite.port"=
+ * "spark.metrics.conf.*.sink.graphite.period"=10
+ * "spark.metrics.conf.*.sink.graphite.unit"=seconds
+ * "spark.metrics.conf.*.sink.graphite.prefix"=""
+ * "spark.metrics.conf.*.sink.graphite.regex"=""
+ * }
+ */ +public class CodahaleGraphiteSink implements Sink { + + // Initialized reflectively as done by Spark's MetricsSystem + private final org.apache.spark.metrics.sink.GraphiteSink delegate; + + /** Constructor for Spark 3.1.x and earlier. */ + public CodahaleGraphiteSink( + final Properties properties, + final MetricRegistry metricRegistry, + final org.apache.spark.SecurityManager securityMgr) { + try { + delegate = + org.apache.spark.metrics.sink.GraphiteSink.class + .getConstructor(Properties.class, MetricRegistry.class, SecurityManager.class) + .newInstance(properties, WithMetricsSupport.forRegistry(metricRegistry), securityMgr); + } catch (ReflectiveOperationException ex) { + throw new RuntimeException(ex); + } + } + + /** Constructor for Spark 3.2.x and later. */ + public CodahaleGraphiteSink(final Properties properties, final MetricRegistry metricRegistry) { + try { + delegate = + org.apache.spark.metrics.sink.GraphiteSink.class + .getConstructor(Properties.class, MetricRegistry.class) + .newInstance(properties, WithMetricsSupport.forRegistry(metricRegistry)); + } catch (ReflectiveOperationException ex) { + throw new RuntimeException(ex); + } + } + + @Override + public void start() { + delegate.start(); + } + + @Override + public void stop() { + delegate.stop(); + } + + @Override + public void report() { + delegate.report(); + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/sink/package-info.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/sink/package-info.java new file mode 100644 index 000000000000..427e5441c579 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/sink/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Spark sinks that supports beam metrics and aggregators. */ +package org.apache.beam.runners.spark.structuredstreaming.metrics.sink; diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/package-info.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/package-info.java new file mode 100644 index 000000000000..aefeb282f8f4 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Internal implementation of the Beam runner for Apache Spark. */ +package org.apache.beam.runners.spark.structuredstreaming; diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/EvaluationContext.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/EvaluationContext.java new file mode 100644 index 000000000000..55c4bbaedd3c --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/EvaluationContext.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation; + +import java.util.Collection; +import javax.annotation.Nullable; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Throwables; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.execution.ExplainMode; +import org.apache.spark.util.Utils; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * The {@link EvaluationContext} is the result of a pipeline {@link PipelineTranslator#translate + * translation} and can be used to evaluate / run the pipeline. + * + *

However, in some cases pipeline translation involves the early evaluation of some parts of the + * pipeline. For example, this is necessary to materialize side-inputs. The {@link + * EvaluationContext} won't re-evaluate such datasets. + */ +@SuppressWarnings("Slf4jDoNotLogMessageOfExceptionExplicitly") +@Internal +public final class EvaluationContext { + private static final Logger LOG = LoggerFactory.getLogger(EvaluationContext.class); + + interface NamedDataset { + String name(); + + @Nullable + Dataset> dataset(); + } + + private final Collection> leaves; + private final SparkSession session; + + EvaluationContext(Collection> leaves, SparkSession session) { + this.leaves = leaves; + this.session = session; + } + + /** Trigger evaluation of all leaf datasets. */ + public void evaluate() { + for (NamedDataset ds : leaves) { + final Dataset dataset = ds.dataset(); + if (dataset == null) { + continue; + } + if (LOG.isDebugEnabled()) { + ExplainMode explainMode = ExplainMode.fromString("simple"); + String execPlan = dataset.queryExecution().explainString(explainMode); + LOG.debug("Evaluating dataset {}:\n{}", ds.name(), execPlan); + } + // force evaluation using a dummy foreach action + evaluate(ds.name(), dataset); + } + } + + /** + * The purpose of this utility is to mark the evaluation of Spark actions, both during Pipeline + * translation, when evaluation is required, and when finally evaluating the pipeline. + */ + public static void evaluate(String name, Dataset ds) { + long startMs = System.currentTimeMillis(); + try { + // force computation using noop format + ds.write().mode("overwrite").format("noop").save(); + LOG.info("Evaluated dataset {} in {}", name, durationSince(startMs)); + } catch (RuntimeException e) { + LOG.error("Failed to evaluate dataset {}: {}", name, Throwables.getRootCause(e).getMessage()); + throw new RuntimeException(e); + } + } + + /** + * The purpose of this utility is to mark the evaluation of Spark actions, both during Pipeline + * translation, when evaluation is required, and when finally evaluating the pipeline. + */ + public static T[] collect(String name, Dataset ds) { + long startMs = System.currentTimeMillis(); + try { + T[] res = (T[]) ds.collect(); + LOG.info("Collected dataset {} in {} [size: {}]", name, durationSince(startMs), res.length); + return res; + } catch (Exception e) { + LOG.error("Failed to collect dataset {}: {}", name, Throwables.getRootCause(e).getMessage()); + throw new RuntimeException(e); + } + } + + public SparkSession getSparkSession() { + return session; + } + + private static String durationSince(long startMs) { + return Utils.msDurationToString(System.currentTimeMillis() - startMs); + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java new file mode 100644 index 000000000000..a681dea2fde5 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java @@ -0,0 +1,516 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation; + +import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.DO_NOT_ENTER_TRANSFORM; +import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.ENTER_TRANSFORM; +import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; +import static org.apache.beam.sdk.values.PCollection.IsBounded.UNBOUNDED; + +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Supplier; +import javax.annotation.Nullable; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.spark.SparkCommonPipelineOptions; +import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.SideInputValues; +import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderProvider; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.Pipeline.PipelineVisitor; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.StreamingOptions; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.runners.TransformHierarchy.Node; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.View; +import org.apache.beam.sdk.util.construction.PTransformTranslation; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.POutput; +import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.storage.StorageLevel; +import org.checkerframework.checker.nullness.qual.MonotonicNonNull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.reflect.ClassTag; + +/** + * The pipeline translator translates a Beam {@link Pipeline} into a Spark correspondence, that can + * then be evaluated. + * + *

The translation involves traversing the hierarchy of a pipeline multiple times: + * + *

    + *
  1. Detect if {@link StreamingOptions#setStreaming streaming} mode is required. + *
  2. Identify datasets that are repeatedly used as input and should be cached. + *
  3. And finally, translate each primitive or composite {@link PTransform} that is {@link + * #getTransformTranslator known} and {@link TransformTranslator#canTranslate supported} into + * its Spark correspondence. If a composite is not supported, it will be expanded further into + * its parts and translated then. + *
+ */ +@Internal +public abstract class PipelineTranslator { + private static final Logger LOG = LoggerFactory.getLogger(PipelineTranslator.class); + + // Threshold to limit query plan complexity to avoid unnecessary planning overhead. Currently this + // is fairly low, Catalyst won't be able to optimize beyond ParDos anyways. Until there's + // dedicated support for schema transforms, there's little value of allowing more complex plans at + // this point. + private static final int PLAN_COMPLEXITY_THRESHOLD = 6; + + public static void replaceTransforms(Pipeline pipeline, StreamingOptions options) { + pipeline.replaceAll(SparkTransformOverrides.getDefaultOverrides(options.isStreaming())); + } + + /** + * Analyse the pipeline to determine if we have to switch to streaming mode for the pipeline + * translation and update {@link StreamingOptions} accordingly. + */ + public static void detectStreamingMode(Pipeline pipeline, StreamingOptions options) { + StreamingModeDetector detector = new StreamingModeDetector(options.isStreaming()); + pipeline.traverseTopologically(detector); + options.setStreaming(detector.streaming); + } + + /** Returns a {@link TransformTranslator} for the given {@link PTransform} if known. */ + protected abstract @Nullable < + InT extends PInput, OutT extends POutput, TransformT extends PTransform> + TransformTranslator getTransformTranslator(TransformT transform); + + /** + * Translates a Beam pipeline into its Spark correspondence using the Spark SQL / Dataset API. + * + *

Note, in some cases this involves the early evaluation of some parts of the pipeline. For + * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView + * PCollectionView} in a translation the corresponding Spark {@link + * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and + * broadcasted to be able to continue with the translation. + * + * @return The result of the translation is an {@link EvaluationContext} that can trigger the + * evaluation of the Spark pipeline. + */ + public EvaluationContext translate( + Pipeline pipeline, SparkSession session, SparkCommonPipelineOptions options) { + LOG.debug("starting translation of the pipeline using {}", getClass().getName()); + DependencyVisitor dependencies = new DependencyVisitor(); + pipeline.traverseTopologically(dependencies); + + TranslatingVisitor translator = new TranslatingVisitor(session, options, dependencies.results); + pipeline.traverseTopologically(translator); + + return new EvaluationContext(translator.leaves, session); + } + + /** + * The correspondence of a {@link PCollection} as result of translating a {@link PTransform} + * including additional metadata (such as name and dependents). + */ + private static final class TranslationResult + implements EvaluationContext.NamedDataset { + private final String name; + private final float complexityFactor; + private float planComplexity = 0; + + private @MonotonicNonNull Dataset> dataset = null; + private @MonotonicNonNull Broadcast> sideInputBroadcast = null; + private @Nullable UnresolvedTranslation unresolved = null; + + // dependent downstream transforms (if empty this is a leaf) + private final Set> dependentTransforms = new HashSet<>(); + // upstream dependencies (required inputs) + private final List> dependencies; + + private TranslationResult( + PCollection pCol, float complexityFactor, List> dependencies) { + this.name = pCol.getName(); + this.complexityFactor = complexityFactor; + this.dependencies = dependencies; + } + + @Override + public String name() { + return name; + } + + @Override + public @Nullable Dataset> dataset() { + return dataset; + } + + private boolean isLeaf() { + return dependentTransforms.isEmpty(); + } + + private int usages() { + return dependentTransforms.size(); + } + + private void resetPlanComplexity() { + planComplexity = 1; + } + + /** Estimate complexity of query plan by multiplying complexities of all dependencies. */ + private float estimatePlanComplexity() { + if (planComplexity > 0) { + return planComplexity; + } + float complexity = 1 + complexityFactor; + for (TranslationResult result : dependencies) { + complexity *= result.estimatePlanComplexity(); + } + return (planComplexity = complexity); + } + } + + /** + * Unresolved translation, allowing to optimize the generated Spark DAG. + * + *

An unresolved translation can - in certain cases - be fused together with following + * transforms. Currently this is only the case for ParDos with linear linage. + */ + public interface UnresolvedTranslation { + PCollection getInput(); + + UnresolvedTranslation fuse(UnresolvedTranslation next); + + Dataset> resolve( + Supplier options, Dataset> input); + } + + /** Shared, mutable state during the translation of a pipeline and omitted afterwards. */ + public interface TranslationState extends EncoderProvider { + Dataset> getDataset(PCollection pCollection); + + boolean isLeaf(PCollection pCollection); + + void putUnresolved( + PCollection out, UnresolvedTranslation unresolved); + + void putDataset( + PCollection pCollection, Dataset> dataset, boolean cache); + + default void putDataset(PCollection pCollection, Dataset> dataset) { + putDataset(pCollection, dataset, true); + } + + Broadcast> getSideInputBroadcast( + PCollection pCollection, SideInputValues.Loader loader); + + Supplier getOptionsSupplier(); + + PipelineOptions getOptions(); + + SparkSession getSparkSession(); + } + + /** + * {@link PTransformVisitor} that translates supported {@link PTransform PTransforms} into their + * Spark correspondence. + * + *

Note, in some cases this involves the early evaluation of some parts of the pipeline. For + * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView + * PCollectionView} in a translation the corresponding Spark {@link + * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and + * broadcasted. + */ + private class TranslatingVisitor extends PTransformVisitor implements TranslationState { + private final Map, TranslationResult> translationResults; + private final Map, Encoder> encoders; + private final SparkSession sparkSession; + private final PipelineOptions options; + private final Supplier optionsSupplier; + private final StorageLevel storageLevel; + + private final Set> leaves; + + public TranslatingVisitor( + SparkSession sparkSession, + SparkCommonPipelineOptions options, + Map, TranslationResult> translationResults) { + this.sparkSession = sparkSession; + this.translationResults = translationResults; + this.options = options; + this.optionsSupplier = new BroadcastOptions(sparkSession, options); + this.storageLevel = StorageLevel.fromString(options.getStorageLevel()); + this.encoders = new HashMap<>(); + this.leaves = new HashSet<>(); + } + + @Override + void visit( + Node node, + PTransform transform, + TransformTranslator> translator) { + + AppliedPTransform> appliedTransform = + (AppliedPTransform) node.toAppliedPTransform(getPipeline()); + try { + LOG.info( + "Translating {}: {}", + node.isCompositeNode() ? "composite" : "primitive", + node.getFullName()); + translator.translate(transform, appliedTransform, this); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public Encoder encoderOf(Coder coder, Factory factory) { + // computeIfAbsent fails with Java 11 on recursive factory + Encoder enc = (Encoder) encoders.get(coder); + if (enc == null) { + enc = factory.apply(coder); + encoders.put(coder, enc); + } + return enc; + } + + private TranslationResult getResult(PCollection pCollection) { + return (TranslationResult) checkStateNotNull(translationResults.get(pCollection)); + } + + @Override + public Dataset> getDataset(PCollection pCollection) { + return getOrResolve(getResult(pCollection)); + } + + @Override + public void putDataset( + PCollection pCollection, Dataset> dataset, boolean cache) { + TranslationResult result = getResult(pCollection); + result.dataset = dataset; + + if (cache && result.usages() > 1) { + LOG.info("Dataset {} will be cached for reuse.", result.name); + dataset.persist(storageLevel); // use NONE to disable + } + + if (result.estimatePlanComplexity() > PLAN_COMPLEXITY_THRESHOLD) { + // Break linage of dataset to limit planning overhead for complex query plans. + LOG.info("Breaking linage of dataset {} to limit complexity of query plan.", result.name); + result.dataset = sparkSession.createDataset(dataset.rdd(), dataset.encoder()); + result.resetPlanComplexity(); + } + + if (result.isLeaf()) { + leaves.add(result); + } + } + + private Dataset> getOrResolve(TranslationResult result) { + UnresolvedTranslation unresolved = result.unresolved; + if (unresolved != null) { + result.dataset = unresolved.resolve(optionsSupplier, getDataset(unresolved.getInput())); + result.unresolved = null; + } + return checkStateNotNull(result.dataset); + } + + @Override + public void putUnresolved( + PCollection out, UnresolvedTranslation unresolved) { + // For simplicity, pretend InT is the same + TranslationResult translIn = getResult(unresolved.getInput()); + TranslationResult translOut = getResult(out); + // Fuse with previous unresolved translation if necessary + UnresolvedTranslation unresolvedIn = translIn.unresolved; + translOut.unresolved = unresolvedIn != null ? unresolvedIn.fuse(unresolved) : unresolved; + translIn.unresolved = null; + // Resolve dataset immediately in case of leaf or when there are multiple downstreams + if (translOut.usages() != 1) { + putDataset(out, getOrResolve(translOut)); + } + } + + @Override + public boolean isLeaf(PCollection pCollection) { + return getResult(pCollection).isLeaf(); + } + + @Override + public Broadcast> getSideInputBroadcast( + PCollection pCollection, SideInputValues.Loader loader) { + TranslationResult result = getResult(pCollection); + if (result.sideInputBroadcast == null) { + SideInputValues sideInputValues = loader.apply(getOrResolve(result)); + result.sideInputBroadcast = broadcast(sparkSession, sideInputValues); + } + return result.sideInputBroadcast; + } + + @Override + public Supplier getOptionsSupplier() { + return optionsSupplier; + } + + @Override + public PipelineOptions getOptions() { + return options; + } + + @Override + public SparkSession getSparkSession() { + return sparkSession; + } + } + + /** + * Supplier wrapping broadcasted {@link PipelineOptions} to avoid repeatedly serializing those as + * part of the task closures. + */ + private static class BroadcastOptions implements Supplier, Serializable { + private final Broadcast broadcast; + + private BroadcastOptions(SparkSession session, PipelineOptions options) { + this.broadcast = broadcast(session, new SerializablePipelineOptions(options)); + } + + @Override + public PipelineOptions get() { + return broadcast.value().get(); + } + } + + private static Broadcast broadcast(SparkSession session, T t) { + return session.sparkContext().broadcast(t, (ClassTag) ClassTag.AnyRef()); + } + + /** + * {@link PTransformVisitor} that analyses dependencies of supported {@link PTransform + * PTransforms} to help identify cache candidates. + * + *

The visitor may throw if a {@link PTransform} is observed that uses unsupported features. + */ + private class DependencyVisitor extends PTransformVisitor { + private final Map, TranslationResult> results = new HashMap<>(); + + @Override + void visit( + Node node, + PTransform transform, + TransformTranslator> translator) { + // Track `transform` as downstream dependency of every input and reversely + // every input is a dependency of each output of `transform`. + List> dependencies = new ArrayList<>(node.getInputs().size()); + for (Map.Entry, PCollection> entry : node.getInputs().entrySet()) { + TranslationResult input = checkStateNotNull(results.get(entry.getValue())); + dependencies.add(input); + input.dependentTransforms.add(transform); + } + // add new translation result for every output of `transform` + for (PCollection pOut : node.getOutputs().values()) { + results.put(pOut, new TranslationResult<>(pOut, translator.complexityFactor, dependencies)); + } + } + } + + /** + * An abstract {@link PipelineVisitor} that visits all translatable {@link PTransform} pipeline + * nodes of a pipeline with the respective {@link TransformTranslator}. + * + *

The visitor may throw if a {@link PTransform} is observed that uses unsupported features. + */ + private abstract class PTransformVisitor extends PipelineVisitor.Defaults { + + /** Visit the {@link PTransform} with its respective {@link TransformTranslator}. */ + abstract void visit( + Node node, + PTransform transform, + TransformTranslator> translator); + + @Override + public final CompositeBehavior enterCompositeTransform(Node node) { + PTransform transform = (PTransform) node.getTransform(); + TransformTranslator> translator = + getSupportedTranslator(transform); + if (transform != null && translator != null) { + visit(node, transform, translator); + return DO_NOT_ENTER_TRANSFORM; + } else { + return ENTER_TRANSFORM; + } + } + + @Override + public final void visitPrimitiveTransform(Node node) { + PTransform transform = (PTransform) node.getTransform(); + if (transform == null || transform.getClass().equals(View.CreatePCollectionView.class)) { + return; // ignore, nothing to be translated here, views are handled on the consumer side + } + TransformTranslator> translator = + getSupportedTranslator(transform); + if (translator == null) { + String urn = PTransformTranslation.urnForTransform(transform); + throw new UnsupportedOperationException("Transform " + urn + " is not supported."); + } + visit(node, transform, translator); + } + + /** {@link TransformTranslator} for {@link PTransform} if translation is known and supported. */ + private @Nullable TransformTranslator> + getSupportedTranslator(@Nullable PTransform transform) { + if (transform == null) { + return null; + } + TransformTranslator> translator = + getTransformTranslator(transform); + return translator != null && translator.canTranslate(transform) ? translator : null; + } + } + + /** + * Traverse the pipeline to check for unbounded {@link PCollection PCollections} that would + * require streaming mode unless streaming mode is already enabled. + */ + private static class StreamingModeDetector extends PipelineVisitor.Defaults { + private boolean streaming; + + StreamingModeDetector(boolean streaming) { + this.streaming = streaming; + } + + @Override + public CompositeBehavior enterCompositeTransform(Node node) { + return streaming ? DO_NOT_ENTER_TRANSFORM : ENTER_TRANSFORM; // stop if in streaming mode + } + + @Override + public void visitValue(PValue value, Node producer) { + if (value instanceof PCollection && ((PCollection) value).isBounded() == UNBOUNDED) { + LOG.info("Found unbounded PCollection {}, switching to streaming mode.", value.getName()); + streaming = true; + } + } + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/SparkSessionFactory.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/SparkSessionFactory.java new file mode 100644 index 000000000000..5b9e5b6fae86 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/SparkSessionFactory.java @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation; + +import static org.apache.commons.lang3.ArrayUtils.EMPTY_STRING_ARRAY; +import static org.apache.commons.lang3.StringUtils.substringBetween; +import static org.apache.commons.lang3.math.NumberUtils.toInt; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.serializers.JavaSerializer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import javax.annotation.Nullable; +import org.apache.beam.repackaged.core.org.apache.commons.lang3.ArrayUtils; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions; +import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.SideInputValues; +import org.apache.beam.sdk.coders.BigDecimalCoder; +import org.apache.beam.sdk.coders.BigEndianIntegerCoder; +import org.apache.beam.sdk.coders.BigEndianLongCoder; +import org.apache.beam.sdk.coders.BigEndianShortCoder; +import org.apache.beam.sdk.coders.BigIntegerCoder; +import org.apache.beam.sdk.coders.BitSetCoder; +import org.apache.beam.sdk.coders.BooleanCoder; +import org.apache.beam.sdk.coders.ByteArrayCoder; +import org.apache.beam.sdk.coders.ByteCoder; +import org.apache.beam.sdk.coders.CollectionCoder; +import org.apache.beam.sdk.coders.DelegateCoder; +import org.apache.beam.sdk.coders.DequeCoder; +import org.apache.beam.sdk.coders.DoubleCoder; +import org.apache.beam.sdk.coders.DurationCoder; +import org.apache.beam.sdk.coders.FloatCoder; +import org.apache.beam.sdk.coders.InstantCoder; +import org.apache.beam.sdk.coders.IterableCoder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.LengthPrefixCoder; +import org.apache.beam.sdk.coders.ListCoder; +import org.apache.beam.sdk.coders.MapCoder; +import org.apache.beam.sdk.coders.NullableCoder; +import org.apache.beam.sdk.coders.RowCoder; +import org.apache.beam.sdk.coders.SerializableCoder; +import org.apache.beam.sdk.coders.SetCoder; +import org.apache.beam.sdk.coders.ShardedKeyCoder; +import org.apache.beam.sdk.coders.SnappyCoder; +import org.apache.beam.sdk.coders.SortedMapCoder; +import org.apache.beam.sdk.coders.StringDelegateCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.TextualIntegerCoder; +import org.apache.beam.sdk.coders.TimestampPrefixingWindowCoder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.coders.VoidCoder; +import org.apache.beam.sdk.io.FileBasedSink; +import org.apache.beam.sdk.io.range.OffsetRange; +import org.apache.beam.sdk.transforms.join.CoGbkResult; +import org.apache.beam.sdk.transforms.join.CoGbkResultSchema; +import org.apache.beam.sdk.transforms.join.UnionCoder; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.util.construction.resources.PipelineResources; +import org.apache.beam.sdk.values.PCollectionViews; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Collections2; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.apache.spark.SparkConf; +import org.apache.spark.serializer.KryoRegistrator; +import org.apache.spark.serializer.KryoSerializer; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTaskResult; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class SparkSessionFactory { + + private static final Logger LOG = LoggerFactory.getLogger(SparkSessionFactory.class); + + // Patterns to exclude local JRE and certain artifact (groups) in Maven and Gradle cache. + private static final Collection SPARK_JAR_EXCLUDES = + Lists.newArrayList( + "jre/lib/ext/", + "/org/slf4j/", + "/org.slf4j/", + "/log4j/", + "/io/dropwizard/metrics/", + "/io.dropwizard.metrics/", + "/org/apache/spark/", + "/org.apache.spark/", + "/org/apache/hadoop/", + "/org.apache.hadoop/", + "/org/scala-lang/", + "/org.scala-lang/", + "/com.esotericsoftware/kryo-shaded", + "/com/esotericsoftware/kryo-shaded"); + + /** + * Gets active {@link SparkSession} or creates one using {@link + * SparkStructuredStreamingPipelineOptions}. + */ + public static SparkSession getOrCreateSession(SparkStructuredStreamingPipelineOptions options) { + if (options.getUseActiveSparkSession()) { + return SparkSession.active(); + } + return sessionBuilder(options.getSparkMaster(), options).getOrCreate(); + } + + /** Creates Spark session builder with some optimizations for local mode, e.g. in tests. */ + public static SparkSession.Builder sessionBuilder(String master) { + return sessionBuilder(master, null); + } + + private static SparkSession.Builder sessionBuilder( + String master, @Nullable SparkStructuredStreamingPipelineOptions options) { + + SparkConf sparkConf = new SparkConf().setIfMissing("spark.master", master); + master = sparkConf.get("spark.master"); // use effective master in the remainder of this method + + if (options != null) { + if (options.getAppName() != null) { + sparkConf.setAppName(options.getAppName()); + } + + if (options.getFilesToStage() != null && !options.getFilesToStage().isEmpty()) { + // Append the files to stage provided by the user to `spark.jars`. + PipelineResources.prepareFilesForStaging(options); + String[] filesToStage = filterFilesToStage(options, Collections.emptyList()); + String[] jars = getSparkJars(sparkConf); + sparkConf.setJars(jars.length > 0 ? ArrayUtils.addAll(jars, filesToStage) : filesToStage); + } else if (!sparkConf.contains("spark.jars") && !master.startsWith("local[")) { + // Stage classpath if `spark.jars` not set and not in local mode. + PipelineResources.prepareFilesForStaging(options); + // Set `spark.jars`, exclude JRE libs and jars causing conflicts using `userClassPathFirst`. + sparkConf.setJars(filterFilesToStage(options, SPARK_JAR_EXCLUDES)); + // Enable `userClassPathFirst` to prevent issues with guava, jackson and others. + sparkConf.setIfMissing("spark.executor.userClassPathFirst", "true"); + } + } + + // Set to 'org.apache.spark.serializer.JavaSerializer' via system property to disable Kryo + String serializer = sparkConf.get("spark.serializer", KryoSerializer.class.getName()); + if (serializer.equals(KryoSerializer.class.getName())) { + // Set to 'false' via system property to disable usage of Kryo unsafe + boolean unsafe = sparkConf.getBoolean("spark.kryo.unsafe", true); + sparkConf.set("spark.serializer", serializer); + sparkConf.set("spark.kryo.unsafe", Boolean.toString(unsafe)); + sparkConf.set("spark.kryo.registrator", SparkKryoRegistrator.class.getName()); + LOG.info("Configured `spark.serializer` to use KryoSerializer [unsafe={}]", unsafe); + } + + // By default, Spark defines 200 as a number of sql partitions. This seems too much for local + // mode, so try to align with value of "sparkMaster" option in this case. + // We should not overwrite this value (or any user-defined spark configuration value) if the + // user has already configured it. + int partitions = localNumPartitions(master); + if (partitions > 0) { + sparkConf.setIfMissing("spark.sql.shuffle.partitions", Integer.toString(partitions)); + } + + return SparkSession.builder().config(sparkConf); + } + + @SuppressWarnings({"return", "toarray.nullable.elements", "methodref.receiver"}) // safe to ignore + private static String[] filterFilesToStage( + SparkStructuredStreamingPipelineOptions opts, Collection excludes) { + Collection files = opts.getFilesToStage(); + if (files == null || files.isEmpty()) { + return EMPTY_STRING_ARRAY; + } + if (!excludes.isEmpty()) { + files = Collections2.filter(files, f -> !excludes.stream().anyMatch(f::contains)); + } + return files.toArray(EMPTY_STRING_ARRAY); + } + + private static String[] getSparkJars(SparkConf conf) { + return conf.contains("spark.jars") ? conf.get("spark.jars").split(",") : EMPTY_STRING_ARRAY; + } + + private static int localNumPartitions(String master) { + return master.startsWith("local[") ? toInt(substringBetween(master, "local[", "]")) : 0; + } + + /** + * {@link KryoRegistrator} for Spark to serialize broadcast variables used for side-inputs. + * + *

Note, this registrator must be public to be accessible for Kryo. + * + * @see SideInputValues + */ + public static class SparkKryoRegistrator implements KryoRegistrator { + @Override + public void registerClasses(Kryo kryo) { + kryo.register(InternalRow.class); + kryo.register(InternalRow[].class); + kryo.register(byte[][].class); + kryo.register(HashMap.class); + kryo.register(ArrayList.class); + + // support writing noop format + kryo.register(DataWritingSparkTaskResult.class); + + // TODO find more efficient ways + kryo.register(SerializablePipelineOptions.class, new JavaSerializer()); + + // side input values (spark runner specific) + kryo.register(SideInputValues.ByWindow.class); + kryo.register(SideInputValues.Global.class); + + // avro coders + tryToRegister(kryo, "org.apache.beam.sdk.extensions.avro.coders.AvroCoder"); + tryToRegister(kryo, "org.apache.beam.sdk.extensions.avro.coders.AvroGenericCoder"); + + // standard coders of org.apache.beam.sdk.coders + kryo.register(BigDecimalCoder.class); + kryo.register(BigEndianIntegerCoder.class); + kryo.register(BigEndianLongCoder.class); + kryo.register(BigEndianShortCoder.class); + kryo.register(BigIntegerCoder.class); + kryo.register(BitSetCoder.class); + kryo.register(BooleanCoder.class); + kryo.register(ByteArrayCoder.class); + kryo.register(ByteCoder.class); + kryo.register(CollectionCoder.class); + kryo.register(DelegateCoder.class); + kryo.register(DequeCoder.class); + kryo.register(DoubleCoder.class); + kryo.register(DurationCoder.class); + kryo.register(FloatCoder.class); + kryo.register(InstantCoder.class); + kryo.register(IterableCoder.class); + kryo.register(KvCoder.class); + kryo.register(LengthPrefixCoder.class); + kryo.register(ListCoder.class); + kryo.register(MapCoder.class); + kryo.register(NullableCoder.class); + kryo.register(RowCoder.class); + kryo.register(SerializableCoder.class); + kryo.register(SetCoder.class); + kryo.register(ShardedKeyCoder.class); + kryo.register(SnappyCoder.class); + kryo.register(SortedMapCoder.class); + kryo.register(StringDelegateCoder.class); + kryo.register(StringUtf8Coder.class); + kryo.register(TextualIntegerCoder.class); + kryo.register(TimestampPrefixingWindowCoder.class); + kryo.register(VarIntCoder.class); + kryo.register(VarLongCoder.class); + kryo.register(VoidCoder.class); + + // bounded windows and windowed value coders + kryo.register(GlobalWindow.Coder.class); + kryo.register(IntervalWindow.IntervalWindowCoder.class); + kryo.register(WindowedValues.FullWindowedValueCoder.class); + kryo.register(WindowedValues.ParamWindowedValueCoder.class); + kryo.register(WindowedValues.ValueOnlyWindowedValueCoder.class); + + // various others + kryo.register(OffsetRange.Coder.class); + kryo.register(UnionCoder.class); + kryo.register(PCollectionViews.ValueOrMetadataCoder.class); + kryo.register(FileBasedSink.FileResultCoder.class); + kryo.register(CoGbkResult.CoGbkResultCoder.class); + kryo.register(CoGbkResultSchema.class); + kryo.register(TupleTag.class); + kryo.register(TupleTagList.class); + } + + private void tryToRegister(Kryo kryo, String className) { + try { + kryo.register(Class.forName(className)); + } catch (ClassNotFoundException e) { + LOG.info("Class {}} was not found on classpath", className); + } + } + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/SparkTransformOverrides.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/SparkTransformOverrides.java new file mode 100644 index 000000000000..a60aa59efd61 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/SparkTransformOverrides.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation; + +import java.util.List; +import org.apache.beam.sdk.runners.PTransformOverride; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.util.construction.PTransformMatchers; +import org.apache.beam.sdk.util.construction.PTransformTranslation; +import org.apache.beam.sdk.util.construction.SplittableParDo; +import org.apache.beam.sdk.util.construction.SplittableParDoNaiveBounded; +import org.apache.beam.sdk.util.construction.UnsupportedOverrideFactory; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; + +/** {@link PTransform} overrides for Spark runner. */ +@SuppressWarnings({ + "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) +}) +class SparkTransformOverrides { + public static List getDefaultOverrides(boolean streaming) { + ImmutableList.Builder builder = ImmutableList.builder(); + // TODO: [https://github.com/apache/beam/issues/19107] Support @RequiresStableInput on Spark + // runner + builder.add( + PTransformOverride.of( + PTransformMatchers.requiresStableInputParDoMulti(), + UnsupportedOverrideFactory.withMessage( + "Spark runner currently doesn't support @RequiresStableInput annotation."))); + if (!streaming) { + builder + .add( + PTransformOverride.of( + PTransformMatchers.splittableParDo(), new SplittableParDo.OverrideFactory())) + .add( + PTransformOverride.of( + PTransformMatchers.urnEqualTo(PTransformTranslation.SPLITTABLE_PROCESS_KEYED_URN), + new SplittableParDoNaiveBounded.OverrideFactory())); + } + return builder.build(); + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java new file mode 100644 index 000000000000..24783040704e --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation; + +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.windowedValueEncoder; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables.getOnlyElement; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; +import org.apache.beam.runners.spark.structuredstreaming.translation.PipelineTranslator.TranslationState; +import org.apache.beam.runners.spark.structuredstreaming.translation.PipelineTranslator.UnresolvedTranslation; +import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.SideInputValues; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.construction.TransformInputs; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.POutput; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.SparkSession; +import org.checkerframework.checker.nullness.qual.MonotonicNonNull; +import scala.Tuple2; +import scala.reflect.ClassTag; + +/** + * A {@link TransformTranslator} provides the capability to translate a specific primitive or + * composite {@link PTransform} into its Spark correspondence. + * + *

WARNING: {@link TransformTranslator TransformTranslators} should never be serializable! This + * could easily hide situations where unnecessary references leak into Spark closures. + */ +@Internal +public abstract class TransformTranslator< + InT extends PInput, OutT extends POutput, TransformT extends PTransform> { + + // Factor to help estimate the complexity of the Spark execution plan. This is used to limit + // complexity by break linage where necessary to avoid overly large plans. Such plans can become + // very expensive during planning in the Catalyst optimizer. + protected final float complexityFactor; + + protected TransformTranslator(float complexityFactor) { + this.complexityFactor = complexityFactor; + } + + protected abstract void translate(TransformT transform, Context cxt) throws IOException; + + final void translate( + TransformT transform, + AppliedPTransform appliedTransform, + TranslationState translationState) + throws IOException { + translate(transform, new Context(appliedTransform, translationState)); + } + + /** + * Checks if a composite / primitive transform can be translated. Composites that cannot be + * translated as is, will be exploded further for translation of their parts. + * + *

This returns {@code true} by default and should be overridden where necessary. + * + * @throws RuntimeException If a transform uses unsupported features, an exception shall be thrown + * to give early feedback before any part of the pipeline is run. + */ + protected boolean canTranslate(TransformT transform) { + return true; + } + + /** + * Available mutable context to translate a {@link PTransform}. The context is backed by the + * shared {@link TranslationState} of the {@link PipelineTranslator}. + */ + protected class Context implements TranslationState { + private final AppliedPTransform transform; + private final TranslationState state; + + private @MonotonicNonNull InT pIn = null; + private @MonotonicNonNull OutT pOut = null; + + private Context(AppliedPTransform transform, TranslationState state) { + this.transform = transform; + this.state = state; + } + + public InT getInput() { + if (pIn == null) { + pIn = (InT) getOnlyElement(TransformInputs.nonAdditionalInputs(transform)); + } + return pIn; + } + + public Map, PCollection> getInputs() { + return transform.getInputs(); + } + + public Map, PCollection> getOutputs() { + return transform.getOutputs(); + } + + public OutT getOutput() { + if (pOut == null) { + pOut = (OutT) getOnlyElement(transform.getOutputs().values()); + } + return pOut; + } + + public PCollection getOutput(TupleTag tag) { + PCollection pc = (PCollection) transform.getOutputs().get(tag); + if (pc == null) { + throw new IllegalStateException("No output for tag " + tag); + } + return pc; + } + + public AppliedPTransform getCurrentTransform() { + return transform; + } + + @Override + public Dataset> getDataset(PCollection pCollection) { + return state.getDataset(pCollection); + } + + @Override + public Broadcast> getSideInputBroadcast( + PCollection pCollection, SideInputValues.Loader loader) { + return state.getSideInputBroadcast(pCollection, loader); + } + + @Override + public void putDataset( + PCollection pCollection, Dataset> dataset, boolean cache) { + state.putDataset(pCollection, dataset, cache); + } + + @Override + public void putUnresolved( + PCollection out, UnresolvedTranslation unresolved) { + state.putUnresolved(out, unresolved); + } + + @Override + public boolean isLeaf(PCollection pCollection) { + return state.isLeaf(pCollection); + } + + @Override + public Supplier getOptionsSupplier() { + return state.getOptionsSupplier(); + } + + @Override + public PipelineOptions getOptions() { + return state.getOptions(); + } + + public Dataset> createDataset( + List> data, Encoder> enc) { + return data.isEmpty() + ? getSparkSession().emptyDataset(enc) + : getSparkSession().createDataset(data, enc); + } + + public Broadcast broadcast(T value) { + return getSparkSession().sparkContext().broadcast(value, (ClassTag) ClassTag.AnyRef()); + } + + @Override + public SparkSession getSparkSession() { + return state.getSparkSession(); + } + + @Override + public Encoder encoderOf(Coder coder, Factory factory) { + return state.encoderOf(coder, factory); + } + + public Encoder> tupleEncoder(Encoder e1, Encoder e2) { + return Encoders.tuple(e1, e2); + } + + public Encoder> windowedEncoder(Coder coder) { + return windowedValueEncoder(encoderOf(coder), windowEncoder()); + } + + public Encoder> windowedEncoder(Encoder enc) { + return windowedValueEncoder(enc, windowEncoder()); + } + + public Encoder> windowedEncoder( + Coder coder, Coder windowCoder) { + return windowedValueEncoder(encoderOf(coder), encoderOf(windowCoder)); + } + + public Encoder windowEncoder() { + checkState(!getInputs().isEmpty(), "Transform has no inputs, cannot get windowCoder!"); + return encoderOf(windowCoder((PCollection) getInput())); + } + } + + protected Coder windowCoder(PCollection pc) { + return (Coder) pc.getWindowingStrategy().getWindowFn().windowCoder(); + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/Aggregators.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/Aggregators.java new file mode 100644 index 000000000000..183445642a0b --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/Aggregators.java @@ -0,0 +1,600 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.collectionEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderOf; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.mapEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.mutablePairEncoder; +import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterators.peekingIterator; + +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; +import java.util.TreeMap; +import java.util.function.BiFunction; +import java.util.function.BinaryOperator; +import java.util.function.Function; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.transforms.Combine.CombineFn; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.Sessions; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; +import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Collections2; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.PeekingIterator; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.expressions.Aggregator; +import org.apache.spark.util.MutablePair; +import org.checkerframework.checker.nullness.qual.MonotonicNonNull; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.nullness.qual.PolyNull; +import org.joda.time.Instant; + +@Internal +class Aggregators { + + /** + * Creates simple value {@link Aggregator} that is not window aware. + * + * @param {@link CombineFn} input type + * @param {@link CombineFn} accumulator type + * @param {@link CombineFn} / {@link Aggregator} result type + * @param {@link Aggregator} input type + */ + static Aggregator value( + CombineFn fn, + Fun1 valueFn, + Encoder accEnc, + Encoder outEnc) { + return new ValueAggregator<>(fn, valueFn, accEnc, outEnc); + } + + /** + * Creates windowed Spark {@link Aggregator} depending on the provided Beam {@link WindowFn}s. + * + *

Specialised implementations are provided for: + *

  • {@link Sessions} + *
  • Non merging window functions + *
  • Merging window functions + * + * @param {@link CombineFn} input type + * @param {@link CombineFn} accumulator type + * @param {@link CombineFn} / {@link Aggregator} result type + * @param {@link Aggregator} input type + */ + static + Aggregator, ?, Collection>> windowedValue( + CombineFn fn, + Fun1, ValT> valueFn, + WindowingStrategy windowing, + Encoder windowEnc, + Encoder accEnc, + Encoder> outEnc) { + if (!windowing.needsMerge()) { + return new NonMergingWindowedAggregator<>(fn, valueFn, windowing, windowEnc, accEnc, outEnc); + } else if (windowing.getWindowFn().getClass().equals(Sessions.class)) { + return new SessionsAggregator<>(fn, valueFn, windowing, (Encoder) windowEnc, accEnc, outEnc); + } + return new MergingWindowedAggregator<>(fn, valueFn, windowing, windowEnc, accEnc, outEnc); + } + + /** + * Simple value {@link Aggregator} that is not window aware. + * + * @param {@link CombineFn} input type + * @param {@link CombineFn} accumulator type + * @param {@link CombineFn} / {@link Aggregator} result type + * @param {@link Aggregator} input type + */ + private static class ValueAggregator + extends CombineFnAggregator { + + public ValueAggregator( + CombineFn fn, + Fun1 valueFn, + Encoder accEnc, + Encoder outEnc) { + super(fn, valueFn, accEnc, outEnc); + } + + @Override + public AccT zero() { + return emptyAcc(); + } + + @Override + public AccT reduce(AccT buff, InT in) { + return addToAcc(buff, value(in)); + } + + @Override + public AccT merge(AccT b1, AccT b2) { + return mergeAccs(b1, b2); + } + + @Override + public ResT finish(AccT buff) { + return extract(buff); + } + } + + /** + * Specialized windowed Spark {@link Aggregator} for Beam {@link WindowFn}s of type {@link + * Sessions}. The aggregator uses a {@link TreeMap} as buffer to maintain ordering of the {@link + * IntervalWindow}s and merge these more efficiently. + * + *

    For efficiency, this aggregator re-implements {@link + * Sessions#mergeWindows(WindowFn.MergeContext)} to leverage the already sorted buffer. + * + * @param {@link CombineFn} input type + * @param {@link CombineFn} accumulator type + * @param {@link CombineFn} / {@link Aggregator} result type + * @param {@link Aggregator} input type + */ + private static class SessionsAggregator + extends WindowedAggregator< + ValT, + AccT, + ResT, + InT, + IntervalWindow, + TreeMap>> { + + SessionsAggregator( + CombineFn combineFn, + Fun1, ValT> valueFn, + WindowingStrategy windowing, + Encoder windowEnc, + Encoder accEnc, + Encoder> outEnc) { + super(combineFn, valueFn, windowing, windowEnc, accEnc, outEnc, (Class) TreeMap.class); + checkArgument(windowing.getWindowFn().getClass().equals(Sessions.class)); + } + + @Override + public final TreeMap> zero() { + return new TreeMap<>(); + } + + @Override + @SuppressWarnings("keyfor") + public TreeMap> reduce( + TreeMap> buff, WindowedValue input) { + for (IntervalWindow window : (Collection) input.getWindows()) { + @MonotonicNonNull MutablePair acc = null; + @MonotonicNonNull IntervalWindow first = null, last = null; + // start with window before or equal to new window (if exists) + @Nullable Entry> lower = buff.floorEntry(window); + if (lower != null && window.intersects(lower.getKey())) { + // if intersecting, init accumulator and extend window to span both + acc = lower.getValue(); + window = window.span(lower.getKey()); + first = last = lower.getKey(); + } + // merge following windows in order if they intersect, then stop + for (Entry> entry : + buff.tailMap(window, false).entrySet()) { + MutablePair entryAcc = entry.getValue(); + IntervalWindow entryWindow = entry.getKey(); + if (window.intersects(entryWindow)) { + // extend window and merge accumulators + window = window.span(entryWindow); + acc = acc == null ? entryAcc : mergeAccs(window, acc, entryAcc); + if (first == null) { + // there was no previous (lower) window intersecting the input window + first = last = entryWindow; + } else { + last = entryWindow; + } + } else { + break; // stop, later windows won't intersect either + } + } + if (first != null && last != null) { + // remove entire subset from from first to last after it got merged into acc + buff.navigableKeySet().subSet(first, true, last, true).clear(); + } + // add input and get accumulator for new (potentially merged) window + buff.put(window, addToAcc(window, acc, value(input), input.getTimestamp())); + } + return buff; + } + + @Override + public TreeMap> merge( + TreeMap> b1, + TreeMap> b2) { + if (b1.isEmpty()) { + return b2; + } else if (b2.isEmpty()) { + return b1; + } + // Init new tree map to merge both buffers + TreeMap> res = zero(); + PeekingIterator>> it1 = + peekingIterator(b1.entrySet().iterator()); + PeekingIterator>> it2 = + peekingIterator(b2.entrySet().iterator()); + + @Nullable MutablePair acc = null; + @Nullable IntervalWindow window = null; + while (it1.hasNext() || it2.hasNext()) { + // pick iterator with the smallest window ahead and forward it + Entry> nextMin = + (it1.hasNext() && it2.hasNext()) + ? it1.peek().getKey().compareTo(it2.peek().getKey()) <= 0 ? it1.next() : it2.next() + : it1.hasNext() ? it1.next() : it2.next(); + if (window != null && window.intersects(nextMin.getKey())) { + // extend window and merge accumulators if intersecting + window = window.span(nextMin.getKey()); + acc = mergeAccs(window, acc, nextMin.getValue()); + } else { + // store window / accumulator if necessary and continue with next minimum + if (window != null && acc != null) { + res.put(window, acc); + } + acc = nextMin.getValue(); + window = nextMin.getKey(); + } + } + if (window != null && acc != null) { + res.put(window, acc); + } + return res; + } + } + + /** + * Merging windowed Spark {@link Aggregator} using a Map of {@link BoundedWindow}s as aggregation + * buffer. When reducing new input, a windowed accumulator is created for each new window of the + * input that doesn't overlap with existing windows. Otherwise, if the window is known or + * overlaps, the window is extended accordingly and accumulators are merged. + * + * @param {@link CombineFn} input type + * @param {@link CombineFn} accumulator type + * @param {@link CombineFn} / {@link Aggregator} result type + * @param {@link Aggregator} input type + */ + private static class MergingWindowedAggregator + extends NonMergingWindowedAggregator { + + private final WindowFn windowFn; + + public MergingWindowedAggregator( + CombineFn combineFn, + Fun1, ValT> valueFn, + WindowingStrategy windowing, + Encoder windowEnc, + Encoder accEnc, + Encoder> outEnc) { + super(combineFn, valueFn, windowing, windowEnc, accEnc, outEnc); + windowFn = (WindowFn) windowing.getWindowFn(); + } + + @Override + protected Map> reduce( + Map> buff, + Collection windows, + ValT value, + Instant timestamp) { + if (buff.isEmpty()) { + // no windows yet to be merged, use the non-merging behavior of super + return super.reduce(buff, windows, value, timestamp); + } + // Merge multiple windows into one target window using the reducer function if the window + // already exists. Otherwise, the input value is added to the accumulator. Merged windows are + // removed from the accumulator map. + Function> accFn = + target -> + (acc, w) -> { + MutablePair accW = buff.remove(w); + return (accW != null) + ? mergeAccs(w, acc, accW) + : addToAcc(w, acc, value, timestamp); + }; + Set unmerged = mergeWindows(buff, ImmutableSet.copyOf(windows), accFn); + if (!unmerged.isEmpty()) { + // remaining windows don't have to be merged + return super.reduce(buff, unmerged, value, timestamp); + } + return buff; + } + + @Override + public Map> merge( + Map> b1, + Map> b2) { + // Merge multiple windows into one target window using the reducer function. Merged windows + // are removed from both accumulator maps + Function> reduceFn = + target -> (acc, w) -> mergeAccs(w, mergeAccs(w, acc, b1.remove(w)), b2.remove(w)); + + Set unmerged = b2.keySet(); + unmerged = mergeWindows(b1, unmerged, reduceFn); + if (!unmerged.isEmpty()) { + // keep only unmerged windows in 2nd accumulator map, continue using "non-merging" merge + b2.keySet().retainAll(unmerged); + return super.merge(b1, b2); + } + return b1; + } + + /** Reduce function to merge multiple windowed accumulator values into one target window. */ + private interface ReduceFn + extends BiFunction< + @Nullable MutablePair, + BoundedWindow, + @Nullable MutablePair> {} + + /** + * Attempt to merge windows of accumulator map with additional windows using the reducer + * function. The reducer function must support {@code null} as zero value. + * + * @return The subset of additional windows that don't require a merge. + */ + private Set mergeWindows( + Map> buff, + Set newWindows, + Function> reduceFn) { + try { + Set newUnmerged = new HashSet<>(newWindows); + windowFn.mergeWindows( + windowFn.new MergeContext() { + @Override + public Collection windows() { + return Sets.union(buff.keySet(), newWindows); + } + + @Override + public void merge(Collection merges, BoundedWindow target) { + @Nullable + MutablePair merged = + merges.stream().reduce(null, reduceFn.apply(target), combiner(target)); + if (merged != null) { + buff.put(target, merged); + } + newUnmerged.removeAll(merges); + } + }); + return newUnmerged; + } catch (Exception e) { + throw new RuntimeException("Unable to merge accumulators windows", e); + } + } + } + + /** + * Non-merging windowed Spark {@link Aggregator} using a Map of {@link BoundedWindow}s as + * aggregation buffer. When reducing new input, a windowed accumulator is created for each new + * window of the input. Otherwise, if the window is known, the accumulators are merged. + * + * @param {@link CombineFn} input type + * @param {@link CombineFn} accumulator type + * @param {@link CombineFn} / {@link Aggregator} result type + * @param {@link Aggregator} input type + */ + private static class NonMergingWindowedAggregator + extends WindowedAggregator< + ValT, AccT, ResT, InT, BoundedWindow, Map>> { + + public NonMergingWindowedAggregator( + CombineFn combineFn, + Fun1, ValT> valueFn, + WindowingStrategy windowing, + Encoder windowEnc, + Encoder accEnc, + Encoder> outEnc) { + super(combineFn, valueFn, windowing, windowEnc, accEnc, outEnc, (Class) Map.class); + } + + @Override + public Map> zero() { + return new HashMap<>(); + } + + @Override + public final Map> reduce( + Map> buff, WindowedValue input) { + Collection windows = (Collection) input.getWindows(); + return reduce(buff, windows, value(input), input.getTimestamp()); + } + + protected Map> reduce( + Map> buff, + Collection windows, + ValT value, + Instant timestamp) { + // for each window add the value to the accumulator + for (BoundedWindow window : windows) { + buff.compute(window, (w, acc) -> addToAcc(w, acc, value, timestamp)); + } + return buff; + } + + @Override + public Map> merge( + Map> b1, + Map> b2) { + if (b1.isEmpty()) { + return b2; + } else if (b2.isEmpty()) { + return b1; + } + if (b2.size() > b1.size()) { + return merge(b2, b1); + } + // merge entries of (smaller) 2nd agg buffer map into first by merging the accumulators + b2.forEach((w, acc) -> b1.merge(w, acc, combiner(w))); + return b1; + } + } + + /** + * Abstract base of a Spark {@link Aggregator} on {@link WindowedValue}s using a Map of {@link W} + * as aggregation buffer. + * + * @param {@link CombineFn} input type + * @param {@link CombineFn} accumulator type + * @param {@link CombineFn} / {@link Aggregator} result type + * @param {@link Aggregator} input type + * @param bounded window type + * @param aggregation buffer {@link W} + */ + private abstract static class WindowedAggregator< + ValT, + AccT, + ResT, + InT, + W extends @NonNull BoundedWindow, + MapT extends Map>> + extends CombineFnAggregator< + ValT, AccT, ResT, WindowedValue, MapT, Collection>> { + private final TimestampCombiner tsCombiner; + + public WindowedAggregator( + CombineFn combineFn, + Fun1, ValT> valueFn, + WindowingStrategy windowing, + Encoder windowEnc, + Encoder accEnc, + Encoder> outEnc, + Class clazz) { + super( + combineFn, + valueFn, + mapEncoder(windowEnc, mutablePairEncoder(encoderOf(Instant.class), accEnc), clazz), + collectionEncoder(outEnc)); + tsCombiner = windowing.getTimestampCombiner(); + } + + protected final Instant resolveTimestamp(BoundedWindow w, Instant t1, Instant t2) { + return tsCombiner.merge(w, t1, t2); + } + + /** Init accumulator with initial input value and timestamp. */ + protected final MutablePair initAcc(ValT value, Instant timestamp) { + return new MutablePair<>(timestamp, addToAcc(emptyAcc(), value)); + } + + /** Merge timestamped accumulators. */ + protected final > @PolyNull T mergeAccs( + W window, @PolyNull T a1, @PolyNull T a2) { + if (a1 == null || a2 == null) { + return a1 == null ? a2 : a1; + } + return (T) a1.update(resolveTimestamp(window, a1._1, a2._1), mergeAccs(a1._2, a2._2)); + } + + protected BinaryOperator<@Nullable MutablePair> combiner(W target) { + return (a1, a2) -> mergeAccs(target, a1, a2); + } + + /** Add an input value to a nullable accumulator. */ + protected final MutablePair addToAcc( + W window, @Nullable MutablePair acc, ValT val, Instant ts) { + if (acc == null) { + return initAcc(val, ts); + } + return acc.update(resolveTimestamp(window, acc._1, ts), addToAcc(acc._2, val)); + } + + @Override + @SuppressWarnings("nullness") // entries are non null + public final Collection> finish(MapT buffer) { + return Collections2.transform(buffer.entrySet(), this::windowedValue); + } + + private WindowedValue windowedValue(Entry> e) { + return WindowedValues.of(extract(e.getValue()._2), e.getValue()._1, e.getKey(), NO_FIRING); + } + } + + /** + * Abstract base of Spark {@link Aggregator}s using a Beam {@link CombineFn}. + * + * @param {@link CombineFn} input type + * @param {@link CombineFn} accumulator type + * @param {@link CombineFn} result type + * @param {@link Aggregator} input type + * @param {@link Aggregator} buffer type + * @param {@link Aggregator} output type + */ + private abstract static class CombineFnAggregator + extends Aggregator { + private final CombineFn fn; + private final Fun1 valueFn; + private final Encoder bufferEnc; + private final Encoder outputEnc; + + public CombineFnAggregator( + CombineFn fn, + Fun1 valueFn, + Encoder bufferEnc, + Encoder outputEnc) { + this.fn = fn; + this.valueFn = valueFn; + this.bufferEnc = bufferEnc; + this.outputEnc = outputEnc; + } + + protected final ValT value(InT in) { + return valueFn.apply(in); + } + + protected final AccT emptyAcc() { + return fn.createAccumulator(); + } + + protected final AccT mergeAccs(AccT a1, AccT a2) { + return fn.mergeAccumulators(ImmutableList.of(a1, a2)); + } + + protected final AccT addToAcc(AccT acc, ValT val) { + return fn.addInput(acc, val); + } + + protected final ResT extract(AccT acc) { + return fn.extractOutput(acc); + } + + @Override + public Encoder bufferEncoder() { + return bufferEnc; + } + + @Override + public Encoder outputEncoder() { + return outputEnc; + } + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTranslatorBatch.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTranslatorBatch.java new file mode 100644 index 000000000000..c77637ab5b91 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTranslatorBatch.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.value; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1; +import static scala.collection.Iterator.single; + +import java.util.Collection; +import java.util.Map; +import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1; +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.Combine.CombineFn; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.expressions.Aggregator; +import scala.collection.Iterator; + +/** + * Translator for {@link Combine.Globally} using a Spark {@link Aggregator}. + * + *

    To minimize the amount of data shuffled, this first reduces the data per partition using + * {@link Aggregator#reduce}, gathers the partial results (using {@code coalesce(1)}) and finally + * merges these using {@link Aggregator#merge}. + * + *

    TODOs: + *

  • any missing features? + */ +class CombineGloballyTranslatorBatch + extends TransformTranslator, PCollection, Combine.Globally> { + + CombineGloballyTranslatorBatch() { + super(0.2f); + } + + @Override + protected void translate(Combine.Globally transform, Context cxt) { + WindowingStrategy windowing = cxt.getInput().getWindowingStrategy(); + CombineFn combineFn = (CombineFn) transform.getFn(); + + Coder inputCoder = cxt.getInput().getCoder(); + Coder outputCoder = cxt.getOutput().getCoder(); + Coder accumCoder = accumulatorCoder(combineFn, inputCoder, cxt); + + Encoder outEnc = cxt.encoderOf(outputCoder); + Encoder accEnc = cxt.encoderOf(accumCoder); + Encoder> wvOutEnc = cxt.windowedEncoder(outEnc); + + Dataset> dataset = cxt.getDataset(cxt.getInput()); + + final Dataset> result; + if (GroupByKeyHelpers.eligibleForGlobalGroupBy(windowing, true)) { + Aggregator agg = Aggregators.value(combineFn, v -> v, accEnc, outEnc); + + // Drop window and restore afterwards, produces single global aggregation result + result = aggregate(dataset, agg, value(), windowedValue(), wvOutEnc); + } else { + Aggregator, ?, Collection>> agg = + Aggregators.windowedValue( + combineFn, value(), windowing, cxt.windowEncoder(), accEnc, wvOutEnc); + + // Produces aggregation result per window + result = + aggregate(dataset, agg, v -> v, fun1(out -> ScalaInterop.scalaIterator(out)), wvOutEnc); + } + cxt.putDataset(cxt.getOutput(), result); + } + + /** + * Aggregate dataset globally without using key. + * + *

    There is no global, typed version of {@link Dataset#agg(Map)} on datasets. This reduces all + * partitions first, and then merges them to receive the final result. + */ + private static Dataset> aggregate( + Dataset> ds, + Aggregator agg, + Fun1, AggInT> valueFn, + Fun1>> finishFn, + Encoder> enc) { + // reduce partition using aggregator + Fun1>, Iterator> reduce = + fun1(it -> single(it.map(valueFn).foldLeft(agg.zero(), agg::reduce))); + // merge reduced partitions using aggregator + Fun1, Iterator>> merge = + fun1(it -> finishFn.apply(agg.finish(it.hasNext() ? it.reduce(agg::merge) : agg.zero()))); + + return ds.mapPartitions(reduce, agg.bufferEncoder()).coalesce(1).mapPartitions(merge, enc); + } + + private Coder accumulatorCoder( + CombineFn fn, Coder valueCoder, Context cxt) { + try { + return fn.getAccumulatorCoder(cxt.getInput().getPipeline().getCoderRegistry(), valueCoder); + } catch (CannotProvideCoderException e) { + throw new RuntimeException(e); + } + } + + private static Fun1>> windowedValue() { + return v -> single(WindowedValues.valueInGlobalWindow(v)); + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGroupedValuesTranslatorBatch.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGroupedValuesTranslatorBatch.java new file mode 100644 index 000000000000..fa59cdf2452c --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGroupedValuesTranslatorBatch.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import java.io.IOException; +import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.Combine.CombineFn; +import org.apache.beam.sdk.transforms.CombineWithContext; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.expressions.Aggregator; + +/** + * Translator for {@link Combine.GroupedValues} if the {@link CombineFn} doesn't require context / + * side-inputs. + * + *

    This doesn't require a Spark {@link Aggregator}. Instead it can directly use the respective + * {@link CombineFn} to reduce each iterable of values into an aggregated output value. + */ +class CombineGroupedValuesTranslatorBatch + extends TransformTranslator< + PCollection>>, + PCollection>, + Combine.GroupedValues> { + + CombineGroupedValuesTranslatorBatch() { + super(0.2f); + } + + @Override + protected void translate(Combine.GroupedValues transform, Context cxt) + throws IOException { + CombineFn combineFn = (CombineFn) transform.getFn(); + + Encoder>> enc = cxt.windowedEncoder(cxt.getOutput().getCoder()); + Dataset>>> inputDs = (Dataset) cxt.getDataset(cxt.getInput()); + + cxt.putDataset(cxt.getOutput(), inputDs.map(reduce(combineFn), enc)); + } + + @Override + public boolean canTranslate(Combine.GroupedValues transform) { + return !(transform.getFn() instanceof CombineWithContext); + } + + private static + Fun1>>, WindowedValue>> reduce( + CombineFn fn) { + return wv -> { + KV> kv = wv.getValue(); + AccT acc = null; + for (InT in : kv.getValue()) { + acc = fn.addInput(acc != null ? acc : fn.createAccumulator(), in); + } + OutT res = acc != null ? fn.extractOutput(acc) : fn.defaultValue(); + return wv.withValue(KV.of(kv.getKey(), res)); + }; + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java new file mode 100644 index 000000000000..47ba1be730e7 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.eligibleForGlobalGroupBy; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.eligibleForGroupByWindow; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.explodeWindowedKey; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.value; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.valueKey; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.valueValue; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.windowedKV; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1; + +import java.util.Collection; +import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1; +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.Combine.CombineFn; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.expressions.Aggregator; +import scala.Tuple2; +import scala.collection.IterableOnce; + +/** + * Translator for {@link Combine.PerKey} using {@link Dataset#groupByKey} with a Spark {@link + * Aggregator}. + * + *

      + *
    • When using the default global window, window information is dropped and restored after the + * aggregation. + *
    • For non-merging windows, windows are exploded and moved into a composite key for better + * distribution. After the aggregation, windowed values are restored from the composite key. + *
    • All other cases use an aggregator on windowed values that is optimized for the current + * windowing strategy. + *
    + * + * TODOs: + *
  • combine with context (CombineFnWithContext)? + *
  • combine with sideInputs? + *
  • other there other missing features? + */ +class CombinePerKeyTranslatorBatch + extends TransformTranslator< + PCollection>, PCollection>, Combine.PerKey> { + + CombinePerKeyTranslatorBatch() { + super(0.2f); + } + + @Override + public void translate(Combine.PerKey transform, Context cxt) { + WindowingStrategy windowing = cxt.getInput().getWindowingStrategy(); + CombineFn combineFn = (CombineFn) transform.getFn(); + + KvCoder inputCoder = (KvCoder) cxt.getInput().getCoder(); + KvCoder outputCoder = (KvCoder) cxt.getOutput().getCoder(); + + Encoder keyEnc = cxt.keyEncoderOf(inputCoder); + Encoder> inputEnc = cxt.encoderOf(inputCoder); + Encoder>> wvOutputEnc = cxt.windowedEncoder(outputCoder); + Encoder accumEnc = accumEncoder(combineFn, inputCoder.getValueCoder(), cxt); + + final Dataset>> result; + + boolean globalGroupBy = eligibleForGlobalGroupBy(windowing, true); + boolean groupByWindow = eligibleForGroupByWindow(windowing, true); + + if (globalGroupBy || groupByWindow) { + Aggregator, ?, OutT> valueAgg = + Aggregators.value(combineFn, KV::getValue, accumEnc, cxt.valueEncoderOf(outputCoder)); + + if (globalGroupBy) { + // Drop window and group by key globally to run the aggregation (combineFn), afterwards the + // global window is restored + result = + cxt.getDataset(cxt.getInput()) + .groupByKey(valueKey(), keyEnc) + .mapValues(value(), inputEnc) + .agg(valueAgg.toColumn()) + .map(globalKV(), wvOutputEnc); + } else { + Encoder> windowedKeyEnc = + cxt.tupleEncoder(cxt.windowEncoder(), keyEnc); + + // Group by window and key to run the aggregation (combineFn) + result = + cxt.getDataset(cxt.getInput()) + .flatMap(explodeWindowedKey(value()), cxt.tupleEncoder(windowedKeyEnc, inputEnc)) + .groupByKey(fun1(Tuple2::_1), windowedKeyEnc) + .mapValues(fun1(Tuple2::_2), inputEnc) + .agg(valueAgg.toColumn()) + .map(windowedKV(), wvOutputEnc); + } + } else { + // Optimized aggregator for non-merging and session window functions, all others depend on + // windowFn.mergeWindows + Aggregator>, ?, Collection>> aggregator = + Aggregators.windowedValue( + combineFn, + valueValue(), + windowing, + cxt.windowEncoder(), + accumEnc, + cxt.windowedEncoder(outputCoder.getValueCoder())); + result = + cxt.getDataset(cxt.getInput()) + .groupByKey(valueKey(), keyEnc) + .agg(aggregator.toColumn()) + .flatMap(explodeWindows(), wvOutputEnc); + } + + cxt.putDataset(cxt.getOutput(), result); + } + + private static + Fun1>>, IterableOnce>>> + explodeWindows() { + return t -> + ScalaInterop.scalaIterator(t._2).map(wv -> wv.withValue(KV.of(t._1, wv.getValue()))); + } + + private static Fun1, WindowedValue>> globalKV() { + return t -> WindowedValues.valueInGlobalWindow(KV.of(t._1, t._2)); + } + + private Encoder accumEncoder( + CombineFn fn, Coder valueCoder, Context cxt) { + try { + CoderRegistry registry = cxt.getInput().getPipeline().getCoderRegistry(); + return cxt.encoderOf(fn.getAccumulatorCoder(registry, valueCoder)); + } catch (CannotProvideCoderException e) { + throw new RuntimeException(e); + } + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnPartitionIteratorFactory.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnPartitionIteratorFactory.java new file mode 100644 index 000000000000..1696a5c81cb1 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnPartitionIteratorFactory.java @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.scalaIterator; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple; + +import java.io.Serializable; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.Map; +import java.util.function.Supplier; +import javax.annotation.CheckForNull; +import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator; +import org.apache.beam.runners.spark.structuredstreaming.translation.batch.DoFnRunnerFactory.DoFnRunnerWithTeardown; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.util.WindowedValueMultiReceiver; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator; +import org.checkerframework.checker.nullness.qual.NonNull; +import scala.Function1; +import scala.Tuple2; +import scala.collection.Iterator; + +/** + * Abstract factory to create a {@link DoFnPartitionIt DoFn partition iterator} using a customizable + * {@link WindowedValueMultiReceiver}. + */ +abstract class DoFnPartitionIteratorFactory + implements Function1>, Iterator>, Serializable { + + protected final DoFnRunnerFactory factory; + protected final Supplier options; + private final MetricsAccumulator metrics; + + private DoFnPartitionIteratorFactory( + Supplier options, + MetricsAccumulator metrics, + DoFnRunnerFactory factory) { + this.options = options; + this.metrics = metrics; + this.factory = factory; + } + + /** + * {@link DoFnPartitionIteratorFactory} emitting a single output of type {@link WindowedValue} of + * {@link OutT}. + */ + static DoFnPartitionIteratorFactory> singleOutput( + Supplier options, + MetricsAccumulator metrics, + DoFnRunnerFactory factory) { + return new SingleOut<>(options, metrics, factory); + } + + /** + * {@link DoFnPartitionIteratorFactory} emitting multiple outputs encoded as tuple of column index + * and {@link WindowedValue} of {@link OutT}, where column index corresponds to the index of a + * {@link TupleTag#getId()} in {@code tagColIdx}. + */ + static + DoFnPartitionIteratorFactory>> multiOutput( + Supplier options, + MetricsAccumulator metrics, + DoFnRunnerFactory factory, + Map tagColIdx) { + return new MultiOut<>(options, metrics, factory, tagColIdx); + } + + @Override + public Iterator apply(Iterator> it) { + return it.hasNext() + ? scalaIterator(new DoFnPartitionIt(it)) + : (Iterator) Iterator.empty(); + } + + /** Output manager emitting outputs of type {@link OutT} to the buffer. */ + abstract WindowedValueMultiReceiver outputManager(Deque buffer); + + /** + * {@link DoFnPartitionIteratorFactory} emitting a single output of type {@link WindowedValue} of + * {@link OutT}. + */ + private static class SingleOut + extends DoFnPartitionIteratorFactory> { + private SingleOut( + Supplier options, + MetricsAccumulator metrics, + DoFnRunnerFactory factory) { + super(options, metrics, factory); + } + + @Override + WindowedValueMultiReceiver outputManager(Deque> buffer) { + return new WindowedValueMultiReceiver() { + @Override + public void output(TupleTag tag, WindowedValue output) { + buffer.add((WindowedValue) output); + } + }; + } + } + + /** + * {@link DoFnPartitionIteratorFactory} emitting multiple outputs encoded as tuple of column index + * and {@link WindowedValue} of {@link OutT}, where column index corresponds to the index of a + * {@link TupleTag#getId()} in {@link #tagColIdx}. + */ + private static class MultiOut + extends DoFnPartitionIteratorFactory>> { + private final Map tagColIdx; + + public MultiOut( + Supplier options, + MetricsAccumulator metrics, + DoFnRunnerFactory factory, + Map tagColIdx) { + super(options, metrics, factory); + this.tagColIdx = tagColIdx; + } + + @Override + WindowedValueMultiReceiver outputManager(Deque>> buffer) { + return new WindowedValueMultiReceiver() { + @Override + public void output(TupleTag tag, WindowedValue output) { + // Additional unused outputs can be skipped here. In that case columnIdx is null. + Integer columnIdx = tagColIdx.get(tag.getId()); + if (columnIdx != null) { + buffer.add(tuple(columnIdx, (WindowedValue) output)); + } + } + }; + } + } + + // FIXME Add support for TimerInternals.TimerData + /** + * Partition iterator that lazily processes each element from the (input) iterator on demand + * producing zero, one or more output elements as output (via an internal buffer). + * + *

    When initializing the iterator for a partition {@code setup} followed by {@code startBundle} + * is called. + */ + private class DoFnPartitionIt extends AbstractIterator { + private final Deque buffer = new ArrayDeque<>(); + private final DoFnRunnerWithTeardown doFnRunner; + private final Iterator> partitionIt; + private boolean isBundleFinished; + + private DoFnPartitionIt(Iterator> partitionIt) { + this.partitionIt = partitionIt; + this.doFnRunner = factory.create(options.get(), metrics, outputManager(buffer)); + } + + @Override + protected @CheckForNull OutT computeNext() { + try { + while (true) { + if (!buffer.isEmpty()) { + return buffer.remove(); + } + if (partitionIt.hasNext()) { + // grab the next element and process it. + doFnRunner.processElement(partitionIt.next()); + } else { + if (!isBundleFinished) { + isBundleFinished = true; + doFnRunner.finishBundle(); + continue; // finishBundle can produce more output + } + doFnRunner.teardown(); + return endOfData(); + } + } + } catch (RuntimeException re) { + doFnRunner.teardown(); + throw re; + } + } + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnRunnerFactory.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnRunnerFactory.java new file mode 100644 index 000000000000..99ce3dc69889 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnRunnerFactory.java @@ -0,0 +1,300 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.apache.beam.runners.core.DoFnRunner; +import org.apache.beam.runners.core.DoFnRunners; +import org.apache.beam.runners.core.SideInputReader; +import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator; +import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.CachedSideInputReader; +import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.NoOpStepContext; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFnSchemaInformation; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.WindowedValueMultiReceiver; +import org.apache.beam.sdk.util.WindowedValueReceiver; +import org.apache.beam.sdk.util.construction.ParDoTranslation; +import org.apache.beam.sdk.values.CausedByDrain; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.joda.time.Instant; + +/** + * Factory to create a {@link DoFnRunner}. The factory supports fusing multiple {@link DoFnRunner + * runners} into a single one. + */ +abstract class DoFnRunnerFactory implements Serializable { + + interface DoFnRunnerWithTeardown extends DoFnRunner { + void teardown(); + } + + /** + * Creates a runner that is ready to process elements. + * + *

    Both, {@link org.apache.beam.sdk.transforms.reflect.DoFnInvoker#invokeSetup setup} and + * {@link DoFnRunner#startBundle()} are already invoked by the factory. + */ + abstract DoFnRunnerWithTeardown create( + PipelineOptions options, MetricsAccumulator metrics, WindowedValueMultiReceiver output); + + /** + * Fuses the factory for the following {@link DoFnRunner} into a single factory that processes + * both DoFns in a single step. + */ + abstract DoFnRunnerFactory fuse(DoFnRunnerFactory next); + + static DoFnRunnerFactory simple( + AppliedPTransform, ?, ParDo.MultiOutput> appliedPT, + PCollection input, + SideInputReader sideInputReader, + boolean filterMainOutput) { + return new SimpleRunnerFactory<>(appliedPT, input, sideInputReader, filterMainOutput); + } + + /** + * Factory creating a {@link org.apache.beam.runners.core.SimpleDoFnRunner SimpleRunner} with + * metrics support. + */ + private static class SimpleRunnerFactory extends DoFnRunnerFactory { + private final String stepName; + private final DoFn doFn; + private final DoFnSchemaInformation doFnSchema; + private final Coder coder; + private final WindowingStrategy windowingStrategy; + private final TupleTag mainOutput; + private final List> additionalOutputs; + private final Map, Coder> outputCoders; + private final Map> sideInputs; + private final SideInputReader sideInputReader; + private final boolean filterMainOutput; + + SimpleRunnerFactory( + AppliedPTransform, ?, ParDo.MultiOutput> appliedPT, + PCollection input, + SideInputReader sideInputReader, + boolean filterMainOutput) { + this.stepName = appliedPT.getFullName(); + this.doFn = appliedPT.getTransform().getFn(); + this.doFnSchema = ParDoTranslation.getSchemaInformation(appliedPT); + this.coder = input.getCoder(); + this.windowingStrategy = input.getWindowingStrategy(); + this.mainOutput = appliedPT.getTransform().getMainOutputTag(); + this.additionalOutputs = additionalOutputs(appliedPT.getTransform()); + this.outputCoders = coders(appliedPT.getOutputs(), mainOutput); + this.sideInputs = appliedPT.getTransform().getSideInputs(); + this.sideInputReader = sideInputReader; + this.filterMainOutput = filterMainOutput; + } + + @Override + DoFnRunnerFactory fuse(DoFnRunnerFactory next) { + return new FusedRunnerFactory<>(Lists.newArrayList(this, next)); + } + + @Override + DoFnRunnerWithTeardown create( + PipelineOptions options, MetricsAccumulator metrics, WindowedValueMultiReceiver output) { + DoFnRunner simpleRunner = + DoFnRunners.simpleRunner( + options, + doFn, + CachedSideInputReader.of(sideInputReader, sideInputs.values()), + filterMainOutput ? new FilteredOutput<>(output, mainOutput) : output, + mainOutput, + additionalOutputs, + new NoOpStepContext(), + coder, + outputCoders, + windowingStrategy, + doFnSchema, + sideInputs); + DoFnRunnerWithTeardown runner = + new DoFnRunnerWithMetrics<>(stepName, simpleRunner, metrics); + // Invoke setup and then startBundle before returning the runner + DoFnInvokers.tryInvokeSetupFor(doFn, options); + try { + runner.startBundle(); + } catch (RuntimeException re) { + DoFnInvokers.invokerFor(doFn).invokeTeardown(); + throw re; + } + return runner; + } + + /** + * Delegate {@link WindowedValueMultiReceiver} that only forwards outputs matching the provided + * tag. This is used in cases where unused, obsolete outputs get dropped to avoid unnecessary + * caching. + */ + private static class FilteredOutput implements WindowedValueMultiReceiver { + final WindowedValueMultiReceiver outputManager; + final TupleTag tupleTag; + + FilteredOutput(WindowedValueMultiReceiver outputManager, TupleTag tupleTag) { + this.outputManager = outputManager; + this.tupleTag = tupleTag; + } + + @Override + public void output(TupleTag tag, WindowedValue value) { + if (this.tupleTag.equals(tag)) { + outputManager.output(tag, value); + } + } + } + + private static Map, Coder> coders( + Map, PCollection> pCols, TupleTag main) { + if (pCols.size() == 1) { + return Collections.singletonMap(main, Iterables.getOnlyElement(pCols.values()).getCoder()); + } + Map, Coder> coders = Maps.newHashMapWithExpectedSize(pCols.size()); + for (Map.Entry, PCollection> e : pCols.entrySet()) { + coders.put(e.getKey(), e.getValue().getCoder()); + } + return coders; + } + + private static List> additionalOutputs(ParDo.MultiOutput transform) { + List> tags = transform.getAdditionalOutputTags().getAll(); + return tags.isEmpty() ? Collections.emptyList() : new ArrayList<>(tags); + } + } + + /** + * Factory that produces a fused runner consisting of multiple chained {@link DoFn DoFns}. Outputs + * are directly forwarded to the next runner without buffering inbetween. + */ + private static class FusedRunnerFactory extends DoFnRunnerFactory { + private final List> factories; + + FusedRunnerFactory(List> factories) { + this.factories = factories; + } + + @Override + DoFnRunnerWithTeardown create( + PipelineOptions options, MetricsAccumulator metrics, WindowedValueMultiReceiver output) { + return new FusedRunner<>(options, metrics, output, factories); + } + + @Override + DoFnRunnerFactory fuse(DoFnRunnerFactory next) { + factories.add(next); + return (DoFnRunnerFactory) this; + } + + private static class FusedRunner implements DoFnRunnerWithTeardown { + final DoFnRunnerWithTeardown[] runners; + + FusedRunner( + PipelineOptions options, + MetricsAccumulator metrics, + WindowedValueMultiReceiver output, + List> factories) { + runners = new DoFnRunnerWithTeardown[factories.size()]; + runners[runners.length - 1] = + factories.get(runners.length - 1).create(options, metrics, output); + for (int i = runners.length - 2; i >= 0; i--) { + runners[i] = factories.get(i).create(options, metrics, new FusedOutput(runners[i + 1])); + } + } + + /** {@link WindowedValueReceiver} that forwards output directly to the next runner. */ + private static class FusedOutput implements WindowedValueMultiReceiver { + final DoFnRunnerWithTeardown runner; + + FusedOutput(DoFnRunnerWithTeardown runner) { + this.runner = runner; + } + + @Override + public void output(TupleTag tag, WindowedValue output) { + runner.processElement((WindowedValue) output); + } + } + + @Override + public void startBundle() { + for (int i = 0; i < runners.length; i++) { + runners[i].startBundle(); + } + } + + @Override + public void processElement(WindowedValue elem) { + runners[0].processElement((WindowedValue) elem); + } + + @Override + public void onTimer( + String timerId, + String timerFamilyId, + KeyT key, + BoundedWindow window, + Instant timestamp, + Instant outputTimestamp, + TimeDomain timeDomain, + CausedByDrain causedByDrain) { + throw new UnsupportedOperationException(); + } + + @Override + public void onWindowExpiration(BoundedWindow window, Instant timestamp, KeyT key) { + throw new UnsupportedOperationException(); + } + + @Override + public void finishBundle() { + for (int i = 0; i < runners.length; i++) { + runners[i].finishBundle(); + } + } + + @Override + public DoFn getFn() { + throw new UnsupportedOperationException(); + } + + @Override + public void teardown() { + for (int i = 0; i < runners.length; i++) { + runners[i].teardown(); + } + } + } + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnRunnerWithMetrics.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnRunnerWithMetrics.java new file mode 100644 index 000000000000..28dbf44cb8fe --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnRunnerWithMetrics.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import java.io.Closeable; +import java.io.IOException; +import org.apache.beam.runners.core.DoFnRunner; +import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator; +import org.apache.beam.runners.spark.structuredstreaming.translation.batch.DoFnRunnerFactory.DoFnRunnerWithTeardown; +import org.apache.beam.sdk.metrics.MetricsContainer; +import org.apache.beam.sdk.metrics.MetricsEnvironment; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.values.CausedByDrain; +import org.apache.beam.sdk.values.WindowedValue; +import org.joda.time.Instant; + +/** DoFnRunner decorator which registers {@link MetricsContainer}. */ +class DoFnRunnerWithMetrics implements DoFnRunnerWithTeardown { + private final DoFnRunner delegate; + private final MetricsContainer metrics; + + DoFnRunnerWithMetrics( + String stepName, DoFnRunner delegate, MetricsAccumulator metricsAccum) { + this(delegate, metricsAccum.value().getContainer(stepName)); + } + + private DoFnRunnerWithMetrics(DoFnRunner delegate, MetricsContainer metrics) { + this.delegate = delegate; + this.metrics = metrics; + } + + @Override + public DoFn getFn() { + return delegate.getFn(); + } + + @Override + public void startBundle() { + try (Closeable ignored = MetricsEnvironment.scopedMetricsContainer(metrics)) { + delegate.startBundle(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void processElement(final WindowedValue elem) { + try (Closeable ignored = MetricsEnvironment.scopedMetricsContainer(metrics)) { + delegate.processElement(elem); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void onTimer( + final String timerId, + final String timerFamilyId, + KeyT key, + final BoundedWindow window, + final Instant timestamp, + final Instant outputTimestamp, + final TimeDomain timeDomain, + CausedByDrain causedByDrain) { + try (Closeable ignored = MetricsEnvironment.scopedMetricsContainer(metrics)) { + delegate.onTimer( + timerId, + timerFamilyId, + key, + window, + timestamp, + outputTimestamp, + timeDomain, + causedByDrain); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void finishBundle() { + try (Closeable ignored = MetricsEnvironment.scopedMetricsContainer(metrics)) { + delegate.finishBundle(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void onWindowExpiration(BoundedWindow window, Instant timestamp, KeyT key) { + delegate.onWindowExpiration(window, timestamp, key); + } + + @Override + public void teardown() { + DoFnInvokers.invokerFor(delegate.getFn()).invokeTeardown(); + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTranslatorBatch.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTranslatorBatch.java new file mode 100644 index 000000000000..63786829bd53 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTranslatorBatch.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1; + +import java.util.Collection; +import java.util.Iterator; +import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.Flatten; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionList; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; + +class FlattenTranslatorBatch + extends TransformTranslator, PCollection, Flatten.PCollections> { + + FlattenTranslatorBatch() { + super(0.1f); + } + + @Override + public void translate(Flatten.PCollections transform, Context cxt) { + Collection> pCollections = cxt.getInputs().values(); + Coder outputCoder = cxt.getOutput().getCoder(); + Encoder> outputEnc = + cxt.windowedEncoder(outputCoder, windowCoder(cxt.getOutput())); + + Dataset> result; + Iterator> pcIt = (Iterator) pCollections.iterator(); + if (pcIt.hasNext()) { + result = getDataset(pcIt.next(), outputCoder, outputEnc, cxt); + while (pcIt.hasNext()) { + result = result.union(getDataset(pcIt.next(), outputCoder, outputEnc, cxt)); + } + } else { + result = cxt.createDataset(ImmutableList.of(), outputEnc); + } + cxt.putDataset(cxt.getOutput(), result); + } + + private Dataset> getDataset( + PCollection pc, Coder coder, Encoder> enc, Context cxt) { + Dataset> current = cxt.getDataset(pc); + // if coders don't match, map using identity function to replace encoder + return pc.getCoder().equals(coder) ? current : current.map(fun1(v -> v), enc); + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyHelpers.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyHelpers.java new file mode 100644 index 000000000000..3cb400759745 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyHelpers.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple; +import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING; +import static org.apache.beam.sdk.transforms.windowing.TimestampCombiner.END_OF_WINDOW; + +import java.util.Collection; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindows; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.sdk.values.WindowingStrategy; +import scala.Tuple2; +import scala.collection.IterableOnce; + +/** + * Package private helpers to support translating grouping transforms using `groupByKey` such as + * {@link GroupByKeyTranslatorBatch} or {@link CombinePerKeyTranslatorBatch}. + */ +class GroupByKeyHelpers { + + private GroupByKeyHelpers() {} + + /** + * Checks if it's possible to use an optimized `groupByKey` that also moves the window into the + * key. + * + * @param windowing The windowing strategy + * @param endOfWindowOnly Flag if to limit this optimization to {@link + * TimestampCombiner#END_OF_WINDOW}. + */ + static boolean eligibleForGroupByWindow( + WindowingStrategy windowing, boolean endOfWindowOnly) { + return !windowing.needsMerge() + && (!endOfWindowOnly || windowing.getTimestampCombiner() == END_OF_WINDOW) + && windowing.getWindowFn().windowCoder().consistentWithEquals(); + } + + /** + * Checks if it's possible to use an optimized `groupByKey` for the global window. + * + * @param windowing The windowing strategy + * @param endOfWindowOnly Flag if to limit this optimization to {@link + * TimestampCombiner#END_OF_WINDOW}. + */ + static boolean eligibleForGlobalGroupBy( + WindowingStrategy windowing, boolean endOfWindowOnly) { + return windowing.getWindowFn() instanceof GlobalWindows + && (!endOfWindowOnly || windowing.getTimestampCombiner() == END_OF_WINDOW); + } + + /** + * Explodes a windowed {@link KV} assigned to potentially multiple {@link BoundedWindow}s to a + * traversable of composite keys {@code (BoundedWindow, Key)} and value. + */ + static + Fun1>, IterableOnce, T>>> + explodeWindowedKey(Fun1>, T> valueFn) { + return v -> { + T value = valueFn.apply(v); + K key = v.getValue().getKey(); + Collection windows = (Collection) v.getWindows(); + return ScalaInterop.scalaIterator(windows).map(w -> tuple(tuple(w, key), value)); + }; + } + + static Fun1, V>, WindowedValue>> windowedKV() { + return t -> windowedKV(t._1, t._2); + } + + static WindowedValue> windowedKV(Tuple2 key, V value) { + return WindowedValues.of(KV.of(key._2, value), key._1.maxTimestamp(), key._1, NO_FIRING); + } + + static Fun1, V> value() { + return v -> v.getValue(); + } + + static Fun1>, V> valueValue() { + return v -> v.getValue().getValue(); + } + + static Fun1>, K> valueKey() { + return v -> v.getValue().getKey(); + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java new file mode 100644 index 000000000000..6a6d42c31cf3 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import static org.apache.beam.repackaged.core.org.apache.commons.lang3.ArrayUtils.EMPTY_BYTE_ARRAY; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.eligibleForGlobalGroupBy; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.eligibleForGroupByWindow; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.explodeWindowedKey; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.valueKey; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.valueValue; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.windowedKV; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.CoderHelpers.toByteArray; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.collectionEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderOf; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.kvEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.windowedValueEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.concat; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun2; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.javaIterator; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf; +import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.collect_list; +import static org.apache.spark.sql.functions.explode; +import static org.apache.spark.sql.functions.max; +import static org.apache.spark.sql.functions.min; +import static org.apache.spark.sql.functions.struct; + +import java.io.Serializable; +import org.apache.beam.runners.core.InMemoryStateInternals; +import org.apache.beam.runners.core.ReduceFnRunner; +import org.apache.beam.runners.core.StateInternalsFactory; +import org.apache.beam.runners.core.SystemReduceFn; +import org.apache.beam.runners.spark.SparkCommonPipelineOptions; +import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; +import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.GroupAlsoByWindowViaOutputBufferFn; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.transforms.GroupByKey; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo.PaneInfoCoder; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.TypedColumn; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.checkerframework.checker.nullness.qual.NonNull; +import scala.Tuple2; +import scala.collection.Iterator; +import scala.collection.JavaConverters; +import scala.collection.immutable.List; + +/** + * Translator for {@link GroupByKey} using {@link Dataset#groupByKey} with the build-in aggregation + * function {@code collect_list} when applicable. + * + *

    Note: Using {@code collect_list} isn't any worse than using {@link ReduceFnRunner}. In the + * latter case the entire group (iterator) has to be loaded into memory as well. Either way there's + * a risk of OOM errors. When enabling {@link + * SparkCommonPipelineOptions#getPreferGroupByKeyToHandleHugeValues()}, a more memory sensitive + * iterable is used that can be traversed just once. Attempting to traverse the iterable again will + * throw. + * + *

      + *
    • When using the default global window, window information is dropped and restored after the + * aggregation. + *
    • For non-merging windows, windows are exploded and moved into a composite key for better + * distribution. Though, to keep the amount of shuffled data low, this is only done if values + * are assigned to a single window or if there are only few keys and distributing data is + * important. After the aggregation, windowed values are restored from the composite key. + *
    • All other cases are implemented using the SDK {@link ReduceFnRunner}. + *
    + */ +class GroupByKeyTranslatorBatch + extends TransformTranslator< + PCollection>, PCollection>>, GroupByKey> { + + /** Literal of binary encoded Pane info. */ + private static final Column PANE_NO_FIRING = lit(toByteArray(NO_FIRING, PaneInfoCoder.of())); + + /** Defaults for value in single global window. */ + private static final List GLOBAL_WINDOW_DETAILS = + windowDetails(lit(new byte[][] {EMPTY_BYTE_ARRAY})); + + GroupByKeyTranslatorBatch() { + super(0.2f); + } + + @Override + public void translate(GroupByKey transform, Context cxt) { + WindowingStrategy windowing = cxt.getInput().getWindowingStrategy(); + TimestampCombiner tsCombiner = windowing.getTimestampCombiner(); + + Dataset>> input = cxt.getDataset(cxt.getInput()); + + KvCoder inputCoder = (KvCoder) cxt.getInput().getCoder(); + KvCoder> outputCoder = (KvCoder>) cxt.getOutput().getCoder(); + + Encoder valueEnc = cxt.valueEncoderOf(inputCoder); + Encoder keyEnc = cxt.keyEncoderOf(inputCoder); + + // In batch we can ignore triggering and allowed lateness parameters + final Dataset>>> result; + + boolean useCollectList = + !cxt.getOptions() + .as(SparkCommonPipelineOptions.class) + .getPreferGroupByKeyToHandleHugeValues(); + if (useCollectList && eligibleForGlobalGroupBy(windowing, false)) { + // Collects all values per key in memory. This might be problematic if there's + // few keys only + // or some highly skewed distribution. + result = + input + .groupBy(col("value.key").as("key")) + .agg(collect_list(col("value.value")).as("values"), timestampAggregator(tsCombiner)) + .select( + inGlobalWindow( + keyValue(col("key").as(keyEnc), col("values").as(iterableEnc(valueEnc))), + windowTimestamp(tsCombiner))); + + } else if (eligibleForGlobalGroupBy(windowing, true)) { + // Produces an iterable that can be traversed exactly once. However, on the plus + // side, data is + // not collected in memory until serialized or done by the user. + result = + cxt.getDataset(cxt.getInput()) + .groupByKey(valueKey(), keyEnc) + .mapValues(valueValue(), cxt.valueEncoderOf(inputCoder)) + .mapGroups(fun2((k, it) -> KV.of(k, iterableOnce(it))), cxt.kvEncoderOf(outputCoder)) + .map(fun1(WindowedValues::valueInGlobalWindow), cxt.windowedEncoder(outputCoder)); + + } else if (useCollectList + && eligibleForGroupByWindow(windowing, false) + && (windowing.getWindowFn().assignsToOneWindow() || transform.fewKeys())) { + // Using the window as part of the key should help to better distribute the + // data. However, if + // values are assigned to multiple windows, more data would be shuffled around. + // If there's few + // keys only, this is still valuable. + // Collects all values per key & window in memory. + result = + input + .select(explode(col("windows")).as("window"), col("value"), col("timestamp")) + .groupBy(col("value.key").as("key"), col("window")) + .agg(collect_list(col("value.value")).as("values"), timestampAggregator(tsCombiner)) + .select( + inSingleWindow( + keyValue(col("key").as(keyEnc), col("values").as(iterableEnc(valueEnc))), + col("window").as(cxt.windowEncoder()), + windowTimestamp(tsCombiner))); + + } else if (eligibleForGroupByWindow(windowing, true) + && (windowing.getWindowFn().assignsToOneWindow() || transform.fewKeys())) { + // Using the window as part of the key should help to better distribute the + // data. However, if + // values are assigned to multiple windows, more data would be shuffled around. + // If there's few + // keys only, this is still valuable. + // Produces an iterable that can be traversed exactly once. However, on the plus + // side, data is + // not collected in memory until serialized or done by the user. + Encoder> windowedKeyEnc = + cxt.tupleEncoder(cxt.windowEncoder(), keyEnc); + result = + cxt.getDataset(cxt.getInput()) + .flatMap(explodeWindowedKey(valueValue()), cxt.tupleEncoder(windowedKeyEnc, valueEnc)) + .groupByKey(fun1(t -> t._1()), windowedKeyEnc) + .mapValues(fun1(t -> t._2()), valueEnc) + .mapGroups( + fun2((wKey, it) -> windowedKV(wKey, iterableOnce((Iterator) it))), + cxt.windowedEncoder(outputCoder)); + + } else { + // Collects all values per key in memory. This might be problematic if there's + // few keys only + // or some highly skewed distribution. + + // FIXME Revisit this case, implementation is far from ideal: + // - iterator traversed at least twice, forcing materialization in memory + + // group by key, then by windows + result = + input + .groupByKey(valueKey(), keyEnc) + .flatMapGroups( + new GroupAlsoByWindowViaOutputBufferFn<>( + windowing, + (SerStateInternalsFactory) key -> InMemoryStateInternals.forKey(key), + SystemReduceFn.buffering(inputCoder.getValueCoder()), + cxt.getOptionsSupplier()), + cxt.windowedEncoder(outputCoder)); + } + + cxt.putDataset(cxt.getOutput(), result); + } + + /** Serializable In-memory state internals factory. */ + private interface SerStateInternalsFactory extends StateInternalsFactory, Serializable {} + + private Encoder> iterableEnc(Encoder enc) { + // safe to use list encoder with collect list + return (Encoder) collectionEncoder(enc); + } + + private static Column[] timestampAggregator(TimestampCombiner tsCombiner) { + if (tsCombiner.equals(TimestampCombiner.END_OF_WINDOW)) { + return new Column[0]; // no aggregation needed + } + Column agg = + tsCombiner.equals(TimestampCombiner.EARLIEST) + ? min(col("timestamp")) + : max(col("timestamp")); + return new Column[] {agg.as("timestamp")}; + } + + private static Column windowTimestamp(TimestampCombiner tsCombiner) { + if (tsCombiner.equals(TimestampCombiner.END_OF_WINDOW)) { + // null will be set to END_OF_WINDOW by the respective deserializer + return litNull(DataTypes.LongType); + } + return col("timestamp"); + } + + /** + * Java {@link Iterable} from Scala {@link Iterator} that can be iterated just once so that we + * don't have to load all data into memory. + */ + private static Iterable iterableOnce(Iterator it) { + return () -> { + checkState(!it.isEmpty(), "Iterator on values can only be consumed once!"); + return javaIterator(it); + }; + } + + private TypedColumn> keyValue(TypedColumn key, TypedColumn value) { + return struct(key.as("key"), value.as("value")).as(kvEncoder(key.encoder(), value.encoder())); + } + + private static TypedColumn> inGlobalWindow( + TypedColumn value, Column ts) { + List fields = concat(timestampedValue(value, ts), GLOBAL_WINDOW_DETAILS); + Encoder> enc = + windowedValueEncoder(value.encoder(), encoderOf(GlobalWindow.class)); + return (TypedColumn>) + struct(JavaConverters.asJavaCollection(fields).toArray(new Column[0])).as(enc); + } + + public static TypedColumn> inSingleWindow( + TypedColumn value, TypedColumn window, Column ts) { + Column windows = org.apache.spark.sql.functions.array(window); + List fields = concat(timestampedValue(value, ts), windowDetails(windows)); + Encoder> enc = windowedValueEncoder(value.encoder(), window.encoder()); + return (TypedColumn>) + struct(JavaConverters.asJavaCollection(fields).toArray(new Column[0])).as(enc); + } + + private static List timestampedValue(Column value, Column ts) { + return seqOf(value.as("value"), ts.as("timestamp")).toList(); + } + + private static List windowDetails(Column windows) { + return seqOf(windows.as("windows"), PANE_NO_FIRING.as("paneInfo")).toList(); + } + + private static Column lit(T t) { + return org.apache.spark.sql.functions.lit(t); + } + + @SuppressWarnings("nullness") // NULL literal + private static Column litNull(DataType dataType) { + return org.apache.spark.sql.functions.lit(null).cast(dataType); + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ImpulseTranslatorBatch.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ImpulseTranslatorBatch.java new file mode 100644 index 000000000000..78afdfa5451e --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ImpulseTranslatorBatch.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import static org.apache.beam.repackaged.core.org.apache.commons.lang3.ArrayUtils.EMPTY_BYTE_ARRAY; + +import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; +import org.apache.beam.sdk.coders.ByteArrayCoder; +import org.apache.beam.sdk.transforms.Impulse; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.spark.sql.Dataset; + +class ImpulseTranslatorBatch extends TransformTranslator, Impulse> { + + ImpulseTranslatorBatch() { + super(0); + } + + @Override + public void translate(Impulse transform, Context cxt) { + Dataset> dataset = + cxt.createDataset( + ImmutableList.of(WindowedValues.valueInGlobalWindow(EMPTY_BYTE_ARRAY)), + cxt.windowedEncoder(ByteArrayCoder.of(), GlobalWindow.Coder.INSTANCE)); + cxt.putDataset(cxt.getOutput(), dataset); + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java new file mode 100644 index 000000000000..0f43f329b0df --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.DoFnRunnerFactory.simple; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.oneOfEncoder; +import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; +import static org.apache.spark.sql.functions.col; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.function.Supplier; +import org.apache.beam.runners.core.DoFnRunners; +import org.apache.beam.runners.core.SideInputReader; +import org.apache.beam.runners.spark.SparkCommonPipelineOptions; +import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator; +import org.apache.beam.runners.spark.structuredstreaming.translation.PipelineTranslator.UnresolvedTranslation; +import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; +import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.SideInputValues; +import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.SparkSideInputReader; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature; +import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.TypedColumn; +import org.apache.spark.storage.StorageLevel; +import scala.Tuple2; + +/** + * Translator for {@link ParDo.MultiOutput} based on {@link DoFnRunners#simpleRunner}. + * + *

    Each tag is encoded as individual column with a respective schema & encoder each. + * + *

    TODO: + *

  • Add support for state and timers. + *
  • Add support for SplittableDoFn + */ +class ParDoTranslatorBatch + extends TransformTranslator< + PCollection, PCollectionTuple, ParDo.MultiOutput> { + + ParDoTranslatorBatch() { + super(0); + } + + @Override + public boolean canTranslate(ParDo.MultiOutput transform) { + DoFn doFn = transform.getFn(); + DoFnSignature signature = DoFnSignatures.signatureForDoFn(doFn); + + // TODO: add support of Splittable DoFn + checkState( + !signature.processElement().isSplittable(), + "Not expected to directly translate splittable DoFn, should have been overridden: %s", + doFn); + + // TODO: add support of states and timers + checkState( + !signature.usesState() && !signature.usesTimers(), + "States and timers are not supported for the moment."); + + checkState( + signature.onWindowExpiration() == null, "onWindowExpiration is not supported: %s", doFn); + + checkState( + !signature.processElement().requiresTimeSortedInput(), + "@RequiresTimeSortedInput is not supported for the moment"); + + SparkSideInputReader.validateMaterializations(transform.getSideInputs().values()); + return true; + } + + @Override + public void translate(ParDo.MultiOutput transform, Context cxt) + throws IOException { + + PCollection input = (PCollection) cxt.getInput(); + SideInputReader sideInputReader = + createSideInputReader(transform.getSideInputs().values(), cxt); + MetricsAccumulator metrics = MetricsAccumulator.getInstance(cxt.getSparkSession()); + + TupleTag mainOut = transform.getMainOutputTag(); + + // Filter out obsolete PCollections to only cache when absolutely necessary + Map, PCollection> outputs = + skipUnconsumedOutputs(cxt.getOutputs(), mainOut, transform.getAdditionalOutputTags(), cxt); + + if (outputs.size() > 1) { + // In case of multiple outputs / tags, map each tag to a column by index. + // At the end split the result into multiple datasets selecting one column each. + Map tagColIdx = tagsColumnIndex((Collection>) outputs.keySet()); + List>> encoders = createEncoders(outputs, tagColIdx, cxt); + + DoFnPartitionIteratorFactory>> doFnMapper = + DoFnPartitionIteratorFactory.multiOutput( + cxt.getOptionsSupplier(), + metrics, + simple(cxt.getCurrentTransform(), input, sideInputReader, false), + tagColIdx); + + // FIXME What's the strategy to unpersist Datasets / RDDs? + + SparkCommonPipelineOptions opts = cxt.getOptions().as(SparkCommonPipelineOptions.class); + StorageLevel storageLevel = StorageLevel.fromString(opts.getStorageLevel()); + + // Persist as wide rows with one column per TupleTag to support different schemas + Dataset>> allTagsDS = + cxt.getDataset(input).mapPartitions(doFnMapper, oneOfEncoder(encoders)); + allTagsDS.persist(storageLevel); + + // divide into separate output datasets per tag + for (TupleTag tag : outputs.keySet()) { + int colIdx = checkStateNotNull(tagColIdx.get(tag.getId()), "Unknown tag"); + // Resolve specific column matching the tuple tag (by id) + TypedColumn>, WindowedValue> col = + (TypedColumn) col(Integer.toString(colIdx)).as(encoders.get(colIdx)); + + // Caching of the returned outputs is disabled to avoid caching the same data twice. + cxt.putDataset( + cxt.getOutput((TupleTag) tag), allTagsDS.filter(col.isNotNull()).select(col), false); + } + } else { + PCollection output = cxt.getOutput(mainOut); + // Obsolete outputs might have to be filtered out + boolean filterMainOutput = cxt.getOutputs().size() > 1; + // Provide unresolved translation so that can be fused if possible + UnresolvedParDo unresolvedParDo = + new UnresolvedParDo<>( + input, + simple(cxt.getCurrentTransform(), input, sideInputReader, filterMainOutput), + () -> cxt.windowedEncoder(output.getCoder())); + cxt.putUnresolved(output, unresolvedParDo); + } + } + + /** + * An unresolved {@link ParDo} translation that can be fused with previous / following ParDos for + * better performance. + */ + private static class UnresolvedParDo implements UnresolvedTranslation { + private final PCollection input; + private final DoFnRunnerFactory doFnFact; + private final Supplier>> encoder; + + UnresolvedParDo( + PCollection input, + DoFnRunnerFactory doFnFact, + Supplier>> encoder) { + this.input = input; + this.doFnFact = doFnFact; + this.encoder = encoder; + } + + @Override + public PCollection getInput() { + return input; + } + + @Override + public UnresolvedTranslation fuse(UnresolvedTranslation next) { + UnresolvedParDo nextParDo = (UnresolvedParDo) next; + return new UnresolvedParDo<>(input, doFnFact.fuse(nextParDo.doFnFact), nextParDo.encoder); + } + + @Override + public Dataset> resolve( + Supplier options, Dataset> input) { + MetricsAccumulator metrics = MetricsAccumulator.getInstance(input.sparkSession()); + DoFnPartitionIteratorFactory> doFnMapper = + DoFnPartitionIteratorFactory.singleOutput(options, metrics, doFnFact); + return input.mapPartitions(doFnMapper, encoder.get()); + } + } + + /** + * Filter out output tags which are not consumed by any transform, except for {@code mainTag}. + * + *

    This can help to avoid unnecessary caching in case of multiple outputs if only {@code + * mainTag} is consumed. + */ + private Map, PCollection> skipUnconsumedOutputs( + Map, PCollection> outputs, + TupleTag mainTag, + TupleTagList otherTags, + Context cxt) { + switch (outputs.size()) { + case 1: + return outputs; // always keep main output + case 2: + TupleTag otherTag = otherTags.get(0); + return cxt.isLeaf(checkStateNotNull(outputs.get(otherTag))) + ? Collections.singletonMap(mainTag, checkStateNotNull(outputs.get(mainTag))) + : outputs; + default: + Map, PCollection> filtered = Maps.newHashMapWithExpectedSize(outputs.size()); + for (Map.Entry, PCollection> e : outputs.entrySet()) { + if (e.getKey().equals(mainTag) || !cxt.isLeaf(e.getValue())) { + filtered.put(e.getKey(), e.getValue()); + } + } + return filtered; + } + } + + private Map tagsColumnIndex(Collection> tags) { + Map index = Maps.newHashMapWithExpectedSize(tags.size()); + for (TupleTag tag : tags) { + index.put(tag.getId(), index.size()); + } + return index; + } + + /** List of encoders matching the order of tagIds. */ + private List>> createEncoders( + Map, PCollection> outputs, Map tagIdColIdx, Context ctx) { + ArrayList>> encoders = new ArrayList<>(outputs.size()); + for (Entry, PCollection> e : outputs.entrySet()) { + Encoder> enc = ctx.windowedEncoder((Coder) e.getValue().getCoder()); + int colIdx = checkStateNotNull(tagIdColIdx.get(e.getKey().getId())); + encoders.add(colIdx, enc); + } + return encoders; + } + + private SideInputReader createSideInputReader( + Collection> views, Context cxt) { + if (views.isEmpty()) { + return SparkSideInputReader.empty(); + } + Map>> broadcasts = + Maps.newHashMapWithExpectedSize(views.size()); + for (PCollectionView view : views) { + PCollection pCol = checkStateNotNull((PCollection) view.getPCollection()); + // get broadcasted SideInputValues for pCol, if not available use loader function + Broadcast> broadcast = + cxt.getSideInputBroadcast(pCol, SideInputValues.loader(pCol)); + broadcasts.put(view.getTagInternal().getId(), (Broadcast) broadcast); + } + return SparkSideInputReader.create(broadcasts); + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/PipelineTranslatorBatch.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/PipelineTranslatorBatch.java new file mode 100644 index 000000000000..c4a18801ccba --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/PipelineTranslatorBatch.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import java.util.HashMap; +import java.util.Map; +import org.apache.beam.runners.spark.structuredstreaming.translation.PipelineTranslator; +import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.Flatten; +import org.apache.beam.sdk.transforms.GroupByKey; +import org.apache.beam.sdk.transforms.Impulse; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.Reshuffle; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.util.construction.SplittableParDo; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.POutput; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * {@link PipelineTranslator} for executing a {@link Pipeline} in Spark in batch mode. This contains + * only the components specific to batch: registry of batch {@link TransformTranslator} and registry + * lookup code. + */ +@Internal +public class PipelineTranslatorBatch extends PipelineTranslator { + + // -------------------------------------------------------------------------------------------- + // Transform Translator Registry + // -------------------------------------------------------------------------------------------- + + @SuppressWarnings("rawtypes") + private static final Map, TransformTranslator> TRANSFORM_TRANSLATORS = + new HashMap<>(); + + // TODO the ability to have more than one TransformTranslator per URN + // that could be dynamically chosen by a predicated that evaluates based on PCollection + // obtainable though node.getInputs.getValue() + // See + // https://github.com/seznam/euphoria/blob/master/euphoria-spark/src/main/java/cz/seznam/euphoria/spark/SparkFlowTranslator.java#L83 + // And + // https://github.com/seznam/euphoria/blob/master/euphoria-spark/src/main/java/cz/seznam/euphoria/spark/SparkFlowTranslator.java#L106 + + static { + TRANSFORM_TRANSLATORS.put(Impulse.class, new ImpulseTranslatorBatch()); + TRANSFORM_TRANSLATORS.put(Combine.PerKey.class, new CombinePerKeyTranslatorBatch<>()); + TRANSFORM_TRANSLATORS.put(Combine.Globally.class, new CombineGloballyTranslatorBatch<>()); + TRANSFORM_TRANSLATORS.put( + Combine.GroupedValues.class, new CombineGroupedValuesTranslatorBatch<>()); + TRANSFORM_TRANSLATORS.put(GroupByKey.class, new GroupByKeyTranslatorBatch<>()); + + TRANSFORM_TRANSLATORS.put(Reshuffle.class, new ReshuffleTranslatorBatch<>()); + TRANSFORM_TRANSLATORS.put( + Reshuffle.ViaRandomKey.class, new ReshuffleTranslatorBatch.ViaRandomKey<>()); + + TRANSFORM_TRANSLATORS.put(Flatten.PCollections.class, new FlattenTranslatorBatch<>()); + + TRANSFORM_TRANSLATORS.put(Window.Assign.class, new WindowAssignTranslatorBatch<>()); + + TRANSFORM_TRANSLATORS.put(ParDo.MultiOutput.class, new ParDoTranslatorBatch<>()); + + TRANSFORM_TRANSLATORS.put( + SplittableParDo.PrimitiveBoundedRead.class, new ReadSourceTranslatorBatch<>()); + } + + /** Returns a {@link TransformTranslator} for the given {@link PTransform} if known. */ + @Override + @Nullable + protected > + TransformTranslator getTransformTranslator(TransformT transform) { + return TRANSFORM_TRANSLATORS.get(transform.getClass()); + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReadSourceTranslatorBatch.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReadSourceTranslatorBatch.java new file mode 100644 index 000000000000..e83ada473d0c --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReadSourceTranslatorBatch.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import java.io.IOException; +import java.util.function.Supplier; +import org.apache.beam.runners.spark.structuredstreaming.io.BoundedDatasetFactory; +import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.util.construction.SplittableParDo; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.SparkSession; + +/** + * Translator for a {@link SplittableParDo.PrimitiveBoundedRead} that creates a Dataset via an RDD + * to avoid an additional serialization roundtrip. + */ +class ReadSourceTranslatorBatch + extends TransformTranslator, SplittableParDo.PrimitiveBoundedRead> { + + ReadSourceTranslatorBatch() { + super(0.05f); + } + + @Override + public void translate(SplittableParDo.PrimitiveBoundedRead transform, Context cxt) + throws IOException { + SparkSession session = cxt.getSparkSession(); + BoundedSource source = transform.getSource(); + Supplier options = cxt.getOptionsSupplier(); + + Encoder> encoder = + cxt.windowedEncoder(source.getOutputCoder(), GlobalWindow.Coder.INSTANCE); + + cxt.putDataset( + cxt.getOutput(), + BoundedDatasetFactory.createDatasetFromRDD(session, source, options, encoder), + false); + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReshuffleTranslatorBatch.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReshuffleTranslatorBatch.java new file mode 100644 index 000000000000..2c541ba4ae43 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReshuffleTranslatorBatch.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import static org.apache.spark.sql.functions.col; + +import java.io.IOException; +import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; +import org.apache.beam.sdk.transforms.Reshuffle; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.internal.SQLConf; + +class ReshuffleTranslatorBatch + extends TransformTranslator>, PCollection>, Reshuffle> { + + ReshuffleTranslatorBatch() { + super(0.1f); + } + + @Override + protected void translate(Reshuffle transform, Context cxt) throws IOException { + Dataset>> input = cxt.getDataset(cxt.getInput()); + cxt.putDataset(cxt.getOutput(), input.repartition(col("value.key"))); + } + + static class ViaRandomKey + extends TransformTranslator, PCollection, Reshuffle.ViaRandomKey> { + + ViaRandomKey() { + super(0.1f); + } + + @Override + protected void translate(Reshuffle.ViaRandomKey transform, Context cxt) throws IOException { + Dataset> input = cxt.getDataset(cxt.getInput()); + // Reshuffle randomly + cxt.putDataset(cxt.getOutput(), input.repartition(SQLConf.get().numShufflePartitions())); + } + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTranslatorBatch.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTranslatorBatch.java new file mode 100644 index 000000000000..25e08cd9de99 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTranslatorBatch.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import java.util.Collection; +import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindows; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.sql.Dataset; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.joda.time.Instant; + +class WindowAssignTranslatorBatch + extends TransformTranslator, PCollection, Window.Assign> { + + WindowAssignTranslatorBatch() { + super(0.05f); + } + + @Override + public void translate(Window.Assign transform, Context cxt) { + WindowFn windowFn = transform.getWindowFn(); + PCollection input = cxt.getInput(); + Dataset> inputDataset = cxt.getDataset(input); + + if (windowFn == null || skipAssignWindows(windowFn, input)) { + cxt.putDataset(cxt.getOutput(), inputDataset); + } else { + Dataset> outputDataset = + inputDataset.map( + assignWindows(windowFn), + cxt.windowedEncoder(input.getCoder(), windowFn.windowCoder())); + + cxt.putDataset(cxt.getOutput(), outputDataset); + } + } + + /** + * Checks if the window transformation should be applied or skipped. + * + *

    Avoid running assign windows if both source and destination are global window or if the user + * has not specified the WindowFn (meaning they are just messing with triggering or allowed + * lateness). + */ + private boolean skipAssignWindows(WindowFn newFn, PCollection input) { + WindowFn currentFn = input.getWindowingStrategy().getWindowFn(); + return currentFn instanceof GlobalWindows && newFn instanceof GlobalWindows; + } + + private static + MapFunction, WindowedValue> assignWindows(WindowFn windowFn) { + return value -> { + final BoundedWindow window = getOnlyWindow(value); + final T element = value.getValue(); + final Instant timestamp = value.getTimestamp(); + Collection windows = + windowFn.assignWindows( + windowFn.new AssignContext() { + + @Override + public T element() { + return element; + } + + @Override + public @NonNull Instant timestamp() { + return timestamp; + } + + @Override + public @NonNull BoundedWindow window() { + return window; + } + }); + return WindowedValues.of(element, timestamp, windows, value.getPaneInfo()); + }; + } + + private static BoundedWindow getOnlyWindow(WindowedValue wv) { + return Iterables.getOnlyElement((Iterable) wv.getWindows()); + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/CachedSideInputReader.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/CachedSideInputReader.java new file mode 100644 index 000000000000..1db08d935fca --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/CachedSideInputReader.java @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions; + +import static org.apache.beam.sdk.transforms.Materializations.MULTIMAP_MATERIALIZATION_URN; + +import java.util.Collection; +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ExecutionException; +import org.apache.beam.runners.core.SideInputReader; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.transforms.Materialization; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindows; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheStats; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * SideInputReader that caches results for costly {@link Materialization Materializations}. + * Concurrent access is not expected, but it won't impact correctness. + */ +@Internal +public class CachedSideInputReader implements SideInputReader { + private final SideInputReader reader; + private final Map, Cache> caches; + + /** + * Creates a SideInputReader that caches results for costly {@link Materialization + * Materializations} if present, otherwise the SideInputReader is returned as is. Concurrent + * access is not expected, but it won't impact correctness. + */ + public static SideInputReader of(SideInputReader reader, Collection> views) { + Map, Cache> caches = initCaches(views, 1000); + return caches.isEmpty() ? reader : new CachedSideInputReader(reader, caches); + } + + private CachedSideInputReader( + SideInputReader reader, Map, Cache> caches) { + this.reader = reader; + this.caches = caches; + } + + /** + * Init caches based on {@link #shouldCache} using {@link SingletonCache} if using global window, + * and otherwise a Guava LRU cache. + */ + private static Map, Cache> initCaches( + Iterable> views, int maxSize) { + ImmutableMap.Builder, Cache> builder = + ImmutableMap.builder(); + for (PCollectionView view : views) { + if (shouldCache(view)) { + boolean isGlobal = + view.getWindowingStrategyInternal().getWindowFn() instanceof GlobalWindows; + builder.put(view, isGlobal ? new SingletonCache<>() : lruCache(maxSize)); + } + } + return builder.build(); + } + + private static boolean shouldCache(PCollectionView view) { + // only cache expensive multimap views + return MULTIMAP_MATERIALIZATION_URN.equals(view.getViewFn().getMaterialization().getUrn()); + } + + private static Cache lruCache(int maxSize) { + // no concurrent access expected, using separate instance per partition iterator + return CacheBuilder.newBuilder().concurrencyLevel(1).maximumSize(maxSize).build(); + } + + @Override + public @Nullable T get(PCollectionView view, BoundedWindow window) { + Cache cache = caches.get(view); + if (cache != null) { + Object result = cache.getIfPresent(window); + if (result == null) { + result = reader.get(view, window); + if (result != null) { + cache.put(window, result); + } + return (T) result; + } + } + return reader.get(view, window); + } + + @Override + public boolean contains(PCollectionView view) { + return reader.contains(view); + } + + @Override + public boolean isEmpty() { + return reader.isEmpty(); + } + + /** Caching a singleton value, ignoring any key. */ + private static class SingletonCache + implements Cache { + private @Nullable V value; + + @Override + public @Nullable V getIfPresent(Object o) { + return value; + } + + @Override + public void put(K k, V v) { + value = v; + } + + @Override + public long size() { + return value != null ? 1 : 0; + } + + @Override + public V get(K k, Callable callable) throws ExecutionException { + throw new UnsupportedOperationException(); + } + + @Override + public ImmutableMap getAllPresent(Iterable iterable) { + throw new UnsupportedOperationException(); + } + + @Override + public void putAll(Map map) { + throw new UnsupportedOperationException(); + } + + @Override + public void invalidate(Object o) {} + + @Override + public void invalidateAll(Iterable iterable) {} + + @Override + public void invalidateAll() {} + + @Override + public CacheStats stats() { + throw new UnsupportedOperationException(); + } + + @Override + public ConcurrentMap asMap() { + throw new UnsupportedOperationException(); + } + + @Override + public void cleanUp() { + value = null; + } + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/GroupAlsoByWindowViaOutputBufferFn.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/GroupAlsoByWindowViaOutputBufferFn.java new file mode 100644 index 000000000000..ea436c24634b --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/GroupAlsoByWindowViaOutputBufferFn.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.function.Supplier; +import org.apache.beam.runners.core.InMemoryTimerInternals; +import org.apache.beam.runners.core.ReduceFnRunner; +import org.apache.beam.runners.core.StateInternals; +import org.apache.beam.runners.core.StateInternalsFactory; +import org.apache.beam.runners.core.SystemReduceFn; +import org.apache.beam.runners.core.TimerInternals; +import org.apache.beam.runners.core.UnsupportedSideInputReader; +import org.apache.beam.runners.core.triggers.ExecutableTriggerStateMachine; +import org.apache.beam.runners.core.triggers.TriggerStateMachines; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.WindowedValueReceiver; +import org.apache.beam.sdk.util.construction.TriggerTranslation; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.spark.api.java.function.FlatMapGroupsFunction; +import org.joda.time.Instant; + +/** A FlatMap function that groups by windows in batch mode using {@link ReduceFnRunner}. */ +public class GroupAlsoByWindowViaOutputBufferFn + implements FlatMapGroupsFunction< + K, WindowedValue>, WindowedValue>>> { + + private final WindowingStrategy windowingStrategy; + private final StateInternalsFactory stateInternalsFactory; + private final SystemReduceFn, Iterable, W> reduceFn; + private final Supplier options; + + public GroupAlsoByWindowViaOutputBufferFn( + WindowingStrategy windowingStrategy, + StateInternalsFactory stateInternalsFactory, + SystemReduceFn, Iterable, W> reduceFn, + Supplier options) { + this.windowingStrategy = windowingStrategy; + this.stateInternalsFactory = stateInternalsFactory; + this.reduceFn = reduceFn; + this.options = options; + } + + @Override + public Iterator>>> call( + K key, Iterator>> iterator) throws Exception { + + // we have to materialize the Iterator because ReduceFnRunner.processElements expects + // to have all elements to merge the windows between each other. + // possible OOM even though the spark framework spills to disk if a given group is too large to + // fit in memory. + ArrayList> values = new ArrayList<>(); + while (iterator.hasNext()) { + WindowedValue> wv = iterator.next(); + values.add(wv.withValue(wv.getValue().getValue())); + } + + // ------ based on GroupAlsoByWindowsViaOutputBufferDoFn ------// + + // Used with Batch, we know that all the data is available for this key. We can't use the + // timer manager from the context because it doesn't exist. So we create one and emulate the + // watermark, knowing that we have all data and it is in timestamp order. + InMemoryTimerInternals timerInternals = new InMemoryTimerInternals(); + timerInternals.advanceProcessingTime(Instant.now()); + timerInternals.advanceSynchronizedProcessingTime(Instant.now()); + StateInternals stateInternals = stateInternalsFactory.stateInternalsForKey(key); + GABWWindowedValueReceiver outputter = new GABWWindowedValueReceiver<>(); + + ReduceFnRunner, W> reduceFnRunner = + new ReduceFnRunner<>( + key, + windowingStrategy, + ExecutableTriggerStateMachine.create( + TriggerStateMachines.stateMachineForTrigger( + TriggerTranslation.toProto(windowingStrategy.getTrigger()))), + stateInternals, + timerInternals, + outputter, + new UnsupportedSideInputReader("GroupAlsoByWindow"), + reduceFn, + options.get()); + + // Process the grouped values. + reduceFnRunner.processElements(values); + + // Finish any pending windows by advancing the input watermark to infinity. + timerInternals.advanceInputWatermark(BoundedWindow.TIMESTAMP_MAX_VALUE); + + // Finally, advance the processing time to infinity to fire any timers. + timerInternals.advanceProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE); + timerInternals.advanceSynchronizedProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE); + + fireEligibleTimers(timerInternals, reduceFnRunner); + + reduceFnRunner.persist(); + + return outputter.getOutputs().iterator(); + } + + private void fireEligibleTimers( + InMemoryTimerInternals timerInternals, + ReduceFnRunner, W> reduceFnRunner) + throws Exception { + List timers = new ArrayList<>(); + while (true) { + TimerInternals.TimerData timer; + while ((timer = timerInternals.removeNextEventTimer()) != null) { + timers.add(timer); + } + while ((timer = timerInternals.removeNextProcessingTimer()) != null) { + timers.add(timer); + } + while ((timer = timerInternals.removeNextSynchronizedProcessingTimer()) != null) { + timers.add(timer); + } + if (timers.isEmpty()) { + break; + } + reduceFnRunner.onTimers(timers); + timers.clear(); + } + } + + private static class GABWWindowedValueReceiver + implements WindowedValueReceiver>> { + private final List>>> outputs = new ArrayList<>(); + + @Override + public void output(WindowedValue>> value) { + outputs.add(value); + } + + Iterable>>> getOutputs() { + return outputs; + } + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/NoOpStepContext.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/NoOpStepContext.java new file mode 100644 index 000000000000..25e6f112a3f3 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/NoOpStepContext.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions; + +import org.apache.beam.runners.core.StateInternals; +import org.apache.beam.runners.core.StepContext; +import org.apache.beam.runners.core.TimerInternals; + +/** A {@link StepContext} for Spark Batch Runner execution. */ +public class NoOpStepContext implements StepContext { + + @Override + public StateInternals stateInternals() { + throw new UnsupportedOperationException("stateInternals is not supported"); + } + + @Override + public TimerInternals timerInternals() { + throw new UnsupportedOperationException("timerInternals is not supported"); + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/SideInputValues.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/SideInputValues.java new file mode 100644 index 000000000000..23c8d49c3091 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/SideInputValues.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions; + +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1; +import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; +import static org.apache.beam.sdk.values.WindowedValues.getFullCoder; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import static org.apache.spark.sql.Encoders.BINARY; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import javax.annotation.Nullable; +import org.apache.beam.runners.spark.structuredstreaming.translation.EvaluationContext; +import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.CoderHelpers; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindows; +import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.spark.sql.Dataset; +import org.checkerframework.checker.nullness.qual.MonotonicNonNull; +import org.checkerframework.checker.nullness.qual.NonNull; +import scala.Function1; + +/** + * {@link SideInputValues} serves as a Kryo serializable container that contains a materialized view + * of side inputs. Once the materialized view is build, the container is broadcasted for use in the + * {@link SparkSideInputReader}. This happens during translation time of the pipeline. + * + *

    If Kryo serialization is disabled in Spark, Java serialization will be used instead and some + * optimizations will not be available. + */ +@Internal +public interface SideInputValues extends Serializable, KryoSerializable { + /** Factory function for load {@link SideInputValues} from a {@link Dataset}. */ + interface Loader extends Function>, SideInputValues> {} + + @Nullable + List get(BoundedWindow window); + + /** + * Factory to load {@link SideInputValues} from a {@link Dataset} based on the window strategy. + */ + static Loader loader(PCollection pCol) { + WindowFn fn = pCol.getWindowingStrategy().getWindowFn(); + return fn instanceof GlobalWindows + ? ds -> new Global<>(pCol.getName(), pCol.getCoder(), ds) + : ds -> new ByWindow<>(pCol.getName(), getFullCoder(pCol.getCoder(), fn.windowCoder()), ds); + } + + /** + * Specialized {@link SideInputValues} for use with the {@link GlobalWindow} in two possible + * states. + *

  • Initially it contains the binary values to be broadcasted. + *
  • On the receiver / executor side the binary values are deserialized once. The binary values + * are dropped to minimize memory usage. + */ + class Global extends BaseSideInputValues, T> { + @VisibleForTesting + Global(String name, Coder coder, Dataset> data) { + super(coder, EvaluationContext.collect(name, binaryDataset(data, coder))); + } + + @Override + public @Nullable List get(BoundedWindow window) { + checkArgument(window instanceof GlobalWindow, "Expected GlobalWindow"); + return getValues(); + } + + @Override + List deserialize(byte[][] binaryValues, Coder coder) { + List values = new ArrayList<>(binaryValues.length); + for (byte[] binaryValue : binaryValues) { + values.add(CoderHelpers.fromByteArray(binaryValue, coder)); + } + return values; + } + + private static Dataset binaryDataset(Dataset> ds, Coder coder) { + return ds.map(bytes(coder), BINARY()); // prevents checker crash + } + + private static Function1, byte[]> bytes(Coder coder) { + return fun1(t -> CoderHelpers.toByteArray(t.getValue(), coder)); + } + } + + /** + * General {@link SideInputValues} for {@link BoundedWindow BoundedWindows} in two possible + * states. + *
  • Initially it contains the binary values to be broadcasted. + *
  • On the receiver / executor side the binary values are deserialized once. The binary values + * are dropped to minimize memory usage. + */ + class ByWindow extends BaseSideInputValues, Map>, T> { + @VisibleForTesting + ByWindow(String name, Coder> coder, Dataset> ds) { + super(coder, EvaluationContext.collect(name, binaryDataset(ds, coder))); + } + + @Override + public @Nullable List get(BoundedWindow window) { + return getValues().get(window); + } + + @Override + Map> deserialize(byte[][] binaryValues, Coder> coder) { + Map> values = new HashMap<>(); + for (byte[] binaryValue : binaryValues) { + WindowedValue value = CoderHelpers.fromByteArray(binaryValue, coder); + for (BoundedWindow window : value.getWindows()) { + List list = values.computeIfAbsent(window, w -> new ArrayList<>()); + list.add(value.getValue()); + } + } + return values; + } + + private static Dataset binaryDataset( + Dataset> ds, Coder> coder) { + return ds.map(bytes(coder), BINARY()); // prevents checker crash + } + + private static Function1, byte[]> bytes(Coder> coder) { + return fun1(t -> CoderHelpers.toByteArray(t, coder)); + } + } + + abstract class BaseSideInputValues + implements SideInputValues { + private Coder coder; + private @Nullable byte[][] binaryValues; + private transient @MonotonicNonNull ValuesT values = null; + + private BaseSideInputValues(Coder coder, @Nullable byte[][] binary) { + this.coder = coder; + this.binaryValues = binary; + } + + abstract ValuesT deserialize(byte[][] binaryValues, Coder coder); + + final ValuesT getValues() { + if (values == null) { + values = deserialize(checkStateNotNull(binaryValues), coder); + } + return values; + } + + @Override + public void write(Kryo kryo, Output output) { + kryo.writeClassAndObject(output, coder); + kryo.writeObject(output, checkStateNotNull(binaryValues)); + } + + @Override + public void read(Kryo kryo, Input input) { + coder = (Coder) kryo.readClassAndObject(input); + values = deserialize(checkStateNotNull(kryo.readObject(input, byte[][].class)), coder); + } + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/SparkSideInputReader.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/SparkSideInputReader.java new file mode 100644 index 000000000000..50c0b8a50b14 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/SparkSideInputReader.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions; + +import static org.apache.beam.sdk.transforms.Materializations.ITERABLE_MATERIALIZATION_URN; +import static org.apache.beam.sdk.transforms.Materializations.MULTIMAP_MATERIALIZATION_URN; +import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; + +import java.io.Serializable; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.apache.beam.runners.core.InMemoryMultimapSideInputView; +import org.apache.beam.runners.core.SideInputReader; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.transforms.Materializations.IterableView; +import org.apache.beam.sdk.transforms.Materializations.MultimapView; +import org.apache.beam.sdk.transforms.ViewFn; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; +import org.apache.spark.broadcast.Broadcast; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** SideInputReader using broadcasted {@link SideInputValues}. */ +public class SparkSideInputReader implements SideInputReader, Serializable { + private static final SideInputReader EMPTY_READER = new EmptyReader(); + + private static final Set SUPPORTED_MATERIALIZATIONS = + ImmutableSet.of(ITERABLE_MATERIALIZATION_URN, MULTIMAP_MATERIALIZATION_URN); + + // Map of PCollectionView tag id to broadcasted SideInputValues + private final Map>> sideInputs; + + public static SideInputReader empty() { + return EMPTY_READER; + } + + /** + * Creates a {@link SideInputReader} for Spark from a map of PCollectionView {@link + * TupleTag#getId() tag ids} and the corresponding broadcasted {@link SideInputValues}. + * + *

    Note, the materialization of respective {@link PCollectionView PCollectionViews} should be + * validated ahead of time before any costly creation and broadcast of {@link SideInputValues}. + */ + public static SideInputReader create(Map>> sideInputs) { + return sideInputs.isEmpty() ? empty() : new SparkSideInputReader(sideInputs); + } + + public static void validateMaterializations(Iterable> views) { + for (PCollectionView view : views) { + String viewUrn = view.getViewFn().getMaterialization().getUrn(); + checkArgument( + SUPPORTED_MATERIALIZATIONS.contains(viewUrn), + "This handler is only capable of dealing with %s materializations " + + "but was asked to handle %s for PCollectionView with tag %s.", + SUPPORTED_MATERIALIZATIONS, + viewUrn, + view.getTagInternal().getId()); + } + } + + private SparkSideInputReader(Map>> sideInputs) { + this.sideInputs = sideInputs; + } + + private static T iterableView( + ViewFn, T> viewFn, @Nullable List values) { + return values != null ? viewFn.apply(() -> values) : viewFn.apply(Collections::emptyList); + } + + private static T multimapView( + ViewFn, T> viewFn, Coder keyCoder, @Nullable List> values) { + return values != null && !values.isEmpty() + ? viewFn.apply(InMemoryMultimapSideInputView.fromIterable(keyCoder, values)) + : viewFn.apply(InMemoryMultimapSideInputView.empty()); + } + + @Override + @SuppressWarnings("unchecked") // + public @Nullable T get(PCollectionView view, BoundedWindow window) { + Broadcast> broadcast = + checkStateNotNull( + sideInputs.get(view.getTagInternal().getId()), "View %s not available.", view); + + @Nullable List values = broadcast.value().get(window); + switch (view.getViewFn().getMaterialization().getUrn()) { + case ITERABLE_MATERIALIZATION_URN: + return (T) iterableView((ViewFn) view.getViewFn(), values); + case MULTIMAP_MATERIALIZATION_URN: + Coder keyCoder = ((KvCoder) view.getCoderInternal()).getKeyCoder(); + return (T) multimapView((ViewFn) view.getViewFn(), keyCoder, (List) values); + default: + throw new IllegalStateException( + String.format( + "Unknown materialization urn '%s'", + view.getViewFn().getMaterialization().getUrn())); + } + } + + @Override + public boolean contains(PCollectionView view) { + return sideInputs.containsKey(view.getTagInternal().getId()); + } + + @Override + public boolean isEmpty() { + return sideInputs.isEmpty(); + } + + private static class EmptyReader implements SideInputReader, Serializable { + @Override + public @Nullable T get(PCollectionView view, BoundedWindow window) { + throw new IllegalArgumentException("Cannot get view from empty SideInputReader"); + } + + @Override + public boolean contains(PCollectionView view) { + return false; + } + + @Override + public boolean isEmpty() { + return true; + } + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/package-info.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/package-info.java new file mode 100644 index 000000000000..1f03bac21240 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Internal implementation of the Beam runner for Apache Spark. */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions; diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/package-info.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/package-info.java new file mode 100644 index 000000000000..6d3ce5aa723f --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Internal utilities to translate Beam pipelines to Spark batching. */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/CoderHelpers.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/CoderHelpers.java new file mode 100644 index 000000000000..f8c63bc34f14 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/CoderHelpers.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.helpers; + +import java.io.IOException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.util.CoderUtils; + +/** Serialization utility class. */ +public final class CoderHelpers { + private CoderHelpers() {} + + /** + * Utility method for serializing an object using the specified coder. + * + * @param value Value to serialize. + * @param coder Coder to serialize with. + * @param type of value that is serialized + * @return Byte array representing serialized object. + */ + public static byte[] toByteArray(T value, Coder coder) { + try { + return CoderUtils.encodeToByteArray(coder, value); + } catch (IOException e) { + throw new IllegalStateException("Error encoding value: " + value, e); + } + } + + /** + * Utility method for deserializing a byte array using the specified coder. + * + * @param serialized bytearray to be deserialized. + * @param coder Coder to deserialize with. + * @param Type of object to be returned. + * @return Deserialized object. + */ + public static T fromByteArray(byte[] serialized, Coder coder) { + try { + return CoderUtils.decodeFromByteArray(coder, serialized); + } catch (IOException e) { + throw new IllegalStateException("Error decoding bytes for coder: " + coder, e); + } + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java new file mode 100644 index 000000000000..087421ab1c64 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.helpers; + +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.emptyList; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.replace; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf; + +import java.lang.reflect.Constructor; +import java.util.ArrayList; +import java.util.List; +import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal; +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder; +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders; +import org.apache.spark.sql.catalyst.encoders.AgnosticExpressionPathEncoder; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; +import org.apache.spark.sql.catalyst.expressions.BoundReference; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.objects.Invoke; +import org.apache.spark.sql.catalyst.expressions.objects.NewInstance; +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import scala.Option; +import scala.collection.Iterator; +import scala.collection.immutable.Seq; +import scala.reflect.ClassTag; + +public class EncoderFactory { + // default constructor to reflectively create static invoke expressions + private static final Constructor STATIC_INVOKE_CONSTRUCTOR = + (Constructor) StaticInvoke.class.getConstructors()[0]; + + private static final Constructor INVOKE_CONSTRUCTOR = + (Constructor) Invoke.class.getConstructors()[0]; + + private static final Constructor NEW_INSTANCE_CONSTRUCTOR = + (Constructor) NewInstance.class.getConstructors()[0]; + + @SuppressWarnings({"nullness", "unchecked"}) + static ExpressionEncoder create( + Expression serializer, Expression deserializer, Class clazz) { + AgnosticEncoder agnosticEncoder = new BeamAgnosticEncoder<>(serializer, deserializer, clazz); + return ExpressionEncoder.apply(agnosticEncoder, serializer, deserializer); + } + + /** + * An {@link AgnosticEncoder} that implements both {@link AgnosticExpressionPathEncoder} (so that + * {@code SerializerBuildHelper} / {@code DeserializerBuildHelper} delegate to our pre-built + * expressions) and {@link AgnosticEncoders.StructEncoder} (so that {@code + * Dataset.select(TypedColumn)} creates an N-attribute plan instead of a 1-attribute wrapped plan, + * preventing {@code FIELD_NUMBER_MISMATCH} errors). + * + *

    The {@code toCatalyst} / {@code fromCatalyst} methods substitute the {@code input} + * expression into the pre-built serializer / deserializer via {@code transformUp}, so that when + * this encoder is nested inside a composite encoder (e.g. {@code Encoders.tuple}) the correct + * field-level expression is used in place of the root {@code BoundReference} / {@code + * GetColumnByOrdinal}. + */ + @SuppressWarnings({"nullness", "unchecked", "deprecation"}) + private static final class BeamAgnosticEncoder + implements AgnosticExpressionPathEncoder, AgnosticEncoders.StructEncoder { + + private final Expression serializer; + private final Expression deserializer; + private final Class clazz; + private final Seq encoderFields; + + BeamAgnosticEncoder(Expression serializer, Expression deserializer, Class clazz) { + this.serializer = serializer; + this.deserializer = deserializer; + this.clazz = clazz; + this.encoderFields = buildFields(serializer.dataType()); + } + + private static Seq buildFields(DataType dt) { + if (dt instanceof StructType) { + StructField[] structFields = ((StructType) dt).fields(); + List fields = new ArrayList<>(structFields.length); + for (StructField sf : structFields) { + fields.add( + new AgnosticEncoders.EncoderField( + sf.name(), + new FieldEncoder<>(sf.dataType(), sf.nullable()), + sf.nullable(), + sf.metadata(), + Option.empty(), + Option.empty())); + } + return seqOf(fields.toArray(new AgnosticEncoders.EncoderField[0])); + } else { + // Non-struct: wrap in a single "value" field so StructEncoder sees one field. + return seqOf( + new AgnosticEncoders.EncoderField( + "value", + new FieldEncoder<>(dt, true), + true, + Metadata.empty(), + Option.empty(), + Option.empty())); + } + } + + // --- AgnosticExpressionPathEncoder --- + + @Override + public Expression toCatalyst(Expression input) { + return serializer.transformUp(replace(BoundReference.class, input)); + } + + @Override + public Expression fromCatalyst(Expression input) { + return deserializer.transformUp(replace(GetColumnByOrdinal.class, input)); + } + + // --- AgnosticEncoders.StructEncoder --- + + @Override + public Seq fields() { + return encoderFields; + } + + @Override + public boolean isStruct() { + return true; + } + + @Override + public void + org$apache$spark$sql$catalyst$encoders$AgnosticEncoders$StructEncoder$_setter_$isStruct_$eq( + boolean v) { + // no-op: isStruct() is implemented directly above + } + + // --- AgnosticEncoder / Encoder (explicit to resolve default-method ambiguity) --- + + @Override + public boolean isPrimitive() { + return false; + } + + @Override + public StructType schema() { + // Build StructType from fields — mirrors the StructEncoder.schema() default. + List sfs = new ArrayList<>(encoderFields.size()); + Iterator it = encoderFields.iterator(); + while (it.hasNext()) { + sfs.add(it.next().structField()); + } + return new StructType(sfs.toArray(new StructField[0])); + } + + @Override + public DataType dataType() { + return schema(); + } + + @Override + public ClassTag clsTag() { + return (ClassTag) ClassTag.apply(clazz); + } + } + + /** + * Minimal {@link AgnosticEncoder} stub used to carry per-field {@link DataType} metadata inside + * {@link AgnosticEncoders.EncoderField}. The actual serialization / deserialization is handled by + * {@link BeamAgnosticEncoder#toCatalyst} and {@link BeamAgnosticEncoder#fromCatalyst}. + */ + @SuppressWarnings({"nullness", "unchecked"}) + private static final class FieldEncoder implements AgnosticEncoder { + private final DataType fieldDataType; + private final boolean fieldNullable; + + FieldEncoder(DataType dataType, boolean nullable) { + this.fieldDataType = dataType; + this.fieldNullable = nullable; + } + + @Override + public boolean isPrimitive() { + return false; + } + + @Override + public DataType dataType() { + return fieldDataType; + } + + @Override + public StructType schema() { + return new StructType().add("value", fieldDataType, fieldNullable); + } + + @Override + public boolean nullable() { + return fieldNullable; + } + + @Override + public ClassTag clsTag() { + return (ClassTag) ClassTag.apply(Object.class); + } + } + + /** + * Invoke method {@code fun} on Class {@code cls}, immediately propagating {@code null} if any + * input arg is {@code null}. + */ + static Expression invokeIfNotNull(Class cls, String fun, DataType type, Expression... args) { + return invoke(cls, fun, type, true, args); + } + + /** Invoke method {@code fun} on Class {@code cls}. */ + static Expression invoke(Class cls, String fun, DataType type, Expression... args) { + return invoke(cls, fun, type, false, args); + } + + private static Expression invoke( + Class cls, String fun, DataType type, boolean propagateNull, Expression... args) { + try { + // To address breaking interfaces between various versions of Spark, expressions are + // created reflectively. This is fine as it's just needed once to create the query plan. + switch (STATIC_INVOKE_CONSTRUCTOR.getParameterCount()) { + case 6: + // Spark 3.1.x + return STATIC_INVOKE_CONSTRUCTOR.newInstance( + cls, type, fun, seqOf(args), propagateNull, true); + case 7: + // Spark 3.2.0 + return STATIC_INVOKE_CONSTRUCTOR.newInstance( + cls, type, fun, seqOf(args), emptyList(), propagateNull, true); + case 8: + // Spark 3.2.x, 3.3.x + return STATIC_INVOKE_CONSTRUCTOR.newInstance( + cls, type, fun, seqOf(args), emptyList(), propagateNull, true, true); + case 9: + // Spark 4.0.x: added Option> parameter + return STATIC_INVOKE_CONSTRUCTOR.newInstance( + cls, type, fun, seqOf(args), emptyList(), propagateNull, true, true, Option.empty()); + default: + throw new RuntimeException("Unsupported version of Spark"); + } + } catch (IllegalArgumentException | ReflectiveOperationException ex) { + throw new RuntimeException(ex); + } + } + + /** Invoke method {@code fun} on {@code obj} with provided {@code args}. */ + static Expression invoke( + Expression obj, String fun, DataType type, boolean nullable, Expression... args) { + try { + // To address breaking interfaces between various versions of Spark, expressions are + // created reflectively. This is fine as it's just needed once to create the query plan. + switch (STATIC_INVOKE_CONSTRUCTOR.getParameterCount()) { + case 6: + // Spark 3.1.x + return INVOKE_CONSTRUCTOR.newInstance(obj, fun, type, seqOf(args), false, nullable); + case 7: + // Spark 3.2.0 + return INVOKE_CONSTRUCTOR.newInstance( + obj, fun, type, seqOf(args), emptyList(), false, nullable); + case 8: + case 9: + // Spark 3.2.x, 3.3.x, 4.0.x: Invoke constructor is 8 params in all these versions + return INVOKE_CONSTRUCTOR.newInstance( + obj, fun, type, seqOf(args), emptyList(), false, nullable, true); + default: + throw new RuntimeException("Unsupported version of Spark"); + } + } catch (IllegalArgumentException | ReflectiveOperationException ex) { + throw new RuntimeException(ex); + } + } + + static Expression newInstance(Class cls, DataType type, Expression... args) { + try { + // To address breaking interfaces between various versions of Spark, expressions are + // created reflectively. This is fine as it's just needed once to create the query plan. + switch (NEW_INSTANCE_CONSTRUCTOR.getParameterCount()) { + case 5: + return NEW_INSTANCE_CONSTRUCTOR.newInstance(cls, seqOf(args), true, type, Option.empty()); + case 6: + // Spark 3.2.x, 3.3.x, 4.0.x: added immutable.Seq parameter + return NEW_INSTANCE_CONSTRUCTOR.newInstance( + cls, seqOf(args), emptyList(), true, type, Option.empty()); + default: + throw new RuntimeException("Unsupported version of Spark"); + } + } catch (IllegalArgumentException | ReflectiveOperationException ex) { + throw new RuntimeException(ex); + } + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java new file mode 100644 index 000000000000..41e3a7fe1590 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java @@ -0,0 +1,610 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.helpers; + +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderFactory.invoke; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderFactory.invokeIfNotNull; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.match; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.replace; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple; +import static org.apache.spark.sql.types.DataTypes.BinaryType; +import static org.apache.spark.sql.types.DataTypes.IntegerType; +import static org.apache.spark.sql.types.DataTypes.LongType; + +import java.math.BigDecimal; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow.IntervalWindowCoder; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.transforms.windowing.PaneInfo.PaneInfoCoder; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.catalyst.SerializerBuildHelper; +import org.apache.spark.sql.catalyst.SerializerBuildHelper.MapElementInformation; +import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal; +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; +import org.apache.spark.sql.catalyst.expressions.BoundReference; +import org.apache.spark.sql.catalyst.expressions.Coalesce; +import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct; +import org.apache.spark.sql.catalyst.expressions.EqualTo; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.GetStructField; +import org.apache.spark.sql.catalyst.expressions.If; +import org.apache.spark.sql.catalyst.expressions.IsNotNull; +import org.apache.spark.sql.catalyst.expressions.IsNull; +import org.apache.spark.sql.catalyst.expressions.Literal; +import org.apache.spark.sql.catalyst.expressions.Literal$; +import org.apache.spark.sql.catalyst.expressions.MapKeys; +import org.apache.spark.sql.catalyst.expressions.MapValues; +import org.apache.spark.sql.catalyst.expressions.objects.MapObjects$; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.ObjectType; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.util.MutablePair; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Instant; +import scala.Option; +import scala.Some; +import scala.Tuple2; +import scala.collection.IndexedSeq; +import scala.collection.JavaConverters; +import scala.collection.Seq; + +/** {@link Encoders} utility class. */ +public class EncoderHelpers { + private static final DataType OBJECT_TYPE = new ObjectType(Object.class); + private static final DataType TUPLE2_TYPE = new ObjectType(Tuple2.class); + private static final DataType WINDOWED_VALUE = new ObjectType(WindowedValue.class); + private static final DataType KV_TYPE = new ObjectType(KV.class); + private static final DataType MUTABLE_PAIR_TYPE = new ObjectType(MutablePair.class); + private static final DataType LIST_TYPE = new ObjectType(List.class); + + // Collections / maps of these types can be (de)serialized without (de)serializing each member + private static final Set> PRIMITIV_TYPES = + ImmutableSet.of( + Boolean.class, + Byte.class, + Short.class, + Integer.class, + Long.class, + Float.class, + Double.class); + + // Default encoders by class + private static final Map, Encoder> DEFAULT_ENCODERS = new ConcurrentHashMap<>(); + + // Factory for default encoders by class + private static @Nullable Encoder encoderFactory(Class cls) { + if (cls.equals(PaneInfo.class)) { + return paneInfoEncoder(); + } else if (cls.equals(GlobalWindow.class)) { + return binaryEncoder(GlobalWindow.Coder.INSTANCE, false); + } else if (cls.equals(IntervalWindow.class)) { + return binaryEncoder(IntervalWindowCoder.of(), false); + } else if (cls.equals(Instant.class)) { + return instantEncoder(); + } else if (cls.equals(String.class)) { + return Encoders.STRING(); + } else if (cls.equals(Boolean.class)) { + return Encoders.BOOLEAN(); + } else if (cls.equals(Integer.class)) { + return Encoders.INT(); + } else if (cls.equals(Long.class)) { + return Encoders.LONG(); + } else if (cls.equals(Float.class)) { + return Encoders.FLOAT(); + } else if (cls.equals(Double.class)) { + return Encoders.DOUBLE(); + } else if (cls.equals(BigDecimal.class)) { + return Encoders.DECIMAL(); + } else if (cls.equals(byte[].class)) { + return Encoders.BINARY(); + } else if (cls.equals(Byte.class)) { + return Encoders.BYTE(); + } else if (cls.equals(Short.class)) { + return Encoders.SHORT(); + } + return null; + } + + @SuppressWarnings({"nullness", "methodref.return"}) // computeIfAbsent allows null returns + private static @Nullable Encoder getOrCreateDefaultEncoder(Class cls) { + return (Encoder) DEFAULT_ENCODERS.computeIfAbsent(cls, EncoderHelpers::encoderFactory); + } + + /** Gets or creates a default {@link Encoder} for {@link T}. */ + public static Encoder encoderOf(Class cls) { + Encoder enc = getOrCreateDefaultEncoder(cls); + if (enc == null) { + throw new IllegalArgumentException("No default coder available for class " + cls); + } + return enc; + } + + /** + * Creates a Spark {@link Encoder} for {@link T} of {@link DataTypes#BinaryType BinaryType} + * delegating to a Beam {@link Coder} underneath. + * + *

    Note: For common types, if available, default Spark {@link Encoder}s are used instead. + * + * @param coder Beam {@link Coder} + */ + public static Encoder encoderFor(Coder coder) { + Encoder enc = getOrCreateDefaultEncoder(coder.getEncodedTypeDescriptor().getRawType()); + return enc != null ? enc : binaryEncoder(coder, true); + } + + /** + * Creates a Spark {@link Encoder} for {@link T} of {@link StructType} with fields {@code value}, + * {@code timestamp}, {@code window} and {@code pane}. + * + * @param value {@link Encoder} to encode field `{@code value}`. + * @param window {@link Encoder} to encode individual windows in field `{@code window}` + */ + public static Encoder> windowedValueEncoder( + Encoder value, Encoder window) { + Encoder timestamp = encoderOf(Instant.class); + Encoder paneInfo = encoderOf(PaneInfo.class); + Encoder> windows = collectionEncoder(window); + Expression serializer = + serializeWindowedValue(rootRef(WINDOWED_VALUE, true), value, timestamp, windows, paneInfo); + Expression deserializer = + deserializeWindowedValue( + rootCol(serializer.dataType()), value, timestamp, windows, paneInfo); + return EncoderFactory.create(serializer, deserializer, WindowedValue.class); + } + + /** + * Creates a one-of Spark {@link Encoder} of {@link StructType} where each alternative is + * represented as colum / field named by its index with a separate {@link Encoder} each. + * + *

    Externally this is represented as tuple {@code (index, data)} where an index corresponds to + * an {@link Encoder} in the provided list. + * + * @param encoders {@link Encoder}s for each alternative. + */ + public static Encoder> oneOfEncoder(List> encoders) { + Expression serializer = serializeOneOf(rootRef(TUPLE2_TYPE, true), encoders); + Expression deserializer = deserializeOneOf(rootCol(serializer.dataType()), encoders); + return EncoderFactory.create(serializer, deserializer, Tuple2.class); + } + + /** + * Creates a Spark {@link Encoder} for {@link KV} of {@link StructType} with fields {@code key} + * and {@code value}. + * + * @param key {@link Encoder} to encode field `{@code key}`. + * @param value {@link Encoder} to encode field `{@code value}` + */ + public static Encoder> kvEncoder(Encoder key, Encoder value) { + Expression serializer = serializeKV(rootRef(KV_TYPE, true), key, value); + Expression deserializer = deserializeKV(rootCol(serializer.dataType()), key, value); + return EncoderFactory.create(serializer, deserializer, KV.class); + } + + /** + * Creates a Spark {@link Encoder} of {@link ArrayType} for Java {@link Collection}s with nullable + * elements. + * + * @param enc {@link Encoder} to encode collection elements + */ + public static Encoder> collectionEncoder(Encoder enc) { + return collectionEncoder(enc, true); + } + + /** + * Creates a Spark {@link Encoder} of {@link ArrayType} for Java {@link Collection}s. + * + * @param enc {@link Encoder} to encode collection elements + * @param nullable Allow nullable collection elements + */ + public static Encoder> collectionEncoder(Encoder enc, boolean nullable) { + DataType type = new ObjectType(Collection.class); + Expression serializer = serializeSeq(rootRef(type, true), enc, nullable); + Expression deserializer = deserializeSeq(rootCol(serializer.dataType()), enc, nullable, true); + return EncoderFactory.create(serializer, deserializer, Collection.class); + } + + /** + * Creates a Spark {@link Encoder} of {@link MapType} that deserializes to {@link MapT}. + * + * @param key {@link Encoder} to encode keys + * @param value {@link Encoder} to encode values + * @param cls Specific class to use, supported are {@link HashMap} and {@link TreeMap} + */ + public static , K, V> Encoder mapEncoder( + Encoder key, Encoder value, Class cls) { + Expression serializer = mapSerializer(rootRef(new ObjectType(cls), true), key, value); + Expression deserializer = mapDeserializer(rootCol(serializer.dataType()), key, value, cls); + return EncoderFactory.create(serializer, deserializer, cls); + } + + /** + * Creates a Spark {@link Encoder} for Spark's {@link MutablePair} of {@link StructType} with + * fields `{@code _1}` and `{@code _2}`. + * + *

    This is intended to be used in places such as aggregators. + * + * @param enc1 {@link Encoder} to encode `{@code _1}` + * @param enc2 {@link Encoder} to encode `{@code _2}` + */ + public static Encoder> mutablePairEncoder( + Encoder enc1, Encoder enc2) { + Expression serializer = serializeMutablePair(rootRef(MUTABLE_PAIR_TYPE, true), enc1, enc2); + Expression deserializer = deserializeMutablePair(rootCol(serializer.dataType()), enc1, enc2); + return EncoderFactory.create(serializer, deserializer, MutablePair.class); + } + + /** + * Creates a Spark {@link Encoder} for {@link PaneInfo} of {@link DataTypes#BinaryType + * BinaryType}. + */ + private static Encoder paneInfoEncoder() { + DataType type = new ObjectType(PaneInfo.class); + return EncoderFactory.create( + invokeIfNotNull(Utils.class, "paneInfoToBytes", BinaryType, rootRef(type, false)), + invokeIfNotNull(Utils.class, "paneInfoFromBytes", type, rootCol(BinaryType)), + PaneInfo.class); + } + + /** + * Creates a Spark {@link Encoder} for Joda {@link Instant} of {@link DataTypes#LongType + * LongType}. + */ + private static Encoder instantEncoder() { + DataType type = new ObjectType(Instant.class); + Expression instant = rootRef(type, true); + Expression millis = rootCol(LongType); + return EncoderFactory.create( + nullSafe(instant, invoke(instant, "getMillis", LongType, false)), + nullSafe(millis, invoke(Instant.class, "ofEpochMilli", type, millis)), + Instant.class); + } + + /** + * Creates a Spark {@link Encoder} for {@link T} of {@link DataTypes#BinaryType BinaryType} + * delegating to a Beam {@link Coder} underneath. + * + * @param coder Beam {@link Coder} + * @param nullable If to allow nullable items + */ + private static Encoder binaryEncoder(Coder coder, boolean nullable) { + Literal litCoder = lit(coder, Coder.class); + // T could be private, use OBJECT_TYPE for code generation to not risk an IllegalAccessError + return EncoderFactory.create( + invokeIfNotNull( + CoderHelpers.class, + "toByteArray", + BinaryType, + rootRef(OBJECT_TYPE, nullable), + litCoder), + invokeIfNotNull( + CoderHelpers.class, "fromByteArray", OBJECT_TYPE, rootCol(BinaryType), litCoder), + coder.getEncodedTypeDescriptor().getRawType()); + } + + private static Expression serializeWindowedValue( + Expression in, + Encoder valueEnc, + Encoder timestampEnc, + Encoder> windowsEnc, + Encoder paneEnc) { + return serializerObject( + in, + tuple("value", serializeField(in, valueEnc, "getValue")), + tuple("timestamp", serializeField(in, timestampEnc, "getTimestamp")), + tuple("windows", serializeField(in, windowsEnc, "getWindows")), + tuple("paneInfo", serializeField(in, paneEnc, "getPaneInfo"))); + } + + private static Expression serializerObject(Expression in, Tuple2... fields) { + return SerializerBuildHelper.createSerializerForObject(in, seqOf(fields)); + } + + private static Expression deserializeWindowedValue( + Expression in, + Encoder valueEnc, + Encoder timestampEnc, + Encoder> windowsEnc, + Encoder paneEnc) { + Expression value = deserializeField(in, valueEnc, 0, "value"); + Expression windows = deserializeField(in, windowsEnc, 2, "windows"); + Expression timestamp = deserializeField(in, timestampEnc, 1, "timestamp"); + Expression paneInfo = deserializeField(in, paneEnc, 3, "paneInfo"); + // set timestamp to end of window (maxTimestamp) if null + timestamp = + ifNotNull(timestamp, invoke(Utils.class, "maxTimestamp", timestamp.dataType(), windows)); + Expression[] fields = new Expression[] {value, timestamp, windows, paneInfo}; + + return nullSafe(paneInfo, invoke(WindowedValues.class, "of", WINDOWED_VALUE, fields)); + } + + private static Expression serializeMutablePair( + Expression in, Encoder enc1, Encoder enc2) { + return serializerObject( + in, + tuple("_1", serializeField(in, enc1, "_1")), + tuple("_2", serializeField(in, enc2, "_2"))); + } + + private static Expression deserializeMutablePair( + Expression in, Encoder enc1, Encoder enc2) { + Expression field1 = deserializeField(in, enc1, 0, "_1"); + Expression field2 = deserializeField(in, enc2, 1, "_2"); + return invoke(MutablePair.class, "apply", MUTABLE_PAIR_TYPE, field1, field2); + } + + private static Expression serializeKV( + Expression in, Encoder keyEnc, Encoder valueEnc) { + return serializerObject( + in, + tuple("key", serializeField(in, keyEnc, "getKey")), + tuple("value", serializeField(in, valueEnc, "getValue"))); + } + + private static Expression deserializeKV( + Expression in, Encoder keyEnc, Encoder valueEnc) { + Expression key = deserializeField(in, keyEnc, 0, "key"); + Expression value = deserializeField(in, valueEnc, 1, "value"); + return invoke(KV.class, "of", KV_TYPE, key, value); + } + + public static Expression serializeOneOf(Expression in, List> encoders) { + Expression type = invoke(in, "_1", IntegerType, false); + Expression[] args = new Expression[encoders.size() * 2]; + for (int i = 0; i < encoders.size(); i++) { + args[i * 2] = lit(String.valueOf(i)); + args[i * 2 + 1] = serializeOneOfField(in, type, encoders.get(i), i); + } + return new CreateNamedStruct(seqOf(args)); + } + + public static Expression deserializeOneOf(Expression in, List> encoders) { + Expression[] args = new Expression[encoders.size()]; + for (int i = 0; i < encoders.size(); i++) { + args[i] = deserializeOneOfField(in, encoders.get(i), i); + } + return new Coalesce(seqOf(args)); + } + + private static Expression serializeOneOfField( + Expression in, Expression type, Encoder enc, int typeIdx) { + Expression litNull = lit(null, serializedType(enc)); + Expression value = invoke(in, "_2", deserializedType(enc), false); + return new If(new EqualTo(type, lit(typeIdx)), serialize(value, enc), litNull); + } + + private static Expression deserializeOneOfField(Expression in, Encoder enc, int idx) { + GetStructField field = new GetStructField(in, idx, Option.empty()); + Expression litNull = lit(null, TUPLE2_TYPE); + Expression newTuple = + EncoderFactory.newInstance(Tuple2.class, TUPLE2_TYPE, lit(idx), deserialize(field, enc)); + return new If(new IsNull(field), litNull, newTuple); + } + + private static Expression serializeField(Expression in, Encoder enc, String getterName) { + Expression ref = serializer(enc).collect(match(BoundReference.class)).head(); + return serialize(invoke(in, getterName, ref.dataType(), ref.nullable()), enc); + } + + private static Expression deserializeField( + Expression in, Encoder enc, int idx, String name) { + return deserialize(new GetStructField(in, idx, new Some<>(name)), enc); + } + + // Note: Currently this doesn't support nullable primitive values + private static Expression mapSerializer(Expression map, Encoder key, Encoder value) { + DataType keyType = deserializedType(key); + DataType valueType = deserializedType(value); + return SerializerBuildHelper.createSerializerForMap( + map, + new MapElementInformation(keyType, false, e -> serialize(e, key)), + new MapElementInformation(valueType, false, e -> serialize(e, value))); + } + + private static , K, V> Expression mapDeserializer( + Expression in, Encoder key, Encoder value, Class cls) { + Preconditions.checkArgument(cls.isAssignableFrom(HashMap.class) || cls.equals(TreeMap.class)); + Expression keys = deserializeSeq(new MapKeys(in), key, false, false); + Expression values = deserializeSeq(new MapValues(in), value, false, false); + String fn = cls.equals(TreeMap.class) ? "toTreeMap" : "toMap"; + return invoke( + Utils.class, fn, new ObjectType(cls), keys, values, mapItemType(key), mapItemType(value)); + } + + // serialized type for primitive types (avoid boxing!), otherwise the deserialized type + private static Literal mapItemType(Encoder enc) { + return lit(isPrimitiveEnc(enc) ? serializedType(enc) : deserializedType(enc), DataType.class); + } + + private static Expression serializeSeq(Expression in, Encoder enc, boolean nullable) { + if (isPrimitiveEnc(enc)) { + Expression array = invoke(in, "toArray", new ObjectType(Object[].class), false); + return SerializerBuildHelper.createSerializerForGenericArray( + array, serializedType(enc), nullable); + } + Expression seq = invoke(Utils.class, "toSeq", new ObjectType(Seq.class), in); + return MapObjects$.MODULE$.apply( + exp -> serialize(exp, enc), seq, deserializedType(enc), nullable, Option.empty()); + } + + private static Expression deserializeSeq( + Expression in, Encoder enc, boolean nullable, boolean exposeAsJava) { + DataType type = serializedType(enc); // input type is the serializer result type + if (isPrimitiveEnc(enc)) { + // Spark may reuse unsafe array data, if directly exposed it must be copied before + return exposeAsJava + ? invoke(Utils.class, "copyToList", LIST_TYPE, in, lit(type, DataType.class)) + : in; + } + Option> optCls = exposeAsJava ? Option.apply(List.class) : Option.empty(); + // MapObjects will always copy + return MapObjects$.MODULE$.apply(exp -> deserialize(exp, enc), in, type, nullable, optCls); + } + + private static boolean isPrimitiveEnc(Encoder enc) { + return PRIMITIV_TYPES.contains(enc.clsTag().runtimeClass()); + } + + private static Expression serialize(Expression input, Encoder enc) { + return serializer(enc).transformUp(replace(BoundReference.class, input)); + } + + private static Expression deserialize(Expression input, Encoder enc) { + return deserializer(enc).transformUp(replace(GetColumnByOrdinal.class, input)); + } + + /** + * Wraps an {@link Encoder} as an {@link ExpressionEncoder}. In Spark 4.x, built-in encoders (e.g. + * {@code Encoders.INT()}) are {@link AgnosticEncoder} subclasses rather than {@link + * ExpressionEncoder}s, so we convert them on demand. + */ + @SuppressWarnings("unchecked") + private static ExpressionEncoder toExpressionEncoder(Encoder enc) { + if (enc instanceof ExpressionEncoder) { + return (ExpressionEncoder) enc; + } else if (enc instanceof AgnosticEncoder) { + return ExpressionEncoder.apply((AgnosticEncoder) enc); + } + throw new IllegalArgumentException("Unsupported encoder type: " + enc.getClass()); + } + + private static Expression serializer(Encoder enc) { + return toExpressionEncoder(enc).objSerializer(); + } + + private static Expression deserializer(Encoder enc) { + return toExpressionEncoder(enc).objDeserializer(); + } + + private static DataType serializedType(Encoder enc) { + return toExpressionEncoder(enc).objSerializer().dataType(); + } + + private static DataType deserializedType(Encoder enc) { + return toExpressionEncoder(enc).objDeserializer().dataType(); + } + + private static Expression rootRef(DataType dt, boolean nullable) { + return new BoundReference(0, dt, nullable); + } + + private static Expression rootCol(DataType dt) { + return new GetColumnByOrdinal(0, dt); + } + + private static Expression nullSafe(Expression in, Expression out) { + return new If(new IsNull(in), lit(null, out.dataType()), out); + } + + private static Expression ifNotNull(Expression expr, Expression otherwise) { + return new If(new IsNotNull(expr), expr, otherwise); + } + + private static Expression lit(T t) { + return Literal$.MODULE$.apply(t); + } + + @SuppressWarnings("nullness") // literal NULL is allowed + private static Expression lit(@Nullable T t, DataType dataType) { + return new Literal(t, dataType); + } + + private static Literal lit(T obj, Class cls) { + return Literal.fromObject(obj, new ObjectType(cls)); + } + + /** Encoder / expression utils that are called from generated code. */ + public static class Utils { + + public static PaneInfo paneInfoFromBytes(byte[] bytes) { + return CoderHelpers.fromByteArray(bytes, PaneInfoCoder.of()); + } + + public static byte[] paneInfoToBytes(PaneInfo paneInfo) { + return CoderHelpers.toByteArray(paneInfo, PaneInfoCoder.of()); + } + + /** The end of the only window (max timestamp). */ + public static Instant maxTimestamp(Iterable windows) { + return Iterables.getOnlyElement(windows).maxTimestamp(); + } + + public static List copyToList(ArrayData arrayData, DataType type) { + // Note, this could be optimized for primitive arrays (if elements are not nullable) using + // Ints.asList(arrayData.toIntArray()) and similar + return Arrays.asList(arrayData.toObjectArray(type)); + } + + public static Seq toSeq(ArrayData arrayData) { + return arrayData.toSeq(OBJECT_TYPE); + } + + public static Seq toSeq(Collection col) { + if (col instanceof List) { + return JavaConverters.asScalaBuffer((List) col); + } + return JavaConverters.collectionAsScalaIterable(col).toSeq(); + } + + public static TreeMap toTreeMap( + ArrayData keys, ArrayData values, DataType keyType, DataType valueType) { + return toMap(new TreeMap<>(), keys, values, keyType, valueType); + } + + public static HashMap toMap( + ArrayData keys, ArrayData values, DataType keyType, DataType valueType) { + HashMap map = Maps.newHashMapWithExpectedSize(keys.numElements()); + return toMap(map, keys, values, keyType, valueType); + } + + private static > MapT toMap( + MapT map, ArrayData keys, ArrayData values, DataType keyType, DataType valueType) { + IndexedSeq keysSeq = keys.toSeq(keyType); + IndexedSeq valuesSeq = values.toSeq(valueType); + for (int i = 0; i < keysSeq.size(); i++) { + map.put(keysSeq.apply(i), valuesSeq.apply(i)); + } + return map; + } + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderProvider.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderProvider.java new file mode 100644 index 000000000000..f749f1439409 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderProvider.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.helpers; + +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.kvEncoder; + +import java.util.function.Function; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.values.KV; +import org.apache.spark.sql.Encoder; + +@Internal +public interface EncoderProvider { + interface Factory extends Function, Encoder> { + Factory INSTANCE = EncoderHelpers::encoderFor; + } + + Encoder encoderOf(Coder coder, Factory factory); + + default Encoder encoderOf(Coder coder) { + return coder instanceof KvCoder + ? (Encoder) kvEncoderOf((KvCoder) coder) + : encoderOf(coder, encoderFactory()); + } + + default Encoder> kvEncoderOf(KvCoder coder) { + return encoderOf(coder, c -> kvEncoder(keyEncoderOf(coder), valueEncoderOf(coder))); + } + + default Encoder keyEncoderOf(KvCoder coder) { + return encoderOf(coder.getKeyCoder(), encoderFactory()); + } + + default Encoder valueEncoderOf(KvCoder coder) { + return encoderOf(coder.getValueCoder(), encoderFactory()); + } + + default Factory encoderFactory() { + return (Factory) Factory.INSTANCE; + } +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/package-info.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/package-info.java new file mode 100644 index 000000000000..7079eadfbe26 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Internal helpers to translate Beam pipelines to Spark streaming. */ +package org.apache.beam.runners.spark.structuredstreaming.translation.helpers; diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/package-info.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/package-info.java new file mode 100644 index 000000000000..2754ac500039 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Internal translators for running Beam pipelines on Spark. */ +package org.apache.beam.runners.spark.structuredstreaming.translation; diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop.java new file mode 100644 index 000000000000..175e144d6506 --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.utils; + +import java.io.Serializable; +import org.checkerframework.checker.nullness.qual.NonNull; +import scala.Function1; +import scala.Function2; +import scala.PartialFunction; +import scala.Tuple2; +import scala.collection.Iterator; +import scala.collection.JavaConverters; +import scala.collection.Seq; +import scala.collection.immutable.List; +import scala.collection.immutable.Nil$; + +/** Utilities for easier interoperability with the Spark Scala API. */ +public class ScalaInterop { + private ScalaInterop() {} + + public static scala.collection.immutable.Seq seqOf(T... t) { + return (scala.collection.immutable.Seq) + JavaConverters.asScalaBuffer(java.util.Arrays.asList(t)).toList(); + } + + public static List concat(List a, List b) { + return b.$colon$colon$colon(a); + } + + public static Seq listOf(T t) { + return emptyList().$colon$colon(t); + } + + public static List emptyList() { + return (List) Nil$.MODULE$; + } + + /** Scala {@link Iterator} of Java {@link Iterable}. */ + public static Iterator scalaIterator(Iterable iterable) { + return scalaIterator(iterable.iterator()); + } + + /** Scala {@link Iterator} of Java {@link java.util.Iterator}. */ + public static Iterator scalaIterator(java.util.Iterator it) { + return JavaConverters.asScalaIterator(it); + } + + /** Java {@link java.util.Iterator} of Scala {@link Iterator}. */ + public static java.util.Iterator javaIterator(Iterator it) { + return JavaConverters.asJavaIterator(it); + } + + public static Tuple2 tuple(T1 t1, T2 t2) { + return new Tuple2<>(t1, t2); + } + + public static PartialFunction replace( + Class clazz, T replace) { + return new PartialFunction() { + + @Override + public boolean isDefinedAt(T x) { + return clazz.isAssignableFrom(x.getClass()); + } + + @Override + public T apply(T x) { + return replace; + } + }; + } + + public static PartialFunction match(Class clazz) { + return new PartialFunction() { + + @Override + public boolean isDefinedAt(T x) { + return clazz.isAssignableFrom(x.getClass()); + } + + @Override + public V apply(T x) { + return (V) x; + } + }; + } + + public static Fun1 fun1(Fun1 fun) { + return fun; + } + + public static Fun2 fun2(Fun2 fun) { + return fun; + } + + public interface Fun1 extends Function1, Serializable {} + + public interface Fun2 extends Function2, Serializable {} +} diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/package-info.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/package-info.java new file mode 100644 index 000000000000..470bef88fb4b --- /dev/null +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Internal utils to translate Beam pipelines to Spark streaming. */ +package org.apache.beam.runners.spark.structuredstreaming.translation.utils; diff --git a/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/SparkSessionRule.java b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/SparkSessionRule.java new file mode 100644 index 000000000000..278fd012d77e --- /dev/null +++ b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/SparkSessionRule.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming; + +import static java.util.stream.Collectors.toMap; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.Map; +import javax.annotation.Nullable; +import org.apache.beam.runners.spark.structuredstreaming.translation.SparkSessionFactory; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.values.KV; +import org.apache.spark.sql.SparkSession; +import org.junit.rules.ExternalResource; +import org.junit.rules.TestRule; +import org.junit.runner.Description; +import org.junit.runners.model.Statement; + +public class SparkSessionRule extends ExternalResource implements Serializable { + private transient SparkSession.Builder builder; + private transient @Nullable SparkSession session = null; + + public SparkSessionRule(String sparkMaster, Map sparkConfig) { + builder = SparkSessionFactory.sessionBuilder(sparkMaster); + sparkConfig.forEach(builder::config); + } + + public SparkSessionRule(KV... sparkConfig) { + this("local[2]", sparkConfig); + } + + public SparkSessionRule(String sparkMaster, KV... sparkConfig) { + this(sparkMaster, Arrays.stream(sparkConfig).collect(toMap(KV::getKey, KV::getValue))); + } + + public SparkSession getSession() { + if (session == null) { + throw new IllegalStateException("SparkSession not available"); + } + return session; + } + + public PipelineOptions createPipelineOptions() { + return configure(TestPipeline.testingPipelineOptions()); + } + + public PipelineOptions configure(PipelineOptions options) { + SparkStructuredStreamingPipelineOptions opts = + options.as(SparkStructuredStreamingPipelineOptions.class); + opts.setUseActiveSparkSession(true); + opts.setRunner(SparkStructuredStreamingRunner.class); + opts.setTestMode(true); + return opts; + } + + /** {@code true} if sessions contains cached Datasets or RDDs. */ + public boolean hasCachedData() { + return !session.sharedState().cacheManager().isEmpty() + || !session.sparkContext().getPersistentRDDs().isEmpty(); + } + + public TestRule clearCache() { + return new ExternalResource() { + @Override + protected void after() { + // clear cached datasets + session.sharedState().cacheManager().clearCache(); + // clear cached RDDs + session.sparkContext().getPersistentRDDs().foreach(fun1(t -> t._2.unpersist(true))); + } + }; + } + + @Override + public Statement apply(Statement base, Description description) { + builder.appName(description.getDisplayName()); + return super.apply(base, description); + } + + @Override + protected void before() throws Throwable { + session = builder.getOrCreate(); + } + + @Override + protected void after() { + getSession().stop(); + session = null; + } +} diff --git a/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunnerRegistrarTest.java b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunnerRegistrarTest.java new file mode 100644 index 000000000000..bc7b561eea2d --- /dev/null +++ b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunnerRegistrarTest.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +import java.util.ServiceLoader; +import org.apache.beam.sdk.options.PipelineOptionsRegistrar; +import org.apache.beam.sdk.runners.PipelineRunnerRegistrar; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Test {@link SparkStructuredStreamingRunnerRegistrar}. */ +@RunWith(JUnit4.class) +public class SparkStructuredStreamingRunnerRegistrarTest { + @Test + public void testOptions() { + assertEquals( + ImmutableList.of(SparkStructuredStreamingPipelineOptions.class), + new SparkStructuredStreamingRunnerRegistrar.Options().getPipelineOptions()); + } + + @Test + public void testRunners() { + assertEquals( + ImmutableList.of(SparkStructuredStreamingRunner.class), + new SparkStructuredStreamingRunnerRegistrar.Runner().getPipelineRunners()); + } + + @Test + public void testServiceLoaderForOptions() { + for (PipelineOptionsRegistrar registrar : + Lists.newArrayList(ServiceLoader.load(PipelineOptionsRegistrar.class).iterator())) { + if (registrar instanceof SparkStructuredStreamingRunnerRegistrar.Options) { + return; + } + } + fail("Expected to find " + SparkStructuredStreamingRunnerRegistrar.Options.class); + } + + @Test + public void testServiceLoaderForRunner() { + for (PipelineRunnerRegistrar registrar : + Lists.newArrayList(ServiceLoader.load(PipelineRunnerRegistrar.class).iterator())) { + if (registrar instanceof SparkStructuredStreamingRunnerRegistrar.Runner) { + return; + } + } + fail("Expected to find " + SparkStructuredStreamingRunnerRegistrar.Runner.class); + } +} diff --git a/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/StructuredStreamingPipelineStateTest.java b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/StructuredStreamingPipelineStateTest.java new file mode 100644 index 000000000000..b44df7bf101b --- /dev/null +++ b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/StructuredStreamingPipelineStateTest.java @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming; + +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.fail; + +import java.io.Serializable; +import org.apache.beam.runners.spark.io.CreateStream; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SimpleFunction; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.joda.time.Duration; +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestName; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** This suite tests that various scenarios result in proper states of the pipeline. */ +@RunWith(JUnit4.class) +public class StructuredStreamingPipelineStateTest implements Serializable { + + private static class MyCustomException extends RuntimeException { + + MyCustomException(final String message) { + super(message); + } + } + + private final transient SparkStructuredStreamingPipelineOptions options = + PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class); + + @Rule public transient TestName testName = new TestName(); + + private static final String FAILED_THE_BATCH_INTENTIONALLY = "Failed the batch intentionally"; + + private ParDo.SingleOutput printParDo(final String prefix) { + return ParDo.of( + new DoFn() { + + @ProcessElement + public void processElement(final ProcessContext c) { + System.out.println(prefix + " " + c.element()); + } + }); + } + + private PTransform> getValues( + final SparkStructuredStreamingPipelineOptions options) { + final boolean doNotSyncWithWatermark = false; + return options.isStreaming() + ? CreateStream.of(StringUtf8Coder.of(), Duration.millis(1), doNotSyncWithWatermark) + .nextBatch("one", "two") + : Create.of("one", "two"); + } + + private SparkStructuredStreamingPipelineOptions getStreamingOptions() { + options.setRunner(SparkStructuredStreamingRunner.class); + options.setStreaming(true); + return options; + } + + private SparkStructuredStreamingPipelineOptions getBatchOptions() { + options.setRunner(SparkStructuredStreamingRunner.class); + options.setStreaming(false); // explicit because options is reused throughout the test. + return options; + } + + private Pipeline getPipeline(final SparkStructuredStreamingPipelineOptions options) { + + final Pipeline pipeline = Pipeline.create(options); + final String name = testName.getMethodName() + "(isStreaming=" + options.isStreaming() + ")"; + + pipeline.apply(getValues(options)).setCoder(StringUtf8Coder.of()).apply(printParDo(name)); + + return pipeline; + } + + private void testFailedPipeline(final SparkStructuredStreamingPipelineOptions options) + throws Exception { + + SparkStructuredStreamingPipelineResult result = null; + + try { + final Pipeline pipeline = Pipeline.create(options); + pipeline + .apply(getValues(options)) + .setCoder(StringUtf8Coder.of()) + .apply( + MapElements.via( + new SimpleFunction() { + + @Override + public String apply(final String input) { + throw new MyCustomException(FAILED_THE_BATCH_INTENTIONALLY); + } + })); + + result = (SparkStructuredStreamingPipelineResult) pipeline.run(); + result.waitUntilFinish(); + } catch (final Exception e) { + assertThat(e, instanceOf(Pipeline.PipelineExecutionException.class)); + assertThat(e.getCause(), instanceOf(MyCustomException.class)); + assertThat(e.getCause().getMessage(), is(FAILED_THE_BATCH_INTENTIONALLY)); + assertThat(result.getState(), is(PipelineResult.State.FAILED)); + result.cancel(); + return; + } + + fail("An injected failure did not affect the pipeline as expected."); + } + + private void testTimeoutPipeline(final SparkStructuredStreamingPipelineOptions options) + throws Exception { + + final Pipeline pipeline = getPipeline(options); + + final SparkStructuredStreamingPipelineResult result = + (SparkStructuredStreamingPipelineResult) pipeline.run(); + + result.waitUntilFinish(Duration.millis(1)); + + assertThat(result.getState(), is(PipelineResult.State.RUNNING)); + + result.cancel(); + } + + private void testCanceledPipeline(final SparkStructuredStreamingPipelineOptions options) + throws Exception { + + final Pipeline pipeline = getPipeline(options); + + final SparkStructuredStreamingPipelineResult result = + (SparkStructuredStreamingPipelineResult) pipeline.run(); + + result.cancel(); + + assertThat(result.getState(), is(PipelineResult.State.CANCELLED)); + } + + private void testRunningPipeline(final SparkStructuredStreamingPipelineOptions options) + throws Exception { + + final Pipeline pipeline = getPipeline(options); + + final SparkStructuredStreamingPipelineResult result = + (SparkStructuredStreamingPipelineResult) pipeline.run(); + + assertThat(result.getState(), is(PipelineResult.State.RUNNING)); + + result.cancel(); + } + + @Ignore("TODO: Reactivate with streaming.") + @Test + public void testStreamingPipelineRunningState() throws Exception { + testRunningPipeline(getStreamingOptions()); + } + + @Test + public void testBatchPipelineRunningState() throws Exception { + testRunningPipeline(getBatchOptions()); + } + + @Ignore("TODO: Reactivate with streaming.") + @Test + public void testStreamingPipelineCanceledState() throws Exception { + testCanceledPipeline(getStreamingOptions()); + } + + @Test + public void testBatchPipelineCanceledState() throws Exception { + testCanceledPipeline(getBatchOptions()); + } + + @Ignore("TODO: Reactivate with streaming.") + @Test + public void testStreamingPipelineFailedState() throws Exception { + testFailedPipeline(getStreamingOptions()); + } + + @Test + public void testBatchPipelineFailedState() throws Exception { + testFailedPipeline(getBatchOptions()); + } + + @Ignore("TODO: Reactivate with streaming.") + @Test + public void testStreamingPipelineTimeoutState() throws Exception { + testTimeoutPipeline(getStreamingOptions()); + } + + @Test + public void testBatchPipelineTimeoutState() throws Exception { + testTimeoutPipeline(getBatchOptions()); + } +} diff --git a/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/InMemoryMetrics.java b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/InMemoryMetrics.java new file mode 100644 index 000000000000..69df5768e5d0 --- /dev/null +++ b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/InMemoryMetrics.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.aggregators.metrics.sink; + +import com.codahale.metrics.Gauge; +import com.codahale.metrics.MetricFilter; +import com.codahale.metrics.MetricRegistry; +import java.util.Collection; +import java.util.Properties; +import java.util.concurrent.atomic.AtomicReference; +import org.apache.beam.runners.spark.structuredstreaming.metrics.WithMetricsSupport; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.spark.metrics.sink.Sink; + +/** An in-memory {@link Sink} implementation for tests. */ +public class InMemoryMetrics implements Sink { + + private static final AtomicReference extendedMetricsRegistry = + new AtomicReference<>(); + private static final AtomicReference internalMetricRegistry = + new AtomicReference<>(); + + // Constructor for Spark 3.1 + @SuppressWarnings("UnusedParameters") + public InMemoryMetrics( + final Properties properties, + final MetricRegistry metricRegistry, + final org.apache.spark.SecurityManager securityMgr) { + extendedMetricsRegistry.set(WithMetricsSupport.forRegistry(metricRegistry)); + internalMetricRegistry.set(metricRegistry); + } + + // Constructor for Spark >= 3.2 + @SuppressWarnings("UnusedParameters") + public InMemoryMetrics(final Properties properties, final MetricRegistry metricRegistry) { + extendedMetricsRegistry.set(WithMetricsSupport.forRegistry(metricRegistry)); + internalMetricRegistry.set(metricRegistry); + } + + @SuppressWarnings({"TypeParameterUnusedInFormals", "rawtypes"}) // because of getGauges + public static T valueOf(final String name) { + // this might fail in case we have multiple aggregators with the same suffix after + // the last dot, but it should be good enough for tests. + WithMetricsSupport extended = extendedMetricsRegistry.get(); + if (extended != null) { + Collection matches = extended.getGauges((n, m) -> n.endsWith(name)).values(); + return matches.isEmpty() ? null : (T) Iterables.getOnlyElement(matches).getValue(); + } else { + return null; + } + } + + @SuppressWarnings("WeakerAccess") + public static void clearAll() { + MetricRegistry internal = internalMetricRegistry.get(); + if (internal != null) { + internal.removeMatching(MetricFilter.ALL); + } + } + + @Override + public void start() {} + + @Override + public void stop() {} + + @Override + public void report() {} +} diff --git a/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/InMemoryMetricsSinkRule.java b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/InMemoryMetricsSinkRule.java new file mode 100644 index 000000000000..f1b996eaf21b --- /dev/null +++ b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/InMemoryMetricsSinkRule.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.aggregators.metrics.sink; + +import org.junit.rules.ExternalResource; + +/** A rule that cleans the {@link InMemoryMetrics} after the tests has finished. */ +class InMemoryMetricsSinkRule extends ExternalResource { + @Override + protected void before() throws Throwable { + InMemoryMetrics.clearAll(); + } +} diff --git a/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/SparkMetricsSinkTest.java b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/SparkMetricsSinkTest.java new file mode 100644 index 000000000000..603221de078b --- /dev/null +++ b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/SparkMetricsSinkTest.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.aggregators.metrics.sink; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; + +import org.apache.beam.runners.spark.structuredstreaming.SparkSessionRule; +import org.apache.beam.runners.spark.structuredstreaming.examples.WordCount; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExternalResource; + +/** A test that verifies Beam metrics are reported to Spark's metrics sink in batch mode. */ +public class SparkMetricsSinkTest { + + @ClassRule + public static final SparkSessionRule SESSION = + new SparkSessionRule( + KV.of("spark.metrics.conf.*.sink.memory.class", InMemoryMetrics.class.getName())); + + @Rule public final ExternalResource inMemoryMetricsSink = new InMemoryMetricsSinkRule(); + + @Rule + public final TestPipeline pipeline = TestPipeline.fromOptions(SESSION.createPipelineOptions()); + + private static final ImmutableList WORDS = + ImmutableList.of("hi there", "hi", "hi sue bob", "hi sue", "", "bob hi"); + private static final ImmutableSet EXPECTED_COUNTS = + ImmutableSet.of("hi: 5", "there: 1", "sue: 2", "bob: 2"); + + @Test + public void testInBatchMode() throws Exception { + assertThat(InMemoryMetrics.valueOf("emptyLines"), is(nullValue())); + + final PCollection output = + pipeline + .apply(Create.of(WORDS).withCoder(StringUtf8Coder.of())) + .apply(new WordCount.CountWords()) + .apply(MapElements.via(new WordCount.FormatAsTextFn())); + + PAssert.that(output).containsInAnyOrder(EXPECTED_COUNTS); + pipeline.run(); + + assertThat(InMemoryMetrics.valueOf("emptyLines"), is(1d)); + } +} diff --git a/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkBeamMetricTest.java b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkBeamMetricTest.java new file mode 100644 index 000000000000..fd0aa35e5c8d --- /dev/null +++ b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkBeamMetricTest.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.metrics; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; + +import org.apache.beam.sdk.metrics.MetricKey; +import org.apache.beam.sdk.metrics.MetricName; +import org.apache.beam.sdk.metrics.MetricResult; +import org.junit.Test; + +/** Test BeamMetric. */ +public class SparkBeamMetricTest { + @Test + public void testRenderName() { + MetricResult metricResult = + MetricResult.create( + MetricKey.create( + "myStep.one.two(three)", MetricName.named("myNameSpace//", "myName()")), + 123, + 456); + String renderedName = SparkBeamMetric.renderName("", metricResult); + assertThat( + "Metric name was not rendered correctly", + renderedName, + equalTo("myStep_one_two_three.myNameSpace__.myName__")); + } + + @Test + public void testRenderNameWithPrefix() { + MetricResult metricResult = + MetricResult.create( + MetricKey.create( + "myStep.one.two(three)", MetricName.named("myNameSpace//", "myName()")), + 123, + 456); + String renderedName = SparkBeamMetric.renderName("prefix", metricResult); + assertThat( + "Metric name was not rendered correctly", + renderedName, + equalTo("prefix.myStep_one_two_three.myNameSpace__.myName__")); + } +} diff --git a/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorsTest.java b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorsTest.java new file mode 100644 index 000000000000..fa5312684fc1 --- /dev/null +++ b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorsTest.java @@ -0,0 +1,371 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderFor; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.windowedValueEncoder; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.joda.time.Duration.standardMinutes; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Map; +import java.util.TreeMap; +import java.util.stream.Collectors; +import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.transforms.windowing.Sessions; +import org.apache.beam.sdk.transforms.windowing.SlidingWindows; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; +import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.transforms.windowing.WindowMappingFn; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Streams; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.expressions.Aggregator; +import org.apache.spark.util.MutablePair; +import org.hamcrest.Matcher; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.experimental.runners.Enclosed; +import org.junit.runner.RunWith; + +@RunWith(Enclosed.class) +public class AggregatorsTest { + + // just something easy readable + private static final Instant NOW = Instant.parse("2000-01-01T00:00Z"); + + /** Tests for NonMergingWindowedAggregator in {@link Aggregators}. */ + public static class NonMergingWindowedAggregatorTest { + + private SlidingWindows sliding = + SlidingWindows.of(standardMinutes(15)).every(standardMinutes(5)); + + private Aggregator< + WindowedValue, + Map>, + Collection>> + agg = windowedAgg(sliding); + + @Test + public void testReduce() { + Map> acc; + + acc = agg.reduce(agg.zero(), windowedValue(1, at(10))); + assertThat( + acc, + equalsToMap( + KV.of(intervalWindow(0, 15), pair(at(10), 1)), + KV.of(intervalWindow(5, 20), pair(at(10), 1)), + KV.of(intervalWindow(10, 25), pair(at(10), 1)))); + + acc = agg.reduce(acc, windowedValue(2, at(16))); + assertThat( + acc, + equalsToMap( + KV.of(intervalWindow(0, 15), pair(at(10), 1)), + KV.of(intervalWindow(5, 20), pair(at(16), 3)), + KV.of(intervalWindow(10, 25), pair(at(16), 3)), + KV.of(intervalWindow(15, 30), pair(at(16), 2)))); + } + + @Test + public void testMerge() { + Map> acc; + + assertThat(agg.merge(agg.zero(), agg.zero()), equalTo(agg.zero())); + + acc = mapOf(KV.of(intervalWindow(0, 15), pair(at(0), 1))); + + assertThat(agg.merge(acc, agg.zero()), equalTo(acc)); + assertThat(agg.merge(agg.zero(), acc), equalTo(acc)); + + acc = agg.merge(acc, acc); + assertThat(acc, equalsToMap(KV.of(intervalWindow(0, 15), pair(at(0), 1 + 1)))); + + acc = agg.merge(acc, mapOf(KV.of(intervalWindow(5, 20), pair(at(5), 3)))); + assertThat( + acc, + equalsToMap( + KV.of(intervalWindow(0, 15), pair(at(0), 1 + 1)), + KV.of(intervalWindow(5, 20), pair(at(5), 3)))); + + acc = agg.merge(mapOf(KV.of(intervalWindow(10, 25), pair(at(10), 4))), acc); + assertThat( + acc, + equalsToMap( + KV.of(intervalWindow(0, 15), pair(at(0), 1 + 1)), + KV.of(intervalWindow(5, 20), pair(at(5), 3)), + KV.of(intervalWindow(10, 25), pair(at(10), 4)))); + } + + private WindowedValue windowedValue(Integer value, Instant ts) { + return WindowedValues.of(value, ts, sliding.assignWindows(ts), PaneInfo.NO_FIRING); + } + } + + /** + * Shared implementation of tests for SessionsAggregator and MergingWindowedAggregator in {@link + * Aggregators}. + */ + public abstract static class AbstractSessionsTest< + AccT extends Map>> { + + static final Duration SESSIONS_GAP = standardMinutes(15); + + final Aggregator, AccT, Collection>> agg; + + AbstractSessionsTest(WindowFn windowFn) { + agg = windowedAgg(windowFn); + } + + abstract AccT accOf(KV>... entries); + + @Test + public void testReduce() { + AccT acc; + + acc = agg.reduce(agg.zero(), sessionValue(10, at(0))); + assertThat(acc, equalsToMap(KV.of(sessionWindow(0), pair(at(0), 10)))); + + // 2nd session after 1st + acc = agg.reduce(acc, sessionValue(7, at(20))); + assertThat( + acc, + equalsToMap( + KV.of(sessionWindow(0), pair(at(0), 10)), KV.of(sessionWindow(20), pair(at(20), 7)))); + + // merge into 2nd session + acc = agg.reduce(acc, sessionValue(6, at(18))); + assertThat( + acc, + equalsToMap( + KV.of(sessionWindow(0), pair(at(0), 10)), + KV.of(sessionWindow(18, 35), pair(at(20), 7 + 6)))); + + // merge into 2nd session + acc = agg.reduce(acc, sessionValue(5, at(21))); + assertThat( + acc, + equalsToMap( + KV.of(sessionWindow(0), pair(at(0), 10)), + KV.of(sessionWindow(18, 36), pair(at(21), 7 + 6 + 5)))); + + // 3rd session after 2nd + acc = agg.reduce(acc, sessionValue(2, NOW.plus(standardMinutes(45)))); + assertThat( + acc, + equalsToMap( + KV.of(sessionWindow(0), pair(at(0), 10)), + KV.of(sessionWindow(18, 36), pair(at(21), 7 + 6 + 5)), + KV.of(sessionWindow(45), pair(at(45), 2)))); + + // merge with 1st and 2nd + acc = agg.reduce(acc, sessionValue(1, at(10))); + assertThat( + acc, + equalsToMap( + KV.of(sessionWindow(0, 36), pair(at(21), 10 + 7 + 6 + 5 + 1)), + KV.of(sessionWindow(45), pair(at(45), 2)))); + } + + @Test + public void testMerge() { + AccT acc; + + assertThat(agg.merge(agg.zero(), agg.zero()), equalTo(agg.zero())); + + acc = accOf(KV.of(sessionWindow(0), pair(at(0), 1))); + + assertThat(agg.merge(acc, agg.zero()), equalTo(acc)); + assertThat(agg.merge(agg.zero(), acc), equalTo(acc)); + + acc = agg.merge(acc, acc); + assertThat(acc, equalsToMap(KV.of(sessionWindow(0), pair(at(0), 1 + 1)))); + + acc = agg.merge(acc, accOf(KV.of(sessionWindow(20), pair(at(20), 2)))); + assertThat( + acc, + equalsToMap( + KV.of(sessionWindow(0), pair(at(0), 1 + 1)), + KV.of(sessionWindow(20), pair(at(20), 2)))); + + acc = agg.merge(accOf(KV.of(sessionWindow(40), pair(at(40), 3))), acc); + assertThat( + acc, + equalsToMap( + KV.of(sessionWindow(0), pair(at(0), 1 + 1)), + KV.of(sessionWindow(20), pair(at(20), 2)), + KV.of(sessionWindow(40), pair(at(40), 3)))); + + acc = agg.merge(acc, accOf(KV.of(sessionWindow(10), pair(at(10), 4)))); + assertThat( + acc, + equalsToMap( + KV.of(sessionWindow(0, 35), pair(at(20), 1 + 1 + 2 + 4)), + KV.of(sessionWindow(40), pair(at(40), 3)))); + + acc = agg.merge(accOf(KV.of(sessionWindow(5, 45), pair(at(30), 5))), acc); + assertThat( + acc, equalsToMap(KV.of(sessionWindow(0, 55), pair(at(40), 1 + 1 + 2 + 4 + 3 + 5)))); + } + + private WindowedValue sessionValue(Integer value, Instant ts) { + return WindowedValues.of(value, ts, new IntervalWindow(ts, SESSIONS_GAP), PaneInfo.NO_FIRING); + } + + private IntervalWindow sessionWindow(int fromMinutes) { + return new IntervalWindow(at(fromMinutes), SESSIONS_GAP); + } + + private static IntervalWindow sessionWindow(int fromMinutes, int toMinutes) { + return intervalWindow(fromMinutes, toMinutes); + } + } + + /** Tests for specialized SessionsAggregator in {@link Aggregators}. */ + public static class SessionsAggregatorTest + extends AbstractSessionsTest>> { + + public SessionsAggregatorTest() { + super(Sessions.withGapDuration(SESSIONS_GAP)); + } + + @Override + TreeMap> accOf( + KV>... entries) { + return new TreeMap<>(mapOf(entries)); + } + } + + /** Tests for MergingWindowedAggregator in {@link Aggregators}. */ + public static class MergingWindowedAggregatorTest + extends AbstractSessionsTest>> { + + public MergingWindowedAggregatorTest() { + super(new CustomSessions<>()); + } + + @Override + Map> accOf( + KV>... entries) { + return mapOf(entries); + } + + /** Wrapper around {@link Sessions} to test the MergingWindowedAggregator. */ + private static class CustomSessions extends WindowFn { + private final Sessions sessions = Sessions.withGapDuration(SESSIONS_GAP); + + @Override + public Collection assignWindows(WindowFn.AssignContext c) { + return sessions.assignWindows((WindowFn.AssignContext) c); + } + + @Override + public void mergeWindows(WindowFn.MergeContext c) throws Exception { + sessions.mergeWindows((WindowFn.MergeContext) c); + } + + @Override + public boolean isCompatible(WindowFn other) { + return sessions.isCompatible(other); + } + + @Override + public Coder windowCoder() { + return sessions.windowCoder(); + } + + @Override + public WindowMappingFn getDefaultWindowMappingFn() { + return sessions.getDefaultWindowMappingFn(); + } + } + } + + private static IntervalWindow intervalWindow(int fromMinutes, int toMinutes) { + return new IntervalWindow(at(fromMinutes), at(toMinutes)); + } + + private static Instant at(int minutes) { + return NOW.plus(standardMinutes(minutes)); + } + + private static Matcher>> equalsToMap( + KV>... entries) { + return equalTo(mapOf(entries)); + } + + private static Map> mapOf( + KV>... entries) { + return Arrays.asList(entries).stream().collect(Collectors.toMap(KV::getKey, KV::getValue)); + } + + private static MutablePair pair(Instant ts, int value) { + return new MutablePair<>(ts, value); + } + + private static + Aggregator, AccT, Collection>> windowedAgg( + WindowFn windowFn) { + Encoder intEnc = EncoderHelpers.encoderOf(Integer.class); + Encoder windowEnc = encoderFor((Coder) IntervalWindow.getCoder()); + Encoder> outputEnc = windowedValueEncoder(intEnc, windowEnc); + + WindowingStrategy windowing = + WindowingStrategy.of(windowFn).withTimestampCombiner(TimestampCombiner.LATEST); + + Aggregator, ?, Collection>> agg = + Aggregators.windowedValue( + new SimpleSum(), WindowedValue::getValue, windowing, windowEnc, intEnc, outputEnc); + return (Aggregator) agg; + } + + private static class SimpleSum extends Combine.CombineFn { + + @Override + public Integer createAccumulator() { + return 0; + } + + @Override + public Integer addInput(Integer acc, Integer input) { + return acc + input; + } + + @Override + public Integer mergeAccumulators(Iterable accs) { + return Streams.stream(accs.iterator()).reduce((a, b) -> a + b).orElseGet(() -> 0); + } + + @Override + public Integer extractOutput(Integer acc) { + return acc; + } + } +} diff --git a/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTest.java b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTest.java new file mode 100644 index 000000000000..cca192df9de3 --- /dev/null +++ b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTest.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import java.io.Serializable; +import org.apache.beam.runners.spark.structuredstreaming.SparkSessionRule; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.Combine.BinaryCombineFn; +import org.apache.beam.sdk.transforms.CombineFnBase; +import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.Sum; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.Sessions; +import org.apache.beam.sdk.transforms.windowing.SlidingWindows; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TimestampedValue; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Test class for beam to spark {@link Combine#globally(CombineFnBase.GlobalCombineFn)} translation. + */ +@RunWith(JUnit4.class) +public class CombineGloballyTest implements Serializable { + @ClassRule public static final SparkSessionRule SESSION = new SparkSessionRule(); + + @Rule + public transient TestPipeline pipeline = + TestPipeline.fromOptions(SESSION.createPipelineOptions()); + + @Test + public void testCombineGlobally() { + PCollection input = + pipeline.apply(Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)).apply(Sum.integersGlobally()); + PAssert.that(input).containsInAnyOrder(55); + // uses combine per key + pipeline.run(); + } + + @Test + public void testCombineGloballyPreservesWindowing() { + PCollection input = + pipeline + .apply( + Create.timestamped( + TimestampedValue.of(1, new Instant(1)), + TimestampedValue.of(2, new Instant(2)), + TimestampedValue.of(3, new Instant(11)), + TimestampedValue.of(4, new Instant(3)), + TimestampedValue.of(5, new Instant(11)), + TimestampedValue.of(6, new Instant(12)))) + .apply(Window.into(FixedWindows.of(Duration.millis(10)))) + .apply(Sum.integersGlobally().withoutDefaults()); + PAssert.that(input).containsInAnyOrder(7, 14); + pipeline.run(); + } + + @Test + public void testCombineGloballyWithSlidingWindows() { + PCollection input = + pipeline + .apply( + Create.timestamped( + TimestampedValue.of(1, new Instant(1)), + TimestampedValue.of(3, new Instant(2)), + TimestampedValue.of(5, new Instant(3)), + TimestampedValue.of(2, new Instant(1)), + TimestampedValue.of(4, new Instant(2)), + TimestampedValue.of(6, new Instant(3)))) + .apply(Window.into(SlidingWindows.of(Duration.millis(3)).every(Duration.millis(1)))) + .apply(Sum.integersGlobally().withoutDefaults()); + PAssert.that(input) + .containsInAnyOrder(1 + 2, 1 + 2 + 3 + 4, 1 + 3 + 5 + 2 + 4 + 6, 3 + 4 + 5 + 6, 5 + 6); + pipeline.run(); + } + + @Test + public void testCombineGloballyWithMergingWindows() { + PCollection input = + pipeline + .apply( + Create.timestamped( + TimestampedValue.of(2, new Instant(5)), + TimestampedValue.of(4, new Instant(11)), + TimestampedValue.of(6, new Instant(12)))) + .apply(Window.into(Sessions.withGapDuration(Duration.millis(5)))) + .apply(Sum.integersGlobally().withoutDefaults()); + + PAssert.that(input).containsInAnyOrder(2 /*window [5-10)*/, 10 /*window [11-17)*/); + pipeline.run(); + } + + @Test + public void testCountGloballyWithSlidingWindows() { + PCollection input = + pipeline + .apply( + Create.timestamped( + TimestampedValue.of("a", new Instant(1)), + TimestampedValue.of("a", new Instant(2)), + TimestampedValue.of("a", new Instant(2)))) + .apply(Window.into(SlidingWindows.of(Duration.millis(2)).every(Duration.millis(1)))); + PCollection output = + input.apply(Combine.globally(Count.combineFn()).withoutDefaults()); + PAssert.that(output).containsInAnyOrder(1L, 3L, 2L); + pipeline.run(); + } + + @Test + public void testBinaryCombineWithSlidingWindows() { + PCollection input = + pipeline + .apply( + Create.timestamped( + TimestampedValue.of(1, new Instant(1)), + TimestampedValue.of(3, new Instant(2)), + TimestampedValue.of(5, new Instant(3)))) + .apply(Window.into(SlidingWindows.of(Duration.millis(3)).every(Duration.millis(1)))) + .apply( + Combine.globally(BinaryCombineFn.of((i1, i2) -> i1 > i2 ? i1 : i2)) + .withoutDefaults()); + PAssert.that(input).containsInAnyOrder(1, 3, 5, 5, 5); + pipeline.run(); + } +} diff --git a/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGroupedValuesTest.java b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGroupedValuesTest.java new file mode 100644 index 000000000000..5b23a6ac9ea5 --- /dev/null +++ b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGroupedValuesTest.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import java.io.Serializable; +import org.apache.beam.runners.spark.structuredstreaming.SparkSessionRule; +import org.apache.beam.sdk.coders.IterableCoder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.Sum; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Test class for beam to spark {@link Combine#groupedValues} translation. */ +@RunWith(JUnit4.class) +public class CombineGroupedValuesTest implements Serializable { + @ClassRule public static final SparkSessionRule SESSION = new SparkSessionRule(); + + @Rule + public transient TestPipeline pipeline = + TestPipeline.fromOptions(SESSION.createPipelineOptions()); + + @Test + public void testCombineGroupedValues() { + PCollection> input = + pipeline + .apply( + Create.>>of( + KV.of("a", ImmutableList.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)), + KV.of("b", ImmutableList.of())) + .withCoder( + KvCoder.of(StringUtf8Coder.of(), IterableCoder.of(VarIntCoder.of())))) + .apply(Combine.groupedValues(Sum.ofIntegers())); + + PAssert.that(input).containsInAnyOrder(KV.of("a", 55), KV.of("b", 0)); + pipeline.run(); + } +} diff --git a/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTest.java b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTest.java new file mode 100644 index 000000000000..41c032cd85be --- /dev/null +++ b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTest.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import org.apache.beam.runners.spark.structuredstreaming.SparkSessionRule; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.CombineFnBase; +import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.Distinct; +import org.apache.beam.sdk.transforms.Sum; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.Sessions; +import org.apache.beam.sdk.transforms.windowing.SlidingWindows; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TimestampedValue; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Test class for beam to spark {@link + * org.apache.beam.sdk.transforms.Combine#perKey(CombineFnBase.GlobalCombineFn)} translation. + */ +@RunWith(JUnit4.class) +public class CombinePerKeyTest implements Serializable { + @ClassRule public static final SparkSessionRule SESSION = new SparkSessionRule(); + + @Rule + public transient TestPipeline pipeline = + TestPipeline.fromOptions(SESSION.createPipelineOptions()); + + @Test + public void testCombinePerKey() { + List> elems = new ArrayList<>(); + elems.add(KV.of(1, 1)); + elems.add(KV.of(1, 3)); + elems.add(KV.of(1, 5)); + elems.add(KV.of(2, 2)); + elems.add(KV.of(2, 4)); + elems.add(KV.of(2, 6)); + + PCollection> input = + pipeline.apply(Create.of(elems)).apply(Sum.integersPerKey()); + PAssert.that(input).containsInAnyOrder(KV.of(1, 9), KV.of(2, 12)); + pipeline.run(); + } + + @Test + public void testDistinctViaCombinePerKey() { + List elems = Lists.newArrayList(1, 2, 3, 3, 4, 4, 4, 4, 5, 5); + + // Distinct is implemented in terms of CombinePerKey + PCollection result = pipeline.apply(Create.of(elems)).apply(Distinct.create()); + + PAssert.that(result).containsInAnyOrder(1, 2, 3, 4, 5); + pipeline.run(); + } + + @Test + public void testCombinePerKeyPreservesWindowing() { + PCollection> input = + pipeline + .apply( + Create.timestamped( + TimestampedValue.of(KV.of(1, 1), new Instant(1)), + TimestampedValue.of(KV.of(1, 3), new Instant(2)), + TimestampedValue.of(KV.of(1, 5), new Instant(11)), + TimestampedValue.of(KV.of(2, 2), new Instant(3)), + TimestampedValue.of(KV.of(2, 4), new Instant(11)), + TimestampedValue.of(KV.of(2, 6), new Instant(12)))) + .apply(Window.into(FixedWindows.of(Duration.millis(10)))) + .apply(Sum.integersPerKey()); + PAssert.that(input).containsInAnyOrder(KV.of(1, 4), KV.of(1, 5), KV.of(2, 2), KV.of(2, 10)); + pipeline.run(); + } + + @Test + public void testCombinePerKeyWithSlidingWindows() { + PCollection> input = + pipeline + .apply( + Create.timestamped( + TimestampedValue.of(KV.of(1, 1), new Instant(1)), + TimestampedValue.of(KV.of(1, 3), new Instant(2)), + TimestampedValue.of(KV.of(1, 5), new Instant(3)), + TimestampedValue.of(KV.of(1, 2), new Instant(1)), + TimestampedValue.of(KV.of(1, 4), new Instant(2)), + TimestampedValue.of(KV.of(1, 6), new Instant(3)))) + .apply(Window.into(SlidingWindows.of(Duration.millis(3)).every(Duration.millis(1)))) + .apply(Sum.integersPerKey()); + PAssert.that(input) + .containsInAnyOrder( + KV.of(1, 1 + 2), + KV.of(1, 1 + 2 + 3 + 4), + KV.of(1, 1 + 3 + 5 + 2 + 4 + 6), + KV.of(1, 3 + 4 + 5 + 6), + KV.of(1, 5 + 6)); + pipeline.run(); + } + + @Test + public void testCombineByKeyWithMergingWindows() { + PCollection> input = + pipeline + .apply( + Create.timestamped( + TimestampedValue.of(KV.of(1, 1), new Instant(5)), + TimestampedValue.of(KV.of(1, 3), new Instant(7)), + TimestampedValue.of(KV.of(1, 5), new Instant(11)), + TimestampedValue.of(KV.of(2, 2), new Instant(5)), + TimestampedValue.of(KV.of(2, 4), new Instant(11)), + TimestampedValue.of(KV.of(2, 6), new Instant(12)))) + .apply(Window.into(Sessions.withGapDuration(Duration.millis(5)))) + .apply(Sum.integersPerKey()); + + PAssert.that(input) + .containsInAnyOrder( + KV.of(1, 9), // window [5-16) + KV.of(2, 2), // window [5-10) + KV.of(2, 10) // window [11-17) + ); + pipeline.run(); + } + + @Test + public void testCountPerElementWithSlidingWindows() { + PCollection input = + pipeline + .apply( + Create.timestamped( + TimestampedValue.of("a", new Instant(1)), + TimestampedValue.of("a", new Instant(2)), + TimestampedValue.of("b", new Instant(3)), + TimestampedValue.of("b", new Instant(4)))) + .apply(Window.into(SlidingWindows.of(Duration.millis(2)).every(Duration.millis(1)))); + PCollection> output = input.apply(Count.perElement()); + PAssert.that(output) + .containsInAnyOrder( + KV.of("a", 1L), + KV.of("a", 2L), + KV.of("a", 1L), + KV.of("b", 1L), + KV.of("b", 2L), + KV.of("b", 1L)); + pipeline.run(); + } +} diff --git a/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ComplexSourceTest.java b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ComplexSourceTest.java new file mode 100644 index 000000000000..4ba356f6ce75 --- /dev/null +++ b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ComplexSourceTest.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.io.PrintStream; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import org.apache.beam.runners.spark.structuredstreaming.SparkSessionRule; +import org.apache.beam.sdk.io.TextIO; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.values.PCollection; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Test class for beam to spark source translation. */ +@RunWith(JUnit4.class) +public class ComplexSourceTest implements Serializable { + @ClassRule public static final TemporaryFolder TEMPORARY_FOLDER = new TemporaryFolder(); + private static File file; + private static List lines = createLines(30); + + @ClassRule public static final SparkSessionRule SESSION = new SparkSessionRule(); + + @Rule + public transient TestPipeline pipeline = + TestPipeline.fromOptions(SESSION.createPipelineOptions()); + + @BeforeClass + public static void beforeClass() throws IOException { + file = createFile(lines); + } + + @Test + public void testBoundedSource() { + PCollection input = pipeline.apply(TextIO.read().from(file.getPath())); + PAssert.that(input).containsInAnyOrder(lines); + pipeline.run(); + } + + private static File createFile(List lines) throws IOException { + File file = TEMPORARY_FOLDER.newFile(); + OutputStream outputStream = new FileOutputStream(file); + try (PrintStream writer = new PrintStream(outputStream)) { + for (String line : lines) { + writer.println(line); + } + } + return file; + } + + private static List createLines(int size) { + List lines = new ArrayList<>(); + for (int i = 0; i < size; ++i) { + lines.add("word" + i); + } + return lines; + } +} diff --git a/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTest.java b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTest.java new file mode 100644 index 000000000000..bf3774ba29ec --- /dev/null +++ b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTest.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import java.io.Serializable; +import org.apache.beam.runners.spark.structuredstreaming.SparkSessionRule; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.Flatten; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionList; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Test class for beam to spark flatten translation. */ +@RunWith(JUnit4.class) +public class FlattenTest implements Serializable { + @ClassRule public static final SparkSessionRule SESSION = new SparkSessionRule(); + + @Rule + public transient TestPipeline pipeline = + TestPipeline.fromOptions(SESSION.createPipelineOptions()); + + @Test + public void testFlatten() { + PCollection input1 = + pipeline.apply("input1", Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)); + PCollection input2 = + pipeline.apply("input2", Create.of(11, 12, 13, 14, 15, 16, 17, 18, 19, 20)); + PCollectionList pcs = PCollectionList.of(input1).and(input2); + PCollection input = pcs.apply(Flatten.pCollections()); + PAssert.that(input) + .containsInAnyOrder(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20); + pipeline.run(); + } +} diff --git a/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTest.java b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTest.java new file mode 100644 index 000000000000..6569b7b20cfc --- /dev/null +++ b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTest.java @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import static java.util.Arrays.stream; +import static java.util.stream.Collectors.groupingBy; +import static java.util.stream.Collectors.mapping; +import static java.util.stream.Collectors.toList; +import static org.apache.beam.sdk.testing.SerializableMatchers.containsInAnyOrder; +import static org.hamcrest.MatcherAssert.assertThat; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.apache.beam.runners.spark.SparkCommonPipelineOptions; +import org.apache.beam.runners.spark.structuredstreaming.SparkSessionRule; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.SerializableMatcher; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.GroupByKey; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.Sessions; +import org.apache.beam.sdk.transforms.windowing.SlidingWindows; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TimestampedValue; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +/** Test class for beam to spark {@link ParDo} translation. */ +@RunWith(Parameterized.class) +public class GroupByKeyTest implements Serializable { + @ClassRule public static final SparkSessionRule SESSION = new SparkSessionRule(); + + @Parameterized.Parameter public boolean preferGroupByKeyToHandleHugeValues; + + @Parameterized.Parameters(name = "Test with preferGroupByKeyToHandleHugeValues={0}") + public static Collection preferGroupByKeyToHandleHugeValues() { + return Arrays.asList(new Object[][] {{true}, {false}}); + } + + @Rule + public transient TestPipeline pipeline = + TestPipeline.fromOptions(SESSION.createPipelineOptions()); + + @Before + public void updatePipelineOptions() { + pipeline + .getOptions() + .as(SparkCommonPipelineOptions.class) + .setPreferGroupByKeyToHandleHugeValues(preferGroupByKeyToHandleHugeValues); + } + + @Test + public void testGroupByKeyPreservesWindowing() { + pipeline + .apply( + Create.timestamped( + shuffleRandomly( + TimestampedValue.of(KV.of(1, 1), new Instant(1)), + TimestampedValue.of(KV.of(1, 3), new Instant(2)), + TimestampedValue.of(KV.of(1, 5), new Instant(11)), + TimestampedValue.of(KV.of(2, 2), new Instant(3)), + TimestampedValue.of(KV.of(2, 4), new Instant(11)), + TimestampedValue.of(KV.of(2, 6), new Instant(12))))) + .apply(Window.into(FixedWindows.of(Duration.millis(10)))) + .apply(GroupByKey.create()) + // Passert do not support multiple kv with same key (because multiple windows) + .apply( + ParDo.of( + new AssertContains<>( + KV.of(1, containsInAnyOrder(1, 3)), // window [0-10) + KV.of(1, containsInAnyOrder(5)), // window [10-20) + KV.of(2, containsInAnyOrder(4, 6)), // window [10-20) + KV.of(2, containsInAnyOrder(2)) // window [0-10) + ))); + pipeline.run(); + } + + @Test + public void testGroupByKeyExplodesMultipleWindows() { + pipeline + .apply( + Create.timestamped( + shuffleRandomly( + TimestampedValue.of(KV.of(1, 1), new Instant(5)), + TimestampedValue.of(KV.of(1, 3), new Instant(7)), + TimestampedValue.of(KV.of(1, 5), new Instant(11)), + TimestampedValue.of(KV.of(2, 2), new Instant(5)), + TimestampedValue.of(KV.of(2, 4), new Instant(11)), + TimestampedValue.of(KV.of(2, 6), new Instant(12))))) + .apply(Window.into(SlidingWindows.of(Duration.millis(10)).every(Duration.millis(5)))) + .apply(GroupByKey.create()) + // Passert do not support multiple kv with same key (because multiple windows) + .apply( + ParDo.of( + new AssertContains<>( + KV.of(1, containsInAnyOrder(1, 3)), // window [0-10) + KV.of(1, containsInAnyOrder(1, 3, 5)), // window [5-15) + KV.of(1, containsInAnyOrder(5)), // window [10-20) + KV.of(2, containsInAnyOrder(2)), // window [0-10) + KV.of(2, containsInAnyOrder(2, 4, 6)), // window [5-15) + KV.of(2, containsInAnyOrder(4, 6)) // window [10-20) + ))); + pipeline.run(); + } + + @Test + public void testGroupByKeyWithMergingWindows() { + pipeline + .apply( + Create.timestamped( + shuffleRandomly( + TimestampedValue.of(KV.of(1, 1), new Instant(5)), + TimestampedValue.of(KV.of(1, 3), new Instant(7)), + TimestampedValue.of(KV.of(1, 5), new Instant(11)), + TimestampedValue.of(KV.of(2, 2), new Instant(5)), + TimestampedValue.of(KV.of(2, 4), new Instant(11)), + TimestampedValue.of(KV.of(2, 6), new Instant(12))))) + .apply(Window.into(Sessions.withGapDuration(Duration.millis(5)))) + .apply(GroupByKey.create()) + // Passert do not support multiple kv with same key (because multiple windows) + .apply( + ParDo.of( + new AssertContains<>( + KV.of(1, containsInAnyOrder(1, 3, 5)), // window [5-16) + KV.of(2, containsInAnyOrder(2)), // window [5-10) + KV.of(2, containsInAnyOrder(4, 6)) // window [11-17) + ))); + pipeline.run(); + } + + @Test + public void testGroupByKey() { + List> elems = + shuffleRandomly( + KV.of(1, 1), KV.of(1, 3), KV.of(1, 5), KV.of(2, 2), KV.of(2, 4), KV.of(2, 6)); + + PCollection>> input = + pipeline.apply(Create.of(elems)).apply(GroupByKey.create()); + + PAssert.thatMap(input) + .satisfies( + results -> { + assertThat(results.get(1), containsInAnyOrder(1, 3, 5)); + assertThat(results.get(2), containsInAnyOrder(2, 4, 6)); + return null; + }); + pipeline.run(); + } + + static class AssertContains extends DoFn>, Void> { + private final Map>>> byKey; + + public AssertContains(KV>>... matchers) { + byKey = stream(matchers).collect(groupingBy(KV::getKey, mapping(KV::getValue, toList()))); + } + + @ProcessElement + public void processElement(@Element KV> elem) { + assertThat("Unexpected key: " + elem.getKey(), byKey.containsKey(elem.getKey())); + List values = ImmutableList.copyOf(elem.getValue()); + assertThat( + "Unexpected values " + values + " for key " + elem.getKey(), + byKey.get(elem.getKey()).stream().anyMatch(m -> m.matches(values))); + } + } + + private List shuffleRandomly(T... elems) { + ArrayList list = Lists.newArrayList(elems); + Collections.shuffle(list); + return list; + } +} diff --git a/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java new file mode 100644 index 000000000000..672a2db4fe1e --- /dev/null +++ b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import static org.junit.Assert.assertTrue; + +import java.io.Serializable; +import java.util.List; +import java.util.Map; +import org.apache.beam.runners.spark.SparkCommonPipelineOptions; +import org.apache.beam.runners.spark.structuredstreaming.SparkSessionRule; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.View; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Test class for beam to spark {@link ParDo} translation. */ +@RunWith(JUnit4.class) +public class ParDoTest implements Serializable { + @ClassRule public static final SparkSessionRule SESSION = new SparkSessionRule(); + + @Rule + public transient TestPipeline pipeline = + TestPipeline.fromOptions(SESSION.createPipelineOptions()); + + @Rule public transient TestRule clearCache = SESSION.clearCache(); + + @Test + public void testPardo() { + PCollection input = + pipeline.apply(Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)).apply(ParDo.of(PLUS_ONE_DOFN)); + PAssert.that(input).containsInAnyOrder(2, 3, 4, 5, 6, 7, 8, 9, 10, 11); + pipeline.run(); + + assertTrue("No usage of cache expected", !SESSION.hasCachedData()); + } + + @Test + public void testPardoWithOutputTagsCachedRDD() { + pardoWithOutputTags("MEMORY_ONLY", true); + assertTrue("Expected cached data", SESSION.hasCachedData()); + } + + @Test + public void testPardoWithOutputTagsCachedDataset() { + pardoWithOutputTags("MEMORY_AND_DISK", true); + assertTrue("Expected cached data", SESSION.hasCachedData()); + } + + @Test + public void testPardoWithUnusedOutputTags() { + pardoWithOutputTags("MEMORY_AND_DISK", false); + assertTrue("No usage of cache expected", !SESSION.hasCachedData()); + } + + private void pardoWithOutputTags(String storageLevel, boolean evaluateAdditionalOutputs) { + pipeline.getOptions().as(SparkCommonPipelineOptions.class).setStorageLevel(storageLevel); + + TupleTag mainTag = new TupleTag() {}; + TupleTag additionalUnevenTag = new TupleTag() {}; + + DoFn doFn = + new DoFn() { + @ProcessElement + public void processElement(@Element Integer i, MultiOutputReceiver out) { + if (i % 2 == 0) { + out.get(mainTag).output(i); + } else { + out.get(additionalUnevenTag).output(i.toString()); + } + } + }; + + PCollectionTuple outputs = + pipeline + .apply(Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) + .apply(ParDo.of(doFn).withOutputTags(mainTag, TupleTagList.of(additionalUnevenTag))); + + PAssert.that(outputs.get(mainTag)).containsInAnyOrder(2, 4, 6, 8, 10); + if (evaluateAdditionalOutputs) { + PAssert.that(outputs.get(additionalUnevenTag)).containsInAnyOrder("1", "3", "5", "7", "9"); + } + pipeline.run(); + } + + @Test + public void testTwoPardoInRow() { + PCollection input = + pipeline + .apply(Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) + .apply("Plus 1 (1st)", ParDo.of(PLUS_ONE_DOFN)) + .apply("Plus 1 (2nd)", ParDo.of(PLUS_ONE_DOFN)); + PAssert.that(input).containsInAnyOrder(3, 4, 5, 6, 7, 8, 9, 10, 11, 12); + pipeline.run(); + + assertTrue("No usage of cache expected", !SESSION.hasCachedData()); + } + + @Test + public void testSideInputAsList() { + PCollectionView> sideInputView = + pipeline.apply("Create sideInput", Create.of(1, 2, 3)).apply(View.asList()); + PCollection input = + pipeline + .apply("Create input", Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) + .apply( + ParDo.of( + new DoFn() { + @ProcessElement + public void processElement(ProcessContext c) { + List sideInputValue = c.sideInput(sideInputView); + if (!sideInputValue.contains(c.element())) { + c.output(c.element()); + } + } + }) + .withSideInputs(sideInputView)); + PAssert.that(input).containsInAnyOrder(4, 5, 6, 7, 8, 9, 10); + pipeline.run(); + + assertTrue("No usage of cache expected", !SESSION.hasCachedData()); + } + + @Test + public void testSideInputAsSingleton() { + PCollectionView sideInputView = + pipeline.apply("Create sideInput", Create.of(1)).apply(View.asSingleton()); + + PCollection input = + pipeline + .apply("Create input", Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) + .apply( + ParDo.of( + new DoFn() { + @ProcessElement + public void processElement(ProcessContext c) { + Integer sideInputValue = c.sideInput(sideInputView); + if (!sideInputValue.equals(c.element())) { + c.output(c.element()); + } + } + }) + .withSideInputs(sideInputView)); + + PAssert.that(input).containsInAnyOrder(2, 3, 4, 5, 6, 7, 8, 9, 10); + pipeline.run(); + + assertTrue("No usage of cache expected", !SESSION.hasCachedData()); + } + + @Test + public void testSideInputAsMap() { + PCollectionView> sideInputView = + pipeline + .apply("Create sideInput", Create.of(KV.of("key1", 1), KV.of("key2", 2))) + .apply(View.asMap()); + PCollection input = + pipeline + .apply("Create input", Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) + .apply( + ParDo.of( + new DoFn() { + @ProcessElement + public void processElement(ProcessContext c) { + Map sideInputValue = c.sideInput(sideInputView); + if (!sideInputValue.containsKey("key" + c.element())) { + c.output(c.element()); + } + } + }) + .withSideInputs(sideInputView)); + PAssert.that(input).containsInAnyOrder(3, 4, 5, 6, 7, 8, 9, 10); + pipeline.run(); + + assertTrue("No usage of cache expected", !SESSION.hasCachedData()); + } + + private static final DoFn PLUS_ONE_DOFN = + new DoFn() { + @ProcessElement + public void processElement(ProcessContext c) { + c.output(c.element() + 1); + } + }; +} diff --git a/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/SimpleSourceTest.java b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/SimpleSourceTest.java new file mode 100644 index 000000000000..d70293d50560 --- /dev/null +++ b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/SimpleSourceTest.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import java.io.Serializable; +import org.apache.beam.runners.spark.structuredstreaming.SparkSessionRule; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.values.PCollection; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Test class for beam to spark source translation. */ +@RunWith(JUnit4.class) +public class SimpleSourceTest implements Serializable { + @ClassRule public static final SparkSessionRule SESSION = new SparkSessionRule(); + + @Rule + public transient TestPipeline pipeline = + TestPipeline.fromOptions(SESSION.createPipelineOptions()); + + @Test + public void testBoundedSource() { + PCollection input = pipeline.apply(Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)); + PAssert.that(input).containsInAnyOrder(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + pipeline.run(); + } +} diff --git a/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTest.java b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTest.java new file mode 100644 index 000000000000..ecb3e7ebdb5b --- /dev/null +++ b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTest.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import java.io.Serializable; +import org.apache.beam.runners.spark.structuredstreaming.SparkSessionRule; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.Sum; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TimestampedValue; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Test class for beam to spark window assign translation. */ +@RunWith(JUnit4.class) +public class WindowAssignTest implements Serializable { + @ClassRule public static final SparkSessionRule SESSION = new SparkSessionRule(); + + @Rule + public transient TestPipeline pipeline = + TestPipeline.fromOptions(SESSION.createPipelineOptions()); + + @Test + public void testWindowAssign() { + PCollection input = + pipeline + .apply( + Create.timestamped( + TimestampedValue.of(1, new Instant(1)), + TimestampedValue.of(2, new Instant(2)), + TimestampedValue.of(3, new Instant(3)), + TimestampedValue.of(4, new Instant(10)), + TimestampedValue.of(5, new Instant(11)))) + .apply(Window.into(FixedWindows.of(Duration.millis(10)))) + .apply(Sum.integersGlobally().withoutDefaults()); + PAssert.that(input).containsInAnyOrder(6, 9); + pipeline.run(); + } +} diff --git a/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/SideInputValuesTest.java b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/SideInputValuesTest.java new file mode 100644 index 000000000000..74d5a0292edb --- /dev/null +++ b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/functions/SideInputValuesTest.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions; + +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderOf; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.windowedValueEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf; +import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; +import static org.apache.beam.sdk.values.WindowedValues.getFullCoder; +import static org.apache.beam.sdk.values.WindowedValues.valueInGlobalWindow; +import static org.assertj.core.api.Assertions.assertThat; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import javax.annotation.Nullable; +import org.apache.beam.runners.spark.structuredstreaming.SparkSessionRule; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.spark.serializer.KryoSerializer; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.assertj.core.util.Lists; +import org.joda.time.Instant; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.ExternalResource; + +public class SideInputValuesTest { + + @ClassRule public static final SparkSessionRule SESSION = new SparkSessionRule(); + + @ClassRule public static final SparkKryo KRYO = new SparkKryo(); + + @Test + public void globalSideInputValues() { + Encoder> enc = + windowedValueEncoder(encoderOf(String.class), encoderOf(GlobalWindow.class)); + Dataset> ds = + dataset(enc, valueInGlobalWindow("a"), valueInGlobalWindow("b")); + + SideInputValues values = new SideInputValues.Global<>("test", StringUtf8Coder.of(), ds); + assertThat(values.get(GlobalWindow.INSTANCE)).isEqualTo(ImmutableList.of("a", "b")); + + SideInputValues deserialized = KRYO.serde(values); + assertThat(deserialized).isEqualToIgnoringGivenFields(values, "binaryValues"); + assertThat(deserialized.get(GlobalWindow.INSTANCE)).isEqualTo(ImmutableList.of("a", "b")); + } + + @Test + public void windowedSideInputValues() { + Encoder> encoder = + windowedValueEncoder(encoderOf(String.class), encoderOf(IntervalWindow.class)); + Coder> coder = + getFullCoder(StringUtf8Coder.of(), IntervalWindow.getCoder()); + + Dataset> ds = + dataset( + encoder, + valueInWindows("a", intervalWindow(0, 1), intervalWindow(1, 2)), + valueInWindows("b", intervalWindow(1, 2), intervalWindow(2, 3))); + + SideInputValues values = new SideInputValues.ByWindow<>("test", coder, ds); + assertThat(values.get(intervalWindow(0, 1))).isEqualTo(ImmutableList.of("a")); + assertThat(values.get(intervalWindow(1, 2))).isEqualTo(ImmutableList.of("a", "b")); + assertThat(values.get(intervalWindow(2, 3))).isEqualTo(ImmutableList.of("b")); + + SideInputValues deserialized = KRYO.serde(values); + assertThat(deserialized).isEqualToIgnoringGivenFields(values, "binaryValues"); + assertThat(deserialized.get(intervalWindow(0, 1))).isEqualTo(ImmutableList.of("a")); + assertThat(deserialized.get(intervalWindow(1, 2))).isEqualTo(ImmutableList.of("a", "b")); + assertThat(deserialized.get(intervalWindow(2, 3))).isEqualTo(ImmutableList.of("b")); + } + + private static Dataset dataset(Encoder enc, T... data) { + return SESSION.getSession().createDataset(seqOf(data), enc); + } + + private static IntervalWindow intervalWindow(int start, int end) { + return new IntervalWindow(Instant.ofEpochMilli(start), Instant.ofEpochMilli(end)); + } + + private static WindowedValue valueInWindows(T value, BoundedWindow... windows) { + return WindowedValues.of(value, Instant.EPOCH, Lists.list(windows), PaneInfo.NO_FIRING); + } + + public static class SparkKryo extends ExternalResource { + private @Nullable Kryo kryo = null; + + @Override + protected void after() { + kryo = null; + } + + T serde(T obj) { + Output out = new Output(128); + kryo().writeClassAndObject(out, obj); + return (T) kryo().readClassAndObject(new Input(out.getBuffer(), 0, out.position())); + } + + Kryo kryo() { + if (kryo == null) { + kryo = new KryoSerializer(SESSION.getSession().sparkContext().conf()).newKryo(); + } + return checkStateNotNull(kryo); + } + } +} diff --git a/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest.java b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest.java new file mode 100644 index 000000000000..48c1e645f6ec --- /dev/null +++ b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest.java @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.helpers; + +import static java.util.Arrays.asList; +import static java.util.function.Function.identity; +import static java.util.stream.Collectors.toList; +import static java.util.stream.Collectors.toMap; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.collectionEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderFor; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.kvEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.mapEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.oneOfEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.windowedValueEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Predicates.notNull; +import static org.apache.spark.sql.types.DataTypes.IntegerType; +import static org.apache.spark.sql.types.DataTypes.StringType; +import static org.apache.spark.sql.types.DataTypes.createStructField; +import static org.apache.spark.sql.types.DataTypes.createStructType; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; + +import java.math.BigDecimal; +import java.math.MathContext; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.TreeMap; +import java.util.function.Function; +import org.apache.beam.runners.spark.structuredstreaming.SparkSessionRule; +import org.apache.beam.sdk.coders.BigDecimalCoder; +import org.apache.beam.sdk.coders.BigEndianIntegerCoder; +import org.apache.beam.sdk.coders.BigEndianLongCoder; +import org.apache.beam.sdk.coders.BigEndianShortCoder; +import org.apache.beam.sdk.coders.BooleanCoder; +import org.apache.beam.sdk.coders.ByteCoder; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.DelegateCoder; +import org.apache.beam.sdk.coders.DoubleCoder; +import org.apache.beam.sdk.coders.FloatCoder; +import org.apache.beam.sdk.coders.InstantCoder; +import org.apache.beam.sdk.coders.ListCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.joda.time.Instant; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import scala.Tuple2; + +/** Test of the wrapping of Beam Coders as Spark ExpressionEncoders. */ +@RunWith(JUnit4.class) +public class EncoderHelpersTest { + @ClassRule public static SparkSessionRule sessionRule = new SparkSessionRule("local[1]"); + + private static final Encoder windowEnc = + EncoderHelpers.encoderOf(GlobalWindow.class); + + private static final Map, List> BASIC_CASES = + ImmutableMap., List>builder() + .put(BooleanCoder.of(), asList(true, false, null)) + .put(ByteCoder.of(), asList((byte) 1, null)) + .put(BigEndianShortCoder.of(), asList((short) 1, null)) + .put(BigEndianIntegerCoder.of(), asList(1, 2, 3, null)) + .put(VarIntCoder.of(), asList(1, 2, 3, null)) + .put(BigEndianLongCoder.of(), asList(1L, 2L, 3L, null)) + .put(VarLongCoder.of(), asList(1L, 2L, 3L, null)) + .put(FloatCoder.of(), asList((float) 1.0, (float) 2.0, null)) + .put(DoubleCoder.of(), asList(1.0, 2.0, null)) + .put(StringUtf8Coder.of(), asList("1", "2", null)) + .put(BigDecimalCoder.of(), asList(bigDecimalOf(1L), bigDecimalOf(2L), null)) + .put(InstantCoder.of(), asList(Instant.ofEpochMilli(1), null)) + .build(); + + private Dataset createDataset(List data, Encoder encoder) { + Dataset ds = sessionRule.getSession().createDataset(data, encoder); + ds.printSchema(); + return ds; + } + + @Test + public void testBeamEncoderMappings() { + BASIC_CASES.forEach( + (coder, data) -> { + Encoder encoder = encoderFor(coder); + serializeAndDeserialize(data.get(0), (Encoder) encoder); + Dataset dataset = createDataset(data, (Encoder) encoder); + assertThat(dataset.collect(), equalTo(data.toArray())); + }); + } + + @Test + public void testBeamEncoderOfPrivateType() { + // Verify concrete types are not used in coder generation. + // In case of private types this would cause an IllegalAccessError. + List data = asList(new PrivateString("1"), new PrivateString("2")); + Dataset dataset = createDataset(data, encoderFor(PrivateString.CODER)); + assertThat(dataset.collect(), equalTo(data.toArray())); + } + + @Test + public void testBeamWindowedValueEncoderMappings() { + BASIC_CASES.forEach( + (coder, data) -> { + List> windowed = + Lists.transform(data, WindowedValues::valueInGlobalWindow); + + Encoder encoder = windowedValueEncoder(encoderFor(coder), windowEnc); + serializeAndDeserialize(windowed.get(0), (Encoder) encoder); + + Dataset dataset = createDataset(windowed, (Encoder) encoder); + assertThat(dataset.collect(), equalTo(windowed.toArray())); + }); + } + + @Test + public void testCollectionEncoder() { + BASIC_CASES.forEach( + (coder, data) -> { + Encoder> encoder = collectionEncoder(encoderFor(coder), true); + Collection collection = Collections.unmodifiableCollection(data); + + Dataset> dataset = createDataset(asList(collection), (Encoder) encoder); + assertThat(dataset.head(), equalTo(data)); + }); + } + + private void testMapEncoder(Class cls, Function, Map> decorator) { + BASIC_CASES.forEach( + (coder, data) -> { + Encoder enc = encoderFor(coder); + Encoder> mapEncoder = mapEncoder(enc, enc, (Class) cls); + Map map = + decorator.apply( + data.stream().filter(notNull()).collect(toMap(identity(), identity()))); + + Dataset> dataset = createDataset(asList(map), mapEncoder); + Map head = dataset.head(); + assertThat(head, equalTo(map)); + assertThat(head, instanceOf(cls)); + }); + } + + @Test + public void testMapEncoder() { + testMapEncoder(Map.class, identity()); + } + + @Test + public void testHashMapEncoder() { + testMapEncoder(HashMap.class, identity()); + } + + @Test + public void testTreeMapEncoder() { + testMapEncoder(TreeMap.class, TreeMap::new); + } + + @Test + public void testBeamBinaryEncoder() { + List> data = asList(asList("a1", "a2", "a3"), asList("b1", "b2"), asList("c1")); + + Encoder> encoder = encoderFor(ListCoder.of(StringUtf8Coder.of())); + serializeAndDeserialize(data.get(0), encoder); + + Dataset> dataset = createDataset(data, encoder); + assertThat(dataset.collect(), equalTo(data.toArray())); + } + + @Test + public void testEncoderForKVCoder() { + List> data = + asList(KV.of(1, "value1"), KV.of(null, "value2"), KV.of(3, null)); + + Encoder> encoder = + kvEncoder(encoderFor(VarIntCoder.of()), encoderFor(StringUtf8Coder.of())); + serializeAndDeserialize(data.get(0), encoder); + + Dataset> dataset = createDataset(data, encoder); + + StructType kvSchema = + createStructType( + new StructField[] { + createStructField("key", IntegerType, true), + createStructField("value", StringType, true) + }); + + assertThat(dataset.schema(), equalTo(kvSchema)); + assertThat(dataset.collectAsList(), equalTo(data)); + } + + @Test + public void testOneOffEncoder() { + List> coders = ImmutableList.copyOf(BASIC_CASES.keySet()); + List> encoders = coders.stream().map(EncoderHelpers::encoderFor).collect(toList()); + + // build oneOf tuples of type index and corresponding value + List> data = + BASIC_CASES.entrySet().stream() + .map(e -> tuple(coders.indexOf(e.getKey()), (Object) e.getValue().get(0))) + .collect(toList()); + + // dataset is a sparse dataset with only one column set per row + Dataset> dataset = createDataset(data, oneOfEncoder((List) encoders)); + assertThat(dataset.collectAsList(), equalTo(data)); + } + + // fix scale/precision to system default to compare using equals + private static BigDecimal bigDecimalOf(long l) { + DecimalType type = DecimalType.SYSTEM_DEFAULT(); + return new BigDecimal(l, new MathContext(type.precision())).setScale(type.scale()); + } + + // test and explicit serialization roundtrip + @SuppressWarnings("unchecked") + private static void serializeAndDeserialize(T data, Encoder enc) { + ExpressionEncoder bound; + if (enc instanceof ExpressionEncoder) { + bound = (ExpressionEncoder) enc; + } else { + bound = ExpressionEncoder.apply((AgnosticEncoder) enc); + } + bound = + bound.resolveAndBind(bound.resolveAndBind$default$1(), bound.resolveAndBind$default$2()); + + InternalRow row = bound.createSerializer().apply(data); + T deserialized = bound.createDeserializer().apply(row); + + assertThat(deserialized, equalTo(data)); + } + + private static class PrivateString { + private static final Coder CODER = + DelegateCoder.of( + StringUtf8Coder.of(), + str -> str.string, + PrivateString::new, + new TypeDescriptor() {}); + + private final String string; + + public PrivateString(String string) { + this.string = string; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof PrivateString)) { + return false; + } + PrivateString that = (PrivateString) o; + return Objects.equals(string, that.string); + } + + @Override + public int hashCode() { + return Objects.hash(string); + } + } +} diff --git a/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/SimpleSourceTest.java b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/SimpleSourceTest.java new file mode 100644 index 000000000000..a06d2cec1e9e --- /dev/null +++ b/runners/spark/4/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/SimpleSourceTest.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.streaming; + +import java.io.Serializable; +import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions; +import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.io.GenerateSequence; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Test class for beam to spark source translation. */ +@RunWith(JUnit4.class) +public class SimpleSourceTest implements Serializable { + private static Pipeline pipeline; + @ClassRule public static final TemporaryFolder TEMPORARY_FOLDER = new TemporaryFolder(); + + @BeforeClass + public static void beforeClass() { + SparkStructuredStreamingPipelineOptions options = + PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class); + options.setRunner(SparkStructuredStreamingRunner.class); + options.setTestMode(true); + pipeline = Pipeline.create(options); + } + + @Ignore + @Test + public void testUnboundedSource() { + // produces an unbounded PCollection of longs from 0 to Long.MAX_VALUE which elements + // have processing time as event timestamps. + pipeline.apply(GenerateSequence.from(0L)); + pipeline.run(); + } +} From ba9bec9a7dddf967b5c3de7e8f343ff47b5d7518 Mon Sep 17 00:00:00 2001 From: Tobias Kaymak Date: Fri, 17 Apr 2026 16:01:37 +0200 Subject: [PATCH 05/14] ci: add Spark 4 PreCommit and PostCommit workflows Add GitHub Actions workflows for the Spark 4 runner module: - beam_PreCommit_Java_Spark4_Versions: runs sparkVersionsTest on changes to runners/spark/**. Currently a no-op (the sparkVersions map is empty) but scaffolds future patch version coverage. - beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming: runs the structured streaming test suite on Java 17. Co-Authored-By: Claude Sonnet 4.6 --- ...datesRunner_Spark4StructuredStreaming.json | 3 + .../beam_PreCommit_Java_Spark4_Versions.json | 3 + ...idatesRunner_Spark4StructuredStreaming.yml | 97 +++++++++++++++ .../beam_PreCommit_Java_Spark4_Versions.yml | 112 ++++++++++++++++++ 4 files changed, 215 insertions(+) create mode 100644 .github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.json create mode 100644 .github/trigger_files/beam_PreCommit_Java_Spark4_Versions.json create mode 100644 .github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.yml create mode 100644 .github/workflows/beam_PreCommit_Java_Spark4_Versions.yml diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.json new file mode 100644 index 000000000000..c4edaa85a89d --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.json @@ -0,0 +1,3 @@ +{ + "comment": "Modify this file in a trivial way to cause this test suite to run" +} diff --git a/.github/trigger_files/beam_PreCommit_Java_Spark4_Versions.json b/.github/trigger_files/beam_PreCommit_Java_Spark4_Versions.json new file mode 100644 index 000000000000..c4edaa85a89d --- /dev/null +++ b/.github/trigger_files/beam_PreCommit_Java_Spark4_Versions.json @@ -0,0 +1,3 @@ +{ + "comment": "Modify this file in a trivial way to cause this test suite to run" +} diff --git a/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.yml b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.yml new file mode 100644 index 000000000000..b595afe6f42c --- /dev/null +++ b/.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.yml @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: PostCommit Java ValidatesRunner Spark4 StructuredStreaming + +on: + schedule: + - cron: '45 4/6 * * *' + pull_request_target: + paths: ['release/trigger_all_tests.json', '.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.json'] + workflow_dispatch: + +#Setting explicit permissions for the action to avoid the default permissions which are `write-all` in case of pull_request_target event +permissions: + actions: write + pull-requests: write + checks: write + contents: read + deployments: read + id-token: none + issues: write + discussions: read + packages: read + pages: read + repository-projects: read + security-events: read + statuses: read + +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.pull_request.number || github.sha || github.head_ref || github.ref }}-${{ github.event.schedule || github.event.comment.id || github.event.sender.login }}' + cancel-in-progress: true + +env: + DEVELOCITY_ACCESS_KEY: ${{ secrets.DEVELOCITY_ACCESS_KEY }} + GRADLE_ENTERPRISE_CACHE_USERNAME: ${{ secrets.GE_CACHE_USERNAME }} + GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} + +jobs: + beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming: + name: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) + runs-on: [self-hosted, ubuntu-24.04, main] + timeout-minutes: 120 + strategy: + matrix: + job_name: [beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming] + job_phrase: [Run Spark4 StructuredStreaming ValidatesRunner] + if: | + github.event_name == 'workflow_dispatch' || + github.event_name == 'pull_request_target' || + (github.event_name == 'schedule' && github.repository == 'apache/beam') || + github.event.comment.body == 'Run Spark4 StructuredStreaming ValidatesRunner' + steps: + - uses: actions/checkout@v4 + - name: Setup repository + uses: ./.github/actions/setup-action + with: + comment_phrase: ${{ matrix.job_phrase }} + github_token: ${{ secrets.GITHUB_TOKEN }} + github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) + - name: Setup environment + uses: ./.github/actions/setup-environment-action + with: + java-version: '17' + - name: run validatesStructuredStreamingRunnerBatch script + uses: ./.github/actions/gradle-command-self-hosted-action + with: + gradle-command: :runners:spark:4:validatesStructuredStreamingRunnerBatch + arguments: | + -PtestJavaVersion=17 \ + -PdisableSpotlessCheck=true \ + - name: Archive JUnit Test Results + uses: actions/upload-artifact@v4 + if: ${{ !success() }} + with: + name: JUnit Test Results + path: "**/build/reports/tests/" + - name: Publish JUnit Test Results + uses: EnricoMi/publish-unit-test-result-action@v2 + if: always() + with: + commit: '${{ env.prsha || env.GITHUB_SHA }}' + comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} + files: '**/build/test-results/**/*.xml' + large_files: true diff --git a/.github/workflows/beam_PreCommit_Java_Spark4_Versions.yml b/.github/workflows/beam_PreCommit_Java_Spark4_Versions.yml new file mode 100644 index 000000000000..666cb05940dc --- /dev/null +++ b/.github/workflows/beam_PreCommit_Java_Spark4_Versions.yml @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: PreCommit Java Spark4 Versions + +on: + push: + tags: ['v*'] + branches: ['master', 'release-*'] + paths: + - 'runners/spark/**' + - '.github/workflows/beam_PreCommit_Java_Spark4_Versions.yml' + pull_request_target: + branches: ['master', 'release-*'] + paths: + - 'runners/spark/**' + - 'release/trigger_all_tests.json' + - '.github/trigger_files/beam_PreCommit_Java_Spark4_Versions.json' + issue_comment: + types: [created] + schedule: + - cron: '30 2/6 * * *' + workflow_dispatch: + +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.pull_request.number || github.event.pull_request.head.label || github.sha || github.head_ref || github.ref }}-${{ github.event.schedule || github.event.comment.id || github.event.sender.login }}' + cancel-in-progress: true + +#Setting explicit permissions for the action to avoid the default permissions which are `write-all` in case of pull_request_target event +permissions: + actions: write + pull-requests: write + checks: write + contents: read + deployments: read + id-token: none + issues: write + discussions: read + packages: read + pages: read + repository-projects: read + security-events: read + statuses: read + +env: + DEVELOCITY_ACCESS_KEY: ${{ secrets.DEVELOCITY_ACCESS_KEY }} + GRADLE_ENTERPRISE_CACHE_USERNAME: ${{ secrets.GE_CACHE_USERNAME }} + GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} + +jobs: + beam_PreCommit_Java_Spark4_Versions: + name: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) + runs-on: [self-hosted, ubuntu-24.04, main] + strategy: + matrix: + job_name: [beam_PreCommit_Java_Spark4_Versions] + job_phrase: [Run Java_Spark4_Versions PreCommit] + timeout-minutes: 120 + if: | + github.event_name == 'push' || + github.event_name == 'pull_request_target' || + (github.event_name == 'schedule' && github.repository == 'apache/beam') || + github.event_name == 'workflow_dispatch' || + github.event.comment.body == 'Run Java_Spark4_Versions PreCommit' + steps: + - uses: actions/checkout@v4 + - name: Setup repository + uses: ./.github/actions/setup-action + with: + comment_phrase: ${{ matrix.job_phrase }} + github_token: ${{ secrets.GITHUB_TOKEN }} + github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) + - name: Setup environment + uses: ./.github/actions/setup-environment-action + with: + java-version: '17' + - name: run sparkVersionsTest script + uses: ./.github/actions/gradle-command-self-hosted-action + with: + gradle-command: :runners:spark:4:sparkVersionsTest + arguments: | + -PtestJavaVersion=17 \ + -PdisableSpotlessCheck=true \ + - name: Archive JUnit Test Results + uses: actions/upload-artifact@v4 + if: ${{ !success() }} + with: + name: JUnit Test Results + path: "**/build/reports/tests/" + - name: Publish JUnit Test Results + uses: EnricoMi/publish-unit-test-result-action@v2 + if: always() + with: + commit: '${{ env.prsha || env.GITHUB_SHA }}' + comment_mode: ${{ github.event_name == 'issue_comment' && 'always' || 'off' }} + files: '**/build/test-results/**/*.xml' + large_files: true From 0a98149a92a8c58fc35fd2a11b20463a2ae21e6b Mon Sep 17 00:00:00 2001 From: Tobias Kaymak Date: Fri, 17 Apr 2026 16:01:38 +0200 Subject: [PATCH 06/14] Add PreCommit Java Spark4 Versions workflow --- .github/workflows/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/README.md b/.github/workflows/README.md index 9e685a983278..03906d1fb7cd 100644 --- a/.github/workflows/README.md +++ b/.github/workflows/README.md @@ -267,6 +267,7 @@ PreCommit Jobs run in a schedule and also get triggered in a PR if relevant sour | [ PreCommit Java Snowflake IO Direct ](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Snowflake_IO_Direct.yml) | N/A |`Run Java_Snowflake_IO_Direct PreCommit`| [![.github/workflows/beam_PreCommit_Java_Snowflake_IO_Direct.yml](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Snowflake_IO_Direct.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Snowflake_IO_Direct.yml?query=event%3Aschedule) | | [ PreCommit Java Solr IO Direct ](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Solr_IO_Direct.yml) | N/A |`Run Java_Solr_IO_Direct PreCommit`| [![.github/workflows/beam_PreCommit_Java_Solr_IO_Direct.yml](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Solr_IO_Direct.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Solr_IO_Direct.yml?query=event%3Aschedule) | | [ PreCommit Java Spark3 Versions ](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Spark3_Versions.yml) | N/A | `Run Java_Spark3_Versions PreCommit` | [![.github/workflows/beam_PreCommit_Java_Spark3_Versions.yml](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Spark3_Versions.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Spark3_Versions.yml?query=event%3Aschedule) | +| [ PreCommit Java Spark4 Versions ](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Spark4_Versions.yml) | N/A | `Run Java_Spark4_Versions PreCommit` | [![.github/workflows/beam_PreCommit_Java_Spark4_Versions.yml](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Spark4_Versions.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Spark4_Versions.yml?query=event%3Aschedule) | | [ PreCommit Java Splunk IO Direct ](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Splunk_IO_Direct.yml) | N/A |`Run Java_Splunk_IO_Direct PreCommit`| [![.github/workflows/beam_PreCommit_Java_Splunk_IO_Direct.yml](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Splunk_IO_Direct.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Splunk_IO_Direct.yml?query=event%3Aschedule) | | [ PreCommit Java Thrift IO Direct ](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Thrift_IO_Direct.yml) | N/A |`Run Java_Thrift_IO_Direct PreCommit`| [![.github/workflows/beam_PreCommit_Java_Thrift_IO_Direct.yml](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Thrift_IO_Direct.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Thrift_IO_Direct.yml?query=event%3Aschedule) | | [ PreCommit Java Tika IO Direct ](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Tika_IO_Direct.yml) | N/A |`Run Java_Tika_IO_Direct PreCommit`| [![.github/workflows/beam_PreCommit_Java_Tika_IO_Direct.yml](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Tika_IO_Direct.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PreCommit_Java_Tika_IO_Direct.yml?query=event%3Aschedule) | @@ -375,6 +376,7 @@ PostCommit Jobs run in a schedule against master branch and generally do not get | [ PostCommit Java ValidatesRunner Spark Java8 ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Spark_Java8.yml) | N/A |`beam_PostCommit_Java_ValidatesRunner_Spark_Java8.json`| [![.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark_Java8.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Spark_Java8.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Spark_Java8.yml?query=event%3Aschedule) | | [ PostCommit Java ValidatesRunner Spark ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Spark.yml) | N/A |`beam_PostCommit_Java_ValidatesRunner_Spark.json`| [![.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Spark.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Spark.yml?query=event%3Aschedule) | | [ PostCommit Java ValidatesRunner SparkStructuredStreaming ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.yml) | N/A |`beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.json`| [![.github/workflows/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.yml?query=event%3Aschedule) | +| [ PostCommit Java ValidatesRunner Spark4StructuredStreaming ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.yml) | N/A |`beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.json`| [![.github/workflows/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Spark4StructuredStreaming.yml?query=event%3Aschedule) | | [ PostCommit Java ValidatesRunner Twister2 ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Twister2.yml) | N/A |`beam_PostCommit_Java_ValidatesRunner_Twister2.json`| [![.github/workflows/beam_PostCommit_Java_ValidatesRunner_Twister2.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Twister2.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_Twister2.yml?query=event%3Aschedule) | | [ PostCommit Java ValidatesRunner ULR ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_ULR.yml) | N/A |`beam_PostCommit_Java_ValidatesRunner_ULR.json`| [![.github/workflows/beam_PostCommit_Java_ValidatesRunner_ULR.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_ULR.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java_ValidatesRunner_ULR.yml?query=event%3Aschedule) | | [ PostCommit Java ](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java.yml) | N/A |`beam_PostCommit_Java.json`| [![.github/workflows/beam_PostCommit_Java.yml](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/beam_PostCommit_Java.yml?query=event%3Aschedule) | From 07f0b0251d19ad070711b5d24a88bd9eb30f1727 Mon Sep 17 00:00:00 2001 From: Tobias Kaymak Date: Fri, 17 Apr 2026 16:01:38 +0200 Subject: [PATCH 07/14] Add cancellation support to Spark pipeline execution --- .../SparkStructuredStreamingPipelineResult.java | 1 + 1 file changed, 1 insertion(+) diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineResult.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineResult.java index 806d838d9bff..9d3419e19473 100644 --- a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineResult.java +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineResult.java @@ -113,6 +113,7 @@ public MetricResults metrics() { @Override public PipelineResult.State cancel() throws IOException { + pipelineExecution.cancel(true); offerNewState(PipelineResult.State.CANCELLED); return state; } From 4a1225c542d90dfd1ae836cdeeb7570b51305504 Mon Sep 17 00:00:00 2001 From: Tobias Kaymak Date: Fri, 17 Apr 2026 16:01:38 +0200 Subject: [PATCH 08/14] Remove unused endOfData() call in close method Remove endOfData() call in close method. --- .../spark/structuredstreaming/io/BoundedDatasetFactory.java | 1 - 1 file changed, 1 deletion(-) diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java index c00b2d3594c0..0020347b89b0 100644 --- a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java @@ -284,7 +284,6 @@ public SourcePartitionIterator(SourcePartition partition, Params params) { @SuppressWarnings("nullness") // ok, reader not used any longer public void close() throws IOException { if (reader != null) { - endOfData(); try { reader.close(); } finally { From 5989f52a156ecee796c88756d63ab6102fd69f80 Mon Sep 17 00:00:00 2001 From: Tobias Kaymak Date: Fri, 17 Apr 2026 16:01:38 +0200 Subject: [PATCH 09/14] build: add Spark 4 job-server and container modules Add job-server and container build configurations for Spark 4, mirroring the existing Spark 3 job-server setup. The container uses eclipse-temurin:17 (Spark 4 requires Java 17). The shared spark_job_server.gradle gains a requireJavaVersion conditional for Spark 4 parent projects. Co-Authored-By: Claude Opus 4.6 --- runners/spark/4/job-server/build.gradle | 31 ++++++++++++++ .../spark/4/job-server/container/Dockerfile | 32 +++++++++++++++ .../spark/4/job-server/container/build.gradle | 41 +++++++++++++++++++ .../translation/EvaluationContext.java | 10 ++++- .../spark/job-server/spark_job_server.gradle | 3 ++ settings.gradle.kts | 2 + 6 files changed, 117 insertions(+), 2 deletions(-) create mode 100644 runners/spark/4/job-server/build.gradle create mode 100644 runners/spark/4/job-server/container/Dockerfile create mode 100644 runners/spark/4/job-server/container/build.gradle diff --git a/runners/spark/4/job-server/build.gradle b/runners/spark/4/job-server/build.gradle new file mode 100644 index 000000000000..598cf3b4913a --- /dev/null +++ b/runners/spark/4/job-server/build.gradle @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +def basePath = '../../job-server' + +project.ext { + // Look for the source code in the parent module + main_source_dirs = ["$basePath/src/main/java"] + test_source_dirs = ["$basePath/src/test/java"] + main_resources_dirs = ["$basePath/src/main/resources"] + test_resources_dirs = ["$basePath/src/test/resources"] + archives_base_name = 'beam-runners-spark-4-job-server' +} + +// Load the main build script which contains all build logic. +apply from: "$basePath/spark_job_server.gradle" diff --git a/runners/spark/4/job-server/container/Dockerfile b/runners/spark/4/job-server/container/Dockerfile new file mode 100644 index 000000000000..f40d8846f102 --- /dev/null +++ b/runners/spark/4/job-server/container/Dockerfile @@ -0,0 +1,32 @@ +############################################################################### +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################### + +FROM eclipse-temurin:17 +MAINTAINER "Apache Beam " + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y libltdl7 + +ADD beam-runners-spark-job-server.jar /opt/apache/beam/jars/ +ADD spark-job-server.sh /opt/apache/beam/ + +WORKDIR /opt/apache/beam + +COPY target/LICENSE /opt/apache/beam/ +COPY target/NOTICE /opt/apache/beam/ + +ENTRYPOINT ["./spark-job-server.sh"] diff --git a/runners/spark/4/job-server/container/build.gradle b/runners/spark/4/job-server/container/build.gradle new file mode 100644 index 000000000000..5a3a94bafc3a --- /dev/null +++ b/runners/spark/4/job-server/container/build.gradle @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +def basePath = '../../../job-server/container' + +project.ext { + resource_path = basePath + spark_job_server_image = 'spark4_job_server' +} + +// Load the main build script which contains all build logic. +apply from: "$basePath/spark_job_server_container.gradle" + +// Override the Dockerfile copy to use the Java 17 Dockerfile for Spark 4. +copyDockerfileDependencies { + // Remove the shared Dockerfile added by the shared gradle script and use the local one instead. + // The shared Dockerfile uses eclipse-temurin:11 which is incompatible with Spark 4 (requires Java 17). + exclude 'Dockerfile' +} + +task copySpark4Dockerfile(type: Copy) { + from "Dockerfile" + into "build" +} + +dockerPrepare.dependsOn copySpark4Dockerfile diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/EvaluationContext.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/EvaluationContext.java index 55c4bbaedd3c..b8448567eafc 100644 --- a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/EvaluationContext.java +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/EvaluationContext.java @@ -86,7 +86,10 @@ public static void evaluate(String name, Dataset ds) { ds.write().mode("overwrite").format("noop").save(); LOG.info("Evaluated dataset {} in {}", name, durationSince(startMs)); } catch (RuntimeException e) { - LOG.error("Failed to evaluate dataset {}: {}", name, Throwables.getRootCause(e).getMessage()); + LOG.error( + "Failed to evaluate dataset {}: {}", + name, + String.valueOf(Throwables.getRootCause(e).getMessage())); throw new RuntimeException(e); } } @@ -102,7 +105,10 @@ public static void evaluate(String name, Dataset ds) { LOG.info("Collected dataset {} in {} [size: {}]", name, durationSince(startMs), res.length); return res; } catch (Exception e) { - LOG.error("Failed to collect dataset {}: {}", name, Throwables.getRootCause(e).getMessage()); + LOG.error( + "Failed to collect dataset {}: {}", + name, + String.valueOf(Throwables.getRootCause(e).getMessage())); throw new RuntimeException(e); } } diff --git a/runners/spark/job-server/spark_job_server.gradle b/runners/spark/job-server/spark_job_server.gradle index 7e2deaf6e395..42691461c3ef 100644 --- a/runners/spark/job-server/spark_job_server.gradle +++ b/runners/spark/job-server/spark_job_server.gradle @@ -28,7 +28,10 @@ apply plugin: 'application' // we need to set mainClassName before applying shadow plugin mainClassName = "org.apache.beam.runners.spark.SparkJobServerDriver" +def parentSparkVersion = project.parent.findProperty('spark_version') ?: '' + applyJavaNature( + requireJavaVersion: (parentSparkVersion.startsWith("4") ? org.gradle.api.JavaVersion.VERSION_17 : null), automaticModuleName: 'org.apache.beam.runners.spark.jobserver', archivesBaseName: project.hasProperty('archives_base_name') ? archives_base_name : archivesBaseName, validateShadowJar: false, diff --git a/settings.gradle.kts b/settings.gradle.kts index 66c99a2c796c..d1e837b06fc6 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -151,6 +151,8 @@ include(":runners:spark:3") include(":runners:spark:3:job-server") include(":runners:spark:3:job-server:container") include(":runners:spark:4") +include(":runners:spark:4:job-server") +include(":runners:spark:4:job-server:container") include(":runners:samza") include(":runners:samza:job-server") include(":sdks:go") From b156bc8a469171ce706714c53fc20b2c088f8435 Mon Sep 17 00:00:00 2001 From: Tobias Kaymak Date: Fri, 17 Apr 2026 16:01:38 +0200 Subject: [PATCH 10/14] build: remove spark.driver.host workaround from Spark 4 build The hostname binding hack is no longer needed now that the local machine resolves its hostname to 127.0.0.1 via /etc/hosts. Co-Authored-By: Claude Opus 4.6 --- runners/spark/4/build.gradle | 4 ---- 1 file changed, 4 deletions(-) diff --git a/runners/spark/4/build.gradle b/runners/spark/4/build.gradle index 908606218b85..283dd6a01ce1 100644 --- a/runners/spark/4/build.gradle +++ b/runners/spark/4/build.gradle @@ -29,13 +29,9 @@ project.ext { // Load the main build script which contains all build logic. apply from: "$basePath/spark_runner.gradle" -// Force Spark to bind to 127.0.0.1 so tests pass on machines where the hostname -// doesn't resolve to a bindable address (e.g. mac.lan in macOS VPN environments). // Spark 4 always requires Java 17, so unconditionally add the --add-opens flags // required by Kryo and other libraries that use reflection on JDK internals. test { - systemProperty "spark.driver.host", "127.0.0.1" - systemProperty "spark.driver.bindAddress", "127.0.0.1" jvmArgs "--add-opens=java.base/sun.nio.ch=ALL-UNNAMED", "--add-opens=java.base/java.nio=ALL-UNNAMED", "--add-opens=java.base/java.util=ALL-UNNAMED", From 0d2854ab2f00e7317c70a3c0bdc507f485c529a8 Mon Sep 17 00:00:00 2001 From: Tobias Kaymak Date: Fri, 17 Apr 2026 16:01:38 +0200 Subject: [PATCH 11/14] docs: add Spark 4 runner entry to CHANGES.md Called out in /ultrareview as a missing contributor checklist item. Adds a Highlight line and a New Features / Improvements entry under the 2.74.0 Unreleased section, referencing issue #36841. --- CHANGES.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index bdcbd3451c7b..17bbcf81f291 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -60,7 +60,7 @@ ## Highlights * New highly anticipated feature X added to Python SDK ([#X](https://github.com/apache/beam/issues/X)). -* New highly anticipated feature Y added to Java SDK ([#Y](https://github.com/apache/beam/issues/Y)). +* Experimental Spark 4 structured streaming runner added (Java), built against Spark 4.0.2 / Scala 2.13 and requiring Java 17 ([#36841](https://github.com/apache/beam/issues/36841)). ## I/Os @@ -74,6 +74,10 @@ compatible. Both coders can decode encoded bytes from the other coder ([#38139](https://github.com/apache/beam/issues/38139)). * (Python) Added type alias for with_exception_handling to be used for typehints. ([#38173](https://github.com/apache/beam/issues/38173)). +* (Java) Added a Spark 4 runner module (`:runners:spark:4`) and job-server + (`:runners:spark:4:job-server`). Batch-only; streaming is not yet supported. + The shared Spark source in `runners/spark/src/` is now compatible with both + Scala 2.12 (Spark 3) and Scala 2.13 (Spark 4) ([#36841](https://github.com/apache/beam/issues/36841)). ## Breaking Changes From ef439bb4418dc7343ff74c6e8a51be3c715c3248 Mon Sep 17 00:00:00 2001 From: Tobias Kaymak Date: Fri, 17 Apr 2026 16:01:38 +0200 Subject: [PATCH 12/14] docs: explain classic.SparkSession downcast in BoundedDatasetFactory Per /ultrareview feedback: the one-line comment didn't make clear why the cast is safe. Expand it to note that SparkSession.builder() always returns a classic.SparkSession at runtime, which is why the downcast avoids reflection. --- .../spark/structuredstreaming/io/BoundedDatasetFactory.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java index 0020347b89b0..f8c200f2a61d 100644 --- a/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java +++ b/runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java @@ -86,7 +86,10 @@ public static Dataset> createDatasetFromRows( Params params = new Params<>(encoder, options, session.sparkContext().defaultParallelism()); BeamTable table = new BeamTable<>(source, params); LogicalPlan logicalPlan = DataSourceV2Relation.create(table, Option.empty(), Option.empty()); - // In Spark 4.0+, Dataset$ moved to org.apache.spark.sql.classic; cast session accordingly. + // In Spark 4.0+, Dataset$ moved to org.apache.spark.sql.classic and its ofRows() now + // takes the classic SparkSession subclass. The runtime instance returned by + // SparkSession.builder() is always a classic.SparkSession, so the downcast is safe and + // avoids reflection. return (Dataset>) Dataset$.MODULE$ .ofRows((org.apache.spark.sql.classic.SparkSession) session, logicalPlan) From 95c0af52efd6a418cc65e6e933f30342855acc01 Mon Sep 17 00:00:00 2001 From: Tobias Kaymak Date: Fri, 17 Apr 2026 16:01:38 +0200 Subject: [PATCH 13/14] fix: log warning when neither WrappedArray nor ArraySeq class is found Per /ultrareview feedback: the fallback branch silently swallowed the second ClassNotFoundException. In practice one of the two classes is always present (Scala 2.12 vs 2.13 stdlib), but a silent skip could mask a broken classpath. Emit a LOG.warn instead. --- .../spark/coders/SparkRunnerKryoRegistrator.java | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/SparkRunnerKryoRegistrator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/SparkRunnerKryoRegistrator.java index ba8c0812c9e5..8df7043443e0 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/SparkRunnerKryoRegistrator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/coders/SparkRunnerKryoRegistrator.java @@ -30,6 +30,8 @@ import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.HashBasedTable; import org.apache.spark.serializer.KryoRegistrator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Custom {@link KryoRegistrator}s for Beam's Spark runner needs and registering used class in spark @@ -45,6 +47,8 @@ }) public class SparkRunnerKryoRegistrator implements KryoRegistrator { + private static final Logger LOG = LoggerFactory.getLogger(SparkRunnerKryoRegistrator.class); + @Override public void registerClasses(Kryo kryo) { // MicrobatchSource is serialized as data and may not be Kryo-serializable. @@ -67,7 +71,11 @@ public void registerClasses(Kryo kryo) { try { kryo.register(Class.forName("scala.collection.mutable.WrappedArray$ofRef")); } catch (ClassNotFoundException ignored) { - // Neither class found; skip registration + LOG.warn( + "Neither scala.collection.mutable.ArraySeq$ofRef (Scala 2.13) nor " + + "scala.collection.mutable.WrappedArray$ofRef (Scala 2.12) was found on the " + + "classpath. Kryo serialization of Scala wrapped arrays will fall back to Java " + + "serialization or fail at runtime if spark.kryo.registrationRequired is true."); } } From 6f40a157a17237bf8a6d4e6b2c4ed4ff72b2a372 Mon Sep 17 00:00:00 2001 From: Tobias Kaymak Date: Fri, 17 Apr 2026 16:01:38 +0200 Subject: [PATCH 14/14] build: compare spark_version numerically via isSparkAtLeast helper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per /ultrareview feedback: the five `"$spark_version" >= "3.5.0"` checks were lexicographic string comparisons. They happened to work for 3.5.0 and 4.0.2 only because '4' > '3' as chars — a future "3.10.0" release would compare less than "3.5.0" and silently drop the Spark 3.5+ dependencies and exclusions. Introduce an `isSparkAtLeast` closure that tokenizes on `.` and `-`, keeps numeric parts, and compares component-by-component. Replace all five call sites. --- runners/spark/spark_runner.gradle | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/runners/spark/spark_runner.gradle b/runners/spark/spark_runner.gradle index a0e118412319..1146c7f8c11b 100644 --- a/runners/spark/spark_runner.gradle +++ b/runners/spark/spark_runner.gradle @@ -36,6 +36,16 @@ applyJavaNature( description = "Apache Beam :: Runners :: Spark $spark_version" +// Numeric version comparison (lexicographic string compare was fragile — e.g. "3.10.0" < "3.5.0"). +def isSparkAtLeast = { String minVersion -> + def parts = spark_version.tokenize('.-').findAll { it.isInteger() }*.toInteger() + def minParts = minVersion.tokenize('.')*.toInteger() + for (int i = 0; i < Math.min(parts.size(), minParts.size()); i++) { + if (parts[i] != minParts[i]) return parts[i] > minParts[i] + } + return parts.size() >= minParts.size() +} + /* * We need to rely on manually specifying these evaluationDependsOn to ensure that * the following projects are evaluated before we evaluate this project. This is because @@ -177,7 +187,7 @@ dependencies { spark.components.each { component -> provided "$component:$spark_version" } - if ("$spark_version" >= "3.5.0") { + if (isSparkAtLeast("3.5.0")) { implementation "org.apache.spark:spark-common-utils_$spark_scala_version:$spark_version" implementation "org.apache.spark:spark-sql-api_$spark_scala_version:$spark_version" } @@ -214,7 +224,7 @@ dependencies { testImplementation library.java.mockito_core testImplementation "org.assertj:assertj-core:3.11.1" testImplementation "org.apache.zookeeper:zookeeper:3.4.11" - if ("$spark_version" >= "3.5.0") { + if (isSparkAtLeast("3.5.0")) { testImplementation "org.apache.spark:spark-common-utils_$spark_scala_version:$spark_version" testImplementation "org.apache.spark:spark-sql-api_$spark_scala_version:$spark_version" } @@ -228,7 +238,7 @@ dependencies { "hadoopVersion$kv.key" "org.apache.hadoop:hadoop-common:$kv.value" // Force paranamer 2.8 to avoid issues when using Scala 2.12 "hadoopVersion$kv.key" "com.thoughtworks.paranamer:paranamer:2.8" - if ("$spark_version" >= "3.5.0") { + if (isSparkAtLeast("3.5.0")) { // Add log4j 2.x dependencies as Spark 3.5+ uses slf4j with log4j 2.x backend "hadoopVersion$kv.key" library.java.log4j2_api "hadoopVersion$kv.key" library.java.log4j2_core @@ -254,7 +264,7 @@ configurations.validatesRunner { // Exclude to make sure log4j binding is used exclude group: "org.slf4j", module: "slf4j-simple" - if ("$spark_version" >= "3.5.0") { + if (isSparkAtLeast("3.5.0")) { // Exclude log4j 1.x dependencies to prevent conflict with log4j 2.x used by spark 3.5+ exclude group: "log4j", module: "log4j" } @@ -265,7 +275,7 @@ hadoopVersions.each { kv -> resolutionStrategy { force "org.apache.hadoop:hadoop-common:$kv.value" } - if ("$spark_version" >= "3.5.0") { + if (isSparkAtLeast("3.5.0")) { // Exclude log4j 1.x dependencies to prevent conflict with log4j 2.x used by spark 3.5+ exclude group: "log4j", module: "log4j" }