diff --git a/client/fetcher.go b/client/fetcher.go index f722cca0c..774c62ef8 100644 --- a/client/fetcher.go +++ b/client/fetcher.go @@ -116,18 +116,27 @@ type FileFetcher struct { Root string } -func (f FileFetcher) ReadCheckpoint(_ context.Context) ([]byte, error) { +func (f FileFetcher) ReadCheckpoint(ctx context.Context) ([]byte, error) { + if err := ctx.Err(); err != nil { + return nil, err + } return os.ReadFile(path.Join(f.Root, layout.CheckpointPath)) } func (f FileFetcher) ReadTile(ctx context.Context, l, i uint64, p uint8) ([]byte, error) { return fetcher.PartialOrFullResource(ctx, p, func(ctx context.Context, p uint8) ([]byte, error) { + if err := ctx.Err(); err != nil { + return nil, err + } return os.ReadFile(path.Join(f.Root, layout.TilePath(l, i, p))) }) } func (f FileFetcher) ReadEntryBundle(ctx context.Context, i uint64, p uint8) ([]byte, error) { return fetcher.PartialOrFullResource(ctx, p, func(ctx context.Context, p uint8) ([]byte, error) { + if err := ctx.Err(); err != nil { + return nil, err + } return os.ReadFile(path.Join(f.Root, layout.EntriesPath(i, p))) }) } diff --git a/client/fetcher_test.go b/client/fetcher_test.go new file mode 100644 index 000000000..3c0c69d3c --- /dev/null +++ b/client/fetcher_test.go @@ -0,0 +1,33 @@ +package client + +import ( + "context" + "errors" + "testing" +) + +func TestFileFetcherContextCancellation(t *testing.T) { + d := t.TempDir() + + f := FileFetcher{ + Root: d, + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + _, err := f.ReadCheckpoint(ctx) + if !errors.Is(err, context.Canceled) { + t.Errorf("ReadCheckpoint: got error %v, want %v", err, context.Canceled) + } + + _, err = f.ReadTile(ctx, 0, 0, 255) + if !errors.Is(err, context.Canceled) { + t.Errorf("ReadTile: got error %v, want %v", err, context.Canceled) + } + + _, err = f.ReadEntryBundle(ctx, 0, 255) + if !errors.Is(err, context.Canceled) { + t.Errorf("ReadEntryBundle: got error %v, want %v", err, context.Canceled) + } +} diff --git a/fsck/status_test.go b/fsck/status_test.go index fa897b88a..ae4dc74f6 100644 --- a/fsck/status_test.go +++ b/fsck/status_test.go @@ -171,6 +171,7 @@ func TestUpdate(t *testing.T) { for i, want := range test.wantRanges { if p == nil { t.Fatalf("got %d entry ranges, want %d", i-1, len(test.wantRanges)) + return // To make the linter happy } got := p.Value.(*Range) if !reflect.DeepEqual(*got, want) { diff --git a/internal/fetcher/fallback.go b/internal/fetcher/fallback.go index a4d588fd5..bb2331016 100644 --- a/internal/fetcher/fallback.go +++ b/internal/fetcher/fallback.go @@ -38,7 +38,7 @@ func PartialOrFullResource(ctx context.Context, p uint8, f func(context.Context, } return sRaw, nil case err != nil: - return sRaw, fmt.Errorf("failed to fetch resource: %v", err) + return sRaw, fmt.Errorf("failed to fetch resource: %w", err) default: return sRaw, nil }