Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Patches/PostModInitPatch.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ private static void CheckSpecialSpireField(FieldInfo field)
var genericTypeDef = fType.GetGenericTypeDefinition();

if (genericTypeDef != typeof(SavedSpireField<,>) &&
genericTypeDef != typeof(AddedNode<,>))
genericTypeDef != typeof(AddedNode<,>) &&
genericTypeDef != typeof(SpireMethod.SpireMethod<>))
return;

field.GetValue(null); //Trigger field initialization
Expand Down
98 changes: 98 additions & 0 deletions SpireMethod/SpireMethod.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
using System.Reflection;
using HarmonyLib;

namespace BaseLib.SpireMethod;

/// <summary>
/// Attaches new behavior to a virtual method on an existing type without requiring
/// every mod author to write their own Harmony patch with runtime type checks.
///
/// <para>Register handlers via the static <c>Register</c> methods:</para>
/// <code>
/// // Async
/// static SpireMethod&lt;MyRelic&gt; hook = SpireMethod&lt;MyRelic&gt;.Register(
/// nameof(AbstractModel.AfterCardPlayed),
/// async (MyRelic instance, object[] args) =&gt; { /* ... */ }
/// );
///
/// // Value-returning
/// static SpireMethod&lt;MyRelic&gt; hook = SpireMethod&lt;MyRelic&gt;.Register(
/// nameof(AbstractModel.ModifyDamageAdditive),
/// (MyRelic instance, decimal current, object[] args) =&gt; current + 5m
/// );
///
/// // Void
/// static SpireMethod&lt;MyRelic&gt; hook = SpireMethod&lt;MyRelic&gt;.Register(
/// nameof(AbstractModel.SomeMethod),
/// (MyRelic instance, object[] args) =&gt; { /* ... */ }
/// );
/// </code>
///
/// <para>Handlers are invoked after the base implementation, in registration order.</para>
///
/// <para><typeparamref name="T"/> must <b>not</b> declare its own override of the target
/// method. If it does, patch the override directly with Harmony instead.</para>
/// </summary>
/// <typeparam name="T">The concrete type whose instances should receive the handler.</typeparam>
public sealed class SpireMethod<T> where T : class
{
private SpireMethod()
{
}

/// <summary>Register a handler for an async (Task-returning) virtual method.</summary>
public static SpireMethod<T> Register(string methodName, AsyncSpireMethodHandler<T> handler)
{
var declaring = ResolveAndValidate(methodName);
SpireMethodRegistry.Register(declaring, new AsyncHandler<T>(handler));
return new SpireMethod<T>();
}

/// <summary>Register a handler for a void virtual method.</summary>
public static SpireMethod<T> Register(string methodName, VoidSpireMethodHandler<T> handler)
{
var declaring = ResolveAndValidate(methodName);
SpireMethodRegistry.Register(declaring, new VoidHandler<T>(handler));
return new SpireMethod<T>();
}

/// <summary>
/// Register a handler for a value-returning virtual method.
/// Each handler receives the current return value and may return a modified one.
/// </summary>
public static SpireMethod<T> Register<TReturn>(string methodName, ValueSpireMethodHandler<T, TReturn> handler)
{
var declaring = ResolveAndValidate(methodName);
SpireMethodRegistry.Register(declaring, new ValueHandler<T, TReturn>(handler));
return new SpireMethod<T>();
}

private static MethodInfo ResolveAndValidate(string methodName)
{
var resolved = AccessTools.Method(typeof(T), methodName)
?? throw new ArgumentException(
$"SpireMethod<{typeof(T).Name}>: method '{methodName}' not found on " +
$"'{typeof(T).FullName}' or any base class.");

var declaredOnT = AccessTools.DeclaredMethod(typeof(T), methodName);
if (declaredOnT != null)
throw new InvalidOperationException(
$"SpireMethod<{typeof(T).Name}>: '{typeof(T).Name}' already declares an override " +
$"of '{methodName}'. Use a direct Harmony patch instead.");

return resolved.GetBaseDefinition();
}
}

/// <summary>Handler delegate for an async (Task-returning) virtual method.</summary>
public delegate Task AsyncSpireMethodHandler<T>(T instance, object[] args) where T : class;

/// <summary>Handler delegate for a void virtual method.</summary>
public delegate void VoidSpireMethodHandler<T>(T instance, object[] args) where T : class;

/// <summary>
/// Handler delegate for a value-returning virtual method.
/// Receives the current return value and the original method arguments;
/// returns a value that replaces the method's return value.
/// </summary>
public delegate TReturn ValueSpireMethodHandler<T, TReturn>(T instance, TReturn current, object[] args) where T : class;
39 changes: 39 additions & 0 deletions SpireMethod/SpireMethodHandlers.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
namespace BaseLib.SpireMethod;

internal abstract class SpireMethodHandlerBase
{
public abstract Type TargetType { get; }

public virtual void InvokeVoid(object instance, object[] args) =>
throw new InvalidOperationException("This handler does not support void invocation.");

public virtual Task InvokeAsync(object instance, object[] args) =>
throw new InvalidOperationException("This handler does not support async invocation.");

public virtual object? InvokeValue(object instance, object? current, object[] args) =>
throw new InvalidOperationException("This handler does not support value invocation.");
}

internal sealed class AsyncHandler<T>(AsyncSpireMethodHandler<T> handler) : SpireMethodHandlerBase where T : class
{
public override Type TargetType => typeof(T);

public override Task InvokeAsync(object instance, object[] args) =>
handler((T)instance, args);
}

internal sealed class VoidHandler<T>(VoidSpireMethodHandler<T> handler) : SpireMethodHandlerBase where T : class
{
public override Type TargetType => typeof(T);

public override void InvokeVoid(object instance, object[] args) =>
handler((T)instance, args);
}

internal sealed class ValueHandler<T, TReturn>(ValueSpireMethodHandler<T, TReturn> handler) : SpireMethodHandlerBase where T : class
{
public override Type TargetType => typeof(T);

public override object? InvokeValue(object instance, object? current, object[] args) =>
handler((T)instance, (TReturn)current!, args);
}
122 changes: 122 additions & 0 deletions SpireMethod/SpireMethodRegistry.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
using System.Reflection;
using System.Runtime.InteropServices;
using HarmonyLib;

namespace BaseLib.SpireMethod;

/// <summary>
/// Central registry for SpireMethod handlers.
///
/// Holds a <c>Dictionary&lt;MethodInfo, List&lt;SpireMethodHandlerBase&gt;&gt;</c> keyed by the
/// declaring base-class method, applies lazy Harmony postfixes, and dispatches at runtime
/// by matching <c>__instance.GetType()</c> against each handler's
/// <see cref="SpireMethodHandlerBase.TargetType"/>.
///
/// The postfix dispatchers receive the patched <c>MethodBase</c> via Harmony's
/// <c>__originalMethod</c> injection, so a single static postfix method handles all
/// registrations for a given return-type category.
/// </summary>
internal static class SpireMethodRegistry
{
private static readonly Dictionary<MethodInfo, List<SpireMethodHandlerBase>> _handlers = [];

// Tracks which methods have already had a Harmony postfix applied. I did have concerns on startup impact, but adapt to normal patching if lazy patching like this is not preferable.
private static readonly HashSet<MethodInfo> _patched = [];

internal static void Register(MethodInfo declaringMethod, SpireMethodHandlerBase handler)
{
ref var handlers = ref CollectionsMarshal.GetValueRefOrAddDefault(_handlers, declaringMethod, out _);
handlers ??= [];
handlers.Add(handler);

// Apply the Harmony postfix lazily (once per method).
LazyPatch(declaringMethod);
}

private static void LazyPatch(MethodInfo method)
{
if (!_patched.Add(method)) return;

var returnType = method.ReturnType;

var postfix = new HarmonyMethod(returnType switch
{
_ when returnType == typeof(Task) => AccessTools.Method(typeof(AsyncPostfixDispatcher),
nameof(AsyncPostfixDispatcher.Postfix)),
_ when returnType == typeof(void) => AccessTools.Method(typeof(VoidPostfixDispatcher),
nameof(VoidPostfixDispatcher.Postfix)),
_ => AccessTools.Method(typeof(ValuePostfixDispatcher<>).MakeGenericType(returnType),
nameof(ValuePostfixDispatcher<object>.Postfix))
});

BaseLibMain.MainHarmony.Patch(method, postfix: postfix);
BaseLibMain.Logger.Info(
$"SpireMethod: patched {method.DeclaringType?.Name}.{method.Name} (return: {returnType.Name})");
}

private static List<SpireMethodHandlerBase>? GetHandlers(MethodBase originalMethod, Type instanceType)
{
var key = ((MethodInfo)originalMethod).GetBaseDefinition();
if (!_handlers.TryGetValue(key, out var all)) return null;

// Filter to handlers whose TargetType is assignable from the instance's runtime type
// (so handlers for RunicPyramid only fire on RunicPyramid instances).
List<SpireMethodHandlerBase>? result = null;
foreach (var handler in all)
{
if (!handler.TargetType.IsAssignableFrom(instanceType)) continue;
result ??= [];
result.Add(handler);
}

return result;
}

private static class AsyncPostfixDispatcher
{
public static void Postfix(object __instance, ref Task __result, object[] __args, MethodBase __originalMethod)
{
var handlers = GetHandlers(__originalMethod, __instance.GetType());
if (handlers == null) return;

__result = ChainAsync(handlers, __instance, __args);
}

private static async Task ChainAsync(List<SpireMethodHandlerBase> handlers, object instance, object[] args)
{
foreach (var handler in handlers)
{
await handler.InvokeAsync(instance, args);
}
}
}

private static class VoidPostfixDispatcher
{
public static void Postfix(object __instance, object[] __args, MethodBase __originalMethod)
{
var handlers = GetHandlers(__originalMethod, __instance.GetType());
if (handlers == null) return;

foreach (var handler in handlers)
{
handler.InvokeVoid(__instance, __args);
}
}
}

private static class ValuePostfixDispatcher<TReturn>
{
public static void Postfix(object __instance, ref TReturn __result, object[] __args,
MethodBase __originalMethod)
{
var handlers = GetHandlers(__originalMethod, __instance.GetType());
if (handlers == null) return;

foreach (var handler in handlers)
{
__result = (TReturn)handler.InvokeValue(__instance, __result, __args)!;
}
}
}
}