gh-79097: Add support for aggregate window functions in sqlite3 (GH-20903)

This commit is contained in:
Erlend Egeberg Aasland 2022-04-12 02:55:59 +02:00 committed by GitHub
parent f45aa8f304
commit 9ebcece82f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 477 additions and 13 deletions

View File

@ -0,0 +1,46 @@
# Example taken from https://www.sqlite.org/windowfunctions.html#udfwinfunc
import sqlite3
class WindowSumInt:
def __init__(self):
self.count = 0
def step(self, value):
"""Adds a row to the current window."""
self.count += value
def value(self):
"""Returns the current value of the aggregate."""
return self.count
def inverse(self, value):
"""Removes a row from the current window."""
self.count -= value
def finalize(self):
"""Returns the final value of the aggregate.
Any clean-up actions should be placed here.
"""
return self.count
con = sqlite3.connect(":memory:")
cur = con.execute("create table test(x, y)")
values = [
("a", 4),
("b", 5),
("c", 3),
("d", 8),
("e", 1),
]
cur.executemany("insert into test values(?, ?)", values)
con.create_window_function("sumint", 1, WindowSumInt)
cur.execute("""
select x, sumint(y) over (
order by x rows between 1 preceding and 1 following
) as sum_y
from test order by x
""")
print(cur.fetchall())

View File

@ -473,6 +473,35 @@ Connection Objects
.. literalinclude:: ../includes/sqlite3/mysumaggr.py .. literalinclude:: ../includes/sqlite3/mysumaggr.py
.. method:: create_window_function(name, num_params, aggregate_class, /)
Creates user-defined aggregate window function *name*.
*aggregate_class* must implement the following methods:
* ``step``: adds a row to the current window
* ``value``: returns the current value of the aggregate
* ``inverse``: removes a row from the current window
* ``finalize``: returns the final value of the aggregate
``step`` and ``value`` accept *num_params* number of parameters,
unless *num_params* is ``-1``, in which case they may take any number of
arguments. ``finalize`` and ``value`` can return any of the types
supported by SQLite:
:class:`bytes`, :class:`str`, :class:`int`, :class:`float`, and
:const:`None`. Call :meth:`create_window_function` with
*aggregate_class* set to :const:`None` to clear window function *name*.
Aggregate window functions are supported by SQLite 3.25.0 and higher.
:exc:`NotSupportedError` will be raised if used with older versions.
.. versionadded:: 3.11
Example:
.. literalinclude:: ../includes/sqlite3/sumintwindow.py
.. method:: create_collation(name, callable) .. method:: create_collation(name, callable)
Creates a collation with the specified *name* and *callable*. The callable will Creates a collation with the specified *name* and *callable*. The callable will

View File

@ -389,6 +389,10 @@ sqlite3
serializing and deserializing databases. serializing and deserializing databases.
(Contributed by Erlend E. Aasland in :issue:`41930`.) (Contributed by Erlend E. Aasland in :issue:`41930`.)
* Add :meth:`~sqlite3.Connection.create_window_function` to
:class:`sqlite3.Connection` for creating aggregate window functions.
(Contributed by Erlend E. Aasland in :issue:`34916`.)
sys sys
--- ---

View File

@ -1084,6 +1084,8 @@ class ThreadTests(unittest.TestCase):
if hasattr(sqlite.Connection, "serialize"): if hasattr(sqlite.Connection, "serialize"):
fns.append(lambda: self.con.serialize()) fns.append(lambda: self.con.serialize())
fns.append(lambda: self.con.deserialize(b"")) fns.append(lambda: self.con.deserialize(b""))
if sqlite.sqlite_version_info >= (3, 25, 0):
fns.append(lambda: self.con.create_window_function("foo", 0, None))
for fn in fns: for fn in fns:
with self.subTest(fn=fn): with self.subTest(fn=fn):

View File

@ -27,9 +27,9 @@ import io
import re import re
import sys import sys
import unittest import unittest
import unittest.mock
import sqlite3 as sqlite import sqlite3 as sqlite
from unittest.mock import Mock, patch
from test.support import bigmemtest, catch_unraisable_exception, gc_collect from test.support import bigmemtest, catch_unraisable_exception, gc_collect
from test.test_sqlite3.test_dbapi import cx_limit from test.test_sqlite3.test_dbapi import cx_limit
@ -393,7 +393,7 @@ class FunctionTests(unittest.TestCase):
# indices, which allows testing based on syntax, iso. the query optimizer. # indices, which allows testing based on syntax, iso. the query optimizer.
@unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher") @unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher")
def test_func_non_deterministic(self): def test_func_non_deterministic(self):
mock = unittest.mock.Mock(return_value=None) mock = Mock(return_value=None)
self.con.create_function("nondeterministic", 0, mock, deterministic=False) self.con.create_function("nondeterministic", 0, mock, deterministic=False)
if sqlite.sqlite_version_info < (3, 15, 0): if sqlite.sqlite_version_info < (3, 15, 0):
self.con.execute("select nondeterministic() = nondeterministic()") self.con.execute("select nondeterministic() = nondeterministic()")
@ -404,7 +404,7 @@ class FunctionTests(unittest.TestCase):
@unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher") @unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher")
def test_func_deterministic(self): def test_func_deterministic(self):
mock = unittest.mock.Mock(return_value=None) mock = Mock(return_value=None)
self.con.create_function("deterministic", 0, mock, deterministic=True) self.con.create_function("deterministic", 0, mock, deterministic=True)
if sqlite.sqlite_version_info < (3, 15, 0): if sqlite.sqlite_version_info < (3, 15, 0):
self.con.execute("select deterministic() = deterministic()") self.con.execute("select deterministic() = deterministic()")
@ -482,6 +482,164 @@ class FunctionTests(unittest.TestCase):
self.con.execute, "select badreturn()") self.con.execute, "select badreturn()")
class WindowSumInt:
def __init__(self):
self.count = 0
def step(self, value):
self.count += value
def value(self):
return self.count
def inverse(self, value):
self.count -= value
def finalize(self):
return self.count
class BadWindow(Exception):
pass
@unittest.skipIf(sqlite.sqlite_version_info < (3, 25, 0),
"Requires SQLite 3.25.0 or newer")
class WindowFunctionTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
self.cur = self.con.cursor()
# Test case taken from https://www.sqlite.org/windowfunctions.html#udfwinfunc
values = [
("a", 4),
("b", 5),
("c", 3),
("d", 8),
("e", 1),
]
with self.con:
self.con.execute("create table test(x, y)")
self.con.executemany("insert into test values(?, ?)", values)
self.expected = [
("a", 9),
("b", 12),
("c", 16),
("d", 12),
("e", 9),
]
self.query = """
select x, %s(y) over (
order by x rows between 1 preceding and 1 following
) as sum_y
from test order by x
"""
self.con.create_window_function("sumint", 1, WindowSumInt)
def test_win_sum_int(self):
self.cur.execute(self.query % "sumint")
self.assertEqual(self.cur.fetchall(), self.expected)
def test_win_error_on_create(self):
self.assertRaises(sqlite.ProgrammingError,
self.con.create_window_function,
"shouldfail", -100, WindowSumInt)
@with_tracebacks(BadWindow)
def test_win_exception_in_method(self):
for meth in "__init__", "step", "value", "inverse":
with self.subTest(meth=meth):
with patch.object(WindowSumInt, meth, side_effect=BadWindow):
name = f"exc_{meth}"
self.con.create_window_function(name, 1, WindowSumInt)
msg = f"'{meth}' method raised error"
with self.assertRaisesRegex(sqlite.OperationalError, msg):
self.cur.execute(self.query % name)
self.cur.fetchall()
@with_tracebacks(BadWindow)
def test_win_exception_in_finalize(self):
# Note: SQLite does not (as of version 3.38.0) propagate finalize
# callback errors to sqlite3_step(); this implies that OperationalError
# is _not_ raised.
with patch.object(WindowSumInt, "finalize", side_effect=BadWindow):
name = f"exception_in_finalize"
self.con.create_window_function(name, 1, WindowSumInt)
self.cur.execute(self.query % name)
self.cur.fetchall()
@with_tracebacks(AttributeError)
def test_win_missing_method(self):
class MissingValue:
def step(self, x): pass
def inverse(self, x): pass
def finalize(self): return 42
class MissingInverse:
def step(self, x): pass
def value(self): return 42
def finalize(self): return 42
class MissingStep:
def value(self): return 42
def inverse(self, x): pass
def finalize(self): return 42
dataset = (
("step", MissingStep),
("value", MissingValue),
("inverse", MissingInverse),
)
for meth, cls in dataset:
with self.subTest(meth=meth, cls=cls):
name = f"exc_{meth}"
self.con.create_window_function(name, 1, cls)
with self.assertRaisesRegex(sqlite.OperationalError,
f"'{meth}' method not defined"):
self.cur.execute(self.query % name)
self.cur.fetchall()
@with_tracebacks(AttributeError)
def test_win_missing_finalize(self):
# Note: SQLite does not (as of version 3.38.0) propagate finalize
# callback errors to sqlite3_step(); this implies that OperationalError
# is _not_ raised.
class MissingFinalize:
def step(self, x): pass
def value(self): return 42
def inverse(self, x): pass
name = "missing_finalize"
self.con.create_window_function(name, 1, MissingFinalize)
self.cur.execute(self.query % name)
self.cur.fetchall()
def test_win_clear_function(self):
self.con.create_window_function("sumint", 1, None)
self.assertRaises(sqlite.OperationalError, self.cur.execute,
self.query % "sumint")
def test_win_redefine_function(self):
# Redefine WindowSumInt; adjust the expected results accordingly.
class Redefined(WindowSumInt):
def step(self, value): self.count += value * 2
def inverse(self, value): self.count -= value * 2
expected = [(v[0], v[1]*2) for v in self.expected]
self.con.create_window_function("sumint", 1, Redefined)
self.cur.execute(self.query % "sumint")
self.assertEqual(self.cur.fetchall(), expected)
def test_win_error_value_return(self):
class ErrorValueReturn:
def __init__(self): pass
def step(self, x): pass
def value(self): return 1 << 65
self.con.create_window_function("err_val_ret", 1, ErrorValueReturn)
self.assertRaisesRegex(sqlite.DataError, "string or blob too big",
self.cur.execute, self.query % "err_val_ret")
class AggregateTests(unittest.TestCase): class AggregateTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.con = sqlite.connect(":memory:") self.con = sqlite.connect(":memory:")
@ -527,10 +685,10 @@ class AggregateTests(unittest.TestCase):
def test_aggr_no_finalize(self): def test_aggr_no_finalize(self):
cur = self.con.cursor() cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm: msg = "user-defined aggregate's 'finalize' method not defined"
with self.assertRaisesRegex(sqlite.OperationalError, msg):
cur.execute("select nofinalize(t) from test") cur.execute("select nofinalize(t) from test")
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")
@with_tracebacks(ZeroDivisionError, name="AggrExceptionInInit") @with_tracebacks(ZeroDivisionError, name="AggrExceptionInInit")
def test_aggr_exception_in_init(self): def test_aggr_exception_in_init(self):

View File

@ -0,0 +1,3 @@
Add :meth:`~sqlite3.Connection.create_window_function` to
:class:`sqlite3.Connection` for creating aggregate window functions.
Patch by Erlend E. Aasland.

View File

@ -235,6 +235,53 @@ exit:
return return_value; return return_value;
} }
#if defined(HAVE_WINDOW_FUNCTIONS)
PyDoc_STRVAR(create_window_function__doc__,
"create_window_function($self, name, num_params, aggregate_class, /)\n"
"--\n"
"\n"
"Creates or redefines an aggregate window function. Non-standard.\n"
"\n"
" name\n"
" The name of the SQL aggregate window function to be created or\n"
" redefined.\n"
" num_params\n"
" The number of arguments the step and inverse methods takes.\n"
" aggregate_class\n"
" A class with step(), finalize(), value(), and inverse() methods.\n"
" Set to None to clear the window function.");
#define CREATE_WINDOW_FUNCTION_METHODDEF \
{"create_window_function", (PyCFunction)(void(*)(void))create_window_function, METH_METHOD|METH_FASTCALL|METH_KEYWORDS, create_window_function__doc__},
static PyObject *
create_window_function_impl(pysqlite_Connection *self, PyTypeObject *cls,
const char *name, int num_params,
PyObject *aggregate_class);
static PyObject *
create_window_function(pysqlite_Connection *self, PyTypeObject *cls, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames)
{
PyObject *return_value = NULL;
static const char * const _keywords[] = {"", "", "", NULL};
static _PyArg_Parser _parser = {"siO:create_window_function", _keywords, 0};
const char *name;
int num_params;
PyObject *aggregate_class;
if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &_parser,
&name, &num_params, &aggregate_class)) {
goto exit;
}
return_value = create_window_function_impl(self, cls, name, num_params, aggregate_class);
exit:
return return_value;
}
#endif /* defined(HAVE_WINDOW_FUNCTIONS) */
PyDoc_STRVAR(pysqlite_connection_create_aggregate__doc__, PyDoc_STRVAR(pysqlite_connection_create_aggregate__doc__,
"create_aggregate($self, /, name, n_arg, aggregate_class)\n" "create_aggregate($self, /, name, n_arg, aggregate_class)\n"
"--\n" "--\n"
@ -975,6 +1022,10 @@ exit:
return return_value; return return_value;
} }
#ifndef CREATE_WINDOW_FUNCTION_METHODDEF
#define CREATE_WINDOW_FUNCTION_METHODDEF
#endif /* !defined(CREATE_WINDOW_FUNCTION_METHODDEF) */
#ifndef PYSQLITE_CONNECTION_ENABLE_LOAD_EXTENSION_METHODDEF #ifndef PYSQLITE_CONNECTION_ENABLE_LOAD_EXTENSION_METHODDEF
#define PYSQLITE_CONNECTION_ENABLE_LOAD_EXTENSION_METHODDEF #define PYSQLITE_CONNECTION_ENABLE_LOAD_EXTENSION_METHODDEF
#endif /* !defined(PYSQLITE_CONNECTION_ENABLE_LOAD_EXTENSION_METHODDEF) */ #endif /* !defined(PYSQLITE_CONNECTION_ENABLE_LOAD_EXTENSION_METHODDEF) */
@ -990,4 +1041,4 @@ exit:
#ifndef DESERIALIZE_METHODDEF #ifndef DESERIALIZE_METHODDEF
#define DESERIALIZE_METHODDEF #define DESERIALIZE_METHODDEF
#endif /* !defined(DESERIALIZE_METHODDEF) */ #endif /* !defined(DESERIALIZE_METHODDEF) */
/*[clinic end generated code: output=d965a68f9229a56c input=a9049054013a1b77]*/ /*[clinic end generated code: output=b9af1b52fda808bf input=a9049054013a1b77]*/

View File

@ -33,6 +33,10 @@
#define HAVE_TRACE_V2 #define HAVE_TRACE_V2
#endif #endif
#if SQLITE_VERSION_NUMBER >= 3025000
#define HAVE_WINDOW_FUNCTIONS
#endif
static const char * static const char *
get_isolation_level(const char *level) get_isolation_level(const char *level)
{ {
@ -799,7 +803,7 @@ final_callback(sqlite3_context *context)
goto error; goto error;
} }
/* Keep the exception (if any) of the last call to step() */ // Keep the exception (if any) of the last call to step, value, or inverse
PyErr_Fetch(&exception, &value, &tb); PyErr_Fetch(&exception, &value, &tb);
callback_context *ctx = (callback_context *)sqlite3_user_data(context); callback_context *ctx = (callback_context *)sqlite3_user_data(context);
@ -814,13 +818,20 @@ final_callback(sqlite3_context *context)
Py_DECREF(function_result); Py_DECREF(function_result);
} }
if (!ok) { if (!ok) {
set_sqlite_error(context, int attr_err = PyErr_ExceptionMatches(PyExc_AttributeError);
"user-defined aggregate's 'finalize' method raised error"); _PyErr_ChainExceptions(exception, value, tb);
}
/* Restore the exception (if any) of the last call to step(), /* Note: contrary to the step, value, and inverse callbacks, SQLite
but clear also the current exception if finalize() failed */ * does _not_, as of SQLite 3.38.0, propagate errors to sqlite3_step()
PyErr_Restore(exception, value, tb); * from the finalize callback. This implies that execute*() will not
* raise OperationalError, as it normally would. */
set_sqlite_error(context, attr_err
? "user-defined aggregate's 'finalize' method not defined"
: "user-defined aggregate's 'finalize' method raised error");
}
else {
PyErr_Restore(exception, value, tb);
}
error: error:
PyGILState_Release(threadstate); PyGILState_Release(threadstate);
@ -968,6 +979,159 @@ pysqlite_connection_create_function_impl(pysqlite_Connection *self,
Py_RETURN_NONE; Py_RETURN_NONE;
} }
#ifdef HAVE_WINDOW_FUNCTIONS
/*
* Regarding the 'inverse' aggregate callback:
* This method is only required by window aggregate functions, not
* ordinary aggregate function implementations. It is invoked to remove
* a row from the current window. The function arguments, if any,
* correspond to the row being removed.
*/
static void
inverse_callback(sqlite3_context *context, int argc, sqlite3_value **params)
{
PyGILState_STATE gilstate = PyGILState_Ensure();
callback_context *ctx = (callback_context *)sqlite3_user_data(context);
assert(ctx != NULL);
int size = sizeof(PyObject *);
PyObject **cls = (PyObject **)sqlite3_aggregate_context(context, size);
assert(cls != NULL);
assert(*cls != NULL);
PyObject *method = PyObject_GetAttr(*cls, ctx->state->str_inverse);
if (method == NULL) {
set_sqlite_error(context,
"user-defined aggregate's 'inverse' method not defined");
goto exit;
}
PyObject *args = _pysqlite_build_py_params(context, argc, params);
if (args == NULL) {
set_sqlite_error(context,
"unable to build arguments for user-defined aggregate's "
"'inverse' method");
goto exit;
}
PyObject *res = PyObject_CallObject(method, args);
Py_DECREF(args);
if (res == NULL) {
set_sqlite_error(context,
"user-defined aggregate's 'inverse' method raised error");
goto exit;
}
Py_DECREF(res);
exit:
Py_XDECREF(method);
PyGILState_Release(gilstate);
}
/*
* Regarding the 'value' aggregate callback:
* This method is only required by window aggregate functions, not
* ordinary aggregate function implementations. It is invoked to return
* the current value of the aggregate.
*/
static void
value_callback(sqlite3_context *context)
{
PyGILState_STATE gilstate = PyGILState_Ensure();
callback_context *ctx = (callback_context *)sqlite3_user_data(context);
assert(ctx != NULL);
int size = sizeof(PyObject *);
PyObject **cls = (PyObject **)sqlite3_aggregate_context(context, size);
assert(cls != NULL);
assert(*cls != NULL);
PyObject *res = PyObject_CallMethodNoArgs(*cls, ctx->state->str_value);
if (res == NULL) {
int attr_err = PyErr_ExceptionMatches(PyExc_AttributeError);
set_sqlite_error(context, attr_err
? "user-defined aggregate's 'value' method not defined"
: "user-defined aggregate's 'value' method raised error");
}
else {
int rc = _pysqlite_set_result(context, res);
Py_DECREF(res);
if (rc < 0) {
set_sqlite_error(context,
"unable to set result from user-defined aggregate's "
"'value' method");
}
}
PyGILState_Release(gilstate);
}
/*[clinic input]
_sqlite3.Connection.create_window_function as create_window_function
cls: defining_class
name: str
The name of the SQL aggregate window function to be created or
redefined.
num_params: int
The number of arguments the step and inverse methods takes.
aggregate_class: object
A class with step(), finalize(), value(), and inverse() methods.
Set to None to clear the window function.
/
Creates or redefines an aggregate window function. Non-standard.
[clinic start generated code]*/
static PyObject *
create_window_function_impl(pysqlite_Connection *self, PyTypeObject *cls,
const char *name, int num_params,
PyObject *aggregate_class)
/*[clinic end generated code: output=5332cd9464522235 input=46d57a54225b5228]*/
{
if (sqlite3_libversion_number() < 3025000) {
PyErr_SetString(self->NotSupportedError,
"create_window_function() requires "
"SQLite 3.25.0 or higher");
return NULL;
}
if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
return NULL;
}
int flags = SQLITE_UTF8;
int rc;
if (Py_IsNone(aggregate_class)) {
rc = sqlite3_create_window_function(self->db, name, num_params, flags,
0, 0, 0, 0, 0, 0);
}
else {
callback_context *ctx = create_callback_context(cls, aggregate_class);
if (ctx == NULL) {
return NULL;
}
rc = sqlite3_create_window_function(self->db, name, num_params, flags,
ctx,
&step_callback,
&final_callback,
&value_callback,
&inverse_callback,
&destructor_callback);
}
if (rc != SQLITE_OK) {
// Errors are not set on the database connection, so we cannot
// use _pysqlite_seterror().
PyErr_SetString(self->ProgrammingError, sqlite3_errstr(rc));
return NULL;
}
Py_RETURN_NONE;
}
#endif
/*[clinic input] /*[clinic input]
_sqlite3.Connection.create_aggregate as pysqlite_connection_create_aggregate _sqlite3.Connection.create_aggregate as pysqlite_connection_create_aggregate
@ -2092,6 +2256,7 @@ static PyMethodDef connection_methods[] = {
GETLIMIT_METHODDEF GETLIMIT_METHODDEF
SERIALIZE_METHODDEF SERIALIZE_METHODDEF
DESERIALIZE_METHODDEF DESERIALIZE_METHODDEF
CREATE_WINDOW_FUNCTION_METHODDEF
{NULL, NULL} {NULL, NULL}
}; };

View File

@ -630,8 +630,10 @@ module_clear(PyObject *module)
Py_CLEAR(state->str___conform__); Py_CLEAR(state->str___conform__);
Py_CLEAR(state->str_executescript); Py_CLEAR(state->str_executescript);
Py_CLEAR(state->str_finalize); Py_CLEAR(state->str_finalize);
Py_CLEAR(state->str_inverse);
Py_CLEAR(state->str_step); Py_CLEAR(state->str_step);
Py_CLEAR(state->str_upper); Py_CLEAR(state->str_upper);
Py_CLEAR(state->str_value);
return 0; return 0;
} }
@ -717,8 +719,10 @@ module_exec(PyObject *module)
ADD_INTERNED(state, __conform__); ADD_INTERNED(state, __conform__);
ADD_INTERNED(state, executescript); ADD_INTERNED(state, executescript);
ADD_INTERNED(state, finalize); ADD_INTERNED(state, finalize);
ADD_INTERNED(state, inverse);
ADD_INTERNED(state, step); ADD_INTERNED(state, step);
ADD_INTERNED(state, upper); ADD_INTERNED(state, upper);
ADD_INTERNED(state, value);
/* Set error constants */ /* Set error constants */
if (add_error_constants(module) < 0) { if (add_error_constants(module) < 0) {

View File

@ -64,8 +64,10 @@ typedef struct {
PyObject *str___conform__; PyObject *str___conform__;
PyObject *str_executescript; PyObject *str_executescript;
PyObject *str_finalize; PyObject *str_finalize;
PyObject *str_inverse;
PyObject *str_step; PyObject *str_step;
PyObject *str_upper; PyObject *str_upper;
PyObject *str_value;
} pysqlite_state; } pysqlite_state;
extern pysqlite_state pysqlite_global_state; extern pysqlite_state pysqlite_global_state;