Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 38 additions & 6 deletions mlx/backend/metal/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,37 @@ void new_stream(Stream s) {
encoders.try_emplace(s.index, d, s.index, d.residency_set());
}

inline void check_error(MTL::CommandBuffer* cbuf) {
// Capture (don't throw) a Metal command-buffer error from an async
// completion handler. The handler runs on a Metal-managed dispatch
// thread; throwing through Objective-C frames hits _objc_terminate ->
// abort() and crashes the whole process. Storing the message and
// re-throwing on the next user-thread eval()/finalize() call keeps
// the error catchable by Python / C++ callers.
inline void capture_async_error(const Stream& s, MTL::CommandBuffer* cbuf) {
if (cbuf->status() == MTL::CommandBufferStatusError) {
std::ostringstream msg;
msg << "[METAL] Command buffer execution failed: "
<< cbuf->error()->localizedDescription()->utf8String();
throw std::runtime_error(msg.str());
scheduler::capture_error(s, msg.str());
}
}

// Re-throw any previously-captured async Metal error on the calling
// thread. Called at the entry of every user-facing eval()/finalize()
// so failures from the prior step surface before more work is queued.
inline void throw_if_captured(const Stream& s) {
auto msg = scheduler::take_error(s);
if (!msg.empty()) {
throw std::runtime_error(msg);
}
}

void eval(array& arr) {
auto pool = metal::new_scoped_memory_pool();
auto s = arr.primitive().stream();
// Surface any async Metal error captured by a prior step's completion
// handler before queuing more work on this stream.
throw_if_captured(s);
auto& encoder = metal::get_command_encoder(s);
auto* command_buffer = encoder.get_command_buffer();

Expand Down Expand Up @@ -63,32 +82,45 @@ void eval(array& arr) {
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
scheduler::notify_task_completion(s);
check_error(cbuf);
capture_async_error(s, cbuf);
});
encoder.commit();
} else {
command_buffer->addCompletedHandler(
[buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
check_error(cbuf);
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
capture_async_error(s, cbuf);
});
}
}

void finalize(Stream s) {
// Surface any prior async Metal error before sealing the stream.
throw_if_captured(s);
auto pool = metal::new_scoped_memory_pool();
auto& encoder = metal::get_command_encoder(s);
auto* cb = encoder.get_command_buffer();
encoder.end_encoding();
cb->addCompletedHandler([](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
cb->addCompletedHandler(
[s](MTL::CommandBuffer* cbuf) { capture_async_error(s, cbuf); });
encoder.commit();
}

void synchronize(Stream s) {
metal::get_command_encoder(s).synchronize();
// CommandEncoder::synchronize only checks the FINAL command buffer.
// An async failure on an earlier (already-completed) command buffer
// landed in the captured-error slot — surface it here too so
// synchronize() is a true "all my queued work succeeded" guarantee.
throw_if_captured(s);
}

void clear_streams() {
metal::get_command_encoders().clear();
// After tearing down encoder state, also clear any poisoned-stream
// flags so the streams are usable again. (The flags would otherwise
// outlive the encoders they refer to and refuse all subsequent
// work — see scheduler::StreamThread.)
scheduler::reset_all_errors();
}

} // namespace mlx::core::gpu
110 changes: 110 additions & 0 deletions mlx/scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,58 @@ struct StreamThread {
bool stop;
std::thread thread;

// Errors raised from async GPU completion handlers can't safely
// propagate as C++ exceptions through Metal's dispatch infrastructure
// (an exception unwinding through Objective-C frames hits
// _objc_terminate -> abort()). The handler captures the message
// here and the next user-thread eval()/finalize() call re-throws
// it on a thread the language runtime can handle.
//
// To address the state-safety concern raised in upstream issue
// #2670 ("we don't have guarantees on the state being in a
// reasonable condition if there is an exception during eval"),
// the stream is also POISONED on capture. Once poisoned, every
// subsequent user-thread entry refuses to queue more work and
// throws — the caller must explicitly reset the stream
// (mx.clear_streams()) to resume. This ensures no further
// operations execute against potentially-corrupt encoder state.
std::mutex error_mtx;
std::string captured_error;
bool poisoned{false};

void capture_error(std::string msg) {
std::lock_guard<std::mutex> lk(error_mtx);
if (captured_error.empty()) {
captured_error = std::move(msg);
}
poisoned = true;
}

// Returns the captured error string (and clears it) on the first
// call after poisoning. Subsequent calls return a generic
// "stream poisoned" message until the stream is explicitly reset
// via reset_error().
std::string take_error() {
std::lock_guard<std::mutex> lk(error_mtx);
if (!poisoned) {
return {};
}
if (!captured_error.empty()) {
auto out = std::move(captured_error);
captured_error.clear();
return out;
}
return "[METAL] Stream is in error state from a prior failure. "
"Call mx.clear_streams() (or destroy this Stream) before "
"queuing more work.";
}

void reset_error() {
std::lock_guard<std::mutex> lk(error_mtx);
captured_error.clear();
poisoned = false;
}

StreamThread() : stop(false), thread(&StreamThread::thread_fn, this) {}

~StreamThread() {
Expand Down Expand Up @@ -93,6 +145,48 @@ class MLX_API Scheduler {
completion_cv.notify_all();
}

// Capture an error from an async callback for the given stream's
// worker thread. Safe to call from Metal completion handlers — never
// throws, never blocks on user threads. Silently no-ops if the
// stream's thread doesn't exist (already torn down).
void capture_error(const Stream& stream, std::string msg) {
std::shared_lock lock(threads_mtx_);
auto it = threads_.find(stream.index);
if (it != threads_.end()) {
it->second->capture_error(std::move(msg));
}
}

// Take + clear the captured error for a stream. Returns empty string
// when no error is pending. Stream stays poisoned until reset_error.
std::string take_error(const Stream& stream) {
std::shared_lock lock(threads_mtx_);
auto it = threads_.find(stream.index);
if (it != threads_.end()) {
return it->second->take_error();
}
return {};
}

// Clear the poisoned flag for a stream. Called by mx.clear_streams()
// (and any future reset API) so the stream is usable again. The
// caller is expected to also have torn down the underlying GPU
// encoder state before clearing the flag.
void reset_error(const Stream& stream) {
std::shared_lock lock(threads_mtx_);
auto it = threads_.find(stream.index);
if (it != threads_.end()) {
it->second->reset_error();
}
}

void reset_all_errors() {
std::shared_lock lock(threads_mtx_);
for (auto& [_, st] : threads_) {
st->reset_error();
}
}

int n_active_tasks() const {
return n_active_tasks_;
}
Expand Down Expand Up @@ -136,6 +230,22 @@ inline void notify_task_completion(const Stream& stream) {
scheduler().notify_task_completion(stream);
}

inline void capture_error(const Stream& stream, std::string msg) {
scheduler().capture_error(stream, std::move(msg));
}

inline std::string take_error(const Stream& stream) {
return scheduler().take_error(stream);
}

inline void reset_error(const Stream& stream) {
scheduler().reset_error(stream);
}

inline void reset_all_errors() {
scheduler().reset_all_errors();
}

inline void wait_for_one() {
scheduler().wait_for_one();
}
Expand Down