Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions admin/server/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package auth

import (
"context"
"strings"

"github.com/coreos/go-oidc/v3/oidc"
"github.com/rilldata/rill/admin"
Expand All @@ -28,21 +29,34 @@ type AuthenticatorOptions struct {
// It provides endpoints for login/logout, creates users, issues cookie-based auth tokens, and provides middleware for authenticating requests.
// The implementation was derived from: https://auth0.com/docs/quickstart/webapp/golang/01-login.
type Authenticator struct {
logger *zap.Logger
admin *admin.Service
cookies *cookies.Store
opts *AuthenticatorOptions
oidc *oidc.Provider
oauth2 oauth2.Config
logger *zap.Logger
admin *admin.Service
cookies *cookies.Store
opts *AuthenticatorOptions
oidc *oidc.Provider
oauth2 oauth2.Config
endSessionEndpoint string
}

// NewAuthenticator creates an Authenticator.
func NewAuthenticator(logger *zap.Logger, adm *admin.Service, cookieStore *cookies.Store, opts *AuthenticatorOptions) (*Authenticator, error) {
oidcProvider, err := oidc.NewProvider(context.Background(), "https://"+opts.AuthDomain+"/")
// AuthDomain with "://" is a full issuer URL (Keycloak, Dex, etc.);
// without it, assume Auth0-style domain and append trailing slash.
issuerURL := opts.AuthDomain
if !strings.Contains(issuerURL, "://") {
issuerURL = "https://" + issuerURL + "/"
}

oidcProvider, err := oidc.NewProvider(context.Background(), issuerURL)
if err != nil {
return nil, err
}

var claims struct {
EndSessionEndpoint string `json:"end_session_endpoint"`
}
_ = oidcProvider.Claims(&claims)

oauth2Config := oauth2.Config{
ClientID: opts.AuthClientID,
ClientSecret: opts.AuthClientSecret,
Expand All @@ -52,12 +66,13 @@ func NewAuthenticator(logger *zap.Logger, adm *admin.Service, cookieStore *cooki
}

a := &Authenticator{
logger: logger,
admin: adm,
cookies: cookieStore,
opts: opts,
oidc: oidcProvider,
oauth2: oauth2Config,
logger: logger,
admin: adm,
cookies: cookieStore,
opts: opts,
oidc: oidcProvider,
oauth2: oauth2Config,
endSessionEndpoint: claims.EndSessionEndpoint,
}

return a, nil
Expand Down
24 changes: 15 additions & 9 deletions admin/server/auth/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,7 @@ func (a *Authenticator) authStart(w http.ResponseWriter, r *http.Request, signup
// Redirect to auth provider (canonical domain flow)
redirectURL := a.oauth2.AuthCodeURL(state)
if signup {
// Set custom parameters for signup using AuthCodeOption
customOption := oauth2.SetAuthURLParam("screen_hint", "signup")
redirectURL = a.oauth2.AuthCodeURL(state, customOption)
redirectURL = a.oauth2.AuthCodeURL(state, oauth2.SetAuthURLParam("prompt", "create"))
}

http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect)
Expand Down Expand Up @@ -611,16 +609,24 @@ func (a *Authenticator) authLogoutProvider(w http.ResponseWriter, r *http.Reques
}
}

// Build and redirect to the auth provider logout URL.
logoutURL, err := url.Parse("https://" + a.opts.AuthDomain + "/v2/logout")
// Build the provider logout URL.
// Standard OIDC providers expose end_session_endpoint; Auth0 uses /v2/logout with "returnTo".
logoutEndpoint := a.endSessionEndpoint
redirectParam := "post_logout_redirect_uri"
if logoutEndpoint == "" {
logoutEndpoint = "https://" + a.opts.AuthDomain + "/v2/logout"
redirectParam = "returnTo"
}

logoutURL, err := url.Parse(logoutEndpoint)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
parameters := url.Values{}
parameters.Add("returnTo", a.admin.URLs.AuthLogoutCallback())
parameters.Add("client_id", a.opts.AuthClientID)
logoutURL.RawQuery = parameters.Encode()
params := url.Values{}
params.Set("client_id", a.opts.AuthClientID)
params.Set(redirectParam, a.admin.URLs.AuthLogoutCallback())
logoutURL.RawQuery = params.Encode()
http.Redirect(w, r, logoutURL.String(), http.StatusTemporaryRedirect)
}

Expand Down