Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 31 additions & 27 deletions atom/src/atomdict.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,7 @@ int AtomDict_traverse( AtomDict* self, visitproc visit, void* arg )
{
Py_VISIT( self->m_key_validator );
Py_VISIT( self->m_value_validator );
#if PY_VERSION_HEX >= 0x03090000
// This was not needed before Python 3.9 (Python issue 35810 and 40217)
Py_VISIT(Py_TYPE(self));
#endif
// PyDict_type is not heap allocated so it does visit the type
Py_VISIT(Py_TYPE(self));
return PyDict_Type.tp_traverse( pyobject_cast( self ), visit, arg );
}

Expand Down Expand Up @@ -153,17 +149,21 @@ PyObject* AtomDict_setdefault( AtomDict* self, PyObject* args )
{
return 0;
}
PyObject* value = PyDict_GetItem( pyobject_cast( self ), key );
// Key must be validated before use due to possible coercion in AtomDict_ass_subscript
cppy::ptr key_ptr( validate_key( self, key ) );
if ( !key_ptr )
return 0;
PyObject* value = PyDict_GetItem( pyobject_cast( self ), key_ptr.get() );
if( value )
{
return cppy::incref( value );
}
if( AtomDict_ass_subscript( self, key, dfv ) < 0 )
if( AtomDict_ass_subscript( self, key_ptr.get(), dfv ) < 0 )
{
return 0;
}
// Get the dictionary from the dict itself in case it was coerced.
return cppy::incref( PyDict_GetItem( pyobject_cast( self ), key ) );
return cppy::incref( PyDict_GetItem( pyobject_cast( self ), key_ptr.get() ) );
}


Expand Down Expand Up @@ -262,25 +262,26 @@ static PyObject* DefaultAtomDict_repr( DefaultAtomDict* self )
{
return 0;
}
ostr << PyUnicode_AsUTF8( repr.get() );
const char* factory_repr = PyUnicode_AsUTF8( repr.get() );
if ( !factory_repr )
return 0;
ostr << factory_repr;
ostr << ", ";
repr = PyDict_Type.tp_repr( pyobject_cast( self ) );
if( !repr )
{
return 0;
}
ostr << PyUnicode_AsUTF8( repr.get() );
const char* self_repr = PyUnicode_AsUTF8( repr.get() );
if ( !self_repr )
return 0;
ostr << self_repr;
ostr << ")";
return PyUnicode_FromString( ostr.str().c_str() );
}

static PyObject* DefaultAtomDict_missing( DefaultAtomDict* self, PyObject* args )
static PyObject* DefaultAtomDict_missing( DefaultAtomDict* self, PyObject* key )
{
PyObject* key;
if( !PyArg_UnpackTuple( args, "__missing__", 1, 1, &key ) )
{
return 0;
}
CAtom* atom = self->dict.pointer->data();
if( !atom )
{
Expand All @@ -289,34 +290,33 @@ static PyObject* DefaultAtomDict_missing( DefaultAtomDict* self, PyObject* args
"so missing value cannot be built."
);
}
#if PY_VERSION_HEX >= 0x03090000
cppy::ptr value_ptr( PyObject_CallOneArg( self->factory, pyobject_cast( atom ) ) );
#else
cppy::ptr temp( PyTuple_Pack(1, pyobject_cast( atom ) ) );
cppy::ptr value_ptr( PyObject_Call( self->factory, temp.get(), 0 ) );
#endif
if( !value_ptr )
{
return 0;
}
if( should_validate( atomdict_cast( self ) ) )
{
// Key must be validated before use due to possible coercion in AtomDict_ass_subscript
cppy::ptr key_ptr( validate_key( atomdict_cast( self ), key ) );
if ( !key_ptr )
return 0;
// We cannot simply validate the value as it will be re-validated when
// it is set which leads to creating a different object.
if( AtomDict_ass_subscript( atomdict_cast( self ), key, value_ptr.get() ) < 0 )
if( AtomDict_ass_subscript( atomdict_cast( self ), key_ptr.get(), value_ptr.get() ) < 0 )
{
return 0;
}
// Get the dictionary from the dict itself in case it was coerced.
value_ptr = cppy::incref( PyDict_GetItem( pyobject_cast( self ), key ) );
value_ptr = cppy::incref( PyDict_GetItem( pyobject_cast( self ), key_ptr.get() ) );
}
return value_ptr.release();
}

static PyMethodDef DefaultAtomDict_methods[] = {
{ "__missing__",
( PyCFunction )DefaultAtomDict_missing,
METH_VARARGS,
METH_O,
"Called when a key is absent from the dictionary" },
{ 0 } // sentinel
};
Expand Down Expand Up @@ -370,6 +370,8 @@ PyObject* AtomDict::New( CAtom* atom, Member* key_validator, Member* value_valid
int AtomDict::Update( AtomDict* dict, PyObject* value )
{
cppy::ptr validated_dict( PyDict_New() );
if ( !validated_dict )
return -1; // LCOV_EXCL_LINE (failed dict creation)
PyObject* key;
PyObject* val;
Py_ssize_t index = 0;
Expand Down Expand Up @@ -455,10 +457,12 @@ bool DefaultAtomDict::Ready()
{
// This will work only if we create this type after the standard AtomDict
// The reference will be handled by the module to which we will add the type
PyObject* bases = PyTuple_New( 1 );
PyTuple_SET_ITEM( bases, 0, pyobject_cast( AtomDict::TypeObject ) );
cppy::ptr bases( PyTuple_New( 1 ) );
if ( !bases )
return false; // LCOV_EXCL_LINE (failed tuple creation)
PyTuple_SET_ITEM( bases.get(), 0, cppy::incref( pyobject_cast( AtomDict::TypeObject ) ) );
TypeObject = pytype_cast(
PyType_FromSpecWithBases( &TypeObject_Spec, bases )
PyType_FromSpecWithBases( &TypeObject_Spec, bases.get() )
);
if( !TypeObject )
{
Expand Down
19 changes: 18 additions & 1 deletion tests/test_atomdefaultdict.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,16 @@

import pytest

from atom.api import Atom, DefaultDict, Instance, Int, List, atomlist, defaultatomdict
from atom.api import (
Atom,
Coerced,
DefaultDict,
Instance,
Int,
List,
atomlist,
defaultatomdict,
)


@pytest.fixture
Expand Down Expand Up @@ -237,3 +246,11 @@ class A(Atom):
content = a.d[1]
assert isinstance(content, atomlist)
assert content is a.d[1]


def test_coerced_key_missing():
class Obj(Atom):
items = DefaultDict(key=Coerced(str), missing=lambda: "missing")

o = Obj()
o.items[1] # key 1 gets coerced to '1'
12 changes: 11 additions & 1 deletion tests/test_atomdict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import pytest

from atom.api import Atom, Dict, Int, List, atomdict, atomlist
from atom.api import Atom, Dict, Coerced, Int, List, atomdict, atomlist


@pytest.fixture
Expand Down Expand Up @@ -186,3 +186,13 @@ def test_update(atom_dict):
atom_dict.fullytyped.update({"": 1})
with pytest.raises(TypeError):
atom_dict.fullytyped.update({"": ""})


def test_coerced_setdefault():
class Obj(Atom):
items = Dict(key=Coerced(str))

o = Obj()
o.items["1"] = "a"
o.items.setdefault(1, "b") # key 1 gets coerced to '1'
assert o.items["1"] == "a"
Loading