diff --git a/atom/src/sortedmap.cpp b/atom/src/sortedmap.cpp index 7d2a0ec3..b63383f4 100644 --- a/atom/src/sortedmap.cpp +++ b/atom/src/sortedmap.cpp @@ -131,6 +131,53 @@ struct SortedMap static bool Ready(); + static PyObject* New( PyTypeObject* type, PyObject* map ) + { + cppy::ptr selfptr( PyType_GenericNew( type, 0, 0 ) ); + if( !selfptr ) + return 0; // LCOV_EXCL_LINE (allocation failed, very unlikely) + SortedMap* self = reinterpret_cast( selfptr.get() ); + self->m_items = new SortedMap::Items(); + + if( map ) + { + if( PyObject_TypeCheck( map, TypeObject ) ) + { + SortedMap* other = reinterpret_cast( map ); + *self->m_items = *other->m_items; + } + else if( PyDict_Check(map) ) + { + PyObject* key; + PyObject* val; + Py_ssize_t index = 0; + while( PyDict_Next( map, &index, &key, &val ) ) + self->setitem( key, val ); + } + else + { + cppy::ptr iter( PyObject_GetIter( map ) ); + if( !iter ) + return 0; + cppy::ptr item; + while( (item = PyIter_Next( iter.get() )) ) + { + cppy::ptr pair( PySequence_Fast( item.get(), "map must be a sequence of key, value pairs") ); + if ( !pair ) + return 0; + if( PySequence_Fast_GET_SIZE( pair.get() ) != 2 ) + return cppy::type_error( pair.get(), "pairs of objects" ); + self->setitem( PySequence_Fast_GET_ITEM( pair.get(), 0 ), + PySequence_Fast_GET_ITEM( pair.get(), 1 ) ); + } + if ( PyErr_Occurred() ) + return 0; // error during iteration + } + } + + return selfptr.release(); + } + PyObject* getitem( PyObject* key, PyObject* default_value = 0 ) { Items::iterator it = std::lower_bound( @@ -288,46 +335,22 @@ SortedMap_new( PyTypeObject* type, PyObject* args, PyObject* kwargs ) static char* kwlist[] = { "map", 0 }; if( !PyArg_ParseTupleAndKeywords( args, kwargs, "|O:__new__", kwlist, &map ) ) return 0; + return SortedMap::New( type, map ); +} - PyObject* self = PyType_GenericNew( type, 0, 0 ); - if( !self ) { - return 0; // LCOV_EXCL_LINE (allocation failed, very unlikely) - } - SortedMap* cself = reinterpret_cast( self ); - cself->m_items = new SortedMap::Items(); - - cppy::ptr seq; - if( map ) - { - if( PyDict_Check( map ) ) - { - seq = PyObject_GetIter( PyDict_Items( map ) ); - if( !seq ) { - return 0; // LCOV_EXCL_LINE (dict items failed, very unlikely) - } - } - else - { - seq = PyObject_GetIter( map ); - if( !seq ) - return 0; - } - } - - if( seq ) - { - cppy::ptr item; - while( (item = PyIter_Next( seq.get() )) ) - { - if( PySequence_Length( item.get() ) != 2) - return cppy::type_error( item.get(), "pairs of objects" ); - - cself->setitem( PySequence_GetItem( item.get(), 0 ), - PySequence_GetItem( item.get(), 1 ) ); - } +PyObject* +SortedMap_vectorcall( PyObject* type, PyObject*const *args, size_t nargsf, PyObject* kwnames ) +{ + if ( kwnames ) + return cppy::type_error("sortedmap takes no kwargs"); + switch (PyVectorcall_NARGS(nargsf)) { + case 0: + return SortedMap::New( reinterpret_cast(type), 0 ); + case 1: + return SortedMap::New( reinterpret_cast(type), args[0] ); + default: + return cppy::type_error("sortedmap takes at most one argument"); } - - return self; } // Clearing the vector may cause arbitrary side effects on item @@ -353,10 +376,7 @@ SortedMap_traverse( SortedMap* self, visitproc visit, void* arg ) Py_VISIT( it->key() ); Py_VISIT( it->value() ); } -#if PY_VERSION_HEX >= 0x03090000 - // This was not needed before Python 3.9 (Python issue 35810 and 40217) Py_VISIT(Py_TYPE(self)); -#endif return 0; } @@ -364,11 +384,13 @@ SortedMap_traverse( SortedMap* self, visitproc visit, void* arg ) void SortedMap_dealloc( SortedMap* self ) { + PyTypeObject *tp = Py_TYPE(self); PyObject_GC_UnTrack( self ); SortedMap_clear( self ); delete self->m_items; self->m_items = 0; - Py_TYPE(self)->tp_free( reinterpret_cast( self ) ); + tp->tp_free( pyobject_cast( self ) ); + Py_DECREF(tp); } @@ -508,8 +530,14 @@ SortedMap_repr( SortedMap* self ) cppy::ptr valstr( PyObject_Repr( it->value() ) ); if( !valstr ) return 0; - ostr << "(" << PyUnicode_AsUTF8( keystr.get() ) << ", "; - ostr << PyUnicode_AsUTF8( valstr.get() ) << "), "; + const char* k = PyUnicode_AsUTF8( keystr.get() ); + if ( !k ) + return 0; + const char* v = PyUnicode_AsUTF8( valstr.get() ); + if ( !v ) + return 0; + ostr << "(" << k << ", "; + ostr << v << "), "; } if( self->m_items->size() > 0 ) ostr.seekp( -2, std::ios_base::cur ); @@ -574,6 +602,9 @@ static PyType_Slot SortedMap_Type_slots[] = { { Py_tp_new, void_cast( SortedMap_new ) }, /* tp_new */ { Py_tp_iter, void_cast( SortedMap_iter ) }, /* tp_iter */ { Py_tp_alloc, void_cast( PyType_GenericAlloc ) }, /* tp_alloc */ +#if defined(Py_tp_vectorcall) + { Py_tp_vectorcall, void_cast( SortedMap_vectorcall ) }, /* tp_vectorcall */ +#endif { Py_mp_length, void_cast( SortedMap_length ) }, /* mp_length */ { Py_mp_subscript, void_cast( SortedMap_subscript ) }, /* mp_subscript */ { Py_mp_ass_subscript, void_cast( SortedMap_ass_subscript ) }, /* mp_ass_subscript */ diff --git a/tests/datastructure/test_sortedmap.py b/tests/datastructure/test_sortedmap.py index d09e82fa..25b21762 100644 --- a/tests/datastructure/test_sortedmap.py +++ b/tests/datastructure/test_sortedmap.py @@ -8,6 +8,7 @@ """Test the sortedmap that acts like an ordered dictionary.""" import gc +import sys import pytest @@ -33,12 +34,52 @@ def test_sortedmap_init(): assert smap.items() == [(1, 2)] smap = sortedmap({1: 2}) assert smap.items() == [(1, 2)] + copy = sortedmap(smap) + assert copy.items() == [(1, 2)] with pytest.raises(TypeError): sortedmap(1) + with pytest.raises(TypeError): + sortedmap(a=1) + with pytest.raises(TypeError): + sortedmap(1, 2) with pytest.raises(TypeError) as excinfo: sortedmap([1]) - assert "pairs" in excinfo.exconly() + assert "sequence of key, value pairs" in excinfo.exconly() + with pytest.raises(TypeError) as excinfo: + sortedmap([[1]]) + assert "pairs of objects" in excinfo.exconly() + + +def test_sortedmap_gen_err(): + """Test that iterator error is raised""" + + def generator(throw): + yield ("a", 1) + if throw: + raise ValueError() + + smap = sortedmap(generator(throw=False)) + assert smap["a"] == 1 + with pytest.raises(ValueError): + smap = sortedmap(generator(throw=True)) + + +def test_sortedmap_refcnt(): + """Test that constructor does not leak references""" + k = object() + v = object() + rck = sys.getrefcount(k) + rcv = sys.getrefcount(v) + smap = sortedmap([(k, v)]) + assert smap[k] == v + del smap + smap = sortedmap({k: v}) + assert smap[k] == v + del smap + gc.collect() + assert sys.getrefcount(k) == rck + assert sys.getrefcount(v) == rcv def test_traverse():