diff --git a/pkg/asset/installconfig/azure/session.go b/pkg/asset/installconfig/azure/session.go index 33cbc02ff0b..67d50808d6d 100644 --- a/pkg/asset/installconfig/azure/session.go +++ b/pkg/asset/installconfig/azure/session.go @@ -4,11 +4,14 @@ import ( "encoding/json" "errors" "fmt" + "io" "io/fs" + "net/http" "os" "path/filepath" "strings" "sync" + "time" "github.com/AlecAivazis/survey/v2" "github.com/Azure/azure-sdk-for-go/sdk/azcore" @@ -107,6 +110,13 @@ func GetSessionWithCredentials(cloudName azure.CloudEnvironment, armEndpoint str default: cred, err = newTokenCredentialFromMSI(credentials, *cloudConfig) authType = ManagedIdentityAuth + if err == nil && credentials.ClientID == "" { + if clientID, imdsErr := getSystemAssignedClientID(); imdsErr != nil { + logrus.Warnf("Failed to retrieve system-assigned identity client ID from IMDS: %v", imdsErr) + } else { + credentials.ClientID = clientID + } + } } if err != nil { return nil, err @@ -393,3 +403,47 @@ func endpointToScope(endpoint string) string { } return endpoint } + +const imdsTokenEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token" + +// getSystemAssignedClientID queries IMDS to retrieve the client ID +// of the system-assigned managed identity on the current Azure VM. +func getSystemAssignedClientID() (string, error) { + req, err := http.NewRequest(http.MethodGet, imdsTokenEndpoint, nil) + if err != nil { + return "", fmt.Errorf("failed to create IMDS request: %w", err) + } + req.Header.Set("Metadata", "true") + q := req.URL.Query() + q.Set("api-version", "2018-02-01") + q.Set("resource", "https://management.azure.com/") + req.URL.RawQuery = q.Encode() + + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("IMDS request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("IMDS returned status %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read IMDS response: %w", err) + } + + var tokenResp struct { + ClientID string `json:"client_id"` + } + if err := json.Unmarshal(body, &tokenResp); err != nil { + return "", fmt.Errorf("failed to parse IMDS response: %w", err) + } + if tokenResp.ClientID == "" { + return "", fmt.Errorf("IMDS response did not contain a client_id") + } + + return tokenResp.ClientID, nil +}