Skip to content

Commit 3a2e7aa

Browse files
gh-129928: Raise more accurate exception for incorrect sqlite3 UDF creation (#129941)
Consistently raise ProgrammingError if the user tries to create an UDF with an invalid number of parameters.
1 parent 0dbe543 commit 3a2e7aa

File tree

3 files changed

+29
-6
lines changed

3 files changed

+29
-6
lines changed

Lib/test/test_sqlite3/test_userfunctions.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def tearDown(self):
171171
self.con.close()
172172

173173
def test_func_error_on_create(self):
174-
with self.assertRaises(sqlite.OperationalError):
174+
with self.assertRaisesRegex(sqlite.ProgrammingError, "not -100"):
175175
self.con.create_function("bla", -100, lambda x: 2*x)
176176

177177
def test_func_too_many_args(self):
@@ -507,9 +507,8 @@ def test_win_sum_int(self):
507507
self.assertEqual(self.cur.fetchall(), self.expected)
508508

509509
def test_win_error_on_create(self):
510-
self.assertRaises(sqlite.ProgrammingError,
511-
self.con.create_window_function,
512-
"shouldfail", -100, WindowSumInt)
510+
with self.assertRaisesRegex(sqlite.ProgrammingError, "not -100"):
511+
self.con.create_window_function("shouldfail", -100, WindowSumInt)
513512

514513
@with_tracebacks(BadWindow)
515514
def test_win_exception_in_method(self):
@@ -638,7 +637,7 @@ def tearDown(self):
638637
self.con.close()
639638

640639
def test_aggr_error_on_create(self):
641-
with self.assertRaises(sqlite.OperationalError):
640+
with self.assertRaisesRegex(sqlite.ProgrammingError, "not -100"):
642641
self.con.create_function("bla", -100, AggrSum)
643642

644643
@with_tracebacks(AttributeError, msg_regex="AggrNoStep")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Raise :exc:`sqlite3.ProgrammingError` if a user-defined SQL function with
2+
invalid number of parameters is created. Patch by Erlend Aasland.

Modules/_sqlite/connection.c

+23-1
Original file line numberDiff line numberDiff line change
@@ -1139,6 +1139,20 @@ destructor_callback(void *ctx)
11391139
}
11401140
}
11411141

1142+
static int
1143+
check_num_params(pysqlite_Connection *self, const int n, const char *name)
1144+
{
1145+
int limit = sqlite3_limit(self->db, SQLITE_LIMIT_FUNCTION_ARG, -1);
1146+
assert(limit >= 0);
1147+
if (n < -1 || n > limit) {
1148+
PyErr_Format(self->ProgrammingError,
1149+
"'%s' must be between -1 and %d, not %d",
1150+
name, limit, n);
1151+
return -1;
1152+
}
1153+
return 0;
1154+
}
1155+
11421156
/*[clinic input]
11431157
_sqlite3.Connection.create_function as pysqlite_connection_create_function
11441158
@@ -1167,6 +1181,9 @@ pysqlite_connection_create_function_impl(pysqlite_Connection *self,
11671181
if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
11681182
return NULL;
11691183
}
1184+
if (check_num_params(self, narg, "narg") < 0) {
1185+
return NULL;
1186+
}
11701187

11711188
if (deterministic) {
11721189
flags |= SQLITE_DETERMINISTIC;
@@ -1307,10 +1324,12 @@ create_window_function_impl(pysqlite_Connection *self, PyTypeObject *cls,
13071324
"SQLite 3.25.0 or higher");
13081325
return NULL;
13091326
}
1310-
13111327
if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
13121328
return NULL;
13131329
}
1330+
if (check_num_params(self, num_params, "num_params") < 0) {
1331+
return NULL;
1332+
}
13141333

13151334
int flags = SQLITE_UTF8;
13161335
int rc;
@@ -1367,6 +1386,9 @@ pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self,
13671386
if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
13681387
return NULL;
13691388
}
1389+
if (check_num_params(self, n_arg, "n_arg") < 0) {
1390+
return NULL;
1391+
}
13701392

13711393
callback_context *ctx = create_callback_context(cls, aggregate_class);
13721394
if (ctx == NULL) {

0 commit comments

Comments
 (0)