gh-110686: Test pattern matching with runtime_checkable
protocols (#110687)
This commit is contained in:
parent
7595d47722
commit
9d02d3451a
@ -2760,6 +2760,132 @@ class TestPatma(unittest.TestCase):
|
||||
self.assertEqual(y, 1)
|
||||
self.assertIs(z, x)
|
||||
|
||||
def test_patma_runtime_checkable_protocol(self):
|
||||
# Runtime-checkable protocol
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
@runtime_checkable
|
||||
class P(Protocol):
|
||||
x: int
|
||||
y: int
|
||||
|
||||
class A:
|
||||
def __init__(self, x: int, y: int):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
class B(A): ...
|
||||
|
||||
for cls in (A, B):
|
||||
with self.subTest(cls=cls.__name__):
|
||||
inst = cls(1, 2)
|
||||
w = 0
|
||||
match inst:
|
||||
case P() as p:
|
||||
self.assertIsInstance(p, cls)
|
||||
self.assertEqual(p.x, 1)
|
||||
self.assertEqual(p.y, 2)
|
||||
w = 1
|
||||
self.assertEqual(w, 1)
|
||||
|
||||
q = 0
|
||||
match inst:
|
||||
case P(x=x, y=y):
|
||||
self.assertEqual(x, 1)
|
||||
self.assertEqual(y, 2)
|
||||
q = 1
|
||||
self.assertEqual(q, 1)
|
||||
|
||||
|
||||
def test_patma_generic_protocol(self):
|
||||
# Runtime-checkable generic protocol
|
||||
from typing import Generic, TypeVar, Protocol, runtime_checkable
|
||||
|
||||
T = TypeVar('T') # not using PEP695 to be able to backport changes
|
||||
|
||||
@runtime_checkable
|
||||
class P(Protocol[T]):
|
||||
a: T
|
||||
b: T
|
||||
|
||||
class A:
|
||||
def __init__(self, x: int, y: int):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
class G(Generic[T]):
|
||||
def __init__(self, x: T, y: T):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
for cls in (A, G):
|
||||
with self.subTest(cls=cls.__name__):
|
||||
inst = cls(1, 2)
|
||||
w = 0
|
||||
match inst:
|
||||
case P():
|
||||
w = 1
|
||||
self.assertEqual(w, 0)
|
||||
|
||||
def test_patma_protocol_with_match_args(self):
|
||||
# Runtime-checkable protocol with `__match_args__`
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
# Used to fail before
|
||||
# https://github.com/python/cpython/issues/110682
|
||||
@runtime_checkable
|
||||
class P(Protocol):
|
||||
__match_args__ = ('x', 'y')
|
||||
x: int
|
||||
y: int
|
||||
|
||||
class A:
|
||||
def __init__(self, x: int, y: int):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
class B(A): ...
|
||||
|
||||
for cls in (A, B):
|
||||
with self.subTest(cls=cls.__name__):
|
||||
inst = cls(1, 2)
|
||||
w = 0
|
||||
match inst:
|
||||
case P() as p:
|
||||
self.assertIsInstance(p, cls)
|
||||
self.assertEqual(p.x, 1)
|
||||
self.assertEqual(p.y, 2)
|
||||
w = 1
|
||||
self.assertEqual(w, 1)
|
||||
|
||||
q = 0
|
||||
match inst:
|
||||
case P(x=x, y=y):
|
||||
self.assertEqual(x, 1)
|
||||
self.assertEqual(y, 2)
|
||||
q = 1
|
||||
self.assertEqual(q, 1)
|
||||
|
||||
j = 0
|
||||
match inst:
|
||||
case P(x=1, y=2):
|
||||
j = 1
|
||||
self.assertEqual(j, 1)
|
||||
|
||||
g = 0
|
||||
match inst:
|
||||
case P(x, y):
|
||||
self.assertEqual(x, 1)
|
||||
self.assertEqual(y, 2)
|
||||
g = 1
|
||||
self.assertEqual(g, 1)
|
||||
|
||||
h = 0
|
||||
match inst:
|
||||
case P(1, 2):
|
||||
h = 1
|
||||
self.assertEqual(h, 1)
|
||||
|
||||
|
||||
class TestSyntaxErrors(unittest.TestCase):
|
||||
|
||||
@ -3198,6 +3324,35 @@ class TestTypeErrors(unittest.TestCase):
|
||||
w = 0
|
||||
self.assertIsNone(w)
|
||||
|
||||
def test_regular_protocol(self):
|
||||
from typing import Protocol
|
||||
class P(Protocol): ...
|
||||
msg = (
|
||||
'Instance and class checks can only be used '
|
||||
'with @runtime_checkable protocols'
|
||||
)
|
||||
w = None
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
match 1:
|
||||
case P():
|
||||
w = 0
|
||||
self.assertIsNone(w)
|
||||
|
||||
def test_positional_patterns_with_regular_protocol(self):
|
||||
from typing import Protocol
|
||||
class P(Protocol):
|
||||
x: int # no `__match_args__`
|
||||
y: int
|
||||
class A:
|
||||
x = 1
|
||||
y = 2
|
||||
w = None
|
||||
with self.assertRaises(TypeError):
|
||||
match A():
|
||||
case P(x, y):
|
||||
w = 0
|
||||
self.assertIsNone(w)
|
||||
|
||||
|
||||
class TestValueErrors(unittest.TestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user