diff --git a/dev/ensure-jars-have-correct-contents.sh b/dev/ensure-jars-have-correct-contents.sh index 084936475d..e4cd7224b0 100755 --- a/dev/ensure-jars-have-correct-contents.sh +++ b/dev/ensure-jars-have-correct-contents.sh @@ -93,6 +93,7 @@ allowed_expr+="|^org/apache/spark/sql/$" allowed_expr+="|^org/apache/spark/sql/ExtendedExplainGenerator.*$" allowed_expr+="|^org/apache/spark/CometPlugin.class$" allowed_expr+="|^org/apache/spark/CometDriverPlugin.*$" +allowed_expr+="|^org/apache/spark/CometExecutorPlugin.*$" allowed_expr+="|^org/apache/spark/CometSource.*$" allowed_expr+="|^org/apache/spark/CometTaskMemoryManager.class$" allowed_expr+="|^org/apache/spark/CometTaskMemoryManager.*$" diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 4a7a21006d..b22072e6c9 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -91,7 +91,7 @@ use std::collections::HashMap; use std::path::PathBuf; use std::time::{Duration, Instant}; use std::{sync::Arc, task::Poll}; -use tokio::runtime::Runtime; +use tokio::runtime::{Handle, Runtime}; use tokio::sync::mpsc; use crate::execution::memory_pools::{ @@ -117,7 +117,7 @@ use std::sync::OnceLock; #[cfg(feature = "jemalloc")] use tikv_jemalloc_ctl::{epoch, stats}; -static TOKIO_RUNTIME: OnceLock = OnceLock::new(); +static TOKIO_RUNTIME: Mutex> = Mutex::new(None); #[cfg(feature = "jemalloc")] fn log_jemalloc_usage() { @@ -211,12 +211,39 @@ fn build_runtime(default_worker_threads: Option) -> Runtime { /// Initialize the global Tokio runtime with the given default worker thread count. /// If the runtime is already initialized, this is a no-op. pub fn init_runtime(default_worker_threads: usize) { - TOKIO_RUNTIME.get_or_init(|| build_runtime(Some(default_worker_threads))); + let mut guard = TOKIO_RUNTIME.lock(); + if guard.is_none() { + *guard = Some(build_runtime(Some(default_worker_threads))); + } +} + +/// Returns a handle to the global Tokio runtime, lazily initializing it if needed. +/// +/// A [`Handle`] is returned (rather than a `&'static Runtime`) so that the runtime +/// can be torn down via [`release_runtime`]. The handle is cheap to clone and can be +/// used with `spawn` / `block_on` just like a `Runtime`. +pub fn get_runtime() -> Handle { + let mut guard = TOKIO_RUNTIME.lock(); + guard + .get_or_insert_with(|| build_runtime(None)) + .handle() + .clone() } -/// Function to get a handle to the global Tokio runtime -pub fn get_runtime() -> &'static Runtime { - TOKIO_RUNTIME.get_or_init(|| build_runtime(None)) +/// Tears down the global Tokio runtime, if it has been initialized. +/// +/// The runtime is moved out of the global slot and shut down in the background so the +/// calling (JNI) thread is not blocked waiting for worker threads to finish. Any handles +/// previously returned by [`get_runtime`] will start failing their spawns once the runtime +/// is gone, so this must only be called when no native execution is in flight. +/// +/// Must not be called from within the runtime's own worker threads, otherwise the shutdown +/// would deadlock/panic. +pub fn release_runtime() { + let runtime = TOKIO_RUNTIME.lock().take(); + if let Some(runtime) = runtime { + runtime.shutdown_background(); + } } /// Returns a short name for an OpStruct variant. diff --git a/native/core/src/execution/operators/iceberg_scan.rs b/native/core/src/execution/operators/iceberg_scan.rs index 713b4089b0..090f5813ac 100644 --- a/native/core/src/execution/operators/iceberg_scan.rs +++ b/native/core/src/execution/operators/iceberg_scan.rs @@ -176,16 +176,24 @@ impl IcebergScanExec { let task_stream = futures::stream::iter(tasks.into_iter().map(Ok)).boxed(); - // iceberg-rust's ArrowReader spawns IO/CPU work onto an iceberg::Runtime. execute() runs - // on the JVM-called thread outside any tokio context, so Runtime::current() would panic; - // build it from Comet's global runtime, which is where the stream is later polled. - let reader = - iceberg::arrow::ArrowReaderBuilder::new(file_io, IcebergRuntime::new(get_runtime())) - .with_batch_size(batch_size) - .with_data_file_concurrency_limit(self.data_file_concurrency_limit) - .with_row_selection_enabled(true) - .with_metadata_size_hint(512 * 1024) // Same as DataFusion's default - .build(); + // iceberg-rust's ArrowReader spawns IO/CPU work onto an iceberg::Runtime, which only needs + // a tokio handle. execute() runs on the JVM-called thread outside any tokio context, so we + // enter Comet's global runtime to capture its handle (this is where the stream is later + // polled). Capturing the handle rather than borrowing the runtime keeps it tear-downable + // via release_runtime. + let iceberg_runtime = { + let handle = get_runtime(); + let _guard = handle.enter(); + IcebergRuntime::try_current().map_err(|e| { + DataFusionError::Execution(format!("Failed to build Iceberg runtime: {e}")) + })? + }; + let reader = iceberg::arrow::ArrowReaderBuilder::new(file_io, iceberg_runtime) + .with_batch_size(batch_size) + .with_data_file_concurrency_limit(self.data_file_concurrency_limit) + .with_row_selection_enabled(true) + .with_metadata_size_hint(512 * 1024) // Same as DataFusion's default + .build(); // Pass all tasks to iceberg-rust at once to utilize its flatten_unordered // parallelization, avoiding overhead of single-task streams diff --git a/native/core/src/lib.rs b/native/core/src/lib.rs index 7d15c761ca..48e17bb502 100644 --- a/native/core/src/lib.rs +++ b/native/core/src/lib.rs @@ -125,6 +125,12 @@ pub extern "system" fn Java_org_apache_comet_NativeBase_init( }) } +#[no_mangle] +/// Releases the global Tokio runtime used by Comet native execution. +pub extern "system" fn Java_org_apache_comet_NativeBase_release(_e: EnvUnowned, _class: JClass) { + execution::jni_api::release_runtime(); +} + const LOG_PATTERN: &str = "{d(%y/%m/%d %H:%M:%S)} {l} {f}: {m}{n}"; /// JNI method to check if a specific feature is enabled in the native Rust code. diff --git a/spark/src/main/java/org/apache/comet/NativeBase.java b/spark/src/main/java/org/apache/comet/NativeBase.java index e2fcbb24a7..e181704584 100644 --- a/spark/src/main/java/org/apache/comet/NativeBase.java +++ b/spark/src/main/java/org/apache/comet/NativeBase.java @@ -293,6 +293,9 @@ private static String resourceName() { */ static native void init(String logConfPath, String logLevel); + /** Release native resources */ + public static native void release(); + /** * Check if a specific feature is enabled in the native library. * diff --git a/spark/src/main/scala/org/apache/spark/Plugins.scala b/spark/src/main/scala/org/apache/spark/Plugins.scala index 7290ab436a..54df98d180 100644 --- a/spark/src/main/scala/org/apache/spark/Plugins.scala +++ b/spark/src/main/scala/org/apache/spark/Plugins.scala @@ -28,8 +28,8 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD, EXECUTOR_MEMORY_OVERHEAD_FACTOR} import org.apache.spark.sql.internal.StaticSQLConf +import org.apache.comet.{CometSparkSessionExtensions, NativeBase} import org.apache.comet.CometConf.{COMET_METRICS_ENABLED, COMET_ONHEAP_ENABLED} -import org.apache.comet.CometSparkSessionExtensions /** * Comet driver plugin. This class is loaded by Spark's plugin framework. It will be instantiated @@ -95,6 +95,10 @@ class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPl override def shutdown(): Unit = { logInfo("CometDriverPlugin shutdown") + if (NativeBase.isLoaded) { + NativeBase.release() + } + super.shutdown() } @@ -148,6 +152,26 @@ object CometDriverPlugin extends Logging { } } +class CometExecutorPlugin extends ExecutorPlugin with Logging { + + override def init(ctx: PluginContext, extraConf: ju.Map[String, String]): Unit = { + logInfo("CometExecutorPlugin init") + + super.init(ctx, extraConf) + } + + override def shutdown(): Unit = { + logInfo("CometExecutorPlugin shutdown") + + if (NativeBase.isLoaded) { + NativeBase.release() + } + + super.shutdown() + } + +} + /** * The Comet plugin for Spark. To enable this plugin, set the config "spark.plugins" to * `org.apache.spark.CometPlugin` @@ -155,5 +179,5 @@ object CometDriverPlugin extends Logging { class CometPlugin extends SparkPlugin with Logging { override def driverPlugin(): DriverPlugin = new CometDriverPlugin - override def executorPlugin(): ExecutorPlugin = null + override def executorPlugin(): ExecutorPlugin = new CometExecutorPlugin }