From f3f632e32d2eb16af993fb310d8dd7103f0b7d23 Mon Sep 17 00:00:00 2001 From: monkeyman192 Date: Sat, 15 Nov 2025 22:32:54 +1100 Subject: [PATCH] Added ability to specify function args with typing.Annotated --- docs/docs/change_log.rst | 1 + docs/docs/creating_hook_definitions.rst | 36 ++++++++++++++++----- pymhf/core/functions.py | 14 +++++++-- pymhf/extensions/ctypes.py | 14 +++++++-- pymhf/utils/partial_struct.py | 14 ++------- tests/unit/test_functions.py | 42 +++++++++++-------------- 6 files changed, 73 insertions(+), 48 deletions(-) diff --git a/docs/docs/change_log.rst b/docs/docs/change_log.rst index 879350d..00340ee 100644 --- a/docs/docs/change_log.rst +++ b/docs/docs/change_log.rst @@ -5,6 +5,7 @@ Current (0.2.1.dev) ------------------- - Fixed an issue where `ctypes.Union` types weren't accepted as a valid function argument type. +- Added the ability to specify the type of a function argument using ``typing.Annotated`` to improve type hinting in libraries. 0.2.0 (02/10/2025) -------------------- diff --git a/docs/docs/creating_hook_definitions.rst b/docs/docs/creating_hook_definitions.rst index cf0590c..918f540 100644 --- a/docs/docs/creating_hook_definitions.rst +++ b/docs/docs/creating_hook_definitions.rst @@ -11,7 +11,7 @@ These decorators do a few things when applied: 2. It also enables calling the function or method directly (more on that below). 3. Finally, the decorators transform the function or method into a decorator which can be applied to the method in our mod which we wish to use as a detour. -Because of this first point, the decorated function MUST have correct type hints. If they are not correct, then the hook will likely fail, produce incorrect results, or even cause the program to crash. +Because of this first point, the decorated function MUST have correct type hints. If they are not correct, then the hook will likely fail, produce incorrect results, or even cause the program to crash. See :ref:`function_type_hints` for more details. The best way to see how these decorators are used is with a few code examples. @@ -39,7 +39,7 @@ The best way to see how these decorators are used is with a few code examples. lpNumberOfBytesRead: wintypes.LPDWORD, lpOverlapped: wintypes.LPVOID, ) -> wintypes.BOOL: - pass + ... class ReadFileMod(Mod): @@ -88,7 +88,7 @@ Exported functions are those which are provided by the binary itself. There are in_pExternalSources: ctypes.c_uint64 = 0, in_PlayingID: ctypes.c_uint32 = 0, ) -> ctypes.c_uint64: - pass + ... class AudioNames(Mod): @@ -164,7 +164,7 @@ Defining functions to hook is done in much the same way as above, however, we si event: ctypes._Pointer[TkAudioID], object: ctypes.c_int64, ) -> ctypes.c_bool: - pass + ... class AudioNames(Mod): @@ -266,7 +266,7 @@ To hook or call a function with an overload, append ``.overload(overload_id: str object: ctypes.c_int64, attenuationScale: ctypes.c_float, ) -> ctypes.c_bool: - pass + ... @function_hook("48 83 EC ? 33 C9 4C 8B D2 89 4C 24 ? 49 8B C0 48 89 4C 24 ? 45 33 C9", overload_id="normal") @overload @@ -276,7 +276,7 @@ To hook or call a function with an overload, append ``.overload(overload_id: str event: ctypes._Pointer[TkAudioID], object: ctypes.c_int64, ) -> ctypes.c_bool: - pass + ... class AudioNames(Mod): @@ -363,13 +363,35 @@ Using ``before`` and ``after`` methods The ``.before`` and ``.after`` method of the functions decorated by the ``function_hook`` or ``static_function_hook`` is required to be used when using this as a decorator to tell pyMHF whether to run the detour before or after the original function. If this is not included then an error will be raised. Depending on whether you mark the hooks as ``before`` or ``after`` hook you may get some functionality. See :ref:`here ` for more details. +.. _function_type_hints: + Function type hints ^^^^^^^^^^^^^^^^^^^ As mentioned at the start of this document, it is critical that the functions which are decorated with these two decorators have correct and complete type hints. These types MUST be either a ctypes plain type (eg. ``ctypes.c_uint32``), a ctypes pointer to some type, or a class which inherits from ``ctypes.Structure``. Note that the :class:`~pymhf.core.hooking.Structure` inherits from this so a type inheriting from this type is also permissible. + +To improve type hinting, it is however possible to specify the type using `typing.Annotated `_. +For example, instead of writing + +.. code-block:: py + + @function_hook("AB CD EF") + def AwardMoney(self, this: "ctypes._Pointer[Obj]", liChange: ctypes.c_int32) -> ctypes.c_uint64: + ... + +We can type + +.. code-block:: py + + @function_hook("AB CD EF") + def AwardMoney(self, this: "ctypes._Pointer[Obj]", liChange: Annotated[int, ctypes.c_int32]) -> ctypes.c_uint64: + ... + +This has the benefit that when we call this method from an instance of the ``Obj`` class our type checker will not complain about being passed an integer (which python will transparently convert to its ``ctypes`` counterpart anyway). + Further, you will have seen above that none of these functions have any actual body. This is because even when we call this function, we don't actually execute the code contained within it. -Because of this it's recommended that you simply add ``pass`` to the body of the function as above. +Because of this it's recommended that you simply add ``...`` to the body of the function as above. We use the Ellipses (``...``) instead of ``pass`` since it keeps type checkers happier and is more consistent with how type stubs are represented in python (which is essentially what we are defining). Any docstrings which are included as part of the body will be shown in your IDE of choice, so if you are writing a library it's recommended that you add docstrings if convenient so that users may know what the function does. .. warning:: diff --git a/pymhf/core/functions.py b/pymhf/core/functions.py index a1da950..6e05d59 100644 --- a/pymhf/core/functions.py +++ b/pymhf/core/functions.py @@ -1,7 +1,7 @@ import ctypes import inspect from functools import lru_cache -from typing import Any, Callable, NamedTuple, Optional, get_args +from typing import Any, Callable, NamedTuple, Optional, _AnnotatedAlias, get_args from typing_extensions import get_type_hints @@ -61,7 +61,7 @@ def _get_funcdef(func: Callable) -> FuncDef: This is wrapped in an lru_cache so that if multiple detours use the same function, it will only be analysed once.""" func_params = inspect.signature(func).parameters - func_type_hints = get_type_hints(func) + func_type_hints = get_type_hints(func, include_extras=True) _restype = func_type_hints.pop("return", type(None)) if _restype is type(None): restype = None @@ -76,6 +76,16 @@ def _get_funcdef(func: Callable) -> FuncDef: if name != "self": if name in func_type_hints: argtype = func_type_hints[name] + # Check if the type is an annotation. If it is, then extract the actual type. + if isinstance(argtype, _AnnotatedAlias): + if len(meta := argtype.__metadata__) == 1: + argtype = meta[0] + else: + raise TypeError( + f"Invalid annotation {meta!r} for argument {name!r}. For Annotated types they " + "must have their 'python' type and their 'ctype' like " + "`Annotated[int, ctypes.c_int32]`." + ) if issubclass(argtype, get_args(CTYPES)): default_val = param.default if default_val != inspect.Signature.empty: diff --git a/pymhf/extensions/ctypes.py b/pymhf/extensions/ctypes.py index c975d7d..8a4ed08 100644 --- a/pymhf/extensions/ctypes.py +++ b/pymhf/extensions/ctypes.py @@ -2,7 +2,6 @@ import ctypes import types -from _ctypes import _Pointer from enum import IntEnum from typing import Generic, Type, TypeVar, Union @@ -10,8 +9,6 @@ IE = TypeVar("IE", bound=IntEnum) -CTYPES = Union[ctypes._SimpleCData, ctypes.Structure, ctypes._Pointer, _Pointer, ctypes.Union, ctypes.Array] - class c_enum32(ctypes.c_int32, Generic[IE]): """c_int32 wrapper for enums. This doesn't have the full set of features an enum would normally have, @@ -56,3 +53,14 @@ def __class_getitem__(cls: Type["c_enum32"], enum_type: Type[IE]): _cls._enum_type = enum_type _cenum_type_cache[enum_type] = _cls return _cls + + +CTYPES = Union[ + ctypes._SimpleCData, + ctypes.Structure, + ctypes._Pointer, + ctypes._Pointer_orig, # The original, un-monkeypatched ctypes._Pointer object + ctypes.Array, + ctypes.Union, + c_enum32, +] diff --git a/pymhf/utils/partial_struct.py b/pymhf/utils/partial_struct.py index 60d4552..3315f48 100644 --- a/pymhf/utils/partial_struct.py +++ b/pymhf/utils/partial_struct.py @@ -1,24 +1,14 @@ import ctypes import inspect from dataclasses import dataclass -from typing import Optional, Type, TypeVar, Union, _AnnotatedAlias, get_args +from typing import Optional, Type, TypeVar, _AnnotatedAlias, get_args from typing_extensions import get_type_hints -from pymhf.extensions.ctypes import c_enum32 +from pymhf.extensions.ctypes import CTYPES _T = TypeVar("_T", bound=Type[ctypes.Structure]) -CTYPES = Union[ - ctypes._SimpleCData, - ctypes.Structure, - ctypes._Pointer, - ctypes._Pointer_orig, # The original, un-monkeypatched ctypes._Pointer object - ctypes.Array, - ctypes.Union, - c_enum32, -] - @dataclass class Field: diff --git a/tests/unit/test_functions.py b/tests/unit/test_functions.py index 0557637..023de43 100644 --- a/tests/unit/test_functions.py +++ b/tests/unit/test_functions.py @@ -1,5 +1,6 @@ import ctypes import re +from typing import Annotated import pytest from typing_extensions import Self @@ -53,8 +54,7 @@ def test_get_funcdef_function(): """Test getting the FuncDef for functions.""" # Function with normal args and a None return. - def func(x: ctypes.c_uint32, y: ctypes.c_uint64 = 1234) -> None: - pass + def func(x: ctypes.c_uint32, y: Annotated[int, ctypes.c_uint64] = 1234) -> None: ... fd = _get_funcdef(func) assert fd.restype is None @@ -63,8 +63,7 @@ def func(x: ctypes.c_uint32, y: ctypes.c_uint64 = 1234) -> None: assert fd.defaults == {"y": 1234} # Function with no args. - def func2() -> ctypes.c_uint32: - pass + def func2() -> ctypes.c_uint32: ... fd = _get_funcdef(func2) assert fd.restype == ctypes.c_uint32 @@ -76,8 +75,7 @@ class ID(ctypes.Structure): _fields_ = [] # Function with an argument which is a pointer. - def func3(id_: ctypes._Pointer[ID]) -> ctypes.c_bool: - pass + def func3(id_: ctypes._Pointer[ID]) -> ctypes.c_bool: ... fd = _get_funcdef(func3) assert fd.restype == ctypes.c_bool @@ -87,24 +85,21 @@ def func3(id_: ctypes._Pointer[ID]) -> ctypes.c_bool: # Function with mixed stringified and non-stringified args. # (emulates `from __future__ import annotations`) - def func4(a: "ctypes.c_int32" = 42, b: ctypes.c_uint16 = 4): - pass + def func4(a: "ctypes.c_int32", b: Annotated[int, ctypes.c_uint16] = 4): ... fd = _get_funcdef(func4) assert fd.restype is None assert fd.arg_names == ["a", "b"] assert fd.arg_types == [ctypes.c_int32, ctypes.c_uint16] - assert fd.defaults == {"a": 42, "b": 4} + assert fd.defaults == {"b": 4} # Function with an invalid argument types. - def func5(a: int): - pass + def func5(a: int): ... with pytest.raises(TypeError, match=re.escape("Invalid type for argument 'a'")): _get_funcdef(func5) - def func6(a, b: ctypes.c_int64): - pass + def func6(a, b: ctypes.c_int64): ... with pytest.raises( TypeError, @@ -112,8 +107,7 @@ def func6(a, b: ctypes.c_int64): ): _get_funcdef(func6) - def func7() -> int: - pass + def func7() -> int: ... with pytest.raises( TypeError, @@ -128,32 +122,32 @@ def test_get_funcdef_method(): class MyClass: def thing(self): # Very boring method with no args or return value. - pass + ... - def thing2(self, x: ctypes.c_uint32, y: ctypes.c_uint16 = 7) -> ctypes.c_uint32: + def thing2(self, x: ctypes.c_uint32, y: Annotated[int, ctypes.c_uint16] = 7) -> ctypes.c_uint32: # Fairly boring method with some args and a default value. - pass + ... @staticmethod - def thing3(x: ctypes.c_float = 999): + def thing3(x: Annotated[float, ctypes.c_float] = 999): # Static method. - pass + ... def thing4(self: Self, x: "ctypes.c_uint32"): # Valid "stringified" type. - pass + ... def thing5(self, x: "int"): # Invalid "stringified" type. - pass + ... def thing6(self, x: ctypes.c_uint32, y=7, z=None) -> ctypes.c_uint32: # Argument missing type hint. - pass + ... def thing7(self) -> "int": # Invalid "stringified" return type. - pass + ... fd = _get_funcdef(MyClass.thing) assert fd.restype is None