Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions rest/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,67 @@ func (c *Context) authMethodsByID(w http.ResponseWriter, r *http.Request) {
}
}

func (c *Context) authMethodsByIDRedirect(w http.ResponseWriter, r *http.Request) {
switch r.PathValue("method") {
case "saml":
id := r.PathValue("id")
samlProviderId := -1
for k := range *c.SAML.Providers {
if (*c.SAML.Providers)[k].ID == id {
samlProviderId = k
}
}
if samlProviderId == -1 {
c.returnError(w, fmt.Errorf("cannot find saml provider"), http.StatusBadRequest)
return
}
redirectURI, err := c.SAML.Client.GetAuthURL((*c.SAML.Providers)[samlProviderId])
if err != nil {
c.returnError(w, fmt.Errorf("cannot get auth url"), http.StatusBadRequest)
return
}

sendCorsHeaders(w, "", c.Hostname, c.Protocol)
http.Redirect(w, r, redirectURI, http.StatusFound)

return
default:
id := r.PathValue("id")
for _, oidcProvider := range c.OIDCProviders {
if id == oidcProvider.ID {
callback := fmt.Sprintf("%s://%s%s", c.Protocol, c.Hostname, oidcProvider.RedirectURI)
discovery, err := c.OIDCStore.GetDiscoveryURI(oidcProvider.DiscoveryURI)
if err != nil {
c.returnError(w, fmt.Errorf("getDiscoveryURI error: %s", err), http.StatusBadRequest)
return
}
redirectURI, state, err := oidc.GetRedirectURI(discovery, oidcProvider.ClientID, oidcProvider.Scope, callback, c.EnableOIDCTokenRenewal)
if err != nil {
c.returnError(w, fmt.Errorf("GetRedirectURI error: %s", err), http.StatusBadRequest)
return
}

newOAuthEntry := oidc.OAuthData{
ID: uuid.NewString(),
OIDCProviderID: oidcProvider.ID,
CreatedAt: time.Now(),
}
err = c.OIDCStore.SaveOAuth2Data(newOAuthEntry, state)
if err != nil {
c.returnError(w, fmt.Errorf("unable to save state to oidc store: %s", err), http.StatusBadRequest)
return
}

sendCorsHeaders(w, "", c.Hostname, c.Protocol)
http.Redirect(w, r, redirectURI, http.StatusFound)

return
}
}
c.returnError(w, fmt.Errorf("element not found"), http.StatusBadRequest)
}
}

func (c *Context) oidcRenewTokensHandler(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodPost:
Expand Down
287 changes: 287 additions & 0 deletions rest/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -944,3 +944,290 @@ func TestOIDCFlow(t *testing.T) {
t.Fatalf("no token received: %+v", loginResponse)
}
}

func TestOIDCRedirect(t *testing.T) {
testUrl := "127.0.0.1:12346"
l, err := net.Listen("tcp", testUrl)
if err != nil {
t.Fatal(err)
}

authURL := "http://" + testUrl + "/auth"

// create a new OIDC connection
oidcProvider := oidc.OIDCProvider{
Name: "test-oidc",
ClientID: "1-2-3-4",
ClientSecret: "9-9-9-9",
Scope: "openid",
DiscoveryURI: "http://" + testUrl + "/discovery.json",
}
jwtPrivateKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
t.Fatalf("can't generate jwt key: %s", err)
}

// first create a new user
c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN)
if err != nil {
t.Fatalf("Cannot create context")
}
c.Hostname = "example.inv"
c.Protocol = "http"
logging.Loglevel = 17

ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
code := "thisisthecode"

switch r.Method {
case http.MethodGet:
parsedURI, _ := url.Parse(r.RequestURI)
switch parsedURI.Path {
case "/discovery.json":
discovery := oidc.Discovery{
Issuer: "test-issuer",
AuthorizationEndpoint: authURL,
TokenEndpoint: "http://" + testUrl + "/token",
JwksURI: "http://" + testUrl + "/jwks.json",
}
out, err := json.Marshal(discovery)
if err != nil {
t.Fatalf("json marshal error: %s", err)
}
w.Write(out)
return
case "/auth":
if oidcProvider.ClientID != r.URL.Query().Get("client_id") {
w.Write([]byte("client id mismatch"))
w.WriteHeader(http.StatusBadRequest)
return
}
if oidcProvider.Scope != r.URL.Query().Get("scope") {
w.Write([]byte("scope mismatch"))
w.WriteHeader(http.StatusBadRequest)
return
}
w.Write([]byte(code))
case "/jwks.json":
publicKey := jwtPrivateKey.PublicKey

jwks := oidc.Jwks{
Keys: []oidc.JwksKey{
{
Kid: "kid-id-1234",
Alg: "RS256",
Kty: "RSA",
Use: "sig",
N: base64.RawURLEncoding.EncodeToString(publicKey.N.Bytes()),
E: "AQAB",
},
},
}
out, err := json.Marshal(jwks)
if err != nil {
w.Write([]byte("jwks marshal error"))
w.WriteHeader(http.StatusBadRequest)
return
}
w.Write(out)
default:
w.WriteHeader(http.StatusNotFound)
}
default:
w.WriteHeader(http.StatusBadRequest)
}
}))

ts.Listener.Close()
ts.Listener = l
ts.Start()
defer ts.Close()
defer l.Close()

payload, err := json.Marshal(oidcProvider)
if err != nil {
t.Fatal(err)
}

// create new oidc provider
req := httptest.NewRequest("POST", "http://example.inv/api/oidc", bytes.NewBuffer(payload))
w := httptest.NewRecorder()
c.oidcProviderHandler(w, req)

resp := w.Result()

if resp.StatusCode != 200 {
t.Fatalf("status code is not 200: %d", resp.StatusCode)
}

defer resp.Body.Close()

err = json.NewDecoder(resp.Body).Decode(&oidcProvider)
if err != nil {
t.Fatalf("Cannot decode response from create user: %s", err)
}

if oidcProvider.ID == "" {
t.Fatalf("Was expecting oidc provider to have an ID")
}

// get redirect URL
req = httptest.NewRequest("GET", "http://example.inv/api/authmethods/oidc/"+oidcProvider.ID+"/redirect", nil)
req.SetPathValue("id", oidcProvider.ID)
req.SetPathValue("method", "oidc")
w = httptest.NewRecorder()
c.authMethodsByIDRedirect(w, req)

resp = w.Result()
defer resp.Body.Close()

if resp.StatusCode != http.StatusFound {
t.Fatalf("status code is not 302: %d", resp.StatusCode)
}

redirectURL := resp.Header.Get("Location")
if redirectURL == "" {
t.Fatalf("no redirect URL in header")
}
if !strings.HasPrefix(redirectURL, authURL) {
t.Fatalf("expected authURL as prefix of redirect url. Redirect URL: %s", redirectURL)
}

if !strings.Contains(redirectURL, "state=") {
t.Fatalf("state parameter not found in redirect URL: %s", redirectURL)
}
parsedRedirectURL, err := url.Parse(redirectURL)
if err != nil {
t.Fatalf("could not parse redirect URL: %s", err)
}
state := parsedRedirectURL.Query().Get("state")
if state == "" {
t.Fatalf("state parameter is empty in redirect URL: %s", redirectURL)
}
}

func TestSAMLRedirect(t *testing.T) {
// generate new keypair
kp := saml.NewKeyPair(&memorystorage.MockMemoryStorage{}, "www.idp.inv")
_, cert, err := kp.GetKeyPair()
if err != nil {
t.Fatalf("Can't generate new keypair: %s", err)
}
certBase64 := base64.StdEncoding.EncodeToString(cert)

testUrl := "127.0.0.1:12347"
l, err := net.Listen("tcp", testUrl)
if err != nil {
t.Fatal(err)
}

singleSignOnURL := "http://" + testUrl + "/auth"

ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
requestURIParsed, _ := url.Parse(r.RequestURI)
if requestURIParsed.Path == "/auth" {
compressedSAMLReq, err := base64.StdEncoding.DecodeString(r.URL.Query().Get("SAMLRequest"))
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(fmt.Sprintf("saml base64 decode error: %s", err)))
return
}
samlRequest := new(bytes.Buffer)
decompressor := flate.NewReader(bytes.NewReader(compressedSAMLReq))
io.Copy(samlRequest, decompressor)
decompressor.Close()

var authnReq saml.AuthnRequest
err = xml.Unmarshal(samlRequest.Bytes(), &authnReq)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(fmt.Sprintf("saml authn request decode error: %s", err)))
return
}
w.Write([]byte("OK"))
return
}
if r.RequestURI == "/metadata" {
out, _ := xml.Marshal(getSAMLCertWithCustomCert(singleSignOnURL, certBase64))
w.Write(out)
return
}
w.WriteHeader(http.StatusBadRequest)
default:
w.WriteHeader(http.StatusBadRequest)
}
}))

ts.Listener.Close()
ts.Listener = l
ts.Start()
defer ts.Close()
defer l.Close()

// first create a new user
c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN)
if err != nil {
t.Fatalf("Cannot create context")
}

// create a new SAML connection
samlProvider := saml.Provider{
Name: "testProvider",
MetadataURL: fmt.Sprintf("%s/metadata", ts.URL),
}

payload, err := json.Marshal(samlProvider)
if err != nil {
t.Fatal(err)
}

req := httptest.NewRequest("POST", "http://example.inv/api/saml-setup", bytes.NewBuffer(payload))
w := httptest.NewRecorder()
c.samlSetupHandler(w, req)

resp := w.Result()

if resp.StatusCode != 200 {
t.Fatalf("status code is not 200: %d", resp.StatusCode)
}

defer resp.Body.Close()

err = json.NewDecoder(resp.Body).Decode(&samlProvider)
if err != nil {
t.Fatalf("Cannot decode response from create user: %s", err)
}

if samlProvider.ID == "" {
t.Fatalf("Was expecting saml provider to have an ID")
}

// get redirect URL
req = httptest.NewRequest("GET", "http://example.inv/api/authmethods/saml/"+samlProvider.ID+"/redirect", nil)
req.SetPathValue("id", samlProvider.ID)
req.SetPathValue("method", "saml")
w = httptest.NewRecorder()
c.authMethodsByIDRedirect(w, req)

resp = w.Result()
defer resp.Body.Close()

if resp.StatusCode != http.StatusFound {
t.Fatalf("status code is not 302: %d", resp.StatusCode)
}

redirectURL := resp.Header.Get("Location")
if redirectURL == "" {
t.Fatalf("no redirect URL in header")
}
if !strings.HasPrefix(redirectURL, singleSignOnURL) {
t.Fatalf("expected authURL as prefix of redirect url. Redirect URL: %s", redirectURL)
}

if !strings.Contains(redirectURL, "SAMLRequest=") {
t.Fatalf("SAMLRequest parameter not found in redirect URL: %s", redirectURL)
}

}
1 change: 1 addition & 0 deletions rest/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ func (c *Context) getRouter(assets fs.FS, indexHtml []byte) *http.ServeMux {
mux.Handle("/api/context", http.HandlerFunc(c.contextHandler))
mux.Handle("/api/auth", http.HandlerFunc(c.authHandler))
mux.Handle("/api/authmethods", http.HandlerFunc(c.authMethods))
mux.Handle("/api/authmethods/{method}/{id}/redirect", http.HandlerFunc(c.authMethodsByIDRedirect))
mux.Handle("/api/authmethods/{method}/{id}", http.HandlerFunc(c.authMethodsByID))
mux.Handle("/api/authmethods/{id}", http.HandlerFunc(c.authMethodsByID))
mux.Handle("/api/upgrade", http.HandlerFunc(c.upgrade))
Expand Down