diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt index eb2b6f096a1..ba91f444287 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt @@ -17,6 +17,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.commons.io.FileUtils import org.junit.Assert import org.junit.Before +import org.junit.Ignore import org.junit.Test import org.junit.runner.RunWith import org.pytorch.executorch.TestFileUtils.getTestFilePath @@ -39,49 +40,48 @@ class ModuleInstrumentationTest { inputStream.close() } + @Ignore( + "The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward " + ) @Test @Throws(IOException::class, URISyntaxException::class) fun testModuleLoadAndForward() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - try { - val results = module.forward(EValue.from(dummyInput())) - Assert.assertTrue(results[0].isTensor) - } finally { - module.destroy() - } + + val results = module.forward() + Assert.assertTrue(results[0].isTensor) } @Test @Throws(IOException::class, URISyntaxException::class) fun testMethodMetadata() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - module.destroy() } + @Ignore( + "The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward " + ) @Test @Throws(IOException::class) fun testModuleLoadMethodAndForward() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - try { - module.loadMethod(FORWARD_METHOD) - val results = module.forward(EValue.from(dummyInput())) - Assert.assertTrue(results[0].isTensor) - } finally { - module.destroy() - } + module.loadMethod(FORWARD_METHOD) + + val results = module.forward() + Assert.assertTrue(results[0].isTensor) } + @Ignore( + "The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward " + ) @Test @Throws(IOException::class) fun testModuleLoadForwardExplicit() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - try { - val results = module.execute(FORWARD_METHOD, EValue.from(dummyInput())) - Assert.assertTrue(results[0].isTensor) - } finally { - module.destroy() - } + + val results = module.execute(FORWARD_METHOD) + Assert.assertTrue(results[0].isTensor) } @Test(expected = RuntimeException::class) @@ -94,18 +94,15 @@ class ModuleInstrumentationTest { @Throws(IOException::class) fun testModuleLoadMethodNonExistantMethod() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - try { - val exception = - Assert.assertThrows(ExecutorchRuntimeException::class.java) { - module.loadMethod(NONE_METHOD) - } - Assert.assertEquals( - ExecutorchRuntimeException.INVALID_ARGUMENT, - exception.getErrorCode(), - ) - } finally { - module.destroy() - } + + val exception = + Assert.assertThrows(ExecutorchRuntimeException::class.java) { + module.loadMethod(NONE_METHOD) + } + Assert.assertEquals( + ExecutorchRuntimeException.INVALID_ARGUMENT, + exception.getErrorCode(), + ) } @Test(expected = RuntimeException::class) @@ -138,6 +135,9 @@ class ModuleInstrumentationTest { Assert.assertThrows(IllegalStateException::class.java) { module.forward() } } + @Ignore( + "The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward " + ) @Test @Throws(InterruptedException::class, IOException::class) fun testForwardFromMultipleThreads() { @@ -151,7 +151,7 @@ class ModuleInstrumentationTest { try { latch.countDown() latch.await(5000, TimeUnit.MILLISECONDS) - val results = module.forward(EValue.from(dummyInput())) + val results = module.forward() Assert.assertTrue(results[0].isTensor) completed.incrementAndGet() } catch (_: InterruptedException) {} @@ -168,7 +168,6 @@ class ModuleInstrumentationTest { } Assert.assertEquals(numThreads.toLong(), completed.get().toLong()) - module.destroy() } companion object { @@ -177,8 +176,5 @@ class ModuleInstrumentationTest { private const val NON_PTE_FILE_NAME = "/test.txt" private const val FORWARD_METHOD = "forward" private const val NONE_METHOD = "none" - private val inputShape = longArrayOf(1, 3, 224, 224) - - private fun dummyInput(): Tensor = Tensor.ones(inputShape, DType.FLOAT) } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java index 53ee4d3f33a..30ebf1a2c1d 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java @@ -36,22 +36,12 @@ public static ExecuTorchRuntime getRuntime() { /** * Validates that the given path points to a readable file. * - * @throws IllegalArgumentException if the path is null, does not exist, is not a file, or is not - * readable. + * @throws RuntimeException if the file does not exist or is not readable. */ public static void validateFilePath(String path, String description) { - if (path == null) { - throw new IllegalArgumentException("Cannot load " + description + ": path is null"); - } File file = new File(path); - if (!file.exists()) { - throw new IllegalArgumentException("Cannot load " + description + "!! " + path); - } - if (!file.isFile()) { - throw new IllegalArgumentException("Cannot load " + description + "!! " + path); - } - if (!file.canRead()) { - throw new IllegalArgumentException("Cannot load " + description + "!! " + path); + if (!file.canRead() || !file.isFile()) { + throw new RuntimeException("Cannot load " + description + " " + path); } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java index e72ed9e3d28..e0fda73cc06 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java @@ -161,11 +161,6 @@ public ExecutorchRuntimeException(int errorCode, String details) { this.errorCode = errorCode; } - public ExecutorchRuntimeException(int errorCode, String details, Throwable cause) { - super(ErrorHelper.formatMessage(errorCode, details), cause); - this.errorCode = errorCode; - } - /** Returns the numeric error code from {@code runtime/core/error.h}. */ public int getErrorCode() { return errorCode; diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java index 6cf99966e6a..05e1e5b88cf 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java @@ -12,7 +12,6 @@ import com.facebook.jni.annotations.DoNotStrip; import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; -import java.io.Closeable; import java.util.HashMap; import java.util.Map; import java.util.concurrent.locks.Lock; @@ -25,7 +24,7 @@ *
Warning: These APIs are experimental and subject to change without notice
*/
@Experimental
-public class Module implements Closeable {
+public class Module {
static {
if (!NativeLoader.isInitialized()) {
@@ -275,9 +274,7 @@ public boolean etdump() {
public void destroy() {
if (mLock.tryLock()) {
try {
- if (mHybridData.isValid()) {
- mHybridData.resetNative();
- }
+ mHybridData.resetNative();
} finally {
mLock.unlock();
}
@@ -285,9 +282,4 @@ public void destroy() {
throw new IllegalStateException("Cannot destroy module while method is executing");
}
}
-
- @Override
- public void close() {
- destroy();
- }
}
diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt
index ab9099ba405..987cb3ec3be 100644
--- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt
+++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt
@@ -11,7 +11,6 @@ package org.pytorch.executorch.extension.asr
import java.io.Closeable
import java.io.File
import java.util.concurrent.atomic.AtomicLong
-import org.pytorch.executorch.ExecutorchRuntimeException
import org.pytorch.executorch.annotations.Experimental
/**
@@ -54,10 +53,7 @@ class AsrModule(
val handle = nativeCreate(modelPath, tokenizerPath, dataPath, preprocessorPath)
if (handle == 0L) {
- throw ExecutorchRuntimeException(
- ExecutorchRuntimeException.INTERNAL,
- "Failed to create native AsrModule",
- )
+ throw RuntimeException("Failed to create native AsrModule")
}
nativeHandle.set(handle)
}
@@ -133,7 +129,7 @@ class AsrModule(
* @param callback Optional callback to receive tokens as they are generated (can be null)
* @return The complete transcribed text
* @throws IllegalStateException if the module has been destroyed
- * @throws ExecutorchRuntimeException if transcription fails (error code carried in exception)
+ * @throws RuntimeException if transcription fails (non-zero result code)
*/
@JvmOverloads
fun transcribe(
@@ -164,7 +160,7 @@ class AsrModule(
)
if (status != 0) {
- throw ExecutorchRuntimeException(status, "Transcription failed")
+ throw RuntimeException("Transcription failed with error code: $status")
}
return result.toString()
diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java
index 58c7704b83e..8f4292c1bc8 100644
--- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java
+++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java
@@ -93,7 +93,7 @@ public static SGD create(Map Warning: These APIs are experimental and subject to change without notice
*/
@Experimental
-public class TrainingModule implements Closeable {
+public class TrainingModule {
static {
if (!NativeLoader.isInitialized()) {
@@ -36,7 +37,6 @@ public class TrainingModule implements Closeable {
}
private final HybridData mHybridData;
- private boolean mDestroyed = false;
@DoNotStrip
private static native HybridData initHybrid(String moduleAbsolutePath, String dataAbsolutePath);
@@ -45,10 +45,6 @@ private TrainingModule(String moduleAbsolutePath, String dataAbsolutePath) {
mHybridData = initHybrid(moduleAbsolutePath, dataAbsolutePath);
}
- private void checkNotDestroyed() {
- if (mDestroyed) throw new IllegalStateException("TrainingModule has been destroyed");
- }
-
/**
* Loads a serialized ExecuTorch Training Module from the specified path on the disk.
*
@@ -82,7 +78,10 @@ public static TrainingModule load(final String modelPath) {
* @return return value(s) from the method.
*/
public EValue[] executeForwardBackward(String methodName, EValue... inputs) {
- checkNotDestroyed();
+ if (!mHybridData.isValid()) {
+ Log.e("ExecuTorch", "Attempt to use a destroyed module");
+ return new EValue[0];
+ }
return executeForwardBackwardNative(methodName, inputs);
}
@@ -90,7 +89,10 @@ public EValue[] executeForwardBackward(String methodName, EValue... inputs) {
private native EValue[] executeForwardBackwardNative(String methodName, EValue... inputs);
public Map