gh-89263: Add typing.get_overloads (GH-31716)
Based on suggestions by Guido van Rossum, Spencer Brown, and Alex Waygood. Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com> Co-authored-by: Guido van Rossum <gvanrossum@gmail.com> Co-authored-by: Ken Jin <kenjin4096@gmail.com>
This commit is contained in:
parent
9300b6d729
commit
055760ed9e
@ -2407,6 +2407,35 @@ Functions and decorators
|
||||
|
||||
See :pep:`484` for details and comparison with other typing semantics.
|
||||
|
||||
.. versionchanged:: 3.11
|
||||
Overloaded functions can now be introspected at runtime using
|
||||
:func:`get_overloads`.
|
||||
|
||||
|
||||
.. function:: get_overloads(func)
|
||||
|
||||
Return a sequence of :func:`@overload <overload>`-decorated definitions for
|
||||
*func*. *func* is the function object for the implementation of the
|
||||
overloaded function. For example, given the definition of ``process`` in
|
||||
the documentation for :func:`@overload <overload>`,
|
||||
``get_overloads(process)`` will return a sequence of three function objects
|
||||
for the three defined overloads. If called on a function with no overloads,
|
||||
``get_overloads`` returns an empty sequence.
|
||||
|
||||
``get_overloads`` can be used for introspecting an overloaded function at
|
||||
runtime.
|
||||
|
||||
.. versionadded:: 3.11
|
||||
|
||||
|
||||
.. function:: clear_overloads()
|
||||
|
||||
Clear all registered overloads in the internal registry. This can be used
|
||||
to reclaim the memory used by the registry.
|
||||
|
||||
.. versionadded:: 3.11
|
||||
|
||||
|
||||
.. decorator:: final
|
||||
|
||||
A decorator to indicate to type checkers that the decorated method
|
||||
|
@ -1,5 +1,6 @@
|
||||
import contextlib
|
||||
import collections
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
import inspect
|
||||
import pickle
|
||||
@ -7,9 +8,11 @@ import re
|
||||
import sys
|
||||
import warnings
|
||||
from unittest import TestCase, main, skipUnless, skip
|
||||
from unittest.mock import patch
|
||||
from copy import copy, deepcopy
|
||||
|
||||
from typing import Any, NoReturn, Never, assert_never
|
||||
from typing import overload, get_overloads, clear_overloads
|
||||
from typing import TypeVar, TypeVarTuple, Unpack, AnyStr
|
||||
from typing import T, KT, VT # Not in __all__.
|
||||
from typing import Union, Optional, Literal
|
||||
@ -3890,11 +3893,22 @@ class ForwardRefTests(BaseTestCase):
|
||||
self.assertEqual("x" | X, Union["x", X])
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def cached_func(x, y):
|
||||
return 3 * x + y
|
||||
|
||||
|
||||
class MethodHolder:
|
||||
@classmethod
|
||||
def clsmethod(cls): ...
|
||||
@staticmethod
|
||||
def stmethod(): ...
|
||||
def method(self): ...
|
||||
|
||||
|
||||
class OverloadTests(BaseTestCase):
|
||||
|
||||
def test_overload_fails(self):
|
||||
from typing import overload
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
|
||||
@overload
|
||||
@ -3904,8 +3918,6 @@ class OverloadTests(BaseTestCase):
|
||||
blah()
|
||||
|
||||
def test_overload_succeeds(self):
|
||||
from typing import overload
|
||||
|
||||
@overload
|
||||
def blah():
|
||||
pass
|
||||
@ -3915,6 +3927,58 @@ class OverloadTests(BaseTestCase):
|
||||
|
||||
blah()
|
||||
|
||||
def set_up_overloads(self):
|
||||
def blah():
|
||||
pass
|
||||
|
||||
overload1 = blah
|
||||
overload(blah)
|
||||
|
||||
def blah():
|
||||
pass
|
||||
|
||||
overload2 = blah
|
||||
overload(blah)
|
||||
|
||||
def blah():
|
||||
pass
|
||||
|
||||
return blah, [overload1, overload2]
|
||||
|
||||
# Make sure we don't clear the global overload registry
|
||||
@patch("typing._overload_registry",
|
||||
defaultdict(lambda: defaultdict(dict)))
|
||||
def test_overload_registry(self):
|
||||
# The registry starts out empty
|
||||
self.assertEqual(typing._overload_registry, {})
|
||||
|
||||
impl, overloads = self.set_up_overloads()
|
||||
self.assertNotEqual(typing._overload_registry, {})
|
||||
self.assertEqual(list(get_overloads(impl)), overloads)
|
||||
|
||||
def some_other_func(): pass
|
||||
overload(some_other_func)
|
||||
other_overload = some_other_func
|
||||
def some_other_func(): pass
|
||||
self.assertEqual(list(get_overloads(some_other_func)), [other_overload])
|
||||
|
||||
# Make sure that after we clear all overloads, the registry is
|
||||
# completely empty.
|
||||
clear_overloads()
|
||||
self.assertEqual(typing._overload_registry, {})
|
||||
self.assertEqual(get_overloads(impl), [])
|
||||
|
||||
# Querying a function with no overloads shouldn't change the registry.
|
||||
def the_only_one(): pass
|
||||
self.assertEqual(get_overloads(the_only_one), [])
|
||||
self.assertEqual(typing._overload_registry, {})
|
||||
|
||||
def test_overload_registry_repeated(self):
|
||||
for _ in range(2):
|
||||
impl, overloads = self.set_up_overloads()
|
||||
|
||||
self.assertEqual(list(get_overloads(impl)), overloads)
|
||||
|
||||
|
||||
# Definitions needed for features introduced in Python 3.6
|
||||
|
||||
|
@ -21,6 +21,7 @@ At large scale, the structure of the module is following:
|
||||
|
||||
from abc import abstractmethod, ABCMeta
|
||||
import collections
|
||||
from collections import defaultdict
|
||||
import collections.abc
|
||||
import contextlib
|
||||
import functools
|
||||
@ -121,9 +122,11 @@ __all__ = [
|
||||
'assert_type',
|
||||
'assert_never',
|
||||
'cast',
|
||||
'clear_overloads',
|
||||
'final',
|
||||
'get_args',
|
||||
'get_origin',
|
||||
'get_overloads',
|
||||
'get_type_hints',
|
||||
'is_typeddict',
|
||||
'LiteralString',
|
||||
@ -2450,6 +2453,10 @@ def _overload_dummy(*args, **kwds):
|
||||
"by an implementation that is not @overload-ed.")
|
||||
|
||||
|
||||
# {module: {qualname: {firstlineno: func}}}
|
||||
_overload_registry = defaultdict(functools.partial(defaultdict, dict))
|
||||
|
||||
|
||||
def overload(func):
|
||||
"""Decorator for overloaded functions/methods.
|
||||
|
||||
@ -2475,10 +2482,37 @@ def overload(func):
|
||||
def utf8(value: str) -> bytes: ...
|
||||
def utf8(value):
|
||||
# implementation goes here
|
||||
|
||||
The overloads for a function can be retrieved at runtime using the
|
||||
get_overloads() function.
|
||||
"""
|
||||
# classmethod and staticmethod
|
||||
f = getattr(func, "__func__", func)
|
||||
try:
|
||||
_overload_registry[f.__module__][f.__qualname__][f.__code__.co_firstlineno] = func
|
||||
except AttributeError:
|
||||
# Not a normal function; ignore.
|
||||
pass
|
||||
return _overload_dummy
|
||||
|
||||
|
||||
def get_overloads(func):
|
||||
"""Return all defined overloads for *func* as a sequence."""
|
||||
# classmethod and staticmethod
|
||||
f = getattr(func, "__func__", func)
|
||||
if f.__module__ not in _overload_registry:
|
||||
return []
|
||||
mod_dict = _overload_registry[f.__module__]
|
||||
if f.__qualname__ not in mod_dict:
|
||||
return []
|
||||
return list(mod_dict[f.__qualname__].values())
|
||||
|
||||
|
||||
def clear_overloads():
|
||||
"""Clear all overloads in the registry."""
|
||||
_overload_registry.clear()
|
||||
|
||||
|
||||
def final(f):
|
||||
"""A decorator to indicate final methods and final classes.
|
||||
|
||||
|
@ -0,0 +1,2 @@
|
||||
Add :func:`typing.get_overloads` and :func:`typing.clear_overloads`.
|
||||
Patch by Jelle Zijlstra.
|
Loading…
x
Reference in New Issue
Block a user