diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index d678461e3a..c75c8451e7 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -440,6 +440,8 @@ void CommandEncoder::end_encoding() { next_outputs_.clear(); concurrent_outputs_.clear(); all_inputs_.clear(); + + check_error(); } bool CommandEncoder::needs_commit() const { @@ -447,7 +449,20 @@ bool CommandEncoder::needs_commit() const { return (buffer_ops_ > max_ops) || ((buffer_sizes_ >> 20) > max_mb); } -void CommandEncoder::commit() { +void CommandEncoder::commit(std::function completion) { + buffer_->addCompletedHandler( + [this, completion = std::move(completion)](MTL::CommandBuffer* cbuf) { + if (completion) { + completion(); + } + if (cbuf->status() == MTL::CommandBufferStatusError) { + std::atomic_store( + &error_, + std::make_shared(fmt::format( + "[METAL] Command buffer execution failed: {}.", + cbuf->error()->localizedDescription()->utf8String()))); + } + }); buffer_->commit(); buffer_ = NS::RetainPtr(queue_->commandBufferWithUnretainedReferences()); buffer_ops_ = 0; @@ -456,22 +471,32 @@ void CommandEncoder::commit() { void CommandEncoder::synchronize() { auto pool = new_scoped_memory_pool(); - auto cb = NS::RetainPtr(get_command_buffer()); + auto cbuf = buffer_; // retained end_encoding(); commit(); - cb->waitUntilCompleted(); - if (!exiting_) { - if (cb->status() == MTL::CommandBufferStatusError) { - throw std::runtime_error( - fmt::format( - "[METAL] Command buffer execution failed: {}.", - cb->error()->localizedDescription()->utf8String())); - } + cbuf->waitUntilCompleted(); + check_error(); +} + +void CommandEncoder::check_error() { + // Do not check error during encoding, otherwise it would leave the program in + // corrupted state. + if (encoder_) { + return; + } + // When exiting with pending GPU commands, errors will happen, ignore them. + if (exiting_) { + return; + } + auto error = std::atomic_exchange(&error_, {}); + if (error) { + throw std::runtime_error(*error); } } MTL::ComputeCommandEncoder* CommandEncoder::get_command_encoder() { if (!encoder_) { + check_error(); encoder_ = NS::RetainPtr( buffer_->computeCommandEncoder(MTL::DispatchTypeConcurrent)); fence_ = NS::TransferPtr(device_.mtl_device()->newFence()); diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 5f2e72f915..f903d6a293 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -91,12 +91,10 @@ class MLX_API CommandEncoder { void barrier(); void end_encoding(); bool needs_commit() const; - void commit(); + void commit(std::function completion = nullptr); void synchronize(); + void check_error(); - MTL::CommandQueue* get_command_queue() const { - return queue_.get(); - } MTL::CommandBuffer* get_command_buffer() const { return buffer_.get(); } @@ -113,6 +111,9 @@ class MLX_API CommandEncoder { int buffer_ops_{0}; size_t buffer_sizes_{0}; + // Error from previous commited command buffer. + std::shared_ptr error_; + // Encoder for issuing GPU commands. // The members are used within a single ComputeCommandEncoder and will be // reset after calling end_encoding(). diff --git a/mlx/backend/metal/eval.cpp b/mlx/backend/metal/eval.cpp index 6f55976efe..8d1200d71f 100644 --- a/mlx/backend/metal/eval.cpp +++ b/mlx/backend/metal/eval.cpp @@ -18,15 +18,6 @@ void new_stream(Stream s) { encoders.try_emplace(s.index, d, s.index, d.residency_set()); } -inline void check_error(MTL::CommandBuffer* cbuf) { - if (cbuf->status() == MTL::CommandBufferStatusError) { - std::ostringstream msg; - msg << "[METAL] Command buffer execution failed: " - << cbuf->error()->localizedDescription()->utf8String(); - throw std::runtime_error(msg.str()); - } -} - void eval(array& arr) { auto pool = metal::new_scoped_memory_pool(); auto s = arr.primitive().stream(); @@ -60,17 +51,12 @@ void eval(array& arr) { if (encoder.needs_commit()) { encoder.end_encoding(); scheduler::notify_new_task(s); - command_buffer->addCompletedHandler( - [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { - scheduler::notify_task_completion(s); - check_error(cbuf); - }); - encoder.commit(); + encoder.commit([s, buffers = std::move(buffers)]() { + scheduler::notify_task_completion(s); + }); } else { command_buffer->addCompletedHandler( - [buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { - check_error(cbuf); - }); + [buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {}); } } @@ -79,7 +65,6 @@ void finalize(Stream s) { auto& encoder = metal::get_command_encoder(s); auto* cb = encoder.get_command_buffer(); encoder.end_encoding(); - cb->addCompletedHandler([](MTL::CommandBuffer* cbuf) { check_error(cbuf); }); encoder.commit(); } diff --git a/mlx/backend/metal/event.cpp b/mlx/backend/metal/event.cpp index 78ed4fafe2..d93ea4facf 100644 --- a/mlx/backend/metal/event.cpp +++ b/mlx/backend/metal/event.cpp @@ -21,9 +21,13 @@ Event::Event(Stream stream) : stream_(stream) { } void Event::wait() { - if (!static_cast(event_.get()) - ->waitUntilSignaledValue(value(), -1)) { - throw std::runtime_error("[Event::wait] Timed out"); + auto* event = static_cast(event_.get()); + // When error happened in command buffer, the event would wait indefinitely + // if we don't set a timeout. + while (!event->waitUntilSignaledValue(value(), 5 * 1000)) { + for (auto& [_, encoder] : metal::get_command_encoders()) { + encoder.check_error(); + } } } diff --git a/mlx/backend/metal/fence.cpp b/mlx/backend/metal/fence.cpp index 0ff7e7f3b4..c338af3f6f 100644 --- a/mlx/backend/metal/fence.cpp +++ b/mlx/backend/metal/fence.cpp @@ -56,9 +56,12 @@ void Fence::wait(Stream stream, const array& x) { scheduler::enqueue(stream, [fence_ = fence_, count = f.count]() mutable { auto& f = *static_cast(fence_.get()); if (!f.use_fast) { - if (!static_cast(f.fence)->waitUntilSignaledValue( - count, -1)) { - throw std::runtime_error("[Fence::wait] Timed out"); + // Same with Event::wait + auto* event = static_cast(f.fence); + while (!event->waitUntilSignaledValue(count, 5 * 1000)) { + for (auto& [_, encoder] : metal::get_command_encoders()) { + encoder.check_error(); + } } return; }