diff --git a/compiler/src/main/java/run/endive/compiler/internal/WasmAnalyzer.java b/compiler/src/main/java/run/endive/compiler/internal/WasmAnalyzer.java index cf13e2af..b31968e3 100644 --- a/compiler/src/main/java/run/endive/compiler/internal/WasmAnalyzer.java +++ b/compiler/src/main/java/run/endive/compiler/internal/WasmAnalyzer.java @@ -168,6 +168,8 @@ public static class TryCatchBlock { // types of values below the try scope that need to be saved/restored int savedStackSlotBase; List savedStackTypes; + // per-catch DROP_KEEP applied before branching to the catch target (null if none) + CompilerInstruction[] catchUnwinds; // set to true when TRY_CATCH_BLOCK was emitted (i.e., block is reachable) boolean registered; @@ -428,6 +430,22 @@ public AnalysisResult analyze(int funcId) { tryCatchBlock.savedStackSlotBase = saveSlotBase; tryCatchBlock.savedStackTypes = new ArrayList<>(belowTypes); + // A catch branches like a br, so precompute the unwind that + // drops values below each catch target (see catchUnwind). + var catches = ins.catches(); + tryCatchBlock.catchUnwinds = new CompilerInstruction[catches.size()]; + for (int ci = 0; ci < catches.size(); ci++) { + tryCatchBlock.catchUnwinds[ci] = + catchUnwind( + functionType, + body, + ins, + catches.get(ci).resolvedLabel(), + belowCount, + tryCatchBlock.savedStackTypes, + stack); + } + // operands: [saveSlotBase, belowCount, type_ids...] long[] saveOperands = new long[2 + allTypes.size()]; saveOperands[0] = saveSlotBase; @@ -939,6 +957,10 @@ private static void analyzeTryCatchEnd( result.add(new CompilerInstruction(CompilerOpCode.CATCH_REGISTER_EXCEPTION)); break; } + // unwind values below the catch target before branching to it + if (tryCatchBlock.catchUnwinds != null && tryCatchBlock.catchUnwinds[i] != null) { + result.add(tryCatchBlock.catchUnwinds[i]); + } result.add( new CompilerInstruction(CompilerOpCode.GOTO, catchCondition.resolvedLabel())); result.add(new CompilerInstruction(CompilerOpCode.LABEL, afterCatchLabel)); @@ -1723,6 +1745,67 @@ private Optional unwindStack( new CompilerInstruction(CompilerOpCode.DROP_KEEP, operands.build().toArray())); } + /** + * The DROP_KEEP a try_table catch handler applies before branching to {@code label}: like a + * {@code br}, it drops values below the target scope (see {@link #unwindStack}). {@code + * savedStackTypes} are the {@code belowCount} restored below-try values, bottom-to-top. + */ + private CompilerInstruction catchUnwind( + FunctionType functionType, + FunctionBody body, + AnnotatedInstruction tryIns, + int label, + int belowCount, + List savedStackTypes, + TypeStack stack) { + + boolean forward = true; + + var target = body.instructions().get(label); + if (target.address() <= tryIns.address()) { + target = body.instructions().get(label - 1); + forward = false; + } + var scope = target.scope(); + + FunctionType blockType; + if (scope.opcode() == OpCode.END) { + scope = FUNCTION_SCOPE; + blockType = functionType; + } else { + blockType = blockType(scope); + } + + var keepTypes = forward ? blockType.returns() : blockType.params(); + int keep = keepTypes.size(); + + var scopeSize = stack.scopeStackSize(scope); + if (scopeSize == null) { + return null; + } + + // reconstructed stack is [belowCount below-try values, keep caught values]; + // drop down to the target scope, like unwindStack + int drop = (belowCount + keep) - scopeSize; + if (forward) { + drop -= keep; + } + if (drop <= 0) { + return null; + } + + // operands: [drop, drop_types..., keep_types...] (dropped = top `drop` of below-try) + var operands = LongStream.builder(); + operands.add(drop); + for (ValType t : savedStackTypes.subList(belowCount - drop, belowCount)) { + operands.add(t.id()); + } + for (ValType t : keepTypes) { + operands.add(t.id()); + } + return new CompilerInstruction(CompilerOpCode.DROP_KEEP, operands.build().toArray()); + } + private FunctionType blockType(Instruction ins) { var typeId = ins.operand(0); if (typeId == 0x40) { diff --git a/machine-tests/src/test/java/run/endive/testing/TrySaveStackTest.java b/machine-tests/src/test/java/run/endive/testing/TrySaveStackTest.java index 2a2d7e58..2fca6e83 100644 --- a/machine-tests/src/test/java/run/endive/testing/TrySaveStackTest.java +++ b/machine-tests/src/test/java/run/endive/testing/TrySaveStackTest.java @@ -53,4 +53,35 @@ public void nestedTryValues(Function machine var instance = machineInject.apply(Instance.builder(MODULE)).build(); assertEquals(6, instance.export("nested-try-values").apply()[0]); } + + @ParameterizedTest + @MethodSource("machineImplementations") + public void catchBackwardToLoop(Function machineInject) { + var instance = machineInject.apply(Instance.builder(MODULE)).build(); + assertEquals(42, instance.export("catch-backward-to-loop").apply()[0]); + } + + @ParameterizedTest + @MethodSource("machineImplementations") + public void catchBackwardToLoopDrop( + Function machineInject) { + var instance = machineInject.apply(Instance.builder(MODULE)).build(); + assertEquals(42, instance.export("catch-backward-to-loop-drop").apply()[0]); + } + + @ParameterizedTest + @MethodSource("machineImplementations") + public void catchDropsValueAboveTarget( + Function machineInject) { + var instance = machineInject.apply(Instance.builder(MODULE)).build(); + assertEquals(7, instance.export("catch-drops-value-above-target").apply()[0]); + } + + @ParameterizedTest + @MethodSource("machineImplementations") + public void catchKeepsValueBelowTarget( + Function machineInject) { + var instance = machineInject.apply(Instance.builder(MODULE)).build(); + assertEquals(1007, instance.export("catch-keeps-value-below-target").apply()[0]); + } } diff --git a/wasm-corpus/src/main/resources/compiled/try_save_stack.wat.wasm b/wasm-corpus/src/main/resources/compiled/try_save_stack.wat.wasm index 018fa702..9a384626 100644 Binary files a/wasm-corpus/src/main/resources/compiled/try_save_stack.wat.wasm and b/wasm-corpus/src/main/resources/compiled/try_save_stack.wat.wasm differ diff --git a/wasm-corpus/src/main/resources/wat/try_save_stack.wat b/wasm-corpus/src/main/resources/wat/try_save_stack.wat index 6abbea2b..0c3cbbb4 100644 --- a/wasm-corpus/src/main/resources/wat/try_save_stack.wat +++ b/wasm-corpus/src/main/resources/wat/try_save_stack.wat @@ -35,6 +35,77 @@ (i32.add) ;; 100 + 205 = 305 ) + ;; Catch branches to an outer label, dropping a value below the try but above the target. + (func (export "catch-drops-value-above-target") (result i32) + (block $t (result i32) + (i32.const 99) ;; above $t -> dropped when the catch branches + (try_table (catch $e $t) + (call $do_throw (i32.const 7)) + ) + (unreachable) + ) ;; $t yields the caught 7 + ) + + ;; Like above, but a value below the target label survives the unwind. + (func (export "catch-keeps-value-below-target") (result i32) + (i32.const 1000) ;; below $t -> survives + (block $t (result i32) + (i32.const 99) ;; above $t -> dropped + (try_table (catch $e $t) + (call $do_throw (i32.const 7)) + ) + (unreachable) + ) + (i32.add) ;; 1000 + 7 = 1007 + ) + + ;; Catch branches backward to a loop. + (func (export "catch-backward-to-loop") (result i32) + (local $count i32) + (local $result i32) + (i32.const 0) + (loop $L (param i32) (result i32) + (local.set $result) + (local.get $count) + (if (then + (local.get $result) + (return) + )) + (local.get $count) + (i32.const 1) + (i32.add) + (local.set $count) + (try_table (catch $e $L) + (call $do_throw (i32.const 42)) + ) + (unreachable) + ) + ) + + ;; Catch branches backward to a loop, dropping an intermediate value. + (func (export "catch-backward-to-loop-drop") (result i32) + (local $count i32) + (local $result i32) + (i32.const 0) + (loop $L (param i32) (result i32) + (local.set $result) + (local.get $count) + (if (then + (local.get $result) + (return) + )) + (local.get $count) + (i32.const 1) + (i32.add) + (local.set $count) + (i32.const 999) ;; above $L -> must be dropped by catch unwind + (try_table (catch $e $L) + (call $do_throw (i32.const 42)) + ) + (unreachable) + ) + ) + ;; Nested try_table with values below both levels (func (export "nested-try-values") (result i32) (i32.const 1) ;; below outer try