Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ members = [
"v-api-installer",
"v-api-param",
"v-api-permission-derive",
"v-cli-sdk",
"v-model",
"xtask"
]
Expand All @@ -29,6 +30,7 @@ hex = "0.4.3"
http = "1"
http-body-util = "0.1.3"
hyper = "1.9.0"
hyper-util = "0.1"
jsonwebtoken = { version = "10.2", features = ["rust_crypto"] }
mockall = "0.14.0"
newtype-uuid = { version = "1.3.2", features = ["schemars08", "serde", "v4"] }
Expand All @@ -37,6 +39,7 @@ oauth2-reqwest = "0.1.0-alpha.3"
partial-struct = { git = "https://github.com/oxidecomputer/partial-struct" }
percent-encoding = "2.3.2"
proc-macro2 = "1"
progenitor-client = "0.14.0"
quote = "1"
rand = "0.10.1"
rand_core = "0.10.1"
Expand Down
1 change: 1 addition & 0 deletions v-api/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ pub struct SendGridConfig {
pub struct OAuthProviders {
pub github: Option<OAuthConfig>,
pub google: Option<OAuthConfig>,
pub zendesk: Option<OAuthConfig>,
}

#[derive(Debug, Deserialize)]
Expand Down
3 changes: 2 additions & 1 deletion v-api/src/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1259,7 +1259,7 @@ pub(crate) mod test_mocks {
pub async fn mock_context(storage: Arc<MockStorage>) -> VContext<VPermission> {
let MockKey { signer, verifier } = mock_key("test");
let mut ctx = VContextBuilder::<VPermission>::new()
.with_public_url("".to_string())
.with_public_url("https://test_public_url".to_string())
.with_storage(storage)
.with_jwt_expiration(JwtConfig::default().default_expiration)
.with_keys(vec![signer, verifier])
Expand All @@ -1278,6 +1278,7 @@ pub(crate) mod test_mocks {
OAuthProviderName::Google,
Box::new(move || {
Box::new(GoogleOAuthProvider::new(
"https://test_public_url".to_string(),
"google_device_client_id".to_string(),
"google_device_client_secret".to_string().into(),
"google_web_client_id".to_string(),
Expand Down
5 changes: 3 additions & 2 deletions v-api/src/endpoints/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ mod macros {
code::{
authz_code_callback_op, authz_code_exchange_op, authz_code_redirect_op,
OAuthAuthzCodeExchangeBody, OAuthAuthzCodeExchangeResponse,
OAuthAuthzCodeQuery, OAuthAuthzCodeReturnQuery,
OAuthAuthzCodeQuery, OAuthAuthzCodeReturnQuery, OAuthAuthzCodeExchangeQuery
},
device_token::{
exchange_device_token_op, get_device_provider_op, AccessTokenExchangeRequest,
Expand Down Expand Up @@ -296,9 +296,10 @@ mod macros {
pub async fn authz_code_exchange(
rqctx: RequestContext<$context_type>,
path: Path<OAuthProviderNameParam>,
query: Query<OAuthAuthzCodeExchangeQuery>,
body: TypedBody<OAuthAuthzCodeExchangeBody>,
) -> Result<HttpResponseOk<OAuthAuthzCodeExchangeResponse>, HttpError> {
authz_code_exchange_op(&rqctx, path, body).await
authz_code_exchange_op(&rqctx, path, query, body).await
}

// DEVICE CODE
Expand Down
12 changes: 12 additions & 0 deletions v-api/src/endpoints/login/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ impl From<LoginError> for HttpError {
pub enum ExternalUserId {
GitHub(String),
Google(String),
Zendesk(String),
#[cfg(feature = "local-dev")]
Local(String),
MagicLink(String),
Expand All @@ -70,6 +71,7 @@ impl ExternalUserId {
match self {
Self::GitHub(id) => id,
Self::Google(id) => id,
Self::Zendesk(id) => id,
#[cfg(feature = "local-dev")]
Self::Local(id) => id,
Self::MagicLink(id) => id,
Expand All @@ -80,6 +82,7 @@ impl ExternalUserId {
match self {
Self::GitHub(_) => "github",
Self::Google(_) => "google",
Self::Zendesk(_) => "zendesk",
#[cfg(feature = "local-dev")]
Self::Local(_) => "local",
Self::MagicLink(_) => "magic-link",
Expand All @@ -103,6 +106,7 @@ impl Serialize for ExternalUserId {
match self {
ExternalUserId::GitHub(id) => serializer.serialize_str(&format!("github-{}", id)),
ExternalUserId::Google(id) => serializer.serialize_str(&format!("google-{}", id)),
ExternalUserId::Zendesk(id) => serializer.serialize_str(&format!("zendesk-{}", id)),
#[cfg(feature = "local-dev")]
ExternalUserId::Local(id) => serializer.serialize_str(&format!("local-{}", id)),
ExternalUserId::MagicLink(id) => {
Expand Down Expand Up @@ -142,6 +146,12 @@ impl<'de> Deserialize<'de> for ExternalUserId {
} else {
Err(de::Error::custom(ExternalUserIdDeserializeError::Empty))
}
} else if let Some(("", id)) = value.split_once("zendesk-") {
if !id.is_empty() {
Ok(ExternalUserId::Zendesk(id.to_string()))
} else {
Err(de::Error::custom(ExternalUserIdDeserializeError::Empty))
}
} else if let Some(("", id)) = value.split_once("local-") {
#[cfg(feature = "local-dev")]
{
Expand Down Expand Up @@ -191,6 +201,8 @@ pub enum UserInfoError {
Deserialize(#[from] serde_json::Error),
#[error("Failed to create user info request {0}")]
Http(#[from] http::Error),
#[error("User account is locked")]
Locked,
#[error("User information is missing")]
MissingUserInfoData(String),
}
Expand Down
56 changes: 38 additions & 18 deletions v-api/src/endpoints/login/oauth/code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use v_model::{
LoginAttempt, LoginAttemptId, NewLoginAttempt, OAuthClient, OAuthClientId,
};

use super::{OAuthProvider, OAuthProviderNameParam, UserInfoProvider, WebClientConfig};
use super::{OAuthProvider, OAuthProviderNameParam, UserInfoProvider};
use crate::{
authn::key::RawKey,
context::{ApiContext, VContext},
Expand Down Expand Up @@ -239,11 +239,7 @@ fn oauth_redirect_response(
// TODO: This behavior should be changed so that clients are precomputed. We do not need to be
// constructing a new client on every request. That said, we need to ensure the client does not
// maintain state between requests
let client = provider
.as_web_client(&WebClientConfig {
prefix: public_url.to_string(),
})
.map_err(to_internal_error)?;
let client = provider.as_web_client().map_err(to_internal_error)?;

// Create an attempt cookie header for storing the login attempt. This also acts as our csrf
// check
Expand Down Expand Up @@ -440,6 +436,12 @@ where
Ok(attempt.callback_url())
}

#[derive(Debug, Deserialize, JsonSchema)]
pub struct OAuthAuthzCodeExchangeQuery {
#[serde(default)]
pub include_idp_token: bool,
}

#[derive(Debug, Deserialize, JsonSchema)]
pub struct OAuthAuthzCodeExchangeBody {
pub client_id: Option<TypedUuid<OAuthClientId>>,
Expand All @@ -455,19 +457,27 @@ pub struct OAuthAuthzCodeExchangeResponse {
pub access_token: String,
pub token_type: String,
pub expires_in: i64,
pub idp_token: Option<OAuthAuthzCodeIdpToken>,
}

#[derive(Debug, Deserialize, JsonSchema, Serialize)]
pub struct OAuthAuthzCodeIdpToken {
pub token: String,
}

#[instrument(skip(rqctx), err(Debug))]
pub async fn authz_code_exchange_op<T>(
rqctx: &RequestContext<impl ApiContext<AppPermissions = T>>,
path: Path<OAuthProviderNameParam>,
query: Query<OAuthAuthzCodeExchangeQuery>,
body: TypedBody<OAuthAuthzCodeExchangeBody>,
) -> Result<HttpResponseOk<OAuthAuthzCodeExchangeResponse>, HttpError>
where
T: VAppPermission + PermissionStorage,
{
let ctx = rqctx.v_ctx();
let path = path.into_inner();
let query = query.into_inner();
let body = body.into_inner();

let (client_id, client_secret) =
Expand Down Expand Up @@ -541,7 +551,14 @@ where

// Now that the attempt has been confirmed, use it to fetch user information form the remote
// provider
let info = fetch_user_info(ctx.public_url(), &ctx.web_client(), &*provider, &attempt).await?;
let (info, raw_token) = fetch_user_info(
ctx.public_url(),
&ctx.web_client(),
&*provider,
&attempt,
query.include_idp_token,
)
.await?;

tracing::debug!("Retrieved user information from remote provider");

Expand Down Expand Up @@ -589,6 +606,9 @@ where
token_type: "Bearer".to_string(),
access_token: token.signed_token,
expires_in: token.expires_in,
idp_token: query.include_idp_token.then(|| OAuthAuthzCodeIdpToken {
token: raw_token.unwrap(),
}),
}))
}

Expand Down Expand Up @@ -713,13 +733,10 @@ async fn fetch_user_info(
client_type: &ClientType,
provider: &dyn OAuthProvider,
attempt: &LoginAttempt,
) -> Result<UserInfo, HttpError> {
return_raw: bool,
) -> Result<(UserInfo, Option<String>), HttpError> {
// Exchange the stored authorization code with the remote provider for a remote access token
let client = provider
.as_web_client(&WebClientConfig {
prefix: public_url.to_string(),
})
.map_err(to_internal_error)?;
let client = provider.as_web_client().map_err(to_internal_error)?;

let mut request = client.exchange_code(AuthorizationCode::new(
attempt
Expand Down Expand Up @@ -754,7 +771,7 @@ async fn fetch_user_info(

// Now that we are done with fetching user information from the remote API, we can revoke it if
// the provider supports it
if provider.token_revocation_endpoint().is_some() {
if !return_raw && provider.token_revocation_endpoint().is_some() {
client
.revoke_token(response.access_token().into())
.map_err(internal_error)?
Expand All @@ -763,7 +780,11 @@ async fn fetch_user_info(
.map_err(internal_error)?;
}

Ok(info)
if return_raw {
Ok((info, Some(response.access_token().secret().to_string())))
} else {
Ok((info, None))
}
}

#[cfg(test)]
Expand Down Expand Up @@ -889,8 +910,7 @@ mod tests {
#[tokio::test]
async fn test_remote_provider_redirect_url() {
let storage = MockStorage::new();
let mut ctx = mock_context(Arc::new(storage)).await;
ctx.with_public_url("https://api.oxeng.dev");
let ctx = mock_context(Arc::new(storage)).await;

let (challenge, _) = PkceCodeChallenge::new_random_sha256();
let attempt = LoginAttempt {
Expand Down Expand Up @@ -927,7 +947,7 @@ mod tests {
.unwrap();
let headers = response.headers();

let expected_location = format!("https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=google_web_client_id&state={}&code_challenge={}&code_challenge_method=S256&redirect_uri=https%3A%2F%2Fapi.oxeng.dev%2Flogin%2Foauth%2Fgoogle%2Fcode%2Fcallback&scope=openid+email+profile", attempt.id, challenge.as_str());
let expected_location = format!("https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=google_web_client_id&state={}&code_challenge={}&code_challenge_method=S256&redirect_uri=https%3A%2F%2Ftest_public_url%2Flogin%2Foauth%2Fgoogle%2Fcode%2Fcallback&scope=openid+email+profile", attempt.id, challenge.as_str());

assert_eq!(
expected_location,
Expand Down
23 changes: 19 additions & 4 deletions v-api/src/endpoints/login/oauth/device_token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,7 @@ where
.await
.map_err(ApiError::OAuth)?;

Ok(HttpResponseOk(provider.provider_info(
rqctx.v_ctx().public_url(),
&ClientType::Device,
)))
Ok(HttpResponseOk(provider.provider_info(&ClientType::Device)))
}

#[derive(Debug, Deserialize, JsonSchema, Serialize)]
Expand Down Expand Up @@ -124,6 +121,24 @@ where

tracing::debug!(provider = ?provider.name(), "Acquired OAuth provider for token exchange");

if provider.device_code_endpoint().is_none() {
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.header(header::CONTENT_TYPE, "application/json")
.body(
serde_json::to_vec(&ProxyTokenError {
error: "unsupported_grant_type".to_string(),
error_description: Some(format!(
"{} does not support device code flow",
path.provider
)),
error_uri: None,
})
.unwrap()
.into(),
)?);
}

let exchange_request = body.into_inner();

if let Some(exchange) = AccessTokenExchange::new(exchange_request, &*provider) {
Expand Down
Loading
Loading