gh-76785: Add Interpreter.prepare_main() (gh-113021)

This is one of the last pieces to get test.support.interpreters in sync with PEP 734.
This commit is contained in:
Eric Snow 2023-12-12 11:06:06 -07:00 committed by GitHub
parent a49b427b02
commit 9898e61041
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 146 additions and 17 deletions

View File

@ -130,7 +130,15 @@ class Interpreter:
""" """
return _interpreters.destroy(self._id) return _interpreters.destroy(self._id)
def exec_sync(self, code, /, channels=None): def prepare_main(self, ns=None, /, **kwargs):
"""Bind the given values into the interpreter's __main__.
The values must be shareable.
"""
ns = dict(ns, **kwargs) if ns is not None else kwargs
_interpreters.set___main___attrs(self._id, ns)
def exec_sync(self, code, /):
"""Run the given source code in the interpreter. """Run the given source code in the interpreter.
This is essentially the same as calling the builtin "exec" This is essentially the same as calling the builtin "exec"
@ -148,13 +156,13 @@ class Interpreter:
that time, the previous interpreter is allowed to run that time, the previous interpreter is allowed to run
in other threads. in other threads.
""" """
excinfo = _interpreters.exec(self._id, code, channels) excinfo = _interpreters.exec(self._id, code)
if excinfo is not None: if excinfo is not None:
raise ExecFailure(excinfo) raise ExecFailure(excinfo)
def run(self, code, /, channels=None): def run(self, code, /):
def task(): def task():
self.exec_sync(code, channels=channels) self.exec_sync(code)
t = threading.Thread(target=task) t = threading.Thread(target=task)
t.start() t.start()
return t return t

View File

@ -586,12 +586,12 @@ class ChannelTests(TestBase):
cid = channels.create() cid = channels.create()
interp = interpreters.create() interp = interpreters.create()
interpreters.set___main___attrs(interp, dict(cid=cid.send))
out = _run_output(interp, dedent(""" out = _run_output(interp, dedent("""
import _xxinterpchannels as _channels import _xxinterpchannels as _channels
print(cid.end) print(cid.end)
_channels.send(cid, b'spam', blocking=False) _channels.send(cid, b'spam', blocking=False)
"""), """))
dict(cid=cid.send))
obj = channels.recv(cid) obj = channels.recv(cid)
self.assertEqual(obj, b'spam') self.assertEqual(obj, b'spam')

View File

@ -33,10 +33,10 @@ def _captured_script(script):
return wrapped, open(r, encoding="utf-8") return wrapped, open(r, encoding="utf-8")
def _run_output(interp, request, shared=None): def _run_output(interp, request):
script, rpipe = _captured_script(request) script, rpipe = _captured_script(request)
with rpipe: with rpipe:
interpreters.run_string(interp, script, shared) interpreters.run_string(interp, script)
return rpipe.read() return rpipe.read()
@ -630,10 +630,10 @@ class RunStringTests(TestBase):
] ]
for obj in objects: for obj in objects:
with self.subTest(obj): with self.subTest(obj):
interpreters.set___main___attrs(interp, dict(obj=obj))
interpreters.run_string( interpreters.run_string(
interp, interp,
f'assert(obj == {obj!r})', f'assert(obj == {obj!r})',
shared=dict(obj=obj),
) )
def test_os_exec(self): def test_os_exec(self):
@ -721,7 +721,8 @@ class RunStringTests(TestBase):
with open({w}, 'wb') as chan: with open({w}, 'wb') as chan:
pickle.dump(ns, chan) pickle.dump(ns, chan)
""") """)
interpreters.run_string(self.id, script, shared) interpreters.set___main___attrs(self.id, shared)
interpreters.run_string(self.id, script)
with open(r, 'rb') as chan: with open(r, 'rb') as chan:
ns = pickle.load(chan) ns = pickle.load(chan)
@ -742,7 +743,8 @@ class RunStringTests(TestBase):
ns2 = dict(vars()) ns2 = dict(vars())
del ns2['__builtins__'] del ns2['__builtins__']
""") """)
interpreters.run_string(self.id, script, shared) interpreters.set___main___attrs(self.id, shared)
interpreters.run_string(self.id, script)
r, w = os.pipe() r, w = os.pipe()
script = dedent(f""" script = dedent(f"""
@ -773,7 +775,8 @@ class RunStringTests(TestBase):
with open({w}, 'wb') as chan: with open({w}, 'wb') as chan:
pickle.dump(ns, chan) pickle.dump(ns, chan)
""") """)
interpreters.run_string(self.id, script, shared) interpreters.set___main___attrs(self.id, shared)
interpreters.run_string(self.id, script)
with open(r, 'rb') as chan: with open(r, 'rb') as chan:
ns = pickle.load(chan) ns = pickle.load(chan)
@ -1036,7 +1039,8 @@ class RunFuncTests(TestBase):
with open(w, 'w', encoding="utf-8") as spipe: with open(w, 'w', encoding="utf-8") as spipe:
with contextlib.redirect_stdout(spipe): with contextlib.redirect_stdout(spipe):
print('it worked!', end='') print('it worked!', end='')
interpreters.run_func(self.id, script, shared=dict(w=w)) interpreters.set___main___attrs(self.id, dict(w=w))
interpreters.run_func(self.id, script)
with open(r, encoding="utf-8") as outfile: with open(r, encoding="utf-8") as outfile:
out = outfile.read() out = outfile.read()
@ -1052,7 +1056,8 @@ class RunFuncTests(TestBase):
with contextlib.redirect_stdout(spipe): with contextlib.redirect_stdout(spipe):
print('it worked!', end='') print('it worked!', end='')
def f(): def f():
interpreters.run_func(self.id, script, shared=dict(w=w)) interpreters.set___main___attrs(self.id, dict(w=w))
interpreters.run_func(self.id, script)
t = threading.Thread(target=f) t = threading.Thread(target=f)
t.start() t.start()
t.join() t.join()
@ -1072,7 +1077,8 @@ class RunFuncTests(TestBase):
with contextlib.redirect_stdout(spipe): with contextlib.redirect_stdout(spipe):
print('it worked!', end='') print('it worked!', end='')
code = script.__code__ code = script.__code__
interpreters.run_func(self.id, code, shared=dict(w=w)) interpreters.set___main___attrs(self.id, dict(w=w))
interpreters.run_func(self.id, code)
with open(r, encoding="utf-8") as outfile: with open(r, encoding="utf-8") as outfile:
out = outfile.read() out = outfile.read()

View File

@ -452,6 +452,63 @@ class TestInterpreterClose(TestBase):
self.assertEqual(os.read(r_interp, 1), FINISHED) self.assertEqual(os.read(r_interp, 1), FINISHED)
class TestInterpreterPrepareMain(TestBase):
def test_empty(self):
interp = interpreters.create()
with self.assertRaises(ValueError):
interp.prepare_main()
def test_dict(self):
values = {'spam': 42, 'eggs': 'ham'}
interp = interpreters.create()
interp.prepare_main(values)
out = _run_output(interp, dedent("""
print(spam, eggs)
"""))
self.assertEqual(out.strip(), '42 ham')
def test_tuple(self):
values = {'spam': 42, 'eggs': 'ham'}
values = tuple(values.items())
interp = interpreters.create()
interp.prepare_main(values)
out = _run_output(interp, dedent("""
print(spam, eggs)
"""))
self.assertEqual(out.strip(), '42 ham')
def test_kwargs(self):
values = {'spam': 42, 'eggs': 'ham'}
interp = interpreters.create()
interp.prepare_main(**values)
out = _run_output(interp, dedent("""
print(spam, eggs)
"""))
self.assertEqual(out.strip(), '42 ham')
def test_dict_and_kwargs(self):
values = {'spam': 42, 'eggs': 'ham'}
interp = interpreters.create()
interp.prepare_main(values, foo='bar')
out = _run_output(interp, dedent("""
print(spam, eggs, foo)
"""))
self.assertEqual(out.strip(), '42 ham bar')
def test_not_shareable(self):
interp = interpreters.create()
# XXX TypeError?
with self.assertRaises(ValueError):
interp.prepare_main(spam={'spam': 'eggs', 'foo': 'bar'})
# Make sure neither was actually bound.
with self.assertRaises(interpreters.ExecFailure):
interp.exec_sync('print(foo)')
with self.assertRaises(interpreters.ExecFailure):
interp.exec_sync('print(spam)')
class TestInterpreterExecSync(TestBase): class TestInterpreterExecSync(TestBase):
def test_success(self): def test_success(self):

View File

@ -29,10 +29,12 @@ def clean_up_interpreters():
pass # already destroyed pass # already destroyed
def _run_output(interp, request, channels=None): def _run_output(interp, request, init=None):
script, rpipe = _captured_script(request) script, rpipe = _captured_script(request)
with rpipe: with rpipe:
interp.exec_sync(script, channels=channels) if init:
interp.prepare_main(init)
interp.exec_sync(script)
return rpipe.read() return rpipe.read()

View File

@ -685,6 +685,60 @@ PyDoc_STRVAR(get_main_doc,
\n\ \n\
Return the ID of main interpreter."); Return the ID of main interpreter.");
static PyObject *
interp_set___main___attrs(PyObject *self, PyObject *args)
{
PyObject *id, *updates;
if (!PyArg_ParseTuple(args, "OO:" MODULE_NAME ".set___main___attrs",
&id, &updates))
{
return NULL;
}
// Look up the interpreter.
PyInterpreterState *interp = PyInterpreterID_LookUp(id);
if (interp == NULL) {
return NULL;
}
// Check the updates.
if (updates != Py_None) {
Py_ssize_t size = PyObject_Size(updates);
if (size < 0) {
return NULL;
}
if (size == 0) {
PyErr_SetString(PyExc_ValueError,
"arg 2 must be a non-empty mapping");
return NULL;
}
}
_PyXI_session session = {0};
// Prep and switch interpreters, including apply the updates.
if (_PyXI_Enter(&session, interp, updates) < 0) {
if (!PyErr_Occurred()) {
_PyXI_ApplyCapturedException(&session);
assert(PyErr_Occurred());
}
else {
assert(!_PyXI_HasCapturedException(&session));
}
return NULL;
}
// Clean up and switch back.
_PyXI_Exit(&session);
Py_RETURN_NONE;
}
PyDoc_STRVAR(set___main___attrs_doc,
"set___main___attrs(id, ns)\n\
\n\
Bind the given attributes in the interpreter's __main__ module.");
static PyUnicodeObject * static PyUnicodeObject *
convert_script_arg(PyObject *arg, const char *fname, const char *displayname, convert_script_arg(PyObject *arg, const char *fname, const char *displayname,
const char *expected) const char *expected)
@ -1033,6 +1087,8 @@ static PyMethodDef module_functions[] = {
{"run_func", _PyCFunction_CAST(interp_run_func), {"run_func", _PyCFunction_CAST(interp_run_func),
METH_VARARGS | METH_KEYWORDS, run_func_doc}, METH_VARARGS | METH_KEYWORDS, run_func_doc},
{"set___main___attrs", _PyCFunction_CAST(interp_set___main___attrs),
METH_VARARGS, set___main___attrs_doc},
{"is_shareable", _PyCFunction_CAST(object_is_shareable), {"is_shareable", _PyCFunction_CAST(object_is_shareable),
METH_VARARGS | METH_KEYWORDS, is_shareable_doc}, METH_VARARGS | METH_KEYWORDS, is_shareable_doc},