From e80e4868832c755475d0db644ee077466438fdf9 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 9 Mar 2026 14:41:50 -0700 Subject: [PATCH 01/13] Revamp AuthMetadataService to match flyteadmin implementation Restructure auth config to match flyteadmin's shape with support for self-hosted and external authorization server modes. The self mode builds OAuth2 metadata from relative URLs based on AuthorizedURIs, while the external mode fetches metadata from .well-known/oauth-authorization-server with retry logic, HTTP proxy support, and token endpoint proxy rewriting. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Kevin Su --- runs/config/config.go | 101 +++++++++ runs/service/auth_metadata.go | 204 +++++++++++++++++ runs/service/auth_metadata_test.go | 344 +++++++++++++++++++++++++++++ runs/setup.go | 6 + 4 files changed, 655 insertions(+) create mode 100644 runs/service/auth_metadata.go create mode 100644 runs/service/auth_metadata_test.go diff --git a/runs/config/config.go b/runs/config/config.go index ea667934b71..594cbb4c463 100644 --- a/runs/config/config.go +++ b/runs/config/config.go @@ -38,6 +38,107 @@ type Config struct { // StoragePrefix is the base URI for storing run data (inputs, outputs) // e.g. "s3://my-bucket" or "gs://my-bucket" or "file:///tmp/flyte/data" StoragePrefix string `json:"storagePrefix" pflag:",Base URI prefix for storing run inputs and outputs"` + + // Auth configuration for the AuthMetadataService + Auth AuthConfig `json:"auth"` +} + +// AuthorizationServerType defines the type of authorization server. +type AuthorizationServerType int + +const ( + // AuthorizationServerTypeSelf indicates the service acts as its own authorization server. + AuthorizationServerTypeSelf AuthorizationServerType = iota + // AuthorizationServerTypeExternal indicates an external authorization server is used. + AuthorizationServerTypeExternal +) + +// AuthConfig holds authentication configuration matching flyteadmin's auth config shape. +type AuthConfig struct { + // AuthorizedURIs is the set of URIs clients can access the service on. + // The first entry is used as the public URL for building OAuth2 metadata URLs. + AuthorizedURIs []config.URL `json:"authorizedUris"` + + // GrpcAuthorizationHeader is the authorization metadata key returned to clients. + GrpcAuthorizationHeader string `json:"grpcAuthorizationHeader"` + + // AppAuth defines app-level OAuth2 settings. + AppAuth OAuth2Options `json:"appAuth"` + + // HTTPProxyURL allows accessing external OAuth2 servers through an HTTP proxy. + HTTPProxyURL config.URL `json:"httpProxyURL"` + + // TokenEndpointProxyConfig proxies token endpoint calls through admin. + TokenEndpointProxyConfig TokenEndpointProxyConfig `json:"tokenEndpointProxyConfig"` +} + +// OAuth2Options holds OAuth2 authorization server options. +type OAuth2Options struct { + // AuthServerType determines whether to use a self-hosted or external auth server. + AuthServerType AuthorizationServerType `json:"authServerType"` + + // SelfAuthServer configures the self-hosted authorization server. + SelfAuthServer AuthorizationServer `json:"selfAuthServer"` + + // ExternalAuthServer configures the external authorization server. + ExternalAuthServer ExternalAuthorizationServer `json:"externalAuthServer"` + + // ThirdParty configures third-party (public client) settings. + ThirdParty ThirdPartyConfigOptions `json:"thirdParty"` +} + +// AuthorizationServer configures a self-hosted authorization server. +type AuthorizationServer struct { + // Issuer is the issuer URL. If empty, the first AuthorizedURI is used. + Issuer string `json:"issuer"` +} + +// ExternalAuthorizationServer configures an external authorization server. +type ExternalAuthorizationServer struct { + // BaseURL is the base URL of the external authorization server. + BaseURL config.URL `json:"baseURL"` + + // MetadataEndpointURL overrides the default .well-known/oauth-authorization-server endpoint. + MetadataEndpointURL config.URL `json:"metadataEndpointURL"` + + // RetryAttempts is the number of retry attempts for fetching metadata. + RetryAttempts int `json:"retryAttempts"` + + // RetryDelay is the delay between retry attempts. + RetryDelay config.Duration `json:"retryDelay"` +} + +// ThirdPartyConfigOptions holds third-party OAuth2 client settings. +type ThirdPartyConfigOptions struct { + // FlyteClientConfig holds public client configuration. + FlyteClientConfig FlyteClientConfig `json:"flyteClient"` +} + +// FlyteClientConfig holds the public client configuration. +type FlyteClientConfig struct { + // ClientID is the public client ID. + ClientID string `json:"clientId"` + + // RedirectURI is the redirect URI for the client. + RedirectURI string `json:"redirectUri"` + + // Scopes are the OAuth2 scopes to request. + Scopes []string `json:"scopes"` + + // Audience is the intended audience for OAuth2 tokens. + Audience string `json:"audience"` +} + +// TokenEndpointProxyConfig configures proxying of token endpoint calls. +type TokenEndpointProxyConfig struct { + // Enabled enables token endpoint proxying. + Enabled bool `json:"enabled"` + + // PublicURL is the public URL to use for rewriting the token endpoint. + PublicURL config.URL `json:"publicURL"` + + // PathPrefix is appended to the public URL when rewriting. + PathPrefix string `json:"pathPrefix"` } // ServerConfig holds HTTP server configuration diff --git a/runs/service/auth_metadata.go b/runs/service/auth_metadata.go new file mode 100644 index 00000000000..62bed4213ef --- /dev/null +++ b/runs/service/auth_metadata.go @@ -0,0 +1,204 @@ +package service + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "connectrpc.com/connect" + + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth/authconnect" + "github.com/flyteorg/flyte/v2/runs/config" + + stdlibConfig "github.com/flyteorg/flyte/v2/flytestdlib/config" + "github.com/flyteorg/flyte/v2/flytestdlib/logger" +) + +const ( + oauth2MetadataEndpoint = ".well-known/oauth-authorization-server" + scopeAll = "all" +) + +type authMetadataService struct { + authconnect.UnimplementedAuthMetadataServiceHandler + cfg config.AuthConfig +} + +func NewAuthMetadataService(cfg config.AuthConfig) authconnect.AuthMetadataServiceHandler { + return &authMetadataService{cfg: cfg} +} + +func (s *authMetadataService) GetOAuth2Metadata( + ctx context.Context, + _ *connect.Request[auth.GetOAuth2MetadataRequest], +) (*connect.Response[auth.GetOAuth2MetadataResponse], error) { + switch s.cfg.AppAuth.AuthServerType { + case config.AuthorizationServerTypeExternal: + return s.getOAuth2MetadataExternal(ctx) + default: + return s.getOAuth2MetadataSelf() + } +} + +func (s *authMetadataService) getOAuth2MetadataSelf() (*connect.Response[auth.GetOAuth2MetadataResponse], error) { + publicURL := getPublicURL(s.cfg.AuthorizedURIs) + issuer := getIssuer(s.cfg, publicURL) + + base := strings.TrimRight(publicURL.String(), "/") + + resp := &auth.GetOAuth2MetadataResponse{ + Issuer: issuer, + AuthorizationEndpoint: base + "/oauth2/authorize", + TokenEndpoint: base + "/oauth2/token", + JwksUri: base + "/oauth2/jwks", + CodeChallengeMethodsSupported: []string{"S256"}, + ResponseTypesSupported: []string{"code", "token", "code token"}, + GrantTypesSupported: []string{"client_credentials", "refresh_token", "authorization_code"}, + ScopesSupported: []string{scopeAll}, + TokenEndpointAuthMethodsSupported: []string{"client_secret_basic"}, + } + + return connect.NewResponse(resp), nil +} + +func (s *authMetadataService) getOAuth2MetadataExternal(ctx context.Context) (*connect.Response[auth.GetOAuth2MetadataResponse], error) { + baseURL := s.cfg.AppAuth.ExternalAuthServer.BaseURL + if baseURL.String() == "" { + return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("external auth server base URL is not configured")) + } + + metadataURL := s.cfg.AppAuth.ExternalAuthServer.MetadataEndpointURL + if metadataURL.String() == "" || metadataURL.String() == baseURL.String() { + u := baseURL.URL + u.Path = strings.TrimRight(u.Path, "/") + "/" + oauth2MetadataEndpoint + metadataURL = stdlibConfig.URL{URL: u} + } + + httpClient := &http.Client{} + if s.cfg.HTTPProxyURL.String() != "" { + httpClient.Transport = &http.Transport{ + Proxy: http.ProxyURL(&s.cfg.HTTPProxyURL.URL), + } + } + + retryAttempts := s.cfg.AppAuth.ExternalAuthServer.RetryAttempts + if retryAttempts <= 0 { + retryAttempts = 5 + } + retryDelay := s.cfg.AppAuth.ExternalAuthServer.RetryDelay.Duration + if retryDelay <= 0 { + retryDelay = time.Second + } + + body, err := sendAndRetryHTTPRequest(ctx, httpClient, metadataURL.String(), retryAttempts, retryDelay) + if err != nil { + return nil, connect.NewError(connect.CodeUnavailable, fmt.Errorf("failed to fetch OAuth2 metadata: %w", err)) + } + + resp := &auth.GetOAuth2MetadataResponse{} + if err := json.Unmarshal(body, resp); err != nil { + return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("failed to unmarshal OAuth2 metadata: %w", err)) + } + + if s.cfg.TokenEndpointProxyConfig.Enabled && resp.TokenEndpoint != "" { + proxyURL := s.cfg.TokenEndpointProxyConfig.PublicURL + if proxyURL.String() == "" { + proxyURL = stdlibConfig.URL{URL: *getPublicURL(s.cfg.AuthorizedURIs)} + } + rewritten := strings.TrimRight(proxyURL.String(), "/") + if s.cfg.TokenEndpointProxyConfig.PathPrefix != "" { + rewritten += "/" + strings.Trim(s.cfg.TokenEndpointProxyConfig.PathPrefix, "/") + } + + // Preserve the original token endpoint path + originalURL, parseErr := url.Parse(resp.TokenEndpoint) + if parseErr == nil { + rewritten += originalURL.Path + } + resp.TokenEndpoint = rewritten + } + + return connect.NewResponse(resp), nil +} + +func (s *authMetadataService) GetPublicClientConfig( + _ context.Context, + _ *connect.Request[auth.GetPublicClientConfigRequest], +) (*connect.Response[auth.GetPublicClientConfigResponse], error) { + fc := s.cfg.AppAuth.ThirdParty.FlyteClientConfig + return connect.NewResponse(&auth.GetPublicClientConfigResponse{ + ClientId: fc.ClientID, + RedirectUri: fc.RedirectURI, + Scopes: fc.Scopes, + AuthorizationMetadataKey: s.cfg.GrpcAuthorizationHeader, + Audience: fc.Audience, + }), nil +} + +// getPublicURL returns the first AuthorizedURI as the public URL, or a default localhost URL. +func getPublicURL(authorizedURIs []stdlibConfig.URL) *url.URL { + if len(authorizedURIs) > 0 { + u := authorizedURIs[0].URL + return &u + } + u, _ := url.Parse("http://localhost:8090") + return u +} + +// getIssuer returns the issuer from SelfAuthServer config, or falls back to public URL. +func getIssuer(cfg config.AuthConfig, publicURL *url.URL) string { + if cfg.AppAuth.SelfAuthServer.Issuer != "" { + return cfg.AppAuth.SelfAuthServer.Issuer + } + return strings.TrimRight(publicURL.String(), "/") +} + +// sendAndRetryHTTPRequest fetches the given URL with retry logic. +func sendAndRetryHTTPRequest(ctx context.Context, client *http.Client, targetURL string, retryAttempts int, retryDelay time.Duration) ([]byte, error) { + var lastErr error + for i := 0; i < retryAttempts; i++ { + if i > 0 { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(retryDelay): + } + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := client.Do(req) + if err != nil { + lastErr = err + logger.Warnf(ctx, "Failed to fetch %s (attempt %d/%d): %v", targetURL, i+1, retryAttempts, err) + continue + } + + body, readErr := io.ReadAll(resp.Body) + resp.Body.Close() + if readErr != nil { + lastErr = readErr + logger.Warnf(ctx, "Failed to read response body from %s (attempt %d/%d): %v", targetURL, i+1, retryAttempts, readErr) + continue + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + lastErr = fmt.Errorf("unexpected status code %d from %s", resp.StatusCode, targetURL) + logger.Warnf(ctx, "Unexpected status code from %s (attempt %d/%d): %d", targetURL, i+1, retryAttempts, resp.StatusCode) + continue + } + + return body, nil + } + + return nil, fmt.Errorf("all %d attempts failed for %s: %w", retryAttempts, targetURL, lastErr) +} diff --git a/runs/service/auth_metadata_test.go b/runs/service/auth_metadata_test.go new file mode 100644 index 00000000000..e9f9408c4d7 --- /dev/null +++ b/runs/service/auth_metadata_test.go @@ -0,0 +1,344 @@ +package service + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "connectrpc.com/connect" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" + "github.com/flyteorg/flyte/v2/runs/config" + + stdlibConfig "github.com/flyteorg/flyte/v2/flytestdlib/config" +) + +func mustParseURL(rawURL string) stdlibConfig.URL { + u, err := url.Parse(rawURL) + if err != nil { + panic(err) + } + return stdlibConfig.URL{URL: *u} +} + +func TestGetPublicClientConfig(t *testing.T) { + cfg := config.AuthConfig{ + GrpcAuthorizationHeader: "flyte-authorization", + AppAuth: config.OAuth2Options{ + ThirdParty: config.ThirdPartyConfigOptions{ + FlyteClientConfig: config.FlyteClientConfig{ + ClientID: "flyte-client", + RedirectURI: "http://localhost:12345/callback", + Scopes: []string{"openid", "offline"}, + Audience: "https://flyte.example.com", + }, + }, + }, + } + + svc := NewAuthMetadataService(cfg) + resp, err := svc.GetPublicClientConfig(context.Background(), connect.NewRequest(&auth.GetPublicClientConfigRequest{})) + require.NoError(t, err) + + msg := resp.Msg + assert.Equal(t, "flyte-client", msg.ClientId) + assert.Equal(t, "http://localhost:12345/callback", msg.RedirectUri) + assert.Equal(t, []string{"openid", "offline"}, msg.Scopes) + assert.Equal(t, "flyte-authorization", msg.AuthorizationMetadataKey) + assert.Equal(t, "https://flyte.example.com", msg.Audience) +} + +func TestGetOAuth2Metadata_SelfAuthServer(t *testing.T) { + cfg := config.AuthConfig{ + AuthorizedURIs: []stdlibConfig.URL{ + mustParseURL("https://flyte.example.com"), + }, + AppAuth: config.OAuth2Options{ + AuthServerType: config.AuthorizationServerTypeSelf, + }, + } + + svc := NewAuthMetadataService(cfg) + resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.NoError(t, err) + + msg := resp.Msg + assert.Equal(t, "https://flyte.example.com", msg.Issuer) + assert.Equal(t, "https://flyte.example.com/oauth2/authorize", msg.AuthorizationEndpoint) + assert.Equal(t, "https://flyte.example.com/oauth2/token", msg.TokenEndpoint) + assert.Equal(t, "https://flyte.example.com/oauth2/jwks", msg.JwksUri) + assert.Equal(t, []string{"S256"}, msg.CodeChallengeMethodsSupported) + assert.Equal(t, []string{"code", "token", "code token"}, msg.ResponseTypesSupported) + assert.Equal(t, []string{"client_credentials", "refresh_token", "authorization_code"}, msg.GrantTypesSupported) + assert.Equal(t, []string{"all"}, msg.ScopesSupported) + assert.Equal(t, []string{"client_secret_basic"}, msg.TokenEndpointAuthMethodsSupported) +} + +func TestGetOAuth2Metadata_SelfAuthServerWithCustomIssuer(t *testing.T) { + cfg := config.AuthConfig{ + AuthorizedURIs: []stdlibConfig.URL{ + mustParseURL("https://flyte.example.com"), + }, + AppAuth: config.OAuth2Options{ + AuthServerType: config.AuthorizationServerTypeSelf, + SelfAuthServer: config.AuthorizationServer{ + Issuer: "https://custom-issuer.example.com", + }, + }, + } + + svc := NewAuthMetadataService(cfg) + resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.NoError(t, err) + + msg := resp.Msg + assert.Equal(t, "https://custom-issuer.example.com", msg.Issuer) + assert.Equal(t, "https://flyte.example.com/oauth2/authorize", msg.AuthorizationEndpoint) +} + +func TestGetOAuth2Metadata_SelfAuthServerNoAuthorizedURIs(t *testing.T) { + cfg := config.AuthConfig{ + AppAuth: config.OAuth2Options{ + AuthServerType: config.AuthorizationServerTypeSelf, + }, + } + + svc := NewAuthMetadataService(cfg) + resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.NoError(t, err) + + msg := resp.Msg + assert.Equal(t, "http://localhost:8090", msg.Issuer) + assert.Equal(t, "http://localhost:8090/oauth2/token", msg.TokenEndpoint) +} + +func TestGetOAuth2Metadata_ExternalAuthServer(t *testing.T) { + expectedMetadata := &auth.GetOAuth2MetadataResponse{ + Issuer: "https://external-idp.example.com", + AuthorizationEndpoint: "https://external-idp.example.com/authorize", + TokenEndpoint: "https://external-idp.example.com/token", + JwksUri: "https://external-idp.example.com/.well-known/jwks.json", + } + + metadataJSON, err := json.Marshal(expectedMetadata) + require.NoError(t, err) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/"+oauth2MetadataEndpoint, r.URL.Path) + w.Header().Set("Content-Type", "application/json") + w.Write(metadataJSON) + })) + defer ts.Close() + + cfg := config.AuthConfig{ + AppAuth: config.OAuth2Options{ + AuthServerType: config.AuthorizationServerTypeExternal, + ExternalAuthServer: config.ExternalAuthorizationServer{ + BaseURL: mustParseURL(ts.URL), + RetryAttempts: 1, + RetryDelay: stdlibConfig.Duration{Duration: 100 * time.Millisecond}, + }, + }, + } + + svc := NewAuthMetadataService(cfg) + resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.NoError(t, err) + + msg := resp.Msg + assert.Equal(t, "https://external-idp.example.com", msg.Issuer) + assert.Equal(t, "https://external-idp.example.com/authorize", msg.AuthorizationEndpoint) + assert.Equal(t, "https://external-idp.example.com/token", msg.TokenEndpoint) + assert.Equal(t, "https://external-idp.example.com/.well-known/jwks.json", msg.JwksUri) +} + +func TestGetOAuth2Metadata_ExternalWithCustomMetadataURL(t *testing.T) { + expectedMetadata := &auth.GetOAuth2MetadataResponse{ + Issuer: "https://external-idp.example.com", + TokenEndpoint: "https://external-idp.example.com/token", + } + + metadataJSON, err := json.Marshal(expectedMetadata) + require.NoError(t, err) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/custom/metadata", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + w.Write(metadataJSON) + })) + defer ts.Close() + + cfg := config.AuthConfig{ + AppAuth: config.OAuth2Options{ + AuthServerType: config.AuthorizationServerTypeExternal, + ExternalAuthServer: config.ExternalAuthorizationServer{ + BaseURL: mustParseURL(ts.URL), + MetadataEndpointURL: mustParseURL(ts.URL + "/custom/metadata"), + RetryAttempts: 1, + RetryDelay: stdlibConfig.Duration{Duration: 100 * time.Millisecond}, + }, + }, + } + + svc := NewAuthMetadataService(cfg) + resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.NoError(t, err) + + assert.Equal(t, "https://external-idp.example.com", resp.Msg.Issuer) + assert.Equal(t, "https://external-idp.example.com/token", resp.Msg.TokenEndpoint) +} + +func TestGetOAuth2Metadata_ExternalWithTokenProxy(t *testing.T) { + expectedMetadata := &auth.GetOAuth2MetadataResponse{ + Issuer: "https://external-idp.example.com", + AuthorizationEndpoint: "https://external-idp.example.com/authorize", + TokenEndpoint: "https://external-idp.example.com/oauth/token", + } + + metadataJSON, err := json.Marshal(expectedMetadata) + require.NoError(t, err) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write(metadataJSON) + })) + defer ts.Close() + + cfg := config.AuthConfig{ + AuthorizedURIs: []stdlibConfig.URL{ + mustParseURL("https://flyte.example.com"), + }, + AppAuth: config.OAuth2Options{ + AuthServerType: config.AuthorizationServerTypeExternal, + ExternalAuthServer: config.ExternalAuthorizationServer{ + BaseURL: mustParseURL(ts.URL), + RetryAttempts: 1, + RetryDelay: stdlibConfig.Duration{Duration: 100 * time.Millisecond}, + }, + }, + TokenEndpointProxyConfig: config.TokenEndpointProxyConfig{ + Enabled: true, + }, + } + + svc := NewAuthMetadataService(cfg) + resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.NoError(t, err) + + msg := resp.Msg + assert.Equal(t, "https://external-idp.example.com", msg.Issuer) + assert.Equal(t, "https://external-idp.example.com/authorize", msg.AuthorizationEndpoint) + // Token endpoint should be rewritten to the public URL + assert.Equal(t, "https://flyte.example.com/oauth/token", msg.TokenEndpoint) +} + +func TestGetOAuth2Metadata_ExternalWithTokenProxyAndPathPrefix(t *testing.T) { + expectedMetadata := &auth.GetOAuth2MetadataResponse{ + TokenEndpoint: "https://external-idp.example.com/oauth/token", + } + + metadataJSON, err := json.Marshal(expectedMetadata) + require.NoError(t, err) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write(metadataJSON) + })) + defer ts.Close() + + cfg := config.AuthConfig{ + AppAuth: config.OAuth2Options{ + AuthServerType: config.AuthorizationServerTypeExternal, + ExternalAuthServer: config.ExternalAuthorizationServer{ + BaseURL: mustParseURL(ts.URL), + RetryAttempts: 1, + RetryDelay: stdlibConfig.Duration{Duration: 100 * time.Millisecond}, + }, + }, + TokenEndpointProxyConfig: config.TokenEndpointProxyConfig{ + Enabled: true, + PublicURL: mustParseURL("https://proxy.example.com"), + PathPrefix: "api/v1", + }, + } + + svc := NewAuthMetadataService(cfg) + resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.NoError(t, err) + + assert.Equal(t, "https://proxy.example.com/api/v1/oauth/token", resp.Msg.TokenEndpoint) +} + +func TestGetOAuth2Metadata_ExternalNoBaseURL(t *testing.T) { + cfg := config.AuthConfig{ + AppAuth: config.OAuth2Options{ + AuthServerType: config.AuthorizationServerTypeExternal, + }, + } + + svc := NewAuthMetadataService(cfg) + _, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.Error(t, err) + assert.Contains(t, err.Error(), "external auth server base URL is not configured") +} + +func TestSendAndRetryHTTPRequest_ImmediateSuccess(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status":"ok"}`)) + })) + defer ts.Close() + + body, err := sendAndRetryHTTPRequest(context.Background(), http.DefaultClient, ts.URL, 3, 10*time.Millisecond) + require.NoError(t, err) + assert.Equal(t, `{"status":"ok"}`, string(body)) +} + +func TestSendAndRetryHTTPRequest_RetryIntoSuccess(t *testing.T) { + attempt := 0 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempt++ + if attempt < 3 { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status":"ok"}`)) + })) + defer ts.Close() + + body, err := sendAndRetryHTTPRequest(context.Background(), http.DefaultClient, ts.URL, 5, 10*time.Millisecond) + require.NoError(t, err) + assert.Equal(t, `{"status":"ok"}`, string(body)) + assert.Equal(t, 3, attempt) +} + +func TestSendAndRetryHTTPRequest_AllRetrysFail(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer ts.Close() + + _, err := sendAndRetryHTTPRequest(context.Background(), http.DefaultClient, ts.URL, 3, 10*time.Millisecond) + require.Error(t, err) + assert.Contains(t, err.Error(), "all 3 attempts failed") +} + +func TestSendAndRetryHTTPRequest_ContextCancelled(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer ts.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := sendAndRetryHTTPRequest(ctx, http.DefaultClient, ts.URL, 5, 10*time.Millisecond) + require.Error(t, err) +} diff --git a/runs/setup.go b/runs/setup.go index 4e444186992..17bdaaf2822 100644 --- a/runs/setup.go +++ b/runs/setup.go @@ -7,6 +7,7 @@ import ( "github.com/flyteorg/flyte/v2/app" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/actions/actionsconnect" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth/authconnect" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/task/taskconnect" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow/workflowconnect" "github.com/flyteorg/flyte/v2/runs/config" @@ -58,6 +59,11 @@ func Setup(ctx context.Context, sc *app.SetupContext) error { sc.Mux.Handle(translatorPath, translatorHandler) logger.Infof(ctx, "Mounted TranslatorService at %s", translatorPath) + authSvc := service.NewAuthMetadataService(cfg.Auth) + authPath, authHandler := authconnect.NewAuthMetadataServiceHandler(authSvc) + sc.Mux.Handle(authPath, authHandler) + logger.Infof(ctx, "Mounted AuthMetadataService at %s", authPath) + sc.AddReadyCheck(func(r *http.Request) error { sqlDB, err := sc.DB.DB() if err != nil { From 7749d5bd40f2c12cc55b3a02d7ddcde03ea578bd Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 11 Mar 2026 00:29:37 -0700 Subject: [PATCH 02/13] update chart Signed-off-by: Kevin Su --- .../cluster-resource-templates/namespace.yaml | 4 -- charts/flyte-binary/templates/_helpers.tpl | 14 +++++++ .../templates/admin-auth-secret.yaml | 8 ++-- .../flyte-binary/templates/config-secret.yaml | 12 +++--- charts/flyte-binary/templates/deployment.yaml | 42 ++++++++++++++----- 5 files changed, 55 insertions(+), 25 deletions(-) delete mode 100644 charts/flyte-binary/defaults/cluster-resource-templates/namespace.yaml diff --git a/charts/flyte-binary/defaults/cluster-resource-templates/namespace.yaml b/charts/flyte-binary/defaults/cluster-resource-templates/namespace.yaml deleted file mode 100644 index 301cb82f42f..00000000000 --- a/charts/flyte-binary/defaults/cluster-resource-templates/namespace.yaml +++ /dev/null @@ -1,4 +0,0 @@ -apiVersion: v1 -kind: Namespace -metadata: - name: '{{ namespace }}' diff --git a/charts/flyte-binary/templates/_helpers.tpl b/charts/flyte-binary/templates/_helpers.tpl index 3e49eefdefc..d4156a6cb59 100644 --- a/charts/flyte-binary/templates/_helpers.tpl +++ b/charts/flyte-binary/templates/_helpers.tpl @@ -113,6 +113,20 @@ templates: {{- toYaml .custom | nindent 2 -}} {{- end -}} {{- end -}} +{{/* +Get the Secret name for Run service authentication secrets. +*/}} +{{ define "flyte-binary.configuration.auth.runServiceAuthSecretName" -}} +{{ printf "%s-admin-auth" (include "flyte-binary.fullname" .) }} +{{ end -}} + +{{/* +Get the Secret name for Flyte authentication client secrets. +*/}} +{{ define "flyte-binary.configuration.auth.clientSecretName" -}} +{{ printf "%s-client-secrets" (include "flyte-binary.fullname" .) }} +{{ end -}} + {{/* Get the Flyte cluster resource templates ConfigMap name. */}} diff --git a/charts/flyte-binary/templates/admin-auth-secret.yaml b/charts/flyte-binary/templates/admin-auth-secret.yaml index 3e164c1f109..173d0bf2880 100644 --- a/charts/flyte-binary/templates/admin-auth-secret.yaml +++ b/charts/flyte-binary/templates/admin-auth-secret.yaml @@ -2,15 +2,15 @@ apiVersion: v1 kind: Secret metadata: - name: {{ include "flyte-binary.configuration.auth.adminAuthSecretName" . }} + name: {{ include "flyte-binary.configuration.auth.runServiceAuthSecretName" . }} namespace: {{ .Release.Namespace | quote }} - labels: {{- include "flyte-binary.labels" . | nindent 4 }} + labels: {{ include "flyte-binary.labels" . | nindent 4 }} {{- if .Values.commonLabels }} - {{- tpl ( .Values.commonLabels | toYaml ) . | nindent 4 }} + {{ tpl ( .Values.commonLabels | toYaml ) . | nindent 4 }} {{- end }} annotations: {{- if .Values.commonAnnotations }} - {{- tpl ( .Values.commonAnnotations | toYaml ) . | nindent 4 }} + {{ tpl ( .Values.commonAnnotations | toYaml ) . | nindent 4 }} {{- end }} type: Opaque {{- end }} diff --git a/charts/flyte-binary/templates/config-secret.yaml b/charts/flyte-binary/templates/config-secret.yaml index 5992755b1fb..4f05d134fbc 100644 --- a/charts/flyte-binary/templates/config-secret.yaml +++ b/charts/flyte-binary/templates/config-secret.yaml @@ -4,19 +4,19 @@ kind: Secret metadata: name: {{ include "flyte-binary.configuration.configSecretName" . }} namespace: {{ .Release.Namespace | quote }} - labels: {{- include "flyte-binary.labels" . | nindent 4 }} + labels: {{ include "flyte-binary.labels" . | nindent 4 }} {{- if .Values.commonLabels }} - {{- tpl ( .Values.commonLabels | toYaml ) . | nindent 4 }} + {{ tpl ( .Values.commonLabels | toYaml ) . | nindent 4 }} {{- end }} {{- if .Values.configuration.labels }} - {{- tpl ( .Values.configuration.labels | toYaml ) . | nindent 4 }} + {{ tpl ( .Values.configuration.labels | toYaml ) . | nindent 4 }} {{- end }} annotations: {{- if .Values.commonAnnotations }} - {{- tpl ( .Values.commonAnnotations | toYaml ) . | nindent 4 }} + {{ tpl ( .Values.commonAnnotations | toYaml ) . | nindent 4 }} {{- end }} {{- if .Values.configuration.annotations }} - {{- tpl ( .Values.configuration.annotations | toYaml ) . | nindent 4 }} + {{ tpl ( .Values.configuration.annotations | toYaml ) . | nindent 4 }} {{- end }} type: Opaque stringData: @@ -44,7 +44,7 @@ stringData: appAuth: selfAuthServer: staticClients: - flytepropeller: + executor: client_secret: {{ .Values.configuration.auth.internal.clientSecretHash | quote }} {{- end }} {{- end }} diff --git a/charts/flyte-binary/templates/deployment.yaml b/charts/flyte-binary/templates/deployment.yaml index ec45925dd6a..704f149f010 100644 --- a/charts/flyte-binary/templates/deployment.yaml +++ b/charts/flyte-binary/templates/deployment.yaml @@ -3,42 +3,48 @@ kind: Deployment metadata: name: {{ include "flyte-binary.fullname" . }} namespace: {{ .Release.Namespace | quote }} - labels: {{- include "flyte-binary.labels" . | nindent 4 }} + labels: {{ include "flyte-binary.labels" . | nindent 4 }} {{- if .Values.commonLabels }} - {{- tpl ( .Values.commonLabels | toYaml ) . | nindent 4 }} + {{ tpl ( .Values.commonLabels | toYaml ) . | nindent 4 }} {{- end }} {{- if .Values.deployment.labels }} - {{- tpl ( .Values.deployment.labels | toYaml ) . | nindent 4 }} + {{ tpl ( .Values.deployment.labels | toYaml ) . | nindent 4 }} {{- end }} annotations: {{- if .Values.commonAnnotations }} - {{- tpl ( .Values.commonAnnotations | toYaml ) . | nindent 4 }} + {{ tpl ( .Values.commonAnnotations | toYaml ) . | nindent 4 }} {{- end }} {{- if .Values.deployment.annotations }} - {{- tpl ( .Values.deployment.annotations | toYaml ) . | nindent 4 }} + {{ tpl ( .Values.deployment.annotations | toYaml ) . | nindent 4 }} {{- end }} spec: replicas: 1 strategy: type: Recreate selector: - matchLabels: {{- include "flyte-binary.selectorLabels" . | nindent 6 }} + matchLabels: {{ include "flyte-binary.selectorLabels" . | nindent 6 }} template: metadata: - labels: {{- include "flyte-binary.selectorLabels" . | nindent 8 }} + labels: {{ include "flyte-binary.selectorLabels" . | nindent 8 }} {{- if .Values.commonLabels }} - {{- tpl ( .Values.commonLabels | toYaml ) . | nindent 8 }} + {{ tpl ( .Values.commonLabels | toYaml ) . | nindent 8 }} {{- end }} {{- if .Values.deployment.podLabels }} - {{- tpl ( .Values.deployment.podLabels | toYaml ) . | nindent 8 }} + {{ tpl ( .Values.deployment.podLabels | toYaml ) . | nindent 8 }} {{- end }} annotations: {{- if not (include "flyte-binary.configuration.externalConfiguration" .) }} checksum/configuration: {{ include (print $.Template.BasePath "/configmap.yaml") . | sha256sum }} checksum/configuration-secret: {{ include (print $.Template.BasePath "/config-secret.yaml") . | sha256sum }} {{- end }} + {{- if .Values.configuration.auth.enabled }} + checksum/runservice-auth-secret: {{ include (print $.Template.BasePath "/runservice-auth-secret.yaml") . | sha256sum }} + {{- if not .Values.configuration.auth.clientSecretsExternalSecretRef }} + checksum/auth-client-secret: {{ include (print $.Template.BasePath "/auth-client-secret.yaml") . | sha256sum }} + {{- end }} + {{- end }} {{- if .Values.commonAnnotations }} - {{- tpl ( .Values.commonAnnotations | toYaml ) . | nindent 8 }} + {{ tpl ( .Values.commonAnnotations | toYaml ) . | nindent 8 }} {{- end }} {{- if .Values.deployment.podAnnotations }} {{- tpl ( .Values.deployment.podAnnotations | toYaml ) . | nindent 8 }} @@ -182,6 +188,20 @@ spec: {{- tpl ( .Values.deployment.sidecars | toYaml ) . | nindent 8 }} {{- end }} volumes: + {{- if .Values.configuration.auth.enabled }} + - name: auth + projected: + sources: + - secret: + name: {{ include "flyte-binary.configuration.auth.runServiceAuthSecretName" . }} + {{- if .Values.configuration.auth.clientSecretsExternalSecretRef }} + - secret: + name: {{ tpl .Values.configuration.auth.clientSecretsExternalSecretRef . }} + {{- else }} + - secret: + name: {{ include "flyte-binary.configuration.auth.clientSecretName" . }} + {{- end }} + {{- end }} - name: config {{- if (include "flyte-binary.configuration.externalConfiguration" .) }} projected: @@ -211,5 +231,5 @@ spec: {{- end }} {{- end }} {{- if .Values.deployment.extraVolumes }} - {{- tpl ( .Values.deployment.extraVolumes | toYaml ) . | nindent 8 }} + {{ tpl ( .Values.deployment.extraVolumes | toYaml ) . | nindent 8 }} {{- end }} From 2c49025d9a23b0f38236902f15f7dbb501b23a7c Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 11 Mar 2026 13:14:46 -0700 Subject: [PATCH 03/13] wip Signed-off-by: Kevin Su --- charts/flyte-binary/templates/_helpers.tpl | 2 +- .../flyte-binary/templates/ingress/grpc.yaml | 10 +- .../flyte-binary/templates/service/grpc.yaml | 48 +++ charts/flyte-binary/values.yaml | 1 - go.mod | 6 +- go.sum | 8 +- runs/config/config.go | 104 +----- runs/service/auth_metadata.go | 204 ----------- runs/service/auth_metadata_test.go | 344 ------------------ runs/service/identity_service.go | 29 -- runs/service/identity_service_test.go | 21 -- runs/setup.go | 24 +- 12 files changed, 85 insertions(+), 716 deletions(-) create mode 100644 charts/flyte-binary/templates/service/grpc.yaml delete mode 100644 runs/service/auth_metadata.go delete mode 100644 runs/service/auth_metadata_test.go delete mode 100644 runs/service/identity_service.go delete mode 100644 runs/service/identity_service_test.go diff --git a/charts/flyte-binary/templates/_helpers.tpl b/charts/flyte-binary/templates/_helpers.tpl index d4156a6cb59..58f7c805c4c 100644 --- a/charts/flyte-binary/templates/_helpers.tpl +++ b/charts/flyte-binary/templates/_helpers.tpl @@ -152,7 +152,7 @@ Get the Flyte service HTTP port. Get the Flyte gRPC service name */}} {{- define "flyte-binary.service.grpc.name" -}} -{{- printf "%s-http" (include "flyte-binary.fullname" .) -}} +{{- printf "%s-grpc" (include "flyte-binary.fullname" .) -}} {{- end -}} {{/* diff --git a/charts/flyte-binary/templates/ingress/grpc.yaml b/charts/flyte-binary/templates/ingress/grpc.yaml index bd43185e509..85ef5fcba45 100644 --- a/charts/flyte-binary/templates/ingress/grpc.yaml +++ b/charts/flyte-binary/templates/ingress/grpc.yaml @@ -7,20 +7,20 @@ metadata: namespace: {{ .Release.Namespace | quote }} labels: {{- include "flyte-binary.labels" . | nindent 4 }} {{- if .Values.commonLabels }} - {{- tpl ( .Values.commonLabels | toYaml ) . | nindent 4 }} + {{ tpl ( .Values.commonLabels | toYaml ) . | nindent 4 }} {{- end }} {{- if .Values.ingress.labels }} - {{- tpl ( .Values.ingress.labels | toYaml ) . | nindent 4 }} + {{ tpl ( .Values.ingress.labels | toYaml ) . | nindent 4 }} {{- end }} annotations: {{- if .Values.commonAnnotations }} - {{- tpl ( .Values.commonAnnotations | toYaml ) . | nindent 4 }} + {{ tpl ( .Values.commonAnnotations | toYaml ) . | nindent 4 }} {{- end }} {{- if .Values.ingress.commonAnnotations }} - {{- tpl ( .Values.ingress.commonAnnotations | toYaml ) . | nindent 4 }} + {{ tpl ( .Values.ingress.commonAnnotations | toYaml ) . | nindent 4 }} {{- end }} {{- if .Values.ingress.grpcAnnotations }} - {{- tpl ( .Values.ingress.grpcAnnotations | toYaml ) . | nindent 4 }} + {{ tpl ( .Values.ingress.grpcAnnotations | toYaml ) . | nindent 4 }} {{- end }} spec: {{- if .Values.ingress.grpcIngressClassName }} diff --git a/charts/flyte-binary/templates/service/grpc.yaml b/charts/flyte-binary/templates/service/grpc.yaml new file mode 100644 index 00000000000..cc6ba67ac9c --- /dev/null +++ b/charts/flyte-binary/templates/service/grpc.yaml @@ -0,0 +1,48 @@ +{{- if .Values.ingress.separateGrpcIngress }} +apiVersion: v1 +kind: Service +metadata: + name: {{ include "flyte-binary.service.grpc.name" . }} + namespace: {{ .Release.Namespace | quote }} + labels: {{ include "flyte-binary.labels" . | nindent 4 }} + {{- if .Values.commonLabels }} + {{ tpl ( .Values.commonLabels | toYaml ) . | nindent 4 }} + {{- end }} + {{- if .Values.service.labels }} + {{ tpl ( .Values.service.labels | toYaml ) . | nindent 4 }} + {{- end }} + annotations: + {{- if .Values.commonAnnotations }} + {{ tpl ( .Values.commonAnnotations | toYaml ) . | nindent 4 }} + {{- end }} + {{- if .Values.service.commonAnnotations }} + {{ tpl ( .Values.service.commonAnnotations | toYaml ) . | nindent 4 }} + {{- end }} + {{- if .Values.service.grpcAnnotations }} + {{ tpl ( .Values.service.grpcAnnotations | toYaml ) . | nindent 4 }} + {{- end }} +spec: + type: {{ .Values.service.type }} + {{- if or (eq .Values.service.type "LoadBalancer") (eq .Values.service.type "NodePort") }} + externalTrafficPolicy: {{ .Values.service.externalTrafficPolicy | quote }} + {{- end }} + {{- if and (eq .Values.service.type "LoadBalancer") (not (empty .Values.service.loadBalancerSourceRanges)) }} + loadBalancerSourceRanges: {{ .Values.service.loadBalancerSourceRanges }} + {{- end }} + {{- if and (eq .Values.service.type "LoadBalancer") (not (empty .Values.service.loadBalancerIP)) }} + loadBalancerIP: {{ .Values.service.loadBalancerIP }} + {{- end }} + {{- if and .Values.service.clusterIP (eq .Values.service.type "ClusterIP") }} + clusterIP: {{ .Values.service.clusterIP }} + {{- end }} + ports: + - name: grpc + port: {{ include "flyte-binary.service.grpc.port" . }} + targetPort: grpc + {{- if and (or (eq .Values.service.type "NodePort") (eq .Values.service.type "LoadBalancer")) (not (empty .Values.service.nodePorts.grpc)) }} + nodePort: {{ .Values.service.nodePorts.grpc }} + {{- else if eq .Values.service.type "ClusterIP" }} + nodePort: null + {{- end }} + selector: {{- include "flyte-binary.selectorLabels" . | nindent 4 }} +{{- end }} diff --git a/charts/flyte-binary/values.yaml b/charts/flyte-binary/values.yaml index c7ef4e51b05..6c361c974e3 100644 --- a/charts/flyte-binary/values.yaml +++ b/charts/flyte-binary/values.yaml @@ -420,7 +420,6 @@ enabled_plugins: - container - sidecar - connector-service - - echo default-for-task-types: container: container sidecar: sidecar diff --git a/go.mod b/go.mod index 33c236559f6..961182a751b 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.41.1 github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1 github.com/coocood/freecache v1.2.4 + github.com/coreos/go-oidc/v3 v3.17.0 github.com/dask/dask-kubernetes/v2023 v2023.0.0-20230626103304-abd02cd17b26 github.com/eko/gocache/lib/v4 v4.2.0 github.com/eko/gocache/store/freecache/v4 v4.2.0 @@ -26,6 +27,7 @@ require ( github.com/ghodss/yaml v1.0.0 github.com/go-gormigrate/gormigrate/v2 v2.1.5 github.com/go-test/deep v1.1.1 + github.com/golang-jwt/jwt/v5 v5.3.0 github.com/golang/protobuf v1.5.4 github.com/googleapis/gax-go/v2 v2.15.0 github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus v1.1.0 @@ -131,7 +133,7 @@ require ( github.com/evanphx/json-patch v5.6.0+incompatible // indirect github.com/evanphx/json-patch/v5 v5.9.11 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/go-jose/go-jose/v4 v4.1.2 // indirect + github.com/go-jose/go-jose/v4 v4.1.3 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-logr/zapr v1.3.0 // indirect @@ -141,7 +143,6 @@ require ( github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/go-task/slim-sprig/v3 v3.0.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/mock v1.6.0 // indirect github.com/google/btree v1.1.3 // indirect @@ -153,6 +154,7 @@ require ( github.com/google/s2a-go v0.1.9 // indirect github.com/google/uuid v1.6.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect + github.com/gorilla/securecookie v1.1.2 // indirect github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.1.0 // indirect github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect diff --git a/go.sum b/go.sum index 34d92880765..a93e939c03d 100644 --- a/go.sum +++ b/go.sum @@ -163,6 +163,8 @@ github.com/cncf/xds/go v0.0.0-20250501225837-2ac532fd4443 h1:aQ3y1lwWyqYPiWZThqv github.com/cncf/xds/go v0.0.0-20250501225837-2ac532fd4443/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= github.com/coocood/freecache v1.2.4 h1:UdR6Yz/X1HW4fZOuH0Z94KwG851GWOSknua5VUbb/5M= github.com/coocood/freecache v1.2.4/go.mod h1:RBUWa/Cy+OHdfTGFEhEuE1pMCMX51Ncizj7rthiQ3vk= +github.com/coreos/go-oidc/v3 v3.17.0 h1:hWBGaQfbi0iVviX4ibC7bk8OKT5qNr4klBaCHVNvehc= +github.com/coreos/go-oidc/v3 v3.17.0/go.mod h1:wqPbKFrVnE90vty060SB40FCJ8fTHTxSwyXJqZH+sI8= github.com/coreos/go-semver v0.3.1 h1:yi21YpKnrx1gt5R+la8n5WgS0kCrsPp33dmEyHReZr4= github.com/coreos/go-semver v0.3.1/go.mod h1:irMmmIw/7yzSRPWryHsK7EYSg09caPQL03VsM8rvUec= github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= @@ -219,8 +221,8 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2 github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gormigrate/gormigrate/v2 v2.1.5 h1:1OyorA5LtdQw12cyJDEHuTrEV3GiXiIhS4/QTTa/SM8= github.com/go-gormigrate/gormigrate/v2 v2.1.5/go.mod h1:mj9ekk/7CPF3VjopaFvWKN2v7fN3D9d3eEOAXRhi/+M= -github.com/go-jose/go-jose/v4 v4.1.2 h1:TK/7NqRQZfgAh+Td8AlsrvtPoUyiHh0LqVvokh+1vHI= -github.com/go-jose/go-jose/v4 v4.1.2/go.mod h1:22cg9HWM1pOlnRiY+9cQYJ9XHmya1bYW8OeDM6Ku6Oo= +github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= +github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/go-logr/logr v1.2.0/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= @@ -338,6 +340,8 @@ github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5m github.com/googleapis/gax-go/v2 v2.15.0 h1:SyjDc1mGgZU5LncH8gimWo9lW1DtIfPibOG81vgd/bo= github.com/googleapis/gax-go/v2 v2.15.0/go.mod h1:zVVkkxAQHa1RQpg9z2AUCMnKhi0Qld9rcmyfL1OZhoc= github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= +github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= +github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 h1:+9834+KizmvFV7pXQGSXQTsaWhq2GjuNUt0aUU0YBYw= diff --git a/runs/config/config.go b/runs/config/config.go index f031e835278..eff3a52038c 100644 --- a/runs/config/config.go +++ b/runs/config/config.go @@ -51,106 +51,14 @@ type Config struct { // Domains are injected into project responses (not stored per project row). Domains []DomainConfig `json:"domains"` - // Auth configuration for the AuthMetadataService - Auth AuthConfig `json:"auth"` + // Security controls authentication and authorization behavior. + Security SecurityConfig `json:"security"` } -// AuthorizationServerType defines the type of authorization server. -type AuthorizationServerType int - -const ( - // AuthorizationServerTypeSelf indicates the service acts as its own authorization server. - AuthorizationServerTypeSelf AuthorizationServerType = iota - // AuthorizationServerTypeExternal indicates an external authorization server is used. - AuthorizationServerTypeExternal -) - -// AuthConfig holds authentication configuration matching flyteadmin's auth config shape. -type AuthConfig struct { - // AuthorizedURIs is the set of URIs clients can access the service on. - // The first entry is used as the public URL for building OAuth2 metadata URLs. - AuthorizedURIs []config.URL `json:"authorizedUris"` - - // GrpcAuthorizationHeader is the authorization metadata key returned to clients. - GrpcAuthorizationHeader string `json:"grpcAuthorizationHeader"` - - // AppAuth defines app-level OAuth2 settings. - AppAuth OAuth2Options `json:"appAuth"` - - // HTTPProxyURL allows accessing external OAuth2 servers through an HTTP proxy. - HTTPProxyURL config.URL `json:"httpProxyURL"` - - // TokenEndpointProxyConfig proxies token endpoint calls through admin. - TokenEndpointProxyConfig TokenEndpointProxyConfig `json:"tokenEndpointProxyConfig"` -} - -// OAuth2Options holds OAuth2 authorization server options. -type OAuth2Options struct { - // AuthServerType determines whether to use a self-hosted or external auth server. - AuthServerType AuthorizationServerType `json:"authServerType"` - - // SelfAuthServer configures the self-hosted authorization server. - SelfAuthServer AuthorizationServer `json:"selfAuthServer"` - - // ExternalAuthServer configures the external authorization server. - ExternalAuthServer ExternalAuthorizationServer `json:"externalAuthServer"` - - // ThirdParty configures third-party (public client) settings. - ThirdParty ThirdPartyConfigOptions `json:"thirdParty"` -} - -// AuthorizationServer configures a self-hosted authorization server. -type AuthorizationServer struct { - // Issuer is the issuer URL. If empty, the first AuthorizedURI is used. - Issuer string `json:"issuer"` -} - -// ExternalAuthorizationServer configures an external authorization server. -type ExternalAuthorizationServer struct { - // BaseURL is the base URL of the external authorization server. - BaseURL config.URL `json:"baseURL"` - - // MetadataEndpointURL overrides the default .well-known/oauth-authorization-server endpoint. - MetadataEndpointURL config.URL `json:"metadataEndpointURL"` - - // RetryAttempts is the number of retry attempts for fetching metadata. - RetryAttempts int `json:"retryAttempts"` - - // RetryDelay is the delay between retry attempts. - RetryDelay config.Duration `json:"retryDelay"` -} - -// ThirdPartyConfigOptions holds third-party OAuth2 client settings. -type ThirdPartyConfigOptions struct { - // FlyteClientConfig holds public client configuration. - FlyteClientConfig FlyteClientConfig `json:"flyteClient"` -} - -// FlyteClientConfig holds the public client configuration. -type FlyteClientConfig struct { - // ClientID is the public client ID. - ClientID string `json:"clientId"` - - // RedirectURI is the redirect URI for the client. - RedirectURI string `json:"redirectUri"` - - // Scopes are the OAuth2 scopes to request. - Scopes []string `json:"scopes"` - - // Audience is the intended audience for OAuth2 tokens. - Audience string `json:"audience"` -} - -// TokenEndpointProxyConfig configures proxying of token endpoint calls. -type TokenEndpointProxyConfig struct { - // Enabled enables token endpoint proxying. - Enabled bool `json:"enabled"` - - // PublicURL is the public URL to use for rewriting the token endpoint. - PublicURL config.URL `json:"publicURL"` - - // PathPrefix is appended to the public URL when rewriting. - PathPrefix string `json:"pathPrefix"` +// SecurityConfig controls authentication and authorization behavior. +type SecurityConfig struct { + // UseAuth enables authentication. When true, AuthMetadataService and IdentityService are registered. + UseAuth bool `json:"useAuth" pflag:",Enable authentication and identity services"` } // ServerConfig holds HTTP server configuration diff --git a/runs/service/auth_metadata.go b/runs/service/auth_metadata.go deleted file mode 100644 index 62bed4213ef..00000000000 --- a/runs/service/auth_metadata.go +++ /dev/null @@ -1,204 +0,0 @@ -package service - -import ( - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - "strings" - "time" - - "connectrpc.com/connect" - - "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" - "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth/authconnect" - "github.com/flyteorg/flyte/v2/runs/config" - - stdlibConfig "github.com/flyteorg/flyte/v2/flytestdlib/config" - "github.com/flyteorg/flyte/v2/flytestdlib/logger" -) - -const ( - oauth2MetadataEndpoint = ".well-known/oauth-authorization-server" - scopeAll = "all" -) - -type authMetadataService struct { - authconnect.UnimplementedAuthMetadataServiceHandler - cfg config.AuthConfig -} - -func NewAuthMetadataService(cfg config.AuthConfig) authconnect.AuthMetadataServiceHandler { - return &authMetadataService{cfg: cfg} -} - -func (s *authMetadataService) GetOAuth2Metadata( - ctx context.Context, - _ *connect.Request[auth.GetOAuth2MetadataRequest], -) (*connect.Response[auth.GetOAuth2MetadataResponse], error) { - switch s.cfg.AppAuth.AuthServerType { - case config.AuthorizationServerTypeExternal: - return s.getOAuth2MetadataExternal(ctx) - default: - return s.getOAuth2MetadataSelf() - } -} - -func (s *authMetadataService) getOAuth2MetadataSelf() (*connect.Response[auth.GetOAuth2MetadataResponse], error) { - publicURL := getPublicURL(s.cfg.AuthorizedURIs) - issuer := getIssuer(s.cfg, publicURL) - - base := strings.TrimRight(publicURL.String(), "/") - - resp := &auth.GetOAuth2MetadataResponse{ - Issuer: issuer, - AuthorizationEndpoint: base + "/oauth2/authorize", - TokenEndpoint: base + "/oauth2/token", - JwksUri: base + "/oauth2/jwks", - CodeChallengeMethodsSupported: []string{"S256"}, - ResponseTypesSupported: []string{"code", "token", "code token"}, - GrantTypesSupported: []string{"client_credentials", "refresh_token", "authorization_code"}, - ScopesSupported: []string{scopeAll}, - TokenEndpointAuthMethodsSupported: []string{"client_secret_basic"}, - } - - return connect.NewResponse(resp), nil -} - -func (s *authMetadataService) getOAuth2MetadataExternal(ctx context.Context) (*connect.Response[auth.GetOAuth2MetadataResponse], error) { - baseURL := s.cfg.AppAuth.ExternalAuthServer.BaseURL - if baseURL.String() == "" { - return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("external auth server base URL is not configured")) - } - - metadataURL := s.cfg.AppAuth.ExternalAuthServer.MetadataEndpointURL - if metadataURL.String() == "" || metadataURL.String() == baseURL.String() { - u := baseURL.URL - u.Path = strings.TrimRight(u.Path, "/") + "/" + oauth2MetadataEndpoint - metadataURL = stdlibConfig.URL{URL: u} - } - - httpClient := &http.Client{} - if s.cfg.HTTPProxyURL.String() != "" { - httpClient.Transport = &http.Transport{ - Proxy: http.ProxyURL(&s.cfg.HTTPProxyURL.URL), - } - } - - retryAttempts := s.cfg.AppAuth.ExternalAuthServer.RetryAttempts - if retryAttempts <= 0 { - retryAttempts = 5 - } - retryDelay := s.cfg.AppAuth.ExternalAuthServer.RetryDelay.Duration - if retryDelay <= 0 { - retryDelay = time.Second - } - - body, err := sendAndRetryHTTPRequest(ctx, httpClient, metadataURL.String(), retryAttempts, retryDelay) - if err != nil { - return nil, connect.NewError(connect.CodeUnavailable, fmt.Errorf("failed to fetch OAuth2 metadata: %w", err)) - } - - resp := &auth.GetOAuth2MetadataResponse{} - if err := json.Unmarshal(body, resp); err != nil { - return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("failed to unmarshal OAuth2 metadata: %w", err)) - } - - if s.cfg.TokenEndpointProxyConfig.Enabled && resp.TokenEndpoint != "" { - proxyURL := s.cfg.TokenEndpointProxyConfig.PublicURL - if proxyURL.String() == "" { - proxyURL = stdlibConfig.URL{URL: *getPublicURL(s.cfg.AuthorizedURIs)} - } - rewritten := strings.TrimRight(proxyURL.String(), "/") - if s.cfg.TokenEndpointProxyConfig.PathPrefix != "" { - rewritten += "/" + strings.Trim(s.cfg.TokenEndpointProxyConfig.PathPrefix, "/") - } - - // Preserve the original token endpoint path - originalURL, parseErr := url.Parse(resp.TokenEndpoint) - if parseErr == nil { - rewritten += originalURL.Path - } - resp.TokenEndpoint = rewritten - } - - return connect.NewResponse(resp), nil -} - -func (s *authMetadataService) GetPublicClientConfig( - _ context.Context, - _ *connect.Request[auth.GetPublicClientConfigRequest], -) (*connect.Response[auth.GetPublicClientConfigResponse], error) { - fc := s.cfg.AppAuth.ThirdParty.FlyteClientConfig - return connect.NewResponse(&auth.GetPublicClientConfigResponse{ - ClientId: fc.ClientID, - RedirectUri: fc.RedirectURI, - Scopes: fc.Scopes, - AuthorizationMetadataKey: s.cfg.GrpcAuthorizationHeader, - Audience: fc.Audience, - }), nil -} - -// getPublicURL returns the first AuthorizedURI as the public URL, or a default localhost URL. -func getPublicURL(authorizedURIs []stdlibConfig.URL) *url.URL { - if len(authorizedURIs) > 0 { - u := authorizedURIs[0].URL - return &u - } - u, _ := url.Parse("http://localhost:8090") - return u -} - -// getIssuer returns the issuer from SelfAuthServer config, or falls back to public URL. -func getIssuer(cfg config.AuthConfig, publicURL *url.URL) string { - if cfg.AppAuth.SelfAuthServer.Issuer != "" { - return cfg.AppAuth.SelfAuthServer.Issuer - } - return strings.TrimRight(publicURL.String(), "/") -} - -// sendAndRetryHTTPRequest fetches the given URL with retry logic. -func sendAndRetryHTTPRequest(ctx context.Context, client *http.Client, targetURL string, retryAttempts int, retryDelay time.Duration) ([]byte, error) { - var lastErr error - for i := 0; i < retryAttempts; i++ { - if i > 0 { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(retryDelay): - } - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, nil) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - resp, err := client.Do(req) - if err != nil { - lastErr = err - logger.Warnf(ctx, "Failed to fetch %s (attempt %d/%d): %v", targetURL, i+1, retryAttempts, err) - continue - } - - body, readErr := io.ReadAll(resp.Body) - resp.Body.Close() - if readErr != nil { - lastErr = readErr - logger.Warnf(ctx, "Failed to read response body from %s (attempt %d/%d): %v", targetURL, i+1, retryAttempts, readErr) - continue - } - - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - lastErr = fmt.Errorf("unexpected status code %d from %s", resp.StatusCode, targetURL) - logger.Warnf(ctx, "Unexpected status code from %s (attempt %d/%d): %d", targetURL, i+1, retryAttempts, resp.StatusCode) - continue - } - - return body, nil - } - - return nil, fmt.Errorf("all %d attempts failed for %s: %w", retryAttempts, targetURL, lastErr) -} diff --git a/runs/service/auth_metadata_test.go b/runs/service/auth_metadata_test.go deleted file mode 100644 index e9f9408c4d7..00000000000 --- a/runs/service/auth_metadata_test.go +++ /dev/null @@ -1,344 +0,0 @@ -package service - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "net/url" - "testing" - "time" - - "connectrpc.com/connect" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" - "github.com/flyteorg/flyte/v2/runs/config" - - stdlibConfig "github.com/flyteorg/flyte/v2/flytestdlib/config" -) - -func mustParseURL(rawURL string) stdlibConfig.URL { - u, err := url.Parse(rawURL) - if err != nil { - panic(err) - } - return stdlibConfig.URL{URL: *u} -} - -func TestGetPublicClientConfig(t *testing.T) { - cfg := config.AuthConfig{ - GrpcAuthorizationHeader: "flyte-authorization", - AppAuth: config.OAuth2Options{ - ThirdParty: config.ThirdPartyConfigOptions{ - FlyteClientConfig: config.FlyteClientConfig{ - ClientID: "flyte-client", - RedirectURI: "http://localhost:12345/callback", - Scopes: []string{"openid", "offline"}, - Audience: "https://flyte.example.com", - }, - }, - }, - } - - svc := NewAuthMetadataService(cfg) - resp, err := svc.GetPublicClientConfig(context.Background(), connect.NewRequest(&auth.GetPublicClientConfigRequest{})) - require.NoError(t, err) - - msg := resp.Msg - assert.Equal(t, "flyte-client", msg.ClientId) - assert.Equal(t, "http://localhost:12345/callback", msg.RedirectUri) - assert.Equal(t, []string{"openid", "offline"}, msg.Scopes) - assert.Equal(t, "flyte-authorization", msg.AuthorizationMetadataKey) - assert.Equal(t, "https://flyte.example.com", msg.Audience) -} - -func TestGetOAuth2Metadata_SelfAuthServer(t *testing.T) { - cfg := config.AuthConfig{ - AuthorizedURIs: []stdlibConfig.URL{ - mustParseURL("https://flyte.example.com"), - }, - AppAuth: config.OAuth2Options{ - AuthServerType: config.AuthorizationServerTypeSelf, - }, - } - - svc := NewAuthMetadataService(cfg) - resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) - require.NoError(t, err) - - msg := resp.Msg - assert.Equal(t, "https://flyte.example.com", msg.Issuer) - assert.Equal(t, "https://flyte.example.com/oauth2/authorize", msg.AuthorizationEndpoint) - assert.Equal(t, "https://flyte.example.com/oauth2/token", msg.TokenEndpoint) - assert.Equal(t, "https://flyte.example.com/oauth2/jwks", msg.JwksUri) - assert.Equal(t, []string{"S256"}, msg.CodeChallengeMethodsSupported) - assert.Equal(t, []string{"code", "token", "code token"}, msg.ResponseTypesSupported) - assert.Equal(t, []string{"client_credentials", "refresh_token", "authorization_code"}, msg.GrantTypesSupported) - assert.Equal(t, []string{"all"}, msg.ScopesSupported) - assert.Equal(t, []string{"client_secret_basic"}, msg.TokenEndpointAuthMethodsSupported) -} - -func TestGetOAuth2Metadata_SelfAuthServerWithCustomIssuer(t *testing.T) { - cfg := config.AuthConfig{ - AuthorizedURIs: []stdlibConfig.URL{ - mustParseURL("https://flyte.example.com"), - }, - AppAuth: config.OAuth2Options{ - AuthServerType: config.AuthorizationServerTypeSelf, - SelfAuthServer: config.AuthorizationServer{ - Issuer: "https://custom-issuer.example.com", - }, - }, - } - - svc := NewAuthMetadataService(cfg) - resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) - require.NoError(t, err) - - msg := resp.Msg - assert.Equal(t, "https://custom-issuer.example.com", msg.Issuer) - assert.Equal(t, "https://flyte.example.com/oauth2/authorize", msg.AuthorizationEndpoint) -} - -func TestGetOAuth2Metadata_SelfAuthServerNoAuthorizedURIs(t *testing.T) { - cfg := config.AuthConfig{ - AppAuth: config.OAuth2Options{ - AuthServerType: config.AuthorizationServerTypeSelf, - }, - } - - svc := NewAuthMetadataService(cfg) - resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) - require.NoError(t, err) - - msg := resp.Msg - assert.Equal(t, "http://localhost:8090", msg.Issuer) - assert.Equal(t, "http://localhost:8090/oauth2/token", msg.TokenEndpoint) -} - -func TestGetOAuth2Metadata_ExternalAuthServer(t *testing.T) { - expectedMetadata := &auth.GetOAuth2MetadataResponse{ - Issuer: "https://external-idp.example.com", - AuthorizationEndpoint: "https://external-idp.example.com/authorize", - TokenEndpoint: "https://external-idp.example.com/token", - JwksUri: "https://external-idp.example.com/.well-known/jwks.json", - } - - metadataJSON, err := json.Marshal(expectedMetadata) - require.NoError(t, err) - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "/"+oauth2MetadataEndpoint, r.URL.Path) - w.Header().Set("Content-Type", "application/json") - w.Write(metadataJSON) - })) - defer ts.Close() - - cfg := config.AuthConfig{ - AppAuth: config.OAuth2Options{ - AuthServerType: config.AuthorizationServerTypeExternal, - ExternalAuthServer: config.ExternalAuthorizationServer{ - BaseURL: mustParseURL(ts.URL), - RetryAttempts: 1, - RetryDelay: stdlibConfig.Duration{Duration: 100 * time.Millisecond}, - }, - }, - } - - svc := NewAuthMetadataService(cfg) - resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) - require.NoError(t, err) - - msg := resp.Msg - assert.Equal(t, "https://external-idp.example.com", msg.Issuer) - assert.Equal(t, "https://external-idp.example.com/authorize", msg.AuthorizationEndpoint) - assert.Equal(t, "https://external-idp.example.com/token", msg.TokenEndpoint) - assert.Equal(t, "https://external-idp.example.com/.well-known/jwks.json", msg.JwksUri) -} - -func TestGetOAuth2Metadata_ExternalWithCustomMetadataURL(t *testing.T) { - expectedMetadata := &auth.GetOAuth2MetadataResponse{ - Issuer: "https://external-idp.example.com", - TokenEndpoint: "https://external-idp.example.com/token", - } - - metadataJSON, err := json.Marshal(expectedMetadata) - require.NoError(t, err) - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "/custom/metadata", r.URL.Path) - w.Header().Set("Content-Type", "application/json") - w.Write(metadataJSON) - })) - defer ts.Close() - - cfg := config.AuthConfig{ - AppAuth: config.OAuth2Options{ - AuthServerType: config.AuthorizationServerTypeExternal, - ExternalAuthServer: config.ExternalAuthorizationServer{ - BaseURL: mustParseURL(ts.URL), - MetadataEndpointURL: mustParseURL(ts.URL + "/custom/metadata"), - RetryAttempts: 1, - RetryDelay: stdlibConfig.Duration{Duration: 100 * time.Millisecond}, - }, - }, - } - - svc := NewAuthMetadataService(cfg) - resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) - require.NoError(t, err) - - assert.Equal(t, "https://external-idp.example.com", resp.Msg.Issuer) - assert.Equal(t, "https://external-idp.example.com/token", resp.Msg.TokenEndpoint) -} - -func TestGetOAuth2Metadata_ExternalWithTokenProxy(t *testing.T) { - expectedMetadata := &auth.GetOAuth2MetadataResponse{ - Issuer: "https://external-idp.example.com", - AuthorizationEndpoint: "https://external-idp.example.com/authorize", - TokenEndpoint: "https://external-idp.example.com/oauth/token", - } - - metadataJSON, err := json.Marshal(expectedMetadata) - require.NoError(t, err) - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Write(metadataJSON) - })) - defer ts.Close() - - cfg := config.AuthConfig{ - AuthorizedURIs: []stdlibConfig.URL{ - mustParseURL("https://flyte.example.com"), - }, - AppAuth: config.OAuth2Options{ - AuthServerType: config.AuthorizationServerTypeExternal, - ExternalAuthServer: config.ExternalAuthorizationServer{ - BaseURL: mustParseURL(ts.URL), - RetryAttempts: 1, - RetryDelay: stdlibConfig.Duration{Duration: 100 * time.Millisecond}, - }, - }, - TokenEndpointProxyConfig: config.TokenEndpointProxyConfig{ - Enabled: true, - }, - } - - svc := NewAuthMetadataService(cfg) - resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) - require.NoError(t, err) - - msg := resp.Msg - assert.Equal(t, "https://external-idp.example.com", msg.Issuer) - assert.Equal(t, "https://external-idp.example.com/authorize", msg.AuthorizationEndpoint) - // Token endpoint should be rewritten to the public URL - assert.Equal(t, "https://flyte.example.com/oauth/token", msg.TokenEndpoint) -} - -func TestGetOAuth2Metadata_ExternalWithTokenProxyAndPathPrefix(t *testing.T) { - expectedMetadata := &auth.GetOAuth2MetadataResponse{ - TokenEndpoint: "https://external-idp.example.com/oauth/token", - } - - metadataJSON, err := json.Marshal(expectedMetadata) - require.NoError(t, err) - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Write(metadataJSON) - })) - defer ts.Close() - - cfg := config.AuthConfig{ - AppAuth: config.OAuth2Options{ - AuthServerType: config.AuthorizationServerTypeExternal, - ExternalAuthServer: config.ExternalAuthorizationServer{ - BaseURL: mustParseURL(ts.URL), - RetryAttempts: 1, - RetryDelay: stdlibConfig.Duration{Duration: 100 * time.Millisecond}, - }, - }, - TokenEndpointProxyConfig: config.TokenEndpointProxyConfig{ - Enabled: true, - PublicURL: mustParseURL("https://proxy.example.com"), - PathPrefix: "api/v1", - }, - } - - svc := NewAuthMetadataService(cfg) - resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) - require.NoError(t, err) - - assert.Equal(t, "https://proxy.example.com/api/v1/oauth/token", resp.Msg.TokenEndpoint) -} - -func TestGetOAuth2Metadata_ExternalNoBaseURL(t *testing.T) { - cfg := config.AuthConfig{ - AppAuth: config.OAuth2Options{ - AuthServerType: config.AuthorizationServerTypeExternal, - }, - } - - svc := NewAuthMetadataService(cfg) - _, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) - require.Error(t, err) - assert.Contains(t, err.Error(), "external auth server base URL is not configured") -} - -func TestSendAndRetryHTTPRequest_ImmediateSuccess(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"status":"ok"}`)) - })) - defer ts.Close() - - body, err := sendAndRetryHTTPRequest(context.Background(), http.DefaultClient, ts.URL, 3, 10*time.Millisecond) - require.NoError(t, err) - assert.Equal(t, `{"status":"ok"}`, string(body)) -} - -func TestSendAndRetryHTTPRequest_RetryIntoSuccess(t *testing.T) { - attempt := 0 - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - attempt++ - if attempt < 3 { - w.WriteHeader(http.StatusServiceUnavailable) - return - } - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"status":"ok"}`)) - })) - defer ts.Close() - - body, err := sendAndRetryHTTPRequest(context.Background(), http.DefaultClient, ts.URL, 5, 10*time.Millisecond) - require.NoError(t, err) - assert.Equal(t, `{"status":"ok"}`, string(body)) - assert.Equal(t, 3, attempt) -} - -func TestSendAndRetryHTTPRequest_AllRetrysFail(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusServiceUnavailable) - })) - defer ts.Close() - - _, err := sendAndRetryHTTPRequest(context.Background(), http.DefaultClient, ts.URL, 3, 10*time.Millisecond) - require.Error(t, err) - assert.Contains(t, err.Error(), "all 3 attempts failed") -} - -func TestSendAndRetryHTTPRequest_ContextCancelled(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusServiceUnavailable) - })) - defer ts.Close() - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - _, err := sendAndRetryHTTPRequest(ctx, http.DefaultClient, ts.URL, 5, 10*time.Millisecond) - require.Error(t, err) -} diff --git a/runs/service/identity_service.go b/runs/service/identity_service.go deleted file mode 100644 index 7484d6c71fb..00000000000 --- a/runs/service/identity_service.go +++ /dev/null @@ -1,29 +0,0 @@ -package service - -import ( - "context" - - "connectrpc.com/connect" - - "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" - "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth/authconnect" -) - -// IdentityService implements the IdentityServiceHandler interface. -type IdentityService struct{} - -// NewIdentityService creates a new IdentityService instance. -func NewIdentityService() *IdentityService { - return &IdentityService{} -} - -var _ authconnect.IdentityServiceHandler = (*IdentityService)(nil) - -// UserInfo returns information about the currently logged in user. -// TODO: Wire with real auth to populate user info from the authenticated context. -func (s *IdentityService) UserInfo( - ctx context.Context, - req *connect.Request[auth.UserInfoRequest], -) (*connect.Response[auth.UserInfoResponse], error) { - return connect.NewResponse(&auth.UserInfoResponse{}), nil -} diff --git a/runs/service/identity_service_test.go b/runs/service/identity_service_test.go deleted file mode 100644 index a1839954398..00000000000 --- a/runs/service/identity_service_test.go +++ /dev/null @@ -1,21 +0,0 @@ -package service - -import ( - "context" - "testing" - - "connectrpc.com/connect" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" -) - -func TestIdentityService_UserInfo(t *testing.T) { - svc := NewIdentityService() - - resp, err := svc.UserInfo(context.Background(), connect.NewRequest(&auth.UserInfoRequest{})) - require.NoError(t, err) - assert.NotNil(t, resp) - assert.NotNil(t, resp.Msg) -} diff --git a/runs/setup.go b/runs/setup.go index 4ac79e2dc58..ed9434227a9 100644 --- a/runs/setup.go +++ b/runs/setup.go @@ -21,6 +21,9 @@ import ( "github.com/flyteorg/flyte/v2/runs/repository/interfaces" "github.com/flyteorg/flyte/v2/runs/repository/models" "github.com/flyteorg/flyte/v2/runs/service" + authservice "github.com/flyteorg/flyte/v2/runs/service/auth" + authConfig "github.com/flyteorg/flyte/v2/runs/service/auth/config" + "github.com/flyteorg/flyte/v2/runs/service/auth/authzserver" "github.com/flyteorg/flyte/v2/flytestdlib/logger" ) @@ -66,15 +69,18 @@ func Setup(ctx context.Context, sc *app.SetupContext) error { sc.Mux.Handle(translatorPath, translatorHandler) logger.Infof(ctx, "Mounted TranslatorService at %s", translatorPath) - authSvc := service.NewAuthMetadataService(cfg.Auth) - authPath, authHandler := authconnect.NewAuthMetadataServiceHandler(authSvc) - sc.Mux.Handle(authPath, authHandler) - logger.Infof(ctx, "Mounted AuthMetadataService at %s", authPath) - - identitySvc := service.NewIdentityService() - identityPath, identityHandler := authconnect.NewIdentityServiceHandler(identitySvc) - sc.Mux.Handle(identityPath, identityHandler) - logger.Infof(ctx, "Mounted IdentityService at %s", identityPath) + if cfg.Security.UseAuth { + authCfg := authConfig.GetConfig() + authSvc := authzserver.NewAuthMetadataService(*authCfg) + authPath, authHandler := authconnect.NewAuthMetadataServiceHandler(authSvc) + sc.Mux.Handle(authPath, authHandler) + logger.Infof(ctx, "Mounted AuthMetadataService at %s", authPath) + + identitySvc := authservice.NewIdentityService() + identityPath, identityHandler := authconnect.NewIdentityServiceHandler(identitySvc) + sc.Mux.Handle(identityPath, identityHandler) + logger.Infof(ctx, "Mounted IdentityService at %s", identityPath) + } domains := make([]*projectpb.Domain, 0, len(cfg.Domains)) for _, d := range cfg.Domains { From 1af67c6e02a060652835d05b5141107be0cd8523 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 11 Mar 2026 14:59:22 -0700 Subject: [PATCH 04/13] update authzserver Signed-off-by: Kevin Su --- runs/service/auth/auth_context.go | 131 +++++++ .../auth/authzserver/claims_verifier.go | 105 ++++++ .../auth/authzserver/claims_verifier_test.go | 114 ++++++ .../auth/authzserver/metadata_provider.go | 217 +++++++++++ .../authzserver/metadata_provider_test.go | 348 ++++++++++++++++++ .../auth/authzserver/resource_server.go | 109 ++++++ .../auth/authzserver/resource_server_test.go | 101 +++++ runs/service/auth/config/config.go | 298 +++++++++++++++ runs/service/auth/constants.go | 22 ++ runs/service/auth/cookie.go | 190 ++++++++++ runs/service/auth/cookie_manager.go | 225 +++++++++++ runs/service/auth/handler_utils.go | 201 ++++++++++ runs/service/auth/handler_utils_test.go | 175 +++++++++ runs/service/auth/handlers.go | 291 +++++++++++++++ runs/service/auth/identity_context.go | 80 ++++ runs/service/auth/identity_context_test.go | 65 ++++ runs/service/auth/identity_service.go | 29 ++ runs/service/auth/identity_service_test.go | 21 ++ runs/service/auth/interceptor.go | 98 +++++ runs/service/auth/interfaces.go | 8 + runs/service/auth/token.go | 176 +++++++++ runs/service/auth/user_info_provider.go | 28 ++ 22 files changed, 3032 insertions(+) create mode 100644 runs/service/auth/auth_context.go create mode 100644 runs/service/auth/authzserver/claims_verifier.go create mode 100644 runs/service/auth/authzserver/claims_verifier_test.go create mode 100644 runs/service/auth/authzserver/metadata_provider.go create mode 100644 runs/service/auth/authzserver/metadata_provider_test.go create mode 100644 runs/service/auth/authzserver/resource_server.go create mode 100644 runs/service/auth/authzserver/resource_server_test.go create mode 100644 runs/service/auth/config/config.go create mode 100644 runs/service/auth/constants.go create mode 100644 runs/service/auth/cookie.go create mode 100644 runs/service/auth/cookie_manager.go create mode 100644 runs/service/auth/handler_utils.go create mode 100644 runs/service/auth/handler_utils_test.go create mode 100644 runs/service/auth/handlers.go create mode 100644 runs/service/auth/identity_context.go create mode 100644 runs/service/auth/identity_context_test.go create mode 100644 runs/service/auth/identity_service.go create mode 100644 runs/service/auth/identity_service_test.go create mode 100644 runs/service/auth/interceptor.go create mode 100644 runs/service/auth/interfaces.go create mode 100644 runs/service/auth/token.go create mode 100644 runs/service/auth/user_info_provider.go diff --git a/runs/service/auth/auth_context.go b/runs/service/auth/auth_context.go new file mode 100644 index 00000000000..1989d78956d --- /dev/null +++ b/runs/service/auth/auth_context.go @@ -0,0 +1,131 @@ +// Package auth contains types needed to start up a standalone OAuth2 Authorization Server or delegate +// authentication to an external provider. It supports OpenID Connect for user authentication. +package auth + +import ( + "context" + "fmt" + "net/http" + "net/url" + "time" + + "github.com/coreos/go-oidc/v3/oidc" + "golang.org/x/oauth2" + + "github.com/flyteorg/flyte/v2/flytestdlib/logger" + "github.com/flyteorg/flyte/v2/runs/service/auth/config" +) + +const ( + IdpConnectionTimeout = 10 * time.Second +) + +// AuthenticationContext holds all the utilities necessary to run authentication. +// +// The auth package supports two request flows, both producing an IdentityContext: +// +// Browser (HTTP) API (gRPC / HTTP Bearer) +// ────────────── ─────────────────────── +// handlers.go interceptor.go +// /login -> IdP redirect ▼ +// /callback -> exchange code token.go / resource_server.go +// /logout -> clear cookies ▼ +// │ claims_verifier.go +// ▼ │ +// cookie_manager.go │ +// (read/write encrypted cookies) │ +// │ │ +// ▼ │ +// cookie.go │ +// (CSRF, secure cookie helpers) │ +// │ │ +// ▼ ▼ +// token.go ──────────────────────> identity_context.go +// (parse/validate JWT) (UserID, AppID, Scopes, Claims) +// +type AuthenticationContext struct { + oauth2Config *oauth2.Config + cookieManager CookieManager + oidcProvider *oidc.Provider + resourceServer OAuth2ResourceServer + cfg config.Config + httpClient *http.Client + oauth2MetadataURL *url.URL + oidcMetadataURL *url.URL +} + +func (c *AuthenticationContext) OAuth2Config() *oauth2.Config { return c.oauth2Config } +func (c *AuthenticationContext) CookieManager() CookieManager { return c.cookieManager } +func (c *AuthenticationContext) OIDCProvider() *oidc.Provider { return c.oidcProvider } +func (c *AuthenticationContext) ResourceServer() OAuth2ResourceServer { return c.resourceServer } +func (c *AuthenticationContext) Config() config.Config { return c.cfg } +func (c *AuthenticationContext) HTTPClient() *http.Client { return c.httpClient } +func (c *AuthenticationContext) OAuth2MetadataURL() *url.URL { return c.oauth2MetadataURL } +func (c *AuthenticationContext) OIDCMetadataURL() *url.URL { return c.oidcMetadataURL } + +// NewAuthContext creates a new AuthContext with all the components needed for authentication. +func NewAuthContext(ctx context.Context, cfg config.Config, resourceServer OAuth2ResourceServer, + hashKeyBase64, blockKeyBase64 string) (*AuthenticationContext, error) { + + cookieManager, err := NewCookieManager(ctx, hashKeyBase64, blockKeyBase64, cfg.UserAuth.CookieSetting) + if err != nil { + logger.Errorf(ctx, "Error creating cookie manager %s", err) + return nil, fmt.Errorf("error creating cookie manager: %w", err) + } + + httpClient := &http.Client{ + Timeout: IdpConnectionTimeout, + } + + if len(cfg.UserAuth.HTTPProxyURL.String()) > 0 { + logger.Infof(ctx, "HTTPProxy URL for OAuth2 is: %s", cfg.UserAuth.HTTPProxyURL.String()) + httpClient.Transport = &http.Transport{Proxy: http.ProxyURL(&cfg.UserAuth.HTTPProxyURL.URL)} + } + + oidcCtx := oidc.ClientContext(ctx, httpClient) + baseURL := cfg.UserAuth.OpenID.BaseURL.String() + provider, err := oidc.NewProvider(oidcCtx, baseURL) + if err != nil { + return nil, fmt.Errorf("error creating oidc provider with issuer [%v]: %w", baseURL, err) + } + + oauth2Config := &oauth2.Config{ + RedirectURL: "callback", + ClientID: cfg.UserAuth.OpenID.ClientID, + Scopes: cfg.UserAuth.OpenID.Scopes, + Endpoint: provider.Endpoint(), + } + + oauth2MetadataURL, err := url.Parse(OAuth2MetadataEndpoint) + if err != nil { + return nil, fmt.Errorf("error parsing oauth2 metadata URL: %w", err) + } + + oidcMetadataURL, err := url.Parse(OIdCMetadataEndpoint) + if err != nil { + return nil, fmt.Errorf("error parsing oidc metadata URL: %w", err) + } + + return &AuthenticationContext{ + oauth2Config: oauth2Config, + cookieManager: cookieManager, + oidcProvider: provider, + resourceServer: resourceServer, + cfg: cfg, + httpClient: httpClient, + oauth2MetadataURL: oauth2MetadataURL, + oidcMetadataURL: oidcMetadataURL, + }, nil +} + +// HandlerConfig returns an AuthHandlerConfig suitable for use with RegisterHandlers. +func (c *AuthenticationContext) HandlerConfig() *AuthHandlerConfig { + return &AuthHandlerConfig{ + CookieManager: c.cookieManager, + OAuth2Config: c.oauth2Config, + OIDCProvider: c.oidcProvider, + ResourceServer: c.resourceServer, + AuthConfig: c.cfg, + HTTPClient: c.httpClient, + } +} diff --git a/runs/service/auth/authzserver/claims_verifier.go b/runs/service/auth/authzserver/claims_verifier.go new file mode 100644 index 00000000000..ae284492609 --- /dev/null +++ b/runs/service/auth/authzserver/claims_verifier.go @@ -0,0 +1,105 @@ +package authzserver + +import ( + "encoding/json" + "fmt" + "time" + + jwtgo "github.com/golang-jwt/jwt/v5" + + authpb "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" + auth "github.com/flyteorg/flyte/v2/runs/service/auth" +) + +const ( + // ClientIDClaim is the JWT claim key for the client ID. + ClientIDClaim = "client_id" + // UserIDClaim is the JWT claim key for user info. + UserIDClaim = "user_info" + // ScopeClaim is the JWT claim key for scopes. + ScopeClaim = "scp" +) + +// verifyClaims extracts identity information from raw JWT claims and validates the audience. +func verifyClaims(expectedAudience map[string]bool, claims jwtgo.MapClaims) (*auth.IdentityContext, error) { + aud, err := claims.GetAudience() + if err != nil { + return nil, fmt.Errorf("failed to get audience: %w", err) + } + + matchedAudience := "" + for _, a := range aud { + if expectedAudience[a] { + matchedAudience = a + break + } + } + if matchedAudience == "" { + return nil, fmt.Errorf("invalid audience %v, wanted one of %v", aud, expectedAudience) + } + + sub, _ := claims.GetSubject() + + issuedAt := time.Time{} + if iat, err := claims.GetIssuedAt(); err == nil && iat != nil { + issuedAt = iat.Time + } + + userInfo := &authpb.UserInfoResponse{} + if userInfoClaim, found := claims[UserIDClaim]; found && userInfoClaim != nil { + if userInfoRaw, ok := userInfoClaim.(map[string]interface{}); ok { + raw, err := json.Marshal(userInfoRaw) + if err != nil { + return nil, err + } + if err = json.Unmarshal(raw, userInfo); err != nil { + return nil, fmt.Errorf("failed to unmarshal user info claim: %w", err) + } + } + } + + clientID := "" + if clientIDClaim, found := claims[ClientIDClaim]; found { + if s, ok := clientIDClaim.(string); ok { + clientID = s + } + } + + var scopes []string + if scopesClaim, found := claims[ScopeClaim]; found { + switch sct := scopesClaim.(type) { + case []interface{}: + scopes = interfaceSliceToStringSlice(sct) + case string: + scopes = []string{sct} + default: + return nil, fmt.Errorf("failed getting scope claims due to unknown type %T with value %v", sct, sct) + } + } + + // In some cases, "user_info" field doesn't exist in the raw claim, + // but we can get email from "email" field. + if emailClaim, found := claims["email"]; found { + if email, ok := emailClaim.(string); ok { + userInfo.Email = email + } + } + + // If this is a user-only access token with no scopes defined then add `all` scope by default because it's equivalent + // to having a user's login cookie or an ID Token as means of accessing the service. + if len(clientID) == 0 && len(scopes) == 0 { + scopes = []string{auth.ScopeAll} + } + + return auth.NewIdentityContext(matchedAudience, sub, clientID, issuedAt, scopes, userInfo, claims), nil +} + +func interfaceSliceToStringSlice(raw []interface{}) []string { + res := make([]string, 0, len(raw)) + for _, item := range raw { + if s, ok := item.(string); ok { + res = append(res, s) + } + } + return res +} diff --git a/runs/service/auth/authzserver/claims_verifier_test.go b/runs/service/auth/authzserver/claims_verifier_test.go new file mode 100644 index 00000000000..ea1f6e4fa04 --- /dev/null +++ b/runs/service/auth/authzserver/claims_verifier_test.go @@ -0,0 +1,114 @@ +package authzserver + +import ( + "testing" + "time" + + jwtgo "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestVerifyClaims_MatchingAudience(t *testing.T) { + claims := jwtgo.MapClaims{ + "aud": []interface{}{"https://flyte.example.com"}, + "sub": "user123", + "iat": float64(time.Now().Unix()), + } + + allowed := map[string]bool{"https://flyte.example.com": true} + identity, err := verifyClaims(allowed, claims) + require.NoError(t, err) + assert.Equal(t, "https://flyte.example.com", identity.Audience()) + assert.Equal(t, "user123", identity.UserID()) +} + +func TestVerifyClaims_NoMatchingAudience(t *testing.T) { + claims := jwtgo.MapClaims{ + "aud": []interface{}{"https://other.example.com"}, + "sub": "user123", + } + + allowed := map[string]bool{"https://flyte.example.com": true} + _, err := verifyClaims(allowed, claims) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid audience") +} + +func TestVerifyClaims_WithClientID(t *testing.T) { + claims := jwtgo.MapClaims{ + "aud": "https://flyte.example.com", + "sub": "user123", + "client_id": "my-client", + "scp": []interface{}{"read", "write"}, + } + + allowed := map[string]bool{"https://flyte.example.com": true} + identity, err := verifyClaims(allowed, claims) + require.NoError(t, err) + assert.Equal(t, "my-client", identity.AppID()) + assert.Equal(t, []string{"read", "write"}, identity.Scopes()) +} + +func TestVerifyClaims_UserOnlyDefaultsToScopeAll(t *testing.T) { + claims := jwtgo.MapClaims{ + "aud": "https://flyte.example.com", + "sub": "user123", + } + + allowed := map[string]bool{"https://flyte.example.com": true} + identity, err := verifyClaims(allowed, claims) + require.NoError(t, err) + assert.Equal(t, []string{"all"}, identity.Scopes()) + assert.Equal(t, "", identity.AppID()) +} + +func TestVerifyClaims_WithEmail(t *testing.T) { + claims := jwtgo.MapClaims{ + "aud": "https://flyte.example.com", + "sub": "user123", + "email": "user@example.com", + } + + allowed := map[string]bool{"https://flyte.example.com": true} + identity, err := verifyClaims(allowed, claims) + require.NoError(t, err) + assert.Equal(t, "user@example.com", identity.UserInfo().Email) +} + +func TestVerifyClaims_WithUserInfoClaim(t *testing.T) { + claims := jwtgo.MapClaims{ + "aud": "https://flyte.example.com", + "sub": "user123", + "user_info": map[string]interface{}{ + "name": "Test User", + "email": "test@example.com", + }, + } + + allowed := map[string]bool{"https://flyte.example.com": true} + identity, err := verifyClaims(allowed, claims) + require.NoError(t, err) + assert.Equal(t, "Test User", identity.UserInfo().Name) + assert.Equal(t, "test@example.com", identity.UserInfo().Email) +} + +func TestVerifyClaims_ScopeAsString(t *testing.T) { + claims := jwtgo.MapClaims{ + "aud": "https://flyte.example.com", + "sub": "user123", + "client_id": "my-client", + "scp": "read", + } + + allowed := map[string]bool{"https://flyte.example.com": true} + identity, err := verifyClaims(allowed, claims) + require.NoError(t, err) + assert.Equal(t, []string{"read"}, identity.Scopes()) +} + +func TestInterfaceSliceToStringSlice(t *testing.T) { + input := []interface{}{"a", "b", "c"} + result := interfaceSliceToStringSlice(input) + assert.Equal(t, []string{"a", "b", "c"}, result) +} diff --git a/runs/service/auth/authzserver/metadata_provider.go b/runs/service/auth/authzserver/metadata_provider.go new file mode 100644 index 00000000000..9ce7dc7c7a5 --- /dev/null +++ b/runs/service/auth/authzserver/metadata_provider.go @@ -0,0 +1,217 @@ +package authzserver + +import ( + "context" + "encoding/json" + "fmt" + "io" + "mime" + "net/http" + "net/url" + "strings" + "time" + + "connectrpc.com/connect" + + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth/authconnect" + "github.com/flyteorg/flyte/v2/flytestdlib/logger" + "github.com/flyteorg/flyte/v2/runs/service/auth/config" + authpkg "github.com/flyteorg/flyte/v2/runs/service/auth" +) + +const ( + oauth2MetadataEndpoint = ".well-known/oauth-authorization-server" +) + +var ( + tokenRelativeURL = mustParseURLPath("/oauth2/token") + authorizeRelativeURL = mustParseURLPath("/oauth2/authorize") + jsonWebKeysURL = mustParseURLPath("/oauth2/jwks") + oauth2MetadataRelURL = mustParseURLPath(oauth2MetadataEndpoint) + + supportedGrantTypes = []string{"client_credentials", "refresh_token", "authorization_code"} +) + +func mustParseURLPath(rawURL string) *url.URL { + u, err := url.Parse(rawURL) + if err != nil { + panic(err) + } + return u +} + +type authMetadataService struct { + authconnect.UnimplementedAuthMetadataServiceHandler + cfg config.Config +} + +func NewAuthMetadataService(cfg config.Config) authconnect.AuthMetadataServiceHandler { + return &authMetadataService{cfg: cfg} +} + +func (s *authMetadataService) GetOAuth2Metadata( + ctx context.Context, + _ *connect.Request[auth.GetOAuth2MetadataRequest], +) (*connect.Response[auth.GetOAuth2MetadataResponse], error) { + switch s.cfg.AppAuth.AuthServerType { + case config.AuthorizationServerTypeSelf: + return s.getOAuth2MetadataSelf(ctx) + default: + return s.getOAuth2MetadataExternal(ctx) + } +} + +func (s *authMetadataService) getOAuth2MetadataSelf(ctx context.Context) (*connect.Response[auth.GetOAuth2MetadataResponse], error) { + publicURL := authpkg.GetPublicURL(ctx, nil, s.cfg) + + resp := &auth.GetOAuth2MetadataResponse{ + Issuer: authpkg.GetIssuer(ctx, nil, s.cfg), + AuthorizationEndpoint: publicURL.ResolveReference(authorizeRelativeURL).String(), + TokenEndpoint: publicURL.ResolveReference(tokenRelativeURL).String(), + JwksUri: publicURL.ResolveReference(jsonWebKeysURL).String(), + CodeChallengeMethodsSupported: []string{"S256"}, + ResponseTypesSupported: []string{"code", "token", "code token"}, + GrantTypesSupported: supportedGrantTypes, + ScopesSupported: []string{authpkg.ScopeAll}, + TokenEndpointAuthMethodsSupported: []string{"client_secret_basic"}, + } + + return connect.NewResponse(resp), nil +} + +func (s *authMetadataService) getOAuth2MetadataExternal(ctx context.Context) (*connect.Response[auth.GetOAuth2MetadataResponse], error) { + baseURL := s.cfg.AppAuth.ExternalAuthServer.BaseURL + if baseURL.String() == "" { + return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("external auth server base URL is not configured")) + } + + // issuer urls, conventionally, do not end with a '/', however, metadata urls are usually relative of those. + // This adds a '/' to ensure ResolveReference behaves intuitively. + baseURL.Path = strings.TrimSuffix(baseURL.Path, "/") + "/" + + var externalMetadataURL *url.URL + if len(s.cfg.AppAuth.ExternalAuthServer.MetadataEndpointURL.String()) > 0 { + externalMetadataURL = baseURL.ResolveReference(&s.cfg.AppAuth.ExternalAuthServer.MetadataEndpointURL.URL) + } else { + externalMetadataURL = baseURL.ResolveReference(oauth2MetadataRelURL) + } + + httpClient := &http.Client{} + if s.cfg.HTTPProxyURL.String() != "" { + httpClient.Transport = &http.Transport{ + Proxy: http.ProxyURL(&s.cfg.HTTPProxyURL.URL), + } + } + + retryAttempts := s.cfg.AppAuth.ExternalAuthServer.RetryAttempts + if retryAttempts <= 0 { + retryAttempts = 5 + } + retryDelay := s.cfg.AppAuth.ExternalAuthServer.RetryDelay.Duration + if retryDelay <= 0 { + retryDelay = time.Second + } + + response, err := sendAndRetryHTTPRequest(ctx, httpClient, externalMetadataURL.String(), retryAttempts, retryDelay) + if err != nil { + return nil, connect.NewError(connect.CodeUnavailable, fmt.Errorf("failed to fetch OAuth2 metadata: %w", err)) + } + + raw, err := io.ReadAll(response.Body) + if err != nil { + return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("failed to read OAuth2 metadata response: %w", err)) + } + + resp := &auth.GetOAuth2MetadataResponse{} + if err := unmarshalResp(response, raw, resp); err != nil { + return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("failed to unmarshal OAuth2 metadata: %w", err)) + } + + tokenProxyConfig := s.cfg.TokenEndpointProxyConfig + if tokenProxyConfig.Enabled { + tokenEndpoint, parseErr := url.Parse(resp.TokenEndpoint) + if parseErr != nil { + return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("failed to parse token endpoint [%v], err: %v", resp.TokenEndpoint, parseErr)) + } + if len(tokenProxyConfig.PublicURL.Host) == 0 { + publicURL := authpkg.GetPublicURL(ctx, nil, s.cfg) + tokenProxyConfig.PublicURL = config.URL{URL: *publicURL} + } + tokenEndpoint.Host = tokenProxyConfig.PublicURL.Host + tokenEndpoint.Path = tokenProxyConfig.PathPrefix + tokenEndpoint.Path + tokenEndpoint.RawPath = tokenProxyConfig.PathPrefix + tokenEndpoint.RawPath + resp.TokenEndpoint = tokenEndpoint.String() + } + + return connect.NewResponse(resp), nil +} + +func (s *authMetadataService) GetPublicClientConfig( + _ context.Context, + _ *connect.Request[auth.GetPublicClientConfigRequest], +) (*connect.Response[auth.GetPublicClientConfigResponse], error) { + fc := s.cfg.AppAuth.ThirdParty.FlyteClientConfig + return connect.NewResponse(&auth.GetPublicClientConfigResponse{ + ClientId: fc.ClientID, + RedirectUri: fc.RedirectURI, + Scopes: fc.Scopes, + AuthorizationMetadataKey: s.cfg.GrpcAuthorizationHeader, + Audience: fc.Audience, + }), nil +} + +// unmarshalResp unmarshals a JSON response body, providing a detailed error if the Content-Type is unexpected. +func unmarshalResp(r *http.Response, body []byte, v interface{}) error { + err := json.Unmarshal(body, &v) + if err == nil { + return nil + } + ct := r.Header.Get("Content-Type") + mediaType, _, parseErr := mime.ParseMediaType(ct) + if parseErr == nil && mediaType == "application/json" { + return fmt.Errorf("got Content-Type = application/json, but could not unmarshal as JSON: %v", err) + } + return fmt.Errorf("expected Content-Type = application/json, got %q: %v", ct, err) +} + +// sendAndRetryHTTPRequest fetches the given URL with retry logic. +func sendAndRetryHTTPRequest(ctx context.Context, client *http.Client, targetURL string, retryAttempts int, retryDelay time.Duration) (*http.Response, error) { + var lastErr error + var lastResp *http.Response + for i := 0; i < retryAttempts; i++ { + if i > 0 { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(retryDelay): + } + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := client.Do(req) + if err != nil { + lastErr = err + logger.Warnf(ctx, "Failed to fetch %s (attempt %d/%d): %v", targetURL, i+1, retryAttempts, err) + continue + } + + if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusMultipleChoices { + return resp, nil + } + + lastErr = fmt.Errorf("unexpected status code %d from %s", resp.StatusCode, targetURL) + lastResp = resp + logger.Warnf(ctx, "Unexpected status code from %s (attempt %d/%d): %d", targetURL, i+1, retryAttempts, resp.StatusCode) + } + + if lastResp != nil && lastResp.StatusCode != http.StatusOK { + return lastResp, fmt.Errorf("failed to get oauth metadata with status code %v: %w", lastResp.StatusCode, lastErr) + } + + return nil, fmt.Errorf("all %d attempts failed for %s: %w", retryAttempts, targetURL, lastErr) +} diff --git a/runs/service/auth/authzserver/metadata_provider_test.go b/runs/service/auth/authzserver/metadata_provider_test.go new file mode 100644 index 00000000000..e50e43dcbbc --- /dev/null +++ b/runs/service/auth/authzserver/metadata_provider_test.go @@ -0,0 +1,348 @@ +package authzserver + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "connectrpc.com/connect" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" + "github.com/flyteorg/flyte/v2/runs/service/auth/config" +) + +func mustParseTestURL(rawURL string) config.URL { + u, err := url.Parse(rawURL) + if err != nil { + panic(err) + } + return config.URL{URL: *u} +} + +func TestGetPublicClientConfig(t *testing.T) { + cfg := config.Config{ + GrpcAuthorizationHeader: "flyte-authorization", + AppAuth: config.OAuth2Options{ + ThirdParty: config.ThirdPartyConfigOptions{ + FlyteClientConfig: config.FlyteClientConfig{ + ClientID: "flyte-client", + RedirectURI: "http://localhost:12345/callback", + Scopes: []string{"openid", "offline"}, + Audience: "https://flyte.example.com", + }, + }, + }, + } + + svc := NewAuthMetadataService(cfg) + resp, err := svc.GetPublicClientConfig(context.Background(), connect.NewRequest(&auth.GetPublicClientConfigRequest{})) + require.NoError(t, err) + + msg := resp.Msg + assert.Equal(t, "flyte-client", msg.ClientId) + assert.Equal(t, "http://localhost:12345/callback", msg.RedirectUri) + assert.Equal(t, []string{"openid", "offline"}, msg.Scopes) + assert.Equal(t, "flyte-authorization", msg.AuthorizationMetadataKey) + assert.Equal(t, "https://flyte.example.com", msg.Audience) +} + +func TestGetOAuth2Metadata_SelfAuthServer(t *testing.T) { + cfg := config.Config{ + AuthorizedURIs: []config.URL{ + mustParseTestURL("https://flyte.example.com"), + }, + AppAuth: config.OAuth2Options{ + AuthServerType: config.AuthorizationServerTypeSelf, + }, + } + + svc := NewAuthMetadataService(cfg) + resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.NoError(t, err) + + msg := resp.Msg + assert.Equal(t, "https://flyte.example.com", msg.Issuer) + assert.Equal(t, "https://flyte.example.com/oauth2/authorize", msg.AuthorizationEndpoint) + assert.Equal(t, "https://flyte.example.com/oauth2/token", msg.TokenEndpoint) + assert.Equal(t, "https://flyte.example.com/oauth2/jwks", msg.JwksUri) + assert.Equal(t, []string{"S256"}, msg.CodeChallengeMethodsSupported) + assert.Equal(t, []string{"code", "token", "code token"}, msg.ResponseTypesSupported) + assert.Equal(t, []string{"client_credentials", "refresh_token", "authorization_code"}, msg.GrantTypesSupported) + assert.Equal(t, []string{"all"}, msg.ScopesSupported) + assert.Equal(t, []string{"client_secret_basic"}, msg.TokenEndpointAuthMethodsSupported) +} + +func TestGetOAuth2Metadata_SelfAuthServerWithCustomIssuer(t *testing.T) { + cfg := config.Config{ + AuthorizedURIs: []config.URL{ + mustParseTestURL("https://flyte.example.com"), + }, + AppAuth: config.OAuth2Options{ + AuthServerType: config.AuthorizationServerTypeSelf, + SelfAuthServer: config.AuthorizationServer{ + Issuer: "https://custom-issuer.example.com", + }, + }, + } + + svc := NewAuthMetadataService(cfg) + resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.NoError(t, err) + + msg := resp.Msg + assert.Equal(t, "https://custom-issuer.example.com", msg.Issuer) + assert.Equal(t, "https://flyte.example.com/oauth2/authorize", msg.AuthorizationEndpoint) +} + +func TestGetOAuth2Metadata_SelfAuthServerDefaultAuthorizedURI(t *testing.T) { + cfg := config.Config{ + AuthorizedURIs: []config.URL{ + mustParseTestURL("http://localhost:8090"), + }, + AppAuth: config.OAuth2Options{ + AuthServerType: config.AuthorizationServerTypeSelf, + }, + } + + svc := NewAuthMetadataService(cfg) + resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.NoError(t, err) + + msg := resp.Msg + assert.Equal(t, "http://localhost:8090", msg.Issuer) + assert.Equal(t, "http://localhost:8090/oauth2/token", msg.TokenEndpoint) +} + +func TestGetOAuth2Metadata_ExternalAuthServer(t *testing.T) { + expectedMetadata := &auth.GetOAuth2MetadataResponse{ + Issuer: "https://external-idp.example.com", + AuthorizationEndpoint: "https://external-idp.example.com/authorize", + TokenEndpoint: "https://external-idp.example.com/token", + JwksUri: "https://external-idp.example.com/.well-known/jwks.json", + } + + metadataJSON, err := json.Marshal(expectedMetadata) + require.NoError(t, err) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write(metadataJSON) + })) + defer ts.Close() + + cfg := config.Config{ + AppAuth: config.OAuth2Options{ + AuthServerType: config.AuthorizationServerTypeExternal, + ExternalAuthServer: config.ExternalAuthorizationServer{ + BaseURL: mustParseTestURL(ts.URL), + RetryAttempts: 1, + RetryDelay: config.Duration{Duration: 100 * time.Millisecond}, + }, + }, + } + + svc := NewAuthMetadataService(cfg) + resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.NoError(t, err) + + msg := resp.Msg + assert.Equal(t, "https://external-idp.example.com", msg.Issuer) + assert.Equal(t, "https://external-idp.example.com/authorize", msg.AuthorizationEndpoint) + assert.Equal(t, "https://external-idp.example.com/token", msg.TokenEndpoint) + assert.Equal(t, "https://external-idp.example.com/.well-known/jwks.json", msg.JwksUri) +} + +func TestGetOAuth2Metadata_ExternalWithCustomMetadataURL(t *testing.T) { + expectedMetadata := &auth.GetOAuth2MetadataResponse{ + Issuer: "https://external-idp.example.com", + TokenEndpoint: "https://external-idp.example.com/token", + } + + metadataJSON, err := json.Marshal(expectedMetadata) + require.NoError(t, err) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/custom/metadata", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + w.Write(metadataJSON) + })) + defer ts.Close() + + cfg := config.Config{ + AppAuth: config.OAuth2Options{ + AuthServerType: config.AuthorizationServerTypeExternal, + ExternalAuthServer: config.ExternalAuthorizationServer{ + BaseURL: mustParseTestURL(ts.URL), + MetadataEndpointURL: mustParseTestURL(ts.URL + "/custom/metadata"), + RetryAttempts: 1, + RetryDelay: config.Duration{Duration: 100 * time.Millisecond}, + }, + }, + } + + svc := NewAuthMetadataService(cfg) + resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.NoError(t, err) + + assert.Equal(t, "https://external-idp.example.com", resp.Msg.Issuer) + assert.Equal(t, "https://external-idp.example.com/token", resp.Msg.TokenEndpoint) +} + +func TestGetOAuth2Metadata_ExternalWithTokenProxy(t *testing.T) { + expectedMetadata := &auth.GetOAuth2MetadataResponse{ + Issuer: "https://external-idp.example.com", + AuthorizationEndpoint: "https://external-idp.example.com/authorize", + TokenEndpoint: "https://external-idp.example.com/oauth/token", + } + + metadataJSON, err := json.Marshal(expectedMetadata) + require.NoError(t, err) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write(metadataJSON) + })) + defer ts.Close() + + cfg := config.Config{ + AuthorizedURIs: []config.URL{ + mustParseTestURL("https://flyte.example.com"), + }, + AppAuth: config.OAuth2Options{ + AuthServerType: config.AuthorizationServerTypeExternal, + ExternalAuthServer: config.ExternalAuthorizationServer{ + BaseURL: mustParseTestURL(ts.URL), + RetryAttempts: 1, + RetryDelay: config.Duration{Duration: 100 * time.Millisecond}, + }, + }, + TokenEndpointProxyConfig: config.TokenEndpointProxyConfig{ + Enabled: true, + }, + } + + svc := NewAuthMetadataService(cfg) + resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.NoError(t, err) + + msg := resp.Msg + assert.Equal(t, "https://external-idp.example.com", msg.Issuer) + assert.Equal(t, "https://external-idp.example.com/authorize", msg.AuthorizationEndpoint) + // Token endpoint should be rewritten to the public URL + assert.Equal(t, "https://flyte.example.com/oauth/token", msg.TokenEndpoint) +} + +func TestGetOAuth2Metadata_ExternalWithTokenProxyAndPathPrefix(t *testing.T) { + expectedMetadata := &auth.GetOAuth2MetadataResponse{ + TokenEndpoint: "https://external-idp.example.com/oauth/token", + } + + metadataJSON, err := json.Marshal(expectedMetadata) + require.NoError(t, err) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write(metadataJSON) + })) + defer ts.Close() + + cfg := config.Config{ + AppAuth: config.OAuth2Options{ + AuthServerType: config.AuthorizationServerTypeExternal, + ExternalAuthServer: config.ExternalAuthorizationServer{ + BaseURL: mustParseTestURL(ts.URL), + RetryAttempts: 1, + RetryDelay: config.Duration{Duration: 100 * time.Millisecond}, + }, + }, + TokenEndpointProxyConfig: config.TokenEndpointProxyConfig{ + Enabled: true, + PublicURL: mustParseTestURL("https://proxy.example.com"), + PathPrefix: "api/v1", + }, + } + + svc := NewAuthMetadataService(cfg) + resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.NoError(t, err) + + assert.Equal(t, "https://proxy.example.com/api/v1/oauth/token", resp.Msg.TokenEndpoint) +} + +func TestGetOAuth2Metadata_ExternalNoBaseURL(t *testing.T) { + cfg := config.Config{ + AppAuth: config.OAuth2Options{ + AuthServerType: config.AuthorizationServerTypeExternal, + }, + } + + svc := NewAuthMetadataService(cfg) + _, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.Error(t, err) + assert.Contains(t, err.Error(), "external auth server base URL is not configured") +} + +func TestSendAndRetryHTTPRequest_ImmediateSuccess(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status":"ok"}`)) + })) + defer ts.Close() + + resp, err := sendAndRetryHTTPRequest(context.Background(), http.DefaultClient, ts.URL, 3, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, http.StatusOK, resp.StatusCode) +} + +func TestSendAndRetryHTTPRequest_RetryIntoSuccess(t *testing.T) { + attempt := 0 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempt++ + if attempt < 3 { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status":"ok"}`)) + })) + defer ts.Close() + + resp, err := sendAndRetryHTTPRequest(context.Background(), http.DefaultClient, ts.URL, 5, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, 3, attempt) +} + +func TestSendAndRetryHTTPRequest_AllRetrysFail(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer ts.Close() + + _, err := sendAndRetryHTTPRequest(context.Background(), http.DefaultClient, ts.URL, 3, 10*time.Millisecond) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to get oauth metadata") +} + +func TestSendAndRetryHTTPRequest_ContextCancelled(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer ts.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := sendAndRetryHTTPRequest(ctx, http.DefaultClient, ts.URL, 5, 10*time.Millisecond) + require.Error(t, err) +} diff --git a/runs/service/auth/authzserver/resource_server.go b/runs/service/auth/authzserver/resource_server.go new file mode 100644 index 00000000000..39429aba5d3 --- /dev/null +++ b/runs/service/auth/authzserver/resource_server.go @@ -0,0 +1,109 @@ +package authzserver + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + "strings" + + "github.com/coreos/go-oidc/v3/oidc" + jwtgo "github.com/golang-jwt/jwt/v5" + + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" + "github.com/flyteorg/flyte/v2/runs/service/auth/config" + authpkg "github.com/flyteorg/flyte/v2/runs/service/auth" +) + +// ResourceServer authorizes access requests issued by an external Authorization Server. +type ResourceServer struct { + signatureVerifier oidc.KeySet + allowedAudience []string +} + +// NewOAuth2ResourceServer initializes a new OAuth2ResourceServer. +func NewOAuth2ResourceServer(ctx context.Context, cfg config.ExternalAuthorizationServer, fallbackBaseURL config.URL) (*ResourceServer, error) { + u := cfg.BaseURL + if len(u.String()) == 0 { + u = fallbackBaseURL + } + + verifier, err := getJwksForIssuer(ctx, u.URL, cfg) + if err != nil { + return nil, err + } + + return &ResourceServer{ + signatureVerifier: verifier, + allowedAudience: cfg.AllowedAudience, + }, nil +} + +// ValidateAccessToken verifies the token signature, validates claims, and returns the identity context. +func (r *ResourceServer) ValidateAccessToken(ctx context.Context, expectedAudience, tokenStr string) (*authpkg.IdentityContext, error) { + _, err := r.signatureVerifier.VerifySignature(ctx, tokenStr) + if err != nil { + return nil, fmt.Errorf("failed to verify token signature: %w", err) + } + + claims := jwtgo.MapClaims{} + parser := jwtgo.NewParser() + if _, _, err = parser.ParseUnverified(tokenStr, claims); err != nil { + return nil, fmt.Errorf("failed to parse token: %w", err) + } + + allowed := make(map[string]bool, len(r.allowedAudience)+1) + for _, a := range r.allowedAudience { + allowed[a] = true + } + allowed[expectedAudience] = true + + return verifyClaims(allowed, claims) +} + +// getJwksForIssuer fetches the OAuth2 metadata from the external auth server and returns the remote JWKS key set. +func getJwksForIssuer(ctx context.Context, issuerBaseURL url.URL, cfg config.ExternalAuthorizationServer) (oidc.KeySet, error) { + issuerBaseURL.Path = strings.TrimSuffix(issuerBaseURL.Path, "/") + "/" + + var wellKnown *url.URL + if len(cfg.MetadataEndpointURL.String()) > 0 { + wellKnown = issuerBaseURL.ResolveReference(&cfg.MetadataEndpointURL.URL) + } else { + wellKnown = issuerBaseURL.ResolveReference(oauth2MetadataRelURL) + } + + httpClient := &http.Client{} + if len(cfg.HTTPProxyURL.String()) > 0 { + httpClient.Transport = &http.Transport{ + Proxy: http.ProxyURL(&cfg.HTTPProxyURL.URL), + } + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnown.String(), nil) + if err != nil { + return nil, err + } + + resp, err := httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("unable to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("%s: %s", resp.Status, body) + } + + p := &auth.GetOAuth2MetadataResponse{} + if err = unmarshalResp(resp, body, p); err != nil { + return nil, fmt.Errorf("failed to decode provider discovery object: %w", err) + } + + return oidc.NewRemoteKeySet(oidc.ClientContext(ctx, httpClient), p.JwksUri), nil +} diff --git a/runs/service/auth/authzserver/resource_server_test.go b/runs/service/auth/authzserver/resource_server_test.go new file mode 100644 index 00000000000..1fbbb43f251 --- /dev/null +++ b/runs/service/auth/authzserver/resource_server_test.go @@ -0,0 +1,101 @@ +package authzserver + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" + "github.com/flyteorg/flyte/v2/runs/service/auth/config" +) + +func TestGetJwksForIssuer_Success(t *testing.T) { + metadata := &auth.GetOAuth2MetadataResponse{ + JwksUri: "https://example.com/.well-known/jwks.json", + } + metadataJSON, err := json.Marshal(metadata) + require.NoError(t, err) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write(metadataJSON) + })) + defer ts.Close() + + cfg := config.ExternalAuthorizationServer{ + BaseURL: mustParseTestURL(ts.URL), + } + + keySet, err := getJwksForIssuer(context.Background(), cfg.BaseURL.URL, cfg) + require.NoError(t, err) + assert.NotNil(t, keySet) +} + +func TestGetJwksForIssuer_CustomMetadataURL(t *testing.T) { + metadata := &auth.GetOAuth2MetadataResponse{ + JwksUri: "https://example.com/.well-known/jwks.json", + } + metadataJSON, err := json.Marshal(metadata) + require.NoError(t, err) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/custom/metadata", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + w.Write(metadataJSON) + })) + defer ts.Close() + + cfg := config.ExternalAuthorizationServer{ + BaseURL: mustParseTestURL(ts.URL), + MetadataEndpointURL: mustParseTestURL(ts.URL + "/custom/metadata"), + } + + keySet, err := getJwksForIssuer(context.Background(), cfg.BaseURL.URL, cfg) + require.NoError(t, err) + assert.NotNil(t, keySet) +} + +func TestGetJwksForIssuer_ServerError(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("server error")) + })) + defer ts.Close() + + cfg := config.ExternalAuthorizationServer{ + BaseURL: mustParseTestURL(ts.URL), + } + + _, err := getJwksForIssuer(context.Background(), cfg.BaseURL.URL, cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "500") +} + +func TestNewOAuth2ResourceServer_FallbackBaseURL(t *testing.T) { + metadata := &auth.GetOAuth2MetadataResponse{ + JwksUri: "https://example.com/.well-known/jwks.json", + } + metadataJSON, err := json.Marshal(metadata) + require.NoError(t, err) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write(metadataJSON) + })) + defer ts.Close() + + cfg := config.ExternalAuthorizationServer{ + AllowedAudience: []string{"https://flyte.example.com"}, + } + fallback := mustParseTestURL(ts.URL) + + rs, err := NewOAuth2ResourceServer(context.Background(), cfg, config.URL(fallback)) + require.NoError(t, err) + assert.NotNil(t, rs) + assert.Equal(t, []string{"https://flyte.example.com"}, rs.allowedAudience) +} diff --git a/runs/service/auth/config/config.go b/runs/service/auth/config/config.go new file mode 100644 index 00000000000..ca426146665 --- /dev/null +++ b/runs/service/auth/config/config.go @@ -0,0 +1,298 @@ +package config + +import ( + "net/url" + + "github.com/flyteorg/flyte/v2/flytestdlib/config" +) + +//go:generate pflags Config --default-var=DefaultConfig + +type SecretName = string + +const ( + // SecretNameOIdCClientSecret defines the default OIdC client secret name to use. + // #nosec + SecretNameOIdCClientSecret SecretName = "oidc_client_secret" + + // SecretNameCookieHashKey defines the default cookie hash key secret name to use. + // #nosec + SecretNameCookieHashKey SecretName = "cookie_hash_key" + + // SecretNameCookieBlockKey defines the default cookie block key secret name to use. + // #nosec + SecretNameCookieBlockKey SecretName = "cookie_block_key" + + // SecretNameClaimSymmetricKey must be a base64 encoded secret of exactly 32 bytes. + // #nosec + SecretNameClaimSymmetricKey SecretName = "claim_symmetric_key" + + // SecretNameTokenSigningRSAKey is the private key used to sign JWT tokens (RS256). + // #nosec + SecretNameTokenSigningRSAKey SecretName = "token_rsa_key.pem" + + // SecretNameOldTokenSigningRSAKey is the old private key for key rotation. Only used to + // validate incoming tokens; new tokens will not be issued with this key. + // #nosec + SecretNameOldTokenSigningRSAKey SecretName = "token_rsa_key_old.pem" +) + +// AuthorizationServerType defines the type of Authorization Server to use. +type AuthorizationServerType int + +const ( + // AuthorizationServerTypeSelf indicates the service acts as its own authorization server. + AuthorizationServerTypeSelf AuthorizationServerType = iota + // AuthorizationServerTypeExternal indicates an external authorization server is used. + AuthorizationServerTypeExternal +) + +// SameSite represents the SameSite cookie policy. +type SameSite int + +const ( + SameSiteDefaultMode SameSite = iota + SameSiteLaxMode + SameSiteStrictMode + SameSiteNoneMode +) + +var ( + DefaultConfig = &Config{ + HTTPAuthorizationHeader: "flyte-authorization", + GrpcAuthorizationHeader: "flyte-authorization", + UserAuth: UserAuthConfig{ + RedirectURL: config.URL{URL: *MustParseURL("/console")}, + CookieHashKeySecretName: SecretNameCookieHashKey, + CookieBlockKeySecretName: SecretNameCookieBlockKey, + OpenID: OpenIDOptions{ + ClientSecretName: SecretNameOIdCClientSecret, + Scopes: []string{ + "openid", + "profile", + }, + }, + CookieSetting: CookieSettings{ + Domain: "", + SameSitePolicy: SameSiteDefaultMode, + }, + }, + AppAuth: OAuth2Options{ + ExternalAuthServer: ExternalAuthorizationServer{ + RetryAttempts: 5, + RetryDelay: config.Duration{Duration: 1_000_000_000}, // 1 second + }, + AuthServerType: AuthorizationServerTypeSelf, + ThirdParty: ThirdPartyConfigOptions{ + FlyteClientConfig: FlyteClientConfig{ + ClientID: "flytectl", + RedirectURI: "http://localhost:53593/callback", + Scopes: []string{"all", "offline"}, + }, + }, + }, + } + + cfgSection = config.MustRegisterSection("auth", DefaultConfig) +) + +// Config holds the full authentication configuration. +type Config struct { + // HTTPAuthorizationHeader is the HTTP header name for authorization (for non-standard headers behind Envoy). + HTTPAuthorizationHeader string `json:"httpAuthorizationHeader"` + + // GrpcAuthorizationHeader is the gRPC metadata key for authorization. + GrpcAuthorizationHeader string `json:"grpcAuthorizationHeader"` + + // DisableForHTTP disables auth enforcement on HTTP endpoints. + DisableForHTTP bool `json:"disableForHttp" pflag:",Disables auth enforcement on HTTP Endpoints."` + + // DisableForGrpc disables auth enforcement on gRPC endpoints. + DisableForGrpc bool `json:"disableForGrpc" pflag:",Disables auth enforcement on Grpc Endpoints."` + + // AuthorizedURIs defines the set of URIs that clients are allowed to visit the service on. + AuthorizedURIs []config.URL `json:"authorizedUris" pflag:"-,Defines the set of URIs that clients are allowed to visit the service on."` + + // HTTPProxyURL allows accessing external OAuth2 servers through an HTTP proxy. + HTTPProxyURL config.URL `json:"httpProxyURL" pflag:",OPTIONAL: HTTP Proxy to be used for OAuth requests."` + + // UserAuth settings used to authenticate end users in web-browsers. + UserAuth UserAuthConfig `json:"userAuth" pflag:",Defines Auth options for users."` + + // AppAuth defines app-level OAuth2 settings. + AppAuth OAuth2Options `json:"appAuth" pflag:",Defines Auth options for apps. UserAuth must be enabled for AppAuth to work."` + + // SecureCookie sets the Secure flag on auth cookies. Should be true in production (HTTPS). + SecureCookie bool `json:"secureCookie" pflag:",Set the Secure flag on auth cookies"` + + // TokenEndpointProxyConfig proxies token endpoint calls through admin. + TokenEndpointProxyConfig TokenEndpointProxyConfig `json:"tokenEndpointProxyConfig" pflag:",Configuration for proxying token endpoint requests."` +} + +// OAuth2Options holds OAuth2 authorization server options. +type OAuth2Options struct { + // AuthServerType determines whether to use a self-hosted or external auth server. + AuthServerType AuthorizationServerType `json:"authServerType"` + + // SelfAuthServer configures the self-hosted authorization server. + SelfAuthServer AuthorizationServer `json:"selfAuthServer" pflag:",Authorization Server config to run as a service."` + + // ExternalAuthServer configures the external authorization server. + ExternalAuthServer ExternalAuthorizationServer `json:"externalAuthServer" pflag:",External Authorization Server config."` + + // ThirdParty configures third-party (public client) settings. + ThirdParty ThirdPartyConfigOptions `json:"thirdPartyConfig" pflag:",Defines settings to instruct flyte cli tools on what config to use."` +} + +// AuthorizationServer configures a self-hosted authorization server. +type AuthorizationServer struct { + // Issuer is the issuer URL. If empty, the first AuthorizedURI is used. + Issuer string `json:"issuer" pflag:",Defines the issuer to use when issuing and validating tokens."` + + // AccessTokenLifespan defines the lifespan of issued access tokens. + AccessTokenLifespan config.Duration `json:"accessTokenLifespan" pflag:",Defines the lifespan of issued access tokens."` + + // RefreshTokenLifespan defines the lifespan of issued refresh tokens. + RefreshTokenLifespan config.Duration `json:"refreshTokenLifespan" pflag:",Defines the lifespan of issued refresh tokens."` + + // AuthorizationCodeLifespan defines the lifespan of issued authorization codes. + AuthorizationCodeLifespan config.Duration `json:"authorizationCodeLifespan" pflag:",Defines the lifespan of issued authorization codes."` + + // ClaimSymmetricEncryptionKeySecretName is the secret name for claim encryption. + ClaimSymmetricEncryptionKeySecretName string `json:"claimSymmetricEncryptionKeySecretName" pflag:",Secret name for claim encryption key."` + + // TokenSigningRSAKeySecretName is the secret name for the RSA signing key. + TokenSigningRSAKeySecretName string `json:"tokenSigningRSAKeySecretName" pflag:",Secret name for RSA Signing Key."` + + // OldTokenSigningRSAKeySecretName is the secret name for the old RSA signing key (key rotation). + OldTokenSigningRSAKeySecretName string `json:"oldTokenSigningRSAKeySecretName" pflag:",Secret name for Old RSA Signing Key for key rotation."` +} + +// ExternalAuthorizationServer configures an external authorization server. +type ExternalAuthorizationServer struct { + // BaseURL is the base URL of the external authorization server. + BaseURL config.URL `json:"baseUrl" pflag:",Base url of the external authorization server."` + + // AllowedAudience is the set of audiences accepted when validating access tokens. + AllowedAudience []string `json:"allowedAudience" pflag:",A list of allowed audiences."` + + // MetadataEndpointURL overrides the default .well-known/oauth-authorization-server endpoint. + MetadataEndpointURL config.URL `json:"metadataUrl" pflag:",Custom metadata url if the server doesn't support the standard endpoint."` + + // HTTPProxyURL allows accessing the external auth server through an HTTP proxy. + HTTPProxyURL config.URL `json:"httpProxyURL" pflag:",HTTP Proxy for external OAuth requests."` + + // RetryAttempts is the number of retry attempts for fetching metadata. + RetryAttempts int `json:"retryAttempts" pflag:",Number of retry attempts for metadata fetch."` + + // RetryDelay is the delay between retry attempts. + RetryDelay config.Duration `json:"retryDelay" pflag:",Duration to wait between retries."` +} + +// ThirdPartyConfigOptions holds third-party OAuth2 client settings. +type ThirdPartyConfigOptions struct { + // FlyteClientConfig holds public client configuration. + FlyteClientConfig FlyteClientConfig `json:"flyteClient"` +} + +// IsEmpty returns true if the third-party config has no meaningful values set. +func (o ThirdPartyConfigOptions) IsEmpty() bool { + return len(o.FlyteClientConfig.ClientID) == 0 && + len(o.FlyteClientConfig.RedirectURI) == 0 && + len(o.FlyteClientConfig.Scopes) == 0 +} + +// FlyteClientConfig holds the public client configuration. +type FlyteClientConfig struct { + // ClientID is the public client ID. + ClientID string `json:"clientId" pflag:",Public identifier for the app which handles authorization."` + + // RedirectURI is the redirect URI for the client. + RedirectURI string `json:"redirectUri" pflag:",Callback uri registered with the app which handles authorization."` + + // Scopes are the OAuth2 scopes to request. + Scopes []string `json:"scopes" pflag:",Recommended scopes for the client to request."` + + // Audience is the intended audience for OAuth2 tokens. + Audience string `json:"audience" pflag:",Audience to use when initiating OAuth2 authorization requests."` +} + +// UserAuthConfig holds user authentication settings (browser-based OAuth2/OIDC flows). +type UserAuthConfig struct { + // RedirectURL is the default redirect URL after the OAuth2 flow completes. + RedirectURL config.URL `json:"redirectUrl"` + + // OpenID defines settings for connecting and trusting an OpenID Connect provider. + OpenID OpenIDOptions `json:"openId" pflag:",OpenID Configuration for User Auth"` + + // HTTPProxyURL allows operators to access external OAuth2 servers using an HTTP Proxy. + HTTPProxyURL config.URL `json:"httpProxyURL" pflag:",HTTP Proxy for OAuth requests."` + + // CookieHashKeySecretName is the secret name for the cookie hash key. + CookieHashKeySecretName string `json:"cookieHashKeySecretName" pflag:",Secret name for cookie hash key."` + + // CookieBlockKeySecretName is the secret name for the cookie block key. + CookieBlockKeySecretName string `json:"cookieBlockKeySecretName" pflag:",Secret name for cookie block key."` + + // CookieSetting configures cookie behavior. + CookieSetting CookieSettings `json:"cookieSetting" pflag:",Settings for auth cookies."` + + // IDPQueryParameter is used to select a particular IDP for user authentication. + IDPQueryParameter string `json:"idpQueryParameter" pflag:",IDP query parameter for selecting a particular IDP."` +} + +// OpenIDOptions holds OpenID Connect provider configuration. +type OpenIDOptions struct { + // ClientID is the client ID for this service in the IDP. + ClientID string `json:"clientId"` + + // ClientSecretName is the secret name containing the OIDC client secret. + ClientSecretName string `json:"clientSecretName"` + + // BaseURL is the base URL of the OIDC provider. + BaseURL config.URL `json:"baseUrl"` + + // Scopes to request from the IDP when authenticating. + Scopes []string `json:"scopes"` +} + +// CookieSettings configures cookie behavior. +type CookieSettings struct { + // SameSitePolicy controls the SameSite attribute on auth cookies. + SameSitePolicy SameSite `json:"sameSitePolicy" pflag:",SameSite policy for auth cookies."` + + // Domain sets the domain attribute on auth cookies. + Domain string `json:"domain" pflag:",Domain attribute on auth cookies."` +} + +// TokenEndpointProxyConfig configures proxying of token endpoint calls. +type TokenEndpointProxyConfig struct { + // Enabled enables token endpoint proxying. + Enabled bool `json:"enabled" pflag:",Enables the token endpoint proxy."` + + // PublicURL is the public URL to use for rewriting the token endpoint. + PublicURL config.URL `json:"publicUrl" pflag:",Public URL for the token endpoint proxy."` + + // PathPrefix is appended to the public URL when rewriting. + PathPrefix string `json:"pathPrefix" pflag:",Path prefix for proxying token requests."` +} + +// URL is an alias for flytestdlib config.URL, re-exported for convenience. +type URL = config.URL + +// Duration is an alias for flytestdlib config.Duration, re-exported for convenience. +type Duration = config.Duration + +// GetConfig returns the parsed auth configuration. +func GetConfig() *Config { + return cfgSection.GetConfig().(*Config) +} + +// MustParseURL panics if the provided url fails parsing. Should only be used in package initialization or tests. +func MustParseURL(rawURL string) *url.URL { + res, err := url.Parse(rawURL) + if err != nil { + panic(err) + } + return res +} diff --git a/runs/service/auth/constants.go b/runs/service/auth/constants.go new file mode 100644 index 00000000000..295ac8e4655 --- /dev/null +++ b/runs/service/auth/constants.go @@ -0,0 +1,22 @@ +package auth + +const ( + // OAuth2 Parameters + CsrfFormKey = "state" + AuthorizationResponseCodeType = "code" + DefaultAuthorizationHeader = "authorization" + BearerScheme = "Bearer" + IDTokenScheme = "IDToken" + // Add the -bin suffix so that the header value is automatically base64 encoded + UserInfoMDKey = "UserInfo-bin" + + // https://tools.ietf.org/html/rfc8414 + // This should be defined without a leading slash. If there is one, the url library's ResolveReference will make it a root path + OAuth2MetadataEndpoint = ".well-known/oauth-authorization-server" + + // https://openid.net/specs/openid-connect-discovery-1_0.html + // This should be defined without a leading slash. If there is one, the url library's ResolveReference will make it a root path + OIdCMetadataEndpoint = ".well-known/openid-configuration" + + RedirectURLParameter = "redirect_url" +) diff --git a/runs/service/auth/cookie.go b/runs/service/auth/cookie.go new file mode 100644 index 00000000000..b8bb54f6355 --- /dev/null +++ b/runs/service/auth/cookie.go @@ -0,0 +1,190 @@ +package auth + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "math/rand" + "net/http" + "net/url" + + "github.com/gorilla/securecookie" + + "github.com/flyteorg/flyte/v2/flytestdlib/logger" + "github.com/flyteorg/flyte/v2/runs/service/auth/config" +) + +const ( + // #nosec + accessTokenCookieName = "flyte_at" + // #nosec + accessTokenCookieNameSplitFirst = "flyte_at_1" + // #nosec + accessTokenCookieNameSplitSecond = "flyte_at_2" + // #nosec + idTokenCookieName = "flyte_idt" + // #nosec + refreshTokenCookieName = "flyte_rt" + // #nosec + csrfStateCookieName = "flyte_csrf_state" + // #nosec + redirectURLCookieName = "flyte_redirect_location" + + // #nosec + idTokenExtra = "id_token" + + // #nosec + authCodeCookieName = "flyte_auth_code" + + // #nosec + userInfoCookieName = "flyte_user_info" +) + +var AllowedChars = []rune("abcdefghijklmnopqrstuvwxyz1234567890") + +func HashCsrfState(csrf string) string { + shaBytes := sha256.Sum256([]byte(csrf)) + hash := hex.EncodeToString(shaBytes[:]) + return hash +} + +func NewSecureCookie(cookieName, value string, hashKey, blockKey []byte, domain string, sameSiteMode http.SameSite) (http.Cookie, error) { + s := securecookie.New(hashKey, blockKey) + encoded, err := s.Encode(cookieName, value) + if err != nil { + return http.Cookie{}, fmt.Errorf("error creating secure cookie: %w", err) + } + + return http.Cookie{ + Name: cookieName, + Value: encoded, + Domain: domain, + SameSite: sameSiteMode, + HttpOnly: true, + Secure: config.GetConfig().SecureCookie, + Path: "/", + }, nil +} + +func retrieveSecureCookie(ctx context.Context, request *http.Request, cookieName string, hashKey, blockKey []byte) (string, error) { + cookie, err := request.Cookie(cookieName) + if err != nil { + logger.Infof(ctx, "Could not detect existing cookie [%v]. Error: %v", cookieName, err) + return "", fmt.Errorf("failure to retrieve cookie [%v]: %w", cookieName, err) + } + + if cookie == nil { + logger.Infof(ctx, "Retrieved empty cookie [%v].", cookieName) + return "", fmt.Errorf("retrieved empty cookie [%v]", cookieName) + } + + logger.Debugf(ctx, "Existing [%v] cookie found", cookieName) + token, err := ReadSecureCookie(ctx, *cookie, hashKey, blockKey) + if err != nil { + logger.Errorf(ctx, "Error reading existing secure cookie [%v]. Error: %s", cookieName, err) + return "", fmt.Errorf("error reading existing secure cookie [%v]: %w", cookieName, err) + } + + if len(token) == 0 { + logger.Errorf(ctx, "Read empty token from secure cookie [%v].", cookieName) + return "", fmt.Errorf("read empty token from secure cookie [%v]", cookieName) + } + + return token, nil +} + +func ReadSecureCookie(ctx context.Context, cookie http.Cookie, hashKey, blockKey []byte) (string, error) { + s := securecookie.New(hashKey, blockKey) + var value string + if err := s.Decode(cookie.Name, cookie.Value, &value); err == nil { + return value, nil + } + logger.Errorf(ctx, "Error reading secure cookie %s", cookie.Name) + return "", fmt.Errorf("error reading secure cookie %s", cookie.Name) +} + +func NewCsrfToken(seed int64) string { + r := rand.New(rand.NewSource(seed)) //nolint:gosec + csrfToken := [10]rune{} + for i := 0; i < len(csrfToken); i++ { + csrfToken[i] = AllowedChars[r.Intn(len(AllowedChars))] + } + return string(csrfToken[:]) +} + +func NewCsrfCookie() http.Cookie { + csrfStateToken := NewCsrfToken(rand.Int63()) //nolint:gosec + return http.Cookie{ + Name: csrfStateCookieName, + Value: csrfStateToken, + SameSite: http.SameSiteLaxMode, + HttpOnly: true, + Secure: config.GetConfig().SecureCookie, + } +} + +func VerifyCsrfCookie(_ context.Context, request *http.Request) error { + csrfState := request.FormValue(CsrfFormKey) + if csrfState == "" { + return fmt.Errorf("empty state in callback, %s", request.Form) + } + csrfCookie, err := request.Cookie(csrfStateCookieName) + if csrfCookie == nil || err != nil { + return fmt.Errorf("could not find csrf cookie: %v", err) + } + if HashCsrfState(csrfCookie.Value) != csrfState { + return fmt.Errorf("CSRF token does not match state %s, %s vs %s", csrfCookie.Value, + HashCsrfState(csrfCookie.Value), csrfState) + } + return nil +} + +// NewRedirectCookie creates a cookie to keep track of where to send the user after +// the OAuth2 login flow is complete. +func NewRedirectCookie(ctx context.Context, redirectURL string) *http.Cookie { + urlObj, err := url.Parse(redirectURL) + if err != nil || urlObj == nil { + logger.Errorf(ctx, "Error creating redirect cookie %s %s", urlObj, err) + return nil + } + + if urlObj.EscapedPath() == "" { + logger.Errorf(ctx, "Error parsing URL, redirect %s resolved to empty string", redirectURL) + return nil + } + + return &http.Cookie{ + Name: redirectURLCookieName, + Value: urlObj.String(), + SameSite: http.SameSiteLaxMode, + HttpOnly: true, + Secure: config.GetConfig().SecureCookie, + } +} + +// GetAuthFlowEndRedirect returns the redirect URI according to data in request. +// At the end of the OAuth flow, the server needs to send the user somewhere. This should have been stored as a cookie +// during the initial /login call. If that cookie is missing from the request, it will default to the one configured. +func GetAuthFlowEndRedirect(ctx context.Context, defaultRedirect string, authorizedURIs []config.URL, request *http.Request) string { + queryParams := request.URL.Query() + if redirectURL := queryParams.Get(RedirectURLParameter); len(redirectURL) > 0 { + if GetRedirectURLAllowed(ctx, redirectURL, authorizedURIs) { + return redirectURL + } + logger.Warnf(ctx, "Rejecting unauthorized redirect_url from query parameter: %s", redirectURL) + return defaultRedirect + } + + cookie, err := request.Cookie(redirectURLCookieName) + if err != nil { + logger.Debugf(ctx, "Could not detect end-of-flow redirect url cookie") + return defaultRedirect + } + + if GetRedirectURLAllowed(ctx, cookie.Value, authorizedURIs) { + return cookie.Value + } + logger.Warnf(ctx, "Rejecting unauthorized redirect_url from cookie: %s", cookie.Value) + return defaultRedirect +} diff --git a/runs/service/auth/cookie_manager.go b/runs/service/auth/cookie_manager.go new file mode 100644 index 00000000000..0a2113f712d --- /dev/null +++ b/runs/service/auth/cookie_manager.go @@ -0,0 +1,225 @@ +package auth + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "time" + + "golang.org/x/oauth2" + + authpb "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" + "github.com/flyteorg/flyte/v2/flytestdlib/logger" + "github.com/flyteorg/flyte/v2/runs/service/auth/config" +) + +// CookieManager manages encrypted cookie operations for auth tokens. +type CookieManager struct { + hashKey []byte + blockKey []byte + domain string + sameSitePolicy config.SameSite +} + +func NewCookieManager(ctx context.Context, hashKeyEncoded, blockKeyEncoded string, cookieSettings config.CookieSettings) (CookieManager, error) { + logger.Infof(ctx, "Instantiating cookie manager") + + hashKey, err := base64.RawStdEncoding.DecodeString(hashKeyEncoded) + if err != nil { + return CookieManager{}, fmt.Errorf("error decoding hash key bytes: %w", err) + } + + blockKey, err := base64.RawStdEncoding.DecodeString(blockKeyEncoded) + if err != nil { + return CookieManager{}, fmt.Errorf("error decoding block key bytes: %w", err) + } + + return CookieManager{ + hashKey: hashKey, + blockKey: blockKey, + domain: cookieSettings.Domain, + sameSitePolicy: cookieSettings.SameSitePolicy, + }, nil +} + +func (c CookieManager) RetrieveAccessToken(ctx context.Context, request *http.Request) (string, error) { + // If there is an old access token, we will retrieve it + oldAccessToken, err := retrieveSecureCookie(ctx, request, accessTokenCookieName, c.hashKey, c.blockKey) + if err == nil && oldAccessToken != "" { + return oldAccessToken, nil + } + // If there is no old access token, we will retrieve the new split access token + accessTokenFirstHalf, err := retrieveSecureCookie(ctx, request, accessTokenCookieNameSplitFirst, c.hashKey, c.blockKey) + if err != nil { + return "", err + } + accessTokenSecondHalf, err := retrieveSecureCookie(ctx, request, accessTokenCookieNameSplitSecond, c.hashKey, c.blockKey) + if err != nil { + return "", err + } + return accessTokenFirstHalf + accessTokenSecondHalf, nil +} + +// RetrieveTokenValues retrieves id, access and refresh tokens from cookies if they exist. +func (c CookieManager) RetrieveTokenValues(ctx context.Context, request *http.Request) (idToken, accessToken, + refreshToken string, err error) { + + idToken, err = retrieveSecureCookie(ctx, request, idTokenCookieName, c.hashKey, c.blockKey) + if err != nil { + return "", "", "", err + } + + accessToken, err = c.RetrieveAccessToken(ctx, request) + if err != nil { + return "", "", "", err + } + + refreshToken, err = retrieveSecureCookie(ctx, request, refreshTokenCookieName, c.hashKey, c.blockKey) + if err != nil { + // Refresh tokens are optional. + logger.Infof(ctx, "Refresh token doesn't exist or failed to read it. Ignoring this error. Error: %v", err) + err = nil + } + + return +} + +func (c CookieManager) SetUserInfoCookie(ctx context.Context, writer http.ResponseWriter, userInfo *authpb.UserInfoResponse) error { + raw, err := json.Marshal(userInfo) + if err != nil { + return fmt.Errorf("failed to marshal user info to store in a cookie: %w", err) + } + + return c.SetUserInfoCookieRaw(ctx, writer, string(raw)) +} + +func (c CookieManager) SetUserInfoCookieRaw(ctx context.Context, writer http.ResponseWriter, userInfoStr string) error { + userInfoCookie, err := NewSecureCookie(userInfoCookieName, userInfoStr, c.hashKey, c.blockKey, c.domain, c.getHTTPSameSitePolicy()) + if err != nil { + logger.Errorf(ctx, "Error generating encrypted user info cookie %s", err) + return err + } + + http.SetCookie(writer, &userInfoCookie) + return nil +} + +func (c CookieManager) RetrieveUserInfo(ctx context.Context, request *http.Request) (*authpb.UserInfoResponse, error) { + userInfoCookie, err := retrieveSecureCookie(ctx, request, userInfoCookieName, c.hashKey, c.blockKey) + if err != nil { + return nil, err + } + + res := authpb.UserInfoResponse{} + err = json.Unmarshal([]byte(userInfoCookie), &res) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal user info cookie: %w", err) + } + + return &res, nil +} + +func (c CookieManager) RetrieveAuthCodeRequest(ctx context.Context, request *http.Request) (string, error) { + return retrieveSecureCookie(ctx, request, authCodeCookieName, c.hashKey, c.blockKey) +} + +func (c CookieManager) SetAuthCodeCookie(ctx context.Context, writer http.ResponseWriter, authRequestURL string) error { + authCodeCookie, err := NewSecureCookie(authCodeCookieName, authRequestURL, c.hashKey, c.blockKey, c.domain, c.getHTTPSameSitePolicy()) + if err != nil { + logger.Errorf(ctx, "Error generating encrypted auth code cookie %s", err) + return err + } + + http.SetCookie(writer, &authCodeCookie) + return nil +} + +func (c CookieManager) StoreAccessToken(ctx context.Context, accessToken string, writer http.ResponseWriter) error { + midpoint := len(accessToken) / 2 + firstHalf := accessToken[:midpoint] + secondHalf := accessToken[midpoint:] + atCookieFirst, err := NewSecureCookie(accessTokenCookieNameSplitFirst, firstHalf, c.hashKey, c.blockKey, c.domain, c.getHTTPSameSitePolicy()) + if err != nil { + logger.Errorf(ctx, "Error generating encrypted accesstoken cookie first half %s", err) + return err + } + http.SetCookie(writer, &atCookieFirst) + atCookieSecond, err := NewSecureCookie(accessTokenCookieNameSplitSecond, secondHalf, c.hashKey, c.blockKey, c.domain, c.getHTTPSameSitePolicy()) + if err != nil { + logger.Errorf(ctx, "Error generating encrypted accesstoken cookie second half %s", err) + return err + } + http.SetCookie(writer, &atCookieSecond) + return nil +} + +func (c CookieManager) SetTokenCookies(ctx context.Context, writer http.ResponseWriter, token *oauth2.Token) error { + idToken, accessToken, refreshToken, err := ExtractTokensFromOauthToken(token) + if err != nil { + logger.Errorf(ctx, "Unable to read all token values from oauth token: %s", err) + return fmt.Errorf("unable to read all token values from oauth token: %w", err) + } + + idCookie, err := NewSecureCookie(idTokenCookieName, idToken, c.hashKey, c.blockKey, c.domain, c.getHTTPSameSitePolicy()) + if err != nil { + logger.Errorf(ctx, "Error generating encrypted id token cookie %s", err) + return err + } + + http.SetCookie(writer, &idCookie) + + err = c.StoreAccessToken(ctx, accessToken, writer) + if err != nil { + logger.Errorf(ctx, "Error storing access token %s", err) + return err + } + + // Set the refresh cookie if there is a refresh token + if len(refreshToken) > 0 { + refreshCookie, err := NewSecureCookie(refreshTokenCookieName, token.RefreshToken, c.hashKey, c.blockKey, c.domain, c.getHTTPSameSitePolicy()) + if err != nil { + logger.Errorf(ctx, "Error generating encrypted refresh token cookie %s", err) + return err + } + http.SetCookie(writer, &refreshCookie) + } + + return nil +} + +func (c CookieManager) getLogoutCookie(name string) *http.Cookie { + return &http.Cookie{ + Name: name, + Value: "", + Domain: c.domain, + MaxAge: 0, + HttpOnly: true, + Secure: config.GetConfig().SecureCookie, + Expires: time.Now().Add(-1 * time.Hour), + } +} + +func (c CookieManager) DeleteCookies(_ context.Context, writer http.ResponseWriter) { + http.SetCookie(writer, c.getLogoutCookie(accessTokenCookieName)) + http.SetCookie(writer, c.getLogoutCookie(accessTokenCookieNameSplitFirst)) + http.SetCookie(writer, c.getLogoutCookie(accessTokenCookieNameSplitSecond)) + http.SetCookie(writer, c.getLogoutCookie(refreshTokenCookieName)) + http.SetCookie(writer, c.getLogoutCookie(idTokenCookieName)) +} + +func (c CookieManager) getHTTPSameSitePolicy() http.SameSite { + switch c.sameSitePolicy { + case config.SameSiteDefaultMode: + return http.SameSiteDefaultMode + case config.SameSiteLaxMode: + return http.SameSiteLaxMode + case config.SameSiteStrictMode: + return http.SameSiteStrictMode + case config.SameSiteNoneMode: + return http.SameSiteNoneMode + default: + return http.SameSiteDefaultMode + } +} diff --git a/runs/service/auth/handler_utils.go b/runs/service/auth/handler_utils.go new file mode 100644 index 00000000000..2307932b20c --- /dev/null +++ b/runs/service/auth/handler_utils.go @@ -0,0 +1,201 @@ +package auth + +import ( + "context" + "net/http" + "net/url" + "strings" + + "google.golang.org/grpc/metadata" + + "github.com/flyteorg/flyte/v2/flytestdlib/logger" + "github.com/flyteorg/flyte/v2/runs/service/auth/config" +) + +const ( + metadataXForwardedHost = "x-forwarded-host" + metadataAuthority = ":authority" +) + +// URLFromRequest attempts to reconstruct the url from the request object. Or nil if not possible +func URLFromRequest(req *http.Request) *url.URL { + if req == nil { + return nil + } + + // from browser req.RequestURI is "/login" and u.scheme is "" + // from unit test req.RequestURI is "" and u is nil + // That means that this function, URLFromRequest(req) returns https://localhost:8088 even though there's no SSL, + // when the request is made from http://localhost:8088 in the web browser. + // Given how this function is used however, it's okay - we're only picking which option to use from the list of + // authorized URIs. + u, _ := url.ParseRequestURI(req.RequestURI) + if u != nil && u.IsAbs() { + return u + } + + if len(req.Host) == 0 { + return nil + } + + scheme := "https://" + if req.URL != nil && len(req.URL.Scheme) > 0 { + scheme = req.URL.Scheme + "://" + } + + u, _ = url.Parse(scheme + req.Host) + return u +} + +// URLFromContext attempts to retrieve the original url from context. gRPC gateway sets metadata in context that refers +// to the original host. Or nil if metadata isn't set. +func URLFromContext(ctx context.Context) *url.URL { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil + } + + forwardedHost := getMetadataValue(md, metadataXForwardedHost) + if len(forwardedHost) == 0 { + forwardedHost = getMetadataValue(md, metadataAuthority) + } + + if len(forwardedHost) == 0 { + return nil + } + + u, _ := url.Parse("https://" + forwardedHost) + return u +} + +// getMetadataValue retrieves the first value for a given key from gRPC metadata. +func getMetadataValue(md metadata.MD, key string) string { + vals := md.Get(key) + if len(vals) == 0 { + return "" + } + return vals[0] +} + +// FirstURL gets the first non-nil url from a list of given urls. +func FirstURL(urls ...*url.URL) *url.URL { + for _, u := range urls { + if u != nil { + return u + } + } + + return nil +} + +// wildcardMatch checks if hostname matches a wildcard pattern (only one level deep) +// Supports patterns like "*.union.ai" matching "tenant1.union.ai" +func wildcardMatch(hostname, pattern string) bool { + if strings.HasPrefix(pattern, "*.") { + urlParts := strings.SplitN(hostname, ".", 2) + if len(urlParts) < 2 { + return false + } + return urlParts[1] == pattern[2:] + } + return hostname == pattern +} + +// buildURL constructs a URL using the authorized template but with the matched hostname +func buildURL(authorizedURL *url.URL, matchedHostname string) *url.URL { + result := *authorizedURL // Copy the URL to avoid modifying the original + if authorizedURL.Port() != "" { + result.Host = matchedHostname + ":" + authorizedURL.Port() + } else { + result.Host = matchedHostname + } + return &result +} + +// GetPublicURL attempts to retrieve the public url of the service. If httpPublicUri is set in the config, it takes +// precedence. If the request is not nil and has a host set, it comes second and lastly it attempts to retrieve the url +// from context if set (e.g. by gRPC gateway). +func GetPublicURL(ctx context.Context, req *http.Request, cfg config.Config) *url.URL { + u := FirstURL(URLFromRequest(req), URLFromContext(ctx)) + var hostMatching *url.URL + var hostAndPortMatching *url.URL + var matchedHostname string + + for i, authorized := range cfg.AuthorizedURIs { + if u == nil { + return &authorized.URL + } + + if wildcardMatch(u.Hostname(), authorized.Hostname()) { + matchedHostname = u.Hostname() + hostMatching = &cfg.AuthorizedURIs[i].URL + if u.Port() == authorized.Port() { + hostAndPortMatching = &cfg.AuthorizedURIs[i].URL + } + + if u.Scheme == authorized.Scheme { + return buildURL(&cfg.AuthorizedURIs[i].URL, matchedHostname) + } + } + } + + if hostAndPortMatching != nil { + return buildURL(hostAndPortMatching, matchedHostname) + } + + if hostMatching != nil { + return buildURL(hostMatching, matchedHostname) + } + + if len(cfg.AuthorizedURIs) > 0 { + return &cfg.AuthorizedURIs[0].URL + } + + return u +} + +// GetIssuer returns the issuer from SelfAuthServer config, or falls back to public URL. +func GetIssuer(ctx context.Context, req *http.Request, cfg config.Config) string { + if configIssuer := cfg.AppAuth.SelfAuthServer.Issuer; len(configIssuer) > 0 { + return configIssuer + } + + return GetPublicURL(ctx, req, cfg).String() +} + +// isAuthorizedRedirectURL checks if a redirect URL matches an authorized URL pattern. +func isAuthorizedRedirectURL(u *url.URL, authorizedURL *url.URL) bool { + if u == nil || authorizedURL == nil { + return false + } + if u.Scheme != authorizedURL.Scheme || u.Port() != authorizedURL.Port() { + return false + } + return wildcardMatch(u.Hostname(), authorizedURL.Hostname()) +} + +// GetRedirectURLAllowed checks whether a redirect URL is in the list of authorized URIs. +func GetRedirectURLAllowed(ctx context.Context, urlRedirectParam string, authorizedURIs []config.URL) bool { + if len(urlRedirectParam) == 0 { + logger.Debugf(ctx, "not validating whether empty redirect url is authorized") + return true + } + redirectURL, err := url.Parse(urlRedirectParam) + if err != nil { + logger.Debugf(ctx, "failed to parse user-supplied redirect url: %s with err: %v", urlRedirectParam, err) + return false + } + if redirectURL.Host == "" { + logger.Debugf(ctx, "not validating whether relative redirect url is authorized") + return true + } + logger.Debugf(ctx, "validating whether redirect url: %s is authorized", redirectURL) + for i := range authorizedURIs { + if isAuthorizedRedirectURL(redirectURL, &authorizedURIs[i].URL) { + logger.Debugf(ctx, "authorizing redirect url: %s against authorized uri: %s", redirectURL.String(), authorizedURIs[i].String()) + return true + } + } + logger.Debugf(ctx, "not authorizing redirect url: %s", redirectURL.String()) + return false +} diff --git a/runs/service/auth/handler_utils_test.go b/runs/service/auth/handler_utils_test.go new file mode 100644 index 00000000000..670ab5acb94 --- /dev/null +++ b/runs/service/auth/handler_utils_test.go @@ -0,0 +1,175 @@ +package auth + +import ( + "context" + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + + authConfig "github.com/flyteorg/flyte/v2/runs/service/auth/config" + "google.golang.org/grpc/metadata" +) + +func TestURLFromRequest(t *testing.T) { + t.Run("nil request", func(t *testing.T) { + assert.Nil(t, URLFromRequest(nil)) + }) + + t.Run("request with host", func(t *testing.T) { + req := &http.Request{ + Host: "example.com:8080", + URL: &url.URL{Scheme: "https"}, + } + u := URLFromRequest(req) + assert.NotNil(t, u) + assert.Equal(t, "example.com:8080", u.Host) + }) + + t.Run("request with empty host", func(t *testing.T) { + req := &http.Request{ + URL: &url.URL{}, + } + assert.Nil(t, URLFromRequest(req)) + }) + + t.Run("request with absolute request URI", func(t *testing.T) { + req := &http.Request{ + RequestURI: "https://absolute.example.com/path", + } + u := URLFromRequest(req) + assert.NotNil(t, u) + assert.Equal(t, "absolute.example.com", u.Hostname()) + }) +} + +func TestURLFromContext(t *testing.T) { + t.Run("no metadata", func(t *testing.T) { + assert.Nil(t, URLFromContext(context.Background())) + }) + + t.Run("x-forwarded-host", func(t *testing.T) { + md := metadata.Pairs("x-forwarded-host", "forwarded.example.com") + ctx := metadata.NewIncomingContext(context.Background(), md) + u := URLFromContext(ctx) + assert.NotNil(t, u) + assert.Equal(t, "forwarded.example.com", u.Hostname()) + assert.Equal(t, "https", u.Scheme) + }) + + t.Run("authority fallback", func(t *testing.T) { + md := metadata.Pairs(":authority", "authority.example.com") + ctx := metadata.NewIncomingContext(context.Background(), md) + u := URLFromContext(ctx) + assert.NotNil(t, u) + assert.Equal(t, "authority.example.com", u.Hostname()) + }) +} + +func TestFirstURL(t *testing.T) { + u1, _ := url.Parse("https://first.example.com") + u2, _ := url.Parse("https://second.example.com") + + assert.Nil(t, FirstURL()) + assert.Nil(t, FirstURL(nil, nil)) + assert.Equal(t, u1, FirstURL(u1, u2)) + assert.Equal(t, u2, FirstURL(nil, u2)) +} + +func TestWildcardMatch(t *testing.T) { + assert.True(t, wildcardMatch("example.com", "example.com")) + assert.True(t, wildcardMatch("tenant1.union.ai", "*.union.ai")) + assert.False(t, wildcardMatch("union.ai", "*.union.ai")) + assert.False(t, wildcardMatch("other.com", "example.com")) + assert.False(t, wildcardMatch("sub.tenant1.union.ai", "*.union.ai")) +} + +func TestBuildURL(t *testing.T) { + authorized, _ := url.Parse("https://example.com:8080/path") + result := buildURL(authorized, "tenant1.example.com") + assert.Equal(t, "tenant1.example.com:8080", result.Host) + assert.Equal(t, "/path", result.Path) + + noPort, _ := url.Parse("https://example.com/path") + result = buildURL(noPort, "tenant1.example.com") + assert.Equal(t, "tenant1.example.com", result.Host) +} + +func TestGetPublicURL(t *testing.T) { + t.Run("no request, returns first authorized URI", func(t *testing.T) { + cfg := authConfig.Config{ + AuthorizedURIs: []authConfig.URL{ + {URL: *mustParseURL("https://flyte.example.com")}, + }, + } + u := GetPublicURL(context.Background(), nil, cfg) + assert.Equal(t, "flyte.example.com", u.Hostname()) + }) + + t.Run("no request, no authorized URIs, returns nil", func(t *testing.T) { + cfg := authConfig.Config{} + u := GetPublicURL(context.Background(), nil, cfg) + assert.Nil(t, u) + }) + + t.Run("wildcard match with request", func(t *testing.T) { + cfg := authConfig.Config{ + AuthorizedURIs: []authConfig.URL{ + {URL: *mustParseURL("https://*.union.ai")}, + }, + } + req := &http.Request{ + Host: "tenant1.union.ai", + URL: &url.URL{Scheme: "https"}, + } + u := GetPublicURL(context.Background(), req, cfg) + assert.Equal(t, "tenant1.union.ai", u.Hostname()) + assert.Equal(t, "https", u.Scheme) + }) + + t.Run("exact match with matching scheme", func(t *testing.T) { + cfg := authConfig.Config{ + AuthorizedURIs: []authConfig.URL{ + {URL: *mustParseURL("https://flyte.example.com:8080")}, + {URL: *mustParseURL("http://flyte.example.com:8080")}, + }, + } + req := &http.Request{ + Host: "flyte.example.com:8080", + URL: &url.URL{Scheme: "http"}, + } + u := GetPublicURL(context.Background(), req, cfg) + assert.Equal(t, "http", u.Scheme) + }) +} + +func TestGetIssuer(t *testing.T) { + t.Run("custom issuer", func(t *testing.T) { + cfg := authConfig.Config{ + AppAuth: authConfig.OAuth2Options{ + SelfAuthServer: authConfig.AuthorizationServer{ + Issuer: "https://custom-issuer.example.com", + }, + }, + } + assert.Equal(t, "https://custom-issuer.example.com", GetIssuer(context.Background(), nil, cfg)) + }) + + t.Run("falls back to public URL", func(t *testing.T) { + cfg := authConfig.Config{ + AuthorizedURIs: []authConfig.URL{ + {URL: *mustParseURL("https://flyte.example.com")}, + }, + } + assert.Equal(t, "https://flyte.example.com", GetIssuer(context.Background(), nil, cfg)) + }) +} + +func mustParseURL(rawURL string) *url.URL { + u, err := url.Parse(rawURL) + if err != nil { + panic(err) + } + return u +} diff --git a/runs/service/auth/handlers.go b/runs/service/auth/handlers.go new file mode 100644 index 00000000000..da2fd1ca423 --- /dev/null +++ b/runs/service/auth/handlers.go @@ -0,0 +1,291 @@ +package auth + +import ( + "context" + "fmt" + "net/http" + "strings" + + "github.com/coreos/go-oidc/v3/oidc" + "golang.org/x/oauth2" + + authpb "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" + "github.com/flyteorg/flyte/v2/flytestdlib/logger" + "github.com/flyteorg/flyte/v2/runs/service/auth/config" +) + +// PreRedirectHookError is returned by PreRedirectHookFunc to signal an error with an HTTP status code. +type PreRedirectHookError struct { + Message string + Code int +} + +func (e *PreRedirectHookError) Error() string { + return e.Message +} + +// PreRedirectHookFunc is called before the redirect at the end of a successful auth callback flow. +type PreRedirectHookFunc func(ctx context.Context, request *http.Request, w http.ResponseWriter) *PreRedirectHookError + +// LogoutHookFunc is called during logout to perform additional cleanup. +type LogoutHookFunc func(ctx context.Context, request *http.Request, w http.ResponseWriter) error + +// AuthHandlerConfig holds dependencies needed by the HTTP auth handlers. +type AuthHandlerConfig struct { + CookieManager CookieManager + OAuth2Config *oauth2.Config + OIDCProvider *oidc.Provider + ResourceServer OAuth2ResourceServer + AuthConfig config.Config + HTTPClient *http.Client + PreRedirectHook PreRedirectHookFunc + LogoutHook LogoutHookFunc +} + +// RegisterHandlers registers the standard OAuth2/OIDC HTTP handlers on the given mux. +func RegisterHandlers(ctx context.Context, mux *http.ServeMux, h *AuthHandlerConfig) { + mux.HandleFunc("/login", RefreshTokensIfNeededHandler(ctx, h, + GetLoginHandler(ctx, h))) + mux.HandleFunc("/callback", GetCallbackHandler(ctx, h)) + mux.HandleFunc(fmt.Sprintf("/%s", OIdCMetadataEndpoint), GetOIdCMetadataEndpointRedirectHandler(ctx, h)) + mux.HandleFunc("/logout", GetLogoutEndpointHandler(ctx, h)) +} + +// RefreshTokensIfNeededHandler wraps a handler to attempt token refresh before redirecting. +func RefreshTokensIfNeededHandler(ctx context.Context, h *AuthHandlerConfig, authHandler http.HandlerFunc) http.HandlerFunc { + return func(writer http.ResponseWriter, request *http.Request) { + newToken, userInfo, refreshed, err := RefreshTokensIfNeeded(ctx, h, request) + if err != nil { + logger.Infof(ctx, "Failed to refresh tokens. Restarting login flow. Error: %s", err) + authHandler(writer, request) + return + } + + if refreshed { + logger.Debugf(ctx, "Tokens are refreshed. Saving new tokens into cookies.") + if err = h.CookieManager.SetTokenCookies(ctx, writer, newToken); err != nil { + logger.Infof(ctx, "Failed to write tokens to response. Restarting login flow. Error: %s", err) + authHandler(writer, request) + return + } + + if err = h.CookieManager.SetUserInfoCookie(ctx, writer, userInfo); err != nil { + logger.Infof(ctx, "Failed to write user info to response. Restarting login flow. Error: %s", err) + authHandler(writer, request) + return + } + } + + redirectURL := GetAuthFlowEndRedirect(ctx, h.AuthConfig.UserAuth.RedirectURL.String(), h.AuthConfig.AuthorizedURIs, request) + http.Redirect(writer, request, redirectURL, http.StatusTemporaryRedirect) + } +} + +// RefreshTokensIfNeeded checks if tokens need refreshing and returns refreshed tokens if so. +func RefreshTokensIfNeeded(ctx context.Context, h *AuthHandlerConfig, request *http.Request) ( + token *oauth2.Token, userInfo *authpb.UserInfoResponse, refreshed bool, err error) { + + ctx = context.WithValue(ctx, oauth2.HTTPClient, h.HTTPClient) + + idToken, accessToken, refreshToken, err := h.CookieManager.RetrieveTokenValues(ctx, request) + if err != nil { + return nil, nil, false, fmt.Errorf("failed to retrieve tokens from request: %w", err) + } + + _, err = ParseIDTokenAndValidate(ctx, h.AuthConfig.UserAuth.OpenID.ClientID, idToken, h.OIDCProvider) + if err != nil { + if strings.Contains(err.Error(), "token is expired") && len(refreshToken) > 0 { + logger.Debugf(ctx, "Expired id token found, attempting to refresh") + newToken, refreshErr := GetRefreshedToken(ctx, h.OAuth2Config, accessToken, refreshToken) + if refreshErr != nil { + return nil, nil, false, fmt.Errorf("failed to refresh tokens: %w", refreshErr) + } + + userInfo, queryErr := QueryUserInfoUsingAccessToken(ctx, request, h, newToken.AccessToken) + if queryErr != nil { + return nil, nil, false, fmt.Errorf("failed to query user info: %w", queryErr) + } + + return newToken, userInfo, true, nil + } + return nil, nil, false, fmt.Errorf("failed to validate tokens: %w", err) + } + + return NewOAuthTokenFromRaw(accessToken, refreshToken, idToken), nil, false, nil +} + +// GetLoginHandler returns an HTTP handler that starts the OAuth2 login flow. +func GetLoginHandler(ctx context.Context, h *AuthHandlerConfig) http.HandlerFunc { + return func(writer http.ResponseWriter, request *http.Request) { + csrfCookie := NewCsrfCookie() + csrfToken := csrfCookie.Value + http.SetCookie(writer, &csrfCookie) + + state := HashCsrfState(csrfToken) + logger.Debugf(ctx, "Setting CSRF state cookie to %s and state to %s\n", csrfToken, state) + urlString := h.OAuth2Config.AuthCodeURL(state) + queryParams := request.URL.Query() + if !GetRedirectURLAllowed(ctx, queryParams.Get(RedirectURLParameter), h.AuthConfig.AuthorizedURIs) { + logger.Infof(ctx, "unauthorized redirect URI") + writer.WriteHeader(http.StatusForbidden) + return + } + if flowEndRedirectURL := queryParams.Get(RedirectURLParameter); flowEndRedirectURL != "" { + redirectCookie := NewRedirectCookie(ctx, flowEndRedirectURL) + if redirectCookie != nil { + http.SetCookie(writer, redirectCookie) + } + } + + http.Redirect(writer, request, urlString, http.StatusTemporaryRedirect) + } +} + +// GetCallbackHandler returns an HTTP handler that completes the OAuth2 authorization code flow. +func GetCallbackHandler(ctx context.Context, h *AuthHandlerConfig) http.HandlerFunc { + return func(writer http.ResponseWriter, request *http.Request) { + logger.Debugf(ctx, "Running callback handler... for RequestURI %v", request.RequestURI) + authorizationCode := request.FormValue(AuthorizationResponseCodeType) + + ctx = context.WithValue(ctx, oauth2.HTTPClient, h.HTTPClient) + + if err := VerifyCsrfCookie(ctx, request); err != nil { + logger.Errorf(ctx, "Invalid CSRF token cookie %s", err) + writer.WriteHeader(http.StatusUnauthorized) + return + } + + token, err := h.OAuth2Config.Exchange(ctx, authorizationCode) + if err != nil { + logger.Errorf(ctx, "Error when exchanging code %s", err) + writer.WriteHeader(http.StatusForbidden) + return + } + + if err = h.CookieManager.SetTokenCookies(ctx, writer, token); err != nil { + logger.Errorf(ctx, "Error setting encrypted JWT cookie %s", err) + writer.WriteHeader(http.StatusForbidden) + return + } + + userInfo, err := QueryUserInfoUsingAccessToken(ctx, request, h, token.AccessToken) + if err != nil { + logger.Errorf(ctx, "Failed to query user info. Error: %v", err) + writer.WriteHeader(http.StatusForbidden) + return + } + + if err = h.CookieManager.SetUserInfoCookie(ctx, writer, userInfo); err != nil { + logger.Errorf(ctx, "Error setting encrypted user info cookie. Error: %v", err) + writer.WriteHeader(http.StatusForbidden) + return + } + + if h.PreRedirectHook != nil { + if hookErr := h.PreRedirectHook(ctx, request, writer); hookErr != nil { + logger.Errorf(ctx, "failed the preRedirect hook due %v with status code %v", hookErr.Message, hookErr.Code) + if http.StatusText(hookErr.Code) != "" { + writer.WriteHeader(hookErr.Code) + } else { + writer.WriteHeader(http.StatusInternalServerError) + } + return + } + } + + redirectURL := GetAuthFlowEndRedirect(ctx, h.AuthConfig.UserAuth.RedirectURL.String(), h.AuthConfig.AuthorizedURIs, request) + http.Redirect(writer, request, redirectURL, http.StatusTemporaryRedirect) + } +} + +// GetOIdCMetadataEndpointRedirectHandler returns a handler that redirects to the OIDC metadata endpoint. +func GetOIdCMetadataEndpointRedirectHandler(_ context.Context, h *AuthHandlerConfig) http.HandlerFunc { + return func(writer http.ResponseWriter, request *http.Request) { + oidcMetadataURL := h.AuthConfig.UserAuth.OpenID.BaseURL.JoinPath("/").JoinPath(OIdCMetadataEndpoint) + http.Redirect(writer, request, oidcMetadataURL.String(), http.StatusSeeOther) + } +} + +// GetLogoutEndpointHandler returns a handler that clears auth cookies and optionally redirects. +func GetLogoutEndpointHandler(ctx context.Context, h *AuthHandlerConfig) http.HandlerFunc { + return func(writer http.ResponseWriter, request *http.Request) { + if h.LogoutHook != nil { + if err := h.LogoutHook(ctx, request, writer); err != nil { + logger.Errorf(ctx, "logout hook failed: %v", err) + writer.WriteHeader(http.StatusInternalServerError) + return + } + } + + logger.Debugf(ctx, "deleting auth cookies") + h.CookieManager.DeleteCookies(ctx, writer) + + queryParams := request.URL.Query() + if redirectURL := queryParams.Get(RedirectURLParameter); redirectURL != "" { + if !GetRedirectURLAllowed(ctx, redirectURL, h.AuthConfig.AuthorizedURIs) { + logger.Warnf(ctx, "Rejecting unauthorized redirect_url in logout: %s", redirectURL) + redirectURL = h.AuthConfig.UserAuth.RedirectURL.String() + } + http.Redirect(writer, request, redirectURL, http.StatusTemporaryRedirect) + } + } +} + +// QueryUserInfoUsingAccessToken fetches user info from the OIDC provider using an access token. +func QueryUserInfoUsingAccessToken(ctx context.Context, originalRequest *http.Request, h *AuthHandlerConfig, accessToken string) ( + *authpb.UserInfoResponse, error) { + + originalToken := oauth2.Token{ + AccessToken: accessToken, + } + + tokenSource := h.OAuth2Config.TokenSource(ctx, &originalToken) + + userInfo, err := h.OIDCProvider.UserInfo(ctx, tokenSource) + if err != nil { + logger.Errorf(ctx, "Error getting user info from IDP %s", err) + return &authpb.UserInfoResponse{}, fmt.Errorf("error getting user info from IDP") + } + + resp := &authpb.UserInfoResponse{} + if err = userInfo.Claims(resp); err != nil { + logger.Errorf(ctx, "Error getting user info from IDP %s", err) + return &authpb.UserInfoResponse{}, fmt.Errorf("error getting user info from IDP") + } + + return resp, nil +} + +// IdentityContextFromRequest extracts identity from an HTTP request (header or cookies). +func IdentityContextFromRequest(ctx context.Context, req *http.Request, h *AuthHandlerConfig) ( + *IdentityContext, error) { + + authHeader := DefaultAuthorizationHeader + if len(h.AuthConfig.HTTPAuthorizationHeader) > 0 { + authHeader = h.AuthConfig.HTTPAuthorizationHeader + } + + headerValue := req.Header.Get(authHeader) + if len(headerValue) == 0 { + headerValue = req.Header.Get(DefaultAuthorizationHeader) + } + + if len(headerValue) > 0 { + if strings.HasPrefix(headerValue, BearerScheme+" ") { + expectedAudience := GetPublicURL(ctx, req, h.AuthConfig).String() + return h.ResourceServer.ValidateAccessToken(ctx, expectedAudience, strings.TrimPrefix(headerValue, BearerScheme+" ")) + } + } + + idToken, _, _, err := h.CookieManager.RetrieveTokenValues(ctx, req) + if err != nil || len(idToken) == 0 { + return nil, fmt.Errorf("unauthenticated request. IDToken Len [%v], Error: %w", len(idToken), err) + } + + userInfo, err := h.CookieManager.RetrieveUserInfo(ctx, req) + if err != nil { + return nil, fmt.Errorf("unauthenticated request: %w", err) + } + + return IdentityContextFromIDToken(ctx, idToken, h.AuthConfig.UserAuth.OpenID.ClientID, h.OIDCProvider, userInfo) +} diff --git a/runs/service/auth/identity_context.go b/runs/service/auth/identity_context.go new file mode 100644 index 00000000000..7f1159f41b8 --- /dev/null +++ b/runs/service/auth/identity_context.go @@ -0,0 +1,80 @@ +package auth + +import ( + "context" + "time" + + authpb "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" +) + +type contextKey string + +const contextKeyIdentityContext contextKey = "identity_context" + +// ScopeAll is the default scope granted to user-only access tokens. +const ScopeAll = "all" + +// IdentityContext encloses the authenticated identity of the user/app. Both gRPC and HTTP +// servers have interceptors to set the IdentityContext on the context.AuthenticationContext. +type IdentityContext struct { + audience string + userID string + appID string + authenticatedAt time.Time + userInfo *authpb.UserInfoResponse + scopes []string + claims map[string]interface{} +} + +// NewIdentityContext creates a new IdentityContext. +func NewIdentityContext(audience, userID, appID string, authenticatedAt time.Time, scopes []string, userInfo *authpb.UserInfoResponse, claims map[string]interface{}) *IdentityContext { + if userInfo == nil { + userInfo = &authpb.UserInfoResponse{} + } + + if len(userInfo.Subject) == 0 { + userInfo.Subject = userID + } + + return &IdentityContext{ + audience: audience, + userID: userID, + appID: appID, + authenticatedAt: authenticatedAt, + userInfo: userInfo, + scopes: scopes, + claims: claims, + } +} + +func (c *IdentityContext) Audience() string { return c.audience } +func (c *IdentityContext) UserID() string { return c.userID } +func (c *IdentityContext) AppID() string { return c.appID } +func (c *IdentityContext) AuthenticatedAt() time.Time { return c.authenticatedAt } +func (c *IdentityContext) Scopes() []string { return c.scopes } +func (c *IdentityContext) Claims() map[string]interface{} { return c.claims } + +func (c *IdentityContext) UserInfo() *authpb.UserInfoResponse { + if c.userInfo == nil { + return &authpb.UserInfoResponse{} + } + return c.userInfo +} + +func (c *IdentityContext) IsEmpty() bool { + return c == nil || (c.audience == "" && c.userID == "" && c.appID == "") +} + +// WithContext stores the IdentityContext in the given context. +func (c *IdentityContext) WithContext(ctx context.Context) context.Context { + return context.WithValue(ctx, contextKeyIdentityContext, c) +} + +// IdentityContextFromContext retrieves the authenticated identity from context.AuthenticationContext. +func IdentityContextFromContext(ctx context.Context) *IdentityContext { + existing := ctx.Value(contextKeyIdentityContext) + if existing != nil { + return existing.(*IdentityContext) + } + return nil +} diff --git a/runs/service/auth/identity_context_test.go b/runs/service/auth/identity_context_test.go new file mode 100644 index 00000000000..1dc1950026d --- /dev/null +++ b/runs/service/auth/identity_context_test.go @@ -0,0 +1,65 @@ +package auth + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + authpb "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" +) + +func TestNewIdentityContext(t *testing.T) { + userInfo := &authpb.UserInfoResponse{ + Name: "Test User", + Email: "test@example.com", + } + now := time.Now() + + ic := NewIdentityContext("aud", "user1", "app1", now, []string{"read"}, userInfo, nil) + + assert.Equal(t, "aud", ic.Audience()) + assert.Equal(t, "user1", ic.UserID()) + assert.Equal(t, "app1", ic.AppID()) + assert.Equal(t, now, ic.AuthenticatedAt()) + assert.Equal(t, []string{"read"}, ic.Scopes()) + assert.Equal(t, "Test User", ic.UserInfo().Name) + // Subject should be filled from userID when empty + assert.Equal(t, "user1", ic.UserInfo().Subject) +} + +func TestNewIdentityContext_PreservesSubject(t *testing.T) { + userInfo := &authpb.UserInfoResponse{ + Subject: "existing-sub", + } + + ic := NewIdentityContext("aud", "user1", "", time.Time{}, nil, userInfo, nil) + assert.Equal(t, "existing-sub", ic.UserInfo().Subject) +} + +func TestNewIdentityContext_NilUserInfo(t *testing.T) { + ic := NewIdentityContext("aud", "user1", "", time.Time{}, nil, nil, nil) + require.NotNil(t, ic.UserInfo()) + assert.Equal(t, "user1", ic.UserInfo().Subject) +} + +func TestIdentityContext_IsEmpty(t *testing.T) { + assert.True(t, (*IdentityContext)(nil).IsEmpty()) + assert.True(t, (&IdentityContext{}).IsEmpty()) + assert.False(t, (&IdentityContext{userID: "u"}).IsEmpty()) +} + +func TestIdentityContext_WithContext(t *testing.T) { + ic := NewIdentityContext("aud", "user1", "app1", time.Now(), nil, nil, nil) + ctx := ic.WithContext(context.Background()) + + retrieved := IdentityContextFromContext(ctx) + require.NotNil(t, retrieved) + assert.Equal(t, "user1", retrieved.UserID()) +} + +func TestIdentityContextFromContext_Empty(t *testing.T) { + assert.Nil(t, IdentityContextFromContext(context.Background())) +} diff --git a/runs/service/auth/identity_service.go b/runs/service/auth/identity_service.go new file mode 100644 index 00000000000..d2c5131175e --- /dev/null +++ b/runs/service/auth/identity_service.go @@ -0,0 +1,29 @@ +package auth + +import ( + "context" + + "connectrpc.com/connect" + + authpb "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth/authconnect" +) + +// IdentityService implements the IdentityServiceHandler interface. +type IdentityService struct{} + +// NewIdentityService creates a new IdentityService instance. +func NewIdentityService() *IdentityService { + return &IdentityService{} +} + +var _ authconnect.IdentityServiceHandler = (*IdentityService)(nil) + +// UserInfo returns information about the currently logged in user. +// TODO: Wire with real auth to populate user info from the authenticated context. +func (s *IdentityService) UserInfo( + ctx context.Context, + req *connect.Request[authpb.UserInfoRequest], +) (*connect.Response[authpb.UserInfoResponse], error) { + return connect.NewResponse(&authpb.UserInfoResponse{}), nil +} diff --git a/runs/service/auth/identity_service_test.go b/runs/service/auth/identity_service_test.go new file mode 100644 index 00000000000..ae76efce8f8 --- /dev/null +++ b/runs/service/auth/identity_service_test.go @@ -0,0 +1,21 @@ +package auth + +import ( + "context" + "testing" + + "connectrpc.com/connect" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + authpb "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" +) + +func TestIdentityService_UserInfo(t *testing.T) { + svc := NewIdentityService() + + resp, err := svc.UserInfo(context.Background(), connect.NewRequest(&authpb.UserInfoRequest{})) + require.NoError(t, err) + assert.NotNil(t, resp) + assert.NotNil(t, resp.Msg) +} diff --git a/runs/service/auth/interceptor.go b/runs/service/auth/interceptor.go new file mode 100644 index 00000000000..55f7332315e --- /dev/null +++ b/runs/service/auth/interceptor.go @@ -0,0 +1,98 @@ +package auth + +import ( + "context" + "fmt" + + "github.com/coreos/go-oidc/v3/oidc" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + + "github.com/flyteorg/flyte/v2/flytestdlib/logger" + "github.com/flyteorg/flyte/v2/runs/service/auth/config" +) + +// BlanketAuthorization is a gRPC unary interceptor that checks the authenticated identity has the "all" scope. +func BlanketAuthorization(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) ( + resp interface{}, err error) { + + identityContext := IdentityContextFromContext(ctx) + if identityContext == nil { + return handler(ctx, req) + } + + for _, scope := range identityContext.Scopes() { + if scope == ScopeAll { + return handler(ctx, req) + } + } + + logger.Debugf(ctx, "authenticated user doesn't have required scope") + return nil, status.Errorf(codes.Unauthenticated, "authenticated user doesn't have required scope") +} + +// GetAuthenticationCustomMetadataInterceptor produces a gRPC interceptor that translates a custom authorization +// header name to the standard "authorization" header for downstream interceptors. +func GetAuthenticationCustomMetadataInterceptor(cfg config.Config) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + if cfg.GrpcAuthorizationHeader != DefaultAuthorizationHeader { + md, ok := metadata.FromIncomingContext(ctx) + if ok { + existingHeader := md.Get(cfg.GrpcAuthorizationHeader) + if len(existingHeader) > 0 { + logger.Debugf(ctx, "Found existing metadata header %s", cfg.GrpcAuthorizationHeader) + newAuthorizationMetadata := metadata.Pairs(DefaultAuthorizationHeader, existingHeader[0]) + joinedMetadata := metadata.Join(md, newAuthorizationMetadata) + newCtx := metadata.NewIncomingContext(ctx, joinedMetadata) + return handler(newCtx, req) + } + } + } + return handler(ctx, req) + } +} + +// GetAuthenticationInterceptor returns a function that validates incoming gRPC requests. +// It attempts to extract and validate an access token or ID token from the request metadata. +func GetAuthenticationInterceptor(cfg config.Config, resourceServer OAuth2ResourceServer, oidcProvider *oidc.Provider) func(context.Context) (context.Context, error) { + return func(ctx context.Context) (context.Context, error) { + logger.Debugf(ctx, "Running authentication gRPC interceptor") + + expectedAudience := GetPublicURL(ctx, nil, cfg).String() + + identityContext, accessTokenErr := GRPCGetIdentityFromAccessToken(ctx, expectedAudience, resourceServer) + if accessTokenErr == nil { + return identityContext.WithContext(ctx), nil + } + + logger.Infof(ctx, "Failed to parse Access Token from context. Will attempt to find IDToken. Error: %v", accessTokenErr) + + identityContext, idTokenErr := GRPCGetIdentityFromIDToken(ctx, cfg.UserAuth.OpenID.ClientID, oidcProvider) + if idTokenErr == nil { + return identityContext.WithContext(ctx), nil + } + logger.Debugf(ctx, "Failed to parse ID Token from context. Error: %v", idTokenErr) + + if !cfg.DisableForGrpc { + err := fmt.Errorf("[id token err: %w] | [access token err: %w]", idTokenErr, accessTokenErr) + return ctx, status.Errorf(codes.Unauthenticated, "token parse error %s", err) + } + + return ctx, nil + } +} + +// AuthenticationLoggingInterceptor logs information about the authenticated user for each gRPC request. +func AuthenticationLoggingInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + identityContext := IdentityContextFromContext(ctx) + if identityContext != nil { + var emailPlaceholder string + if len(identityContext.UserInfo().GetEmail()) > 0 { + emailPlaceholder = fmt.Sprintf(" (%s) ", identityContext.UserInfo().GetEmail()) + } + logger.Debugf(ctx, "gRPC server info in logging interceptor [%s]%smethod [%s]\n", identityContext.UserID(), emailPlaceholder, info.FullMethod) + } + return handler(ctx, req) +} diff --git a/runs/service/auth/interfaces.go b/runs/service/auth/interfaces.go new file mode 100644 index 00000000000..617ff457137 --- /dev/null +++ b/runs/service/auth/interfaces.go @@ -0,0 +1,8 @@ +package auth + +import "context" + +// OAuth2ResourceServer represents a resource server that can validate access tokens. +type OAuth2ResourceServer interface { + ValidateAccessToken(ctx context.Context, expectedAudience, tokenStr string) (*IdentityContext, error) +} diff --git a/runs/service/auth/token.go b/runs/service/auth/token.go new file mode 100644 index 00000000000..466f2f572ff --- /dev/null +++ b/runs/service/auth/token.go @@ -0,0 +1,176 @@ +package auth + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/coreos/go-oidc/v3/oidc" + "golang.org/x/oauth2" + "google.golang.org/grpc/metadata" + + authpb "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" + "github.com/flyteorg/flyte/v2/flytestdlib/logger" +) + +// GetRefreshedToken refreshes a JWT using the provided OAuth2 config and refresh token. +func GetRefreshedToken(ctx context.Context, oauth *oauth2.Config, accessToken, refreshToken string) (*oauth2.Token, error) { + logger.Debugf(ctx, "Attempting to refresh token") + originalToken := oauth2.Token{ + AccessToken: accessToken, + RefreshToken: refreshToken, + Expiry: time.Now().Add(-1 * time.Minute), // force expired by setting to the past + } + + tokenSource := oauth.TokenSource(ctx, &originalToken) + newToken, err := tokenSource.Token() + if err != nil { + logger.Errorf(ctx, "Error refreshing token %s", err) + return nil, fmt.Errorf("error refreshing token: %w", err) + } + + return newToken, nil +} + +// ParseIDTokenAndValidate parses and validates an ID token using the OIDC provider. +func ParseIDTokenAndValidate(ctx context.Context, clientID, rawIDToken string, provider *oidc.Provider) (*oidc.IDToken, error) { + cfg := &oidc.Config{ + ClientID: clientID, + } + + if len(clientID) == 0 { + cfg.SkipClientIDCheck = true + cfg.SkipIssuerCheck = true + cfg.SkipExpiryCheck = true + } + + verifier := provider.Verifier(cfg) + + idToken, err := verifier.Verify(ctx, rawIDToken) + if err != nil { + logger.Debugf(ctx, "JWT parsing with claims failed %s", err) + if strings.Contains(err.Error(), "token is expired") { + return idToken, fmt.Errorf("token is expired: %w", err) + } + return idToken, fmt.Errorf("jwt parse with claims failed: %w", err) + } + return idToken, nil +} + +// GRPCGetIdentityFromAccessToken attempts to extract a bearer token from gRPC metadata +// and validate it using the provided resource server. +func GRPCGetIdentityFromAccessToken(ctx context.Context, expectedAudience string, resourceServer OAuth2ResourceServer) ( + *IdentityContext, error) { + + tokenStr, err := bearerTokenFromMD(ctx) + if err != nil { + return nil, fmt.Errorf("could not retrieve bearer token from metadata: %w", err) + } + + return resourceServer.ValidateAccessToken(ctx, expectedAudience, tokenStr) +} + +// GRPCGetIdentityFromIDToken attempts to extract an ID token from gRPC metadata and validate it. +func GRPCGetIdentityFromIDToken(ctx context.Context, clientID string, provider *oidc.Provider) ( + *IdentityContext, error) { + + tokenStr, err := idTokenFromMD(ctx) + if err != nil { + return nil, fmt.Errorf("could not retrieve id token from metadata: %w", err) + } + + return IdentityContextFromIDToken(ctx, tokenStr, clientID, provider, nil) +} + +// IdentityContextFromIDToken creates an IdentityContext from a validated ID token. +func IdentityContextFromIDToken(ctx context.Context, tokenStr, clientID string, provider *oidc.Provider, + userInfo *authpb.UserInfoResponse) (*IdentityContext, error) { + + idToken, err := ParseIDTokenAndValidate(ctx, clientID, tokenStr, provider) + if err != nil { + return nil, err + } + var claims map[string]interface{} + if err := idToken.Claims(&claims); err != nil { + logger.Infof(ctx, "Failed to unmarshal claims from id token, err: %v", err) + } + + return NewIdentityContext(idToken.Audience[0], idToken.Subject, "", idToken.IssuedAt, + []string{ScopeAll}, userInfo, claims), nil +} + +func NewOAuthTokenFromRaw(accessToken, refreshToken, idToken string) *oauth2.Token { + return (&oauth2.Token{ + AccessToken: accessToken, + RefreshToken: refreshToken, + }).WithExtra(map[string]interface{}{ + idTokenExtra: idToken, + }) +} + +func ExtractTokensFromOauthToken(token *oauth2.Token) (idToken, accessToken, refreshToken string, err error) { + if token == nil { + return "", "", "", fmt.Errorf("attempting to set cookies with nil token") + } + + idTokenRaw, converted := token.Extra(idTokenExtra).(string) + if !converted { + return "", "", "", fmt.Errorf("response does not contain an id_token") + } + + return idTokenRaw, token.AccessToken, token.RefreshToken, nil +} + +// bearerTokenFromMD extracts a Bearer token from gRPC incoming metadata. +func bearerTokenFromMD(ctx context.Context) (string, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return "", fmt.Errorf("no metadata in context") + } + + vals := md.Get(DefaultAuthorizationHeader) + if len(vals) == 0 { + return "", fmt.Errorf("no authorization header in metadata") + } + + header := vals[0] + prefix := BearerScheme + " " + if !strings.HasPrefix(header, prefix) { + return "", fmt.Errorf("authorization header does not start with %q", BearerScheme) + } + + token := strings.TrimPrefix(header, prefix) + if token == "" { + return "", fmt.Errorf("bearer token is blank") + } + + return token, nil +} + +// idTokenFromMD extracts an IDToken from gRPC incoming metadata. +func idTokenFromMD(ctx context.Context) (string, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return "", fmt.Errorf("no metadata in context") + } + + vals := md.Get(DefaultAuthorizationHeader) + if len(vals) == 0 { + return "", fmt.Errorf("no authorization header in metadata") + } + + header := vals[0] + prefix := IDTokenScheme + " " + if !strings.HasPrefix(header, prefix) { + return "", fmt.Errorf("authorization header does not start with %q", IDTokenScheme) + } + + token := strings.TrimPrefix(header, prefix) + if token == "" { + return "", fmt.Errorf("id token is blank") + } + + return token, nil +} + diff --git a/runs/service/auth/user_info_provider.go b/runs/service/auth/user_info_provider.go new file mode 100644 index 00000000000..eda30e6889f --- /dev/null +++ b/runs/service/auth/user_info_provider.go @@ -0,0 +1,28 @@ +package auth + +import ( + "context" + + "connectrpc.com/connect" + + authpb "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth/authconnect" +) + +// UserInfoProvider serves user info claims about the currently logged in user. +// See the OpenID Connect spec at https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse +type UserInfoProvider struct { + authconnect.UnimplementedIdentityServiceHandler +} + +func NewUserInfoProvider() *UserInfoProvider { + return &UserInfoProvider{} +} + +func (s *UserInfoProvider) UserInfo(ctx context.Context, _ *connect.Request[authpb.UserInfoRequest]) (*connect.Response[authpb.UserInfoResponse], error) { + identityContext := IdentityContextFromContext(ctx) + if identityContext != nil { + return connect.NewResponse(identityContext.UserInfo()), nil + } + return connect.NewResponse(&authpb.UserInfoResponse{}), nil +} From ba32dcf61bf7ebded8450747e029a9e27de742e3 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 11 Mar 2026 17:54:32 -0700 Subject: [PATCH 05/13] rename Signed-off-by: Kevin Su --- .../{admin-auth-secret.yaml => run-service-auth-secret.yaml} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename charts/flyte-binary/templates/{admin-auth-secret.yaml => run-service-auth-secret.yaml} (100%) diff --git a/charts/flyte-binary/templates/admin-auth-secret.yaml b/charts/flyte-binary/templates/run-service-auth-secret.yaml similarity index 100% rename from charts/flyte-binary/templates/admin-auth-secret.yaml rename to charts/flyte-binary/templates/run-service-auth-secret.yaml From 2fcdbe7a552957aceba6e1f285e3e4ad4675b8a9 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 11 Mar 2026 20:10:17 -0700 Subject: [PATCH 06/13] Add console chatrt Signed-off-by: Kevin Su --- charts/flyte-binary/templates/_helpers.tpl | 8 ++++ .../flyte-binary/templates/clusterrole.yaml | 9 ++++ .../templates/console/deployment.yaml | 42 +++++++++++++++++++ .../templates/console/service.yaml | 20 +++++++++ .../flyte-binary/templates/ingress/http.yaml | 16 +++++++ charts/flyte-binary/values.yaml | 13 ++++++ 6 files changed, 108 insertions(+) create mode 100644 charts/flyte-binary/templates/console/deployment.yaml create mode 100644 charts/flyte-binary/templates/console/service.yaml diff --git a/charts/flyte-binary/templates/_helpers.tpl b/charts/flyte-binary/templates/_helpers.tpl index 58f7c805c4c..1519acca9d7 100644 --- a/charts/flyte-binary/templates/_helpers.tpl +++ b/charts/flyte-binary/templates/_helpers.tpl @@ -113,6 +113,14 @@ templates: {{- toYaml .custom | nindent 2 -}} {{- end -}} {{- end -}} +{{/* +Selector labels for Console +*/}} +{{ define "flyte-binary.consoleSelectorLabels" -}} +{{ include "flyte-binary.selectorLabels" . }} +app.kubernetes.io/component: console +{{- end }} + {{/* Get the Secret name for Run service authentication secrets. */}} diff --git a/charts/flyte-binary/templates/clusterrole.yaml b/charts/flyte-binary/templates/clusterrole.yaml index 02988e68ab1..ab6484b68c5 100644 --- a/charts/flyte-binary/templates/clusterrole.yaml +++ b/charts/flyte-binary/templates/clusterrole.yaml @@ -19,6 +19,15 @@ metadata: {{- tpl ( .Values.rbac.annotations | toYaml ) . | nindent 4 }} {{- end }} rules: + - apiGroups: + - "" + resources: + - namespaces + verbs: + - create + - get + - list + - watch - apiGroups: - "" resources: diff --git a/charts/flyte-binary/templates/console/deployment.yaml b/charts/flyte-binary/templates/console/deployment.yaml new file mode 100644 index 00000000000..fbe53324802 --- /dev/null +++ b/charts/flyte-binary/templates/console/deployment.yaml @@ -0,0 +1,42 @@ +{{- if .Values.console.enabled }} +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ include "flyte-binary.fullname" . }}-console + namespace: {{ .Release.Namespace | quote }} + labels: {{ include "flyte-binary.labels" . | nindent 4 }} +spec: + replicas: 1 + selector: + matchLabels: {{ include "flyte-binary.consoleSelectorLabels" . | nindent 6 }} + template: + metadata: + labels: {{ include "flyte-binary.consoleSelectorLabels" . | nindent 8 }} + spec: + {{- with .Values.console.imagePullSecrets }} + imagePullSecrets: + {{ toYaml . | nindent 8 }} + {{- end }} + containers: + - name: console + {{- with .Values.console.image }} + image: {{ printf "%s:%s" .repository .tag | quote }} + imagePullPolicy: {{ .pullPolicy | quote }} + {{- end }} + ports: + - name: http + containerPort: 8080 + protocol: TCP + readinessProbe: + httpGet: + path: /v2 + port: http + initialDelaySeconds: 5 + periodSeconds: 10 + livenessProbe: + httpGet: + path: /v2 + port: http + initialDelaySeconds: 5 + periodSeconds: 30 +{{- end }} \ No newline at end of file diff --git a/charts/flyte-binary/templates/console/service.yaml b/charts/flyte-binary/templates/console/service.yaml new file mode 100644 index 00000000000..6ba2e10d6a4 --- /dev/null +++ b/charts/flyte-binary/templates/console/service.yaml @@ -0,0 +1,20 @@ +{{- if .Values.console.enabled }} +apiVersion: v1 +kind: Service +metadata: + name: {{ include "flyte-binary.fullname" . }}-console + namespace: {{ .Release.Namespace | quote }} + labels: {{ include "flyte-binary.labels" . | nindent 4 }} + {{- with .Values.console.service.annotations }} + annotations: + {{ toYaml . | nindent 4 }} + {{- end }} +spec: + type: {{ .Values.console.service.type | default "ClusterIP" }} + ports: + - name: http + port: {{ .Values.console.service.port | default 80 }} + targetPort: http + protocol: TCP + selector: {{ include "flyte-binary.consoleSelectorLabels" . | nindent 4 }} +{{- end }} diff --git a/charts/flyte-binary/templates/ingress/http.yaml b/charts/flyte-binary/templates/ingress/http.yaml index cdce87047bd..ef47d792dbe 100644 --- a/charts/flyte-binary/templates/ingress/http.yaml +++ b/charts/flyte-binary/templates/ingress/http.yaml @@ -38,6 +38,22 @@ spec: {{- if .Values.ingress.httpExtraPaths.prepend }} {{- tpl ( .Values.ingress.httpExtraPaths.prepend | toYaml ) . | nindent 6 }} {{- end }} + {{- if .Values.console.enabled }} + - backend: + service: + name: {{ include "flyte-binary.fullname" . }}-console + port: + number: {{ .Values.console.service.port | default 80 }} + path: /v2 + pathType: ImplementationSpecific + - backend: + service: + name: {{ include "flyte-binary.fullname" . }}-console + port: + number: {{ .Values.console.service.port | default 80 }} + path: /v2/* + pathType: ImplementationSpecific + {{- end }} - backend: service: name: {{ include "flyte-binary.service.http.name" . }} diff --git a/charts/flyte-binary/values.yaml b/charts/flyte-binary/values.yaml index 6c361c974e3..2aeff2604cd 100644 --- a/charts/flyte-binary/values.yaml +++ b/charts/flyte-binary/values.yaml @@ -426,3 +426,16 @@ enabled_plugins: container_array: k8s-array # -- Uncomment to enable task type that uses Flyte Connector # bigquery_query_job_task: connector-service + +console: + enabled: true + image: + repository: ghcr.io/flyteorg/consolev2 + tag: 34149492567122a466f885700f1df9da42025e6b + pullPolicy: Always + imagePullSecrets: + - name: ghcr-pull-secret + service: + type: ClusterIP + port: 80 + annotations: {} \ No newline at end of file From decce4056c1d8390973b35d18b46798de4af371f Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 11 Mar 2026 23:05:33 -0700 Subject: [PATCH 07/13] update chart Signed-off-by: Kevin Su --- charts/flyte-binary/templates/_helpers.tpl | 2 +- charts/flyte-binary/templates/console/deployment.yaml | 4 ++++ charts/flyte-binary/templates/deployment.yaml | 10 ++++++---- charts/flyte-binary/templates/service/grpc.yaml | 2 +- charts/flyte-binary/templates/service/http.yaml | 2 +- charts/flyte-binary/values.yaml | 4 +++- 6 files changed, 16 insertions(+), 8 deletions(-) diff --git a/charts/flyte-binary/templates/_helpers.tpl b/charts/flyte-binary/templates/_helpers.tpl index 1519acca9d7..19f0282a0b2 100644 --- a/charts/flyte-binary/templates/_helpers.tpl +++ b/charts/flyte-binary/templates/_helpers.tpl @@ -153,7 +153,7 @@ Get the Flyte HTTP service name Get the Flyte service HTTP port. */}} {{- define "flyte-binary.service.http.port" -}} -{{- default 8090 .Values.service.ports.http -}} +{{- default 8080 .Values.service.ports.http -}} {{- end -}} {{/* diff --git a/charts/flyte-binary/templates/console/deployment.yaml b/charts/flyte-binary/templates/console/deployment.yaml index fbe53324802..2b2b6eded7d 100644 --- a/charts/flyte-binary/templates/console/deployment.yaml +++ b/charts/flyte-binary/templates/console/deployment.yaml @@ -17,6 +17,10 @@ spec: imagePullSecrets: {{ toYaml . | nindent 8 }} {{- end }} + {{- with .Values.deployment.extraPodSpec.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 8 }} + {{- end }} containers: - name: console {{- with .Values.console.image }} diff --git a/charts/flyte-binary/templates/deployment.yaml b/charts/flyte-binary/templates/deployment.yaml index 704f149f010..3107a9362ae 100644 --- a/charts/flyte-binary/templates/deployment.yaml +++ b/charts/flyte-binary/templates/deployment.yaml @@ -92,14 +92,14 @@ spec: {{- end }} {{- end }} {{- if .Values.deployment.resources }} - resources: {{- toYaml .Values.deployment.resources | nindent 12 }} + resources: {{ toYaml .Values.deployment.resources | nindent 12 }} {{- end }} {{- if .Values.deployment.waitForDB.securityContext }} - securityContext: {{- toYaml .Values.deployment.waitForDB.securityContext | nindent 12 }} + securityContext: {{ toYaml .Values.deployment.waitForDB.securityContext | nindent 12 }} {{- end }} {{- end }} {{- if .Values.deployment.initContainers }} - {{- tpl ( .Values.deployment.initContainers | toYaml ) . | nindent 8 }} + {{ tpl ( .Values.deployment.initContainers | toYaml ) . | nindent 8 }} {{- end }} {{- end }} containers: @@ -144,6 +144,8 @@ spec: ports: - name: http containerPort: 8090 + - name: grpc + containerPort: 8080 {{- if .Values.deployment.startupProbe }} startupProbe: {{- tpl ( .Values.deployment.startupProbe | toYaml ) . | nindent 12 }} {{- end }} @@ -163,7 +165,7 @@ spec: httpGet: path: /healthz port: http - initialDelaySeconds: 30 + initialDelaySeconds: 5 {{- end }} {{- if .Values.deployment.resources }} resources: {{- toYaml .Values.deployment.resources | nindent 12 }} diff --git a/charts/flyte-binary/templates/service/grpc.yaml b/charts/flyte-binary/templates/service/grpc.yaml index cc6ba67ac9c..6bc6f881766 100644 --- a/charts/flyte-binary/templates/service/grpc.yaml +++ b/charts/flyte-binary/templates/service/grpc.yaml @@ -38,7 +38,7 @@ spec: ports: - name: grpc port: {{ include "flyte-binary.service.grpc.port" . }} - targetPort: grpc + targetPort: http {{- if and (or (eq .Values.service.type "NodePort") (eq .Values.service.type "LoadBalancer")) (not (empty .Values.service.nodePorts.grpc)) }} nodePort: {{ .Values.service.nodePorts.grpc }} {{- else if eq .Values.service.type "ClusterIP" }} diff --git a/charts/flyte-binary/templates/service/http.yaml b/charts/flyte-binary/templates/service/http.yaml index ab79c90210d..e2a727aa195 100644 --- a/charts/flyte-binary/templates/service/http.yaml +++ b/charts/flyte-binary/templates/service/http.yaml @@ -46,7 +46,7 @@ spec: {{- if not .Values.ingress.separateGrpcIngress }} - name: grpc port: {{ include "flyte-binary.service.grpc.port" . }} - targetPort: grpc + targetPort: http {{- if and (or (eq .Values.service.type "NodePort") (eq .Values.service.type "LoadBalancer")) (not (empty .Values.service.nodePorts.grpc)) }} nodePort: {{ .Values.service.nodePorts.grpc }} {{- else if eq .Values.service.type "ClusterIP" }} diff --git a/charts/flyte-binary/values.yaml b/charts/flyte-binary/values.yaml index 2aeff2604cd..b57cfa13e9e 100644 --- a/charts/flyte-binary/values.yaml +++ b/charts/flyte-binary/values.yaml @@ -438,4 +438,6 @@ console: service: type: ClusterIP port: 80 - annotations: {} \ No newline at end of file + annotations: + alb.ingress.kubernetes.io/healthcheck-path: /v2 + alb.ingress.kubernetes.io/healthcheck-port: "8080" \ No newline at end of file From e8f854c602e6f456cc72ad42e91d210917c28843 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 10 Apr 2026 16:09:49 -0700 Subject: [PATCH 08/13] go mod tidy Signed-off-by: Kevin Su --- go.mod | 1 - go.sum | 6 ++---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index 5bae6502536..bcc9872a344 100644 --- a/go.mod +++ b/go.mod @@ -133,7 +133,6 @@ require ( github.com/evanphx/json-patch v5.6.0+incompatible // indirect github.com/evanphx/json-patch/v5 v5.9.11 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/glebarez/go-sqlite v1.21.2 // indirect github.com/go-jose/go-jose/v4 v4.1.3 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect diff --git a/go.sum b/go.sum index 2e8f4e9d03a..5a7cf0420b6 100644 --- a/go.sum +++ b/go.sum @@ -230,8 +230,6 @@ github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeME github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-gormigrate/gormigrate/v2 v2.1.5 h1:1OyorA5LtdQw12cyJDEHuTrEV3GiXiIhS4/QTTa/SM8= -github.com/go-gormigrate/gormigrate/v2 v2.1.5/go.mod h1:mj9ekk/7CPF3VjopaFvWKN2v7fN3D9d3eEOAXRhi/+M= github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/go-kit/log v0.2.1 h1:MRVx0/zhvdseW+Gza6N9rVzU/IVzaeE1SFI4raAhmBU= @@ -354,8 +352,8 @@ github.com/googleapis/gax-go/v2 v2.15.0/go.mod h1:zVVkkxAQHa1RQpg9z2AUCMnKhi0Qld github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= -github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= -github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= +github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 h1:+9834+KizmvFV7pXQGSXQTsaWhq2GjuNUt0aUU0YBYw= github.com/grpc-ecosystem/go-grpc-middleware v1.3.0/go.mod h1:z0ButlSOZa5vEBq9m2m2hlwIgKw+rp3sdCBRoJY+30Y= github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus v1.1.0 h1:QGLs/O40yoNK9vmy4rhUGBVyMf1lISBGtXRpsu/Qu/o= From f16451e681448c585c24b872ac3251a80ba0edf3 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 11 Apr 2026 00:04:42 -0700 Subject: [PATCH 09/13] [V2] Wire external OAuth2 auth into runs service + helm chart Enable external-mode OAuth2 authentication for the v2 runs service so JWT bearer tokens and auth cookies are validated at the HTTP boundary and the standard OIDC browser login flow is served from the same binary. Go changes: - runs/setup.go: new setupAuth() builds ResourceServer + AuthContext, registers /login /callback /logout and the OIDC metadata redirect, chains the new HTTP auth middleware with existing middleware, and replaces the buggy duplicate AuthMetadataService mount with a single real-or-stub branch. - runs/service/auth/http_middleware.go: bearer/cookie validator with a public-path allowlist (/healthz, /readyz, /healthcheck, /login, /callback, /logout, /.well-known/, /flyteidl2.auth.AuthMetadataService/). - runs/service/auth/auth_context.go: NewAuthContext takes an oidcClientSecret and populates oauth2.Config.ClientSecret; RedirectURL is computed as an absolute URL from the first authorizedUri via a new computeOIDCRedirectURL helper. - runs/service/auth/config/config.go: add go:generate enumer directives for AuthorizationServerType and SameSite. - Generated enumer files for AuthorizationServerType and SameSite. - Unit tests: http_middleware, computeOIDCRedirectURL, enumer round trips, cookie helpers, token helpers, config defaults. Helm chart (charts/flyte-binary): - templates/configmap.yaml: render a new 004-auth.yaml from configuration.auth.* (including externalAuthServer, authorizedUris, userAuth.openId, thirdPartyConfig.flyteClient, and runs.security.useAuth) when auth.enabled. - templates/_helpers.tpl: runServiceAuthSecretName honors a new configuration.auth.runServiceAuthSecretRef override so deployments can reuse an existing admin-auth secret instead of re-rendering. - templates/run-service-auth-secret.yaml: skip rendering when the override is set to avoid Helm ownership conflicts. - templates/deployment.yaml: fix run-service-auth-secret include-path typo; guard its checksum with the override; keep the existing extraInlineSecretRefs projection loop. - templates/ingress/http.yaml: new ingress.minimalPaths flag that omits /oauth2, /.well-known, /me, /config, /v1/*, /api, and /console paths so they can fall through to an adjacent Flyte deployment sharing the same ALB ingress group. - values.yaml: defaults for configuration.auth.externalAuthServer.*, configuration.auth.runServiceAuthSecretRef, and ingress.minimalPaths (false, preserving existing behavior). Signed-off-by: Kevin Su --- charts/flyte-binary/templates/_helpers.tpl | 21 ++- .../flyte-binary/templates/clusterrole.yaml | 10 ++ .../flyte-binary/templates/config-secret.yaml | 4 + charts/flyte-binary/templates/configmap.yaml | 48 ++++++ charts/flyte-binary/templates/deployment.yaml | 8 +- .../flyte-binary/templates/ingress/http.yaml | 8 + .../templates/run-service-auth-secret.yaml | 2 +- charts/flyte-binary/values.yaml | 21 ++- runs/service/auth/auth_context.go | 27 ++- runs/service/auth/auth_context_test.go | 69 ++++++++ .../config/authorizationservertype_enumer.go | 67 ++++++++ runs/service/auth/config/config.go | 2 + runs/service/auth/config/config_test.go | 112 ++++++++++++ runs/service/auth/config/samesite_enumer.go | 69 ++++++++ runs/service/auth/cookie_test.go | 160 ++++++++++++++++++ runs/service/auth/http_middleware.go | 64 +++++++ runs/service/auth/http_middleware_test.go | 92 ++++++++++ runs/service/auth/token_test.go | 94 ++++++++++ runs/setup.go | 121 +++++++++++-- 19 files changed, 974 insertions(+), 25 deletions(-) create mode 100644 runs/service/auth/auth_context_test.go create mode 100644 runs/service/auth/config/authorizationservertype_enumer.go create mode 100644 runs/service/auth/config/config_test.go create mode 100644 runs/service/auth/config/samesite_enumer.go create mode 100644 runs/service/auth/cookie_test.go create mode 100644 runs/service/auth/http_middleware.go create mode 100644 runs/service/auth/http_middleware_test.go create mode 100644 runs/service/auth/token_test.go diff --git a/charts/flyte-binary/templates/_helpers.tpl b/charts/flyte-binary/templates/_helpers.tpl index 28cc543bb7a..c9cd4940759 100644 --- a/charts/flyte-binary/templates/_helpers.tpl +++ b/charts/flyte-binary/templates/_helpers.tpl @@ -122,10 +122,17 @@ app.kubernetes.io/component: console {{- end }} {{/* -Get the Secret name for Run service authentication secrets. +Get the Secret name for Run service authentication secrets. When a user +supplies `configuration.auth.runServiceAuthSecretRef`, that existing Secret is +referenced directly (no template is rendered); otherwise a new Secret named +`-admin-auth` is used. */}} {{ define "flyte-binary.configuration.auth.runServiceAuthSecretName" -}} +{{- if .Values.configuration.auth.runServiceAuthSecretRef -}} +{{ tpl .Values.configuration.auth.runServiceAuthSecretRef . }} +{{- else -}} {{ printf "%s-admin-auth" (include "flyte-binary.fullname" .) }} +{{- end -}} {{ end -}} {{/* @@ -176,6 +183,8 @@ Get the Flyte API paths for ingress. {{- define "flyte-binary.ingress.grpcPaths" -}} - /flyteidl2.workflow.RunService - /flyteidl2.workflow.RunService/* +- /flyteidl2.workflow.InternalRunService +- /flyteidl2.workflow.InternalRunService/* - /flyteidl2.task.TaskService - /flyteidl2.task.TaskService/* - /flyteidl2.workflow.TranslatorService @@ -186,6 +195,16 @@ Get the Flyte API paths for ingress. - /flyteidl2.dataproxy.DataProxyService/* - /flyteidl2.secret.SecretService - /flyteidl2.secret.SecretService/* +- /flyteidl2.project.ProjectService +- /flyteidl2.project.ProjectService/* +- /flyteidl2.app.AppService +- /flyteidl2.app.AppService/* +- /flyteidl2.trigger.TriggerService +- /flyteidl2.trigger.TriggerService/* +- /flyteidl2.auth.AuthMetadataService +- /flyteidl2.auth.AuthMetadataService/* +- /flyteidl2.auth.IdentityService +- /flyteidl2.auth.IdentityService/* {{- end -}} {{/* diff --git a/charts/flyte-binary/templates/clusterrole.yaml b/charts/flyte-binary/templates/clusterrole.yaml index aeb1c15789f..97ee3354d78 100644 --- a/charts/flyte-binary/templates/clusterrole.yaml +++ b/charts/flyte-binary/templates/clusterrole.yaml @@ -42,17 +42,23 @@ rules: - watch - apiGroups: - "" + - events.k8s.io resources: - events verbs: - create - delete + - get + - list - patch - update + - watch - apiGroups: - flyte.org resources: - taskactions + - taskactions/status + - taskactions/finalizers verbs: - create - delete @@ -78,8 +84,12 @@ rules: - secrets verbs: - create + - delete - get + - list + - patch - update + - watch {{- if .Values.rbac.extraRules }} {{- toYaml .Values.rbac.extraRules | nindent 2 }} {{- end }} diff --git a/charts/flyte-binary/templates/config-secret.yaml b/charts/flyte-binary/templates/config-secret.yaml index 4f05d134fbc..467e2a6e0dc 100644 --- a/charts/flyte-binary/templates/config-secret.yaml +++ b/charts/flyte-binary/templates/config-secret.yaml @@ -25,6 +25,10 @@ stringData: database: postgres: password: {{ .Values.configuration.database.password | quote }} + runs: + database: + postgres: + password: {{ .Values.configuration.database.password | quote }} {{- end }} {{- if eq "s3" .Values.configuration.storage.provider }} {{- if eq "accesskey" .Values.configuration.storage.providerConfig.s3.authType }} diff --git a/charts/flyte-binary/templates/configmap.yaml b/charts/flyte-binary/templates/configmap.yaml index e0d4f20a9b5..161e325e3fb 100644 --- a/charts/flyte-binary/templates/configmap.yaml +++ b/charts/flyte-binary/templates/configmap.yaml @@ -105,6 +105,54 @@ data: {{- end }} container: {{ required "Metadata container required" .metadataContainer }} {{- end }} + {{- if .Values.configuration.auth.enabled }} + 004-auth.yaml: | + auth: + appAuth: + {{- if .Values.configuration.auth.enableAuthServer }} + authServerType: Self + {{- else }} + authServerType: External + {{- end }} + {{- with .Values.configuration.auth.externalAuthServer }} + externalAuthServer: + baseUrl: {{ tpl (default "" .baseUrl) $ | quote }} + {{- if .metadataUrl }} + metadataUrl: {{ .metadataUrl | quote }} + {{- end }} + allowedAudience: + {{- range .allowedAudience }} + - {{ tpl . $ | quote }} + {{- end }} + {{- end }} + {{- with .Values.configuration.auth.flyteClient }} + thirdPartyConfig: + flyteClient: + clientId: {{ .clientId | quote }} + redirectUri: {{ .redirectUri | quote }} + {{- if .audience }} + audience: {{ .audience | quote }} + {{- end }} + scopes: + {{- range .scopes }} + - {{ . | quote }} + {{- end }} + {{- end }} + authorizedUris: + {{- range .Values.configuration.auth.authorizedUris }} + - {{ tpl . $ | quote }} + {{- end }} + userAuth: + openId: + baseUrl: {{ .Values.configuration.auth.oidc.baseUrl | quote }} + clientId: {{ .Values.configuration.auth.oidc.clientId | quote }} + scopes: + - openid + - profile + runs: + security: + useAuth: true + {{- end }} {{- if .Values.configuration.inline }} 100-inline-config.yaml: | {{- tpl ( .Values.configuration.inline | toYaml ) . | nindent 4 }} diff --git a/charts/flyte-binary/templates/deployment.yaml b/charts/flyte-binary/templates/deployment.yaml index 4d9b94a6ea1..91448335706 100644 --- a/charts/flyte-binary/templates/deployment.yaml +++ b/charts/flyte-binary/templates/deployment.yaml @@ -38,7 +38,9 @@ spec: checksum/configuration-secret: {{ include (print $.Template.BasePath "/config-secret.yaml") . | sha256sum }} {{- end }} {{- if .Values.configuration.auth.enabled }} - checksum/runservice-auth-secret: {{ include (print $.Template.BasePath "/runservice-auth-secret.yaml") . | sha256sum }} + {{- if not .Values.configuration.auth.runServiceAuthSecretRef }} + checksum/runservice-auth-secret: {{ include (print $.Template.BasePath "/run-service-auth-secret.yaml") . | sha256sum }} + {{- end }} {{- if not .Values.configuration.auth.clientSecretsExternalSecretRef }} checksum/auth-client-secret: {{ include (print $.Template.BasePath "/auth-client-secret.yaml") . | sha256sum }} {{- end }} @@ -236,6 +238,10 @@ spec: - secret: name: {{ tpl .Values.configuration.inlineSecretRef . }} {{- end }} + {{- range .Values.configuration.extraInlineSecretRefs }} + - secret: + name: {{ tpl . $ }} + {{- end }} {{- end }} - name: webhook-certs emptyDir: {} diff --git a/charts/flyte-binary/templates/ingress/http.yaml b/charts/flyte-binary/templates/ingress/http.yaml index ef47d792dbe..58bcd58aa25 100644 --- a/charts/flyte-binary/templates/ingress/http.yaml +++ b/charts/flyte-binary/templates/ingress/http.yaml @@ -54,6 +54,7 @@ spec: path: /v2/* pathType: ImplementationSpecific {{- end }} + {{- if not .Values.ingress.minimalPaths }} - backend: service: name: {{ include "flyte-binary.service.http.name" . }} @@ -68,6 +69,8 @@ spec: number: {{ include "flyte-binary.service.http.port" . }} path: /console/* pathType: ImplementationSpecific + {{- end }} + {{- if not .Values.ingress.minimalPaths }} - backend: service: name: {{ include "flyte-binary.service.http.name" . }} @@ -82,6 +85,7 @@ spec: number: {{ include "flyte-binary.service.http.port" . }} path: /api/* pathType: ImplementationSpecific + {{- end }} - backend: service: name: {{ include "flyte-binary.service.http.name" . }} @@ -89,6 +93,7 @@ spec: number: {{ include "flyte-binary.service.http.port" . }} path: /healthcheck pathType: ImplementationSpecific + {{- if not .Values.ingress.minimalPaths }} - backend: service: name: {{ include "flyte-binary.service.http.name" . }} @@ -110,6 +115,7 @@ spec: number: {{ include "flyte-binary.service.http.port" . }} path: /.well-known/* pathType: ImplementationSpecific + {{- end }} - backend: service: name: {{ include "flyte-binary.service.http.name" . }} @@ -152,6 +158,7 @@ spec: number: {{ include "flyte-binary.service.http.port" . }} path: /callback/* pathType: ImplementationSpecific + {{- if not .Values.ingress.minimalPaths }} - backend: service: name: {{ include "flyte-binary.service.http.name" . }} @@ -187,6 +194,7 @@ spec: number: {{ include "flyte-binary.service.http.port" . }} path: /oauth2/* pathType: ImplementationSpecific + {{- end }} {{- if not .Values.ingress.separateGrpcIngress }} {{- $paths := (include "flyte-binary.ingress.grpcPaths" .) | fromYamlArray }} {{- range $path := $paths }} diff --git a/charts/flyte-binary/templates/run-service-auth-secret.yaml b/charts/flyte-binary/templates/run-service-auth-secret.yaml index 173d0bf2880..111ea59912f 100644 --- a/charts/flyte-binary/templates/run-service-auth-secret.yaml +++ b/charts/flyte-binary/templates/run-service-auth-secret.yaml @@ -1,4 +1,4 @@ -{{- if .Values.configuration.auth.enabled }} +{{- if and .Values.configuration.auth.enabled (not .Values.configuration.auth.runServiceAuthSecretRef) }} apiVersion: v1 kind: Secret metadata: diff --git a/charts/flyte-binary/values.yaml b/charts/flyte-binary/values.yaml index c61cd4a1cd2..b4a242b20e0 100644 --- a/charts/flyte-binary/values.yaml +++ b/charts/flyte-binary/values.yaml @@ -184,9 +184,21 @@ configuration: audience: "" # authorizedUris Set of URIs that clients are allowed to visit the service on authorizedUris: [] + # externalAuthServer Configuration for the external OAuth2 authorization + # server whose tokens Flyte will validate. Only used when + # `enableAuthServer: false`. Set `baseUrl` to the issuer URL and + # `allowedAudience` to the list of audiences Flyte should accept. + externalAuthServer: + baseUrl: "" + metadataUrl: "" + allowedAudience: [] # clientSecretExternalSecretRef Specify an existing, external Secret containing values for `client_secret` and `oidc_client_secret`. # If set, a Secret will not be generated by this chart for client secrets. clientSecretsExternalSecretRef: "" + # runServiceAuthSecretRef Specify an existing Secret to supply cookie + # hash/block keys (and other run-service auth secrets) at /etc/secrets. + # If set, this chart will NOT render its own run-service auth Secret. + runServiceAuthSecretRef: "" # co-pilot Configuration for Flyte CoPilot co-pilot: # image Configure image to use for CoPilot sidecar @@ -357,6 +369,13 @@ ingress: labels: {} # host Hostname to bind to ingress resources host: "" + # minimalPaths When true, the HTTP ingress only emits the paths that this + # Flyte deployment actually serves (login/callback/logout, api, v2, console, + # healthcheck). The legacy auth-server paths (/oauth2, /.well-known, /me, + # /config, /v1/*) are omitted so they can be served by a different Flyte + # deployment sharing the same ALB group. Set this on deployments that defer + # token issuance and OAuth metadata to an upstream auth server. + minimalPaths: false # separateGrpcIngress Create a separate ingress resource for GRPC if true. Required for certain ingress controllers like nginx. separateGrpcIngress: true # commonAnnotations Add common annotations to all ingress resources @@ -437,7 +456,7 @@ console: enabled: true image: repository: ghcr.io/flyteorg/consolev2 - tag: 34149492567122a466f885700f1df9da42025e6b + tag: latest pullPolicy: Always imagePullSecrets: - name: ghcr-pull-secret diff --git a/runs/service/auth/auth_context.go b/runs/service/auth/auth_context.go index 1989d78956d..ccefc8ea147 100644 --- a/runs/service/auth/auth_context.go +++ b/runs/service/auth/auth_context.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "net/url" + "strings" "time" "github.com/coreos/go-oidc/v3/oidc" @@ -64,8 +65,10 @@ func (c *AuthenticationContext) OAuth2MetadataURL() *url.URL { return c func (c *AuthenticationContext) OIDCMetadataURL() *url.URL { return c.oidcMetadataURL } // NewAuthContext creates a new AuthContext with all the components needed for authentication. +// oidcClientSecret is the IdP-issued confidential client secret used during the OAuth2 code +// exchange; it may be empty if the client is registered as public at the IdP. func NewAuthContext(ctx context.Context, cfg config.Config, resourceServer OAuth2ResourceServer, - hashKeyBase64, blockKeyBase64 string) (*AuthenticationContext, error) { + hashKeyBase64, blockKeyBase64, oidcClientSecret string) (*AuthenticationContext, error) { cookieManager, err := NewCookieManager(ctx, hashKeyBase64, blockKeyBase64, cfg.UserAuth.CookieSetting) if err != nil { @@ -90,10 +93,11 @@ func NewAuthContext(ctx context.Context, cfg config.Config, resourceServer OAuth } oauth2Config := &oauth2.Config{ - RedirectURL: "callback", - ClientID: cfg.UserAuth.OpenID.ClientID, - Scopes: cfg.UserAuth.OpenID.Scopes, - Endpoint: provider.Endpoint(), + RedirectURL: computeOIDCRedirectURL(cfg), + ClientID: cfg.UserAuth.OpenID.ClientID, + ClientSecret: oidcClientSecret, + Scopes: cfg.UserAuth.OpenID.Scopes, + Endpoint: provider.Endpoint(), } oauth2MetadataURL, err := url.Parse(OAuth2MetadataEndpoint) @@ -118,6 +122,19 @@ func NewAuthContext(ctx context.Context, cfg config.Config, resourceServer OAuth }, nil } +// computeOIDCRedirectURL returns the absolute redirect URL to use during the OAuth2 authorization +// code flow. IdPs like Okta require an absolute URL registered in their allowed-callbacks list. +// The URL is derived from the first authorizedUri with "/callback" appended. If no authorizedUris +// are configured, the legacy relative "callback" value is returned as a fallback. +func computeOIDCRedirectURL(cfg config.Config) string { + if len(cfg.AuthorizedURIs) == 0 { + return "callback" + } + base := cfg.AuthorizedURIs[0].URL + base.Path = strings.TrimSuffix(base.Path, "/") + "/callback" + return base.String() +} + // HandlerConfig returns an AuthHandlerConfig suitable for use with RegisterHandlers. func (c *AuthenticationContext) HandlerConfig() *AuthHandlerConfig { return &AuthHandlerConfig{ diff --git a/runs/service/auth/auth_context_test.go b/runs/service/auth/auth_context_test.go new file mode 100644 index 00000000000..01a432856cd --- /dev/null +++ b/runs/service/auth/auth_context_test.go @@ -0,0 +1,69 @@ +package auth + +import ( + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + stdconfig "github.com/flyteorg/flyte/v2/flytestdlib/config" + "github.com/flyteorg/flyte/v2/runs/service/auth/config" +) + +func mustParse(t *testing.T, raw string) stdconfig.URL { + t.Helper() + u, err := url.Parse(raw) + require.NoError(t, err) + return stdconfig.URL{URL: *u} +} + +func TestComputeOIDCRedirectURL(t *testing.T) { + cases := []struct { + name string + cfg config.Config + want string + }{ + { + name: "no authorizedUris falls back to relative path", + cfg: config.Config{}, + want: "callback", + }, + { + name: "simple https host", + cfg: config.Config{ + AuthorizedURIs: []stdconfig.URL{mustParse(t, "https://flyte.example.com")}, + }, + want: "https://flyte.example.com/callback", + }, + { + name: "host with trailing slash does not duplicate separator", + cfg: config.Config{ + AuthorizedURIs: []stdconfig.URL{mustParse(t, "https://flyte.example.com/")}, + }, + want: "https://flyte.example.com/callback", + }, + { + name: "picks first uri when multiple", + cfg: config.Config{ + AuthorizedURIs: []stdconfig.URL{ + mustParse(t, "https://flyte.example.com"), + mustParse(t, "http://flyte2.flyte:8080"), + }, + }, + want: "https://flyte.example.com/callback", + }, + { + name: "host with path prefix appends callback", + cfg: config.Config{ + AuthorizedURIs: []stdconfig.URL{mustParse(t, "https://flyte.example.com/v2")}, + }, + want: "https://flyte.example.com/v2/callback", + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.want, computeOIDCRedirectURL(tc.cfg)) + }) + } +} diff --git a/runs/service/auth/config/authorizationservertype_enumer.go b/runs/service/auth/config/authorizationservertype_enumer.go new file mode 100644 index 00000000000..f6e89a6fded --- /dev/null +++ b/runs/service/auth/config/authorizationservertype_enumer.go @@ -0,0 +1,67 @@ +// Code generated by "enumer --type=AuthorizationServerType --trimprefix=AuthorizationServerType -json"; DO NOT EDIT. + +package config + +import ( + "encoding/json" + "fmt" +) + +const _AuthorizationServerTypeName = "SelfExternal" + +var _AuthorizationServerTypeIndex = [...]uint8{0, 4, 12} + +func (i AuthorizationServerType) String() string { + if i < 0 || i >= AuthorizationServerType(len(_AuthorizationServerTypeIndex)-1) { + return fmt.Sprintf("AuthorizationServerType(%d)", i) + } + return _AuthorizationServerTypeName[_AuthorizationServerTypeIndex[i]:_AuthorizationServerTypeIndex[i+1]] +} + +var _AuthorizationServerTypeValues = []AuthorizationServerType{0, 1} + +var _AuthorizationServerTypeNameToValueMap = map[string]AuthorizationServerType{ + _AuthorizationServerTypeName[0:4]: 0, + _AuthorizationServerTypeName[4:12]: 1, +} + +// AuthorizationServerTypeString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func AuthorizationServerTypeString(s string) (AuthorizationServerType, error) { + if val, ok := _AuthorizationServerTypeNameToValueMap[s]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to AuthorizationServerType values", s) +} + +// AuthorizationServerTypeValues returns all values of the enum +func AuthorizationServerTypeValues() []AuthorizationServerType { + return _AuthorizationServerTypeValues +} + +// IsAAuthorizationServerType returns "true" if the value is listed in the enum definition. "false" otherwise +func (i AuthorizationServerType) IsAAuthorizationServerType() bool { + for _, v := range _AuthorizationServerTypeValues { + if i == v { + return true + } + } + return false +} + +// MarshalJSON implements the json.Marshaler interface for AuthorizationServerType +func (i AuthorizationServerType) MarshalJSON() ([]byte, error) { + return json.Marshal(i.String()) +} + +// UnmarshalJSON implements the json.Unmarshaler interface for AuthorizationServerType +func (i *AuthorizationServerType) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return fmt.Errorf("AuthorizationServerType should be a string, got %s", data) + } + + var err error + *i, err = AuthorizationServerTypeString(s) + return err +} diff --git a/runs/service/auth/config/config.go b/runs/service/auth/config/config.go index ca426146665..7d217900127 100644 --- a/runs/service/auth/config/config.go +++ b/runs/service/auth/config/config.go @@ -7,6 +7,8 @@ import ( ) //go:generate pflags Config --default-var=DefaultConfig +//go:generate enumer --type=AuthorizationServerType --trimprefix=AuthorizationServerType -json +//go:generate enumer --type=SameSite --trimprefix=SameSite -json type SecretName = string diff --git a/runs/service/auth/config/config_test.go b/runs/service/auth/config/config_test.go new file mode 100644 index 00000000000..216a9890566 --- /dev/null +++ b/runs/service/auth/config/config_test.go @@ -0,0 +1,112 @@ +package config + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAuthorizationServerType_String(t *testing.T) { + assert.Equal(t, "Self", AuthorizationServerTypeSelf.String()) + assert.Equal(t, "External", AuthorizationServerTypeExternal.String()) + assert.Equal(t, "AuthorizationServerType(99)", AuthorizationServerType(99).String()) +} + +func TestAuthorizationServerTypeString(t *testing.T) { + v, err := AuthorizationServerTypeString("Self") + require.NoError(t, err) + assert.Equal(t, AuthorizationServerTypeSelf, v) + + v, err = AuthorizationServerTypeString("External") + require.NoError(t, err) + assert.Equal(t, AuthorizationServerTypeExternal, v) + + _, err = AuthorizationServerTypeString("bogus") + assert.Error(t, err) +} + +func TestAuthorizationServerType_JSON(t *testing.T) { + b, err := json.Marshal(AuthorizationServerTypeExternal) + require.NoError(t, err) + assert.JSONEq(t, `"External"`, string(b)) + + var out AuthorizationServerType + require.NoError(t, json.Unmarshal([]byte(`"Self"`), &out)) + assert.Equal(t, AuthorizationServerTypeSelf, out) + + assert.Error(t, json.Unmarshal([]byte(`"bogus"`), &out)) + assert.Error(t, json.Unmarshal([]byte(`42`), &out)) +} + +func TestAuthorizationServerType_IsA(t *testing.T) { + assert.True(t, AuthorizationServerTypeSelf.IsAAuthorizationServerType()) + assert.True(t, AuthorizationServerTypeExternal.IsAAuthorizationServerType()) + assert.False(t, AuthorizationServerType(99).IsAAuthorizationServerType()) + assert.ElementsMatch(t, + []AuthorizationServerType{AuthorizationServerTypeSelf, AuthorizationServerTypeExternal}, + AuthorizationServerTypeValues()) +} + +func TestSameSite_String(t *testing.T) { + assert.Equal(t, "DefaultMode", SameSiteDefaultMode.String()) + assert.Equal(t, "LaxMode", SameSiteLaxMode.String()) + assert.Equal(t, "StrictMode", SameSiteStrictMode.String()) + assert.Equal(t, "NoneMode", SameSiteNoneMode.String()) + assert.Equal(t, "SameSite(99)", SameSite(99).String()) +} + +func TestSameSiteString(t *testing.T) { + v, err := SameSiteString("StrictMode") + require.NoError(t, err) + assert.Equal(t, SameSiteStrictMode, v) + + _, err = SameSiteString("bogus") + assert.Error(t, err) +} + +func TestSameSite_JSON(t *testing.T) { + b, err := json.Marshal(SameSiteNoneMode) + require.NoError(t, err) + assert.JSONEq(t, `"NoneMode"`, string(b)) + + var out SameSite + require.NoError(t, json.Unmarshal([]byte(`"LaxMode"`), &out)) + assert.Equal(t, SameSiteLaxMode, out) + + assert.Error(t, json.Unmarshal([]byte(`"bogus"`), &out)) +} + +func TestSameSite_IsA(t *testing.T) { + assert.True(t, SameSiteDefaultMode.IsASameSite()) + assert.True(t, SameSiteNoneMode.IsASameSite()) + assert.False(t, SameSite(99).IsASameSite()) + assert.Len(t, SameSiteValues(), 4) +} + +func TestThirdPartyConfigOptions_IsEmpty(t *testing.T) { + assert.True(t, ThirdPartyConfigOptions{}.IsEmpty()) + assert.False(t, ThirdPartyConfigOptions{ + FlyteClientConfig: FlyteClientConfig{ClientID: "x"}, + }.IsEmpty()) + assert.False(t, ThirdPartyConfigOptions{ + FlyteClientConfig: FlyteClientConfig{Scopes: []string{"all"}}, + }.IsEmpty()) +} + +func TestMustParseURL(t *testing.T) { + u := MustParseURL("https://example.com/path") + require.NotNil(t, u) + assert.Equal(t, "example.com", u.Host) + + assert.Panics(t, func() { MustParseURL("://bogus") }) +} + +func TestDefaultConfig(t *testing.T) { + require.NotNil(t, DefaultConfig) + assert.Equal(t, "flyte-authorization", DefaultConfig.HTTPAuthorizationHeader) + assert.Equal(t, AuthorizationServerTypeSelf, DefaultConfig.AppAuth.AuthServerType) + assert.Equal(t, SameSiteDefaultMode, DefaultConfig.UserAuth.CookieSetting.SameSitePolicy) + assert.Contains(t, DefaultConfig.UserAuth.OpenID.Scopes, "openid") +} diff --git a/runs/service/auth/config/samesite_enumer.go b/runs/service/auth/config/samesite_enumer.go new file mode 100644 index 00000000000..e42e58fbbe5 --- /dev/null +++ b/runs/service/auth/config/samesite_enumer.go @@ -0,0 +1,69 @@ +// Code generated by "enumer --type=SameSite --trimprefix=SameSite -json"; DO NOT EDIT. + +package config + +import ( + "encoding/json" + "fmt" +) + +const _SameSiteName = "DefaultModeLaxModeStrictModeNoneMode" + +var _SameSiteIndex = [...]uint8{0, 11, 18, 28, 36} + +func (i SameSite) String() string { + if i < 0 || i >= SameSite(len(_SameSiteIndex)-1) { + return fmt.Sprintf("SameSite(%d)", i) + } + return _SameSiteName[_SameSiteIndex[i]:_SameSiteIndex[i+1]] +} + +var _SameSiteValues = []SameSite{0, 1, 2, 3} + +var _SameSiteNameToValueMap = map[string]SameSite{ + _SameSiteName[0:11]: 0, + _SameSiteName[11:18]: 1, + _SameSiteName[18:28]: 2, + _SameSiteName[28:36]: 3, +} + +// SameSiteString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func SameSiteString(s string) (SameSite, error) { + if val, ok := _SameSiteNameToValueMap[s]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to SameSite values", s) +} + +// SameSiteValues returns all values of the enum +func SameSiteValues() []SameSite { + return _SameSiteValues +} + +// IsASameSite returns "true" if the value is listed in the enum definition. "false" otherwise +func (i SameSite) IsASameSite() bool { + for _, v := range _SameSiteValues { + if i == v { + return true + } + } + return false +} + +// MarshalJSON implements the json.Marshaler interface for SameSite +func (i SameSite) MarshalJSON() ([]byte, error) { + return json.Marshal(i.String()) +} + +// UnmarshalJSON implements the json.Unmarshaler interface for SameSite +func (i *SameSite) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return fmt.Errorf("SameSite should be a string, got %s", data) + } + + var err error + *i, err = SameSiteString(s) + return err +} diff --git a/runs/service/auth/cookie_test.go b/runs/service/auth/cookie_test.go new file mode 100644 index 00000000000..07a55a87366 --- /dev/null +++ b/runs/service/auth/cookie_test.go @@ -0,0 +1,160 @@ +package auth + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/gorilla/securecookie" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/flyteorg/flyte/v2/runs/service/auth/config" +) + +func TestHashCsrfState(t *testing.T) { + h := HashCsrfState("hello") + // sha256("hello") hex + assert.Equal(t, "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", h) + assert.Equal(t, h, HashCsrfState("hello"), "deterministic") + assert.NotEqual(t, h, HashCsrfState("world")) +} + +func TestNewCsrfToken(t *testing.T) { + a := NewCsrfToken(1) + b := NewCsrfToken(1) + assert.Equal(t, a, b, "same seed should produce identical token") + assert.Len(t, a, 10) + for _, r := range a { + assert.Contains(t, string(AllowedChars), string(r)) + } + assert.NotEqual(t, a, NewCsrfToken(2)) +} + +func TestNewCsrfCookie(t *testing.T) { + c := NewCsrfCookie() + assert.Equal(t, "flyte_csrf_state", c.Name) + assert.Len(t, c.Value, 10) + assert.Equal(t, http.SameSiteLaxMode, c.SameSite) + assert.True(t, c.HttpOnly) +} + +func newTestKeys() ([]byte, []byte) { + return securecookie.GenerateRandomKey(64), securecookie.GenerateRandomKey(32) +} + +func TestSecureCookie_RoundTrip(t *testing.T) { + hashKey, blockKey := newTestKeys() + cookie, err := NewSecureCookie("flyte_at", "super-secret", hashKey, blockKey, "", http.SameSiteLaxMode) + require.NoError(t, err) + assert.Equal(t, "flyte_at", cookie.Name) + assert.True(t, cookie.HttpOnly) + assert.Equal(t, "/", cookie.Path) + + out, err := ReadSecureCookie(context.Background(), cookie, hashKey, blockKey) + require.NoError(t, err) + assert.Equal(t, "super-secret", out) +} + +func TestReadSecureCookie_WrongKey(t *testing.T) { + hashKey, blockKey := newTestKeys() + cookie, err := NewSecureCookie("flyte_at", "value", hashKey, blockKey, "", http.SameSiteLaxMode) + require.NoError(t, err) + + wrongHash, wrongBlock := newTestKeys() + _, err = ReadSecureCookie(context.Background(), cookie, wrongHash, wrongBlock) + assert.Error(t, err) +} + +func TestRetrieveSecureCookie_Missing(t *testing.T) { + hashKey, blockKey := newTestKeys() + req := httptest.NewRequest(http.MethodGet, "/", nil) + _, err := retrieveSecureCookie(context.Background(), req, "missing", hashKey, blockKey) + assert.Error(t, err) +} + +func TestVerifyCsrfCookie(t *testing.T) { + token := "abcdefghij" + hashed := HashCsrfState(token) + + req := httptest.NewRequest(http.MethodPost, "/callback", strings.NewReader("state="+hashed)) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.AddCookie(&http.Cookie{Name: "flyte_csrf_state", Value: token}) + require.NoError(t, req.ParseForm()) + require.NoError(t, VerifyCsrfCookie(context.Background(), req)) +} + +func TestVerifyCsrfCookie_Mismatch(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/callback", strings.NewReader("state=wrong")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.AddCookie(&http.Cookie{Name: "flyte_csrf_state", Value: "abcdefghij"}) + require.NoError(t, req.ParseForm()) + assert.Error(t, VerifyCsrfCookie(context.Background(), req)) +} + +func TestVerifyCsrfCookie_EmptyState(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/callback", strings.NewReader("")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + require.NoError(t, req.ParseForm()) + assert.Error(t, VerifyCsrfCookie(context.Background(), req)) +} + +func TestVerifyCsrfCookie_MissingCookie(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/callback", strings.NewReader("state=x")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + require.NoError(t, req.ParseForm()) + assert.Error(t, VerifyCsrfCookie(context.Background(), req)) +} + +func TestNewRedirectCookie(t *testing.T) { + c := NewRedirectCookie(context.Background(), "/console/projects") + require.NotNil(t, c) + assert.Equal(t, "flyte_redirect_location", c.Name) + assert.Equal(t, "/console/projects", c.Value) + assert.True(t, c.HttpOnly) +} + +func TestNewRedirectCookie_Invalid(t *testing.T) { + assert.Nil(t, NewRedirectCookie(context.Background(), "")) +} + +func TestGetAuthFlowEndRedirect_QueryAllowed(t *testing.T) { + authorized := []config.URL{{URL: *mustURL(t, "https://flyte.mycorp.com")}} + req := httptest.NewRequest(http.MethodGet, "https://flyte.mycorp.com/callback?redirect_url=https://flyte.mycorp.com/console", nil) + + got := GetAuthFlowEndRedirect(context.Background(), "/default", authorized, req) + assert.Equal(t, "https://flyte.mycorp.com/console", got) +} + +func TestGetAuthFlowEndRedirect_QueryUnauthorizedFallsBack(t *testing.T) { + authorized := []config.URL{{URL: *mustURL(t, "https://flyte.mycorp.com")}} + req := httptest.NewRequest(http.MethodGet, "https://flyte.mycorp.com/callback?redirect_url=https://evil.example.com", nil) + + got := GetAuthFlowEndRedirect(context.Background(), "/default", authorized, req) + assert.Equal(t, "/default", got) +} + +func TestGetAuthFlowEndRedirect_CookieFallback(t *testing.T) { + authorized := []config.URL{{URL: *mustURL(t, "https://flyte.mycorp.com")}} + req := httptest.NewRequest(http.MethodGet, "https://flyte.mycorp.com/callback", nil) + req.AddCookie(&http.Cookie{Name: "flyte_redirect_location", Value: "https://flyte.mycorp.com/console"}) + + got := GetAuthFlowEndRedirect(context.Background(), "/default", authorized, req) + assert.Equal(t, "https://flyte.mycorp.com/console", got) +} + +func TestGetAuthFlowEndRedirect_NoCookieReturnsDefault(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "https://flyte.mycorp.com/callback", nil) + got := GetAuthFlowEndRedirect(context.Background(), "/default", nil, req) + assert.Equal(t, "/default", got) +} + +func mustURL(t *testing.T, s string) *url.URL { + t.Helper() + u, err := url.Parse(s) + require.NoError(t, err) + return u +} diff --git a/runs/service/auth/http_middleware.go b/runs/service/auth/http_middleware.go new file mode 100644 index 00000000000..79a0b07fa7b --- /dev/null +++ b/runs/service/auth/http_middleware.go @@ -0,0 +1,64 @@ +package auth + +import ( + "net/http" + "strings" + + "github.com/flyteorg/flyte/v2/flytestdlib/logger" +) + +// publicPathPrefixes lists request paths that never require authentication. +// These must cover health probes, the browser OAuth2/OIDC flow, metadata +// discovery, and the AuthMetadataService (which clients call *before* they +// have a token). +var publicPathPrefixes = []string{ + "/healthz", + "/readyz", + "/healthcheck", + "/login", + "/callback", + "/logout", + "/.well-known/", + "/flyteidl2.auth.AuthMetadataService/", +} + +// IsPublicPath reports whether an HTTP request path bypasses authentication. +func IsPublicPath(path string) bool { + for _, p := range publicPathPrefixes { + if strings.HasPrefix(path, p) { + return true + } + } + return false +} + +// GetAuthenticationHTTPInterceptor returns middleware that validates a bearer +// token or auth cookies on incoming HTTP requests and injects the resulting +// IdentityContext into the request context. Public paths (see IsPublicPath) +// pass through without validation. When DisableForHTTP is set on the config, +// every request passes through unchanged. +func GetAuthenticationHTTPInterceptor(h *AuthHandlerConfig) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if IsPublicPath(req.URL.Path) { + next.ServeHTTP(w, req) + return + } + + if h.AuthConfig.DisableForHTTP { + next.ServeHTTP(w, req) + return + } + + ctx := req.Context() + identity, err := IdentityContextFromRequest(ctx, req, h) + if err != nil { + logger.Infof(ctx, "unauthenticated request to %s: %v", req.URL.Path, err) + w.WriteHeader(http.StatusUnauthorized) + return + } + + next.ServeHTTP(w, req.WithContext(identity.WithContext(ctx))) + }) + } +} diff --git a/runs/service/auth/http_middleware_test.go b/runs/service/auth/http_middleware_test.go new file mode 100644 index 00000000000..8ad1055c7f5 --- /dev/null +++ b/runs/service/auth/http_middleware_test.go @@ -0,0 +1,92 @@ +package auth + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/flyteorg/flyte/v2/runs/service/auth/config" +) + +func TestIsPublicPath(t *testing.T) { + cases := map[string]bool{ + "/healthz": true, + "/readyz": true, + "/healthcheck": true, + "/login": true, + "/login?redirect_url=/console": true, + "/callback": true, + "/logout": true, + "/.well-known/openid-configuration": true, + "/.well-known/oauth-authorization-server": true, + "/flyteidl2.auth.AuthMetadataService/GetOAuth2Metadata": true, + "/flyteidl2.workflow.RunService/CreateRun": false, + "/flyteidl2.auth.IdentityService/UserInfo": false, + "/": false, + "/api/v1/projects": false, + } + for path, want := range cases { + got := IsPublicPath(path) + assert.Equalf(t, want, got, "IsPublicPath(%q)", path) + } +} + +// servedBy wraps a boolean flag so tests can check that the next handler ran. +type servedBy struct{ called bool } + +func (s *servedBy) handler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + s.called = true + w.WriteHeader(http.StatusOK) + }) +} + +func TestMiddleware_PublicPathBypassesAuth(t *testing.T) { + // Even with a zero AuthHandlerConfig (no resource server, no cookie manager), + // public paths must not touch any auth plumbing. + h := &AuthHandlerConfig{AuthConfig: config.Config{}} + mw := GetAuthenticationHTTPInterceptor(h) + + var sb servedBy + req := httptest.NewRequest(http.MethodGet, "/healthz", nil) + w := httptest.NewRecorder() + mw(sb.handler()).ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.True(t, sb.called, "public path should have reached the next handler") +} + +func TestMiddleware_DisabledForHTTPBypassesAuth(t *testing.T) { + h := &AuthHandlerConfig{AuthConfig: config.Config{DisableForHTTP: true}} + mw := GetAuthenticationHTTPInterceptor(h) + + var sb servedBy + req := httptest.NewRequest(http.MethodGet, "/flyteidl2.workflow.RunService/CreateRun", nil) + w := httptest.NewRecorder() + mw(sb.handler()).ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.True(t, sb.called) +} + +func TestMiddleware_NoAuthReturns401(t *testing.T) { + // AuthHandlerConfig missing a CookieManager will cause IdentityContextFromRequest + // to fail when no bearer header is present. The middleware must convert that to 401. + h := &AuthHandlerConfig{ + AuthConfig: config.Config{}, + CookieManager: CookieManager{}, + ResourceServer: nil, + } + mw := GetAuthenticationHTTPInterceptor(h) + + var sb servedBy + req := httptest.NewRequest(http.MethodGet, "/flyteidl2.workflow.RunService/CreateRun", nil) + w := httptest.NewRecorder() + mw(sb.handler()).ServeHTTP(w, req) + + assert.Equal(t, http.StatusUnauthorized, w.Code) + assert.False(t, sb.called, "protected path must not reach next handler without auth") +} + diff --git a/runs/service/auth/token_test.go b/runs/service/auth/token_test.go new file mode 100644 index 00000000000..5c6a8685979 --- /dev/null +++ b/runs/service/auth/token_test.go @@ -0,0 +1,94 @@ +package auth + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" + "google.golang.org/grpc/metadata" +) + +func TestNewOAuthTokenFromRaw(t *testing.T) { + tok := NewOAuthTokenFromRaw("access", "refresh", "id-token") + require.NotNil(t, tok) + assert.Equal(t, "access", tok.AccessToken) + assert.Equal(t, "refresh", tok.RefreshToken) + assert.Equal(t, "id-token", tok.Extra(idTokenExtra)) +} + +func TestExtractTokensFromOauthToken(t *testing.T) { + src := NewOAuthTokenFromRaw("a", "r", "i") + id, access, refresh, err := ExtractTokensFromOauthToken(src) + require.NoError(t, err) + assert.Equal(t, "i", id) + assert.Equal(t, "a", access) + assert.Equal(t, "r", refresh) +} + +func TestExtractTokensFromOauthToken_Nil(t *testing.T) { + _, _, _, err := ExtractTokensFromOauthToken(nil) + assert.Error(t, err) +} + +func TestExtractTokensFromOauthToken_MissingIDToken(t *testing.T) { + // bare oauth2.Token without id_token extra should fail + tok := &oauth2.Token{AccessToken: "a", RefreshToken: "r"} + _, _, _, err := ExtractTokensFromOauthToken(tok) + assert.Error(t, err) +} + +func ctxWithMD(pairs ...string) context.Context { + md := metadata.Pairs(pairs...) + return metadata.NewIncomingContext(context.Background(), md) +} + +func TestBearerTokenFromMD(t *testing.T) { + ctx := ctxWithMD(DefaultAuthorizationHeader, "Bearer my-token") + tok, err := bearerTokenFromMD(ctx) + require.NoError(t, err) + assert.Equal(t, "my-token", tok) +} + +func TestBearerTokenFromMD_NoMetadata(t *testing.T) { + _, err := bearerTokenFromMD(context.Background()) + assert.Error(t, err) +} + +func TestBearerTokenFromMD_MissingHeader(t *testing.T) { + _, err := bearerTokenFromMD(ctxWithMD("other", "v")) + assert.Error(t, err) +} + +func TestBearerTokenFromMD_WrongScheme(t *testing.T) { + _, err := bearerTokenFromMD(ctxWithMD(DefaultAuthorizationHeader, "IDToken abc")) + assert.Error(t, err) +} + +func TestBearerTokenFromMD_Blank(t *testing.T) { + _, err := bearerTokenFromMD(ctxWithMD(DefaultAuthorizationHeader, "Bearer ")) + assert.Error(t, err) +} + +func TestIDTokenFromMD(t *testing.T) { + ctx := ctxWithMD(DefaultAuthorizationHeader, "IDToken my-id-token") + tok, err := idTokenFromMD(ctx) + require.NoError(t, err) + assert.Equal(t, "my-id-token", tok) +} + +func TestIDTokenFromMD_WrongScheme(t *testing.T) { + _, err := idTokenFromMD(ctxWithMD(DefaultAuthorizationHeader, "Bearer abc")) + assert.Error(t, err) +} + +func TestIDTokenFromMD_Blank(t *testing.T) { + _, err := idTokenFromMD(ctxWithMD(DefaultAuthorizationHeader, "IDToken ")) + assert.Error(t, err) +} + +func TestIDTokenFromMD_NoMetadata(t *testing.T) { + _, err := idTokenFromMD(context.Background()) + assert.Error(t, err) +} diff --git a/runs/setup.go b/runs/setup.go index 7e6a4c0435a..4ebe0af9c88 100644 --- a/runs/setup.go +++ b/runs/setup.go @@ -5,6 +5,9 @@ import ( "errors" "fmt" "net/http" + "os" + "path/filepath" + "strings" "time" "github.com/flyteorg/flyte/v2/flytestdlib/app" @@ -26,12 +29,16 @@ import ( "github.com/flyteorg/flyte/v2/runs/scheduler" "github.com/flyteorg/flyte/v2/runs/service" authservice "github.com/flyteorg/flyte/v2/runs/service/auth" - authConfig "github.com/flyteorg/flyte/v2/runs/service/auth/config" "github.com/flyteorg/flyte/v2/runs/service/auth/authzserver" + authConfig "github.com/flyteorg/flyte/v2/runs/service/auth/config" "github.com/flyteorg/flyte/v2/flytestdlib/logger" ) +// authSecretsDir is the directory where cookie hash/block keys and other auth +// secrets are mounted (matches the flyte-binary chart volume mount path). +const authSecretsDir = "/etc/secrets" + // Setup registers Run and Task service handlers on the SetupContext mux. // Requires sc.DB and sc.DataStore to be set. When sc.K8sConfig is provided, // RunLogsService is also mounted to enable pod log streaming. @@ -88,23 +95,18 @@ func Setup(ctx context.Context, sc *app.SetupContext) error { logger.Infof(ctx, "Mounted TranslatorService at %s", translatorPath) if cfg.Security.UseAuth { - authCfg := authConfig.GetConfig() - authSvc := authzserver.NewAuthMetadataService(*authCfg) - authPath, authHandler := authconnect.NewAuthMetadataServiceHandler(authSvc) - sc.Mux.Handle(authPath, authHandler) - logger.Infof(ctx, "Mounted AuthMetadataService at %s", authPath) - - identitySvc := authservice.NewIdentityService() - identityPath, identityHandler := authconnect.NewIdentityServiceHandler(identitySvc) - sc.Mux.Handle(identityPath, identityHandler) - logger.Infof(ctx, "Mounted IdentityService at %s", identityPath) + if err := setupAuth(ctx, sc); err != nil { + return fmt.Errorf("runs: failed to set up auth: %w", err) + } + } else { + // When auth is disabled, still mount a stub AuthMetadataService so + // clients performing metadata discovery get a coherent response. + authMetadataSvc := service.NewAuthMetadataService(sc.BaseURL) + authMetadataPath, authMetadataHandler := authconnect.NewAuthMetadataServiceHandler(authMetadataSvc) + sc.Mux.Handle(authMetadataPath, authMetadataHandler) + logger.Infof(ctx, "Mounted stub AuthMetadataService at %s", authMetadataPath) } - authMetadataSvc := service.NewAuthMetadataService(sc.BaseURL) - authMetadataPath, authMetadataHandler := authconnect.NewAuthMetadataServiceHandler(authMetadataSvc) - sc.Mux.Handle(authMetadataPath, authMetadataHandler) - logger.Infof(ctx, "Mounted AuthMetadataService at %s", authMetadataPath) - appSvc := service.NewAppService() appPath, appHandler := flyteappconnect.NewAppServiceHandler(appSvc) sc.Mux.Handle(appPath, appHandler) @@ -162,6 +164,93 @@ func Setup(ctx context.Context, sc *app.SetupContext) error { return nil } +// setupAuth wires up the external-mode OAuth2 resource server, OIDC browser +// handlers, AuthMetadataService / IdentityService, and a bearer-token +// validating HTTP middleware on the shared mux. It requires that the auth +// config section is populated and that cookie hash/block keys are present as +// files under authSecretsDir. +func setupAuth(ctx context.Context, sc *app.SetupContext) error { + authCfg := authConfig.GetConfig() + + // Mount the real AuthMetadataService backed by the configured issuer. + authMetadataSvc := authzserver.NewAuthMetadataService(*authCfg) + authPath, authHandler := authconnect.NewAuthMetadataServiceHandler(authMetadataSvc) + sc.Mux.Handle(authPath, authHandler) + logger.Infof(ctx, "Mounted AuthMetadataService at %s", authPath) + + identitySvc := authservice.NewIdentityService() + identityPath, identityHandler := authconnect.NewIdentityServiceHandler(identitySvc) + sc.Mux.Handle(identityPath, identityHandler) + logger.Infof(ctx, "Mounted IdentityService at %s", identityPath) + + hashKey, err := readSecretFile(authConfig.SecretNameCookieHashKey) + if err != nil { + return err + } + blockKey, err := readSecretFile(authConfig.SecretNameCookieBlockKey) + if err != nil { + return err + } + + // Load the OIDC client secret used during the OAuth2 code exchange. The + // filename is configurable so that a deployment can swap the secret name + // without redeploying the binary. + oidcClientSecretName := authCfg.UserAuth.OpenID.ClientSecretName + if oidcClientSecretName == "" { + oidcClientSecretName = authConfig.SecretNameOIdCClientSecret + } + oidcClientSecret, err := readSecretFile(oidcClientSecretName) + if err != nil { + return err + } + + // Validate tokens issued by the configured external authorization server. + // If BaseURL is empty, the resource server falls back to the first authorizedUri. + var fallbackURL authConfig.URL + if len(authCfg.AuthorizedURIs) > 0 { + fallbackURL = authCfg.AuthorizedURIs[0] + } + resourceServer, err := authzserver.NewOAuth2ResourceServer(ctx, authCfg.AppAuth.ExternalAuthServer, fallbackURL) + if err != nil { + return fmt.Errorf("failed to create OAuth2 resource server: %w", err) + } + + authCtx, err := authservice.NewAuthContext(ctx, *authCfg, resourceServer, hashKey, blockKey, oidcClientSecret) + if err != nil { + return fmt.Errorf("failed to create auth context: %w", err) + } + + // Register /login, /callback, /logout, /.well-known/openid-configuration. + authservice.RegisterHandlers(ctx, sc.Mux, authCtx.HandlerConfig()) + logger.Infof(ctx, "Registered OIDC browser handlers (/login, /callback, /logout)") + + // Chain the bearer/cookie auth middleware with any existing middleware + // (e.g. CORS). Ordering: request -> CORS -> auth -> mux. + prev := sc.Middleware + authMw := authservice.GetAuthenticationHTTPInterceptor(authCtx.HandlerConfig()) + sc.Middleware = func(next http.Handler) http.Handler { + wrapped := authMw(next) + if prev != nil { + wrapped = prev(wrapped) + } + return wrapped + } + logger.Infof(ctx, "Auth middleware installed; audience=%s", authCfg.AppAuth.ExternalAuthServer.BaseURL.String()) + + return nil +} + +// readSecretFile reads a base64-encoded key file from authSecretsDir and +// returns the trimmed string contents. +func readSecretFile(name string) (string, error) { + path := filepath.Join(authSecretsDir, name) + b, err := os.ReadFile(path) + if err != nil { + return "", fmt.Errorf("failed to read auth secret %s: %w", path, err) + } + return strings.TrimSpace(string(b)), nil +} + func seedProjects(ctx context.Context, projectRepo interfaces.ProjectRepo, projects []string) error { for _, projectID := range projects { if projectID == "" { From 6e6d75d700b34ed9de5a101dbee7ad8135a4d86e Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 11 Apr 2026 00:17:03 -0700 Subject: [PATCH 10/13] [V2] auth middleware: bypass for loopback intra-process calls MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The unified Flyte binary uses connect-rpc clients that talk to their own mux via http://localhost: (e.g. RunService calls ActionsService.CreateAction). Those calls have no Authorization header because they're in-process, and the new external auth middleware was rejecting them with 401 — so run creation silently failed end-to-end. Bypass auth when req.RemoteAddr is a loopback address (127.0.0.0/8 or ::1). External traffic from the ALB never has a loopback remote addr, so this doesn't widen the attack surface. Add table-driven isLoopbackRequest tests and middleware tests for both IPv4 and IPv6 loopback and a non-loopback pod IP. Signed-off-by: Kevin Su --- runs/service/auth/http_middleware.go | 25 +++++++++ runs/service/auth/http_middleware_test.go | 63 +++++++++++++++++++++++ 2 files changed, 88 insertions(+) diff --git a/runs/service/auth/http_middleware.go b/runs/service/auth/http_middleware.go index 79a0b07fa7b..be8a7bb9f51 100644 --- a/runs/service/auth/http_middleware.go +++ b/runs/service/auth/http_middleware.go @@ -1,6 +1,7 @@ package auth import ( + "net" "net/http" "strings" @@ -32,6 +33,25 @@ func IsPublicPath(path string) bool { return false } +// isLoopbackRequest returns true when the request originated from the local +// loopback interface. The unified Flyte binary makes intra-process connect-rpc +// calls to its own HTTP mux via http://localhost: (e.g. RunService -> +// ActionsService). Those calls have no Authorization header and must not be +// forced through the external auth gate, or every run creation will fail with +// 401. External traffic (ALB, port-forward from outside the pod) never has a +// loopback RemoteAddr. +func isLoopbackRequest(req *http.Request) bool { + host, _, err := net.SplitHostPort(req.RemoteAddr) + if err != nil { + host = req.RemoteAddr + } + ip := net.ParseIP(host) + if ip == nil { + return false + } + return ip.IsLoopback() +} + // GetAuthenticationHTTPInterceptor returns middleware that validates a bearer // token or auth cookies on incoming HTTP requests and injects the resulting // IdentityContext into the request context. Public paths (see IsPublicPath) @@ -45,6 +65,11 @@ func GetAuthenticationHTTPInterceptor(h *AuthHandlerConfig) func(http.Handler) h return } + if isLoopbackRequest(req) { + next.ServeHTTP(w, req) + return + } + if h.AuthConfig.DisableForHTTP { next.ServeHTTP(w, req) return diff --git a/runs/service/auth/http_middleware_test.go b/runs/service/auth/http_middleware_test.go index 8ad1055c7f5..8f9aa367967 100644 --- a/runs/service/auth/http_middleware_test.go +++ b/runs/service/auth/http_middleware_test.go @@ -71,6 +71,69 @@ func TestMiddleware_DisabledForHTTPBypassesAuth(t *testing.T) { assert.True(t, sb.called) } +func TestMiddleware_LoopbackIPv4BypassesAuth(t *testing.T) { + // Intra-process connect-rpc calls (e.g. runs -> actions on localhost:8090) + // must pass through the middleware without an Authorization header. + h := &AuthHandlerConfig{AuthConfig: config.Config{}, CookieManager: CookieManager{}} + mw := GetAuthenticationHTTPInterceptor(h) + + var sb servedBy + req := httptest.NewRequest(http.MethodPost, "/flyteidl2.actions.ActionsService/CreateAction", nil) + req.RemoteAddr = "127.0.0.1:54321" + w := httptest.NewRecorder() + mw(sb.handler()).ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.True(t, sb.called, "loopback call must reach next handler") +} + +func TestMiddleware_LoopbackIPv6BypassesAuth(t *testing.T) { + h := &AuthHandlerConfig{AuthConfig: config.Config{}, CookieManager: CookieManager{}} + mw := GetAuthenticationHTTPInterceptor(h) + + var sb servedBy + req := httptest.NewRequest(http.MethodPost, "/flyteidl2.actions.ActionsService/CreateAction", nil) + req.RemoteAddr = "[::1]:54321" + w := httptest.NewRecorder() + mw(sb.handler()).ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.True(t, sb.called) +} + +func TestMiddleware_NonLoopbackStillBlocks(t *testing.T) { + // A caller from a real pod IP must still hit the 401 path when no auth + // is present — the loopback bypass is strictly for in-process calls. + h := &AuthHandlerConfig{AuthConfig: config.Config{}, CookieManager: CookieManager{}} + mw := GetAuthenticationHTTPInterceptor(h) + + var sb servedBy + req := httptest.NewRequest(http.MethodPost, "/flyteidl2.actions.ActionsService/CreateAction", nil) + req.RemoteAddr = "10.1.42.7:48221" + w := httptest.NewRecorder() + mw(sb.handler()).ServeHTTP(w, req) + + assert.Equal(t, http.StatusUnauthorized, w.Code) + assert.False(t, sb.called) +} + +func TestIsLoopbackRequest(t *testing.T) { + cases := map[string]bool{ + "127.0.0.1:1234": true, + "127.1.2.3:80": true, + "[::1]:8080": true, + "10.0.0.1:8080": false, + "192.168.1.1:80": false, + "": false, + "bogus": false, + } + for addr, want := range cases { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = addr + assert.Equalf(t, want, isLoopbackRequest(req), "isLoopbackRequest(%q)", addr) + } +} + func TestMiddleware_NoAuthReturns401(t *testing.T) { // AuthHandlerConfig missing a CookieManager will cause IdentityContextFromRequest // to fail when no bearer header is present. The middleware must convert that to 401. From 09b5a5af4960ac865a0127755945265e7224275a Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 11 Apr 2026 00:49:48 -0700 Subject: [PATCH 11/13] [V2] auth middleware: allow task-pod calls to ActionsService + InternalRunService MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Task pods running flytekit call ActionsService.CreateAction (and InternalRunService) via the flyte2-grpc ClusterIP service to enqueue subsequent actions. Those calls arrive at the pod with the task pod's IP as RemoteAddr — not loopback — so the loopback bypass does not catch them, and the external auth middleware was returning 401, which flytekit reported as "Failed to launch action: Unauthorized" and task execution failed. Add /flyteidl2.actions.ActionsService/ and /flyteidl2.workflow.InternalRunService/ to the public-path allowlist so in-cluster traffic to these services passes without credentials. Remove the same paths from the ingress grpcPaths helper so they are not exposed via the external ALB — they remain reachable only through the ClusterIP service inside the cluster, matching v1's propeller -> flyteadmin pattern. Update the table-driven IsPublicPath test and swap the loopback / non-loopback test path to RunService so the assertion still exercises the gate rather than the new public path. Signed-off-by: Kevin Su --- charts/flyte-binary/templates/_helpers.tpl | 10 ++-- runs/service/auth/http_middleware.go | 14 ++++-- runs/service/auth/http_middleware_test.go | 55 +++++++++++++++------- 3 files changed, 54 insertions(+), 25 deletions(-) diff --git a/charts/flyte-binary/templates/_helpers.tpl b/charts/flyte-binary/templates/_helpers.tpl index c9cd4940759..287725002c1 100644 --- a/charts/flyte-binary/templates/_helpers.tpl +++ b/charts/flyte-binary/templates/_helpers.tpl @@ -178,19 +178,19 @@ Get the Flyte service gRPC port. {{- end -}} {{/* -Get the Flyte API paths for ingress. +Get the Flyte API paths for ingress. Services whose names start with +"Internal" (e.g. InternalRunService) plus ActionsService are intended for +intra-cluster traffic from task pods only; they are deliberately NOT exposed +via the external ALB ingress here. The Go auth middleware allowlists them so +cluster-internal ClusterIP calls reach them without credentials. */}} {{- define "flyte-binary.ingress.grpcPaths" -}} - /flyteidl2.workflow.RunService - /flyteidl2.workflow.RunService/* -- /flyteidl2.workflow.InternalRunService -- /flyteidl2.workflow.InternalRunService/* - /flyteidl2.task.TaskService - /flyteidl2.task.TaskService/* - /flyteidl2.workflow.TranslatorService - /flyteidl2.workflow.TranslatorService/* -- /flyteidl2.actions.ActionsService -- /flyteidl2.actions.ActionsService/* - /flyteidl2.dataproxy.DataProxyService - /flyteidl2.dataproxy.DataProxyService/* - /flyteidl2.secret.SecretService diff --git a/runs/service/auth/http_middleware.go b/runs/service/auth/http_middleware.go index be8a7bb9f51..03cba45866d 100644 --- a/runs/service/auth/http_middleware.go +++ b/runs/service/auth/http_middleware.go @@ -9,9 +9,15 @@ import ( ) // publicPathPrefixes lists request paths that never require authentication. -// These must cover health probes, the browser OAuth2/OIDC flow, metadata -// discovery, and the AuthMetadataService (which clients call *before* they -// have a token). +// These cover: +// - health probes +// - the browser OAuth2/OIDC flow +// - metadata discovery (AuthMetadataService is called pre-auth) +// - intra-cluster services that task pods call via the ClusterIP service +// without credentials (ActionsService, InternalRunService). These are +// deliberately excluded from the external ALB ingress in +// charts/flyte-binary/templates/_helpers.tpl so they cannot be reached +// from the public internet; only in-cluster pods can hit them. var publicPathPrefixes = []string{ "/healthz", "/readyz", @@ -21,6 +27,8 @@ var publicPathPrefixes = []string{ "/logout", "/.well-known/", "/flyteidl2.auth.AuthMetadataService/", + "/flyteidl2.actions.ActionsService/", + "/flyteidl2.workflow.InternalRunService/", } // IsPublicPath reports whether an HTTP request path bypasses authentication. diff --git a/runs/service/auth/http_middleware_test.go b/runs/service/auth/http_middleware_test.go index 8f9aa367967..6aa887c825c 100644 --- a/runs/service/auth/http_middleware_test.go +++ b/runs/service/auth/http_middleware_test.go @@ -12,20 +12,22 @@ import ( func TestIsPublicPath(t *testing.T) { cases := map[string]bool{ - "/healthz": true, - "/readyz": true, - "/healthcheck": true, - "/login": true, - "/login?redirect_url=/console": true, - "/callback": true, - "/logout": true, - "/.well-known/openid-configuration": true, - "/.well-known/oauth-authorization-server": true, + "/healthz": true, + "/readyz": true, + "/healthcheck": true, + "/login": true, + "/login?redirect_url=/console": true, + "/callback": true, + "/logout": true, + "/.well-known/openid-configuration": true, + "/.well-known/oauth-authorization-server": true, "/flyteidl2.auth.AuthMetadataService/GetOAuth2Metadata": true, - "/flyteidl2.workflow.RunService/CreateRun": false, - "/flyteidl2.auth.IdentityService/UserInfo": false, - "/": false, - "/api/v1/projects": false, + "/flyteidl2.actions.ActionsService/CreateAction": true, + "/flyteidl2.workflow.InternalRunService/UpdateRun": true, + "/flyteidl2.workflow.RunService/CreateRun": false, + "/flyteidl2.auth.IdentityService/UserInfo": false, + "/": false, + "/api/v1/projects": false, } for path, want := range cases { got := IsPublicPath(path) @@ -72,13 +74,13 @@ func TestMiddleware_DisabledForHTTPBypassesAuth(t *testing.T) { } func TestMiddleware_LoopbackIPv4BypassesAuth(t *testing.T) { - // Intra-process connect-rpc calls (e.g. runs -> actions on localhost:8090) + // Intra-process connect-rpc calls (e.g. runs -> RunService on localhost) // must pass through the middleware without an Authorization header. h := &AuthHandlerConfig{AuthConfig: config.Config{}, CookieManager: CookieManager{}} mw := GetAuthenticationHTTPInterceptor(h) var sb servedBy - req := httptest.NewRequest(http.MethodPost, "/flyteidl2.actions.ActionsService/CreateAction", nil) + req := httptest.NewRequest(http.MethodPost, "/flyteidl2.workflow.RunService/CreateRun", nil) req.RemoteAddr = "127.0.0.1:54321" w := httptest.NewRecorder() mw(sb.handler()).ServeHTTP(w, req) @@ -92,7 +94,7 @@ func TestMiddleware_LoopbackIPv6BypassesAuth(t *testing.T) { mw := GetAuthenticationHTTPInterceptor(h) var sb servedBy - req := httptest.NewRequest(http.MethodPost, "/flyteidl2.actions.ActionsService/CreateAction", nil) + req := httptest.NewRequest(http.MethodPost, "/flyteidl2.workflow.RunService/CreateRun", nil) req.RemoteAddr = "[::1]:54321" w := httptest.NewRecorder() mw(sb.handler()).ServeHTTP(w, req) @@ -101,14 +103,33 @@ func TestMiddleware_LoopbackIPv6BypassesAuth(t *testing.T) { assert.True(t, sb.called) } +func TestMiddleware_ActionsServicePublicFromPodIP(t *testing.T) { + // Task pods call ActionsService from their pod IP (non-loopback) over + // the ClusterIP service. The path must be allowlisted so the SDK can + // launch actions without carrying credentials. + h := &AuthHandlerConfig{AuthConfig: config.Config{}, CookieManager: CookieManager{}} + mw := GetAuthenticationHTTPInterceptor(h) + + var sb servedBy + req := httptest.NewRequest(http.MethodPost, "/flyteidl2.actions.ActionsService/CreateAction", nil) + req.RemoteAddr = "10.1.193.72:33100" + w := httptest.NewRecorder() + mw(sb.handler()).ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.True(t, sb.called, "ActionsService must be reachable from task pods without auth") +} + func TestMiddleware_NonLoopbackStillBlocks(t *testing.T) { // A caller from a real pod IP must still hit the 401 path when no auth // is present — the loopback bypass is strictly for in-process calls. + // Use RunService (user-facing) rather than ActionsService (in-cluster-only + // public path) so we're exercising the actual gate. h := &AuthHandlerConfig{AuthConfig: config.Config{}, CookieManager: CookieManager{}} mw := GetAuthenticationHTTPInterceptor(h) var sb servedBy - req := httptest.NewRequest(http.MethodPost, "/flyteidl2.actions.ActionsService/CreateAction", nil) + req := httptest.NewRequest(http.MethodPost, "/flyteidl2.workflow.RunService/CreateRun", nil) req.RemoteAddr = "10.1.42.7:48221" w := httptest.NewRecorder() mw(sb.handler()).ServeHTTP(w, req) From 5ba507bf6c133ff2aa37b653d4d47468fd49a4b1 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 11 Apr 2026 22:23:10 -0700 Subject: [PATCH 12/13] rm identity_service Signed-off-by: Kevin Su --- runs/service/auth/identity_service.go | 29 ---------------------- runs/service/auth/identity_service_test.go | 21 ---------------- runs/setup.go | 2 +- 3 files changed, 1 insertion(+), 51 deletions(-) delete mode 100644 runs/service/auth/identity_service.go delete mode 100644 runs/service/auth/identity_service_test.go diff --git a/runs/service/auth/identity_service.go b/runs/service/auth/identity_service.go deleted file mode 100644 index d2c5131175e..00000000000 --- a/runs/service/auth/identity_service.go +++ /dev/null @@ -1,29 +0,0 @@ -package auth - -import ( - "context" - - "connectrpc.com/connect" - - authpb "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" - "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth/authconnect" -) - -// IdentityService implements the IdentityServiceHandler interface. -type IdentityService struct{} - -// NewIdentityService creates a new IdentityService instance. -func NewIdentityService() *IdentityService { - return &IdentityService{} -} - -var _ authconnect.IdentityServiceHandler = (*IdentityService)(nil) - -// UserInfo returns information about the currently logged in user. -// TODO: Wire with real auth to populate user info from the authenticated context. -func (s *IdentityService) UserInfo( - ctx context.Context, - req *connect.Request[authpb.UserInfoRequest], -) (*connect.Response[authpb.UserInfoResponse], error) { - return connect.NewResponse(&authpb.UserInfoResponse{}), nil -} diff --git a/runs/service/auth/identity_service_test.go b/runs/service/auth/identity_service_test.go deleted file mode 100644 index ae76efce8f8..00000000000 --- a/runs/service/auth/identity_service_test.go +++ /dev/null @@ -1,21 +0,0 @@ -package auth - -import ( - "context" - "testing" - - "connectrpc.com/connect" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - authpb "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" -) - -func TestIdentityService_UserInfo(t *testing.T) { - svc := NewIdentityService() - - resp, err := svc.UserInfo(context.Background(), connect.NewRequest(&authpb.UserInfoRequest{})) - require.NoError(t, err) - assert.NotNil(t, resp) - assert.NotNil(t, resp.Msg) -} diff --git a/runs/setup.go b/runs/setup.go index 4ebe0af9c88..15f6129e9b0 100644 --- a/runs/setup.go +++ b/runs/setup.go @@ -178,7 +178,7 @@ func setupAuth(ctx context.Context, sc *app.SetupContext) error { sc.Mux.Handle(authPath, authHandler) logger.Infof(ctx, "Mounted AuthMetadataService at %s", authPath) - identitySvc := authservice.NewIdentityService() + identitySvc := authservice.NewUserInfoProvider() identityPath, identityHandler := authconnect.NewIdentityServiceHandler(identitySvc) sc.Mux.Handle(identityPath, identityHandler) logger.Infof(ctx, "Mounted IdentityService at %s", identityPath) From ba49ba0329c0e10010b3a618b2e75572738aa39a Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 11 Apr 2026 23:08:41 -0700 Subject: [PATCH 13/13] [V2] fix: use protojson for OAuth2 metadata deserialization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The external auth server's .well-known/oauth-authorization-server response uses camelCase JSON keys (authorizationEndpoint, tokenEndpoint, jwksUri) per the proto3 JSON specification. The v2 proto struct has snake_case JSON tags (authorization_endpoint, token_endpoint, jwks_uri). json.Unmarshal only matched the issuer field (same case), silently dropping all other fields. This caused GetOAuth2Metadata to return only {"issuer":"..."}, breaking CLI auth bootstrap — the client could not discover the token or authorization endpoints. Switch unmarshalResp from json.Unmarshal to protojson.Unmarshal, which accepts both camelCase and snake_case input per the protobuf spec. Signed-off-by: Kevin Su --- runs/service/auth/authzserver/metadata_provider.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/runs/service/auth/authzserver/metadata_provider.go b/runs/service/auth/authzserver/metadata_provider.go index 9ce7dc7c7a5..1ff1349b661 100644 --- a/runs/service/auth/authzserver/metadata_provider.go +++ b/runs/service/auth/authzserver/metadata_provider.go @@ -2,7 +2,6 @@ package authzserver import ( "context" - "encoding/json" "fmt" "io" "mime" @@ -12,6 +11,8 @@ import ( "time" "connectrpc.com/connect" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth/authconnect" @@ -161,9 +162,14 @@ func (s *authMetadataService) GetPublicClientConfig( }), nil } -// unmarshalResp unmarshals a JSON response body, providing a detailed error if the Content-Type is unexpected. -func unmarshalResp(r *http.Response, body []byte, v interface{}) error { - err := json.Unmarshal(body, &v) +// unmarshalResp unmarshals a JSON response body into a protobuf message. It +// uses protojson.Unmarshal which accepts both the camelCase form used by +// proto3 JSON serialization and the snake_case form matching proto field +// names. This is important because external authorization servers (including +// flyteadmin) emit camelCase keys while the Go proto struct tags are +// snake_case. +func unmarshalResp(r *http.Response, body []byte, v proto.Message) error { + err := protojson.Unmarshal(body, v) if err == nil { return nil }