Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 35 additions & 19 deletions crates/forge_infra/src/auth/http/standard.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
use forge_app::OAuthHttpProvider;
use forge_domain::{AuthCodeParams, OAuthConfig, OAuthTokenResponse};
use oauth2::{
AuthorizationCode as OAuth2AuthCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope,
};
use oauth2::{CsrfToken, PkceCodeChallenge, Scope};
use serde::Serialize;

use crate::auth::util::*;

/// Standard RFC-compliant OAuth provider
pub struct StandardHttpProvider;

#[derive(Debug, Serialize)]
struct StandardTokenRequest<'a> {
grant_type: &'static str,
code: &'a str,
client_id: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
redirect_uri: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
code_verifier: Option<&'a str>,
}

#[async_trait::async_trait]
impl OAuthHttpProvider for StandardHttpProvider {
async fn build_auth_url(&self, config: &OAuthConfig) -> anyhow::Result<AuthCodeParams> {
Expand Down Expand Up @@ -58,27 +68,33 @@ impl OAuthHttpProvider for StandardHttpProvider {
code: &str,
verifier: Option<&str>,
) -> anyhow::Result<OAuthTokenResponse> {
use oauth2::{AuthUrl, ClientId, TokenUrl};

let mut client =
oauth2::basic::BasicClient::new(ClientId::new(config.client_id.to_string()))
.set_auth_uri(AuthUrl::new(config.auth_url.to_string())?)
.set_token_uri(TokenUrl::new(config.token_url.to_string())?);

if let Some(redirect_uri) = &config.redirect_uri {
client = client.set_redirect_uri(oauth2::RedirectUrl::new(redirect_uri.clone())?);
}

let http_client = self.build_http_client(config)?;
let request_body = StandardTokenRequest {
grant_type: "authorization_code",
code,
client_id: config.client_id.as_ref(),
redirect_uri: config.redirect_uri.as_deref(),
code_verifier: verifier,
};

let response = http_client
.post(config.token_url.as_str())
.header("Content-Type", "application/x-www-form-urlencoded")
.header("Accept", "application/json")
.body(serde_urlencoded::to_string(&request_body)?)
.send()
.await?;

let mut request = client.exchange_code(OAuth2AuthCode::new(code.to_string()));
let status = response.status();
let body = response.text().await?;

if let Some(v) = verifier {
request = request.set_pkce_verifier(PkceCodeVerifier::new(v.to_string()));
if !status.is_success() {
anyhow::bail!("OAuth token exchange failed ({status}): {body}");
}

let token_result = request.request_async(&http_client).await?;
Ok(into_domain(token_result))
// Parse the raw token payload so provider-specific fields like
// `id_token` are preserved instead of being dropped by generic helpers.
Ok(parse_token_response(&body)?)
}

/// Create HTTP client with provider-specific headers/behavior
Expand Down
71 changes: 49 additions & 22 deletions crates/forge_infra/src/auth/strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -772,23 +772,12 @@ async fn poll_for_tokens(
}

// No error field - parse as success
let (access_token, refresh_token, expires_in) = parse_token_response(&body_text)?;

return Ok(build_token_response(
access_token,
refresh_token,
expires_in,
));
return Ok(parse_token_response(&body_text)?);
}

// Standard OAuth: HTTP success means tokens
if !github_compatible && status.is_success() {
let (access_token, refresh_token, expires_in) = parse_token_response(&body_text)?;
return Ok(build_token_response(
access_token,
refresh_token,
expires_in,
));
return Ok(parse_token_response(&body_text)?);
}

// Handle error responses (non-200 status for standard OAuth)
Expand Down Expand Up @@ -911,16 +900,11 @@ async fn codex_poll_for_tokens(
.into());
}

let (access_token, refresh_token, expires_in) =
parse_token_response(&token_response.text().await.map_err(|e| {
return Ok(parse_token_response(
&token_response.text().await.map_err(|e| {
AuthError::PollFailed(format!("Failed to read token response: {e}"))
})?)?;

return Ok(build_token_response(
access_token,
refresh_token,
expires_in,
));
})?,
)?);
}

// 403/404 means authorization pending (user hasn't entered code yet)
Expand Down Expand Up @@ -1323,6 +1307,49 @@ mod tests {
assert_eq!(actual, None);
}

#[test]
fn test_enrich_codex_oauth_credential_uses_id_token_claims() {
let fixture_id_token = build_jwt(&serde_json::json!({
"chatgpt_account_id": "acct_from_id_token"
}));
let fixture_access_token = "not-a-jwt";
let mut actual = AuthCredential::new_oauth(
ProviderId::CODEX,
OAuthTokens::new(
fixture_access_token,
None::<String>,
chrono::Utc::now() + chrono::Duration::hours(1),
),
OAuthConfig {
client_id: "test".to_string().into(),
auth_url: Url::parse("https://example.com/auth").unwrap(),
token_url: Url::parse("https://example.com/token").unwrap(),
scopes: vec![],
redirect_uri: Some("http://localhost:1455/auth/callback".to_string()),
use_pkce: true,
token_refresh_url: None,
extra_auth_params: None,
custom_headers: None,
},
);

enrich_codex_oauth_credential(
&ProviderId::CODEX,
&mut actual,
Some(&fixture_id_token),
fixture_access_token,
);

let actual = actual
.url_params
.get(&URLParam::from("chatgpt_account_id".to_string()));
let expected = Some(&forge_domain::URLParamValue::from(
"acct_from_id_token".to_string(),
));

assert_eq!(actual, expected);
}

#[tokio::test]
async fn test_refresh_oauth_credential_preserves_url_params() {
let fixture_config = OAuthConfig {
Expand Down
67 changes: 23 additions & 44 deletions crates/forge_infra/src/auth/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,23 +86,6 @@ pub(crate) fn build_oauth_credential(
))
}

/// Build OAuthTokenResponse with standard defaults
pub(crate) fn build_token_response(
access_token: String,
refresh_token: Option<String>,
expires_in: Option<u64>,
) -> OAuthTokenResponse {
OAuthTokenResponse {
access_token,
refresh_token,
expires_in,
expires_at: None,
token_type: "Bearer".to_string(),
scope: None,
id_token: None,
}
}

/// Extract OAuth tokens from any credential type
pub(crate) fn extract_oauth_tokens(credential: &AuthCredential) -> anyhow::Result<&OAuthTokens> {
match &credential.auth_details {
Expand Down Expand Up @@ -217,25 +200,18 @@ pub(crate) fn handle_oauth_error(error_code: &str) -> Result<(), Error> {
}
}

/// Parse token response from JSON
pub(crate) fn parse_token_response(
body: &str,
) -> Result<(String, Option<String>, Option<u64>), Error> {
let token_response: serde_json::Value = serde_json::from_str(body)
/// Parse token response from JSON.
pub(crate) fn parse_token_response(body: &str) -> Result<OAuthTokenResponse, Error> {
let token_response: OAuthTokenResponse = serde_json::from_str(body)
.map_err(|e| Error::PollFailed(format!("Failed to parse token response: {e}")))?;

let access_token = token_response["access_token"]
.as_str()
.ok_or_else(|| Error::PollFailed("Missing access_token in response".to_string()))?
.to_string();

let refresh_token = token_response["refresh_token"]
.as_str()
.map(|s| s.to_string());

let expires_in = token_response["expires_in"].as_u64();
if token_response.access_token.trim().is_empty() {
return Err(Error::PollFailed(
"Missing access_token in response".to_string(),
));
}

Ok((access_token, refresh_token, expires_in))
Ok(token_response)
}

#[cfg(test)]
Expand Down Expand Up @@ -265,17 +241,20 @@ mod tests {
}

#[test]
fn test_build_token_response() {
let response = build_token_response(
"test_token".to_string(),
Some("refresh_token".to_string()),
Some(3600),
);

assert_eq!(response.access_token, "test_token");
assert_eq!(response.refresh_token, Some("refresh_token".to_string()));
assert_eq!(response.expires_in, Some(3600));
assert_eq!(response.token_type, "Bearer");
fn test_parse_token_response_preserves_id_token() {
let fixture = r#"{
"access_token": "test_token",
"refresh_token": "refresh_token",
"expires_in": 3600,
"id_token": "test_id_token"
}"#;

let actual = parse_token_response(fixture).unwrap();

assert_eq!(actual.access_token, "test_token");
assert_eq!(actual.refresh_token, Some("refresh_token".to_string()));
assert_eq!(actual.expires_in, Some(3600));
assert_eq!(actual.id_token, Some("test_id_token".to_string()));
}

#[test]
Expand Down
Loading