From a5fa36e5462d0a0d772cc15b766ef1228fe6a562 Mon Sep 17 00:00:00 2001 From: everpcpc Date: Thu, 7 May 2026 15:03:07 +0800 Subject: [PATCH] feat(auth): add key-pair authentication support Implement JWT-based key-pair authentication as per the key-pair-auth RFC. - Support RSA, ECDSA (ES256), and Ed25519 private keys - Support encrypted PKCS#8 private keys with passphrase - Handle PKCS#1, SEC1, and PKCS#8 key formats - Rebuild EC PKCS#8 with named curve for Ring compatibility - Extract PKCS#1 DER from PKCS#8 for RSA Ring compatibility - Add --private-key-file and --private-key-passphrase-file CLI flags - Send X-DATABEND-AUTH-METHOD: keypair header with Bearer token --- .gitignore | 1 + cli/src/main.rs | 22 +++ core/Cargo.toml | 4 + core/src/auth.rs | 421 +++++++++++++++++++++++++++++++++++++++++++++ core/src/client.rs | 24 ++- 5 files changed, 469 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 6f6caae78..1f535e909 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,4 @@ frontend/.next/ # Frontend build artifacts now use fixed filenames, included in git # cli/frontend/build/ will be committed to ensure cargo install has full UI +.alma-snapshots diff --git a/cli/src/main.rs b/cli/src/main.rs index e1bd62b84..64361feed 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -89,6 +89,15 @@ struct Args { )] password: Option, + #[clap( + long, + help = "Private key file for key-pair authentication, overrides password in DSN" + )] + private_key_file: Option, + + #[clap(long, help = "Passphrase file for encrypted private key")] + private_key_passphrase_file: Option, + #[clap(short = 'r', long, help = "Downgrade role name, overrides role in DSN")] role: Option, @@ -292,6 +301,19 @@ pub async fn main() -> Result<()> { if let Some(role) = args.role { conn_args.args.insert("role".to_string(), role); } + + // override private key file if specified in command line + if let Some(private_key_file) = args.private_key_file { + conn_args + .args + .insert("private_key_file".to_string(), private_key_file); + } + if let Some(private_key_passphrase_file) = args.private_key_passphrase_file { + conn_args.args.insert( + "private_key_passphrase_file".to_string(), + private_key_passphrase_file, + ); + } } let user = conn_args.user.clone(); diff --git a/core/Cargo.toml b/core/Cargo.toml index 548530001..001bd0565 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -41,6 +41,10 @@ tokio-retry = "0.3" tokio-util = { version = "0.7", features = ["io-util"] } url = { version = "2.5", default-features = false } uuid = { version = "1.16", features = ["std", "v4", "v7"] } +jsonwebtoken = "9" +pem = "3" +pkcs8 = { version = "0.11", features = ["encryption", "pem"] } [dev-dependencies] chrono = { workspace = true } +tempfile = "3" diff --git a/core/src/auth.rs b/core/src/auth.rs index 81e864c18..f69dabeed 100644 --- a/core/src/auth.rs +++ b/core/src/auth.rs @@ -12,7 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; + +use jsonwebtoken::{encode, Algorithm, EncodingKey, Header}; use reqwest::RequestBuilder; +use serde::Serialize; use crate::error::{Error, Result}; @@ -104,6 +109,244 @@ impl Auth for AccessTokenFileAuth { } } +const HEADER_AUTH_METHOD: &str = "X-DATABEND-AUTH-METHOD"; +const KEYPAIR_TOKEN_TTL_SECS: u64 = 60; + +#[derive(Serialize)] +struct KeyPairClaims { + sub: String, + iat: u64, + exp: u64, +} + +#[derive(Clone)] +pub struct KeyPairAuth { + username: String, + encoding_key: Arc, + algorithm: Algorithm, +} + +impl KeyPairAuth { + pub fn new( + username: impl ToString, + private_key_file: &str, + passphrase_file: Option<&str>, + ) -> Result { + let pem_data = std::fs::read(private_key_file).map_err(|e| { + Error::IO(format!( + "cannot read private key from file {}: {}", + private_key_file, e + )) + })?; + + let passphrase = match passphrase_file { + Some(path) => { + let p = std::fs::read_to_string(path).map_err(|e| { + Error::IO(format!("cannot read passphrase from file {}: {}", path, e)) + })?; + Some(p.trim().to_string()) + } + None => None, + }; + + let (encoding_key, algorithm) = Self::parse_private_key(&pem_data, passphrase.as_deref())?; + + Ok(Self { + username: username.to_string(), + encoding_key: Arc::new(encoding_key), + algorithm, + }) + } + + fn parse_private_key( + pem_data: &[u8], + passphrase: Option<&str>, + ) -> Result<(EncodingKey, Algorithm)> { + let pem_str = std::str::from_utf8(pem_data) + .map_err(|e| Error::IO(format!("private key is not valid UTF-8: {e}")))?; + + if let Some(passphrase) = passphrase { + // Encrypted PKCS#8 key — decrypt using pkcs8 crate to get DER + Self::parse_encrypted_key(pem_str, passphrase) + } else { + // Unencrypted key — detect type and use jsonwebtoken's PEM methods + Self::parse_unencrypted_key(pem_data, pem_str) + } + } + + fn parse_encrypted_key(pem_str: &str, passphrase: &str) -> Result<(EncodingKey, Algorithm)> { + use pkcs8::DecodePrivateKey; + + let doc = pkcs8::SecretDocument::from_pkcs8_encrypted_pem(pem_str, passphrase.as_bytes()) + .map_err(|e| Error::IO(format!("failed to decrypt private key: {e}")))?; + + let der_bytes = doc.as_bytes(); + + // Try each key type with DER + // from_*_der returns EncodingKey directly (infallible for the struct construction), + // but the underlying parsing may still fail at sign time. + // We try RSA first, then EC, then Ed25519 by attempting to parse the key info. + // Since from_*_der doesn't validate, we use the OID from the PKCS#8 structure. + let private_key_info = pkcs8::PrivateKeyInfoRef::try_from(der_bytes) + .map_err(|e| Error::IO(format!("failed to parse PKCS#8 DER: {e}")))?; + + let algorithm_oid = private_key_info.algorithm.oid; + + // RSA: 1.2.840.113549.1.1.1 + const RSA_OID: pkcs8::ObjectIdentifier = + pkcs8::ObjectIdentifier::new_unwrap("1.2.840.113549.1.1.1"); + // EC: 1.2.840.10045.2.1 + const EC_OID: pkcs8::ObjectIdentifier = + pkcs8::ObjectIdentifier::new_unwrap("1.2.840.10045.2.1"); + // Ed25519: 1.3.101.112 + const ED25519_OID: pkcs8::ObjectIdentifier = + pkcs8::ObjectIdentifier::new_unwrap("1.3.101.112"); + + if algorithm_oid == RSA_OID { + // Ring's from_der expects PKCS#1 RSAPrivateKey DER, not full PKCS#8. + // Extract the inner private key bytes from PKCS#8 PrivateKeyInfo. + let rsa_der = private_key_info.private_key.as_bytes(); + Ok((EncodingKey::from_rsa_der(rsa_der), Algorithm::RS256)) + } else if algorithm_oid == EC_OID { + // Ring requires named curve PKCS#8 format for EC keys. + // If the key uses explicit parameters (common with openssl genpkey), + // we rebuild a named curve PKCS#8 from the inner SEC1 private key. + let ec_der = Self::rebuild_ec_pkcs8_named_curve(private_key_info)?; + Ok((EncodingKey::from_ec_der(&ec_der), Algorithm::ES256)) + } else if algorithm_oid == ED25519_OID { + Ok((EncodingKey::from_ed_der(der_bytes), Algorithm::EdDSA)) + } else { + Err(Error::IO(format!( + "unsupported key algorithm OID: {algorithm_oid}" + ))) + } + } + + /// Rebuild a named-curve PKCS#8 DER for EC keys. + /// Ring only accepts PKCS#8 with named curve OID (not explicit parameters). + /// This extracts the SEC1 private key from PrivateKeyInfo and wraps it + /// in a minimal named-curve PKCS#8 structure. + fn rebuild_ec_pkcs8_named_curve(pki: pkcs8::PrivateKeyInfoRef) -> Result> { + use pkcs8::der::Encode; + + // Detect curve from AlgorithmIdentifier parameters + // For EC keys, parameters contain the curve OID (named) or explicit params + const P256_OID: pkcs8::ObjectIdentifier = + pkcs8::ObjectIdentifier::new_unwrap("1.2.840.10045.3.1.7"); + + // Try to extract named curve OID from parameters + let curve_oid = if let Some(params) = pki.algorithm.parameters { + // Try decoding as OID (named curve case) + params + .decode_as::() + .unwrap_or(P256_OID) + } else { + // Default to P-256 if no parameters (shouldn't happen for EC) + P256_OID + }; + + // Build AlgorithmIdentifier with named curve OID + let alg_id = pkcs8::AlgorithmIdentifierRef { + oid: pkcs8::ObjectIdentifier::new_unwrap("1.2.840.10045.2.1"), + parameters: Some(pkcs8::der::asn1::AnyRef::from(&curve_oid)), + }; + + // Rebuild PrivateKeyInfo with the named curve AlgorithmIdentifier + let new_pki = pkcs8::PrivateKeyInfo { + algorithm: alg_id, + private_key: pki.private_key, + public_key: pki.public_key, + }; + + new_pki + .to_der() + .map_err(|e| Error::IO(format!("failed to re-encode EC PKCS#8: {e}"))) + } + + fn parse_unencrypted_key(pem_data: &[u8], pem_str: &str) -> Result<(EncodingKey, Algorithm)> { + if pem_str.contains("RSA PRIVATE KEY") { + // PKCS#1 RSA key + let key = EncodingKey::from_rsa_pem(pem_data) + .map_err(|e| Error::IO(format!("failed to parse RSA private key: {e}")))?; + return Ok((key, Algorithm::RS256)); + } + + if pem_str.contains("EC PRIVATE KEY") { + // SEC1 EC key + let key = EncodingKey::from_ec_pem(pem_data) + .map_err(|e| Error::IO(format!("failed to parse EC private key: {e}")))?; + return Ok((key, Algorithm::ES256)); + } + + // PKCS#8 "BEGIN PRIVATE KEY" — parse OID to determine key type, + // then use from_*_der with full PKCS#8 DER (from_ec_pem has issues with PKCS#8 EC keys) + let pem_parsed = + pem::parse(pem_data).map_err(|e| Error::IO(format!("failed to parse PEM: {e}")))?; + let der_bytes = pem_parsed.contents(); + + let private_key_info = pkcs8::PrivateKeyInfoRef::try_from(der_bytes) + .map_err(|e| Error::IO(format!("failed to parse PKCS#8 DER: {e}")))?; + + let algorithm_oid = private_key_info.algorithm.oid; + + const RSA_OID: pkcs8::ObjectIdentifier = + pkcs8::ObjectIdentifier::new_unwrap("1.2.840.113549.1.1.1"); + const EC_OID: pkcs8::ObjectIdentifier = + pkcs8::ObjectIdentifier::new_unwrap("1.2.840.10045.2.1"); + const ED25519_OID: pkcs8::ObjectIdentifier = + pkcs8::ObjectIdentifier::new_unwrap("1.3.101.112"); + + if algorithm_oid == RSA_OID { + // Ring's from_der expects PKCS#1 RSAPrivateKey DER, not full PKCS#8. + // Extract the inner private key bytes from PKCS#8 PrivateKeyInfo. + let rsa_der = private_key_info.private_key.as_bytes(); + Ok((EncodingKey::from_rsa_der(rsa_der), Algorithm::RS256)) + } else if algorithm_oid == EC_OID { + // Ring requires named curve PKCS#8 format for EC keys. + // If the key uses explicit parameters (common with openssl genpkey), + // we rebuild a named curve PKCS#8 from the inner SEC1 private key. + let ec_der = Self::rebuild_ec_pkcs8_named_curve(private_key_info)?; + Ok((EncodingKey::from_ec_der(&ec_der), Algorithm::ES256)) + } else if algorithm_oid == ED25519_OID { + Ok((EncodingKey::from_ed_der(der_bytes), Algorithm::EdDSA)) + } else { + Err(Error::IO(format!( + "unsupported key algorithm OID: {algorithm_oid}" + ))) + } + } + + fn generate_jwt(&self) -> Result { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_err(|e| Error::IO(format!("system time error: {e}")))? + .as_secs(); + + let claims = KeyPairClaims { + sub: self.username.clone(), + iat: now, + exp: now + KEYPAIR_TOKEN_TTL_SECS, + }; + + let header = Header::new(self.algorithm); + encode(&header, &claims, &self.encoding_key) + .map_err(|e| Error::IO(format!("failed to sign JWT: {e}"))) + } +} + +impl Auth for KeyPairAuth { + fn wrap(&self, builder: RequestBuilder) -> Result { + let token = self.generate_jwt()?; + Ok(builder + .bearer_auth(token) + .header(HEADER_AUTH_METHOD, "keypair")) + } + + fn username(&self) -> String { + self.username.clone() + } +} + #[derive(::serde::Deserialize, ::serde::Serialize)] #[serde(from = "String", into = "String")] #[derive(Clone, Default, PartialEq, Eq)] @@ -167,4 +410,182 @@ mod tests { let debug = format!("{value:?}"); assert_eq!(debug, "\"**REDACTED**\""); } + + #[test] + fn keypair_auth_rsa() { + use std::io::Write; + use tempfile::NamedTempFile; + + // Generate a test RSA private key in PKCS#8 format using openssl command + // Note: `openssl genrsa` outputs PKCS#1 which may have compatibility issues + // with some versions of ring. Using genpkey ensures PKCS#8 format. + let output = std::process::Command::new("openssl") + .args([ + "genpkey", + "-algorithm", + "RSA", + "-pkeyopt", + "rsa_keygen_bits:2048", + ]) + .output(); + let output = match output { + Ok(o) if o.status.success() => o, + _ => { + // Skip test if openssl is not available + return; + } + }; + + let mut key_file = NamedTempFile::new().unwrap(); + key_file.write_all(&output.stdout).unwrap(); + + let auth = KeyPairAuth::new("testuser", key_file.path().to_str().unwrap(), None).unwrap(); + assert_eq!(auth.username(), "testuser"); + assert_eq!(auth.algorithm, Algorithm::RS256); + + // Verify JWT can be generated + let token = auth.generate_jwt().unwrap(); + assert!(!token.is_empty()); + + // Verify JWT structure (header.payload.signature) + let parts: Vec<&str> = token.split('.').collect(); + assert_eq!(parts.len(), 3); + } + + #[test] + fn keypair_auth_rsa_pkcs1() { + use std::io::Write; + use tempfile::NamedTempFile; + + // Generate a PKCS#1 RSA key (BEGIN RSA PRIVATE KEY) + let output = std::process::Command::new("openssl") + .args(["genrsa", "2048"]) + .output(); + let output = match output { + Ok(o) if o.status.success() => o, + _ => return, + }; + + let pem_str = String::from_utf8_lossy(&output.stdout); + if !pem_str.contains("RSA PRIVATE KEY") { + // Skip if openssl outputs PKCS#8 format instead + return; + } + + let mut key_file = NamedTempFile::new().unwrap(); + key_file.write_all(&output.stdout).unwrap(); + + let auth = KeyPairAuth::new("testuser", key_file.path().to_str().unwrap(), None).unwrap(); + assert_eq!(auth.algorithm, Algorithm::RS256); + + let token = auth.generate_jwt().unwrap(); + let parts: Vec<&str> = token.split('.').collect(); + assert_eq!(parts.len(), 3); + } + + #[test] + fn keypair_auth_ec() { + use std::io::Write; + use tempfile::NamedTempFile; + + // Generate a test EC private key in PKCS#8 format + let output = std::process::Command::new("openssl") + .args([ + "genpkey", + "-algorithm", + "EC", + "-pkeyopt", + "ec_paramgen_curve:P-256", + ]) + .output(); + let output = match output { + Ok(o) if o.status.success() => o, + _ => return, + }; + + let mut key_file = NamedTempFile::new().unwrap(); + key_file.write_all(&output.stdout).unwrap(); + + let auth = KeyPairAuth::new("testuser", key_file.path().to_str().unwrap(), None).unwrap(); + assert_eq!(auth.algorithm, Algorithm::ES256); + + let token = auth.generate_jwt().unwrap(); + let parts: Vec<&str> = token.split('.').collect(); + assert_eq!(parts.len(), 3); + } + + #[test] + fn keypair_auth_ed25519() { + use std::io::Write; + use tempfile::NamedTempFile; + + // Generate a test Ed25519 private key + let output = std::process::Command::new("openssl") + .args(["genpkey", "-algorithm", "ed25519"]) + .output(); + let output = match output { + Ok(o) if o.status.success() => o, + _ => return, + }; + + let mut key_file = NamedTempFile::new().unwrap(); + key_file.write_all(&output.stdout).unwrap(); + + let auth = KeyPairAuth::new("testuser", key_file.path().to_str().unwrap(), None).unwrap(); + assert_eq!(auth.algorithm, Algorithm::EdDSA); + + let token = auth.generate_jwt().unwrap(); + let parts: Vec<&str> = token.split('.').collect(); + assert_eq!(parts.len(), 3); + } + + #[test] + fn keypair_auth_encrypted_key() { + use std::io::Write; + use tempfile::NamedTempFile; + + // Generate an encrypted RSA private key with scrypt KDF (supported by pkcs8 crate) + let output = std::process::Command::new("openssl") + .args([ + "genpkey", + "-algorithm", + "RSA", + "-pkeyopt", + "rsa_keygen_bits:2048", + "-aes-256-cbc", + "-pass", + "pass:testpass", + "-v2prf", + "hmacWithSHA256", + ]) + .output(); + let output = match output { + Ok(o) if o.status.success() => o, + _ => return, + }; + + // Check if the generated key is actually encrypted + let pem_str = String::from_utf8_lossy(&output.stdout); + if !pem_str.contains("ENCRYPTED") { + return; + } + + let mut key_file = NamedTempFile::new().unwrap(); + key_file.write_all(&output.stdout).unwrap(); + + let mut pass_file = NamedTempFile::new().unwrap(); + pass_file.write_all(b"testpass\n").unwrap(); + + let auth = KeyPairAuth::new( + "testuser", + key_file.path().to_str().unwrap(), + Some(pass_file.path().to_str().unwrap()), + ) + .unwrap(); + assert_eq!(auth.algorithm, Algorithm::RS256); + + let token = auth.generate_jwt().unwrap(); + let parts: Vec<&str> = token.split('.').collect(); + assert_eq!(parts.len(), 3); + } } diff --git a/core/src/client.rs b/core/src/client.rs index 4c8924d63..f1c8d7431 100644 --- a/core/src/client.rs +++ b/core/src/client.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth}; +use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth, KeyPairAuth}; use crate::capability::Capability; use crate::client_mgr::{GLOBAL_CLIENT_MANAGER, GLOBAL_RUNTIME}; use crate::error_code::{need_refresh_token, ResponseWithErrorCode}; @@ -260,10 +260,11 @@ impl APIClient { client.host = host.to_string(); } - if u.username() != "" { + let username = u.username().to_string(); + if !username.is_empty() { let password = u.password().unwrap_or_default(); let password = percent_decode_str(password).decode_utf8()?; - client.auth = Arc::new(BasicAuth::new(u.username(), password)); + client.auth = Arc::new(BasicAuth::new(&username, password)); } let mut session_state = SessionState::default(); @@ -273,6 +274,9 @@ impl APIClient { session_state.set_database(database); } + let mut private_key_file: Option = None; + let mut private_key_passphrase_file: Option = None; + let mut scheme = "https"; for (k, v) in u.query_pairs() { match k.as_ref() { @@ -331,6 +335,12 @@ impl APIClient { "access_token_file" => { client.auth = Arc::new(AccessTokenFileAuth::new(v)); } + "private_key_file" => { + private_key_file = Some(v.to_string()); + } + "private_key_passphrase_file" => { + private_key_passphrase_file = Some(v.to_string()); + } "login" => { client.disable_login = match v.as_ref() { "disable" => true, @@ -373,6 +383,14 @@ impl APIClient { } } } + // If private_key_file is specified, use KeyPairAuth + if let Some(key_file) = private_key_file { + client.auth = Arc::new(KeyPairAuth::new( + &username, + &key_file, + private_key_passphrase_file.as_deref(), + )?); + } client.port = match u.port() { Some(p) => p, None => match scheme {