Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions packages/orchestrator/pkg/sandbox/uffd/userfaultfd/fd.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ const (
UFFDIO_API = C.UFFDIO_API
UFFDIO_REGISTER = C.UFFDIO_REGISTER
UFFDIO_UNREGISTER = C.UFFDIO_UNREGISTER
UFFDIO_WAKE = C.UFFDIO_WAKE
UFFDIO_COPY = C.UFFDIO_COPY

UFFD_PAGEFAULT_FLAG_WRITE = C.UFFD_PAGEFAULT_FLAG_WRITE
Expand Down Expand Up @@ -130,6 +131,20 @@ func (f Fd) copy(addr, pagesize uintptr, data []byte, mode CULong) error {
return nil
}

// wake wakes threads waiting on page faults in the given address range
// without resolving the fault. The woken threads will re-execute the
// faulting instruction, triggering a new page fault that will be
// delivered as a fresh message on the uffd fd.
func (f Fd) wake(addr, pagesize uintptr) error {
r := newUffdioRange(CULong(addr)&^CULong(pagesize-1), CULong(pagesize))

if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_WAKE, uintptr(unsafe.Pointer(&r))); errno != 0 {
return errno
}

return nil
}

func (f Fd) close() error {
return syscall.Close(int(f))
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"sync"
"sync/atomic"
"syscall"
"unsafe"

Expand All @@ -23,7 +24,13 @@ import (

var tracer = otel.Tracer("github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/userfaultfd")

const maxRequestsInProgress = 4096
const (
maxRequestsInProgress = 4096
// maxFaultRetries is the number of times a page fault can be retried via
// UFFDIO_WAKE before giving up. Each retry releases the goroutine and
// lets the kernel re-deliver the fault as a fresh message.
maxFaultRetries = 3
)

var ErrUnexpectedEventType = errors.New("unexpected event type")

Expand All @@ -48,6 +55,10 @@ type Userfaultfd struct {

wg errgroup.Group

// faultRetries tracks how many times each page address has been retried
// via UFFDIO_WAKE. Key is page-aligned address, value is *atomic.Int32.
faultRetries sync.Map

logger logger.Logger
}

Expand Down Expand Up @@ -333,19 +344,46 @@ func (u *Userfaultfd) faultPage(

b, dataErr := source.Slice(ctx, offset, int64(pagesize))
if dataErr != nil {
retryVal, _ := u.faultRetries.LoadOrStore(addr, &atomic.Int32{})
retries := retryVal.(*atomic.Int32)
attempt := int(retries.Add(1))

var wakeErr error

if attempt <= maxFaultRetries {
u.logger.Warn(ctx, "UFFD serve data fetch failed, waking for retry",
zap.Int("attempt", attempt),
zap.Int("max_retries", maxFaultRetries),
zap.Int64("offset", offset),
zap.Uintptr("addr", addr),
zap.Error(dataErr),
)

wakeErr = u.fd.wake(addr, pagesize)
if wakeErr == nil {
return nil
}

u.logger.Error(ctx, "UFFD wake failed", zap.Uintptr("addr", addr), zap.Error(wakeErr))
}

u.faultRetries.Delete(addr)
Comment thread
jakubno marked this conversation as resolved.

var signalErr error
if onFailure != nil {
signalErr = onFailure()
Comment thread
arkamar marked this conversation as resolved.
}

joinedErr := errors.Join(dataErr, signalErr)
joinedErr := errors.Join(dataErr, wakeErr, signalErr)

span.RecordError(joinedErr)
u.logger.Error(ctx, "UFFD serve data fetch error", zap.Error(joinedErr))

return fmt.Errorf("failed to read from source: %w", joinedErr)
}

u.faultRetries.Delete(addr)

var copyMode CULong

// Performing copy() on UFFD clears the WP bit unless we explicitly tell
Expand Down
Loading