Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 75 additions & 44 deletions atom/src/sortedmap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,53 @@ struct SortedMap

static bool Ready();

static PyObject* New( PyTypeObject* type, PyObject* map )
{
cppy::ptr selfptr( PyType_GenericNew( type, 0, 0 ) );
if( !selfptr )
return 0; // LCOV_EXCL_LINE (allocation failed, very unlikely)
SortedMap* self = reinterpret_cast<SortedMap*>( selfptr.get() );
self->m_items = new SortedMap::Items();

if( map )
{
if( PyObject_TypeCheck( map, TypeObject ) )
{
SortedMap* other = reinterpret_cast<SortedMap*>( map );
*self->m_items = *other->m_items;
}
else if( PyDict_Check(map) )
{
PyObject* key;
PyObject* val;
Py_ssize_t index = 0;
while( PyDict_Next( map, &index, &key, &val ) )
self->setitem( key, val );
}
else
{
cppy::ptr iter( PyObject_GetIter( map ) );
if( !iter )
return 0;
cppy::ptr item;
while( (item = PyIter_Next( iter.get() )) )
{
cppy::ptr pair( PySequence_Fast( item.get(), "map must be a sequence of key, value pairs") );
if ( !pair )
return 0;
if( PySequence_Fast_GET_SIZE( pair.get() ) != 2 )
return cppy::type_error( pair.get(), "pairs of objects" );
self->setitem( PySequence_Fast_GET_ITEM( pair.get(), 0 ),
PySequence_Fast_GET_ITEM( pair.get(), 1 ) );
}
if ( PyErr_Occurred() )
return 0; // error during iteration
}
}

return selfptr.release();
}

PyObject* getitem( PyObject* key, PyObject* default_value = 0 )
{
Items::iterator it = std::lower_bound(
Expand Down Expand Up @@ -288,46 +335,22 @@ SortedMap_new( PyTypeObject* type, PyObject* args, PyObject* kwargs )
static char* kwlist[] = { "map", 0 };
if( !PyArg_ParseTupleAndKeywords( args, kwargs, "|O:__new__", kwlist, &map ) )
return 0;
return SortedMap::New( type, map );
}

PyObject* self = PyType_GenericNew( type, 0, 0 );
if( !self ) {
return 0; // LCOV_EXCL_LINE (allocation failed, very unlikely)
}
SortedMap* cself = reinterpret_cast<SortedMap*>( self );
cself->m_items = new SortedMap::Items();

cppy::ptr seq;
if( map )
{
if( PyDict_Check( map ) )
{
seq = PyObject_GetIter( PyDict_Items( map ) );
if( !seq ) {
return 0; // LCOV_EXCL_LINE (dict items failed, very unlikely)
}
}
else
{
seq = PyObject_GetIter( map );
if( !seq )
return 0;
}
}

if( seq )
{
cppy::ptr item;
while( (item = PyIter_Next( seq.get() )) )
{
if( PySequence_Length( item.get() ) != 2)
return cppy::type_error( item.get(), "pairs of objects" );

cself->setitem( PySequence_GetItem( item.get(), 0 ),
PySequence_GetItem( item.get(), 1 ) );
}
PyObject*
SortedMap_vectorcall( PyObject* type, PyObject*const *args, size_t nargsf, PyObject* kwnames )
{
if ( kwnames )
return cppy::type_error("sortedmap takes no kwargs");
switch (PyVectorcall_NARGS(nargsf)) {
case 0:
return SortedMap::New( reinterpret_cast<PyTypeObject*>(type), 0 );
case 1:
return SortedMap::New( reinterpret_cast<PyTypeObject*>(type), args[0] );
default:
return cppy::type_error("sortedmap takes at most one argument");
}

return self;
}

// Clearing the vector may cause arbitrary side effects on item
Expand All @@ -353,22 +376,21 @@ SortedMap_traverse( SortedMap* self, visitproc visit, void* arg )
Py_VISIT( it->key() );
Py_VISIT( it->value() );
}
#if PY_VERSION_HEX >= 0x03090000
// This was not needed before Python 3.9 (Python issue 35810 and 40217)
Py_VISIT(Py_TYPE(self));
#endif
return 0;
}


void
SortedMap_dealloc( SortedMap* self )
{
PyTypeObject *tp = Py_TYPE(self);
PyObject_GC_UnTrack( self );
SortedMap_clear( self );
delete self->m_items;
self->m_items = 0;
Py_TYPE(self)->tp_free( reinterpret_cast<PyObject*>( self ) );
tp->tp_free( pyobject_cast( self ) );
Py_DECREF(tp);
}


Expand Down Expand Up @@ -508,8 +530,14 @@ SortedMap_repr( SortedMap* self )
cppy::ptr valstr( PyObject_Repr( it->value() ) );
if( !valstr )
return 0;
ostr << "(" << PyUnicode_AsUTF8( keystr.get() ) << ", ";
ostr << PyUnicode_AsUTF8( valstr.get() ) << "), ";
const char* k = PyUnicode_AsUTF8( keystr.get() );
if ( !k )
return 0;
const char* v = PyUnicode_AsUTF8( valstr.get() );
if ( !v )
return 0;
ostr << "(" << k << ", ";
ostr << v << "), ";
}
if( self->m_items->size() > 0 )
ostr.seekp( -2, std::ios_base::cur );
Expand Down Expand Up @@ -574,6 +602,9 @@ static PyType_Slot SortedMap_Type_slots[] = {
{ Py_tp_new, void_cast( SortedMap_new ) }, /* tp_new */
{ Py_tp_iter, void_cast( SortedMap_iter ) }, /* tp_iter */
{ Py_tp_alloc, void_cast( PyType_GenericAlloc ) }, /* tp_alloc */
#if defined(Py_tp_vectorcall)
{ Py_tp_vectorcall, void_cast( SortedMap_vectorcall ) }, /* tp_vectorcall */
#endif
{ Py_mp_length, void_cast( SortedMap_length ) }, /* mp_length */
{ Py_mp_subscript, void_cast( SortedMap_subscript ) }, /* mp_subscript */
{ Py_mp_ass_subscript, void_cast( SortedMap_ass_subscript ) }, /* mp_ass_subscript */
Expand Down
43 changes: 42 additions & 1 deletion tests/datastructure/test_sortedmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""Test the sortedmap that acts like an ordered dictionary."""

import gc
import sys

import pytest

Expand All @@ -33,12 +34,52 @@ def test_sortedmap_init():
assert smap.items() == [(1, 2)]
smap = sortedmap({1: 2})
assert smap.items() == [(1, 2)]
copy = sortedmap(smap)
assert copy.items() == [(1, 2)]

with pytest.raises(TypeError):
sortedmap(1)
with pytest.raises(TypeError):
sortedmap(a=1)
with pytest.raises(TypeError):
sortedmap(1, 2)
with pytest.raises(TypeError) as excinfo:
sortedmap([1])
assert "pairs" in excinfo.exconly()
assert "sequence of key, value pairs" in excinfo.exconly()
with pytest.raises(TypeError) as excinfo:
sortedmap([[1]])
assert "pairs of objects" in excinfo.exconly()


def test_sortedmap_gen_err():
"""Test that iterator error is raised"""

def generator(throw):
yield ("a", 1)
if throw:
raise ValueError()

smap = sortedmap(generator(throw=False))
assert smap["a"] == 1
with pytest.raises(ValueError):
smap = sortedmap(generator(throw=True))


def test_sortedmap_refcnt():
"""Test that constructor does not leak references"""
k = object()
v = object()
rck = sys.getrefcount(k)
rcv = sys.getrefcount(v)
smap = sortedmap([(k, v)])
assert smap[k] == v
del smap
smap = sortedmap({k: v})
assert smap[k] == v
del smap
gc.collect()
assert sys.getrefcount(k) == rck
assert sys.getrefcount(v) == rcv


def test_traverse():
Expand Down
Loading