gh-106905: Use separate structs to track recursion depth in each PyAST_mod2obj call. (GH-113035)

Co-authored-by: Gregory P. Smith [Google LLC] <greg@krypto.org>
This commit is contained in:
Yilei Yang 2023-12-25 09:36:59 -08:00 committed by GitHub
parent 3f5eb3e6c7
commit 48c49739f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 412 additions and 339 deletions

View File

@ -16,8 +16,6 @@ extern "C" {
struct ast_state { struct ast_state {
_PyOnceFlag once; _PyOnceFlag once;
int finalized; int finalized;
int recursion_depth;
int recursion_limit;
PyObject *AST_type; PyObject *AST_type;
PyObject *Add_singleton; PyObject *Add_singleton;
PyObject *Add_type; PyObject *Add_type;

View File

@ -0,0 +1,7 @@
Use per AST-parser state rather than global state to track recursion depth
within the AST parser to prevent potential race condition due to
simultaneous parsing.
The issue primarily showed up in 3.11 by multithreaded users of
:func:`ast.parse`. In 3.12 a change to when garbage collection can be
triggered prevented the race condition from occurring.

View File

@ -731,7 +731,7 @@ class SequenceConstructorVisitor(EmitVisitor):
class PyTypesDeclareVisitor(PickleVisitor): class PyTypesDeclareVisitor(PickleVisitor):
def visitProduct(self, prod, name): def visitProduct(self, prod, name):
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, void*);" % name, 0) self.emit("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, void*);" % name, 0)
if prod.attributes: if prod.attributes:
self.emit("static const char * const %s_attributes[] = {" % name, 0) self.emit("static const char * const %s_attributes[] = {" % name, 0)
for a in prod.attributes: for a in prod.attributes:
@ -752,7 +752,7 @@ class PyTypesDeclareVisitor(PickleVisitor):
ptype = "void*" ptype = "void*"
if is_simple(sum): if is_simple(sum):
ptype = get_c_type(name) ptype = get_c_type(name)
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, %s);" % (name, ptype), 0) self.emit("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s);" % (name, ptype), 0)
for t in sum.types: for t in sum.types:
self.visitConstructor(t, name) self.visitConstructor(t, name)
@ -984,7 +984,8 @@ add_attributes(struct ast_state *state, PyObject *type, const char * const *attr
/* Conversion AST -> Python */ /* Conversion AST -> Python */
static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq, PyObject* (*func)(struct ast_state *state, void*)) static PyObject* ast2obj_list(struct ast_state *state, struct validator *vstate, asdl_seq *seq,
PyObject* (*func)(struct ast_state *state, struct validator *vstate, void*))
{ {
Py_ssize_t i, n = asdl_seq_LEN(seq); Py_ssize_t i, n = asdl_seq_LEN(seq);
PyObject *result = PyList_New(n); PyObject *result = PyList_New(n);
@ -992,7 +993,7 @@ static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq, PyObject*
if (!result) if (!result)
return NULL; return NULL;
for (i = 0; i < n; i++) { for (i = 0; i < n; i++) {
value = func(state, asdl_seq_GET_UNTYPED(seq, i)); value = func(state, vstate, asdl_seq_GET_UNTYPED(seq, i));
if (!value) { if (!value) {
Py_DECREF(result); Py_DECREF(result);
return NULL; return NULL;
@ -1002,7 +1003,7 @@ static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq, PyObject*
return result; return result;
} }
static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), void *o) static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), void *o)
{ {
PyObject *op = (PyObject*)o; PyObject *op = (PyObject*)o;
if (!op) { if (!op) {
@ -1014,7 +1015,7 @@ static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), void *o)
#define ast2obj_identifier ast2obj_object #define ast2obj_identifier ast2obj_object
#define ast2obj_string ast2obj_object #define ast2obj_string ast2obj_object
static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), long b) static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), long b)
{ {
return PyLong_FromLong(b); return PyLong_FromLong(b);
} }
@ -1116,8 +1117,6 @@ static int add_ast_fields(struct ast_state *state)
for dfn in mod.dfns: for dfn in mod.dfns:
self.visit(dfn) self.visit(dfn)
self.file.write(textwrap.dedent(''' self.file.write(textwrap.dedent('''
state->recursion_depth = 0;
state->recursion_limit = 0;
return 0; return 0;
} }
''')) '''))
@ -1260,7 +1259,7 @@ class ObjVisitor(PickleVisitor):
def func_begin(self, name): def func_begin(self, name):
ctype = get_c_type(name) ctype = get_c_type(name)
self.emit("PyObject*", 0) self.emit("PyObject*", 0)
self.emit("ast2obj_%s(struct ast_state *state, void* _o)" % (name), 0) self.emit("ast2obj_%s(struct ast_state *state, struct validator *vstate, void* _o)" % (name), 0)
self.emit("{", 0) self.emit("{", 0)
self.emit("%s o = (%s)_o;" % (ctype, ctype), 1) self.emit("%s o = (%s)_o;" % (ctype, ctype), 1)
self.emit("PyObject *result = NULL, *value = NULL;", 1) self.emit("PyObject *result = NULL, *value = NULL;", 1)
@ -1268,17 +1267,17 @@ class ObjVisitor(PickleVisitor):
self.emit('if (!o) {', 1) self.emit('if (!o) {', 1)
self.emit("Py_RETURN_NONE;", 2) self.emit("Py_RETURN_NONE;", 2)
self.emit("}", 1) self.emit("}", 1)
self.emit("if (++state->recursion_depth > state->recursion_limit) {", 1) self.emit("if (++vstate->recursion_depth > vstate->recursion_limit) {", 1)
self.emit("PyErr_SetString(PyExc_RecursionError,", 2) self.emit("PyErr_SetString(PyExc_RecursionError,", 2)
self.emit('"maximum recursion depth exceeded during ast construction");', 3) self.emit('"maximum recursion depth exceeded during ast construction");', 3)
self.emit("return NULL;", 2) self.emit("return NULL;", 2)
self.emit("}", 1) self.emit("}", 1)
def func_end(self): def func_end(self):
self.emit("state->recursion_depth--;", 1) self.emit("vstate->recursion_depth--;", 1)
self.emit("return result;", 1) self.emit("return result;", 1)
self.emit("failed:", 0) self.emit("failed:", 0)
self.emit("state->recursion_depth--;", 1) self.emit("vstate->recursion_depth--;", 1)
self.emit("Py_XDECREF(value);", 1) self.emit("Py_XDECREF(value);", 1)
self.emit("Py_XDECREF(result);", 1) self.emit("Py_XDECREF(result);", 1)
self.emit("return NULL;", 1) self.emit("return NULL;", 1)
@ -1296,7 +1295,7 @@ class ObjVisitor(PickleVisitor):
self.visitConstructor(t, i + 1, name) self.visitConstructor(t, i + 1, name)
self.emit("}", 1) self.emit("}", 1)
for a in sum.attributes: for a in sum.attributes:
self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1) self.emit("value = ast2obj_%s(state, vstate, o->%s);" % (a.type, a.name), 1)
self.emit("if (!value) goto failed;", 1) self.emit("if (!value) goto failed;", 1)
self.emit('if (PyObject_SetAttr(result, state->%s, value) < 0)' % a.name, 1) self.emit('if (PyObject_SetAttr(result, state->%s, value) < 0)' % a.name, 1)
self.emit('goto failed;', 2) self.emit('goto failed;', 2)
@ -1304,7 +1303,7 @@ class ObjVisitor(PickleVisitor):
self.func_end() self.func_end()
def simpleSum(self, sum, name): def simpleSum(self, sum, name):
self.emit("PyObject* ast2obj_%s(struct ast_state *state, %s_ty o)" % (name, name), 0) self.emit("PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s_ty o)" % (name, name), 0)
self.emit("{", 0) self.emit("{", 0)
self.emit("switch(o) {", 1) self.emit("switch(o) {", 1)
for t in sum.types: for t in sum.types:
@ -1322,7 +1321,7 @@ class ObjVisitor(PickleVisitor):
for field in prod.fields: for field in prod.fields:
self.visitField(field, name, 1, True) self.visitField(field, name, 1, True)
for a in prod.attributes: for a in prod.attributes:
self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1) self.emit("value = ast2obj_%s(state, vstate, o->%s);" % (a.type, a.name), 1)
self.emit("if (!value) goto failed;", 1) self.emit("if (!value) goto failed;", 1)
self.emit("if (PyObject_SetAttr(result, state->%s, value) < 0)" % a.name, 1) self.emit("if (PyObject_SetAttr(result, state->%s, value) < 0)" % a.name, 1)
self.emit('goto failed;', 2) self.emit('goto failed;', 2)
@ -1363,7 +1362,7 @@ class ObjVisitor(PickleVisitor):
self.emit("for(i = 0; i < n; i++)", depth+1) self.emit("for(i = 0; i < n; i++)", depth+1)
# This cannot fail, so no need for error handling # This cannot fail, so no need for error handling
self.emit( self.emit(
"PyList_SET_ITEM(value, i, ast2obj_{0}(state, ({0}_ty)asdl_seq_GET({1}, i)));".format( "PyList_SET_ITEM(value, i, ast2obj_{0}(state, vstate, ({0}_ty)asdl_seq_GET({1}, i)));".format(
field.type, field.type,
value value
), ),
@ -1372,9 +1371,9 @@ class ObjVisitor(PickleVisitor):
) )
self.emit("}", depth) self.emit("}", depth)
else: else:
self.emit("value = ast2obj_list(state, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth) self.emit("value = ast2obj_list(state, vstate, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth)
else: else:
self.emit("value = ast2obj_%s(state, %s);" % (field.type, value), depth, reflow=False) self.emit("value = ast2obj_%s(state, vstate, %s);" % (field.type, value), depth, reflow=False)
class PartingShots(StaticVisitor): class PartingShots(StaticVisitor):
@ -1394,18 +1393,19 @@ PyObject* PyAST_mod2obj(mod_ty t)
if (!tstate) { if (!tstate) {
return NULL; return NULL;
} }
state->recursion_limit = Py_C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE; struct validator vstate;
vstate.recursion_limit = Py_C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
int recursion_depth = Py_C_RECURSION_LIMIT - tstate->c_recursion_remaining; int recursion_depth = Py_C_RECURSION_LIMIT - tstate->c_recursion_remaining;
starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE; starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE;
state->recursion_depth = starting_recursion_depth; vstate.recursion_depth = starting_recursion_depth;
PyObject *result = ast2obj_mod(state, t); PyObject *result = ast2obj_mod(state, &vstate, t);
/* Check that the recursion depth counting balanced correctly */ /* Check that the recursion depth counting balanced correctly */
if (result && state->recursion_depth != starting_recursion_depth) { if (result && vstate.recursion_depth != starting_recursion_depth) {
PyErr_Format(PyExc_SystemError, PyErr_Format(PyExc_SystemError,
"AST constructor recursion depth mismatch (before=%d, after=%d)", "AST constructor recursion depth mismatch (before=%d, after=%d)",
starting_recursion_depth, state->recursion_depth); starting_recursion_depth, vstate.recursion_depth);
return NULL; return NULL;
} }
return result; return result;
@ -1475,8 +1475,6 @@ def generate_ast_state(module_state, f):
f.write('struct ast_state {\n') f.write('struct ast_state {\n')
f.write(' _PyOnceFlag once;\n') f.write(' _PyOnceFlag once;\n')
f.write(' int finalized;\n') f.write(' int finalized;\n')
f.write(' int recursion_depth;\n')
f.write(' int recursion_limit;\n')
for s in module_state: for s in module_state:
f.write(' PyObject *' + s + ';\n') f.write(' PyObject *' + s + ';\n')
f.write('};') f.write('};')
@ -1539,6 +1537,11 @@ def generate_module_def(mod, metadata, f, internal_h):
#include "pycore_pystate.h" // _PyInterpreterState_GET() #include "pycore_pystate.h" // _PyInterpreterState_GET()
#include <stddef.h> #include <stddef.h>
struct validator {
int recursion_depth; /* current recursion depth */
int recursion_limit; /* recursion limit */
};
// Forward declaration // Forward declaration
static int init_types(struct ast_state *state); static int init_types(struct ast_state *state);

689
Python/Python-ast.c generated

File diff suppressed because it is too large Load Diff