diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/fd.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/fd.go index 75a464c95c..fe3c655c9d 100644 --- a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/fd.go +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/fd.go @@ -44,6 +44,7 @@ const ( UFFDIO_API = C.UFFDIO_API UFFDIO_REGISTER = C.UFFDIO_REGISTER UFFDIO_UNREGISTER = C.UFFDIO_UNREGISTER + UFFDIO_WAKE = C.UFFDIO_WAKE UFFDIO_COPY = C.UFFDIO_COPY UFFD_PAGEFAULT_FLAG_WRITE = C.UFFD_PAGEFAULT_FLAG_WRITE @@ -130,6 +131,20 @@ func (f Fd) copy(addr, pagesize uintptr, data []byte, mode CULong) error { return nil } +// wake wakes threads waiting on page faults in the given address range +// without resolving the fault. The woken threads will re-execute the +// faulting instruction, triggering a new page fault that will be +// delivered as a fresh message on the uffd fd. +func (f Fd) wake(addr, pagesize uintptr) error { + r := newUffdioRange(CULong(addr)&^CULong(pagesize-1), CULong(pagesize)) + + if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_WAKE, uintptr(unsafe.Pointer(&r))); errno != 0 { + return errno + } + + return nil +} + func (f Fd) close() error { return syscall.Close(int(f)) } diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/userfaultfd.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/userfaultfd.go index 51bcc4f17d..84f4626a7f 100644 --- a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/userfaultfd.go +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/userfaultfd.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "sync" + "sync/atomic" "syscall" "unsafe" @@ -23,7 +24,13 @@ import ( var tracer = otel.Tracer("github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/userfaultfd") -const maxRequestsInProgress = 4096 +const ( + maxRequestsInProgress = 4096 + // maxFaultRetries is the number of times a page fault can be retried via + // UFFDIO_WAKE before giving up. Each retry releases the goroutine and + // lets the kernel re-deliver the fault as a fresh message. + maxFaultRetries = 3 +) var ErrUnexpectedEventType = errors.New("unexpected event type") @@ -48,6 +55,10 @@ type Userfaultfd struct { wg errgroup.Group + // faultRetries tracks how many times each page address has been retried + // via UFFDIO_WAKE. Key is page-aligned address, value is *atomic.Int32. + faultRetries sync.Map + logger logger.Logger } @@ -333,12 +344,37 @@ func (u *Userfaultfd) faultPage( b, dataErr := source.Slice(ctx, offset, int64(pagesize)) if dataErr != nil { + retryVal, _ := u.faultRetries.LoadOrStore(addr, &atomic.Int32{}) + retries := retryVal.(*atomic.Int32) + attempt := int(retries.Add(1)) + + var wakeErr error + + if attempt <= maxFaultRetries { + u.logger.Warn(ctx, "UFFD serve data fetch failed, waking for retry", + zap.Int("attempt", attempt), + zap.Int("max_retries", maxFaultRetries), + zap.Int64("offset", offset), + zap.Uintptr("addr", addr), + zap.Error(dataErr), + ) + + wakeErr = u.fd.wake(addr, pagesize) + if wakeErr == nil { + return nil + } + + u.logger.Error(ctx, "UFFD wake failed", zap.Uintptr("addr", addr), zap.Error(wakeErr)) + } + + u.faultRetries.Delete(addr) + var signalErr error if onFailure != nil { signalErr = onFailure() } - joinedErr := errors.Join(dataErr, signalErr) + joinedErr := errors.Join(dataErr, wakeErr, signalErr) span.RecordError(joinedErr) u.logger.Error(ctx, "UFFD serve data fetch error", zap.Error(joinedErr)) @@ -346,6 +382,8 @@ func (u *Userfaultfd) faultPage( return fmt.Errorf("failed to read from source: %w", joinedErr) } + u.faultRetries.Delete(addr) + var copyMode CULong // Performing copy() on UFFD clears the WP bit unless we explicitly tell