gh-110686: Test pattern matching with runtime_checkable protocols (#110687)

This commit is contained in:
Nikita Sobolev 2023-12-10 18:21:20 +03:00 committed by GitHub
parent 7595d47722
commit 9d02d3451a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2760,6 +2760,132 @@ class TestPatma(unittest.TestCase):
self.assertEqual(y, 1)
self.assertIs(z, x)
def test_patma_runtime_checkable_protocol(self):
# Runtime-checkable protocol
from typing import Protocol, runtime_checkable
@runtime_checkable
class P(Protocol):
x: int
y: int
class A:
def __init__(self, x: int, y: int):
self.x = x
self.y = y
class B(A): ...
for cls in (A, B):
with self.subTest(cls=cls.__name__):
inst = cls(1, 2)
w = 0
match inst:
case P() as p:
self.assertIsInstance(p, cls)
self.assertEqual(p.x, 1)
self.assertEqual(p.y, 2)
w = 1
self.assertEqual(w, 1)
q = 0
match inst:
case P(x=x, y=y):
self.assertEqual(x, 1)
self.assertEqual(y, 2)
q = 1
self.assertEqual(q, 1)
def test_patma_generic_protocol(self):
# Runtime-checkable generic protocol
from typing import Generic, TypeVar, Protocol, runtime_checkable
T = TypeVar('T') # not using PEP695 to be able to backport changes
@runtime_checkable
class P(Protocol[T]):
a: T
b: T
class A:
def __init__(self, x: int, y: int):
self.x = x
self.y = y
class G(Generic[T]):
def __init__(self, x: T, y: T):
self.x = x
self.y = y
for cls in (A, G):
with self.subTest(cls=cls.__name__):
inst = cls(1, 2)
w = 0
match inst:
case P():
w = 1
self.assertEqual(w, 0)
def test_patma_protocol_with_match_args(self):
# Runtime-checkable protocol with `__match_args__`
from typing import Protocol, runtime_checkable
# Used to fail before
# https://github.com/python/cpython/issues/110682
@runtime_checkable
class P(Protocol):
__match_args__ = ('x', 'y')
x: int
y: int
class A:
def __init__(self, x: int, y: int):
self.x = x
self.y = y
class B(A): ...
for cls in (A, B):
with self.subTest(cls=cls.__name__):
inst = cls(1, 2)
w = 0
match inst:
case P() as p:
self.assertIsInstance(p, cls)
self.assertEqual(p.x, 1)
self.assertEqual(p.y, 2)
w = 1
self.assertEqual(w, 1)
q = 0
match inst:
case P(x=x, y=y):
self.assertEqual(x, 1)
self.assertEqual(y, 2)
q = 1
self.assertEqual(q, 1)
j = 0
match inst:
case P(x=1, y=2):
j = 1
self.assertEqual(j, 1)
g = 0
match inst:
case P(x, y):
self.assertEqual(x, 1)
self.assertEqual(y, 2)
g = 1
self.assertEqual(g, 1)
h = 0
match inst:
case P(1, 2):
h = 1
self.assertEqual(h, 1)
class TestSyntaxErrors(unittest.TestCase):
@ -3198,6 +3324,35 @@ class TestTypeErrors(unittest.TestCase):
w = 0
self.assertIsNone(w)
def test_regular_protocol(self):
from typing import Protocol
class P(Protocol): ...
msg = (
'Instance and class checks can only be used '
'with @runtime_checkable protocols'
)
w = None
with self.assertRaisesRegex(TypeError, msg):
match 1:
case P():
w = 0
self.assertIsNone(w)
def test_positional_patterns_with_regular_protocol(self):
from typing import Protocol
class P(Protocol):
x: int # no `__match_args__`
y: int
class A:
x = 1
y = 2
w = None
with self.assertRaises(TypeError):
match A():
case P(x, y):
w = 0
self.assertIsNone(w)
class TestValueErrors(unittest.TestCase):