diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt index eb2b6f096a1..ba91f444287 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt @@ -17,6 +17,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.commons.io.FileUtils import org.junit.Assert import org.junit.Before +import org.junit.Ignore import org.junit.Test import org.junit.runner.RunWith import org.pytorch.executorch.TestFileUtils.getTestFilePath @@ -39,49 +40,48 @@ class ModuleInstrumentationTest { inputStream.close() } + @Ignore( + "The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward " + ) @Test @Throws(IOException::class, URISyntaxException::class) fun testModuleLoadAndForward() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - try { - val results = module.forward(EValue.from(dummyInput())) - Assert.assertTrue(results[0].isTensor) - } finally { - module.destroy() - } + + val results = module.forward() + Assert.assertTrue(results[0].isTensor) } @Test @Throws(IOException::class, URISyntaxException::class) fun testMethodMetadata() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - module.destroy() } + @Ignore( + "The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward " + ) @Test @Throws(IOException::class) fun testModuleLoadMethodAndForward() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - try { - module.loadMethod(FORWARD_METHOD) - val results = module.forward(EValue.from(dummyInput())) - Assert.assertTrue(results[0].isTensor) - } finally { - module.destroy() - } + module.loadMethod(FORWARD_METHOD) + + val results = module.forward() + Assert.assertTrue(results[0].isTensor) } + @Ignore( + "The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward " + ) @Test @Throws(IOException::class) fun testModuleLoadForwardExplicit() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - try { - val results = module.execute(FORWARD_METHOD, EValue.from(dummyInput())) - Assert.assertTrue(results[0].isTensor) - } finally { - module.destroy() - } + + val results = module.execute(FORWARD_METHOD) + Assert.assertTrue(results[0].isTensor) } @Test(expected = RuntimeException::class) @@ -94,18 +94,15 @@ class ModuleInstrumentationTest { @Throws(IOException::class) fun testModuleLoadMethodNonExistantMethod() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - try { - val exception = - Assert.assertThrows(ExecutorchRuntimeException::class.java) { - module.loadMethod(NONE_METHOD) - } - Assert.assertEquals( - ExecutorchRuntimeException.INVALID_ARGUMENT, - exception.getErrorCode(), - ) - } finally { - module.destroy() - } + + val exception = + Assert.assertThrows(ExecutorchRuntimeException::class.java) { + module.loadMethod(NONE_METHOD) + } + Assert.assertEquals( + ExecutorchRuntimeException.INVALID_ARGUMENT, + exception.getErrorCode(), + ) } @Test(expected = RuntimeException::class) @@ -138,6 +135,9 @@ class ModuleInstrumentationTest { Assert.assertThrows(IllegalStateException::class.java) { module.forward() } } + @Ignore( + "The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward " + ) @Test @Throws(InterruptedException::class, IOException::class) fun testForwardFromMultipleThreads() { @@ -151,7 +151,7 @@ class ModuleInstrumentationTest { try { latch.countDown() latch.await(5000, TimeUnit.MILLISECONDS) - val results = module.forward(EValue.from(dummyInput())) + val results = module.forward() Assert.assertTrue(results[0].isTensor) completed.incrementAndGet() } catch (_: InterruptedException) {} @@ -168,7 +168,6 @@ class ModuleInstrumentationTest { } Assert.assertEquals(numThreads.toLong(), completed.get().toLong()) - module.destroy() } companion object { @@ -177,8 +176,5 @@ class ModuleInstrumentationTest { private const val NON_PTE_FILE_NAME = "/test.txt" private const val FORWARD_METHOD = "forward" private const val NONE_METHOD = "none" - private val inputShape = longArrayOf(1, 3, 224, 224) - - private fun dummyInput(): Tensor = Tensor.ones(inputShape, DType.FLOAT) } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java index 53ee4d3f33a..30ebf1a2c1d 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java @@ -36,22 +36,12 @@ public static ExecuTorchRuntime getRuntime() { /** * Validates that the given path points to a readable file. * - * @throws IllegalArgumentException if the path is null, does not exist, is not a file, or is not - * readable. + * @throws RuntimeException if the file does not exist or is not readable. */ public static void validateFilePath(String path, String description) { - if (path == null) { - throw new IllegalArgumentException("Cannot load " + description + ": path is null"); - } File file = new File(path); - if (!file.exists()) { - throw new IllegalArgumentException("Cannot load " + description + "!! " + path); - } - if (!file.isFile()) { - throw new IllegalArgumentException("Cannot load " + description + "!! " + path); - } - if (!file.canRead()) { - throw new IllegalArgumentException("Cannot load " + description + "!! " + path); + if (!file.canRead() || !file.isFile()) { + throw new RuntimeException("Cannot load " + description + " " + path); } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java index e72ed9e3d28..e0fda73cc06 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java @@ -161,11 +161,6 @@ public ExecutorchRuntimeException(int errorCode, String details) { this.errorCode = errorCode; } - public ExecutorchRuntimeException(int errorCode, String details, Throwable cause) { - super(ErrorHelper.formatMessage(errorCode, details), cause); - this.errorCode = errorCode; - } - /** Returns the numeric error code from {@code runtime/core/error.h}. */ public int getErrorCode() { return errorCode; diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java index 6cf99966e6a..05e1e5b88cf 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java @@ -12,7 +12,6 @@ import com.facebook.jni.annotations.DoNotStrip; import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; -import java.io.Closeable; import java.util.HashMap; import java.util.Map; import java.util.concurrent.locks.Lock; @@ -25,7 +24,7 @@ *

Warning: These APIs are experimental and subject to change without notice */ @Experimental -public class Module implements Closeable { +public class Module { static { if (!NativeLoader.isInitialized()) { @@ -275,9 +274,7 @@ public boolean etdump() { public void destroy() { if (mLock.tryLock()) { try { - if (mHybridData.isValid()) { - mHybridData.resetNative(); - } + mHybridData.resetNative(); } finally { mLock.unlock(); } @@ -285,9 +282,4 @@ public void destroy() { throw new IllegalStateException("Cannot destroy module while method is executing"); } } - - @Override - public void close() { - destroy(); - } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt index ab9099ba405..987cb3ec3be 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt @@ -11,7 +11,6 @@ package org.pytorch.executorch.extension.asr import java.io.Closeable import java.io.File import java.util.concurrent.atomic.AtomicLong -import org.pytorch.executorch.ExecutorchRuntimeException import org.pytorch.executorch.annotations.Experimental /** @@ -54,10 +53,7 @@ class AsrModule( val handle = nativeCreate(modelPath, tokenizerPath, dataPath, preprocessorPath) if (handle == 0L) { - throw ExecutorchRuntimeException( - ExecutorchRuntimeException.INTERNAL, - "Failed to create native AsrModule", - ) + throw RuntimeException("Failed to create native AsrModule") } nativeHandle.set(handle) } @@ -133,7 +129,7 @@ class AsrModule( * @param callback Optional callback to receive tokens as they are generated (can be null) * @return The complete transcribed text * @throws IllegalStateException if the module has been destroyed - * @throws ExecutorchRuntimeException if transcription fails (error code carried in exception) + * @throws RuntimeException if transcription fails (non-zero result code) */ @JvmOverloads fun transcribe( @@ -164,7 +160,7 @@ class AsrModule( ) if (status != 0) { - throw ExecutorchRuntimeException(status, "Transcription failed") + throw RuntimeException("Transcription failed with error code: $status") } return result.toString() diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java index 58c7704b83e..8f4292c1bc8 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java @@ -93,7 +93,7 @@ public static SGD create(Map namedParameters, double learningRat */ public void step(Map namedGradients) { if (!mHybridData.isValid()) { - throw new IllegalStateException("SGD optimizer has been destroyed"); + throw new RuntimeException("Attempt to use a destroyed SGD optimizer"); } stepNative(namedGradients); } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java index ca4bac9aa54..4a6653cb7a1 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java @@ -8,11 +8,12 @@ package org.pytorch.executorch.training; +import android.util.Log; import com.facebook.jni.HybridData; import com.facebook.jni.annotations.DoNotStrip; import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; -import java.io.Closeable; +import java.util.HashMap; import java.util.Map; import org.pytorch.executorch.EValue; import org.pytorch.executorch.ExecuTorchRuntime; @@ -25,7 +26,7 @@ *

Warning: These APIs are experimental and subject to change without notice */ @Experimental -public class TrainingModule implements Closeable { +public class TrainingModule { static { if (!NativeLoader.isInitialized()) { @@ -36,7 +37,6 @@ public class TrainingModule implements Closeable { } private final HybridData mHybridData; - private boolean mDestroyed = false; @DoNotStrip private static native HybridData initHybrid(String moduleAbsolutePath, String dataAbsolutePath); @@ -45,10 +45,6 @@ private TrainingModule(String moduleAbsolutePath, String dataAbsolutePath) { mHybridData = initHybrid(moduleAbsolutePath, dataAbsolutePath); } - private void checkNotDestroyed() { - if (mDestroyed) throw new IllegalStateException("TrainingModule has been destroyed"); - } - /** * Loads a serialized ExecuTorch Training Module from the specified path on the disk. * @@ -82,7 +78,10 @@ public static TrainingModule load(final String modelPath) { * @return return value(s) from the method. */ public EValue[] executeForwardBackward(String methodName, EValue... inputs) { - checkNotDestroyed(); + if (!mHybridData.isValid()) { + Log.e("ExecuTorch", "Attempt to use a destroyed module"); + return new EValue[0]; + } return executeForwardBackwardNative(methodName, inputs); } @@ -90,7 +89,10 @@ public EValue[] executeForwardBackward(String methodName, EValue... inputs) { private native EValue[] executeForwardBackwardNative(String methodName, EValue... inputs); public Map namedParameters(String methodName) { - checkNotDestroyed(); + if (!mHybridData.isValid()) { + Log.e("ExecuTorch", "Attempt to use a destroyed module"); + return new HashMap(); + } return namedParametersNative(methodName); } @@ -98,17 +100,13 @@ public Map namedParameters(String methodName) { private native Map namedParametersNative(String methodName); public Map namedGradients(String methodName) { - checkNotDestroyed(); + if (!mHybridData.isValid()) { + Log.e("ExecuTorch", "Attempt to use a destroyed module"); + return new HashMap(); + } return namedGradientsNative(methodName); } @DoNotStrip private native Map namedGradientsNative(String methodName); - - @Override - public void close() { - if (mDestroyed) return; - mDestroyed = true; - mHybridData.resetNative(); - } } diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 0cf08e41983..88e9f9e2a12 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -284,18 +284,8 @@ class ExecuTorchJni : public facebook::jni::HybridClass { #else auto etdump_gen = nullptr; #endif - try { - module_ = std::make_unique( - modelPath->toStdString(), load_mode, std::move(etdump_gen)); - } catch (const std::exception& e) { - executorch::jni_helper::throwExecutorchException( - static_cast(Error::Internal), - std::string("Failed to create Module: ") + e.what()); - } catch (...) { - executorch::jni_helper::throwExecutorchException( - static_cast(Error::Internal), - "Failed to create Module: unknown native error"); - } + module_ = std::make_unique( + modelPath->toStdString(), load_mode, std::move(etdump_gen)); #ifdef ET_USE_THREADPOOL // Default to using cores/2 threadpool threads. The long-term plan is to diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 0c1ff5c67b9..2c0117dc576 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -10,7 +10,6 @@ #include #include #include -#include #include #include #include @@ -149,117 +148,103 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { jint num_bos = 0, jint num_eos = 0, jint load_mode = 1) { - try { - temperature_ = temperature; - num_bos_ = num_bos; - num_eos_ = num_eos; + temperature_ = temperature; + num_bos_ = num_bos; + num_eos_ = num_eos; #if defined(ET_USE_THREADPOOL) - // Reserve 1 thread for the main thread. - int32_t num_performant_cores = - ::executorch::extension::cpuinfo::get_num_performant_cores() - 1; - if (num_performant_cores > 0) { - ET_LOG( - Info, "Resetting threadpool to %d threads", num_performant_cores); - ::executorch::extension::threadpool::get_threadpool() - ->_unsafe_reset_threadpool(num_performant_cores); - } + // Reserve 1 thread for the main thread. + int32_t num_performant_cores = + ::executorch::extension::cpuinfo::get_num_performant_cores() - 1; + if (num_performant_cores > 0) { + ET_LOG(Info, "Resetting threadpool to %d threads", num_performant_cores); + ::executorch::extension::threadpool::get_threadpool() + ->_unsafe_reset_threadpool(num_performant_cores); + } #endif - model_type_category_ = model_type_category; - auto cpp_load_mode = load_mode_from_int(load_mode); - std::vector data_files_vector; - if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) { - runner_ = llm::create_multimodal_runner( - model_path->toStdString().c_str(), - llm::load_tokenizer(tokenizer_path->toStdString()), - std::nullopt, - cpp_load_mode); - } else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) { - if (data_files != nullptr) { - // Convert Java List to C++ std::vector - auto list_class = facebook::jni::findClassStatic("java/util/List"); - auto size_method = list_class->getMethod("size"); - auto get_method = - list_class->getMethod(jint)>( - "get"); - - jint size = size_method(data_files); - for (jint i = 0; i < size; ++i) { - auto str_obj = get_method(data_files, i); - auto jstr = facebook::jni::static_ref_cast(str_obj); - data_files_vector.push_back(jstr->toStdString()); - } + model_type_category_ = model_type_category; + auto cpp_load_mode = load_mode_from_int(load_mode); + std::vector data_files_vector; + if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) { + runner_ = llm::create_multimodal_runner( + model_path->toStdString().c_str(), + llm::load_tokenizer(tokenizer_path->toStdString()), + std::nullopt, + cpp_load_mode); + } else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) { + if (data_files != nullptr) { + // Convert Java List to C++ std::vector + auto list_class = facebook::jni::findClassStatic("java/util/List"); + auto size_method = list_class->getMethod("size"); + auto get_method = + list_class->getMethod(jint)>( + "get"); + + jint size = size_method(data_files); + for (jint i = 0; i < size; ++i) { + auto str_obj = get_method(data_files, i); + auto jstr = facebook::jni::static_ref_cast(str_obj); + data_files_vector.push_back(jstr->toStdString()); } - runner_ = executorch::extension::llm::create_text_llm_runner( - model_path->toStdString(), - llm::load_tokenizer(tokenizer_path->toStdString()), - data_files_vector, - /*temperature=*/-1.0f, - /*event_tracer=*/nullptr, - /*method_name=*/"forward", - cpp_load_mode); + } + runner_ = executorch::extension::llm::create_text_llm_runner( + model_path->toStdString(), + llm::load_tokenizer(tokenizer_path->toStdString()), + data_files_vector, + /*temperature=*/-1.0f, + /*event_tracer=*/nullptr, + /*method_name=*/"forward", + cpp_load_mode); #if defined(EXECUTORCH_BUILD_QNN) - } else if (model_type_category == MODEL_TYPE_QNN_LLAMA) { - std::unique_ptr module = - std::make_unique( - model_path->toStdString().c_str(), - data_files_vector, - cpp_load_mode); - std::string decoder_model = "llama3"; // use llama3 for now - // Using 8bit as default since this meta is introduced with 16bit kv io - // support and older models only have 8bit kv io. - example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8; - if (module->method_names()->count("get_kv_io_bit_width") > 0) { - kv_bitwidth = static_cast( - module->get("get_kv_io_bit_width") - .get() - .toScalar() - .to()); - } - - if (kv_bitwidth == example::KvBitWidth::kWidth8) { - runner_ = std::make_unique>( - std::move(module), - decoder_model.c_str(), - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str(), - "", - "", - temperature_); - } else if (kv_bitwidth == example::KvBitWidth::kWidth16) { - runner_ = std::make_unique>( - std::move(module), - decoder_model.c_str(), + } else if (model_type_category == MODEL_TYPE_QNN_LLAMA) { + std::unique_ptr module = + std::make_unique( model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str(), - "", - "", - temperature_); - } else { - ET_CHECK_MSG( - false, - "Unsupported kv bitwidth: %ld", - static_cast(kv_bitwidth)); - } - model_type_category_ = MODEL_TYPE_CATEGORY_LLM; + data_files_vector, + cpp_load_mode); + std::string decoder_model = "llama3"; // use llama3 for now + // Using 8bit as default since this meta is introduced with 16bit kv io + // support and older models only have 8bit kv io. + example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8; + if (module->method_names()->count("get_kv_io_bit_width") > 0) { + kv_bitwidth = static_cast( + module->get("get_kv_io_bit_width").get().toScalar().to()); + } + + if (kv_bitwidth == example::KvBitWidth::kWidth8) { + runner_ = std::make_unique>( + std::move(module), + decoder_model.c_str(), + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str(), + "", + "", + temperature_); + } else if (kv_bitwidth == example::KvBitWidth::kWidth16) { + runner_ = std::make_unique>( + std::move(module), + decoder_model.c_str(), + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str(), + "", + "", + temperature_); + } else { + ET_CHECK_MSG( + false, + "Unsupported kv bitwidth: %ld", + static_cast(kv_bitwidth)); + } + model_type_category_ = MODEL_TYPE_CATEGORY_LLM; #endif #if defined(EXECUTORCH_BUILD_MEDIATEK) - } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) { - runner_ = std::make_unique( - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str()); - // Interpret the model type as LLM - model_type_category_ = MODEL_TYPE_CATEGORY_LLM; + } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) { + runner_ = std::make_unique( + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str()); + // Interpret the model type as LLM + model_type_category_ = MODEL_TYPE_CATEGORY_LLM; #endif - } - } catch (const std::exception& e) { - executorch::jni_helper::throwExecutorchException( - static_cast(Error::Internal), - std::string("Failed to create LlmModule: ") + e.what()); - } catch (...) { - executorch::jni_helper::throwExecutorchException( - static_cast(Error::Internal), - "Failed to create LlmModule: unknown native error"); } } @@ -609,19 +594,21 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { jint load() { if (!runner_) { - std::stringstream ss; - ss << "Model runner was not created. model_type_category=" - << model_type_category_ - << ". Valid values: " << MODEL_TYPE_CATEGORY_LLM << " (LLM), " - << MODEL_TYPE_CATEGORY_MULTIMODAL << " (Multimodal)"; - executorch::jni_helper::throwExecutorchException( - static_cast(Error::InvalidState), ss.str().c_str()); + ET_LOG( + Error, + "ExecuTorchLlmJni::load() called but runner_ is null. " + "The model runner was not created or failed to initialize due to a " + "previous configuration or initialization error. " + "Model type category: %d.", + model_type_category_); return static_cast(Error::InvalidState); } const auto load_result = static_cast(runner_->load()); if (load_result != static_cast(Error::Ok)) { - executorch::jni_helper::throwExecutorchException( - static_cast(load_result), "Failed to load model runner"); + ET_LOG( + Error, + "ExecuTorchLlmJni::load() failed in runner_->load() with error code %d.", + static_cast(load_result)); } return load_result; }