diff --git a/mtorrent-core/src/trackers/mod.rs b/mtorrent-core/src/trackers/mod.rs index 40b7591..5f51011 100644 --- a/mtorrent-core/src/trackers/mod.rs +++ b/mtorrent-core/src/trackers/mod.rs @@ -2,7 +2,6 @@ mod http; mod udp; mod url; -use futures_util::TryFutureExt; use local_async_utils::sec; use mtorrent_utils::net; use mtorrent_utils::peer_id::PeerId; @@ -10,7 +9,7 @@ use std::collections::HashMap; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; use std::time::Duration; use std::{io, iter}; -use tokio::net::{UdpSocket, lookup_host}; +use tokio::net::lookup_host; use tokio::sync::{mpsc, oneshot}; use tokio::task; use tokio_util::sync::CancellationToken; @@ -336,17 +335,14 @@ async fn new_udp_client( local_ipv4: Ipv4Addr, local_ipv6: Ipv6Addr, ) -> io::Result { - async fn bind_and_connect_socket( + async fn bind_and_connect( bind_addr: &SocketAddr, remote_addr: &SocketAddr, interface: Option<&str>, - ) -> io::Result { - let socket = UdpSocket::bind(bind_addr).await?; - if let Some(iface) = interface { - net::bind_to_interface(&socket, iface)?; - } + ) -> io::Result { + let socket = net::bound_udp_socket(*bind_addr, interface)?; socket.connect(&remote_addr).await?; - Ok(socket) + udp::TrackerConnection::from_connected_socket(socket).await } for tracker_addr in lookup_host(tracker_addr_str).await? { @@ -355,10 +351,7 @@ async fn new_udp_client( SocketAddr::V6(_) => local_ipv6.into(), }; let local_addr = SocketAddr::new(local_ip, 0); - if let Ok(client) = bind_and_connect_socket(&local_addr, &tracker_addr, interface) - .and_then(udp::TrackerConnection::from_connected_socket) - .await - { + if let Ok(client) = bind_and_connect(&local_addr, &tracker_addr, interface).await { return Ok(client); } } @@ -427,6 +420,7 @@ async fn do_udp_scrape( #[cfg(test)] mod tests { use super::*; + use tokio::net::UdpSocket; use tokio::time; fn init_loopback() -> (Client, Manager) { diff --git a/mtorrent-utils/src/net.rs b/mtorrent-utils/src/net.rs index 74f2ac7..1a041c2 100644 --- a/mtorrent-utils/src/net.rs +++ b/mtorrent-utils/src/net.rs @@ -1,8 +1,10 @@ use bytes::Buf; use socket2::SockRef; use std::hash::{DefaultHasher, Hash, Hasher}; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::time::Duration; use std::{io, ops}; +use tokio::net::{TcpSocket, UdpSocket}; #[cfg(not(windows))] pub(crate) fn get_local_addr(mut predicate: impl FnMut(&IpAddr) -> bool) -> Option { @@ -101,6 +103,77 @@ pub fn get_bind_addr_v6(interface: Option<&str>) -> Ipv6Addr { // ------------------------------------------------------------------------------------------------ +/// Create a UDP socket bound to the specified local address and network interface (if any). +pub fn bound_udp_socket(local_addr: SocketAddr, interface: Option<&str>) -> io::Result { + let socket = socket2::Socket::new( + match local_addr { + SocketAddr::V4(_) => socket2::Domain::IPV4, + SocketAddr::V6(_) => socket2::Domain::IPV6, + }, + socket2::Type::DGRAM, + None, + )?; + if local_addr.is_ipv6() { + socket.set_only_v6(true)?; + } + socket.set_nonblocking(true)?; + if let Some(interface) = interface { + bind_to_interface(&socket, interface)?; + } + socket.bind(&local_addr.into())?; + let std_socket = std::net::UdpSocket::from(socket); + UdpSocket::from_std(std_socket) +} + +/// Create a TCP socket bound to the specified local address and network interface (if any). +/// The following socket options are set on the created socket: +/// - SO_REUSEADDR (on all platforms) and SO_REUSEPORT (on Linux) +/// - SO_LINGER with 0 timeout, to avoid putting socket into TIME_WAIT when disconnecting someone +/// - TCP_NODELAY +pub fn bound_tcp_socket(local_addr: SocketAddr, interface: Option<&str>) -> io::Result { + let socket = socket2::Socket::new( + match local_addr { + SocketAddr::V4(_) => socket2::Domain::IPV4, + SocketAddr::V6(_) => socket2::Domain::IPV6, + }, + socket2::Type::STREAM, + None, + )?; + if local_addr.is_ipv6() { + socket.set_only_v6(true)?; + } + + // To use the same local addr and port for outgoing PWP connections and for TCP listener, + // (in order to deal with endpoint-independent NAT mappings, https://www.rfc-editor.org/rfc/rfc5128#section-2.3) + // we need to set SO_REUSEADDR on Windows, and SO_REUSEADDR and SO_REUSEPORT on Linux. + // See https://stackoverflow.com/a/14388707/4432988 for details. + socket.set_reuse_address(true)?; + #[cfg(not(windows))] + socket.set_reuse_port(true)?; + // To avoid putting socket into TIME_WAIT when disconnecting someone, enable SO_LINGER with 0 + // timeout See https://stackoverflow.com/a/71975993 + socket.set_linger(Some(Duration::ZERO))?; + socket.set_tcp_nodelay(true)?; + + socket.set_nonblocking(true)?; + if let Some(interface) = interface { + bind_to_interface(&socket, interface)?; + } + socket.bind(&local_addr.into())?; + #[cfg(any(unix, all(target_os = "wasi", not(target_env = "p1"))))] + unsafe { + use std::os::fd::{FromRawFd, IntoRawFd}; + Ok(FromRawFd::from_raw_fd(socket.into_raw_fd())) + } + #[cfg(windows)] + unsafe { + use std::os::windows::io::{FromRawSocket, IntoRawSocket}; + Ok(FromRawSocket::from_raw_socket(socket.into_raw_socket())) + } +} + +// ------------------------------------------------------------------------------------------------ + #[doc(hidden)] pub fn set_so_sndbuf_internal<'s>(socket: impl Into>, value: usize, module: &str) { if let Err(e) = socket.into().set_send_buffer_size(value) { @@ -135,14 +208,14 @@ macro_rules! set_so_rcvbuf { /// Bind a socket to a specific network interface. Does nothing on Windows. #[cfg(target_os = "windows")] -pub fn bind_to_interface<'s>(_socket: impl Into>, _interface: &str) -> io::Result<()> { +fn bind_to_interface<'s>(_socket: impl Into>, _interface: &str) -> io::Result<()> { // not supported on Windows Ok(()) } /// Bind a socket to a specific network interface. #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] -pub fn bind_to_interface<'s>(socket: impl Into>, interface: &str) -> io::Result<()> { +fn bind_to_interface<'s>(socket: impl Into>, interface: &str) -> io::Result<()> { let socket = socket.into(); socket.bind_device(Some(interface.as_bytes()))?; @@ -159,7 +232,7 @@ pub fn bind_to_interface<'s>(socket: impl Into>, interface: &str) -> target_os = "visionos", target_os = "watchos", ))] -pub fn bind_to_interface<'s>(socket: impl Into>, interface: &str) -> io::Result<()> { +fn bind_to_interface<'s>(socket: impl Into>, interface: &str) -> io::Result<()> { let socket = socket.into(); let interface = std::ffi::CString::new(interface)?; diff --git a/mtorrent-utils/src/upnp.rs b/mtorrent-utils/src/upnp.rs index 0df8e80..6071bdd 100644 --- a/mtorrent-utils/src/upnp.rs +++ b/mtorrent-utils/src/upnp.rs @@ -72,7 +72,7 @@ impl PortOpener { } /// Get the external socket address that was mapped to the internal port. - pub fn external_ip(&self) -> SocketAddr { + pub fn external_addr(&self) -> SocketAddr { self.external_addr } @@ -146,7 +146,7 @@ mod tests { let port_opener = PortOpener::new(PortMappingProtocol::TCP, internal_port, None, None) .await .unwrap_or_else(|e| panic!("Failed to create PortOpener: {e}")); - log::info!("port opener created, external ip: {}", port_opener.external_ip()); + log::info!("port opener created, external ip: {}", port_opener.external_addr()); time::sleep(sec!(1)).await; drop(port_opener); log::info!("port opener dropped"); diff --git a/mtorrent/src/app/dht.rs b/mtorrent/src/app/dht.rs index 9931aa6..b28d048 100644 --- a/mtorrent/src/app/dht.rs +++ b/mtorrent/src/app/dht.rs @@ -1,10 +1,8 @@ use mtorrent_dht as dht; use mtorrent_utils::{info_stopwatch, net, upnp, worker}; use std::io; -use std::net::SocketAddrV4; use std::path::PathBuf; use std::time::Duration; -use tokio::net::UdpSocket; use tokio::{join, task}; /// Startup configuration for the DHT system. @@ -65,7 +63,7 @@ async fn start_upnp(local_port: u16, interface: Option<&str>) -> io::Result<()> .await .map_err(io::Error::other)?; - log::info!("UPnP for DHT succeeded, public ip: {}", port_opener.external_ip()); + log::info!("UPnP for DHT succeeded, public ip: {}", port_opener.external_addr()); // start periodic renewal of port mapping. It will stop and remove the mapping // automatically once the DHT runtime shuts down @@ -90,22 +88,15 @@ async fn dht_main( ) { let _sw = info_stopwatch!("DHT"); - let socket = match UdpSocket::bind(SocketAddrV4::new( - net::get_bind_addr_v4(bind_interface.as_deref()), - local_port, - )) - .await - { - Err(e) => return log::error!("Failed to create a UDP socket for DHT: {e}"), - Ok(socket) => socket, - }; - - if let Some(interface) = &bind_interface - && let Err(e) = net::bind_to_interface(&socket, interface) - { - log::error!("Failed to bind DHT UDP socket to interface {interface}: {e}"); - return; - } + let local_ipv4 = net::get_bind_addr_v4(bind_interface.as_deref()); + let socket = + match net::bound_udp_socket((local_ipv4, local_port).into(), bind_interface.as_deref()) { + Err(e) => { + log::error!("Failed to create a UDP socket for DHT: {e}"); + return; + } + Ok(socket) => socket, + }; if use_upnp && let Err(e) = start_upnp(local_port, bind_interface.as_deref()).await { log::error!("UPnP for DHT failed: {e}"); diff --git a/mtorrent/src/app/main.rs b/mtorrent/src/app/main.rs index d4ba6e8..73dacff 100644 --- a/mtorrent/src/app/main.rs +++ b/mtorrent/src/app/main.rs @@ -6,7 +6,7 @@ use mtorrent_utils::peer_id::PeerId; use mtorrent_utils::{info_stopwatch, net, upnp}; use std::borrow::Borrow; use std::io; -use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::path::{Path, PathBuf}; use std::rc::Rc; use tokio::sync::broadcast; @@ -43,20 +43,20 @@ pub struct Context { #[derive(Clone)] struct Handles<'h> { - dht_handle: Option<&'h dht::CommandSink>, + dht: Option<&'h dht::CommandSink>, pwp_runtime: &'h runtime::Handle, storage_runtime: &'h runtime::Handle, - utp_handle: &'h ops::UtpHandle, - trackers_handle: &'h trackers::Client, + utp: ops::UtpHandle, + trackers: trackers::Client, } #[derive(Clone)] struct Params { local_peer_id: PeerId, - listener_port: u16, - pwp_external_port: u16, - pwp_local_addr_v4: Ipv4Addr, - pwp_local_addr_v6: Ipv6Addr, + internal_pwp_port: u16, + external_pwp_port: u16, + local_ip_v4: Ipv4Addr, + local_ip_v6: Ipv6Addr, bind_interface: Option, } @@ -74,7 +74,7 @@ async fn start_upnp( return internal_port; }; - let external_addr = port_opener.external_ip(); + let external_addr = port_opener.external_addr(); log::info!("UPnP: {proto:?} port mapping succeeded, public addr: {external_addr}"); task::spawn(async move { @@ -103,37 +103,38 @@ pub async fn single_torrent( } let ctx: &Context = ctx.borrow(); - let listener_port = cfg.pwp_port.unwrap_or_else(|| net::port_from_hash(&metainfo_uri.as_ref())); - let pwp_local_addr_v4 = net::get_bind_addr_v4(cfg.bind_interface.as_deref()); - let pwp_local_addr_v6 = net::get_bind_addr_v6(cfg.bind_interface.as_deref()); + let internal_pwp_port = + cfg.pwp_port.unwrap_or_else(|| net::port_from_hash(&metainfo_uri.as_ref())); + let local_addr_v4 = net::get_bind_addr_v4(cfg.bind_interface.as_deref()); + let local_addr_v6 = net::get_bind_addr_v6(cfg.bind_interface.as_deref()); // create port mappings and get external port to send correct listening port to trackers and // peers later let external_pwp_port = if cfg.use_upnp { let _g = ctx.pwp_runtime.enter(); - let (_public_pwp_port, public_utp_port) = join!( + let (_external_tcp_port, external_udp_port) = join!( start_upnp( - listener_port, + internal_pwp_port, cfg.pwp_port, upnp::PortMappingProtocol::TCP, cfg.bind_interface.as_deref() ), start_upnp( - listener_port, + internal_pwp_port, cfg.pwp_port, upnp::PortMappingProtocol::UDP, cfg.bind_interface.as_deref() ), ); - public_utp_port + external_udp_port } else { - listener_port + internal_pwp_port }; - // start uTP on IPv4 only for now let utp_handle = ops::launch_utp( &ctx.pwp_runtime, - (pwp_local_addr_v4, listener_port).into(), + SocketAddrV4::new(local_addr_v4, internal_pwp_port), + SocketAddrV6::new(local_addr_v6, internal_pwp_port, 0, 0), cfg.bind_interface.clone(), ); @@ -143,19 +144,19 @@ pub async fn single_torrent( ctx.pwp_runtime.spawn(trackers_mgr.run()); let handles = Handles { - dht_handle: ctx.dht_handle.as_ref(), + dht: ctx.dht_handle.as_ref(), pwp_runtime: &ctx.pwp_runtime, storage_runtime: &ctx.storage_runtime, - utp_handle: &utp_handle, - trackers_handle: &tracker_client, + utp: utp_handle, + trackers: tracker_client, }; let params = Params { local_peer_id: cfg.local_peer_id, - listener_port, - pwp_external_port: external_pwp_port, - pwp_local_addr_v4, - pwp_local_addr_v6, + internal_pwp_port, + external_pwp_port, + local_ip_v4: local_addr_v4, + local_ip_v6: local_addr_v6, bind_interface: cfg.bind_interface, }; @@ -220,10 +221,10 @@ async fn preliminary_stage( let ctx = ops::PreliminaryCtx::new( magnet_link, params.local_peer_id, - params.pwp_external_port, - params.listener_port, - params.pwp_local_addr_v4, - params.pwp_local_addr_v6, + params.external_pwp_port, + params.internal_pwp_port, + params.local_ip_v4, + params.local_ip_v6, params.bind_interface.clone(), ); @@ -234,29 +235,29 @@ async fn preliminary_stage( ctx_handle: ctx.clone(), pwp_worker_handle: handles.pwp_runtime.clone(), peer_reporter: peer_reporter.clone(), - utp_handle: handles.utp_handle.clone(), + utp_handle: handles.utp.clone(), }); tasks.spawn_local(connect_throttle.run()); - if let Err(e) = handles.utp_handle.restart(peer_reporter.clone()).await { + if let Err(e) = handles.utp.restart(peer_reporter.clone()).await { log::error!("Failed to restart uTP: {e}"); } - if let Err(e) = handles.trackers_handle.abort_all().await { + if let Err(e) = handles.trackers.abort_all().await { log::error!("Failed to abort tracker announces: {e}"); } - handles.dht_handle.map(|dht_cmds| { + handles.dht.map(|dht_cmds| { tasks.spawn_local(ops::run_dht_search( info_hash, dht_cmds.clone(), peer_reporter.clone(), - params.pwp_external_port, + params.external_pwp_port, )) }); tasks.spawn_on( ops::run_pwp_listener( - SocketAddr::new(params.pwp_local_addr_v4.into(), params.listener_port), + SocketAddr::new(params.local_ip_v4.into(), params.internal_pwp_port), params.bind_interface.clone(), peer_reporter.clone(), ), @@ -265,7 +266,7 @@ async fn preliminary_stage( tasks.spawn_on( ops::run_pwp_listener( - SocketAddr::new(params.pwp_local_addr_v6.into(), params.listener_port), + SocketAddr::new(params.local_ip_v6.into(), params.internal_pwp_port), params.bind_interface, peer_reporter.clone(), ), @@ -274,7 +275,7 @@ async fn preliminary_stage( tasks.spawn_local(ops::make_preliminary_announces( ctx.clone(), - handles.trackers_handle.clone(), + handles.trackers, peer_reporter.clone(), config_dir, )); @@ -320,10 +321,10 @@ async fn main_stage( let ctx: ops::Handle<_> = ops::MainCtx::new( metainfo, params.local_peer_id, - params.pwp_external_port, - params.listener_port, - params.pwp_local_addr_v4, - params.pwp_local_addr_v6, + params.external_pwp_port, + params.internal_pwp_port, + params.local_ip_v4, + params.local_ip_v6, params.bind_interface.clone(), )?; @@ -337,29 +338,29 @@ async fn main_stage( pwp_worker_handle: handles.pwp_runtime.clone(), peer_reporter: peer_reporter.clone(), piece_downloaded_channel: Rc::new(broadcast::Sender::new(2048)), - utp_handle: handles.utp_handle.clone(), + utp_handle: handles.utp.clone(), }); tasks.spawn_local(connect_throttle.run()); - if let Err(e) = handles.utp_handle.restart(peer_reporter.clone()).await { + if let Err(e) = handles.utp.restart(peer_reporter.clone()).await { log::error!("Failed to restart uTP: {e}"); } - if let Err(e) = handles.trackers_handle.abort_all().await { + if let Err(e) = handles.trackers.abort_all().await { log::error!("Failed to abort tracker announces: {e}"); } - handles.dht_handle.map(|dht_cmds| { + handles.dht.map(|dht_cmds| { tasks.spawn_local(ops::run_dht_search( info_hash, dht_cmds.clone(), peer_reporter.clone(), - params.pwp_external_port, + params.external_pwp_port, )) }); tasks.spawn_on( ops::run_pwp_listener( - SocketAddr::new(params.pwp_local_addr_v4.into(), params.listener_port), + SocketAddr::new(params.local_ip_v4.into(), params.internal_pwp_port), params.bind_interface.clone(), peer_reporter.clone(), ), @@ -368,7 +369,7 @@ async fn main_stage( tasks.spawn_on( ops::run_pwp_listener( - SocketAddr::new(params.pwp_local_addr_v6.into(), params.listener_port), + SocketAddr::new(params.local_ip_v6.into(), params.internal_pwp_port), params.bind_interface, peer_reporter.clone(), ), @@ -377,7 +378,7 @@ async fn main_stage( tasks.spawn_local(ops::make_periodic_announces( ctx.clone(), - handles.trackers_handle.clone(), + handles.trackers, peer_reporter.clone(), config_dir, )); diff --git a/mtorrent/src/ops/ctx.rs b/mtorrent/src/ops/ctx.rs index 2aecbec..fb3878a 100644 --- a/mtorrent/src/ops/ctx.rs +++ b/mtorrent/src/ops/ctx.rs @@ -55,8 +55,8 @@ pub(super) struct ConstData { local_peer_id: PeerId, pwp_external_port: u16, pwp_internal_port: u16, - pwp_local_addr_v4: Ipv4Addr, - pwp_local_addr_v6: Ipv6Addr, + local_ip_v4: Ipv4Addr, + local_ip_v6: Ipv6Addr, bind_interface: Option, outbound_pwp_mode: PwpMode, } @@ -71,11 +71,11 @@ impl ConstData { pub(super) fn pwp_internal_port(&self) -> u16 { self.pwp_internal_port } - pub(super) fn pwp_local_addr_v4(&self) -> Ipv4Addr { - self.pwp_local_addr_v4 + pub(super) fn local_ip_v4(&self) -> Ipv4Addr { + self.local_ip_v4 } - pub(super) fn pwp_local_addr_v6(&self) -> Ipv6Addr { - self.pwp_local_addr_v6 + pub(super) fn local_ip_v6(&self) -> Ipv6Addr { + self.local_ip_v6 } pub(super) fn bind_interface(&self) -> Option<&str> { self.bind_interface.as_deref() @@ -103,8 +103,8 @@ impl PreliminaryCtx { local_peer_id: PeerId, pwp_external_port: u16, pwp_internal_port: u16, - pwp_local_addr_v4: Ipv4Addr, - pwp_local_addr_v6: Ipv6Addr, + local_ip_v4: Ipv4Addr, + local_ip_v6: Ipv6Addr, bind_interface: Option, ) -> Handle { Handle::new(Self { @@ -117,8 +117,8 @@ impl PreliminaryCtx { local_peer_id, pwp_external_port, pwp_internal_port, - pwp_local_addr_v4, - pwp_local_addr_v6, + local_ip_v4, + local_ip_v6, bind_interface, outbound_pwp_mode: get_outbound_pwp_mode(), }, @@ -142,8 +142,8 @@ impl MainCtx { local_peer_id: PeerId, pwp_external_port: u16, pwp_internal_port: u16, - pwp_local_addr_v4: Ipv4Addr, - pwp_local_addr_v6: Ipv6Addr, + local_ip_v4: Ipv4Addr, + local_ip_v6: Ipv6Addr, bind_interface: Option, ) -> io::Result> { fn make_error(s: &'static str) -> impl FnOnce() -> io::Error { @@ -170,8 +170,8 @@ impl MainCtx { local_peer_id, pwp_external_port, pwp_internal_port, - pwp_local_addr_v4, - pwp_local_addr_v6, + local_ip_v4, + local_ip_v6, bind_interface, outbound_pwp_mode: get_outbound_pwp_mode(), }, @@ -295,8 +295,8 @@ impl ConstData { local_peer_id: PeerId::generate_new(), pwp_external_port: 12345, pwp_internal_port: 0, - pwp_local_addr_v4: Ipv4Addr::LOCALHOST, - pwp_local_addr_v6: Ipv6Addr::LOCALHOST, + local_ip_v4: Ipv4Addr::LOCALHOST, + local_ip_v6: Ipv6Addr::LOCALHOST, bind_interface: None, outbound_pwp_mode: PwpMode::Any, } diff --git a/mtorrent/src/ops/peer/tcp.rs b/mtorrent/src/ops/peer/tcp.rs index c013ac3..79a036d 100644 --- a/mtorrent/src/ops/peer/tcp.rs +++ b/mtorrent/src/ops/peer/tcp.rs @@ -2,39 +2,13 @@ use super::super::PeerReporter; use super::ctx; use bytes::BytesMut; use mtorrent_core::{pe, pwp}; -use mtorrent_utils::info_stopwatch; -use mtorrent_utils::net::bind_to_interface; use mtorrent_utils::peer_id::PeerId; +use mtorrent_utils::{info_stopwatch, net}; use std::io; use std::net::SocketAddr; -use tokio::net::{TcpSocket, TcpStream}; +use tokio::net::TcpStream; use tokio::runtime; -fn bound_pwp_socket(local_addr: SocketAddr, interface: Option<&str>) -> io::Result { - let socket = match local_addr { - SocketAddr::V4(_) => TcpSocket::new_v4()?, - SocketAddr::V6(_) => TcpSocket::new_v6()?, - }; - - // To use the same local addr and port for outgoing PWP connections and for TCP listener, - // (in order to deal with endpoint-independent NAT mappings, https://www.rfc-editor.org/rfc/rfc5128#section-2.3) - // we need to set SO_REUSEADDR on Windows, and SO_REUSEADDR and SO_REUSEPORT on Linux. - // See https://stackoverflow.com/a/14388707/4432988 for details. - socket.set_reuseaddr(true)?; - #[cfg(not(windows))] - socket.set_reuseport(true)?; - // To avoid putting socket into TIME_WAIT when disconnecting someone, enable SO_LINGER with 0 - // timeout See https://stackoverflow.com/a/71975993 - socket.set_zero_linger()?; - socket.set_nodelay(true)?; - - socket.bind(local_addr)?; - if let Some(interface) = interface { - bind_to_interface(&socket, interface)?; - } - Ok(socket) -} - pub async fn new_outbound_connection( data: &ctx::ConstData, info_hash: &[u8; 20], @@ -44,8 +18,8 @@ pub async fn new_outbound_connection( pwp_runtime: &runtime::Handle, ) -> io::Result<(pwp::DownloadChannels, pwp::UploadChannels, Option)> { let local_addr = match &peer_addr { - SocketAddr::V4(_) => data.pwp_local_addr_v4().into(), - SocketAddr::V6(_) => data.pwp_local_addr_v6().into(), + SocketAddr::V4(_) => data.local_ip_v4().into(), + SocketAddr::V6(_) => data.local_ip_v6().into(), }; let local_peer_id = *data.local_peer_id(); @@ -54,8 +28,10 @@ pub async fn new_outbound_connection( let local_port = data.pwp_internal_port(); pwp_runtime .spawn(async move { - let socket = - bound_pwp_socket(SocketAddr::new(local_addr, local_port), interface.as_deref())?; + let socket = net::bound_tcp_socket( + SocketAddr::new(local_addr, local_port), + interface.as_deref(), + )?; let mut stream = socket.connect(peer_addr).await?; let crypto = if protocol_encryption_enabled { pe::outbound_handshake(&mut stream, &info_hash, &[0u8; 0][..]).await? @@ -129,7 +105,7 @@ pub async fn run_pwp_listener( let _sw = info_stopwatch!("TCP listener on {local_addr}"); let result: io::Result<()> = async { - let socket = bound_pwp_socket(local_addr, interface.as_deref())?; + let socket = net::bound_tcp_socket(local_addr, interface.as_deref())?; let listener = socket.listen(1024)?; log::info!("TCP listener started on {}", listener.local_addr()?); loop { diff --git a/mtorrent/src/ops/peer/utp.rs b/mtorrent/src/ops/peer/utp.rs index 0308fb7..3108fe2 100644 --- a/mtorrent/src/ops/peer/utp.rs +++ b/mtorrent/src/ops/peer/utp.rs @@ -1,40 +1,61 @@ use super::super::PeerReporter; use bytes::BytesMut; -use futures_util::StreamExt; +use futures_util::{Stream, StreamExt, stream}; use mtorrent_core::{pe, pwp, utp}; -use mtorrent_utils::net::bind_to_interface; +use mtorrent_utils::net; use mtorrent_utils::peer_id::PeerId; use std::io; -use std::net::SocketAddr; -use tokio::net::UdpSocket; +use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; use tokio::{join, runtime, select, task, time}; +fn create_endpoint( + local_addr: SocketAddr, + interface: Option<&str>, +) -> io::Result<(utp::EndpointHandle, utp::InboundListener)> { + let socket = net::bound_udp_socket(local_addr, interface)?; + let (endpoint, listener, driver) = utp::new_endpoint(socket); + task::spawn_local(driver.run()); + Ok((endpoint, listener)) +} + pub fn launch_utp( pwp_runtime: &runtime::Handle, - local_addr: SocketAddr, + local_addr_v4: SocketAddrV4, + local_addr_v6: SocketAddrV6, interface: Option, ) -> UtpHandle { let (cmd_sender, cmd_receiver) = mpsc::channel(1); pwp_runtime.spawn(async move { - let Ok(socket) = UdpSocket::bind(local_addr) - .await - .inspect_err(|e| log::error!("Failed to create uTP socket: {e}")) - else { - return; - }; + task::spawn_local(async move { + let local_addr_v4 = SocketAddr::V4(local_addr_v4); + let local_addr_v6 = SocketAddr::V6(local_addr_v6); - if let Some(interface) = interface - && let Err(e) = bind_to_interface(&socket, &interface) - { - log::error!("Failed to bind uTP socket to interface {interface}: {e}"); - return; - } + let v6_result = create_endpoint(local_addr_v6, interface.as_deref()); + let v4_result = create_endpoint(local_addr_v4, interface.as_deref()); - task::spawn_local(async move { - let (endpoint, connect_reporter, udp_demux) = utp::new_endpoint(socket); - join!(udp_demux.run(), bridge_task(endpoint, cmd_receiver, connect_reporter)); + for (result, local_addr) in [(&v4_result, local_addr_v4), (&v6_result, local_addr_v6)] { + match &result { + Ok((_, _)) => log::info!("Created uTP endpoint on {local_addr}"), + Err(e) => log::error!("Failed to create uTP endpoint on {local_addr}: {e}"), + } + } + + match (v4_result, v6_result) { + (Ok(v4), Ok(v6)) => { + run_bridge(v4, v6, cmd_receiver).await; + } + (Ok(v4), Err(_)) => { + let ep = v4.0.clone(); + run_bridge(v4, (ep, stream::pending()), cmd_receiver).await; + } + (Err(_), Ok(v6)) => { + let ep = v6.0.clone(); + run_bridge((ep, stream::pending()), v6, cmd_receiver).await; + } + (Err(_), Err(_)) => {} + } }); }); UtpHandle(cmd_sender) @@ -115,14 +136,18 @@ impl UtpHandle { } } -async fn bridge_task( - endpoint: utp::EndpointHandle, +async fn run_bridge< + L1: Stream + Unpin, + L2: Stream + Unpin, +>( + (endpoint_v4, mut listener_v4): (utp::EndpointHandle, L1), + (endpoint_v6, mut listener_v6): (utp::EndpointHandle, L2), mut cmd_receiver: mpsc::Receiver, - mut listener: utp::InboundListener, ) { struct Bridge { reporter: Option, - endpoint: utp::EndpointHandle, + endpoint_v4: utp::EndpointHandle, + endpoint_v6: utp::EndpointHandle, } impl Bridge { @@ -142,10 +167,16 @@ async fn bridge_task( match cmd { Command::Restart { reporter } => { self.reporter = Some(reporter); - self.endpoint.reset_connections().await; + join!( + self.endpoint_v4.reset_connections(), + self.endpoint_v6.reset_connections(), + ); } Command::OutboundConnect { args, resp } => { - let endpoint = self.endpoint.clone(); + let endpoint = match args.peer_addr { + SocketAddr::V4(_) => self.endpoint_v4.clone(), + SocketAddr::V6(_) => self.endpoint_v6.clone(), + }; task::spawn_local(async move { let ret = time::timeout_at(args.deadline, async { let mut stream = @@ -172,7 +203,10 @@ async fn bridge_task( }); } Command::InboundConnect { args, resp } => { - let endpoint = self.endpoint.clone(); + let endpoint = match args.peer_addr { + SocketAddr::V4(_) => self.endpoint_v4.clone(), + SocketAddr::V6(_) => self.endpoint_v6.clone(), + }; task::spawn_local(async move { let ret = time::timeout_at(args.deadline, async { let stream = @@ -221,7 +255,8 @@ async fn bridge_task( let mut bridge = Bridge { reporter: None, - endpoint, + endpoint_v4, + endpoint_v6, }; loop { @@ -233,7 +268,13 @@ async fn bridge_task( }; bridge.process_command(cmd).await; } - inbound = listener.next() => { + inbound = listener_v4.next() => { + let Some((addr, data)) = inbound else { + break; + }; + bridge.report_inbound(addr, data).await; + } + inbound = listener_v6.next() => { let Some((addr, data)) = inbound else { break; };