diff --git a/acr-cli b/acr-cli new file mode 100755 index 00000000..4960e1e1 Binary files /dev/null and b/acr-cli differ diff --git a/acr/acr b/acr/acr new file mode 100755 index 00000000..8d5cef9b Binary files /dev/null and b/acr/acr differ diff --git a/cmd/acr/purge.go b/cmd/acr/purge.go index 837b1518..8526ccad 100644 --- a/cmd/acr/purge.go +++ b/cmd/acr/purge.go @@ -7,8 +7,10 @@ import ( "context" "fmt" "net/http" + "os" "runtime" "sort" + "strconv" "strings" "time" @@ -228,34 +230,75 @@ func purge(ctx context.Context, dryRun bool, includeLocked bool) (deletedTagsCount int, deletedManifestsCount int, err error) { - // In order to print a summary of the deleted tags/manifests the counters get updated everytime a repo is purged. - for repoName, tagRegex := range tagFilters { - var singleDeletedTagsCount int - var manifestToTagsCountMap map[string]int - - // Handle tag deletion based on mode - if untaggedOnly { - // Initialize empty map for untagged-only mode (no tag deletion) - manifestToTagsCountMap = make(map[string]int) - } else { - // Standard mode: delete matching tags first - singleDeletedTagsCount, manifestToTagsCountMap, err = purgeTags(ctx, acrClient, repoParallelism, loginURL, repoName, agoDuration, tagRegex, keep, filterTimeout, dryRun, includeLocked) - if err != nil { - return deletedTagsCount, deletedManifestsCount, fmt.Errorf("failed to purge tags: %w", err) + // Load ABAC batch size from environment variable + abacBatchSize := 10 // default + if envVal, exists := os.LookupEnv("ABAC_BATCH_SIZE"); exists { + if parsed, err := strconv.Atoi(envVal); err == nil && parsed > 0 { + abacBatchSize = parsed + } + } + + // Collect all repository names into a slice for batching + repos := make([]string, 0, len(tagFilters)) + for repoName := range tagFilters { + repos = append(repos, repoName) + } + + // Process repositories in batches of abacBatchSize. + // For ABAC-enabled registries, we refresh the token per batch. + // For non-ABAC registries, the batching loop is harmless (no token refresh needed). + for i := 0; i < len(repos); i += abacBatchSize { + end := i + abacBatchSize + if end > len(repos) { + end = len(repos) + } + batch := repos[i:end] + + // For ABAC registries, request a token that covers all repositories in this batch + if acrClient.IsAbac() { + if err := acrClient.RefreshTokenForAbac(ctx, batch); err != nil { + return deletedTagsCount, deletedManifestsCount, fmt.Errorf("failed to authorize ABAC repositories batch: %w", err) } } - singleDeletedManifestsCount := 0 - // If the untagged flag is set or untagged-only mode is enabled, delete manifests - if removeUntaggedManifests { - singleDeletedManifestsCount, err = purgeDanglingManifests(ctx, acrClient, repoParallelism, loginURL, repoName, agoDuration, keep, manifestToTagsCountMap, dryRun, includeLocked) - if err != nil { - return deletedTagsCount, deletedManifestsCount, fmt.Errorf("failed to purge manifests: %w", err) + // Process all repositories in this batch + for j, repoName := range batch { + // For ABAC registries, check if token expired and refresh for remaining repos in batch + if acrClient.IsAbac() && acrClient.IsTokenExpired() { + remainingRepos := batch[j:] + fmt.Printf("ABAC token expired, refreshing for remaining repositories: %v\n", remainingRepos) + if err := acrClient.RefreshTokenForAbac(ctx, remainingRepos); err != nil { + return deletedTagsCount, deletedManifestsCount, fmt.Errorf("failed to refresh ABAC token: %w", err) + } + } + tagRegex := tagFilters[repoName] + var singleDeletedTagsCount int + var manifestToTagsCountMap map[string]int + + // Handle tag deletion based on mode + if untaggedOnly { + // Initialize empty map for untagged-only mode (no tag deletion) + manifestToTagsCountMap = make(map[string]int) + } else { + // Standard mode: delete matching tags first + singleDeletedTagsCount, manifestToTagsCountMap, err = purgeTags(ctx, acrClient, repoParallelism, loginURL, repoName, agoDuration, tagRegex, keep, filterTimeout, dryRun, includeLocked) + if err != nil { + return deletedTagsCount, deletedManifestsCount, fmt.Errorf("failed to purge tags: %w", err) + } + } + + singleDeletedManifestsCount := 0 + // If the untagged flag is set or untagged-only mode is enabled, delete manifests + if removeUntaggedManifests { + singleDeletedManifestsCount, err = purgeDanglingManifests(ctx, acrClient, repoParallelism, loginURL, repoName, agoDuration, keep, manifestToTagsCountMap, dryRun, includeLocked) + if err != nil { + return deletedTagsCount, deletedManifestsCount, fmt.Errorf("failed to purge manifests: %w", err) + } } + // After every repository is purged the counters are updated. + deletedTagsCount += singleDeletedTagsCount + deletedManifestsCount += singleDeletedManifestsCount } - // After every repository is purged the counters are updated. - deletedTagsCount += singleDeletedTagsCount - deletedManifestsCount += singleDeletedManifestsCount } return deletedTagsCount, deletedManifestsCount, nil diff --git a/cmd/acr/purge_test.go b/cmd/acr/purge_test.go index 912d3d7b..30a08443 100644 --- a/cmd/acr/purge_test.go +++ b/cmd/acr/purge_test.go @@ -552,6 +552,10 @@ func TestDryRun(t *testing.T) { t.Run("RepositoryNotFoundTest", func(t *testing.T) { assert := assert.New(t) mockClient := &mocks.AcrCLIClientInterface{} + // Mock IsAbac to return false (non-ABAC registry) to use standard wildcard token flow + mockClient.On("IsAbac").Return(false) + // Need a .Maybe() since it's only called for ABAC registries (this test mocks IsAbac to return false) + mockClient.On("IsTokenExpired").Return(false).Maybe() mockClient.On("GetAcrManifests", mock.Anything, testRepo, "", "").Return(notFoundManifestResponse, errors.New("testRepo not found")).Once() mockClient.On("GetAcrTags", mock.Anything, testRepo, "timedesc", "").Return(notFoundTagResponse, errors.New("testRepo not found")).Once() deletedTags, deletedManifests, err := purge(testCtx, mockClient, testLoginURL, 60, -24*time.Hour, 0, 1, true, false, map[string]string{testRepo: "[\\s\\S]*"}, true, false) diff --git a/cmd/acr/purge_untagged_only_test.go b/cmd/acr/purge_untagged_only_test.go index 4d59f23c..6cfd083a 100644 --- a/cmd/acr/purge_untagged_only_test.go +++ b/cmd/acr/purge_untagged_only_test.go @@ -25,6 +25,8 @@ func TestPurgeUntaggedOnly(t *testing.T) { t.Run("UntaggedOnlyPurgeManifestsOnly", func(t *testing.T) { assert := assert.New(t) mockClient := &mocks.AcrCLIClientInterface{} + mockClient.On("IsAbac").Return(false) + mockClient.On("IsTokenExpired").Return(false).Maybe() // Setup mock response for manifests without tags manifestDigest := "sha256:abc123" @@ -97,6 +99,8 @@ func TestPurgeUntaggedOnly(t *testing.T) { t.Run("UntaggedOnlyNoFilterAllRepos", func(t *testing.T) { assert := assert.New(t) mockClient := &mocks.AcrCLIClientInterface{} + mockClient.On("IsAbac").Return(false) + mockClient.On("IsTokenExpired").Return(false).Maybe() // We won't test GetRepositories here since the purge function is called // with already-created tagFilters. Instead test that all repos are processed. @@ -149,6 +153,8 @@ func TestPurgeUntaggedOnly(t *testing.T) { t.Run("UntaggedOnlyWithFilter", func(t *testing.T) { assert := assert.New(t) mockClient := &mocks.AcrCLIClientInterface{} + mockClient.On("IsAbac").Return(false) + mockClient.On("IsTokenExpired").Return(false).Maybe() manifestDigest := "sha256:def456" mediaType := "application/vnd.docker.distribution.manifest.v2+json" @@ -218,6 +224,8 @@ func TestPurgeUntaggedOnly(t *testing.T) { t.Run("UntaggedOnlyDryRun", func(t *testing.T) { assert := assert.New(t) mockClient := &mocks.AcrCLIClientInterface{} + mockClient.On("IsAbac").Return(false) + mockClient.On("IsTokenExpired").Return(false).Maybe() manifestDigest := "sha256:ghi789" mediaType := "application/vnd.docker.distribution.manifest.v2+json" @@ -282,6 +290,8 @@ func TestPurgeUntaggedOnly(t *testing.T) { t.Run("UntaggedOnlyWithLockedManifests", func(t *testing.T) { assert := assert.New(t) mockClient := &mocks.AcrCLIClientInterface{} + mockClient.On("IsAbac").Return(false) + mockClient.On("IsTokenExpired").Return(false).Maybe() // Create locked and unlocked untagged manifests lockedDigest := "sha256:locked123" @@ -363,6 +373,8 @@ func TestPurgeUntaggedOnly(t *testing.T) { t.Run("UntaggedOnlyWithIncludeLocked", func(t *testing.T) { assert := assert.New(t) mockClient := &mocks.AcrCLIClientInterface{} + mockClient.On("IsAbac").Return(false) + mockClient.On("IsTokenExpired").Return(false).Maybe() // Create locked untagged manifest lockedDigest := "sha256:locked789" diff --git a/cmd/mocks/AcrCLIClientInterface.go b/cmd/mocks/AcrCLIClientInterface.go index d553d8d5..a517731c 100644 --- a/cmd/mocks/AcrCLIClientInterface.go +++ b/cmd/mocks/AcrCLIClientInterface.go @@ -2,11 +2,15 @@ package mocks -import acr "github.com/Azure/acr-cli/acr" +import ( + acr "github.com/Azure/acr-cli/acr" -import autorest "github.com/Azure/go-autorest/autorest" -import context "context" -import mock "github.com/stretchr/testify/mock" + autorest "github.com/Azure/go-autorest/autorest" + + context "context" + + mock "github.com/stretchr/testify/mock" +) // AcrCLIClientInterface is an autogenerated mock type for the AcrCLIClientInterface type type AcrCLIClientInterface struct { @@ -196,3 +200,45 @@ func (_m *AcrCLIClientInterface) UpdateAcrManifestAttributes(ctx context.Context return r0, r1 } + +// IsAbac provides a mock function that returns whether the registry is ABAC-enabled +func (_m *AcrCLIClientInterface) IsAbac() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// IsTokenExpired provides a mock function for checking if token is expired +func (_m *AcrCLIClientInterface) IsTokenExpired() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// RefreshTokenForAbac provides a mock function for refreshing tokens with specific repository scopes +func (_m *AcrCLIClientInterface) RefreshTokenForAbac(ctx context.Context, repositories []string) error { + ret := _m.Called(ctx, repositories) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, []string) error); ok { + r0 = rf(ctx, repositories) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/internal/api/acrsdk.go b/internal/api/acrsdk.go index 0b3a7c11..2c68d083 100644 --- a/internal/api/acrsdk.go +++ b/internal/api/acrsdk.go @@ -7,6 +7,7 @@ package api import ( "bytes" "context" + "fmt" "io/ioutil" "strings" "time" @@ -52,6 +53,10 @@ type AcrCLIClient struct { // accessTokenExp refers to the expiration time for the access token, it is in a unix time format represented by a // 64 bit integer. accessTokenExp int64 + // isAbac indicates whether this registry uses Attribute-Based Access Control (ABAC). + // ABAC registries require repository-level permissions instead of registry-wide wildcards. + // This is detected by checking if the refresh token contains the "aad_identity" claim. + isAbac bool } // LoginURL returns the FQDN for a registry. @@ -91,10 +96,29 @@ func newAcrCLIClientWithBasicAuth(loginURL string, username string, password str } // newAcrCLIClientWithBearerAuth creates a client that uses bearer token authentication. +// It detects if the registry is ABAC-enabled by checking for the "aad_identity" claim in the refresh token. +// For ABAC registries, it only requests catalog access initially; repository access is requested per-batch. +// For non-ABAC registries, it requests the traditional wildcard scope for all repositories. func newAcrCLIClientWithBearerAuth(loginURL string, refreshToken string) (AcrCLIClient, error) { + // Detect if this is an ABAC-enabled registry by checking for aad_identity claim + isAbac := hasAadIdentityClaim(refreshToken) + newAcrCLIClient := newAcrCLIClient(loginURL) + newAcrCLIClient.isAbac = isAbac + ctx := context.Background() - accessTokenResponse, err := newAcrCLIClient.AutorestClient.GetAcrAccessToken(ctx, loginURL, "registry:catalog:* repository:*:*", refreshToken) + var scope string + if isAbac { + // For ABAC registries, only request catalog access initially. + // Repository-level access will be requested on-demand per repository or batch. + // This is because ABAC registries cannot grant wildcard repository access. + scope = "registry:catalog:*" + } else { + // For non-ABAC registries, request full wildcard access to all repositories. + scope = "registry:catalog:* repository:*:*" + } + + accessTokenResponse, err := newAcrCLIClient.AutorestClient.GetAcrAccessToken(ctx, loginURL, scope, refreshToken) if err != nil { return newAcrCLIClient, err } @@ -154,6 +178,7 @@ func GetAcrCLIClientWithAuth(loginURL string, username string, password string, } // refreshAcrCLIClientToken obtains a new token and gets its expiration time. +// This uses the wildcard scope and should only be called for non-ABAC registries. func refreshAcrCLIClientToken(ctx context.Context, c *AcrCLIClient) error { accessTokenResponse, err := c.AutorestClient.GetAcrAccessToken(ctx, c.loginURL, "repository:*:*", c.token.RefreshToken) if err != nil { @@ -173,6 +198,72 @@ func refreshAcrCLIClientToken(ctx context.Context, c *AcrCLIClient) error { return nil } +// hasAadIdentityClaim checks if a JWT token contains the "aad_identity" claim. +// The presence of this claim indicates that the registry is ABAC-enabled. +// ABAC (Attribute-Based Access Control) registries grant permissions at the repository level, +// not at the registry level, so wildcard scopes like "repository:*:*" will not work. +func hasAadIdentityClaim(tokenString string) bool { + parser := jwt.Parser{SkipClaimsValidation: true} + mapC := jwt.MapClaims{} + // We only need to check for the claim, not verify the signature + _, _, err := parser.ParseUnverified(tokenString, mapC) + if err != nil { + return false + } + _, ok := mapC["aad_identity"] + return ok +} + +// RefreshTokenForAbac obtains a new access token scoped to specific repositories. +// This is used for ABAC-enabled registries where wildcard repository access is not allowed. +// The token will include permissions for all specified repositories. +// +// Parameters: +// - repositories: list of repository names to request access for +// +// The scope format is: "registry:catalog:* repository::pull repository::delete ..." +// This allows batching multiple repositories into a single token request for efficiency. +func (c *AcrCLIClient) RefreshTokenForAbac(ctx context.Context, repositories []string) error { + if c.token == nil { + return errors.New("no refresh token available for ABAC token refresh") + } + + // Build the scope string for all requested repositories. + // Each repository needs pull, delete, and metadata permissions for purge operations. + // Format: "repository:repo1:pull,delete,metadata_read,metadata_write repository:repo2:pull,delete,metadata_read,metadata_write ..." + var scopeParts []string + for _, repo := range repositories { + scopeParts = append(scopeParts, fmt.Sprintf("repository:%s:pull,delete,metadata_read,metadata_write", repo)) + } + scope := strings.Join(scopeParts, " ") + + accessTokenResponse, err := c.AutorestClient.GetAcrAccessToken(ctx, c.loginURL, scope, c.token.RefreshToken) + if err != nil { + return errors.Wrap(err, "failed to refresh token for ABAC repositories") + } + + token := &adal.Token{ + AccessToken: *accessTokenResponse.AccessToken, + RefreshToken: c.token.RefreshToken, + } + c.token = token + c.AutorestClient.Authorizer = autorest.NewBearerAuthorizer(token) + + exp, err := getExpiration(token.AccessToken) + if err != nil { + return err + } + c.accessTokenExp = exp + + return nil +} + +// IsAbac returns true if this client is connected to an ABAC-enabled registry. +// ABAC registries require repository-level token scopes instead of wildcard scopes. +func (c *AcrCLIClient) IsAbac() bool { + return c.isAbac +} + // getExpiration is used to obtain the expiration out of a jwt token. func getExpiration(token string) (int64, error) { parser := jwt.Parser{SkipClaimsValidation: true} @@ -198,6 +289,13 @@ func (c *AcrCLIClient) isExpired() bool { return (time.Now().Add(5 * time.Minute)).Unix() > c.accessTokenExp } +// IsTokenExpired returns true when the token is expired or close to expiring. +// This is the public version of isExpired for use by callers that need to check +// token expiration before making batched ABAC token refresh requests. +func (c *AcrCLIClient) IsTokenExpired() bool { + return c.isExpired() +} + // GetAcrTags list the tags of a repository with their attributes. func (c *AcrCLIClient) GetAcrTags(ctx context.Context, repoName string, orderBy string, last string) (*acrapi.RepositoryTagsType, error) { if c.isExpired() { @@ -348,4 +446,11 @@ type AcrCLIClientInterface interface { GetAcrManifestAttributes(ctx context.Context, repoName string, reference string) (*acrapi.ManifestAttributes, error) UpdateAcrTagAttributes(ctx context.Context, repoName string, reference string, value *acrapi.ChangeableAttributes) (*autorest.Response, error) UpdateAcrManifestAttributes(ctx context.Context, repoName string, reference string, value *acrapi.ChangeableAttributes) (*autorest.Response, error) + + // IsAbac returns true if the registry uses Attribute-Based Access Control. + IsAbac() bool + // IsTokenExpired returns true if the access token is expired or close to expiring. + IsTokenExpired() bool + // RefreshTokenForAbac refreshes the access token with scopes for specific repositories. + RefreshTokenForAbac(ctx context.Context, repositories []string) error } diff --git a/internal/api/acrsdk_test.go b/internal/api/acrsdk_test.go index 60224873..31312727 100644 --- a/internal/api/acrsdk_test.go +++ b/internal/api/acrsdk_test.go @@ -234,3 +234,84 @@ func TestGetAcrCLIClientWithAuth(t *testing.T) { }) } } + +// TestHasAadIdentityClaim tests the ABAC detection function +func TestHasAadIdentityClaim(t *testing.T) { + tests := []struct { + name string + token string + expected bool + }{ + { + name: "token with aad_identity claim - ABAC enabled", + // JWT with {"aad_identity": "user@example.com"} in payload + token: strings.Join([]string{ + base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`)), + base64.RawURLEncoding.EncodeToString([]byte(`{"exp":1563910981,"aad_identity":"user@example.com"}`)), + "", + }, "."), + expected: true, + }, + { + name: "token without aad_identity claim - non-ABAC", + // JWT without aad_identity + token: strings.Join([]string{ + base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`)), + base64.RawURLEncoding.EncodeToString([]byte(`{"exp":1563910981}`)), + "", + }, "."), + expected: false, + }, + { + name: "invalid token", + token: "not-a-valid-jwt", + expected: false, + }, + { + name: "empty token", + token: "", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := hasAadIdentityClaim(tt.token) + if result != tt.expected { + t.Errorf("hasAadIdentityClaim() = %v, expected %v", result, tt.expected) + } + }) + } +} + +// TestAcrCLIClientIsAbac tests the IsAbac method +func TestAcrCLIClientIsAbac(t *testing.T) { + tests := []struct { + name string + isAbac bool + expected bool + }{ + { + name: "ABAC enabled client", + isAbac: true, + expected: true, + }, + { + name: "non-ABAC client", + isAbac: false, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := AcrCLIClient{ + isAbac: tt.isAbac, + } + result := client.IsAbac() + if result != tt.expected { + t.Errorf("IsAbac() = %v, expected %v", result, tt.expected) + } + }) + } +}