diff --git a/atom/src/atomdict.cpp b/atom/src/atomdict.cpp index 7c48585c..1c321f5c 100644 --- a/atom/src/atomdict.cpp +++ b/atom/src/atomdict.cpp @@ -103,11 +103,7 @@ int AtomDict_traverse( AtomDict* self, visitproc visit, void* arg ) { Py_VISIT( self->m_key_validator ); Py_VISIT( self->m_value_validator ); -#if PY_VERSION_HEX >= 0x03090000 - // This was not needed before Python 3.9 (Python issue 35810 and 40217) - Py_VISIT(Py_TYPE(self)); -#endif - // PyDict_type is not heap allocated so it does visit the type + Py_VISIT(Py_TYPE(self)); return PyDict_Type.tp_traverse( pyobject_cast( self ), visit, arg ); } @@ -153,17 +149,21 @@ PyObject* AtomDict_setdefault( AtomDict* self, PyObject* args ) { return 0; } - PyObject* value = PyDict_GetItem( pyobject_cast( self ), key ); + // Key must be validated before use due to possible coercion in AtomDict_ass_subscript + cppy::ptr key_ptr( validate_key( self, key ) ); + if ( !key_ptr ) + return 0; + PyObject* value = PyDict_GetItem( pyobject_cast( self ), key_ptr.get() ); if( value ) { return cppy::incref( value ); } - if( AtomDict_ass_subscript( self, key, dfv ) < 0 ) + if( AtomDict_ass_subscript( self, key_ptr.get(), dfv ) < 0 ) { return 0; } // Get the dictionary from the dict itself in case it was coerced. - return cppy::incref( PyDict_GetItem( pyobject_cast( self ), key ) ); + return cppy::incref( PyDict_GetItem( pyobject_cast( self ), key_ptr.get() ) ); } @@ -262,25 +262,26 @@ static PyObject* DefaultAtomDict_repr( DefaultAtomDict* self ) { return 0; } - ostr << PyUnicode_AsUTF8( repr.get() ); + const char* factory_repr = PyUnicode_AsUTF8( repr.get() ); + if ( !factory_repr ) + return 0; + ostr << factory_repr; ostr << ", "; repr = PyDict_Type.tp_repr( pyobject_cast( self ) ); if( !repr ) { return 0; } - ostr << PyUnicode_AsUTF8( repr.get() ); + const char* self_repr = PyUnicode_AsUTF8( repr.get() ); + if ( !self_repr ) + return 0; + ostr << self_repr; ostr << ")"; return PyUnicode_FromString( ostr.str().c_str() ); } -static PyObject* DefaultAtomDict_missing( DefaultAtomDict* self, PyObject* args ) +static PyObject* DefaultAtomDict_missing( DefaultAtomDict* self, PyObject* key ) { - PyObject* key; - if( !PyArg_UnpackTuple( args, "__missing__", 1, 1, &key ) ) - { - return 0; - } CAtom* atom = self->dict.pointer->data(); if( !atom ) { @@ -289,26 +290,25 @@ static PyObject* DefaultAtomDict_missing( DefaultAtomDict* self, PyObject* args "so missing value cannot be built." ); } -#if PY_VERSION_HEX >= 0x03090000 cppy::ptr value_ptr( PyObject_CallOneArg( self->factory, pyobject_cast( atom ) ) ); -#else - cppy::ptr temp( PyTuple_Pack(1, pyobject_cast( atom ) ) ); - cppy::ptr value_ptr( PyObject_Call( self->factory, temp.get(), 0 ) ); -#endif if( !value_ptr ) { return 0; } if( should_validate( atomdict_cast( self ) ) ) { + // Key must be validated before use due to possible coercion in AtomDict_ass_subscript + cppy::ptr key_ptr( validate_key( atomdict_cast( self ), key ) ); + if ( !key_ptr ) + return 0; // We cannot simply validate the value as it will be re-validated when // it is set which leads to creating a different object. - if( AtomDict_ass_subscript( atomdict_cast( self ), key, value_ptr.get() ) < 0 ) + if( AtomDict_ass_subscript( atomdict_cast( self ), key_ptr.get(), value_ptr.get() ) < 0 ) { return 0; } // Get the dictionary from the dict itself in case it was coerced. - value_ptr = cppy::incref( PyDict_GetItem( pyobject_cast( self ), key ) ); + value_ptr = cppy::incref( PyDict_GetItem( pyobject_cast( self ), key_ptr.get() ) ); } return value_ptr.release(); } @@ -316,7 +316,7 @@ static PyObject* DefaultAtomDict_missing( DefaultAtomDict* self, PyObject* args static PyMethodDef DefaultAtomDict_methods[] = { { "__missing__", ( PyCFunction )DefaultAtomDict_missing, - METH_VARARGS, + METH_O, "Called when a key is absent from the dictionary" }, { 0 } // sentinel }; @@ -370,6 +370,8 @@ PyObject* AtomDict::New( CAtom* atom, Member* key_validator, Member* value_valid int AtomDict::Update( AtomDict* dict, PyObject* value ) { cppy::ptr validated_dict( PyDict_New() ); + if ( !validated_dict ) + return -1; // LCOV_EXCL_LINE (failed dict creation) PyObject* key; PyObject* val; Py_ssize_t index = 0; @@ -455,10 +457,12 @@ bool DefaultAtomDict::Ready() { // This will work only if we create this type after the standard AtomDict // The reference will be handled by the module to which we will add the type - PyObject* bases = PyTuple_New( 1 ); - PyTuple_SET_ITEM( bases, 0, pyobject_cast( AtomDict::TypeObject ) ); + cppy::ptr bases( PyTuple_New( 1 ) ); + if ( !bases ) + return false; // LCOV_EXCL_LINE (failed tuple creation) + PyTuple_SET_ITEM( bases.get(), 0, cppy::incref( pyobject_cast( AtomDict::TypeObject ) ) ); TypeObject = pytype_cast( - PyType_FromSpecWithBases( &TypeObject_Spec, bases ) + PyType_FromSpecWithBases( &TypeObject_Spec, bases.get() ) ); if( !TypeObject ) { diff --git a/tests/test_atomdefaultdict.py b/tests/test_atomdefaultdict.py index 9845800c..fa5210b3 100644 --- a/tests/test_atomdefaultdict.py +++ b/tests/test_atomdefaultdict.py @@ -11,7 +11,16 @@ import pytest -from atom.api import Atom, DefaultDict, Instance, Int, List, atomlist, defaultatomdict +from atom.api import ( + Atom, + Coerced, + DefaultDict, + Instance, + Int, + List, + atomlist, + defaultatomdict, +) @pytest.fixture @@ -237,3 +246,11 @@ class A(Atom): content = a.d[1] assert isinstance(content, atomlist) assert content is a.d[1] + + +def test_coerced_key_missing(): + class Obj(Atom): + items = DefaultDict(key=Coerced(str), missing=lambda: "missing") + + o = Obj() + o.items[1] # key 1 gets coerced to '1' diff --git a/tests/test_atomdict.py b/tests/test_atomdict.py index 7a48d472..03ff9235 100644 --- a/tests/test_atomdict.py +++ b/tests/test_atomdict.py @@ -9,7 +9,7 @@ import pytest -from atom.api import Atom, Dict, Int, List, atomdict, atomlist +from atom.api import Atom, Dict, Coerced, Int, List, atomdict, atomlist @pytest.fixture @@ -186,3 +186,13 @@ def test_update(atom_dict): atom_dict.fullytyped.update({"": 1}) with pytest.raises(TypeError): atom_dict.fullytyped.update({"": ""}) + + +def test_coerced_setdefault(): + class Obj(Atom): + items = Dict(key=Coerced(str)) + + o = Obj() + o.items["1"] = "a" + o.items.setdefault(1, "b") # key 1 gets coerced to '1' + assert o.items["1"] == "a"