Skip to content

Commit 75a7ffe

Browse files
committed
[mypyc] Generate more type methods for types with managed dicts
1 parent eb41eb1 commit 75a7ffe

4 files changed

Lines changed: 68 additions & 19 deletions

File tree

mypyc/codegen/emitclass.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,8 @@ def generate_class(cl: ClassIR, module: str, emitter: Emitter) -> None:
262262
if not cl.builtin_base:
263263
fields["tp_new"] = new_name
264264

265-
if generate_full:
265+
managed_dict = has_managed_dict(cl, emitter)
266+
if generate_full or managed_dict:
266267
fields["tp_dealloc"] = f"(destructor){name_prefix}_dealloc"
267268
if not cl.is_acyclic:
268269
fields["tp_traverse"] = f"(traverseproc){name_prefix}_traverse"
@@ -335,6 +336,14 @@ def emit_line() -> None:
335336
else:
336337
fields["tp_basicsize"] = base_size
337338

339+
if generate_full or managed_dict:
340+
if not cl.is_acyclic:
341+
generate_traverse_for_class(cl, traverse_name, emitter)
342+
emit_line()
343+
generate_clear_for_class(cl, clear_name, emitter)
344+
emit_line()
345+
generate_dealloc_for_class(cl, dealloc_name, clear_name, bool(del_method), emitter)
346+
emit_line()
338347
if generate_full:
339348
assert cl.setup is not None
340349
emitter.emit_line(native_function_header(cl.setup, emitter) + ";")
@@ -345,13 +354,6 @@ def emit_line() -> None:
345354
init_fn = cl.get_method("__init__")
346355
generate_new_for_class(cl, new_name, vtable_name, setup_name, init_fn, emitter)
347356
emit_line()
348-
if not cl.is_acyclic:
349-
generate_traverse_for_class(cl, traverse_name, emitter)
350-
emit_line()
351-
generate_clear_for_class(cl, clear_name, emitter)
352-
emit_line()
353-
generate_dealloc_for_class(cl, dealloc_name, clear_name, bool(del_method), emitter)
354-
emit_line()
355357

356358
if cl.allow_interpreted_subclasses:
357359
shadow_vtable_name: str | None = generate_vtables(
@@ -380,7 +382,7 @@ def emit_line() -> None:
380382
emit_line()
381383

382384
flags = ["Py_TPFLAGS_DEFAULT", "Py_TPFLAGS_HEAPTYPE", "Py_TPFLAGS_BASETYPE"]
383-
if generate_full and not cl.is_acyclic:
385+
if (generate_full or managed_dict) and not cl.is_acyclic:
384386
flags.append("Py_TPFLAGS_HAVE_GC")
385387
if cl.has_method("__call__"):
386388
fields["tp_vectorcall_offset"] = "offsetof({}, vectorcall)".format(
@@ -391,7 +393,7 @@ def emit_line() -> None:
391393
# This is just a placeholder to please CPython. It will be
392394
# overridden during setup.
393395
fields["tp_call"] = "PyVectorcall_Call"
394-
if has_managed_dict(cl, emitter):
396+
if managed_dict:
395397
flags.append("Py_TPFLAGS_MANAGED_DICT")
396398
fields["tp_flags"] = " | ".join(flags)
397399

@@ -869,7 +871,8 @@ def generate_traverse_for_class(cl: ClassIR, func_name: str, emitter: Emitter) -
869871
for attr, rtype in base.attributes.items():
870872
emitter.emit_gc_visit(f"self->{emitter.attr(attr)}", rtype)
871873
if has_managed_dict(cl, emitter):
872-
emitter.emit_line("PyObject_VisitManagedDict((PyObject *)self, visit, arg);")
874+
emitter.emit_line("int rv = PyObject_VisitManagedDict((PyObject *)self, visit, arg);")
875+
emitter.emit_line("if (rv < 0) return rv;")
873876
elif cl.has_dict:
874877
struct_name = cl.struct_name(emitter.names)
875878
# __dict__ lives right after the struct and __weakref__ lives right after that
@@ -934,6 +937,12 @@ def generate_dealloc_for_class(
934937
emitter.emit_line("if (res < 0) {")
935938
emitter.emit_line("goto done;")
936939
emitter.emit_line("}")
940+
if cl.builtin_base:
941+
# For native subclasses of builtins such as dict, the base deallocator
942+
# is responsible for tearing down base-owned storage and freeing memory.
943+
emitter.emit_line(f"{clear_func_name}(self);")
944+
emitter.emit_line("Py_TYPE(self)->tp_base->tp_dealloc((PyObject *)self);")
945+
emitter.emit_line("goto done;")
937946
if not cl.is_acyclic:
938947
emitter.emit_line("PyObject_GC_UnTrack(self);")
939948
if cl.reuse_freed_instance:

mypyc/test-data/fixtures/ir.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __pow__(self, other: T_contra, modulo: _M) -> T_co: ...
4040

4141
class object:
4242
__class__: type
43+
__dict__: dict[str, Any]
4344
def __new__(cls) -> Self: pass
4445
def __init__(self) -> None: pass
4546
def __init_subclass__(cls, **kwargs: object) -> None: pass

mypyc/test-data/run-classes.test

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3325,20 +3325,31 @@ def test_function():
33253325
assert(isinstance(d.fitem, ForwardDefinedClass))
33263326
assert(isinstance(d.fitems, ForwardDefinedClass))
33273327

3328-
[case testDelForDictSubclass-xfail]
3329-
# The crash in issue mypy#19175 is fixed.
3330-
# But, for classes that derive from built-in Python classes, user-defined __del__ method is not
3331-
# being invoked.
3328+
[case testDelForDictSubclass]
3329+
events: list[str] = []
3330+
3331+
class Item:
3332+
def __del__(self) -> None:
3333+
events.append("deleting Item")
3334+
33323335
class DictSubclass(dict):
3333-
def __del__(self):
3334-
print("deleting DictSubclass...")
3336+
def __del__(self) -> None:
3337+
events.append("deleting DictSubclass")
3338+
3339+
def test_dict_subclass_dealloc() -> None:
3340+
d = DictSubclass()
3341+
d["item"] = Item()
3342+
del d
33353343

33363344
[file driver.py]
33373345
import native
3338-
native.DictSubclass()
3346+
native.test_dict_subclass_dealloc()
3347+
assert native.events == [
3348+
"deleting DictSubclass",
3349+
"deleting Item",
3350+
]
33393351

33403352
[out]
3341-
deleting DictSubclass...
33423353

33433354
[case testDel]
33443355
class A:

mypyc/test-data/run-dicts.test

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,3 +368,31 @@ class subc(dict[Any, Any]):
368368
[file userdefineddict.py]
369369
class dict:
370370
pass
371+
372+
[case testDunderDictAccessAfterDel]
373+
class NormDict(dict[str, str]):
374+
def __init__(self, attr: int = 42) -> None:
375+
super().__init__()
376+
self.attr = attr
377+
378+
def test_dict_access() -> None:
379+
n = NormDict(1)
380+
d = n.__dict__
381+
assert d["attr"] == 1
382+
del n
383+
assert d["attr"] == 1
384+
385+
[file driver.py]
386+
from native import NormDict, test_dict_access
387+
388+
def test_dict_access_interpreted() -> None:
389+
n = NormDict()
390+
d = n.__dict__
391+
assert d["attr"] == 42
392+
del n
393+
assert d["attr"] == 42
394+
395+
test_dict_access()
396+
test_dict_access_interpreted()
397+
398+
[fixture fixtures/typing-full.pyi]

0 commit comments

Comments
 (0)