diff --git a/crates/forge_infra/src/auth/http/standard.rs b/crates/forge_infra/src/auth/http/standard.rs index 8df1acd14e..162b307138 100644 --- a/crates/forge_infra/src/auth/http/standard.rs +++ b/crates/forge_infra/src/auth/http/standard.rs @@ -1,14 +1,24 @@ use forge_app::OAuthHttpProvider; use forge_domain::{AuthCodeParams, OAuthConfig, OAuthTokenResponse}; -use oauth2::{ - AuthorizationCode as OAuth2AuthCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, -}; +use oauth2::{CsrfToken, PkceCodeChallenge, Scope}; +use serde::Serialize; use crate::auth::util::*; /// Standard RFC-compliant OAuth provider pub struct StandardHttpProvider; +#[derive(Debug, Serialize)] +struct StandardTokenRequest<'a> { + grant_type: &'static str, + code: &'a str, + client_id: &'a str, + #[serde(skip_serializing_if = "Option::is_none")] + redirect_uri: Option<&'a str>, + #[serde(skip_serializing_if = "Option::is_none")] + code_verifier: Option<&'a str>, +} + #[async_trait::async_trait] impl OAuthHttpProvider for StandardHttpProvider { async fn build_auth_url(&self, config: &OAuthConfig) -> anyhow::Result { @@ -58,27 +68,33 @@ impl OAuthHttpProvider for StandardHttpProvider { code: &str, verifier: Option<&str>, ) -> anyhow::Result { - use oauth2::{AuthUrl, ClientId, TokenUrl}; - - let mut client = - oauth2::basic::BasicClient::new(ClientId::new(config.client_id.to_string())) - .set_auth_uri(AuthUrl::new(config.auth_url.to_string())?) - .set_token_uri(TokenUrl::new(config.token_url.to_string())?); - - if let Some(redirect_uri) = &config.redirect_uri { - client = client.set_redirect_uri(oauth2::RedirectUrl::new(redirect_uri.clone())?); - } - let http_client = self.build_http_client(config)?; + let request_body = StandardTokenRequest { + grant_type: "authorization_code", + code, + client_id: config.client_id.as_ref(), + redirect_uri: config.redirect_uri.as_deref(), + code_verifier: verifier, + }; + + let response = http_client + .post(config.token_url.as_str()) + .header("Content-Type", "application/x-www-form-urlencoded") + .header("Accept", "application/json") + .body(serde_urlencoded::to_string(&request_body)?) + .send() + .await?; - let mut request = client.exchange_code(OAuth2AuthCode::new(code.to_string())); + let status = response.status(); + let body = response.text().await?; - if let Some(v) = verifier { - request = request.set_pkce_verifier(PkceCodeVerifier::new(v.to_string())); + if !status.is_success() { + anyhow::bail!("OAuth token exchange failed ({status}): {body}"); } - let token_result = request.request_async(&http_client).await?; - Ok(into_domain(token_result)) + // Parse the raw token payload so provider-specific fields like + // `id_token` are preserved instead of being dropped by generic helpers. + Ok(parse_token_response(&body)?) } /// Create HTTP client with provider-specific headers/behavior diff --git a/crates/forge_infra/src/auth/strategy.rs b/crates/forge_infra/src/auth/strategy.rs index 2e223c98a4..0d64ae160f 100644 --- a/crates/forge_infra/src/auth/strategy.rs +++ b/crates/forge_infra/src/auth/strategy.rs @@ -772,23 +772,12 @@ async fn poll_for_tokens( } // No error field - parse as success - let (access_token, refresh_token, expires_in) = parse_token_response(&body_text)?; - - return Ok(build_token_response( - access_token, - refresh_token, - expires_in, - )); + return Ok(parse_token_response(&body_text)?); } // Standard OAuth: HTTP success means tokens if !github_compatible && status.is_success() { - let (access_token, refresh_token, expires_in) = parse_token_response(&body_text)?; - return Ok(build_token_response( - access_token, - refresh_token, - expires_in, - )); + return Ok(parse_token_response(&body_text)?); } // Handle error responses (non-200 status for standard OAuth) @@ -911,16 +900,11 @@ async fn codex_poll_for_tokens( .into()); } - let (access_token, refresh_token, expires_in) = - parse_token_response(&token_response.text().await.map_err(|e| { + return Ok(parse_token_response( + &token_response.text().await.map_err(|e| { AuthError::PollFailed(format!("Failed to read token response: {e}")) - })?)?; - - return Ok(build_token_response( - access_token, - refresh_token, - expires_in, - )); + })?, + )?); } // 403/404 means authorization pending (user hasn't entered code yet) @@ -1323,6 +1307,49 @@ mod tests { assert_eq!(actual, None); } + #[test] + fn test_enrich_codex_oauth_credential_uses_id_token_claims() { + let fixture_id_token = build_jwt(&serde_json::json!({ + "chatgpt_account_id": "acct_from_id_token" + })); + let fixture_access_token = "not-a-jwt"; + let mut actual = AuthCredential::new_oauth( + ProviderId::CODEX, + OAuthTokens::new( + fixture_access_token, + None::, + chrono::Utc::now() + chrono::Duration::hours(1), + ), + OAuthConfig { + client_id: "test".to_string().into(), + auth_url: Url::parse("https://example.com/auth").unwrap(), + token_url: Url::parse("https://example.com/token").unwrap(), + scopes: vec![], + redirect_uri: Some("http://localhost:1455/auth/callback".to_string()), + use_pkce: true, + token_refresh_url: None, + extra_auth_params: None, + custom_headers: None, + }, + ); + + enrich_codex_oauth_credential( + &ProviderId::CODEX, + &mut actual, + Some(&fixture_id_token), + fixture_access_token, + ); + + let actual = actual + .url_params + .get(&URLParam::from("chatgpt_account_id".to_string())); + let expected = Some(&forge_domain::URLParamValue::from( + "acct_from_id_token".to_string(), + )); + + assert_eq!(actual, expected); + } + #[tokio::test] async fn test_refresh_oauth_credential_preserves_url_params() { let fixture_config = OAuthConfig { diff --git a/crates/forge_infra/src/auth/util.rs b/crates/forge_infra/src/auth/util.rs index 90f8bf1d71..a3890fc6b0 100644 --- a/crates/forge_infra/src/auth/util.rs +++ b/crates/forge_infra/src/auth/util.rs @@ -86,23 +86,6 @@ pub(crate) fn build_oauth_credential( )) } -/// Build OAuthTokenResponse with standard defaults -pub(crate) fn build_token_response( - access_token: String, - refresh_token: Option, - expires_in: Option, -) -> OAuthTokenResponse { - OAuthTokenResponse { - access_token, - refresh_token, - expires_in, - expires_at: None, - token_type: "Bearer".to_string(), - scope: None, - id_token: None, - } -} - /// Extract OAuth tokens from any credential type pub(crate) fn extract_oauth_tokens(credential: &AuthCredential) -> anyhow::Result<&OAuthTokens> { match &credential.auth_details { @@ -217,25 +200,18 @@ pub(crate) fn handle_oauth_error(error_code: &str) -> Result<(), Error> { } } -/// Parse token response from JSON -pub(crate) fn parse_token_response( - body: &str, -) -> Result<(String, Option, Option), Error> { - let token_response: serde_json::Value = serde_json::from_str(body) +/// Parse token response from JSON. +pub(crate) fn parse_token_response(body: &str) -> Result { + let token_response: OAuthTokenResponse = serde_json::from_str(body) .map_err(|e| Error::PollFailed(format!("Failed to parse token response: {e}")))?; - let access_token = token_response["access_token"] - .as_str() - .ok_or_else(|| Error::PollFailed("Missing access_token in response".to_string()))? - .to_string(); - - let refresh_token = token_response["refresh_token"] - .as_str() - .map(|s| s.to_string()); - - let expires_in = token_response["expires_in"].as_u64(); + if token_response.access_token.trim().is_empty() { + return Err(Error::PollFailed( + "Missing access_token in response".to_string(), + )); + } - Ok((access_token, refresh_token, expires_in)) + Ok(token_response) } #[cfg(test)] @@ -265,17 +241,20 @@ mod tests { } #[test] - fn test_build_token_response() { - let response = build_token_response( - "test_token".to_string(), - Some("refresh_token".to_string()), - Some(3600), - ); - - assert_eq!(response.access_token, "test_token"); - assert_eq!(response.refresh_token, Some("refresh_token".to_string())); - assert_eq!(response.expires_in, Some(3600)); - assert_eq!(response.token_type, "Bearer"); + fn test_parse_token_response_preserves_id_token() { + let fixture = r#"{ + "access_token": "test_token", + "refresh_token": "refresh_token", + "expires_in": 3600, + "id_token": "test_id_token" + }"#; + + let actual = parse_token_response(fixture).unwrap(); + + assert_eq!(actual.access_token, "test_token"); + assert_eq!(actual.refresh_token, Some("refresh_token".to_string())); + assert_eq!(actual.expires_in, Some(3600)); + assert_eq!(actual.id_token, Some("test_id_token".to_string())); } #[test]