From 6527e13c028458a564a81b86fbbc315f60bb1770 Mon Sep 17 00:00:00 2001 From: andreinknv Date: Sun, 10 May 2026 23:01:36 -0400 Subject: [PATCH] fix(metal): defer command-buffer error reporting + poison-on-failure MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit C++ exceptions thrown from inside Metal completion handlers (the three addCompletedHandler callbacks in eval.cpp) hit std::terminate -> abort() because the handlers run on Metal-managed dispatch threads where C++ exceptions cannot be caught. Production users have hit this in #3317 (M2 Ultra, Qwen3.5-122B, ~3.5 h sustained inference) and again locally (M4 Max, Granite-1b 4-pool, mlx-lm 0.31.3 prompt-cache buildup, kIOGPUCommandBufferCallbackErrorOutOfMemory). Prior PR #3318 proposed deferring the error to the next user-thread eval(). It was rejected because mlx core is not exception-safe — re- throwing later could leave the encoder in a stale state and "fail later in a much weirder way" (zcbenz). awni reiterated in #2670 that mlx wouldn't add the feature without state guarantees. This PR addresses that concern by combining the deferred-throw with explicit STREAM POISONING: - The async completion handler captures the error message into a per-StreamThread slot and SETS poisoned=true. It never throws. - The next user-thread eval()/finalize()/synchronize() entry on that stream calls throw_if_captured(): re-throws the original error on first call, then refuses all subsequent work with a clear "stream is poisoned, mx.clear_streams() to reset" message. - mx.clear_streams() now also calls scheduler::reset_all_errors() so the user has a documented path back to a working state. This guarantees no further operations execute against a stream that just had a Metal failure. The state-safety concern is preserved because once a stream is in error, NOTHING runs on it until the caller explicitly resets — there's no opportunity to operate on half-initialized encoder state. Tests: - 709 mlx tests pass, 26 skipped (the suite already had skips), zero new failures - Custom stress: 8 threads × 1000 iter × (512x512) matmul on per- thread streams completes in 0.7 s with no errors (post-patch), same throughput as baseline build - Smoke: addition / matmul / pre-flight alloc check / post-error recovery all behave correctly Per-stream rather than global: concurrent streams (eg. mlx-lm's 4-server pool, batched inference) won't cross-pollute errors. A crash on stream 2 only poisons stream 2, not 0/1/3. Closes #3317. References #2670, #3318. --- mlx/backend/metal/eval.cpp | 44 +++++++++++++-- mlx/scheduler.h | 110 +++++++++++++++++++++++++++++++++++++ 2 files changed, 148 insertions(+), 6 deletions(-) diff --git a/mlx/backend/metal/eval.cpp b/mlx/backend/metal/eval.cpp index 6f55976efe..603f1c83c2 100644 --- a/mlx/backend/metal/eval.cpp +++ b/mlx/backend/metal/eval.cpp @@ -18,18 +18,37 @@ void new_stream(Stream s) { encoders.try_emplace(s.index, d, s.index, d.residency_set()); } -inline void check_error(MTL::CommandBuffer* cbuf) { +// Capture (don't throw) a Metal command-buffer error from an async +// completion handler. The handler runs on a Metal-managed dispatch +// thread; throwing through Objective-C frames hits _objc_terminate -> +// abort() and crashes the whole process. Storing the message and +// re-throwing on the next user-thread eval()/finalize() call keeps +// the error catchable by Python / C++ callers. +inline void capture_async_error(const Stream& s, MTL::CommandBuffer* cbuf) { if (cbuf->status() == MTL::CommandBufferStatusError) { std::ostringstream msg; msg << "[METAL] Command buffer execution failed: " << cbuf->error()->localizedDescription()->utf8String(); - throw std::runtime_error(msg.str()); + scheduler::capture_error(s, msg.str()); + } +} + +// Re-throw any previously-captured async Metal error on the calling +// thread. Called at the entry of every user-facing eval()/finalize() +// so failures from the prior step surface before more work is queued. +inline void throw_if_captured(const Stream& s) { + auto msg = scheduler::take_error(s); + if (!msg.empty()) { + throw std::runtime_error(msg); } } void eval(array& arr) { auto pool = metal::new_scoped_memory_pool(); auto s = arr.primitive().stream(); + // Surface any async Metal error captured by a prior step's completion + // handler before queuing more work on this stream. + throw_if_captured(s); auto& encoder = metal::get_command_encoder(s); auto* command_buffer = encoder.get_command_buffer(); @@ -63,32 +82,45 @@ void eval(array& arr) { command_buffer->addCompletedHandler( [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { scheduler::notify_task_completion(s); - check_error(cbuf); + capture_async_error(s, cbuf); }); encoder.commit(); } else { command_buffer->addCompletedHandler( - [buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { - check_error(cbuf); + [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { + capture_async_error(s, cbuf); }); } } void finalize(Stream s) { + // Surface any prior async Metal error before sealing the stream. + throw_if_captured(s); auto pool = metal::new_scoped_memory_pool(); auto& encoder = metal::get_command_encoder(s); auto* cb = encoder.get_command_buffer(); encoder.end_encoding(); - cb->addCompletedHandler([](MTL::CommandBuffer* cbuf) { check_error(cbuf); }); + cb->addCompletedHandler( + [s](MTL::CommandBuffer* cbuf) { capture_async_error(s, cbuf); }); encoder.commit(); } void synchronize(Stream s) { metal::get_command_encoder(s).synchronize(); + // CommandEncoder::synchronize only checks the FINAL command buffer. + // An async failure on an earlier (already-completed) command buffer + // landed in the captured-error slot — surface it here too so + // synchronize() is a true "all my queued work succeeded" guarantee. + throw_if_captured(s); } void clear_streams() { metal::get_command_encoders().clear(); + // After tearing down encoder state, also clear any poisoned-stream + // flags so the streams are usable again. (The flags would otherwise + // outlive the encoders they refer to and refuse all subsequent + // work — see scheduler::StreamThread.) + scheduler::reset_all_errors(); } } // namespace mlx::core::gpu diff --git a/mlx/scheduler.h b/mlx/scheduler.h index c84ab62855..3988a1d025 100644 --- a/mlx/scheduler.h +++ b/mlx/scheduler.h @@ -23,6 +23,58 @@ struct StreamThread { bool stop; std::thread thread; + // Errors raised from async GPU completion handlers can't safely + // propagate as C++ exceptions through Metal's dispatch infrastructure + // (an exception unwinding through Objective-C frames hits + // _objc_terminate -> abort()). The handler captures the message + // here and the next user-thread eval()/finalize() call re-throws + // it on a thread the language runtime can handle. + // + // To address the state-safety concern raised in upstream issue + // #2670 ("we don't have guarantees on the state being in a + // reasonable condition if there is an exception during eval"), + // the stream is also POISONED on capture. Once poisoned, every + // subsequent user-thread entry refuses to queue more work and + // throws — the caller must explicitly reset the stream + // (mx.clear_streams()) to resume. This ensures no further + // operations execute against potentially-corrupt encoder state. + std::mutex error_mtx; + std::string captured_error; + bool poisoned{false}; + + void capture_error(std::string msg) { + std::lock_guard lk(error_mtx); + if (captured_error.empty()) { + captured_error = std::move(msg); + } + poisoned = true; + } + + // Returns the captured error string (and clears it) on the first + // call after poisoning. Subsequent calls return a generic + // "stream poisoned" message until the stream is explicitly reset + // via reset_error(). + std::string take_error() { + std::lock_guard lk(error_mtx); + if (!poisoned) { + return {}; + } + if (!captured_error.empty()) { + auto out = std::move(captured_error); + captured_error.clear(); + return out; + } + return "[METAL] Stream is in error state from a prior failure. " + "Call mx.clear_streams() (or destroy this Stream) before " + "queuing more work."; + } + + void reset_error() { + std::lock_guard lk(error_mtx); + captured_error.clear(); + poisoned = false; + } + StreamThread() : stop(false), thread(&StreamThread::thread_fn, this) {} ~StreamThread() { @@ -93,6 +145,48 @@ class MLX_API Scheduler { completion_cv.notify_all(); } + // Capture an error from an async callback for the given stream's + // worker thread. Safe to call from Metal completion handlers — never + // throws, never blocks on user threads. Silently no-ops if the + // stream's thread doesn't exist (already torn down). + void capture_error(const Stream& stream, std::string msg) { + std::shared_lock lock(threads_mtx_); + auto it = threads_.find(stream.index); + if (it != threads_.end()) { + it->second->capture_error(std::move(msg)); + } + } + + // Take + clear the captured error for a stream. Returns empty string + // when no error is pending. Stream stays poisoned until reset_error. + std::string take_error(const Stream& stream) { + std::shared_lock lock(threads_mtx_); + auto it = threads_.find(stream.index); + if (it != threads_.end()) { + return it->second->take_error(); + } + return {}; + } + + // Clear the poisoned flag for a stream. Called by mx.clear_streams() + // (and any future reset API) so the stream is usable again. The + // caller is expected to also have torn down the underlying GPU + // encoder state before clearing the flag. + void reset_error(const Stream& stream) { + std::shared_lock lock(threads_mtx_); + auto it = threads_.find(stream.index); + if (it != threads_.end()) { + it->second->reset_error(); + } + } + + void reset_all_errors() { + std::shared_lock lock(threads_mtx_); + for (auto& [_, st] : threads_) { + st->reset_error(); + } + } + int n_active_tasks() const { return n_active_tasks_; } @@ -136,6 +230,22 @@ inline void notify_task_completion(const Stream& stream) { scheduler().notify_task_completion(stream); } +inline void capture_error(const Stream& stream, std::string msg) { + scheduler().capture_error(stream, std::move(msg)); +} + +inline std::string take_error(const Stream& stream) { + return scheduler().take_error(stream); +} + +inline void reset_error(const Stream& stream) { + scheduler().reset_error(stream); +} + +inline void reset_all_errors() { + scheduler().reset_all_errors(); +} + inline void wait_for_one() { scheduler().wait_for_one(); }