Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions myogen/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import warnings

from myogen._cuda_env import setup as _setup_cuda
_setup_cuda()
del _setup_cuda

import numpy as np
from numpy.random import Generator

Expand Down
37 changes: 37 additions & 0 deletions myogen/_cuda_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""
Windows CUDA DLL discovery for pip-installed nvidia-* packages.

CuPy 13 cannot locate DLLs shipped by ``pip install nvidia-cuda-nvrtc-cu12``
and similar wheels because Python 3.8+ no longer searches PATH for DLLs.
This module registers those directories via ``os.add_dll_directory`` and
pre-loads the NVRTC builtins library required for JIT compilation.

No-op on Linux / macOS and when the nvidia packages are absent.
"""

import sys


def setup() -> None:
"""Register CUDA DLL paths from pip-installed nvidia-* packages."""
if sys.platform != "win32":
return

import ctypes
import os
import pathlib

for site_dir in [pathlib.Path(p) for p in sys.path if "site-packages" in p]:
# Register every nvidia/*/bin so cublas, cusolver, cufft etc. are found
for bin_dir in site_dir.glob("nvidia/*/bin"):
if bin_dir.is_dir():
os.add_dll_directory(str(bin_dir))

# Pre-load nvrtc-builtins (required by CuPy for JIT kernel compilation)
for nvrtc_dll in site_dir.glob(
"nvidia/cuda_nvrtc/bin/nvrtc-builtins*.dll"
):
try:
ctypes.WinDLL(str(nvrtc_dll))
except OSError:
pass
81 changes: 54 additions & 27 deletions myogen/simulator/core/emg/intramuscular/bioelectric.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_tm_current(z: np.ndarray, D1: float = 96.0, D2: float = -90.0) -> np.nda
return Vm


def get_tm_current_dz(z: np.ndarray, D1: float = 96.0) -> np.ndarray:
def get_tm_current_dz(z: np.ndarray, D1: float = 96.0, xp=np) -> np.ndarray:
"""
Calculate first derivative of transmembrane current (Rosenfalck model).

Expand All @@ -60,16 +60,18 @@ def get_tm_current_dz(z: np.ndarray, D1: float = 96.0) -> np.ndarray:
Spatial coordinates along fiber in mm
D1 : float, default=96.0
Current amplitude parameter in mV/mm³
xp : module, default=np
Array backend (numpy or cupy)

Returns
-------
np.ndarray
First derivative of transmembrane current
"""
Vm = np.zeros_like(z, dtype=np.float64)
Vm = xp.zeros_like(z, dtype=xp.float64)
pos_mask = z > 0
z_pos = z[pos_mask]
Vm[pos_mask] = D1 * (3 * z_pos**2 - z_pos**3) * np.exp(-z_pos)
Vm[pos_mask] = D1 * (3 * z_pos**2 - z_pos**3) * xp.exp(-z_pos)
return Vm


Expand Down Expand Up @@ -102,6 +104,7 @@ def get_elementary_current_response(
r: np.ndarray,
sigma_r: float = 63.0, # S/m
sigma_z: float = 330.0, # S/m
xp=np,
) -> np.ndarray:
"""
Calculate elementary current response for volume conductor.
Expand All @@ -122,6 +125,8 @@ def get_elementary_current_response(
Radial conductivity in S/m (from Andreassen & Rosenfalck 1980)
sigma_z : float, default=330.0
Longitudinal conductivity in S/m (from Andreassen & Rosenfalck 1980)
xp : module, default=np
Array backend (numpy or cupy)

Returns
-------
Expand All @@ -133,13 +138,17 @@ def get_elementary_current_response(
sigma_r_S_per_mm = sigma_r / 1000.0 # CORRECTED: convert S/m → S/mm
sigma_z_S_per_mm = sigma_z / 1000.0 # CORRECTED: convert S/m → S/mm

return np.divide(
1 / 4 / np.pi / sigma_r_S_per_mm,
np.sqrt(sigma_z_S_per_mm / sigma_r_S_per_mm * r**2 + (z - z_electrode) ** 2),
# Normalize inputs to computation backend (prevents numpy/cupy mixing)
z = xp.asarray(z)
z_electrode = float(z_electrode)

return xp.divide(
1 / 4 / xp.pi / sigma_r_S_per_mm,
xp.sqrt(sigma_z_S_per_mm / sigma_r_S_per_mm * r**2 + (z - z_electrode) ** 2),
)


def shift_padding(vec, sh, axis):
def shift_padding(vec, sh, axis, xp=np):
"""
Circularly shifts 'vec' by 'sh' positions along the specified 'axis'
and then pads the shifted region with zeros.
Expand All @@ -152,21 +161,25 @@ def shift_padding(vec, sh, axis):
Shift amount (positive means downward/rightward like MATLAB).
axis : int
Axis along which to shift.
xp : module, default=np
Array backend (numpy or cupy).

Returns
-------
ndarray
Shifted and zero-padded array.
"""
vec = np.roll(vec, sh, axis=axis)
vec = xp.roll(vec, sh, axis=axis)

n = len(vec)
n = vec.shape[0]

# Equivalent of vec(1:sh) = 0
if sh > 0:
vec[:sh] = 0

# Equivalent of vec(end+sh+1:end) = 0
# Note: when sh > 0, both head AND tail are zeroed — this is the
# original MATLAB semantics (suppress wrap-around on both sides).
if sh < 0:
start = n + sh # because end+sh+1 in MATLAB is 1-based
if start < n:
Expand Down Expand Up @@ -215,7 +228,8 @@ def hr_shift_template(x, delay):


def get_current_density(
t, z, zi, L1, L2, v, d=55e-3, suppress_endplate_density=True, endplate_width=0.5
t, z, zi, L1, L2, v, d=55e-3, suppress_endplate_density=True, endplate_width=0.5,
xp=np,
):
"""
Model the individual action potential (IAP) or single fiber action potential (SFAP) in space and time.
Expand All @@ -241,33 +255,49 @@ def get_current_density(
Whether to suppress density at endplate region (default: True)
endplate_width : float, optional
Width around endplate where density is suppressed (mm)
xp : module, default=np
Array backend (numpy or cupy). Pass cupy for GPU acceleration.
"""

dz = np.mean(np.diff(z, axis=0))
z = np.concatenate([z, z[[-1]] + dz], axis=0)
# Normalize inputs to computation backend (prevents numpy/cupy mixing
# when callers pass numpy arrays with xp=cupy)
t = xp.asarray(t)
z = xp.asarray(z)
zi = float(zi)
L1 = float(L1)
L2 = float(L2)
v = float(v)
d = float(d)

dz = xp.mean(xp.diff(z, axis=0))
z = xp.concatenate([z, z[[-1]] + dz], axis=0)

T, Z = np.meshgrid(t, z)
# ravel() needed: t,z arrive as (N,1) column vectors; meshgrid expects 1-D
T, Z = xp.meshgrid(xp.ravel(t), xp.ravel(z))

# Tendon terminator function
def tendon_terminator(z_inline, L_inline):
return (z_inline <= L_inline / 2) & (z_inline >= -L_inline / 2)

# Compute psi (transmembrane current derivative)
if L1 >= L2:
psi = -4 * get_tm_current_dz(-2 * (Z - zi - v * T))
longest_wave = np.diff(psi, axis=0) / dz
psi = -4 * get_tm_current_dz(-2 * (Z - zi - v * T), xp=xp)
longest_wave = xp.diff(psi, axis=0) / dz
longest_wave *= tendon_terminator(Z[:-1, :] - zi - L1 / 2, L1)
longest_wave *= (Z[:-1, :] - zi) / v > 0 # negative time suppression
# Explicit bool→float64 cast required: CuPy does not support
# implicit multiplication of bool arrays with float arrays.
longest_wave *= ((Z[:-1, :] - zi) / v > 0).astype(xp.float64)
else:
psi = 4 * get_tm_current_dz(-2 * (-Z + zi - v * T))
longest_wave = np.diff(psi, axis=0) / dz
psi = 4 * get_tm_current_dz(-2 * (-Z + zi - v * T), xp=xp)
longest_wave = xp.diff(psi, axis=0) / dz
longest_wave *= tendon_terminator(Z[:-1, :] - zi + L2 / 2, L2)
longest_wave *= (-Z[:-1, :] + zi) / v > 0
longest_wave *= ((-Z[:-1, :] + zi) / v > 0).astype(xp.float64) # bool→float64

# Shortest wave (reversed)
shortest_wave = longest_wave[::-1].copy()
shift_amount = int(np.round((L1 + L2 - max(z) + L2 - L1) / dz))
shortest_wave = shift_padding(shortest_wave, shift_amount, 0)
# Use round(float(...)) to avoid device-to-host sync from int(xp.round(...))
shift_amount = round(float((L1 + L2 - float(z.max()) + L2 - L1) / dz))
shortest_wave = shift_padding(shortest_wave, shift_amount, axis=0, xp=xp)

if L1 >= L2:
shortest_wave *= tendon_terminator(Z[:-1, :] - zi + L2 / 2, L2)
Expand All @@ -278,11 +308,8 @@ def tendon_terminator(z_inline, L_inline):

# Suppress endplate density if required
if suppress_endplate_density:

def endplate_terminator(z_inline):
return (z_inline <= (zi - endplate_width)) | (z_inline >= (zi + endplate_width))

iap *= endplate_terminator(Z[:-1, :])
iap *= ((Z[:-1, :] <= (zi - endplate_width)) |
(Z[:-1, :] >= (zi + endplate_width))).astype(xp.float64)

# ---- FIXED UNIT CONVERSIONS ----
# Intracellular conductivity: 1.01 S/m → convert to S/mm
Expand All @@ -291,7 +318,7 @@ def endplate_terminator(z_inline):

# Fiber diameter is already in mm (default d=55e-3 mm = 55 um)
# Compute cross-sectional area in mm²
area_mm2 = np.pi * (d / 2) ** 2 # CORRECTED: removed extra /4
area_mm2 = xp.pi * (d / 2) ** 2 # CORRECTED: removed extra /4

# Scale current density by intracellular conductivity and fiber cross-section area
iap *= sigma_i * area_mm2
Expand Down
Loading
Loading