Skip to content

Commit 01cf357

Browse files
committed
gh-141510: Check argument in PyDict_Contains()
PyDict_Contains() and PyDict_ContainsString() now fail with SystemError if the first argument is not a dict, frozendict, dict subclass or frozendict subclass.
1 parent 7258dbc commit 01cf357

2 files changed

Lines changed: 26 additions & 10 deletions

File tree

Lib/test/test_capi/test_dict.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def test_dict_getitemwitherror(self):
210210
# CRASHES getitem(NULL, 'a')
211211

212212
def test_dict_contains(self):
213+
# Test PyDict_Contains()
213214
contains = _testlimitedcapi.dict_contains
214215
dct = {'a': 1, '\U0001f40d': 2}
215216
self.assertTrue(contains(dct, 'a'))
@@ -222,11 +223,12 @@ def test_dict_contains(self):
222223

223224
self.assertRaises(TypeError, contains, {}, []) # unhashable
224225
# CRASHES contains({}, NULL)
225-
# CRASHES contains(UserDict(), 'a')
226-
# CRASHES contains(42, 'a')
226+
self.assertRaises(SystemError, contains, UserDict(), 'a')
227+
self.assertRaises(SystemError, contains, 42, 'a')
227228
# CRASHES contains(NULL, 'a')
228229

229230
def test_dict_contains_string(self):
231+
# Test PyDict_ContainsString()
230232
contains_string = _testcapi.dict_containsstring
231233
dct = {'a': 1, '\U0001f40d': 2}
232234
self.assertTrue(contains_string(dct, b'a'))
@@ -238,6 +240,8 @@ def test_dict_contains_string(self):
238240
self.assertTrue(contains_string(dct2, b'a'))
239241
self.assertFalse(contains_string(dct2, b'b'))
240242

243+
self.assertRaises(SystemError, contains_string, UserDict(), 'a')
244+
self.assertRaises(SystemError, contains_string, 42, 'a')
241245
# CRASHES contains({}, NULL)
242246
# CRASHES contains(NULL, b'a')
243247

Objects/dictobject.c

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ static PyObject* frozendict_new(PyTypeObject *type, PyObject *args,
140140
PyObject *kwds);
141141
static PyObject* dict_new(PyTypeObject *type, PyObject *args, PyObject *kwds);
142142
static int dict_merge(PyObject *a, PyObject *b, int override);
143+
static int dict_contains(PyObject *op, PyObject *key);
143144

144145

145146
/*[clinic input]
@@ -4121,7 +4122,7 @@ dict_merge(PyObject *a, PyObject *b, int override)
41214122

41224123
for (key = PyIter_Next(iter); key; key = PyIter_Next(iter)) {
41234124
if (override != 1) {
4124-
status = PyDict_Contains(a, key);
4125+
status = dict_contains(a, key);
41254126
if (status != 0) {
41264127
if (status > 0) {
41274128
if (override == 0) {
@@ -4464,7 +4465,7 @@ static PyObject *
44644465
dict___contains___impl(PyDictObject *self, PyObject *key)
44654466
/*[clinic end generated code: output=1b314e6da7687dae input=fe1cb42ad831e820]*/
44664467
{
4467-
int contains = PyDict_Contains((PyObject *)self, key);
4468+
int contains = dict_contains((PyObject *)self, key);
44684469
if (contains < 0) {
44694470
return NULL;
44704471
}
@@ -4964,9 +4965,8 @@ static PyMethodDef mapp_methods[] = {
49644965
{NULL, NULL} /* sentinel */
49654966
};
49664967

4967-
/* Return 1 if `key` is in dict `op`, 0 if not, and -1 on error. */
4968-
int
4969-
PyDict_Contains(PyObject *op, PyObject *key)
4968+
static int
4969+
dict_contains(PyObject *op, PyObject *key)
49704970
{
49714971
Py_hash_t hash = _PyObject_HashFast(key);
49724972
if (hash == -1) {
@@ -4977,6 +4977,18 @@ PyDict_Contains(PyObject *op, PyObject *key)
49774977
return _PyDict_Contains_KnownHash(op, key, hash);
49784978
}
49794979

4980+
/* Return 1 if `key` is in dict `op`, 0 if not, and -1 on error. */
4981+
int
4982+
PyDict_Contains(PyObject *op, PyObject *key)
4983+
{
4984+
if (!PyAnyDict_Check(op)) {
4985+
PyErr_BadInternalCall();
4986+
return -1;
4987+
}
4988+
4989+
return dict_contains(op, key);
4990+
}
4991+
49804992
int
49814993
PyDict_ContainsString(PyObject *op, const char *key)
49824994
{
@@ -4993,7 +5005,7 @@ PyDict_ContainsString(PyObject *op, const char *key)
49935005
int
49945006
_PyDict_Contains_KnownHash(PyObject *op, PyObject *key, Py_hash_t hash)
49955007
{
4996-
PyDictObject *mp = (PyDictObject *)op;
5008+
PyDictObject *mp = _PyAnyDict_CAST(op);
49975009
PyObject *value;
49985010
Py_ssize_t ix;
49995011

@@ -5022,7 +5034,7 @@ static PySequenceMethods dict_as_sequence = {
50225034
0, /* sq_slice */
50235035
0, /* sq_ass_item */
50245036
0, /* sq_ass_slice */
5025-
PyDict_Contains, /* sq_contains */
5037+
dict_contains, /* sq_contains */
50265038
0, /* sq_inplace_concat */
50275039
0, /* sq_inplace_repeat */
50285040
};
@@ -6272,7 +6284,7 @@ dictkeys_contains(PyObject *self, PyObject *obj)
62726284
_PyDictViewObject *dv = (_PyDictViewObject *)self;
62736285
if (dv->dv_dict == NULL)
62746286
return 0;
6275-
return PyDict_Contains((PyObject *)dv->dv_dict, obj);
6287+
return dict_contains((PyObject *)dv->dv_dict, obj);
62766288
}
62776289

62786290
static PySequenceMethods dictkeys_as_sequence = {

0 commit comments

Comments
 (0)