From 7e87378e2855d3a12f1039a4ec9239bc7dd71c11 Mon Sep 17 00:00:00 2001 From: nhz2 Date: Sat, 28 Mar 2026 13:47:50 -0400 Subject: [PATCH 1/3] Refactor MPIPtr cconvert --- ext/AMDGPUExt.jl | 18 +++--------------- ext/CUDAExt.jl | 16 ++++------------ src/api/api.jl | 4 +--- src/buffers.jl | 25 ++++++++++++++++++------- 4 files changed, 26 insertions(+), 37 deletions(-) diff --git a/ext/AMDGPUExt.jl b/ext/AMDGPUExt.jl index 912525bff..a79af33f7 100644 --- a/ext/AMDGPUExt.jl +++ b/ext/AMDGPUExt.jl @@ -2,22 +2,10 @@ module AMDGPUExt import MPI isdefined(Base, :get_extension) ? (import AMDGPU) : (import ..AMDGPU) -import MPI: MPIPtr, Buffer, Datatype +import MPI: MPIPtr, Buffer, Datatype, CConvWrapper -function Base.cconvert(::Type{MPIPtr}, A::AMDGPU.ROCArray{T}) where T - A -end - -function Base.unsafe_convert(::Type{MPIPtr}, X::AMDGPU.ROCArray{T}) where T - reinterpret(MPIPtr, Base.unsafe_convert(Ptr{T}, X)) -end - -# only need to define this for strided arrays: all others can be handled by generic machinery -function Base.unsafe_convert(::Type{MPIPtr}, V::SubArray{T,N,P,I,true}) where {T,N,P<:AMDGPU.ROCArray,I} - X = parent(V) - pX = Base.unsafe_convert(Ptr{T}, X) - pV = pX + ((V.offset1 + V.stride1) - first(LinearIndices(X)))*sizeof(T) - return reinterpret(MPIPtr, pV) +function Base.cconvert(::Type{MPIPtr}, x::AMDGPU.ROCArray{T}) where T + CConvWrapper(Ptr{T}, x) end function Buffer(arr::AMDGPU.ROCArray) diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index f86dcd18a..ffafdfc60 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -2,22 +2,14 @@ module CUDAExt import MPI isdefined(Base, :get_extension) ? (import CUDA) : (import ..CUDA) -import MPI: MPIPtr, Buffer, Datatype +import MPI: MPIPtr, Buffer, Datatype, CConvWrapper function Base.cconvert(::Type{MPIPtr}, buf::CUDA.CuArray{T}) where T - Base.cconvert(CUDA.CuPtr{T}, buf) # returns DeviceBuffer + CConvWrapper(CUDA.CuPtr{T}, buf) end -function Base.unsafe_convert(::Type{MPIPtr}, X::CUDA.CuArray{T}) where T - reinterpret(MPIPtr, Base.unsafe_convert(CUDA.CuPtr{T}, X)) -end - -# only need to define this for strided arrays: all others can be handled by generic machinery -function Base.unsafe_convert(::Type{MPIPtr}, V::SubArray{T,N,P,I,true}) where {T,N,P<:CUDA.CuArray,I} - X = parent(V) - pX = Base.unsafe_convert(CUDA.CuPtr{T}, X) - pV = pX + ((V.offset1 + V.stride1) - first(LinearIndices(X)))*sizeof(T) - return reinterpret(MPIPtr, pV) +function Base.cconvert(::Type{MPIPtr}, buf::SubArray{T,N,P,I,true}) where {T,N,P<:CUDA.CuArray,I} + CConvWrapper(CUDA.CuPtr{T}, buf) end function Buffer(arr::CUDA.CuArray) diff --git a/src/api/api.jl b/src/api/api.jl index 5b731c3b6..6fccd3f1f 100644 --- a/src/api/api.jl +++ b/src/api/api.jl @@ -76,9 +76,7 @@ end primitive type MPIPtr Sys.WORD_SIZE end @assert sizeof(MPIPtr) == sizeof(Ptr{Cvoid}) -Base.cconvert(::Type{MPIPtr}, x::SentinelPtr) = x -Base.unsafe_convert(::Type{MPIPtr}, x::SentinelPtr) = reinterpret(MPIPtr, x) - +Base.cconvert(::Type{MPIPtr}, x::SentinelPtr) = reinterpret(MPIPtr, x) # Initialize the ref constants from the library. # This is not `API.__init__`, as it should be called _after_ diff --git a/src/buffers.jl b/src/buffers.jl index 57d848d6c..475c365dc 100644 --- a/src/buffers.jl +++ b/src/buffers.jl @@ -1,16 +1,27 @@ MPIBuffertype{T} = Union{Ptr{T}, Array{T}, SubArray{T}, Ref{T}} MPIBuffertypeOrConst{T} = Union{MPIBuffertype{T}, SentinelPtr} -Base.cconvert(::Type{MPIPtr}, x::Union{Ptr{T}, Array{T}, Ref{T}}) where T = Base.cconvert(Ptr{T}, x) -Base.cconvert(::Type{MPIPtr}, x::SubArray{T}) where T = Base.cconvert(Ptr{T}, x) -function Base.unsafe_convert(::Type{MPIPtr}, x::MPIBuffertype{T}) where T - ptr = Base.unsafe_convert(Ptr{T}, x) +struct CConvWrapper{T, C} + cconv::C +end +function CConvWrapper(T, x) + cconv = Base.cconvert(T, x) + CConvWrapper{T, typeof(cconv)}(cconv) +end + +function Base.unsafe_convert(::Type{MPIPtr}, x::CConvWrapper{T}) where T + ptr = Base.unsafe_convert(T, x.cconv) reinterpret(MPIPtr, ptr) end +function Base.cconvert(::Type{MPIPtr}, x::Union{Array{T}, SubArray{T}, Ref{T}}) where T + CConvWrapper(Ptr{T}, x) +end +function Base.cconvert(::Type{MPIPtr}, x::String) + CConvWrapper(Ptr{UInt8}, x) +end -Base.cconvert(::Type{MPIPtr}, x::String) = x -Base.unsafe_convert(::Type{MPIPtr}, x::String) = reinterpret(MPIPtr, pointer(x)) +Base.cconvert(::Type{MPIPtr}, ptr::Ptr) = reinterpret(MPIPtr, ptr) Base.cconvert(::Type{MPIPtr}, ::Nothing) = reinterpret(MPIPtr, C_NULL) @@ -45,7 +56,7 @@ MPIPtr struct InPlace end -Base.cconvert(::Type{MPIPtr}, ::InPlace) = API.MPI_IN_PLACE[] +Base.cconvert(::Type{MPIPtr}, ::InPlace) = reinterpret(MPIPtr, API.MPI_IN_PLACE[]) """ From 20e23f727fb4e2f143f89a232070bd6368ddd825 Mon Sep 17 00:00:00 2001 From: nhz2 Date: Sat, 28 Mar 2026 13:53:42 -0400 Subject: [PATCH 2/3] use Type T in CConvWrapper constructor --- src/buffers.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/buffers.jl b/src/buffers.jl index 475c365dc..89305ef37 100644 --- a/src/buffers.jl +++ b/src/buffers.jl @@ -4,7 +4,7 @@ MPIBuffertypeOrConst{T} = Union{MPIBuffertype{T}, SentinelPtr} struct CConvWrapper{T, C} cconv::C end -function CConvWrapper(T, x) +function CConvWrapper(::Type{T}, x) where T cconv = Base.cconvert(T, x) CConvWrapper{T, typeof(cconv)}(cconv) end From c76c03d472ea9100a7da2855214e9ff8b2d0be6b Mon Sep 17 00:00:00 2001 From: nhz2 Date: Sun, 29 Mar 2026 17:38:53 -0400 Subject: [PATCH 3/3] Add comments --- src/buffers.jl | 48 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/src/buffers.jl b/src/buffers.jl index 89305ef37..071c3f2d0 100644 --- a/src/buffers.jl +++ b/src/buffers.jl @@ -1,19 +1,63 @@ MPIBuffertype{T} = Union{Ptr{T}, Array{T}, SubArray{T}, Ref{T}} MPIBuffertypeOrConst{T} = Union{MPIBuffertype{T}, SentinelPtr} +# CConvWrapper: GC-safe adapter for converting Julia objects to MPIPtr in ccall. +# +# Background: ccall's argument conversion protocol works in two steps: +# 1. cconvert(T, x) — called before the ccall. Its return value is GC-rooted +# by ccall for the duration of the foreign call, keeping the underlying +# Julia object alive while a pointer to it is in use. +# 2. unsafe_convert(T, result_of_cconvert) — called on the GC-rooted result +# to extract the raw pointer. Crucially, dispatch is on the *return type* +# of cconvert, not the original argument type. +# +# Problem: because unsafe_convert dispatches on the cconvert return type, the +# unsafe_convert(::Type{MPIPtr}, ...) method must match whatever cconvert +# returned. If cconvert delegates to e.g. Base.cconvert(Ptr{T}, x), the return +# type depends on the Base implementation, so an unsafe_convert method written +# for the original type will never be called. +# +# Solution: CConvWrapper provides a single, predictable return type from +# cconvert(MPIPtr, x). The conversion proceeds as: +# +# ccall argument x::Array{Float64} +# │ +# ▼ +# cconvert(MPIPtr, x) +# calls Base.cconvert(Ptr{Float64}, x) — returns the Array (kept alive) +# wraps it in CConvWrapper{Ptr{Float64}}(array) +# ◄── ccall GC-roots this CConvWrapper, which holds the Array +# │ +# ▼ +# unsafe_convert(MPIPtr, wrapper::CConvWrapper{Ptr{Float64}}) +# calls Base.unsafe_convert(Ptr{Float64}, wrapper.cconv) — extracts raw ptr +# reinterprets to MPIPtr +# ◄── only called while ccall holds the GC root on the wrapper +# +# Types that don't need GC protection (Ptr, Nothing, InPlace, SentinelPtr) skip +# the wrapper and return an MPIPtr directly from cconvert, since they are plain +# bit types with no GC-managed backing memory. struct CConvWrapper{T, C} - cconv::C + # T: the intermediate pointer type (e.g. Ptr{Float64}, CuPtr{Float64}) + # C: the type of the GC-rooted cconvert result (e.g. Array{Float64,1}) + cconv::C # the GC-rooted object — kept alive by ccall holding the wrapper end function CConvWrapper(::Type{T}, x) where T + # Delegate to Base.cconvert(T, x) to get the GC-rootable object, then wrap + # it so unsafe_convert dispatch is predictable. cconv = Base.cconvert(T, x) CConvWrapper{T, typeof(cconv)}(cconv) end function Base.unsafe_convert(::Type{MPIPtr}, x::CConvWrapper{T}) where T + # Called by ccall while x (and thus x.cconv) is GC-rooted. + # Delegate to the Base pointer extraction, then reinterpret to MPIPtr. ptr = Base.unsafe_convert(T, x.cconv) reinterpret(MPIPtr, ptr) end +# --- cconvert methods for types with GC-managed memory (use CConvWrapper) --- + function Base.cconvert(::Type{MPIPtr}, x::Union{Array{T}, SubArray{T}, Ref{T}}) where T CConvWrapper(Ptr{T}, x) end @@ -21,6 +65,8 @@ function Base.cconvert(::Type{MPIPtr}, x::String) CConvWrapper(Ptr{UInt8}, x) end +# --- cconvert methods for plain bit types (no GC protection needed) --- + Base.cconvert(::Type{MPIPtr}, ptr::Ptr) = reinterpret(MPIPtr, ptr) Base.cconvert(::Type{MPIPtr}, ::Nothing) = reinterpret(MPIPtr, C_NULL)