// typing.Union -- used to represent e.g. Union[int, str], int | str #include "Python.h" #include "pycore_object.h" // _PyObject_GC_TRACK/UNTRACK #include "pycore_typevarobject.h" // _PyTypeAlias_Type, _Py_typing_type_repr #include "pycore_unicodeobject.h" // _PyUnicode_EqualToASCIIString #include "pycore_unionobject.h" typedef struct { PyObject_HEAD PyObject *args; // all args (tuple) PyObject *hashable_args; // frozenset or NULL PyObject *unhashable_args; // tuple or NULL PyObject *parameters; PyObject *weakreflist; } unionobject; static void unionobject_dealloc(PyObject *self) { unionobject *alias = (unionobject *)self; _PyObject_GC_UNTRACK(self); if (alias->weakreflist != NULL) { PyObject_ClearWeakRefs((PyObject *)alias); } Py_XDECREF(alias->args); Py_XDECREF(alias->hashable_args); Py_XDECREF(alias->unhashable_args); Py_XDECREF(alias->parameters); Py_TYPE(self)->tp_free(self); } static int union_traverse(PyObject *self, visitproc visit, void *arg) { unionobject *alias = (unionobject *)self; Py_VISIT(alias->args); Py_VISIT(alias->hashable_args); Py_VISIT(alias->unhashable_args); Py_VISIT(alias->parameters); return 0; } static Py_hash_t union_hash(PyObject *self) { unionobject *alias = (unionobject *)self; // If there are any unhashable args, treat this union as unhashable. // Otherwise, two unions might compare equal but have different hashes. if (alias->unhashable_args) { // Attempt to get an error from one of the values. assert(PyTuple_CheckExact(alias->unhashable_args)); Py_ssize_t n = PyTuple_GET_SIZE(alias->unhashable_args); for (Py_ssize_t i = 0; i < n; i++) { PyObject *arg = PyTuple_GET_ITEM(alias->unhashable_args, i); Py_hash_t hash = PyObject_Hash(arg); if (hash == -1) { return -1; } } // The unhashable values somehow became hashable again. Still raise // an error. PyErr_Format(PyExc_TypeError, "union contains %d unhashable elements", n); return -1; } return PyObject_Hash(alias->hashable_args); } static int unions_equal(unionobject *a, unionobject *b) { int result = PyObject_RichCompareBool(a->hashable_args, b->hashable_args, Py_EQ); if (result == -1) { return -1; } if (result == 0) { return 0; } if (a->unhashable_args && b->unhashable_args) { Py_ssize_t n = PyTuple_GET_SIZE(a->unhashable_args); if (n != PyTuple_GET_SIZE(b->unhashable_args)) { return 0; } for (Py_ssize_t i = 0; i < n; i++) { PyObject *arg_a = PyTuple_GET_ITEM(a->unhashable_args, i); int result = PySequence_Contains(b->unhashable_args, arg_a); if (result == -1) { return -1; } if (!result) { return 0; } } for (Py_ssize_t i = 0; i < n; i++) { PyObject *arg_b = PyTuple_GET_ITEM(b->unhashable_args, i); int result = PySequence_Contains(a->unhashable_args, arg_b); if (result == -1) { return -1; } if (!result) { return 0; } } } else if (a->unhashable_args || b->unhashable_args) { return 0; } return 1; } static PyObject * union_richcompare(PyObject *a, PyObject *b, int op) { if (!_PyUnion_Check(b) || (op != Py_EQ && op != Py_NE)) { Py_RETURN_NOTIMPLEMENTED; } int equal = unions_equal((unionobject*)a, (unionobject*)b); if (equal == -1) { return NULL; } if (op == Py_EQ) { return PyBool_FromLong(equal); } else { return PyBool_FromLong(!equal); } } typedef struct { PyObject *args; // list PyObject *hashable_args; // set PyObject *unhashable_args; // list or NULL bool is_checked; // whether to call type_check() } unionbuilder; static bool unionbuilder_add_tuple(unionbuilder *, PyObject *); static PyObject *make_union(unionbuilder *); static PyObject *type_check(PyObject *, const char *); static bool unionbuilder_init(unionbuilder *ub, bool is_checked) { ub->args = PyList_New(0); if (ub->args == NULL) { return false; } ub->hashable_args = PySet_New(NULL); if (ub->hashable_args == NULL) { Py_DECREF(ub->args); return false; } ub->unhashable_args = NULL; ub->is_checked = is_checked; return true; } static void unionbuilder_finalize(unionbuilder *ub) { Py_DECREF(ub->args); Py_DECREF(ub->hashable_args); Py_XDECREF(ub->unhashable_args); } static bool unionbuilder_add_single_unchecked(unionbuilder *ub, PyObject *arg) { Py_hash_t hash = PyObject_Hash(arg); if (hash == -1) { PyErr_Clear(); if (ub->unhashable_args == NULL) { ub->unhashable_args = PyList_New(0); if (ub->unhashable_args == NULL) { return false; } } else { int contains = PySequence_Contains(ub->unhashable_args, arg); if (contains < 0) { return false; } if (contains == 1) { return true; } } if (PyList_Append(ub->unhashable_args, arg) < 0) { return false; } } else { int contains = PySet_Contains(ub->hashable_args, arg); if (contains < 0) { return false; } if (contains == 1) { return true; } if (PySet_Add(ub->hashable_args, arg) < 0) { return false; } } return PyList_Append(ub->args, arg) == 0; } static bool unionbuilder_add_single(unionbuilder *ub, PyObject *arg) { if (Py_IsNone(arg)) { arg = (PyObject *)&_PyNone_Type; // immortal, so no refcounting needed } else if (_PyUnion_Check(arg)) { PyObject *args = ((unionobject *)arg)->args; return unionbuilder_add_tuple(ub, args); } if (ub->is_checked) { PyObject *type = type_check(arg, "Union[arg, ...]: each arg must be a type."); if (type == NULL) { return false; } bool result = unionbuilder_add_single_unchecked(ub, type); Py_DECREF(type); return result; } else { return unionbuilder_add_single_unchecked(ub, arg); } } static bool unionbuilder_add_tuple(unionbuilder *ub, PyObject *tuple) { Py_ssize_t n = PyTuple_GET_SIZE(tuple); for (Py_ssize_t i = 0; i < n; i++) { if (!unionbuilder_add_single(ub, PyTuple_GET_ITEM(tuple, i))) { return false; } } return true; } static int is_unionable(PyObject *obj) { if (obj == Py_None || PyType_Check(obj) || _PyGenericAlias_Check(obj) || _PyUnion_Check(obj) || Py_IS_TYPE(obj, &_PyTypeAlias_Type)) { return 1; } return 0; } PyObject * _Py_union_type_or(PyObject* self, PyObject* other) { if (!is_unionable(self) || !is_unionable(other)) { Py_RETURN_NOTIMPLEMENTED; } unionbuilder ub; // unchecked because we already checked is_unionable() if (!unionbuilder_init(&ub, false)) { return NULL; } if (!unionbuilder_add_single(&ub, self) || !unionbuilder_add_single(&ub, other)) { unionbuilder_finalize(&ub); return NULL; } PyObject *new_union = make_union(&ub); return new_union; } static PyObject * union_repr(PyObject *self) { unionobject *alias = (unionobject *)self; Py_ssize_t len = PyTuple_GET_SIZE(alias->args); // Shortest type name "int" (3 chars) + " | " (3 chars) separator Py_ssize_t estimate = (len <= PY_SSIZE_T_MAX / 6) ? len * 6 : len; PyUnicodeWriter *writer = PyUnicodeWriter_Create(estimate); if (writer == NULL) { return NULL; } for (Py_ssize_t i = 0; i < len; i++) { if (i > 0 && PyUnicodeWriter_WriteASCII(writer, " | ", 3) < 0) { goto error; } PyObject *p = PyTuple_GET_ITEM(alias->args, i); if (_Py_typing_type_repr(writer, p) < 0) { goto error; } } #if 0 PyUnicodeWriter_WriteASCII(writer, "|args=", 6); PyUnicodeWriter_WriteRepr(writer, alias->args); PyUnicodeWriter_WriteASCII(writer, "|h=", 3); PyUnicodeWriter_WriteRepr(writer, alias->hashable_args); if (alias->unhashable_args) { PyUnicodeWriter_WriteASCII(writer, "|u=", 3); PyUnicodeWriter_WriteRepr(writer, alias->unhashable_args); } #endif return PyUnicodeWriter_Finish(writer); error: PyUnicodeWriter_Discard(writer); return NULL; } static PyMemberDef union_members[] = { {"__args__", _Py_T_OBJECT, offsetof(unionobject, args), Py_READONLY}, {0} }; // Populate __parameters__ if needed. static int union_init_parameters(unionobject *alias) { int result = 0; Py_BEGIN_CRITICAL_SECTION(alias); if (alias->parameters == NULL) { alias->parameters = _Py_make_parameters(alias->args); if (alias->parameters == NULL) { result = -1; } } Py_END_CRITICAL_SECTION(); return result; } static PyObject * union_getitem(PyObject *self, PyObject *item) { unionobject *alias = (unionobject *)self; if (union_init_parameters(alias) < 0) { return NULL; } PyObject *newargs = _Py_subs_parameters(self, alias->args, alias->parameters, item); if (newargs == NULL) { return NULL; } PyObject *res = _Py_union_from_tuple(newargs); Py_DECREF(newargs); return res; } static PyMappingMethods union_as_mapping = { .mp_subscript = union_getitem, }; static PyObject * union_parameters(PyObject *self, void *Py_UNUSED(unused)) { unionobject *alias = (unionobject *)self; if (union_init_parameters(alias) < 0) { return NULL; } return Py_NewRef(alias->parameters); } static PyObject * union_name(PyObject *Py_UNUSED(self), void *Py_UNUSED(ignored)) { return PyUnicode_FromString("Union"); } static PyObject * union_origin(PyObject *Py_UNUSED(self), void *Py_UNUSED(ignored)) { return Py_NewRef(&_PyUnion_Type); } static PyGetSetDef union_properties[] = { {"__name__", union_name, NULL, PyDoc_STR("Name of the type"), NULL}, {"__qualname__", union_name, NULL, PyDoc_STR("Qualified name of the type"), NULL}, {"__origin__", union_origin, NULL, PyDoc_STR("Always returns the type"), NULL}, {"__parameters__", union_parameters, NULL, PyDoc_STR("Type variables in the types.UnionType."), NULL}, {0} }; static PyNumberMethods union_as_number = { .nb_or = _Py_union_type_or, // Add __or__ function }; static const char* const cls_attrs[] = { "__module__", // Required for compatibility with typing module NULL, }; static PyObject * union_getattro(PyObject *self, PyObject *name) { unionobject *alias = (unionobject *)self; if (PyUnicode_Check(name)) { for (const char * const *p = cls_attrs; ; p++) { if (*p == NULL) { break; } if (_PyUnicode_EqualToASCIIString(name, *p)) { return PyObject_GetAttr((PyObject *) Py_TYPE(alias), name); } } } return PyObject_GenericGetAttr(self, name); } PyObject * _Py_union_args(PyObject *self) { assert(_PyUnion_Check(self)); return ((unionobject *) self)->args; } static PyObject * call_typing_func_object(const char *name, PyObject **args, size_t nargs) { PyObject *typing = PyImport_ImportModule("typing"); if (typing == NULL) { return NULL; } PyObject *func = PyObject_GetAttrString(typing, name); if (func == NULL) { Py_DECREF(typing); return NULL; } PyObject *result = PyObject_Vectorcall(func, args, nargs, NULL); Py_DECREF(func); Py_DECREF(typing); return result; } static PyObject * type_check(PyObject *arg, const char *msg) { if (Py_IsNone(arg)) { // NoneType is immortal, so don't need an INCREF return (PyObject *)Py_TYPE(arg); } // Fast path to avoid calling into typing.py if (is_unionable(arg)) { return Py_NewRef(arg); } PyObject *message_str = PyUnicode_FromString(msg); if (message_str == NULL) { return NULL; } PyObject *args[2] = {arg, message_str}; PyObject *result = call_typing_func_object("_type_check", args, 2); Py_DECREF(message_str); return result; } PyObject * _Py_union_from_tuple(PyObject *args) { unionbuilder ub; if (!unionbuilder_init(&ub, true)) { return NULL; } if (PyTuple_CheckExact(args)) { if (!unionbuilder_add_tuple(&ub, args)) { return NULL; } } else { if (!unionbuilder_add_single(&ub, args)) { return NULL; } } return make_union(&ub); } static PyObject * union_class_getitem(PyObject *cls, PyObject *args) { return _Py_union_from_tuple(args); } static PyObject * union_mro_entries(PyObject *self, PyObject *args) { return PyErr_Format(PyExc_TypeError, "Cannot subclass %R", self); } static PyMethodDef union_methods[] = { {"__mro_entries__", union_mro_entries, METH_O}, {"__class_getitem__", union_class_getitem, METH_O|METH_CLASS, PyDoc_STR("See PEP 585")}, {0} }; PyTypeObject _PyUnion_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) .tp_name = "typing.Union", .tp_doc = PyDoc_STR("Represent a union type\n" "\n" "E.g. for int | str"), .tp_basicsize = sizeof(unionobject), .tp_dealloc = unionobject_dealloc, .tp_alloc = PyType_GenericAlloc, .tp_free = PyObject_GC_Del, .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, .tp_traverse = union_traverse, .tp_hash = union_hash, .tp_getattro = union_getattro, .tp_members = union_members, .tp_methods = union_methods, .tp_richcompare = union_richcompare, .tp_as_mapping = &union_as_mapping, .tp_as_number = &union_as_number, .tp_repr = union_repr, .tp_getset = union_properties, .tp_weaklistoffset = offsetof(unionobject, weakreflist), }; static PyObject * make_union(unionbuilder *ub) { Py_ssize_t n = PyList_GET_SIZE(ub->args); if (n == 0) { PyErr_SetString(PyExc_TypeError, "Cannot take a Union of no types."); unionbuilder_finalize(ub); return NULL; } if (n == 1) { PyObject *result = PyList_GET_ITEM(ub->args, 0); Py_INCREF(result); unionbuilder_finalize(ub); return result; } PyObject *args = NULL, *hashable_args = NULL, *unhashable_args = NULL; args = PyList_AsTuple(ub->args); if (args == NULL) { goto error; } hashable_args = PyFrozenSet_New(ub->hashable_args); if (hashable_args == NULL) { goto error; } if (ub->unhashable_args != NULL) { unhashable_args = PyList_AsTuple(ub->unhashable_args); if (unhashable_args == NULL) { goto error; } } unionobject *result = PyObject_GC_New(unionobject, &_PyUnion_Type); if (result == NULL) { goto error; } unionbuilder_finalize(ub); result->parameters = NULL; result->args = args; result->hashable_args = hashable_args; result->unhashable_args = unhashable_args; result->weakreflist = NULL; _PyObject_GC_TRACK(result); return (PyObject*)result; error: Py_XDECREF(args); Py_XDECREF(hashable_args); Py_XDECREF(unhashable_args); unionbuilder_finalize(ub); return NULL; }