diff --git a/src/api.rs b/src/api.rs index 1967541..53418d0 100644 --- a/src/api.rs +++ b/src/api.rs @@ -5,9 +5,11 @@ use crate::privatebin::{Comment, DecryptedComment, Paste, PostCommentResponse, P use crate::util::check_filesize; use crate::DecryptedPaste; use rand_chacha::rand_core::{RngCore, SeedableRng}; +use reqwest::tls::Certificate; use reqwest::{Method, Url}; use scraper::{Html, Selector}; use std::str::FromStr; +use std::time::Duration; #[cfg_attr(feature = "uniffi", derive(uniffi::Object))] pub struct API { @@ -29,6 +31,34 @@ impl API { } impl API { + fn build_client(&self) -> PbResult { + let mut builder = reqwest::blocking::Client::builder(); + + let timeout_secs = self.opts.timeout.unwrap_or(30); + builder = builder + .connect_timeout(Duration::from_secs(timeout_secs)) + .timeout(Duration::from_secs(timeout_secs * 4)); + + if self.opts.insecure { + builder = builder.danger_accept_invalid_certs(true); + } + + if let Some(ref ca_path) = self.opts.ca_cert { + let pem = std::fs::read(ca_path).map_err(|e| { + PbError::InvalidCertificate(format!( + "failed to read CA cert {}: {}", + ca_path.display(), + e + )) + })?; + for cert in pem_certs_from_bundle(&pem)? { + builder = builder.add_root_certificate(cert); + } + } + + Ok(builder.build()?) + } + fn get_oidc_access_token(&self) -> PbResult { let oidc_token_endpoint = self.opts.oidc_token_url.as_ref().unwrap(); let oidc_client_id = self.opts.oidc_client_id.as_ref().unwrap(); @@ -41,7 +71,7 @@ impl API { post_fields.insert("username", oidc_username); post_fields.insert("password", oidc_password); - let client = reqwest::blocking::Client::builder().build()?; + let client = self.build_client()?; let mut request = client.post(oidc_token_endpoint); request = request.form(&post_fields); @@ -78,7 +108,7 @@ impl API { url: Url, json_request: bool, ) -> PbResult { - let client = reqwest::blocking::Client::builder().build()?; + let client = self.build_client()?; let mut request = client.request(Method::from_str(method).unwrap(), url); if json_request { @@ -95,6 +125,35 @@ impl API { } } +fn pem_certs_from_bundle(pem: &[u8]) -> PbResult> { + let pem_str = std::str::from_utf8(pem) + .map_err(|e| PbError::InvalidCertificate(format!("CA cert is not valid UTF-8: {}", e)))?; + let mut certs = Vec::new(); + let mut current = String::new(); + let mut in_cert = false; + for line in pem_str.lines() { + if line.contains("BEGIN CERTIFICATE") { + in_cert = true; + current.clear(); + current.push_str(line); + current.push('\n'); + } else if line.contains("END CERTIFICATE") { + current.push_str(line); + current.push('\n'); + certs.push( + Certificate::from_pem(current.as_bytes()).map_err(|e| { + PbError::InvalidCertificate(format!("invalid certificate in bundle: {}", e)) + })?, + ); + in_cert = false; + } else if in_cert { + current.push_str(line); + current.push('\n'); + } + } + Ok(certs) +} + #[cfg_attr(feature = "uniffi", uniffi::export)] impl API { pub fn get_paste(&self, paste_id: &str) -> PbResult { diff --git a/src/error.rs b/src/error.rs index 1cf63ab..0cc22cf 100644 --- a/src/error.rs +++ b/src/error.rs @@ -37,6 +37,7 @@ pub enum PasteError { InvalidTokenType(String), OidcBadRequest(serde_json::Value), LoggerInit(log::SetLoggerError), + InvalidCertificate(String), } impl std::error::Error for PasteError {} @@ -76,6 +77,7 @@ impl fmt::Display for PasteError { PasteError::LoggerInit(err) => { write!(f, "Failed to init logger: {}", err) } + PasteError::InvalidCertificate(msg) => write!(f, "{}", msg), } } } diff --git a/src/opts.rs b/src/opts.rs index 77ada4b..3e03132 100644 --- a/src/opts.rs +++ b/src/opts.rs @@ -108,6 +108,21 @@ pub struct Opts { #[clap(help("password to send to the token endpoint"))] pub oidc_password: Option, + #[cfg_attr(feature = "uniffi", uniffi(default = None))] + #[clap(long, value_name = "FILE")] + #[clap(help("path to a PEM CA certificate bundle for TLS verification"))] + pub ca_cert: Option, + + #[cfg_attr(feature = "uniffi", uniffi(default = false))] + #[clap(long)] + #[clap(help("accept invalid TLS certificates (insecure)"))] + pub insecure: bool, + + #[cfg_attr(feature = "uniffi", uniffi(default = None))] + #[clap(long, value_name = "SECONDS")] + #[clap(help("connection timeout in seconds (default: 30)"))] + pub timeout: Option, + #[cfg_attr(feature = "uniffi", uniffi(default = false))] #[clap(long)] #[clap(help("print debug output to stderr"))]