diff --git a/charts/flyte-binary/defaults/cluster-resource-templates/namespace.yaml b/charts/flyte-binary/defaults/cluster-resource-templates/namespace.yaml deleted file mode 100644 index 301cb82f42f..00000000000 --- a/charts/flyte-binary/defaults/cluster-resource-templates/namespace.yaml +++ /dev/null @@ -1,4 +0,0 @@ -apiVersion: v1 -kind: Namespace -metadata: - name: '{{ namespace }}' diff --git a/charts/flyte-binary/templates/_helpers.tpl b/charts/flyte-binary/templates/_helpers.tpl index 6349140b802..287725002c1 100644 --- a/charts/flyte-binary/templates/_helpers.tpl +++ b/charts/flyte-binary/templates/_helpers.tpl @@ -113,6 +113,35 @@ templates: {{- toYaml .custom | nindent 2 -}} {{- end -}} {{- end -}} +{{/* +Selector labels for Console +*/}} +{{ define "flyte-binary.consoleSelectorLabels" -}} +{{ include "flyte-binary.selectorLabels" . }} +app.kubernetes.io/component: console +{{- end }} + +{{/* +Get the Secret name for Run service authentication secrets. When a user +supplies `configuration.auth.runServiceAuthSecretRef`, that existing Secret is +referenced directly (no template is rendered); otherwise a new Secret named +`-admin-auth` is used. +*/}} +{{ define "flyte-binary.configuration.auth.runServiceAuthSecretName" -}} +{{- if .Values.configuration.auth.runServiceAuthSecretRef -}} +{{ tpl .Values.configuration.auth.runServiceAuthSecretRef . }} +{{- else -}} +{{ printf "%s-admin-auth" (include "flyte-binary.fullname" .) }} +{{- end -}} +{{ end -}} + +{{/* +Get the Secret name for Flyte authentication client secrets. +*/}} +{{ define "flyte-binary.configuration.auth.clientSecretName" -}} +{{ printf "%s-client-secrets" (include "flyte-binary.fullname" .) }} +{{ end -}} + {{/* Get the Flyte cluster resource templates ConfigMap name. */}} @@ -131,14 +160,14 @@ Get the Flyte HTTP service name Get the Flyte service HTTP port. */}} {{- define "flyte-binary.service.http.port" -}} -{{- default 8090 .Values.service.ports.http -}} +{{- default 8080 .Values.service.ports.http -}} {{- end -}} {{/* Get the Flyte gRPC service name */}} {{- define "flyte-binary.service.grpc.name" -}} -{{- printf "%s-http" (include "flyte-binary.fullname" .) -}} +{{- printf "%s-grpc" (include "flyte-binary.fullname" .) -}} {{- end -}} {{/* @@ -149,7 +178,11 @@ Get the Flyte service gRPC port. {{- end -}} {{/* -Get the Flyte API paths for ingress. +Get the Flyte API paths for ingress. Services whose names start with +"Internal" (e.g. InternalRunService) plus ActionsService are intended for +intra-cluster traffic from task pods only; they are deliberately NOT exposed +via the external ALB ingress here. The Go auth middleware allowlists them so +cluster-internal ClusterIP calls reach them without credentials. */}} {{- define "flyte-binary.ingress.grpcPaths" -}} - /flyteidl2.workflow.RunService @@ -158,12 +191,20 @@ Get the Flyte API paths for ingress. - /flyteidl2.task.TaskService/* - /flyteidl2.workflow.TranslatorService - /flyteidl2.workflow.TranslatorService/* -- /flyteidl2.actions.ActionsService -- /flyteidl2.actions.ActionsService/* - /flyteidl2.dataproxy.DataProxyService - /flyteidl2.dataproxy.DataProxyService/* - /flyteidl2.secret.SecretService - /flyteidl2.secret.SecretService/* +- /flyteidl2.project.ProjectService +- /flyteidl2.project.ProjectService/* +- /flyteidl2.app.AppService +- /flyteidl2.app.AppService/* +- /flyteidl2.trigger.TriggerService +- /flyteidl2.trigger.TriggerService/* +- /flyteidl2.auth.AuthMetadataService +- /flyteidl2.auth.AuthMetadataService/* +- /flyteidl2.auth.IdentityService +- /flyteidl2.auth.IdentityService/* {{- end -}} {{/* diff --git a/charts/flyte-binary/templates/admin-auth-secret.yaml b/charts/flyte-binary/templates/admin-auth-secret.yaml deleted file mode 100644 index 3e164c1f109..00000000000 --- a/charts/flyte-binary/templates/admin-auth-secret.yaml +++ /dev/null @@ -1,16 +0,0 @@ -{{- if .Values.configuration.auth.enabled }} -apiVersion: v1 -kind: Secret -metadata: - name: {{ include "flyte-binary.configuration.auth.adminAuthSecretName" . }} - namespace: {{ .Release.Namespace | quote }} - labels: {{- include "flyte-binary.labels" . | nindent 4 }} - {{- if .Values.commonLabels }} - {{- tpl ( .Values.commonLabels | toYaml ) . | nindent 4 }} - {{- end }} - annotations: - {{- if .Values.commonAnnotations }} - {{- tpl ( .Values.commonAnnotations | toYaml ) . | nindent 4 }} - {{- end }} -type: Opaque -{{- end }} diff --git a/charts/flyte-binary/templates/clusterrole.yaml b/charts/flyte-binary/templates/clusterrole.yaml index 9999a292d63..97ee3354d78 100644 --- a/charts/flyte-binary/templates/clusterrole.yaml +++ b/charts/flyte-binary/templates/clusterrole.yaml @@ -19,6 +19,15 @@ metadata: {{- tpl ( .Values.rbac.annotations | toYaml ) . | nindent 4 }} {{- end }} rules: + - apiGroups: + - "" + resources: + - namespaces + verbs: + - create + - get + - list + - watch - apiGroups: - "" resources: @@ -33,17 +42,23 @@ rules: - watch - apiGroups: - "" + - events.k8s.io resources: - events verbs: - create - delete + - get + - list - patch - update + - watch - apiGroups: - flyte.org resources: - taskactions + - taskactions/status + - taskactions/finalizers verbs: - create - delete @@ -69,8 +84,12 @@ rules: - secrets verbs: - create + - delete - get + - list + - patch - update + - watch {{- if .Values.rbac.extraRules }} {{- toYaml .Values.rbac.extraRules | nindent 2 }} {{- end }} diff --git a/charts/flyte-binary/templates/config-secret.yaml b/charts/flyte-binary/templates/config-secret.yaml index 5992755b1fb..467e2a6e0dc 100644 --- a/charts/flyte-binary/templates/config-secret.yaml +++ b/charts/flyte-binary/templates/config-secret.yaml @@ -4,19 +4,19 @@ kind: Secret metadata: name: {{ include "flyte-binary.configuration.configSecretName" . }} namespace: {{ .Release.Namespace | quote }} - labels: {{- include "flyte-binary.labels" . | nindent 4 }} + labels: {{ include "flyte-binary.labels" . | nindent 4 }} {{- if .Values.commonLabels }} - {{- tpl ( .Values.commonLabels | toYaml ) . | nindent 4 }} + {{ tpl ( .Values.commonLabels | toYaml ) . | nindent 4 }} {{- end }} {{- if .Values.configuration.labels }} - {{- tpl ( .Values.configuration.labels | toYaml ) . | nindent 4 }} + {{ tpl ( .Values.configuration.labels | toYaml ) . | nindent 4 }} {{- end }} annotations: {{- if .Values.commonAnnotations }} - {{- tpl ( .Values.commonAnnotations | toYaml ) . | nindent 4 }} + {{ tpl ( .Values.commonAnnotations | toYaml ) . | nindent 4 }} {{- end }} {{- if .Values.configuration.annotations }} - {{- tpl ( .Values.configuration.annotations | toYaml ) . | nindent 4 }} + {{ tpl ( .Values.configuration.annotations | toYaml ) . | nindent 4 }} {{- end }} type: Opaque stringData: @@ -25,6 +25,10 @@ stringData: database: postgres: password: {{ .Values.configuration.database.password | quote }} + runs: + database: + postgres: + password: {{ .Values.configuration.database.password | quote }} {{- end }} {{- if eq "s3" .Values.configuration.storage.provider }} {{- if eq "accesskey" .Values.configuration.storage.providerConfig.s3.authType }} @@ -44,7 +48,7 @@ stringData: appAuth: selfAuthServer: staticClients: - flytepropeller: + executor: client_secret: {{ .Values.configuration.auth.internal.clientSecretHash | quote }} {{- end }} {{- end }} diff --git a/charts/flyte-binary/templates/configmap.yaml b/charts/flyte-binary/templates/configmap.yaml index e0d4f20a9b5..161e325e3fb 100644 --- a/charts/flyte-binary/templates/configmap.yaml +++ b/charts/flyte-binary/templates/configmap.yaml @@ -105,6 +105,54 @@ data: {{- end }} container: {{ required "Metadata container required" .metadataContainer }} {{- end }} + {{- if .Values.configuration.auth.enabled }} + 004-auth.yaml: | + auth: + appAuth: + {{- if .Values.configuration.auth.enableAuthServer }} + authServerType: Self + {{- else }} + authServerType: External + {{- end }} + {{- with .Values.configuration.auth.externalAuthServer }} + externalAuthServer: + baseUrl: {{ tpl (default "" .baseUrl) $ | quote }} + {{- if .metadataUrl }} + metadataUrl: {{ .metadataUrl | quote }} + {{- end }} + allowedAudience: + {{- range .allowedAudience }} + - {{ tpl . $ | quote }} + {{- end }} + {{- end }} + {{- with .Values.configuration.auth.flyteClient }} + thirdPartyConfig: + flyteClient: + clientId: {{ .clientId | quote }} + redirectUri: {{ .redirectUri | quote }} + {{- if .audience }} + audience: {{ .audience | quote }} + {{- end }} + scopes: + {{- range .scopes }} + - {{ . | quote }} + {{- end }} + {{- end }} + authorizedUris: + {{- range .Values.configuration.auth.authorizedUris }} + - {{ tpl . $ | quote }} + {{- end }} + userAuth: + openId: + baseUrl: {{ .Values.configuration.auth.oidc.baseUrl | quote }} + clientId: {{ .Values.configuration.auth.oidc.clientId | quote }} + scopes: + - openid + - profile + runs: + security: + useAuth: true + {{- end }} {{- if .Values.configuration.inline }} 100-inline-config.yaml: | {{- tpl ( .Values.configuration.inline | toYaml ) . | nindent 4 }} diff --git a/charts/flyte-binary/templates/console/deployment.yaml b/charts/flyte-binary/templates/console/deployment.yaml new file mode 100644 index 00000000000..2b2b6eded7d --- /dev/null +++ b/charts/flyte-binary/templates/console/deployment.yaml @@ -0,0 +1,46 @@ +{{- if .Values.console.enabled }} +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ include "flyte-binary.fullname" . }}-console + namespace: {{ .Release.Namespace | quote }} + labels: {{ include "flyte-binary.labels" . | nindent 4 }} +spec: + replicas: 1 + selector: + matchLabels: {{ include "flyte-binary.consoleSelectorLabels" . | nindent 6 }} + template: + metadata: + labels: {{ include "flyte-binary.consoleSelectorLabels" . | nindent 8 }} + spec: + {{- with .Values.console.imagePullSecrets }} + imagePullSecrets: + {{ toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.deployment.extraPodSpec.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 8 }} + {{- end }} + containers: + - name: console + {{- with .Values.console.image }} + image: {{ printf "%s:%s" .repository .tag | quote }} + imagePullPolicy: {{ .pullPolicy | quote }} + {{- end }} + ports: + - name: http + containerPort: 8080 + protocol: TCP + readinessProbe: + httpGet: + path: /v2 + port: http + initialDelaySeconds: 5 + periodSeconds: 10 + livenessProbe: + httpGet: + path: /v2 + port: http + initialDelaySeconds: 5 + periodSeconds: 30 +{{- end }} \ No newline at end of file diff --git a/charts/flyte-binary/templates/console/service.yaml b/charts/flyte-binary/templates/console/service.yaml new file mode 100644 index 00000000000..6ba2e10d6a4 --- /dev/null +++ b/charts/flyte-binary/templates/console/service.yaml @@ -0,0 +1,20 @@ +{{- if .Values.console.enabled }} +apiVersion: v1 +kind: Service +metadata: + name: {{ include "flyte-binary.fullname" . }}-console + namespace: {{ .Release.Namespace | quote }} + labels: {{ include "flyte-binary.labels" . | nindent 4 }} + {{- with .Values.console.service.annotations }} + annotations: + {{ toYaml . | nindent 4 }} + {{- end }} +spec: + type: {{ .Values.console.service.type | default "ClusterIP" }} + ports: + - name: http + port: {{ .Values.console.service.port | default 80 }} + targetPort: http + protocol: TCP + selector: {{ include "flyte-binary.consoleSelectorLabels" . | nindent 4 }} +{{- end }} diff --git a/charts/flyte-binary/templates/deployment.yaml b/charts/flyte-binary/templates/deployment.yaml index 144c7c359d2..91448335706 100644 --- a/charts/flyte-binary/templates/deployment.yaml +++ b/charts/flyte-binary/templates/deployment.yaml @@ -3,42 +3,50 @@ kind: Deployment metadata: name: {{ include "flyte-binary.fullname" . }} namespace: {{ .Release.Namespace | quote }} - labels: {{- include "flyte-binary.labels" . | nindent 4 }} + labels: {{ include "flyte-binary.labels" . | nindent 4 }} {{- if .Values.commonLabels }} - {{- tpl ( .Values.commonLabels | toYaml ) . | nindent 4 }} + {{ tpl ( .Values.commonLabels | toYaml ) . | nindent 4 }} {{- end }} {{- if .Values.deployment.labels }} - {{- tpl ( .Values.deployment.labels | toYaml ) . | nindent 4 }} + {{ tpl ( .Values.deployment.labels | toYaml ) . | nindent 4 }} {{- end }} annotations: {{- if .Values.commonAnnotations }} - {{- tpl ( .Values.commonAnnotations | toYaml ) . | nindent 4 }} + {{ tpl ( .Values.commonAnnotations | toYaml ) . | nindent 4 }} {{- end }} {{- if .Values.deployment.annotations }} - {{- tpl ( .Values.deployment.annotations | toYaml ) . | nindent 4 }} + {{ tpl ( .Values.deployment.annotations | toYaml ) . | nindent 4 }} {{- end }} spec: replicas: 1 strategy: type: Recreate selector: - matchLabels: {{- include "flyte-binary.selectorLabels" . | nindent 6 }} + matchLabels: {{ include "flyte-binary.selectorLabels" . | nindent 6 }} template: metadata: - labels: {{- include "flyte-binary.selectorLabels" . | nindent 8 }} + labels: {{ include "flyte-binary.selectorLabels" . | nindent 8 }} {{- if .Values.commonLabels }} - {{- tpl ( .Values.commonLabels | toYaml ) . | nindent 8 }} + {{ tpl ( .Values.commonLabels | toYaml ) . | nindent 8 }} {{- end }} {{- if .Values.deployment.podLabels }} - {{- tpl ( .Values.deployment.podLabels | toYaml ) . | nindent 8 }} + {{ tpl ( .Values.deployment.podLabels | toYaml ) . | nindent 8 }} {{- end }} annotations: {{- if not (include "flyte-binary.configuration.externalConfiguration" .) }} checksum/configuration: {{ include (print $.Template.BasePath "/configmap.yaml") . | sha256sum }} checksum/configuration-secret: {{ include (print $.Template.BasePath "/config-secret.yaml") . | sha256sum }} {{- end }} + {{- if .Values.configuration.auth.enabled }} + {{- if not .Values.configuration.auth.runServiceAuthSecretRef }} + checksum/runservice-auth-secret: {{ include (print $.Template.BasePath "/run-service-auth-secret.yaml") . | sha256sum }} + {{- end }} + {{- if not .Values.configuration.auth.clientSecretsExternalSecretRef }} + checksum/auth-client-secret: {{ include (print $.Template.BasePath "/auth-client-secret.yaml") . | sha256sum }} + {{- end }} + {{- end }} {{- if .Values.commonAnnotations }} - {{- tpl ( .Values.commonAnnotations | toYaml ) . | nindent 8 }} + {{ tpl ( .Values.commonAnnotations | toYaml ) . | nindent 8 }} {{- end }} {{- if .Values.deployment.podAnnotations }} {{- tpl ( .Values.deployment.podAnnotations | toYaml ) . | nindent 8 }} @@ -86,14 +94,14 @@ spec: {{- end }} {{- end }} {{- if .Values.deployment.resources }} - resources: {{- toYaml .Values.deployment.resources | nindent 12 }} + resources: {{ toYaml .Values.deployment.resources | nindent 12 }} {{- end }} {{- if .Values.deployment.waitForDB.securityContext }} - securityContext: {{- toYaml .Values.deployment.waitForDB.securityContext | nindent 12 }} + securityContext: {{ toYaml .Values.deployment.waitForDB.securityContext | nindent 12 }} {{- end }} {{- end }} {{- if .Values.deployment.initContainers }} - {{- tpl ( .Values.deployment.initContainers | toYaml ) . | nindent 8 }} + {{ tpl ( .Values.deployment.initContainers | toYaml ) . | nindent 8 }} {{- end }} {{- end }} containers: @@ -138,6 +146,8 @@ spec: ports: - name: http containerPort: 8090 + - name: grpc + containerPort: 8080 - name: webhook containerPort: 9443 {{- if .Values.deployment.startupProbe }} @@ -159,7 +169,7 @@ spec: httpGet: path: /healthz port: http - initialDelaySeconds: 30 + initialDelaySeconds: 5 {{- end }} {{- if .Values.deployment.resources }} resources: {{- toYaml .Values.deployment.resources | nindent 12 }} @@ -187,6 +197,20 @@ spec: {{- tpl ( .Values.deployment.sidecars | toYaml ) . | nindent 8 }} {{- end }} volumes: + {{- if .Values.configuration.auth.enabled }} + - name: auth + projected: + sources: + - secret: + name: {{ include "flyte-binary.configuration.auth.runServiceAuthSecretName" . }} + {{- if .Values.configuration.auth.clientSecretsExternalSecretRef }} + - secret: + name: {{ tpl .Values.configuration.auth.clientSecretsExternalSecretRef . }} + {{- else }} + - secret: + name: {{ include "flyte-binary.configuration.auth.clientSecretName" . }} + {{- end }} + {{- end }} - name: config {{- if (include "flyte-binary.configuration.externalConfiguration" .) }} projected: @@ -214,9 +238,13 @@ spec: - secret: name: {{ tpl .Values.configuration.inlineSecretRef . }} {{- end }} + {{- range .Values.configuration.extraInlineSecretRefs }} + - secret: + name: {{ tpl . $ }} + {{- end }} {{- end }} - name: webhook-certs emptyDir: {} {{- if .Values.deployment.extraVolumes }} - {{- tpl ( .Values.deployment.extraVolumes | toYaml ) . | nindent 8 }} + {{ tpl ( .Values.deployment.extraVolumes | toYaml ) . | nindent 8 }} {{- end }} diff --git a/charts/flyte-binary/templates/ingress/grpc.yaml b/charts/flyte-binary/templates/ingress/grpc.yaml index bd43185e509..85ef5fcba45 100644 --- a/charts/flyte-binary/templates/ingress/grpc.yaml +++ b/charts/flyte-binary/templates/ingress/grpc.yaml @@ -7,20 +7,20 @@ metadata: namespace: {{ .Release.Namespace | quote }} labels: {{- include "flyte-binary.labels" . | nindent 4 }} {{- if .Values.commonLabels }} - {{- tpl ( .Values.commonLabels | toYaml ) . | nindent 4 }} + {{ tpl ( .Values.commonLabels | toYaml ) . | nindent 4 }} {{- end }} {{- if .Values.ingress.labels }} - {{- tpl ( .Values.ingress.labels | toYaml ) . | nindent 4 }} + {{ tpl ( .Values.ingress.labels | toYaml ) . | nindent 4 }} {{- end }} annotations: {{- if .Values.commonAnnotations }} - {{- tpl ( .Values.commonAnnotations | toYaml ) . | nindent 4 }} + {{ tpl ( .Values.commonAnnotations | toYaml ) . | nindent 4 }} {{- end }} {{- if .Values.ingress.commonAnnotations }} - {{- tpl ( .Values.ingress.commonAnnotations | toYaml ) . | nindent 4 }} + {{ tpl ( .Values.ingress.commonAnnotations | toYaml ) . | nindent 4 }} {{- end }} {{- if .Values.ingress.grpcAnnotations }} - {{- tpl ( .Values.ingress.grpcAnnotations | toYaml ) . | nindent 4 }} + {{ tpl ( .Values.ingress.grpcAnnotations | toYaml ) . | nindent 4 }} {{- end }} spec: {{- if .Values.ingress.grpcIngressClassName }} diff --git a/charts/flyte-binary/templates/ingress/http.yaml b/charts/flyte-binary/templates/ingress/http.yaml index cdce87047bd..58bcd58aa25 100644 --- a/charts/flyte-binary/templates/ingress/http.yaml +++ b/charts/flyte-binary/templates/ingress/http.yaml @@ -38,6 +38,23 @@ spec: {{- if .Values.ingress.httpExtraPaths.prepend }} {{- tpl ( .Values.ingress.httpExtraPaths.prepend | toYaml ) . | nindent 6 }} {{- end }} + {{- if .Values.console.enabled }} + - backend: + service: + name: {{ include "flyte-binary.fullname" . }}-console + port: + number: {{ .Values.console.service.port | default 80 }} + path: /v2 + pathType: ImplementationSpecific + - backend: + service: + name: {{ include "flyte-binary.fullname" . }}-console + port: + number: {{ .Values.console.service.port | default 80 }} + path: /v2/* + pathType: ImplementationSpecific + {{- end }} + {{- if not .Values.ingress.minimalPaths }} - backend: service: name: {{ include "flyte-binary.service.http.name" . }} @@ -52,6 +69,8 @@ spec: number: {{ include "flyte-binary.service.http.port" . }} path: /console/* pathType: ImplementationSpecific + {{- end }} + {{- if not .Values.ingress.minimalPaths }} - backend: service: name: {{ include "flyte-binary.service.http.name" . }} @@ -66,6 +85,7 @@ spec: number: {{ include "flyte-binary.service.http.port" . }} path: /api/* pathType: ImplementationSpecific + {{- end }} - backend: service: name: {{ include "flyte-binary.service.http.name" . }} @@ -73,6 +93,7 @@ spec: number: {{ include "flyte-binary.service.http.port" . }} path: /healthcheck pathType: ImplementationSpecific + {{- if not .Values.ingress.minimalPaths }} - backend: service: name: {{ include "flyte-binary.service.http.name" . }} @@ -94,6 +115,7 @@ spec: number: {{ include "flyte-binary.service.http.port" . }} path: /.well-known/* pathType: ImplementationSpecific + {{- end }} - backend: service: name: {{ include "flyte-binary.service.http.name" . }} @@ -136,6 +158,7 @@ spec: number: {{ include "flyte-binary.service.http.port" . }} path: /callback/* pathType: ImplementationSpecific + {{- if not .Values.ingress.minimalPaths }} - backend: service: name: {{ include "flyte-binary.service.http.name" . }} @@ -171,6 +194,7 @@ spec: number: {{ include "flyte-binary.service.http.port" . }} path: /oauth2/* pathType: ImplementationSpecific + {{- end }} {{- if not .Values.ingress.separateGrpcIngress }} {{- $paths := (include "flyte-binary.ingress.grpcPaths" .) | fromYamlArray }} {{- range $path := $paths }} diff --git a/charts/flyte-binary/templates/run-service-auth-secret.yaml b/charts/flyte-binary/templates/run-service-auth-secret.yaml new file mode 100644 index 00000000000..111ea59912f --- /dev/null +++ b/charts/flyte-binary/templates/run-service-auth-secret.yaml @@ -0,0 +1,16 @@ +{{- if and .Values.configuration.auth.enabled (not .Values.configuration.auth.runServiceAuthSecretRef) }} +apiVersion: v1 +kind: Secret +metadata: + name: {{ include "flyte-binary.configuration.auth.runServiceAuthSecretName" . }} + namespace: {{ .Release.Namespace | quote }} + labels: {{ include "flyte-binary.labels" . | nindent 4 }} + {{- if .Values.commonLabels }} + {{ tpl ( .Values.commonLabels | toYaml ) . | nindent 4 }} + {{- end }} + annotations: + {{- if .Values.commonAnnotations }} + {{ tpl ( .Values.commonAnnotations | toYaml ) . | nindent 4 }} + {{- end }} +type: Opaque +{{- end }} diff --git a/charts/flyte-binary/templates/service/grpc.yaml b/charts/flyte-binary/templates/service/grpc.yaml new file mode 100644 index 00000000000..6bc6f881766 --- /dev/null +++ b/charts/flyte-binary/templates/service/grpc.yaml @@ -0,0 +1,48 @@ +{{- if .Values.ingress.separateGrpcIngress }} +apiVersion: v1 +kind: Service +metadata: + name: {{ include "flyte-binary.service.grpc.name" . }} + namespace: {{ .Release.Namespace | quote }} + labels: {{ include "flyte-binary.labels" . | nindent 4 }} + {{- if .Values.commonLabels }} + {{ tpl ( .Values.commonLabels | toYaml ) . | nindent 4 }} + {{- end }} + {{- if .Values.service.labels }} + {{ tpl ( .Values.service.labels | toYaml ) . | nindent 4 }} + {{- end }} + annotations: + {{- if .Values.commonAnnotations }} + {{ tpl ( .Values.commonAnnotations | toYaml ) . | nindent 4 }} + {{- end }} + {{- if .Values.service.commonAnnotations }} + {{ tpl ( .Values.service.commonAnnotations | toYaml ) . | nindent 4 }} + {{- end }} + {{- if .Values.service.grpcAnnotations }} + {{ tpl ( .Values.service.grpcAnnotations | toYaml ) . | nindent 4 }} + {{- end }} +spec: + type: {{ .Values.service.type }} + {{- if or (eq .Values.service.type "LoadBalancer") (eq .Values.service.type "NodePort") }} + externalTrafficPolicy: {{ .Values.service.externalTrafficPolicy | quote }} + {{- end }} + {{- if and (eq .Values.service.type "LoadBalancer") (not (empty .Values.service.loadBalancerSourceRanges)) }} + loadBalancerSourceRanges: {{ .Values.service.loadBalancerSourceRanges }} + {{- end }} + {{- if and (eq .Values.service.type "LoadBalancer") (not (empty .Values.service.loadBalancerIP)) }} + loadBalancerIP: {{ .Values.service.loadBalancerIP }} + {{- end }} + {{- if and .Values.service.clusterIP (eq .Values.service.type "ClusterIP") }} + clusterIP: {{ .Values.service.clusterIP }} + {{- end }} + ports: + - name: grpc + port: {{ include "flyte-binary.service.grpc.port" . }} + targetPort: http + {{- if and (or (eq .Values.service.type "NodePort") (eq .Values.service.type "LoadBalancer")) (not (empty .Values.service.nodePorts.grpc)) }} + nodePort: {{ .Values.service.nodePorts.grpc }} + {{- else if eq .Values.service.type "ClusterIP" }} + nodePort: null + {{- end }} + selector: {{- include "flyte-binary.selectorLabels" . | nindent 4 }} +{{- end }} diff --git a/charts/flyte-binary/templates/service/http.yaml b/charts/flyte-binary/templates/service/http.yaml index ab79c90210d..e2a727aa195 100644 --- a/charts/flyte-binary/templates/service/http.yaml +++ b/charts/flyte-binary/templates/service/http.yaml @@ -46,7 +46,7 @@ spec: {{- if not .Values.ingress.separateGrpcIngress }} - name: grpc port: {{ include "flyte-binary.service.grpc.port" . }} - targetPort: grpc + targetPort: http {{- if and (or (eq .Values.service.type "NodePort") (eq .Values.service.type "LoadBalancer")) (not (empty .Values.service.nodePorts.grpc)) }} nodePort: {{ .Values.service.nodePorts.grpc }} {{- else if eq .Values.service.type "ClusterIP" }} diff --git a/charts/flyte-binary/values.yaml b/charts/flyte-binary/values.yaml index 93782f4e311..b4a242b20e0 100644 --- a/charts/flyte-binary/values.yaml +++ b/charts/flyte-binary/values.yaml @@ -184,9 +184,21 @@ configuration: audience: "" # authorizedUris Set of URIs that clients are allowed to visit the service on authorizedUris: [] + # externalAuthServer Configuration for the external OAuth2 authorization + # server whose tokens Flyte will validate. Only used when + # `enableAuthServer: false`. Set `baseUrl` to the issuer URL and + # `allowedAudience` to the list of audiences Flyte should accept. + externalAuthServer: + baseUrl: "" + metadataUrl: "" + allowedAudience: [] # clientSecretExternalSecretRef Specify an existing, external Secret containing values for `client_secret` and `oidc_client_secret`. # If set, a Secret will not be generated by this chart for client secrets. clientSecretsExternalSecretRef: "" + # runServiceAuthSecretRef Specify an existing Secret to supply cookie + # hash/block keys (and other run-service auth secrets) at /etc/secrets. + # If set, this chart will NOT render its own run-service auth Secret. + runServiceAuthSecretRef: "" # co-pilot Configuration for Flyte CoPilot co-pilot: # image Configure image to use for CoPilot sidecar @@ -357,6 +369,13 @@ ingress: labels: {} # host Hostname to bind to ingress resources host: "" + # minimalPaths When true, the HTTP ingress only emits the paths that this + # Flyte deployment actually serves (login/callback/logout, api, v2, console, + # healthcheck). The legacy auth-server paths (/oauth2, /.well-known, /me, + # /config, /v1/*) are omitted so they can be served by a different Flyte + # deployment sharing the same ALB group. Set this on deployments that defer + # token issuance and OAuth metadata to an upstream auth server. + minimalPaths: false # separateGrpcIngress Create a separate ingress resource for GRPC if true. Required for certain ingress controllers like nginx. separateGrpcIngress: true # commonAnnotations Add common annotations to all ingress resources @@ -426,10 +445,24 @@ enabled_plugins: - container - sidecar - connector-service - - echo default-for-task-types: container: container sidecar: sidecar container_array: k8s-array # -- Uncomment to enable task type that uses Flyte Connector # bigquery_query_job_task: connector-service + +console: + enabled: true + image: + repository: ghcr.io/flyteorg/consolev2 + tag: latest + pullPolicy: Always + imagePullSecrets: + - name: ghcr-pull-secret + service: + type: ClusterIP + port: 80 + annotations: + alb.ingress.kubernetes.io/healthcheck-path: /v2 + alb.ingress.kubernetes.io/healthcheck-port: "8080" \ No newline at end of file diff --git a/go.mod b/go.mod index d08b80b45f6..bcc9872a344 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.41.1 github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1 github.com/coocood/freecache v1.2.4 + github.com/coreos/go-oidc/v3 v3.17.0 github.com/dask/dask-kubernetes/v2023 v2023.0.0-20230626103304-abd02cd17b26 github.com/eko/gocache/lib/v4 v4.2.0 github.com/eko/gocache/store/freecache/v4 v4.2.0 @@ -26,8 +27,10 @@ require ( github.com/fsnotify/fsnotify v1.9.0 github.com/ghodss/yaml v1.0.0 github.com/go-test/deep v1.1.1 + github.com/golang-jwt/jwt/v5 v5.3.0 github.com/golang/protobuf v1.5.4 github.com/googleapis/gax-go/v2 v2.15.0 + github.com/gorilla/securecookie v1.1.2 github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus v1.1.0 github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 github.com/hashicorp/golang-lru v1.0.2 @@ -130,7 +133,7 @@ require ( github.com/evanphx/json-patch v5.6.0+incompatible // indirect github.com/evanphx/json-patch/v5 v5.9.11 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/go-jose/go-jose/v4 v4.1.2 // indirect + github.com/go-jose/go-jose/v4 v4.1.3 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-logr/zapr v1.3.0 // indirect @@ -139,7 +142,6 @@ require ( github.com/go-openapi/swag v0.23.0 // indirect github.com/go-task/slim-sprig/v3 v3.0.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/mock v1.6.0 // indirect github.com/google/btree v1.1.3 // indirect diff --git a/go.sum b/go.sum index 70757065def..5a7cf0420b6 100644 --- a/go.sum +++ b/go.sum @@ -172,6 +172,8 @@ github.com/cncf/xds/go v0.0.0-20250501225837-2ac532fd4443 h1:aQ3y1lwWyqYPiWZThqv github.com/cncf/xds/go v0.0.0-20250501225837-2ac532fd4443/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= github.com/coocood/freecache v1.2.4 h1:UdR6Yz/X1HW4fZOuH0Z94KwG851GWOSknua5VUbb/5M= github.com/coocood/freecache v1.2.4/go.mod h1:RBUWa/Cy+OHdfTGFEhEuE1pMCMX51Ncizj7rthiQ3vk= +github.com/coreos/go-oidc/v3 v3.17.0 h1:hWBGaQfbi0iVviX4ibC7bk8OKT5qNr4klBaCHVNvehc= +github.com/coreos/go-oidc/v3 v3.17.0/go.mod h1:wqPbKFrVnE90vty060SB40FCJ8fTHTxSwyXJqZH+sI8= github.com/coreos/go-semver v0.3.1 h1:yi21YpKnrx1gt5R+la8n5WgS0kCrsPp33dmEyHReZr4= github.com/coreos/go-semver v0.3.1/go.mod h1:irMmmIw/7yzSRPWryHsK7EYSg09caPQL03VsM8rvUec= github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= @@ -228,8 +230,8 @@ github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeME github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-jose/go-jose/v4 v4.1.2 h1:TK/7NqRQZfgAh+Td8AlsrvtPoUyiHh0LqVvokh+1vHI= -github.com/go-jose/go-jose/v4 v4.1.2/go.mod h1:22cg9HWM1pOlnRiY+9cQYJ9XHmya1bYW8OeDM6Ku6Oo= +github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= +github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/go-kit/log v0.2.1 h1:MRVx0/zhvdseW+Gza6N9rVzU/IVzaeE1SFI4raAhmBU= github.com/go-kit/log v0.2.1/go.mod h1:NwTd00d/i8cPZ3xOwwiv2PO5MOcx78fFErGNcVmBjv0= github.com/go-logfmt/logfmt v0.5.1 h1:otpy5pqBCBZ1ng9RQ0dPu4PN7ba75Y/aA+UpowDyNVA= @@ -348,6 +350,8 @@ github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5m github.com/googleapis/gax-go/v2 v2.15.0 h1:SyjDc1mGgZU5LncH8gimWo9lW1DtIfPibOG81vgd/bo= github.com/googleapis/gax-go/v2 v2.15.0/go.mod h1:zVVkkxAQHa1RQpg9z2AUCMnKhi0Qld9rcmyfL1OZhoc= github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= +github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= +github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 h1:+9834+KizmvFV7pXQGSXQTsaWhq2GjuNUt0aUU0YBYw= diff --git a/runs/config/config.go b/runs/config/config.go index 590dcb0a6dd..58e77c3aa42 100644 --- a/runs/config/config.go +++ b/runs/config/config.go @@ -63,6 +63,9 @@ type Config struct { // Domains are injected into project responses (not stored per project row). Domains []DomainConfig `json:"domains"` + // Security controls authentication and authorization behavior. + Security SecurityConfig `json:"security"` + // TriggerScheduler configures the cron-based trigger scheduler worker. TriggerScheduler TriggerSchedulerConfig `json:"triggerScheduler"` @@ -70,6 +73,12 @@ type Config struct { Apps AppsConfig `json:"apps"` } +// SecurityConfig controls authentication and authorization behavior. +type SecurityConfig struct { + // UseAuth enables authentication. When true, AuthMetadataService and IdentityService are registered. + UseAuth bool `json:"useAuth" pflag:",Enable authentication and identity services"` +} + // ServerConfig holds HTTP server configuration type ServerConfig struct { Port int `json:"port" pflag:",Port to bind the HTTP server"` diff --git a/runs/service/auth/auth_context.go b/runs/service/auth/auth_context.go new file mode 100644 index 00000000000..ccefc8ea147 --- /dev/null +++ b/runs/service/auth/auth_context.go @@ -0,0 +1,148 @@ +// Package auth contains types needed to start up a standalone OAuth2 Authorization Server or delegate +// authentication to an external provider. It supports OpenID Connect for user authentication. +package auth + +import ( + "context" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/coreos/go-oidc/v3/oidc" + "golang.org/x/oauth2" + + "github.com/flyteorg/flyte/v2/flytestdlib/logger" + "github.com/flyteorg/flyte/v2/runs/service/auth/config" +) + +const ( + IdpConnectionTimeout = 10 * time.Second +) + +// AuthenticationContext holds all the utilities necessary to run authentication. +// +// The auth package supports two request flows, both producing an IdentityContext: +// +// Browser (HTTP) API (gRPC / HTTP Bearer) +// ────────────── ─────────────────────── +// handlers.go interceptor.go +// /login -> IdP redirect ▼ +// /callback -> exchange code token.go / resource_server.go +// /logout -> clear cookies ▼ +// │ claims_verifier.go +// ▼ │ +// cookie_manager.go │ +// (read/write encrypted cookies) │ +// │ │ +// ▼ │ +// cookie.go │ +// (CSRF, secure cookie helpers) │ +// │ │ +// ▼ ▼ +// token.go ──────────────────────> identity_context.go +// (parse/validate JWT) (UserID, AppID, Scopes, Claims) +// +type AuthenticationContext struct { + oauth2Config *oauth2.Config + cookieManager CookieManager + oidcProvider *oidc.Provider + resourceServer OAuth2ResourceServer + cfg config.Config + httpClient *http.Client + oauth2MetadataURL *url.URL + oidcMetadataURL *url.URL +} + +func (c *AuthenticationContext) OAuth2Config() *oauth2.Config { return c.oauth2Config } +func (c *AuthenticationContext) CookieManager() CookieManager { return c.cookieManager } +func (c *AuthenticationContext) OIDCProvider() *oidc.Provider { return c.oidcProvider } +func (c *AuthenticationContext) ResourceServer() OAuth2ResourceServer { return c.resourceServer } +func (c *AuthenticationContext) Config() config.Config { return c.cfg } +func (c *AuthenticationContext) HTTPClient() *http.Client { return c.httpClient } +func (c *AuthenticationContext) OAuth2MetadataURL() *url.URL { return c.oauth2MetadataURL } +func (c *AuthenticationContext) OIDCMetadataURL() *url.URL { return c.oidcMetadataURL } + +// NewAuthContext creates a new AuthContext with all the components needed for authentication. +// oidcClientSecret is the IdP-issued confidential client secret used during the OAuth2 code +// exchange; it may be empty if the client is registered as public at the IdP. +func NewAuthContext(ctx context.Context, cfg config.Config, resourceServer OAuth2ResourceServer, + hashKeyBase64, blockKeyBase64, oidcClientSecret string) (*AuthenticationContext, error) { + + cookieManager, err := NewCookieManager(ctx, hashKeyBase64, blockKeyBase64, cfg.UserAuth.CookieSetting) + if err != nil { + logger.Errorf(ctx, "Error creating cookie manager %s", err) + return nil, fmt.Errorf("error creating cookie manager: %w", err) + } + + httpClient := &http.Client{ + Timeout: IdpConnectionTimeout, + } + + if len(cfg.UserAuth.HTTPProxyURL.String()) > 0 { + logger.Infof(ctx, "HTTPProxy URL for OAuth2 is: %s", cfg.UserAuth.HTTPProxyURL.String()) + httpClient.Transport = &http.Transport{Proxy: http.ProxyURL(&cfg.UserAuth.HTTPProxyURL.URL)} + } + + oidcCtx := oidc.ClientContext(ctx, httpClient) + baseURL := cfg.UserAuth.OpenID.BaseURL.String() + provider, err := oidc.NewProvider(oidcCtx, baseURL) + if err != nil { + return nil, fmt.Errorf("error creating oidc provider with issuer [%v]: %w", baseURL, err) + } + + oauth2Config := &oauth2.Config{ + RedirectURL: computeOIDCRedirectURL(cfg), + ClientID: cfg.UserAuth.OpenID.ClientID, + ClientSecret: oidcClientSecret, + Scopes: cfg.UserAuth.OpenID.Scopes, + Endpoint: provider.Endpoint(), + } + + oauth2MetadataURL, err := url.Parse(OAuth2MetadataEndpoint) + if err != nil { + return nil, fmt.Errorf("error parsing oauth2 metadata URL: %w", err) + } + + oidcMetadataURL, err := url.Parse(OIdCMetadataEndpoint) + if err != nil { + return nil, fmt.Errorf("error parsing oidc metadata URL: %w", err) + } + + return &AuthenticationContext{ + oauth2Config: oauth2Config, + cookieManager: cookieManager, + oidcProvider: provider, + resourceServer: resourceServer, + cfg: cfg, + httpClient: httpClient, + oauth2MetadataURL: oauth2MetadataURL, + oidcMetadataURL: oidcMetadataURL, + }, nil +} + +// computeOIDCRedirectURL returns the absolute redirect URL to use during the OAuth2 authorization +// code flow. IdPs like Okta require an absolute URL registered in their allowed-callbacks list. +// The URL is derived from the first authorizedUri with "/callback" appended. If no authorizedUris +// are configured, the legacy relative "callback" value is returned as a fallback. +func computeOIDCRedirectURL(cfg config.Config) string { + if len(cfg.AuthorizedURIs) == 0 { + return "callback" + } + base := cfg.AuthorizedURIs[0].URL + base.Path = strings.TrimSuffix(base.Path, "/") + "/callback" + return base.String() +} + +// HandlerConfig returns an AuthHandlerConfig suitable for use with RegisterHandlers. +func (c *AuthenticationContext) HandlerConfig() *AuthHandlerConfig { + return &AuthHandlerConfig{ + CookieManager: c.cookieManager, + OAuth2Config: c.oauth2Config, + OIDCProvider: c.oidcProvider, + ResourceServer: c.resourceServer, + AuthConfig: c.cfg, + HTTPClient: c.httpClient, + } +} diff --git a/runs/service/auth/auth_context_test.go b/runs/service/auth/auth_context_test.go new file mode 100644 index 00000000000..01a432856cd --- /dev/null +++ b/runs/service/auth/auth_context_test.go @@ -0,0 +1,69 @@ +package auth + +import ( + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + stdconfig "github.com/flyteorg/flyte/v2/flytestdlib/config" + "github.com/flyteorg/flyte/v2/runs/service/auth/config" +) + +func mustParse(t *testing.T, raw string) stdconfig.URL { + t.Helper() + u, err := url.Parse(raw) + require.NoError(t, err) + return stdconfig.URL{URL: *u} +} + +func TestComputeOIDCRedirectURL(t *testing.T) { + cases := []struct { + name string + cfg config.Config + want string + }{ + { + name: "no authorizedUris falls back to relative path", + cfg: config.Config{}, + want: "callback", + }, + { + name: "simple https host", + cfg: config.Config{ + AuthorizedURIs: []stdconfig.URL{mustParse(t, "https://flyte.example.com")}, + }, + want: "https://flyte.example.com/callback", + }, + { + name: "host with trailing slash does not duplicate separator", + cfg: config.Config{ + AuthorizedURIs: []stdconfig.URL{mustParse(t, "https://flyte.example.com/")}, + }, + want: "https://flyte.example.com/callback", + }, + { + name: "picks first uri when multiple", + cfg: config.Config{ + AuthorizedURIs: []stdconfig.URL{ + mustParse(t, "https://flyte.example.com"), + mustParse(t, "http://flyte2.flyte:8080"), + }, + }, + want: "https://flyte.example.com/callback", + }, + { + name: "host with path prefix appends callback", + cfg: config.Config{ + AuthorizedURIs: []stdconfig.URL{mustParse(t, "https://flyte.example.com/v2")}, + }, + want: "https://flyte.example.com/v2/callback", + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.want, computeOIDCRedirectURL(tc.cfg)) + }) + } +} diff --git a/runs/service/auth/authzserver/claims_verifier.go b/runs/service/auth/authzserver/claims_verifier.go new file mode 100644 index 00000000000..ae284492609 --- /dev/null +++ b/runs/service/auth/authzserver/claims_verifier.go @@ -0,0 +1,105 @@ +package authzserver + +import ( + "encoding/json" + "fmt" + "time" + + jwtgo "github.com/golang-jwt/jwt/v5" + + authpb "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" + auth "github.com/flyteorg/flyte/v2/runs/service/auth" +) + +const ( + // ClientIDClaim is the JWT claim key for the client ID. + ClientIDClaim = "client_id" + // UserIDClaim is the JWT claim key for user info. + UserIDClaim = "user_info" + // ScopeClaim is the JWT claim key for scopes. + ScopeClaim = "scp" +) + +// verifyClaims extracts identity information from raw JWT claims and validates the audience. +func verifyClaims(expectedAudience map[string]bool, claims jwtgo.MapClaims) (*auth.IdentityContext, error) { + aud, err := claims.GetAudience() + if err != nil { + return nil, fmt.Errorf("failed to get audience: %w", err) + } + + matchedAudience := "" + for _, a := range aud { + if expectedAudience[a] { + matchedAudience = a + break + } + } + if matchedAudience == "" { + return nil, fmt.Errorf("invalid audience %v, wanted one of %v", aud, expectedAudience) + } + + sub, _ := claims.GetSubject() + + issuedAt := time.Time{} + if iat, err := claims.GetIssuedAt(); err == nil && iat != nil { + issuedAt = iat.Time + } + + userInfo := &authpb.UserInfoResponse{} + if userInfoClaim, found := claims[UserIDClaim]; found && userInfoClaim != nil { + if userInfoRaw, ok := userInfoClaim.(map[string]interface{}); ok { + raw, err := json.Marshal(userInfoRaw) + if err != nil { + return nil, err + } + if err = json.Unmarshal(raw, userInfo); err != nil { + return nil, fmt.Errorf("failed to unmarshal user info claim: %w", err) + } + } + } + + clientID := "" + if clientIDClaim, found := claims[ClientIDClaim]; found { + if s, ok := clientIDClaim.(string); ok { + clientID = s + } + } + + var scopes []string + if scopesClaim, found := claims[ScopeClaim]; found { + switch sct := scopesClaim.(type) { + case []interface{}: + scopes = interfaceSliceToStringSlice(sct) + case string: + scopes = []string{sct} + default: + return nil, fmt.Errorf("failed getting scope claims due to unknown type %T with value %v", sct, sct) + } + } + + // In some cases, "user_info" field doesn't exist in the raw claim, + // but we can get email from "email" field. + if emailClaim, found := claims["email"]; found { + if email, ok := emailClaim.(string); ok { + userInfo.Email = email + } + } + + // If this is a user-only access token with no scopes defined then add `all` scope by default because it's equivalent + // to having a user's login cookie or an ID Token as means of accessing the service. + if len(clientID) == 0 && len(scopes) == 0 { + scopes = []string{auth.ScopeAll} + } + + return auth.NewIdentityContext(matchedAudience, sub, clientID, issuedAt, scopes, userInfo, claims), nil +} + +func interfaceSliceToStringSlice(raw []interface{}) []string { + res := make([]string, 0, len(raw)) + for _, item := range raw { + if s, ok := item.(string); ok { + res = append(res, s) + } + } + return res +} diff --git a/runs/service/auth/authzserver/claims_verifier_test.go b/runs/service/auth/authzserver/claims_verifier_test.go new file mode 100644 index 00000000000..ea1f6e4fa04 --- /dev/null +++ b/runs/service/auth/authzserver/claims_verifier_test.go @@ -0,0 +1,114 @@ +package authzserver + +import ( + "testing" + "time" + + jwtgo "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestVerifyClaims_MatchingAudience(t *testing.T) { + claims := jwtgo.MapClaims{ + "aud": []interface{}{"https://flyte.example.com"}, + "sub": "user123", + "iat": float64(time.Now().Unix()), + } + + allowed := map[string]bool{"https://flyte.example.com": true} + identity, err := verifyClaims(allowed, claims) + require.NoError(t, err) + assert.Equal(t, "https://flyte.example.com", identity.Audience()) + assert.Equal(t, "user123", identity.UserID()) +} + +func TestVerifyClaims_NoMatchingAudience(t *testing.T) { + claims := jwtgo.MapClaims{ + "aud": []interface{}{"https://other.example.com"}, + "sub": "user123", + } + + allowed := map[string]bool{"https://flyte.example.com": true} + _, err := verifyClaims(allowed, claims) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid audience") +} + +func TestVerifyClaims_WithClientID(t *testing.T) { + claims := jwtgo.MapClaims{ + "aud": "https://flyte.example.com", + "sub": "user123", + "client_id": "my-client", + "scp": []interface{}{"read", "write"}, + } + + allowed := map[string]bool{"https://flyte.example.com": true} + identity, err := verifyClaims(allowed, claims) + require.NoError(t, err) + assert.Equal(t, "my-client", identity.AppID()) + assert.Equal(t, []string{"read", "write"}, identity.Scopes()) +} + +func TestVerifyClaims_UserOnlyDefaultsToScopeAll(t *testing.T) { + claims := jwtgo.MapClaims{ + "aud": "https://flyte.example.com", + "sub": "user123", + } + + allowed := map[string]bool{"https://flyte.example.com": true} + identity, err := verifyClaims(allowed, claims) + require.NoError(t, err) + assert.Equal(t, []string{"all"}, identity.Scopes()) + assert.Equal(t, "", identity.AppID()) +} + +func TestVerifyClaims_WithEmail(t *testing.T) { + claims := jwtgo.MapClaims{ + "aud": "https://flyte.example.com", + "sub": "user123", + "email": "user@example.com", + } + + allowed := map[string]bool{"https://flyte.example.com": true} + identity, err := verifyClaims(allowed, claims) + require.NoError(t, err) + assert.Equal(t, "user@example.com", identity.UserInfo().Email) +} + +func TestVerifyClaims_WithUserInfoClaim(t *testing.T) { + claims := jwtgo.MapClaims{ + "aud": "https://flyte.example.com", + "sub": "user123", + "user_info": map[string]interface{}{ + "name": "Test User", + "email": "test@example.com", + }, + } + + allowed := map[string]bool{"https://flyte.example.com": true} + identity, err := verifyClaims(allowed, claims) + require.NoError(t, err) + assert.Equal(t, "Test User", identity.UserInfo().Name) + assert.Equal(t, "test@example.com", identity.UserInfo().Email) +} + +func TestVerifyClaims_ScopeAsString(t *testing.T) { + claims := jwtgo.MapClaims{ + "aud": "https://flyte.example.com", + "sub": "user123", + "client_id": "my-client", + "scp": "read", + } + + allowed := map[string]bool{"https://flyte.example.com": true} + identity, err := verifyClaims(allowed, claims) + require.NoError(t, err) + assert.Equal(t, []string{"read"}, identity.Scopes()) +} + +func TestInterfaceSliceToStringSlice(t *testing.T) { + input := []interface{}{"a", "b", "c"} + result := interfaceSliceToStringSlice(input) + assert.Equal(t, []string{"a", "b", "c"}, result) +} diff --git a/runs/service/auth/authzserver/metadata_provider.go b/runs/service/auth/authzserver/metadata_provider.go new file mode 100644 index 00000000000..1ff1349b661 --- /dev/null +++ b/runs/service/auth/authzserver/metadata_provider.go @@ -0,0 +1,223 @@ +package authzserver + +import ( + "context" + "fmt" + "io" + "mime" + "net/http" + "net/url" + "strings" + "time" + + "connectrpc.com/connect" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth/authconnect" + "github.com/flyteorg/flyte/v2/flytestdlib/logger" + "github.com/flyteorg/flyte/v2/runs/service/auth/config" + authpkg "github.com/flyteorg/flyte/v2/runs/service/auth" +) + +const ( + oauth2MetadataEndpoint = ".well-known/oauth-authorization-server" +) + +var ( + tokenRelativeURL = mustParseURLPath("/oauth2/token") + authorizeRelativeURL = mustParseURLPath("/oauth2/authorize") + jsonWebKeysURL = mustParseURLPath("/oauth2/jwks") + oauth2MetadataRelURL = mustParseURLPath(oauth2MetadataEndpoint) + + supportedGrantTypes = []string{"client_credentials", "refresh_token", "authorization_code"} +) + +func mustParseURLPath(rawURL string) *url.URL { + u, err := url.Parse(rawURL) + if err != nil { + panic(err) + } + return u +} + +type authMetadataService struct { + authconnect.UnimplementedAuthMetadataServiceHandler + cfg config.Config +} + +func NewAuthMetadataService(cfg config.Config) authconnect.AuthMetadataServiceHandler { + return &authMetadataService{cfg: cfg} +} + +func (s *authMetadataService) GetOAuth2Metadata( + ctx context.Context, + _ *connect.Request[auth.GetOAuth2MetadataRequest], +) (*connect.Response[auth.GetOAuth2MetadataResponse], error) { + switch s.cfg.AppAuth.AuthServerType { + case config.AuthorizationServerTypeSelf: + return s.getOAuth2MetadataSelf(ctx) + default: + return s.getOAuth2MetadataExternal(ctx) + } +} + +func (s *authMetadataService) getOAuth2MetadataSelf(ctx context.Context) (*connect.Response[auth.GetOAuth2MetadataResponse], error) { + publicURL := authpkg.GetPublicURL(ctx, nil, s.cfg) + + resp := &auth.GetOAuth2MetadataResponse{ + Issuer: authpkg.GetIssuer(ctx, nil, s.cfg), + AuthorizationEndpoint: publicURL.ResolveReference(authorizeRelativeURL).String(), + TokenEndpoint: publicURL.ResolveReference(tokenRelativeURL).String(), + JwksUri: publicURL.ResolveReference(jsonWebKeysURL).String(), + CodeChallengeMethodsSupported: []string{"S256"}, + ResponseTypesSupported: []string{"code", "token", "code token"}, + GrantTypesSupported: supportedGrantTypes, + ScopesSupported: []string{authpkg.ScopeAll}, + TokenEndpointAuthMethodsSupported: []string{"client_secret_basic"}, + } + + return connect.NewResponse(resp), nil +} + +func (s *authMetadataService) getOAuth2MetadataExternal(ctx context.Context) (*connect.Response[auth.GetOAuth2MetadataResponse], error) { + baseURL := s.cfg.AppAuth.ExternalAuthServer.BaseURL + if baseURL.String() == "" { + return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("external auth server base URL is not configured")) + } + + // issuer urls, conventionally, do not end with a '/', however, metadata urls are usually relative of those. + // This adds a '/' to ensure ResolveReference behaves intuitively. + baseURL.Path = strings.TrimSuffix(baseURL.Path, "/") + "/" + + var externalMetadataURL *url.URL + if len(s.cfg.AppAuth.ExternalAuthServer.MetadataEndpointURL.String()) > 0 { + externalMetadataURL = baseURL.ResolveReference(&s.cfg.AppAuth.ExternalAuthServer.MetadataEndpointURL.URL) + } else { + externalMetadataURL = baseURL.ResolveReference(oauth2MetadataRelURL) + } + + httpClient := &http.Client{} + if s.cfg.HTTPProxyURL.String() != "" { + httpClient.Transport = &http.Transport{ + Proxy: http.ProxyURL(&s.cfg.HTTPProxyURL.URL), + } + } + + retryAttempts := s.cfg.AppAuth.ExternalAuthServer.RetryAttempts + if retryAttempts <= 0 { + retryAttempts = 5 + } + retryDelay := s.cfg.AppAuth.ExternalAuthServer.RetryDelay.Duration + if retryDelay <= 0 { + retryDelay = time.Second + } + + response, err := sendAndRetryHTTPRequest(ctx, httpClient, externalMetadataURL.String(), retryAttempts, retryDelay) + if err != nil { + return nil, connect.NewError(connect.CodeUnavailable, fmt.Errorf("failed to fetch OAuth2 metadata: %w", err)) + } + + raw, err := io.ReadAll(response.Body) + if err != nil { + return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("failed to read OAuth2 metadata response: %w", err)) + } + + resp := &auth.GetOAuth2MetadataResponse{} + if err := unmarshalResp(response, raw, resp); err != nil { + return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("failed to unmarshal OAuth2 metadata: %w", err)) + } + + tokenProxyConfig := s.cfg.TokenEndpointProxyConfig + if tokenProxyConfig.Enabled { + tokenEndpoint, parseErr := url.Parse(resp.TokenEndpoint) + if parseErr != nil { + return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("failed to parse token endpoint [%v], err: %v", resp.TokenEndpoint, parseErr)) + } + if len(tokenProxyConfig.PublicURL.Host) == 0 { + publicURL := authpkg.GetPublicURL(ctx, nil, s.cfg) + tokenProxyConfig.PublicURL = config.URL{URL: *publicURL} + } + tokenEndpoint.Host = tokenProxyConfig.PublicURL.Host + tokenEndpoint.Path = tokenProxyConfig.PathPrefix + tokenEndpoint.Path + tokenEndpoint.RawPath = tokenProxyConfig.PathPrefix + tokenEndpoint.RawPath + resp.TokenEndpoint = tokenEndpoint.String() + } + + return connect.NewResponse(resp), nil +} + +func (s *authMetadataService) GetPublicClientConfig( + _ context.Context, + _ *connect.Request[auth.GetPublicClientConfigRequest], +) (*connect.Response[auth.GetPublicClientConfigResponse], error) { + fc := s.cfg.AppAuth.ThirdParty.FlyteClientConfig + return connect.NewResponse(&auth.GetPublicClientConfigResponse{ + ClientId: fc.ClientID, + RedirectUri: fc.RedirectURI, + Scopes: fc.Scopes, + AuthorizationMetadataKey: s.cfg.GrpcAuthorizationHeader, + Audience: fc.Audience, + }), nil +} + +// unmarshalResp unmarshals a JSON response body into a protobuf message. It +// uses protojson.Unmarshal which accepts both the camelCase form used by +// proto3 JSON serialization and the snake_case form matching proto field +// names. This is important because external authorization servers (including +// flyteadmin) emit camelCase keys while the Go proto struct tags are +// snake_case. +func unmarshalResp(r *http.Response, body []byte, v proto.Message) error { + err := protojson.Unmarshal(body, v) + if err == nil { + return nil + } + ct := r.Header.Get("Content-Type") + mediaType, _, parseErr := mime.ParseMediaType(ct) + if parseErr == nil && mediaType == "application/json" { + return fmt.Errorf("got Content-Type = application/json, but could not unmarshal as JSON: %v", err) + } + return fmt.Errorf("expected Content-Type = application/json, got %q: %v", ct, err) +} + +// sendAndRetryHTTPRequest fetches the given URL with retry logic. +func sendAndRetryHTTPRequest(ctx context.Context, client *http.Client, targetURL string, retryAttempts int, retryDelay time.Duration) (*http.Response, error) { + var lastErr error + var lastResp *http.Response + for i := 0; i < retryAttempts; i++ { + if i > 0 { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(retryDelay): + } + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := client.Do(req) + if err != nil { + lastErr = err + logger.Warnf(ctx, "Failed to fetch %s (attempt %d/%d): %v", targetURL, i+1, retryAttempts, err) + continue + } + + if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusMultipleChoices { + return resp, nil + } + + lastErr = fmt.Errorf("unexpected status code %d from %s", resp.StatusCode, targetURL) + lastResp = resp + logger.Warnf(ctx, "Unexpected status code from %s (attempt %d/%d): %d", targetURL, i+1, retryAttempts, resp.StatusCode) + } + + if lastResp != nil && lastResp.StatusCode != http.StatusOK { + return lastResp, fmt.Errorf("failed to get oauth metadata with status code %v: %w", lastResp.StatusCode, lastErr) + } + + return nil, fmt.Errorf("all %d attempts failed for %s: %w", retryAttempts, targetURL, lastErr) +} diff --git a/runs/service/auth/authzserver/metadata_provider_test.go b/runs/service/auth/authzserver/metadata_provider_test.go new file mode 100644 index 00000000000..e50e43dcbbc --- /dev/null +++ b/runs/service/auth/authzserver/metadata_provider_test.go @@ -0,0 +1,348 @@ +package authzserver + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "connectrpc.com/connect" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" + "github.com/flyteorg/flyte/v2/runs/service/auth/config" +) + +func mustParseTestURL(rawURL string) config.URL { + u, err := url.Parse(rawURL) + if err != nil { + panic(err) + } + return config.URL{URL: *u} +} + +func TestGetPublicClientConfig(t *testing.T) { + cfg := config.Config{ + GrpcAuthorizationHeader: "flyte-authorization", + AppAuth: config.OAuth2Options{ + ThirdParty: config.ThirdPartyConfigOptions{ + FlyteClientConfig: config.FlyteClientConfig{ + ClientID: "flyte-client", + RedirectURI: "http://localhost:12345/callback", + Scopes: []string{"openid", "offline"}, + Audience: "https://flyte.example.com", + }, + }, + }, + } + + svc := NewAuthMetadataService(cfg) + resp, err := svc.GetPublicClientConfig(context.Background(), connect.NewRequest(&auth.GetPublicClientConfigRequest{})) + require.NoError(t, err) + + msg := resp.Msg + assert.Equal(t, "flyte-client", msg.ClientId) + assert.Equal(t, "http://localhost:12345/callback", msg.RedirectUri) + assert.Equal(t, []string{"openid", "offline"}, msg.Scopes) + assert.Equal(t, "flyte-authorization", msg.AuthorizationMetadataKey) + assert.Equal(t, "https://flyte.example.com", msg.Audience) +} + +func TestGetOAuth2Metadata_SelfAuthServer(t *testing.T) { + cfg := config.Config{ + AuthorizedURIs: []config.URL{ + mustParseTestURL("https://flyte.example.com"), + }, + AppAuth: config.OAuth2Options{ + AuthServerType: config.AuthorizationServerTypeSelf, + }, + } + + svc := NewAuthMetadataService(cfg) + resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.NoError(t, err) + + msg := resp.Msg + assert.Equal(t, "https://flyte.example.com", msg.Issuer) + assert.Equal(t, "https://flyte.example.com/oauth2/authorize", msg.AuthorizationEndpoint) + assert.Equal(t, "https://flyte.example.com/oauth2/token", msg.TokenEndpoint) + assert.Equal(t, "https://flyte.example.com/oauth2/jwks", msg.JwksUri) + assert.Equal(t, []string{"S256"}, msg.CodeChallengeMethodsSupported) + assert.Equal(t, []string{"code", "token", "code token"}, msg.ResponseTypesSupported) + assert.Equal(t, []string{"client_credentials", "refresh_token", "authorization_code"}, msg.GrantTypesSupported) + assert.Equal(t, []string{"all"}, msg.ScopesSupported) + assert.Equal(t, []string{"client_secret_basic"}, msg.TokenEndpointAuthMethodsSupported) +} + +func TestGetOAuth2Metadata_SelfAuthServerWithCustomIssuer(t *testing.T) { + cfg := config.Config{ + AuthorizedURIs: []config.URL{ + mustParseTestURL("https://flyte.example.com"), + }, + AppAuth: config.OAuth2Options{ + AuthServerType: config.AuthorizationServerTypeSelf, + SelfAuthServer: config.AuthorizationServer{ + Issuer: "https://custom-issuer.example.com", + }, + }, + } + + svc := NewAuthMetadataService(cfg) + resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.NoError(t, err) + + msg := resp.Msg + assert.Equal(t, "https://custom-issuer.example.com", msg.Issuer) + assert.Equal(t, "https://flyte.example.com/oauth2/authorize", msg.AuthorizationEndpoint) +} + +func TestGetOAuth2Metadata_SelfAuthServerDefaultAuthorizedURI(t *testing.T) { + cfg := config.Config{ + AuthorizedURIs: []config.URL{ + mustParseTestURL("http://localhost:8090"), + }, + AppAuth: config.OAuth2Options{ + AuthServerType: config.AuthorizationServerTypeSelf, + }, + } + + svc := NewAuthMetadataService(cfg) + resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.NoError(t, err) + + msg := resp.Msg + assert.Equal(t, "http://localhost:8090", msg.Issuer) + assert.Equal(t, "http://localhost:8090/oauth2/token", msg.TokenEndpoint) +} + +func TestGetOAuth2Metadata_ExternalAuthServer(t *testing.T) { + expectedMetadata := &auth.GetOAuth2MetadataResponse{ + Issuer: "https://external-idp.example.com", + AuthorizationEndpoint: "https://external-idp.example.com/authorize", + TokenEndpoint: "https://external-idp.example.com/token", + JwksUri: "https://external-idp.example.com/.well-known/jwks.json", + } + + metadataJSON, err := json.Marshal(expectedMetadata) + require.NoError(t, err) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write(metadataJSON) + })) + defer ts.Close() + + cfg := config.Config{ + AppAuth: config.OAuth2Options{ + AuthServerType: config.AuthorizationServerTypeExternal, + ExternalAuthServer: config.ExternalAuthorizationServer{ + BaseURL: mustParseTestURL(ts.URL), + RetryAttempts: 1, + RetryDelay: config.Duration{Duration: 100 * time.Millisecond}, + }, + }, + } + + svc := NewAuthMetadataService(cfg) + resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.NoError(t, err) + + msg := resp.Msg + assert.Equal(t, "https://external-idp.example.com", msg.Issuer) + assert.Equal(t, "https://external-idp.example.com/authorize", msg.AuthorizationEndpoint) + assert.Equal(t, "https://external-idp.example.com/token", msg.TokenEndpoint) + assert.Equal(t, "https://external-idp.example.com/.well-known/jwks.json", msg.JwksUri) +} + +func TestGetOAuth2Metadata_ExternalWithCustomMetadataURL(t *testing.T) { + expectedMetadata := &auth.GetOAuth2MetadataResponse{ + Issuer: "https://external-idp.example.com", + TokenEndpoint: "https://external-idp.example.com/token", + } + + metadataJSON, err := json.Marshal(expectedMetadata) + require.NoError(t, err) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/custom/metadata", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + w.Write(metadataJSON) + })) + defer ts.Close() + + cfg := config.Config{ + AppAuth: config.OAuth2Options{ + AuthServerType: config.AuthorizationServerTypeExternal, + ExternalAuthServer: config.ExternalAuthorizationServer{ + BaseURL: mustParseTestURL(ts.URL), + MetadataEndpointURL: mustParseTestURL(ts.URL + "/custom/metadata"), + RetryAttempts: 1, + RetryDelay: config.Duration{Duration: 100 * time.Millisecond}, + }, + }, + } + + svc := NewAuthMetadataService(cfg) + resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.NoError(t, err) + + assert.Equal(t, "https://external-idp.example.com", resp.Msg.Issuer) + assert.Equal(t, "https://external-idp.example.com/token", resp.Msg.TokenEndpoint) +} + +func TestGetOAuth2Metadata_ExternalWithTokenProxy(t *testing.T) { + expectedMetadata := &auth.GetOAuth2MetadataResponse{ + Issuer: "https://external-idp.example.com", + AuthorizationEndpoint: "https://external-idp.example.com/authorize", + TokenEndpoint: "https://external-idp.example.com/oauth/token", + } + + metadataJSON, err := json.Marshal(expectedMetadata) + require.NoError(t, err) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write(metadataJSON) + })) + defer ts.Close() + + cfg := config.Config{ + AuthorizedURIs: []config.URL{ + mustParseTestURL("https://flyte.example.com"), + }, + AppAuth: config.OAuth2Options{ + AuthServerType: config.AuthorizationServerTypeExternal, + ExternalAuthServer: config.ExternalAuthorizationServer{ + BaseURL: mustParseTestURL(ts.URL), + RetryAttempts: 1, + RetryDelay: config.Duration{Duration: 100 * time.Millisecond}, + }, + }, + TokenEndpointProxyConfig: config.TokenEndpointProxyConfig{ + Enabled: true, + }, + } + + svc := NewAuthMetadataService(cfg) + resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.NoError(t, err) + + msg := resp.Msg + assert.Equal(t, "https://external-idp.example.com", msg.Issuer) + assert.Equal(t, "https://external-idp.example.com/authorize", msg.AuthorizationEndpoint) + // Token endpoint should be rewritten to the public URL + assert.Equal(t, "https://flyte.example.com/oauth/token", msg.TokenEndpoint) +} + +func TestGetOAuth2Metadata_ExternalWithTokenProxyAndPathPrefix(t *testing.T) { + expectedMetadata := &auth.GetOAuth2MetadataResponse{ + TokenEndpoint: "https://external-idp.example.com/oauth/token", + } + + metadataJSON, err := json.Marshal(expectedMetadata) + require.NoError(t, err) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write(metadataJSON) + })) + defer ts.Close() + + cfg := config.Config{ + AppAuth: config.OAuth2Options{ + AuthServerType: config.AuthorizationServerTypeExternal, + ExternalAuthServer: config.ExternalAuthorizationServer{ + BaseURL: mustParseTestURL(ts.URL), + RetryAttempts: 1, + RetryDelay: config.Duration{Duration: 100 * time.Millisecond}, + }, + }, + TokenEndpointProxyConfig: config.TokenEndpointProxyConfig{ + Enabled: true, + PublicURL: mustParseTestURL("https://proxy.example.com"), + PathPrefix: "api/v1", + }, + } + + svc := NewAuthMetadataService(cfg) + resp, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.NoError(t, err) + + assert.Equal(t, "https://proxy.example.com/api/v1/oauth/token", resp.Msg.TokenEndpoint) +} + +func TestGetOAuth2Metadata_ExternalNoBaseURL(t *testing.T) { + cfg := config.Config{ + AppAuth: config.OAuth2Options{ + AuthServerType: config.AuthorizationServerTypeExternal, + }, + } + + svc := NewAuthMetadataService(cfg) + _, err := svc.GetOAuth2Metadata(context.Background(), connect.NewRequest(&auth.GetOAuth2MetadataRequest{})) + require.Error(t, err) + assert.Contains(t, err.Error(), "external auth server base URL is not configured") +} + +func TestSendAndRetryHTTPRequest_ImmediateSuccess(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status":"ok"}`)) + })) + defer ts.Close() + + resp, err := sendAndRetryHTTPRequest(context.Background(), http.DefaultClient, ts.URL, 3, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, http.StatusOK, resp.StatusCode) +} + +func TestSendAndRetryHTTPRequest_RetryIntoSuccess(t *testing.T) { + attempt := 0 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempt++ + if attempt < 3 { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status":"ok"}`)) + })) + defer ts.Close() + + resp, err := sendAndRetryHTTPRequest(context.Background(), http.DefaultClient, ts.URL, 5, 10*time.Millisecond) + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, 3, attempt) +} + +func TestSendAndRetryHTTPRequest_AllRetrysFail(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer ts.Close() + + _, err := sendAndRetryHTTPRequest(context.Background(), http.DefaultClient, ts.URL, 3, 10*time.Millisecond) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to get oauth metadata") +} + +func TestSendAndRetryHTTPRequest_ContextCancelled(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer ts.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := sendAndRetryHTTPRequest(ctx, http.DefaultClient, ts.URL, 5, 10*time.Millisecond) + require.Error(t, err) +} diff --git a/runs/service/auth/authzserver/resource_server.go b/runs/service/auth/authzserver/resource_server.go new file mode 100644 index 00000000000..39429aba5d3 --- /dev/null +++ b/runs/service/auth/authzserver/resource_server.go @@ -0,0 +1,109 @@ +package authzserver + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + "strings" + + "github.com/coreos/go-oidc/v3/oidc" + jwtgo "github.com/golang-jwt/jwt/v5" + + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" + "github.com/flyteorg/flyte/v2/runs/service/auth/config" + authpkg "github.com/flyteorg/flyte/v2/runs/service/auth" +) + +// ResourceServer authorizes access requests issued by an external Authorization Server. +type ResourceServer struct { + signatureVerifier oidc.KeySet + allowedAudience []string +} + +// NewOAuth2ResourceServer initializes a new OAuth2ResourceServer. +func NewOAuth2ResourceServer(ctx context.Context, cfg config.ExternalAuthorizationServer, fallbackBaseURL config.URL) (*ResourceServer, error) { + u := cfg.BaseURL + if len(u.String()) == 0 { + u = fallbackBaseURL + } + + verifier, err := getJwksForIssuer(ctx, u.URL, cfg) + if err != nil { + return nil, err + } + + return &ResourceServer{ + signatureVerifier: verifier, + allowedAudience: cfg.AllowedAudience, + }, nil +} + +// ValidateAccessToken verifies the token signature, validates claims, and returns the identity context. +func (r *ResourceServer) ValidateAccessToken(ctx context.Context, expectedAudience, tokenStr string) (*authpkg.IdentityContext, error) { + _, err := r.signatureVerifier.VerifySignature(ctx, tokenStr) + if err != nil { + return nil, fmt.Errorf("failed to verify token signature: %w", err) + } + + claims := jwtgo.MapClaims{} + parser := jwtgo.NewParser() + if _, _, err = parser.ParseUnverified(tokenStr, claims); err != nil { + return nil, fmt.Errorf("failed to parse token: %w", err) + } + + allowed := make(map[string]bool, len(r.allowedAudience)+1) + for _, a := range r.allowedAudience { + allowed[a] = true + } + allowed[expectedAudience] = true + + return verifyClaims(allowed, claims) +} + +// getJwksForIssuer fetches the OAuth2 metadata from the external auth server and returns the remote JWKS key set. +func getJwksForIssuer(ctx context.Context, issuerBaseURL url.URL, cfg config.ExternalAuthorizationServer) (oidc.KeySet, error) { + issuerBaseURL.Path = strings.TrimSuffix(issuerBaseURL.Path, "/") + "/" + + var wellKnown *url.URL + if len(cfg.MetadataEndpointURL.String()) > 0 { + wellKnown = issuerBaseURL.ResolveReference(&cfg.MetadataEndpointURL.URL) + } else { + wellKnown = issuerBaseURL.ResolveReference(oauth2MetadataRelURL) + } + + httpClient := &http.Client{} + if len(cfg.HTTPProxyURL.String()) > 0 { + httpClient.Transport = &http.Transport{ + Proxy: http.ProxyURL(&cfg.HTTPProxyURL.URL), + } + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnown.String(), nil) + if err != nil { + return nil, err + } + + resp, err := httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("unable to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("%s: %s", resp.Status, body) + } + + p := &auth.GetOAuth2MetadataResponse{} + if err = unmarshalResp(resp, body, p); err != nil { + return nil, fmt.Errorf("failed to decode provider discovery object: %w", err) + } + + return oidc.NewRemoteKeySet(oidc.ClientContext(ctx, httpClient), p.JwksUri), nil +} diff --git a/runs/service/auth/authzserver/resource_server_test.go b/runs/service/auth/authzserver/resource_server_test.go new file mode 100644 index 00000000000..1fbbb43f251 --- /dev/null +++ b/runs/service/auth/authzserver/resource_server_test.go @@ -0,0 +1,101 @@ +package authzserver + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" + "github.com/flyteorg/flyte/v2/runs/service/auth/config" +) + +func TestGetJwksForIssuer_Success(t *testing.T) { + metadata := &auth.GetOAuth2MetadataResponse{ + JwksUri: "https://example.com/.well-known/jwks.json", + } + metadataJSON, err := json.Marshal(metadata) + require.NoError(t, err) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write(metadataJSON) + })) + defer ts.Close() + + cfg := config.ExternalAuthorizationServer{ + BaseURL: mustParseTestURL(ts.URL), + } + + keySet, err := getJwksForIssuer(context.Background(), cfg.BaseURL.URL, cfg) + require.NoError(t, err) + assert.NotNil(t, keySet) +} + +func TestGetJwksForIssuer_CustomMetadataURL(t *testing.T) { + metadata := &auth.GetOAuth2MetadataResponse{ + JwksUri: "https://example.com/.well-known/jwks.json", + } + metadataJSON, err := json.Marshal(metadata) + require.NoError(t, err) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/custom/metadata", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + w.Write(metadataJSON) + })) + defer ts.Close() + + cfg := config.ExternalAuthorizationServer{ + BaseURL: mustParseTestURL(ts.URL), + MetadataEndpointURL: mustParseTestURL(ts.URL + "/custom/metadata"), + } + + keySet, err := getJwksForIssuer(context.Background(), cfg.BaseURL.URL, cfg) + require.NoError(t, err) + assert.NotNil(t, keySet) +} + +func TestGetJwksForIssuer_ServerError(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("server error")) + })) + defer ts.Close() + + cfg := config.ExternalAuthorizationServer{ + BaseURL: mustParseTestURL(ts.URL), + } + + _, err := getJwksForIssuer(context.Background(), cfg.BaseURL.URL, cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "500") +} + +func TestNewOAuth2ResourceServer_FallbackBaseURL(t *testing.T) { + metadata := &auth.GetOAuth2MetadataResponse{ + JwksUri: "https://example.com/.well-known/jwks.json", + } + metadataJSON, err := json.Marshal(metadata) + require.NoError(t, err) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write(metadataJSON) + })) + defer ts.Close() + + cfg := config.ExternalAuthorizationServer{ + AllowedAudience: []string{"https://flyte.example.com"}, + } + fallback := mustParseTestURL(ts.URL) + + rs, err := NewOAuth2ResourceServer(context.Background(), cfg, config.URL(fallback)) + require.NoError(t, err) + assert.NotNil(t, rs) + assert.Equal(t, []string{"https://flyte.example.com"}, rs.allowedAudience) +} diff --git a/runs/service/auth/config/authorizationservertype_enumer.go b/runs/service/auth/config/authorizationservertype_enumer.go new file mode 100644 index 00000000000..f6e89a6fded --- /dev/null +++ b/runs/service/auth/config/authorizationservertype_enumer.go @@ -0,0 +1,67 @@ +// Code generated by "enumer --type=AuthorizationServerType --trimprefix=AuthorizationServerType -json"; DO NOT EDIT. + +package config + +import ( + "encoding/json" + "fmt" +) + +const _AuthorizationServerTypeName = "SelfExternal" + +var _AuthorizationServerTypeIndex = [...]uint8{0, 4, 12} + +func (i AuthorizationServerType) String() string { + if i < 0 || i >= AuthorizationServerType(len(_AuthorizationServerTypeIndex)-1) { + return fmt.Sprintf("AuthorizationServerType(%d)", i) + } + return _AuthorizationServerTypeName[_AuthorizationServerTypeIndex[i]:_AuthorizationServerTypeIndex[i+1]] +} + +var _AuthorizationServerTypeValues = []AuthorizationServerType{0, 1} + +var _AuthorizationServerTypeNameToValueMap = map[string]AuthorizationServerType{ + _AuthorizationServerTypeName[0:4]: 0, + _AuthorizationServerTypeName[4:12]: 1, +} + +// AuthorizationServerTypeString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func AuthorizationServerTypeString(s string) (AuthorizationServerType, error) { + if val, ok := _AuthorizationServerTypeNameToValueMap[s]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to AuthorizationServerType values", s) +} + +// AuthorizationServerTypeValues returns all values of the enum +func AuthorizationServerTypeValues() []AuthorizationServerType { + return _AuthorizationServerTypeValues +} + +// IsAAuthorizationServerType returns "true" if the value is listed in the enum definition. "false" otherwise +func (i AuthorizationServerType) IsAAuthorizationServerType() bool { + for _, v := range _AuthorizationServerTypeValues { + if i == v { + return true + } + } + return false +} + +// MarshalJSON implements the json.Marshaler interface for AuthorizationServerType +func (i AuthorizationServerType) MarshalJSON() ([]byte, error) { + return json.Marshal(i.String()) +} + +// UnmarshalJSON implements the json.Unmarshaler interface for AuthorizationServerType +func (i *AuthorizationServerType) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return fmt.Errorf("AuthorizationServerType should be a string, got %s", data) + } + + var err error + *i, err = AuthorizationServerTypeString(s) + return err +} diff --git a/runs/service/auth/config/config.go b/runs/service/auth/config/config.go new file mode 100644 index 00000000000..7d217900127 --- /dev/null +++ b/runs/service/auth/config/config.go @@ -0,0 +1,300 @@ +package config + +import ( + "net/url" + + "github.com/flyteorg/flyte/v2/flytestdlib/config" +) + +//go:generate pflags Config --default-var=DefaultConfig +//go:generate enumer --type=AuthorizationServerType --trimprefix=AuthorizationServerType -json +//go:generate enumer --type=SameSite --trimprefix=SameSite -json + +type SecretName = string + +const ( + // SecretNameOIdCClientSecret defines the default OIdC client secret name to use. + // #nosec + SecretNameOIdCClientSecret SecretName = "oidc_client_secret" + + // SecretNameCookieHashKey defines the default cookie hash key secret name to use. + // #nosec + SecretNameCookieHashKey SecretName = "cookie_hash_key" + + // SecretNameCookieBlockKey defines the default cookie block key secret name to use. + // #nosec + SecretNameCookieBlockKey SecretName = "cookie_block_key" + + // SecretNameClaimSymmetricKey must be a base64 encoded secret of exactly 32 bytes. + // #nosec + SecretNameClaimSymmetricKey SecretName = "claim_symmetric_key" + + // SecretNameTokenSigningRSAKey is the private key used to sign JWT tokens (RS256). + // #nosec + SecretNameTokenSigningRSAKey SecretName = "token_rsa_key.pem" + + // SecretNameOldTokenSigningRSAKey is the old private key for key rotation. Only used to + // validate incoming tokens; new tokens will not be issued with this key. + // #nosec + SecretNameOldTokenSigningRSAKey SecretName = "token_rsa_key_old.pem" +) + +// AuthorizationServerType defines the type of Authorization Server to use. +type AuthorizationServerType int + +const ( + // AuthorizationServerTypeSelf indicates the service acts as its own authorization server. + AuthorizationServerTypeSelf AuthorizationServerType = iota + // AuthorizationServerTypeExternal indicates an external authorization server is used. + AuthorizationServerTypeExternal +) + +// SameSite represents the SameSite cookie policy. +type SameSite int + +const ( + SameSiteDefaultMode SameSite = iota + SameSiteLaxMode + SameSiteStrictMode + SameSiteNoneMode +) + +var ( + DefaultConfig = &Config{ + HTTPAuthorizationHeader: "flyte-authorization", + GrpcAuthorizationHeader: "flyte-authorization", + UserAuth: UserAuthConfig{ + RedirectURL: config.URL{URL: *MustParseURL("/console")}, + CookieHashKeySecretName: SecretNameCookieHashKey, + CookieBlockKeySecretName: SecretNameCookieBlockKey, + OpenID: OpenIDOptions{ + ClientSecretName: SecretNameOIdCClientSecret, + Scopes: []string{ + "openid", + "profile", + }, + }, + CookieSetting: CookieSettings{ + Domain: "", + SameSitePolicy: SameSiteDefaultMode, + }, + }, + AppAuth: OAuth2Options{ + ExternalAuthServer: ExternalAuthorizationServer{ + RetryAttempts: 5, + RetryDelay: config.Duration{Duration: 1_000_000_000}, // 1 second + }, + AuthServerType: AuthorizationServerTypeSelf, + ThirdParty: ThirdPartyConfigOptions{ + FlyteClientConfig: FlyteClientConfig{ + ClientID: "flytectl", + RedirectURI: "http://localhost:53593/callback", + Scopes: []string{"all", "offline"}, + }, + }, + }, + } + + cfgSection = config.MustRegisterSection("auth", DefaultConfig) +) + +// Config holds the full authentication configuration. +type Config struct { + // HTTPAuthorizationHeader is the HTTP header name for authorization (for non-standard headers behind Envoy). + HTTPAuthorizationHeader string `json:"httpAuthorizationHeader"` + + // GrpcAuthorizationHeader is the gRPC metadata key for authorization. + GrpcAuthorizationHeader string `json:"grpcAuthorizationHeader"` + + // DisableForHTTP disables auth enforcement on HTTP endpoints. + DisableForHTTP bool `json:"disableForHttp" pflag:",Disables auth enforcement on HTTP Endpoints."` + + // DisableForGrpc disables auth enforcement on gRPC endpoints. + DisableForGrpc bool `json:"disableForGrpc" pflag:",Disables auth enforcement on Grpc Endpoints."` + + // AuthorizedURIs defines the set of URIs that clients are allowed to visit the service on. + AuthorizedURIs []config.URL `json:"authorizedUris" pflag:"-,Defines the set of URIs that clients are allowed to visit the service on."` + + // HTTPProxyURL allows accessing external OAuth2 servers through an HTTP proxy. + HTTPProxyURL config.URL `json:"httpProxyURL" pflag:",OPTIONAL: HTTP Proxy to be used for OAuth requests."` + + // UserAuth settings used to authenticate end users in web-browsers. + UserAuth UserAuthConfig `json:"userAuth" pflag:",Defines Auth options for users."` + + // AppAuth defines app-level OAuth2 settings. + AppAuth OAuth2Options `json:"appAuth" pflag:",Defines Auth options for apps. UserAuth must be enabled for AppAuth to work."` + + // SecureCookie sets the Secure flag on auth cookies. Should be true in production (HTTPS). + SecureCookie bool `json:"secureCookie" pflag:",Set the Secure flag on auth cookies"` + + // TokenEndpointProxyConfig proxies token endpoint calls through admin. + TokenEndpointProxyConfig TokenEndpointProxyConfig `json:"tokenEndpointProxyConfig" pflag:",Configuration for proxying token endpoint requests."` +} + +// OAuth2Options holds OAuth2 authorization server options. +type OAuth2Options struct { + // AuthServerType determines whether to use a self-hosted or external auth server. + AuthServerType AuthorizationServerType `json:"authServerType"` + + // SelfAuthServer configures the self-hosted authorization server. + SelfAuthServer AuthorizationServer `json:"selfAuthServer" pflag:",Authorization Server config to run as a service."` + + // ExternalAuthServer configures the external authorization server. + ExternalAuthServer ExternalAuthorizationServer `json:"externalAuthServer" pflag:",External Authorization Server config."` + + // ThirdParty configures third-party (public client) settings. + ThirdParty ThirdPartyConfigOptions `json:"thirdPartyConfig" pflag:",Defines settings to instruct flyte cli tools on what config to use."` +} + +// AuthorizationServer configures a self-hosted authorization server. +type AuthorizationServer struct { + // Issuer is the issuer URL. If empty, the first AuthorizedURI is used. + Issuer string `json:"issuer" pflag:",Defines the issuer to use when issuing and validating tokens."` + + // AccessTokenLifespan defines the lifespan of issued access tokens. + AccessTokenLifespan config.Duration `json:"accessTokenLifespan" pflag:",Defines the lifespan of issued access tokens."` + + // RefreshTokenLifespan defines the lifespan of issued refresh tokens. + RefreshTokenLifespan config.Duration `json:"refreshTokenLifespan" pflag:",Defines the lifespan of issued refresh tokens."` + + // AuthorizationCodeLifespan defines the lifespan of issued authorization codes. + AuthorizationCodeLifespan config.Duration `json:"authorizationCodeLifespan" pflag:",Defines the lifespan of issued authorization codes."` + + // ClaimSymmetricEncryptionKeySecretName is the secret name for claim encryption. + ClaimSymmetricEncryptionKeySecretName string `json:"claimSymmetricEncryptionKeySecretName" pflag:",Secret name for claim encryption key."` + + // TokenSigningRSAKeySecretName is the secret name for the RSA signing key. + TokenSigningRSAKeySecretName string `json:"tokenSigningRSAKeySecretName" pflag:",Secret name for RSA Signing Key."` + + // OldTokenSigningRSAKeySecretName is the secret name for the old RSA signing key (key rotation). + OldTokenSigningRSAKeySecretName string `json:"oldTokenSigningRSAKeySecretName" pflag:",Secret name for Old RSA Signing Key for key rotation."` +} + +// ExternalAuthorizationServer configures an external authorization server. +type ExternalAuthorizationServer struct { + // BaseURL is the base URL of the external authorization server. + BaseURL config.URL `json:"baseUrl" pflag:",Base url of the external authorization server."` + + // AllowedAudience is the set of audiences accepted when validating access tokens. + AllowedAudience []string `json:"allowedAudience" pflag:",A list of allowed audiences."` + + // MetadataEndpointURL overrides the default .well-known/oauth-authorization-server endpoint. + MetadataEndpointURL config.URL `json:"metadataUrl" pflag:",Custom metadata url if the server doesn't support the standard endpoint."` + + // HTTPProxyURL allows accessing the external auth server through an HTTP proxy. + HTTPProxyURL config.URL `json:"httpProxyURL" pflag:",HTTP Proxy for external OAuth requests."` + + // RetryAttempts is the number of retry attempts for fetching metadata. + RetryAttempts int `json:"retryAttempts" pflag:",Number of retry attempts for metadata fetch."` + + // RetryDelay is the delay between retry attempts. + RetryDelay config.Duration `json:"retryDelay" pflag:",Duration to wait between retries."` +} + +// ThirdPartyConfigOptions holds third-party OAuth2 client settings. +type ThirdPartyConfigOptions struct { + // FlyteClientConfig holds public client configuration. + FlyteClientConfig FlyteClientConfig `json:"flyteClient"` +} + +// IsEmpty returns true if the third-party config has no meaningful values set. +func (o ThirdPartyConfigOptions) IsEmpty() bool { + return len(o.FlyteClientConfig.ClientID) == 0 && + len(o.FlyteClientConfig.RedirectURI) == 0 && + len(o.FlyteClientConfig.Scopes) == 0 +} + +// FlyteClientConfig holds the public client configuration. +type FlyteClientConfig struct { + // ClientID is the public client ID. + ClientID string `json:"clientId" pflag:",Public identifier for the app which handles authorization."` + + // RedirectURI is the redirect URI for the client. + RedirectURI string `json:"redirectUri" pflag:",Callback uri registered with the app which handles authorization."` + + // Scopes are the OAuth2 scopes to request. + Scopes []string `json:"scopes" pflag:",Recommended scopes for the client to request."` + + // Audience is the intended audience for OAuth2 tokens. + Audience string `json:"audience" pflag:",Audience to use when initiating OAuth2 authorization requests."` +} + +// UserAuthConfig holds user authentication settings (browser-based OAuth2/OIDC flows). +type UserAuthConfig struct { + // RedirectURL is the default redirect URL after the OAuth2 flow completes. + RedirectURL config.URL `json:"redirectUrl"` + + // OpenID defines settings for connecting and trusting an OpenID Connect provider. + OpenID OpenIDOptions `json:"openId" pflag:",OpenID Configuration for User Auth"` + + // HTTPProxyURL allows operators to access external OAuth2 servers using an HTTP Proxy. + HTTPProxyURL config.URL `json:"httpProxyURL" pflag:",HTTP Proxy for OAuth requests."` + + // CookieHashKeySecretName is the secret name for the cookie hash key. + CookieHashKeySecretName string `json:"cookieHashKeySecretName" pflag:",Secret name for cookie hash key."` + + // CookieBlockKeySecretName is the secret name for the cookie block key. + CookieBlockKeySecretName string `json:"cookieBlockKeySecretName" pflag:",Secret name for cookie block key."` + + // CookieSetting configures cookie behavior. + CookieSetting CookieSettings `json:"cookieSetting" pflag:",Settings for auth cookies."` + + // IDPQueryParameter is used to select a particular IDP for user authentication. + IDPQueryParameter string `json:"idpQueryParameter" pflag:",IDP query parameter for selecting a particular IDP."` +} + +// OpenIDOptions holds OpenID Connect provider configuration. +type OpenIDOptions struct { + // ClientID is the client ID for this service in the IDP. + ClientID string `json:"clientId"` + + // ClientSecretName is the secret name containing the OIDC client secret. + ClientSecretName string `json:"clientSecretName"` + + // BaseURL is the base URL of the OIDC provider. + BaseURL config.URL `json:"baseUrl"` + + // Scopes to request from the IDP when authenticating. + Scopes []string `json:"scopes"` +} + +// CookieSettings configures cookie behavior. +type CookieSettings struct { + // SameSitePolicy controls the SameSite attribute on auth cookies. + SameSitePolicy SameSite `json:"sameSitePolicy" pflag:",SameSite policy for auth cookies."` + + // Domain sets the domain attribute on auth cookies. + Domain string `json:"domain" pflag:",Domain attribute on auth cookies."` +} + +// TokenEndpointProxyConfig configures proxying of token endpoint calls. +type TokenEndpointProxyConfig struct { + // Enabled enables token endpoint proxying. + Enabled bool `json:"enabled" pflag:",Enables the token endpoint proxy."` + + // PublicURL is the public URL to use for rewriting the token endpoint. + PublicURL config.URL `json:"publicUrl" pflag:",Public URL for the token endpoint proxy."` + + // PathPrefix is appended to the public URL when rewriting. + PathPrefix string `json:"pathPrefix" pflag:",Path prefix for proxying token requests."` +} + +// URL is an alias for flytestdlib config.URL, re-exported for convenience. +type URL = config.URL + +// Duration is an alias for flytestdlib config.Duration, re-exported for convenience. +type Duration = config.Duration + +// GetConfig returns the parsed auth configuration. +func GetConfig() *Config { + return cfgSection.GetConfig().(*Config) +} + +// MustParseURL panics if the provided url fails parsing. Should only be used in package initialization or tests. +func MustParseURL(rawURL string) *url.URL { + res, err := url.Parse(rawURL) + if err != nil { + panic(err) + } + return res +} diff --git a/runs/service/auth/config/config_test.go b/runs/service/auth/config/config_test.go new file mode 100644 index 00000000000..216a9890566 --- /dev/null +++ b/runs/service/auth/config/config_test.go @@ -0,0 +1,112 @@ +package config + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAuthorizationServerType_String(t *testing.T) { + assert.Equal(t, "Self", AuthorizationServerTypeSelf.String()) + assert.Equal(t, "External", AuthorizationServerTypeExternal.String()) + assert.Equal(t, "AuthorizationServerType(99)", AuthorizationServerType(99).String()) +} + +func TestAuthorizationServerTypeString(t *testing.T) { + v, err := AuthorizationServerTypeString("Self") + require.NoError(t, err) + assert.Equal(t, AuthorizationServerTypeSelf, v) + + v, err = AuthorizationServerTypeString("External") + require.NoError(t, err) + assert.Equal(t, AuthorizationServerTypeExternal, v) + + _, err = AuthorizationServerTypeString("bogus") + assert.Error(t, err) +} + +func TestAuthorizationServerType_JSON(t *testing.T) { + b, err := json.Marshal(AuthorizationServerTypeExternal) + require.NoError(t, err) + assert.JSONEq(t, `"External"`, string(b)) + + var out AuthorizationServerType + require.NoError(t, json.Unmarshal([]byte(`"Self"`), &out)) + assert.Equal(t, AuthorizationServerTypeSelf, out) + + assert.Error(t, json.Unmarshal([]byte(`"bogus"`), &out)) + assert.Error(t, json.Unmarshal([]byte(`42`), &out)) +} + +func TestAuthorizationServerType_IsA(t *testing.T) { + assert.True(t, AuthorizationServerTypeSelf.IsAAuthorizationServerType()) + assert.True(t, AuthorizationServerTypeExternal.IsAAuthorizationServerType()) + assert.False(t, AuthorizationServerType(99).IsAAuthorizationServerType()) + assert.ElementsMatch(t, + []AuthorizationServerType{AuthorizationServerTypeSelf, AuthorizationServerTypeExternal}, + AuthorizationServerTypeValues()) +} + +func TestSameSite_String(t *testing.T) { + assert.Equal(t, "DefaultMode", SameSiteDefaultMode.String()) + assert.Equal(t, "LaxMode", SameSiteLaxMode.String()) + assert.Equal(t, "StrictMode", SameSiteStrictMode.String()) + assert.Equal(t, "NoneMode", SameSiteNoneMode.String()) + assert.Equal(t, "SameSite(99)", SameSite(99).String()) +} + +func TestSameSiteString(t *testing.T) { + v, err := SameSiteString("StrictMode") + require.NoError(t, err) + assert.Equal(t, SameSiteStrictMode, v) + + _, err = SameSiteString("bogus") + assert.Error(t, err) +} + +func TestSameSite_JSON(t *testing.T) { + b, err := json.Marshal(SameSiteNoneMode) + require.NoError(t, err) + assert.JSONEq(t, `"NoneMode"`, string(b)) + + var out SameSite + require.NoError(t, json.Unmarshal([]byte(`"LaxMode"`), &out)) + assert.Equal(t, SameSiteLaxMode, out) + + assert.Error(t, json.Unmarshal([]byte(`"bogus"`), &out)) +} + +func TestSameSite_IsA(t *testing.T) { + assert.True(t, SameSiteDefaultMode.IsASameSite()) + assert.True(t, SameSiteNoneMode.IsASameSite()) + assert.False(t, SameSite(99).IsASameSite()) + assert.Len(t, SameSiteValues(), 4) +} + +func TestThirdPartyConfigOptions_IsEmpty(t *testing.T) { + assert.True(t, ThirdPartyConfigOptions{}.IsEmpty()) + assert.False(t, ThirdPartyConfigOptions{ + FlyteClientConfig: FlyteClientConfig{ClientID: "x"}, + }.IsEmpty()) + assert.False(t, ThirdPartyConfigOptions{ + FlyteClientConfig: FlyteClientConfig{Scopes: []string{"all"}}, + }.IsEmpty()) +} + +func TestMustParseURL(t *testing.T) { + u := MustParseURL("https://example.com/path") + require.NotNil(t, u) + assert.Equal(t, "example.com", u.Host) + + assert.Panics(t, func() { MustParseURL("://bogus") }) +} + +func TestDefaultConfig(t *testing.T) { + require.NotNil(t, DefaultConfig) + assert.Equal(t, "flyte-authorization", DefaultConfig.HTTPAuthorizationHeader) + assert.Equal(t, AuthorizationServerTypeSelf, DefaultConfig.AppAuth.AuthServerType) + assert.Equal(t, SameSiteDefaultMode, DefaultConfig.UserAuth.CookieSetting.SameSitePolicy) + assert.Contains(t, DefaultConfig.UserAuth.OpenID.Scopes, "openid") +} diff --git a/runs/service/auth/config/samesite_enumer.go b/runs/service/auth/config/samesite_enumer.go new file mode 100644 index 00000000000..e42e58fbbe5 --- /dev/null +++ b/runs/service/auth/config/samesite_enumer.go @@ -0,0 +1,69 @@ +// Code generated by "enumer --type=SameSite --trimprefix=SameSite -json"; DO NOT EDIT. + +package config + +import ( + "encoding/json" + "fmt" +) + +const _SameSiteName = "DefaultModeLaxModeStrictModeNoneMode" + +var _SameSiteIndex = [...]uint8{0, 11, 18, 28, 36} + +func (i SameSite) String() string { + if i < 0 || i >= SameSite(len(_SameSiteIndex)-1) { + return fmt.Sprintf("SameSite(%d)", i) + } + return _SameSiteName[_SameSiteIndex[i]:_SameSiteIndex[i+1]] +} + +var _SameSiteValues = []SameSite{0, 1, 2, 3} + +var _SameSiteNameToValueMap = map[string]SameSite{ + _SameSiteName[0:11]: 0, + _SameSiteName[11:18]: 1, + _SameSiteName[18:28]: 2, + _SameSiteName[28:36]: 3, +} + +// SameSiteString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func SameSiteString(s string) (SameSite, error) { + if val, ok := _SameSiteNameToValueMap[s]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to SameSite values", s) +} + +// SameSiteValues returns all values of the enum +func SameSiteValues() []SameSite { + return _SameSiteValues +} + +// IsASameSite returns "true" if the value is listed in the enum definition. "false" otherwise +func (i SameSite) IsASameSite() bool { + for _, v := range _SameSiteValues { + if i == v { + return true + } + } + return false +} + +// MarshalJSON implements the json.Marshaler interface for SameSite +func (i SameSite) MarshalJSON() ([]byte, error) { + return json.Marshal(i.String()) +} + +// UnmarshalJSON implements the json.Unmarshaler interface for SameSite +func (i *SameSite) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return fmt.Errorf("SameSite should be a string, got %s", data) + } + + var err error + *i, err = SameSiteString(s) + return err +} diff --git a/runs/service/auth/constants.go b/runs/service/auth/constants.go new file mode 100644 index 00000000000..295ac8e4655 --- /dev/null +++ b/runs/service/auth/constants.go @@ -0,0 +1,22 @@ +package auth + +const ( + // OAuth2 Parameters + CsrfFormKey = "state" + AuthorizationResponseCodeType = "code" + DefaultAuthorizationHeader = "authorization" + BearerScheme = "Bearer" + IDTokenScheme = "IDToken" + // Add the -bin suffix so that the header value is automatically base64 encoded + UserInfoMDKey = "UserInfo-bin" + + // https://tools.ietf.org/html/rfc8414 + // This should be defined without a leading slash. If there is one, the url library's ResolveReference will make it a root path + OAuth2MetadataEndpoint = ".well-known/oauth-authorization-server" + + // https://openid.net/specs/openid-connect-discovery-1_0.html + // This should be defined without a leading slash. If there is one, the url library's ResolveReference will make it a root path + OIdCMetadataEndpoint = ".well-known/openid-configuration" + + RedirectURLParameter = "redirect_url" +) diff --git a/runs/service/auth/cookie.go b/runs/service/auth/cookie.go new file mode 100644 index 00000000000..b8bb54f6355 --- /dev/null +++ b/runs/service/auth/cookie.go @@ -0,0 +1,190 @@ +package auth + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "math/rand" + "net/http" + "net/url" + + "github.com/gorilla/securecookie" + + "github.com/flyteorg/flyte/v2/flytestdlib/logger" + "github.com/flyteorg/flyte/v2/runs/service/auth/config" +) + +const ( + // #nosec + accessTokenCookieName = "flyte_at" + // #nosec + accessTokenCookieNameSplitFirst = "flyte_at_1" + // #nosec + accessTokenCookieNameSplitSecond = "flyte_at_2" + // #nosec + idTokenCookieName = "flyte_idt" + // #nosec + refreshTokenCookieName = "flyte_rt" + // #nosec + csrfStateCookieName = "flyte_csrf_state" + // #nosec + redirectURLCookieName = "flyte_redirect_location" + + // #nosec + idTokenExtra = "id_token" + + // #nosec + authCodeCookieName = "flyte_auth_code" + + // #nosec + userInfoCookieName = "flyte_user_info" +) + +var AllowedChars = []rune("abcdefghijklmnopqrstuvwxyz1234567890") + +func HashCsrfState(csrf string) string { + shaBytes := sha256.Sum256([]byte(csrf)) + hash := hex.EncodeToString(shaBytes[:]) + return hash +} + +func NewSecureCookie(cookieName, value string, hashKey, blockKey []byte, domain string, sameSiteMode http.SameSite) (http.Cookie, error) { + s := securecookie.New(hashKey, blockKey) + encoded, err := s.Encode(cookieName, value) + if err != nil { + return http.Cookie{}, fmt.Errorf("error creating secure cookie: %w", err) + } + + return http.Cookie{ + Name: cookieName, + Value: encoded, + Domain: domain, + SameSite: sameSiteMode, + HttpOnly: true, + Secure: config.GetConfig().SecureCookie, + Path: "/", + }, nil +} + +func retrieveSecureCookie(ctx context.Context, request *http.Request, cookieName string, hashKey, blockKey []byte) (string, error) { + cookie, err := request.Cookie(cookieName) + if err != nil { + logger.Infof(ctx, "Could not detect existing cookie [%v]. Error: %v", cookieName, err) + return "", fmt.Errorf("failure to retrieve cookie [%v]: %w", cookieName, err) + } + + if cookie == nil { + logger.Infof(ctx, "Retrieved empty cookie [%v].", cookieName) + return "", fmt.Errorf("retrieved empty cookie [%v]", cookieName) + } + + logger.Debugf(ctx, "Existing [%v] cookie found", cookieName) + token, err := ReadSecureCookie(ctx, *cookie, hashKey, blockKey) + if err != nil { + logger.Errorf(ctx, "Error reading existing secure cookie [%v]. Error: %s", cookieName, err) + return "", fmt.Errorf("error reading existing secure cookie [%v]: %w", cookieName, err) + } + + if len(token) == 0 { + logger.Errorf(ctx, "Read empty token from secure cookie [%v].", cookieName) + return "", fmt.Errorf("read empty token from secure cookie [%v]", cookieName) + } + + return token, nil +} + +func ReadSecureCookie(ctx context.Context, cookie http.Cookie, hashKey, blockKey []byte) (string, error) { + s := securecookie.New(hashKey, blockKey) + var value string + if err := s.Decode(cookie.Name, cookie.Value, &value); err == nil { + return value, nil + } + logger.Errorf(ctx, "Error reading secure cookie %s", cookie.Name) + return "", fmt.Errorf("error reading secure cookie %s", cookie.Name) +} + +func NewCsrfToken(seed int64) string { + r := rand.New(rand.NewSource(seed)) //nolint:gosec + csrfToken := [10]rune{} + for i := 0; i < len(csrfToken); i++ { + csrfToken[i] = AllowedChars[r.Intn(len(AllowedChars))] + } + return string(csrfToken[:]) +} + +func NewCsrfCookie() http.Cookie { + csrfStateToken := NewCsrfToken(rand.Int63()) //nolint:gosec + return http.Cookie{ + Name: csrfStateCookieName, + Value: csrfStateToken, + SameSite: http.SameSiteLaxMode, + HttpOnly: true, + Secure: config.GetConfig().SecureCookie, + } +} + +func VerifyCsrfCookie(_ context.Context, request *http.Request) error { + csrfState := request.FormValue(CsrfFormKey) + if csrfState == "" { + return fmt.Errorf("empty state in callback, %s", request.Form) + } + csrfCookie, err := request.Cookie(csrfStateCookieName) + if csrfCookie == nil || err != nil { + return fmt.Errorf("could not find csrf cookie: %v", err) + } + if HashCsrfState(csrfCookie.Value) != csrfState { + return fmt.Errorf("CSRF token does not match state %s, %s vs %s", csrfCookie.Value, + HashCsrfState(csrfCookie.Value), csrfState) + } + return nil +} + +// NewRedirectCookie creates a cookie to keep track of where to send the user after +// the OAuth2 login flow is complete. +func NewRedirectCookie(ctx context.Context, redirectURL string) *http.Cookie { + urlObj, err := url.Parse(redirectURL) + if err != nil || urlObj == nil { + logger.Errorf(ctx, "Error creating redirect cookie %s %s", urlObj, err) + return nil + } + + if urlObj.EscapedPath() == "" { + logger.Errorf(ctx, "Error parsing URL, redirect %s resolved to empty string", redirectURL) + return nil + } + + return &http.Cookie{ + Name: redirectURLCookieName, + Value: urlObj.String(), + SameSite: http.SameSiteLaxMode, + HttpOnly: true, + Secure: config.GetConfig().SecureCookie, + } +} + +// GetAuthFlowEndRedirect returns the redirect URI according to data in request. +// At the end of the OAuth flow, the server needs to send the user somewhere. This should have been stored as a cookie +// during the initial /login call. If that cookie is missing from the request, it will default to the one configured. +func GetAuthFlowEndRedirect(ctx context.Context, defaultRedirect string, authorizedURIs []config.URL, request *http.Request) string { + queryParams := request.URL.Query() + if redirectURL := queryParams.Get(RedirectURLParameter); len(redirectURL) > 0 { + if GetRedirectURLAllowed(ctx, redirectURL, authorizedURIs) { + return redirectURL + } + logger.Warnf(ctx, "Rejecting unauthorized redirect_url from query parameter: %s", redirectURL) + return defaultRedirect + } + + cookie, err := request.Cookie(redirectURLCookieName) + if err != nil { + logger.Debugf(ctx, "Could not detect end-of-flow redirect url cookie") + return defaultRedirect + } + + if GetRedirectURLAllowed(ctx, cookie.Value, authorizedURIs) { + return cookie.Value + } + logger.Warnf(ctx, "Rejecting unauthorized redirect_url from cookie: %s", cookie.Value) + return defaultRedirect +} diff --git a/runs/service/auth/cookie_manager.go b/runs/service/auth/cookie_manager.go new file mode 100644 index 00000000000..0a2113f712d --- /dev/null +++ b/runs/service/auth/cookie_manager.go @@ -0,0 +1,225 @@ +package auth + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "time" + + "golang.org/x/oauth2" + + authpb "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" + "github.com/flyteorg/flyte/v2/flytestdlib/logger" + "github.com/flyteorg/flyte/v2/runs/service/auth/config" +) + +// CookieManager manages encrypted cookie operations for auth tokens. +type CookieManager struct { + hashKey []byte + blockKey []byte + domain string + sameSitePolicy config.SameSite +} + +func NewCookieManager(ctx context.Context, hashKeyEncoded, blockKeyEncoded string, cookieSettings config.CookieSettings) (CookieManager, error) { + logger.Infof(ctx, "Instantiating cookie manager") + + hashKey, err := base64.RawStdEncoding.DecodeString(hashKeyEncoded) + if err != nil { + return CookieManager{}, fmt.Errorf("error decoding hash key bytes: %w", err) + } + + blockKey, err := base64.RawStdEncoding.DecodeString(blockKeyEncoded) + if err != nil { + return CookieManager{}, fmt.Errorf("error decoding block key bytes: %w", err) + } + + return CookieManager{ + hashKey: hashKey, + blockKey: blockKey, + domain: cookieSettings.Domain, + sameSitePolicy: cookieSettings.SameSitePolicy, + }, nil +} + +func (c CookieManager) RetrieveAccessToken(ctx context.Context, request *http.Request) (string, error) { + // If there is an old access token, we will retrieve it + oldAccessToken, err := retrieveSecureCookie(ctx, request, accessTokenCookieName, c.hashKey, c.blockKey) + if err == nil && oldAccessToken != "" { + return oldAccessToken, nil + } + // If there is no old access token, we will retrieve the new split access token + accessTokenFirstHalf, err := retrieveSecureCookie(ctx, request, accessTokenCookieNameSplitFirst, c.hashKey, c.blockKey) + if err != nil { + return "", err + } + accessTokenSecondHalf, err := retrieveSecureCookie(ctx, request, accessTokenCookieNameSplitSecond, c.hashKey, c.blockKey) + if err != nil { + return "", err + } + return accessTokenFirstHalf + accessTokenSecondHalf, nil +} + +// RetrieveTokenValues retrieves id, access and refresh tokens from cookies if they exist. +func (c CookieManager) RetrieveTokenValues(ctx context.Context, request *http.Request) (idToken, accessToken, + refreshToken string, err error) { + + idToken, err = retrieveSecureCookie(ctx, request, idTokenCookieName, c.hashKey, c.blockKey) + if err != nil { + return "", "", "", err + } + + accessToken, err = c.RetrieveAccessToken(ctx, request) + if err != nil { + return "", "", "", err + } + + refreshToken, err = retrieveSecureCookie(ctx, request, refreshTokenCookieName, c.hashKey, c.blockKey) + if err != nil { + // Refresh tokens are optional. + logger.Infof(ctx, "Refresh token doesn't exist or failed to read it. Ignoring this error. Error: %v", err) + err = nil + } + + return +} + +func (c CookieManager) SetUserInfoCookie(ctx context.Context, writer http.ResponseWriter, userInfo *authpb.UserInfoResponse) error { + raw, err := json.Marshal(userInfo) + if err != nil { + return fmt.Errorf("failed to marshal user info to store in a cookie: %w", err) + } + + return c.SetUserInfoCookieRaw(ctx, writer, string(raw)) +} + +func (c CookieManager) SetUserInfoCookieRaw(ctx context.Context, writer http.ResponseWriter, userInfoStr string) error { + userInfoCookie, err := NewSecureCookie(userInfoCookieName, userInfoStr, c.hashKey, c.blockKey, c.domain, c.getHTTPSameSitePolicy()) + if err != nil { + logger.Errorf(ctx, "Error generating encrypted user info cookie %s", err) + return err + } + + http.SetCookie(writer, &userInfoCookie) + return nil +} + +func (c CookieManager) RetrieveUserInfo(ctx context.Context, request *http.Request) (*authpb.UserInfoResponse, error) { + userInfoCookie, err := retrieveSecureCookie(ctx, request, userInfoCookieName, c.hashKey, c.blockKey) + if err != nil { + return nil, err + } + + res := authpb.UserInfoResponse{} + err = json.Unmarshal([]byte(userInfoCookie), &res) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal user info cookie: %w", err) + } + + return &res, nil +} + +func (c CookieManager) RetrieveAuthCodeRequest(ctx context.Context, request *http.Request) (string, error) { + return retrieveSecureCookie(ctx, request, authCodeCookieName, c.hashKey, c.blockKey) +} + +func (c CookieManager) SetAuthCodeCookie(ctx context.Context, writer http.ResponseWriter, authRequestURL string) error { + authCodeCookie, err := NewSecureCookie(authCodeCookieName, authRequestURL, c.hashKey, c.blockKey, c.domain, c.getHTTPSameSitePolicy()) + if err != nil { + logger.Errorf(ctx, "Error generating encrypted auth code cookie %s", err) + return err + } + + http.SetCookie(writer, &authCodeCookie) + return nil +} + +func (c CookieManager) StoreAccessToken(ctx context.Context, accessToken string, writer http.ResponseWriter) error { + midpoint := len(accessToken) / 2 + firstHalf := accessToken[:midpoint] + secondHalf := accessToken[midpoint:] + atCookieFirst, err := NewSecureCookie(accessTokenCookieNameSplitFirst, firstHalf, c.hashKey, c.blockKey, c.domain, c.getHTTPSameSitePolicy()) + if err != nil { + logger.Errorf(ctx, "Error generating encrypted accesstoken cookie first half %s", err) + return err + } + http.SetCookie(writer, &atCookieFirst) + atCookieSecond, err := NewSecureCookie(accessTokenCookieNameSplitSecond, secondHalf, c.hashKey, c.blockKey, c.domain, c.getHTTPSameSitePolicy()) + if err != nil { + logger.Errorf(ctx, "Error generating encrypted accesstoken cookie second half %s", err) + return err + } + http.SetCookie(writer, &atCookieSecond) + return nil +} + +func (c CookieManager) SetTokenCookies(ctx context.Context, writer http.ResponseWriter, token *oauth2.Token) error { + idToken, accessToken, refreshToken, err := ExtractTokensFromOauthToken(token) + if err != nil { + logger.Errorf(ctx, "Unable to read all token values from oauth token: %s", err) + return fmt.Errorf("unable to read all token values from oauth token: %w", err) + } + + idCookie, err := NewSecureCookie(idTokenCookieName, idToken, c.hashKey, c.blockKey, c.domain, c.getHTTPSameSitePolicy()) + if err != nil { + logger.Errorf(ctx, "Error generating encrypted id token cookie %s", err) + return err + } + + http.SetCookie(writer, &idCookie) + + err = c.StoreAccessToken(ctx, accessToken, writer) + if err != nil { + logger.Errorf(ctx, "Error storing access token %s", err) + return err + } + + // Set the refresh cookie if there is a refresh token + if len(refreshToken) > 0 { + refreshCookie, err := NewSecureCookie(refreshTokenCookieName, token.RefreshToken, c.hashKey, c.blockKey, c.domain, c.getHTTPSameSitePolicy()) + if err != nil { + logger.Errorf(ctx, "Error generating encrypted refresh token cookie %s", err) + return err + } + http.SetCookie(writer, &refreshCookie) + } + + return nil +} + +func (c CookieManager) getLogoutCookie(name string) *http.Cookie { + return &http.Cookie{ + Name: name, + Value: "", + Domain: c.domain, + MaxAge: 0, + HttpOnly: true, + Secure: config.GetConfig().SecureCookie, + Expires: time.Now().Add(-1 * time.Hour), + } +} + +func (c CookieManager) DeleteCookies(_ context.Context, writer http.ResponseWriter) { + http.SetCookie(writer, c.getLogoutCookie(accessTokenCookieName)) + http.SetCookie(writer, c.getLogoutCookie(accessTokenCookieNameSplitFirst)) + http.SetCookie(writer, c.getLogoutCookie(accessTokenCookieNameSplitSecond)) + http.SetCookie(writer, c.getLogoutCookie(refreshTokenCookieName)) + http.SetCookie(writer, c.getLogoutCookie(idTokenCookieName)) +} + +func (c CookieManager) getHTTPSameSitePolicy() http.SameSite { + switch c.sameSitePolicy { + case config.SameSiteDefaultMode: + return http.SameSiteDefaultMode + case config.SameSiteLaxMode: + return http.SameSiteLaxMode + case config.SameSiteStrictMode: + return http.SameSiteStrictMode + case config.SameSiteNoneMode: + return http.SameSiteNoneMode + default: + return http.SameSiteDefaultMode + } +} diff --git a/runs/service/auth/cookie_test.go b/runs/service/auth/cookie_test.go new file mode 100644 index 00000000000..07a55a87366 --- /dev/null +++ b/runs/service/auth/cookie_test.go @@ -0,0 +1,160 @@ +package auth + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/gorilla/securecookie" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/flyteorg/flyte/v2/runs/service/auth/config" +) + +func TestHashCsrfState(t *testing.T) { + h := HashCsrfState("hello") + // sha256("hello") hex + assert.Equal(t, "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", h) + assert.Equal(t, h, HashCsrfState("hello"), "deterministic") + assert.NotEqual(t, h, HashCsrfState("world")) +} + +func TestNewCsrfToken(t *testing.T) { + a := NewCsrfToken(1) + b := NewCsrfToken(1) + assert.Equal(t, a, b, "same seed should produce identical token") + assert.Len(t, a, 10) + for _, r := range a { + assert.Contains(t, string(AllowedChars), string(r)) + } + assert.NotEqual(t, a, NewCsrfToken(2)) +} + +func TestNewCsrfCookie(t *testing.T) { + c := NewCsrfCookie() + assert.Equal(t, "flyte_csrf_state", c.Name) + assert.Len(t, c.Value, 10) + assert.Equal(t, http.SameSiteLaxMode, c.SameSite) + assert.True(t, c.HttpOnly) +} + +func newTestKeys() ([]byte, []byte) { + return securecookie.GenerateRandomKey(64), securecookie.GenerateRandomKey(32) +} + +func TestSecureCookie_RoundTrip(t *testing.T) { + hashKey, blockKey := newTestKeys() + cookie, err := NewSecureCookie("flyte_at", "super-secret", hashKey, blockKey, "", http.SameSiteLaxMode) + require.NoError(t, err) + assert.Equal(t, "flyte_at", cookie.Name) + assert.True(t, cookie.HttpOnly) + assert.Equal(t, "/", cookie.Path) + + out, err := ReadSecureCookie(context.Background(), cookie, hashKey, blockKey) + require.NoError(t, err) + assert.Equal(t, "super-secret", out) +} + +func TestReadSecureCookie_WrongKey(t *testing.T) { + hashKey, blockKey := newTestKeys() + cookie, err := NewSecureCookie("flyte_at", "value", hashKey, blockKey, "", http.SameSiteLaxMode) + require.NoError(t, err) + + wrongHash, wrongBlock := newTestKeys() + _, err = ReadSecureCookie(context.Background(), cookie, wrongHash, wrongBlock) + assert.Error(t, err) +} + +func TestRetrieveSecureCookie_Missing(t *testing.T) { + hashKey, blockKey := newTestKeys() + req := httptest.NewRequest(http.MethodGet, "/", nil) + _, err := retrieveSecureCookie(context.Background(), req, "missing", hashKey, blockKey) + assert.Error(t, err) +} + +func TestVerifyCsrfCookie(t *testing.T) { + token := "abcdefghij" + hashed := HashCsrfState(token) + + req := httptest.NewRequest(http.MethodPost, "/callback", strings.NewReader("state="+hashed)) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.AddCookie(&http.Cookie{Name: "flyte_csrf_state", Value: token}) + require.NoError(t, req.ParseForm()) + require.NoError(t, VerifyCsrfCookie(context.Background(), req)) +} + +func TestVerifyCsrfCookie_Mismatch(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/callback", strings.NewReader("state=wrong")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.AddCookie(&http.Cookie{Name: "flyte_csrf_state", Value: "abcdefghij"}) + require.NoError(t, req.ParseForm()) + assert.Error(t, VerifyCsrfCookie(context.Background(), req)) +} + +func TestVerifyCsrfCookie_EmptyState(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/callback", strings.NewReader("")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + require.NoError(t, req.ParseForm()) + assert.Error(t, VerifyCsrfCookie(context.Background(), req)) +} + +func TestVerifyCsrfCookie_MissingCookie(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/callback", strings.NewReader("state=x")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + require.NoError(t, req.ParseForm()) + assert.Error(t, VerifyCsrfCookie(context.Background(), req)) +} + +func TestNewRedirectCookie(t *testing.T) { + c := NewRedirectCookie(context.Background(), "/console/projects") + require.NotNil(t, c) + assert.Equal(t, "flyte_redirect_location", c.Name) + assert.Equal(t, "/console/projects", c.Value) + assert.True(t, c.HttpOnly) +} + +func TestNewRedirectCookie_Invalid(t *testing.T) { + assert.Nil(t, NewRedirectCookie(context.Background(), "")) +} + +func TestGetAuthFlowEndRedirect_QueryAllowed(t *testing.T) { + authorized := []config.URL{{URL: *mustURL(t, "https://flyte.mycorp.com")}} + req := httptest.NewRequest(http.MethodGet, "https://flyte.mycorp.com/callback?redirect_url=https://flyte.mycorp.com/console", nil) + + got := GetAuthFlowEndRedirect(context.Background(), "/default", authorized, req) + assert.Equal(t, "https://flyte.mycorp.com/console", got) +} + +func TestGetAuthFlowEndRedirect_QueryUnauthorizedFallsBack(t *testing.T) { + authorized := []config.URL{{URL: *mustURL(t, "https://flyte.mycorp.com")}} + req := httptest.NewRequest(http.MethodGet, "https://flyte.mycorp.com/callback?redirect_url=https://evil.example.com", nil) + + got := GetAuthFlowEndRedirect(context.Background(), "/default", authorized, req) + assert.Equal(t, "/default", got) +} + +func TestGetAuthFlowEndRedirect_CookieFallback(t *testing.T) { + authorized := []config.URL{{URL: *mustURL(t, "https://flyte.mycorp.com")}} + req := httptest.NewRequest(http.MethodGet, "https://flyte.mycorp.com/callback", nil) + req.AddCookie(&http.Cookie{Name: "flyte_redirect_location", Value: "https://flyte.mycorp.com/console"}) + + got := GetAuthFlowEndRedirect(context.Background(), "/default", authorized, req) + assert.Equal(t, "https://flyte.mycorp.com/console", got) +} + +func TestGetAuthFlowEndRedirect_NoCookieReturnsDefault(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "https://flyte.mycorp.com/callback", nil) + got := GetAuthFlowEndRedirect(context.Background(), "/default", nil, req) + assert.Equal(t, "/default", got) +} + +func mustURL(t *testing.T, s string) *url.URL { + t.Helper() + u, err := url.Parse(s) + require.NoError(t, err) + return u +} diff --git a/runs/service/auth/handler_utils.go b/runs/service/auth/handler_utils.go new file mode 100644 index 00000000000..2307932b20c --- /dev/null +++ b/runs/service/auth/handler_utils.go @@ -0,0 +1,201 @@ +package auth + +import ( + "context" + "net/http" + "net/url" + "strings" + + "google.golang.org/grpc/metadata" + + "github.com/flyteorg/flyte/v2/flytestdlib/logger" + "github.com/flyteorg/flyte/v2/runs/service/auth/config" +) + +const ( + metadataXForwardedHost = "x-forwarded-host" + metadataAuthority = ":authority" +) + +// URLFromRequest attempts to reconstruct the url from the request object. Or nil if not possible +func URLFromRequest(req *http.Request) *url.URL { + if req == nil { + return nil + } + + // from browser req.RequestURI is "/login" and u.scheme is "" + // from unit test req.RequestURI is "" and u is nil + // That means that this function, URLFromRequest(req) returns https://localhost:8088 even though there's no SSL, + // when the request is made from http://localhost:8088 in the web browser. + // Given how this function is used however, it's okay - we're only picking which option to use from the list of + // authorized URIs. + u, _ := url.ParseRequestURI(req.RequestURI) + if u != nil && u.IsAbs() { + return u + } + + if len(req.Host) == 0 { + return nil + } + + scheme := "https://" + if req.URL != nil && len(req.URL.Scheme) > 0 { + scheme = req.URL.Scheme + "://" + } + + u, _ = url.Parse(scheme + req.Host) + return u +} + +// URLFromContext attempts to retrieve the original url from context. gRPC gateway sets metadata in context that refers +// to the original host. Or nil if metadata isn't set. +func URLFromContext(ctx context.Context) *url.URL { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil + } + + forwardedHost := getMetadataValue(md, metadataXForwardedHost) + if len(forwardedHost) == 0 { + forwardedHost = getMetadataValue(md, metadataAuthority) + } + + if len(forwardedHost) == 0 { + return nil + } + + u, _ := url.Parse("https://" + forwardedHost) + return u +} + +// getMetadataValue retrieves the first value for a given key from gRPC metadata. +func getMetadataValue(md metadata.MD, key string) string { + vals := md.Get(key) + if len(vals) == 0 { + return "" + } + return vals[0] +} + +// FirstURL gets the first non-nil url from a list of given urls. +func FirstURL(urls ...*url.URL) *url.URL { + for _, u := range urls { + if u != nil { + return u + } + } + + return nil +} + +// wildcardMatch checks if hostname matches a wildcard pattern (only one level deep) +// Supports patterns like "*.union.ai" matching "tenant1.union.ai" +func wildcardMatch(hostname, pattern string) bool { + if strings.HasPrefix(pattern, "*.") { + urlParts := strings.SplitN(hostname, ".", 2) + if len(urlParts) < 2 { + return false + } + return urlParts[1] == pattern[2:] + } + return hostname == pattern +} + +// buildURL constructs a URL using the authorized template but with the matched hostname +func buildURL(authorizedURL *url.URL, matchedHostname string) *url.URL { + result := *authorizedURL // Copy the URL to avoid modifying the original + if authorizedURL.Port() != "" { + result.Host = matchedHostname + ":" + authorizedURL.Port() + } else { + result.Host = matchedHostname + } + return &result +} + +// GetPublicURL attempts to retrieve the public url of the service. If httpPublicUri is set in the config, it takes +// precedence. If the request is not nil and has a host set, it comes second and lastly it attempts to retrieve the url +// from context if set (e.g. by gRPC gateway). +func GetPublicURL(ctx context.Context, req *http.Request, cfg config.Config) *url.URL { + u := FirstURL(URLFromRequest(req), URLFromContext(ctx)) + var hostMatching *url.URL + var hostAndPortMatching *url.URL + var matchedHostname string + + for i, authorized := range cfg.AuthorizedURIs { + if u == nil { + return &authorized.URL + } + + if wildcardMatch(u.Hostname(), authorized.Hostname()) { + matchedHostname = u.Hostname() + hostMatching = &cfg.AuthorizedURIs[i].URL + if u.Port() == authorized.Port() { + hostAndPortMatching = &cfg.AuthorizedURIs[i].URL + } + + if u.Scheme == authorized.Scheme { + return buildURL(&cfg.AuthorizedURIs[i].URL, matchedHostname) + } + } + } + + if hostAndPortMatching != nil { + return buildURL(hostAndPortMatching, matchedHostname) + } + + if hostMatching != nil { + return buildURL(hostMatching, matchedHostname) + } + + if len(cfg.AuthorizedURIs) > 0 { + return &cfg.AuthorizedURIs[0].URL + } + + return u +} + +// GetIssuer returns the issuer from SelfAuthServer config, or falls back to public URL. +func GetIssuer(ctx context.Context, req *http.Request, cfg config.Config) string { + if configIssuer := cfg.AppAuth.SelfAuthServer.Issuer; len(configIssuer) > 0 { + return configIssuer + } + + return GetPublicURL(ctx, req, cfg).String() +} + +// isAuthorizedRedirectURL checks if a redirect URL matches an authorized URL pattern. +func isAuthorizedRedirectURL(u *url.URL, authorizedURL *url.URL) bool { + if u == nil || authorizedURL == nil { + return false + } + if u.Scheme != authorizedURL.Scheme || u.Port() != authorizedURL.Port() { + return false + } + return wildcardMatch(u.Hostname(), authorizedURL.Hostname()) +} + +// GetRedirectURLAllowed checks whether a redirect URL is in the list of authorized URIs. +func GetRedirectURLAllowed(ctx context.Context, urlRedirectParam string, authorizedURIs []config.URL) bool { + if len(urlRedirectParam) == 0 { + logger.Debugf(ctx, "not validating whether empty redirect url is authorized") + return true + } + redirectURL, err := url.Parse(urlRedirectParam) + if err != nil { + logger.Debugf(ctx, "failed to parse user-supplied redirect url: %s with err: %v", urlRedirectParam, err) + return false + } + if redirectURL.Host == "" { + logger.Debugf(ctx, "not validating whether relative redirect url is authorized") + return true + } + logger.Debugf(ctx, "validating whether redirect url: %s is authorized", redirectURL) + for i := range authorizedURIs { + if isAuthorizedRedirectURL(redirectURL, &authorizedURIs[i].URL) { + logger.Debugf(ctx, "authorizing redirect url: %s against authorized uri: %s", redirectURL.String(), authorizedURIs[i].String()) + return true + } + } + logger.Debugf(ctx, "not authorizing redirect url: %s", redirectURL.String()) + return false +} diff --git a/runs/service/auth/handler_utils_test.go b/runs/service/auth/handler_utils_test.go new file mode 100644 index 00000000000..670ab5acb94 --- /dev/null +++ b/runs/service/auth/handler_utils_test.go @@ -0,0 +1,175 @@ +package auth + +import ( + "context" + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + + authConfig "github.com/flyteorg/flyte/v2/runs/service/auth/config" + "google.golang.org/grpc/metadata" +) + +func TestURLFromRequest(t *testing.T) { + t.Run("nil request", func(t *testing.T) { + assert.Nil(t, URLFromRequest(nil)) + }) + + t.Run("request with host", func(t *testing.T) { + req := &http.Request{ + Host: "example.com:8080", + URL: &url.URL{Scheme: "https"}, + } + u := URLFromRequest(req) + assert.NotNil(t, u) + assert.Equal(t, "example.com:8080", u.Host) + }) + + t.Run("request with empty host", func(t *testing.T) { + req := &http.Request{ + URL: &url.URL{}, + } + assert.Nil(t, URLFromRequest(req)) + }) + + t.Run("request with absolute request URI", func(t *testing.T) { + req := &http.Request{ + RequestURI: "https://absolute.example.com/path", + } + u := URLFromRequest(req) + assert.NotNil(t, u) + assert.Equal(t, "absolute.example.com", u.Hostname()) + }) +} + +func TestURLFromContext(t *testing.T) { + t.Run("no metadata", func(t *testing.T) { + assert.Nil(t, URLFromContext(context.Background())) + }) + + t.Run("x-forwarded-host", func(t *testing.T) { + md := metadata.Pairs("x-forwarded-host", "forwarded.example.com") + ctx := metadata.NewIncomingContext(context.Background(), md) + u := URLFromContext(ctx) + assert.NotNil(t, u) + assert.Equal(t, "forwarded.example.com", u.Hostname()) + assert.Equal(t, "https", u.Scheme) + }) + + t.Run("authority fallback", func(t *testing.T) { + md := metadata.Pairs(":authority", "authority.example.com") + ctx := metadata.NewIncomingContext(context.Background(), md) + u := URLFromContext(ctx) + assert.NotNil(t, u) + assert.Equal(t, "authority.example.com", u.Hostname()) + }) +} + +func TestFirstURL(t *testing.T) { + u1, _ := url.Parse("https://first.example.com") + u2, _ := url.Parse("https://second.example.com") + + assert.Nil(t, FirstURL()) + assert.Nil(t, FirstURL(nil, nil)) + assert.Equal(t, u1, FirstURL(u1, u2)) + assert.Equal(t, u2, FirstURL(nil, u2)) +} + +func TestWildcardMatch(t *testing.T) { + assert.True(t, wildcardMatch("example.com", "example.com")) + assert.True(t, wildcardMatch("tenant1.union.ai", "*.union.ai")) + assert.False(t, wildcardMatch("union.ai", "*.union.ai")) + assert.False(t, wildcardMatch("other.com", "example.com")) + assert.False(t, wildcardMatch("sub.tenant1.union.ai", "*.union.ai")) +} + +func TestBuildURL(t *testing.T) { + authorized, _ := url.Parse("https://example.com:8080/path") + result := buildURL(authorized, "tenant1.example.com") + assert.Equal(t, "tenant1.example.com:8080", result.Host) + assert.Equal(t, "/path", result.Path) + + noPort, _ := url.Parse("https://example.com/path") + result = buildURL(noPort, "tenant1.example.com") + assert.Equal(t, "tenant1.example.com", result.Host) +} + +func TestGetPublicURL(t *testing.T) { + t.Run("no request, returns first authorized URI", func(t *testing.T) { + cfg := authConfig.Config{ + AuthorizedURIs: []authConfig.URL{ + {URL: *mustParseURL("https://flyte.example.com")}, + }, + } + u := GetPublicURL(context.Background(), nil, cfg) + assert.Equal(t, "flyte.example.com", u.Hostname()) + }) + + t.Run("no request, no authorized URIs, returns nil", func(t *testing.T) { + cfg := authConfig.Config{} + u := GetPublicURL(context.Background(), nil, cfg) + assert.Nil(t, u) + }) + + t.Run("wildcard match with request", func(t *testing.T) { + cfg := authConfig.Config{ + AuthorizedURIs: []authConfig.URL{ + {URL: *mustParseURL("https://*.union.ai")}, + }, + } + req := &http.Request{ + Host: "tenant1.union.ai", + URL: &url.URL{Scheme: "https"}, + } + u := GetPublicURL(context.Background(), req, cfg) + assert.Equal(t, "tenant1.union.ai", u.Hostname()) + assert.Equal(t, "https", u.Scheme) + }) + + t.Run("exact match with matching scheme", func(t *testing.T) { + cfg := authConfig.Config{ + AuthorizedURIs: []authConfig.URL{ + {URL: *mustParseURL("https://flyte.example.com:8080")}, + {URL: *mustParseURL("http://flyte.example.com:8080")}, + }, + } + req := &http.Request{ + Host: "flyte.example.com:8080", + URL: &url.URL{Scheme: "http"}, + } + u := GetPublicURL(context.Background(), req, cfg) + assert.Equal(t, "http", u.Scheme) + }) +} + +func TestGetIssuer(t *testing.T) { + t.Run("custom issuer", func(t *testing.T) { + cfg := authConfig.Config{ + AppAuth: authConfig.OAuth2Options{ + SelfAuthServer: authConfig.AuthorizationServer{ + Issuer: "https://custom-issuer.example.com", + }, + }, + } + assert.Equal(t, "https://custom-issuer.example.com", GetIssuer(context.Background(), nil, cfg)) + }) + + t.Run("falls back to public URL", func(t *testing.T) { + cfg := authConfig.Config{ + AuthorizedURIs: []authConfig.URL{ + {URL: *mustParseURL("https://flyte.example.com")}, + }, + } + assert.Equal(t, "https://flyte.example.com", GetIssuer(context.Background(), nil, cfg)) + }) +} + +func mustParseURL(rawURL string) *url.URL { + u, err := url.Parse(rawURL) + if err != nil { + panic(err) + } + return u +} diff --git a/runs/service/auth/handlers.go b/runs/service/auth/handlers.go new file mode 100644 index 00000000000..da2fd1ca423 --- /dev/null +++ b/runs/service/auth/handlers.go @@ -0,0 +1,291 @@ +package auth + +import ( + "context" + "fmt" + "net/http" + "strings" + + "github.com/coreos/go-oidc/v3/oidc" + "golang.org/x/oauth2" + + authpb "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" + "github.com/flyteorg/flyte/v2/flytestdlib/logger" + "github.com/flyteorg/flyte/v2/runs/service/auth/config" +) + +// PreRedirectHookError is returned by PreRedirectHookFunc to signal an error with an HTTP status code. +type PreRedirectHookError struct { + Message string + Code int +} + +func (e *PreRedirectHookError) Error() string { + return e.Message +} + +// PreRedirectHookFunc is called before the redirect at the end of a successful auth callback flow. +type PreRedirectHookFunc func(ctx context.Context, request *http.Request, w http.ResponseWriter) *PreRedirectHookError + +// LogoutHookFunc is called during logout to perform additional cleanup. +type LogoutHookFunc func(ctx context.Context, request *http.Request, w http.ResponseWriter) error + +// AuthHandlerConfig holds dependencies needed by the HTTP auth handlers. +type AuthHandlerConfig struct { + CookieManager CookieManager + OAuth2Config *oauth2.Config + OIDCProvider *oidc.Provider + ResourceServer OAuth2ResourceServer + AuthConfig config.Config + HTTPClient *http.Client + PreRedirectHook PreRedirectHookFunc + LogoutHook LogoutHookFunc +} + +// RegisterHandlers registers the standard OAuth2/OIDC HTTP handlers on the given mux. +func RegisterHandlers(ctx context.Context, mux *http.ServeMux, h *AuthHandlerConfig) { + mux.HandleFunc("/login", RefreshTokensIfNeededHandler(ctx, h, + GetLoginHandler(ctx, h))) + mux.HandleFunc("/callback", GetCallbackHandler(ctx, h)) + mux.HandleFunc(fmt.Sprintf("/%s", OIdCMetadataEndpoint), GetOIdCMetadataEndpointRedirectHandler(ctx, h)) + mux.HandleFunc("/logout", GetLogoutEndpointHandler(ctx, h)) +} + +// RefreshTokensIfNeededHandler wraps a handler to attempt token refresh before redirecting. +func RefreshTokensIfNeededHandler(ctx context.Context, h *AuthHandlerConfig, authHandler http.HandlerFunc) http.HandlerFunc { + return func(writer http.ResponseWriter, request *http.Request) { + newToken, userInfo, refreshed, err := RefreshTokensIfNeeded(ctx, h, request) + if err != nil { + logger.Infof(ctx, "Failed to refresh tokens. Restarting login flow. Error: %s", err) + authHandler(writer, request) + return + } + + if refreshed { + logger.Debugf(ctx, "Tokens are refreshed. Saving new tokens into cookies.") + if err = h.CookieManager.SetTokenCookies(ctx, writer, newToken); err != nil { + logger.Infof(ctx, "Failed to write tokens to response. Restarting login flow. Error: %s", err) + authHandler(writer, request) + return + } + + if err = h.CookieManager.SetUserInfoCookie(ctx, writer, userInfo); err != nil { + logger.Infof(ctx, "Failed to write user info to response. Restarting login flow. Error: %s", err) + authHandler(writer, request) + return + } + } + + redirectURL := GetAuthFlowEndRedirect(ctx, h.AuthConfig.UserAuth.RedirectURL.String(), h.AuthConfig.AuthorizedURIs, request) + http.Redirect(writer, request, redirectURL, http.StatusTemporaryRedirect) + } +} + +// RefreshTokensIfNeeded checks if tokens need refreshing and returns refreshed tokens if so. +func RefreshTokensIfNeeded(ctx context.Context, h *AuthHandlerConfig, request *http.Request) ( + token *oauth2.Token, userInfo *authpb.UserInfoResponse, refreshed bool, err error) { + + ctx = context.WithValue(ctx, oauth2.HTTPClient, h.HTTPClient) + + idToken, accessToken, refreshToken, err := h.CookieManager.RetrieveTokenValues(ctx, request) + if err != nil { + return nil, nil, false, fmt.Errorf("failed to retrieve tokens from request: %w", err) + } + + _, err = ParseIDTokenAndValidate(ctx, h.AuthConfig.UserAuth.OpenID.ClientID, idToken, h.OIDCProvider) + if err != nil { + if strings.Contains(err.Error(), "token is expired") && len(refreshToken) > 0 { + logger.Debugf(ctx, "Expired id token found, attempting to refresh") + newToken, refreshErr := GetRefreshedToken(ctx, h.OAuth2Config, accessToken, refreshToken) + if refreshErr != nil { + return nil, nil, false, fmt.Errorf("failed to refresh tokens: %w", refreshErr) + } + + userInfo, queryErr := QueryUserInfoUsingAccessToken(ctx, request, h, newToken.AccessToken) + if queryErr != nil { + return nil, nil, false, fmt.Errorf("failed to query user info: %w", queryErr) + } + + return newToken, userInfo, true, nil + } + return nil, nil, false, fmt.Errorf("failed to validate tokens: %w", err) + } + + return NewOAuthTokenFromRaw(accessToken, refreshToken, idToken), nil, false, nil +} + +// GetLoginHandler returns an HTTP handler that starts the OAuth2 login flow. +func GetLoginHandler(ctx context.Context, h *AuthHandlerConfig) http.HandlerFunc { + return func(writer http.ResponseWriter, request *http.Request) { + csrfCookie := NewCsrfCookie() + csrfToken := csrfCookie.Value + http.SetCookie(writer, &csrfCookie) + + state := HashCsrfState(csrfToken) + logger.Debugf(ctx, "Setting CSRF state cookie to %s and state to %s\n", csrfToken, state) + urlString := h.OAuth2Config.AuthCodeURL(state) + queryParams := request.URL.Query() + if !GetRedirectURLAllowed(ctx, queryParams.Get(RedirectURLParameter), h.AuthConfig.AuthorizedURIs) { + logger.Infof(ctx, "unauthorized redirect URI") + writer.WriteHeader(http.StatusForbidden) + return + } + if flowEndRedirectURL := queryParams.Get(RedirectURLParameter); flowEndRedirectURL != "" { + redirectCookie := NewRedirectCookie(ctx, flowEndRedirectURL) + if redirectCookie != nil { + http.SetCookie(writer, redirectCookie) + } + } + + http.Redirect(writer, request, urlString, http.StatusTemporaryRedirect) + } +} + +// GetCallbackHandler returns an HTTP handler that completes the OAuth2 authorization code flow. +func GetCallbackHandler(ctx context.Context, h *AuthHandlerConfig) http.HandlerFunc { + return func(writer http.ResponseWriter, request *http.Request) { + logger.Debugf(ctx, "Running callback handler... for RequestURI %v", request.RequestURI) + authorizationCode := request.FormValue(AuthorizationResponseCodeType) + + ctx = context.WithValue(ctx, oauth2.HTTPClient, h.HTTPClient) + + if err := VerifyCsrfCookie(ctx, request); err != nil { + logger.Errorf(ctx, "Invalid CSRF token cookie %s", err) + writer.WriteHeader(http.StatusUnauthorized) + return + } + + token, err := h.OAuth2Config.Exchange(ctx, authorizationCode) + if err != nil { + logger.Errorf(ctx, "Error when exchanging code %s", err) + writer.WriteHeader(http.StatusForbidden) + return + } + + if err = h.CookieManager.SetTokenCookies(ctx, writer, token); err != nil { + logger.Errorf(ctx, "Error setting encrypted JWT cookie %s", err) + writer.WriteHeader(http.StatusForbidden) + return + } + + userInfo, err := QueryUserInfoUsingAccessToken(ctx, request, h, token.AccessToken) + if err != nil { + logger.Errorf(ctx, "Failed to query user info. Error: %v", err) + writer.WriteHeader(http.StatusForbidden) + return + } + + if err = h.CookieManager.SetUserInfoCookie(ctx, writer, userInfo); err != nil { + logger.Errorf(ctx, "Error setting encrypted user info cookie. Error: %v", err) + writer.WriteHeader(http.StatusForbidden) + return + } + + if h.PreRedirectHook != nil { + if hookErr := h.PreRedirectHook(ctx, request, writer); hookErr != nil { + logger.Errorf(ctx, "failed the preRedirect hook due %v with status code %v", hookErr.Message, hookErr.Code) + if http.StatusText(hookErr.Code) != "" { + writer.WriteHeader(hookErr.Code) + } else { + writer.WriteHeader(http.StatusInternalServerError) + } + return + } + } + + redirectURL := GetAuthFlowEndRedirect(ctx, h.AuthConfig.UserAuth.RedirectURL.String(), h.AuthConfig.AuthorizedURIs, request) + http.Redirect(writer, request, redirectURL, http.StatusTemporaryRedirect) + } +} + +// GetOIdCMetadataEndpointRedirectHandler returns a handler that redirects to the OIDC metadata endpoint. +func GetOIdCMetadataEndpointRedirectHandler(_ context.Context, h *AuthHandlerConfig) http.HandlerFunc { + return func(writer http.ResponseWriter, request *http.Request) { + oidcMetadataURL := h.AuthConfig.UserAuth.OpenID.BaseURL.JoinPath("/").JoinPath(OIdCMetadataEndpoint) + http.Redirect(writer, request, oidcMetadataURL.String(), http.StatusSeeOther) + } +} + +// GetLogoutEndpointHandler returns a handler that clears auth cookies and optionally redirects. +func GetLogoutEndpointHandler(ctx context.Context, h *AuthHandlerConfig) http.HandlerFunc { + return func(writer http.ResponseWriter, request *http.Request) { + if h.LogoutHook != nil { + if err := h.LogoutHook(ctx, request, writer); err != nil { + logger.Errorf(ctx, "logout hook failed: %v", err) + writer.WriteHeader(http.StatusInternalServerError) + return + } + } + + logger.Debugf(ctx, "deleting auth cookies") + h.CookieManager.DeleteCookies(ctx, writer) + + queryParams := request.URL.Query() + if redirectURL := queryParams.Get(RedirectURLParameter); redirectURL != "" { + if !GetRedirectURLAllowed(ctx, redirectURL, h.AuthConfig.AuthorizedURIs) { + logger.Warnf(ctx, "Rejecting unauthorized redirect_url in logout: %s", redirectURL) + redirectURL = h.AuthConfig.UserAuth.RedirectURL.String() + } + http.Redirect(writer, request, redirectURL, http.StatusTemporaryRedirect) + } + } +} + +// QueryUserInfoUsingAccessToken fetches user info from the OIDC provider using an access token. +func QueryUserInfoUsingAccessToken(ctx context.Context, originalRequest *http.Request, h *AuthHandlerConfig, accessToken string) ( + *authpb.UserInfoResponse, error) { + + originalToken := oauth2.Token{ + AccessToken: accessToken, + } + + tokenSource := h.OAuth2Config.TokenSource(ctx, &originalToken) + + userInfo, err := h.OIDCProvider.UserInfo(ctx, tokenSource) + if err != nil { + logger.Errorf(ctx, "Error getting user info from IDP %s", err) + return &authpb.UserInfoResponse{}, fmt.Errorf("error getting user info from IDP") + } + + resp := &authpb.UserInfoResponse{} + if err = userInfo.Claims(resp); err != nil { + logger.Errorf(ctx, "Error getting user info from IDP %s", err) + return &authpb.UserInfoResponse{}, fmt.Errorf("error getting user info from IDP") + } + + return resp, nil +} + +// IdentityContextFromRequest extracts identity from an HTTP request (header or cookies). +func IdentityContextFromRequest(ctx context.Context, req *http.Request, h *AuthHandlerConfig) ( + *IdentityContext, error) { + + authHeader := DefaultAuthorizationHeader + if len(h.AuthConfig.HTTPAuthorizationHeader) > 0 { + authHeader = h.AuthConfig.HTTPAuthorizationHeader + } + + headerValue := req.Header.Get(authHeader) + if len(headerValue) == 0 { + headerValue = req.Header.Get(DefaultAuthorizationHeader) + } + + if len(headerValue) > 0 { + if strings.HasPrefix(headerValue, BearerScheme+" ") { + expectedAudience := GetPublicURL(ctx, req, h.AuthConfig).String() + return h.ResourceServer.ValidateAccessToken(ctx, expectedAudience, strings.TrimPrefix(headerValue, BearerScheme+" ")) + } + } + + idToken, _, _, err := h.CookieManager.RetrieveTokenValues(ctx, req) + if err != nil || len(idToken) == 0 { + return nil, fmt.Errorf("unauthenticated request. IDToken Len [%v], Error: %w", len(idToken), err) + } + + userInfo, err := h.CookieManager.RetrieveUserInfo(ctx, req) + if err != nil { + return nil, fmt.Errorf("unauthenticated request: %w", err) + } + + return IdentityContextFromIDToken(ctx, idToken, h.AuthConfig.UserAuth.OpenID.ClientID, h.OIDCProvider, userInfo) +} diff --git a/runs/service/auth/http_middleware.go b/runs/service/auth/http_middleware.go new file mode 100644 index 00000000000..03cba45866d --- /dev/null +++ b/runs/service/auth/http_middleware.go @@ -0,0 +1,97 @@ +package auth + +import ( + "net" + "net/http" + "strings" + + "github.com/flyteorg/flyte/v2/flytestdlib/logger" +) + +// publicPathPrefixes lists request paths that never require authentication. +// These cover: +// - health probes +// - the browser OAuth2/OIDC flow +// - metadata discovery (AuthMetadataService is called pre-auth) +// - intra-cluster services that task pods call via the ClusterIP service +// without credentials (ActionsService, InternalRunService). These are +// deliberately excluded from the external ALB ingress in +// charts/flyte-binary/templates/_helpers.tpl so they cannot be reached +// from the public internet; only in-cluster pods can hit them. +var publicPathPrefixes = []string{ + "/healthz", + "/readyz", + "/healthcheck", + "/login", + "/callback", + "/logout", + "/.well-known/", + "/flyteidl2.auth.AuthMetadataService/", + "/flyteidl2.actions.ActionsService/", + "/flyteidl2.workflow.InternalRunService/", +} + +// IsPublicPath reports whether an HTTP request path bypasses authentication. +func IsPublicPath(path string) bool { + for _, p := range publicPathPrefixes { + if strings.HasPrefix(path, p) { + return true + } + } + return false +} + +// isLoopbackRequest returns true when the request originated from the local +// loopback interface. The unified Flyte binary makes intra-process connect-rpc +// calls to its own HTTP mux via http://localhost: (e.g. RunService -> +// ActionsService). Those calls have no Authorization header and must not be +// forced through the external auth gate, or every run creation will fail with +// 401. External traffic (ALB, port-forward from outside the pod) never has a +// loopback RemoteAddr. +func isLoopbackRequest(req *http.Request) bool { + host, _, err := net.SplitHostPort(req.RemoteAddr) + if err != nil { + host = req.RemoteAddr + } + ip := net.ParseIP(host) + if ip == nil { + return false + } + return ip.IsLoopback() +} + +// GetAuthenticationHTTPInterceptor returns middleware that validates a bearer +// token or auth cookies on incoming HTTP requests and injects the resulting +// IdentityContext into the request context. Public paths (see IsPublicPath) +// pass through without validation. When DisableForHTTP is set on the config, +// every request passes through unchanged. +func GetAuthenticationHTTPInterceptor(h *AuthHandlerConfig) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if IsPublicPath(req.URL.Path) { + next.ServeHTTP(w, req) + return + } + + if isLoopbackRequest(req) { + next.ServeHTTP(w, req) + return + } + + if h.AuthConfig.DisableForHTTP { + next.ServeHTTP(w, req) + return + } + + ctx := req.Context() + identity, err := IdentityContextFromRequest(ctx, req, h) + if err != nil { + logger.Infof(ctx, "unauthenticated request to %s: %v", req.URL.Path, err) + w.WriteHeader(http.StatusUnauthorized) + return + } + + next.ServeHTTP(w, req.WithContext(identity.WithContext(ctx))) + }) + } +} diff --git a/runs/service/auth/http_middleware_test.go b/runs/service/auth/http_middleware_test.go new file mode 100644 index 00000000000..6aa887c825c --- /dev/null +++ b/runs/service/auth/http_middleware_test.go @@ -0,0 +1,176 @@ +package auth + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/flyteorg/flyte/v2/runs/service/auth/config" +) + +func TestIsPublicPath(t *testing.T) { + cases := map[string]bool{ + "/healthz": true, + "/readyz": true, + "/healthcheck": true, + "/login": true, + "/login?redirect_url=/console": true, + "/callback": true, + "/logout": true, + "/.well-known/openid-configuration": true, + "/.well-known/oauth-authorization-server": true, + "/flyteidl2.auth.AuthMetadataService/GetOAuth2Metadata": true, + "/flyteidl2.actions.ActionsService/CreateAction": true, + "/flyteidl2.workflow.InternalRunService/UpdateRun": true, + "/flyteidl2.workflow.RunService/CreateRun": false, + "/flyteidl2.auth.IdentityService/UserInfo": false, + "/": false, + "/api/v1/projects": false, + } + for path, want := range cases { + got := IsPublicPath(path) + assert.Equalf(t, want, got, "IsPublicPath(%q)", path) + } +} + +// servedBy wraps a boolean flag so tests can check that the next handler ran. +type servedBy struct{ called bool } + +func (s *servedBy) handler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + s.called = true + w.WriteHeader(http.StatusOK) + }) +} + +func TestMiddleware_PublicPathBypassesAuth(t *testing.T) { + // Even with a zero AuthHandlerConfig (no resource server, no cookie manager), + // public paths must not touch any auth plumbing. + h := &AuthHandlerConfig{AuthConfig: config.Config{}} + mw := GetAuthenticationHTTPInterceptor(h) + + var sb servedBy + req := httptest.NewRequest(http.MethodGet, "/healthz", nil) + w := httptest.NewRecorder() + mw(sb.handler()).ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.True(t, sb.called, "public path should have reached the next handler") +} + +func TestMiddleware_DisabledForHTTPBypassesAuth(t *testing.T) { + h := &AuthHandlerConfig{AuthConfig: config.Config{DisableForHTTP: true}} + mw := GetAuthenticationHTTPInterceptor(h) + + var sb servedBy + req := httptest.NewRequest(http.MethodGet, "/flyteidl2.workflow.RunService/CreateRun", nil) + w := httptest.NewRecorder() + mw(sb.handler()).ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.True(t, sb.called) +} + +func TestMiddleware_LoopbackIPv4BypassesAuth(t *testing.T) { + // Intra-process connect-rpc calls (e.g. runs -> RunService on localhost) + // must pass through the middleware without an Authorization header. + h := &AuthHandlerConfig{AuthConfig: config.Config{}, CookieManager: CookieManager{}} + mw := GetAuthenticationHTTPInterceptor(h) + + var sb servedBy + req := httptest.NewRequest(http.MethodPost, "/flyteidl2.workflow.RunService/CreateRun", nil) + req.RemoteAddr = "127.0.0.1:54321" + w := httptest.NewRecorder() + mw(sb.handler()).ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.True(t, sb.called, "loopback call must reach next handler") +} + +func TestMiddleware_LoopbackIPv6BypassesAuth(t *testing.T) { + h := &AuthHandlerConfig{AuthConfig: config.Config{}, CookieManager: CookieManager{}} + mw := GetAuthenticationHTTPInterceptor(h) + + var sb servedBy + req := httptest.NewRequest(http.MethodPost, "/flyteidl2.workflow.RunService/CreateRun", nil) + req.RemoteAddr = "[::1]:54321" + w := httptest.NewRecorder() + mw(sb.handler()).ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.True(t, sb.called) +} + +func TestMiddleware_ActionsServicePublicFromPodIP(t *testing.T) { + // Task pods call ActionsService from their pod IP (non-loopback) over + // the ClusterIP service. The path must be allowlisted so the SDK can + // launch actions without carrying credentials. + h := &AuthHandlerConfig{AuthConfig: config.Config{}, CookieManager: CookieManager{}} + mw := GetAuthenticationHTTPInterceptor(h) + + var sb servedBy + req := httptest.NewRequest(http.MethodPost, "/flyteidl2.actions.ActionsService/CreateAction", nil) + req.RemoteAddr = "10.1.193.72:33100" + w := httptest.NewRecorder() + mw(sb.handler()).ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.True(t, sb.called, "ActionsService must be reachable from task pods without auth") +} + +func TestMiddleware_NonLoopbackStillBlocks(t *testing.T) { + // A caller from a real pod IP must still hit the 401 path when no auth + // is present — the loopback bypass is strictly for in-process calls. + // Use RunService (user-facing) rather than ActionsService (in-cluster-only + // public path) so we're exercising the actual gate. + h := &AuthHandlerConfig{AuthConfig: config.Config{}, CookieManager: CookieManager{}} + mw := GetAuthenticationHTTPInterceptor(h) + + var sb servedBy + req := httptest.NewRequest(http.MethodPost, "/flyteidl2.workflow.RunService/CreateRun", nil) + req.RemoteAddr = "10.1.42.7:48221" + w := httptest.NewRecorder() + mw(sb.handler()).ServeHTTP(w, req) + + assert.Equal(t, http.StatusUnauthorized, w.Code) + assert.False(t, sb.called) +} + +func TestIsLoopbackRequest(t *testing.T) { + cases := map[string]bool{ + "127.0.0.1:1234": true, + "127.1.2.3:80": true, + "[::1]:8080": true, + "10.0.0.1:8080": false, + "192.168.1.1:80": false, + "": false, + "bogus": false, + } + for addr, want := range cases { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = addr + assert.Equalf(t, want, isLoopbackRequest(req), "isLoopbackRequest(%q)", addr) + } +} + +func TestMiddleware_NoAuthReturns401(t *testing.T) { + // AuthHandlerConfig missing a CookieManager will cause IdentityContextFromRequest + // to fail when no bearer header is present. The middleware must convert that to 401. + h := &AuthHandlerConfig{ + AuthConfig: config.Config{}, + CookieManager: CookieManager{}, + ResourceServer: nil, + } + mw := GetAuthenticationHTTPInterceptor(h) + + var sb servedBy + req := httptest.NewRequest(http.MethodGet, "/flyteidl2.workflow.RunService/CreateRun", nil) + w := httptest.NewRecorder() + mw(sb.handler()).ServeHTTP(w, req) + + assert.Equal(t, http.StatusUnauthorized, w.Code) + assert.False(t, sb.called, "protected path must not reach next handler without auth") +} + diff --git a/runs/service/auth/identity_context.go b/runs/service/auth/identity_context.go new file mode 100644 index 00000000000..7f1159f41b8 --- /dev/null +++ b/runs/service/auth/identity_context.go @@ -0,0 +1,80 @@ +package auth + +import ( + "context" + "time" + + authpb "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" +) + +type contextKey string + +const contextKeyIdentityContext contextKey = "identity_context" + +// ScopeAll is the default scope granted to user-only access tokens. +const ScopeAll = "all" + +// IdentityContext encloses the authenticated identity of the user/app. Both gRPC and HTTP +// servers have interceptors to set the IdentityContext on the context.AuthenticationContext. +type IdentityContext struct { + audience string + userID string + appID string + authenticatedAt time.Time + userInfo *authpb.UserInfoResponse + scopes []string + claims map[string]interface{} +} + +// NewIdentityContext creates a new IdentityContext. +func NewIdentityContext(audience, userID, appID string, authenticatedAt time.Time, scopes []string, userInfo *authpb.UserInfoResponse, claims map[string]interface{}) *IdentityContext { + if userInfo == nil { + userInfo = &authpb.UserInfoResponse{} + } + + if len(userInfo.Subject) == 0 { + userInfo.Subject = userID + } + + return &IdentityContext{ + audience: audience, + userID: userID, + appID: appID, + authenticatedAt: authenticatedAt, + userInfo: userInfo, + scopes: scopes, + claims: claims, + } +} + +func (c *IdentityContext) Audience() string { return c.audience } +func (c *IdentityContext) UserID() string { return c.userID } +func (c *IdentityContext) AppID() string { return c.appID } +func (c *IdentityContext) AuthenticatedAt() time.Time { return c.authenticatedAt } +func (c *IdentityContext) Scopes() []string { return c.scopes } +func (c *IdentityContext) Claims() map[string]interface{} { return c.claims } + +func (c *IdentityContext) UserInfo() *authpb.UserInfoResponse { + if c.userInfo == nil { + return &authpb.UserInfoResponse{} + } + return c.userInfo +} + +func (c *IdentityContext) IsEmpty() bool { + return c == nil || (c.audience == "" && c.userID == "" && c.appID == "") +} + +// WithContext stores the IdentityContext in the given context. +func (c *IdentityContext) WithContext(ctx context.Context) context.Context { + return context.WithValue(ctx, contextKeyIdentityContext, c) +} + +// IdentityContextFromContext retrieves the authenticated identity from context.AuthenticationContext. +func IdentityContextFromContext(ctx context.Context) *IdentityContext { + existing := ctx.Value(contextKeyIdentityContext) + if existing != nil { + return existing.(*IdentityContext) + } + return nil +} diff --git a/runs/service/auth/identity_context_test.go b/runs/service/auth/identity_context_test.go new file mode 100644 index 00000000000..1dc1950026d --- /dev/null +++ b/runs/service/auth/identity_context_test.go @@ -0,0 +1,65 @@ +package auth + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + authpb "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" +) + +func TestNewIdentityContext(t *testing.T) { + userInfo := &authpb.UserInfoResponse{ + Name: "Test User", + Email: "test@example.com", + } + now := time.Now() + + ic := NewIdentityContext("aud", "user1", "app1", now, []string{"read"}, userInfo, nil) + + assert.Equal(t, "aud", ic.Audience()) + assert.Equal(t, "user1", ic.UserID()) + assert.Equal(t, "app1", ic.AppID()) + assert.Equal(t, now, ic.AuthenticatedAt()) + assert.Equal(t, []string{"read"}, ic.Scopes()) + assert.Equal(t, "Test User", ic.UserInfo().Name) + // Subject should be filled from userID when empty + assert.Equal(t, "user1", ic.UserInfo().Subject) +} + +func TestNewIdentityContext_PreservesSubject(t *testing.T) { + userInfo := &authpb.UserInfoResponse{ + Subject: "existing-sub", + } + + ic := NewIdentityContext("aud", "user1", "", time.Time{}, nil, userInfo, nil) + assert.Equal(t, "existing-sub", ic.UserInfo().Subject) +} + +func TestNewIdentityContext_NilUserInfo(t *testing.T) { + ic := NewIdentityContext("aud", "user1", "", time.Time{}, nil, nil, nil) + require.NotNil(t, ic.UserInfo()) + assert.Equal(t, "user1", ic.UserInfo().Subject) +} + +func TestIdentityContext_IsEmpty(t *testing.T) { + assert.True(t, (*IdentityContext)(nil).IsEmpty()) + assert.True(t, (&IdentityContext{}).IsEmpty()) + assert.False(t, (&IdentityContext{userID: "u"}).IsEmpty()) +} + +func TestIdentityContext_WithContext(t *testing.T) { + ic := NewIdentityContext("aud", "user1", "app1", time.Now(), nil, nil, nil) + ctx := ic.WithContext(context.Background()) + + retrieved := IdentityContextFromContext(ctx) + require.NotNil(t, retrieved) + assert.Equal(t, "user1", retrieved.UserID()) +} + +func TestIdentityContextFromContext_Empty(t *testing.T) { + assert.Nil(t, IdentityContextFromContext(context.Background())) +} diff --git a/runs/service/auth/interceptor.go b/runs/service/auth/interceptor.go new file mode 100644 index 00000000000..55f7332315e --- /dev/null +++ b/runs/service/auth/interceptor.go @@ -0,0 +1,98 @@ +package auth + +import ( + "context" + "fmt" + + "github.com/coreos/go-oidc/v3/oidc" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + + "github.com/flyteorg/flyte/v2/flytestdlib/logger" + "github.com/flyteorg/flyte/v2/runs/service/auth/config" +) + +// BlanketAuthorization is a gRPC unary interceptor that checks the authenticated identity has the "all" scope. +func BlanketAuthorization(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) ( + resp interface{}, err error) { + + identityContext := IdentityContextFromContext(ctx) + if identityContext == nil { + return handler(ctx, req) + } + + for _, scope := range identityContext.Scopes() { + if scope == ScopeAll { + return handler(ctx, req) + } + } + + logger.Debugf(ctx, "authenticated user doesn't have required scope") + return nil, status.Errorf(codes.Unauthenticated, "authenticated user doesn't have required scope") +} + +// GetAuthenticationCustomMetadataInterceptor produces a gRPC interceptor that translates a custom authorization +// header name to the standard "authorization" header for downstream interceptors. +func GetAuthenticationCustomMetadataInterceptor(cfg config.Config) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + if cfg.GrpcAuthorizationHeader != DefaultAuthorizationHeader { + md, ok := metadata.FromIncomingContext(ctx) + if ok { + existingHeader := md.Get(cfg.GrpcAuthorizationHeader) + if len(existingHeader) > 0 { + logger.Debugf(ctx, "Found existing metadata header %s", cfg.GrpcAuthorizationHeader) + newAuthorizationMetadata := metadata.Pairs(DefaultAuthorizationHeader, existingHeader[0]) + joinedMetadata := metadata.Join(md, newAuthorizationMetadata) + newCtx := metadata.NewIncomingContext(ctx, joinedMetadata) + return handler(newCtx, req) + } + } + } + return handler(ctx, req) + } +} + +// GetAuthenticationInterceptor returns a function that validates incoming gRPC requests. +// It attempts to extract and validate an access token or ID token from the request metadata. +func GetAuthenticationInterceptor(cfg config.Config, resourceServer OAuth2ResourceServer, oidcProvider *oidc.Provider) func(context.Context) (context.Context, error) { + return func(ctx context.Context) (context.Context, error) { + logger.Debugf(ctx, "Running authentication gRPC interceptor") + + expectedAudience := GetPublicURL(ctx, nil, cfg).String() + + identityContext, accessTokenErr := GRPCGetIdentityFromAccessToken(ctx, expectedAudience, resourceServer) + if accessTokenErr == nil { + return identityContext.WithContext(ctx), nil + } + + logger.Infof(ctx, "Failed to parse Access Token from context. Will attempt to find IDToken. Error: %v", accessTokenErr) + + identityContext, idTokenErr := GRPCGetIdentityFromIDToken(ctx, cfg.UserAuth.OpenID.ClientID, oidcProvider) + if idTokenErr == nil { + return identityContext.WithContext(ctx), nil + } + logger.Debugf(ctx, "Failed to parse ID Token from context. Error: %v", idTokenErr) + + if !cfg.DisableForGrpc { + err := fmt.Errorf("[id token err: %w] | [access token err: %w]", idTokenErr, accessTokenErr) + return ctx, status.Errorf(codes.Unauthenticated, "token parse error %s", err) + } + + return ctx, nil + } +} + +// AuthenticationLoggingInterceptor logs information about the authenticated user for each gRPC request. +func AuthenticationLoggingInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + identityContext := IdentityContextFromContext(ctx) + if identityContext != nil { + var emailPlaceholder string + if len(identityContext.UserInfo().GetEmail()) > 0 { + emailPlaceholder = fmt.Sprintf(" (%s) ", identityContext.UserInfo().GetEmail()) + } + logger.Debugf(ctx, "gRPC server info in logging interceptor [%s]%smethod [%s]\n", identityContext.UserID(), emailPlaceholder, info.FullMethod) + } + return handler(ctx, req) +} diff --git a/runs/service/auth/interfaces.go b/runs/service/auth/interfaces.go new file mode 100644 index 00000000000..617ff457137 --- /dev/null +++ b/runs/service/auth/interfaces.go @@ -0,0 +1,8 @@ +package auth + +import "context" + +// OAuth2ResourceServer represents a resource server that can validate access tokens. +type OAuth2ResourceServer interface { + ValidateAccessToken(ctx context.Context, expectedAudience, tokenStr string) (*IdentityContext, error) +} diff --git a/runs/service/auth/token.go b/runs/service/auth/token.go new file mode 100644 index 00000000000..466f2f572ff --- /dev/null +++ b/runs/service/auth/token.go @@ -0,0 +1,176 @@ +package auth + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/coreos/go-oidc/v3/oidc" + "golang.org/x/oauth2" + "google.golang.org/grpc/metadata" + + authpb "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" + "github.com/flyteorg/flyte/v2/flytestdlib/logger" +) + +// GetRefreshedToken refreshes a JWT using the provided OAuth2 config and refresh token. +func GetRefreshedToken(ctx context.Context, oauth *oauth2.Config, accessToken, refreshToken string) (*oauth2.Token, error) { + logger.Debugf(ctx, "Attempting to refresh token") + originalToken := oauth2.Token{ + AccessToken: accessToken, + RefreshToken: refreshToken, + Expiry: time.Now().Add(-1 * time.Minute), // force expired by setting to the past + } + + tokenSource := oauth.TokenSource(ctx, &originalToken) + newToken, err := tokenSource.Token() + if err != nil { + logger.Errorf(ctx, "Error refreshing token %s", err) + return nil, fmt.Errorf("error refreshing token: %w", err) + } + + return newToken, nil +} + +// ParseIDTokenAndValidate parses and validates an ID token using the OIDC provider. +func ParseIDTokenAndValidate(ctx context.Context, clientID, rawIDToken string, provider *oidc.Provider) (*oidc.IDToken, error) { + cfg := &oidc.Config{ + ClientID: clientID, + } + + if len(clientID) == 0 { + cfg.SkipClientIDCheck = true + cfg.SkipIssuerCheck = true + cfg.SkipExpiryCheck = true + } + + verifier := provider.Verifier(cfg) + + idToken, err := verifier.Verify(ctx, rawIDToken) + if err != nil { + logger.Debugf(ctx, "JWT parsing with claims failed %s", err) + if strings.Contains(err.Error(), "token is expired") { + return idToken, fmt.Errorf("token is expired: %w", err) + } + return idToken, fmt.Errorf("jwt parse with claims failed: %w", err) + } + return idToken, nil +} + +// GRPCGetIdentityFromAccessToken attempts to extract a bearer token from gRPC metadata +// and validate it using the provided resource server. +func GRPCGetIdentityFromAccessToken(ctx context.Context, expectedAudience string, resourceServer OAuth2ResourceServer) ( + *IdentityContext, error) { + + tokenStr, err := bearerTokenFromMD(ctx) + if err != nil { + return nil, fmt.Errorf("could not retrieve bearer token from metadata: %w", err) + } + + return resourceServer.ValidateAccessToken(ctx, expectedAudience, tokenStr) +} + +// GRPCGetIdentityFromIDToken attempts to extract an ID token from gRPC metadata and validate it. +func GRPCGetIdentityFromIDToken(ctx context.Context, clientID string, provider *oidc.Provider) ( + *IdentityContext, error) { + + tokenStr, err := idTokenFromMD(ctx) + if err != nil { + return nil, fmt.Errorf("could not retrieve id token from metadata: %w", err) + } + + return IdentityContextFromIDToken(ctx, tokenStr, clientID, provider, nil) +} + +// IdentityContextFromIDToken creates an IdentityContext from a validated ID token. +func IdentityContextFromIDToken(ctx context.Context, tokenStr, clientID string, provider *oidc.Provider, + userInfo *authpb.UserInfoResponse) (*IdentityContext, error) { + + idToken, err := ParseIDTokenAndValidate(ctx, clientID, tokenStr, provider) + if err != nil { + return nil, err + } + var claims map[string]interface{} + if err := idToken.Claims(&claims); err != nil { + logger.Infof(ctx, "Failed to unmarshal claims from id token, err: %v", err) + } + + return NewIdentityContext(idToken.Audience[0], idToken.Subject, "", idToken.IssuedAt, + []string{ScopeAll}, userInfo, claims), nil +} + +func NewOAuthTokenFromRaw(accessToken, refreshToken, idToken string) *oauth2.Token { + return (&oauth2.Token{ + AccessToken: accessToken, + RefreshToken: refreshToken, + }).WithExtra(map[string]interface{}{ + idTokenExtra: idToken, + }) +} + +func ExtractTokensFromOauthToken(token *oauth2.Token) (idToken, accessToken, refreshToken string, err error) { + if token == nil { + return "", "", "", fmt.Errorf("attempting to set cookies with nil token") + } + + idTokenRaw, converted := token.Extra(idTokenExtra).(string) + if !converted { + return "", "", "", fmt.Errorf("response does not contain an id_token") + } + + return idTokenRaw, token.AccessToken, token.RefreshToken, nil +} + +// bearerTokenFromMD extracts a Bearer token from gRPC incoming metadata. +func bearerTokenFromMD(ctx context.Context) (string, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return "", fmt.Errorf("no metadata in context") + } + + vals := md.Get(DefaultAuthorizationHeader) + if len(vals) == 0 { + return "", fmt.Errorf("no authorization header in metadata") + } + + header := vals[0] + prefix := BearerScheme + " " + if !strings.HasPrefix(header, prefix) { + return "", fmt.Errorf("authorization header does not start with %q", BearerScheme) + } + + token := strings.TrimPrefix(header, prefix) + if token == "" { + return "", fmt.Errorf("bearer token is blank") + } + + return token, nil +} + +// idTokenFromMD extracts an IDToken from gRPC incoming metadata. +func idTokenFromMD(ctx context.Context) (string, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return "", fmt.Errorf("no metadata in context") + } + + vals := md.Get(DefaultAuthorizationHeader) + if len(vals) == 0 { + return "", fmt.Errorf("no authorization header in metadata") + } + + header := vals[0] + prefix := IDTokenScheme + " " + if !strings.HasPrefix(header, prefix) { + return "", fmt.Errorf("authorization header does not start with %q", IDTokenScheme) + } + + token := strings.TrimPrefix(header, prefix) + if token == "" { + return "", fmt.Errorf("id token is blank") + } + + return token, nil +} + diff --git a/runs/service/auth/token_test.go b/runs/service/auth/token_test.go new file mode 100644 index 00000000000..5c6a8685979 --- /dev/null +++ b/runs/service/auth/token_test.go @@ -0,0 +1,94 @@ +package auth + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" + "google.golang.org/grpc/metadata" +) + +func TestNewOAuthTokenFromRaw(t *testing.T) { + tok := NewOAuthTokenFromRaw("access", "refresh", "id-token") + require.NotNil(t, tok) + assert.Equal(t, "access", tok.AccessToken) + assert.Equal(t, "refresh", tok.RefreshToken) + assert.Equal(t, "id-token", tok.Extra(idTokenExtra)) +} + +func TestExtractTokensFromOauthToken(t *testing.T) { + src := NewOAuthTokenFromRaw("a", "r", "i") + id, access, refresh, err := ExtractTokensFromOauthToken(src) + require.NoError(t, err) + assert.Equal(t, "i", id) + assert.Equal(t, "a", access) + assert.Equal(t, "r", refresh) +} + +func TestExtractTokensFromOauthToken_Nil(t *testing.T) { + _, _, _, err := ExtractTokensFromOauthToken(nil) + assert.Error(t, err) +} + +func TestExtractTokensFromOauthToken_MissingIDToken(t *testing.T) { + // bare oauth2.Token without id_token extra should fail + tok := &oauth2.Token{AccessToken: "a", RefreshToken: "r"} + _, _, _, err := ExtractTokensFromOauthToken(tok) + assert.Error(t, err) +} + +func ctxWithMD(pairs ...string) context.Context { + md := metadata.Pairs(pairs...) + return metadata.NewIncomingContext(context.Background(), md) +} + +func TestBearerTokenFromMD(t *testing.T) { + ctx := ctxWithMD(DefaultAuthorizationHeader, "Bearer my-token") + tok, err := bearerTokenFromMD(ctx) + require.NoError(t, err) + assert.Equal(t, "my-token", tok) +} + +func TestBearerTokenFromMD_NoMetadata(t *testing.T) { + _, err := bearerTokenFromMD(context.Background()) + assert.Error(t, err) +} + +func TestBearerTokenFromMD_MissingHeader(t *testing.T) { + _, err := bearerTokenFromMD(ctxWithMD("other", "v")) + assert.Error(t, err) +} + +func TestBearerTokenFromMD_WrongScheme(t *testing.T) { + _, err := bearerTokenFromMD(ctxWithMD(DefaultAuthorizationHeader, "IDToken abc")) + assert.Error(t, err) +} + +func TestBearerTokenFromMD_Blank(t *testing.T) { + _, err := bearerTokenFromMD(ctxWithMD(DefaultAuthorizationHeader, "Bearer ")) + assert.Error(t, err) +} + +func TestIDTokenFromMD(t *testing.T) { + ctx := ctxWithMD(DefaultAuthorizationHeader, "IDToken my-id-token") + tok, err := idTokenFromMD(ctx) + require.NoError(t, err) + assert.Equal(t, "my-id-token", tok) +} + +func TestIDTokenFromMD_WrongScheme(t *testing.T) { + _, err := idTokenFromMD(ctxWithMD(DefaultAuthorizationHeader, "Bearer abc")) + assert.Error(t, err) +} + +func TestIDTokenFromMD_Blank(t *testing.T) { + _, err := idTokenFromMD(ctxWithMD(DefaultAuthorizationHeader, "IDToken ")) + assert.Error(t, err) +} + +func TestIDTokenFromMD_NoMetadata(t *testing.T) { + _, err := idTokenFromMD(context.Background()) + assert.Error(t, err) +} diff --git a/runs/service/auth/user_info_provider.go b/runs/service/auth/user_info_provider.go new file mode 100644 index 00000000000..eda30e6889f --- /dev/null +++ b/runs/service/auth/user_info_provider.go @@ -0,0 +1,28 @@ +package auth + +import ( + "context" + + "connectrpc.com/connect" + + authpb "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth/authconnect" +) + +// UserInfoProvider serves user info claims about the currently logged in user. +// See the OpenID Connect spec at https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse +type UserInfoProvider struct { + authconnect.UnimplementedIdentityServiceHandler +} + +func NewUserInfoProvider() *UserInfoProvider { + return &UserInfoProvider{} +} + +func (s *UserInfoProvider) UserInfo(ctx context.Context, _ *connect.Request[authpb.UserInfoRequest]) (*connect.Response[authpb.UserInfoResponse], error) { + identityContext := IdentityContextFromContext(ctx) + if identityContext != nil { + return connect.NewResponse(identityContext.UserInfo()), nil + } + return connect.NewResponse(&authpb.UserInfoResponse{}), nil +} diff --git a/runs/service/identity_service.go b/runs/service/identity_service.go deleted file mode 100644 index 7484d6c71fb..00000000000 --- a/runs/service/identity_service.go +++ /dev/null @@ -1,29 +0,0 @@ -package service - -import ( - "context" - - "connectrpc.com/connect" - - "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" - "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth/authconnect" -) - -// IdentityService implements the IdentityServiceHandler interface. -type IdentityService struct{} - -// NewIdentityService creates a new IdentityService instance. -func NewIdentityService() *IdentityService { - return &IdentityService{} -} - -var _ authconnect.IdentityServiceHandler = (*IdentityService)(nil) - -// UserInfo returns information about the currently logged in user. -// TODO: Wire with real auth to populate user info from the authenticated context. -func (s *IdentityService) UserInfo( - ctx context.Context, - req *connect.Request[auth.UserInfoRequest], -) (*connect.Response[auth.UserInfoResponse], error) { - return connect.NewResponse(&auth.UserInfoResponse{}), nil -} diff --git a/runs/service/identity_service_test.go b/runs/service/identity_service_test.go deleted file mode 100644 index a1839954398..00000000000 --- a/runs/service/identity_service_test.go +++ /dev/null @@ -1,21 +0,0 @@ -package service - -import ( - "context" - "testing" - - "connectrpc.com/connect" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth" -) - -func TestIdentityService_UserInfo(t *testing.T) { - svc := NewIdentityService() - - resp, err := svc.UserInfo(context.Background(), connect.NewRequest(&auth.UserInfoRequest{})) - require.NoError(t, err) - assert.NotNil(t, resp) - assert.NotNil(t, resp.Msg) -} diff --git a/runs/setup.go b/runs/setup.go index 85e90f17237..15f6129e9b0 100644 --- a/runs/setup.go +++ b/runs/setup.go @@ -5,12 +5,16 @@ import ( "errors" "fmt" "net/http" + "os" + "path/filepath" + "strings" "time" "github.com/flyteorg/flyte/v2/flytestdlib/app" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/actions/actionsconnect" flyteappconnect "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/app/appconnect" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/auth/authconnect" + projectpb "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/project" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/project/projectconnect" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/task/taskconnect" @@ -24,10 +28,17 @@ import ( "github.com/flyteorg/flyte/v2/runs/repository/models" "github.com/flyteorg/flyte/v2/runs/scheduler" "github.com/flyteorg/flyte/v2/runs/service" + authservice "github.com/flyteorg/flyte/v2/runs/service/auth" + "github.com/flyteorg/flyte/v2/runs/service/auth/authzserver" + authConfig "github.com/flyteorg/flyte/v2/runs/service/auth/config" "github.com/flyteorg/flyte/v2/flytestdlib/logger" ) +// authSecretsDir is the directory where cookie hash/block keys and other auth +// secrets are mounted (matches the flyte-binary chart volume mount path). +const authSecretsDir = "/etc/secrets" + // Setup registers Run and Task service handlers on the SetupContext mux. // Requires sc.DB and sc.DataStore to be set. When sc.K8sConfig is provided, // RunLogsService is also mounted to enable pod log streaming. @@ -83,15 +94,18 @@ func Setup(ctx context.Context, sc *app.SetupContext) error { sc.Mux.Handle(translatorPath, translatorHandler) logger.Infof(ctx, "Mounted TranslatorService at %s", translatorPath) - identitySvc := service.NewIdentityService() - identityPath, identityHandler := authconnect.NewIdentityServiceHandler(identitySvc) - sc.Mux.Handle(identityPath, identityHandler) - logger.Infof(ctx, "Mounted IdentityService at %s", identityPath) - - authMetadataSvc := service.NewAuthMetadataService(sc.BaseURL) - authMetadataPath, authMetadataHandler := authconnect.NewAuthMetadataServiceHandler(authMetadataSvc) - sc.Mux.Handle(authMetadataPath, authMetadataHandler) - logger.Infof(ctx, "Mounted AuthMetadataService at %s", authMetadataPath) + if cfg.Security.UseAuth { + if err := setupAuth(ctx, sc); err != nil { + return fmt.Errorf("runs: failed to set up auth: %w", err) + } + } else { + // When auth is disabled, still mount a stub AuthMetadataService so + // clients performing metadata discovery get a coherent response. + authMetadataSvc := service.NewAuthMetadataService(sc.BaseURL) + authMetadataPath, authMetadataHandler := authconnect.NewAuthMetadataServiceHandler(authMetadataSvc) + sc.Mux.Handle(authMetadataPath, authMetadataHandler) + logger.Infof(ctx, "Mounted stub AuthMetadataService at %s", authMetadataPath) + } appSvc := service.NewAppService() appPath, appHandler := flyteappconnect.NewAppServiceHandler(appSvc) @@ -150,6 +164,93 @@ func Setup(ctx context.Context, sc *app.SetupContext) error { return nil } +// setupAuth wires up the external-mode OAuth2 resource server, OIDC browser +// handlers, AuthMetadataService / IdentityService, and a bearer-token +// validating HTTP middleware on the shared mux. It requires that the auth +// config section is populated and that cookie hash/block keys are present as +// files under authSecretsDir. +func setupAuth(ctx context.Context, sc *app.SetupContext) error { + authCfg := authConfig.GetConfig() + + // Mount the real AuthMetadataService backed by the configured issuer. + authMetadataSvc := authzserver.NewAuthMetadataService(*authCfg) + authPath, authHandler := authconnect.NewAuthMetadataServiceHandler(authMetadataSvc) + sc.Mux.Handle(authPath, authHandler) + logger.Infof(ctx, "Mounted AuthMetadataService at %s", authPath) + + identitySvc := authservice.NewUserInfoProvider() + identityPath, identityHandler := authconnect.NewIdentityServiceHandler(identitySvc) + sc.Mux.Handle(identityPath, identityHandler) + logger.Infof(ctx, "Mounted IdentityService at %s", identityPath) + + hashKey, err := readSecretFile(authConfig.SecretNameCookieHashKey) + if err != nil { + return err + } + blockKey, err := readSecretFile(authConfig.SecretNameCookieBlockKey) + if err != nil { + return err + } + + // Load the OIDC client secret used during the OAuth2 code exchange. The + // filename is configurable so that a deployment can swap the secret name + // without redeploying the binary. + oidcClientSecretName := authCfg.UserAuth.OpenID.ClientSecretName + if oidcClientSecretName == "" { + oidcClientSecretName = authConfig.SecretNameOIdCClientSecret + } + oidcClientSecret, err := readSecretFile(oidcClientSecretName) + if err != nil { + return err + } + + // Validate tokens issued by the configured external authorization server. + // If BaseURL is empty, the resource server falls back to the first authorizedUri. + var fallbackURL authConfig.URL + if len(authCfg.AuthorizedURIs) > 0 { + fallbackURL = authCfg.AuthorizedURIs[0] + } + resourceServer, err := authzserver.NewOAuth2ResourceServer(ctx, authCfg.AppAuth.ExternalAuthServer, fallbackURL) + if err != nil { + return fmt.Errorf("failed to create OAuth2 resource server: %w", err) + } + + authCtx, err := authservice.NewAuthContext(ctx, *authCfg, resourceServer, hashKey, blockKey, oidcClientSecret) + if err != nil { + return fmt.Errorf("failed to create auth context: %w", err) + } + + // Register /login, /callback, /logout, /.well-known/openid-configuration. + authservice.RegisterHandlers(ctx, sc.Mux, authCtx.HandlerConfig()) + logger.Infof(ctx, "Registered OIDC browser handlers (/login, /callback, /logout)") + + // Chain the bearer/cookie auth middleware with any existing middleware + // (e.g. CORS). Ordering: request -> CORS -> auth -> mux. + prev := sc.Middleware + authMw := authservice.GetAuthenticationHTTPInterceptor(authCtx.HandlerConfig()) + sc.Middleware = func(next http.Handler) http.Handler { + wrapped := authMw(next) + if prev != nil { + wrapped = prev(wrapped) + } + return wrapped + } + logger.Infof(ctx, "Auth middleware installed; audience=%s", authCfg.AppAuth.ExternalAuthServer.BaseURL.String()) + + return nil +} + +// readSecretFile reads a base64-encoded key file from authSecretsDir and +// returns the trimmed string contents. +func readSecretFile(name string) (string, error) { + path := filepath.Join(authSecretsDir, name) + b, err := os.ReadFile(path) + if err != nil { + return "", fmt.Errorf("failed to read auth secret %s: %w", path, err) + } + return strings.TrimSpace(string(b)), nil +} + func seedProjects(ctx context.Context, projectRepo interfaces.ProjectRepo, projects []string) error { for _, projectID := range projects { if projectID == "" {