diff --git a/mlx/backend/metal/eval.cpp b/mlx/backend/metal/eval.cpp index 6f55976efe..603f1c83c2 100644 --- a/mlx/backend/metal/eval.cpp +++ b/mlx/backend/metal/eval.cpp @@ -18,18 +18,37 @@ void new_stream(Stream s) { encoders.try_emplace(s.index, d, s.index, d.residency_set()); } -inline void check_error(MTL::CommandBuffer* cbuf) { +// Capture (don't throw) a Metal command-buffer error from an async +// completion handler. The handler runs on a Metal-managed dispatch +// thread; throwing through Objective-C frames hits _objc_terminate -> +// abort() and crashes the whole process. Storing the message and +// re-throwing on the next user-thread eval()/finalize() call keeps +// the error catchable by Python / C++ callers. +inline void capture_async_error(const Stream& s, MTL::CommandBuffer* cbuf) { if (cbuf->status() == MTL::CommandBufferStatusError) { std::ostringstream msg; msg << "[METAL] Command buffer execution failed: " << cbuf->error()->localizedDescription()->utf8String(); - throw std::runtime_error(msg.str()); + scheduler::capture_error(s, msg.str()); + } +} + +// Re-throw any previously-captured async Metal error on the calling +// thread. Called at the entry of every user-facing eval()/finalize() +// so failures from the prior step surface before more work is queued. +inline void throw_if_captured(const Stream& s) { + auto msg = scheduler::take_error(s); + if (!msg.empty()) { + throw std::runtime_error(msg); } } void eval(array& arr) { auto pool = metal::new_scoped_memory_pool(); auto s = arr.primitive().stream(); + // Surface any async Metal error captured by a prior step's completion + // handler before queuing more work on this stream. + throw_if_captured(s); auto& encoder = metal::get_command_encoder(s); auto* command_buffer = encoder.get_command_buffer(); @@ -63,32 +82,45 @@ void eval(array& arr) { command_buffer->addCompletedHandler( [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { scheduler::notify_task_completion(s); - check_error(cbuf); + capture_async_error(s, cbuf); }); encoder.commit(); } else { command_buffer->addCompletedHandler( - [buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { - check_error(cbuf); + [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { + capture_async_error(s, cbuf); }); } } void finalize(Stream s) { + // Surface any prior async Metal error before sealing the stream. + throw_if_captured(s); auto pool = metal::new_scoped_memory_pool(); auto& encoder = metal::get_command_encoder(s); auto* cb = encoder.get_command_buffer(); encoder.end_encoding(); - cb->addCompletedHandler([](MTL::CommandBuffer* cbuf) { check_error(cbuf); }); + cb->addCompletedHandler( + [s](MTL::CommandBuffer* cbuf) { capture_async_error(s, cbuf); }); encoder.commit(); } void synchronize(Stream s) { metal::get_command_encoder(s).synchronize(); + // CommandEncoder::synchronize only checks the FINAL command buffer. + // An async failure on an earlier (already-completed) command buffer + // landed in the captured-error slot — surface it here too so + // synchronize() is a true "all my queued work succeeded" guarantee. + throw_if_captured(s); } void clear_streams() { metal::get_command_encoders().clear(); + // After tearing down encoder state, also clear any poisoned-stream + // flags so the streams are usable again. (The flags would otherwise + // outlive the encoders they refer to and refuse all subsequent + // work — see scheduler::StreamThread.) + scheduler::reset_all_errors(); } } // namespace mlx::core::gpu diff --git a/mlx/scheduler.h b/mlx/scheduler.h index c84ab62855..3988a1d025 100644 --- a/mlx/scheduler.h +++ b/mlx/scheduler.h @@ -23,6 +23,58 @@ struct StreamThread { bool stop; std::thread thread; + // Errors raised from async GPU completion handlers can't safely + // propagate as C++ exceptions through Metal's dispatch infrastructure + // (an exception unwinding through Objective-C frames hits + // _objc_terminate -> abort()). The handler captures the message + // here and the next user-thread eval()/finalize() call re-throws + // it on a thread the language runtime can handle. + // + // To address the state-safety concern raised in upstream issue + // #2670 ("we don't have guarantees on the state being in a + // reasonable condition if there is an exception during eval"), + // the stream is also POISONED on capture. Once poisoned, every + // subsequent user-thread entry refuses to queue more work and + // throws — the caller must explicitly reset the stream + // (mx.clear_streams()) to resume. This ensures no further + // operations execute against potentially-corrupt encoder state. + std::mutex error_mtx; + std::string captured_error; + bool poisoned{false}; + + void capture_error(std::string msg) { + std::lock_guard lk(error_mtx); + if (captured_error.empty()) { + captured_error = std::move(msg); + } + poisoned = true; + } + + // Returns the captured error string (and clears it) on the first + // call after poisoning. Subsequent calls return a generic + // "stream poisoned" message until the stream is explicitly reset + // via reset_error(). + std::string take_error() { + std::lock_guard lk(error_mtx); + if (!poisoned) { + return {}; + } + if (!captured_error.empty()) { + auto out = std::move(captured_error); + captured_error.clear(); + return out; + } + return "[METAL] Stream is in error state from a prior failure. " + "Call mx.clear_streams() (or destroy this Stream) before " + "queuing more work."; + } + + void reset_error() { + std::lock_guard lk(error_mtx); + captured_error.clear(); + poisoned = false; + } + StreamThread() : stop(false), thread(&StreamThread::thread_fn, this) {} ~StreamThread() { @@ -93,6 +145,48 @@ class MLX_API Scheduler { completion_cv.notify_all(); } + // Capture an error from an async callback for the given stream's + // worker thread. Safe to call from Metal completion handlers — never + // throws, never blocks on user threads. Silently no-ops if the + // stream's thread doesn't exist (already torn down). + void capture_error(const Stream& stream, std::string msg) { + std::shared_lock lock(threads_mtx_); + auto it = threads_.find(stream.index); + if (it != threads_.end()) { + it->second->capture_error(std::move(msg)); + } + } + + // Take + clear the captured error for a stream. Returns empty string + // when no error is pending. Stream stays poisoned until reset_error. + std::string take_error(const Stream& stream) { + std::shared_lock lock(threads_mtx_); + auto it = threads_.find(stream.index); + if (it != threads_.end()) { + return it->second->take_error(); + } + return {}; + } + + // Clear the poisoned flag for a stream. Called by mx.clear_streams() + // (and any future reset API) so the stream is usable again. The + // caller is expected to also have torn down the underlying GPU + // encoder state before clearing the flag. + void reset_error(const Stream& stream) { + std::shared_lock lock(threads_mtx_); + auto it = threads_.find(stream.index); + if (it != threads_.end()) { + it->second->reset_error(); + } + } + + void reset_all_errors() { + std::shared_lock lock(threads_mtx_); + for (auto& [_, st] : threads_) { + st->reset_error(); + } + } + int n_active_tasks() const { return n_active_tasks_; } @@ -136,6 +230,22 @@ inline void notify_task_completion(const Stream& stream) { scheduler().notify_task_completion(stream); } +inline void capture_error(const Stream& stream, std::string msg) { + scheduler().capture_error(stream, std::move(msg)); +} + +inline std::string take_error(const Stream& stream) { + return scheduler().take_error(stream); +} + +inline void reset_error(const Stream& stream) { + scheduler().reset_error(stream); +} + +inline void reset_all_errors() { + scheduler().reset_all_errors(); +} + inline void wait_for_one() { scheduler().wait_for_one(); }