bpo-32896: Fix error when subclassing a dataclass with a field that uses a default_factory (GH-6170)
Fix the way that new annotations in a class are detected.
This commit is contained in:
parent
10b134a07c
commit
8f6eccdc64
@ -574,17 +574,18 @@ def _get_field(cls, a_name, a_type):
|
|||||||
|
|
||||||
def _find_fields(cls):
|
def _find_fields(cls):
|
||||||
# Return a list of Field objects, in order, for this class (and no
|
# Return a list of Field objects, in order, for this class (and no
|
||||||
# base classes). Fields are found from __annotations__ (which is
|
# base classes). Fields are found from the class dict's
|
||||||
# guaranteed to be ordered). Default values are from class
|
# __annotations__ (which is guaranteed to be ordered). Default
|
||||||
# attributes, if a field has a default. If the default value is
|
# values are from class attributes, if a field has a default. If
|
||||||
# a Field(), then it contains additional info beyond (and
|
# the default value is a Field(), then it contains additional
|
||||||
# possibly including) the actual default value. Pseudo-fields
|
# info beyond (and possibly including) the actual default value.
|
||||||
# ClassVars and InitVars are included, despite the fact that
|
# Pseudo-fields ClassVars and InitVars are included, despite the
|
||||||
# they're not real fields. That's dealt with later.
|
# fact that they're not real fields. That's dealt with later.
|
||||||
|
|
||||||
annotations = getattr(cls, '__annotations__', {})
|
# If __annotations__ isn't present, then this class adds no new
|
||||||
return [_get_field(cls, a_name, a_type)
|
# annotations.
|
||||||
for a_name, a_type in annotations.items()]
|
annotations = cls.__dict__.get('__annotations__', {})
|
||||||
|
return [_get_field(cls, name, type) for name, type in annotations.items()]
|
||||||
|
|
||||||
|
|
||||||
def _set_new_attribute(cls, name, value):
|
def _set_new_attribute(cls, name, value):
|
||||||
|
@ -1147,6 +1147,55 @@ class TestCase(unittest.TestCase):
|
|||||||
C().x
|
C().x
|
||||||
self.assertEqual(factory.call_count, 2)
|
self.assertEqual(factory.call_count, 2)
|
||||||
|
|
||||||
|
def test_default_factory_derived(self):
|
||||||
|
# See bpo-32896.
|
||||||
|
@dataclass
|
||||||
|
class Foo:
|
||||||
|
x: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Bar(Foo):
|
||||||
|
y: int = 1
|
||||||
|
|
||||||
|
self.assertEqual(Foo().x, {})
|
||||||
|
self.assertEqual(Bar().x, {})
|
||||||
|
self.assertEqual(Bar().y, 1)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Baz(Foo):
|
||||||
|
pass
|
||||||
|
self.assertEqual(Baz().x, {})
|
||||||
|
|
||||||
|
def test_intermediate_non_dataclass(self):
|
||||||
|
# Test that an intermediate class that defines
|
||||||
|
# annotations does not define fields.
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class A:
|
||||||
|
x: int
|
||||||
|
|
||||||
|
class B(A):
|
||||||
|
y: int
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class C(B):
|
||||||
|
z: int
|
||||||
|
|
||||||
|
c = C(1, 3)
|
||||||
|
self.assertEqual((c.x, c.z), (1, 3))
|
||||||
|
|
||||||
|
# .y was not initialized.
|
||||||
|
with self.assertRaisesRegex(AttributeError,
|
||||||
|
'object has no attribute'):
|
||||||
|
c.y
|
||||||
|
|
||||||
|
# And if we again derive a non-dataclass, no fields are added.
|
||||||
|
class D(C):
|
||||||
|
t: int
|
||||||
|
d = D(4, 5)
|
||||||
|
self.assertEqual((d.x, d.z), (4, 5))
|
||||||
|
|
||||||
|
|
||||||
def x_test_classvar_default_factory(self):
|
def x_test_classvar_default_factory(self):
|
||||||
# XXX: it's an error for a ClassVar to have a factory function
|
# XXX: it's an error for a ClassVar to have a factory function
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -0,0 +1,2 @@
|
|||||||
|
Fix an error where subclassing a dataclass with a field that uses a
|
||||||
|
default_factory would generate an incorrect class.
|
Loading…
x
Reference in New Issue
Block a user