gh-110686: Test pattern matching with runtime_checkable
protocols (#110687)
This commit is contained in:
parent
7595d47722
commit
9d02d3451a
@ -2760,6 +2760,132 @@ class TestPatma(unittest.TestCase):
|
|||||||
self.assertEqual(y, 1)
|
self.assertEqual(y, 1)
|
||||||
self.assertIs(z, x)
|
self.assertIs(z, x)
|
||||||
|
|
||||||
|
def test_patma_runtime_checkable_protocol(self):
|
||||||
|
# Runtime-checkable protocol
|
||||||
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class P(Protocol):
|
||||||
|
x: int
|
||||||
|
y: int
|
||||||
|
|
||||||
|
class A:
|
||||||
|
def __init__(self, x: int, y: int):
|
||||||
|
self.x = x
|
||||||
|
self.y = y
|
||||||
|
|
||||||
|
class B(A): ...
|
||||||
|
|
||||||
|
for cls in (A, B):
|
||||||
|
with self.subTest(cls=cls.__name__):
|
||||||
|
inst = cls(1, 2)
|
||||||
|
w = 0
|
||||||
|
match inst:
|
||||||
|
case P() as p:
|
||||||
|
self.assertIsInstance(p, cls)
|
||||||
|
self.assertEqual(p.x, 1)
|
||||||
|
self.assertEqual(p.y, 2)
|
||||||
|
w = 1
|
||||||
|
self.assertEqual(w, 1)
|
||||||
|
|
||||||
|
q = 0
|
||||||
|
match inst:
|
||||||
|
case P(x=x, y=y):
|
||||||
|
self.assertEqual(x, 1)
|
||||||
|
self.assertEqual(y, 2)
|
||||||
|
q = 1
|
||||||
|
self.assertEqual(q, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_patma_generic_protocol(self):
|
||||||
|
# Runtime-checkable generic protocol
|
||||||
|
from typing import Generic, TypeVar, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
T = TypeVar('T') # not using PEP695 to be able to backport changes
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class P(Protocol[T]):
|
||||||
|
a: T
|
||||||
|
b: T
|
||||||
|
|
||||||
|
class A:
|
||||||
|
def __init__(self, x: int, y: int):
|
||||||
|
self.x = x
|
||||||
|
self.y = y
|
||||||
|
|
||||||
|
class G(Generic[T]):
|
||||||
|
def __init__(self, x: T, y: T):
|
||||||
|
self.x = x
|
||||||
|
self.y = y
|
||||||
|
|
||||||
|
for cls in (A, G):
|
||||||
|
with self.subTest(cls=cls.__name__):
|
||||||
|
inst = cls(1, 2)
|
||||||
|
w = 0
|
||||||
|
match inst:
|
||||||
|
case P():
|
||||||
|
w = 1
|
||||||
|
self.assertEqual(w, 0)
|
||||||
|
|
||||||
|
def test_patma_protocol_with_match_args(self):
|
||||||
|
# Runtime-checkable protocol with `__match_args__`
|
||||||
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
|
# Used to fail before
|
||||||
|
# https://github.com/python/cpython/issues/110682
|
||||||
|
@runtime_checkable
|
||||||
|
class P(Protocol):
|
||||||
|
__match_args__ = ('x', 'y')
|
||||||
|
x: int
|
||||||
|
y: int
|
||||||
|
|
||||||
|
class A:
|
||||||
|
def __init__(self, x: int, y: int):
|
||||||
|
self.x = x
|
||||||
|
self.y = y
|
||||||
|
|
||||||
|
class B(A): ...
|
||||||
|
|
||||||
|
for cls in (A, B):
|
||||||
|
with self.subTest(cls=cls.__name__):
|
||||||
|
inst = cls(1, 2)
|
||||||
|
w = 0
|
||||||
|
match inst:
|
||||||
|
case P() as p:
|
||||||
|
self.assertIsInstance(p, cls)
|
||||||
|
self.assertEqual(p.x, 1)
|
||||||
|
self.assertEqual(p.y, 2)
|
||||||
|
w = 1
|
||||||
|
self.assertEqual(w, 1)
|
||||||
|
|
||||||
|
q = 0
|
||||||
|
match inst:
|
||||||
|
case P(x=x, y=y):
|
||||||
|
self.assertEqual(x, 1)
|
||||||
|
self.assertEqual(y, 2)
|
||||||
|
q = 1
|
||||||
|
self.assertEqual(q, 1)
|
||||||
|
|
||||||
|
j = 0
|
||||||
|
match inst:
|
||||||
|
case P(x=1, y=2):
|
||||||
|
j = 1
|
||||||
|
self.assertEqual(j, 1)
|
||||||
|
|
||||||
|
g = 0
|
||||||
|
match inst:
|
||||||
|
case P(x, y):
|
||||||
|
self.assertEqual(x, 1)
|
||||||
|
self.assertEqual(y, 2)
|
||||||
|
g = 1
|
||||||
|
self.assertEqual(g, 1)
|
||||||
|
|
||||||
|
h = 0
|
||||||
|
match inst:
|
||||||
|
case P(1, 2):
|
||||||
|
h = 1
|
||||||
|
self.assertEqual(h, 1)
|
||||||
|
|
||||||
|
|
||||||
class TestSyntaxErrors(unittest.TestCase):
|
class TestSyntaxErrors(unittest.TestCase):
|
||||||
|
|
||||||
@ -3198,6 +3324,35 @@ class TestTypeErrors(unittest.TestCase):
|
|||||||
w = 0
|
w = 0
|
||||||
self.assertIsNone(w)
|
self.assertIsNone(w)
|
||||||
|
|
||||||
|
def test_regular_protocol(self):
|
||||||
|
from typing import Protocol
|
||||||
|
class P(Protocol): ...
|
||||||
|
msg = (
|
||||||
|
'Instance and class checks can only be used '
|
||||||
|
'with @runtime_checkable protocols'
|
||||||
|
)
|
||||||
|
w = None
|
||||||
|
with self.assertRaisesRegex(TypeError, msg):
|
||||||
|
match 1:
|
||||||
|
case P():
|
||||||
|
w = 0
|
||||||
|
self.assertIsNone(w)
|
||||||
|
|
||||||
|
def test_positional_patterns_with_regular_protocol(self):
|
||||||
|
from typing import Protocol
|
||||||
|
class P(Protocol):
|
||||||
|
x: int # no `__match_args__`
|
||||||
|
y: int
|
||||||
|
class A:
|
||||||
|
x = 1
|
||||||
|
y = 2
|
||||||
|
w = None
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
match A():
|
||||||
|
case P(x, y):
|
||||||
|
w = 0
|
||||||
|
self.assertIsNone(w)
|
||||||
|
|
||||||
|
|
||||||
class TestValueErrors(unittest.TestCase):
|
class TestValueErrors(unittest.TestCase):
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user