Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions mypyc/irbuild/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,20 +108,25 @@ def visit_value_pattern(self, pattern: ValuePattern) -> None:
self.builder.add_bool_branch(cond, self.code_block, self.next_block)

def visit_or_pattern(self, pattern: OrPattern) -> None:
backup_block = self.next_block
self.next_block = BasicBlock()
code_block = self.code_block
next_block = self.next_block

for p in pattern.patterns:
self.code_block = BasicBlock()
self.next_block = BasicBlock()

# Hack to ensure the as pattern is bound to each pattern in the
# "or" pattern, but not every subpattern
backup = self.as_pattern
p.accept(self)
self.as_pattern = backup

self.builder.activate_block(self.code_block)
self.builder.goto(code_block)
self.builder.activate_block(self.next_block)
self.next_block = BasicBlock()

self.next_block = backup_block
self.code_block = code_block
self.next_block = next_block
self.builder.goto(self.next_block)

def visit_class_pattern(self, pattern: ClassPattern) -> None:
Expand Down
133 changes: 81 additions & 52 deletions mypyc/test-data/irbuild-match.test
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,17 @@ def f():
r8, r9 :: object
L0:
r0 = int_eq 246, 246
if r0 goto L3 else goto L1 :: bool
if r0 goto L1 else goto L2 :: bool
L1:
r1 = int_eq 246, 912
if r1 goto L3 else goto L2 :: bool
goto L5
L2:
goto L4
r1 = int_eq 246, 912
if r1 goto L3 else goto L4 :: bool
L3:
goto L5
L4:
goto L6
L5:
r2 = 'matched'
r3 = builtins :: module
r4 = 'print'
Expand All @@ -63,9 +67,9 @@ L3:
r7 = load_address r6
r8 = PyObject_Vectorcall(r5, r7, 1, 0)
keep_alive r2
goto L5
L4:
L5:
goto L7
L6:
L7:
r9 = box(None, 1)
return r9

Expand All @@ -86,19 +90,27 @@ def f():
r10, r11 :: object
L0:
r0 = int_eq 2, 2
if r0 goto L5 else goto L1 :: bool
if r0 goto L1 else goto L2 :: bool
L1:
r1 = int_eq 2, 4
if r1 goto L5 else goto L2 :: bool
goto L9
L2:
r2 = int_eq 2, 6
if r2 goto L5 else goto L3 :: bool
r1 = int_eq 2, 4
if r1 goto L3 else goto L4 :: bool
L3:
r3 = int_eq 2, 8
if r3 goto L5 else goto L4 :: bool
goto L9
L4:
goto L6
r2 = int_eq 2, 6
if r2 goto L5 else goto L6 :: bool
L5:
goto L9
L6:
r3 = int_eq 2, 8
if r3 goto L7 else goto L8 :: bool
L7:
goto L9
L8:
goto L10
L9:
r4 = 'matched'
r5 = builtins :: module
r6 = 'print'
Expand All @@ -107,9 +119,9 @@ L5:
r9 = load_address r8
r10 = PyObject_Vectorcall(r7, r9, 1, 0)
keep_alive r4
goto L7
L6:
L7:
goto L11
L10:
L11:
r11 = box(None, 1)
return r11

Expand Down Expand Up @@ -280,16 +292,20 @@ L1:
r6 = load_address r5
r7 = PyObject_Vectorcall(r4, r6, 1, 0)
keep_alive r1
goto L9
goto L11
L2:
r8 = int_eq 246, 4
if r8 goto L5 else goto L3 :: bool
if r8 goto L3 else goto L4 :: bool
L3:
r9 = int_eq 246, 6
if r9 goto L5 else goto L4 :: bool
goto L7
L4:
goto L6
r9 = int_eq 246, 6
if r9 goto L5 else goto L6 :: bool
L5:
goto L7
L6:
goto L8
L7:
r10 = 'here 2 | 3'
r11 = builtins :: module
r12 = 'print'
Expand All @@ -298,11 +314,11 @@ L5:
r15 = load_address r14
r16 = PyObject_Vectorcall(r13, r15, 1, 0)
keep_alive r10
goto L9
L6:
goto L11
L8:
r17 = int_eq 246, 246
if r17 goto L7 else goto L8 :: bool
L7:
if r17 goto L9 else goto L10 :: bool
L9:
r18 = 'here 123'
r19 = builtins :: module
r20 = 'print'
Expand All @@ -311,9 +327,9 @@ L7:
r23 = load_address r22
r24 = PyObject_Vectorcall(r21, r23, 1, 0)
keep_alive r18
goto L9
L8:
L9:
goto L11
L10:
L11:
r25 = box(None, 1)
return r25

Expand Down Expand Up @@ -456,15 +472,19 @@ def f():
r10, r11 :: object
L0:
r0 = int_eq 2, 2
if r0 goto L3 else goto L1 :: bool
if r0 goto L1 else goto L2 :: bool
L1:
goto L5
L2:
r1 = load_address PyLong_Type
r2 = object 1
r3 = CPy_TypeCheck(r2, r1)
if r3 goto L3 else goto L2 :: bool
L2:
goto L4
if r3 goto L3 else goto L4 :: bool
L3:
goto L5
L4:
goto L6
L5:
r4 = 'matched'
r5 = builtins :: module
r6 = 'print'
Expand All @@ -473,9 +493,9 @@ L3:
r9 = load_address r8
r10 = PyObject_Vectorcall(r7, r9, 1, 0)
keep_alive r4
goto L5
L4:
L5:
goto L7
L6:
L7:
r11 = box(None, 1)
return r11

Expand Down Expand Up @@ -532,25 +552,29 @@ L0:
r0 = int_eq 2, 2
r1 = object 1
x = r1
if r0 goto L3 else goto L1 :: bool
if r0 goto L1 else goto L2 :: bool
L1:
goto L5
L2:
r2 = int_eq 2, 4
r3 = object 2
x = r3
if r2 goto L3 else goto L2 :: bool
L2:
goto L4
if r2 goto L3 else goto L4 :: bool
L3:
goto L5
L4:
goto L6
L5:
r4 = builtins :: module
r5 = 'print'
r6 = CPyObject_GetAttr(r4, r5)
r7 = [x]
r8 = load_address r7
r9 = PyObject_Vectorcall(r6, r8, 1, 0)
keep_alive x
goto L5
L4:
L5:
goto L7
L6:
L7:
r10 = box(None, 1)
return r10

Expand Down Expand Up @@ -809,7 +833,7 @@ L0:
r1 = PyObject_IsInstance(x, r0)
r2 = r1 >= 0 :: signed
r3 = truncate r1: i32 to builtins.bool
if r3 goto L1 else goto L5 :: bool
if r3 goto L1 else goto L7 :: bool
L1:
r4 = 'num'
r5 = CPyObject_GetAttr(x, r4)
Expand All @@ -818,17 +842,21 @@ L1:
r8 = PyObject_IsTrue(r7)
r9 = r8 >= 0 :: signed
r10 = truncate r8: i32 to builtins.bool
if r10 goto L4 else goto L2 :: bool
if r10 goto L2 else goto L3 :: bool
L2:
goto L6
L3:
r11 = object 2
r12 = PyObject_RichCompare(r5, r11, 2)
r13 = PyObject_IsTrue(r12)
r14 = r13 >= 0 :: signed
r15 = truncate r13: i32 to builtins.bool
if r15 goto L4 else goto L3 :: bool
L3:
goto L5
if r15 goto L4 else goto L5 :: bool
L4:
goto L6
L5:
goto L7
L6:
r16 = 'matched'
r17 = builtins :: module
r18 = 'print'
Expand All @@ -837,11 +865,12 @@ L4:
r21 = load_address r20
r22 = PyObject_Vectorcall(r19, r21, 1, 0)
keep_alive r16
goto L6
L5:
L6:
goto L8
L7:
L8:
r23 = box(None, 1)
return r23

[case testAsPatternDoesntBleedIntoSubPatterns_python3_10]
class C:
__match_args__ = ("a", "b")
Expand Down
21 changes: 21 additions & 0 deletions mypyc/test-data/run-match.test
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,27 @@ test 21 ('')
test 21 (' as well')
test sequence final
test final

[case testMatchOrSequencePattern_python3_10]
def f(x: tuple[str, str]) -> str:
match x:
case ("X", "Y") | ("X", "Z"):
return "THERE"
case _:
return "OTHER"

[file driver.py]
from native import f

print(f(("X", "Y")))
print(f(("X", "Z")))
print(f(("X", "A")))

[out]
THERE
THERE
OTHER

[case testCustomMappingAndSequenceObjects_python3_10]
def f(x: object) -> None:
match x:
Expand Down
Loading