diff --git a/Cargo.lock b/Cargo.lock index c64419dd..5a3f6cf2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2272,6 +2272,21 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "progenitor-client" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e8a874cf25a33cac7a01b9c1de87bcfbc8aea93f3156d09dcc3bee516a78926" +dependencies = [ + "bytes", + "futures-core", + "percent-encoding", + "reqwest 0.13.2", + "serde", + "serde_json", + "serde_urlencoded", +] + [[package]] name = "quinn" version = "0.11.9" @@ -2558,6 +2573,7 @@ dependencies = [ "rustls-platform-verifier", "serde", "serde_json", + "serde_urlencoded", "sync_wrapper", "tokio", "tokio-rustls 0.26.4", @@ -2732,7 +2748,7 @@ dependencies = [ "security-framework", "security-framework-sys", "webpki-root-certs", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -3858,6 +3874,25 @@ dependencies = [ "v-model", ] +[[package]] +name = "v-cli-sdk" +version = "0.2.0" +dependencies = [ + "anyhow", + "chrono", + "clap", + "http", + "http-body-util", + "hyper", + "hyper-util", + "oauth2", + "oauth2-reqwest", + "progenitor-client", + "reqwest 0.13.2", + "serde", + "tokio", +] + [[package]] name = "v-model" version = "0.2.0" @@ -4118,7 +4153,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 7d58a2ae..e2353bad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ members = [ "v-api-installer", "v-api-param", "v-api-permission-derive", + "v-cli-sdk", "v-model", "xtask" ] @@ -29,6 +30,7 @@ hex = "0.4.3" http = "1" http-body-util = "0.1.3" hyper = "1.9.0" +hyper-util = "0.1" jsonwebtoken = { version = "10.2", features = ["rust_crypto"] } mockall = "0.14.0" newtype-uuid = { version = "1.3.2", features = ["schemars08", "serde", "v4"] } @@ -37,6 +39,7 @@ oauth2-reqwest = "0.1.0-alpha.3" partial-struct = { git = "https://github.com/oxidecomputer/partial-struct" } percent-encoding = "2.3.2" proc-macro2 = "1" +progenitor-client = "0.14.0" quote = "1" rand = "0.10.1" rand_core = "0.10.1" diff --git a/v-api/src/config.rs b/v-api/src/config.rs index 22789099..917063ce 100644 --- a/v-api/src/config.rs +++ b/v-api/src/config.rs @@ -151,6 +151,7 @@ pub struct SendGridConfig { pub struct OAuthProviders { pub github: Option, pub google: Option, + pub zendesk: Option, } #[derive(Debug, Deserialize)] diff --git a/v-api/src/context/mod.rs b/v-api/src/context/mod.rs index b73abdb3..f0dbd40c 100644 --- a/v-api/src/context/mod.rs +++ b/v-api/src/context/mod.rs @@ -1259,7 +1259,7 @@ pub(crate) mod test_mocks { pub async fn mock_context(storage: Arc) -> VContext { let MockKey { signer, verifier } = mock_key("test"); let mut ctx = VContextBuilder::::new() - .with_public_url("".to_string()) + .with_public_url("https://test_public_url".to_string()) .with_storage(storage) .with_jwt_expiration(JwtConfig::default().default_expiration) .with_keys(vec![signer, verifier]) @@ -1278,6 +1278,7 @@ pub(crate) mod test_mocks { OAuthProviderName::Google, Box::new(move || { Box::new(GoogleOAuthProvider::new( + "https://test_public_url".to_string(), "google_device_client_id".to_string(), "google_device_client_secret".to_string().into(), "google_web_client_id".to_string(), diff --git a/v-api/src/endpoints/handlers.rs b/v-api/src/endpoints/handlers.rs index f2dd46e4..070977bf 100644 --- a/v-api/src/endpoints/handlers.rs +++ b/v-api/src/endpoints/handlers.rs @@ -72,7 +72,7 @@ mod macros { code::{ authz_code_callback_op, authz_code_exchange_op, authz_code_redirect_op, OAuthAuthzCodeExchangeBody, OAuthAuthzCodeExchangeResponse, - OAuthAuthzCodeQuery, OAuthAuthzCodeReturnQuery, + OAuthAuthzCodeQuery, OAuthAuthzCodeReturnQuery, OAuthAuthzCodeExchangeQuery }, device_token::{ exchange_device_token_op, get_device_provider_op, AccessTokenExchangeRequest, @@ -296,9 +296,10 @@ mod macros { pub async fn authz_code_exchange( rqctx: RequestContext<$context_type>, path: Path, + query: Query, body: TypedBody, ) -> Result, HttpError> { - authz_code_exchange_op(&rqctx, path, body).await + authz_code_exchange_op(&rqctx, path, query, body).await } // DEVICE CODE diff --git a/v-api/src/endpoints/login/mod.rs b/v-api/src/endpoints/login/mod.rs index f9c471e4..05a7207e 100644 --- a/v-api/src/endpoints/login/mod.rs +++ b/v-api/src/endpoints/login/mod.rs @@ -60,6 +60,7 @@ impl From for HttpError { pub enum ExternalUserId { GitHub(String), Google(String), + Zendesk(String), #[cfg(feature = "local-dev")] Local(String), MagicLink(String), @@ -70,6 +71,7 @@ impl ExternalUserId { match self { Self::GitHub(id) => id, Self::Google(id) => id, + Self::Zendesk(id) => id, #[cfg(feature = "local-dev")] Self::Local(id) => id, Self::MagicLink(id) => id, @@ -80,6 +82,7 @@ impl ExternalUserId { match self { Self::GitHub(_) => "github", Self::Google(_) => "google", + Self::Zendesk(_) => "zendesk", #[cfg(feature = "local-dev")] Self::Local(_) => "local", Self::MagicLink(_) => "magic-link", @@ -103,6 +106,7 @@ impl Serialize for ExternalUserId { match self { ExternalUserId::GitHub(id) => serializer.serialize_str(&format!("github-{}", id)), ExternalUserId::Google(id) => serializer.serialize_str(&format!("google-{}", id)), + ExternalUserId::Zendesk(id) => serializer.serialize_str(&format!("zendesk-{}", id)), #[cfg(feature = "local-dev")] ExternalUserId::Local(id) => serializer.serialize_str(&format!("local-{}", id)), ExternalUserId::MagicLink(id) => { @@ -142,6 +146,12 @@ impl<'de> Deserialize<'de> for ExternalUserId { } else { Err(de::Error::custom(ExternalUserIdDeserializeError::Empty)) } + } else if let Some(("", id)) = value.split_once("zendesk-") { + if !id.is_empty() { + Ok(ExternalUserId::Zendesk(id.to_string())) + } else { + Err(de::Error::custom(ExternalUserIdDeserializeError::Empty)) + } } else if let Some(("", id)) = value.split_once("local-") { #[cfg(feature = "local-dev")] { @@ -191,6 +201,8 @@ pub enum UserInfoError { Deserialize(#[from] serde_json::Error), #[error("Failed to create user info request {0}")] Http(#[from] http::Error), + #[error("User account is locked")] + Locked, #[error("User information is missing")] MissingUserInfoData(String), } diff --git a/v-api/src/endpoints/login/oauth/code.rs b/v-api/src/endpoints/login/oauth/code.rs index 39d3a4a8..c3cd164e 100644 --- a/v-api/src/endpoints/login/oauth/code.rs +++ b/v-api/src/endpoints/login/oauth/code.rs @@ -30,7 +30,7 @@ use v_model::{ LoginAttempt, LoginAttemptId, NewLoginAttempt, OAuthClient, OAuthClientId, }; -use super::{OAuthProvider, OAuthProviderNameParam, UserInfoProvider, WebClientConfig}; +use super::{OAuthProvider, OAuthProviderNameParam, UserInfoProvider}; use crate::{ authn::key::RawKey, context::{ApiContext, VContext}, @@ -239,11 +239,7 @@ fn oauth_redirect_response( // TODO: This behavior should be changed so that clients are precomputed. We do not need to be // constructing a new client on every request. That said, we need to ensure the client does not // maintain state between requests - let client = provider - .as_web_client(&WebClientConfig { - prefix: public_url.to_string(), - }) - .map_err(to_internal_error)?; + let client = provider.as_web_client().map_err(to_internal_error)?; // Create an attempt cookie header for storing the login attempt. This also acts as our csrf // check @@ -440,6 +436,12 @@ where Ok(attempt.callback_url()) } +#[derive(Debug, Deserialize, JsonSchema)] +pub struct OAuthAuthzCodeExchangeQuery { + #[serde(default)] + pub include_idp_token: bool, +} + #[derive(Debug, Deserialize, JsonSchema)] pub struct OAuthAuthzCodeExchangeBody { pub client_id: Option>, @@ -455,12 +457,19 @@ pub struct OAuthAuthzCodeExchangeResponse { pub access_token: String, pub token_type: String, pub expires_in: i64, + pub idp_token: Option, +} + +#[derive(Debug, Deserialize, JsonSchema, Serialize)] +pub struct OAuthAuthzCodeIdpToken { + pub token: String, } #[instrument(skip(rqctx), err(Debug))] pub async fn authz_code_exchange_op( rqctx: &RequestContext>, path: Path, + query: Query, body: TypedBody, ) -> Result, HttpError> where @@ -468,6 +477,7 @@ where { let ctx = rqctx.v_ctx(); let path = path.into_inner(); + let query = query.into_inner(); let body = body.into_inner(); let (client_id, client_secret) = @@ -541,7 +551,14 @@ where // Now that the attempt has been confirmed, use it to fetch user information form the remote // provider - let info = fetch_user_info(ctx.public_url(), &ctx.web_client(), &*provider, &attempt).await?; + let (info, raw_token) = fetch_user_info( + ctx.public_url(), + &ctx.web_client(), + &*provider, + &attempt, + query.include_idp_token, + ) + .await?; tracing::debug!("Retrieved user information from remote provider"); @@ -589,6 +606,9 @@ where token_type: "Bearer".to_string(), access_token: token.signed_token, expires_in: token.expires_in, + idp_token: query.include_idp_token.then(|| OAuthAuthzCodeIdpToken { + token: raw_token.unwrap(), + }), })) } @@ -713,13 +733,10 @@ async fn fetch_user_info( client_type: &ClientType, provider: &dyn OAuthProvider, attempt: &LoginAttempt, -) -> Result { + return_raw: bool, +) -> Result<(UserInfo, Option), HttpError> { // Exchange the stored authorization code with the remote provider for a remote access token - let client = provider - .as_web_client(&WebClientConfig { - prefix: public_url.to_string(), - }) - .map_err(to_internal_error)?; + let client = provider.as_web_client().map_err(to_internal_error)?; let mut request = client.exchange_code(AuthorizationCode::new( attempt @@ -754,7 +771,7 @@ async fn fetch_user_info( // Now that we are done with fetching user information from the remote API, we can revoke it if // the provider supports it - if provider.token_revocation_endpoint().is_some() { + if !return_raw && provider.token_revocation_endpoint().is_some() { client .revoke_token(response.access_token().into()) .map_err(internal_error)? @@ -763,7 +780,11 @@ async fn fetch_user_info( .map_err(internal_error)?; } - Ok(info) + if return_raw { + Ok((info, Some(response.access_token().secret().to_string()))) + } else { + Ok((info, None)) + } } #[cfg(test)] @@ -889,8 +910,7 @@ mod tests { #[tokio::test] async fn test_remote_provider_redirect_url() { let storage = MockStorage::new(); - let mut ctx = mock_context(Arc::new(storage)).await; - ctx.with_public_url("https://api.oxeng.dev"); + let ctx = mock_context(Arc::new(storage)).await; let (challenge, _) = PkceCodeChallenge::new_random_sha256(); let attempt = LoginAttempt { @@ -927,7 +947,7 @@ mod tests { .unwrap(); let headers = response.headers(); - let expected_location = format!("https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=google_web_client_id&state={}&code_challenge={}&code_challenge_method=S256&redirect_uri=https%3A%2F%2Fapi.oxeng.dev%2Flogin%2Foauth%2Fgoogle%2Fcode%2Fcallback&scope=openid+email+profile", attempt.id, challenge.as_str()); + let expected_location = format!("https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=google_web_client_id&state={}&code_challenge={}&code_challenge_method=S256&redirect_uri=https%3A%2F%2Ftest_public_url%2Flogin%2Foauth%2Fgoogle%2Fcode%2Fcallback&scope=openid+email+profile", attempt.id, challenge.as_str()); assert_eq!( expected_location, diff --git a/v-api/src/endpoints/login/oauth/device_token.rs b/v-api/src/endpoints/login/oauth/device_token.rs index 0c9d5c11..22488a61 100644 --- a/v-api/src/endpoints/login/oauth/device_token.rs +++ b/v-api/src/endpoints/login/oauth/device_token.rs @@ -40,10 +40,7 @@ where .await .map_err(ApiError::OAuth)?; - Ok(HttpResponseOk(provider.provider_info( - rqctx.v_ctx().public_url(), - &ClientType::Device, - ))) + Ok(HttpResponseOk(provider.provider_info(&ClientType::Device))) } #[derive(Debug, Deserialize, JsonSchema, Serialize)] @@ -124,6 +121,24 @@ where tracing::debug!(provider = ?provider.name(), "Acquired OAuth provider for token exchange"); + if provider.device_code_endpoint().is_none() { + return Ok(Response::builder() + .status(StatusCode::BAD_REQUEST) + .header(header::CONTENT_TYPE, "application/json") + .body( + serde_json::to_vec(&ProxyTokenError { + error: "unsupported_grant_type".to_string(), + error_description: Some(format!( + "{} does not support device code flow", + path.provider + )), + error_uri: None, + }) + .unwrap() + .into(), + )?); + } + let exchange_request = body.into_inner(); if let Some(exchange) = AccessTokenExchange::new(exchange_request, &*provider) { diff --git a/v-api/src/endpoints/login/oauth/github.rs b/v-api/src/endpoints/login/oauth/github.rs index bce13561..3f3483fe 100644 --- a/v-api/src/endpoints/login/oauth/github.rs +++ b/v-api/src/endpoints/login/oauth/github.rs @@ -26,6 +26,9 @@ pub struct GitHubOAuthProvider { additional_scopes: Vec, default_headers: HeaderMap, client: reqwest::Client, + token_endpoint: Option, + redirect_endpoint: Option, + redirect_proxy_endpoint: Option, } impl fmt::Debug for GitHubOAuthProvider { @@ -36,6 +39,7 @@ impl fmt::Debug for GitHubOAuthProvider { impl GitHubOAuthProvider { pub fn new( + public_url: String, device_client_id: String, device_client_secret: SecretString, web_client_id: String, @@ -64,6 +68,9 @@ impl GitHubOAuthProvider { .redirect(reqwest::redirect::Policy::none()) .build() .expect("Static client must build"), + token_endpoint: Some(format!("{}/login/oauth/github/device/exchange", public_url)), + redirect_endpoint: Some(format!("{}/login/oauth/github/code/callback", public_url,)), + redirect_proxy_endpoint: None, } } @@ -154,8 +161,8 @@ impl OAuthProvider for GitHubOAuthProvider { ] } - fn device_code_endpoint(&self) -> &str { - "https://github.com/login/device/code" + fn device_code_endpoint(&self) -> Option<&str> { + Some("https://github.com/login/device/code") } fn auth_url_endpoint(&self) -> &str { @@ -177,4 +184,14 @@ impl OAuthProvider for GitHubOAuthProvider { fn supports_pkce(&self) -> bool { true } + + fn token_endpoint(&self) -> Option<&str> { + self.token_endpoint.as_deref() + } + fn redirect_endpoint(&self) -> Option<&str> { + self.redirect_endpoint.as_deref() + } + fn redirect_proxy_endpoint(&self) -> Option<&str> { + self.redirect_proxy_endpoint.as_deref() + } } diff --git a/v-api/src/endpoints/login/oauth/google.rs b/v-api/src/endpoints/login/oauth/google.rs index a38024d7..e477ae44 100644 --- a/v-api/src/endpoints/login/oauth/google.rs +++ b/v-api/src/endpoints/login/oauth/google.rs @@ -22,6 +22,9 @@ pub struct GoogleOAuthProvider { web_private: Option, additional_scopes: Vec, client: reqwest::Client, + token_endpoint: Option, + redirect_endpoint: Option, + redirect_proxy_endpoint: Option, } impl fmt::Debug for GoogleOAuthProvider { @@ -32,6 +35,7 @@ impl fmt::Debug for GoogleOAuthProvider { impl GoogleOAuthProvider { pub fn new( + public_url: String, device_client_id: String, device_client_secret: SecretString, web_client_id: String, @@ -56,6 +60,9 @@ impl GoogleOAuthProvider { .redirect(reqwest::redirect::Policy::none()) .build() .expect("Static client must build"), + token_endpoint: Some(format!("{}/login/oauth/google/device/exchange", public_url)), + redirect_endpoint: Some(format!("{}/login/oauth/google/code/callback", public_url,)), + redirect_proxy_endpoint: None, } } @@ -162,8 +169,8 @@ impl OAuthProvider for GoogleOAuthProvider { ] } - fn device_code_endpoint(&self) -> &str { - "https://oauth2.googleapis.com/device/code" + fn device_code_endpoint(&self) -> Option<&str> { + Some("https://oauth2.googleapis.com/device/code") } fn auth_url_endpoint(&self) -> &str { @@ -185,4 +192,14 @@ impl OAuthProvider for GoogleOAuthProvider { fn supports_pkce(&self) -> bool { true } + + fn token_endpoint(&self) -> Option<&str> { + self.token_endpoint.as_deref() + } + fn redirect_endpoint(&self) -> Option<&str> { + self.redirect_endpoint.as_deref() + } + fn redirect_proxy_endpoint(&self) -> Option<&str> { + self.redirect_proxy_endpoint.as_deref() + } } diff --git a/v-api/src/endpoints/login/oauth/mod.rs b/v-api/src/endpoints/login/oauth/mod.rs index f89d822d..b986c5c4 100644 --- a/v-api/src/endpoints/login/oauth/mod.rs +++ b/v-api/src/endpoints/login/oauth/mod.rs @@ -27,11 +27,16 @@ pub mod code; pub mod device_token; pub mod github; pub mod google; +pub mod zendesk; #[derive(Debug, Error)] pub enum OAuthProviderError { #[error("Unable to instantiate invalid provider")] FailToCreateInvalidProvider, + #[error("Missing redirect URI")] + MissingRedirectUri, + #[error("Failed to parse URL")] + UrlParseError(#[from] ParseError), } #[derive(Debug)] @@ -40,11 +45,6 @@ pub enum ClientType { Web, } -#[derive(Debug)] -pub struct WebClientConfig { - prefix: String, -} - pub type WebClient = BasicClient< // HasAuthUrl EndpointSet, @@ -76,20 +76,26 @@ pub trait OAuthProvider: ExtractUserInfo + Debug + Send + Sync { // TODO: How can user info be change to something statically checked instead of a runtime check fn user_info_endpoints(&self) -> Vec<&str>; - fn device_code_endpoint(&self) -> &str; + fn device_code_endpoint(&self) -> Option<&str>; fn auth_url_endpoint(&self) -> &str; fn token_exchange_content_type(&self) -> &str; fn token_exchange_endpoint(&self) -> &str; fn token_revocation_endpoint(&self) -> Option<&str>; fn supports_pkce(&self) -> bool; - fn provider_info(&self, public_url: &str, client_type: &ClientType) -> OAuthProviderInfo { + fn token_endpoint(&self) -> Option<&str>; + fn redirect_endpoint(&self) -> Option<&str>; + fn redirect_proxy_endpoint(&self) -> Option<&str>; + + fn provider_info(&self, client_type: &ClientType) -> OAuthProviderInfo { OAuthProviderInfo { provider: self.name(), client_id: self.client_id(client_type).to_string(), auth_url_endpoint: self.auth_url_endpoint().to_string(), - device_code_endpoint: self.device_code_endpoint().to_string(), - token_endpoint: format!("{}/login/oauth/{}/device/exchange", public_url, self.name(),), + device_code_endpoint: self.device_code_endpoint().map(|s| s.to_string()), + token_endpoint: self.token_endpoint().map(|s| s.to_string()), + redirect_endpoint: self.redirect_endpoint().map(|s| s.to_string()), + redirect_proxy_endpoint: self.redirect_proxy_endpoint().map(|s| s.to_string()), scopes: self .scopes() .into_iter() @@ -98,7 +104,7 @@ pub trait OAuthProvider: ExtractUserInfo + Debug + Send + Sync { } } - fn as_web_client(&self, config: &WebClientConfig) -> Result { + fn as_web_client(&self) -> Result { let mut client = BasicClient::new(ClientId::new(self.client_id(&ClientType::Web).to_string())) .set_auth_uri(AuthUrl::new(self.auth_url_endpoint().to_string())?) @@ -108,11 +114,11 @@ pub trait OAuthProvider: ExtractUserInfo + Debug + Send + Sync { .map(|s| RevocationUrl::new(s.to_string())) .transpose()?, ) - .set_redirect_uri(RedirectUrl::new(format!( - "{}/login/oauth/{}/code/callback", - &config.prefix, - self.name() - ))?); + .set_redirect_uri(RedirectUrl::new( + self.redirect_endpoint() + .ok_or(OAuthProviderError::MissingRedirectUri)? + .to_string(), + )?); if let Some(secret) = self.client_secret(&ClientType::Web) { client = client.set_client_secret(ClientSecret::new(secret.expose_secret().to_string())) @@ -171,8 +177,10 @@ pub struct OAuthProviderInfo { provider: OAuthProviderName, client_id: String, auth_url_endpoint: String, - device_code_endpoint: String, - token_endpoint: String, + device_code_endpoint: Option, + token_endpoint: Option, + redirect_endpoint: Option, + redirect_proxy_endpoint: Option, scopes: Vec, } @@ -182,6 +190,7 @@ pub enum OAuthProviderName { #[serde(rename = "github")] GitHub, Google, + Zendesk, } impl Display for OAuthProviderName { @@ -189,6 +198,7 @@ impl Display for OAuthProviderName { match self { OAuthProviderName::GitHub => write!(f, "github"), OAuthProviderName::Google => write!(f, "google"), + OAuthProviderName::Zendesk => write!(f, "zendesk"), } } } diff --git a/v-api/src/endpoints/login/oauth/zendesk.rs b/v-api/src/endpoints/login/oauth/zendesk.rs new file mode 100644 index 00000000..b2421ce2 --- /dev/null +++ b/v-api/src/endpoints/login/oauth/zendesk.rs @@ -0,0 +1,202 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use hyper::body::Bytes; +use reqwest::Request; +use secrecy::SecretString; +use serde::Deserialize; +use std::fmt; + +use crate::endpoints::login::{ExternalUserId, UserInfo, UserInfoError}; + +use super::{ + ClientType, ExtractUserInfo, OAuthPrivateCredentials, OAuthProvider, OAuthProviderName, + OAuthPublicCredentials, +}; + +pub struct ZendeskOAuthProvider { + device_public: OAuthPublicCredentials, + device_private: Option, + web_public: OAuthPublicCredentials, + web_private: Option, + additional_scopes: Vec, + client: reqwest::Client, + user_info_endpoint: String, + auth_url_endpoint: String, + token_exchange_endpoint: String, + token_endpoint: Option, + redirect_endpoint: Option, + redirect_proxy_endpoint: Option, +} + +impl fmt::Debug for ZendeskOAuthProvider { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ZendeskOAuthProvider").finish() + } +} + +impl ZendeskOAuthProvider { + pub fn new( + public_url: String, + subdomain: String, + device_client_id: String, + device_client_secret: SecretString, + web_client_id: String, + web_client_secret: SecretString, + additional_scopes: Option>, + redirect_proxy_port: u16, + ) -> Self { + let base_url = format!("https://{}.zendesk.com", subdomain); + + Self { + device_public: OAuthPublicCredentials { + client_id: device_client_id, + }, + device_private: Some(OAuthPrivateCredentials { + client_secret: device_client_secret, + }), + web_public: OAuthPublicCredentials { + client_id: web_client_id, + }, + web_private: Some(OAuthPrivateCredentials { + client_secret: web_client_secret, + }), + additional_scopes: additional_scopes.unwrap_or_default(), + client: reqwest::ClientBuilder::new() + .redirect(reqwest::redirect::Policy::none()) + .build() + .expect("Static client must build"), + user_info_endpoint: format!("{}/api/v2/users/me.json", base_url), + auth_url_endpoint: format!("{}/oauth/authorizations/new", base_url), + token_exchange_endpoint: format!("{}/oauth/tokens", base_url), + token_endpoint: Some(format!( + "{}/login/oauth/zendesk/device/exchange", + public_url + )), + redirect_endpoint: Some(format!("{}/login/oauth/zendesk/code/callback", public_url,)), + redirect_proxy_endpoint: Some(format!( + "http://localhost:{}/login/oauth/zendesk/code/callback", + redirect_proxy_port + )), + } + } + + pub fn with_client(&mut self, client: reqwest::Client) -> &mut Self { + self.client = client; + self + } +} + +#[derive(Debug, Deserialize)] +struct ZendeskUserResponse { + user: ZendeskUser, +} + +#[derive(Debug, Deserialize)] +struct ZendeskUser { + id: u64, + name: String, + email: String, + verified: bool, + suspended: bool, +} + +impl ExtractUserInfo for ZendeskOAuthProvider { + fn extract_user_info(&self, data: &[Bytes]) -> Result { + let response: ZendeskUserResponse = serde_json::from_slice(&data[0])?; + let user = response.user; + + if user.suspended { + return Err(UserInfoError::Locked); + } + + let verified_emails = if user.verified { + vec![user.email] + } else { + vec![] + }; + + Ok(UserInfo { + external_id: ExternalUserId::Zendesk(user.id.to_string()), + verified_emails, + display_name: Some(user.name), + }) + } +} + +impl OAuthProvider for ZendeskOAuthProvider { + fn name(&self) -> OAuthProviderName { + OAuthProviderName::Zendesk + } + + fn scopes(&self) -> Vec<&str> { + let mut default = vec!["users:read"]; + default.extend(self.additional_scopes.iter().map(|s| s.as_str())); + default + } + + fn initialize_headers(&self, _request: &mut Request) {} + + fn client(&self) -> &reqwest::Client { + &self.client + } + + fn client_id(&self, client_type: &ClientType) -> &str { + match client_type { + ClientType::Device => &self.device_public.client_id, + ClientType::Web => &self.web_public.client_id, + } + } + + fn client_secret(&self, client_type: &ClientType) -> Option<&SecretString> { + match client_type { + ClientType::Device => self + .device_private + .as_ref() + .map(|private| &private.client_secret), + ClientType::Web => self + .web_private + .as_ref() + .map(|private| &private.client_secret), + } + } + + fn user_info_endpoints(&self) -> Vec<&str> { + vec![&self.user_info_endpoint] + } + + fn device_code_endpoint(&self) -> Option<&str> { + None + } + + fn auth_url_endpoint(&self) -> &str { + &self.auth_url_endpoint + } + + fn token_exchange_content_type(&self) -> &str { + "application/x-www-form-urlencoded" + } + + fn token_exchange_endpoint(&self) -> &str { + &self.token_exchange_endpoint + } + + fn token_revocation_endpoint(&self) -> Option<&str> { + None + } + + fn supports_pkce(&self) -> bool { + false + } + + fn token_endpoint(&self) -> Option<&str> { + self.token_endpoint.as_deref() + } + fn redirect_endpoint(&self) -> Option<&str> { + self.redirect_endpoint.as_deref() + } + fn redirect_proxy_endpoint(&self) -> Option<&str> { + self.redirect_proxy_endpoint.as_deref() + } +} diff --git a/v-cli-sdk/Cargo.toml b/v-cli-sdk/Cargo.toml new file mode 100644 index 00000000..c3c630ec --- /dev/null +++ b/v-cli-sdk/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "v-cli-sdk" +version = "0.2.0" +edition = "2021" + +[dependencies] +anyhow = { workspace = true } +chrono = { workspace = true } +clap = { workspace = true } +http = { workspace = true } +http-body-util = { workspace = true } +hyper = { workspace = true, features = ["server", "http1"] } +hyper-util = { workspace = true, features = ["tokio"] } +oauth2 = { workspace = true } +oauth2-reqwest = { workspace = true } +progenitor-client = { workspace = true } +reqwest = { workspace = true } +serde = { workspace = true } +tokio = { workspace = true, features = ["macros", "rt", "net", "sync"] } diff --git a/v-cli-sdk/src/cmd/auth/login.rs b/v-cli-sdk/src/cmd/auth/login.rs new file mode 100644 index 00000000..6d09a937 --- /dev/null +++ b/v-cli-sdk/src/cmd/auth/login.rs @@ -0,0 +1,232 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use anyhow::Result; +use clap::{Parser, Subcommand, ValueEnum}; +use oauth2::TokenResponse; +use std::{error::Error as StdError, fmt::Debug, future::Future, io::Write, pin::Pin, sync::Arc}; + +use crate::{ + cmd::{ + auth::oauth::{self, CliOAuthAdapter, CliOAuthProviderInfo}, + config::CliConfig, + }, + CliContext, +}; + +pub trait CliAdapterToken { + fn access_token(&self) -> &str; +} + +pub trait CliConsumerLoginProvider: Into + Subcommand + Debug + Clone {} +impl CliConsumerLoginProvider for T where T: Into + Subcommand + Debug + Clone {} + +// Authenticates and generates an access token for interacting with the api +#[derive(Parser, Debug, Clone)] +#[clap(name = "login")] +pub struct Login

+where + P: CliConsumerLoginProvider, +{ + #[command(subcommand)] + method: LoginMethod

, + #[arg(short = 'm', default_value = "id")] + mode: AuthenticationMode, +} + +impl

Login

+where + P: CliConsumerLoginProvider, +{ + pub async fn run(&self, ctx: &mut T) -> Result<()> + where + T: CliContext, + >::Error: StdError + Send + Sync + 'static, + { + let access_token = self.method.run(ctx, &self.mode).await?; + + ctx.config_mut().set_token(access_token); + ctx.config_mut().save()?; + + Ok(()) + } +} + +#[derive(Subcommand, Debug, Clone)] +pub enum LoginMethod

+where + P: Subcommand + Debug + Clone, +{ + #[command(name = "oauth")] + /// Login via OAuth + OAuth { + #[command(subcommand)] + provider: P, + }, + /// Login via Magic Link + #[command(name = "mlink")] + MagicLink { + /// Email recipient to login via + email: String, + /// Optional access scopes to apply to this session + scope: Option, + }, +} + +pub enum LoginProvider { + Google, + GitHub, + Zendesk, +} + +#[derive(ValueEnum, Debug, Clone, PartialEq)] +pub enum AuthenticationMode { + /// Retrieve and store an identity token. Identity mode is the default and should be used to + /// when you do not require extended (multi-day) access + #[value(name = "id")] + Identity, + /// Retrieve and store an api token. Token mode should be used when you want to authenticate + /// a machine for continued access. This requires the permission to create api tokens + #[value(name = "token")] + Token, + /// Retrieve and store a remote token. Remote mode should be used when you want to authenticate + /// and retrieve a token for use against the underlying authentication provider + #[value(name = "remote")] + Remote, +} + +impl

LoginMethod

+where + P: CliConsumerLoginProvider, +{ + pub async fn run(&self, ctx: &T, mode: &AuthenticationMode) -> Result + where + T: CliContext, + >::Error: StdError + Send + Sync + 'static, + { + match self { + Self::OAuth { provider } => { + let adapter = ctx.oauth_adapter(); + let provider = provider.clone().into(); + let provider = adapter.provider(&provider).await?; + + // We now need to inspect the provider to determine the correct flow to use. If + // possible we use a limited input device flow, but not all providers support it. + // To handle those cases we need to use a proxy path that emulates an authorization + // code flow. + if provider.device_code_endpoint().is_some() { + self.run_oauth_device_provider(provider, mode, ctx.oauth_adapter()) + .await + } else if provider.code_redirect_proxy_endpoint().is_some() { + self.run_oauth_code_provider(provider, mode, ctx.oauth_adapter()) + .await + } else { + anyhow::bail!("OAuth provider does not support any CLI authentication methods") + } + } + Self::MagicLink { email, scope } => { + self.run_magic_link(email, scope.as_deref(), ctx.mlink_adapter()) + .await + } + } + } + + async fn run_oauth_device_provider( + &self, + provider: V, + mode: &AuthenticationMode, + adapter: T, + ) -> Result + where + T: CliOAuthAdapter, + V: CliOAuthProviderInfo, + { + let oauth_client = oauth::DeviceOAuth::new(provider)?; + let details = oauth_client.get_device_authorization().await?; + + println!( + "To complete login visit: {} and enter {}", + details.verification_uri().as_str(), + details.user_code().secret() + ); + + let token_response = oauth_client.login(&details).await; + + let identity_token = match token_response { + Ok(token) => Ok(token.access_token().to_owned()), + Err(err) => Err(anyhow::anyhow!("Authentication failed: {}", err)), + }?; + + if mode == &AuthenticationMode::Token { + let token = adapter + .get_long_lived_token(identity_token.secret()) + .await?; + Ok(token.access_token().to_string()) + } else { + Ok(identity_token.secret().to_string()) + } + } + + async fn run_oauth_code_provider( + &self, + provider: V, + mode: &AuthenticationMode, + adapter: T, + ) -> Result + where + T: CliOAuthAdapter + Send + Sync + 'static, + V: CliOAuthProviderInfo, + { + let oauth_client = oauth::CodeOAuth::new(provider)?; + let adapter = Arc::new(adapter); + + let identity_token = oauth_client.login(Arc::clone(&adapter)).await?; + + if mode == &AuthenticationMode::Token { + let token = adapter.get_long_lived_token(&identity_token).await?; + Ok(token.access_token().to_string()) + } else { + Ok(identity_token) + } + } + + async fn run_magic_link( + &self, + email: &str, + scope: Option<&str>, + adapter: T, + ) -> Result + where + T: CliMagicLinkAdapter, + { + let attempt = adapter.create_attempt(email, scope).await?; + + let mut auth_secret = String::new(); + print!("Enter the login token sent to the recipient: "); + std::io::stdout().flush()?; + std::io::stdin().read_line(&mut auth_secret)?; + + let token = adapter.exchange(attempt, email, &auth_secret).await?; + + Ok(token.access_token().to_string()) + } +} + +pub trait CliMagicLinkAdapter { + type Attempt; + type Token: CliAdapterToken; + type Error: StdError + Send + Sync + 'static; + + fn create_attempt( + &self, + email: &str, + scope: Option<&str>, + ) -> Pin> + Send>>; + fn exchange( + &self, + attempt: Self::Attempt, + email: &str, + token: &str, + ) -> Pin> + Send>>; +} diff --git a/v-cli-sdk/src/cmd/auth/mod.rs b/v-cli-sdk/src/cmd/auth/mod.rs new file mode 100644 index 00000000..5de8042e --- /dev/null +++ b/v-cli-sdk/src/cmd/auth/mod.rs @@ -0,0 +1,48 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use anyhow::Result; +use clap::{Parser, Subcommand}; +use std::error::Error as StdError; + +use crate::{cmd::auth::login::CliConsumerLoginProvider, CliContext}; + +pub mod login; +pub mod oauth; +pub mod proxy; + +// Authenticate against the Meetings API +#[derive(Parser, Debug)] +#[clap(name = "auth")] +pub struct Auth

+where + P: CliConsumerLoginProvider, +{ + #[command(subcommand)] + auth: AuthCommands

, +} + +#[derive(Subcommand, Debug, Clone)] +enum AuthCommands

+where + P: CliConsumerLoginProvider, +{ + /// Login via an authentication provider + Login(login::Login

), +} + +impl

Auth

+where + P: CliConsumerLoginProvider, +{ + pub async fn run(&self, ctx: &mut T) -> Result<()> + where + T: CliContext, + >::Error: StdError + Send + Sync + 'static, + { + match &self.auth { + AuthCommands::Login(login) => login.run(ctx).await, + } + } +} diff --git a/v-cli-sdk/src/cmd/auth/oauth.rs b/v-cli-sdk/src/cmd/auth/oauth.rs new file mode 100644 index 00000000..8f5f6751 --- /dev/null +++ b/v-cli-sdk/src/cmd/auth/oauth.rs @@ -0,0 +1,306 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use std::error::Error as StdError; +use std::future::Future; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; + +use anyhow::Result; +use http::{Request, Response, StatusCode}; +use http_body_util::{BodyExt, Full}; +use hyper::body::{Bytes, Incoming}; +use oauth2::basic::{BasicClient, BasicTokenType}; +use oauth2::StandardDeviceAuthorizationResponse; +use oauth2::{ + AuthType, AuthUrl, ClientId, CsrfToken, DeviceAuthorizationUrl, EmptyExtraTokenFields, + EndpointNotSet, EndpointSet, RedirectUrl, Scope, StandardTokenResponse, TokenUrl, +}; +use reqwest::Url; +use tokio::sync::oneshot; + +use crate::cmd::auth::login::CliAdapterToken; + +use super::proxy::run_proxy_server; + +pub trait CliOAuthAdapter { + type Token: CliAdapterToken; + type Error: StdError + Send + Sync + 'static; + + fn provider( + &self, + provider: &super::login::LoginProvider, + ) -> Pin> + Send>>; + fn exchange_authorization_code( + &self, + request: Request, + ) -> Pin>, Self::Error>> + Send>>; + fn get_long_lived_token( + &self, + access_token: &str, + ) -> Pin> + Send>>; +} + +pub trait CliOAuthProviderInfo { + fn device_code_endpoint(&self) -> Option<&str>; + fn code_redirect_proxy_endpoint(&self) -> Option<&str>; + fn auth_url_endpoint(&self) -> &str; + fn token_endpoint(&self) -> &str; + fn client_id(&self) -> &str; + fn scopes(&self) -> &[String]; +} + +type CodeClient = BasicClient< + // HasAuthUrl + EndpointSet, + // HasDeviceAuthUrl + EndpointNotSet, + // HasIntrospectionUrl + EndpointNotSet, + // HasRevocationUrl + EndpointNotSet, + // HasTokenUrl + EndpointSet, +>; + +pub struct CodeOAuth { + client: CodeClient, + scopes: Vec, + port: u16, +} + +impl CodeOAuth { + pub fn new(provider: T) -> Result + where + T: CliOAuthProviderInfo, + { + let redirect_url = provider + .code_redirect_proxy_endpoint() + .ok_or_else(|| anyhow::anyhow!("Provider does not support code redirect proxy flow"))?; + + let parsed_url = Url::parse(redirect_url)?; + + let port = parsed_url.port().ok_or_else(|| { + anyhow::anyhow!("Provider proxy url does not have a defined port to listen on") + })?; + + if parsed_url.scheme() != "http" { + anyhow::bail!("Provider proxy url scheme must be http"); + } + + if parsed_url + .host_str() + .map(|h| h != "localhost" && h != "127.0.0.1") + .unwrap_or(true) + { + anyhow::bail!("Provider proxy url host must be localhost"); + } + + let client = BasicClient::new(ClientId::new(provider.client_id().to_string())) + .set_auth_uri(AuthUrl::new(provider.auth_url_endpoint().to_string())?) + .set_auth_type(AuthType::RequestBody) + .set_token_uri(TokenUrl::new(provider.token_endpoint().to_string())?) + .set_redirect_uri(RedirectUrl::new(redirect_url.to_string())?); + + Ok(Self { + client, + scopes: provider.scopes().iter().map(|s| s.to_string()).collect(), + port, + }) + } + + /// Build the authorization URL that the user should visit in a browser. + /// Returns the full URL and the CSRF state token used for verification. + pub fn authorize_url(&self) -> (oauth2::url::Url, CsrfToken) { + let mut req = self.client.authorize_url(CsrfToken::new_random); + + for scope in &self.scopes { + req = req.add_scope(Scope::new(scope.to_string())); + } + + req.url() + } + + /// Run the full authorization code login flow: + /// + /// 1. Generate the authorization URL and print it for the user. + /// 2. Spin up a local HTTP proxy server to capture the IdP redirect. + /// 3. Forward the redirect request to the API server via the adapter. + /// 4. Extract the token from the server's response. + /// 5. Return a success page to the browser and shut down the proxy. + pub async fn login(&self, adapter: Arc) -> Result + where + T: CliOAuthAdapter + Send + Sync + 'static, + { + let (auth_url, _csrf_state) = self.authorize_url(); + + println!( + "Open the following URL in your browser to authenticate:\n\n {}\n", + auth_url + ); + + // Channel to receive the token extracted from the server response. + let (token_tx, token_rx) = oneshot::channel::>(); + let token_tx: Arc>>>> = + Arc::new(Mutex::new(Some(token_tx))); + + // Channel to shut down the proxy server once we have the token. + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + + let port = self.port; + + // Spawn the local proxy server in a background task. + tokio::spawn({ + let callback_token_tx = Arc::clone(&token_tx); + let error_token_tx = Arc::clone(&token_tx); + + async move { + let callback = Arc::new(move |request: Request| { + let adapter = Arc::clone(&adapter); + let token_tx = Arc::clone(&callback_token_tx); + + Box::pin(async move { + // Forward the redirect request to the API server. + let response = adapter + .exchange_authorization_code(request) + .await + .map_err(|e| anyhow::anyhow!(e))?; + + // The server responds with the access token in the body. + let (_parts, body) = response.into_parts(); + let body_bytes = body + .collect() + .await + .expect("Full collection cannot fail") + .to_bytes(); + let token = String::from_utf8(body_bytes.to_vec())?; + + // Send the token back to the main task. + if let Ok(mut guard) = token_tx.lock() { + if let Some(tx) = guard.take() { + let _ = tx.send(Ok(token)); + } + } + + // Return a friendly page to the browser so the user + // knows they can close the tab. + Ok(Response::builder() + .status(StatusCode::OK) + .header("content-type", "text/html; charset=utf-8") + .body(Full::new(Bytes::from(concat!( + "", + "

Authentication successful!

", + "

You can close this tab and return to the CLI.

", + "" + ))))?) + }) + as Pin< + Box>>> + Send>, + > + }); + + if let Err(e) = run_proxy_server(port, callback, shutdown_rx).await { + eprintln!("Proxy server error: {e}"); + + // If the proxy died before we got a token, unblock the + // receiver so the caller isn't stuck forever. + if let Ok(mut guard) = error_token_tx.lock() { + if let Some(tx) = guard.take() { + let _ = tx.send(Err(anyhow::anyhow!( + "Proxy server exited unexpectedly: {e}" + ))); + } + } + } + } + }); + + // Wait for the proxy callback to extract the token. + let token = token_rx.await.map_err(|_| { + anyhow::anyhow!( + "Authentication callback was never received — proxy server may have exited early" + ) + })??; + + // Tell the proxy server to stop. + let _ = shutdown_tx.send(()); + + Ok(token) + } +} + +type DeviceClient = BasicClient< + // HasAuthUrl + EndpointSet, + // HasDeviceAuthUrl + EndpointSet, + // HasIntrospectionUrl + EndpointNotSet, + // HasRevocationUrl + EndpointNotSet, + // HasTokenUrl + EndpointSet, +>; + +pub struct DeviceOAuth { + client: DeviceClient, + http: oauth2_reqwest::ReqwestClient, + scopes: Vec, +} + +impl DeviceOAuth { + pub fn new(provider: T) -> Result + where + T: CliOAuthProviderInfo, + { + if let Some(device_endpoint) = provider.device_code_endpoint() { + let device_auth_url = DeviceAuthorizationUrl::new(device_endpoint.to_string())?; + + let client = BasicClient::new(ClientId::new(provider.client_id().to_string())) + .set_auth_uri(AuthUrl::new(provider.auth_url_endpoint().to_string())?) + .set_auth_type(AuthType::RequestBody) + .set_token_uri(TokenUrl::new(provider.token_endpoint().to_string())?) + .set_device_authorization_url(device_auth_url); + + Ok(Self { + client, + http: oauth2_reqwest::ReqwestClient::from( + reqwest::ClientBuilder::new() + .redirect(reqwest::redirect::Policy::none()) + .build() + .unwrap(), + ), + scopes: provider.scopes().iter().map(|s| s.to_string()).collect(), + }) + } else { + anyhow::bail!("Device authorization is not supported by this provider") + } + } + + pub async fn login( + &self, + details: &StandardDeviceAuthorizationResponse, + ) -> Result> { + let token = self + .client + .exchange_device_access_token(details) + .set_max_backoff_interval(details.interval()) + .request_async(&self.http, tokio::time::sleep, Some(details.expires_in())) + .await; + + Ok(token?) + } + + pub async fn get_device_authorization(&self) -> Result { + let mut req = self.client.exchange_device_code(); + + for scope in &self.scopes { + req = req.add_scope(Scope::new(scope.to_string())); + } + + let res = req.request_async(&self.http).await; + + Ok(res?) + } +} diff --git a/v-cli-sdk/src/cmd/auth/proxy.rs b/v-cli-sdk/src/cmd/auth/proxy.rs new file mode 100644 index 00000000..47049421 --- /dev/null +++ b/v-cli-sdk/src/cmd/auth/proxy.rs @@ -0,0 +1,129 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use std::future::Future; +use std::net::SocketAddr; +use std::sync::Arc; + +use http_body_util::Full; +use hyper::body::{Bytes, Incoming}; +use hyper::service::service_fn; +use hyper::{Request, Response}; +use hyper_util::rt::TokioIo; +use tokio::net::TcpListener; +use tokio::sync::oneshot; + +/// A callback function that receives an incoming HTTP request and returns a response. +pub type Callback = Arc< + dyn Fn( + Request, + ) + -> std::pin::Pin>>> + Send>> + + Send + + Sync, +>; + +/// Start a minimal HTTP server on the given port that forwards every incoming +/// request to `callback` and returns whatever response the callback produces. +/// +/// The server will run until a message is sent on the `shutdown` channel, at +/// which point it will stop accepting new connections and return. +pub async fn run_proxy_server( + port: u16, + callback: Callback, + shutdown: oneshot::Receiver<()>, +) -> anyhow::Result<()> { + let addr = SocketAddr::from(([127, 0, 0, 1], port)); + let listener = TcpListener::bind(addr).await?; + serve_loop(listener, callback, shutdown).await +} + +/// Core accept-loop shared by [`run_proxy_server`] and tests. +/// +/// Accepts connections on `listener`, forwarding each request to `callback`. +/// Stops when `shutdown` fires. +async fn serve_loop( + listener: TcpListener, + callback: Callback, + shutdown: oneshot::Receiver<()>, +) -> anyhow::Result<()> { + tokio::pin!(shutdown); + + loop { + tokio::select! { + _ = &mut shutdown => { + break; + } + accepted = listener.accept() => { + let (stream, _remote_addr) = accepted?; + let io = TokioIo::new(stream); + let cb = Arc::clone(&callback); + + tokio::task::spawn(async move { + let service = service_fn(move |req: Request| { + let cb = Arc::clone(&cb); + async move { cb(req).await } + }); + + if let Err(err) = + hyper::server::conn::http1::Builder::new() + .serve_connection(io, service) + .await + { + eprintln!("Error serving connection: {err}"); + } + }); + } + } + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use hyper::StatusCode; + + #[tokio::test] + async fn test_proxy_server_responds() { + let callback: Callback = Arc::new(|_req| { + Box::pin(async { + Ok(Response::builder() + .status(StatusCode::OK) + .body(Full::new(Bytes::from("hello from callback"))) + .unwrap()) + }) + }); + + let (tx, rx) = oneshot::channel::<()>(); + + // Use port 0 to let the OS pick an available port. + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await.unwrap(); + let local_addr = listener.local_addr().unwrap(); + + let server_handle = tokio::spawn({ + let callback = Arc::clone(&callback); + async move { + serve_loop(listener, callback, rx).await.unwrap(); + } + }); + + // Send a request to the server. + let client = reqwest::Client::new(); + let resp = client + .get(format!("http://{}", local_addr)) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), 200); + assert_eq!(resp.text().await.unwrap(), "hello from callback"); + + // Shut down the server. + tx.send(()).unwrap(); + server_handle.await.unwrap(); + } +} diff --git a/v-cli-sdk/src/cmd/config/mod.rs b/v-cli-sdk/src/cmd/config/mod.rs new file mode 100644 index 00000000..bf2d123e --- /dev/null +++ b/v-cli-sdk/src/cmd/config/mod.rs @@ -0,0 +1,149 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use anyhow::Result; +use clap::{Parser, Subcommand}; + +use crate::{CliContext, FormatStyle}; + +pub trait CliConfig { + fn host(&self) -> Option<&str>; + fn set_host(&mut self, host: String); + fn token(&self) -> Option<&str>; + fn set_token(&mut self, token: String); + fn default_format(&self) -> Option<&FormatStyle>; + fn set_default_format(&mut self, format: FormatStyle); + fn mlink_redirect(&self) -> Option<&str>; + fn set_mlink_redirect(&mut self, redirect: String); + fn mlink_secret(&self) -> Option<&str>; + fn set_mlink_secret(&mut self, secret: String); + fn save(&self) -> Result<(), std::io::Error>; +} + +#[derive(Debug, Parser)] +#[clap(name = "config")] +pub struct ConfigCmd { + #[clap(subcommand)] + setting: SettingCmd, +} + +#[derive(Debug, Subcommand)] +pub enum SettingCmd { + /// Gets a setting + #[clap(subcommand, name = "get")] + Get(GetCmd), + /// Sets a setting + #[clap(subcommand, name = "set")] + Set(SetCmd), +} + +#[derive(Debug, Subcommand)] +pub enum GetCmd { + /// Get the default formatter to use when printing results + #[clap(name = "format")] + Format, + /// Get the configured API host in use + #[clap(name = "host")] + Host, + /// Get the configured access token + #[clap(name = "token")] + Token, + /// Get the configured magic redirect uri + #[clap(name = "mlink-redirect")] + MagicLinkRedirectUri, + /// Get the configured magic link secret + #[clap(name = "mlink-secret")] + MagicLinkSecret, +} + +#[derive(Debug, Subcommand)] +pub enum SetCmd { + /// Set the default formatter to use when printing results + #[clap(name = "format")] + Format { format: FormatStyle }, + /// Set the configured API host to use + #[clap(name = "host")] + Host { host: String }, + /// Set the configured magic redirect uri + #[clap(name = "mlink-redirect")] + MagicLinkRedirectUri { redirect: String }, + /// Set the configured magic link secret + #[clap(name = "mlink-secret")] + MagicLinkSecret { secret: String }, +} + +impl ConfigCmd { + pub async fn run(&self, ctx: &mut T) -> Result<()> + where + T: CliContext, + { + match &self.setting { + SettingCmd::Get(get) => get.run(ctx.config()).await?, + SettingCmd::Set(set) => set.run(ctx.config_mut()).await?, + } + + Ok(()) + } +} + +impl GetCmd { + pub async fn run(&self, config: &T) -> Result<()> + where + T: CliConfig, + { + match &self { + GetCmd::Format => { + println!( + "{}", + config + .default_format() + .copied() + .unwrap_or(FormatStyle::Json) + ); + } + GetCmd::Host => { + println!("{}", config.host().unwrap_or("None")); + } + GetCmd::Token => { + println!("{}", config.token().unwrap_or("None")); + } + GetCmd::MagicLinkRedirectUri => { + println!("{}", config.mlink_redirect().unwrap_or("None")); + } + GetCmd::MagicLinkSecret => { + println!("{}", config.mlink_secret().unwrap_or("None")); + } + } + + Ok(()) + } +} + +impl SetCmd { + pub async fn run(&self, config: &mut T) -> Result<()> + where + T: CliConfig, + { + match &self { + SetCmd::Format { format } => { + config.set_default_format(*format); + config.save()?; + } + SetCmd::Host { host } => { + config.set_host(host.to_string()); + config.save()?; + } + SetCmd::MagicLinkRedirectUri { redirect } => { + config.set_mlink_redirect(redirect.to_string()); + config.save()?; + } + SetCmd::MagicLinkSecret { secret } => { + config.set_mlink_secret(secret.to_string()); + config.save()?; + } + } + + Ok(()) + } +} diff --git a/v-cli-sdk/src/cmd/mod.rs b/v-cli-sdk/src/cmd/mod.rs new file mode 100644 index 00000000..df67c4a3 --- /dev/null +++ b/v-cli-sdk/src/cmd/mod.rs @@ -0,0 +1,6 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +pub mod auth; +pub mod config; diff --git a/v-cli-sdk/src/err.rs b/v-cli-sdk/src/err.rs new file mode 100644 index 00000000..e84eeb36 --- /dev/null +++ b/v-cli-sdk/src/err.rs @@ -0,0 +1,71 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use anyhow::{anyhow, Error}; +use progenitor_client::Error as ProgenitorClientError; + +use crate::{ApiErrorMessage, CliContext, VerbosityLevel}; + +pub fn format_api_err(ctx: &T, client_err: ProgenitorClientError) -> Error +where + T: CliContext, + E: ApiErrorMessage, +{ + let mut err = anyhow!("API Request failed"); + + match client_err { + ProgenitorClientError::CommunicationError(inner) => { + if ctx.verbosity() >= VerbosityLevel::All { + err = err.context("Communication Error").context(inner); + } + } + ProgenitorClientError::ErrorResponse(response) => { + if ctx.verbosity() >= VerbosityLevel::All { + err = err.context(format!("Status: {}", response.status())); + err = err.context(format!("Headers {:?}", response.headers())); + } + + let response_message = response.into_inner(); + + if ctx.verbosity() >= VerbosityLevel::All { + err = err.context(format!( + "Request {}", + response_message.request_id().unwrap_or("") + )); + } + + err = err.context(format!( + "Code: {}", + response_message.error_code().unwrap_or("") + )); + err = err.context(response_message.message().unwrap_or("").to_string()); + } + ProgenitorClientError::InvalidRequest(message) => { + err = err.context("Invalid request").context(message); + } + ProgenitorClientError::InvalidResponsePayload(_, inner) => { + err = err.context("Invalid response").context(inner); + } + ProgenitorClientError::UnexpectedResponse(response) => { + err = err + .context("Unexpected response") + .context(format!("Status: {}", response.status())); + + if ctx.verbosity() >= VerbosityLevel::All { + err = err.context(format!("Headers {:?}", response.headers())); + } + } + ProgenitorClientError::ResponseBodyError(inner) => { + err = err.context("Invalid response").context(inner); + } + ProgenitorClientError::InvalidUpgrade(inner) => { + err = err.context("Invalid upgrade").context(inner) + } + ProgenitorClientError::Custom(inner) => { + err = err.context("Inner progenitor error").context(inner) + } + } + + err +} diff --git a/v-cli-sdk/src/lib.rs b/v-cli-sdk/src/lib.rs new file mode 100644 index 00000000..652ab27c --- /dev/null +++ b/v-cli-sdk/src/lib.rs @@ -0,0 +1,63 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use clap::ValueEnum; +use serde::{Deserialize, Serialize}; +use std::fmt::Display; + +use crate::cmd::{ + auth::{login::CliMagicLinkAdapter, oauth::CliOAuthAdapter}, + config::CliConfig, +}; + +pub mod cmd; +pub mod err; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum VerbosityLevel { + None, + All, +} + +#[derive(Copy, Debug, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Clone, Serialize, Deserialize)] +pub enum FormatStyle { + #[value(name = "json")] + Json, + #[value(name = "tab")] + Tab, +} + +impl Display for FormatStyle { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Json => write!(f, "json"), + Self::Tab => write!(f, "tab"), + } + } +} + +pub trait CliContext { + type Attempt; + type Token; + type Error; + + fn config(&self) -> &impl CliConfig; + fn config_mut(&mut self) -> &mut impl CliConfig; + fn client(&self) -> Option<&C>; + fn printer(&self) -> Option<&P>; + fn verbosity(&self) -> VerbosityLevel; + + fn oauth_adapter( + &self, + ) -> impl CliOAuthAdapter + Send + Sync + 'static; + fn mlink_adapter( + &self, + ) -> impl CliMagicLinkAdapter + Send + Sync + 'static; +} + +pub trait ApiErrorMessage { + fn message(&self) -> Option<&str>; + fn error_code(&self) -> Option<&str>; + fn request_id(&self) -> Option<&str>; +}