Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,24 @@ impl fmt::Display for ForceHttpsButUriNotHttps {
}

impl std::error::Error for ForceHttpsButUriNotHttps {}

#[cfg(test)]
mod tests {
use super::*;
use hyper::Uri;
use tower_service::Service;

#[tokio::test]
async fn test_https_only_rejects_http() {
let tls = native_tls::TlsConnector::new().expect("TLS creation failed");
let http = HttpConnector::new();
let mut connector = HttpsConnector::from((http, tls.into()));
connector.https_only(true);

let uri = "http://example.com".parse::<Uri>().unwrap();
let result = connector.call(uri).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.to_string(), "https required but URI was not https");
}
}
116 changes: 116 additions & 0 deletions src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,119 @@ fn negotiated_h2<T: std::io::Read + std::io::Write>(s: &native_tls::TlsStream<T>
.map(|list| list == &b"h2"[..])
.unwrap_or(false)
}

#[cfg(test)]
mod tests {
use super::*;
use hyper::rt::{Read, ReadBufCursor, Write};
use hyper_util::client::legacy::connect::{Connected, Connection};
use std::pin::Pin;
use std::task::{Context, Poll};

struct MockStream;

impl fmt::Debug for MockStream {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("MockStream")
}
}

impl Read for MockStream {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: ReadBufCursor<'_>,
) -> Poll<Result<(), io::Error>> {
Poll::Pending
}
}

impl Write for MockStream {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
Poll::Pending
}

fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Poll::Pending
}

fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Poll::Pending
}
}

impl Connection for MockStream {
fn connected(&self) -> Connected {
Connected::new()
}
}

#[test]
fn test_maybe_https_stream_debug_http() {
let stream = MaybeHttpsStream::Http(MockStream);
assert_eq!(format!("{:?}", stream), "Http(MockStream)");
}

#[test]
fn test_maybe_https_stream_connected_http_delegates() {
struct ProxiedStream;

impl fmt::Debug for ProxiedStream {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("ProxiedStream")
}
}

impl Read for ProxiedStream {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: ReadBufCursor<'_>,
) -> Poll<Result<(), io::Error>> {
Poll::Pending
}
}

impl Write for ProxiedStream {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
Poll::Pending
}

fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Poll::Pending
}

fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Poll::Pending
}
}

impl Connection for ProxiedStream {
fn connected(&self) -> Connected {
Connected::new().proxy(true)
}
}

let stream = MaybeHttpsStream::Http(ProxiedStream);
assert!(stream.connected().is_proxied());
}

#[test]
fn test_connected_default_is_not_proxied() {
assert!(!Connected::new().is_proxied());
}

#[test]
fn test_maybe_https_stream_from() {
let stream: MaybeHttpsStream<MockStream> = MockStream.into();
assert!(matches!(stream, MaybeHttpsStream::Http(_)));
}
}