Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions ext/AdaptiveArrayPoolsCUDAExt/debug.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
using AdaptiveArrayPools: _runtime_check, _validate_pool_return,
_set_pending_callsite!, _maybe_record_borrow!,
_invalidate_released_slots!, _check_wrapper_mutation!, _zero_dims_tuple,
_throw_pool_escape_error,
_throw_pool_escape_error, _scope_boundary,
PoolRuntimeEscapeError

# ==============================================================================
Expand Down Expand Up @@ -289,23 +289,27 @@ function _check_cuda_pointer_overlap(arr::CuArray, pool::CuAdaptiveArrayPool, or
isempty(rs) ? nothing : rs
end

current_depth = pool._current_depth

# Check fixed slots
AdaptiveArrayPools.foreach_fixed_slot(pool) do tp
_check_tp_cuda_overlap(tp, arr_ptr, arr_end, pool, return_site, original_val)
_check_tp_cuda_overlap(tp, arr_ptr, arr_end, current_depth, pool, return_site, original_val)
end

# Check others
for tp in values(pool.others)
_check_tp_cuda_overlap(tp, arr_ptr, arr_end, pool, return_site, original_val)
_check_tp_cuda_overlap(tp, arr_ptr, arr_end, current_depth, pool, return_site, original_val)
end
return
end

@noinline function _check_tp_cuda_overlap(
tp::AbstractTypedPool, arr_ptr::UInt, arr_end::UInt,
pool::CuAdaptiveArrayPool, return_site, original_val
current_depth::Int, pool::CuAdaptiveArrayPool, return_site, original_val
)
for v in tp.vectors
boundary = _scope_boundary(tp, current_depth)
for i in (boundary + 1):tp.n_active
v = @inbounds tp.vectors[i]
v_ptr = UInt(pointer(v))
v_bytes = length(v) * sizeof(eltype(v))
v_end = v_ptr + v_bytes
Expand Down
14 changes: 9 additions & 5 deletions ext/AdaptiveArrayPoolsMetalExt/debug.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
using AdaptiveArrayPools: _runtime_check, _validate_pool_return,
_set_pending_callsite!, _maybe_record_borrow!,
_invalidate_released_slots!, _check_wrapper_mutation!, _zero_dims_tuple,
_throw_pool_escape_error,
_throw_pool_escape_error, _scope_boundary,
PoolRuntimeEscapeError

# ==============================================================================
Expand Down Expand Up @@ -289,23 +289,27 @@ function _check_metal_overlap(arr::MtlArray, pool::MetalAdaptiveArrayPool, origi
isempty(rs) ? nothing : rs
end

current_depth = pool._current_depth

# Check fixed slots
AdaptiveArrayPools.foreach_fixed_slot(pool) do tp
_check_tp_metal_overlap(tp, arr_buf, arr_off, arr_end, pool, return_site, original_val)
_check_tp_metal_overlap(tp, arr_buf, arr_off, arr_end, current_depth, pool, return_site, original_val)
end

# Check others
for tp in values(pool.others)
_check_tp_metal_overlap(tp, arr_buf, arr_off, arr_end, pool, return_site, original_val)
_check_tp_metal_overlap(tp, arr_buf, arr_off, arr_end, current_depth, pool, return_site, original_val)
end
return
end

@noinline function _check_tp_metal_overlap(
tp::AbstractTypedPool, abuf, aoff::Int, aend::Int,
pool::MetalAdaptiveArrayPool, return_site, original_val
current_depth::Int, pool::MetalAdaptiveArrayPool, return_site, original_val
)
for v in tp.vectors
boundary = _scope_boundary(tp, current_depth)
for i in (boundary + 1):tp.n_active
v = @inbounds tp.vectors[i]
vptr = pointer(v)
vbuf = vptr.buffer
voff = Int(vptr.offset)
Expand Down
8 changes: 6 additions & 2 deletions src/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ end
# Safety Validation (S=1 runtime check mode)
# ==============================================================================

# Check if BitArray chunks overlap with the pool's BitTypedPool storage
# Check if BitArray chunks overlap with pool's BitTypedPool storage
# (scope-aware: only checks vectors acquired in the current scope)
function _check_bitchunks_overlap(arr::BitArray, pool::AdaptiveArrayPool, original_val = arr)
arr_chunks = arr.chunks
arr_ptr = UInt(pointer(arr_chunks))
Expand All @@ -191,7 +192,10 @@ function _check_bitchunks_overlap(arr::BitArray, pool::AdaptiveArrayPool, origin
isempty(rs) ? nothing : rs
end

for v in pool.bits.vectors
tp = pool.bits
boundary = _scope_boundary(tp, pool._current_depth)
for i in (boundary + 1):tp.n_active
v = @inbounds tp.vectors[i]
v_chunks = v.chunks
v_ptr = UInt(pointer(v_chunks))
v_len = length(v_chunks) * sizeof(UInt64)
Expand Down
40 changes: 36 additions & 4 deletions src/debug.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,17 @@ _eltype_may_contain_arrays(::Type{Symbol}) = false
_eltype_may_contain_arrays(::Type{Char}) = false
_eltype_may_contain_arrays(::Type) = true

# Check if array memory overlaps with any pool vector.
# Scope-aware boundary: returns the n_active saved at checkpoint for `depth`.
# Vectors with index <= boundary belong to an outer scope and are NOT escapees.
# If this type has no checkpoint at `depth`, it was never touched in this scope → all safe.
@inline function _scope_boundary(tp::AbstractTypedPool, depth::Int)
@inbounds if tp._checkpoint_depths[end] == depth
return tp._checkpoint_n_active[end] # vectors[1:boundary] are from outer scopes
end
return tp.n_active # no checkpoint at this depth → nothing acquired here → all safe
end

# Check if array memory overlaps with any pool vector **acquired in the current scope**.
# `original_val` is the user-visible value (e.g., SubArray) for error reporting;
# `arr` may be its parent Array used for the actual pointer comparison.
function _check_pointer_overlap(arr::Array, pool::AdaptiveArrayPool, original_val = arr)
Expand All @@ -78,8 +88,12 @@ function _check_pointer_overlap(arr::Array, pool::AdaptiveArrayPool, original_va
isempty(rs) ? nothing : rs
end

current_depth = pool._current_depth

check_overlap = function (tp)
for v in tp.vectors
boundary = _scope_boundary(tp, current_depth)
for i in (boundary + 1):tp.n_active
v = @inbounds tp.vectors[i]
v isa Array || continue # Skip BitVector (no pointer(); checked via _check_bitchunks_overlap)
v_ptr = UInt(pointer(v))
v_len = length(v) * sizeof(eltype(v))
Expand Down Expand Up @@ -260,9 +274,27 @@ end
_poison_value(::Type{T}) where {T <: AbstractFloat} = T(NaN)
_poison_value(::Type{T}) where {T <: Integer} = typemax(T)
_poison_value(::Type{Complex{T}}) where {T} = Complex{T}(_poison_value(T), _poison_value(T))
_poison_value(::Type{T}) where {T} = zero(T) # generic fallback
_poison_value(::Type{T}) where {T} = zero(T) # generic fallback (Rational, etc.)

_poison_fill!(v::Vector{T}) where {T} = fill!(v, _poison_value(T))
function _poison_fill!(v::Vector{T}) where {T}
isempty(v) && return nothing
if !isbitstype(T)
# non-isbits (reference types): skip poison, resize!(v, 0) handles invalidation
return nothing
end
# isbits: try _poison_value dispatch (NaN, typemax, zero for known types),
# then duck-type 0 * first(v) for custom structs without zero(T).
# If neither works, skip poisoning — must not throw during rewind.
try
fill!(v, _poison_value(T))
catch
try
fill!(v, 0 * first(v))
catch
end
end
return nothing
end
_poison_fill!(v::BitVector) = fill!(v, true)

"""
Expand Down
8 changes: 6 additions & 2 deletions src/legacy/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,14 +220,18 @@ end
# Safety Validation (S=1 runtime check mode)
# ==============================================================================

# Check if BitArray chunks overlap with the pool's BitTypedPool storage
# Check if BitArray chunks overlap with pool's BitTypedPool storage
# (scope-aware: only checks vectors acquired in the current scope)
function _check_bitchunks_overlap(arr::BitArray, pool::AdaptiveArrayPool, original_val = arr)
arr_chunks = arr.chunks
arr_ptr = UInt(pointer(arr_chunks))
arr_len = length(arr_chunks) * sizeof(UInt64)
arr_end = arr_ptr + arr_len

for v in pool.bits.vectors
tp = pool.bits
boundary = _scope_boundary(tp, pool._current_depth)
for i in (boundary + 1):tp.n_active
v = @inbounds tp.vectors[i]
v_chunks = v.chunks
v_ptr = UInt(pointer(v_chunks))
v_len = length(v_chunks) * sizeof(UInt64)
Expand Down
18 changes: 12 additions & 6 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -764,15 +764,18 @@ function _generate_block_inner(pool_name, expr, safe::Bool, source)
local $(esc(entry_depth_var)) = $(esc(pool_name))._current_depth
$checkpoint_call
local _result = $(esc(transformed_expr))
if $_RUNTIME_CHECK_REF($(esc(pool_name)))
$_validate_pool_return(_result, $(esc(pool_name)))
end
# Leaked scope cleanup BEFORE validation: if an inner @with_pool threw
# without rewind, _current_depth is still the inner depth. Validation
# uses _current_depth via _scope_boundary, so we must normalize first.
if $_RUNTIME_CHECK_REF($(esc(pool_name))) && $(esc(pool_name))._current_depth > $(esc(entry_depth_var)) + 1
$_WARN_LEAKED_SCOPE_REF($(esc(pool_name)), $(esc(entry_depth_var)))
end
while $(esc(pool_name))._current_depth > $(esc(entry_depth_var)) + 1
$_REWIND_REF($(esc(pool_name)))
end
if $_RUNTIME_CHECK_REF($(esc(pool_name)))
$_validate_pool_return(_result, $(esc(pool_name)))
end
$rewind_call
_result
end
Expand Down Expand Up @@ -839,15 +842,18 @@ function _generate_function_inner(pool_name, expr, safe::Bool, source)
local $(esc(entry_depth_var)) = $(esc(pool_name))._current_depth
$checkpoint_call
local _result = $(esc(transformed_expr))
if $_RUNTIME_CHECK_REF($(esc(pool_name)))
$_validate_pool_return(_result, $(esc(pool_name)))
end
# Leaked scope cleanup BEFORE validation: if an inner @with_pool threw
# without rewind, _current_depth is still the inner depth. Validation
# uses _current_depth via _scope_boundary, so we must normalize first.
if $_RUNTIME_CHECK_REF($(esc(pool_name))) && $(esc(pool_name))._current_depth > $(esc(entry_depth_var)) + 1
$_WARN_LEAKED_SCOPE_REF($(esc(pool_name)), $(esc(entry_depth_var)))
end
while $(esc(pool_name))._current_depth > $(esc(entry_depth_var)) + 1
$_REWIND_REF($(esc(pool_name)))
end
if $_RUNTIME_CHECK_REF($(esc(pool_name)))
$_validate_pool_return(_result, $(esc(pool_name)))
end
$rewind_call
_result
end
Expand Down
1 change: 1 addition & 0 deletions test/cuda/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,6 @@ else
include("test_disabled_pool.jl")
include("test_cuda_safety.jl")
include("test_runtime_mutation.jl")
include("test_scope_depth_validation.jl")
end
end
Loading
Loading