Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import java.util.concurrent.atomic.AtomicInteger
import org.apache.commons.io.FileUtils
import org.junit.Assert
import org.junit.Before
import org.junit.Ignore
import org.junit.Test
import org.junit.runner.RunWith
import org.pytorch.executorch.TestFileUtils.getTestFilePath
Expand All @@ -39,49 +40,48 @@ class ModuleInstrumentationTest {
inputStream.close()
}

@Ignore(
"The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward "
)
@Test
@Throws(IOException::class, URISyntaxException::class)
fun testModuleLoadAndForward() {
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
try {
val results = module.forward(EValue.from(dummyInput()))
Assert.assertTrue(results[0].isTensor)
} finally {
module.destroy()
}

val results = module.forward()
Assert.assertTrue(results[0].isTensor)
}

@Test
@Throws(IOException::class, URISyntaxException::class)
fun testMethodMetadata() {
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
module.destroy()
}

@Ignore(
"The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward "
)
@Test
@Throws(IOException::class)
fun testModuleLoadMethodAndForward() {
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
try {
module.loadMethod(FORWARD_METHOD)

val results = module.forward(EValue.from(dummyInput()))
Assert.assertTrue(results[0].isTensor)
} finally {
module.destroy()
}
module.loadMethod(FORWARD_METHOD)

val results = module.forward()
Assert.assertTrue(results[0].isTensor)
}

@Ignore(
"The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward "
)
@Test
@Throws(IOException::class)
fun testModuleLoadForwardExplicit() {
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
try {
val results = module.execute(FORWARD_METHOD, EValue.from(dummyInput()))
Assert.assertTrue(results[0].isTensor)
} finally {
module.destroy()
}

val results = module.execute(FORWARD_METHOD)
Assert.assertTrue(results[0].isTensor)
}

@Test(expected = RuntimeException::class)
Expand All @@ -94,18 +94,15 @@ class ModuleInstrumentationTest {
@Throws(IOException::class)
fun testModuleLoadMethodNonExistantMethod() {
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
try {
val exception =
Assert.assertThrows(ExecutorchRuntimeException::class.java) {
module.loadMethod(NONE_METHOD)
}
Assert.assertEquals(
ExecutorchRuntimeException.INVALID_ARGUMENT,
exception.getErrorCode(),
)
} finally {
module.destroy()
}

val exception =
Assert.assertThrows(ExecutorchRuntimeException::class.java) {
module.loadMethod(NONE_METHOD)
}
Assert.assertEquals(
ExecutorchRuntimeException.INVALID_ARGUMENT,
exception.getErrorCode(),
)
}

@Test(expected = RuntimeException::class)
Expand Down Expand Up @@ -138,6 +135,9 @@ class ModuleInstrumentationTest {
Assert.assertThrows(IllegalStateException::class.java) { module.forward() }
}

@Ignore(
"The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward "
)
@Test
@Throws(InterruptedException::class, IOException::class)
fun testForwardFromMultipleThreads() {
Expand All @@ -151,7 +151,7 @@ class ModuleInstrumentationTest {
try {
latch.countDown()
latch.await(5000, TimeUnit.MILLISECONDS)
val results = module.forward(EValue.from(dummyInput()))
val results = module.forward()
Assert.assertTrue(results[0].isTensor)
completed.incrementAndGet()
} catch (_: InterruptedException) {}
Expand All @@ -168,7 +168,6 @@ class ModuleInstrumentationTest {
}

Assert.assertEquals(numThreads.toLong(), completed.get().toLong())
module.destroy()
}

companion object {
Expand All @@ -177,8 +176,5 @@ class ModuleInstrumentationTest {
private const val NON_PTE_FILE_NAME = "/test.txt"
private const val FORWARD_METHOD = "forward"
private const val NONE_METHOD = "none"
private val inputShape = longArrayOf(1, 3, 224, 224)

private fun dummyInput(): Tensor = Tensor.ones(inputShape, DType.FLOAT)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,12 @@ public static ExecuTorchRuntime getRuntime() {
/**
* Validates that the given path points to a readable file.
*
* @throws IllegalArgumentException if the path is null, does not exist, is not a file, or is not
* readable.
* @throws RuntimeException if the file does not exist or is not readable.
*/
public static void validateFilePath(String path, String description) {
if (path == null) {
throw new IllegalArgumentException("Cannot load " + description + ": path is null");
}
File file = new File(path);
if (!file.exists()) {
throw new IllegalArgumentException("Cannot load " + description + "!! " + path);
}
if (!file.isFile()) {
throw new IllegalArgumentException("Cannot load " + description + "!! " + path);
}
if (!file.canRead()) {
throw new IllegalArgumentException("Cannot load " + description + "!! " + path);
if (!file.canRead() || !file.isFile()) {
throw new RuntimeException("Cannot load " + description + " " + path);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,6 @@ public ExecutorchRuntimeException(int errorCode, String details) {
this.errorCode = errorCode;
}

public ExecutorchRuntimeException(int errorCode, String details, Throwable cause) {
super(ErrorHelper.formatMessage(errorCode, details), cause);
this.errorCode = errorCode;
}

/** Returns the numeric error code from {@code runtime/core/error.h}. */
public int getErrorCode() {
return errorCode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import com.facebook.jni.annotations.DoNotStrip;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
import java.io.Closeable;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.locks.Lock;
Expand All @@ -25,7 +24,7 @@
* <p>Warning: These APIs are experimental and subject to change without notice
*/
@Experimental
public class Module implements Closeable {
public class Module {

static {
if (!NativeLoader.isInitialized()) {
Expand Down Expand Up @@ -275,19 +274,12 @@ public boolean etdump() {
public void destroy() {
if (mLock.tryLock()) {
try {
if (mHybridData.isValid()) {
mHybridData.resetNative();
}
mHybridData.resetNative();
} finally {
mLock.unlock();
}
} else {
throw new IllegalStateException("Cannot destroy module while method is executing");
}
}

@Override
public void close() {
destroy();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ package org.pytorch.executorch.extension.asr
import java.io.Closeable
import java.io.File
import java.util.concurrent.atomic.AtomicLong
import org.pytorch.executorch.ExecutorchRuntimeException
import org.pytorch.executorch.annotations.Experimental

/**
Expand Down Expand Up @@ -54,10 +53,7 @@ class AsrModule(

val handle = nativeCreate(modelPath, tokenizerPath, dataPath, preprocessorPath)
if (handle == 0L) {
throw ExecutorchRuntimeException(
ExecutorchRuntimeException.INTERNAL,
"Failed to create native AsrModule",
)
throw RuntimeException("Failed to create native AsrModule")
}
nativeHandle.set(handle)
}
Expand Down Expand Up @@ -133,7 +129,7 @@ class AsrModule(
* @param callback Optional callback to receive tokens as they are generated (can be null)
* @return The complete transcribed text
* @throws IllegalStateException if the module has been destroyed
* @throws ExecutorchRuntimeException if transcription fails (error code carried in exception)
* @throws RuntimeException if transcription fails (non-zero result code)
*/
@JvmOverloads
fun transcribe(
Expand Down Expand Up @@ -164,7 +160,7 @@ class AsrModule(
)

if (status != 0) {
throw ExecutorchRuntimeException(status, "Transcription failed")
throw RuntimeException("Transcription failed with error code: $status")
}

return result.toString()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public static SGD create(Map<String, Tensor> namedParameters, double learningRat
*/
public void step(Map<String, Tensor> namedGradients) {
if (!mHybridData.isValid()) {
throw new IllegalStateException("SGD optimizer has been destroyed");
throw new RuntimeException("Attempt to use a destroyed SGD optimizer");
}
stepNative(namedGradients);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@

package org.pytorch.executorch.training;

import android.util.Log;
import com.facebook.jni.HybridData;
import com.facebook.jni.annotations.DoNotStrip;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
import java.io.Closeable;
import java.util.HashMap;
import java.util.Map;
import org.pytorch.executorch.EValue;
import org.pytorch.executorch.ExecuTorchRuntime;
Expand All @@ -25,7 +26,7 @@
* <p>Warning: These APIs are experimental and subject to change without notice
*/
@Experimental
public class TrainingModule implements Closeable {
public class TrainingModule {

static {
if (!NativeLoader.isInitialized()) {
Expand All @@ -36,7 +37,6 @@ public class TrainingModule implements Closeable {
}

private final HybridData mHybridData;
private boolean mDestroyed = false;

@DoNotStrip
private static native HybridData initHybrid(String moduleAbsolutePath, String dataAbsolutePath);
Expand All @@ -45,10 +45,6 @@ private TrainingModule(String moduleAbsolutePath, String dataAbsolutePath) {
mHybridData = initHybrid(moduleAbsolutePath, dataAbsolutePath);
}

private void checkNotDestroyed() {
if (mDestroyed) throw new IllegalStateException("TrainingModule has been destroyed");
}

/**
* Loads a serialized ExecuTorch Training Module from the specified path on the disk.
*
Expand Down Expand Up @@ -82,33 +78,35 @@ public static TrainingModule load(final String modelPath) {
* @return return value(s) from the method.
*/
public EValue[] executeForwardBackward(String methodName, EValue... inputs) {
checkNotDestroyed();
if (!mHybridData.isValid()) {
Log.e("ExecuTorch", "Attempt to use a destroyed module");
return new EValue[0];
}
return executeForwardBackwardNative(methodName, inputs);
}

@DoNotStrip
private native EValue[] executeForwardBackwardNative(String methodName, EValue... inputs);

public Map<String, Tensor> namedParameters(String methodName) {
checkNotDestroyed();
if (!mHybridData.isValid()) {
Log.e("ExecuTorch", "Attempt to use a destroyed module");
return new HashMap<String, Tensor>();
}
return namedParametersNative(methodName);
}

@DoNotStrip
private native Map<String, Tensor> namedParametersNative(String methodName);

public Map<String, Tensor> namedGradients(String methodName) {
checkNotDestroyed();
if (!mHybridData.isValid()) {
Log.e("ExecuTorch", "Attempt to use a destroyed module");
return new HashMap<String, Tensor>();
}
return namedGradientsNative(methodName);
}

@DoNotStrip
private native Map<String, Tensor> namedGradientsNative(String methodName);

@Override
public void close() {
if (mDestroyed) return;
mDestroyed = true;
mHybridData.resetNative();
}
}
14 changes: 2 additions & 12 deletions extension/android/jni/jni_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,18 +284,8 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
#else
auto etdump_gen = nullptr;
#endif
try {
module_ = std::make_unique<Module>(
modelPath->toStdString(), load_mode, std::move(etdump_gen));
} catch (const std::exception& e) {
executorch::jni_helper::throwExecutorchException(
static_cast<uint32_t>(Error::Internal),
std::string("Failed to create Module: ") + e.what());
} catch (...) {
executorch::jni_helper::throwExecutorchException(
static_cast<uint32_t>(Error::Internal),
"Failed to create Module: unknown native error");
}
module_ = std::make_unique<Module>(
modelPath->toStdString(), load_mode, std::move(etdump_gen));

#ifdef ET_USE_THREADPOOL
// Default to using cores/2 threadpool threads. The long-term plan is to
Expand Down
Loading
Loading