gh-105499: Merge typing.Union and types.UnionType (#105511)

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
Co-authored-by: Ken Jin <kenjin@python.org>
Co-authored-by: Carl Meyer <carl@oddbird.net>
This commit is contained in:
Jelle Zijlstra 2025-03-04 11:44:19 -08:00 committed by GitHub
parent e091520fdb
commit dc6d66f44c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 562 additions and 327 deletions

View File

@ -127,6 +127,11 @@ although there is currently no date scheduled for their removal.
* :class:`typing.Text` (:gh:`92332`). * :class:`typing.Text` (:gh:`92332`).
* The internal class ``typing._UnionGenericAlias`` is no longer used to implement
:class:`typing.Union`. To preserve compatibility with users using this private
class, a compatibility shim will be provided until at least Python 3.17. (Contributed by
Jelle Zijlstra in :gh:`105499`.)
* :class:`unittest.IsolatedAsyncioTestCase`: it is deprecated to return a value * :class:`unittest.IsolatedAsyncioTestCase`: it is deprecated to return a value
that is not ``None`` from a test case. that is not ``None`` from a test case.

View File

@ -518,7 +518,7 @@ The :mod:`functools` module defines the following functions:
... for i, elem in enumerate(arg): ... for i, elem in enumerate(arg):
... print(i, elem) ... print(i, elem)
:data:`types.UnionType` and :data:`typing.Union` can also be used:: :class:`typing.Union` can also be used::
>>> @fun.register >>> @fun.register
... def _(arg: int | float, verbose=False): ... def _(arg: int | float, verbose=False):
@ -654,8 +654,8 @@ The :mod:`functools` module defines the following functions:
The :func:`register` attribute now supports using type annotations. The :func:`register` attribute now supports using type annotations.
.. versionchanged:: 3.11 .. versionchanged:: 3.11
The :func:`register` attribute now supports :data:`types.UnionType` The :func:`register` attribute now supports
and :data:`typing.Union` as type annotations. :class:`typing.Union` as a type annotation.
.. class:: singledispatchmethod(func) .. class:: singledispatchmethod(func)

View File

@ -5364,7 +5364,7 @@ Union Type
A union object holds the value of the ``|`` (bitwise or) operation on A union object holds the value of the ``|`` (bitwise or) operation on
multiple :ref:`type objects <bltin-type-objects>`. These types are intended multiple :ref:`type objects <bltin-type-objects>`. These types are intended
primarily for :term:`type annotations <annotation>`. The union type expression primarily for :term:`type annotations <annotation>`. The union type expression
enables cleaner type hinting syntax compared to :data:`typing.Union`. enables cleaner type hinting syntax compared to subscripting :class:`typing.Union`.
.. describe:: X | Y | ... .. describe:: X | Y | ...
@ -5400,9 +5400,10 @@ enables cleaner type hinting syntax compared to :data:`typing.Union`.
int | str == str | int int | str == str | int
* It is compatible with :data:`typing.Union`:: * It creates instances of :class:`typing.Union`::
int | str == typing.Union[int, str] int | str == typing.Union[int, str]
type(int | str) is typing.Union
* Optional types can be spelled as a union with ``None``:: * Optional types can be spelled as a union with ``None``::
@ -5428,16 +5429,15 @@ enables cleaner type hinting syntax compared to :data:`typing.Union`.
TypeError: isinstance() argument 2 cannot be a parameterized generic TypeError: isinstance() argument 2 cannot be a parameterized generic
The user-exposed type for the union object can be accessed from The user-exposed type for the union object can be accessed from
:data:`types.UnionType` and used for :func:`isinstance` checks. An object cannot be :class:`typing.Union` and used for :func:`isinstance` checks::
instantiated from the type::
>>> import types >>> import typing
>>> isinstance(int | str, types.UnionType) >>> isinstance(int | str, typing.Union)
True True
>>> types.UnionType() >>> typing.Union()
Traceback (most recent call last): Traceback (most recent call last):
File "<stdin>", line 1, in <module> File "<stdin>", line 1, in <module>
TypeError: cannot create 'types.UnionType' instances TypeError: cannot create 'typing.Union' instances
.. note:: .. note::
The :meth:`!__or__` method for type objects was added to support the syntax The :meth:`!__or__` method for type objects was added to support the syntax
@ -5464,6 +5464,11 @@ instantiated from the type::
.. versionadded:: 3.10 .. versionadded:: 3.10
.. versionchanged:: 3.14
Union objects are now instances of :class:`typing.Union`. Previously, they were instances
of :class:`types.UnionType`, which remains an alias for :class:`typing.Union`.
.. _typesother: .. _typesother:

View File

@ -314,6 +314,10 @@ Standard names are defined for the following types:
.. versionadded:: 3.10 .. versionadded:: 3.10
.. versionchanged:: 3.14
This is now an alias for :class:`typing.Union`.
.. class:: TracebackType(tb_next, tb_frame, tb_lasti, tb_lineno) .. class:: TracebackType(tb_next, tb_frame, tb_lasti, tb_lineno)
The type of traceback objects such as found in ``sys.exception().__traceback__``. The type of traceback objects such as found in ``sys.exception().__traceback__``.

View File

@ -1086,7 +1086,7 @@ Special forms
These can be used as types in annotations. They all support subscription using These can be used as types in annotations. They all support subscription using
``[]``, but each has a unique syntax. ``[]``, but each has a unique syntax.
.. data:: Union .. class:: Union
Union type; ``Union[X, Y]`` is equivalent to ``X | Y`` and means either X or Y. Union type; ``Union[X, Y]`` is equivalent to ``X | Y`` and means either X or Y.
@ -1121,6 +1121,14 @@ These can be used as types in annotations. They all support subscription using
Unions can now be written as ``X | Y``. See Unions can now be written as ``X | Y``. See
:ref:`union type expressions<types-union>`. :ref:`union type expressions<types-union>`.
.. versionchanged:: 3.14
:class:`types.UnionType` is now an alias for :class:`Union`, and both
``Union[int, str]`` and ``int | str`` create instances of the same class.
To check whether an object is a ``Union`` at runtime, use
``isinstance(obj, Union)``. For compatibility with earlier versions of
Python, use
``get_origin(obj) is typing.Union or get_origin(obj) is types.UnionType``.
.. data:: Optional .. data:: Optional
``Optional[X]`` is equivalent to ``X | None`` (or ``Union[X, None]``). ``Optional[X]`` is equivalent to ``X | None`` (or ``Union[X, None]``).

View File

@ -722,10 +722,10 @@ PEP 604: New Type Union Operator
A new type union operator was introduced which enables the syntax ``X | Y``. A new type union operator was introduced which enables the syntax ``X | Y``.
This provides a cleaner way of expressing 'either type X or type Y' instead of This provides a cleaner way of expressing 'either type X or type Y' instead of
using :data:`typing.Union`, especially in type hints. using :class:`typing.Union`, especially in type hints.
In previous versions of Python, to apply a type hint for functions accepting In previous versions of Python, to apply a type hint for functions accepting
arguments of multiple types, :data:`typing.Union` was used:: arguments of multiple types, :class:`typing.Union` was used::
def square(number: Union[int, float]) -> Union[int, float]: def square(number: Union[int, float]) -> Union[int, float]:
return number ** 2 return number ** 2

View File

@ -740,8 +740,8 @@ fractions
functools functools
--------- ---------
* :func:`functools.singledispatch` now supports :data:`types.UnionType` * :func:`functools.singledispatch` now supports :class:`types.UnionType`
and :data:`typing.Union` as annotations to the dispatch argument.:: and :class:`typing.Union` as annotations to the dispatch argument.::
>>> from functools import singledispatch >>> from functools import singledispatch
>>> @singledispatch >>> @singledispatch

View File

@ -18,6 +18,7 @@ PyAPI_FUNC(PyObject *) _Py_union_type_or(PyObject *, PyObject *);
extern PyObject *_Py_subs_parameters(PyObject *, PyObject *, PyObject *, PyObject *); extern PyObject *_Py_subs_parameters(PyObject *, PyObject *, PyObject *, PyObject *);
extern PyObject *_Py_make_parameters(PyObject *); extern PyObject *_Py_make_parameters(PyObject *);
extern PyObject *_Py_union_args(PyObject *self); extern PyObject *_Py_union_args(PyObject *self);
extern PyObject *_Py_union_from_tuple(PyObject *args);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@ -926,16 +926,11 @@ def singledispatch(func):
dispatch_cache[cls] = impl dispatch_cache[cls] = impl
return impl return impl
def _is_union_type(cls):
from typing import get_origin, Union
return get_origin(cls) in {Union, UnionType}
def _is_valid_dispatch_type(cls): def _is_valid_dispatch_type(cls):
if isinstance(cls, type): if isinstance(cls, type):
return True return True
from typing import get_args return (isinstance(cls, UnionType) and
return (_is_union_type(cls) and all(isinstance(arg, type) for arg in cls.__args__))
all(isinstance(arg, type) for arg in get_args(cls)))
def register(cls, func=None): def register(cls, func=None):
"""generic_func.register(cls, func) -> func """generic_func.register(cls, func) -> func
@ -967,7 +962,7 @@ def singledispatch(func):
from annotationlib import Format, ForwardRef from annotationlib import Format, ForwardRef
argname, cls = next(iter(get_type_hints(func, format=Format.FORWARDREF).items())) argname, cls = next(iter(get_type_hints(func, format=Format.FORWARDREF).items()))
if not _is_valid_dispatch_type(cls): if not _is_valid_dispatch_type(cls):
if _is_union_type(cls): if isinstance(cls, UnionType):
raise TypeError( raise TypeError(
f"Invalid annotation for {argname!r}. " f"Invalid annotation for {argname!r}. "
f"{cls!r} not all arguments are classes." f"{cls!r} not all arguments are classes."
@ -983,10 +978,8 @@ def singledispatch(func):
f"{cls!r} is not a class." f"{cls!r} is not a class."
) )
if _is_union_type(cls): if isinstance(cls, UnionType):
from typing import get_args for arg in cls.__args__:
for arg in get_args(cls):
registry[arg] = func registry[arg] = func
else: else:
registry[cls] = func registry[cls] = func

View File

@ -2314,7 +2314,7 @@ class TestDocString(unittest.TestCase):
class C: class C:
x: Union[int, type(None)] = None x: Union[int, type(None)] = None
self.assertDocStrEqual(C.__doc__, "C(x:Optional[int]=None)") self.assertDocStrEqual(C.__doc__, "C(x:int|None=None)")
def test_docstring_list_field(self): def test_docstring_list_field(self):
@dataclass @dataclass

View File

@ -3083,7 +3083,7 @@ class TestSingleDispatch(unittest.TestCase):
"Invalid annotation for 'arg'." "Invalid annotation for 'arg'."
) )
self.assertEndsWith(str(exc.exception), self.assertEndsWith(str(exc.exception),
'typing.Union[int, typing.Iterable[str]] not all arguments are classes.' 'int | typing.Iterable[str] not all arguments are classes.'
) )
def test_invalid_positional_argument(self): def test_invalid_positional_argument(self):

View File

@ -1750,8 +1750,8 @@ class TestClassesAndFunctions(unittest.TestCase):
class TestFormatAnnotation(unittest.TestCase): class TestFormatAnnotation(unittest.TestCase):
def test_typing_replacement(self): def test_typing_replacement(self):
from test.typinganndata.ann_module9 import ann, ann1 from test.typinganndata.ann_module9 import ann, ann1
self.assertEqual(inspect.formatannotation(ann), 'Union[List[str], int]') self.assertEqual(inspect.formatannotation(ann), 'List[str] | int')
self.assertEqual(inspect.formatannotation(ann1), 'Union[List[testModule.typing.A], int]') self.assertEqual(inspect.formatannotation(ann1), 'List[testModule.typing.A] | int')
def test_forwardref(self): def test_forwardref(self):
fwdref = ForwardRef('fwdref') fwdref = ForwardRef('fwdref')

View File

@ -133,7 +133,7 @@ DATA
c_alias = test.test_pydoc.pydoc_mod.C[int] c_alias = test.test_pydoc.pydoc_mod.C[int]
list_alias1 = typing.List[int] list_alias1 = typing.List[int]
list_alias2 = list[int] list_alias2 = list[int]
type_union1 = typing.Union[int, str] type_union1 = int | str
type_union2 = int | str type_union2 = int | str
VERSION VERSION
@ -223,7 +223,7 @@ Data
c_alias = test.test_pydoc.pydoc_mod.C[int] c_alias = test.test_pydoc.pydoc_mod.C[int]
list_alias1 = typing.List[int] list_alias1 = typing.List[int]
list_alias2 = list[int] list_alias2 = list[int]
type_union1 = typing.Union[int, str] type_union1 = int | str
type_union2 = int | str type_union2 = int | str
Author Author
@ -1447,17 +1447,17 @@ class TestDescriptions(unittest.TestCase):
self.assertIn(list.__doc__.strip().splitlines()[0], doc) self.assertIn(list.__doc__.strip().splitlines()[0], doc)
def test_union_type(self): def test_union_type(self):
self.assertEqual(pydoc.describe(typing.Union[int, str]), '_UnionGenericAlias') self.assertEqual(pydoc.describe(typing.Union[int, str]), 'Union')
doc = pydoc.render_doc(typing.Union[int, str], renderer=pydoc.plaintext) doc = pydoc.render_doc(typing.Union[int, str], renderer=pydoc.plaintext)
self.assertIn('_UnionGenericAlias in module typing', doc) self.assertIn('Union in module typing', doc)
self.assertIn('Union = typing.Union', doc) self.assertIn('class Union(builtins.object)', doc)
if typing.Union.__doc__: if typing.Union.__doc__:
self.assertIn(typing.Union.__doc__.strip().splitlines()[0], doc) self.assertIn(typing.Union.__doc__.strip().splitlines()[0], doc)
self.assertEqual(pydoc.describe(int | str), 'UnionType') self.assertEqual(pydoc.describe(int | str), 'Union')
doc = pydoc.render_doc(int | str, renderer=pydoc.plaintext) doc = pydoc.render_doc(int | str, renderer=pydoc.plaintext)
self.assertIn('UnionType in module types object', doc) self.assertIn('Union in module typing', doc)
self.assertIn('\nclass UnionType(builtins.object)', doc) self.assertIn('class Union(builtins.object)', doc)
if not MISSING_C_DOCSTRINGS: if not MISSING_C_DOCSTRINGS:
self.assertIn(types.UnionType.__doc__.strip().splitlines()[0], doc) self.assertIn(types.UnionType.__doc__.strip().splitlines()[0], doc)

View File

@ -709,10 +709,6 @@ class UnionTests(unittest.TestCase):
y = int | bool y = int | bool
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
x < y x < y
# Check that we don't crash if typing.Union does not have a tuple in __args__
y = typing.Union[str, int]
y.__args__ = [str, int]
self.assertEqual(x, y)
def test_hash(self): def test_hash(self):
self.assertEqual(hash(int | str), hash(str | int)) self.assertEqual(hash(int | str), hash(str | int))
@ -727,17 +723,40 @@ class UnionTests(unittest.TestCase):
self.assertEqual((A | B).__args__, (A, B)) self.assertEqual((A | B).__args__, (A, B))
union1 = A | B union1 = A | B
with self.assertRaises(TypeError): with self.assertRaisesRegex(TypeError, "unhashable type: 'UnhashableMeta'"):
hash(union1) hash(union1)
union2 = int | B union2 = int | B
with self.assertRaises(TypeError): with self.assertRaisesRegex(TypeError, "unhashable type: 'UnhashableMeta'"):
hash(union2) hash(union2)
union3 = A | int union3 = A | int
with self.assertRaises(TypeError): with self.assertRaisesRegex(TypeError, "unhashable type: 'UnhashableMeta'"):
hash(union3) hash(union3)
def test_unhashable_becomes_hashable(self):
is_hashable = False
class UnhashableMeta(type):
def __hash__(self):
if is_hashable:
return 1
else:
raise TypeError("not hashable")
class A(metaclass=UnhashableMeta): ...
class B(metaclass=UnhashableMeta): ...
union = A | B
self.assertEqual(union.__args__, (A, B))
with self.assertRaisesRegex(TypeError, "not hashable"):
hash(union)
is_hashable = True
with self.assertRaisesRegex(TypeError, "union contains 2 unhashable elements"):
hash(union)
def test_instancecheck_and_subclasscheck(self): def test_instancecheck_and_subclasscheck(self):
for x in (int | str, typing.Union[int, str]): for x in (int | str, typing.Union[int, str]):
with self.subTest(x=x): with self.subTest(x=x):
@ -921,7 +940,7 @@ class UnionTests(unittest.TestCase):
self.assertEqual(typing.get_args(typing.get_type_hints(forward_after)['x']), self.assertEqual(typing.get_args(typing.get_type_hints(forward_after)['x']),
(int, Forward)) (int, Forward))
self.assertEqual(typing.get_args(typing.get_type_hints(forward_before)['x']), self.assertEqual(typing.get_args(typing.get_type_hints(forward_before)['x']),
(int, Forward)) (Forward, int))
def test_or_type_operator_with_Protocol(self): def test_or_type_operator_with_Protocol(self):
class Proto(typing.Protocol): class Proto(typing.Protocol):
@ -1015,9 +1034,14 @@ class UnionTests(unittest.TestCase):
return 1 / 0 return 1 / 0
bt = BadType('bt', (), {}) bt = BadType('bt', (), {})
bt2 = BadType('bt2', (), {})
# Comparison should fail and errors should propagate out for bad types. # Comparison should fail and errors should propagate out for bad types.
union1 = int | bt
union2 = int | bt2
with self.assertRaises(ZeroDivisionError): with self.assertRaises(ZeroDivisionError):
list[int] | list[bt] union1 == union2
with self.assertRaises(ZeroDivisionError):
bt | bt2
union_ga = (list[str] | int, collections.abc.Callable[..., str] | int, union_ga = (list[str] | int, collections.abc.Callable[..., str] | int,
d | int) d | int)
@ -1060,6 +1084,14 @@ class UnionTests(unittest.TestCase):
self.assertLessEqual(sys.gettotalrefcount() - before, leeway, self.assertLessEqual(sys.gettotalrefcount() - before, leeway,
msg='Check for union reference leak.') msg='Check for union reference leak.')
def test_instantiation(self):
with self.assertRaises(TypeError):
types.UnionType()
self.assertIs(int, types.UnionType[int])
self.assertIs(int, types.UnionType[int, int])
self.assertEqual(int | str, types.UnionType[int, str])
self.assertEqual(int | typing.ForwardRef("str"), types.UnionType[int, "str"])
class MappingProxyTests(unittest.TestCase): class MappingProxyTests(unittest.TestCase):
mappingproxy = types.MappingProxyType mappingproxy = types.MappingProxyType

View File

@ -502,7 +502,7 @@ class TypeVarTests(BaseTestCase):
def test_bound_errors(self): def test_bound_errors(self):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
TypeVar('X', bound=Union) TypeVar('X', bound=Optional)
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
TypeVar('X', str, float, bound=Employee) TypeVar('X', str, float, bound=Employee)
with self.assertRaisesRegex(TypeError, with self.assertRaisesRegex(TypeError,
@ -542,7 +542,7 @@ class TypeVarTests(BaseTestCase):
def test_bad_var_substitution(self): def test_bad_var_substitution(self):
T = TypeVar('T') T = TypeVar('T')
bad_args = ( bad_args = (
(), (int, str), Union, (), (int, str), Optional,
Generic, Generic[T], Protocol, Protocol[T], Generic, Generic[T], Protocol, Protocol[T],
Final, Final[int], ClassVar, ClassVar[int], Final, Final[int], ClassVar, ClassVar[int],
) )
@ -2043,10 +2043,6 @@ class UnionTests(BaseTestCase):
self.assertNotIsSubclass(int, Union[Any, str]) self.assertNotIsSubclass(int, Union[Any, str])
def test_union_issubclass_type_error(self): def test_union_issubclass_type_error(self):
with self.assertRaises(TypeError):
issubclass(int, Union)
with self.assertRaises(TypeError):
issubclass(Union, int)
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
issubclass(Union[int, str], int) issubclass(Union[int, str], int)
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
@ -2121,41 +2117,40 @@ class UnionTests(BaseTestCase):
self.assertEqual(Union[A, B].__args__, (A, B)) self.assertEqual(Union[A, B].__args__, (A, B))
union1 = Union[A, B] union1 = Union[A, B]
with self.assertRaises(TypeError): with self.assertRaisesRegex(TypeError, "unhashable type: 'UnhashableMeta'"):
hash(union1) hash(union1)
union2 = Union[int, B] union2 = Union[int, B]
with self.assertRaises(TypeError): with self.assertRaisesRegex(TypeError, "unhashable type: 'UnhashableMeta'"):
hash(union2) hash(union2)
union3 = Union[A, int] union3 = Union[A, int]
with self.assertRaises(TypeError): with self.assertRaisesRegex(TypeError, "unhashable type: 'UnhashableMeta'"):
hash(union3) hash(union3)
def test_repr(self): def test_repr(self):
self.assertEqual(repr(Union), 'typing.Union')
u = Union[Employee, int] u = Union[Employee, int]
self.assertEqual(repr(u), 'typing.Union[%s.Employee, int]' % __name__) self.assertEqual(repr(u), f'{__name__}.Employee | int')
u = Union[int, Employee] u = Union[int, Employee]
self.assertEqual(repr(u), 'typing.Union[int, %s.Employee]' % __name__) self.assertEqual(repr(u), f'int | {__name__}.Employee')
T = TypeVar('T') T = TypeVar('T')
u = Union[T, int][int] u = Union[T, int][int]
self.assertEqual(repr(u), repr(int)) self.assertEqual(repr(u), repr(int))
u = Union[List[int], int] u = Union[List[int], int]
self.assertEqual(repr(u), 'typing.Union[typing.List[int], int]') self.assertEqual(repr(u), 'typing.List[int] | int')
u = Union[list[int], dict[str, float]] u = Union[list[int], dict[str, float]]
self.assertEqual(repr(u), 'typing.Union[list[int], dict[str, float]]') self.assertEqual(repr(u), 'list[int] | dict[str, float]')
u = Union[int | float] u = Union[int | float]
self.assertEqual(repr(u), 'typing.Union[int, float]') self.assertEqual(repr(u), 'int | float')
u = Union[None, str] u = Union[None, str]
self.assertEqual(repr(u), 'typing.Optional[str]') self.assertEqual(repr(u), 'None | str')
u = Union[str, None] u = Union[str, None]
self.assertEqual(repr(u), 'typing.Optional[str]') self.assertEqual(repr(u), 'str | None')
u = Union[None, str, int] u = Union[None, str, int]
self.assertEqual(repr(u), 'typing.Union[NoneType, str, int]') self.assertEqual(repr(u), 'None | str | int')
u = Optional[str] u = Optional[str]
self.assertEqual(repr(u), 'typing.Optional[str]') self.assertEqual(repr(u), 'str | None')
def test_dir(self): def test_dir(self):
dir_items = set(dir(Union[str, int])) dir_items = set(dir(Union[str, int]))
@ -2167,14 +2162,11 @@ class UnionTests(BaseTestCase):
def test_cannot_subclass(self): def test_cannot_subclass(self):
with self.assertRaisesRegex(TypeError, with self.assertRaisesRegex(TypeError,
r'Cannot subclass typing\.Union'): r"type 'typing\.Union' is not an acceptable base type"):
class C(Union): class C(Union):
pass pass
with self.assertRaisesRegex(TypeError, CANNOT_SUBCLASS_TYPE):
class D(type(Union)):
pass
with self.assertRaisesRegex(TypeError, with self.assertRaisesRegex(TypeError,
r'Cannot subclass typing\.Union\[int, str\]'): r'Cannot subclass int \| str'):
class E(Union[int, str]): class E(Union[int, str]):
pass pass
@ -2220,7 +2212,7 @@ class UnionTests(BaseTestCase):
def test_function_repr_union(self): def test_function_repr_union(self):
def fun() -> int: ... def fun() -> int: ...
self.assertEqual(repr(Union[fun, int]), 'typing.Union[fun, int]') self.assertEqual(repr(Union[fun, int]), f'{__name__}.{fun.__qualname__} | int')
def test_union_str_pattern(self): def test_union_str_pattern(self):
# Shouldn't crash; see http://bugs.python.org/issue25390 # Shouldn't crash; see http://bugs.python.org/issue25390
@ -4895,11 +4887,11 @@ class GenericTests(BaseTestCase):
def test_extended_generic_rules_repr(self): def test_extended_generic_rules_repr(self):
T = TypeVar('T') T = TypeVar('T')
self.assertEqual(repr(Union[Tuple, Callable]).replace('typing.', ''), self.assertEqual(repr(Union[Tuple, Callable]).replace('typing.', ''),
'Union[Tuple, Callable]') 'Tuple | Callable')
self.assertEqual(repr(Union[Tuple, Tuple[int]]).replace('typing.', ''), self.assertEqual(repr(Union[Tuple, Tuple[int]]).replace('typing.', ''),
'Union[Tuple, Tuple[int]]') 'Tuple | Tuple[int]')
self.assertEqual(repr(Callable[..., Optional[T]][int]).replace('typing.', ''), self.assertEqual(repr(Callable[..., Optional[T]][int]).replace('typing.', ''),
'Callable[..., Optional[int]]') 'Callable[..., int | None]')
self.assertEqual(repr(Callable[[], List[T]][int]).replace('typing.', ''), self.assertEqual(repr(Callable[[], List[T]][int]).replace('typing.', ''),
'Callable[[], List[int]]') 'Callable[[], List[int]]')
@ -5079,9 +5071,9 @@ class GenericTests(BaseTestCase):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
issubclass(Tuple[int, ...], typing.Iterable) issubclass(Tuple[int, ...], typing.Iterable)
def test_fail_with_bare_union(self): def test_fail_with_special_forms(self):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
List[Union] List[Final]
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
Tuple[Optional] Tuple[Optional]
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
@ -5623,8 +5615,6 @@ class GenericTests(BaseTestCase):
for obj in ( for obj in (
ClassVar[int], ClassVar[int],
Final[int], Final[int],
Union[int, float],
Optional[int],
Literal[1, 2], Literal[1, 2],
Concatenate[int, ParamSpec("P")], Concatenate[int, ParamSpec("P")],
TypeGuard[int], TypeGuard[int],
@ -5656,7 +5646,7 @@ class GenericTests(BaseTestCase):
__parameters__ = (T,) __parameters__ = (T,)
# Bare classes should be skipped # Bare classes should be skipped
for a in (List, list): for a in (List, list):
for b in (A, int, TypeVar, TypeVarTuple, ParamSpec, types.GenericAlias, types.UnionType): for b in (A, int, TypeVar, TypeVarTuple, ParamSpec, types.GenericAlias, Union):
with self.subTest(generic=a, sub=b): with self.subTest(generic=a, sub=b):
with self.assertRaisesRegex(TypeError, '.* is not a generic class'): with self.assertRaisesRegex(TypeError, '.* is not a generic class'):
a[b][str] a[b][str]
@ -5675,7 +5665,7 @@ class GenericTests(BaseTestCase):
for s in (int, G, A, List, list, for s in (int, G, A, List, list,
TypeVar, TypeVarTuple, ParamSpec, TypeVar, TypeVarTuple, ParamSpec,
types.GenericAlias, types.UnionType): types.GenericAlias, Union):
for t in Tuple, tuple: for t in Tuple, tuple:
with self.subTest(tuple=t, sub=s): with self.subTest(tuple=t, sub=s):
@ -7176,7 +7166,7 @@ class GetUtilitiesTestCase(TestCase):
self.assertIs(get_origin(Callable), collections.abc.Callable) self.assertIs(get_origin(Callable), collections.abc.Callable)
self.assertIs(get_origin(list[int]), list) self.assertIs(get_origin(list[int]), list)
self.assertIs(get_origin(list), None) self.assertIs(get_origin(list), None)
self.assertIs(get_origin(list | str), types.UnionType) self.assertIs(get_origin(list | str), Union)
self.assertIs(get_origin(P.args), P) self.assertIs(get_origin(P.args), P)
self.assertIs(get_origin(P.kwargs), P) self.assertIs(get_origin(P.kwargs), P)
self.assertIs(get_origin(Required[int]), Required) self.assertIs(get_origin(Required[int]), Required)
@ -10434,7 +10424,6 @@ class SpecialAttrsTests(BaseTestCase):
typing.TypeGuard: 'TypeGuard', typing.TypeGuard: 'TypeGuard',
typing.TypeIs: 'TypeIs', typing.TypeIs: 'TypeIs',
typing.TypeVar: 'TypeVar', typing.TypeVar: 'TypeVar',
typing.Union: 'Union',
typing.Self: 'Self', typing.Self: 'Self',
# Subscripted special forms # Subscripted special forms
typing.Annotated[Any, "Annotation"]: 'Annotated', typing.Annotated[Any, "Annotation"]: 'Annotated',
@ -10445,7 +10434,7 @@ class SpecialAttrsTests(BaseTestCase):
typing.Literal[Any]: 'Literal', typing.Literal[Any]: 'Literal',
typing.Literal[1, 2]: 'Literal', typing.Literal[1, 2]: 'Literal',
typing.Literal[True, 2]: 'Literal', typing.Literal[True, 2]: 'Literal',
typing.Optional[Any]: 'Optional', typing.Optional[Any]: 'Union',
typing.TypeGuard[Any]: 'TypeGuard', typing.TypeGuard[Any]: 'TypeGuard',
typing.TypeIs[Any]: 'TypeIs', typing.TypeIs[Any]: 'TypeIs',
typing.Union[Any]: 'Any', typing.Union[Any]: 'Any',
@ -10464,7 +10453,10 @@ class SpecialAttrsTests(BaseTestCase):
for proto in range(pickle.HIGHEST_PROTOCOL + 1): for proto in range(pickle.HIGHEST_PROTOCOL + 1):
s = pickle.dumps(cls, proto) s = pickle.dumps(cls, proto)
loaded = pickle.loads(s) loaded = pickle.loads(s)
self.assertIs(cls, loaded) if isinstance(cls, Union):
self.assertEqual(cls, loaded)
else:
self.assertIs(cls, loaded)
TypeName = typing.NewType('SpecialAttrsTests.TypeName', Any) TypeName = typing.NewType('SpecialAttrsTests.TypeName', Any)
@ -10739,6 +10731,34 @@ class TypeIterationTests(BaseTestCase):
self.assertNotIsInstance(type_to_test, collections.abc.Iterable) self.assertNotIsInstance(type_to_test, collections.abc.Iterable)
class UnionGenericAliasTests(BaseTestCase):
def test_constructor(self):
# Used e.g. in typer, pydantic
with self.assertWarns(DeprecationWarning):
inst = typing._UnionGenericAlias(typing.Union, (int, str))
self.assertEqual(inst, int | str)
with self.assertWarns(DeprecationWarning):
# name is accepted but ignored
inst = typing._UnionGenericAlias(typing.Union, (int, None), name="Optional")
self.assertEqual(inst, int | None)
def test_isinstance(self):
# Used e.g. in pydantic
with self.assertWarns(DeprecationWarning):
self.assertTrue(isinstance(Union[int, str], typing._UnionGenericAlias))
with self.assertWarns(DeprecationWarning):
self.assertFalse(isinstance(int, typing._UnionGenericAlias))
def test_eq(self):
# type(t) == _UnionGenericAlias is used in vyos
with self.assertWarns(DeprecationWarning):
self.assertEqual(Union, typing._UnionGenericAlias)
with self.assertWarns(DeprecationWarning):
self.assertEqual(typing._UnionGenericAlias, typing._UnionGenericAlias)
with self.assertWarns(DeprecationWarning):
self.assertNotEqual(int, typing._UnionGenericAlias)
def load_tests(loader, tests, pattern): def load_tests(loader, tests, pattern):
import doctest import doctest
tests.addTests(doctest.DocTestSuite(typing)) tests.addTests(doctest.DocTestSuite(typing))

View File

@ -29,7 +29,13 @@ import functools
import operator import operator
import sys import sys
import types import types
from types import GenericAlias from types import (
WrapperDescriptorType,
MethodWrapperType,
MethodDescriptorType,
GenericAlias,
)
import warnings
from _typing import ( from _typing import (
_idfunc, _idfunc,
@ -40,6 +46,7 @@ from _typing import (
ParamSpecKwargs, ParamSpecKwargs,
TypeAliasType, TypeAliasType,
Generic, Generic,
Union,
NoDefault, NoDefault,
) )
@ -367,21 +374,6 @@ def _compare_args_orderless(first_args, second_args):
return False return False
return not t return not t
def _remove_dups_flatten(parameters):
"""Internal helper for Union creation and substitution.
Flatten Unions among parameters, then remove duplicates.
"""
# Flatten out Union[Union[...], ...].
params = []
for p in parameters:
if isinstance(p, (_UnionGenericAlias, types.UnionType)):
params.extend(p.__args__)
else:
params.append(p)
return tuple(_deduplicate(params, unhashable_fallback=True))
def _flatten_literal_params(parameters): def _flatten_literal_params(parameters):
"""Internal helper for Literal creation: flatten Literals among parameters.""" """Internal helper for Literal creation: flatten Literals among parameters."""
@ -470,7 +462,7 @@ def _eval_type(t, globalns, localns, type_params=_sentinel, *, recursive_guard=f
return evaluate_forward_ref(t, globals=globalns, locals=localns, return evaluate_forward_ref(t, globals=globalns, locals=localns,
type_params=type_params, owner=owner, type_params=type_params, owner=owner,
_recursive_guard=recursive_guard, format=format) _recursive_guard=recursive_guard, format=format)
if isinstance(t, (_GenericAlias, GenericAlias, types.UnionType)): if isinstance(t, (_GenericAlias, GenericAlias, Union)):
if isinstance(t, GenericAlias): if isinstance(t, GenericAlias):
args = tuple( args = tuple(
_make_forward_ref(arg) if isinstance(arg, str) else arg _make_forward_ref(arg) if isinstance(arg, str) else arg
@ -495,7 +487,7 @@ def _eval_type(t, globalns, localns, type_params=_sentinel, *, recursive_guard=f
return t return t
if isinstance(t, GenericAlias): if isinstance(t, GenericAlias):
return GenericAlias(t.__origin__, ev_args) return GenericAlias(t.__origin__, ev_args)
if isinstance(t, types.UnionType): if isinstance(t, Union):
return functools.reduce(operator.or_, ev_args) return functools.reduce(operator.or_, ev_args)
else: else:
return t.copy_with(ev_args) return t.copy_with(ev_args)
@ -749,59 +741,6 @@ def Final(self, parameters):
item = _type_check(parameters, f'{self} accepts only single type.', allow_special_forms=True) item = _type_check(parameters, f'{self} accepts only single type.', allow_special_forms=True)
return _GenericAlias(self, (item,)) return _GenericAlias(self, (item,))
@_SpecialForm
def Union(self, parameters):
"""Union type; Union[X, Y] means either X or Y.
On Python 3.10 and higher, the | operator
can also be used to denote unions;
X | Y means the same thing to the type checker as Union[X, Y].
To define a union, use e.g. Union[int, str]. Details:
- The arguments must be types and there must be at least one.
- None as an argument is a special case and is replaced by
type(None).
- Unions of unions are flattened, e.g.::
assert Union[Union[int, str], float] == Union[int, str, float]
- Unions of a single argument vanish, e.g.::
assert Union[int] == int # The constructor actually returns int
- Redundant arguments are skipped, e.g.::
assert Union[int, str, int] == Union[int, str]
- When comparing unions, the argument order is ignored, e.g.::
assert Union[int, str] == Union[str, int]
- You cannot subclass or instantiate a union.
- You can use Optional[X] as a shorthand for Union[X, None].
"""
if parameters == ():
raise TypeError("Cannot take a Union of no types.")
if not isinstance(parameters, tuple):
parameters = (parameters,)
msg = "Union[arg, ...]: each arg must be a type."
parameters = tuple(_type_check(p, msg) for p in parameters)
parameters = _remove_dups_flatten(parameters)
if len(parameters) == 1:
return parameters[0]
if len(parameters) == 2 and type(None) in parameters:
return _UnionGenericAlias(self, parameters, name="Optional")
return _UnionGenericAlias(self, parameters)
def _make_union(left, right):
"""Used from the C implementation of TypeVar.
TypeVar.__or__ calls this instead of returning types.UnionType
because we want to allow unions between TypeVars and strings
(forward references).
"""
return Union[left, right]
@_SpecialForm @_SpecialForm
def Optional(self, parameters): def Optional(self, parameters):
"""Optional[X] is equivalent to Union[X, None].""" """Optional[X] is equivalent to Union[X, None]."""
@ -1708,45 +1647,34 @@ class _TupleType(_SpecialGenericAlias, _root=True):
return self.copy_with(params) return self.copy_with(params)
class _UnionGenericAlias(_NotIterable, _GenericAlias, _root=True): class _UnionGenericAliasMeta(type):
def copy_with(self, params): def __instancecheck__(self, inst: object) -> bool:
return Union[params] warnings._deprecated("_UnionGenericAlias", remove=(3, 17))
return isinstance(inst, Union)
def __subclasscheck__(self, inst: type) -> bool:
warnings._deprecated("_UnionGenericAlias", remove=(3, 17))
return issubclass(inst, Union)
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, (_UnionGenericAlias, types.UnionType)): warnings._deprecated("_UnionGenericAlias", remove=(3, 17))
return NotImplemented if other is _UnionGenericAlias or other is Union:
try: # fast path return True
return set(self.__args__) == set(other.__args__) return NotImplemented
except TypeError: # not hashable, slow path
return _compare_args_orderless(self.__args__, other.__args__)
def __hash__(self):
return hash(frozenset(self.__args__))
def __repr__(self): class _UnionGenericAlias(metaclass=_UnionGenericAliasMeta):
args = self.__args__ """Compatibility hack.
if len(args) == 2:
if args[0] is type(None):
return f'typing.Optional[{_type_repr(args[1])}]'
elif args[1] is type(None):
return f'typing.Optional[{_type_repr(args[0])}]'
return super().__repr__()
def __instancecheck__(self, obj): A class named _UnionGenericAlias used to be used to implement
for arg in self.__args__: typing.Union. This class exists to serve as a shim to preserve
if isinstance(obj, arg): the meaning of some code that used to use _UnionGenericAlias
return True directly.
return False
def __subclasscheck__(self, cls): """
for arg in self.__args__: def __new__(cls, self_cls, parameters, /, *, name=None):
if issubclass(cls, arg): warnings._deprecated("_UnionGenericAlias", remove=(3, 17))
return True return Union[parameters]
return False
def __reduce__(self):
func, (origin, args) = super().__reduce__()
return func, (Union, args)
def _value_and_type_iter(parameters): def _value_and_type_iter(parameters):
@ -2472,7 +2400,7 @@ def _strip_annotations(t):
if stripped_args == t.__args__: if stripped_args == t.__args__:
return t return t
return GenericAlias(t.__origin__, stripped_args) return GenericAlias(t.__origin__, stripped_args)
if isinstance(t, types.UnionType): if isinstance(t, Union):
stripped_args = tuple(_strip_annotations(a) for a in t.__args__) stripped_args = tuple(_strip_annotations(a) for a in t.__args__)
if stripped_args == t.__args__: if stripped_args == t.__args__:
return t return t
@ -2506,8 +2434,8 @@ def get_origin(tp):
return tp.__origin__ return tp.__origin__
if tp is Generic: if tp is Generic:
return Generic return Generic
if isinstance(tp, types.UnionType): if isinstance(tp, Union):
return types.UnionType return Union
return None return None
@ -2532,7 +2460,7 @@ def get_args(tp):
if _should_unflatten_callable_args(tp, res): if _should_unflatten_callable_args(tp, res):
res = (list(res[:-1]), res[-1]) res = (list(res[:-1]), res[-1])
return res return res
if isinstance(tp, types.UnionType): if isinstance(tp, Union):
return tp.__args__ return tp.__args__
return () return ()

View File

@ -0,0 +1,3 @@
Make :class:`types.UnionType` an alias for :class:`typing.Union`. Both
``int | str`` and ``Union[int, str]`` now create instances of the same
type. Patch by Jelle Zijlstra.

View File

@ -5,9 +5,10 @@
#endif #endif
#include "Python.h" #include "Python.h"
#include "pycore_interp.h" #include "internal/pycore_interp.h"
#include "internal/pycore_typevarobject.h"
#include "internal/pycore_unionobject.h" // _PyUnion_Type
#include "pycore_pystate.h" // _PyInterpreterState_GET() #include "pycore_pystate.h" // _PyInterpreterState_GET()
#include "pycore_typevarobject.h"
#include "clinic/_typingmodule.c.h" #include "clinic/_typingmodule.c.h"
/*[clinic input] /*[clinic input]
@ -63,6 +64,9 @@ _typing_exec(PyObject *m)
if (PyModule_AddObjectRef(m, "TypeAliasType", (PyObject *)&_PyTypeAlias_Type) < 0) { if (PyModule_AddObjectRef(m, "TypeAliasType", (PyObject *)&_PyTypeAlias_Type) < 0) {
return -1; return -1;
} }
if (PyModule_AddObjectRef(m, "Union", (PyObject *)&_PyUnion_Type) < 0) {
return -1;
}
if (PyModule_AddObjectRef(m, "NoDefault", (PyObject *)&_Py_NoDefaultStruct) < 0) { if (PyModule_AddObjectRef(m, "NoDefault", (PyObject *)&_Py_NoDefaultStruct) < 0) {
return -1; return -1;
} }

View File

@ -2,8 +2,8 @@
#include "Python.h" #include "Python.h"
#include "pycore_object.h" // _PyObject_GC_TRACK/UNTRACK, PyAnnotateFormat #include "pycore_object.h" // _PyObject_GC_TRACK/UNTRACK, PyAnnotateFormat
#include "pycore_typevarobject.h" #include "pycore_typevarobject.h"
#include "pycore_unionobject.h" // _Py_union_type_or #include "pycore_unionobject.h" // _Py_union_type_or, _Py_union_from_tuple
#include "structmember.h"
/*[clinic input] /*[clinic input]
class typevar "typevarobject *" "&_PyTypeVar_Type" class typevar "typevarobject *" "&_PyTypeVar_Type"
@ -370,9 +370,13 @@ type_check(PyObject *arg, const char *msg)
static PyObject * static PyObject *
make_union(PyObject *self, PyObject *other) make_union(PyObject *self, PyObject *other)
{ {
PyObject *args[2] = {self, other}; PyObject *args = PyTuple_Pack(2, self, other);
PyObject *result = call_typing_func_object("_make_union", args, 2); if (args == NULL) {
return result; return NULL;
}
PyObject *u = _Py_union_from_tuple(args);
Py_DECREF(args);
return u;
} }
static PyObject * static PyObject *

View File

@ -1,17 +1,17 @@
// types.UnionType -- used to represent e.g. Union[int, str], int | str // typing.Union -- used to represent e.g. Union[int, str], int | str
#include "Python.h" #include "Python.h"
#include "pycore_object.h" // _PyObject_GC_TRACK/UNTRACK #include "pycore_object.h" // _PyObject_GC_TRACK/UNTRACK
#include "pycore_typevarobject.h" // _PyTypeAlias_Type, _Py_typing_type_repr #include "pycore_typevarobject.h" // _PyTypeAlias_Type, _Py_typing_type_repr
#include "pycore_unionobject.h" #include "pycore_unionobject.h"
static PyObject *make_union(PyObject *);
typedef struct { typedef struct {
PyObject_HEAD PyObject_HEAD
PyObject *args; PyObject *args; // all args (tuple)
PyObject *hashable_args; // frozenset or NULL
PyObject *unhashable_args; // tuple or NULL
PyObject *parameters; PyObject *parameters;
PyObject *weakreflist;
} unionobject; } unionobject;
static void static void
@ -20,8 +20,13 @@ unionobject_dealloc(PyObject *self)
unionobject *alias = (unionobject *)self; unionobject *alias = (unionobject *)self;
_PyObject_GC_UNTRACK(self); _PyObject_GC_UNTRACK(self);
if (alias->weakreflist != NULL) {
PyObject_ClearWeakRefs((PyObject *)alias);
}
Py_XDECREF(alias->args); Py_XDECREF(alias->args);
Py_XDECREF(alias->hashable_args);
Py_XDECREF(alias->unhashable_args);
Py_XDECREF(alias->parameters); Py_XDECREF(alias->parameters);
Py_TYPE(self)->tp_free(self); Py_TYPE(self)->tp_free(self);
} }
@ -31,6 +36,8 @@ union_traverse(PyObject *self, visitproc visit, void *arg)
{ {
unionobject *alias = (unionobject *)self; unionobject *alias = (unionobject *)self;
Py_VISIT(alias->args); Py_VISIT(alias->args);
Py_VISIT(alias->hashable_args);
Py_VISIT(alias->unhashable_args);
Py_VISIT(alias->parameters); Py_VISIT(alias->parameters);
return 0; return 0;
} }
@ -39,13 +46,67 @@ static Py_hash_t
union_hash(PyObject *self) union_hash(PyObject *self)
{ {
unionobject *alias = (unionobject *)self; unionobject *alias = (unionobject *)self;
PyObject *args = PyFrozenSet_New(alias->args); // If there are any unhashable args, treat this union as unhashable.
if (args == NULL) { // Otherwise, two unions might compare equal but have different hashes.
return (Py_hash_t)-1; if (alias->unhashable_args) {
// Attempt to get an error from one of the values.
assert(PyTuple_CheckExact(alias->unhashable_args));
Py_ssize_t n = PyTuple_GET_SIZE(alias->unhashable_args);
for (Py_ssize_t i = 0; i < n; i++) {
PyObject *arg = PyTuple_GET_ITEM(alias->unhashable_args, i);
Py_hash_t hash = PyObject_Hash(arg);
if (hash == -1) {
return -1;
}
}
// The unhashable values somehow became hashable again. Still raise
// an error.
PyErr_Format(PyExc_TypeError, "union contains %d unhashable elements", n);
return -1;
} }
Py_hash_t hash = PyObject_Hash(args); return PyObject_Hash(alias->hashable_args);
Py_DECREF(args); }
return hash;
static int
unions_equal(unionobject *a, unionobject *b)
{
int result = PyObject_RichCompareBool(a->hashable_args, b->hashable_args, Py_EQ);
if (result == -1) {
return -1;
}
if (result == 0) {
return 0;
}
if (a->unhashable_args && b->unhashable_args) {
Py_ssize_t n = PyTuple_GET_SIZE(a->unhashable_args);
if (n != PyTuple_GET_SIZE(b->unhashable_args)) {
return 0;
}
for (Py_ssize_t i = 0; i < n; i++) {
PyObject *arg_a = PyTuple_GET_ITEM(a->unhashable_args, i);
int result = PySequence_Contains(b->unhashable_args, arg_a);
if (result == -1) {
return -1;
}
if (!result) {
return 0;
}
}
for (Py_ssize_t i = 0; i < n; i++) {
PyObject *arg_b = PyTuple_GET_ITEM(b->unhashable_args, i);
int result = PySequence_Contains(a->unhashable_args, arg_b);
if (result == -1) {
return -1;
}
if (!result) {
return 0;
}
}
}
else if (a->unhashable_args || b->unhashable_args) {
return 0;
}
return 1;
} }
static PyObject * static PyObject *
@ -55,95 +116,130 @@ union_richcompare(PyObject *a, PyObject *b, int op)
Py_RETURN_NOTIMPLEMENTED; Py_RETURN_NOTIMPLEMENTED;
} }
PyObject *a_set = PySet_New(((unionobject*)a)->args); int equal = unions_equal((unionobject*)a, (unionobject*)b);
if (a_set == NULL) { if (equal == -1) {
return NULL; return NULL;
} }
PyObject *b_set = PySet_New(((unionobject*)b)->args); if (op == Py_EQ) {
if (b_set == NULL) { return PyBool_FromLong(equal);
Py_DECREF(a_set);
return NULL;
}
PyObject *result = PyObject_RichCompare(a_set, b_set, op);
Py_DECREF(b_set);
Py_DECREF(a_set);
return result;
}
static int
is_same(PyObject *left, PyObject *right)
{
int is_ga = _PyGenericAlias_Check(left) && _PyGenericAlias_Check(right);
return is_ga ? PyObject_RichCompareBool(left, right, Py_EQ) : left == right;
}
static int
contains(PyObject **items, Py_ssize_t size, PyObject *obj)
{
for (Py_ssize_t i = 0; i < size; i++) {
int is_duplicate = is_same(items[i], obj);
if (is_duplicate) { // -1 or 1
return is_duplicate;
}
}
return 0;
}
static PyObject *
merge(PyObject **items1, Py_ssize_t size1,
PyObject **items2, Py_ssize_t size2)
{
PyObject *tuple = NULL;
Py_ssize_t pos = 0;
for (Py_ssize_t i = 0; i < size2; i++) {
PyObject *arg = items2[i];
int is_duplicate = contains(items1, size1, arg);
if (is_duplicate < 0) {
Py_XDECREF(tuple);
return NULL;
}
if (is_duplicate) {
continue;
}
if (tuple == NULL) {
tuple = PyTuple_New(size1 + size2 - i);
if (tuple == NULL) {
return NULL;
}
for (; pos < size1; pos++) {
PyObject *a = items1[pos];
PyTuple_SET_ITEM(tuple, pos, Py_NewRef(a));
}
}
PyTuple_SET_ITEM(tuple, pos, Py_NewRef(arg));
pos++;
}
if (tuple) {
(void) _PyTuple_Resize(&tuple, pos);
}
return tuple;
}
static PyObject **
get_types(PyObject **obj, Py_ssize_t *size)
{
if (*obj == Py_None) {
*obj = (PyObject *)&_PyNone_Type;
}
if (_PyUnion_Check(*obj)) {
PyObject *args = ((unionobject *) *obj)->args;
*size = PyTuple_GET_SIZE(args);
return &PyTuple_GET_ITEM(args, 0);
} }
else { else {
*size = 1; return PyBool_FromLong(!equal);
return obj;
} }
} }
typedef struct {
PyObject *args; // list
PyObject *hashable_args; // set
PyObject *unhashable_args; // list or NULL
bool is_checked; // whether to call type_check()
} unionbuilder;
static bool unionbuilder_add_tuple(unionbuilder *, PyObject *);
static PyObject *make_union(unionbuilder *);
static PyObject *type_check(PyObject *, const char *);
static bool
unionbuilder_init(unionbuilder *ub, bool is_checked)
{
ub->args = PyList_New(0);
if (ub->args == NULL) {
return false;
}
ub->hashable_args = PySet_New(NULL);
if (ub->hashable_args == NULL) {
Py_DECREF(ub->args);
return false;
}
ub->unhashable_args = NULL;
ub->is_checked = is_checked;
return true;
}
static void
unionbuilder_finalize(unionbuilder *ub)
{
Py_DECREF(ub->args);
Py_DECREF(ub->hashable_args);
Py_XDECREF(ub->unhashable_args);
}
static bool
unionbuilder_add_single_unchecked(unionbuilder *ub, PyObject *arg)
{
Py_hash_t hash = PyObject_Hash(arg);
if (hash == -1) {
PyErr_Clear();
if (ub->unhashable_args == NULL) {
ub->unhashable_args = PyList_New(0);
if (ub->unhashable_args == NULL) {
return false;
}
}
else {
int contains = PySequence_Contains(ub->unhashable_args, arg);
if (contains < 0) {
return false;
}
if (contains == 1) {
return true;
}
}
if (PyList_Append(ub->unhashable_args, arg) < 0) {
return false;
}
}
else {
int contains = PySet_Contains(ub->hashable_args, arg);
if (contains < 0) {
return false;
}
if (contains == 1) {
return true;
}
if (PySet_Add(ub->hashable_args, arg) < 0) {
return false;
}
}
return PyList_Append(ub->args, arg) == 0;
}
static bool
unionbuilder_add_single(unionbuilder *ub, PyObject *arg)
{
if (Py_IsNone(arg)) {
arg = (PyObject *)&_PyNone_Type; // immortal, so no refcounting needed
}
else if (_PyUnion_Check(arg)) {
PyObject *args = ((unionobject *)arg)->args;
return unionbuilder_add_tuple(ub, args);
}
if (ub->is_checked) {
PyObject *type = type_check(arg, "Union[arg, ...]: each arg must be a type.");
if (type == NULL) {
return false;
}
bool result = unionbuilder_add_single_unchecked(ub, type);
Py_DECREF(type);
return result;
}
else {
return unionbuilder_add_single_unchecked(ub, arg);
}
}
static bool
unionbuilder_add_tuple(unionbuilder *ub, PyObject *tuple)
{
Py_ssize_t n = PyTuple_GET_SIZE(tuple);
for (Py_ssize_t i = 0; i < n; i++) {
if (!unionbuilder_add_single(ub, PyTuple_GET_ITEM(tuple, i))) {
return false;
}
}
return true;
}
static int static int
is_unionable(PyObject *obj) is_unionable(PyObject *obj)
{ {
@ -164,19 +260,18 @@ _Py_union_type_or(PyObject* self, PyObject* other)
Py_RETURN_NOTIMPLEMENTED; Py_RETURN_NOTIMPLEMENTED;
} }
Py_ssize_t size1, size2; unionbuilder ub;
PyObject **items1 = get_types(&self, &size1); // unchecked because we already checked is_unionable()
PyObject **items2 = get_types(&other, &size2); if (!unionbuilder_init(&ub, false)) {
PyObject *tuple = merge(items1, size1, items2, size2); return NULL;
if (tuple == NULL) { }
if (PyErr_Occurred()) { if (!unionbuilder_add_single(&ub, self) ||
return NULL; !unionbuilder_add_single(&ub, other)) {
} unionbuilder_finalize(&ub);
return Py_NewRef(self); return NULL;
} }
PyObject *new_union = make_union(tuple); PyObject *new_union = make_union(&ub);
Py_DECREF(tuple);
return new_union; return new_union;
} }
@ -202,6 +297,18 @@ union_repr(PyObject *self)
goto error; goto error;
} }
} }
#if 0
PyUnicodeWriter_WriteUTF8(writer, "|args=", 6);
PyUnicodeWriter_WriteRepr(writer, alias->args);
PyUnicodeWriter_WriteUTF8(writer, "|h=", 3);
PyUnicodeWriter_WriteRepr(writer, alias->hashable_args);
if (alias->unhashable_args) {
PyUnicodeWriter_WriteUTF8(writer, "|u=", 3);
PyUnicodeWriter_WriteRepr(writer, alias->unhashable_args);
}
#endif
return PyUnicodeWriter_Finish(writer); return PyUnicodeWriter_Finish(writer);
error: error:
@ -231,21 +338,7 @@ union_getitem(PyObject *self, PyObject *item)
return NULL; return NULL;
} }
PyObject *res; PyObject *res = _Py_union_from_tuple(newargs);
Py_ssize_t nargs = PyTuple_GET_SIZE(newargs);
if (nargs == 0) {
res = make_union(newargs);
}
else {
res = Py_NewRef(PyTuple_GET_ITEM(newargs, 0));
for (Py_ssize_t iarg = 1; iarg < nargs; iarg++) {
PyObject *arg = PyTuple_GET_ITEM(newargs, iarg);
Py_SETREF(res, PyNumber_Or(res, arg));
if (res == NULL) {
break;
}
}
}
Py_DECREF(newargs); Py_DECREF(newargs);
return res; return res;
} }
@ -267,7 +360,25 @@ union_parameters(PyObject *self, void *Py_UNUSED(unused))
return Py_NewRef(alias->parameters); return Py_NewRef(alias->parameters);
} }
static PyObject *
union_name(PyObject *Py_UNUSED(self), void *Py_UNUSED(ignored))
{
return PyUnicode_FromString("Union");
}
static PyObject *
union_origin(PyObject *Py_UNUSED(self), void *Py_UNUSED(ignored))
{
return Py_NewRef(&_PyUnion_Type);
}
static PyGetSetDef union_properties[] = { static PyGetSetDef union_properties[] = {
{"__name__", union_name, NULL,
PyDoc_STR("Name of the type"), NULL},
{"__qualname__", union_name, NULL,
PyDoc_STR("Qualified name of the type"), NULL},
{"__origin__", union_origin, NULL,
PyDoc_STR("Always returns the type"), NULL},
{"__parameters__", union_parameters, (setter)NULL, {"__parameters__", union_parameters, (setter)NULL,
PyDoc_STR("Type variables in the types.UnionType."), NULL}, PyDoc_STR("Type variables in the types.UnionType."), NULL},
{0} {0}
@ -306,10 +417,88 @@ _Py_union_args(PyObject *self)
return ((unionobject *) self)->args; return ((unionobject *) self)->args;
} }
static PyObject *
call_typing_func_object(const char *name, PyObject **args, size_t nargs)
{
PyObject *typing = PyImport_ImportModule("typing");
if (typing == NULL) {
return NULL;
}
PyObject *func = PyObject_GetAttrString(typing, name);
if (func == NULL) {
Py_DECREF(typing);
return NULL;
}
PyObject *result = PyObject_Vectorcall(func, args, nargs, NULL);
Py_DECREF(func);
Py_DECREF(typing);
return result;
}
static PyObject *
type_check(PyObject *arg, const char *msg)
{
if (Py_IsNone(arg)) {
// NoneType is immortal, so don't need an INCREF
return (PyObject *)Py_TYPE(arg);
}
// Fast path to avoid calling into typing.py
if (is_unionable(arg)) {
return Py_NewRef(arg);
}
PyObject *message_str = PyUnicode_FromString(msg);
if (message_str == NULL) {
return NULL;
}
PyObject *args[2] = {arg, message_str};
PyObject *result = call_typing_func_object("_type_check", args, 2);
Py_DECREF(message_str);
return result;
}
PyObject *
_Py_union_from_tuple(PyObject *args)
{
unionbuilder ub;
if (!unionbuilder_init(&ub, true)) {
return NULL;
}
if (PyTuple_CheckExact(args)) {
if (!unionbuilder_add_tuple(&ub, args)) {
return NULL;
}
}
else {
if (!unionbuilder_add_single(&ub, args)) {
return NULL;
}
}
return make_union(&ub);
}
static PyObject *
union_class_getitem(PyObject *cls, PyObject *args)
{
return _Py_union_from_tuple(args);
}
static PyObject *
union_mro_entries(PyObject *self, PyObject *args)
{
return PyErr_Format(PyExc_TypeError,
"Cannot subclass %R", self);
}
static PyMethodDef union_methods[] = {
{"__mro_entries__", union_mro_entries, METH_O},
{"__class_getitem__", union_class_getitem, METH_O|METH_CLASS, PyDoc_STR("See PEP 585")},
{0}
};
PyTypeObject _PyUnion_Type = { PyTypeObject _PyUnion_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0) PyVarObject_HEAD_INIT(&PyType_Type, 0)
.tp_name = "types.UnionType", .tp_name = "typing.Union",
.tp_doc = PyDoc_STR("Represent a PEP 604 union type\n" .tp_doc = PyDoc_STR("Represent a union type\n"
"\n" "\n"
"E.g. for int | str"), "E.g. for int | str"),
.tp_basicsize = sizeof(unionobject), .tp_basicsize = sizeof(unionobject),
@ -321,25 +510,64 @@ PyTypeObject _PyUnion_Type = {
.tp_hash = union_hash, .tp_hash = union_hash,
.tp_getattro = union_getattro, .tp_getattro = union_getattro,
.tp_members = union_members, .tp_members = union_members,
.tp_methods = union_methods,
.tp_richcompare = union_richcompare, .tp_richcompare = union_richcompare,
.tp_as_mapping = &union_as_mapping, .tp_as_mapping = &union_as_mapping,
.tp_as_number = &union_as_number, .tp_as_number = &union_as_number,
.tp_repr = union_repr, .tp_repr = union_repr,
.tp_getset = union_properties, .tp_getset = union_properties,
.tp_weaklistoffset = offsetof(unionobject, weakreflist),
}; };
static PyObject * static PyObject *
make_union(PyObject *args) make_union(unionbuilder *ub)
{ {
assert(PyTuple_CheckExact(args)); Py_ssize_t n = PyList_GET_SIZE(ub->args);
if (n == 0) {
PyErr_SetString(PyExc_TypeError, "Cannot take a Union of no types.");
unionbuilder_finalize(ub);
return NULL;
}
if (n == 1) {
PyObject *result = PyList_GET_ITEM(ub->args, 0);
Py_INCREF(result);
unionbuilder_finalize(ub);
return result;
}
PyObject *args = NULL, *hashable_args = NULL, *unhashable_args = NULL;
args = PyList_AsTuple(ub->args);
if (args == NULL) {
goto error;
}
hashable_args = PyFrozenSet_New(ub->hashable_args);
if (hashable_args == NULL) {
goto error;
}
if (ub->unhashable_args != NULL) {
unhashable_args = PyList_AsTuple(ub->unhashable_args);
if (unhashable_args == NULL) {
goto error;
}
}
unionobject *result = PyObject_GC_New(unionobject, &_PyUnion_Type); unionobject *result = PyObject_GC_New(unionobject, &_PyUnion_Type);
if (result == NULL) { if (result == NULL) {
return NULL; goto error;
} }
unionbuilder_finalize(ub);
result->parameters = NULL; result->parameters = NULL;
result->args = Py_NewRef(args); result->args = args;
result->hashable_args = hashable_args;
result->unhashable_args = unhashable_args;
result->weakreflist = NULL;
_PyObject_GC_TRACK(result); _PyObject_GC_TRACK(result);
return (PyObject*)result; return (PyObject*)result;
error:
Py_XDECREF(args);
Py_XDECREF(hashable_args);
Py_XDECREF(unhashable_args);
unionbuilder_finalize(ub);
return NULL;
} }