diff --git a/rest/auth.go b/rest/auth.go index e9c1ac9..c5efad6 100644 --- a/rest/auth.go +++ b/rest/auth.go @@ -409,6 +409,67 @@ func (c *Context) authMethodsByID(w http.ResponseWriter, r *http.Request) { } } +func (c *Context) authMethodsByIDRedirect(w http.ResponseWriter, r *http.Request) { + switch r.PathValue("method") { + case "saml": + id := r.PathValue("id") + samlProviderId := -1 + for k := range *c.SAML.Providers { + if (*c.SAML.Providers)[k].ID == id { + samlProviderId = k + } + } + if samlProviderId == -1 { + c.returnError(w, fmt.Errorf("cannot find saml provider"), http.StatusBadRequest) + return + } + redirectURI, err := c.SAML.Client.GetAuthURL((*c.SAML.Providers)[samlProviderId]) + if err != nil { + c.returnError(w, fmt.Errorf("cannot get auth url"), http.StatusBadRequest) + return + } + + sendCorsHeaders(w, "", c.Hostname, c.Protocol) + http.Redirect(w, r, redirectURI, http.StatusFound) + + return + default: + id := r.PathValue("id") + for _, oidcProvider := range c.OIDCProviders { + if id == oidcProvider.ID { + callback := fmt.Sprintf("%s://%s%s", c.Protocol, c.Hostname, oidcProvider.RedirectURI) + discovery, err := c.OIDCStore.GetDiscoveryURI(oidcProvider.DiscoveryURI) + if err != nil { + c.returnError(w, fmt.Errorf("getDiscoveryURI error: %s", err), http.StatusBadRequest) + return + } + redirectURI, state, err := oidc.GetRedirectURI(discovery, oidcProvider.ClientID, oidcProvider.Scope, callback, c.EnableOIDCTokenRenewal) + if err != nil { + c.returnError(w, fmt.Errorf("GetRedirectURI error: %s", err), http.StatusBadRequest) + return + } + + newOAuthEntry := oidc.OAuthData{ + ID: uuid.NewString(), + OIDCProviderID: oidcProvider.ID, + CreatedAt: time.Now(), + } + err = c.OIDCStore.SaveOAuth2Data(newOAuthEntry, state) + if err != nil { + c.returnError(w, fmt.Errorf("unable to save state to oidc store: %s", err), http.StatusBadRequest) + return + } + + sendCorsHeaders(w, "", c.Hostname, c.Protocol) + http.Redirect(w, r, redirectURI, http.StatusFound) + + return + } + } + c.returnError(w, fmt.Errorf("element not found"), http.StatusBadRequest) + } +} + func (c *Context) oidcRenewTokensHandler(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodPost: diff --git a/rest/auth_test.go b/rest/auth_test.go index af31db9..d803567 100644 --- a/rest/auth_test.go +++ b/rest/auth_test.go @@ -944,3 +944,290 @@ func TestOIDCFlow(t *testing.T) { t.Fatalf("no token received: %+v", loginResponse) } } + +func TestOIDCRedirect(t *testing.T) { + testUrl := "127.0.0.1:12346" + l, err := net.Listen("tcp", testUrl) + if err != nil { + t.Fatal(err) + } + + authURL := "http://" + testUrl + "/auth" + + // create a new OIDC connection + oidcProvider := oidc.OIDCProvider{ + Name: "test-oidc", + ClientID: "1-2-3-4", + ClientSecret: "9-9-9-9", + Scope: "openid", + DiscoveryURI: "http://" + testUrl + "/discovery.json", + } + jwtPrivateKey, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + t.Fatalf("can't generate jwt key: %s", err) + } + + // first create a new user + c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN) + if err != nil { + t.Fatalf("Cannot create context") + } + c.Hostname = "example.inv" + c.Protocol = "http" + logging.Loglevel = 17 + + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + code := "thisisthecode" + + switch r.Method { + case http.MethodGet: + parsedURI, _ := url.Parse(r.RequestURI) + switch parsedURI.Path { + case "/discovery.json": + discovery := oidc.Discovery{ + Issuer: "test-issuer", + AuthorizationEndpoint: authURL, + TokenEndpoint: "http://" + testUrl + "/token", + JwksURI: "http://" + testUrl + "/jwks.json", + } + out, err := json.Marshal(discovery) + if err != nil { + t.Fatalf("json marshal error: %s", err) + } + w.Write(out) + return + case "/auth": + if oidcProvider.ClientID != r.URL.Query().Get("client_id") { + w.Write([]byte("client id mismatch")) + w.WriteHeader(http.StatusBadRequest) + return + } + if oidcProvider.Scope != r.URL.Query().Get("scope") { + w.Write([]byte("scope mismatch")) + w.WriteHeader(http.StatusBadRequest) + return + } + w.Write([]byte(code)) + case "/jwks.json": + publicKey := jwtPrivateKey.PublicKey + + jwks := oidc.Jwks{ + Keys: []oidc.JwksKey{ + { + Kid: "kid-id-1234", + Alg: "RS256", + Kty: "RSA", + Use: "sig", + N: base64.RawURLEncoding.EncodeToString(publicKey.N.Bytes()), + E: "AQAB", + }, + }, + } + out, err := json.Marshal(jwks) + if err != nil { + w.Write([]byte("jwks marshal error")) + w.WriteHeader(http.StatusBadRequest) + return + } + w.Write(out) + default: + w.WriteHeader(http.StatusNotFound) + } + default: + w.WriteHeader(http.StatusBadRequest) + } + })) + + ts.Listener.Close() + ts.Listener = l + ts.Start() + defer ts.Close() + defer l.Close() + + payload, err := json.Marshal(oidcProvider) + if err != nil { + t.Fatal(err) + } + + // create new oidc provider + req := httptest.NewRequest("POST", "http://example.inv/api/oidc", bytes.NewBuffer(payload)) + w := httptest.NewRecorder() + c.oidcProviderHandler(w, req) + + resp := w.Result() + + if resp.StatusCode != 200 { + t.Fatalf("status code is not 200: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + err = json.NewDecoder(resp.Body).Decode(&oidcProvider) + if err != nil { + t.Fatalf("Cannot decode response from create user: %s", err) + } + + if oidcProvider.ID == "" { + t.Fatalf("Was expecting oidc provider to have an ID") + } + + // get redirect URL + req = httptest.NewRequest("GET", "http://example.inv/api/authmethods/oidc/"+oidcProvider.ID+"/redirect", nil) + req.SetPathValue("id", oidcProvider.ID) + req.SetPathValue("method", "oidc") + w = httptest.NewRecorder() + c.authMethodsByIDRedirect(w, req) + + resp = w.Result() + defer resp.Body.Close() + + if resp.StatusCode != http.StatusFound { + t.Fatalf("status code is not 302: %d", resp.StatusCode) + } + + redirectURL := resp.Header.Get("Location") + if redirectURL == "" { + t.Fatalf("no redirect URL in header") + } + if !strings.HasPrefix(redirectURL, authURL) { + t.Fatalf("expected authURL as prefix of redirect url. Redirect URL: %s", redirectURL) + } + + if !strings.Contains(redirectURL, "state=") { + t.Fatalf("state parameter not found in redirect URL: %s", redirectURL) + } + parsedRedirectURL, err := url.Parse(redirectURL) + if err != nil { + t.Fatalf("could not parse redirect URL: %s", err) + } + state := parsedRedirectURL.Query().Get("state") + if state == "" { + t.Fatalf("state parameter is empty in redirect URL: %s", redirectURL) + } +} + +func TestSAMLRedirect(t *testing.T) { + // generate new keypair + kp := saml.NewKeyPair(&memorystorage.MockMemoryStorage{}, "www.idp.inv") + _, cert, err := kp.GetKeyPair() + if err != nil { + t.Fatalf("Can't generate new keypair: %s", err) + } + certBase64 := base64.StdEncoding.EncodeToString(cert) + + testUrl := "127.0.0.1:12347" + l, err := net.Listen("tcp", testUrl) + if err != nil { + t.Fatal(err) + } + + singleSignOnURL := "http://" + testUrl + "/auth" + + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + requestURIParsed, _ := url.Parse(r.RequestURI) + if requestURIParsed.Path == "/auth" { + compressedSAMLReq, err := base64.StdEncoding.DecodeString(r.URL.Query().Get("SAMLRequest")) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf("saml base64 decode error: %s", err))) + return + } + samlRequest := new(bytes.Buffer) + decompressor := flate.NewReader(bytes.NewReader(compressedSAMLReq)) + io.Copy(samlRequest, decompressor) + decompressor.Close() + + var authnReq saml.AuthnRequest + err = xml.Unmarshal(samlRequest.Bytes(), &authnReq) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf("saml authn request decode error: %s", err))) + return + } + w.Write([]byte("OK")) + return + } + if r.RequestURI == "/metadata" { + out, _ := xml.Marshal(getSAMLCertWithCustomCert(singleSignOnURL, certBase64)) + w.Write(out) + return + } + w.WriteHeader(http.StatusBadRequest) + default: + w.WriteHeader(http.StatusBadRequest) + } + })) + + ts.Listener.Close() + ts.Listener = l + ts.Start() + defer ts.Close() + defer l.Close() + + // first create a new user + c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN) + if err != nil { + t.Fatalf("Cannot create context") + } + + // create a new SAML connection + samlProvider := saml.Provider{ + Name: "testProvider", + MetadataURL: fmt.Sprintf("%s/metadata", ts.URL), + } + + payload, err := json.Marshal(samlProvider) + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest("POST", "http://example.inv/api/saml-setup", bytes.NewBuffer(payload)) + w := httptest.NewRecorder() + c.samlSetupHandler(w, req) + + resp := w.Result() + + if resp.StatusCode != 200 { + t.Fatalf("status code is not 200: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + err = json.NewDecoder(resp.Body).Decode(&samlProvider) + if err != nil { + t.Fatalf("Cannot decode response from create user: %s", err) + } + + if samlProvider.ID == "" { + t.Fatalf("Was expecting saml provider to have an ID") + } + + // get redirect URL + req = httptest.NewRequest("GET", "http://example.inv/api/authmethods/saml/"+samlProvider.ID+"/redirect", nil) + req.SetPathValue("id", samlProvider.ID) + req.SetPathValue("method", "saml") + w = httptest.NewRecorder() + c.authMethodsByIDRedirect(w, req) + + resp = w.Result() + defer resp.Body.Close() + + if resp.StatusCode != http.StatusFound { + t.Fatalf("status code is not 302: %d", resp.StatusCode) + } + + redirectURL := resp.Header.Get("Location") + if redirectURL == "" { + t.Fatalf("no redirect URL in header") + } + if !strings.HasPrefix(redirectURL, singleSignOnURL) { + t.Fatalf("expected authURL as prefix of redirect url. Redirect URL: %s", redirectURL) + } + + if !strings.Contains(redirectURL, "SAMLRequest=") { + t.Fatalf("SAMLRequest parameter not found in redirect URL: %s", redirectURL) + } + +} diff --git a/rest/router.go b/rest/router.go index d2f4f91..4a8ea5c 100644 --- a/rest/router.go +++ b/rest/router.go @@ -21,6 +21,7 @@ func (c *Context) getRouter(assets fs.FS, indexHtml []byte) *http.ServeMux { mux.Handle("/api/context", http.HandlerFunc(c.contextHandler)) mux.Handle("/api/auth", http.HandlerFunc(c.authHandler)) mux.Handle("/api/authmethods", http.HandlerFunc(c.authMethods)) + mux.Handle("/api/authmethods/{method}/{id}/redirect", http.HandlerFunc(c.authMethodsByIDRedirect)) mux.Handle("/api/authmethods/{method}/{id}", http.HandlerFunc(c.authMethodsByID)) mux.Handle("/api/authmethods/{id}", http.HandlerFunc(c.authMethodsByID)) mux.Handle("/api/upgrade", http.HandlerFunc(c.upgrade))