diff --git a/packages/orchestrator/pkg/sandbox/block/cache.go b/packages/orchestrator/pkg/sandbox/block/cache.go index 9e82c653d7..22074f549c 100644 --- a/packages/orchestrator/pkg/sandbox/block/cache.go +++ b/packages/orchestrator/pkg/sandbox/block/cache.go @@ -410,6 +410,17 @@ func (c *Cache) Path() string { return c.filePath } +func (c *Cache) Data() []byte { + c.mu.RLock() + defer c.mu.RUnlock() + + if c.mmap == nil { + return nil + } + + return []byte(*c.mmap) +} + func NewCacheFromProcessMemory( ctx context.Context, blockSize int64, diff --git a/packages/orchestrator/pkg/sandbox/block/chunk.go b/packages/orchestrator/pkg/sandbox/block/chunk.go index ad2017d2aa..db424ad05c 100644 --- a/packages/orchestrator/pkg/sandbox/block/chunk.go +++ b/packages/orchestrator/pkg/sandbox/block/chunk.go @@ -85,6 +85,7 @@ type Chunker interface { ReadAt(ctx context.Context, b []byte, off int64) (int, error) WriteTo(ctx context.Context, w io.Writer) (int64, error) Close() error + Data() []byte FileSize() (int64, error) } @@ -296,6 +297,10 @@ func (c *FullFetchChunker) Close() error { return c.cache.Close() } +func (c *FullFetchChunker) Data() []byte { + return c.cache.Data() +} + func (c *FullFetchChunker) FileSize() (int64, error) { return c.cache.FileSize() } diff --git a/packages/orchestrator/pkg/sandbox/block/streaming_chunk.go b/packages/orchestrator/pkg/sandbox/block/streaming_chunk.go index ca648397f3..17550336e5 100644 --- a/packages/orchestrator/pkg/sandbox/block/streaming_chunk.go +++ b/packages/orchestrator/pkg/sandbox/block/streaming_chunk.go @@ -460,6 +460,10 @@ func (c *StreamingChunker) Close() error { return c.cache.Close() } +func (c *StreamingChunker) Data() []byte { + return c.cache.Data() +} + func (c *StreamingChunker) FileSize() (int64, error) { return c.cache.FileSize() } diff --git a/packages/orchestrator/pkg/sandbox/build/diff.go b/packages/orchestrator/pkg/sandbox/build/diff.go index b817235aa9..0c971a09a6 100644 --- a/packages/orchestrator/pkg/sandbox/build/diff.go +++ b/packages/orchestrator/pkg/sandbox/build/diff.go @@ -30,6 +30,7 @@ type Diff interface { block.Slicer CacheKey() DiffStoreKey CachePath() (string, error) + Data() []byte FileSize() (int64, error) Init(ctx context.Context) error } @@ -42,6 +43,10 @@ func (n *NoDiff) CachePath() (string, error) { return "", NoDiffError{} } +func (n *NoDiff) Data() []byte { + return nil +} + func (n *NoDiff) Slice(_ context.Context, _, _ int64) ([]byte, error) { return nil, NoDiffError{} } diff --git a/packages/orchestrator/pkg/sandbox/build/local_diff.go b/packages/orchestrator/pkg/sandbox/build/local_diff.go index df5fec4ea7..3505e9de1c 100644 --- a/packages/orchestrator/pkg/sandbox/build/local_diff.go +++ b/packages/orchestrator/pkg/sandbox/build/local_diff.go @@ -110,6 +110,10 @@ func (b *localDiff) CachePath() (string, error) { return b.cache.Path(), nil } +func (b *localDiff) Data() []byte { + return b.cache.Data() +} + func (b *localDiff) Close() error { return b.cache.Close() } diff --git a/packages/orchestrator/pkg/sandbox/build/mocks/mockdiff.go b/packages/orchestrator/pkg/sandbox/build/mocks/mockdiff.go index ea61e38b25..be1f5a1379 100644 --- a/packages/orchestrator/pkg/sandbox/build/mocks/mockdiff.go +++ b/packages/orchestrator/pkg/sandbox/build/mocks/mockdiff.go @@ -223,6 +223,52 @@ func (_c *MockDiff_Close_Call) RunAndReturn(run func() error) *MockDiff_Close_Ca return _c } +// Data provides a mock function for the type MockDiff +func (_mock *MockDiff) Data() []byte { + ret := _mock.Called() + + if len(ret) == 0 { + panic("no return value specified for Data") + } + + var r0 []byte + if returnFunc, ok := ret.Get(0).(func() []byte); ok { + r0 = returnFunc() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + return r0 +} + +// MockDiff_Data_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Data' +type MockDiff_Data_Call struct { + *mock.Call +} + +// Data is a helper method to define mock.On call +func (_e *MockDiff_Expecter) Data() *MockDiff_Data_Call { + return &MockDiff_Data_Call{Call: _e.mock.On("Data")} +} + +func (_c *MockDiff_Data_Call) Run(run func()) *MockDiff_Data_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockDiff_Data_Call) Return(bytes []byte) *MockDiff_Data_Call { + _c.Call.Return(bytes) + return _c +} + +func (_c *MockDiff_Data_Call) RunAndReturn(run func() []byte) *MockDiff_Data_Call { + _c.Call.Return(run) + return _c +} + // FileSize provides a mock function for the type MockDiff func (_mock *MockDiff) FileSize() (int64, error) { ret := _mock.Called() diff --git a/packages/orchestrator/pkg/sandbox/build/storage_diff.go b/packages/orchestrator/pkg/sandbox/build/storage_diff.go index eca9b11bb8..94f239c2af 100644 --- a/packages/orchestrator/pkg/sandbox/build/storage_diff.go +++ b/packages/orchestrator/pkg/sandbox/build/storage_diff.go @@ -150,6 +150,15 @@ func (b *StorageDiff) CachePath() (string, error) { return b.cachePath, nil } +func (b *StorageDiff) Data() []byte { + c, err := b.chunker.Wait() + if err != nil { + return nil + } + + return c.Data() +} + func (b *StorageDiff) FileSize() (int64, error) { c, err := b.chunker.Wait() if err != nil { diff --git a/packages/orchestrator/pkg/sandbox/snapshot.go b/packages/orchestrator/pkg/sandbox/snapshot.go index eb6f1a8fd2..38f189edf9 100644 --- a/packages/orchestrator/pkg/sandbox/snapshot.go +++ b/packages/orchestrator/pkg/sandbox/snapshot.go @@ -26,30 +26,6 @@ func (s *Snapshot) Upload( persistence storage.StorageProvider, paths storage.Paths, ) error { - var memfilePath *string - switch r := s.MemfileDiff.(type) { - case *build.NoDiff: - default: - memfileLocalPath, err := r.CachePath() - if err != nil { - return fmt.Errorf("error getting memfile diff path: %w", err) - } - - memfilePath = &memfileLocalPath - } - - var rootfsPath *string - switch r := s.RootfsDiff.(type) { - case *build.NoDiff: - default: - rootfsLocalPath, err := r.CachePath() - if err != nil { - return fmt.Errorf("error getting rootfs diff path: %w", err) - } - - rootfsPath = &rootfsLocalPath - } - templateBuild := NewTemplateBuild( s.MemfileDiffHeader, s.RootfsDiffHeader, @@ -61,8 +37,8 @@ func (s *Snapshot) Upload( ctx, s.Metafile.Path(), s.Snapfile.Path(), - memfilePath, - rootfsPath, + s.MemfileDiff, + s.RootfsDiff, ); err != nil { return fmt.Errorf("error uploading template files: %w", err) } diff --git a/packages/orchestrator/pkg/sandbox/template/peerclient/seekable.go b/packages/orchestrator/pkg/sandbox/template/peerclient/seekable.go index 2054b1aa5c..9d873e77c0 100644 --- a/packages/orchestrator/pkg/sandbox/template/peerclient/seekable.go +++ b/packages/orchestrator/pkg/sandbox/template/peerclient/seekable.go @@ -130,6 +130,15 @@ func (s *peerSeekable) StoreFile(ctx context.Context, path string) error { return fallback.StoreFile(ctx, path) } +func (s *peerSeekable) StoreData(ctx context.Context, data []byte) error { + fallback, err := s.getOrOpenBase(ctx) + if err != nil { + return err + } + + return fallback.StoreData(ctx, data) +} + // openPeerSeekableStream opens a ReadAtBuildSeekable stream, checks peer availability, // and returns a recv function that yields data chunks starting with the first message's data. // The passed context HAS to be canceled by the caller when done with the stream to avoid leaks. diff --git a/packages/orchestrator/pkg/sandbox/template_build.go b/packages/orchestrator/pkg/sandbox/template_build.go index 374d39fba4..9f756530ad 100644 --- a/packages/orchestrator/pkg/sandbox/template_build.go +++ b/packages/orchestrator/pkg/sandbox/template_build.go @@ -8,6 +8,7 @@ import ( "golang.org/x/sync/errgroup" + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/build" "github.com/e2b-dev/infra/packages/shared/pkg/storage" headers "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) @@ -58,13 +59,13 @@ func (t *TemplateBuild) uploadMemfileHeader(ctx context.Context, h *headers.Head return nil } -func (t *TemplateBuild) uploadMemfile(ctx context.Context, memfilePath string) error { +func (t *TemplateBuild) uploadMemfile(ctx context.Context, data []byte) error { object, err := t.persistence.OpenSeekable(ctx, t.paths.Memfile(), storage.MemfileObjectType) if err != nil { return err } - err = object.StoreFile(ctx, memfilePath) + err = object.StoreData(ctx, data) if err != nil { return fmt.Errorf("error when uploading memfile: %w", err) } @@ -91,13 +92,13 @@ func (t *TemplateBuild) uploadRootfsHeader(ctx context.Context, h *headers.Heade return nil } -func (t *TemplateBuild) uploadRootfs(ctx context.Context, rootfsPath string) error { +func (t *TemplateBuild) uploadRootfs(ctx context.Context, data []byte) error { object, err := t.persistence.OpenSeekable(ctx, t.paths.Rootfs(), storage.RootFSObjectType) if err != nil { return err } - err = object.StoreFile(ctx, rootfsPath) + err = object.StoreData(ctx, data) if err != nil { return fmt.Errorf("error when uploading rootfs: %w", err) } @@ -153,7 +154,7 @@ func uploadFileAsBlob(ctx context.Context, b storage.Blob, path string) error { return nil } -func (t *TemplateBuild) Upload(ctx context.Context, metadataPath string, fcSnapfilePath string, memfilePath *string, rootfsPath *string) error { +func (t *TemplateBuild) Upload(ctx context.Context, metadataPath string, fcSnapfilePath string, memfileDiff build.Diff, rootfsDiff build.Diff) error { eg, ctx := errgroup.WithContext(ctx) eg.Go(func() error { @@ -173,19 +174,21 @@ func (t *TemplateBuild) Upload(ctx context.Context, metadataPath string, fcSnapf }) eg.Go(func() error { - if rootfsPath == nil { + data := rootfsDiff.Data() + if data == nil { return nil } - return t.uploadRootfs(ctx, *rootfsPath) + return t.uploadRootfs(ctx, data) }) eg.Go(func() error { - if memfilePath == nil { + data := memfileDiff.Data() + if data == nil { return nil } - return t.uploadMemfile(ctx, *memfilePath) + return t.uploadMemfile(ctx, data) }) eg.Go(func() error { diff --git a/packages/shared/pkg/storage/gcp_multipart.go b/packages/shared/pkg/storage/gcp_multipart.go deleted file mode 100644 index 75324c16c1..0000000000 --- a/packages/shared/pkg/storage/gcp_multipart.go +++ /dev/null @@ -1,368 +0,0 @@ -package storage - -import ( - "bytes" - "context" - "crypto/md5" - "encoding/base64" - "encoding/xml" - "fmt" - "io" - "math" - "math/rand" - "net/http" - "os" - "sort" - "sync" - "time" - - "github.com/hashicorp/go-retryablehttp" - "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" - "go.uber.org/zap" - "golang.org/x/oauth2/google" - "golang.org/x/sync/errgroup" - - "github.com/e2b-dev/infra/packages/shared/pkg/logger" -) - -const ( - gcpMultipartUploadChunkSize = 50 * 1024 * 1024 // 50MB chunks -) - -// RetryConfig holds the configuration for retry logic -type RetryConfig struct { - MaxAttempts int - InitialBackoff time.Duration - MaxBackoff time.Duration - BackoffMultiplier float64 -} - -// DefaultRetryConfig returns the default retry configuration matching storage_google.go -func DefaultRetryConfig() RetryConfig { - return RetryConfig{ - MaxAttempts: googleMaxAttempts, - InitialBackoff: googleInitialBackoff, - MaxBackoff: googleMaxBackoff, - BackoffMultiplier: googleBackoffMultiplier, - } -} - -func createRetryableClient(ctx context.Context, config RetryConfig) *retryablehttp.Client { - client := retryablehttp.NewClient() - - client.RetryMax = config.MaxAttempts - 1 // go-retryablehttp counts retries, not total attempts - client.RetryWaitMin = config.InitialBackoff - client.RetryWaitMax = config.MaxBackoff - - // Custom backoff function with full jitter to avoid thundering herd - client.Backoff = func(start, maxBackoff time.Duration, attemptNum int, _ *http.Response) time.Duration { - // Calculate exponential backoff - backoff := start - for range attemptNum { - backoff = time.Duration(float64(backoff) * config.BackoffMultiplier) - if backoff > maxBackoff { - backoff = maxBackoff - - break - } - } - - // Apply full jitter: random(0, backoff) - // This implements the "full jitter" strategy recommended by AWS: - // https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ - // Benefits: - // - Spreads retry attempts across time to avoid thundering herd - // - Reduces peak load on servers during outages - // - Improves overall system stability under high retry scenarios - if backoff > 0 { - return time.Duration(rand.Int63n(int64(backoff))) - } - - return backoff - } - - // add otel instrumentation - originalTransport := client.HTTPClient.Transport - client.HTTPClient.Transport = otelhttp.NewTransport(originalTransport) - - // Use zap logger - client.Logger = &leveledLogger{ - logger: logger.L().Detach(ctx), - } - - return client -} - -// zapLogger adapts zap.Logger to retryablehttp.LeveledLogger interface -var _ retryablehttp.LeveledLogger = &leveledLogger{} - -type leveledLogger struct { - logger *zap.Logger -} - -func (z *leveledLogger) Error(msg string, keysAndValues ...any) { - z.logger.Error(msg, zap.Any("details", keysAndValues)) -} - -func (z *leveledLogger) Info(msg string, keysAndValues ...any) { - z.logger.Info(msg, zap.Any("details", keysAndValues)) -} - -func (z *leveledLogger) Debug(string, ...any) { - // Ignore debug logs -} - -func (z *leveledLogger) Warn(msg string, keysAndValues ...any) { - z.logger.Warn(msg, zap.Any("details", keysAndValues)) -} - -type InitiateMultipartUploadResult struct { - Bucket string `xml:"Bucket"` - Key string `xml:"Key"` - UploadID string `xml:"UploadId"` -} - -type CompleteMultipartUpload struct { - XMLName string `xml:"CompleteMultipartUpload"` - Parts []Part `xml:"Part"` -} - -type Part struct { - PartNumber int `xml:"PartNumber"` - ETag string `xml:"ETag"` -} - -type MultipartUploader struct { - bucketName string - objectName string - token string - client *retryablehttp.Client - retryConfig RetryConfig - baseURL string // Allow overriding for testing -} - -func NewMultipartUploaderWithRetryConfig(ctx context.Context, bucketName, objectName string, retryConfig RetryConfig) (*MultipartUploader, error) { - creds, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/cloud-platform") - if err != nil { - return nil, fmt.Errorf("failed to get credentials: %w", err) - } - - token, err := creds.TokenSource.Token() - if err != nil { - return nil, fmt.Errorf("failed to get token: %w", err) - } - - return &MultipartUploader{ - bucketName: bucketName, - objectName: objectName, - token: token.AccessToken, - client: createRetryableClient(ctx, retryConfig), - retryConfig: retryConfig, - baseURL: fmt.Sprintf("https://%s.storage.googleapis.com", bucketName), - }, nil -} - -func (m *MultipartUploader) initiateUpload(ctx context.Context) (string, error) { - url := fmt.Sprintf("%s/%s?uploads", m.baseURL, m.objectName) - - req, err := retryablehttp.NewRequestWithContext(ctx, "POST", url, nil) - if err != nil { - return "", err - } - - req.Header.Set("Authorization", "Bearer "+m.token) - req.Header.Set("Content-Length", "0") - req.Header.Set("Content-Type", "application/octet-stream") - - resp, err := m.client.Do(req) - if err != nil { - return "", err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - - return "", fmt.Errorf("failed to initiate upload (status %d): %s", resp.StatusCode, string(body)) - } - - var result InitiateMultipartUploadResult - if err := xml.NewDecoder(resp.Body).Decode(&result); err != nil { - return "", fmt.Errorf("failed to parse initiate response: %w", err) - } - - return result.UploadID, nil -} - -func (m *MultipartUploader) uploadPart(ctx context.Context, uploadID string, partNumber int, data []byte) (string, error) { - // Calculate MD5 for data integrity - hasher := md5.New() - hasher.Write(data) - md5Sum := base64.StdEncoding.EncodeToString(hasher.Sum(nil)) - - url := fmt.Sprintf("%s/%s?partNumber=%d&uploadId=%s", - m.baseURL, m.objectName, partNumber, uploadID) - - req, err := retryablehttp.NewRequestWithContext(ctx, "PUT", url, bytes.NewReader(data)) - if err != nil { - return "", err - } - - req.Header.Set("Authorization", "Bearer "+m.token) - req.Header.Set("Content-Length", fmt.Sprintf("%d", len(data))) - req.Header.Set("Content-MD5", md5Sum) - - resp, err := m.client.Do(req) - if err != nil { - return "", err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - - return "", fmt.Errorf("failed to upload part %d (status %d): %s", partNumber, resp.StatusCode, string(body)) - } - - etag := resp.Header.Get("ETag") - if etag == "" { - return "", fmt.Errorf("no ETag returned for part %d", partNumber) - } - - return etag, nil -} - -func (m *MultipartUploader) completeUpload(ctx context.Context, uploadID string, parts []Part) error { - // Sort parts by part number - sort.Slice(parts, func(i, j int) bool { - return parts[i].PartNumber < parts[j].PartNumber - }) - - completeReq := CompleteMultipartUpload{Parts: parts} - xmlData, err := xml.Marshal(completeReq) - if err != nil { - return fmt.Errorf("failed to marshal complete request: %w", err) - } - - url := fmt.Sprintf("%s/%s?uploadId=%s", - m.baseURL, m.objectName, uploadID) - - req, err := retryablehttp.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(xmlData)) - if err != nil { - return fmt.Errorf("failed to create complete request: %w", err) - } - - req.Header.Set("Authorization", "Bearer "+m.token) - req.Header.Set("Content-Type", "application/xml") - req.Header.Set("Content-Length", fmt.Sprintf("%d", len(xmlData))) - - resp, err := m.client.Do(req) - if err != nil { - return fmt.Errorf("http request failed: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - - return fmt.Errorf("failed to complete upload (status %d): %s", resp.StatusCode, string(body)) - } - - return nil -} - -func (m *MultipartUploader) UploadFileInParallel(ctx context.Context, filePath string, maxConcurrency int) (int64, error) { - // Open file - file, err := os.Open(filePath) - if err != nil { - return 0, fmt.Errorf("failed to open file: %w", err) - } - defer file.Close() - - // Get file size - fileInfo, err := file.Stat() - if err != nil { - return 0, fmt.Errorf("failed to get file info: %w", err) - } - fileSize := fileInfo.Size() - - // Calculate number of parts - numParts := int(math.Ceil(float64(fileSize) / float64(gcpMultipartUploadChunkSize))) - if numParts == 0 { - numParts = 1 // Always upload at least 1 part, even for empty files - } - - // Initiate multipart upload - uploadID, err := m.initiateUpload(ctx) - if err != nil { - return 0, fmt.Errorf("failed to initiate upload: %w", err) - } - - parts, err := m.uploadParts(ctx, maxConcurrency, numParts, fileSize, file, uploadID) - if err != nil { - return 0, fmt.Errorf("failed to upload parts: %w", err) - } - - if err := m.completeUpload(ctx, uploadID, parts); err != nil { - return 0, fmt.Errorf("failed to complete upload: %w", err) - } - - return fileSize, nil -} - -func (m *MultipartUploader) uploadParts(ctx context.Context, maxConcurrency int, numParts int, fileSize int64, file *os.File, uploadID string) ([]Part, error) { - g, ctx := errgroup.WithContext(ctx) // Context ONLY for waitgroup goroutines; canceled after errgroup finishes - g.SetLimit(maxConcurrency) // Limit concurrent goroutines - - // Thread-safe map to collect parts - var partsMu sync.Mutex - parts := make([]Part, numParts) - - // Upload each part concurrently - for partNumber := 1; partNumber <= numParts; partNumber++ { - g.Go(func() error { - // Check if context was cancelled - select { - case <-ctx.Done(): - return fmt.Errorf("part %d failed: %w", partNumber, ctx.Err()) - default: - } - - // Read chunk from file - offset := int64(partNumber-1) * gcpMultipartUploadChunkSize - chunkSize := gcpMultipartUploadChunkSize - if offset+int64(chunkSize) > fileSize { - chunkSize = int(fileSize - offset) - } - - chunk := make([]byte, chunkSize) - _, err := file.ReadAt(chunk, offset) - if err != nil { - return fmt.Errorf("failed to read chunk for part %d: %w", partNumber, err) - } - - // Upload part - etag, err := m.uploadPart(ctx, uploadID, partNumber, chunk) - if err != nil { - return fmt.Errorf("failed to upload part %d: %w", partNumber, err) - } - - // Store result thread-safely - partsMu.Lock() - parts[partNumber-1] = Part{ - PartNumber: partNumber, - ETag: etag, - } - partsMu.Unlock() - - return nil - }) - } - - // Wait for all parts to complete or first error - if err := g.Wait(); err != nil { - return nil, fmt.Errorf("upload failed: %w", err) - } - - return parts, nil -} diff --git a/packages/shared/pkg/storage/gcpmultipart/uploader.go b/packages/shared/pkg/storage/gcpmultipart/uploader.go new file mode 100644 index 0000000000..d27ec4de01 --- /dev/null +++ b/packages/shared/pkg/storage/gcpmultipart/uploader.go @@ -0,0 +1,234 @@ +package gcpmultipart + +import ( + "bytes" + "context" + "encoding/xml" + "fmt" + "io" + "math/rand" + "net" + "net/http" + "strconv" + "time" + + "github.com/hashicorp/go-retryablehttp" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" + "go.uber.org/zap" + "golang.org/x/oauth2/google" + "golang.org/x/sync/errgroup" + + "github.com/e2b-dev/infra/packages/shared/pkg/logger" +) + +const ChunkSize = 50 * 1024 * 1024 + +var httpClient = &retryablehttp.Client{ + RetryMax: 9, + RetryWaitMin: 10 * time.Millisecond, + RetryWaitMax: 10 * time.Second, + CheckRetry: retryablehttp.DefaultRetryPolicy, + Logger: &leveledLogger{logger: logger.L().Detach(context.Background())}, + HTTPClient: &http.Client{ + Transport: otelhttp.NewTransport(&http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 100, + IdleConnTimeout: 90 * time.Second, + WriteBufferSize: 4 << 20, + ReadBufferSize: 64 << 10, + ForceAttemptHTTP2: true, + DialContext: (&net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + }), + }, + Backoff: func(start, maxBackoff time.Duration, attempt int, _ *http.Response) time.Duration { + b := start + for range attempt { + b = time.Duration(float64(b) * 2) + if b > maxBackoff { + b = maxBackoff + + break + } + } + + if b > 0 { + return time.Duration(rand.Int63n(int64(b))) + } + + return 0 + }, +} + +type Uploader struct { + token string + baseURL string + client *retryablehttp.Client +} + +func NewUploader(ctx context.Context, bucketName, objectName string) (*Uploader, error) { + creds, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/cloud-platform") + if err != nil { + return nil, fmt.Errorf("get credentials: %w", err) + } + + token, err := creds.TokenSource.Token() + if err != nil { + return nil, fmt.Errorf("get token: %w", err) + } + + return &Uploader{ + token: token.AccessToken, + baseURL: "https://" + bucketName + ".storage.googleapis.com/" + objectName, + client: httpClient, + }, nil +} + +func (u *Uploader) Upload(ctx context.Context, data []byte, maxConcurrency int) (int64, error) { + uploadID, err := u.initiate(ctx) + if err != nil { + return 0, err + } + + dataLen := len(data) + numParts := (dataLen + ChunkSize - 1) / ChunkSize + if numParts == 0 { + numParts = 1 + } + + parts := make([]xmlPart, numParts) + g, gCtx := errgroup.WithContext(ctx) + g.SetLimit(maxConcurrency) + + for i := range numParts { + g.Go(func() error { + start := i * ChunkSize + end := min(start+ChunkSize, dataLen) + partNum := i + 1 + + etag, err := u.putPart(gCtx, uploadID, partNum, data[start:end]) + if err != nil { + return err + } + + parts[i] = xmlPart{PartNumber: partNum, ETag: etag} + + return nil + }) + } + + if err := g.Wait(); err != nil { + return 0, err + } + + if err := u.complete(ctx, uploadID, parts); err != nil { + return 0, err + } + + return int64(dataLen), nil +} + +type xmlInitiateResponse struct { + UploadID string `xml:"UploadId"` +} + +type xmlCompleteRequest struct { + XMLName string `xml:"CompleteMultipartUpload"` + Parts []xmlPart `xml:"Part"` +} + +type xmlPart struct { + PartNumber int `xml:"PartNumber"` + ETag string `xml:"ETag"` +} + +func (u *Uploader) doRequest(ctx context.Context, method, url string, body io.ReadSeeker, headers [][2]string) (*http.Response, error) { + req, err := retryablehttp.NewRequestWithContext(ctx, method, url, body) + if err != nil { + return nil, err + } + + req.Header.Set("Authorization", "Bearer "+u.token) + for _, h := range headers { + req.Header.Set(h[0], h[1]) + } + + resp, err := u.client.Do(req) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, body) + } + + return resp, nil +} + +func (u *Uploader) initiate(ctx context.Context) (string, error) { + resp, err := u.doRequest(ctx, "POST", u.baseURL+"?uploads", nil, [][2]string{ + {"Content-Length", "0"}, + {"Content-Type", "application/octet-stream"}, + }) + if err != nil { + return "", fmt.Errorf("initiate upload: %w", err) + } + defer resp.Body.Close() + + var result xmlInitiateResponse + if err := xml.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", fmt.Errorf("parse initiate response: %w", err) + } + + return result.UploadID, nil +} + +func (u *Uploader) putPart(ctx context.Context, uploadID string, partNumber int, data []byte) (string, error) { + resp, err := u.doRequest(ctx, "PUT", + u.baseURL+"?partNumber="+strconv.Itoa(partNumber)+"&uploadId="+uploadID, + bytes.NewReader(data), + [][2]string{{"Content-Length", strconv.Itoa(len(data))}}, + ) + if err != nil { + return "", err + } + resp.Body.Close() + + if etag := resp.Header.Get("ETag"); etag != "" { + return etag, nil + } + + return "", fmt.Errorf("no ETag for part %d", partNumber) +} + +func (u *Uploader) complete(ctx context.Context, uploadID string, parts []xmlPart) error { + body, err := xml.Marshal(xmlCompleteRequest{Parts: parts}) + if err != nil { + return err + } + + resp, err := u.doRequest(ctx, "POST", u.baseURL+"?uploadId="+uploadID, bytes.NewReader(body), [][2]string{ + {"Content-Type", "application/xml"}, + {"Content-Length", strconv.Itoa(len(body))}, + }) + if err != nil { + return err + } + resp.Body.Close() + + return nil +} + +var _ retryablehttp.LeveledLogger = &leveledLogger{} + +type leveledLogger struct{ logger *zap.Logger } + +func (l *leveledLogger) Error(msg string, kv ...any) { l.logger.Error(msg, zap.Any("details", kv)) } +func (l *leveledLogger) Info(msg string, kv ...any) { l.logger.Info(msg, zap.Any("details", kv)) } +func (l *leveledLogger) Debug(string, ...any) {} +func (l *leveledLogger) Warn(msg string, kv ...any) { l.logger.Warn(msg, zap.Any("details", kv)) } diff --git a/packages/shared/pkg/storage/gcp_multipart_test.go b/packages/shared/pkg/storage/gcpmultipart/uploader_test.go similarity index 60% rename from packages/shared/pkg/storage/gcp_multipart_test.go rename to packages/shared/pkg/storage/gcpmultipart/uploader_test.go index 7fe4d397ce..0855b227dc 100644 --- a/packages/shared/pkg/storage/gcp_multipart_test.go +++ b/packages/shared/pkg/storage/gcpmultipart/uploader_test.go @@ -1,4 +1,4 @@ -package storage +package gcpmultipart import ( "encoding/xml" @@ -31,32 +31,38 @@ const ( uploadsPath = "uploads" ) -// createTestMultipartUploader creates a test uploader with a mock HTTP client -func createTestMultipartUploader(t *testing.T, handler http.HandlerFunc, retryConfig ...RetryConfig) *MultipartUploader { +func createTestUploader(t *testing.T, handler http.HandlerFunc, retryConfigs ...retryConfig) *Uploader { t.Helper() server := httptest.NewServer(handler) t.Cleanup(server.Close) - config := DefaultRetryConfig() - if len(retryConfig) > 0 { - config = retryConfig[0] + cfg := retryConfig{maxAttempts: 10, initialBackoff: 10 * time.Millisecond, maxBackoff: 10 * time.Second, multiplier: 2} + if len(retryConfigs) > 0 { + cfg = retryConfigs[0] } - // Create retryable client using the test server's client - retryableClient := createRetryableClient(t.Context(), config) - retryableClient.HTTPClient = server.Client() - - uploader := &MultipartUploader{ - bucketName: testBucketName, - objectName: testObjectName, - token: testToken, - client: retryableClient, - retryConfig: config, - baseURL: server.URL, // Override to use test server + client := &retryablehttp.Client{ + RetryMax: cfg.maxAttempts - 1, + RetryWaitMin: cfg.initialBackoff, + RetryWaitMax: cfg.maxBackoff, + CheckRetry: retryablehttp.DefaultRetryPolicy, + HTTPClient: server.Client(), + Backoff: httpClient.Backoff, } - return uploader + return &Uploader{ + token: testToken, + baseURL: server.URL + "/" + testObjectName, + client: client, + } +} + +type retryConfig struct { + maxAttempts int + initialBackoff time.Duration + maxBackoff time.Duration + multiplier float64 } func TestMultipartUploader_InitiateUpload_Success(t *testing.T) { @@ -70,9 +76,7 @@ func TestMultipartUploader_InitiateUpload_Success(t *testing.T) { assert.Equal(t, "Bearer "+testToken, r.Header.Get("Authorization")) assert.Equal(t, "application/octet-stream", r.Header.Get("Content-Type")) - response := InitiateMultipartUploadResult{ - Bucket: testBucketName, - Key: testObjectName, + response := xmlInitiateResponse{ UploadID: expectedUploadID, } @@ -82,8 +86,8 @@ func TestMultipartUploader_InitiateUpload_Success(t *testing.T) { w.Write(xmlData) }) - uploader := createTestMultipartUploader(t, handler) - uploadID, err := uploader.initiateUpload(t.Context()) + uploader := createTestUploader(t, handler) + uploadID, err := uploader.initiate(t.Context()) require.NoError(t, err) require.Equal(t, expectedUploadID, uploadID) @@ -108,8 +112,8 @@ func TestMultipartUploader_UploadPart_Success(t *testing.T) { w.WriteHeader(http.StatusOK) }) - uploader := createTestMultipartUploader(t, handler) - etag, err := uploader.uploadPart(t.Context(), "test-upload-id", 1, testData) + uploader := createTestUploader(t, handler) + etag, err := uploader.putPart(t.Context(), "test-upload-id", 1, testData) require.NoError(t, err) require.Equal(t, expectedETag, etag) @@ -122,17 +126,17 @@ func TestMultipartUploader_UploadPart_MissingETag(t *testing.T) { w.WriteHeader(http.StatusOK) }) - uploader := createTestMultipartUploader(t, handler) - etag, err := uploader.uploadPart(t.Context(), "test-upload-id", 1, []byte("test")) + uploader := createTestUploader(t, handler) + etag, err := uploader.putPart(t.Context(), "test-upload-id", 1, []byte("test")) require.Error(t, err) - require.Contains(t, err.Error(), "no ETag returned for part 1") + require.Contains(t, err.Error(), "no ETag for part 1") require.Empty(t, etag) } func TestMultipartUploader_CompleteUpload_Success(t *testing.T) { t.Parallel() - parts := []Part{ + parts := []xmlPart{ {PartNumber: 1, ETag: `"etag1"`}, {PartNumber: 2, ETag: `"etag2"`}, } @@ -146,7 +150,7 @@ func TestMultipartUploader_CompleteUpload_Success(t *testing.T) { body, err := io.ReadAll(r.Body) assert.NoError(t, err) - var completeReq CompleteMultipartUpload + var completeReq xmlCompleteRequest err = xml.Unmarshal(body, &completeReq) assert.NoError(t, err) assert.Len(t, completeReq.Parts, 2) @@ -156,9 +160,18 @@ func TestMultipartUploader_CompleteUpload_Success(t *testing.T) { w.WriteHeader(http.StatusOK) }) - uploader := createTestMultipartUploader(t, handler) - err := uploader.completeUpload(t.Context(), "test-upload-id", parts) + uploader := createTestUploader(t, handler) + err := uploader.complete(t.Context(), "test-upload-id", parts) + require.NoError(t, err) +} + +func readTestFile(t *testing.T, path string) []byte { + t.Helper() + + data, err := os.ReadFile(path) require.NoError(t, err) + + return data } func TestMultipartUploader_UploadFileInParallel_Success(t *testing.T) { @@ -171,7 +184,7 @@ func TestMultipartUploader_UploadFileInParallel_Success(t *testing.T) { require.NoError(t, err) var uploadID string - var initiateCount, uploadPartCount, completeCount int32 + var initiateCount, putPartCount, completeCount int32 receivedParts := sync.Map{} handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -180,9 +193,7 @@ func TestMultipartUploader_UploadFileInParallel_Success(t *testing.T) { // Initiate upload atomic.AddInt32(&initiateCount, 1) uploadID = "test-upload-id-123" - response := InitiateMultipartUploadResult{ - Bucket: testBucketName, - Key: testObjectName, + response := xmlInitiateResponse{ UploadID: uploadID, } xmlData, _ := xml.Marshal(response) @@ -192,7 +203,7 @@ func TestMultipartUploader_UploadFileInParallel_Success(t *testing.T) { case strings.Contains(r.URL.RawQuery, "partNumber"): // Upload part - partNum := atomic.AddInt32(&uploadPartCount, 1) + partNum := atomic.AddInt32(&putPartCount, 1) body, _ := io.ReadAll(r.Body) receivedParts.Store(int(partNum), string(body)) @@ -206,17 +217,17 @@ func TestMultipartUploader_UploadFileInParallel_Success(t *testing.T) { } }) - uploader := createTestMultipartUploader(t, handler) - _, err = uploader.UploadFileInParallel(t.Context(), testFile, 2) + uploader := createTestUploader(t, handler) + _, err = uploader.Upload(t.Context(), readTestFile(t, testFile), 2) require.NoError(t, err) require.Equal(t, int32(1), atomic.LoadInt32(&initiateCount)) require.Equal(t, int32(1), atomic.LoadInt32(&completeCount)) - require.Positive(t, atomic.LoadInt32(&uploadPartCount)) + require.Positive(t, atomic.LoadInt32(&putPartCount)) // Verify all parts were uploaded and content matches var reconstructed strings.Builder - for i := 1; i <= int(atomic.LoadInt32(&uploadPartCount)); i++ { + for i := 1; i <= int(atomic.LoadInt32(&putPartCount)); i++ { if part, ok := receivedParts.Load(i); ok { reconstructed.WriteString(part.(string)) } @@ -237,9 +248,7 @@ func TestMultipartUploader_InitiateUpload_WithRetries(t *testing.T) { return } - response := InitiateMultipartUploadResult{ - Bucket: testBucketName, - Key: testObjectName, + response := xmlInitiateResponse{ UploadID: expectedUploadID, } xmlData, _ := xml.Marshal(response) @@ -247,15 +256,8 @@ func TestMultipartUploader_InitiateUpload_WithRetries(t *testing.T) { w.Write(xmlData) }) - config := RetryConfig{ - MaxAttempts: 3, - InitialBackoff: 10 * time.Millisecond, - MaxBackoff: 1 * time.Second, - BackoffMultiplier: 2, - } - - uploader := createTestMultipartUploader(t, handler, config) - uploadID, err := uploader.initiateUpload(t.Context()) + uploader := createTestUploader(t, handler, retryConfig{maxAttempts: 3, initialBackoff: 10 * time.Millisecond, maxBackoff: 1 * time.Second, multiplier: 2}) + uploadID, err := uploader.initiate(t.Context()) require.NoError(t, err) require.Equal(t, expectedUploadID, uploadID) @@ -282,9 +284,7 @@ func TestMultipartUploader_HighConcurrency_StressTest(t *testing.T) { switch { case r.URL.RawQuery == uploadsPath: atomic.AddInt32(&initiateCalls, 1) - response := InitiateMultipartUploadResult{ - Bucket: testBucketName, - Key: testObjectName, + response := xmlInitiateResponse{ UploadID: "stress-test-upload-id", } xmlData, _ := xml.Marshal(response) @@ -320,10 +320,10 @@ func TestMultipartUploader_HighConcurrency_StressTest(t *testing.T) { } }) - uploader := createTestMultipartUploader(t, handler) + uploader := createTestUploader(t, handler) // Use high concurrency to stress test - _, err = uploader.UploadFileInParallel(t.Context(), testFile, 50) + _, err = uploader.Upload(t.Context(), readTestFile(t, testFile), 50) require.NoError(t, err) // Verify all calls were made @@ -356,9 +356,7 @@ func TestMultipartUploader_RandomFailures_ChaosTest(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case r.URL.RawQuery == uploadsPath: - response := InitiateMultipartUploadResult{ - Bucket: testBucketName, - Key: testObjectName, + response := xmlInitiateResponse{ UploadID: "chaos-upload-id", } xmlData, _ := xml.Marshal(response) @@ -385,16 +383,8 @@ func TestMultipartUploader_RandomFailures_ChaosTest(t *testing.T) { } }) - // Use aggressive retry config for chaos test - config := RetryConfig{ - MaxAttempts: 10, - InitialBackoff: 1 * time.Millisecond, - MaxBackoff: 100 * time.Millisecond, - BackoffMultiplier: 2, - } - - uploader := createTestMultipartUploader(t, handler, config) - _, err = uploader.UploadFileInParallel(t.Context(), testFile, 10) + uploader := createTestUploader(t, handler, retryConfig{maxAttempts: 10, initialBackoff: 1 * time.Millisecond, maxBackoff: 100 * time.Millisecond, multiplier: 2}) + _, err = uploader.Upload(t.Context(), readTestFile(t, testFile), 10) require.NoError(t, err) t.Logf("Chaos test: %d total attempts, %d successes", @@ -418,9 +408,7 @@ func TestMultipartUploader_PartialFailures_Recovery(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case r.URL.RawQuery == uploadsPath: - response := InitiateMultipartUploadResult{ - Bucket: testBucketName, - Key: testObjectName, + response := xmlInitiateResponse{ UploadID: "partial-fail-upload-id", } xmlData, _ := xml.Marshal(response) @@ -451,15 +439,8 @@ func TestMultipartUploader_PartialFailures_Recovery(t *testing.T) { } }) - config := RetryConfig{ - MaxAttempts: maxAttempts, - InitialBackoff: 5 * time.Millisecond, - MaxBackoff: 50 * time.Millisecond, - BackoffMultiplier: 2, - } - - uploader := createTestMultipartUploader(t, handler, config) - _, err = uploader.UploadFileInParallel(t.Context(), testFile, 5) + uploader := createTestUploader(t, handler, retryConfig{maxAttempts: maxAttempts, initialBackoff: 5 * time.Millisecond, maxBackoff: 50 * time.Millisecond, multiplier: 2}) + _, err = uploader.Upload(t.Context(), readTestFile(t, testFile), 5) require.NoError(t, err) // Verify that all parts eventually succeeded after retries @@ -484,9 +465,7 @@ func TestMultipartUploader_EdgeCases_EmptyFile(t *testing.T) { switch { case r.URL.RawQuery == uploadsPath: atomic.AddInt32(&initiateCalls, 1) - response := InitiateMultipartUploadResult{ - Bucket: testBucketName, - Key: testObjectName, + response := xmlInitiateResponse{ UploadID: "empty-file-upload-id", } xmlData, _ := xml.Marshal(response) @@ -507,8 +486,8 @@ func TestMultipartUploader_EdgeCases_EmptyFile(t *testing.T) { } }) - uploader := createTestMultipartUploader(t, handler) - _, err = uploader.UploadFileInParallel(t.Context(), emptyFile, 5) + uploader := createTestUploader(t, handler) + _, err = uploader.Upload(t.Context(), readTestFile(t, emptyFile), 5) require.NoError(t, err) require.Equal(t, int32(1), atomic.LoadInt32(&initiateCalls)) @@ -529,9 +508,7 @@ func TestMultipartUploader_EdgeCases_VerySmallFile(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case r.URL.RawQuery == uploadsPath: - response := InitiateMultipartUploadResult{ - Bucket: testBucketName, - Key: testObjectName, + response := xmlInitiateResponse{ UploadID: "small-file-upload-id", } xmlData, _ := xml.Marshal(response) @@ -550,8 +527,8 @@ func TestMultipartUploader_EdgeCases_VerySmallFile(t *testing.T) { } }) - uploader := createTestMultipartUploader(t, handler) - _, err = uploader.UploadFileInParallel(t.Context(), smallFile, 10) // High concurrency for small file + uploader := createTestUploader(t, handler) + _, err = uploader.Upload(t.Context(), readTestFile(t, smallFile), 10) require.NoError(t, err) require.Equal(t, smallContent, receivedData) } @@ -590,9 +567,9 @@ func TestMultipartUploader_ResourceExhaustion_TooManyConcurrentUploads(t *testin testFile := filepath.Join(tempDir, "resource.txt") file, err := os.Create(testFile) require.NoError(t, err) - count, err := io.Copy(file, newRepeatReader('a', gcpMultipartUploadChunkSize*totalChunks)) + count, err := io.Copy(file, newRepeatReader('a', ChunkSize*totalChunks)) require.NoError(t, err) - assert.GreaterOrEqual(t, count, int64(gcpMultipartUploadChunkSize*totalChunks)) + assert.GreaterOrEqual(t, count, int64(ChunkSize*totalChunks)) err = file.Close() require.NoError(t, err) @@ -602,9 +579,7 @@ func TestMultipartUploader_ResourceExhaustion_TooManyConcurrentUploads(t *testin handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case r.URL.RawQuery == uploadsPath: - response := InitiateMultipartUploadResult{ - Bucket: testBucketName, - Key: testObjectName, + response := xmlInitiateResponse{ UploadID: "resource-test-upload-id", } xmlData, _ := xml.Marshal(response) @@ -635,10 +610,10 @@ func TestMultipartUploader_ResourceExhaustion_TooManyConcurrentUploads(t *testin } }) - uploader := createTestMultipartUploader(t, handler) + uploader := createTestUploader(t, handler) // Try with extremely high concurrency - _, err = uploader.UploadFileInParallel(t.Context(), testFile, 1000) + _, err = uploader.Upload(t.Context(), readTestFile(t, testFile), 1000) require.NoError(t, err) // Should have observed significant concurrency but not necessarily 1000 @@ -652,7 +627,7 @@ func TestMultipartUploader_BoundaryConditions_ExactChunkSize(t *testing.T) { tempDir := t.TempDir() testFile := filepath.Join(tempDir, "exact.txt") // Create file that's exactly 2 chunks - testContent := strings.Repeat("x", gcpMultipartUploadChunkSize*2) + testContent := strings.Repeat("x", ChunkSize*2) err := os.WriteFile(testFile, []byte(testContent), 0o644) require.NoError(t, err) @@ -662,9 +637,7 @@ func TestMultipartUploader_BoundaryConditions_ExactChunkSize(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case r.URL.RawQuery == uploadsPath: - response := InitiateMultipartUploadResult{ - Bucket: testBucketName, - Key: testObjectName, + response := xmlInitiateResponse{ UploadID: "boundary-upload-id", } xmlData, _ := xml.Marshal(response) @@ -686,25 +659,14 @@ func TestMultipartUploader_BoundaryConditions_ExactChunkSize(t *testing.T) { } }) - uploader := createTestMultipartUploader(t, handler) - _, err = uploader.UploadFileInParallel(t.Context(), testFile, 5) + uploader := createTestUploader(t, handler) + _, err = uploader.Upload(t.Context(), readTestFile(t, testFile), 5) require.NoError(t, err) // Should have exactly 2 parts, each of ChunkSize require.Len(t, partSizes, 2) - require.Equal(t, gcpMultipartUploadChunkSize, partSizes[0]) - require.Equal(t, gcpMultipartUploadChunkSize, partSizes[1]) -} - -func TestMultipartUploader_FileNotFound_Error(t *testing.T) { - t.Parallel() - uploader := createTestMultipartUploader(t, func(http.ResponseWriter, *http.Request) { - t.Error("Should not make any HTTP requests for missing file") - }) - - _, err := uploader.UploadFileInParallel(t.Context(), "/nonexistent/file.txt", 5) - require.Error(t, err) - require.Contains(t, err.Error(), "failed to open file") + require.Equal(t, ChunkSize, partSizes[0]) + require.Equal(t, ChunkSize, partSizes[1]) } func TestMultipartUploader_ConcurrentRetries_RaceCondition(t *testing.T) { @@ -723,9 +685,7 @@ func TestMultipartUploader_ConcurrentRetries_RaceCondition(t *testing.T) { switch { case r.URL.RawQuery == uploadsPath: - response := InitiateMultipartUploadResult{ - Bucket: testBucketName, - Key: testObjectName, + response := xmlInitiateResponse{ UploadID: "race-upload-id", } xmlData, _ := xml.Marshal(response) @@ -757,15 +717,8 @@ func TestMultipartUploader_ConcurrentRetries_RaceCondition(t *testing.T) { } }) - config := RetryConfig{ - MaxAttempts: 5, - InitialBackoff: 1 * time.Millisecond, // Very fast retries to increase race probability - MaxBackoff: 10 * time.Millisecond, - BackoffMultiplier: 2, - } - - uploader := createTestMultipartUploader(t, handler, config) - _, err = uploader.UploadFileInParallel(t.Context(), testFile, 20) // High concurrency + uploader := createTestUploader(t, handler, retryConfig{maxAttempts: 5, initialBackoff: 1 * time.Millisecond, maxBackoff: 10 * time.Millisecond, multiplier: 2}) + _, err = uploader.Upload(t.Context(), readTestFile(t, testFile), 20) require.NoError(t, err) t.Logf("Total HTTP requests made: %d", atomic.LoadInt32(&totalRequests)) @@ -778,196 +731,3 @@ func TestMultipartUploader_ConcurrentRetries_RaceCondition(t *testing.T) { return true }) } - -// TestCreateRetryableClient_JitterBehavior tests that the jittered backoff works correctly -func TestCreateRetryableClient_JitterBehavior(t *testing.T) { - t.Parallel() - config := RetryConfig{ - MaxAttempts: 3, - InitialBackoff: 100 * time.Millisecond, - MaxBackoff: 1 * time.Second, - BackoffMultiplier: 2.0, - } - - client := createRetryableClient(t.Context(), config) - require.NotNil(t, client) - require.NotNil(t, client.Backoff) - - // Test jitter produces values within expected range - t.Run("JitterRange", func(t *testing.T) { - t.Parallel() - // Test first attempt (attemptNum = 0) - for range 10 { - backoff := client.Backoff(config.InitialBackoff, config.MaxBackoff, 0, nil) - require.GreaterOrEqual(t, backoff, time.Duration(0)) - require.Less(t, backoff, config.InitialBackoff) - } - - // Test second attempt (attemptNum = 1) - should be jittered version of 200ms - expectedBase := time.Duration(float64(config.InitialBackoff) * config.BackoffMultiplier) - for range 10 { - backoff := client.Backoff(config.InitialBackoff, config.MaxBackoff, 1, nil) - require.GreaterOrEqual(t, backoff, time.Duration(0)) - require.Less(t, backoff, expectedBase) - } - }) - - // Test that jitter produces different values (randomness) - t.Run("JitterRandomness", func(t *testing.T) { - t.Parallel() - values := make(map[time.Duration]bool) - - // Collect 20 jittered values - for range 20 { - backoff := client.Backoff(config.InitialBackoff, config.MaxBackoff, 1, nil) - values[backoff] = true - } - - // Should have at least some variation (not all the same value) - // With a range of 0-200ms, getting 20 identical values is highly unlikely - require.Greater(t, len(values), 1, "Jitter should produce varied values") - }) - - // Test exponential backoff base calculation (before jitter) - t.Run("ExponentialBackoffBase", func(t *testing.T) { - t.Parallel() - // We can't directly test the base calculation due to jitter, - // but we can verify the max possible value matches our expectation - - // For attemptNum=0: base should be 100ms, jitter: 0-100ms - // For attemptNum=1: base should be 200ms, jitter: 0-200ms - // For attemptNum=2: base should be 400ms, jitter: 0-400ms - - // Test attempt 2 multiple times and verify max range - var maxSeen time.Duration - for range 100 { - backoff := client.Backoff(config.InitialBackoff, config.MaxBackoff, 2, nil) - if backoff > maxSeen { - maxSeen = backoff - } - } - - expectedBase := time.Duration(float64(config.InitialBackoff) * config.BackoffMultiplier * config.BackoffMultiplier) - // The max we should ever see is just under the expected base (due to jitter being 0 to base-1) - require.Less(t, maxSeen, expectedBase) - // But we should see values reasonably close to the base in 100 attempts - require.Greater(t, maxSeen, expectedBase/2) - }) - - // Test max backoff cap - t.Run("MaxBackoffCap", func(t *testing.T) { - t.Parallel() - // With high attempt numbers, backoff should be capped at MaxBackoff - for range 10 { - backoff := client.Backoff(config.InitialBackoff, config.MaxBackoff, 10, nil) - require.GreaterOrEqual(t, backoff, time.Duration(0)) - require.Less(t, backoff, config.MaxBackoff) - } - }) -} - -// TestCreateRetryableClient_Configuration tests the retry client configuration -func TestCreateRetryableClient_Configuration(t *testing.T) { - t.Parallel() - config := RetryConfig{ - MaxAttempts: 5, - InitialBackoff: 50 * time.Millisecond, - MaxBackoff: 2 * time.Second, - BackoffMultiplier: 3.0, - } - - client := createRetryableClient(t.Context(), config) - - // Verify retry configuration - require.Equal(t, config.MaxAttempts-1, client.RetryMax) // go-retryablehttp counts retries, not total attempts - require.Equal(t, config.InitialBackoff, client.RetryWaitMin) - require.Equal(t, config.MaxBackoff, client.RetryWaitMax) - require.NotNil(t, client.Logger) - require.NotNil(t, client.Backoff) -} - -// TestCreateRetryableClient_ZeroBackoff tests edge case of zero backoff -func TestCreateRetryableClient_ZeroBackoff(t *testing.T) { - t.Parallel() - config := RetryConfig{ - MaxAttempts: 2, - InitialBackoff: 0, // Zero initial backoff - MaxBackoff: 1 * time.Second, - BackoffMultiplier: 2.0, - } - - client := createRetryableClient(t.Context(), config) - - // With zero initial backoff, jitter should also return zero - backoff := client.Backoff(config.InitialBackoff, config.MaxBackoff, 0, nil) - require.Equal(t, time.Duration(0), backoff) -} - -// TestRetryableClient_ActualRetryBehavior tests the retry behavior in practice -func TestRetryableClient_ActualRetryBehavior(t *testing.T) { - t.Parallel() - var requestCount int32 - var retryDelays []time.Duration - var retryTimes []time.Time - var retryMu sync.Mutex - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - count := atomic.AddInt32(&requestCount, 1) - retryMu.Lock() - retryTimes = append(retryTimes, time.Now()) - retryMu.Unlock() - - if count < 3 { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("server error")) - } else { - w.WriteHeader(http.StatusOK) - w.Write([]byte("success")) - } - })) - defer server.Close() - - config := RetryConfig{ - MaxAttempts: 3, - InitialBackoff: 50 * time.Millisecond, - MaxBackoff: 500 * time.Millisecond, - BackoffMultiplier: 2.0, - } - - client := createRetryableClient(t.Context(), config) - client.HTTPClient = server.Client() - - startTime := time.Now() - req, err := retryablehttp.NewRequestWithContext(t.Context(), "GET", server.URL, nil) - require.NoError(t, err) - - resp, err := client.Do(req) - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) - resp.Body.Close() - - // Should have made 3 requests (initial + 2 retries) - require.Equal(t, int32(3), atomic.LoadInt32(&requestCount)) - require.Len(t, retryTimes, 3) - - // Calculate actual delays between requests - for i := 1; i < len(retryTimes); i++ { - delay := retryTimes[i].Sub(retryTimes[i-1]) - retryDelays = append(retryDelays, delay) - } - - // Verify we had some delays due to backoff (but jittered, so variable) - require.Len(t, retryDelays, 2) - - // First retry delay should be jittered version of 50ms (0-50ms range) - // But in practice, with network overhead, it might be slightly higher - require.Greater(t, retryDelays[0], time.Duration(0)) - require.Less(t, retryDelays[0], 200*time.Millisecond) // Allow some overhead - - // Second retry delay should be jittered version of 100ms (0-100ms range) - require.Greater(t, retryDelays[1], time.Duration(0)) - require.Less(t, retryDelays[1], 300*time.Millisecond) // Allow some overhead - - totalTime := time.Since(startTime) - t.Logf("Total time: %v, Retry delays: %v", totalTime, retryDelays) -} diff --git a/packages/shared/pkg/storage/mocks/mockseekableobjectprovider.go b/packages/shared/pkg/storage/mocks/mockseekableobjectprovider.go index 3931f6b349..1800270de8 100644 --- a/packages/shared/pkg/storage/mocks/mockseekableobjectprovider.go +++ b/packages/shared/pkg/storage/mocks/mockseekableobjectprovider.go @@ -244,6 +244,63 @@ func (_c *MockSeekable_Size_Call) RunAndReturn(run func(ctx context.Context) (in return _c } +// StoreData provides a mock function for the type MockSeekable +func (_mock *MockSeekable) StoreData(ctx context.Context, data []byte) error { + ret := _mock.Called(ctx, data) + + if len(ret) == 0 { + panic("no return value specified for StoreData") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, []byte) error); ok { + r0 = returnFunc(ctx, data) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockSeekable_StoreData_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'StoreData' +type MockSeekable_StoreData_Call struct { + *mock.Call +} + +// StoreData is a helper method to define mock.On call +// - ctx context.Context +// - data []byte +func (_e *MockSeekable_Expecter) StoreData(ctx interface{}, data interface{}) *MockSeekable_StoreData_Call { + return &MockSeekable_StoreData_Call{Call: _e.mock.On("StoreData", ctx, data)} +} + +func (_c *MockSeekable_StoreData_Call) Run(run func(ctx context.Context, data []byte)) *MockSeekable_StoreData_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 []byte + if args[1] != nil { + arg1 = args[1].([]byte) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockSeekable_StoreData_Call) Return(err error) *MockSeekable_StoreData_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockSeekable_StoreData_Call) RunAndReturn(run func(ctx context.Context, data []byte) error) *MockSeekable_StoreData_Call { + _c.Call.Return(run) + return _c +} + // StoreFile provides a mock function for the type MockSeekable func (_mock *MockSeekable) StoreFile(ctx context.Context, path string) error { ret := _mock.Called(ctx, path) diff --git a/packages/shared/pkg/storage/storage.go b/packages/shared/pkg/storage/storage.go index 3ba75e84d1..77c5edb60b 100644 --- a/packages/shared/pkg/storage/storage.go +++ b/packages/shared/pkg/storage/storage.go @@ -101,8 +101,8 @@ type StreamingReader interface { } type SeekableWriter interface { - // Store entire file StoreFile(ctx context.Context, path string) error + StoreData(ctx context.Context, data []byte) error } type Seekable interface { diff --git a/packages/shared/pkg/storage/storage_aws.go b/packages/shared/pkg/storage/storage_aws.go index 189e1cd501..821716a70f 100644 --- a/packages/shared/pkg/storage/storage_aws.go +++ b/packages/shared/pkg/storage/storage_aws.go @@ -192,6 +192,10 @@ func (o *awsObject) StoreFile(ctx context.Context, path string) error { return err } +func (o *awsObject) StoreData(ctx context.Context, data []byte) error { + return o.Put(ctx, data) +} + func (o *awsObject) Put(ctx context.Context, data []byte) error { ctx, cancel := context.WithTimeout(ctx, awsWriteTimeout) defer cancel() diff --git a/packages/shared/pkg/storage/storage_cache_seekable.go b/packages/shared/pkg/storage/storage_cache_seekable.go index 7341107392..2cddf2777c 100644 --- a/packages/shared/pkg/storage/storage_cache_seekable.go +++ b/packages/shared/pkg/storage/storage_cache_seekable.go @@ -304,6 +304,16 @@ func (c *cachedSeekable) StoreFile(ctx context.Context, path string) (e error) { return c.inner.StoreFile(ctx, path) } +func (c *cachedSeekable) StoreData(ctx context.Context, data []byte) (e error) { + ctx, span := c.tracer.Start(ctx, "write object from data") + defer func() { + recordError(span, e) + span.End() + }() + + return c.inner.StoreData(ctx, data) +} + func (c *cachedSeekable) goCtx(ctx context.Context, fn func(context.Context)) { c.wg.Go(func() { fn(context.WithoutCancel(ctx)) diff --git a/packages/shared/pkg/storage/storage_fs.go b/packages/shared/pkg/storage/storage_fs.go index 8eb2d3cc13..239d849e74 100644 --- a/packages/shared/pkg/storage/storage_fs.go +++ b/packages/shared/pkg/storage/storage_fs.go @@ -145,6 +145,18 @@ func (o *fsObject) StoreFile(_ context.Context, path string) error { return nil } +func (o *fsObject) StoreData(_ context.Context, data []byte) error { + handle, err := o.getHandle(false) + if err != nil { + return err + } + defer handle.Close() + + _, err = handle.Write(data) + + return err +} + func (o *fsObject) OpenRangeReader(_ context.Context, off, length int64) (io.ReadCloser, error) { f, err := o.getHandle(true) if err != nil { diff --git a/packages/shared/pkg/storage/storage_google.go b/packages/shared/pkg/storage/storage_google.go index 9434963c44..8904aa8ee3 100644 --- a/packages/shared/pkg/storage/storage_google.go +++ b/packages/shared/pkg/storage/storage_google.go @@ -28,6 +28,7 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/env" "github.com/e2b-dev/infra/packages/shared/pkg/limit" "github.com/e2b-dev/infra/packages/shared/pkg/logger" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/gcpmultipart" "github.com/e2b-dev/infra/packages/shared/pkg/telemetry" "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) @@ -385,39 +386,41 @@ func (o *gcpObject) WriteTo(ctx context.Context, dst io.Writer) (int64, error) { } func (o *gcpObject) StoreFile(ctx context.Context, path string) (e error) { - ctx, span := tracer.Start(ctx, "write to gcp from file system") + ctx, span := tracer.Start(ctx, "store file to gcp") defer func() { recordError(span, e) span.End() }() - bucketName := o.storage.bucket.BucketName() - objectName := o.path - - fileInfo, err := os.Stat(path) + data, err := os.ReadFile(path) if err != nil { - return fmt.Errorf("failed to get file size: %w", err) + return fmt.Errorf("failed to read file: %w", err) } - // If the file is too small, the overhead of writing in parallel isn't worth the effort. - // Write it in one shot instead. - if fileInfo.Size() < gcpMultipartUploadChunkSize { + return o.storeData(ctx, data) +} + +func (o *gcpObject) StoreData(ctx context.Context, data []byte) (e error) { + ctx, span := tracer.Start(ctx, "store data to gcp") + defer func() { + recordError(span, e) + span.End() + }() + + return o.storeData(ctx, data) +} + +func (o *gcpObject) storeData(ctx context.Context, data []byte) error { + if int64(len(data)) < gcpmultipart.ChunkSize { timer := googleWriteTimerFactory.Begin( attribute.String(gcsOperationAttr, gcsOperationAttrWriteFromFileSystemOneShot), ) - data, err := os.ReadFile(path) - if err != nil { - timer.Failure(ctx, 0) - - return fmt.Errorf("failed to read file: %w", err) - } - - err = o.Put(ctx, data) + err := o.Put(ctx, data) if err != nil { timer.Failure(ctx, int64(len(data))) - return fmt.Errorf("failed to write file (%d bytes): %w", len(data), err) + return fmt.Errorf("failed to write data (%d bytes): %w", len(data), err) } timer.Success(ctx, int64(len(data))) @@ -445,12 +448,7 @@ func (o *gcpObject) StoreFile(ctx context.Context, path string) (e error) { maxConcurrency = o.limiter.GCloudMaxTasks(ctx) } - uploader, err := NewMultipartUploaderWithRetryConfig( - ctx, - bucketName, - objectName, - DefaultRetryConfig(), - ) + uploader, err := gcpmultipart.NewUploader(ctx, o.storage.bucket.BucketName(), o.path) if err != nil { timer.Failure(ctx, 0) @@ -458,20 +456,18 @@ func (o *gcpObject) StoreFile(ctx context.Context, path string) (e error) { } start := time.Now() - count, err := uploader.UploadFileInParallel(ctx, path, maxConcurrency) + count, err := uploader.Upload(ctx, data, maxConcurrency) if err != nil { timer.Failure(ctx, count) - return fmt.Errorf("failed to upload file in parallel: %w", err) + return fmt.Errorf("failed to upload in parallel: %w", err) } - logger.L().Debug(ctx, "Uploaded file in parallel", - zap.String("bucket", bucketName), - zap.String("object", objectName), - zap.String("path", path), + logger.L().Debug(ctx, "Uploaded data in parallel", + zap.String("object", o.path), zap.Int("max_concurrency", maxConcurrency), - zap.Int64("file_size", fileInfo.Size()), - zap.Int64("duration", time.Since(start).Milliseconds()), + zap.Int64("size", int64(len(data))), + zap.Int64("duration_ms", time.Since(start).Milliseconds()), ) timer.Success(ctx, count)